diff --git a/src/impl/sctptransport.cpp b/src/impl/sctptransport.cpp index c3d85a2..dc2961c 100644 --- a/src/impl/sctptransport.cpp +++ b/src/impl/sctptransport.cpp @@ -85,8 +85,30 @@ static LogCounter static LogCounter COUNTER_BAD_SCTP_STATUS(plog::warning, "Number of SCTP packets received with a bad status"); -std::unordered_set SctpTransport::Instances; -std::shared_mutex SctpTransport::InstancesMutex; +class SctpTransport::InstancesSet { +public: + void insert(SctpTransport *instance) { + std::unique_lock lock(mMutex); + mSet.insert(instance); + } + + void erase(SctpTransport *instance) { + std::unique_lock lock(mMutex); + mSet.erase(instance); + } + + using shared_lock = std::shared_lock; + optional lock(SctpTransport *instance) { + shared_lock lock(mMutex); + return mSet.find(instance) != mSet.end() ? std::make_optional(std::move(lock)) : nullopt; + } + +private: + std::unordered_set mSet; + std::shared_mutex mMutex; +}; + +SctpTransport::InstancesSet *SctpTransport::Instances = new InstancesSet; void SctpTransport::Init() { usrsctp_init(0, &SctpTransport::WriteCallback, nullptr); @@ -143,10 +165,7 @@ SctpTransport::SctpTransport(shared_ptr lower, const Configuration &c PLOG_DEBUG << "Initializing SCTP transport"; usrsctp_register_address(this); - { - std::unique_lock lock(InstancesMutex); - Instances.insert(this); - } + Instances->insert(this); mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, nullptr, nullptr, 0, nullptr); if (!mSock) @@ -285,10 +304,7 @@ SctpTransport::~SctpTransport() { close(); usrsctp_deregister_address(this); - { - std::unique_lock lock(InstancesMutex); - Instances.erase(this); - } + Instances->erase(this); } void SctpTransport::start() { @@ -848,11 +864,8 @@ optional SctpTransport::rtt() { void SctpTransport::UpcallCallback(struct socket *, void *arg, int /* flags */) { auto *transport = static_cast(arg); - std::shared_lock lock(InstancesMutex); - if (Instances.find(transport) == Instances.end()) - return; - - transport->handleUpcall(); + if(auto lock = Instances->lock(transport)) + transport->handleUpcall(); } int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) { @@ -860,11 +873,10 @@ int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, // Workaround for sctplab/usrsctp#405: Send callback is invoked on already closed socket // https://github.com/sctplab/usrsctp/issues/405 - std::shared_lock lock(InstancesMutex); - if (Instances.find(transport) == Instances.end()) + if(auto lock = Instances->lock(transport)) + return transport->handleWrite(static_cast(data), len, tos, set_df); + else return -1; - - return transport->handleWrite(static_cast(data), len, tos, set_df); } } // namespace rtc::impl diff --git a/src/impl/sctptransport.hpp b/src/impl/sctptransport.hpp index a7cf7b2..654099b 100644 --- a/src/impl/sctptransport.hpp +++ b/src/impl/sctptransport.hpp @@ -120,8 +120,8 @@ private: static void UpcallCallback(struct socket *sock, void *arg, int flags); static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df); - static std::unordered_set Instances; - static std::shared_mutex InstancesMutex; + class InstancesSet; + static InstancesSet *Instances; }; } // namespace rtc::impl