diff --git a/CMakeLists.txt b/CMakeLists.txt index d7bea16..af2c82e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,26 +46,17 @@ endif() set(LIBDATACHANNEL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/candidate.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/certificate.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/channel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/configuration.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/datachannel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/description.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/dtlssrtptransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/dtlstransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/icetransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/init.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/log.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/message.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/peerconnection.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/logcounter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpreceivingsession.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/threadpool.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/tls.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/track.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/processor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/capi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizationconfig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpsrreporter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizer.cpp @@ -78,36 +69,7 @@ set(LIBDATACHANNEL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/mediahandlerelement.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/mediahandlerrootelement.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpnackresponder.cpp -) - -set(LIBDATACHANNEL_PRIVATE_HEADERS - ${CMAKE_CURRENT_SOURCE_DIR}/src/certificate.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/dtlssrtptransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/dtlstransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/icetransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/logcounter.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/threadpool.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/tls.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/processor.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/transport.hpp -) - -set(LIBDATACHANNEL_WEBSOCKET_SOURCES - ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/verifiedtlstransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp -) - -set(LIBDATACHANNEL_WEBSOCKET_PRIVATE_HEADERS - ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/verifiedtlstransport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/capi.cpp ) set(LIBDATACHANNEL_HEADERS @@ -119,12 +81,11 @@ set(LIBDATACHANNEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/description.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/mediahandler.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtcpreceivingsession.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/include.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/common.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/init.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/log.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/message.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/peerconnection.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/queue.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp @@ -145,6 +106,51 @@ set(LIBDATACHANNEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtcpnackresponder.hpp ) +set(LIBDATACHANNEL_IMPL_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/certificate.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/channel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/datachannel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/dtlssrtptransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/dtlstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/icetransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/peerconnection.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/logcounter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sctptransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/threadpool.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tls.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/track.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.cpp +) + +set(LIBDATACHANNEL_IMPL_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/certificate.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/channel.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/datachannel.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/dtlssrtptransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/dtlstransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/icetransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/peerconnection.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/queue.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/logcounter.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sctptransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/threadpool.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tls.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/track.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.hpp +) + set(TESTS_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test/main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/connectivity.cpp @@ -163,7 +169,8 @@ set(TESTS_UWP_RESOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/tests/SmallLogo44x44.png ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/tests/SplashScreen.png ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/tests/StoreLogo.png - ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/tests/Windows_TemporaryKey.pfx) + ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/tests/Windows_TemporaryKey.pfx +) set(BENCHMARK_UWP_RESOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/benchmark/Logo.png @@ -172,7 +179,8 @@ set(BENCHMARK_UWP_RESOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/benchmark/SmallLogo44x44.png ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/benchmark/SplashScreen.png ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/benchmark/StoreLogo.png - ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/benchmark/Windows_TemporaryKey.pfx) + ${CMAKE_CURRENT_SOURCE_DIR}/test/uwp/benchmark/Windows_TemporaryKey.pfx +) set(CMAKE_THREAD_PREFER_PTHREAD TRUE) set(THREADS_PREFER_PTHREAD_FLAG TRUE) @@ -192,33 +200,16 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") endif() add_library(Usrsctp::Usrsctp ALIAS usrsctp) -if (NO_WEBSOCKET) - add_library(datachannel SHARED - ${LIBDATACHANNEL_SOURCES} - ${LIBDATACHANNEL_PRIVATE_HEADERS} - ${LIBDATACHANNEL_HEADERS}) - add_library(datachannel-static STATIC EXCLUDE_FROM_ALL - ${LIBDATACHANNEL_SOURCES} - ${LIBDATACHANNEL_PRIVATE_HEADERS} - ${LIBDATACHANNEL_HEADERS}) - target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=0) - target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=0) -else() - add_library(datachannel SHARED - ${LIBDATACHANNEL_SOURCES} - ${LIBDATACHANNEL_PRIVATE_HEADERS} - ${LIBDATACHANNEL_WEBSOCKET_SOURCES} - ${LIBDATACHANNEL_WEBSOCKET_PRIVATE_HEADERS} - ${LIBDATACHANNEL_HEADERS}) - add_library(datachannel-static STATIC EXCLUDE_FROM_ALL - ${LIBDATACHANNEL_SOURCES} - ${LIBDATACHANNEL_PRIVATE_HEADERS} - ${LIBDATACHANNEL_WEBSOCKET_SOURCES} - ${LIBDATACHANNEL_WEBSOCKET_PRIVATE_HEADERS} - ${LIBDATACHANNEL_HEADERS}) - target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=1) - target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=1) -endif() +add_library(datachannel SHARED + ${LIBDATACHANNEL_SOURCES} + ${LIBDATACHANNEL_HEADERS} + ${LIBDATACHANNEL_IMPL_SOURCES} + ${LIBDATACHANNEL_IMPL_HEADERS}) +add_library(datachannel-static STATIC EXCLUDE_FROM_ALL + ${LIBDATACHANNEL_SOURCES} + ${LIBDATACHANNEL_HEADERS} + ${LIBDATACHANNEL_IMPL_SOURCES} + ${LIBDATACHANNEL_IMPL_HEADERS}) set_target_properties(datachannel PROPERTIES VERSION ${PROJECT_VERSION} @@ -244,6 +235,14 @@ if(WIN32) target_link_libraries(datachannel-static PUBLIC ws2_32) # winsock2 endif() +if (NO_WEBSOCKET) + target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=0) + target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=0) +else() + target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=1) + target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=1) +endif() + if(NO_MEDIA) target_compile_definitions(datachannel PUBLIC RTC_ENABLE_MEDIA=0) target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_MEDIA=0) diff --git a/README.md b/README.md index 80d1c2f..c02e291 100644 --- a/README.md +++ b/README.md @@ -145,50 +145,49 @@ Additionnaly, you might want to have a look at the [C API](https://github.com/pa rtc::Configuration config; config.iceServers.emplace_back("mystunserver.org:3478"); -auto pc = make_shared(config); +rtc::PeerConection pc(config); -pc->onLocalDescription([](rtc::Description sdp) { +pc.onLocalDescription([](rtc::Description sdp) { // Send the SDP to the remote peer MY_SEND_DESCRIPTION_TO_REMOTE(string(sdp)); }); -pc->onLocalCandidate([](rtc::Candidate candidate) { +pc.onLocalCandidate([](rtc::Candidate candidate) { // Send the candidate to the remote peer MY_SEND_CANDIDATE_TO_REMOTE(candidate.candidate(), candidate.mid()); }); -MY_ON_RECV_DESCRIPTION_FROM_REMOTE([pc](string sdp) { - pc->setRemoteDescription(rtc::Description(sdp)); +MY_ON_RECV_DESCRIPTION_FROM_REMOTE([&pc](string sdp) { + pc.setRemoteDescription(rtc::Description(sdp)); }); -MY_ON_RECV_CANDIDATE_FROM_REMOTE([pc](string candidate, string mid) { - pc->addRemoteCandidate(rtc::Candidate(candidate, mid)); +MY_ON_RECV_CANDIDATE_FROM_REMOTE([&pc](string candidate, string mid) { + pc.addRemoteCandidate(rtc::Candidate(candidate, mid)); }); ``` ### Observe the PeerConnection state ```cpp -pc->onStateChange([](PeerConnection::State state) { +pc.onStateChange([](PeerConnection::State state) { cout << "State: " << state << endl; }); -pc->onGatheringStateChange([](PeerConnection::GatheringState state) { +pc.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state: " << state << endl; }); - ``` ### Create a DataChannel ```cpp -auto dc = pc->createDataChannel("test"); +auto dc = pc.createDataChannel("test"); -dc->onOpen([]() { +dc.onOpen([]() { cout << "Open" << endl; }); -dc->onMessage([](variant message) { +dc.onMessage([](variant message) { if (holds_alternative(message)) { cout << "Received: " << get(message) << endl; } @@ -199,30 +198,28 @@ dc->onMessage([](variant message) { ```cpp shared_ptr dc; -pc->onDataChannel([&dc](shared_ptr incoming) { +pc.onDataChannel([&dc](shared_ptr incoming) { dc = incoming; dc->send("Hello world!"); }); - ``` ### Open a WebSocket ```cpp -auto ws = make_shared(); +rtc::WebSocket ws; -ws->onOpen([]() { +ws.onOpen([]() { cout << "WebSocket open" << endl; }); -ws->onMessage([](variant message) { +ws.onMessage([](variant message) { if (holds_alternative(message)) { cout << "WebSocket received: " << get(message) << endl; } }); -ws->open("wss://my.websocket/service"); - +ws.open("wss://my.websocket/service"); ``` ## External resources diff --git a/examples/client/main.cpp b/examples/client/main.cpp index 243cae6..78af56c 100644 --- a/examples/client/main.cpp +++ b/examples/client/main.cpp @@ -54,20 +54,20 @@ shared_ptr createPeerConnection(const Configuration &config, string randomId(size_t length); int main(int argc, char **argv) try { - auto params = std::make_unique(argc, argv); + Cmdline params(argc, argv); rtc::InitLogger(LogLevel::Info); Configuration config; string stunServer = ""; - if (params->noStun()) { + if (params.noStun()) { cout << "No STUN server is configured. Only local hosts and public IP addresses supported." << endl; } else { - if (params->stunServer().substr(0, 5).compare("stun:") != 0) { + if (params.stunServer().substr(0, 5).compare("stun:") != 0) { stunServer = "stun:"; } - stunServer += params->stunServer() + ":" + to_string(params->stunPort()); + stunServer += params.stunServer() + ":" + to_string(params.stunPort()); cout << "Stun server is " << stunServer << endl; config.iceServers.emplace_back(stunServer); } @@ -129,11 +129,11 @@ int main(int argc, char **argv) try { }); string wsPrefix = ""; - if (params->webSocketServer().substr(0, 5).compare("ws://") != 0) { + if (params.webSocketServer().substr(0, 5).compare("ws://") != 0) { wsPrefix = "ws://"; } - const string url = wsPrefix + params->webSocketServer() + ":" + - to_string(params->webSocketPort()) + "/" + localId; + const string url = wsPrefix + params.webSocketServer() + ":" + + to_string(params.webSocketPort()) + "/" + localId; cout << "Url is " << url << endl; ws->open(url); @@ -251,3 +251,4 @@ string randomId(size_t length) { generate(id.begin(), id.end(), [&]() { return characters.at(dist(rng)); }); return id; } + diff --git a/include/rtc/candidate.hpp b/include/rtc/candidate.hpp index 637717e..56b32a0 100644 --- a/include/rtc/candidate.hpp +++ b/include/rtc/candidate.hpp @@ -19,7 +19,7 @@ #ifndef RTC_CANDIDATE_H #define RTC_CANDIDATE_H -#include "include.hpp" +#include "common.hpp" #include diff --git a/include/rtc/channel.hpp b/include/rtc/channel.hpp index 52421ae..b3367f9 100644 --- a/include/rtc/channel.hpp +++ b/include/rtc/channel.hpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2019-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -19,7 +19,7 @@ #ifndef RTC_CHANNEL_H #define RTC_CHANNEL_H -#include "include.hpp" +#include "common.hpp" #include "message.hpp" #include @@ -28,10 +28,13 @@ namespace rtc { -class RTC_CPP_EXPORT Channel { +namespace impl { +struct Channel; +} + +class RTC_CPP_EXPORT Channel : private CheshireCat { public: - Channel() = default; - virtual ~Channel() = default; + virtual ~Channel(); virtual void close() = 0; virtual bool send(message_variant data) = 0; // returns false if buffered @@ -54,30 +57,13 @@ public: void setBufferedAmountLowThreshold(size_t amount); // Extended API - virtual std::optional receive() = 0; // only if onMessage unset - virtual std::optional peek() = 0; // only if onMessage unset - virtual size_t availableAmount() const; // total size available to receive + std::optional receive(); // only if onMessage unset + std::optional peek(); // only if onMessage unset + size_t availableAmount() const; // total size available to receive void onAvailable(std::function callback); protected: - virtual void triggerOpen(); - virtual void triggerClosed(); - virtual void triggerError(string error); - virtual void triggerAvailable(size_t count); - virtual void triggerBufferedAmount(size_t amount); - - void resetCallbacks(); - -private: - synchronized_callback<> mOpenCallback; - synchronized_callback<> mClosedCallback; - synchronized_callback mErrorCallback; - synchronized_callback mMessageCallback; - synchronized_callback<> mAvailableCallback; - synchronized_callback<> mBufferedAmountLowCallback; - - std::atomic mBufferedAmount = 0; - std::atomic mBufferedAmountLowThreshold = 0; + Channel(impl_ptr impl); }; } // namespace rtc diff --git a/include/rtc/common.hpp b/include/rtc/common.hpp new file mode 100644 index 0000000..b96ac7d --- /dev/null +++ b/include/rtc/common.hpp @@ -0,0 +1,75 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_COMMON_H +#define RTC_COMMON_H + +#ifndef RTC_ENABLE_MEDIA +#define RTC_ENABLE_MEDIA 1 +#endif + +#ifndef RTC_ENABLE_WEBSOCKET +#define RTC_ENABLE_WEBSOCKET 1 +#endif + +#ifdef _WIN32 +#define RTC_CPP_EXPORT __declspec(dllexport) +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0602 // Windows 8 +#endif +#ifdef _MSC_VER +#pragma warning(disable : 4251) // disable "X needs to have dll-interface..." +#endif +#else +#define RTC_CPP_EXPORT +#endif + +#include "log.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rtc { + +using std::byte; +using std::nullopt; +using std::shared_ptr; +using std::string; +using std::string_view; +using std::unique_ptr; +using std::weak_ptr; + +using binary = std::vector; +using binary_ptr = std::shared_ptr; + +using std::size_t; +using std::uint16_t; +using std::uint32_t; +using std::uint64_t; +using std::uint8_t; + +} // namespace rtc + +#endif diff --git a/include/rtc/configuration.hpp b/include/rtc/configuration.hpp index d37b663..84acb84 100644 --- a/include/rtc/configuration.hpp +++ b/include/rtc/configuration.hpp @@ -19,7 +19,7 @@ #ifndef RTC_ICE_CONFIGURATION_H #define RTC_ICE_CONFIGURATION_H -#include "include.hpp" +#include "common.hpp" #include "message.hpp" #include diff --git a/include/rtc/datachannel.hpp b/include/rtc/datachannel.hpp index d9fb68a..a851b7a 100644 --- a/include/rtc/datachannel.hpp +++ b/include/rtc/datachannel.hpp @@ -20,28 +20,30 @@ #define RTC_DATA_CHANNEL_H #include "channel.hpp" -#include "include.hpp" +#include "common.hpp" #include "message.hpp" -#include "queue.hpp" #include "reliability.hpp" #include #include #include +#include #include #include #include namespace rtc { -class SctpTransport; -class PeerConnection; +namespace impl { -class RTC_CPP_EXPORT DataChannel : public std::enable_shared_from_this, - public Channel { +struct DataChannel; +struct PeerConnection; + +} // namespace impl + +class RTC_CPP_EXPORT DataChannel final : private CheshireCat, public Channel { public: - DataChannel(std::weak_ptr pc, uint16_t stream, string label, string protocol, - Reliability reliability); + DataChannel(impl_ptr impl); virtual ~DataChannel(); uint16_t stream() const; @@ -50,60 +52,18 @@ public: string protocol() const; Reliability reliability() const; + bool isOpen(void) const override; + bool isClosed(void) const override; + size_t maxMessageSize() const override; + void close(void) override; bool send(message_variant data) override; bool send(const byte *data, size_t size) override; template bool sendBuffer(const Buffer &buf); template bool sendBuffer(Iterator first, Iterator last); - bool isOpen(void) const override; - bool isClosed(void) const override; - size_t maxMessageSize() const override; - - // Extended API - size_t availableAmount() const override; - std::optional receive() override; - std::optional peek() override; - -protected: - virtual void open(std::shared_ptr transport); - virtual void processOpenMessage(message_ptr message); - void remoteClose(); - bool outgoing(message_ptr message); - void incoming(message_ptr message); - - const std::weak_ptr mPeerConnection; - std::weak_ptr mSctpTransport; - - uint16_t mStream; - string mLabel; - string mProtocol; - std::shared_ptr mReliability; - - mutable std::shared_mutex mMutex; - - std::atomic mIsOpen = false; - std::atomic mIsClosed = false; - private: - Queue mRecvQueue; - - friend class PeerConnection; -}; - -class RTC_CPP_EXPORT NegotiatedDataChannel final : public DataChannel { -public: - NegotiatedDataChannel(std::weak_ptr pc, uint16_t stream, string label, - string protocol, Reliability reliability); - NegotiatedDataChannel(std::weak_ptr pc, std::weak_ptr transport, - uint16_t stream); - ~NegotiatedDataChannel(); - -private: - void open(std::shared_ptr transport) override; - void processOpenMessage(message_ptr message) override; - - friend class PeerConnection; + using CheshireCat::impl; }; template std::pair to_bytes(const Buffer &buf) { @@ -115,9 +75,7 @@ template std::pair to_bytes(const Buffer template bool DataChannel::sendBuffer(const Buffer &buf) { auto [bytes, size] = to_bytes(buf); - auto message = std::make_shared(size); - std::copy(bytes, bytes + size, message->data()); - return outgoing(message); + return send(bytes, size); } template bool DataChannel::sendBuffer(Iterator first, Iterator last) { @@ -125,13 +83,13 @@ template bool DataChannel::sendBuffer(Iterator first, Iterat for (Iterator it = first; it != last; ++it) size += it->size(); - auto message = std::make_shared(size); - auto pos = message->begin(); + binary buffer(size); + byte *pos = buffer.data(); for (Iterator it = first; it != last; ++it) { auto [bytes, len] = to_bytes(*it); pos = std::copy(bytes, bytes + len, pos); } - return outgoing(message); + return send(std::move(buffer)); } } // namespace rtc diff --git a/include/rtc/description.hpp b/include/rtc/description.hpp index a09c91b..19f5b65 100644 --- a/include/rtc/description.hpp +++ b/include/rtc/description.hpp @@ -21,7 +21,7 @@ #define RTC_DESCRIPTION_H #include "candidate.hpp" -#include "include.hpp" +#include "common.hpp" #include #include diff --git a/include/rtc/init.hpp b/include/rtc/init.hpp index 2830d18..8325281 100644 --- a/include/rtc/init.hpp +++ b/include/rtc/init.hpp @@ -19,7 +19,7 @@ #ifndef RTC_INIT_H #define RTC_INIT_H -#include "include.hpp" +#include "common.hpp" #include diff --git a/include/rtc/log.hpp b/include/rtc/log.hpp index 56a2daf..e8b851f 100644 --- a/include/rtc/log.hpp +++ b/include/rtc/log.hpp @@ -35,7 +35,7 @@ #pragma warning(pop) #endif -#include "include.hpp" +#include "common.hpp" namespace rtc { diff --git a/include/rtc/mediahandler.hpp b/include/rtc/mediahandler.hpp index e9203f8..253c8ed 100644 --- a/include/rtc/mediahandler.hpp +++ b/include/rtc/mediahandler.hpp @@ -20,7 +20,7 @@ #ifndef RTC_MEDIA_HANDLER_H #define RTC_MEDIA_HANDLER_H -#include "include.hpp" +#include "common.hpp" #include "message.hpp" namespace rtc { diff --git a/include/rtc/mediahandlerelement.hpp b/include/rtc/mediahandlerelement.hpp index edf177c..2603150 100644 --- a/include/rtc/mediahandlerelement.hpp +++ b/include/rtc/mediahandlerelement.hpp @@ -20,7 +20,7 @@ #if RTC_ENABLE_MEDIA -#include "include.hpp" +#include "common.hpp" #include "message.hpp" #include "rtp.hpp" diff --git a/include/rtc/message.hpp b/include/rtc/message.hpp index 54fde3d..e1860b4 100644 --- a/include/rtc/message.hpp +++ b/include/rtc/message.hpp @@ -19,7 +19,7 @@ #ifndef RTC_MESSAGE_H #define RTC_MESSAGE_H -#include "include.hpp" +#include "common.hpp" #include "reliability.hpp" #include diff --git a/include/rtc/nalunit.hpp b/include/rtc/nalunit.hpp index 1012bf0..966b89c 100644 --- a/include/rtc/nalunit.hpp +++ b/include/rtc/nalunit.hpp @@ -21,7 +21,7 @@ #if RTC_ENABLE_MEDIA -#include "include.hpp" +#include "common.hpp" namespace rtc { diff --git a/include/rtc/peerconnection.hpp b/include/rtc/peerconnection.hpp index e80528f..2c61572 100644 --- a/include/rtc/peerconnection.hpp +++ b/include/rtc/peerconnection.hpp @@ -23,33 +23,23 @@ #include "configuration.hpp" #include "datachannel.hpp" #include "description.hpp" -#include "include.hpp" +#include "common.hpp" #include "init.hpp" #include "message.hpp" #include "reliability.hpp" #include "rtc.hpp" #include "track.hpp" -#include #include #include -#include -#include -#include -#include -#include -#include namespace rtc { -class Certificate; -class Processor; -class IceTransport; -class DtlsTransport; -class SctpTransport; +namespace impl { -using certificate_ptr = std::shared_ptr; -using future_certificate_ptr = std::shared_future; +struct PeerConnection; + +} struct RTC_CPP_EXPORT DataChannelInit { Reliability reliability = {}; @@ -58,7 +48,7 @@ struct RTC_CPP_EXPORT DataChannelInit { string protocol = ""; }; -class RTC_CPP_EXPORT PeerConnection final : public std::enable_shared_from_this { +class RTC_CPP_EXPORT PeerConnection final : CheshireCat { public: enum class State : int { New = RTC_NEW, @@ -84,7 +74,7 @@ public: } rtcSignalingState; PeerConnection(); - PeerConnection(const Configuration &config); + PeerConnection(Configuration config); ~PeerConnection(); void close(); @@ -128,80 +118,6 @@ public: // Track media support requires compiling with libSRTP std::shared_ptr addTrack(Description::Media description); void onTrack(std::function track)> callback); - -private: - std::shared_ptr initIceTransport(); - std::shared_ptr initDtlsTransport(); - std::shared_ptr initSctpTransport(); - void closeTransports(); - - void endLocalCandidates(); - bool checkFingerprint(const std::string &fingerprint) const; - void forwardMessage(message_ptr message); - void forwardMedia(message_ptr message); - void forwardBufferedAmount(uint16_t stream, size_t amount); - std::optional getMidFromSsrc(uint32_t ssrc); - - std::shared_ptr emplaceDataChannel(Description::Role role, string label, - DataChannelInit init); - std::shared_ptr findDataChannel(uint16_t stream); - void iterateDataChannels(std::function channel)> func); - void openDataChannels(); - void closeDataChannels(); - void remoteCloseDataChannels(); - - void incomingTrack(Description::Media description); - void openTracks(); - - void validateRemoteDescription(const Description &description); - void processLocalDescription(Description description); - void processLocalCandidate(Candidate candidate); - void processRemoteDescription(Description description); - void processRemoteCandidate(Candidate candidate); - string localBundleMid() const; - - void triggerDataChannel(std::weak_ptr weakDataChannel); - void triggerTrack(std::shared_ptr track); - bool changeState(State state); - bool changeGatheringState(GatheringState state); - bool changeSignalingState(SignalingState state); - - void resetCallbacks(); - - void outgoingMedia(message_ptr message); - - const init_token mInitToken = Init::Token(); - const Configuration mConfig; - const future_certificate_ptr mCertificate; - const std::unique_ptr mProcessor; - - std::optional mLocalDescription, mRemoteDescription; - std::optional mCurrentLocalDescription; - mutable std::mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; - - std::shared_ptr mIceTransport; - std::shared_ptr mDtlsTransport; - std::shared_ptr mSctpTransport; - - std::unordered_map> mDataChannels; // by stream ID - std::unordered_map> mTracks; // by mid - std::vector> mTrackLines; // by SDP order - std::shared_mutex mDataChannelsMutex, mTracksMutex; - - std::unordered_map mMidFromSsrc; // cache - - std::atomic mState; - std::atomic mGatheringState; - std::atomic mSignalingState; - std::atomic mNegotiationNeeded; - - synchronized_callback> mDataChannelCallback; - synchronized_callback mLocalDescriptionCallback; - synchronized_callback mLocalCandidateCallback; - synchronized_callback mStateChangeCallback; - synchronized_callback mGatheringStateChangeCallback; - synchronized_callback mSignalingStateChangeCallback; - synchronized_callback> mTrackCallback; }; } // namespace rtc diff --git a/include/rtc/reliability.hpp b/include/rtc/reliability.hpp index fc03a76..80e3fa0 100644 --- a/include/rtc/reliability.hpp +++ b/include/rtc/reliability.hpp @@ -19,7 +19,7 @@ #ifndef RTC_RELIABILITY_H #define RTC_RELIABILITY_H -#include "include.hpp" +#include "common.hpp" #include #include diff --git a/include/rtc/rtc.hpp b/include/rtc/rtc.hpp index e09b965..ec8b286 100644 --- a/include/rtc/rtc.hpp +++ b/include/rtc/rtc.hpp @@ -17,7 +17,7 @@ */ // C++ API -#include "include.hpp" +#include "common.hpp" #include "init.hpp" // for rtc::Cleanup() #include "log.hpp" // diff --git a/include/rtc/rtcpreceivingsession.hpp b/include/rtc/rtcpreceivingsession.hpp index 5ad31f5..47091d2 100644 --- a/include/rtc/rtcpreceivingsession.hpp +++ b/include/rtc/rtcpreceivingsession.hpp @@ -22,7 +22,7 @@ #if RTC_ENABLE_MEDIA -#include "include.hpp" +#include "common.hpp" #include "mediahandler.hpp" #include "message.hpp" #include "rtp.hpp" diff --git a/include/rtc/track.hpp b/include/rtc/track.hpp index 93b5789..03edbd9 100644 --- a/include/rtc/track.hpp +++ b/include/rtc/track.hpp @@ -21,29 +21,30 @@ #include "channel.hpp" #include "description.hpp" -#include "include.hpp" -#include "message.hpp" -#include "queue.hpp" +#include "common.hpp" #include "mediahandler.hpp" +#include "message.hpp" #include -#include #include +#include namespace rtc { -#if RTC_ENABLE_MEDIA -class DtlsSrtpTransport; -#endif +namespace impl { -class RTC_CPP_EXPORT Track final : public std::enable_shared_from_this, public Channel { +class Track; + +} // namespace impl + +class RTC_CPP_EXPORT Track final : private CheshireCat, public Channel { public: - Track(Description::Media description); + Track(impl_ptr impl); ~Track() = default; string mid() const; - Description::Media description() const; Description::Direction direction() const; + Description::Media description() const; void setDescription(Description::Media description); @@ -55,11 +56,6 @@ public: bool isClosed(void) const override; size_t maxMessageSize() const override; - // Extended API - size_t availableAmount() const override; - std::optional receive() override; - std::optional peek() override; - bool requestKeyframe(); // RTCP handler @@ -67,24 +63,7 @@ public: std::shared_ptr getRtcpHandler(); private: -#if RTC_ENABLE_MEDIA - void open(std::shared_ptr transport); - std::weak_ptr mDtlsSrtpTransport; -#endif - - void incoming(message_ptr message); - bool outgoing(message_ptr message); - - Description::Media mMediaDescription; - std::shared_ptr mRtcpHandler; - - mutable std::shared_mutex mMutex; - - std::atomic mIsClosed = false; - - Queue mRecvQueue; - - friend class PeerConnection; + using CheshireCat::impl; }; } // namespace rtc diff --git a/include/rtc/include.hpp b/include/rtc/utils.hpp similarity index 65% rename from include/rtc/include.hpp rename to include/rtc/utils.hpp index cf72f40..890dc7d 100644 --- a/include/rtc/include.hpp +++ b/include/rtc/utils.hpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2019-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -16,69 +16,15 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_INCLUDE_H -#define RTC_INCLUDE_H +#ifndef RTC_UTILS_H +#define RTC_UTILS_H -#ifndef RTC_ENABLE_MEDIA -#define RTC_ENABLE_MEDIA 1 -#endif - -#ifndef RTC_ENABLE_WEBSOCKET -#define RTC_ENABLE_WEBSOCKET 1 -#endif - -#ifdef _WIN32 -#define RTC_CPP_EXPORT __declspec(dllexport) -#ifndef _WIN32_WINNT -#define _WIN32_WINNT 0x0602 // Windows 8 -#endif -#ifdef _MSC_VER -#pragma warning(disable : 4251) // disable "X needs to have dll-interface..." -#endif -#else -#define RTC_CPP_EXPORT -#endif - -#include "log.hpp" - -#include #include #include #include -#include -#include -#include -#include namespace rtc { -using std::byte; -using std::string; -using std::string_view; -using binary = std::vector; -using binary_ptr = std::shared_ptr; - -using std::nullopt; - -using std::size_t; -using std::uint16_t; -using std::uint32_t; -using std::uint64_t; -using std::uint8_t; - -const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length -const size_t MAX_NUMERICSERV_LEN = 6; // Max port string representation length - -const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default -const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP -const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size - -const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // Max per-channel queue size - -const int THREADPOOL_SIZE = 4; // Number of threads in the global thread pool - -const size_t DEFAULT_IPV4_MTU = 1200; // IPv4 safe MTU value recommended by RFC 8261 - // overloaded helper template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; @@ -95,7 +41,7 @@ template auto weak_bind(F &&f, T *t, } // scope_guard helper -class scope_guard { +class scope_guard final { public: scope_guard(std::function func) : function(std::move(func)) {} scope_guard(scope_guard &&other) = delete; @@ -111,7 +57,8 @@ private: std::function function; }; -template class synchronized_callback { +// callback with built-in synchronization +template class synchronized_callback final { public: synchronized_callback() = default; synchronized_callback(synchronized_callback &&cb) { *this = std::move(cb); } @@ -157,6 +104,33 @@ private: std::function callback; mutable std::recursive_mutex mutex; }; + +// pimpl base class +template using impl_ptr = std::shared_ptr; +template class CheshireCat { +public: + CheshireCat(impl_ptr impl) : mImpl(std::move(impl)) {} + template + CheshireCat(Args... args) : mImpl(std::make_shared(std::move(args)...)) {} + CheshireCat(CheshireCat &&cc) { *this = std::move(cc); } + CheshireCat(const CheshireCat &) = delete; + + virtual ~CheshireCat() = default; + + CheshireCat &operator=(CheshireCat &&cc) { + mImpl = std::move(cc.mImpl); + return *this; + }; + CheshireCat &operator=(const CheshireCat &) = delete; + +protected: + impl_ptr impl() { return mImpl; } + impl_ptr impl() const { return mImpl; } + +private: + impl_ptr mImpl; +}; + } // namespace rtc #endif diff --git a/include/rtc/websocket.hpp b/include/rtc/websocket.hpp index 328bafd..052182b 100644 --- a/include/rtc/websocket.hpp +++ b/include/rtc/websocket.hpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020 Paul-Louis Ageneau + * Copyright (c) 2020-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -22,24 +22,18 @@ #if RTC_ENABLE_WEBSOCKET #include "channel.hpp" -#include "include.hpp" -#include "init.hpp" +#include "common.hpp" #include "message.hpp" -#include "queue.hpp" - -#include -#include -#include -#include namespace rtc { -class TcpTransport; -class TlsTransport; -class WsTransport; +namespace impl { -class RTC_CPP_EXPORT WebSocket final : public Channel, - public std::enable_shared_from_this { +struct WebSocket; + +} + +class RTC_CPP_EXPORT WebSocket final : private CheshireCat, public Channel { public: enum class State : int { Connecting = 0, @@ -53,49 +47,25 @@ public: std::vector protocols; }; - WebSocket(std::optional config = nullopt); + WebSocket(); + WebSocket(Configuration config); ~WebSocket(); State readyState() const; + bool isOpen() const override; + bool isClosed() const override; + size_t maxMessageSize() const override; + void open(const string &url); void close() override; bool send(const message_variant data) override; bool send(const byte *data, size_t size) override; - bool isOpen() const override; - bool isClosed() const override; - size_t maxMessageSize() const override; - - // Extended API - std::optional receive() override; - std::optional peek() override; - size_t availableAmount() const override; // total size available to receive - private: - bool changeState(State state); - void remoteClose(); - bool outgoing(message_ptr message); - void incoming(message_ptr message); - - std::shared_ptr initTcpTransport(); - std::shared_ptr initTlsTransport(); - std::shared_ptr initWsTransport(); - void closeTransports(); - - init_token mInitToken = Init::Token(); - - std::shared_ptr mTcpTransport; - std::shared_ptr mTlsTransport; - std::shared_ptr mWsTransport; - std::recursive_mutex mInitMutex; - - const Configuration mConfig; - string mScheme, mHost, mHostname, mService, mPath; - std::atomic mState = State::Closed; - - Queue mRecvQueue; + using CheshireCat::impl; }; + } // namespace rtc #endif diff --git a/src/candidate.cpp b/src/candidate.cpp index eef2c4b..dc52fd5 100644 --- a/src/candidate.cpp +++ b/src/candidate.cpp @@ -17,6 +17,7 @@ */ #include "candidate.hpp" +#include "globals.hpp" #include #include diff --git a/src/capi.cpp b/src/capi.cpp index 7d93f2e..616344d 100644 --- a/src/capi.cpp +++ b/src/capi.cpp @@ -16,7 +16,7 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#include "include.hpp" +#include "common.hpp" #include "rtc.h" @@ -43,8 +43,6 @@ using namespace rtc; using std::optional; -using std::shared_ptr; -using std::string; using std::chrono::milliseconds; namespace { diff --git a/src/channel.cpp b/src/channel.cpp index b5d15fe..6782822 100644 --- a/src/channel.cpp +++ b/src/channel.cpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2019-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -17,27 +17,36 @@ */ #include "channel.hpp" +#include "globals.hpp" + +#include "impl/channel.hpp" namespace rtc { +Channel::~Channel() { + impl()->resetCallbacks(); +} + +Channel::Channel(impl_ptr impl) : CheshireCat(std::move(impl)) {} + size_t Channel::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; } -size_t Channel::bufferedAmount() const { return mBufferedAmount; } +size_t Channel::bufferedAmount() const { return impl()->bufferedAmount; } -size_t Channel::availableAmount() const { return 0; } +void Channel::onOpen(std::function callback) { impl()->openCallback = callback; } -void Channel::onOpen(std::function callback) { mOpenCallback = callback; } +void Channel::onClosed(std::function callback) { impl()->closedCallback = callback; } -void Channel::onClosed(std::function callback) { mClosedCallback = callback; } - -void Channel::onError(std::function callback) { mErrorCallback = callback; } +void Channel::onError(std::function callback) { + impl()->errorCallback = callback; +} void Channel::onMessage(std::function callback) { - mMessageCallback = callback; + impl()->messageCallback = callback; // Pass pending messages while (auto message = receive()) - mMessageCallback(*message); + impl()->messageCallback(*message); } void Channel::onMessage(std::function binaryCallback, @@ -48,45 +57,25 @@ void Channel::onMessage(std::function binaryCallback, } void Channel::onBufferedAmountLow(std::function callback) { - mBufferedAmountLowCallback = callback; + impl()->bufferedAmountLowCallback = callback; } -void Channel::setBufferedAmountLowThreshold(size_t amount) { mBufferedAmountLowThreshold = amount; } - -void Channel::onAvailable(std::function callback) { mAvailableCallback = callback; } - -void Channel::triggerOpen() { mOpenCallback(); } - -void Channel::triggerClosed() { mClosedCallback(); } - -void Channel::triggerError(string error) { mErrorCallback(error); } - -void Channel::triggerAvailable(size_t count) { - if (count == 1) - mAvailableCallback(); - - while (mMessageCallback && count--) { - auto message = receive(); - if (!message) - break; - mMessageCallback(*message); - } +void Channel::setBufferedAmountLowThreshold(size_t amount) { + impl()->bufferedAmountLowThreshold = amount; } -void Channel::triggerBufferedAmount(size_t amount) { - size_t previous = mBufferedAmount.exchange(amount); - size_t threshold = mBufferedAmountLowThreshold.load(); - if (previous > threshold && amount <= threshold) - mBufferedAmountLowCallback(); +std::optional Channel::receive() { + return impl()->receive(); } -void Channel::resetCallbacks() { - mOpenCallback = nullptr; - mClosedCallback = nullptr; - mErrorCallback = nullptr; - mMessageCallback = nullptr; - mAvailableCallback = nullptr; - mBufferedAmountLowCallback = nullptr; +std::optional Channel::peek() { + return impl()->peek(); } +size_t Channel::availableAmount() const { + return impl()->availableAmount(); +} + +void Channel::onAvailable(std::function callback) { impl()->availableCallback = callback; } + } // namespace rtc diff --git a/src/datachannel.cpp b/src/datachannel.cpp index 3d46f5b..c48d10d 100644 --- a/src/datachannel.cpp +++ b/src/datachannel.cpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2019-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -17,10 +17,12 @@ */ #include "datachannel.hpp" -#include "include.hpp" -#include "logcounter.hpp" +#include "globals.hpp" +#include "common.hpp" #include "peerconnection.hpp" -#include "sctptransport.hpp" + +#include "impl/datachannel.hpp" +#include "impl/peerconnection.hpp" #ifdef _WIN32 #include @@ -30,334 +32,36 @@ namespace rtc { -LogCounter COUNTER_USERNEG_OPEN_MESSAGE( - plog::warning, "Number of open messages for a user-negotiated DataChannel received"); - -using std::shared_ptr; -using std::weak_ptr; -using std::chrono::milliseconds; - -// Messages for the DataChannel establishment protocol -// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09 - -enum MessageType : uint8_t { - MESSAGE_OPEN_REQUEST = 0x00, - MESSAGE_OPEN_RESPONSE = 0x01, - MESSAGE_ACK = 0x02, - MESSAGE_OPEN = 0x03, - MESSAGE_CLOSE = 0x04 -}; - -enum ChannelType : uint8_t { - CHANNEL_RELIABLE = 0x00, - CHANNEL_PARTIAL_RELIABLE_REXMIT = 0x01, - CHANNEL_PARTIAL_RELIABLE_TIMED = 0x02 -}; - -#pragma pack(push, 1) -struct OpenMessage { - uint8_t type = MESSAGE_OPEN; - uint8_t channelType; - uint16_t priority; - uint32_t reliabilityParameter; - uint16_t labelLength; - uint16_t protocolLength; - // The following fields are: - // uint8_t[labelLength] label - // uint8_t[protocolLength] protocol -}; - -struct AckMessage { - uint8_t type = MESSAGE_ACK; -}; - -struct CloseMessage { - uint8_t type = MESSAGE_CLOSE; -}; -#pragma pack(pop) - -DataChannel::DataChannel(weak_ptr pc, uint16_t stream, string label, - string protocol, Reliability reliability) - : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)), - mProtocol(std::move(protocol)), - mReliability(std::make_shared(std::move(reliability))), - mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {} +DataChannel::DataChannel(impl_ptr impl) + : CheshireCat(impl), + Channel(std::dynamic_pointer_cast(impl)) {} DataChannel::~DataChannel() { close(); } -uint16_t DataChannel::stream() const { return mStream; } +void DataChannel::close() { return impl()->close(); } -uint16_t DataChannel::id() const { return mStream; } +uint16_t DataChannel::stream() const { return impl()->stream(); } -string DataChannel::label() const { - std::shared_lock lock(mMutex); - return mLabel; +uint16_t DataChannel::id() const { return impl()->stream(); } + +string DataChannel::label() const { return impl()->label(); } + +string DataChannel::protocol() const { return impl()->protocol(); } + +Reliability DataChannel::reliability() const { return impl()->reliability(); } + +bool DataChannel::isOpen(void) const { return impl()->isOpen(); } + +bool DataChannel::isClosed(void) const { return impl()->isClosed(); } + +size_t DataChannel::maxMessageSize() const { return impl()->maxMessageSize(); } + +bool DataChannel::send(message_variant data) { + return impl()->outgoing(make_message(std::move(data))); } -string DataChannel::protocol() const { - std::shared_lock lock(mMutex); - return mProtocol; -} - -Reliability DataChannel::reliability() const { - std::shared_lock lock(mMutex); - return *mReliability; -} - -void DataChannel::close() { - std::shared_ptr transport; - { - std::shared_lock lock(mMutex); - transport = mSctpTransport.lock(); - } - - mIsClosed = true; - if (mIsOpen.exchange(false) && transport) - transport->closeStream(mStream); - - resetCallbacks(); -} - -void DataChannel::remoteClose() { - if (!mIsClosed.exchange(true)) - triggerClosed(); - - mIsOpen = false; -} - -bool DataChannel::send(message_variant data) { return outgoing(make_message(std::move(data))); } - bool DataChannel::send(const byte *data, size_t size) { - return outgoing(std::make_shared(data, data + size, Message::Binary)); -} - -std::optional DataChannel::receive() { - while (auto next = mRecvQueue.tryPop()) { - message_ptr message = *next; - if (message->type != Message::Control) - return to_variant(std::move(*message)); - - auto raw = reinterpret_cast(message->data()); - if (!message->empty() && raw[0] == MESSAGE_CLOSE) - remoteClose(); - } - - return nullopt; -} - -std::optional DataChannel::peek() { - while (auto next = mRecvQueue.peek()) { - message_ptr message = *next; - if (message->type != Message::Control) - return to_variant(std::move(*message)); - - auto raw = reinterpret_cast(message->data()); - if (!message->empty() && raw[0] == MESSAGE_CLOSE) - remoteClose(); - - mRecvQueue.tryPop(); - } - - return nullopt; -} - -bool DataChannel::isOpen(void) const { return mIsOpen; } - -bool DataChannel::isClosed(void) const { return mIsClosed; } - -size_t DataChannel::maxMessageSize() const { - size_t remoteMax = DEFAULT_MAX_MESSAGE_SIZE; - if (auto pc = mPeerConnection.lock()) - if (auto description = pc->remoteDescription()) - if (auto *application = description->application()) - if (auto maxMessageSize = application->maxMessageSize()) - remoteMax = *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE; - - return std::min(remoteMax, LOCAL_MAX_MESSAGE_SIZE); -} - -size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); } - -void DataChannel::open(shared_ptr transport) { - { - std::unique_lock lock(mMutex); - mSctpTransport = transport; - } - - if (!mIsOpen.exchange(true)) - triggerOpen(); -} - -void DataChannel::processOpenMessage(message_ptr) { - PLOG_DEBUG << "Received an open message for a user-negotiated DataChannel, ignoring"; - COUNTER_USERNEG_OPEN_MESSAGE++; -} - -bool DataChannel::outgoing(message_ptr message) { - std::shared_ptr transport; - { - std::shared_lock lock(mMutex); - transport = mSctpTransport.lock(); - - if (!transport || mIsClosed) - throw std::runtime_error("DataChannel is closed"); - - if (message->size() > maxMessageSize()) - throw std::runtime_error("Message size exceeds limit"); - - // Before the ACK has been received on a DataChannel, all messages must be sent ordered - message->reliability = mIsOpen ? mReliability : nullptr; - message->stream = mStream; - } - - return transport->send(message); -} - -void DataChannel::incoming(message_ptr message) { - if (!message) - return; - - switch (message->type) { - case Message::Control: { - if (message->size() == 0) - break; // Ignore - auto raw = reinterpret_cast(message->data()); - switch (raw[0]) { - case MESSAGE_OPEN: - processOpenMessage(message); - break; - case MESSAGE_ACK: - if (!mIsOpen.exchange(true)) { - triggerOpen(); - } - break; - case MESSAGE_CLOSE: - // The close message will be processed in-order in receive() - mRecvQueue.push(message); - triggerAvailable(mRecvQueue.size()); - break; - default: - // Ignore - break; - } - break; - } - case Message::String: - case Message::Binary: - mRecvQueue.push(message); - triggerAvailable(mRecvQueue.size()); - break; - default: - // Ignore - break; - } -} - -NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr pc, uint16_t stream, - string label, string protocol, Reliability reliability) - : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {} - -NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr pc, - std::weak_ptr transport, - uint16_t stream) - : DataChannel(pc, stream, "", "", {}) { - mSctpTransport = transport; -} - -NegotiatedDataChannel::~NegotiatedDataChannel() {} - -void NegotiatedDataChannel::open(shared_ptr transport) { - std::unique_lock lock(mMutex); - mSctpTransport = transport; - - uint8_t channelType; - uint32_t reliabilityParameter; - switch (mReliability->type) { - case Reliability::Type::Rexmit: - channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT; - reliabilityParameter = uint32_t(std::get(mReliability->rexmit)); - break; - - case Reliability::Type::Timed: - channelType = CHANNEL_PARTIAL_RELIABLE_TIMED; - reliabilityParameter = uint32_t(std::get(mReliability->rexmit).count()); - break; - - default: - channelType = CHANNEL_RELIABLE; - reliabilityParameter = 0; - break; - } - - if (mReliability->unordered) - channelType |= 0x80; - - const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size(); - binary buffer(len, byte(0)); - auto &open = *reinterpret_cast(buffer.data()); - open.type = MESSAGE_OPEN; - open.channelType = channelType; - open.priority = htons(0); - open.reliabilityParameter = htonl(reliabilityParameter); - open.labelLength = htons(uint16_t(mLabel.size())); - open.protocolLength = htons(uint16_t(mProtocol.size())); - - auto end = reinterpret_cast(buffer.data() + sizeof(OpenMessage)); - std::copy(mLabel.begin(), mLabel.end(), end); - std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size()); - - lock.unlock(); - - transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); -} - -void NegotiatedDataChannel::processOpenMessage(message_ptr message) { - std::unique_lock lock(mMutex); - auto transport = mSctpTransport.lock(); - if (!transport) - throw std::runtime_error("DataChannel has no transport"); - - if (message->size() < sizeof(OpenMessage)) - throw std::invalid_argument("DataChannel open message too small"); - - OpenMessage open = *reinterpret_cast(message->data()); - open.priority = ntohs(open.priority); - open.reliabilityParameter = ntohl(open.reliabilityParameter); - open.labelLength = ntohs(open.labelLength); - open.protocolLength = ntohs(open.protocolLength); - - if (message->size() < sizeof(OpenMessage) + size_t(open.labelLength + open.protocolLength)) - throw std::invalid_argument("DataChannel open message truncated"); - - auto end = reinterpret_cast(message->data() + sizeof(OpenMessage)); - mLabel.assign(end, open.labelLength); - mProtocol.assign(end + open.labelLength, open.protocolLength); - - mReliability->unordered = (open.channelType & 0x80) != 0; - switch (open.channelType & 0x7F) { - case CHANNEL_PARTIAL_RELIABLE_REXMIT: - mReliability->type = Reliability::Type::Rexmit; - mReliability->rexmit = int(open.reliabilityParameter); - break; - case CHANNEL_PARTIAL_RELIABLE_TIMED: - mReliability->type = Reliability::Type::Timed; - mReliability->rexmit = milliseconds(open.reliabilityParameter); - break; - default: - mReliability->type = Reliability::Type::Reliable; - mReliability->rexmit = int(0); - } - - lock.unlock(); - - binary buffer(sizeof(AckMessage), byte(0)); - auto &ack = *reinterpret_cast(buffer.data()); - ack.type = MESSAGE_ACK; - - transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); - - if (!mIsOpen.exchange(true)) - triggerOpen(); + return impl()->outgoing(std::make_shared(data, data + size, Message::Binary)); } } // namespace rtc diff --git a/src/description.cpp b/src/description.cpp index 7e67ae2..4ae7864 100644 --- a/src/description.cpp +++ b/src/description.cpp @@ -28,7 +28,6 @@ #include #include -using std::shared_ptr; using std::chrono::system_clock; namespace { diff --git a/src/globals.hpp b/src/globals.hpp new file mode 100644 index 0000000..8dbf030 --- /dev/null +++ b/src/globals.hpp @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_GLOBALS_H +#define RTC_GLOBALS_H + +#include "common.hpp" + +namespace rtc { + +const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length +const size_t MAX_NUMERICSERV_LEN = 6; // Max port string representation length + +const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default +const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP +const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size + +const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // Max per-channel queue size + +const int THREADPOOL_SIZE = 4; // Number of threads in the global thread pool (>= 2) + +const size_t DEFAULT_IPV4_MTU = 1200; // IPv4 safe MTU value recommended by RFC 8261 + +} // namespace rtc + +#endif diff --git a/src/h264rtppacketizer.cpp b/src/h264rtppacketizer.cpp index 51a141c..6f18baf 100644 --- a/src/h264rtppacketizer.cpp +++ b/src/h264rtppacketizer.cpp @@ -22,9 +22,6 @@ namespace rtc { -using std::make_shared; -using std::shared_ptr; - typedef enum { NUSM_noMatch, NUSM_firstZero, @@ -74,7 +71,7 @@ NalUnitStartSequenceMatch StartSequenceMatchSucc(NalUnitStartSequenceMatch match } shared_ptr H264RtpPacketizer::splitMessage(binary_ptr message) { - auto nalus = make_shared(); + auto nalus = std::make_shared(); if (separator == Separator::Length) { unsigned long long index = 0; while (index < message->size()) { diff --git a/src/base64.cpp b/src/impl/base64.cpp similarity index 97% rename from src/base64.cpp rename to src/impl/base64.cpp index 508a898..faec9f8 100644 --- a/src/base64.cpp +++ b/src/impl/base64.cpp @@ -20,7 +20,7 @@ #include "base64.hpp" -namespace rtc { +namespace rtc::impl { using std::to_integer; @@ -59,6 +59,6 @@ string to_base64(const binary &data) { return out; } -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/base64.hpp b/src/impl/base64.hpp similarity index 90% rename from src/base64.hpp rename to src/impl/base64.hpp index 4c0b8e0..fd3f1a6 100644 --- a/src/base64.hpp +++ b/src/impl/base64.hpp @@ -16,14 +16,14 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_BASE64_H -#define RTC_BASE64_H +#ifndef RTC_IMPL_BASE64_H +#define RTC_IMPL_BASE64_H #if RTC_ENABLE_WEBSOCKET -#include "include.hpp" +#include "common.hpp" -namespace rtc { +namespace rtc::impl { string to_base64(const binary &data); diff --git a/src/certificate.cpp b/src/impl/certificate.cpp similarity index 98% rename from src/certificate.cpp rename to src/impl/certificate.cpp index 9618b57..1807e5d 100644 --- a/src/certificate.cpp +++ b/src/impl/certificate.cpp @@ -26,14 +26,10 @@ #include #include -using std::shared_ptr; -using std::string; -using std::unique_ptr; +namespace rtc::impl { #if USE_GNUTLS -namespace rtc { - Certificate::Certificate(string crt_pem, string key_pem) : mCredentials(gnutls::new_credentials(), gnutls::free_credentials) { @@ -129,12 +125,8 @@ certificate_ptr make_certificate_impl(string commonName) { } // namespace -} // namespace rtc - #else // USE_GNUTLS==0 -namespace rtc { - Certificate::Certificate(string crt_pem, string key_pem) { BIO *bio = BIO_new(BIO_s_mem()); BIO_write(bio, crt_pem.data(), int(crt_pem.size())); @@ -230,14 +222,10 @@ certificate_ptr make_certificate_impl(string commonName) { } // namespace -} // namespace rtc - #endif // Common for GnuTLS and OpenSSL -namespace rtc { - namespace { static std::unordered_map CertificateCache; @@ -262,4 +250,4 @@ void CleanupCertificateCache() { CertificateCache.clear(); } -} // namespace rtc +} // namespace rtc::impl diff --git a/src/certificate.hpp b/src/impl/certificate.hpp similarity index 93% rename from src/certificate.hpp rename to src/impl/certificate.hpp index 4c19dfc..631e8b9 100644 --- a/src/certificate.hpp +++ b/src/impl/certificate.hpp @@ -16,16 +16,16 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_CERTIFICATE_H -#define RTC_CERTIFICATE_H +#ifndef RTC_IMPL_CERTIFICATE_H +#define RTC_IMPL_CERTIFICATE_H -#include "include.hpp" +#include "common.hpp" #include "tls.hpp" #include #include -namespace rtc { +namespace rtc::impl { class Certificate { public: @@ -65,6 +65,6 @@ future_certificate_ptr make_certificate(string commonName = "libdatachannel"); / void CleanupCertificateCache(); -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/impl/channel.cpp b/src/impl/channel.cpp new file mode 100644 index 0000000..d56c30f --- /dev/null +++ b/src/impl/channel.cpp @@ -0,0 +1,57 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "channel.hpp" + +namespace rtc::impl { + +void Channel::triggerOpen() { openCallback(); } + +void Channel::triggerClosed() { closedCallback(); } + +void Channel::triggerError(string error) { errorCallback(error); } + +void Channel::triggerAvailable(size_t count) { + if (count == 1) + availableCallback(); + + while (messageCallback && count--) { + auto message = receive(); + if (!message) + break; + messageCallback(*message); + } +} + +void Channel::triggerBufferedAmount(size_t amount) { + size_t previous = bufferedAmount.exchange(amount); + size_t threshold = bufferedAmountLowThreshold.load(); + if (previous > threshold && amount <= threshold) + bufferedAmountLowCallback(); +} + +void Channel::resetCallbacks() { + openCallback = nullptr; + closedCallback = nullptr; + errorCallback = nullptr; + messageCallback = nullptr; + availableCallback = nullptr; + bufferedAmountLowCallback = nullptr; +} + +} // namespace rtc::impl diff --git a/src/impl/channel.hpp b/src/impl/channel.hpp new file mode 100644 index 0000000..87ca7e0 --- /dev/null +++ b/src/impl/channel.hpp @@ -0,0 +1,57 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_IMPL_CHANNEL_H +#define RTC_IMPL_CHANNEL_H + +#include "common.hpp" +#include "message.hpp" + +#include +#include +#include + +namespace rtc::impl { + +struct Channel { + virtual std::optional receive() = 0; + virtual std::optional peek() = 0; + virtual size_t availableAmount() const = 0; + + virtual void triggerOpen(); + virtual void triggerClosed(); + virtual void triggerError(string error); + virtual void triggerAvailable(size_t count); + virtual void triggerBufferedAmount(size_t amount); + + virtual void resetCallbacks(); + + synchronized_callback<> openCallback; + synchronized_callback<> closedCallback; + synchronized_callback errorCallback; + synchronized_callback messageCallback; + synchronized_callback<> availableCallback; + synchronized_callback<> bufferedAmountLowCallback; + + std::atomic bufferedAmount = 0; + std::atomic bufferedAmountLowThreshold = 0; +}; + +} // namespace rtc::impl + +#endif diff --git a/src/impl/datachannel.cpp b/src/impl/datachannel.cpp new file mode 100644 index 0000000..409c16a --- /dev/null +++ b/src/impl/datachannel.cpp @@ -0,0 +1,366 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "datachannel.hpp" +#include "globals.hpp" +#include "common.hpp" +#include "logcounter.hpp" +#include "peerconnection.hpp" +#include "sctptransport.hpp" + +#include "rtc/datachannel.hpp" +#include "rtc/track.hpp" + +#ifdef _WIN32 +#include +#else +#include +#endif + +using std::chrono::milliseconds; + +namespace rtc::impl { + +// Messages for the DataChannel establishment protocol +// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09 + +enum MessageType : uint8_t { + MESSAGE_OPEN_REQUEST = 0x00, + MESSAGE_OPEN_RESPONSE = 0x01, + MESSAGE_ACK = 0x02, + MESSAGE_OPEN = 0x03, + MESSAGE_CLOSE = 0x04 +}; + +enum ChannelType : uint8_t { + CHANNEL_RELIABLE = 0x00, + CHANNEL_PARTIAL_RELIABLE_REXMIT = 0x01, + CHANNEL_PARTIAL_RELIABLE_TIMED = 0x02 +}; + +#pragma pack(push, 1) +struct OpenMessage { + uint8_t type = MESSAGE_OPEN; + uint8_t channelType; + uint16_t priority; + uint32_t reliabilityParameter; + uint16_t labelLength; + uint16_t protocolLength; + // The following fields are: + // uint8_t[labelLength] label + // uint8_t[protocolLength] protocol +}; + +struct AckMessage { + uint8_t type = MESSAGE_ACK; +}; + +struct CloseMessage { + uint8_t type = MESSAGE_CLOSE; +}; +#pragma pack(pop) + +LogCounter COUNTER_USERNEG_OPEN_MESSAGE( + plog::warning, "Number of open messages for a user-negotiated DataChannel received"); + +DataChannel::DataChannel(weak_ptr pc, uint16_t stream, string label, + string protocol, Reliability reliability) + : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)), + mProtocol(std::move(protocol)), + mReliability(std::make_shared(std::move(reliability))), + mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {} + +DataChannel::~DataChannel() { close(); } + +void DataChannel::close() { + std::shared_ptr transport; + { + std::shared_lock lock(mMutex); + transport = mSctpTransport.lock(); + } + + mIsClosed = true; + if (mIsOpen.exchange(false) && transport) + transport->closeStream(mStream); + + resetCallbacks(); +} + +void DataChannel::remoteClose() { + if (!mIsClosed.exchange(true)) + triggerClosed(); + + mIsOpen = false; +} + +std::optional DataChannel::receive() { + while (auto next = mRecvQueue.tryPop()) { + message_ptr message = *next; + if (message->type != Message::Control) + return to_variant(std::move(*message)); + + auto raw = reinterpret_cast(message->data()); + if (!message->empty() && raw[0] == MESSAGE_CLOSE) + remoteClose(); + } + + return nullopt; +} + +std::optional DataChannel::peek() { + while (auto next = mRecvQueue.peek()) { + message_ptr message = *next; + if (message->type != Message::Control) + return to_variant(std::move(*message)); + + auto raw = reinterpret_cast(message->data()); + if (!message->empty() && raw[0] == MESSAGE_CLOSE) + remoteClose(); + + mRecvQueue.tryPop(); + } + + return nullopt; +} + +size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); } + +uint16_t DataChannel::stream() const { + std::shared_lock lock(mMutex); + return mStream; +} + +string DataChannel::label() const { + std::shared_lock lock(mMutex); + return mLabel; +} + +string DataChannel::protocol() const { + std::shared_lock lock(mMutex); + return mProtocol; +} + +Reliability DataChannel::reliability() const { + std::shared_lock lock(mMutex); + return *mReliability; +} + +bool DataChannel::isOpen(void) const { return mIsOpen; } + +bool DataChannel::isClosed(void) const { return mIsClosed; } + +size_t DataChannel::maxMessageSize() const { + size_t remoteMax = DEFAULT_MAX_MESSAGE_SIZE; + if (auto pc = mPeerConnection.lock()) + if (auto description = pc->remoteDescription()) + if (auto *application = description->application()) + if (auto maxMessageSize = application->maxMessageSize()) + remoteMax = *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE; + + return std::min(remoteMax, LOCAL_MAX_MESSAGE_SIZE); +} + +void DataChannel::shiftStream() { + if (mStream % 2 == 1) + mStream -= 1; +} + +void DataChannel::open(shared_ptr transport) { + { + std::unique_lock lock(mMutex); + mSctpTransport = transport; + } + + if (!mIsOpen.exchange(true)) + triggerOpen(); +} + +void DataChannel::processOpenMessage(message_ptr) { + PLOG_DEBUG << "Received an open message for a user-negotiated DataChannel, ignoring"; + COUNTER_USERNEG_OPEN_MESSAGE++; +} + +bool DataChannel::outgoing(message_ptr message) { + std::shared_ptr transport; + { + std::shared_lock lock(mMutex); + transport = mSctpTransport.lock(); + + if (!transport || mIsClosed) + throw std::runtime_error("DataChannel is closed"); + + if (message->size() > maxMessageSize()) + throw std::runtime_error("Message size exceeds limit"); + + // Before the ACK has been received on a DataChannel, all messages must be sent ordered + message->reliability = mIsOpen ? mReliability : nullptr; + message->stream = mStream; + } + + return transport->send(message); +} + +void DataChannel::incoming(message_ptr message) { + if (!message) + return; + + switch (message->type) { + case Message::Control: { + if (message->size() == 0) + break; // Ignore + auto raw = reinterpret_cast(message->data()); + switch (raw[0]) { + case MESSAGE_OPEN: + processOpenMessage(message); + break; + case MESSAGE_ACK: + if (!mIsOpen.exchange(true)) { + triggerOpen(); + } + break; + case MESSAGE_CLOSE: + // The close message will be processed in-order in receive() + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); + break; + default: + // Ignore + break; + } + break; + } + case Message::String: + case Message::Binary: + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); + break; + default: + // Ignore + break; + } +} + +NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr pc, + uint16_t stream, string label, string protocol, + Reliability reliability) + : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {} + +NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr pc, + std::weak_ptr transport, + uint16_t stream) + : DataChannel(pc, stream, "", "", {}) { + mSctpTransport = transport; +} + +NegotiatedDataChannel::~NegotiatedDataChannel() {} + +void NegotiatedDataChannel::open(shared_ptr transport) { + std::unique_lock lock(mMutex); + mSctpTransport = transport; + + uint8_t channelType; + uint32_t reliabilityParameter; + switch (mReliability->type) { + case Reliability::Type::Rexmit: + channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT; + reliabilityParameter = uint32_t(std::get(mReliability->rexmit)); + break; + + case Reliability::Type::Timed: + channelType = CHANNEL_PARTIAL_RELIABLE_TIMED; + reliabilityParameter = uint32_t(std::get(mReliability->rexmit).count()); + break; + + default: + channelType = CHANNEL_RELIABLE; + reliabilityParameter = 0; + break; + } + + if (mReliability->unordered) + channelType |= 0x80; + + const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size(); + binary buffer(len, byte(0)); + auto &open = *reinterpret_cast(buffer.data()); + open.type = MESSAGE_OPEN; + open.channelType = channelType; + open.priority = htons(0); + open.reliabilityParameter = htonl(reliabilityParameter); + open.labelLength = htons(uint16_t(mLabel.size())); + open.protocolLength = htons(uint16_t(mProtocol.size())); + + auto end = reinterpret_cast(buffer.data() + sizeof(OpenMessage)); + std::copy(mLabel.begin(), mLabel.end(), end); + std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size()); + + lock.unlock(); + + transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); +} + +void NegotiatedDataChannel::processOpenMessage(message_ptr message) { + std::unique_lock lock(mMutex); + auto transport = mSctpTransport.lock(); + if (!transport) + throw std::runtime_error("DataChannel has no transport"); + + if (message->size() < sizeof(OpenMessage)) + throw std::invalid_argument("DataChannel open message too small"); + + OpenMessage open = *reinterpret_cast(message->data()); + open.priority = ntohs(open.priority); + open.reliabilityParameter = ntohl(open.reliabilityParameter); + open.labelLength = ntohs(open.labelLength); + open.protocolLength = ntohs(open.protocolLength); + + if (message->size() < sizeof(OpenMessage) + size_t(open.labelLength + open.protocolLength)) + throw std::invalid_argument("DataChannel open message truncated"); + + auto end = reinterpret_cast(message->data() + sizeof(OpenMessage)); + mLabel.assign(end, open.labelLength); + mProtocol.assign(end + open.labelLength, open.protocolLength); + + mReliability->unordered = (open.channelType & 0x80) != 0; + switch (open.channelType & 0x7F) { + case CHANNEL_PARTIAL_RELIABLE_REXMIT: + mReliability->type = Reliability::Type::Rexmit; + mReliability->rexmit = int(open.reliabilityParameter); + break; + case CHANNEL_PARTIAL_RELIABLE_TIMED: + mReliability->type = Reliability::Type::Timed; + mReliability->rexmit = milliseconds(open.reliabilityParameter); + break; + default: + mReliability->type = Reliability::Type::Reliable; + mReliability->rexmit = int(0); + } + + lock.unlock(); + + binary buffer(sizeof(AckMessage), byte(0)); + auto &ack = *reinterpret_cast(buffer.data()); + ack.type = MESSAGE_ACK; + + transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); + + if (!mIsOpen.exchange(true)) + triggerOpen(); +} + +} // namespace rtc::impl diff --git a/src/impl/datachannel.hpp b/src/impl/datachannel.hpp new file mode 100644 index 0000000..4d8eb92 --- /dev/null +++ b/src/impl/datachannel.hpp @@ -0,0 +1,94 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_IMPL_DATA_CHANNEL_H +#define RTC_IMPL_DATA_CHANNEL_H + +#include "channel.hpp" +#include "common.hpp" +#include "message.hpp" +#include "peerconnection.hpp" +#include "queue.hpp" +#include "reliability.hpp" +#include "sctptransport.hpp" + +#include + +namespace rtc::impl { + +struct PeerConnection; + +struct DataChannel : Channel, std::enable_shared_from_this { + DataChannel(weak_ptr pc, uint16_t stream, string label, string protocol, + Reliability reliability); + ~DataChannel(); + + void close(); + void remoteClose(); + bool outgoing(message_ptr message); + void incoming(message_ptr message); + + std::optional receive() override; + std::optional peek() override; + size_t availableAmount() const override; + + uint16_t stream() const; + string label() const; + string protocol() const; + Reliability reliability() const; + + bool isOpen(void) const; + bool isClosed(void) const; + size_t maxMessageSize() const; + + void shiftStream(); + + virtual void open(shared_ptr transport); + virtual void processOpenMessage(message_ptr); + +protected: + const weak_ptr mPeerConnection; + weak_ptr mSctpTransport; + + uint16_t mStream; + string mLabel; + string mProtocol; + shared_ptr mReliability; + + mutable std::shared_mutex mMutex; + + Queue mRecvQueue; + + std::atomic mIsOpen = false; + std::atomic mIsClosed = false; +}; + +struct NegotiatedDataChannel final : public DataChannel { + NegotiatedDataChannel(weak_ptr pc, uint16_t stream, string label, + string protocol, Reliability reliability); + NegotiatedDataChannel(weak_ptr pc, + weak_ptr transport, uint16_t stream); + ~NegotiatedDataChannel(); + + void open(impl_ptr transport) override; + void processOpenMessage(message_ptr message) override; +}; + +} // namespace rtc::impl + +#endif diff --git a/src/dtlssrtptransport.cpp b/src/impl/dtlssrtptransport.cpp similarity index 99% rename from src/dtlssrtptransport.cpp rename to src/impl/dtlssrtptransport.cpp index df6292f..2704ef6 100644 --- a/src/dtlssrtptransport.cpp +++ b/src/impl/dtlssrtptransport.cpp @@ -19,17 +19,17 @@ #include "dtlssrtptransport.hpp" #include "logcounter.hpp" #include "tls.hpp" +#include "rtp.hpp" #if RTC_ENABLE_MEDIA #include #include -using std::shared_ptr; using std::to_integer; using std::to_string; -namespace rtc { +namespace rtc::impl { static LogCounter COUNTER_MEDIA_TRUNCATED(plog::warning, "Number of truncated SRT(C)P packets received"); @@ -48,7 +48,7 @@ static LogCounter COUNTER_SRTP_REPLAY(plog::warning, "Number of SRTP replay pack static LogCounter COUNTER_SRTP_AUTH_FAIL(plog::warning, "Number of SRTP packets received that failed authentication checks"); -static rtc::LogCounter +static LogCounter COUNTER_SRTP_FAIL(plog::warning, "Number of SRTP packets received that had an unknown libSRTP failure"); diff --git a/src/dtlssrtptransport.hpp b/src/impl/dtlssrtptransport.hpp similarity index 93% rename from src/dtlssrtptransport.hpp rename to src/impl/dtlssrtptransport.hpp index e93daca..631b91a 100644 --- a/src/dtlssrtptransport.hpp +++ b/src/impl/dtlssrtptransport.hpp @@ -16,11 +16,11 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_DTLS_SRTP_TRANSPORT_H -#define RTC_DTLS_SRTP_TRANSPORT_H +#ifndef RTC_IMPL_DTLS_SRTP_TRANSPORT_H +#define RTC_IMPL_DTLS_SRTP_TRANSPORT_H #include "dtlstransport.hpp" -#include "include.hpp" +#include "common.hpp" #if RTC_ENABLE_MEDIA @@ -32,7 +32,7 @@ #include -namespace rtc { +namespace rtc::impl { class DtlsSrtpTransport final : public DtlsTransport { public: diff --git a/src/dtlstransport.cpp b/src/impl/dtlstransport.cpp similarity index 99% rename from src/dtlstransport.cpp rename to src/impl/dtlstransport.cpp index fc0d7f8..bd64c79 100644 --- a/src/dtlstransport.cpp +++ b/src/impl/dtlstransport.cpp @@ -17,6 +17,7 @@ */ #include "dtlstransport.hpp" +#include "globals.hpp" #include "icetransport.hpp" #include @@ -34,12 +35,7 @@ using namespace std::chrono; -using std::shared_ptr; -using std::string; -using std::unique_ptr; -using std::weak_ptr; - -namespace rtc { +namespace rtc::impl { #if USE_GNUTLS @@ -597,4 +593,4 @@ long DtlsTransport::BioMethodCtrl(BIO * /*bio*/, int cmd, long /*num*/, void * / #endif -} // namespace rtc +} // namespace rtc::impl diff --git a/src/dtlstransport.hpp b/src/impl/dtlstransport.hpp similarity index 95% rename from src/dtlstransport.hpp rename to src/impl/dtlstransport.hpp index c31f58b..dd3dcd5 100644 --- a/src/dtlstransport.hpp +++ b/src/impl/dtlstransport.hpp @@ -16,12 +16,11 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_DTLS_TRANSPORT_H -#define RTC_DTLS_TRANSPORT_H +#ifndef RTC_IMPL_DTLS_TRANSPORT_H +#define RTC_IMPL_DTLS_TRANSPORT_H #include "certificate.hpp" -#include "include.hpp" -#include "peerconnection.hpp" +#include "common.hpp" #include "queue.hpp" #include "tls.hpp" #include "transport.hpp" @@ -32,7 +31,7 @@ #include #include -namespace rtc { +namespace rtc::impl { class IceTransport; diff --git a/src/icetransport.cpp b/src/impl/icetransport.cpp similarity index 97% rename from src/icetransport.cpp rename to src/impl/icetransport.cpp index 0037891..68d04dd 100644 --- a/src/icetransport.cpp +++ b/src/impl/icetransport.cpp @@ -18,6 +18,7 @@ #include "icetransport.hpp" #include "configuration.hpp" +#include "globals.hpp" #include "transport.hpp" #include @@ -37,17 +38,14 @@ #include using namespace std::chrono_literals; - -using std::shared_ptr; -using std::weak_ptr; using std::chrono::system_clock; +namespace rtc::impl { + #if !USE_NICE #define MAX_TURN_SERVERS_COUNT 2 -namespace rtc { - IceTransport::IceTransport(const Configuration &config, candidate_callback candidateCallback, state_callback stateChangeCallback, gathering_state_callback gatheringStateChangeCallback) @@ -282,7 +280,7 @@ void IceTransport::processCandidate(const string &candidate) { void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); } void IceTransport::StateChangeCallback(juice_agent_t *, juice_state_t state, void *user_ptr) { - auto iceTransport = static_cast(user_ptr); + auto iceTransport = static_cast(user_ptr); try { iceTransport->processStateChange(static_cast(state)); } catch (const std::exception &e) { @@ -291,7 +289,7 @@ void IceTransport::StateChangeCallback(juice_agent_t *, juice_state_t state, voi } void IceTransport::CandidateCallback(juice_agent_t *, const char *sdp, void *user_ptr) { - auto iceTransport = static_cast(user_ptr); + auto iceTransport = static_cast(user_ptr); try { iceTransport->processCandidate(sdp); } catch (const std::exception &e) { @@ -300,7 +298,7 @@ void IceTransport::CandidateCallback(juice_agent_t *, const char *sdp, void *use } void IceTransport::GatheringDoneCallback(juice_agent_t *, void *user_ptr) { - auto iceTransport = static_cast(user_ptr); + auto iceTransport = static_cast(user_ptr); try { iceTransport->processGatheringDone(); } catch (const std::exception &e) { @@ -309,7 +307,7 @@ void IceTransport::GatheringDoneCallback(juice_agent_t *, void *user_ptr) { } void IceTransport::RecvCallback(juice_agent_t *, const char *data, size_t size, void *user_ptr) { - auto iceTransport = static_cast(user_ptr); + auto iceTransport = static_cast(user_ptr); try { PLOG_VERBOSE << "Incoming size=" << size; auto b = reinterpret_cast(data); @@ -341,12 +339,8 @@ void IceTransport::LogCallback(juice_log_level_t level, const char *message) { PLOG(severity) << "juice: " << message; } -} // namespace rtc - #else // USE_NICE == 1 -namespace rtc { - IceTransport::IceTransport(const Configuration &config, candidate_callback candidateCallback, state_callback stateChangeCallback, gathering_state_callback gatheringStateChangeCallback) @@ -717,7 +711,7 @@ string IceTransport::AddressToString(const NiceAddress &addr) { void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData) { - auto iceTransport = static_cast(userData); + auto iceTransport = static_cast(userData); gchar *cand = nice_agent_generate_local_candidate_sdp(agent, candidate); try { iceTransport->processCandidate(cand); @@ -729,7 +723,7 @@ void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, void IceTransport::GatheringDoneCallback(NiceAgent * /*agent*/, guint /*streamId*/, gpointer userData) { - auto iceTransport = static_cast(userData); + auto iceTransport = static_cast(userData); try { iceTransport->processGatheringDone(); } catch (const std::exception &e) { @@ -739,7 +733,7 @@ void IceTransport::GatheringDoneCallback(NiceAgent * /*agent*/, guint /*streamId void IceTransport::StateChangeCallback(NiceAgent * /*agent*/, guint /*streamId*/, guint /*componentId*/, guint state, gpointer userData) { - auto iceTransport = static_cast(userData); + auto iceTransport = static_cast(userData); try { iceTransport->processStateChange(state); } catch (const std::exception &e) { @@ -749,7 +743,7 @@ void IceTransport::StateChangeCallback(NiceAgent * /*agent*/, guint /*streamId*/ void IceTransport::RecvCallback(NiceAgent * /*agent*/, guint /*streamId*/, guint /*componentId*/, guint len, gchar *buf, gpointer userData) { - auto iceTransport = static_cast(userData); + auto iceTransport = static_cast(userData); try { PLOG_VERBOSE << "Incoming size=" << len; auto b = reinterpret_cast(buf); @@ -760,7 +754,7 @@ void IceTransport::RecvCallback(NiceAgent * /*agent*/, guint /*streamId*/, guint } gboolean IceTransport::TimeoutCallback(gpointer userData) { - auto iceTransport = static_cast(userData); + auto iceTransport = static_cast(userData); try { iceTransport->processTimeout(); } catch (const std::exception &e) { @@ -811,6 +805,6 @@ bool IceTransport::getSelectedCandidatePair(Candidate *local, Candidate *remote) return true; } -} // namespace rtc - #endif + +} // namespace rtc::impl diff --git a/src/icetransport.hpp b/src/impl/icetransport.hpp similarity index 97% rename from src/icetransport.hpp rename to src/impl/icetransport.hpp index 893b414..53fa784 100644 --- a/src/icetransport.hpp +++ b/src/impl/icetransport.hpp @@ -16,13 +16,13 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_ICE_TRANSPORT_H -#define RTC_ICE_TRANSPORT_H +#ifndef RTC_IMPL_ICE_TRANSPORT_H +#define RTC_IMPL_ICE_TRANSPORT_H #include "candidate.hpp" #include "configuration.hpp" #include "description.hpp" -#include "include.hpp" +#include "common.hpp" #include "peerconnection.hpp" #include "transport.hpp" @@ -37,7 +37,7 @@ #include #include -namespace rtc { +namespace rtc::impl { class IceTransport : public Transport { public: diff --git a/src/logcounter.cpp b/src/impl/logcounter.cpp similarity index 98% rename from src/logcounter.cpp rename to src/impl/logcounter.cpp index dd226f5..b55af83 100644 --- a/src/logcounter.cpp +++ b/src/impl/logcounter.cpp @@ -18,7 +18,7 @@ #include "logcounter.hpp" -namespace rtc { +namespace rtc::impl { LogCounter::LogCounter(plog::Severity severity, const std::string &text, std::chrono::seconds duration) { diff --git a/src/logcounter.hpp b/src/impl/logcounter.hpp similarity index 96% rename from src/logcounter.hpp rename to src/impl/logcounter.hpp index cd0e0bd..3b494b6 100644 --- a/src/logcounter.hpp +++ b/src/impl/logcounter.hpp @@ -19,13 +19,13 @@ #ifndef RTC_SERVER_LOGCOUNTER_HPP #define RTC_SERVER_LOGCOUNTER_HPP -#include "include.hpp" +#include "common.hpp" #include "threadpool.hpp" #include #include -namespace rtc { +namespace rtc::impl { class LogCounter { private: diff --git a/src/impl/peerconnection.cpp b/src/impl/peerconnection.cpp new file mode 100644 index 0000000..106e278 --- /dev/null +++ b/src/impl/peerconnection.cpp @@ -0,0 +1,1011 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "peerconnection.hpp" +#include "certificate.hpp" +#include "dtlstransport.hpp" +#include "globals.hpp" +#include "icetransport.hpp" +#include "common.hpp" +#include "logcounter.hpp" +#include "peerconnection.hpp" +#include "processor.hpp" +#include "rtp.hpp" +#include "sctptransport.hpp" +#include "threadpool.hpp" + +#if RTC_ENABLE_MEDIA +#include "dtlssrtptransport.hpp" +#endif + +#include +#include +#include + +using namespace std::placeholders; + +#if __clang__ && defined(__APPLE__) +namespace { +template +inline std::shared_ptr reinterpret_pointer_cast(std::shared_ptr const &ptr) noexcept { + return std::shared_ptr(ptr, reinterpret_cast(ptr.get())); +} +} // namespace +#else +using std::reinterpret_pointer_cast; +#endif + +namespace rtc::impl { + +static LogCounter COUNTER_MEDIA_TRUNCATED(plog::warning, + "Number of RTP packets truncated over past second"); +static LogCounter COUNTER_SRTP_DECRYPT_ERROR(plog::warning, + "Number of SRTP decryption errors over past second"); +static LogCounter COUNTER_SRTP_ENCRYPT_ERROR(plog::warning, + "Number of SRTP encryption errors over past second"); +static LogCounter + COUNTER_UNKNOWN_PACKET_TYPE(plog::warning, + "Number of unknown RTCP packet types over past second"); + +PeerConnection::PeerConnection(Configuration config_) + : config(std::move(config_)), mCertificate(make_certificate()), + mProcessor(std::make_unique()) { + PLOG_VERBOSE << "Creating PeerConnection"; + + if (config.portRangeEnd && config.portRangeBegin > config.portRangeEnd) + throw std::invalid_argument("Invalid port range"); + + if (config.mtu) { + if (*config.mtu < 576) // Min MTU for IPv4 + throw std::invalid_argument("Invalid MTU value"); + + if (*config.mtu > 1500) { // Standard Ethernet + PLOG_WARNING << "MTU set to " << *config.mtu; + } else { + PLOG_VERBOSE << "MTU set to " << *config.mtu; + } + } +} + +PeerConnection::~PeerConnection() { + PLOG_VERBOSE << "Destroying PeerConnection"; + mProcessor->join(); +} + +void PeerConnection::close() { + PLOG_VERBOSE << "Closing PeerConnection"; + + negotiationNeeded = false; + + // Close data channels asynchronously + mProcessor->enqueue(&PeerConnection::closeDataChannels, this); + + closeTransports(); +} + +std::optional PeerConnection::localDescription() const { + std::lock_guard lock(mLocalDescriptionMutex); + return mLocalDescription; +} + +std::optional PeerConnection::remoteDescription() const { + std::lock_guard lock(mRemoteDescriptionMutex); + return mRemoteDescription; +} + +shared_ptr PeerConnection::initIceTransport() { + try { + if (auto transport = std::atomic_load(&mIceTransport)) + return transport; + + PLOG_VERBOSE << "Starting ICE transport"; + + auto transport = std::make_shared( + config, weak_bind(&PeerConnection::processLocalCandidate, this, _1), + [this, weak_this = weak_from_this()](IceTransport::State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case IceTransport::State::Connecting: + changeState(State::Connecting); + break; + case IceTransport::State::Failed: + changeState(State::Failed); + break; + case IceTransport::State::Connected: + initDtlsTransport(); + break; + case IceTransport::State::Disconnected: + changeState(State::Disconnected); + break; + default: + // Ignore + break; + } + }, + [this, weak_this = weak_from_this()](IceTransport::GatheringState gatheringState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (gatheringState) { + case IceTransport::GatheringState::InProgress: + changeGatheringState(GatheringState::InProgress); + break; + case IceTransport::GatheringState::Complete: + endLocalCandidates(); + changeGatheringState(GatheringState::Complete); + break; + default: + // Ignore + break; + } + }); + + std::atomic_store(&mIceTransport, transport); + if (state.load() == State::Closed) { + mIceTransport.reset(); + throw std::runtime_error("Connection is closed"); + } + transport->start(); + return transport; + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + changeState(State::Failed); + throw std::runtime_error("ICE transport initialization failed"); + } +} + +shared_ptr PeerConnection::initDtlsTransport() { + try { + if (auto transport = std::atomic_load(&mDtlsTransport)) + return transport; + + PLOG_VERBOSE << "Starting DTLS transport"; + + auto certificate = mCertificate.get(); + auto lower = std::atomic_load(&mIceTransport); + auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1); + auto stateChangeCallback = + [this, weak_this = weak_from_this()](DtlsTransport::State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + + switch (transportState) { + case DtlsTransport::State::Connected: + if (auto remote = remoteDescription(); remote && remote->hasApplication()) + initSctpTransport(); + else + changeState(State::Connected); + + mProcessor->enqueue(&PeerConnection::openTracks, this); + break; + case DtlsTransport::State::Failed: + changeState(State::Failed); + break; + case DtlsTransport::State::Disconnected: + changeState(State::Disconnected); + break; + default: + // Ignore + break; + } + }; + + shared_ptr transport; + if (auto local = localDescription(); local && local->hasAudioOrVideo()) { +#if RTC_ENABLE_MEDIA + PLOG_INFO << "This connection requires media support"; + + // DTLS-SRTP + transport = std::make_shared( + lower, certificate, config.mtu, verifierCallback, + weak_bind(&PeerConnection::forwardMedia, this, _1), stateChangeCallback); +#else + PLOG_WARNING << "Ignoring media support (not compiled with media support)"; +#endif + } + + if (!transport) { + // DTLS only + transport = std::make_shared(lower, certificate, config.mtu, + verifierCallback, stateChangeCallback); + } + + std::atomic_store(&mDtlsTransport, transport); + if (state.load() == State::Closed) { + mDtlsTransport.reset(); + throw std::runtime_error("Connection is closed"); + } + transport->start(); + return transport; + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + changeState(State::Failed); + throw std::runtime_error("DTLS transport initialization failed"); + } +} + +shared_ptr PeerConnection::initSctpTransport() { + try { + if (auto transport = std::atomic_load(&mSctpTransport)) + return transport; + + PLOG_VERBOSE << "Starting SCTP transport"; + + auto remote = remoteDescription(); + if (!remote || !remote->application()) + throw std::logic_error("Starting SCTP transport without application description"); + + uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT); + auto lower = std::atomic_load(&mDtlsTransport); + auto transport = std::make_shared( + lower, sctpPort, config.mtu, weak_bind(&PeerConnection::forwardMessage, this, _1), + weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2), + [this, weak_this = weak_from_this()](SctpTransport::State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case SctpTransport::State::Connected: + changeState(State::Connected); + mProcessor->enqueue(&PeerConnection::openDataChannels, this); + break; + case SctpTransport::State::Failed: + LOG_WARNING << "SCTP transport failed"; + changeState(State::Failed); + mProcessor->enqueue(&PeerConnection::remoteCloseDataChannels, this); + break; + case SctpTransport::State::Disconnected: + changeState(State::Disconnected); + mProcessor->enqueue(&PeerConnection::remoteCloseDataChannels, this); + break; + default: + // Ignore + break; + } + }); + + std::atomic_store(&mSctpTransport, transport); + if (state.load() == State::Closed) { + mSctpTransport.reset(); + throw std::runtime_error("Connection is closed"); + } + transport->start(); + return transport; + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + changeState(State::Failed); + throw std::runtime_error("SCTP transport initialization failed"); + } +} + +std::shared_ptr PeerConnection::getIceTransport() const { + return std::atomic_load(&mIceTransport); +} + +std::shared_ptr PeerConnection::getDtlsTransport() const { + return std::atomic_load(&mDtlsTransport); +} + +std::shared_ptr PeerConnection::getSctpTransport() const { + return std::atomic_load(&mSctpTransport); +} + +void PeerConnection::closeTransports() { + PLOG_VERBOSE << "Closing transports"; + + // Change state to sink state Closed + if (!changeState(State::Closed)) + return; // already closed + + // Reset callbacks now that state is changed + resetCallbacks(); + + // Initiate transport stop on the processor after closing the data channels + mProcessor->enqueue([this]() { + // Pass the pointers to a thread + auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr)); + auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr)); + auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr)); + ThreadPool::Instance().enqueue([sctp, dtls, ice]() mutable { + if (sctp) + sctp->stop(); + if (dtls) + dtls->stop(); + if (ice) + ice->stop(); + + sctp.reset(); + dtls.reset(); + ice.reset(); + }); + }); +} + +void PeerConnection::endLocalCandidates() { + std::lock_guard lock(mLocalDescriptionMutex); + if (mLocalDescription) + mLocalDescription->endCandidates(); +} + +void PeerConnection::rollbackLocalDescription() { + PLOG_DEBUG << "Rolling back pending local description"; + + std::unique_lock lock(mLocalDescriptionMutex); + if (mCurrentLocalDescription) { + std::vector existingCandidates; + if (mLocalDescription) + existingCandidates = mLocalDescription->extractCandidates(); + + mLocalDescription.emplace(std::move(*mCurrentLocalDescription)); + mLocalDescription->addCandidates(std::move(existingCandidates)); + mCurrentLocalDescription.reset(); + } +} + +bool PeerConnection::checkFingerprint(const std::string &fingerprint) const { + std::lock_guard lock(mRemoteDescriptionMutex); + if (auto expectedFingerprint = + mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) { + return *expectedFingerprint == fingerprint; + } + return false; +} + +void PeerConnection::forwardMessage(message_ptr message) { + if (!message) { + remoteCloseDataChannels(); + return; + } + + uint16_t stream = uint16_t(message->stream); + auto channel = findDataChannel(stream); + if (!channel) { + auto iceTransport = getIceTransport(); + auto sctpTransport = getSctpTransport(); + if (!iceTransport || !sctpTransport) + return; + + const byte dataChannelOpenMessage{0x03}; + uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0; + if (message->type == Message::Control && *message->data() == dataChannelOpenMessage && + stream % 2 == remoteParity) { + + channel = + std::make_shared(shared_from_this(), sctpTransport, stream); + channel->openCallback = weak_bind(&PeerConnection::triggerDataChannel, this, + weak_ptr{channel}); + + std::unique_lock lock(mDataChannelsMutex); // we are going to emplace + mDataChannels.emplace(stream, channel); + } else { + // Invalid, close the DataChannel + sctpTransport->closeStream(message->stream); + return; + } + } + + channel->incoming(message); +} + +void PeerConnection::forwardMedia(message_ptr message) { + if (!message) + return; + + // Browsers like to compound their packets with a random SSRC. + // we have to do this monstrosity to distribute the report blocks + if (message->type == Message::Control) { + std::set ssrcs; + size_t offset = 0; + while ((sizeof(rtc::RTCP_HEADER) + offset) <= message->size()) { + auto header = reinterpret_cast(message->data() + offset); + if (header->lengthInBytes() > message->size() - offset) { + COUNTER_MEDIA_TRUNCATED++; + break; + } + offset += header->lengthInBytes(); + if (header->payloadType() == 205 || header->payloadType() == 206) { + auto rtcpfb = reinterpret_cast(header); + ssrcs.insert(rtcpfb->getPacketSenderSSRC()); + ssrcs.insert(rtcpfb->getMediaSourceSSRC()); + + } else if (header->payloadType() == 200 || header->payloadType() == 201) { + auto rtcpsr = reinterpret_cast(header); + ssrcs.insert(rtcpsr->senderSSRC()); + for (int i = 0; i < rtcpsr->header.reportCount(); ++i) + ssrcs.insert(rtcpsr->getReportBlock(i)->getSSRC()); + } else if (header->payloadType() == 202) { + auto sdes = reinterpret_cast(header); + if (!sdes->isValid()) { + PLOG_WARNING << "RTCP SDES packet is invalid"; + continue; + } + for (unsigned int i = 0; i < sdes->chunksCount(); i++) { + auto chunk = sdes->getChunk(i); + ssrcs.insert(chunk->ssrc()); + } + } else { + // PT=207 == Extended Report + if (header->payloadType() != 207) { + COUNTER_UNKNOWN_PACKET_TYPE++; + } + } + } + + if (!ssrcs.empty()) { + for (uint32_t ssrc : ssrcs) { + if (auto mid = getMidFromSsrc(ssrc)) { + std::shared_lock lock(mTracksMutex); // read-only + if (auto it = mTracks.find(*mid); it != mTracks.end()) + if (auto track = it->second.lock()) + track->incoming(message); + } + } + return; + } + } + + uint32_t ssrc = uint32_t(message->stream); + if (auto mid = getMidFromSsrc(ssrc)) { + std::shared_lock lock(mTracksMutex); // read-only + if (auto it = mTracks.find(*mid); it != mTracks.end()) + if (auto track = it->second.lock()) + track->incoming(message); + } else { + /* + * TODO: So the problem is that when stop sending streams, we stop getting report blocks for + * those streams Therefore when we get compound RTCP packets, they are empty, and we can't + * forward them. Therefore, it is expected that we don't know where to forward packets. Is + * this ideal? No! Do I know how to fix it? No! + */ + // PLOG_WARNING << "Track not found for SSRC " << ssrc << ", dropping"; + return; + } +} + +std::optional PeerConnection::getMidFromSsrc(uint32_t ssrc) { + if (auto it = mMidFromSsrc.find(ssrc); it != mMidFromSsrc.end()) + return it->second; + + { + std::lock_guard lock(mRemoteDescriptionMutex); + if (!mRemoteDescription) + return nullopt; + for (unsigned int i = 0; i < mRemoteDescription->mediaCount(); ++i) { + if (auto found = std::visit( + rtc::overloaded{[&](Description::Application *) -> std::optional { + return std::nullopt; + }, + [&](Description::Media *media) -> std::optional { + return media->hasSSRC(ssrc) + ? std::make_optional(media->mid()) + : nullopt; + }}, + mRemoteDescription->media(i))) { + + mMidFromSsrc.emplace(ssrc, *found); + return *found; + } + } + } + { + std::lock_guard lock(mLocalDescriptionMutex); + if (!mLocalDescription) + return nullopt; + for (unsigned int i = 0; i < mLocalDescription->mediaCount(); ++i) { + if (auto found = std::visit( + rtc::overloaded{[&](Description::Application *) -> std::optional { + return std::nullopt; + }, + [&](Description::Media *media) -> std::optional { + return media->hasSSRC(ssrc) + ? std::make_optional(media->mid()) + : nullopt; + }}, + mLocalDescription->media(i))) { + + mMidFromSsrc.emplace(ssrc, *found); + return *found; + } + } + } + + return nullopt; +} + +void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) { + if (auto channel = findDataChannel(stream)) + channel->triggerBufferedAmount(amount); +} + +shared_ptr PeerConnection::emplaceDataChannel(Description::Role role, string label, + DataChannelInit init) { + std::unique_lock lock(mDataChannelsMutex); // we are going to emplace + uint16_t stream; + if (init.id) { + stream = *init.id; + if (stream == 65535) + throw std::invalid_argument("Invalid DataChannel id"); + } else { + // The active side must use streams with even identifiers, whereas the passive side must use + // streams with odd identifiers. + // See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6 + stream = (role == Description::Role::Active) ? 0 : 1; + while (mDataChannels.find(stream) != mDataChannels.end()) { + if (stream >= 65535 - 2) + throw std::runtime_error("Too many DataChannels"); + + stream += 2; + } + } + // If the DataChannel is user-negotiated, do not negociate it here + auto channel = + init.negotiated + ? std::make_shared(shared_from_this(), stream, std::move(label), + std::move(init.protocol), std::move(init.reliability)) + : std::make_shared(shared_from_this(), stream, std::move(label), + std::move(init.protocol), + std::move(init.reliability)); + mDataChannels.emplace(std::make_pair(stream, channel)); + return channel; +} + +shared_ptr PeerConnection::findDataChannel(uint16_t stream) { + std::shared_lock lock(mDataChannelsMutex); // read-only + if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) + if (auto channel = it->second.lock()) + return channel; + + return nullptr; +} + +void PeerConnection::shiftDataChannels() { + auto iceTransport = std::atomic_load(&mIceTransport); + auto sctpTransport = std::atomic_load(&mSctpTransport); + if (!sctpTransport && iceTransport && iceTransport->role() == Description::Role::Active) { + std::unique_lock lock(mDataChannelsMutex); // we are going to swap the container + decltype(mDataChannels) newDataChannels; + auto it = mDataChannels.begin(); + while (it != mDataChannels.end()) { + auto channel = it->second.lock(); + channel->shiftStream(); + newDataChannels.emplace(channel->stream(), channel); + ++it; + } + std::swap(mDataChannels, newDataChannels); + } +} + +void PeerConnection::iterateDataChannels( + std::function channel)> func) { + // Iterate + { + std::shared_lock lock(mDataChannelsMutex); // read-only + auto it = mDataChannels.begin(); + while (it != mDataChannels.end()) { + auto channel = it->second.lock(); + if (channel && !channel->isClosed()) + func(channel); + + ++it; + } + } + + // Cleanup + { + std::unique_lock lock(mDataChannelsMutex); // we are going to erase + auto it = mDataChannels.begin(); + while (it != mDataChannels.end()) { + if (!it->second.lock()) { + it = mDataChannels.erase(it); + continue; + } + + ++it; + } + } +} + +void PeerConnection::openDataChannels() { + if (auto transport = std::atomic_load(&mSctpTransport)) + iterateDataChannels([&](shared_ptr channel) { channel->open(transport); }); +} + +void PeerConnection::closeDataChannels() { + iterateDataChannels([&](shared_ptr channel) { channel->close(); }); +} + +void PeerConnection::remoteCloseDataChannels() { + iterateDataChannels([&](shared_ptr channel) { channel->remoteClose(); }); +} + +shared_ptr PeerConnection::emplaceTrack(Description::Media description) { + std::shared_ptr track; + if (auto it = mTracks.find(description.mid()); it != mTracks.end()) + if (track = it->second.lock(); track) + track->setDescription(std::move(description)); + + if (!track) { + track = std::make_shared(std::move(description)); + mTracks.emplace(std::make_pair(track->mid(), track)); + mTrackLines.emplace_back(track); + } + + return track; +} + +void PeerConnection::incomingTrack(Description::Media description) { + std::unique_lock lock(mTracksMutex); // we are going to emplace +#if !RTC_ENABLE_MEDIA + if (mTracks.empty()) { + PLOG_WARNING << "Tracks will be inative (not compiled with media support)"; + } +#endif + if (mTracks.find(description.mid()) == mTracks.end()) { + auto track = std::make_shared(std::move(description)); + mTracks.emplace(std::make_pair(track->mid(), track)); + mTrackLines.emplace_back(track); + triggerTrack(track); + } +} + +void PeerConnection::openTracks() { +#if RTC_ENABLE_MEDIA + if (auto transport = std::atomic_load(&mDtlsTransport)) { + auto srtpTransport = reinterpret_pointer_cast(transport); + std::shared_lock lock(mTracksMutex); // read-only + for (auto it = mTracks.begin(); it != mTracks.end(); ++it) + if (auto track = it->second.lock()) + if (!track->isOpen()) + track->open(srtpTransport); + } +#endif +} + +void PeerConnection::validateRemoteDescription(const Description &description) { + if (!description.iceUfrag()) + throw std::invalid_argument("Remote description has no ICE user fragment"); + + if (!description.icePwd()) + throw std::invalid_argument("Remote description has no ICE password"); + + if (!description.fingerprint()) + throw std::invalid_argument("Remote description has no fingerprint"); + + if (description.mediaCount() == 0) + throw std::invalid_argument("Remote description has no media line"); + + int activeMediaCount = 0; + for (unsigned int i = 0; i < description.mediaCount(); ++i) + std::visit(rtc::overloaded{[&](const Description::Application *) { ++activeMediaCount; }, + [&](const Description::Media *media) { + if (media->direction() != Description::Direction::Inactive) + ++activeMediaCount; + }}, + description.media(i)); + + if (activeMediaCount == 0) + throw std::invalid_argument("Remote description has no active media"); + + if (auto local = localDescription(); local && local->iceUfrag() && local->icePwd()) + if (*description.iceUfrag() == *local->iceUfrag() && + *description.icePwd() == *local->icePwd()) + throw std::logic_error("Got the local description as remote description"); + + PLOG_VERBOSE << "Remote description looks valid"; +} + +void PeerConnection::processLocalDescription(Description description) { + + if (auto remote = remoteDescription()) { + // Reciprocate remote description + for (unsigned int i = 0; i < remote->mediaCount(); ++i) + std::visit( // reciprocate each media + rtc::overloaded{ + [&](Description::Application *remoteApp) { + std::shared_lock lock(mDataChannelsMutex); + if (!mDataChannels.empty()) { + // Prefer local description + Description::Application app(remoteApp->mid()); + app.setSctpPort(DEFAULT_SCTP_PORT); + app.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); + + PLOG_DEBUG << "Adding application to local description, mid=\"" + << app.mid() << "\""; + + description.addMedia(std::move(app)); + return; + } + + auto reciprocated = remoteApp->reciprocate(); + reciprocated.hintSctpPort(DEFAULT_SCTP_PORT); + reciprocated.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); + + PLOG_DEBUG << "Reciprocating application in local description, mid=\"" + << reciprocated.mid() << "\""; + + description.addMedia(std::move(reciprocated)); + }, + [&](Description::Media *remoteMedia) { + std::shared_lock lock(mTracksMutex); + if (auto it = mTracks.find(remoteMedia->mid()); it != mTracks.end()) { + // Prefer local description + if (auto track = it->second.lock()) { + auto media = track->description(); +#if !RTC_ENABLE_MEDIA + // No media support, mark as inactive + media.setDirection(Description::Direction::Inactive); +#endif + PLOG_DEBUG + << "Adding media to local description, mid=\"" << media.mid() + << "\", active=" << std::boolalpha + << (media.direction() != Description::Direction::Inactive); + + description.addMedia(std::move(media)); + } else { + auto reciprocated = remoteMedia->reciprocate(); + reciprocated.setDirection(Description::Direction::Inactive); + + PLOG_DEBUG << "Adding inactive media to local description, mid=\"" + << reciprocated.mid() << "\""; + + description.addMedia(std::move(reciprocated)); + } + return; + } + lock.unlock(); // we are going to call incomingTrack() + + auto reciprocated = remoteMedia->reciprocate(); +#if !RTC_ENABLE_MEDIA + // No media support, mark as inactive + reciprocated.setDirection(Description::Direction::Inactive); +#endif + incomingTrack(reciprocated); + + PLOG_DEBUG + << "Reciprocating media in local description, mid=\"" + << reciprocated.mid() << "\", active=" << std::boolalpha + << (reciprocated.direction() != Description::Direction::Inactive); + + description.addMedia(std::move(reciprocated)); + }, + }, + remote->media(i)); + } + + if (description.type() == Description::Type::Offer) { + // This is an offer, add locally created data channels and tracks + // Add application for data channels + if (!description.hasApplication()) { + std::shared_lock lock(mDataChannelsMutex); + if (!mDataChannels.empty()) { + unsigned int m = 0; + while (description.hasMid(std::to_string(m))) + ++m; + Description::Application app(std::to_string(m)); + app.setSctpPort(DEFAULT_SCTP_PORT); + app.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); + + PLOG_DEBUG << "Adding application to local description, mid=\"" << app.mid() + << "\""; + + description.addMedia(std::move(app)); + } + } + + // Add media for local tracks + std::shared_lock lock(mTracksMutex); + for (auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) { + if (auto track = it->lock()) { + if (description.hasMid(track->mid())) + continue; + + auto media = track->description(); +#if !RTC_ENABLE_MEDIA + // No media support, mark as inactive + media.setDirection(Description::Direction::Inactive); +#endif + PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid() + << "\", active=" << std::boolalpha + << (media.direction() != Description::Direction::Inactive); + + description.addMedia(std::move(media)); + } + } + } + + // Set local fingerprint (wait for certificate if necessary) + description.setFingerprint(mCertificate.get()->fingerprint()); + + { + // Set as local description + std::lock_guard lock(mLocalDescriptionMutex); + + std::vector existingCandidates; + if (mLocalDescription) { + existingCandidates = mLocalDescription->extractCandidates(); + mCurrentLocalDescription.emplace(std::move(*mLocalDescription)); + } + + mLocalDescription.emplace(description); + mLocalDescription->addCandidates(std::move(existingCandidates)); + } + + PLOG_VERBOSE << "Issuing local description: " << description; + mProcessor->enqueue(localDescriptionCallback.wrap(), std::move(description)); + + // Reciprocated tracks might need to be open + if (auto dtlsTransport = std::atomic_load(&mDtlsTransport); + dtlsTransport && dtlsTransport->state() == Transport::State::Connected) + mProcessor->enqueue(&PeerConnection::openTracks, this); +} + +void PeerConnection::processLocalCandidate(Candidate candidate) { + std::lock_guard lock(mLocalDescriptionMutex); + if (!mLocalDescription) + throw std::logic_error("Got a local candidate without local description"); + + candidate.resolve(Candidate::ResolveMode::Simple); + mLocalDescription->addCandidate(candidate); + + PLOG_VERBOSE << "Issuing local candidate: " << candidate; + mProcessor->enqueue(localCandidateCallback.wrap(), std::move(candidate)); +} + +void PeerConnection::processRemoteDescription(Description description) { + { + // Set as remote description + std::lock_guard lock(mRemoteDescriptionMutex); + + std::vector existingCandidates; + if (mRemoteDescription) + existingCandidates = mRemoteDescription->extractCandidates(); + + mRemoteDescription.emplace(description); + mRemoteDescription->addCandidates(std::move(existingCandidates)); + } + + auto iceTransport = initIceTransport(); + iceTransport->setRemoteDescription(std::move(description)); + + if (description.hasApplication()) { + auto dtlsTransport = std::atomic_load(&mDtlsTransport); + auto sctpTransport = std::atomic_load(&mSctpTransport); + if (!sctpTransport && dtlsTransport && + dtlsTransport->state() == Transport::State::Connected) + initSctpTransport(); + } +} + +void PeerConnection::processRemoteCandidate(Candidate candidate) { + auto iceTransport = std::atomic_load(&mIceTransport); + { + // Set as remote candidate + std::lock_guard lock(mRemoteDescriptionMutex); + if (!mRemoteDescription) + throw std::logic_error("Got a remote candidate without remote description"); + + if (!iceTransport) + throw std::logic_error("Got a remote candidate without ICE transport"); + + candidate.hintMid(mRemoteDescription->bundleMid()); + + if (mRemoteDescription->hasCandidate(candidate)) + return; // already in description, ignore + + candidate.resolve(Candidate::ResolveMode::Simple); + mRemoteDescription->addCandidate(candidate); + } + + if (candidate.isResolved()) { + iceTransport->addRemoteCandidate(std::move(candidate)); + } else { + // We might need a lookup, do it asynchronously + // We don't use the thread pool because we have no control on the timeout + if ((iceTransport = std::atomic_load(&mIceTransport))) { + weak_ptr weakIceTransport{iceTransport}; + std::thread t([weakIceTransport, candidate = std::move(candidate)]() mutable { + if (candidate.resolve(Candidate::ResolveMode::Lookup)) + if (auto iceTransport = weakIceTransport.lock()) + iceTransport->addRemoteCandidate(std::move(candidate)); + }); + t.detach(); + } + } +} + +string PeerConnection::localBundleMid() const { + std::lock_guard lock(mLocalDescriptionMutex); + return mLocalDescription ? mLocalDescription->bundleMid() : "0"; +} + +void PeerConnection::triggerDataChannel(weak_ptr weakDataChannel) { + auto dataChannel = weakDataChannel.lock(); + if (!dataChannel) + return; + + mProcessor->enqueue(dataChannelCallback.wrap(), + std::make_shared(std::move(dataChannel))); +} + +void PeerConnection::triggerTrack(std::shared_ptr track) { + mProcessor->enqueue(trackCallback.wrap(), std::make_shared(std::move(track))); +} + +bool PeerConnection::changeState(State newState) { + State current; + do { + current = state.load(); + if (current == State::Closed) + return false; + if (current == newState) + return false; + + } while (!state.compare_exchange_weak(current, newState)); + + std::ostringstream s; + s << newState; + PLOG_INFO << "Changed state to " << s.str(); + + if (newState == State::Closed) + // This is the last state change, so we may steal the callback + mProcessor->enqueue([cb = std::move(stateChangeCallback)]() { cb(State::Closed); }); + else + mProcessor->enqueue(stateChangeCallback.wrap(), newState); + + return true; +} + +bool PeerConnection::changeGatheringState(GatheringState newState) { + if (gatheringState.exchange(newState) == newState) + return false; + + std::ostringstream s; + s << newState; + PLOG_INFO << "Changed gathering state to " << s.str(); + mProcessor->enqueue(gatheringStateChangeCallback.wrap(), newState); + return true; +} + +bool PeerConnection::changeSignalingState(SignalingState newState) { + if (signalingState.exchange(newState) == newState) + return false; + + std::ostringstream s; + s << state; + PLOG_INFO << "Changed signaling state to " << s.str(); + mProcessor->enqueue(signalingStateChangeCallback.wrap(), newState); + return true; +} + +void PeerConnection::resetCallbacks() { + // Unregister all callbacks + dataChannelCallback = nullptr; + localDescriptionCallback = nullptr; + localCandidateCallback = nullptr; + stateChangeCallback = nullptr; + gatheringStateChangeCallback = nullptr; +} + +} // namespace rtc::impl diff --git a/src/impl/peerconnection.hpp b/src/impl/peerconnection.hpp new file mode 100644 index 0000000..a85a5cd --- /dev/null +++ b/src/impl/peerconnection.hpp @@ -0,0 +1,129 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_IMPL_PEER_CONNECTION_H +#define RTC_IMPL_PEER_CONNECTION_H + +#include "datachannel.hpp" +#include "dtlstransport.hpp" +#include "icetransport.hpp" +#include "common.hpp" +#include "sctptransport.hpp" +#include "track.hpp" + +#include "rtc/peerconnection.hpp" + +namespace rtc::impl { + +struct PeerConnection : std::enable_shared_from_this { + using State = rtc::PeerConnection::State; + using GatheringState = rtc::PeerConnection::GatheringState; + using SignalingState = rtc::PeerConnection::SignalingState; + + PeerConnection(Configuration config_); + ~PeerConnection(); + + void close(); + + std::optional localDescription() const; + std::optional remoteDescription() const; + + std::shared_ptr initIceTransport(); + std::shared_ptr initDtlsTransport(); + std::shared_ptr initSctpTransport(); + std::shared_ptr getIceTransport() const; + std::shared_ptr getDtlsTransport() const; + std::shared_ptr getSctpTransport() const; + void closeTransports(); + + void endLocalCandidates(); + void rollbackLocalDescription(); + bool checkFingerprint(const std::string &fingerprint) const; + void forwardMessage(message_ptr message); + void forwardMedia(message_ptr message); + void forwardBufferedAmount(uint16_t stream, size_t amount); + std::optional getMidFromSsrc(uint32_t ssrc); + + shared_ptr emplaceDataChannel(Description::Role role, string label, + DataChannelInit init); + shared_ptr findDataChannel(uint16_t stream); + void shiftDataChannels(); + void iterateDataChannels(std::function channel)> func); + void openDataChannels(); + void closeDataChannels(); + void remoteCloseDataChannels(); + + shared_ptr emplaceTrack(Description::Media description); + void incomingTrack(Description::Media description); + void openTracks(); + + void validateRemoteDescription(const Description &description); + void processLocalDescription(Description description); + void processLocalCandidate(Candidate candidate); + void processRemoteDescription(Description description); + void processRemoteCandidate(Candidate candidate); + string localBundleMid() const; + + void triggerDataChannel(std::weak_ptr weakDataChannel); + void triggerTrack(std::shared_ptr track); + bool changeState(State newState); + bool changeGatheringState(GatheringState newState); + bool changeSignalingState(SignalingState newState); + + void resetCallbacks(); + + void outgoingMedia(message_ptr message); + + const Configuration config; + std::atomic state = State::New; + std::atomic gatheringState = GatheringState::New; + std::atomic signalingState = SignalingState::Stable; + std::atomic negotiationNeeded = false; + + synchronized_callback> dataChannelCallback; + synchronized_callback localDescriptionCallback; + synchronized_callback localCandidateCallback; + synchronized_callback stateChangeCallback; + synchronized_callback gatheringStateChangeCallback; + synchronized_callback signalingStateChangeCallback; + synchronized_callback> trackCallback; + +private: + const init_token mInitToken = Init::Token(); + const future_certificate_ptr mCertificate; + const std::unique_ptr mProcessor; + + std::optional mLocalDescription, mRemoteDescription; + std::optional mCurrentLocalDescription; + mutable std::mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; + + std::shared_ptr mIceTransport; + std::shared_ptr mDtlsTransport; + std::shared_ptr mSctpTransport; + + std::unordered_map> mDataChannels; // by stream ID + std::unordered_map> mTracks; // by mid + std::vector> mTrackLines; // by SDP order + std::shared_mutex mDataChannelsMutex, mTracksMutex; + + std::unordered_map mMidFromSsrc; // cache +}; + +} // namespace rtc::impl + +#endif diff --git a/src/processor.cpp b/src/impl/processor.cpp similarity index 96% rename from src/processor.cpp rename to src/impl/processor.cpp index 5e83e47..c2eeb6d 100644 --- a/src/processor.cpp +++ b/src/impl/processor.cpp @@ -18,7 +18,7 @@ #include "processor.hpp" -namespace rtc { +namespace rtc::impl { Processor::Processor(size_t limit) : mTasks(limit) {} @@ -40,4 +40,4 @@ void Processor::schedule() { } } -} // namespace rtc +} // namespace rtc::impl diff --git a/src/processor.hpp b/src/impl/processor.hpp similarity index 94% rename from src/processor.hpp rename to src/impl/processor.hpp index a63259b..778d708 100644 --- a/src/processor.hpp +++ b/src/impl/processor.hpp @@ -16,10 +16,10 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_PROCESSOR_H -#define RTC_PROCESSOR_H +#ifndef RTC_IMPL_PROCESSOR_H +#define RTC_IMPL_PROCESSOR_H -#include "include.hpp" +#include "common.hpp" #include "init.hpp" #include "queue.hpp" #include "threadpool.hpp" @@ -30,7 +30,7 @@ #include #include -namespace rtc { +namespace rtc::impl { // Processed tasks in order by delegating them to the thread pool class Processor final { @@ -76,6 +76,6 @@ template void Processor::enqueue(F &&f, Args &&...args) } } -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/include/rtc/queue.hpp b/src/impl/queue.hpp similarity index 97% rename from include/rtc/queue.hpp rename to src/impl/queue.hpp index 1ffc41d..fe9f98a 100644 --- a/include/rtc/queue.hpp +++ b/src/impl/queue.hpp @@ -16,10 +16,10 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_QUEUE_H -#define RTC_QUEUE_H +#ifndef RTC_IMPL_QUEUE_H +#define RTC_IMPL_QUEUE_H -#include "include.hpp" +#include "common.hpp" #include #include @@ -28,7 +28,7 @@ #include #include -namespace rtc { +namespace rtc::impl { template class Queue { public: @@ -167,6 +167,6 @@ template std::optional Queue::popImpl() { return element; } -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/sctptransport.cpp b/src/impl/sctptransport.cpp similarity index 99% rename from src/sctptransport.cpp rename to src/impl/sctptransport.cpp index 6ac988f..baf7ad7 100644 --- a/src/sctptransport.cpp +++ b/src/impl/sctptransport.cpp @@ -18,6 +18,7 @@ #include "sctptransport.hpp" #include "dtlstransport.hpp" +#include "globals.hpp" #include "logcounter.hpp" #include @@ -54,7 +55,7 @@ using namespace std::chrono_literals; using namespace std::chrono; -namespace rtc { +namespace rtc::impl { static LogCounter COUNTER_UNKNOWN_PPID(plog::warning, "Number of SCTP packets received with an unknown PPID"); @@ -808,4 +809,4 @@ int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, return transport->handleWrite(static_cast(data), len, tos, set_df); } -} // namespace rtc +} // namespace rtc::impl diff --git a/src/sctptransport.hpp b/src/impl/sctptransport.hpp similarity index 94% rename from src/sctptransport.hpp rename to src/impl/sctptransport.hpp index 2bae169..e2ea74e 100644 --- a/src/sctptransport.hpp +++ b/src/impl/sctptransport.hpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2019-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -16,11 +16,10 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_SCTP_TRANSPORT_H -#define RTC_SCTP_TRANSPORT_H +#ifndef RTC_IMPL_SCTP_TRANSPORT_H +#define RTC_IMPL_SCTP_TRANSPORT_H -#include "include.hpp" -#include "peerconnection.hpp" +#include "common.hpp" #include "processor.hpp" #include "queue.hpp" #include "transport.hpp" @@ -34,7 +33,7 @@ #include "usrsctp.h" -namespace rtc { +namespace rtc::impl { class SctpTransport final : public Transport { public: @@ -120,6 +119,6 @@ private: static std::shared_mutex InstancesMutex; }; -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/tcptransport.cpp b/src/impl/tcptransport.cpp similarity index 99% rename from src/tcptransport.cpp rename to src/impl/tcptransport.cpp index a4b8619..cedf634 100644 --- a/src/tcptransport.cpp +++ b/src/impl/tcptransport.cpp @@ -17,6 +17,7 @@ */ #include "tcptransport.hpp" +#include "globals.hpp" #if RTC_ENABLE_WEBSOCKET @@ -27,7 +28,7 @@ #include #endif -namespace rtc { +namespace rtc::impl { using std::to_string; @@ -398,6 +399,6 @@ int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) { void TcpTransport::interruptSelect() { mInterrupter.interrupt(); } -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/tcptransport.hpp b/src/impl/tcptransport.hpp similarity index 94% rename from src/tcptransport.hpp rename to src/impl/tcptransport.hpp index 1447835..f091fa2 100644 --- a/src/tcptransport.hpp +++ b/src/impl/tcptransport.hpp @@ -16,10 +16,10 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_TCP_TRANSPORT_H -#define RTC_TCP_TRANSPORT_H +#ifndef RTC_IMPL_TCP_TRANSPORT_H +#define RTC_IMPL_TCP_TRANSPORT_H -#include "include.hpp" +#include "common.hpp" #include "queue.hpp" #include "transport.hpp" @@ -31,7 +31,7 @@ // Use the socket defines from libjuice #include "../deps/libjuice/src/socket.h" -namespace rtc { +namespace rtc::impl { // Utility class to interrupt select() class SelectInterrupter { @@ -85,7 +85,7 @@ private: Queue mSendQueue; }; -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/threadpool.cpp b/src/impl/threadpool.cpp similarity index 94% rename from src/threadpool.cpp rename to src/impl/threadpool.cpp index 42f3693..690d985 100644 --- a/src/threadpool.cpp +++ b/src/impl/threadpool.cpp @@ -22,11 +22,11 @@ namespace { -void joinThreadPoolInstance() { rtc::ThreadPool::Instance().join(); } +void joinThreadPoolInstance() { rtc::impl::ThreadPool::Instance().join(); } } // namespace -namespace rtc { +namespace rtc::impl { ThreadPool &ThreadPool::Instance() { static ThreadPool *instance = new ThreadPool; @@ -103,4 +103,4 @@ std::function ThreadPool::dequeue() { return nullptr; } -} // namespace rtc +} // namespace rtc::impl diff --git a/src/threadpool.hpp b/src/impl/threadpool.hpp similarity index 96% rename from src/threadpool.hpp rename to src/impl/threadpool.hpp index 70885b5..d312b1a 100644 --- a/src/threadpool.hpp +++ b/src/impl/threadpool.hpp @@ -16,10 +16,10 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_THREADPOOL_H -#define RTC_THREADPOOL_H +#ifndef RTC_IMPL_THREADPOOL_H +#define RTC_IMPL_THREADPOOL_H -#include "include.hpp" +#include "common.hpp" #include "init.hpp" #include @@ -34,7 +34,7 @@ #include #include -namespace rtc { +namespace rtc::impl { template using invoke_future_t = std::future, std::decay_t...>>; @@ -119,6 +119,6 @@ auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args) return result; } -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/tls.cpp b/src/impl/tls.cpp similarity index 100% rename from src/tls.cpp rename to src/impl/tls.cpp diff --git a/src/tls.hpp b/src/impl/tls.hpp similarity index 98% rename from src/tls.hpp rename to src/impl/tls.hpp index 8a2e150..da94e1a 100644 --- a/src/tls.hpp +++ b/src/impl/tls.hpp @@ -19,7 +19,7 @@ #ifndef RTC_TLS_H #define RTC_TLS_H -#include "include.hpp" +#include "common.hpp" #if USE_GNUTLS diff --git a/src/tlstransport.cpp b/src/impl/tlstransport.cpp similarity index 98% rename from src/tlstransport.cpp rename to src/impl/tlstransport.cpp index 246c3f1..256a885 100644 --- a/src/tlstransport.cpp +++ b/src/impl/tlstransport.cpp @@ -28,12 +28,7 @@ using namespace std::chrono; -using std::shared_ptr; -using std::string; -using std::unique_ptr; -using std::weak_ptr; - -namespace rtc { +namespace rtc::impl { #if USE_GNUTLS @@ -443,6 +438,6 @@ void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) { #endif -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/tlstransport.hpp b/src/impl/tlstransport.hpp similarity index 93% rename from src/tlstransport.hpp rename to src/impl/tlstransport.hpp index 820d4e9..8810572 100644 --- a/src/tlstransport.hpp +++ b/src/impl/tlstransport.hpp @@ -16,10 +16,10 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_TLS_TRANSPORT_H -#define RTC_TLS_TRANSPORT_H +#ifndef RTC_IMPL_TLS_TRANSPORT_H +#define RTC_IMPL_TLS_TRANSPORT_H -#include "include.hpp" +#include "common.hpp" #include "queue.hpp" #include "tls.hpp" #include "transport.hpp" @@ -28,7 +28,7 @@ #include -namespace rtc { +namespace rtc::impl { class TcpTransport; @@ -76,7 +76,7 @@ protected: #endif }; -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/impl/track.cpp b/src/impl/track.cpp new file mode 100644 index 0000000..11ed623 --- /dev/null +++ b/src/impl/track.cpp @@ -0,0 +1,182 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "track.hpp" +#include "globals.hpp" +#include "logcounter.hpp" + +namespace rtc::impl { + +static LogCounter COUNTER_MEDIA_BAD_DIRECTION(plog::warning, + "Number of media packets sent in invalid directions"); +static LogCounter COUNTER_QUEUE_FULL(plog::warning, + "Number of media packets dropped due to a full queue"); + +Track::Track(Description::Media description) + : mMediaDescription(std::move(description)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {} + +string Track::mid() const { + std::shared_lock lock(mMutex); + return mMediaDescription.mid(); +} + +Description::Direction Track::direction() const { + std::shared_lock lock(mMutex); + return mMediaDescription.direction(); +} + +Description::Media Track::description() const { + std::shared_lock lock(mMutex); + return mMediaDescription; +} + +void Track::setDescription(Description::Media description) { + std::unique_lock lock(mMutex); + if (description.mid() != mMediaDescription.mid()) + throw std::logic_error("Media description mid does not match track mid"); + + mMediaDescription = std::move(description); +} + +void Track::close() { + mIsClosed = true; + + setRtcpHandler(nullptr); + resetCallbacks(); +} + +std::optional Track::receive() { + if (auto next = mRecvQueue.tryPop()) + return to_variant(std::move(**next)); + + return nullopt; +} + +std::optional Track::peek() { + if (auto next = mRecvQueue.peek()) + return to_variant(std::move(**next)); + + return nullopt; +} + +bool Track::isOpen(void) const { +#if RTC_ENABLE_MEDIA + std::shared_lock lock(mMutex); + return !mIsClosed && mDtlsSrtpTransport.lock(); +#else + return !mIsClosed; +#endif +} + +bool Track::isClosed(void) const { return mIsClosed; } + +size_t Track::availableAmount() const { return mRecvQueue.amount(); } + +#if RTC_ENABLE_MEDIA +void Track::open(shared_ptr transport) { + { + std::lock_guard lock(mMutex); + mDtlsSrtpTransport = transport; + } + + triggerOpen(); +} +#endif + +void Track::incoming(message_ptr message) { + if (!message) + return; + + // TODO + auto dir = direction(); + if ((dir == Description::Direction::SendOnly || dir == Description::Direction::Inactive) && + message->type != Message::Control) { + COUNTER_MEDIA_BAD_DIRECTION++; + return; + } + + if (auto handler = getRtcpHandler()) { + message = handler->incoming(message); + if (!message) + return; + } + + // Tail drop if queue is full + if (mRecvQueue.full()) { + COUNTER_QUEUE_FULL++; + return; + } + + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); +} + +bool Track::outgoing([[maybe_unused]] message_ptr message) { + if (mIsClosed) + throw std::runtime_error("Track is closed"); + + auto dir = direction(); + if ((dir == Description::Direction::RecvOnly || dir == Description::Direction::Inactive)) { + COUNTER_MEDIA_BAD_DIRECTION++; + return false; + } + + if (auto handler = getRtcpHandler()) { + message = handler->outgoing(message); + if (!message) + return false; + } + +#if RTC_ENABLE_MEDIA + std::shared_ptr transport; + { + std::shared_lock lock(mMutex); + transport = mDtlsSrtpTransport.lock(); + if (!transport) + throw std::runtime_error("Track is closed"); + + // Set recommended medium-priority DSCP value + // See https://tools.ietf.org/html/draft-ietf-tsvwg-rtcweb-qos-18 + if (mMediaDescription.type() == "audio") + message->dscp = 46; // EF: Expedited Forwarding + else + message->dscp = 36; // AF42: Assured Forwarding class 4, medium drop probability + } + + return transport->sendMedia(message); +#else + PLOG_WARNING << "Ignoring track send (not compiled with media support)"; + return false; +#endif +} + +void Track::setRtcpHandler(std::shared_ptr handler) { + { + std::unique_lock lock(mMutex); + mRtcpHandler = handler; + } + + handler->onOutgoing(std::bind(&Track::outgoing, this, std::placeholders::_1)); +} + +std::shared_ptr Track::getRtcpHandler() { + std::shared_lock lock(mMutex); + return mRtcpHandler; +} + +} // namespace rtc::impl diff --git a/src/impl/track.hpp b/src/impl/track.hpp new file mode 100644 index 0000000..d57f2ae --- /dev/null +++ b/src/impl/track.hpp @@ -0,0 +1,83 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_IMPL_TRACK_H +#define RTC_IMPL_TRACK_H + +#include "channel.hpp" +#include "common.hpp" +#include "description.hpp" +#include "mediahandler.hpp" +#include "queue.hpp" + +#if RTC_ENABLE_MEDIA +#include "dtlssrtptransport.hpp" +#endif + +#include +#include +#include + +namespace rtc::impl { + +class Track final : public std::enable_shared_from_this, public Channel { +public: + Track(Description::Media description); + ~Track() = default; + + void close(); + void incoming(message_ptr message); + bool outgoing(message_ptr message); + + std::optional receive() override; + std::optional peek() override; + size_t availableAmount() const override; + + bool isOpen() const; + bool isClosed() const; + + string mid() const; + Description::Direction direction() const; + Description::Media description() const; + void setDescription(Description::Media description); + + std::shared_ptr getRtcpHandler(); + void setRtcpHandler(shared_ptr handler); + +#if RTC_ENABLE_MEDIA + void open(std::shared_ptr transport); +#endif + +private: +#if RTC_ENABLE_MEDIA + weak_ptr mDtlsSrtpTransport; +#endif + + Description::Media mMediaDescription; + shared_ptr mRtcpHandler; + + mutable std::shared_mutex mMutex; + + std::atomic mIsClosed = false; + + Queue mRecvQueue; +}; + +} // namespace rtc::impl + +#endif diff --git a/src/transport.hpp b/src/impl/transport.hpp similarity index 95% rename from src/transport.hpp rename to src/impl/transport.hpp index 04d106c..fb68d1e 100644 --- a/src/transport.hpp +++ b/src/impl/transport.hpp @@ -16,17 +16,17 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_TRANSPORT_H -#define RTC_TRANSPORT_H +#ifndef RTC_IMPL_TRANSPORT_H +#define RTC_IMPL_TRANSPORT_H -#include "include.hpp" +#include "common.hpp" #include "message.hpp" #include #include #include -namespace rtc { +namespace rtc::impl { class Transport { public: @@ -98,6 +98,6 @@ private: std::atomic mStopped = true; }; -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/verifiedtlstransport.cpp b/src/impl/verifiedtlstransport.cpp similarity index 90% rename from src/verifiedtlstransport.cpp rename to src/impl/verifiedtlstransport.cpp index f03ba65..b214822 100644 --- a/src/verifiedtlstransport.cpp +++ b/src/impl/verifiedtlstransport.cpp @@ -17,16 +17,11 @@ */ #include "verifiedtlstransport.hpp" -#include "include.hpp" +#include "common.hpp" #if RTC_ENABLE_WEBSOCKET -using std::shared_ptr; -using std::string; -using std::unique_ptr; -using std::weak_ptr; - -namespace rtc { +namespace rtc::impl { VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr lower, string host, state_callback callback) @@ -44,6 +39,6 @@ VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr lower, strin VerifiedTlsTransport::~VerifiedTlsTransport() {} -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/verifiedtlstransport.hpp b/src/impl/verifiedtlstransport.hpp similarity index 88% rename from src/verifiedtlstransport.hpp rename to src/impl/verifiedtlstransport.hpp index d70b1ca..8e69869 100644 --- a/src/verifiedtlstransport.hpp +++ b/src/impl/verifiedtlstransport.hpp @@ -16,14 +16,14 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_VERIFIED_TLS_TRANSPORT_H -#define RTC_VERIFIED_TLS_TRANSPORT_H +#ifndef RTC_IMPL_VERIFIED_TLS_TRANSPORT_H +#define RTC_IMPL_VERIFIED_TLS_TRANSPORT_H #include "tlstransport.hpp" #if RTC_ENABLE_WEBSOCKET -namespace rtc { +namespace rtc::impl { class VerifiedTlsTransport final : public TlsTransport { public: @@ -31,7 +31,7 @@ public: ~VerifiedTlsTransport(); }; -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/impl/websocket.cpp b/src/impl/websocket.cpp new file mode 100644 index 0000000..47d1b3d --- /dev/null +++ b/src/impl/websocket.cpp @@ -0,0 +1,371 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "websocket.hpp" +#include "globals.hpp" +#include "common.hpp" +#include "threadpool.hpp" + +#include "tcptransport.hpp" +#include "tlstransport.hpp" +#include "verifiedtlstransport.hpp" +#include "wstransport.hpp" + +#include + +#ifdef _WIN32 +#include +#endif + +namespace rtc::impl { + +using namespace std::placeholders; + +WebSocket::WebSocket(Configuration config_) + : config(std::move(config_)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) { + PLOG_VERBOSE << "Creating WebSocket"; +} + +WebSocket::~WebSocket() { + PLOG_VERBOSE << "Destroying WebSocket"; + remoteClose(); +} + +void WebSocket::parse(const string &url) { + PLOG_VERBOSE << "Opening WebSocket to URL: " << url; + + if (state != State::Closed) + throw std::logic_error("WebSocket must be closed before opening"); + + // Modified regex from RFC 3986, see https://tools.ietf.org/html/rfc3986#appendix-B + static const char *rs = + R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)"; + + static const std::regex r(rs, std::regex::extended); + + std::smatch m; + if (!std::regex_match(url, m, r) || m[10].length() == 0) + throw std::invalid_argument("Invalid WebSocket URL: " + url); + + mScheme = m[2]; + if (mScheme.empty()) + mScheme = "ws"; + else if (mScheme != "ws" && mScheme != "wss") + throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme); + + mHostname = m[10]; + mService = m[12]; + if (mService.empty()) { + mService = mScheme == "ws" ? "80" : "443"; + mHost = mHostname; + } else { + mHost = mHostname + ':' + mService; + } + + while (!mHostname.empty() && mHostname.front() == '[') + mHostname.erase(mHostname.begin()); + while (!mHostname.empty() && mHostname.back() == ']') + mHostname.pop_back(); + + mPath = m[13]; + if (mPath.empty()) + mPath += '/'; + if (string query = m[15]; !query.empty()) + mPath += "?" + query; + + changeState(State::Connecting); + initTcpTransport(); +} + +void WebSocket::close() { + auto s = state.load(); + if (s == State::Connecting || s == State::Open) { + PLOG_VERBOSE << "Closing WebSocket"; + changeState(State::Closing); + if (auto transport = std::atomic_load(&mWsTransport)) + transport->close(); + else + changeState(State::Closed); + } +} + +void WebSocket::remoteClose() { + if (state.load() != State::Closed) { + close(); + closeTransports(); + } +} + +bool WebSocket::isOpen() const { return state == State::Open; } + +bool WebSocket::isClosed() const { return state == State::Closed; } + +size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; } + +std::optional WebSocket::receive() { + while (auto next = mRecvQueue.tryPop()) { + message_ptr message = *next; + if (message->type != Message::Control) + return to_variant(std::move(*message)); + } + return nullopt; +} + +std::optional WebSocket::peek() { + while (auto next = mRecvQueue.peek()) { + message_ptr message = *next; + if (message->type != Message::Control) + return to_variant(std::move(*message)); + + mRecvQueue.tryPop(); + } + return nullopt; +} + +size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); } + +bool WebSocket::changeState(State newState) { return state.exchange(newState) != newState; } + +bool WebSocket::outgoing(message_ptr message) { + if (state != State::Open || !mWsTransport) + throw std::runtime_error("WebSocket is not open"); + + if (message->size() > maxMessageSize()) + throw std::runtime_error("Message size exceeds limit"); + + return mWsTransport->send(message); +} + +void WebSocket::incoming(message_ptr message) { + if (!message) { + remoteClose(); + return; + } + + if (message->type == Message::String || message->type == Message::Binary) { + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); + } +} + +shared_ptr WebSocket::initTcpTransport() { + PLOG_VERBOSE << "Starting TCP transport"; + using State = TcpTransport::State; + try { + if (auto transport = std::atomic_load(&mTcpTransport)) + return transport; + + auto transport = std::make_shared( + mHostname, mService, [this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + if (mScheme == "ws") + initWsTransport(); + else + initTlsTransport(); + break; + case State::Failed: + triggerError("TCP connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mTcpTransport, transport); + if (state == WebSocket::State::Closed) { + mTcpTransport.reset(); + throw std::runtime_error("Connection is closed"); + } + transport->start(); + return transport; + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("TCP transport initialization failed"); + } +} + +shared_ptr WebSocket::initTlsTransport() { + PLOG_VERBOSE << "Starting TLS transport"; + using State = TlsTransport::State; + try { + if (auto transport = std::atomic_load(&mTlsTransport)) + return transport; + + auto lower = std::atomic_load(&mTcpTransport); + auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + initWsTransport(); + break; + case State::Failed: + triggerError("TCP connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }; + + shared_ptr transport; +#ifdef _WIN32 + if (!config.disableTlsVerification) { + PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows"; + } + transport = std::make_shared(lower, mHostname, stateChangeCallback); +#else + if (config.disableTlsVerification) + transport = std::make_shared(lower, mHostname, stateChangeCallback); + else + transport = + std::make_shared(lower, mHostname, stateChangeCallback); +#endif + + std::atomic_store(&mTlsTransport, transport); + if (state == WebSocket::State::Closed) { + mTlsTransport.reset(); + throw std::runtime_error("Connection is closed"); + } + transport->start(); + return transport; + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("TLS transport initialization failed"); + } +} + +shared_ptr WebSocket::initWsTransport() { + PLOG_VERBOSE << "Starting WebSocket transport"; + using State = WsTransport::State; + try { + if (auto transport = std::atomic_load(&mWsTransport)) + return transport; + + shared_ptr lower = std::atomic_load(&mTlsTransport); + if (!lower) + lower = std::atomic_load(&mTcpTransport); + + WsTransport::Configuration wsConfig = {}; + wsConfig.host = mHost; + wsConfig.path = mPath; + wsConfig.protocols = config.protocols; + + auto transport = std::make_shared( + lower, wsConfig, weak_bind(&WebSocket::incoming, this, _1), + [this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + if (state == WebSocket::State::Connecting) { + PLOG_DEBUG << "WebSocket open"; + changeState(WebSocket::State::Open); + triggerOpen(); + } + break; + case State::Failed: + triggerError("WebSocket connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mWsTransport, transport); + if (state == WebSocket::State::Closed) { + mWsTransport.reset(); + throw std::runtime_error("Connection is closed"); + } + transport->start(); + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("WebSocket transport initialization failed"); + } +} + +std::shared_ptr WebSocket::getTcpTransport() const { + return std::atomic_load(&mTcpTransport); +} + +std::shared_ptr WebSocket::getTlsTransport() const { + return std::atomic_load(&mTlsTransport); +} + +std::shared_ptr WebSocket::getWsTransport() const { + return std::atomic_load(&mWsTransport); +} + +void WebSocket::closeTransports() { + PLOG_VERBOSE << "Closing transports"; + + if (state.load() != State::Closed) { + changeState(State::Closed); + triggerClosed(); + } + + // Reset callbacks now that state is changed + resetCallbacks(); + + // Pass the pointers to a thread, allowing to terminate a transport from its own thread + auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr)); + auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr)); + auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr)); + ThreadPool::Instance().enqueue([ws, tls, tcp]() mutable { + if (ws) + ws->stop(); + if (tls) + tls->stop(); + if (tcp) + tcp->stop(); + + ws.reset(); + tls.reset(); + tcp.reset(); + }); +} + +} // namespace rtc::impl + +#endif diff --git a/src/impl/websocket.hpp b/src/impl/websocket.hpp new file mode 100644 index 0000000..66b63a9 --- /dev/null +++ b/src/impl/websocket.hpp @@ -0,0 +1,93 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_IMPL_WEBSOCKET_H +#define RTC_IMPL_WEBSOCKET_H + +#if RTC_ENABLE_WEBSOCKET + +#include "channel.hpp" +#include "common.hpp" +#include "init.hpp" +#include "message.hpp" +#include "queue.hpp" +#include "tcptransport.hpp" +#include "tlstransport.hpp" +#include "wstransport.hpp" + +#include "rtc/websocket.hpp" + +#include +#include +#include +#include + +namespace rtc::impl { + +struct WebSocket final : public Channel, public std::enable_shared_from_this { + using State = rtc::WebSocket::State; + using Configuration = rtc::WebSocket::Configuration; + + WebSocket(Configuration config_); + ~WebSocket(); + + void parse(const string &url); + void close(); + bool outgoing(message_ptr message); + void incoming(message_ptr message); + + std::optional receive() override; + std::optional peek() override; + size_t availableAmount() const override; + + bool isOpen() const; + bool isClosed() const; + size_t maxMessageSize() const; + + bool changeState(State state); + void remoteClose(); + + std::shared_ptr initTcpTransport(); + std::shared_ptr initTlsTransport(); + std::shared_ptr initWsTransport(); + std::shared_ptr getTcpTransport() const; + std::shared_ptr getTlsTransport() const; + std::shared_ptr getWsTransport() const; + + void closeTransports(); + + const Configuration config; + std::atomic state = State::Closed; + +private: + const init_token mInitToken = Init::Token(); + + std::shared_ptr mTcpTransport; + std::shared_ptr mTlsTransport; + std::shared_ptr mWsTransport; + + string mScheme, mHost, mHostname, mService, mPath; + + Queue mRecvQueue; +}; + +} // namespace rtc::impl + +#endif + +#endif // RTC_IMPL_WEBSOCKET_H diff --git a/src/wstransport.cpp b/src/impl/wstransport.cpp similarity index 99% rename from src/wstransport.cpp rename to src/impl/wstransport.cpp index 527dee9..34729cb 100644 --- a/src/wstransport.cpp +++ b/src/impl/wstransport.cpp @@ -45,7 +45,7 @@ #define ntohll(x) htonll(x) #endif -namespace rtc { +namespace rtc::impl { using namespace std::chrono; using std::to_integer; diff --git a/src/wstransport.hpp b/src/impl/wstransport.hpp similarity index 93% rename from src/wstransport.hpp rename to src/impl/wstransport.hpp index 0f80dee..c3e915a 100644 --- a/src/wstransport.hpp +++ b/src/impl/wstransport.hpp @@ -16,15 +16,15 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#ifndef RTC_WS_TRANSPORT_H -#define RTC_WS_TRANSPORT_H +#ifndef RTC_IMPL_WS_TRANSPORT_H +#define RTC_IMPL_WS_TRANSPORT_H -#include "include.hpp" +#include "common.hpp" #include "transport.hpp" #if RTC_ENABLE_WEBSOCKET -namespace rtc { +namespace rtc::impl { class TcpTransport; class TlsTransport; @@ -81,7 +81,7 @@ private: Opcode mPartialOpcode; }; -} // namespace rtc +} // namespace rtc::impl #endif diff --git a/src/init.cpp b/src/init.cpp index 8430d50..da3260f 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -17,27 +17,26 @@ */ #include "init.hpp" +#include "globals.hpp" -#include "certificate.hpp" -#include "dtlstransport.hpp" -#include "sctptransport.hpp" -#include "threadpool.hpp" -#include "tls.hpp" +#include "impl/certificate.hpp" +#include "impl/dtlstransport.hpp" +#include "impl/sctptransport.hpp" +#include "impl/threadpool.hpp" +#include "impl/tls.hpp" #if RTC_ENABLE_WEBSOCKET -#include "tlstransport.hpp" +#include "impl/tlstransport.hpp" #endif #if RTC_ENABLE_MEDIA -#include "dtlssrtptransport.hpp" +#include "impl/dtlssrtptransport.hpp" #endif #ifdef _WIN32 #include #endif -using std::shared_ptr; - namespace rtc { namespace { @@ -51,7 +50,7 @@ void doInit() { throw std::runtime_error("WSAStartup failed, error=" + std::to_string(WSAGetLastError())); #endif - ThreadPool::Instance().spawn(THREADPOOL_SIZE); + impl::ThreadPool::Instance().spawn(THREADPOOL_SIZE); #if USE_GNUTLS // Nothing to do @@ -59,29 +58,29 @@ void doInit() { openssl::init(); #endif - SctpTransport::Init(); - DtlsTransport::Init(); + impl::SctpTransport::Init(); + impl::DtlsTransport::Init(); #if RTC_ENABLE_WEBSOCKET - TlsTransport::Init(); + impl::TlsTransport::Init(); #endif #if RTC_ENABLE_MEDIA - DtlsSrtpTransport::Init(); + impl::DtlsSrtpTransport::Init(); #endif } void doCleanup() { PLOG_DEBUG << "Global cleanup"; - ThreadPool::Instance().join(); - CleanupCertificateCache(); + impl::ThreadPool::Instance().join(); + impl::CleanupCertificateCache(); - SctpTransport::Cleanup(); - DtlsTransport::Cleanup(); + impl::SctpTransport::Cleanup(); + impl::DtlsTransport::Cleanup(); #if RTC_ENABLE_WEBSOCKET - TlsTransport::Cleanup(); + impl::TlsTransport::Cleanup(); #endif #if RTC_ENABLE_MEDIA - DtlsSrtpTransport::Cleanup(); + impl::DtlsSrtpTransport::Cleanup(); #endif #ifdef _WIN32 @@ -114,7 +113,7 @@ void Init::Preload() { Global = new shared_ptr(token); PLOG_DEBUG << "Preloading certificate"; - make_certificate().wait(); + impl::make_certificate().wait(); } void Init::Cleanup() { diff --git a/src/peerconnection.cpp b/src/peerconnection.cpp index 07a719d..2f504ca 100644 --- a/src/peerconnection.cpp +++ b/src/peerconnection.cpp @@ -18,25 +18,28 @@ */ #include "peerconnection.hpp" -#include "certificate.hpp" -#include "include.hpp" -#include "logcounter.hpp" -#include "processor.hpp" +#include "common.hpp" #include "rtp.hpp" -#include "threadpool.hpp" -#include "dtlstransport.hpp" -#include "icetransport.hpp" -#include "sctptransport.hpp" +#include "impl/certificate.hpp" +#include "impl/dtlstransport.hpp" +#include "impl/icetransport.hpp" +#include "impl/peerconnection.hpp" +#include "impl/processor.hpp" +#include "impl/sctptransport.hpp" +#include "impl/threadpool.hpp" +#include "impl/track.hpp" #if RTC_ENABLE_MEDIA -#include "dtlssrtptransport.hpp" +#include "impl/dtlssrtptransport.hpp" #endif #include #include #include +using namespace std::placeholders; + #if __clang__ && defined(__APPLE__) namespace { template @@ -48,89 +51,35 @@ inline std::shared_ptr reinterpret_pointer_cast(std::shared_ptr const using std::reinterpret_pointer_cast; #endif -static rtc::LogCounter COUNTER_MEDIA_TRUNCATED(plog::warning, - "Number of RTP packets truncated over past second"); -static rtc::LogCounter - COUNTER_SRTP_DECRYPT_ERROR(plog::warning, "Number of SRTP decryption errors over past second"); -static rtc::LogCounter - COUNTER_SRTP_ENCRYPT_ERROR(plog::warning, "Number of SRTP encryption errors over past second"); -static rtc::LogCounter - COUNTER_UNKNOWN_PACKET_TYPE(plog::warning, - "Number of unknown RTCP packet types over past second"); - namespace rtc { -using namespace std::placeholders; - -using std::shared_ptr; -using std::weak_ptr; - PeerConnection::PeerConnection() : PeerConnection(Configuration()) {} -PeerConnection::PeerConnection(const Configuration &config) - : mConfig(config), mCertificate(make_certificate()), mProcessor(std::make_unique()), - mState(State::New), mGatheringState(GatheringState::New), - mSignalingState(SignalingState::Stable), mNegotiationNeeded(false) { - PLOG_VERBOSE << "Creating PeerConnection"; +PeerConnection::PeerConnection(Configuration config) + : CheshireCat(std::move(config)) {} - if (config.portRangeEnd && config.portRangeBegin > config.portRangeEnd) - throw std::invalid_argument("Invalid port range"); +PeerConnection::~PeerConnection() { close(); } - if (config.mtu) { - if (*config.mtu < 576) // Min MTU for IPv4 - throw std::invalid_argument("Invalid MTU value"); +void PeerConnection::close() { impl()->close(); } - if (*config.mtu > 1500) { // Standard Ethernet - PLOG_WARNING << "MTU set to " << *config.mtu; - } else { - PLOG_VERBOSE << "MTU set to " << *config.mtu; - } - } +const Configuration *PeerConnection::config() const { return &impl()->config; } + +PeerConnection::State PeerConnection::state() const { return impl()->state; } + +PeerConnection::GatheringState PeerConnection::gatheringState() const { + return impl()->gatheringState; } -PeerConnection::~PeerConnection() { - PLOG_VERBOSE << "Destroying PeerConnection"; - close(); - mProcessor->join(); +PeerConnection::SignalingState PeerConnection::signalingState() const { + return impl()->signalingState; } -void PeerConnection::close() { - PLOG_VERBOSE << "Closing PeerConnection"; - - mNegotiationNeeded = false; - - // Close data channels asynchronously - mProcessor->enqueue(&PeerConnection::closeDataChannels, this); - - closeTransports(); -} - -const Configuration *PeerConnection::config() const { return &mConfig; } - -PeerConnection::State PeerConnection::state() const { return mState; } - -PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; } - -PeerConnection::SignalingState PeerConnection::signalingState() const { return mSignalingState; } - std::optional PeerConnection::localDescription() const { - std::lock_guard lock(mLocalDescriptionMutex); - return mLocalDescription; + return impl()->localDescription(); } std::optional PeerConnection::remoteDescription() const { - std::lock_guard lock(mRemoteDescriptionMutex); - return mRemoteDescription; -} - -bool PeerConnection::hasLocalDescription() const { - std::lock_guard lock(mLocalDescriptionMutex); - return bool(mLocalDescription); -} - -bool PeerConnection::hasRemoteDescription() const { - std::lock_guard lock(mRemoteDescriptionMutex); - return bool(mRemoteDescription); + return impl()->remoteDescription(); } bool PeerConnection::hasMedia() const { @@ -141,39 +90,26 @@ bool PeerConnection::hasMedia() const { void PeerConnection::setLocalDescription(Description::Type type) { PLOG_VERBOSE << "Setting local description, type=" << Description::typeToString(type); - SignalingState signalingState = mSignalingState.load(); + SignalingState signalingState = impl()->signalingState.load(); if (type == Description::Type::Rollback) { if (signalingState == SignalingState::HaveLocalOffer || signalingState == SignalingState::HaveLocalPranswer) { - PLOG_DEBUG << "Rolling back pending local description"; - - std::unique_lock lock(mLocalDescriptionMutex); - if (mCurrentLocalDescription) { - std::vector existingCandidates; - if (mLocalDescription) - existingCandidates = mLocalDescription->extractCandidates(); - - mLocalDescription.emplace(std::move(*mCurrentLocalDescription)); - mLocalDescription->addCandidates(std::move(existingCandidates)); - mCurrentLocalDescription.reset(); - } - lock.unlock(); - - changeSignalingState(SignalingState::Stable); + impl()->rollbackLocalDescription(); + impl()->changeSignalingState(SignalingState::Stable); } return; } // Guess the description type if unspecified if (type == Description::Type::Unspec) { - if (mSignalingState == SignalingState::HaveRemoteOffer) + if (signalingState == SignalingState::HaveRemoteOffer) type = Description::Type::Answer; else type = Description::Type::Offer; } // Only a local offer resets the negotiation needed flag - if (type == Description::Type::Offer && !mNegotiationNeeded.exchange(false)) { + if (type == Description::Type::Offer && !impl()->negotiationNeeded.exchange(false)) { PLOG_DEBUG << "No negotiation needed"; return; } @@ -210,15 +146,15 @@ void PeerConnection::setLocalDescription(Description::Type type) { } } - auto iceTransport = initIceTransport(); + auto iceTransport = impl()->initIceTransport(); Description local = iceTransport->getLocalDescription(type); - processLocalDescription(std::move(local)); + impl()->processLocalDescription(std::move(local)); - changeSignalingState(newSignalingState); + impl()->changeSignalingState(newSignalingState); - if (mGatheringState == GatheringState::New) { - iceTransport->gatherLocalCandidates(localBundleMid()); + if (impl()->gatheringState == GatheringState::New) { + iceTransport->gatherLocalCandidates(impl()->localBundleMid()); } } @@ -228,14 +164,14 @@ void PeerConnection::setRemoteDescription(Description description) { if (description.type() == Description::Type::Rollback) { // This is mostly useless because we accept any offer PLOG_VERBOSE << "Rolling back pending remote description"; - changeSignalingState(SignalingState::Stable); + impl()->changeSignalingState(SignalingState::Stable); return; } - validateRemoteDescription(description); + impl()->validateRemoteDescription(description); // Get the new signaling state - SignalingState signalingState = mSignalingState.load(); + SignalingState signalingState = impl()->signalingState.load(); SignalingState newSignalingState; switch (signalingState) { case SignalingState::Stable: @@ -290,32 +226,18 @@ void PeerConnection::setRemoteDescription(Description description) { // Candidates will be added at the end, extract them for now auto remoteCandidates = description.extractCandidates(); auto type = description.type(); - processRemoteDescription(std::move(description)); + impl()->processRemoteDescription(std::move(description)); - changeSignalingState(newSignalingState); + impl()->changeSignalingState(newSignalingState); if (type == Description::Type::Offer) { // This is an offer, we need to answer setLocalDescription(Description::Type::Answer); } else { // This is an answer - auto iceTransport = std::atomic_load(&mIceTransport); - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (!sctpTransport && iceTransport && iceTransport->role() == Description::Role::Active) { - // Since we assumed passive role during DataChannel creation, we need to shift the - // stream numbers by one to shift them from odd to even. - std::unique_lock lock(mDataChannelsMutex); // we are going to swap the container - decltype(mDataChannels) newDataChannels; - auto it = mDataChannels.begin(); - while (it != mDataChannels.end()) { - auto channel = it->second.lock(); - if (channel->stream() % 2 == 1) - channel->mStream -= 1; - newDataChannels.emplace(channel->stream(), channel); - ++it; - } - std::swap(mDataChannels, newDataChannels); - } + // Since we assumed passive role during DataChannel creation, we need to shift the + // stream numbers by one to shift them from odd to even. + impl()->shiftDataChannels(); } for (const auto &candidate : remoteCandidates) @@ -324,16 +246,16 @@ void PeerConnection::setRemoteDescription(Description description) { void PeerConnection::addRemoteCandidate(Candidate candidate) { PLOG_VERBOSE << "Adding remote candidate: " << string(candidate); - processRemoteCandidate(std::move(candidate)); + impl()->processRemoteCandidate(std::move(candidate)); } std::optional PeerConnection::localAddress() const { - auto iceTransport = std::atomic_load(&mIceTransport); + auto iceTransport = impl()->getIceTransport(); return iceTransport ? iceTransport->getLocalAddress() : nullopt; } std::optional PeerConnection::remoteAddress() const { - auto iceTransport = std::atomic_load(&mIceTransport); + auto iceTransport = impl()->getIceTransport(); return iceTransport ? iceTransport->getRemoteAddress() : nullopt; } @@ -342,21 +264,21 @@ shared_ptr PeerConnection::addDataChannel(string label, DataChannel // setup:passive. [...] Thus, setup:active is RECOMMENDED. // See https://tools.ietf.org/html/rfc5763#section-5 // Therefore, we assume passive role when we are the offerer. - auto iceTransport = std::atomic_load(&mIceTransport); + auto iceTransport = impl()->getIceTransport(); auto role = iceTransport ? iceTransport->role() : Description::Role::Passive; - auto channel = emplaceDataChannel(role, std::move(label), std::move(init)); + auto channelImpl = impl()->emplaceDataChannel(role, std::move(label), std::move(init)); - if (auto transport = std::atomic_load(&mSctpTransport)) - if (transport->state() == SctpTransport::State::Connected) - channel->open(transport); + if (auto transport = impl()->getSctpTransport()) + if (transport->state() == impl::SctpTransport::State::Connected) + channelImpl->open(transport); // Renegotiation is needed iff the current local description does not have application - std::lock_guard lock(mLocalDescriptionMutex); - if (!mLocalDescription || !mLocalDescription->hasApplication()) - mNegotiationNeeded = true; + auto local = impl()->localDescription(); + if (!local || !local->hasApplication()) + impl()->negotiationNeeded = true; - return channel; + return std::make_shared(channelImpl); } shared_ptr PeerConnection::createDataChannel(string label, DataChannelInit init) { @@ -367,927 +289,65 @@ shared_ptr PeerConnection::createDataChannel(string label, DataChan void PeerConnection::onDataChannel( std::function dataChannel)> callback) { - mDataChannelCallback = callback; + impl()->dataChannelCallback = callback; } void PeerConnection::onLocalDescription(std::function callback) { - mLocalDescriptionCallback = callback; + impl()->localDescriptionCallback = callback; } void PeerConnection::onLocalCandidate(std::function callback) { - mLocalCandidateCallback = callback; + impl()->localCandidateCallback = callback; } void PeerConnection::onStateChange(std::function callback) { - mStateChangeCallback = callback; + impl()->stateChangeCallback = callback; } void PeerConnection::onGatheringStateChange(std::function callback) { - mGatheringStateChangeCallback = callback; + impl()->gatheringStateChangeCallback = callback; } void PeerConnection::onSignalingStateChange(std::function callback) { - mSignalingStateChangeCallback = callback; + impl()->signalingStateChangeCallback = callback; } std::shared_ptr PeerConnection::addTrack(Description::Media description) { -#if !RTC_ENABLE_MEDIA - if (mTracks.empty()) { - PLOG_WARNING << "Tracks will be inative (not compiled with media support)"; - } -#endif - - std::shared_ptr track; - if (auto it = mTracks.find(description.mid()); it != mTracks.end()) - if (track = it->second.lock(); track) - track->setDescription(std::move(description)); - - if (!track) { - track = std::make_shared(std::move(description)); - mTracks.emplace(std::make_pair(track->mid(), track)); - mTrackLines.emplace_back(track); - } + auto trackImpl = impl()->emplaceTrack(std::move(description)); // Renegotiation is needed for the new or updated track - mNegotiationNeeded = true; + impl()->negotiationNeeded = true; - return track; + return std::make_shared(trackImpl); } void PeerConnection::onTrack(std::function)> callback) { - mTrackCallback = callback; + impl()->trackCallback = callback; } -shared_ptr PeerConnection::initIceTransport() { - try { - if (auto transport = std::atomic_load(&mIceTransport)) - return transport; - - PLOG_VERBOSE << "Starting ICE transport"; - - auto transport = std::make_shared( - mConfig, weak_bind(&PeerConnection::processLocalCandidate, this, _1), - [this, weak_this = weak_from_this()](IceTransport::State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case IceTransport::State::Connecting: - changeState(State::Connecting); - break; - case IceTransport::State::Failed: - changeState(State::Failed); - break; - case IceTransport::State::Connected: - initDtlsTransport(); - break; - case IceTransport::State::Disconnected: - changeState(State::Disconnected); - break; - default: - // Ignore - break; - } - }, - [this, weak_this = weak_from_this()](IceTransport::GatheringState state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case IceTransport::GatheringState::InProgress: - changeGatheringState(GatheringState::InProgress); - break; - case IceTransport::GatheringState::Complete: - endLocalCandidates(); - changeGatheringState(GatheringState::Complete); - break; - default: - // Ignore - break; - } - }); - - std::atomic_store(&mIceTransport, transport); - if (mState == State::Closed) { - mIceTransport.reset(); - throw std::runtime_error("Connection is closed"); - } - transport->start(); - return transport; - - } catch (const std::exception &e) { - PLOG_ERROR << e.what(); - changeState(State::Failed); - throw std::runtime_error("ICE transport initialization failed"); - } -} - -shared_ptr PeerConnection::initDtlsTransport() { - try { - if (auto transport = std::atomic_load(&mDtlsTransport)) - return transport; - - PLOG_VERBOSE << "Starting DTLS transport"; - - auto certificate = mCertificate.get(); - auto lower = std::atomic_load(&mIceTransport); - auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1); - auto stateChangeCallback = [this, - weak_this = weak_from_this()](DtlsTransport::State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - - switch (state) { - case DtlsTransport::State::Connected: - if (auto remote = remoteDescription(); remote && remote->hasApplication()) - initSctpTransport(); - else - changeState(State::Connected); - - mProcessor->enqueue(&PeerConnection::openTracks, this); - break; - case DtlsTransport::State::Failed: - changeState(State::Failed); - break; - case DtlsTransport::State::Disconnected: - changeState(State::Disconnected); - break; - default: - // Ignore - break; - } - }; - - shared_ptr transport; - if (hasMedia()) { -#if RTC_ENABLE_MEDIA - PLOG_INFO << "This connection requires media support"; - - // DTLS-SRTP - transport = std::make_shared( - lower, certificate, mConfig.mtu, verifierCallback, - weak_bind(&PeerConnection::forwardMedia, this, _1), stateChangeCallback); -#else - PLOG_WARNING << "Ignoring media support (not compiled with media support)"; -#endif - } - - if (!transport) { - // DTLS only - transport = std::make_shared(lower, certificate, mConfig.mtu, - verifierCallback, stateChangeCallback); - } - - std::atomic_store(&mDtlsTransport, transport); - if (mState == State::Closed) { - mDtlsTransport.reset(); - throw std::runtime_error("Connection is closed"); - } - transport->start(); - return transport; - - } catch (const std::exception &e) { - PLOG_ERROR << e.what(); - changeState(State::Failed); - throw std::runtime_error("DTLS transport initialization failed"); - } -} - -shared_ptr PeerConnection::initSctpTransport() { - try { - if (auto transport = std::atomic_load(&mSctpTransport)) - return transport; - - PLOG_VERBOSE << "Starting SCTP transport"; - - auto remote = remoteDescription(); - if (!remote || !remote->application()) - throw std::logic_error("Starting SCTP transport without application description"); - - uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT); - auto lower = std::atomic_load(&mDtlsTransport); - auto transport = std::make_shared( - lower, sctpPort, mConfig.mtu, weak_bind(&PeerConnection::forwardMessage, this, _1), - weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2), - [this, weak_this = weak_from_this()](SctpTransport::State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case SctpTransport::State::Connected: - changeState(State::Connected); - mProcessor->enqueue(&PeerConnection::openDataChannels, this); - break; - case SctpTransport::State::Failed: - LOG_WARNING << "SCTP transport failed"; - changeState(State::Failed); - mProcessor->enqueue(&PeerConnection::remoteCloseDataChannels, this); - break; - case SctpTransport::State::Disconnected: - changeState(State::Disconnected); - mProcessor->enqueue(&PeerConnection::remoteCloseDataChannels, this); - break; - default: - // Ignore - break; - } - }); - - std::atomic_store(&mSctpTransport, transport); - if (mState == State::Closed) { - mSctpTransport.reset(); - throw std::runtime_error("Connection is closed"); - } - transport->start(); - return transport; - - } catch (const std::exception &e) { - PLOG_ERROR << e.what(); - changeState(State::Failed); - throw std::runtime_error("SCTP transport initialization failed"); - } -} - -void PeerConnection::closeTransports() { - PLOG_VERBOSE << "Closing transports"; - - // Change state to sink state Closed - if (!changeState(State::Closed)) - return; // already closed - - // Reset callbacks now that state is changed - resetCallbacks(); - - // Initiate transport stop on the processor after closing the data channels - mProcessor->enqueue([this]() { - // Pass the pointers to a thread - auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr)); - auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr)); - auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr)); - ThreadPool::Instance().enqueue([sctp, dtls, ice]() mutable { - if (sctp) - sctp->stop(); - if (dtls) - dtls->stop(); - if (ice) - ice->stop(); - - sctp.reset(); - dtls.reset(); - ice.reset(); - }); - }); -} - -void PeerConnection::endLocalCandidates() { - std::lock_guard lock(mLocalDescriptionMutex); - if (mLocalDescription) - mLocalDescription->endCandidates(); -} - -bool PeerConnection::checkFingerprint(const std::string &fingerprint) const { - std::lock_guard lock(mRemoteDescriptionMutex); - if (auto expectedFingerprint = - mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) { - return *expectedFingerprint == fingerprint; - } - return false; -} - -void PeerConnection::forwardMessage(message_ptr message) { - if (!message) { - remoteCloseDataChannels(); - return; - } - - uint16_t stream = uint16_t(message->stream); - auto channel = findDataChannel(stream); - if (!channel) { - auto iceTransport = std::atomic_load(&mIceTransport); - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (!iceTransport || !sctpTransport) - return; - - const byte dataChannelOpenMessage{0x03}; - uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0; - if (message->type == Message::Control && *message->data() == dataChannelOpenMessage && - stream % 2 == remoteParity) { - - channel = - std::make_shared(shared_from_this(), sctpTransport, stream); - channel->onOpen(weak_bind(&PeerConnection::triggerDataChannel, this, - weak_ptr{channel})); - - std::unique_lock lock(mDataChannelsMutex); // we are going to emplace - mDataChannels.emplace(stream, channel); - } else { - // Invalid, close the DataChannel - sctpTransport->closeStream(message->stream); - return; - } - } - - channel->incoming(message); -} - -void PeerConnection::forwardMedia(message_ptr message) { - if (!message) - return; - - // Browsers like to compound their packets with a random SSRC. - // we have to do this monstrosity to distribute the report blocks - if (message->type == Message::Control) { - std::set ssrcs; - size_t offset = 0; - while ((sizeof(rtc::RTCP_HEADER) + offset) <= message->size()) { - auto header = reinterpret_cast(message->data() + offset); - if (header->lengthInBytes() > message->size() - offset) { - COUNTER_MEDIA_TRUNCATED++; - break; - } - offset += header->lengthInBytes(); - if (header->payloadType() == 205 || header->payloadType() == 206) { - auto rtcpfb = reinterpret_cast(header); - ssrcs.insert(rtcpfb->getPacketSenderSSRC()); - ssrcs.insert(rtcpfb->getMediaSourceSSRC()); - - } else if (header->payloadType() == 200 || header->payloadType() == 201) { - auto rtcpsr = reinterpret_cast(header); - ssrcs.insert(rtcpsr->senderSSRC()); - for (int i = 0; i < rtcpsr->header.reportCount(); ++i) - ssrcs.insert(rtcpsr->getReportBlock(i)->getSSRC()); - } else if (header->payloadType() == 202) { - auto sdes = reinterpret_cast(header); - if (!sdes->isValid()) { - PLOG_WARNING << "RTCP SDES packet is invalid"; - continue; - } - for (unsigned int i = 0; i < sdes->chunksCount(); i++) { - auto chunk = sdes->getChunk(i); - ssrcs.insert(chunk->ssrc()); - } - } else { - // PT=207 == Extended Report - if (header->payloadType() != 207) { - COUNTER_UNKNOWN_PACKET_TYPE++; - } - } - } - - if (!ssrcs.empty()) { - for (uint32_t ssrc : ssrcs) { - if (auto mid = getMidFromSsrc(ssrc)) { - std::shared_lock lock(mTracksMutex); // read-only - if (auto it = mTracks.find(*mid); it != mTracks.end()) - if (auto track = it->second.lock()) - track->incoming(message); - } - } - return; - } - } - - uint32_t ssrc = uint32_t(message->stream); - if (auto mid = getMidFromSsrc(ssrc)) { - std::shared_lock lock(mTracksMutex); // read-only - if (auto it = mTracks.find(*mid); it != mTracks.end()) - if (auto track = it->second.lock()) - track->incoming(message); - } else { - /* - * TODO: So the problem is that when stop sending streams, we stop getting report blocks for - * those streams Therefore when we get compound RTCP packets, they are empty, and we can't - * forward them. Therefore, it is expected that we don't know where to forward packets. Is - * this ideal? No! Do I know how to fix it? No! - */ - // PLOG_WARNING << "Track not found for SSRC " << ssrc << ", dropping"; - return; - } -} - -std::optional PeerConnection::getMidFromSsrc(uint32_t ssrc) { - if (auto it = mMidFromSsrc.find(ssrc); it != mMidFromSsrc.end()) - return it->second; - - { - std::lock_guard lock(mRemoteDescriptionMutex); - if (!mRemoteDescription) - return nullopt; - for (unsigned int i = 0; i < mRemoteDescription->mediaCount(); ++i) { - if (auto found = std::visit( - rtc::overloaded{[&](Description::Application *) -> std::optional { - return std::nullopt; - }, - [&](Description::Media *media) -> std::optional { - return media->hasSSRC(ssrc) - ? std::make_optional(media->mid()) - : nullopt; - }}, - mRemoteDescription->media(i))) { - - mMidFromSsrc.emplace(ssrc, *found); - return *found; - } - } - } - { - std::lock_guard lock(mLocalDescriptionMutex); - if (!mLocalDescription) - return nullopt; - for (unsigned int i = 0; i < mLocalDescription->mediaCount(); ++i) { - if (auto found = std::visit( - rtc::overloaded{[&](Description::Application *) -> std::optional { - return std::nullopt; - }, - [&](Description::Media *media) -> std::optional { - return media->hasSSRC(ssrc) - ? std::make_optional(media->mid()) - : nullopt; - }}, - mLocalDescription->media(i))) { - - mMidFromSsrc.emplace(ssrc, *found); - return *found; - } - } - } - - return nullopt; -} - -void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) { - if (auto channel = findDataChannel(stream)) - channel->triggerBufferedAmount(amount); -} - -shared_ptr PeerConnection::emplaceDataChannel(Description::Role role, string label, - DataChannelInit init) { - std::unique_lock lock(mDataChannelsMutex); // we are going to emplace - uint16_t stream; - if (init.id) { - stream = *init.id; - if (stream == 65535) - throw std::invalid_argument("Invalid DataChannel id"); - } else { - // The active side must use streams with even identifiers, whereas the passive side must use - // streams with odd identifiers. - // See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6 - stream = (role == Description::Role::Active) ? 0 : 1; - while (mDataChannels.find(stream) != mDataChannels.end()) { - if (stream >= 65535 - 2) - throw std::runtime_error("Too many DataChannels"); - - stream += 2; - } - } - // If the DataChannel is user-negotiated, do not negociate it here - auto channel = - init.negotiated - ? std::make_shared(shared_from_this(), stream, std::move(label), - std::move(init.protocol), std::move(init.reliability)) - : std::make_shared(shared_from_this(), stream, std::move(label), - std::move(init.protocol), - std::move(init.reliability)); - mDataChannels.emplace(std::make_pair(stream, channel)); - return channel; -} - -shared_ptr PeerConnection::findDataChannel(uint16_t stream) { - std::shared_lock lock(mDataChannelsMutex); // read-only - if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) - if (auto channel = it->second.lock()) - return channel; - - return nullptr; -} - -void PeerConnection::iterateDataChannels( - std::function channel)> func) { - // Iterate - { - std::shared_lock lock(mDataChannelsMutex); // read-only - auto it = mDataChannels.begin(); - while (it != mDataChannels.end()) { - auto channel = it->second.lock(); - if (channel && !channel->isClosed()) - func(channel); - - ++it; - } - } - - // Cleanup - { - std::unique_lock lock(mDataChannelsMutex); // we are going to erase - auto it = mDataChannels.begin(); - while (it != mDataChannels.end()) { - if (!it->second.lock()) { - it = mDataChannels.erase(it); - continue; - } - - ++it; - } - } -} - -void PeerConnection::openDataChannels() { - if (auto transport = std::atomic_load(&mSctpTransport)) - iterateDataChannels([&](shared_ptr channel) { channel->open(transport); }); -} - -void PeerConnection::closeDataChannels() { - iterateDataChannels([&](shared_ptr channel) { channel->close(); }); -} - -void PeerConnection::remoteCloseDataChannels() { - iterateDataChannels([&](shared_ptr channel) { channel->remoteClose(); }); -} - -void PeerConnection::incomingTrack(Description::Media description) { - std::unique_lock lock(mTracksMutex); // we are going to emplace -#if !RTC_ENABLE_MEDIA - if (mTracks.empty()) { - PLOG_WARNING << "Tracks will be inative (not compiled with media support)"; - } -#endif - if (mTracks.find(description.mid()) == mTracks.end()) { - auto track = std::make_shared(std::move(description)); - mTracks.emplace(std::make_pair(track->mid(), track)); - mTrackLines.emplace_back(track); - triggerTrack(track); - } -} - -void PeerConnection::openTracks() { -#if RTC_ENABLE_MEDIA - if (auto transport = std::atomic_load(&mDtlsTransport)) { - auto srtpTransport = reinterpret_pointer_cast(transport); - std::shared_lock lock(mTracksMutex); // read-only - for (auto it = mTracks.begin(); it != mTracks.end(); ++it) - if (auto track = it->second.lock()) - if (!track->isOpen()) - track->open(srtpTransport); - } -#endif -} - -void PeerConnection::validateRemoteDescription(const Description &description) { - if (!description.iceUfrag()) - throw std::invalid_argument("Remote description has no ICE user fragment"); - - if (!description.icePwd()) - throw std::invalid_argument("Remote description has no ICE password"); - - if (!description.fingerprint()) - throw std::invalid_argument("Remote description has no fingerprint"); - - if (description.mediaCount() == 0) - throw std::invalid_argument("Remote description has no media line"); - - int activeMediaCount = 0; - for (unsigned int i = 0; i < description.mediaCount(); ++i) - std::visit(rtc::overloaded{[&](const Description::Application *) { ++activeMediaCount; }, - [&](const Description::Media *media) { - if (media->direction() != Description::Direction::Inactive) - ++activeMediaCount; - }}, - description.media(i)); - - if (activeMediaCount == 0) - throw std::invalid_argument("Remote description has no active media"); - - if (auto local = localDescription(); local && local->iceUfrag() && local->icePwd()) - if (*description.iceUfrag() == *local->iceUfrag() && - *description.icePwd() == *local->icePwd()) - throw std::logic_error("Got the local description as remote description"); - - PLOG_VERBOSE << "Remote description looks valid"; -} - -void PeerConnection::processLocalDescription(Description description) { - - if (auto remote = remoteDescription()) { - // Reciprocate remote description - for (unsigned int i = 0; i < remote->mediaCount(); ++i) - std::visit( // reciprocate each media - rtc::overloaded{ - [&](Description::Application *remoteApp) { - std::shared_lock lock(mDataChannelsMutex); - if (!mDataChannels.empty()) { - // Prefer local description - Description::Application app(remoteApp->mid()); - app.setSctpPort(DEFAULT_SCTP_PORT); - app.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); - - PLOG_DEBUG << "Adding application to local description, mid=\"" - << app.mid() << "\""; - - description.addMedia(std::move(app)); - return; - } - - auto reciprocated = remoteApp->reciprocate(); - reciprocated.hintSctpPort(DEFAULT_SCTP_PORT); - reciprocated.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); - - PLOG_DEBUG << "Reciprocating application in local description, mid=\"" - << reciprocated.mid() << "\""; - - description.addMedia(std::move(reciprocated)); - }, - [&](Description::Media *remoteMedia) { - std::shared_lock lock(mTracksMutex); - if (auto it = mTracks.find(remoteMedia->mid()); it != mTracks.end()) { - // Prefer local description - if (auto track = it->second.lock()) { - auto media = track->description(); -#if !RTC_ENABLE_MEDIA - // No media support, mark as inactive - media.setDirection(Description::Direction::Inactive); -#endif - PLOG_DEBUG - << "Adding media to local description, mid=\"" << media.mid() - << "\", active=" << std::boolalpha - << (media.direction() != Description::Direction::Inactive); - - description.addMedia(std::move(media)); - } else { - auto reciprocated = remoteMedia->reciprocate(); - reciprocated.setDirection(Description::Direction::Inactive); - - PLOG_DEBUG << "Adding inactive media to local description, mid=\"" - << reciprocated.mid() << "\""; - - description.addMedia(std::move(reciprocated)); - } - return; - } - lock.unlock(); // we are going to call incomingTrack() - - auto reciprocated = remoteMedia->reciprocate(); -#if !RTC_ENABLE_MEDIA - // No media support, mark as inactive - reciprocated.setDirection(Description::Direction::Inactive); -#endif - incomingTrack(reciprocated); - - PLOG_DEBUG - << "Reciprocating media in local description, mid=\"" - << reciprocated.mid() << "\", active=" << std::boolalpha - << (reciprocated.direction() != Description::Direction::Inactive); - - description.addMedia(std::move(reciprocated)); - }, - }, - remote->media(i)); - } - - if (description.type() == Description::Type::Offer) { - // This is an offer, add locally created data channels and tracks - // Add application for data channels - if (!description.hasApplication()) { - std::shared_lock lock(mDataChannelsMutex); - if (!mDataChannels.empty()) { - unsigned int m = 0; - while (description.hasMid(std::to_string(m))) - ++m; - Description::Application app(std::to_string(m)); - app.setSctpPort(DEFAULT_SCTP_PORT); - app.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE); - - PLOG_DEBUG << "Adding application to local description, mid=\"" << app.mid() - << "\""; - - description.addMedia(std::move(app)); - } - } - - // Add media for local tracks - std::shared_lock lock(mTracksMutex); - for (auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) { - if (auto track = it->lock()) { - if (description.hasMid(track->mid())) - continue; - - auto media = track->description(); -#if !RTC_ENABLE_MEDIA - // No media support, mark as inactive - media.setDirection(Description::Direction::Inactive); -#endif - PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid() - << "\", active=" << std::boolalpha - << (media.direction() != Description::Direction::Inactive); - - description.addMedia(std::move(media)); - } - } - } - - // Set local fingerprint (wait for certificate if necessary) - description.setFingerprint(mCertificate.get()->fingerprint()); - - { - // Set as local description - std::lock_guard lock(mLocalDescriptionMutex); - - std::vector existingCandidates; - if (mLocalDescription) { - existingCandidates = mLocalDescription->extractCandidates(); - mCurrentLocalDescription.emplace(std::move(*mLocalDescription)); - } - - mLocalDescription.emplace(description); - mLocalDescription->addCandidates(std::move(existingCandidates)); - } - - PLOG_VERBOSE << "Issuing local description: " << description; - mProcessor->enqueue(mLocalDescriptionCallback.wrap(), std::move(description)); - - // Reciprocated tracks might need to be open - if (auto dtlsTransport = std::atomic_load(&mDtlsTransport); - dtlsTransport && dtlsTransport->state() == Transport::State::Connected) - mProcessor->enqueue(&PeerConnection::openTracks, this); -} - -void PeerConnection::processLocalCandidate(Candidate candidate) { - std::lock_guard lock(mLocalDescriptionMutex); - if (!mLocalDescription) - throw std::logic_error("Got a local candidate without local description"); - - candidate.resolve(Candidate::ResolveMode::Simple); - mLocalDescription->addCandidate(candidate); - - PLOG_VERBOSE << "Issuing local candidate: " << candidate; - mProcessor->enqueue(mLocalCandidateCallback.wrap(), std::move(candidate)); -} - -void PeerConnection::processRemoteDescription(Description description) { - { - // Set as remote description - std::lock_guard lock(mRemoteDescriptionMutex); - - std::vector existingCandidates; - if (mRemoteDescription) - existingCandidates = mRemoteDescription->extractCandidates(); - - mRemoteDescription.emplace(description); - mRemoteDescription->addCandidates(std::move(existingCandidates)); - } - - auto iceTransport = initIceTransport(); - iceTransport->setRemoteDescription(std::move(description)); - - if (description.hasApplication()) { - auto dtlsTransport = std::atomic_load(&mDtlsTransport); - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (!sctpTransport && dtlsTransport && - dtlsTransport->state() == Transport::State::Connected) - initSctpTransport(); - } -} - -void PeerConnection::processRemoteCandidate(Candidate candidate) { - auto iceTransport = std::atomic_load(&mIceTransport); - { - // Set as remote candidate - std::lock_guard lock(mRemoteDescriptionMutex); - if (!mRemoteDescription) - throw std::logic_error("Got a remote candidate without remote description"); - - if (!iceTransport) - throw std::logic_error("Got a remote candidate without ICE transport"); - - candidate.hintMid(mRemoteDescription->bundleMid()); - - if (mRemoteDescription->hasCandidate(candidate)) - return; // already in description, ignore - - candidate.resolve(Candidate::ResolveMode::Simple); - mRemoteDescription->addCandidate(candidate); - } - - if (candidate.isResolved()) { - iceTransport->addRemoteCandidate(std::move(candidate)); - } else { - // We might need a lookup, do it asynchronously - // We don't use the thread pool because we have no control on the timeout - if ((iceTransport = std::atomic_load(&mIceTransport))) { - weak_ptr weakIceTransport{iceTransport}; - std::thread t([weakIceTransport, candidate = std::move(candidate)]() mutable { - if (candidate.resolve(Candidate::ResolveMode::Lookup)) - if (auto iceTransport = weakIceTransport.lock()) - iceTransport->addRemoteCandidate(std::move(candidate)); - }); - t.detach(); - } - } -} - -string PeerConnection::localBundleMid() const { - std::lock_guard lock(mLocalDescriptionMutex); - return mLocalDescription ? mLocalDescription->bundleMid() : "0"; -} - -void PeerConnection::triggerDataChannel(weak_ptr weakDataChannel) { - auto dataChannel = weakDataChannel.lock(); - if (!dataChannel) - return; - - mProcessor->enqueue(mDataChannelCallback.wrap(), std::move(dataChannel)); -} - -void PeerConnection::triggerTrack(std::shared_ptr track) { - mProcessor->enqueue(mTrackCallback.wrap(), std::move(track)); -} - -bool PeerConnection::changeState(State state) { - State current; - do { - current = mState.load(); - if (current == State::Closed) - return false; - if (current == state) - return false; - - } while (!mState.compare_exchange_weak(current, state)); - - std::ostringstream s; - s << state; - PLOG_INFO << "Changed state to " << s.str(); - - if (state == State::Closed) - // This is the last state change, so we may steal the callback - mProcessor->enqueue([cb = std::move(mStateChangeCallback)]() { cb(State::Closed); }); - else - mProcessor->enqueue(mStateChangeCallback.wrap(), state); - - return true; -} - -bool PeerConnection::changeGatheringState(GatheringState state) { - if (mGatheringState.exchange(state) == state) - return false; - - std::ostringstream s; - s << state; - PLOG_INFO << "Changed gathering state to " << s.str(); - mProcessor->enqueue(mGatheringStateChangeCallback.wrap(), state); - return true; -} - -bool PeerConnection::changeSignalingState(SignalingState state) { - if (mSignalingState.exchange(state) == state) - return false; - - std::ostringstream s; - s << state; - PLOG_INFO << "Changed signaling state to " << s.str(); - mProcessor->enqueue(mSignalingStateChangeCallback.wrap(), state); - return true; -} - -void PeerConnection::resetCallbacks() { - // Unregister all callbacks - mDataChannelCallback = nullptr; - mLocalDescriptionCallback = nullptr; - mLocalCandidateCallback = nullptr; - mStateChangeCallback = nullptr; - mGatheringStateChangeCallback = nullptr; -} - -bool PeerConnection::getSelectedCandidatePair([[maybe_unused]] Candidate *local, - [[maybe_unused]] Candidate *remote) { - auto iceTransport = std::atomic_load(&mIceTransport); +bool PeerConnection::getSelectedCandidatePair(Candidate *local, Candidate *remote) { + auto iceTransport = impl()->getIceTransport(); return iceTransport ? iceTransport->getSelectedCandidatePair(local, remote) : false; } void PeerConnection::clearStats() { - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (sctpTransport) + if (auto sctpTransport = impl()->getSctpTransport()) return sctpTransport->clearStats(); } size_t PeerConnection::bytesSent() { - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (sctpTransport) - return sctpTransport->bytesSent(); - return 0; + auto sctpTransport = impl()->getSctpTransport(); + return sctpTransport ? sctpTransport->bytesSent() : 0; } size_t PeerConnection::bytesReceived() { - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (sctpTransport) - return sctpTransport->bytesReceived(); - return 0; + auto sctpTransport = impl()->getSctpTransport(); + return sctpTransport ? sctpTransport->bytesReceived() : 0; } std::optional PeerConnection::rtt() { - auto sctpTransport = std::atomic_load(&mSctpTransport); - if (sctpTransport) - return sctpTransport->rtt(); - return std::nullopt; + auto sctpTransport = impl()->getSctpTransport(); + return sctpTransport ? sctpTransport->rtt() : nullopt; } } // namespace rtc diff --git a/src/rtcpreceivingsession.cpp b/src/rtcpreceivingsession.cpp index b4ae27f..c187c67 100644 --- a/src/rtcpreceivingsession.cpp +++ b/src/rtcpreceivingsession.cpp @@ -20,9 +20,10 @@ #if RTC_ENABLE_MEDIA #include "rtcpreceivingsession.hpp" -#include "logcounter.hpp" #include "track.hpp" +#include "impl/logcounter.hpp" + #include #include @@ -34,10 +35,10 @@ namespace rtc { -static LogCounter COUNTER_BAD_RTP_HEADER(plog::warning, "Number of malformed RTP headers"); -static LogCounter COUNTER_UNKNOWN_PPID(plog::warning, "Number of Unknown PPID messages"); -static LogCounter COUNTER_BAD_NOTIF_LEN(plog::warning, "Number of Bad-Lengthed notifications"); -static LogCounter COUNTER_BAD_SCTP_STATUS(plog::warning, "Number of unknown SCTP_STATUS errors"); +static impl::LogCounter COUNTER_BAD_RTP_HEADER(plog::warning, "Number of malformed RTP headers"); +static impl::LogCounter COUNTER_UNKNOWN_PPID(plog::warning, "Number of Unknown PPID messages"); +static impl::LogCounter COUNTER_BAD_NOTIF_LEN(plog::warning, "Number of Bad-Lengthed notifications"); +static impl::LogCounter COUNTER_BAD_SCTP_STATUS(plog::warning, "Number of unknown SCTP_STATUS errors"); message_ptr RtcpReceivingSession::outgoing(message_ptr ptr) { return ptr; } diff --git a/src/track.cpp b/src/track.cpp index 2aaa15f..f1dd2b5 100644 --- a/src/track.cpp +++ b/src/track.cpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020 Paul-Louis Ageneau + * Copyright (c) 2020-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -17,189 +17,50 @@ */ #include "track.hpp" -#include "dtlssrtptransport.hpp" -#include "include.hpp" -#include "logcounter.hpp" -static rtc::LogCounter - COUNTER_MEDIA_BAD_DIRECTION(plog::warning, - "Number of media packets sent in invalid directions"); -static rtc::LogCounter COUNTER_QUEUE_FULL(plog::warning, - "Number of media packets dropped due to a full queue"); +#include "impl/track.hpp" namespace rtc { -using std::shared_ptr; -using std::weak_ptr; +Track::Track(impl_ptr impl) + : CheshireCat(impl), Channel(std::dynamic_pointer_cast(impl)) {} -Track::Track(Description::Media description) - : mMediaDescription(std::move(description)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {} +string Track::mid() const { return impl()->mid(); } -string Track::mid() const { - std::shared_lock lock(mMutex); - return mMediaDescription.mid(); -} +Description::Direction Track::direction() const { return impl()->direction(); } -Description::Media Track::description() const { - std::shared_lock lock(mMutex); - return mMediaDescription; -} - -Description::Direction Track::direction() const { - std::shared_lock lock(mMutex); - return mMediaDescription.direction(); -} +Description::Media Track::description() const { return impl()->description(); } void Track::setDescription(Description::Media description) { - std::unique_lock lock(mMutex); - if (description.mid() != mMediaDescription.mid()) - throw std::logic_error("Media description mid does not match track mid"); - - mMediaDescription = std::move(description); + impl()->setDescription(std::move(description)); } -void Track::close() { - mIsClosed = true; +void Track::close() { impl()->close(); } - setRtcpHandler(nullptr); - resetCallbacks(); -} - -bool Track::send(message_variant data) { - if (mIsClosed) - throw std::runtime_error("Track is closed"); - - auto dir = direction(); - if ((dir == Description::Direction::RecvOnly || dir == Description::Direction::Inactive)) { - COUNTER_MEDIA_BAD_DIRECTION++; - return false; - } - - auto message = make_message(std::move(data)); - - if (auto handler = getRtcpHandler()) { - message = handler->outgoing(message); - if (!message) - return false; - } - - return outgoing(std::move(message)); -} +bool Track::send(message_variant data) { return impl()->outgoing(make_message(std::move(data))); } bool Track::send(const byte *data, size_t size) { return send(binary(data, data + size)); } -std::optional Track::receive() { - if (auto next = mRecvQueue.tryPop()) - return to_variant(std::move(**next)); +bool Track::isOpen(void) const { return impl()->isOpen(); } - return nullopt; -} - -std::optional Track::peek() { - if (auto next = mRecvQueue.peek()) - return to_variant(std::move(**next)); - - return nullopt; -} - -bool Track::isOpen(void) const { -#if RTC_ENABLE_MEDIA - std::shared_lock lock(mMutex); - return !mIsClosed && mDtlsSrtpTransport.lock(); -#else - return !mIsClosed; -#endif -} - -bool Track::isClosed(void) const { return mIsClosed; } +bool Track::isClosed(void) const { return impl()->isClosed(); } size_t Track::maxMessageSize() const { - return 65535 - 12 - 4; // SRTP/UDP -} - -size_t Track::availableAmount() const { return mRecvQueue.amount(); } - -#if RTC_ENABLE_MEDIA -void Track::open(shared_ptr transport) { - { - std::lock_guard lock(mMutex); - mDtlsSrtpTransport = transport; - } - - triggerOpen(); -} -#endif - -void Track::incoming(message_ptr message) { - if (!message) - return; - - auto dir = direction(); - if ((dir == Description::Direction::SendOnly || dir == Description::Direction::Inactive) && - message->type != Message::Control) { - COUNTER_MEDIA_BAD_DIRECTION++; - return; - } - - if (auto handler = getRtcpHandler()) { - message = handler->incoming(message); - if (!message) - return; - } - - // Tail drop if queue is full - if (mRecvQueue.full()) { - COUNTER_QUEUE_FULL++; - return; - } - - mRecvQueue.push(message); - triggerAvailable(mRecvQueue.size()); -} - -bool Track::outgoing([[maybe_unused]] message_ptr message) { -#if RTC_ENABLfiE_MEDIA - std::shared_ptr transport; - { - std::shared_lock lock(mMutex); - transport = mDtlsSrtpTransport.lock(); - if (!transport) - throw std::runtime_error("Track is closed"); - - // Set recommended medium-priority DSCP value - // See https://tools.ietf.org/html/draft-ietf-tsvwg-rtcweb-qos-18 - if (mMediaDescription.type() == "audio") - message->dscp = 46; // EF: Expedited Forwarding - else - message->dscp = 36; // AF42: Assured Forwarding class 4, medium drop probability - } - - return transport->sendMedia(message); -#else - PLOG_WARNING << "Ignoring track send (not compiled with media support)"; - return false; -#endif + // TODO + return 65535; } void Track::setRtcpHandler(std::shared_ptr handler) { - { - std::unique_lock lock(mMutex); - mRtcpHandler = handler; - } - - handler->onOutgoing(std::bind(&Track::outgoing, this, std::placeholders::_1)); + impl()->setRtcpHandler(std::move(handler)); } bool Track::requestKeyframe() { - if (auto handler = getRtcpHandler()) + if (auto handler = impl()->getRtcpHandler()) return handler->requestKeyframe(); return false; } -std::shared_ptr Track::getRtcpHandler() { - std::shared_lock lock(mMutex); - return mRtcpHandler; -} +std::shared_ptr Track::getRtcpHandler() { return impl()->getRtcpHandler(); } } // namespace rtc diff --git a/src/websocket.cpp b/src/websocket.cpp index 454a2b9..fc2936f 100644 --- a/src/websocket.cpp +++ b/src/websocket.cpp @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020 Paul-Louis Ageneau + * Copyright (c) 2020-2021 Paul-Louis Ageneau * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -19,13 +19,10 @@ #if RTC_ENABLE_WEBSOCKET #include "websocket.hpp" -#include "include.hpp" -#include "threadpool.hpp" +#include "globals.hpp" +#include "common.hpp" -#include "tcptransport.hpp" -#include "tlstransport.hpp" -#include "verifiedtlstransport.hpp" -#include "wstransport.hpp" +#include "impl/websocket.hpp" #include @@ -35,335 +32,50 @@ namespace rtc { -using std::shared_ptr; using namespace std::placeholders; -WebSocket::WebSocket(std::optional config) - : mConfig(config ? std::move(*config) : Configuration()), - mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) { - PLOG_VERBOSE << "Creating WebSocket"; -} +WebSocket::WebSocket() : WebSocket(Configuration()) {} -WebSocket::~WebSocket() { - PLOG_VERBOSE << "Destroying WebSocket"; - remoteClose(); -} +WebSocket::WebSocket(Configuration config) + : CheshireCat(std::move(config)), + Channel(std::dynamic_pointer_cast(CheshireCat::impl())) {} -WebSocket::State WebSocket::readyState() const { return mState; } +WebSocket::~WebSocket() { impl()->remoteClose(); } + +WebSocket::State WebSocket::readyState() const { return impl()->state; } + +bool WebSocket::isOpen() const { return impl()->state.load() == State::Open; } + +bool WebSocket::isClosed() const { return impl()->state.load() == State::Closed; } + +size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; } void WebSocket::open(const string &url) { PLOG_VERBOSE << "Opening WebSocket to URL: " << url; - if (mState != State::Closed) - throw std::logic_error("WebSocket must be closed before opening"); - - // Modified regex from RFC 3986, see https://tools.ietf.org/html/rfc3986#appendix-B - static const char *rs = - R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)"; - - static const std::regex r(rs, std::regex::extended); - - std::smatch m; - if (!std::regex_match(url, m, r) || m[10].length() == 0) - throw std::invalid_argument("Invalid WebSocket URL: " + url); - - mScheme = m[2]; - if (mScheme.empty()) - mScheme = "ws"; - else if (mScheme != "ws" && mScheme != "wss") - throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme); - - mHostname = m[10]; - mService = m[12]; - if (mService.empty()) { - mService = mScheme == "ws" ? "80" : "443"; - mHost = mHostname; - } else { - mHost = mHostname + ':' + mService; - } - - while (!mHostname.empty() && mHostname.front() == '[') - mHostname.erase(mHostname.begin()); - while (!mHostname.empty() && mHostname.back() == ']') - mHostname.pop_back(); - - mPath = m[13]; - if (mPath.empty()) - mPath += '/'; - if (string query = m[15]; !query.empty()) - mPath += "?" + query; - - changeState(State::Connecting); - initTcpTransport(); + impl()->parse(url); + impl()->changeState(State::Connecting); + impl()->initTcpTransport(); } void WebSocket::close() { - auto state = mState.load(); + auto state = impl()->state.load(); if (state == State::Connecting || state == State::Open) { PLOG_VERBOSE << "Closing WebSocket"; - changeState(State::Closing); - if (auto transport = std::atomic_load(&mWsTransport)) + impl()->changeState(State::Closing); + if (auto transport = impl()->getWsTransport()) transport->close(); else - changeState(State::Closed); + impl()->changeState(State::Closed); } } -void WebSocket::remoteClose() { - if (mState.load() != State::Closed) { - close(); - closeTransports(); - } +bool WebSocket::send(message_variant data) { + return impl()->outgoing(make_message(std::move(data))); } -bool WebSocket::send(message_variant data) { return outgoing(make_message(std::move(data))); } - bool WebSocket::send(const byte *data, size_t size) { - return outgoing(make_message(data, data + size)); -} - -bool WebSocket::isOpen() const { return mState == State::Open; } - -bool WebSocket::isClosed() const { return mState == State::Closed; } - -size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; } - -std::optional WebSocket::receive() { - while (auto next = mRecvQueue.tryPop()) { - message_ptr message = *next; - if (message->type != Message::Control) - return to_variant(std::move(*message)); - } - return nullopt; -} - -std::optional WebSocket::peek() { - while (auto next = mRecvQueue.peek()) { - message_ptr message = *next; - if (message->type != Message::Control) - return to_variant(std::move(*message)); - - mRecvQueue.tryPop(); - } - return nullopt; -} - -size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); } - -bool WebSocket::changeState(State state) { return mState.exchange(state) != state; } - -bool WebSocket::outgoing(message_ptr message) { - if (mState != State::Open || !mWsTransport) - throw std::runtime_error("WebSocket is not open"); - - if (message->size() > maxMessageSize()) - throw std::runtime_error("Message size exceeds limit"); - - return mWsTransport->send(message); -} - -void WebSocket::incoming(message_ptr message) { - if (!message) { - remoteClose(); - return; - } - - if (message->type == Message::String || message->type == Message::Binary) { - mRecvQueue.push(message); - triggerAvailable(mRecvQueue.size()); - } -} - -shared_ptr WebSocket::initTcpTransport() { - PLOG_VERBOSE << "Starting TCP transport"; - using State = TcpTransport::State; - try { - std::lock_guard lock(mInitMutex); - if (auto transport = std::atomic_load(&mTcpTransport)) - return transport; - - auto transport = std::make_shared( - mHostname, mService, [this, weak_this = weak_from_this()](State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case State::Connected: - if (mScheme == "ws") - initWsTransport(); - else - initTlsTransport(); - break; - case State::Failed: - triggerError("TCP connection failed"); - remoteClose(); - break; - case State::Disconnected: - remoteClose(); - break; - default: - // Ignore - break; - } - }); - std::atomic_store(&mTcpTransport, transport); - if (mState == WebSocket::State::Closed) { - mTcpTransport.reset(); - throw std::runtime_error("Connection is closed"); - } - transport->start(); - return transport; - - } catch (const std::exception &e) { - PLOG_ERROR << e.what(); - remoteClose(); - throw std::runtime_error("TCP transport initialization failed"); - } -} - -shared_ptr WebSocket::initTlsTransport() { - PLOG_VERBOSE << "Starting TLS transport"; - using State = TlsTransport::State; - try { - std::lock_guard lock(mInitMutex); - if (auto transport = std::atomic_load(&mTlsTransport)) - return transport; - - auto lower = std::atomic_load(&mTcpTransport); - auto stateChangeCallback = [this, weak_this = weak_from_this()](State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case State::Connected: - initWsTransport(); - break; - case State::Failed: - triggerError("TCP connection failed"); - remoteClose(); - break; - case State::Disconnected: - remoteClose(); - break; - default: - // Ignore - break; - } - }; - - shared_ptr transport; -#ifdef _WIN32 - if (!mConfig.disableTlsVerification) { - PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows"; - } - transport = std::make_shared(lower, mHostname, stateChangeCallback); -#else - if (mConfig.disableTlsVerification) - transport = std::make_shared(lower, mHostname, stateChangeCallback); - else - transport = - std::make_shared(lower, mHostname, stateChangeCallback); -#endif - - std::atomic_store(&mTlsTransport, transport); - if (mState == WebSocket::State::Closed) { - mTlsTransport.reset(); - throw std::runtime_error("Connection is closed"); - } - transport->start(); - return transport; - - } catch (const std::exception &e) { - PLOG_ERROR << e.what(); - remoteClose(); - throw std::runtime_error("TLS transport initialization failed"); - } -} - -shared_ptr WebSocket::initWsTransport() { - PLOG_VERBOSE << "Starting WebSocket transport"; - using State = WsTransport::State; - try { - std::lock_guard lock(mInitMutex); - if (auto transport = std::atomic_load(&mWsTransport)) - return transport; - - shared_ptr lower = std::atomic_load(&mTlsTransport); - if (!lower) - lower = std::atomic_load(&mTcpTransport); - - WsTransport::Configuration wsConfig = {}; - wsConfig.host = mHost; - wsConfig.path = mPath; - wsConfig.protocols = mConfig.protocols; - - auto transport = std::make_shared( - lower, wsConfig, weak_bind(&WebSocket::incoming, this, _1), - [this, weak_this = weak_from_this()](State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case State::Connected: - if (mState == WebSocket::State::Connecting) { - PLOG_DEBUG << "WebSocket open"; - changeState(WebSocket::State::Open); - triggerOpen(); - } - break; - case State::Failed: - triggerError("WebSocket connection failed"); - remoteClose(); - break; - case State::Disconnected: - remoteClose(); - break; - default: - // Ignore - break; - } - }); - std::atomic_store(&mWsTransport, transport); - if (mState == WebSocket::State::Closed) { - mWsTransport.reset(); - throw std::runtime_error("Connection is closed"); - } - transport->start(); - return transport; - } catch (const std::exception &e) { - PLOG_ERROR << e.what(); - remoteClose(); - throw std::runtime_error("WebSocket transport initialization failed"); - } -} - -void WebSocket::closeTransports() { - PLOG_VERBOSE << "Closing transports"; - - if (mState.load() != State::Closed) { - changeState(State::Closed); - triggerClosed(); - } - - // Reset callbacks now that state is changed - resetCallbacks(); - - // Pass the pointers to a thread, allowing to terminate a transport from its own thread - auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr)); - auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr)); - auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr)); - ThreadPool::Instance().enqueue([ws, tls, tcp]() mutable { - if (ws) - ws->stop(); - if (tls) - tls->stop(); - if (tcp) - tcp->stop(); - - ws.reset(); - tls.reset(); - tcp.reset(); - }); + return impl()->outgoing(make_message(data, data + size)); } } // namespace rtc diff --git a/test/benchmark.cpp b/test/benchmark.cpp index fca21f5..0787489 100644 --- a/test/benchmark.cpp +++ b/test/benchmark.cpp @@ -42,53 +42,41 @@ size_t benchmark(milliseconds duration) { // config1.iceServers.emplace_back("stun:stun.l.google.com:19302"); // config1.mtu = 1500; - auto pc1 = std::make_shared(config1); + PeerConnection pc1(config1); Configuration config2; // config2.iceServers.emplace_back("stun:stun.l.google.com:19302"); // config2.mtu = 1500; - auto pc2 = std::make_shared(config2); + PeerConnection pc2(config2); - pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](Description sdp) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalDescription([&pc2](Description sdp) { cout << "Description 1: " << sdp << endl; - pc2->setRemoteDescription(std::move(sdp)); + pc2.setRemoteDescription(std::move(sdp)); }); - pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](Candidate candidate) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalCandidate([&pc2](Candidate candidate) { cout << "Candidate 1: " << candidate << endl; - pc2->addRemoteCandidate(std::move(candidate)); + pc2.addRemoteCandidate(std::move(candidate)); }); - pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); - pc1->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc1.onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); + pc1.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 1: " << state << endl; }); - pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalDescription([&pc1](Description sdp) { cout << "Description 2: " << sdp << endl; - pc1->setRemoteDescription(std::move(sdp)); + pc1.setRemoteDescription(std::move(sdp)); }); - pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](Candidate candidate) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalCandidate([&pc1](Candidate candidate) { cout << "Candidate 2: " << candidate << endl; - pc1->addRemoteCandidate(std::move(candidate)); + pc1.addRemoteCandidate(std::move(candidate)); }); - pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); - pc2->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc2.onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); + pc2.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 2: " << state << endl; }); @@ -101,7 +89,7 @@ size_t benchmark(milliseconds duration) { steady_clock::time_point startTime, openTime, receivedTime, endTime; shared_ptr dc2; - pc2->onDataChannel([&dc2, &receivedSize, &receivedTime](shared_ptr dc) { + pc2.onDataChannel([&dc2, &receivedSize, &receivedTime](shared_ptr dc) { dc->onMessage([&receivedTime, &receivedSize](variant message) { if (holds_alternative(message)) { const auto &bin = get(message); @@ -117,7 +105,7 @@ size_t benchmark(milliseconds duration) { }); startTime = steady_clock::now(); - auto dc1 = pc1->createDataChannel("benchmark"); + auto dc1 = pc1.createDataChannel("benchmark"); dc1->onOpen([wdc1 = make_weak_ptr(dc1), &messageData, &openTime]() { auto dc1 = wdc1.lock(); @@ -169,8 +157,8 @@ size_t benchmark(milliseconds duration) { cout << "Goodput: " << goodput * 0.001 << " MB/s" << " (" << goodput * 0.001 * 8 << " Mbit/s)" << endl; - pc1->close(); - pc2->close(); + pc1.close(); + pc2.close(); rtc::Cleanup(); this_thread::sleep_for(1s); diff --git a/test/connectivity.cpp b/test/connectivity.cpp index 882e856..67a98b2 100644 --- a/test/connectivity.cpp +++ b/test/connectivity.cpp @@ -39,7 +39,7 @@ void test_connectivity() { // Custom MTU example config1.mtu = 1500; - auto pc1 = std::make_shared(config1); + PeerConnection pc1(config1); Configuration config2; // STUN server example (not necessary to connect locally) @@ -51,62 +51,50 @@ void test_connectivity() { config2.portRangeBegin = 5000; config2.portRangeEnd = 6000; - auto pc2 = std::make_shared(config2); + PeerConnection pc2(config2); - pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](Description sdp) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalDescription([&pc2](Description sdp) { cout << "Description 1: " << sdp << endl; - pc2->setRemoteDescription(string(sdp)); + pc2.setRemoteDescription(string(sdp)); }); - pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](Candidate candidate) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalCandidate([&pc2](Candidate candidate) { cout << "Candidate 1: " << candidate << endl; - pc2->addRemoteCandidate(string(candidate)); + pc2.addRemoteCandidate(string(candidate)); }); - pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); + pc1.onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); - pc1->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc1.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 1: " << state << endl; }); - pc1->onSignalingStateChange([](PeerConnection::SignalingState state) { + pc1.onSignalingStateChange([](PeerConnection::SignalingState state) { cout << "Signaling state 1: " << state << endl; }); - pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalDescription([&pc1](Description sdp) { cout << "Description 2: " << sdp << endl; - pc1->setRemoteDescription(string(sdp)); + pc1.setRemoteDescription(string(sdp)); }); - pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](Candidate candidate) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalCandidate([&pc1](Candidate candidate) { cout << "Candidate 2: " << candidate << endl; - pc1->addRemoteCandidate(string(candidate)); + pc1.addRemoteCandidate(string(candidate)); }); - pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); + pc2.onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); - pc2->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc2.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 2: " << state << endl; }); - pc2->onSignalingStateChange([](PeerConnection::SignalingState state) { + pc2.onSignalingStateChange([](PeerConnection::SignalingState state) { cout << "Signaling state 2: " << state << endl; }); shared_ptr dc2; - pc2->onDataChannel([&dc2](shared_ptr dc) { + pc2.onDataChannel([&dc2](shared_ptr dc) { cout << "DataChannel 2: Received with label \"" << dc->label() << "\"" << endl; if (dc->label() != "test") { cerr << "Wrong DataChannel label" << endl; @@ -124,7 +112,7 @@ void test_connectivity() { std::atomic_store(&dc2, dc); }); - auto dc1 = pc1->createDataChannel("test"); + auto dc1 = pc1.createDataChannel("test"); dc1->onOpen([wdc1 = make_weak_ptr(dc1)]() { auto dc1 = wdc1.lock(); if (!dc1) @@ -145,35 +133,35 @@ void test_connectivity() { while ((!(adc2 = std::atomic_load(&dc2)) || !adc2->isOpen() || !dc1->isOpen()) && attempts--) this_thread::sleep_for(1s); - if (pc1->state() != PeerConnection::State::Connected && - pc2->state() != PeerConnection::State::Connected) + if (pc1.state() != PeerConnection::State::Connected && + pc2.state() != PeerConnection::State::Connected) throw runtime_error("PeerConnection is not connected"); if (!adc2 || !adc2->isOpen() || !dc1->isOpen()) throw runtime_error("DataChannel is not open"); - if (auto addr = pc1->localAddress()) + if (auto addr = pc1.localAddress()) cout << "Local address 1: " << *addr << endl; - if (auto addr = pc1->remoteAddress()) + if (auto addr = pc1.remoteAddress()) cout << "Remote address 1: " << *addr << endl; - if (auto addr = pc2->localAddress()) + if (auto addr = pc2.localAddress()) cout << "Local address 2: " << *addr << endl; - if (auto addr = pc2->remoteAddress()) + if (auto addr = pc2.remoteAddress()) cout << "Remote address 2: " << *addr << endl; Candidate local, remote; - if (pc1->getSelectedCandidatePair(&local, &remote)) { + if (pc1.getSelectedCandidatePair(&local, &remote)) { cout << "Local candidate 1: " << local << endl; cout << "Remote candidate 1: " << remote << endl; } - if (pc2->getSelectedCandidatePair(&local, &remote)) { + if (pc2.getSelectedCandidatePair(&local, &remote)) { cout << "Local candidate 2: " << local << endl; cout << "Remote candidate 2: " << remote << endl; } // Try to open a second data channel with another label shared_ptr second2; - pc2->onDataChannel([&second2](shared_ptr dc) { + pc2.onDataChannel([&second2](shared_ptr dc) { cout << "Second DataChannel 2: Received with label \"" << dc->label() << "\"" << endl; if (dc->label() != "second") { cerr << "Wrong second DataChannel label" << endl; @@ -191,7 +179,7 @@ void test_connectivity() { std::atomic_store(&second2, dc); }); - auto second1 = pc1->createDataChannel("second"); + auto second1 = pc1.createDataChannel("second"); second1->onOpen([wsecond1 = make_weak_ptr(dc1)]() { auto second1 = wsecond1.lock(); if (!second1) @@ -221,8 +209,8 @@ void test_connectivity() { DataChannelInit init; init.negotiated = true; init.id = 42; - auto negotiated1 = pc1->createDataChannel("negotiated", init); - auto negotiated2 = pc2->createDataChannel("negoctated", init); + auto negotiated1 = pc1.createDataChannel("negotiated", init); + auto negotiated2 = pc2.createDataChannel("negoctated", init); if (!negotiated1->isOpen() || !negotiated2->isOpen()) throw runtime_error("Negotiated DataChannel is not open"); @@ -246,9 +234,9 @@ void test_connectivity() { throw runtime_error("Negotiated DataChannel failed"); // Delay close of peer 2 to check closing works properly - pc1->close(); + pc1.close(); this_thread::sleep_for(1s); - pc2->close(); + pc2.close(); this_thread::sleep_for(1s); // You may call rtc::Cleanup() when finished to free static resources diff --git a/test/track.cpp b/test/track.cpp index 9f1c7a5..372368c 100644 --- a/test/track.cpp +++ b/test/track.cpp @@ -36,7 +36,7 @@ void test_track() { // STUN server example // config1.iceServers.emplace_back("stun:stun.l.google.com:19302"); - auto pc1 = std::make_shared(config1); + PeerConnection pc1(config1); Configuration config2; // STUN server example @@ -45,55 +45,43 @@ void test_track() { config2.portRangeBegin = 5000; config2.portRangeEnd = 6000; - auto pc2 = std::make_shared(config2); + PeerConnection pc2(config2); - pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](Description sdp) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalDescription([&pc2](Description sdp) { cout << "Description 1: " << sdp << endl; - pc2->setRemoteDescription(string(sdp)); + pc2.setRemoteDescription(string(sdp)); }); - pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](Candidate candidate) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalCandidate([&pc2](Candidate candidate) { cout << "Candidate 1: " << candidate << endl; - pc2->addRemoteCandidate(string(candidate)); + pc2.addRemoteCandidate(string(candidate)); }); - pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); + pc1.onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); - pc1->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc1.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 1: " << state << endl; }); - pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalDescription([&pc1](Description sdp) { cout << "Description 2: " << sdp << endl; - pc1->setRemoteDescription(string(sdp)); + pc1.setRemoteDescription(string(sdp)); }); - pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](Candidate candidate) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalCandidate([&pc1](Candidate candidate) { cout << "Candidate 2: " << candidate << endl; - pc1->addRemoteCandidate(string(candidate)); + pc1.addRemoteCandidate(string(candidate)); }); - pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); + pc2.onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); - pc2->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc2.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 2: " << state << endl; }); shared_ptr t2; string newTrackMid; - pc2->onTrack([&t2, &newTrackMid](shared_ptr t) { + pc2.onTrack([&t2, &newTrackMid](shared_ptr t) { cout << "Track 2: Received with mid \"" << t->mid() << "\"" << endl; if (t->mid() != newTrackMid) { cerr << "Wrong track mid" << endl; @@ -105,17 +93,17 @@ void test_track() { // Test opening a track newTrackMid = "test"; - auto t1 = pc1->addTrack(Description::Video(newTrackMid)); + auto t1 = pc1.addTrack(Description::Video(newTrackMid)); - pc1->setLocalDescription(); + pc1.setLocalDescription(); int attempts = 10; shared_ptr at2; while ((!(at2 = std::atomic_load(&t2)) || !at2->isOpen() || !t1->isOpen()) && attempts--) this_thread::sleep_for(1s); - if (pc1->state() != PeerConnection::State::Connected && - pc2->state() != PeerConnection::State::Connected) + if (pc1.state() != PeerConnection::State::Connected && + pc2.state() != PeerConnection::State::Connected) throw runtime_error("PeerConnection is not connected"); if (!at2 || !at2->isOpen() || !t1->isOpen()) @@ -123,9 +111,9 @@ void test_track() { // Test renegotiation newTrackMid = "added"; - t1 = pc1->addTrack(Description::Video(newTrackMid)); + t1 = pc1.addTrack(Description::Video(newTrackMid)); - pc1->setLocalDescription(); + pc1.setLocalDescription(); attempts = 10; t2.reset(); @@ -138,9 +126,9 @@ void test_track() { // TODO: Test sending RTP packets in track // Delay close of peer 2 to check closing works properly - pc1->close(); + pc1.close(); this_thread::sleep_for(1s); - pc2->close(); + pc2.close(); this_thread::sleep_for(1s); // You may call rtc::Cleanup() when finished to free static resources diff --git a/test/turn_connectivity.cpp b/test/turn_connectivity.cpp index e80de78..48b3969 100644 --- a/test/turn_connectivity.cpp +++ b/test/turn_connectivity.cpp @@ -40,7 +40,7 @@ void test_turn_connectivity() { // Please do not use outside of libdatachannel tests config1.iceServers.emplace_back("turn:datachannel_test:14018314739877@stun.ageneau.net:3478"); - auto pc1 = std::make_shared(config1); + PeerConnection pc1(config1); Configuration config2; // STUN server example (not necessary, just here for testing) @@ -50,70 +50,58 @@ void test_turn_connectivity() { // Please do not use outside of libdatachannel tests config2.iceServers.emplace_back("turn:datachannel_test:14018314739877@stun.ageneau.net:3478"); - auto pc2 = std::make_shared(config2); + PeerConnection pc2(config2); - pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](Description sdp) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalDescription([&pc2](Description sdp) { cout << "Description 1: " << sdp << endl; - pc2->setRemoteDescription(string(sdp)); + pc2.setRemoteDescription(string(sdp)); }); - pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](Candidate candidate) { - auto pc2 = wpc2.lock(); - if (!pc2) - return; + pc1.onLocalCandidate([&pc2](Candidate candidate) { // For this test, filter out non-relay candidates to force TURN string str(candidate); if (str.find("relay") != string::npos) { cout << "Candidate 1: " << str << endl; - pc2->addRemoteCandidate(str); + pc2.addRemoteCandidate(str); } }); - pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); + pc1.onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; }); - pc1->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc1.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 1: " << state << endl; }); - pc1->onSignalingStateChange([](PeerConnection::SignalingState state) { + pc1.onSignalingStateChange([](PeerConnection::SignalingState state) { cout << "Signaling state 1: " << state << endl; }); - pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalDescription([&pc1](Description sdp) { cout << "Description 2: " << sdp << endl; - pc1->setRemoteDescription(string(sdp)); + pc1.setRemoteDescription(string(sdp)); }); - pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](Candidate candidate) { - auto pc1 = wpc1.lock(); - if (!pc1) - return; + pc2.onLocalCandidate([&pc1](Candidate candidate) { // For this test, filter out non-relay candidates to force TURN string str(candidate); if (str.find("relay") != string::npos) { cout << "Candidate 1: " << str << endl; - pc1->addRemoteCandidate(str); + pc1.addRemoteCandidate(str); } }); - pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); + pc2.onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; }); - pc2->onGatheringStateChange([](PeerConnection::GatheringState state) { + pc2.onGatheringStateChange([](PeerConnection::GatheringState state) { cout << "Gathering state 2: " << state << endl; }); - pc2->onSignalingStateChange([](PeerConnection::SignalingState state) { + pc2.onSignalingStateChange([](PeerConnection::SignalingState state) { cout << "Signaling state 2: " << state << endl; }); shared_ptr dc2; - pc2->onDataChannel([&dc2](shared_ptr dc) { + pc2.onDataChannel([&dc2](shared_ptr dc) { cout << "DataChannel 2: Received with label \"" << dc->label() << "\"" << endl; if (dc->label() != "test") { cerr << "Wrong DataChannel label" << endl; @@ -131,7 +119,7 @@ void test_turn_connectivity() { std::atomic_store(&dc2, dc); }); - auto dc1 = pc1->createDataChannel("test"); + auto dc1 = pc1.createDataChannel("test"); dc1->onOpen([wdc1 = make_weak_ptr(dc1)]() { auto dc1 = wdc1.lock(); if (!dc1) @@ -152,35 +140,35 @@ void test_turn_connectivity() { while ((!(adc2 = std::atomic_load(&dc2)) || !adc2->isOpen() || !dc1->isOpen()) && attempts--) this_thread::sleep_for(1s); - if (pc1->state() != PeerConnection::State::Connected && - pc2->state() != PeerConnection::State::Connected) + if (pc1.state() != PeerConnection::State::Connected && + pc2.state() != PeerConnection::State::Connected) throw runtime_error("PeerConnection is not connected"); if (!adc2 || !adc2->isOpen() || !dc1->isOpen()) throw runtime_error("DataChannel is not open"); - if (auto addr = pc1->localAddress()) + if (auto addr = pc1.localAddress()) cout << "Local address 1: " << *addr << endl; - if (auto addr = pc1->remoteAddress()) + if (auto addr = pc1.remoteAddress()) cout << "Remote address 1: " << *addr << endl; - if (auto addr = pc2->localAddress()) + if (auto addr = pc2.localAddress()) cout << "Local address 2: " << *addr << endl; - if (auto addr = pc2->remoteAddress()) + if (auto addr = pc2.remoteAddress()) cout << "Remote address 2: " << *addr << endl; Candidate local, remote; - if (pc1->getSelectedCandidatePair(&local, &remote)) { + if (pc1.getSelectedCandidatePair(&local, &remote)) { cout << "Local candidate 1: " << local << endl; cout << "Remote candidate 1: " << remote << endl; } - if (pc2->getSelectedCandidatePair(&local, &remote)) { + if (pc2.getSelectedCandidatePair(&local, &remote)) { cout << "Local candidate 2: " << local << endl; cout << "Remote candidate 2: " << remote << endl; } // Try to open a second data channel with another label shared_ptr second2; - pc2->onDataChannel([&second2](shared_ptr dc) { + pc2.onDataChannel([&second2](shared_ptr dc) { cout << "Second DataChannel 2: Received with label \"" << dc->label() << "\"" << endl; if (dc->label() != "second") { cerr << "Wrong second DataChannel label" << endl; @@ -198,7 +186,7 @@ void test_turn_connectivity() { std::atomic_store(&second2, dc); }); - auto second1 = pc1->createDataChannel("second"); + auto second1 = pc1.createDataChannel("second"); second1->onOpen([wsecond1 = make_weak_ptr(dc1)]() { auto second1 = wsecond1.lock(); if (!second1) @@ -228,8 +216,8 @@ void test_turn_connectivity() { DataChannelInit init; init.negotiated = true; init.id = 42; - auto negotiated1 = pc1->createDataChannel("negotiated", init); - auto negotiated2 = pc2->createDataChannel("negoctated", init); + auto negotiated1 = pc1.createDataChannel("negotiated", init); + auto negotiated2 = pc2.createDataChannel("negoctated", init); if (!negotiated1->isOpen() || !negotiated2->isOpen()) throw runtime_error("Negotiated DataChannel is not open"); @@ -253,9 +241,9 @@ void test_turn_connectivity() { throw runtime_error("Negotiated DataChannel failed"); // Delay close of peer 2 to check closing works properly - pc1->close(); + pc1.close(); this_thread::sleep_for(1s); - pc2->close(); + pc2.close(); this_thread::sleep_for(1s); // You may call rtc::Cleanup() when finished to free static resources diff --git a/test/websocket.cpp b/test/websocket.cpp index 60892f7..52af229 100644 --- a/test/websocket.cpp +++ b/test/websocket.cpp @@ -36,24 +36,20 @@ void test_websocket() { const string myMessage = "Hello world from libdatachannel"; - auto ws = std::make_shared(); + WebSocket ws; // Certificate verification can be disabled - // auto ws = std::make_shared(WebSocket::Configuration{.disableTlsVerification = - // true}); + // WebSocket ws(WebSocket::Configuration{.disableTlsVerification = true}); - ws->onOpen([wws = make_weak_ptr(ws), &myMessage]() { - auto ws = wws.lock(); - if (!ws) - return; + ws.onOpen([&ws, &myMessage]() { cout << "WebSocket: Open" << endl; - ws->send(myMessage); + ws.send(myMessage); }); - ws->onClosed([]() { cout << "WebSocket: Closed" << endl; }); + ws.onClosed([]() { cout << "WebSocket: Closed" << endl; }); std::atomic received = false; - ws->onMessage([&received, &myMessage](variant message) { + ws.onMessage([&received, &myMessage](variant message) { if (holds_alternative(message)) { string str = std::move(get(message)); if ((received = (str == myMessage))) @@ -63,19 +59,19 @@ void test_websocket() { } }); - ws->open("wss://echo.websocket.org:443/"); + ws.open("wss://echo.websocket.org:443/"); int attempts = 10; - while ((!ws->isOpen() || !received) && attempts--) + while ((!ws.isOpen() || !received) && attempts--) this_thread::sleep_for(1s); - if (!ws->isOpen()) + if (!ws.isOpen()) throw runtime_error("WebSocket is not open"); if (!received) throw runtime_error("Expected message not received"); - ws->close(); + ws.close(); this_thread::sleep_for(1s); // You may call rtc::Cleanup() when finished to free static resources