diff --git a/include/rtc/channel.hpp b/include/rtc/channel.hpp index 93affbb..fc953f8 100644 --- a/include/rtc/channel.hpp +++ b/include/rtc/channel.hpp @@ -60,6 +60,8 @@ protected: virtual void triggerAvailable(size_t count); virtual void triggerBufferedAmount(size_t amount); + void resetCallbacks(); + private: synchronized_callback<> mOpenCallback; synchronized_callback<> mClosedCallback; diff --git a/include/rtc/datachannel.hpp b/include/rtc/datachannel.hpp index 9abdd3c..c209048 100644 --- a/include/rtc/datachannel.hpp +++ b/include/rtc/datachannel.hpp @@ -40,7 +40,7 @@ class DataChannel : public std::enable_shared_from_this, public Cha public: DataChannel(std::weak_ptr pc, unsigned int stream, string label, string protocol, Reliability reliability); - DataChannel(std::weak_ptr pc, std::shared_ptr transport, + DataChannel(std::weak_ptr pc, std::weak_ptr transport, unsigned int stream); ~DataChannel(); @@ -65,13 +65,13 @@ public: private: void remoteClose(); - void open(std::shared_ptr sctpTransport); + void open(std::shared_ptr transport); bool outgoing(mutable_message_ptr message); void incoming(message_ptr message); void processOpenMessage(message_ptr message); const std::weak_ptr mPeerConnection; - std::shared_ptr mSctpTransport; + std::weak_ptr mSctpTransport; unsigned int mStream; string mLabel; diff --git a/include/rtc/peerconnection.hpp b/include/rtc/peerconnection.hpp index acf9e38..0abcdbc 100644 --- a/include/rtc/peerconnection.hpp +++ b/include/rtc/peerconnection.hpp @@ -52,14 +52,13 @@ public: Connected = RTC_CONNECTED, Disconnected = RTC_DISCONNECTED, Failed = RTC_FAILED, - Closed = RTC_CLOSED, - Destroying = RTC_DESTROYING + Closed = RTC_CLOSED }; enum class GatheringState : int { New = RTC_GATHERING_NEW, InProgress = RTC_GATHERING_INPROGRESS, - Complete = RTC_GATHERING_COMPLETE, + Complete = RTC_GATHERING_COMPLETE }; PeerConnection(void); @@ -94,6 +93,7 @@ private: std::shared_ptr initIceTransport(Description::Role role); std::shared_ptr initDtlsTransport(); std::shared_ptr initSctpTransport(); + void closeTransports(); void endLocalCandidates(); bool checkFingerprint(const std::string &fingerprint) const; @@ -112,8 +112,10 @@ private: void processLocalDescription(Description description); void processLocalCandidate(Candidate candidate); void triggerDataChannel(std::weak_ptr weakDataChannel); - void changeState(State state); - void changeGatheringState(GatheringState state); + bool changeState(State state); + bool changeGatheringState(GatheringState state); + + void resetCallbacks(); const Configuration mConfig; const std::shared_ptr mCertificate; diff --git a/include/rtc/rtc.h b/include/rtc/rtc.h index f4a3d59..f610ec7 100644 --- a/include/rtc/rtc.h +++ b/include/rtc/rtc.h @@ -31,8 +31,7 @@ typedef enum { RTC_CONNECTED = 2, RTC_DISCONNECTED = 3, RTC_FAILED = 4, - RTC_CLOSED = 5, - RTC_DESTROYING = 6 // internal + RTC_CLOSED = 5 } rtcState; typedef enum { diff --git a/src/channel.cpp b/src/channel.cpp index e7a402b..a7cb422 100644 --- a/src/channel.cpp +++ b/src/channel.cpp @@ -88,5 +88,14 @@ void Channel::triggerBufferedAmount(size_t amount) { mBufferedAmountLowCallback(); } +void Channel::resetCallbacks() { + mOpenCallback = nullptr; + mClosedCallback = nullptr; + mErrorCallback = nullptr; + mMessageCallback = nullptr; + mAvailableCallback = nullptr; + mBufferedAmountLowCallback = nullptr; +} + } // namespace rtc diff --git a/src/datachannel.cpp b/src/datachannel.cpp index d4ae583..1ae435f 100644 --- a/src/datachannel.cpp +++ b/src/datachannel.cpp @@ -74,7 +74,7 @@ DataChannel::DataChannel(weak_ptr pc, unsigned int stream, strin mReliability(std::make_shared(std::move(reliability))), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {} -DataChannel::DataChannel(weak_ptr pc, shared_ptr transport, +DataChannel::DataChannel(weak_ptr pc, weak_ptr transport, unsigned int stream) : mPeerConnection(pc), mSctpTransport(transport), mStream(stream), mReliability(std::make_shared()), @@ -93,10 +93,13 @@ string DataChannel::protocol() const { return mProtocol; } Reliability DataChannel::reliability() const { return *mReliability; } void DataChannel::close() { - if (mIsOpen.exchange(false) && mSctpTransport) - mSctpTransport->reset(mStream); + if (mIsOpen.exchange(false)) + if (auto transport = mSctpTransport.lock()) + transport->reset(mStream); mIsClosed = true; mSctpTransport.reset(); + + resetCallbacks(); } void DataChannel::remoteClose() { @@ -158,8 +161,8 @@ size_t DataChannel::maxMessageSize() const { size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); } -void DataChannel::open(shared_ptr sctpTransport) { - mSctpTransport = sctpTransport; +void DataChannel::open(shared_ptr transport) { + mSctpTransport = transport; uint8_t channelType = static_cast(mReliability->type); if (mReliability->unordered) @@ -186,20 +189,24 @@ void DataChannel::open(shared_ptr sctpTransport) { std::copy(mLabel.begin(), mLabel.end(), end); std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size()); - mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); + transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); } bool DataChannel::outgoing(mutable_message_ptr message) { - if (mIsClosed || !mSctpTransport) + if (mIsClosed) throw std::runtime_error("DataChannel is closed"); if (message->size() > maxMessageSize()) throw std::runtime_error("Message size exceeds limit"); + auto transport = mSctpTransport.lock(); + if (!transport) + throw std::runtime_error("DataChannel has no transport"); + // Before the ACK has been received on a DataChannel, all messages must be sent ordered message->reliability = mIsOpen ? mReliability : nullptr; message->stream = mStream; - return mSctpTransport->send(message); + return transport->send(message); } void DataChannel::incoming(message_ptr message) { @@ -238,6 +245,10 @@ void DataChannel::incoming(message_ptr message) { } void DataChannel::processOpenMessage(message_ptr message) { + auto transport = mSctpTransport.lock(); + if (!transport) + throw std::runtime_error("DataChannel has no transport"); + if (message->size() < sizeof(OpenMessage)) throw std::invalid_argument("DataChannel open message too small"); @@ -274,7 +285,7 @@ void DataChannel::processOpenMessage(message_ptr message) { auto &ack = *reinterpret_cast(buffer.data()); ack.type = MESSAGE_ACK; - mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); + transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); mIsOpen = true; triggerOpen(); diff --git a/src/peerconnection.cpp b/src/peerconnection.cpp index 8bed3bd..8613871 100644 --- a/src/peerconnection.cpp +++ b/src/peerconnection.cpp @@ -24,6 +24,7 @@ #include "sctptransport.hpp" #include +#include namespace rtc { @@ -38,28 +39,11 @@ PeerConnection::PeerConnection() : PeerConnection(Configuration()) { PeerConnection::PeerConnection(const Configuration &config) : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {} -PeerConnection::~PeerConnection() { - changeState(State::Destroying); - close(); - mSctpTransport.reset(); - mDtlsTransport.reset(); - mIceTransport.reset(); -} +PeerConnection::~PeerConnection() { close(); } void PeerConnection::close() { - // Close DataChannels closeDataChannels(); - - // Close Transports - for (int i = 0; i < 2; ++i) { // Make sure a transport wasn't spawn behind our back - if (auto transport = std::atomic_load(&mSctpTransport)) - transport->stop(); - if (auto transport = std::atomic_load(&mDtlsTransport)) - transport->stop(); - if (auto transport = std::atomic_load(&mIceTransport)) - transport->stop(); - } - changeState(State::Closed); + closeTransports(); } const Configuration *PeerConnection::config() const { return &mConfig; } @@ -241,8 +225,15 @@ shared_ptr PeerConnection::initIceTransport(Description::Role role break; } }); + std::atomic_store(&mIceTransport, transport); + if (mState == State::Closed) { + mIceTransport.reset(); + transport->stop(); + throw std::runtime_error("Connection is closed"); + } return transport; + } catch (const std::exception &e) { PLOG_ERROR << e.what(); changeState(State::Failed); @@ -274,8 +265,15 @@ shared_ptr PeerConnection::initDtlsTransport() { break; } }); + std::atomic_store(&mDtlsTransport, transport); + if (mState == State::Closed) { + mDtlsTransport.reset(); + transport->stop(); + throw std::runtime_error("Connection is closed"); + } return transport; + } catch (const std::exception &e) { PLOG_ERROR << e.what(); changeState(State::Failed); @@ -312,8 +310,15 @@ shared_ptr PeerConnection::initSctpTransport() { break; } }); + std::atomic_store(&mSctpTransport, transport); + if (mState == State::Closed) { + mSctpTransport.reset(); + transport->stop(); + throw std::runtime_error("Connection is closed"); + } return transport; + } catch (const std::exception &e) { PLOG_ERROR << e.what(); changeState(State::Failed); @@ -321,6 +326,34 @@ shared_ptr PeerConnection::initSctpTransport() { } } +void PeerConnection::closeTransports() { + // Change state to sink state Closed to block init methods + changeState(State::Closed); + + // Reset callbacks now that state is changed + resetCallbacks(); + + // Pass the references to a thread, allowing to terminate a transport from its own thread + auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr)); + auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr)); + auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr)); + if (sctp || dtls || ice) { + std::thread t([sctp, dtls, ice]() mutable { + if (sctp) + sctp->stop(); + if (dtls) + dtls->stop(); + if (ice) + ice->stop(); + + sctp.reset(); + dtls.reset(); + ice.reset(); + }); + t.detach(); + } +} + void PeerConnection::endLocalCandidates() { std::lock_guard lock(mLocalDescriptionMutex); if (mLocalDescription) @@ -467,21 +500,34 @@ void PeerConnection::triggerDataChannel(weak_ptr weakDataChannel) { mDataChannelCallback(dataChannel); } -void PeerConnection::changeState(State state) { +bool PeerConnection::changeState(State state) { State current; do { current = mState.load(); - if (current == state || current == State::Destroying) - return; + if (current == state) + return true; + if (current == State::Closed) + return false; + } while (!mState.compare_exchange_weak(current, state)); - if (state != State::Destroying) - mStateChangeCallback(state); + mStateChangeCallback(state); + return true; } -void PeerConnection::changeGatheringState(GatheringState state) { +bool PeerConnection::changeGatheringState(GatheringState state) { if (mGatheringState.exchange(state) != state) mGatheringStateChangeCallback(state); + return true; +} + +void PeerConnection::resetCallbacks() { + // Unregister all callbacks + mDataChannelCallback = nullptr; + mLocalDescriptionCallback = nullptr; + mLocalCandidateCallback = nullptr; + mStateChangeCallback = nullptr; + mGatheringStateChangeCallback = nullptr; } } // namespace rtc @@ -508,9 +554,6 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &st case State::Closed: str = "closed"; break; - case State::Destroying: - str = "destroying"; - break; default: str = "unknown"; break; diff --git a/src/rtc.cpp b/src/rtc.cpp index b243045..4099c6d 100644 --- a/src/rtc.cpp +++ b/src/rtc.cpp @@ -53,6 +53,14 @@ void *getUserPointer(int id) { return it != userPointerMap.end() ? it->second : nullptr; } +void setUserPointer(int i, void *ptr) { + std::lock_guard lock(mutex); + if (ptr) + userPointerMap.insert(std::make_pair(i, ptr)); + else + userPointerMap.erase(i); +} + shared_ptr getPeerConnection(int id) { std::lock_guard lock(mutex); auto it = peerConnectionMap.find(id); @@ -99,12 +107,7 @@ bool eraseDataChannel(int dc) { void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast(level)); } -void rtcSetUserPointer(int i, void *ptr) { - if (ptr) - userPointerMap.insert(std::make_pair(i, ptr)); - else - userPointerMap.erase(i); -} +void rtcSetUserPointer(int i, void *ptr) { setUserPointer(i, ptr); } int rtcCreatePeerConnection(const rtcConfiguration *config) { Configuration c; diff --git a/test/capi.cpp b/test/capi.cpp index 6027d13..7caef76 100644 --- a/test/capi.cpp +++ b/test/capi.cpp @@ -157,6 +157,16 @@ int test_capi_main() { sleep(3); + if (peer1->state != RTC_CONNECTED || peer2->state != RTC_CONNECTED) { + fprintf(stderr, "PeerConnection is not connected\n"); + goto error; + } + + if (!peer1->connected || !peer2->connected) { + fprintf(stderr, "DataChannel is not connected\n"); + goto error; + } + char buffer[256]; if (rtcGetLocalAddress(peer1->pc, buffer, 256) >= 0) printf("Local address 1: %s\n", buffer); @@ -167,13 +177,12 @@ int test_capi_main() { if (rtcGetRemoteAddress(peer2->pc, buffer, 256) >= 0) printf("Remote address 2: %s\n", buffer); - if (peer1->connected && peer2->connected) { - deletePeer(peer1); - deletePeer(peer2); - sleep(1); - printf("Success\n"); - return 0; - } + deletePeer(peer1); + sleep(1); + deletePeer(peer2); + + printf("Success\n"); + return 0; error: deletePeer(peer1); diff --git a/test/connectivity.cpp b/test/connectivity.cpp index 3c9d559..7c83d5b 100644 --- a/test/connectivity.cpp +++ b/test/connectivity.cpp @@ -108,6 +108,13 @@ void test_connectivity() { this_thread::sleep_for(3s); + if (pc1->state() != PeerConnection::State::Connected && + pc2->state() != PeerConnection::State::Connected) + throw runtime_error("PeerConnection is not connected"); + + if (!dc1->isOpen() || !dc2->isOpen()) + throw runtime_error("DataChannel is not open"); + if (auto addr = pc1->localAddress()) cout << "Local address 1: " << *addr << endl; if (auto addr = pc1->remoteAddress()) @@ -117,13 +124,10 @@ void test_connectivity() { if (auto addr = pc2->remoteAddress()) cout << "Remote address 2: " << *addr << endl; - if (!dc1->isOpen() || !dc2->isOpen()) - throw runtime_error("DataChannel is not open"); - + // Delay close of peer 2 to check closing works properly pc1->close(); - pc2->close(); - this_thread::sleep_for(1s); + pc2->close(); cout << "Success" << endl; }