diff --git a/CMakeLists.txt b/CMakeLists.txt index ff4c075..bbdf2b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ option(NO_WEBSOCKET "Disable WebSocket support" OFF) option(NO_EXAMPLES "Disable examples" OFF) option(NO_TESTS "Disable tests build" OFF) option(WARNINGS_AS_ERRORS "Treat warnings as errors" OFF) +option(RSA_KEY_BITS_2048 "Use 2048-bit RSA key instead of 3072-bit" OFF) option(CAPI_STDCALL "Set calling convention of C API callbacks stdcall" OFF) # Option USE_SRTP defaults to AUTO (enabled if libSRTP is found, else disabled) set(USE_SRTP AUTO CACHE STRING "Use libSRTP and enable media support") @@ -229,6 +230,11 @@ else() target_link_libraries(datachannel-static PRIVATE LibJuice::LibJuiceStatic) endif() +if(RSA_KEY_BITS_2048) + target_compile_definitions(datachannel PUBLIC RSA_KEY_BITS_2048) + target_compile_definitions(datachannel-static PUBLIC RSA_KEY_BITS_2048) +endif() + if(CAPI_STDCALL) target_compile_definitions(datachannel PUBLIC CAPI_STDCALL) target_compile_definitions(datachannel-static PUBLIC CAPI_STDCALL) diff --git a/include/rtc/description.hpp b/include/rtc/description.hpp index 3fec1bc..d7da416 100644 --- a/include/rtc/description.hpp +++ b/include/rtc/description.hpp @@ -34,8 +34,8 @@ namespace rtc { class Description { public: - enum class Type { Unspec = 0, Offer = 1, Answer = 2 }; - enum class Role { ActPass = 0, Passive = 1, Active = 2 }; + enum class Type { Unspec, Offer, Answer, Pranswer, Rollback }; + enum class Role { ActPass, Passive, Active }; enum class Direction { SendOnly, RecvOnly, SendRecv, Inactive, Unknown }; Description(const string &sdp, const string &typeString = ""); @@ -45,10 +45,9 @@ public: Type type() const; string typeString() const; Role role() const; - string roleString() const; string bundleMid() const; - string iceUfrag() const; - string icePwd() const; + std::optional iceUfrag() const; + std::optional icePwd() const; std::optional fingerprint() const; bool ended() const; @@ -56,6 +55,7 @@ public: void setFingerprint(string fingerprint); void addCandidate(Candidate candidate); + void addCandidates(std::vector candidates); void endCandidates(); std::vector extractCandidates(); @@ -94,8 +94,7 @@ public: struct Application : public Entry { public: Application(string mid = "data"); - Application(const Application &other) = default; - Application(Application &&other) = default; + virtual ~Application() = default; string description() const override; Application reciprocate() const; @@ -121,8 +120,6 @@ public: public: Media(const string &sdp); Media(const string &mline, string mid, Direction dir = Direction::SendOnly); - Media(const Media &other) = default; - Media(Media &&other) = default; virtual ~Media() = default; string description() const override; @@ -180,6 +177,7 @@ public: bool hasApplication() const; bool hasAudioOrVideo() const; + bool hasMid(string_view mid) const; int addMedia(Media media); int addMedia(Application application); @@ -193,6 +191,9 @@ public: Application *application(); + static Type stringToType(const string &typeString); + static string typeToString(Type type); + private: std::optional defaultCandidate() const; std::shared_ptr createEntry(string mline, string mid, Direction dir); @@ -204,7 +205,7 @@ private: Role mRole; string mUsername; string mSessionId; - string mIceUfrag, mIcePwd; + std::optional mIceUfrag, mIcePwd; std::optional mFingerprint; // Entries @@ -214,14 +215,12 @@ private: // Candidates std::vector mCandidates; bool mEnded = false; - - static Type stringToType(const string &typeString); - static string typeToString(Type type); - static string roleToString(Role role); }; } // namespace rtc std::ostream &operator<<(std::ostream &out, const rtc::Description &description); +std::ostream &operator<<(std::ostream &out, rtc::Description::Type type); +std::ostream &operator<<(std::ostream &out, rtc::Description::Role role); #endif diff --git a/include/rtc/include.hpp b/include/rtc/include.hpp index 4362521..a9c0195 100644 --- a/include/rtc/include.hpp +++ b/include/rtc/include.hpp @@ -62,7 +62,7 @@ using std::uint8_t; const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length const size_t MAX_NUMERICSERV_LEN = 6; // Max port string representation length -const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default +const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size @@ -72,7 +72,7 @@ const int THREADPOOL_SIZE = 4; // Number of threads in the global thread pool // overloaded helper template struct overloaded : Ts... { using Ts::operator()...; }; -template overloaded(Ts...)->overloaded; +template overloaded(Ts...) -> overloaded; // weak_ptr bind helper template auto weak_bind(F &&f, T *t, Args &&... _args) { @@ -85,6 +85,23 @@ template auto weak_bind(F &&f, T *t, }; } +// scope_guard helper +class scope_guard { +public: + scope_guard(std::function func) : function(std::move(func)) {} + scope_guard(scope_guard &&other) = delete; + scope_guard(const scope_guard &) = delete; + void operator=(const scope_guard &) = delete; + + ~scope_guard() { + if (function) + function(); + } + +private: + std::function function; +}; + template class synchronized_callback { public: synchronized_callback() = default; @@ -127,6 +144,6 @@ private: std::function callback; mutable std::recursive_mutex mutex; }; -} +} // namespace rtc #endif diff --git a/include/rtc/peerconnection.hpp b/include/rtc/peerconnection.hpp index b0af009..0f39890 100644 --- a/include/rtc/peerconnection.hpp +++ b/include/rtc/peerconnection.hpp @@ -67,6 +67,14 @@ public: Complete = RTC_GATHERING_COMPLETE }; + enum class SignalingState : int { + Stable = RTC_SIGNALING_STABLE, + HaveLocalOffer = RTC_SIGNALING_HAVE_LOCAL_OFFER, + HaveRemoteOffer = RTC_SIGNALING_HAVE_REMOTE_OFFER, + HaveLocalPranswer = RTC_SIGNALING_HAVE_LOCAL_PRANSWER, + HaveRemotePranswer = RTC_SIGNALING_HAVE_REMOTE_PRANSWER, + } rtcSignalingState; + PeerConnection(void); PeerConnection(const Configuration &config); ~PeerConnection(); @@ -76,6 +84,7 @@ public: const Configuration *config() const; State state() const; GatheringState gatheringState() const; + SignalingState signalingState() const; bool hasLocalDescription() const; bool hasRemoteDescription() const; bool hasMedia() const; @@ -83,8 +92,9 @@ public: std::optional remoteDescription() const; std::optional localAddress() const; std::optional remoteAddress() const; + bool getSelectedCandidatePair(Candidate *local, Candidate *remote); - void setLocalDescription(); + void setLocalDescription(Description::Type type = Description::Type::Unspec); void setRemoteDescription(Description description); void addRemoteCandidate(Candidate candidate); @@ -100,6 +110,7 @@ public: void onLocalCandidate(std::function callback); void onStateChange(std::function callback); void onGatheringStateChange(std::function callback); + void onSignalingStateChange(std::function callback); // Stats void clearStats(); @@ -111,9 +122,6 @@ public: std::shared_ptr addTrack(Description::Media description); void onTrack(std::function track)> callback); - // libnice only - bool getSelectedCandidatePair(Candidate *local, Candidate *remote); - private: std::shared_ptr initIceTransport(Description::Role role); std::shared_ptr initDtlsTransport(); @@ -137,12 +145,16 @@ private: void incomingTrack(Description::Media description); void openTracks(); + void validateRemoteDescription(const Description &description); void processLocalDescription(Description description); void processLocalCandidate(Candidate candidate); + void processRemoteDescription(Description description); + void processRemoteCandidate(Candidate candidate); void triggerDataChannel(std::weak_ptr weakDataChannel); void triggerTrack(std::shared_ptr track); bool changeState(State state); bool changeGatheringState(GatheringState state); + bool changeSignalingState(SignalingState state); void resetCallbacks(); @@ -154,6 +166,7 @@ private: const std::unique_ptr mProcessor; std::optional mLocalDescription, mRemoteDescription; + std::optional mCurrentLocalDescription; mutable std::mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; std::shared_ptr mIceTransport; @@ -168,18 +181,22 @@ private: std::atomic mState; std::atomic mGatheringState; + std::atomic mSignalingState; + std::atomic mNegotiationNeeded; synchronized_callback> mDataChannelCallback; synchronized_callback mLocalDescriptionCallback; synchronized_callback mLocalCandidateCallback; synchronized_callback mStateChangeCallback; synchronized_callback mGatheringStateChangeCallback; + synchronized_callback mSignalingStateChangeCallback; synchronized_callback> mTrackCallback; }; } // namespace rtc -std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &state); -std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::GatheringState &state); +std::ostream &operator<<(std::ostream &out, rtc::PeerConnection::State state); +std::ostream &operator<<(std::ostream &out, rtc::PeerConnection::GatheringState state); +std::ostream &operator<<(std::ostream &out, rtc::PeerConnection::SignalingState state); #endif diff --git a/include/rtc/rtc.h b/include/rtc/rtc.h index 24a330b..2367db8 100644 --- a/include/rtc/rtc.h +++ b/include/rtc/rtc.h @@ -59,6 +59,14 @@ typedef enum { RTC_GATHERING_COMPLETE = 2 } rtcGatheringState; +typedef enum { + RTC_SIGNALING_STABLE = 0, + RTC_SIGNALING_HAVE_LOCAL_OFFER = 1, + RTC_SIGNALING_HAVE_REMOTE_OFFER = 2, + RTC_SIGNALING_HAVE_LOCAL_PRANSWER = 3, + RTC_SIGNALING_HAVE_REMOTE_PRANSWER = 4, +} rtcSignalingState; + typedef enum { // Don't change, it must match plog severity RTC_LOG_NONE = 0, RTC_LOG_FATAL = 1, @@ -92,6 +100,7 @@ typedef void (RTC_API *rtcDescriptionCallbackFunc)(int pc, const char *sdp, cons typedef void (RTC_API *rtcCandidateCallbackFunc)(int pc, const char *cand, const char *mid, void *ptr); typedef void (RTC_API *rtcStateChangeCallbackFunc)(int pc, rtcState state, void *ptr); typedef void (RTC_API *rtcGatheringStateCallbackFunc)(int pc, rtcGatheringState state, void *ptr); +typedef void (RTC_API *rtcSignalingStateCallbackFunc)(int pc, rtcSignalingState state, void *ptr); typedef void (RTC_API *rtcDataChannelCallbackFunc)(int pc, int dc, void *ptr); typedef void (RTC_API *rtcTrackCallbackFunc)(int pc, int tr, void *ptr); typedef void (RTC_API *rtcOpenCallbackFunc)(int id, void *ptr); @@ -116,8 +125,9 @@ RTC_EXPORT int rtcSetLocalDescriptionCallback(int pc, rtcDescriptionCallbackFunc RTC_EXPORT int rtcSetLocalCandidateCallback(int pc, rtcCandidateCallbackFunc cb); RTC_EXPORT int rtcSetStateChangeCallback(int pc, rtcStateChangeCallbackFunc cb); RTC_EXPORT int rtcSetGatheringStateChangeCallback(int pc, rtcGatheringStateCallbackFunc cb); +RTC_EXPORT int rtcSetSignalingStateChangeCallback(int pc, rtcSignalingStateCallbackFunc cb); -RTC_EXPORT int rtcSetLocalDescription(int pc); +RTC_EXPORT int rtcSetLocalDescription(int pc, const char *type); RTC_EXPORT int rtcSetRemoteDescription(int pc, const char *sdp, const char *type); RTC_EXPORT int rtcAddRemoteCandidate(int pc, const char *cand, const char *mid); diff --git a/include/rtc/track.hpp b/include/rtc/track.hpp index 5c55b57..b43e91c 100644 --- a/include/rtc/track.hpp +++ b/include/rtc/track.hpp @@ -43,6 +43,8 @@ public: string mid() const; Description::Media description() const; + void setDescription(Description::Media description); + void close(void) override; bool send(message_variant data) override; bool send(const byte *data, size_t size); diff --git a/src/capi.cpp b/src/capi.cpp index c2cae6f..671ff4c 100644 --- a/src/capi.cpp +++ b/src/capi.cpp @@ -317,7 +317,7 @@ int rtcCreateDataChannel(int pc, const char *label) { int rtcCreateDataChannelExt(int pc, const char *label, const char *protocol, const rtcReliability *reliability) { int dc = rtcAddDataChannelExt(pc, label, protocol, reliability); - rtcSetLocalDescription(pc); + rtcSetLocalDescription(pc, NULL); return dc; } @@ -468,6 +468,19 @@ int rtcSetGatheringStateChangeCallback(int pc, rtcGatheringStateCallbackFunc cb) }); } +int rtcSetSignalingStateChangeCallback(int pc, rtcSignalingStateCallbackFunc cb) { + return WRAP({ + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onSignalingStateChange([pc, cb](PeerConnection::SignalingState state) { + if (auto ptr = getUserPointer(pc)) + cb(pc, static_cast(state), *ptr); + }); + else + peerConnection->onGatheringStateChange(nullptr); + }); +} + int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb) { return WRAP({ auto peerConnection = getPeerConnection(pc); @@ -500,10 +513,11 @@ int rtcSetTrackCallback(int pc, rtcTrackCallbackFunc cb) { }); } -int rtcSetLocalDescription(int pc) { +int rtcSetLocalDescription(int pc, const char *type) { return WRAP({ auto peerConnection = getPeerConnection(pc); - peerConnection->setLocalDescription(); + peerConnection->setLocalDescription(type ? Description::stringToType(type) + : Description::Type::Unspec); }); } diff --git a/src/certificate.cpp b/src/certificate.cpp index 927cab8..f13e8ec 100644 --- a/src/certificate.cpp +++ b/src/certificate.cpp @@ -99,7 +99,11 @@ certificate_ptr make_certificate_impl(string commonName) { unique_ptr crt(new_crt(), free_crt); unique_ptr privkey(new_privkey(), free_privkey); +#ifdef RSA_KEY_BITS_2048 + const unsigned int bits = 2048; +#else const unsigned int bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_RSA, GNUTLS_SEC_PARAM_HIGH); +#endif gnutls::check(gnutls_x509_privkey_generate(*privkey, GNUTLS_PK_RSA, bits, 0), "Unable to generate key pair"); @@ -190,7 +194,11 @@ certificate_ptr make_certificate_impl(string commonName) { if (!x509 || !pkey || !rsa || !exponent || !serial_number || !name) throw std::runtime_error("Unable allocate structures for certificate generation"); - const int bits = 4096; +#ifdef RSA_KEY_BITS_2048 + const int bits = 2048; +#else + const int bits = 3072; +#endif const unsigned int e = 65537; // 2^16 + 1 if (!pkey || !rsa || !exponent || !BN_set_word(exponent.get(), e) || diff --git a/src/description.cpp b/src/description.cpp index 2d50654..e1eb75e 100644 --- a/src/description.cpp +++ b/src/description.cpp @@ -26,6 +26,7 @@ #include #include #include +#include using std::shared_ptr; using std::size_t; @@ -129,12 +130,6 @@ Description::Description(const string &sdp, Type type, Role role) } } - if (mIceUfrag.empty()) - throw std::invalid_argument("Missing ice-ufrag parameter in SDP description"); - - if (mIcePwd.empty()) - throw std::invalid_argument("Missing ice-pwd parameter in SDP description"); - if (mUsername.empty()) mUsername = "rtc"; @@ -152,16 +147,14 @@ string Description::typeString() const { return typeToString(mType); } Description::Role Description::role() const { return mRole; } -string Description::roleString() const { return roleToString(mRole); } - string Description::bundleMid() const { // Get the mid of the first media return !mEntries.empty() ? mEntries[0]->mid() : "0"; } -string Description::iceUfrag() const { return mIceUfrag; } +std::optional Description::iceUfrag() const { return mIceUfrag; } -string Description::icePwd() const { return mIcePwd; } +std::optional Description::icePwd() const { return mIcePwd; } std::optional Description::fingerprint() const { return mFingerprint; } @@ -183,6 +176,11 @@ void Description::addCandidate(Candidate candidate) { mCandidates.emplace_back(std::move(candidate)); } +void Description::addCandidates(std::vector candidates) { + for(auto candidate : candidates) + mCandidates.emplace_back(std::move(candidate)); +} + void Description::endCandidates() { mEnded = true; } std::vector Description::extractCandidates() { @@ -222,13 +220,14 @@ string Description::generateSdp(string_view eol) const { // Session-level attributes sdp << "a=msid-semantic:WMS *" << eol; - sdp << "a=setup:" << roleToString(mRole) << eol; - sdp << "a=ice-ufrag:" << mIceUfrag << eol; - sdp << "a=ice-pwd:" << mIcePwd << eol; + sdp << "a=setup:" << mRole << eol; + if (mIceUfrag) + sdp << "a=ice-ufrag:" << *mIceUfrag << eol; + if (mIcePwd) + sdp << "a=ice-pwd:" << *mIcePwd << eol; if (!mEnded) sdp << "a=ice-options:trickle" << eol; - if (mFingerprint) sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol; @@ -281,13 +280,14 @@ string Description::generateApplicationSdp(string_view eol) const { // Session-level attributes sdp << "a=msid-semantic:WMS *" << eol; - sdp << "a=setup:" << roleToString(mRole) << eol; - sdp << "a=ice-ufrag:" << mIceUfrag << eol; - sdp << "a=ice-pwd:" << mIcePwd << eol; + sdp << "a=setup:" << mRole << eol; + if (mIceUfrag) + sdp << "a=ice-ufrag:" << *mIceUfrag << eol; + if (mIcePwd) + sdp << "a=ice-pwd:" << *mIcePwd << eol; if (!mEnded) sdp << "a=ice-options:trickle" << eol; - if (mFingerprint) sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol; @@ -351,6 +351,14 @@ bool Description::hasAudioOrVideo() const { return false; } +bool Description::hasMid(string_view mid) const { + for (const auto &entry : mEntries) + if (entry->mid() == mid) + return true; + + return false; +} + int Description::addMedia(Media media) { mEntries.emplace_back(std::make_shared(std::move(media))); return int(mEntries.size()) - 1; @@ -767,33 +775,30 @@ Description::Video::Video(string mid, Direction dir) : Media("video 9 UDP/TLS/RTP/SAVPF", std::move(mid), dir) {} Description::Type Description::stringToType(const string &typeString) { - if (typeString == "offer") - return Type::Offer; - else if (typeString == "answer") - return Type::Answer; - else - return Type::Unspec; + using TypeMap_t = std::unordered_map; + static const TypeMap_t TypeMap = {{"unspec", Type::Unspec}, + {"offer", Type::Offer}, + {"answer", Type::Pranswer}, + {"pranswer", Type::Pranswer}, + {"rollback", Type::Rollback}}; + auto it = TypeMap.find(typeString); + return it != TypeMap.end() ? it->second : Type::Unspec; } string Description::typeToString(Type type) { switch (type) { + case Type::Unspec: + return "unspec"; case Type::Offer: return "offer"; case Type::Answer: return "answer"; + case Type::Pranswer: + return "pranswer"; + case Type::Rollback: + return "rollback"; default: - return ""; - } -} - -string Description::roleToString(Role role) { - switch (role) { - case Role::Active: - return "active"; - case Role::Passive: - return "passive"; - default: - return "actpass"; + return "unknown"; } } @@ -802,3 +807,25 @@ string Description::roleToString(Role role) { std::ostream &operator<<(std::ostream &out, const rtc::Description &description) { return out << std::string(description); } + +std::ostream &operator<<(std::ostream &out, rtc::Description::Type type) { + return out << rtc::Description::typeToString(type); +} + +std::ostream &operator<<(std::ostream &out, rtc::Description::Role role) { + using Role = rtc::Description::Role; + const char *str; + // Used for SDP generation, do not change + switch (role) { + case Role::Active: + str = "active"; + break; + case Role::Passive: + str = "passive"; + break; + default: + str = "actpass"; + break; + } + return out << str; +} diff --git a/src/peerconnection.cpp b/src/peerconnection.cpp index 7707a05..049ae6d 100644 --- a/src/peerconnection.cpp +++ b/src/peerconnection.cpp @@ -44,7 +44,8 @@ PeerConnection::PeerConnection() : PeerConnection(Configuration()) {} PeerConnection::PeerConnection(const Configuration &config) : mConfig(config), mCertificate(make_certificate()), mProcessor(std::make_unique()), - mState(State::New), mGatheringState(GatheringState::New) { + mState(State::New), mGatheringState(GatheringState::New), + mSignalingState(SignalingState::Stable), mNegotiationNeeded(false) { PLOG_VERBOSE << "Creating PeerConnection"; if (config.portRangeEnd && config.portRangeBegin > config.portRangeEnd) @@ -60,6 +61,8 @@ PeerConnection::~PeerConnection() { void PeerConnection::close() { PLOG_VERBOSE << "Closing PeerConnection"; + mNegotiationNeeded = false; + // Close data channels asynchronously mProcessor->enqueue(std::bind(&PeerConnection::closeDataChannels, this)); @@ -72,6 +75,8 @@ PeerConnection::State PeerConnection::state() const { return mState; } PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; } +PeerConnection::SignalingState PeerConnection::signalingState() const { return mSignalingState; } + std::optional PeerConnection::localDescription() const { std::lock_guard lock(mLocalDescriptionMutex); return mLocalDescription; @@ -97,88 +102,178 @@ bool PeerConnection::hasMedia() const { return local && local->hasAudioOrVideo(); } -void PeerConnection::setLocalDescription() { - PLOG_VERBOSE << "Setting local description"; +void PeerConnection::setLocalDescription(Description::Type type) { + PLOG_VERBOSE << "Setting local description, type=" << Description::typeToString(type); - if (std::atomic_load(&mIceTransport)) { - PLOG_DEBUG << "Local description is already set, ignoring"; + SignalingState signalingState = mSignalingState.load(); + if (type == Description::Type::Rollback) { + if (signalingState == SignalingState::HaveLocalOffer || + signalingState == SignalingState::HaveLocalPranswer) { + PLOG_DEBUG << "Rolling back pending local description"; + + std::unique_lock lock(mLocalDescriptionMutex); + if (mCurrentLocalDescription) { + std::vector existingCandidates; + if (mLocalDescription) + existingCandidates = mLocalDescription->extractCandidates(); + + mLocalDescription.emplace(std::move(*mCurrentLocalDescription)); + mLocalDescription->addCandidates(std::move(existingCandidates)); + mCurrentLocalDescription.reset(); + } + lock.unlock(); + + changeSignalingState(SignalingState::Stable); + } return; } - // RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of - // setup:actpass. - // See https://tools.ietf.org/html/rfc5763#section-5 - auto iceTransport = initIceTransport(Description::Role::ActPass); - Description localDescription = iceTransport->getLocalDescription(Description::Type::Offer); - processLocalDescription(localDescription); - iceTransport->gatherLocalCandidates(); + // Guess the description type if unspecified + if (type == Description::Type::Unspec) { + if (mSignalingState == SignalingState::HaveRemoteOffer) + type = Description::Type::Answer; + else + type = Description::Type::Offer; + } + + // Only a local offer resets the negotiation needed flag + if (type == Description::Type::Offer && !mNegotiationNeeded.exchange(false)) { + PLOG_DEBUG << "No negotiation needed"; + return; + } + + // Get the new signaling state + SignalingState newSignalingState; + switch (signalingState) { + case SignalingState::Stable: + if (type != Description::Type::Offer) { + std::ostringstream oss; + oss << "Unexpected local desciption type " << type << " in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::HaveLocalOffer; + break; + + case SignalingState::HaveRemoteOffer: + case SignalingState::HaveLocalPranswer: + if (type != Description::Type::Answer && type != Description::Type::Pranswer) { + std::ostringstream oss; + oss << "Unexpected local description type " << type + << " description in signaling state " << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::Stable; + break; + + default: { + std::ostringstream oss; + oss << "Unexpected local description in signaling state " << signalingState << ", ignoring"; + LOG_WARNING << oss.str(); + return; + } + } + + auto iceTransport = std::atomic_load(&mIceTransport); + if (!iceTransport) { + // RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of + // setup:actpass. + // See https://tools.ietf.org/html/rfc5763#section-5 + iceTransport = initIceTransport(Description::Role::ActPass); + } + + Description localDescription = iceTransport->getLocalDescription(type); + processLocalDescription(std::move(localDescription)); + + changeSignalingState(newSignalingState); + + if (mGatheringState == GatheringState::New) + iceTransport->gatherLocalCandidates(); } void PeerConnection::setRemoteDescription(Description description) { PLOG_VERBOSE << "Setting remote description: " << string(description); - if (hasRemoteDescription()) - throw std::logic_error("Remote description is already set"); + if (description.type() == Description::Type::Rollback) { + // This is mostly useless because we accept any offer + PLOG_VERBOSE << "Rolling back pending remote description"; + changeSignalingState(SignalingState::Stable); + return; + } - if (description.mediaCount() == 0) - throw std::invalid_argument("Remote description has no media line"); + validateRemoteDescription(description); - int activeMediaCount = 0; - for (int i = 0; i < description.mediaCount(); ++i) - std::visit( // reciprocate each media - rtc::overloaded{[&](Description::Application *) { ++activeMediaCount; }, - [&](Description::Media *media) { - if (media->direction() != Description::Direction::Inactive) - ++activeMediaCount; - }}, - description.media(i)); - - if (activeMediaCount == 0) - throw std::invalid_argument("Remote description has no active media"); - - if (!description.fingerprint()) - throw std::invalid_argument("Remote description has no fingerprint"); - - description.hintType(hasLocalDescription() ? Description::Type::Answer - : Description::Type::Offer); - - if (description.type() == Description::Type::Offer) { - if (hasLocalDescription()) { - PLOG_ERROR << "Got a remote offer description while an answer was expected"; - throw std::logic_error("Got an unexpected remote offer description"); + // Get the new signaling state + SignalingState signalingState = mSignalingState.load(); + SignalingState newSignalingState; + switch (signalingState) { + case SignalingState::Stable: + description.hintType(Description::Type::Offer); + if (description.type() != Description::Type::Offer) { + std::ostringstream oss; + oss << "Unexpected remote " << description.type() << " description in signaling state " + << signalingState; + throw std::logic_error(oss.str()); } - } else { // Answer - if (auto local = localDescription()) { - if (description.iceUfrag() == local->iceUfrag() && - description.icePwd() == local->icePwd()) - throw std::logic_error("Got the local description as remote description"); - } else { - PLOG_ERROR << "Got a remote answer description while an offer was expected"; - throw std::logic_error("Got an unexpected remote answer description"); + newSignalingState = SignalingState::HaveRemoteOffer; + break; + + case SignalingState::HaveLocalOffer: + description.hintType(Description::Type::Answer); + if (description.type() == Description::Type::Offer) { + // The ICE agent will automatically initiate a rollback when a peer that had previously + // created an offer receives an offer from the remote peer + setLocalDescription(Description::Type::Rollback); + newSignalingState = SignalingState::HaveRemoteOffer; + break; } + if (description.type() != Description::Type::Answer && + description.type() != Description::Type::Pranswer) { + std::ostringstream oss; + oss << "Unexpected remote " << description.type() << " description in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::Stable; + break; + + case SignalingState::HaveRemotePranswer: + description.hintType(Description::Type::Answer); + if (description.type() != Description::Type::Answer && + description.type() != Description::Type::Pranswer) { + std::ostringstream oss; + oss << "Unexpected remote " << description.type() << " description in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::Stable; + break; + + default: { + std::ostringstream oss; + oss << "Unexpected remote description in signaling state " << signalingState; + throw std::logic_error(oss.str()); + } } // Candidates will be added at the end, extract them for now auto remoteCandidates = description.extractCandidates(); + auto type = description.type(); auto iceTransport = std::atomic_load(&mIceTransport); if (!iceTransport) iceTransport = initIceTransport(Description::Role::ActPass); + iceTransport->setRemoteDescription(description); + processRemoteDescription(std::move(description)); - { - // Set as remote description - std::lock_guard lock(mRemoteDescriptionMutex); - mRemoteDescription.emplace(std::move(description)); - } + changeSignalingState(newSignalingState); - if (description.type() == Description::Type::Offer) { - // This is an offer and we are the answerer. - Description localDescription = iceTransport->getLocalDescription(Description::Type::Answer); - processLocalDescription(localDescription); - iceTransport->gatherLocalCandidates(); + if (type == Description::Type::Offer) { + // This is an offer, we need to answer + setLocalDescription(Description::Type::Answer); } else { - // This is an answer and we are the offerer. + // This is an answer auto sctpTransport = std::atomic_load(&mSctpTransport); if (!sctpTransport && iceTransport->role() == Description::Role::Active) { // Since we assumed passive role during DataChannel creation, we need to shift the @@ -203,27 +298,7 @@ void PeerConnection::setRemoteDescription(Description description) { void PeerConnection::addRemoteCandidate(Candidate candidate) { PLOG_VERBOSE << "Adding remote candidate: " << string(candidate); - - auto iceTransport = std::atomic_load(&mIceTransport); - if (!mRemoteDescription || !iceTransport) - throw std::logic_error("Remote candidate set without remote description"); - - if (candidate.resolve(Candidate::ResolveMode::Simple)) { - iceTransport->addRemoteCandidate(candidate); - } else { - // OK, we might need a lookup, do it asynchronously - // We don't use the thread pool because we have no control on the timeout - weak_ptr weakIceTransport{iceTransport}; - std::thread t([weakIceTransport, candidate]() mutable { - if (candidate.resolve(Candidate::ResolveMode::Lookup)) - if (auto iceTransport = weakIceTransport.lock()) - iceTransport->addRemoteCandidate(candidate); - }); - t.detach(); - } - - std::lock_guard lock(mRemoteDescriptionMutex); - mRemoteDescription->addCandidate(candidate); + processRemoteCandidate(std::move(candidate)); } std::optional PeerConnection::localAddress() const { @@ -238,11 +313,6 @@ std::optional PeerConnection::remoteAddress() const { shared_ptr PeerConnection::addDataChannel(string label, string protocol, Reliability reliability) { - if (auto local = localDescription(); local && !local->hasApplication()) { - PLOG_ERROR << "The PeerConnection was negociated without DataChannel support."; - throw std::runtime_error("No DataChannel support on the PeerConnection"); - } - // RFC 5763: The answerer MUST use either a setup attribute value of setup:active or // setup:passive. [...] Thus, setup:active is RECOMMENDED. // See https://tools.ietf.org/html/rfc5763#section-5 @@ -257,6 +327,11 @@ shared_ptr PeerConnection::addDataChannel(string label, string prot if (transport->state() == SctpTransport::State::Connected) channel->open(transport); + // Renegotiation is needed iff the current local description does not have application + std::lock_guard lock(mLocalDescriptionMutex); + if (!mLocalDescription || !mLocalDescription->hasApplication()) + mNegotiationNeeded = true; + return channel; } @@ -288,21 +363,30 @@ void PeerConnection::onGatheringStateChange(std::function callback) { + mSignalingStateChangeCallback = callback; +} + std::shared_ptr PeerConnection::addTrack(Description::Media description) { - if (hasLocalDescription()) - throw std::logic_error("Tracks must be created before local description"); - - if (auto it = mTracks.find(description.mid()); it != mTracks.end()) - if (auto track = it->second.lock()) - return track; - #if !RTC_ENABLE_MEDIA if (mTracks.empty()) { PLOG_WARNING << "Tracks will be inative (not compiled with SRTP support)"; } #endif - auto track = std::make_shared(std::move(description)); - mTracks.emplace(std::make_pair(track->mid(), track)); + + std::shared_ptr track; + if (auto it = mTracks.find(description.mid()); it != mTracks.end()) + if (track = it->second.lock(); track) + track->setDescription(std::move(description)); + + if (!track) { + track = std::make_shared(std::move(description)); + mTracks.emplace(std::make_pair(track->mid(), track)); + } + + // Renegotiation is needed for the new or updated track + mNegotiationNeeded = true; + return track; } @@ -311,6 +395,7 @@ void PeerConnection::onTrack(std::function)> callbac } shared_ptr PeerConnection::initIceTransport(Description::Role role) { + PLOG_VERBOSE << "Starting ICE transport"; try { if (auto transport = std::atomic_load(&mIceTransport)) return transport; @@ -373,6 +458,7 @@ shared_ptr PeerConnection::initIceTransport(Description::Role role } shared_ptr PeerConnection::initDtlsTransport() { + PLOG_VERBOSE << "Starting DTLS transport"; try { if (auto transport = std::atomic_load(&mDtlsTransport)) return transport; @@ -388,12 +474,12 @@ shared_ptr PeerConnection::initDtlsTransport() { switch (state) { case DtlsTransport::State::Connected: - if (auto local = localDescription(); local && local->hasApplication()) + if (auto remote = remoteDescription(); remote && remote->hasApplication()) initSctpTransport(); else changeState(State::Connected); - openTracks(); + mProcessor->enqueue(std::bind(&PeerConnection::openTracks, this)); break; case DtlsTransport::State::Failed: changeState(State::Failed); @@ -443,42 +529,43 @@ shared_ptr PeerConnection::initDtlsTransport() { } shared_ptr PeerConnection::initSctpTransport() { + PLOG_VERBOSE << "Starting SCTP transport"; try { if (auto transport = std::atomic_load(&mSctpTransport)) return transport; auto remote = remoteDescription(); if (!remote || !remote->application()) - throw std::logic_error("Initializing SCTP transport without application description"); + throw std::logic_error("Starting SCTP transport without application description"); uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT); auto lower = std::atomic_load(&mDtlsTransport); auto transport = std::make_shared( - lower, sctpPort, weak_bind(&PeerConnection::forwardMessage, this, _1), - weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2), - [this, weak_this = weak_from_this()](SctpTransport::State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case SctpTransport::State::Connected: - changeState(State::Connected); - mProcessor->enqueue(std::bind(&PeerConnection::openDataChannels, this)); - break; - case SctpTransport::State::Failed: - LOG_WARNING << "SCTP transport failed"; - changeState(State::Failed); - mProcessor->enqueue(std::bind(&PeerConnection::remoteCloseDataChannels, this)); - break; - case SctpTransport::State::Disconnected: - changeState(State::Disconnected); - mProcessor->enqueue(std::bind(&PeerConnection::remoteCloseDataChannels, this)); - break; - default: - // Ignore - break; - } - }); + lower, sctpPort, weak_bind(&PeerConnection::forwardMessage, this, _1), + weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2), + [this, weak_this = weak_from_this()](SctpTransport::State state) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (state) { + case SctpTransport::State::Connected: + changeState(State::Connected); + mProcessor->enqueue(std::bind(&PeerConnection::openDataChannels, this)); + break; + case SctpTransport::State::Failed: + LOG_WARNING << "SCTP transport failed"; + changeState(State::Failed); + mProcessor->enqueue(std::bind(&PeerConnection::remoteCloseDataChannels, this)); + break; + case SctpTransport::State::Disconnected: + changeState(State::Disconnected); + mProcessor->enqueue(std::bind(&PeerConnection::remoteCloseDataChannels, this)); + break; + default: + // Ignore + break; + } + }); std::atomic_store(&mSctpTransport, transport); if (mState == State::Closed) { @@ -499,7 +586,8 @@ void PeerConnection::closeTransports() { PLOG_VERBOSE << "Closing transports"; // Change state to sink state Closed - changeState(State::Closed); + if (!changeState(State::Closed)) + return; // already closed // Reset callbacks now that state is changed resetCallbacks(); @@ -723,40 +811,105 @@ void PeerConnection::openTracks() { std::shared_lock lock(mTracksMutex); // read-only for (auto it = mTracks.begin(); it != mTracks.end(); ++it) if (auto track = it->second.lock()) - track->open(srtpTransport); + if (!track->isOpen()) + track->open(srtpTransport); } #endif } +void PeerConnection::validateRemoteDescription(const Description &description) { + if (!description.iceUfrag()) + throw std::invalid_argument("Remote description has no ICE user fragment"); + + if (!description.icePwd()) + throw std::invalid_argument("Remote description has no ICE password"); + + if (!description.fingerprint()) + throw std::invalid_argument("Remote description has no fingerprint"); + + if (description.mediaCount() == 0) + throw std::invalid_argument("Remote description has no media line"); + + int activeMediaCount = 0; + for (int i = 0; i < description.mediaCount(); ++i) + std::visit(rtc::overloaded{[&](const Description::Application *) { ++activeMediaCount; }, + [&](const Description::Media *media) { + if (media->direction() != Description::Direction::Inactive) + ++activeMediaCount; + }}, + description.media(i)); + + if (activeMediaCount == 0) + throw std::invalid_argument("Remote description has no active media"); + + if (auto local = localDescription(); local && local->iceUfrag() && local->icePwd()) + if (*description.iceUfrag() == *local->iceUfrag() && + *description.icePwd() == *local->icePwd()) + throw std::logic_error("Got the local description as remote description"); + + PLOG_VERBOSE << "Remote description looks valid"; +} void PeerConnection::processLocalDescription(Description description) { - int activeMediaCount = 0; - - if (hasLocalDescription()) - throw std::logic_error("Local description is already set"); - if (auto remote = remoteDescription()) { // Reciprocate remote description for (int i = 0; i < remote->mediaCount(); ++i) std::visit( // reciprocate each media rtc::overloaded{ - [&](Description::Application *app) { - auto reciprocated = app->reciprocate(); + [&](Description::Application *remoteApp) { + std::shared_lock lock(mDataChannelsMutex); + if (!mDataChannels.empty()) { + // Prefer local description + Description::Application app(remoteApp->mid()); + app.setSctpPort(DEFAULT_SCTP_PORT); + app.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); + + PLOG_DEBUG << "Adding application to local description, mid=\"" + << app.mid() << "\""; + + description.addMedia(std::move(app)); + return; + } + + auto reciprocated = remoteApp->reciprocate(); reciprocated.hintSctpPort(DEFAULT_SCTP_PORT); reciprocated.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); - ++activeMediaCount; PLOG_DEBUG << "Reciprocating application in local description, mid=\"" << reciprocated.mid() << "\""; description.addMedia(std::move(reciprocated)); }, - [&](Description::Media *media) { - auto reciprocated = media->reciprocate(); -#if RTC_ENABLE_MEDIA - if (reciprocated.direction() != Description::Direction::Inactive) - ++activeMediaCount; -#else + [&](Description::Media *remoteMedia) { + std::shared_lock lock(mTracksMutex); + if (auto it = mTracks.find(remoteMedia->mid()); it != mTracks.end()) { + // Prefer local description + if (auto track = it->second.lock()) { + auto media = track->description(); +#if !RTC_ENABLE_MEDIA + // No media support, mark as inactive + media.setDirection(Description::Direction::Inactive); +#endif + PLOG_DEBUG + << "Adding media to local description, mid=\"" << media.mid() + << "\", active=" << std::boolalpha + << (media.direction() != Description::Direction::Inactive); + + description.addMedia(std::move(media)); + } else { + auto reciprocated = remoteMedia->reciprocate(); + reciprocated.setDirection(Description::Direction::Inactive); + + PLOG_DEBUG << "Adding inactive media to local description, mid=\"" + << reciprocated.mid() << "\""; + + description.addMedia(std::move(reciprocated)); + } + return; + } + + auto reciprocated = remoteMedia->reciprocate(); +#if !RTC_ENABLE_MEDIA // No media support, mark as inactive reciprocated.setDirection(Description::Direction::Inactive); #endif @@ -771,15 +924,17 @@ void PeerConnection::processLocalDescription(Description description) { }, }, remote->media(i)); - } else { + } + + if (description.type() == Description::Type::Offer) { + // This is an offer, add locally created data channels and tracks // Add application for data channels - { + if (!description.hasApplication()) { std::shared_lock lock(mDataChannelsMutex); if (!mDataChannels.empty()) { Description::Application app("data"); app.setSctpPort(DEFAULT_SCTP_PORT); app.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); - ++activeMediaCount; PLOG_DEBUG << "Adding application to local description, mid=\"" << app.mid() << "\""; @@ -789,45 +944,52 @@ void PeerConnection::processLocalDescription(Description description) { } // Add media for local tracks - { - std::shared_lock lock(mTracksMutex); - for (auto it = mTracks.begin(); it != mTracks.end(); ++it) { - if (auto track = it->second.lock()) { - auto media = track->description(); -#if RTC_ENABLE_MEDIA - if (media.direction() != Description::Direction::Inactive) - ++activeMediaCount; -#else - // No media support, mark as inactive - media.setDirection(Description::Direction::Inactive); -#endif - PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid() - << "\", active=" << std::boolalpha - << (media.direction() != Description::Direction::Inactive); + std::shared_lock lock(mTracksMutex); + for (auto it = mTracks.begin(); it != mTracks.end(); ++it) { + if (description.hasMid(it->first)) + continue; - description.addMedia(std::move(media)); - } + if (auto track = it->second.lock()) { + auto media = track->description(); +#if !RTC_ENABLE_MEDIA + // No media support, mark as inactive + media.setDirection(Description::Direction::Inactive); +#endif + PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid() + << "\", active=" << std::boolalpha + << (media.direction() != Description::Direction::Inactive); + + description.addMedia(std::move(media)); } } } - // There must be at least one active media to negociate - if (activeMediaCount == 0) - throw std::runtime_error("Nothing to negociate"); - // Set local fingerprint (wait for certificate if necessary) description.setFingerprint(mCertificate.get()->fingerprint()); { // Set as local description std::lock_guard lock(mLocalDescriptionMutex); + + std::vector existingCandidates; + if (mLocalDescription) { + existingCandidates = mLocalDescription->extractCandidates(); + mCurrentLocalDescription.emplace(std::move(*mLocalDescription)); + } + mLocalDescription.emplace(std::move(description)); + mLocalDescription->addCandidates(std::move(existingCandidates)); } mProcessor->enqueue([this, description = *mLocalDescription]() { PLOG_VERBOSE << "Issuing local description: " << description; mLocalDescriptionCallback(std::move(description)); }); + + // Reciprocated tracks might need to be open + if (auto dtlsTransport = std::atomic_load(&mDtlsTransport); + dtlsTransport && dtlsTransport->state() == Transport::State::Connected) + mProcessor->enqueue(std::bind(&PeerConnection::openTracks, this)); } void PeerConnection::processLocalCandidate(Candidate candidate) { @@ -844,6 +1006,56 @@ void PeerConnection::processLocalCandidate(Candidate candidate) { }); } +void PeerConnection::processRemoteDescription(Description description) { + { + // Set as remote description + std::lock_guard lock(mRemoteDescriptionMutex); + + std::vector existingCandidates; + if (mRemoteDescription) + existingCandidates = mRemoteDescription->extractCandidates(); + + mRemoteDescription.emplace(std::move(description)); + mRemoteDescription->addCandidates(std::move(existingCandidates)); + } + + if (description.hasApplication()) { + auto dtlsTransport = std::atomic_load(&mDtlsTransport); + auto sctpTransport = std::atomic_load(&mSctpTransport); + if (!sctpTransport && dtlsTransport && + dtlsTransport->state() == Transport::State::Connected) + initSctpTransport(); + } +} + +void PeerConnection::processRemoteCandidate(Candidate candidate) { + auto iceTransport = std::atomic_load(&mIceTransport); + if (!iceTransport) + throw std::logic_error("Remote candidate set without remote description"); + + if (candidate.resolve(Candidate::ResolveMode::Simple)) { + iceTransport->addRemoteCandidate(candidate); + } else { + // OK, we might need a lookup, do it asynchronously + // We don't use the thread pool because we have no control on the timeout + weak_ptr weakIceTransport{iceTransport}; + std::thread t([weakIceTransport, candidate]() mutable { + if (candidate.resolve(Candidate::ResolveMode::Lookup)) + if (auto iceTransport = weakIceTransport.lock()) + iceTransport->addRemoteCandidate(candidate); + }); + t.detach(); + } + + { + std::lock_guard lock(mRemoteDescriptionMutex); + if (!mRemoteDescription) + throw std::logic_error("Got a remote candidate without remote description"); + + mRemoteDescription->addCandidate(candidate); + } +} + void PeerConnection::triggerDataChannel(weak_ptr weakDataChannel) { auto dataChannel = weakDataChannel.lock(); if (!dataChannel) @@ -861,10 +1073,10 @@ bool PeerConnection::changeState(State state) { State current; do { current = mState.load(); - if (current == state) - return true; if (current == State::Closed) return false; + if (current == state) + return false; } while (!mState.compare_exchange_weak(current, state)); @@ -882,12 +1094,24 @@ bool PeerConnection::changeState(State state) { } bool PeerConnection::changeGatheringState(GatheringState state) { - if (mGatheringState.exchange(state) != state) { - std::ostringstream s; - s << state; - PLOG_INFO << "Changed gathering state to " << s.str(); - mProcessor->enqueue([this, state] { mGatheringStateChangeCallback(state); }); - } + if (mGatheringState.exchange(state) == state) + return false; + + std::ostringstream s; + s << state; + PLOG_INFO << "Changed gathering state to " << s.str(); + mProcessor->enqueue([this, state] { mGatheringStateChangeCallback(state); }); + return true; +} + +bool PeerConnection::changeSignalingState(SignalingState state) { + if (mSignalingState.exchange(state) == state) + return false; + + std::ostringstream s; + s << state; + PLOG_INFO << "Changed signaling state to " << s.str(); + mProcessor->enqueue([this, state] { mSignalingStateChangeCallback(state); }); return true; } @@ -930,15 +1154,14 @@ std::optional PeerConnection::rtt() { auto sctpTransport = std::atomic_load(&mSctpTransport); if (sctpTransport) return sctpTransport->rtt(); - PLOG_WARNING << "Could not load sctpTransport"; return std::nullopt; } } // namespace rtc -std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &state) { +std::ostream &operator<<(std::ostream &out, rtc::PeerConnection::State state) { using State = rtc::PeerConnection::State; - std::string str; + const char *str; switch (state) { case State::New: str = "new"; @@ -965,15 +1188,15 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &st return out << str; } -std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::GatheringState &state) { +std::ostream &operator<<(std::ostream &out, rtc::PeerConnection::GatheringState state) { using GatheringState = rtc::PeerConnection::GatheringState; - std::string str; + const char *str; switch (state) { case GatheringState::New: str = "new"; break; case GatheringState::InProgress: - str = "in_progress"; + str = "in-progress"; break; case GatheringState::Complete: str = "complete"; @@ -984,3 +1207,29 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::Gathering } return out << str; } + +std::ostream &operator<<(std::ostream &out, rtc::PeerConnection::SignalingState state) { + using SignalingState = rtc::PeerConnection::SignalingState; + const char *str; + switch (state) { + case SignalingState::Stable: + str = "stable"; + break; + case SignalingState::HaveLocalOffer: + str = "have-local-offer"; + break; + case SignalingState::HaveRemoteOffer: + str = "have-remote-offer"; + break; + case SignalingState::HaveLocalPranswer: + str = "have-local-pranswer"; + break; + case SignalingState::HaveRemotePranswer: + str = "have-remote-pranswer"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} diff --git a/src/processor.hpp b/src/processor.hpp index 69253b4..466f7b1 100644 --- a/src/processor.hpp +++ b/src/processor.hpp @@ -45,7 +45,7 @@ public: void join(); template - auto enqueue(F &&f, Args &&... args) -> invoke_future_t; + void enqueue(F &&f, Args &&... args); protected: void schedule(); @@ -60,31 +60,20 @@ protected: std::condition_variable mCondition; }; -template -auto Processor::enqueue(F &&f, Args &&... args) -> invoke_future_t { +template void Processor::enqueue(F &&f, Args &&... args) { std::unique_lock lock(mMutex); - using R = std::invoke_result_t, std::decay_t...>; - auto task = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...)); - std::future result = task->get_future(); - - auto bundle = [this, task = std::move(task)]() { - try { - (*task)(); - } catch (const std::exception &e) { - PLOG_WARNING << "Unhandled exception in task: " << e.what(); - } - schedule(); // chain the next task + auto bound = std::bind(std::forward(f), std::forward(args)...); + auto task = [this, bound = std::move(bound)]() mutable { + scope_guard guard(std::bind(&Processor::schedule, this)); // chain the next task + return bound(); }; if (!mPending) { - ThreadPool::Instance().enqueue(std::move(bundle)); + ThreadPool::Instance().enqueue(std::move(task)); mPending = true; } else { - mTasks.emplace(std::move(bundle)); + mTasks.emplace(std::move(task)); } - - return result; } } // namespace rtc diff --git a/src/threadpool.cpp b/src/threadpool.cpp index df8cb08..0d5c6b5 100644 --- a/src/threadpool.cpp +++ b/src/threadpool.cpp @@ -58,11 +58,7 @@ void ThreadPool::run() { bool ThreadPool::runOne() { if (auto task = dequeue()) { - try { - task(); - } catch (const std::exception &e) { - PLOG_WARNING << "Unhandled exception in task: " << e.what(); - } + task(); return true; } return false; diff --git a/src/threadpool.hpp b/src/threadpool.hpp index 2d1a426..3ca1464 100644 --- a/src/threadpool.hpp +++ b/src/threadpool.hpp @@ -73,8 +73,15 @@ template auto ThreadPool::enqueue(F &&f, Args &&... args) -> invoke_future_t { std::unique_lock lock(mMutex); using R = std::invoke_result_t, std::decay_t...>; - auto task = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...)); + auto bound = std::bind(std::forward(f), std::forward(args)...); + auto task = std::make_shared>([bound = std::move(bound)]() mutable { + try { + return bound(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + throw; + } + }); std::future result = task->get_future(); mTasks.emplace([task = std::move(task), token = Init::Token()]() { return (*task)(); }); diff --git a/src/track.cpp b/src/track.cpp index d6a8465..05b8e04 100644 --- a/src/track.cpp +++ b/src/track.cpp @@ -32,6 +32,13 @@ string Track::mid() const { return mMediaDescription.mid(); } Description::Media Track::description() const { return mMediaDescription; } +void Track::setDescription(Description::Media description) { + if(description.mid() != mMediaDescription.mid()) + throw std::logic_error("Media description mid does not match track mid"); + + mMediaDescription = std::move(description); +} + void Track::close() { mIsClosed = true; resetCallbacks(); diff --git a/src/websocket.cpp b/src/websocket.cpp index 354de5f..bab5e74 100644 --- a/src/websocket.cpp +++ b/src/websocket.cpp @@ -159,6 +159,7 @@ void WebSocket::incoming(message_ptr message) { } shared_ptr WebSocket::initTcpTransport() { + PLOG_VERBOSE << "Starting TCP transport"; using State = TcpTransport::State; try { std::lock_guard lock(mInitMutex); @@ -205,6 +206,7 @@ shared_ptr WebSocket::initTcpTransport() { } shared_ptr WebSocket::initTlsTransport() { + PLOG_VERBOSE << "Starting TLS transport"; using State = TlsTransport::State; try { std::lock_guard lock(mInitMutex); @@ -262,6 +264,7 @@ shared_ptr WebSocket::initTlsTransport() { } shared_ptr WebSocket::initWsTransport() { + PLOG_VERBOSE << "Starting WebSocket transport"; using State = WsTransport::State; try { std::lock_guard lock(mInitMutex); @@ -340,6 +343,6 @@ void WebSocket::closeTransports() { }); } - } // namespace rtc +} // namespace rtc #endif diff --git a/test/capi_connectivity.cpp b/test/capi_connectivity.cpp index 4c4204e..a1bed27 100644 --- a/test/capi_connectivity.cpp +++ b/test/capi_connectivity.cpp @@ -34,6 +34,7 @@ static void sleep(unsigned int secs) { Sleep(secs * 1000); } typedef struct { rtcState state; rtcGatheringState gatheringState; + rtcSignalingState signalingState; int pc; int dc; bool connected; @@ -68,6 +69,12 @@ static void RTC_API gatheringStateCallback(int pc, rtcGatheringState state, void printf("Gathering state %d: %d\n", peer == peer1 ? 1 : 2, (int)state); } +static void RTC_API signalingStateCallback(int pc, rtcSignalingState state, void *ptr) { + Peer *peer = (Peer *)ptr; + peer->signalingState = state; + printf("Signaling state %d: %d\n", peer == peer1 ? 1 : 2, (int)state); +} + static void RTC_API openCallback(int id, void *ptr) { Peer *peer = (Peer *)ptr; peer->connected = true; @@ -180,6 +187,12 @@ int test_capi_connectivity_main() { goto error; } + if (peer1->signalingState != RTC_SIGNALING_STABLE || + peer2->signalingState != RTC_SIGNALING_STABLE) { + fprintf(stderr, "Signaling state is not stable\n"); + goto error; + } + if (!peer1->connected || !peer2->connected) { fprintf(stderr, "DataChannel is not connected\n"); goto error; @@ -236,7 +249,6 @@ int test_capi_connectivity_main() { } printf("Remote address 2: %s\n", buffer); - if (rtcGetSelectedCandidatePair(peer1->pc, buffer, BUFFER_SIZE, buffer2, BUFFER_SIZE) < 0) { fprintf(stderr, "rtcGetSelectedCandidatePair failed\n"); goto error; @@ -251,7 +263,6 @@ int test_capi_connectivity_main() { printf("Local candidate 2: %s\n", buffer); printf("Remote candidate 2: %s\n", buffer2); - deletePeer(peer1); sleep(1); deletePeer(peer2); diff --git a/test/capi_track.cpp b/test/capi_track.cpp index c83a756..3696ab7 100644 --- a/test/capi_track.cpp +++ b/test/capi_track.cpp @@ -156,7 +156,7 @@ int test_capi_track_main() { rtcSetClosedCallback(peer1->tr, closedCallback); // Initiate the handshake - rtcSetLocalDescription(peer1->pc); + rtcSetLocalDescription(peer1->pc, NULL); attempts = 10; while ((!peer2->connected || !peer1->connected) && attempts--) diff --git a/test/connectivity.cpp b/test/connectivity.cpp index e42f23e..78f65c2 100644 --- a/test/connectivity.cpp +++ b/test/connectivity.cpp @@ -69,6 +69,10 @@ void test_connectivity() { cout << "Gathering state 1: " << state << endl; }); + pc1->onSignalingStateChange([](PeerConnection::SignalingState state) { + cout << "Signaling state 1: " << state << endl; + }); + pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) { auto pc1 = wpc1.lock(); if (!pc1) @@ -91,6 +95,10 @@ void test_connectivity() { cout << "Gathering state 2: " << state << endl; }); + pc2->onSignalingStateChange([](PeerConnection::SignalingState state) { + cout << "Signaling state 2: " << state << endl; + }); + shared_ptr dc2; pc2->onDataChannel([&dc2](shared_ptr dc) { cout << "DataChannel 2: Received with label \"" << dc->label() << "\"" << endl; diff --git a/test/track.cpp b/test/track.cpp index 3c8000c..3bebafe 100644 --- a/test/track.cpp +++ b/test/track.cpp @@ -92,9 +92,10 @@ void test_track() { }); shared_ptr t2; - pc2->onTrack([&t2](shared_ptr t) { + string newTrackMid; + pc2->onTrack([&t2, &newTrackMid](shared_ptr t) { cout << "Track 2: Received with mid \"" << t->mid() << "\"" << endl; - if (t->mid() != "test") { + if (t->mid() != newTrackMid) { cerr << "Wrong track mid" << endl; return; } @@ -102,7 +103,9 @@ void test_track() { std::atomic_store(&t2, t); }); - auto t1 = pc1->addTrack(Description::Video("test")); + // Test opening a track + newTrackMid = "test"; + auto t1 = pc1->addTrack(Description::Video(newTrackMid)); pc1->setLocalDescription(); @@ -118,6 +121,20 @@ void test_track() { if (!at2 || !at2->isOpen() || !t1->isOpen()) throw runtime_error("Track is not open"); + // Test renegotiation + newTrackMid = "added"; + t1 = pc1->addTrack(Description::Video(newTrackMid)); + + pc1->setLocalDescription(); + + attempts = 10; + t2.reset(); + while ((!(at2 = std::atomic_load(&t2)) || !at2->isOpen() || !t1->isOpen()) && attempts--) + this_thread::sleep_for(1s); + + if (!at2 || !at2->isOpen() || !t1->isOpen()) + throw runtime_error("Renegociated track is not open"); + // TODO: Test sending RTP packets in track // Delay close of peer 2 to check closing works properly