diff --git a/CMakeLists.txt b/CMakeLists.txt index de2720c..9fc2691 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ project (libdatachannel option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF) option(USE_JUICE "Use libjuice instead of libnice" OFF) +option(RTC_ENABLE_WEBSOCKET "Build WebSocket support" ON) if(USE_GNUTLS) option(USE_NETTLE "Use Nettle instead of OpenSSL in libjuice" ON) @@ -39,6 +40,14 @@ set(LIBDATACHANNEL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp ) +set(LIBDATACHANNEL_WEBSOCKET_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp +) + set(LIBDATACHANNEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/candidate.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp @@ -55,6 +64,7 @@ set(LIBDATACHANNEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/websocket.hpp ) set(TESTS_SOURCES @@ -89,26 +99,42 @@ endif() add_library(Usrsctp::Usrsctp ALIAS usrsctp) add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static) -add_library(datachannel SHARED ${LIBDATACHANNEL_SOURCES}) +if (RTC_ENABLE_WEBSOCKET) + add_library(datachannel SHARED + ${LIBDATACHANNEL_SOURCES} + ${LIBDATACHANNEL_WEBSOCKET_SOURCES}) + add_library(datachannel-static STATIC EXCLUDE_FROM_ALL + ${LIBDATACHANNEL_SOURCES} + ${LIBDATACHANNEL_WEBSOCKET_SOURCES}) + target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=1) + target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=1) +else() + add_library(datachannel SHARED + ${LIBDATACHANNEL_SOURCES}) + add_library(datachannel-static STATIC EXCLUDE_FROM_ALL + ${LIBDATACHANNEL_SOURCES}) + target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=0) + target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=0) +endif() + set_target_properties(datachannel PROPERTIES VERSION ${PROJECT_VERSION} CXX_STANDARD 17) +set_target_properties(datachannel-static PROPERTIES + VERSION ${PROJECT_VERSION} + CXX_STANDARD 17) target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc) target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include) -target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic) - -add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES}) -set_target_properties(datachannel-static PROPERTIES - VERSION ${PROJECT_VERSION} - CXX_STANDARD 17) target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc) target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include) + +target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic) target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic) if(WIN32) diff --git a/Jamfile b/Jamfile index db3a563..6b905b1 100644 --- a/Jamfile +++ b/Jamfile @@ -10,6 +10,7 @@ lib libdatachannel 17 ./include/rtc USE_JUICE=1 + RTC_ENABLE_WEBSOCKET=0 /libdatachannel//usrsctp /libdatachannel//juice /libdatachannel//plog diff --git a/Makefile b/Makefile index aabd6cb..c0a83e9 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,14 @@ else LIBS+=glib-2.0 gobject-2.0 nice endif +RTC_ENABLE_WEBSOCKET ?= 1 +ifneq ($(RTC_ENABLE_WEBSOCKET), 0) + CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=1 +else + CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=0 +endif + + INCLUDES+=$(shell pkg-config --cflags $(LIBS)) LDLIBS+=$(LOCALLIBS) $(shell pkg-config --libs $(LIBS)) diff --git a/include/rtc/datachannel.hpp b/include/rtc/datachannel.hpp index c209048..c2efe0d 100644 --- a/include/rtc/datachannel.hpp +++ b/include/rtc/datachannel.hpp @@ -82,7 +82,6 @@ private: std::atomic mIsClosed = false; Queue mRecvQueue; - std::atomic mRecvAmount = 0; friend class PeerConnection; }; diff --git a/include/rtc/include.hpp b/include/rtc/include.hpp index 834e15f..d588ae7 100644 --- a/include/rtc/include.hpp +++ b/include/rtc/include.hpp @@ -19,6 +19,10 @@ #ifndef RTC_INCLUDE_H #define RTC_INCLUDE_H +#ifndef RTC_ENABLE_WEBSOCKET +#define RTC_ENABLE_WEBSOCKET 1 +#endif + #ifdef _WIN32 #ifndef _WIN32_WINNT #define _WIN32_WINNT 0x0602 @@ -56,10 +60,21 @@ const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size - +// overloaded helper template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...)->overloaded; +// weak_ptr bind helper +template auto weak_bind(F &&f, T *t, Args &&... _args) { + return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) { + using result_type = typename decltype(bound)::result_type; + if (auto shared_this = weak_this.lock()) + return bound(args...); + else + return (result_type) false; + }; +} + template class synchronized_callback { public: synchronized_callback() = default; diff --git a/include/rtc/message.hpp b/include/rtc/message.hpp index 465b68a..e1396d1 100644 --- a/include/rtc/message.hpp +++ b/include/rtc/message.hpp @@ -30,6 +30,7 @@ namespace rtc { struct Message : binary { enum Type { Binary, String, Control, Reset }; + Message(const Message &message) = default; Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {} template diff --git a/include/rtc/peerconnection.hpp b/include/rtc/peerconnection.hpp index 9e5c98c..984076a 100644 --- a/include/rtc/peerconnection.hpp +++ b/include/rtc/peerconnection.hpp @@ -98,8 +98,6 @@ public: std::string connectionInfo; private: - init_token mInitToken = Init::Token(); - std::shared_ptr initIceTransport(Description::Role role); std::shared_ptr initDtlsTransport(); std::shared_ptr initSctpTransport(); @@ -130,6 +128,8 @@ private: const Configuration mConfig; const std::shared_ptr mCertificate; + init_token mInitToken = Init::Token(); + std::optional mLocalDescription, mRemoteDescription; mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; diff --git a/include/rtc/queue.hpp b/include/rtc/queue.hpp index fa4468e..9f3aae4 100644 --- a/include/rtc/queue.hpp +++ b/include/rtc/queue.hpp @@ -44,6 +44,7 @@ public: void push(T element); std::optional pop(); std::optional peek(); + std::optional exchange(T element); bool wait(const std::optional &duration = nullopt); private: @@ -118,6 +119,16 @@ template std::optional Queue::peek() { } } +template std::optional Queue::exchange(T element) { + std::unique_lock lock(mMutex); + if (!mQueue.empty()) { + std::swap(mQueue.front(), element); + return std::optional{element}; + } else { + return nullopt; + } +} + template bool Queue::wait(const std::optional &duration) { std::unique_lock lock(mMutex); diff --git a/include/rtc/rtc.h b/include/rtc/rtc.h index 66c009d..91dbcd3 100644 --- a/include/rtc/rtc.h +++ b/include/rtc/rtc.h @@ -27,6 +27,10 @@ extern "C" { // libdatachannel C API +#ifndef RTC_ENABLE_WEBSOCKET +#define RTC_ENABLE_WEBSOCKET 1 +#endif + typedef enum { RTC_NEW = 0, RTC_CONNECTING = 1, @@ -42,8 +46,7 @@ typedef enum { RTC_GATHERING_COMPLETE = 2 } rtcGatheringState; -// Don't change, it must match plog severity -typedef enum { +typedef enum { // Don't change, it must match plog severity RTC_LOG_NONE = 0, RTC_LOG_FATAL = 1, RTC_LOG_ERROR = 2, @@ -76,10 +79,10 @@ typedef void (*availableCallbackFunc)(void *ptr); void rtcInitLogger(rtcLogLevel level); // User pointer -void rtcSetUserPointer(int i, void *ptr); +void rtcSetUserPointer(int id, void *ptr); // PeerConnection -int rtcCreatePeerConnection(const rtcConfiguration *config); +int rtcCreatePeerConnection(const rtcConfiguration *config); // returns pc id int rtcDeletePeerConnection(int pc); int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb); @@ -95,24 +98,32 @@ int rtcGetLocalAddress(int pc, char *buffer, int size); int rtcGetRemoteAddress(int pc, char *buffer, int size); // DataChannel -int rtcCreateDataChannel(int pc, const char *label); +int rtcCreateDataChannel(int pc, const char *label); // returns dc id int rtcDeleteDataChannel(int dc); int rtcGetDataChannelLabel(int dc, char *buffer, int size); -int rtcSetOpenCallback(int dc, openCallbackFunc cb); -int rtcSetClosedCallback(int dc, closedCallbackFunc cb); -int rtcSetErrorCallback(int dc, errorCallbackFunc cb); -int rtcSetMessageCallback(int dc, messageCallbackFunc cb); -int rtcSendMessage(int dc, const char *data, int size); -int rtcGetBufferedAmount(int dc); // total size buffered to send -int rtcSetBufferedAmountLowThreshold(int dc, int amount); -int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb); +// WebSocket +#if RTC_ENABLE_WEBSOCKET +int rtcCreateWebSocket(const char *url); // returns ws id +int rtcDeleteWebsocket(int ws); +#endif -// DataChannel extended API -int rtcGetAvailableAmount(int dc); // total size available to receive -int rtcSetAvailableCallback(int dc, availableCallbackFunc cb); -int rtcReceiveMessage(int dc, char *buffer, int *size); +// DataChannel and WebSocket common API +int rtcSetOpenCallback(int id, openCallbackFunc cb); +int rtcSetClosedCallback(int id, closedCallbackFunc cb); +int rtcSetErrorCallback(int id, errorCallbackFunc cb); +int rtcSetMessageCallback(int id, messageCallbackFunc cb); +int rtcSendMessage(int id, const char *data, int size); + +int rtcGetBufferedAmount(int id); // total size buffered to send +int rtcSetBufferedAmountLowThreshold(int id, int amount); +int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb); + +// DataChannel and WebSocket common extended API +int rtcGetAvailableAmount(int id); // total size available to receive +int rtcSetAvailableCallback(int id, availableCallbackFunc cb); +int rtcReceiveMessage(int id, char *buffer, int *size); // Cleanup void rtcCleanup(); diff --git a/include/rtc/rtc.hpp b/include/rtc/rtc.hpp index 9a72714..b78502d 100644 --- a/include/rtc/rtc.hpp +++ b/include/rtc/rtc.hpp @@ -23,6 +23,7 @@ // #include "datachannel.hpp" #include "peerconnection.hpp" +#include "websocket.hpp" // C API #include "rtc.h" diff --git a/include/rtc/websocket.hpp b/include/rtc/websocket.hpp new file mode 100644 index 0000000..acb7d6e --- /dev/null +++ b/include/rtc/websocket.hpp @@ -0,0 +1,95 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_WEBSOCKET_H +#define RTC_WEBSOCKET_H + +#if RTC_ENABLE_WEBSOCKET + +#include "channel.hpp" +#include "include.hpp" +#include "init.hpp" +#include "message.hpp" +#include "queue.hpp" + +#include +#include +#include +#include + +namespace rtc { + +class TcpTransport; +class TlsTransport; +class WsTransport; + +class WebSocket final : public Channel, public std::enable_shared_from_this { +public: + enum class State : int { + Connecting = 0, + Open = 1, + Closing = 2, + Closed = 3, + }; + + WebSocket(); + WebSocket(const string &url); + ~WebSocket(); + + State readyState() const; + + void open(const string &url); + void close() override; + bool send(const std::variant &data) override; + + bool isOpen() const override; + bool isClosed() const override; + size_t maxMessageSize() const override; + + // Extended API + std::optional> receive() override; + size_t availableAmount() const override; // total size available to receive + +private: + bool changeState(State state); + void remoteClose(); + bool outgoing(mutable_message_ptr message); + void incoming(message_ptr message); + + std::shared_ptr initTcpTransport(); + std::shared_ptr initTlsTransport(); + std::shared_ptr initWsTransport(); + void closeTransports(); + + init_token mInitToken = Init::Token(); + + std::shared_ptr mTcpTransport; + std::shared_ptr mTlsTransport; + std::shared_ptr mWsTransport; + std::recursive_mutex mInitMutex; + + string mScheme, mHost, mHostname, mService, mPath; + std::atomic mState = State::Closed; + + Queue mRecvQueue; +}; +} // namespace rtc + +#endif + +#endif // RTC_WEBSOCKET_H diff --git a/src/base64.cpp b/src/base64.cpp new file mode 100644 index 0000000..6ed0075 --- /dev/null +++ b/src/base64.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "base64.hpp" + +namespace rtc { + +using std::to_integer; + +string to_base64(const binary &data) { + static const char tab[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + string out; + out.reserve(3 * ((data.size() + 3) / 4)); + int i = 0; + while (data.size() - i >= 3) { + auto d0 = to_integer(data[i]); + auto d1 = to_integer(data[i + 1]); + auto d2 = to_integer(data[i + 2]); + out += tab[d0 >> 2]; + out += tab[((d0 & 3) << 4) | (d1 >> 4)]; + out += tab[((d1 & 0x0F) << 2) | (d2 >> 6)]; + out += tab[d2 & 0x3F]; + i += 3; + } + + int left = data.size() - i; + if (left) { + auto d0 = to_integer(data[i]); + out += tab[d0 >> 2]; + if (left == 1) { + out += tab[(d0 & 3) << 4]; + out += '='; + } else { // left == 2 + auto d1 = to_integer(data[i + 1]); + out += tab[((d0 & 3) << 4) | (d1 >> 4)]; + out += tab[(d1 & 0x0F) << 2]; + } + out += '='; + } + + return out; +} + +} // namespace rtc + +#endif diff --git a/src/base64.hpp b/src/base64.hpp new file mode 100644 index 0000000..4c0b8e0 --- /dev/null +++ b/src/base64.hpp @@ -0,0 +1,34 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_BASE64_H +#define RTC_BASE64_H + +#if RTC_ENABLE_WEBSOCKET + +#include "include.hpp" + +namespace rtc { + +string to_base64(const binary &data); + +} + +#endif + +#endif diff --git a/src/datachannel.cpp b/src/datachannel.cpp index c72cb0c..c0eea20 100644 --- a/src/datachannel.cpp +++ b/src/datachannel.cpp @@ -214,6 +214,9 @@ bool DataChannel::outgoing(mutable_message_ptr message) { } void DataChannel::incoming(message_ptr message) { + if (!message) + return; + switch (message->type) { case Message::Control: { auto raw = reinterpret_cast(message->data()); diff --git a/src/dtlstransport.cpp b/src/dtlstransport.cpp index 7723130..b902151 100644 --- a/src/dtlstransport.cpp +++ b/src/dtlstransport.cpp @@ -18,9 +18,7 @@ #include "dtlstransport.hpp" #include "icetransport.hpp" -#include "message.hpp" -#include #include #include #include @@ -64,11 +62,9 @@ void DtlsTransport::Cleanup() { } DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, - verifier_callback verifierCallback, - state_callback stateChangeCallback) - : Transport(lower), mCertificate(certificate), mState(State::Disconnected), - mVerifierCallback(std::move(verifierCallback)), - mStateChangeCallback(std::move(stateChangeCallback)) { + verifier_callback verifierCallback, state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate), + mVerifierCallback(std::move(verifierCallback)) { PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)"; @@ -76,31 +72,37 @@ DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptrcredentials(), CertificateCallback); - check_gnutls( - gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials())); + gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback); + check_gnutls( + gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials())); - gnutls_dtls_set_timeouts(mSession, - 1000, // 1s retransmission timeout recommended by RFC 6347 - 30000); // 30s total timeout - gnutls_handshake_set_timeout(mSession, 30000); + gnutls_dtls_set_timeouts(mSession, + 1000, // 1s retransmission timeout recommended by RFC 6347 + 30000); // 30s total timeout + gnutls_handshake_set_timeout(mSession, 30000); - gnutls_session_set_ptr(mSession, this); - gnutls_transport_set_ptr(mSession, this); - gnutls_transport_set_push_function(mSession, WriteCallback); - gnutls_transport_set_pull_function(mSession, ReadCallback); - gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); + gnutls_session_set_ptr(mSession, this); + gnutls_transport_set_ptr(mSession, this); + gnutls_transport_set_push_function(mSession, WriteCallback); + gnutls_transport_set_pull_function(mSession, ReadCallback); + gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); - mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this); - registerIncoming(); + mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this); + registerIncoming(); + + } catch (...) { + gnutls_deinit(mSession); + throw; + } } DtlsTransport::~DtlsTransport() { @@ -109,8 +111,6 @@ DtlsTransport::~DtlsTransport() { gnutls_deinit(mSession); } -DtlsTransport::State DtlsTransport::state() const { return mState; } - bool DtlsTransport::stop() { if (!Transport::stop()) return false; @@ -122,7 +122,7 @@ bool DtlsTransport::stop() { } bool DtlsTransport::send(message_ptr message) { - if (!message || mState != State::Connected) + if (!message || state() != State::Connected) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -148,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } -void DtlsTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(state); -} - void DtlsTransport::runRecvLoop() { const size_t maxMtu = 4096; @@ -169,7 +164,7 @@ void DtlsTransport::runRecvLoop() { throw std::runtime_error("MTU is too low"); } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN || - !check_gnutls(ret, "TLS handshake failed")); + !check_gnutls(ret, "DTLS handshake failed")); // RFC 8261: DTLS MUST support sending messages larger than the current path MTU // See https://tools.ietf.org/html/rfc8261#section-5 @@ -183,7 +178,7 @@ void DtlsTransport::runRecvLoop() { // Receive loop try { - PLOG_INFO << "DTLS handshake done"; + PLOG_INFO << "DTLS handshake finished"; changeState(State::Connected); const size_t bufferSize = maxMtu; @@ -218,7 +213,7 @@ void DtlsTransport::runRecvLoop() { gnutls_bye(mSession, GNUTLS_SHUT_RDWR); - PLOG_INFO << "DTLS disconnected"; + PLOG_INFO << "DTLS closed"; changeState(State::Disconnected); recv(nullptr); } @@ -341,7 +336,7 @@ void DtlsTransport::Init() { if (!BioMethods) { BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer"); if (!BioMethods) - throw std::runtime_error("Unable to BIO methods for DTLS writer"); + throw std::runtime_error("Failed to create BIO methods for DTLS writer"); BIO_meth_set_create(BioMethods, BioMethodNew); BIO_meth_set_destroy(BioMethods, BioMethodFree); BIO_meth_set_write(BioMethods, BioMethodWrite); @@ -358,60 +353,68 @@ void DtlsTransport::Cleanup() { DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, verifier_callback verifierCallback, state_callback stateChangeCallback) - : Transport(lower), mCertificate(certificate), mState(State::Disconnected), - mVerifierCallback(std::move(verifierCallback)), - mStateChangeCallback(std::move(stateChangeCallback)) { + : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate), + mVerifierCallback(std::move(verifierCallback)) { PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)"; - if (!(mCtx = SSL_CTX_new(DTLS_method()))) - throw std::runtime_error("Unable to create SSL context"); + try { + if (!(mCtx = SSL_CTX_new(DTLS_method()))) + throw std::runtime_error("Failed to create SSL context"); - check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), - "Unable to set SSL priorities"); + check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), + "Failed to set SSL priorities"); - // RFC 8261: SCTP performs segmentation and reassembly based on the path MTU. - // Therefore, the DTLS layer MUST NOT use any compression algorithm. - // See https://tools.ietf.org/html/rfc8261#section-5 - SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_QUERY_MTU); - SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION); - SSL_CTX_set_read_ahead(mCtx, 1); - SSL_CTX_set_quiet_shutdown(mCtx, 1); - SSL_CTX_set_info_callback(mCtx, InfoCallback); - SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, - CertificateCallback); - SSL_CTX_set_verify_depth(mCtx, 1); + // RFC 8261: SCTP performs segmentation and reassembly based on the path MTU. + // Therefore, the DTLS layer MUST NOT use any compression algorithm. + // See https://tools.ietf.org/html/rfc8261#section-5 + SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_QUERY_MTU); + SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION); + SSL_CTX_set_read_ahead(mCtx, 1); + SSL_CTX_set_quiet_shutdown(mCtx, 1); + SSL_CTX_set_info_callback(mCtx, InfoCallback); + SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + CertificateCallback); + SSL_CTX_set_verify_depth(mCtx, 1); - auto [x509, pkey] = mCertificate->credentials(); - SSL_CTX_use_certificate(mCtx, x509); - SSL_CTX_use_PrivateKey(mCtx, pkey); + auto [x509, pkey] = mCertificate->credentials(); + SSL_CTX_use_certificate(mCtx, x509); + SSL_CTX_use_PrivateKey(mCtx, pkey); - check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed"); + check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed"); - if (!(mSsl = SSL_new(mCtx))) - throw std::runtime_error("Unable to create SSL instance"); + if (!(mSsl = SSL_new(mCtx))) + throw std::runtime_error("Failed to create SSL instance"); - SSL_set_ex_data(mSsl, TransportExIndex, this); + SSL_set_ex_data(mSsl, TransportExIndex, this); - if (lower->role() == Description::Role::Active) - SSL_set_connect_state(mSsl); - else - SSL_set_accept_state(mSsl); + if (lower->role() == Description::Role::Active) + SSL_set_connect_state(mSsl); + else + SSL_set_accept_state(mSsl); - if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods))) - throw std::runtime_error("Unable to create BIO"); + if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods))) + throw std::runtime_error("Failed to create BIO"); - BIO_set_mem_eof_return(mInBio, BIO_EOF); - BIO_set_data(mOutBio, this); - SSL_set_bio(mSsl, mInBio, mOutBio); + BIO_set_mem_eof_return(mInBio, BIO_EOF); + BIO_set_data(mOutBio, this); + SSL_set_bio(mSsl, mInBio, mOutBio); - auto ecdh = unique_ptr( - EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); - SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE); - SSL_set_tmp_ecdh(mSsl, ecdh.get()); + auto ecdh = unique_ptr( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); + SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE); + SSL_set_tmp_ecdh(mSsl, ecdh.get()); - mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this); - registerIncoming(); + mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this); + registerIncoming(); + + } catch (...) { + if (mSsl) + SSL_free(mSsl); + if (mCtx) + SSL_CTX_free(mCtx); + throw; + } } DtlsTransport::~DtlsTransport() { @@ -432,18 +435,14 @@ bool DtlsTransport::stop() { return true; } -DtlsTransport::State DtlsTransport::state() const { return mState; } - bool DtlsTransport::send(message_ptr message) { - if (!message || mState != State::Connected) + if (!message || state() != State::Connected) return false; PLOG_VERBOSE << "Send size=" << message->size(); int ret = SSL_write(mSsl, message->data(), message->size()); - if (!check_openssl_ret(mSsl, ret)) - return false; - return true; + return check_openssl_ret(mSsl, ret); } void DtlsTransport::incoming(message_ptr message) { @@ -456,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } -void DtlsTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(state); -} - void DtlsTransport::runRecvLoop() { const size_t maxMtu = 4096; try { @@ -479,7 +473,7 @@ void DtlsTransport::runRecvLoop() { auto message = *mIncomingQueue.pop(); BIO_write(mInBio, message->data(), message->size()); - if (mState == State::Connecting) { + if (state() == State::Connecting) { // Continue the handshake int ret = SSL_do_handshake(mSsl); if (!check_openssl_ret(mSsl, ret, "Handshake failed")) @@ -490,7 +484,7 @@ void DtlsTransport::runRecvLoop() { // MTU See https://tools.ietf.org/html/rfc8261#section-5 SSL_set_mtu(mSsl, maxMtu + 1); - PLOG_INFO << "DTLS handshake done"; + PLOG_INFO << "DTLS handshake finished"; changeState(State::Connected); } } else { @@ -504,7 +498,7 @@ void DtlsTransport::runRecvLoop() { // No more messages pending, retransmit and rearm timeout if connecting std::optional duration; - if (mState == State::Connecting) { + if (state() == State::Connecting) { // Warning: This function breaks the usual return value convention int ret = DTLSv1_handle_timeout(mSsl); if (ret < 0) { @@ -514,7 +508,7 @@ void DtlsTransport::runRecvLoop() { } struct timeval timeout = {}; - if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) { + if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) { duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000); // Also handle handshake timeout manually because OpenSSL actually doesn't... // OpenSSL backs off exponentially in base 2 starting from the recommended 1s @@ -535,8 +529,8 @@ void DtlsTransport::runRecvLoop() { PLOG_ERROR << "DTLS recv: " << e.what(); } - if (mState == State::Connected) { - PLOG_INFO << "DTLS disconnected"; + if (state() == State::Connected) { + PLOG_INFO << "DTLS closed"; changeState(State::Disconnected); recv(nullptr); } else { diff --git a/src/dtlstransport.hpp b/src/dtlstransport.hpp index d0e6030..9389069 100644 --- a/src/dtlstransport.hpp +++ b/src/dtlstransport.hpp @@ -46,33 +46,25 @@ public: static void Init(); static void Cleanup(); - enum class State { Disconnected, Connecting, Connected, Failed }; - using verifier_callback = std::function; - using state_callback = std::function; DtlsTransport(std::shared_ptr lower, std::shared_ptr certificate, verifier_callback verifierCallback, state_callback stateChangeCallback); ~DtlsTransport(); - State state() const; - bool stop() override; bool send(message_ptr message) override; // false if dropped private: void incoming(message_ptr message) override; - void changeState(State state); void runRecvLoop(); const std::shared_ptr mCertificate; Queue mIncomingQueue; - std::atomic mState; std::thread mRecvThread; verifier_callback mVerifierCallback; - state_callback mStateChangeCallback; #if USE_GNUTLS gnutls_session_t mSession; @@ -82,8 +74,8 @@ private: static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms); #else - SSL_CTX *mCtx; - SSL *mSsl; + SSL_CTX *mCtx = NULL; + SSL *mSsl = NULL; BIO *mInBio, *mOutBio; static BIO_METHOD *BioMethods; diff --git a/src/icetransport.cpp b/src/icetransport.cpp index 22ef719..d402fd5 100644 --- a/src/icetransport.cpp +++ b/src/icetransport.cpp @@ -48,9 +48,8 @@ namespace rtc { IceTransport::IceTransport(const Configuration &config, Description::Role role, candidate_callback candidateCallback, state_callback stateChangeCallback, gathering_state_callback gatheringStateChangeCallback) - : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New), - mCandidateCallback(std::move(candidateCallback)), - mStateChangeCallback(std::move(stateChangeCallback)), + : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"), + mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)), mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), mAgent(nullptr, nullptr) { @@ -84,6 +83,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role, mStunService = server.service; jconfig.stun_server_host = mStunHostname.c_str(); jconfig.stun_server_port = std::stoul(mStunService); + break; } } @@ -108,8 +108,6 @@ bool IceTransport::stop() { Description::Role IceTransport::role() const { return mRole; } -IceTransport::State IceTransport::state() const { return mState; } - Description IceTransport::getLocalDescription(Description::Type type) const { char sdp[JUICE_MAX_SDP_STRING_LEN]; if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0) @@ -161,7 +159,8 @@ std::optional IceTransport::getRemoteAddress() const { } bool IceTransport::send(message_ptr message) { - if (!message || (mState != State::Connected && mState != State::Completed)) + auto s = state(); + if (!message || (s != State::Connected && s != State::Completed)) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -173,18 +172,29 @@ bool IceTransport::outgoing(message_ptr message) { message->size()) >= 0; } -void IceTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(mState); -} - void IceTransport::changeGatheringState(GatheringState state) { if (mGatheringState.exchange(state) != state) mGatheringStateChangeCallback(mGatheringState); } void IceTransport::processStateChange(unsigned int state) { - changeState(static_cast(state)); + switch (state) { + case JUICE_STATE_DISCONNECTED: + changeState(State::Disconnected); + break; + case JUICE_STATE_CONNECTING: + changeState(State::Connecting); + break; + case JUICE_STATE_CONNECTED: + changeState(State::Connected); + break; + case JUICE_STATE_COMPLETED: + changeState(State::Completed); + break; + case JUICE_STATE_FAILED: + changeState(State::Failed); + break; + }; } void IceTransport::processCandidate(const string &candidate) { @@ -263,9 +273,8 @@ namespace rtc { IceTransport::IceTransport(const Configuration &config, Description::Role role, candidate_callback candidateCallback, state_callback stateChangeCallback, gathering_state_callback gatheringStateChangeCallback) - : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New), - mCandidateCallback(std::move(candidateCallback)), - mStateChangeCallback(std::move(stateChangeCallback)), + : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"), + mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)), mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) { @@ -457,8 +466,6 @@ bool IceTransport::stop() { Description::Role IceTransport::role() const { return mRole; } -IceTransport::State IceTransport::state() const { return mState; } - Description IceTransport::getLocalDescription(Description::Type type) const { // RFC 8445: The initiating agent that started the ICE processing MUST take the controlling // role, and the other MUST take the controlled role. @@ -529,7 +536,8 @@ std::optional IceTransport::getRemoteAddress() const { } bool IceTransport::send(message_ptr message) { - if (!message || (mState != State::Connected && mState != State::Completed)) + auto s = state(); + if (!message || (s != State::Connected && s != State::Completed)) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -541,11 +549,6 @@ bool IceTransport::outgoing(message_ptr message) { reinterpret_cast(message->data())) >= 0; } -void IceTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(mState); -} - void IceTransport::changeGatheringState(GatheringState state) { if (mGatheringState.exchange(state) != state) mGatheringStateChangeCallback(mGatheringState); @@ -576,7 +579,23 @@ void IceTransport::processStateChange(unsigned int state) { mTimeoutId = 0; } - changeState(static_cast(state)); + switch (state) { + case NICE_COMPONENT_STATE_DISCONNECTED: + changeState(State::Disconnected); + break; + case NICE_COMPONENT_STATE_CONNECTING: + changeState(State::Connecting); + break; + case NICE_COMPONENT_STATE_CONNECTED: + changeState(State::Connected); + break; + case NICE_COMPONENT_STATE_READY: + changeState(State::Completed); + break; + case NICE_COMPONENT_STATE_FAILED: + changeState(State::Failed); + break; + }; } string IceTransport::AddressToString(const NiceAddress &addr) { diff --git a/src/icetransport.hpp b/src/icetransport.hpp index 46375be..ce0d3a4 100644 --- a/src/icetransport.hpp +++ b/src/icetransport.hpp @@ -40,29 +40,9 @@ namespace rtc { class IceTransport : public Transport { public: -#if USE_JUICE - enum class State : unsigned int{ - Disconnected = JUICE_STATE_DISCONNECTED, - Connecting = JUICE_STATE_CONNECTING, - Connected = JUICE_STATE_CONNECTED, - Completed = JUICE_STATE_COMPLETED, - Failed = JUICE_STATE_FAILED, - }; -#else - enum class State : unsigned int { - Disconnected = NICE_COMPONENT_STATE_DISCONNECTED, - Connecting = NICE_COMPONENT_STATE_CONNECTING, - Connected = NICE_COMPONENT_STATE_CONNECTED, - Completed = NICE_COMPONENT_STATE_READY, - Failed = NICE_COMPONENT_STATE_FAILED, - }; - - bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote); -#endif enum class GatheringState { New = 0, InProgress = 1, Complete = 2 }; using candidate_callback = std::function; - using state_callback = std::function; using gathering_state_callback = std::function; IceTransport(const Configuration &config, Description::Role role, @@ -71,7 +51,6 @@ public: ~IceTransport(); Description::Role role() const; - State state() const; GatheringState gatheringState() const; Description getLocalDescription(Description::Type type) const; void setRemoteDescription(const Description &description); @@ -84,10 +63,13 @@ public: bool stop() override; bool send(message_ptr message) override; // false if dropped +#if !USE_JUICE + bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote); +#endif + private: bool outgoing(message_ptr message) override; - void changeState(State state); void changeGatheringState(GatheringState state); void processStateChange(unsigned int state); @@ -98,11 +80,9 @@ private: Description::Role mRole; string mMid; std::chrono::milliseconds mTrickleTimeout; - std::atomic mState; std::atomic mGatheringState; candidate_callback mCandidateCallback; - state_callback mStateChangeCallback; gathering_state_callback mGatheringStateChangeCallback; #if USE_JUICE diff --git a/src/init.cpp b/src/init.cpp index 38afc1a..425bab2 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -21,6 +21,10 @@ #include "dtlstransport.hpp" #include "sctptransport.hpp" +#if RTC_ENABLE_WEBSOCKET +#include "tlstransport.hpp" +#endif + #ifdef _WIN32 #include #endif @@ -69,13 +73,19 @@ Init::Init() { ERR_load_crypto_strings(); #endif - DtlsTransport::Init(); SctpTransport::Init(); + DtlsTransport::Init(); +#if RTC_ENABLE_WEBSOCKET + TlsTransport::Init(); +#endif } Init::~Init() { - DtlsTransport::Cleanup(); SctpTransport::Cleanup(); + DtlsTransport::Cleanup(); +#if RTC_ENABLE_WEBSOCKET + TlsTransport::Cleanup(); +#endif #ifdef _WIN32 WSACleanup(); diff --git a/src/peerconnection.cpp b/src/peerconnection.cpp index 234b3eb..3d3a1c4 100644 --- a/src/peerconnection.cpp +++ b/src/peerconnection.cpp @@ -23,7 +23,6 @@ #include "include.hpp" #include "sctptransport.hpp" -#include #include namespace rtc { @@ -33,23 +32,6 @@ using namespace std::placeholders; using std::shared_ptr; using std::weak_ptr; -template auto weak_bind(F &&f, T *t, Args &&... _args) { - return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) { - if (auto shared_this = weak_this.lock()) - bound(args...); - }; -} - -template -auto weak_bind_verifier(F &&f, T *t, Args &&... _args) { - return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) { - if (auto shared_this = weak_this.lock()) - return bound(args...); - else - return false; - }; -} - PeerConnection::PeerConnection() : PeerConnection(Configuration()) {} PeerConnection::PeerConnection(const Configuration &config) @@ -271,7 +253,7 @@ shared_ptr PeerConnection::initDtlsTransport() { auto lower = std::atomic_load(&mIceTransport); auto transport = std::make_shared( - lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1), + lower, mCertificate, weak_bind(&PeerConnection::checkFingerprint, this, _1), [this, weak_this = weak_from_this()](DtlsTransport::State state) { auto shared_this = weak_this.lock(); if (!shared_this) diff --git a/src/rtc.cpp b/src/rtc.cpp index 93ba20d..fb48d37 100644 --- a/src/rtc.cpp +++ b/src/rtc.cpp @@ -16,10 +16,15 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#include "datachannel.hpp" #include "include.hpp" + +#include "datachannel.hpp" #include "peerconnection.hpp" +#if RTC_ENABLE_WEBSOCKET +#include "websocket.hpp" +#endif + #include #include @@ -43,6 +48,9 @@ namespace { std::unordered_map> peerConnectionMap; std::unordered_map> dataChannelMap; +#if RTC_ENABLE_WEBSOCKET +std::unordered_map> webSocketMap; +#endif std::unordered_map userPointerMap; std::mutex mutex; int lastId = 0; @@ -103,6 +111,40 @@ bool eraseDataChannel(int dc) { return true; } +#if RTC_ENABLE_WEBSOCKET +shared_ptr getWebSocket(int id) { + std::lock_guard lock(mutex); + auto it = webSocketMap.find(id); + return it != webSocketMap.end() ? it->second : nullptr; +} + +int emplaceWebSocket(shared_ptr ptr) { + std::lock_guard lock(mutex); + int ws = ++lastId; + webSocketMap.emplace(std::make_pair(ws, ptr)); + return ws; +} + +bool eraseWebSocket(int ws) { + std::lock_guard lock(mutex); + if (webSocketMap.erase(ws) == 0) + return false; + userPointerMap.erase(ws); + return true; +} +#endif + +shared_ptr getChannel(int id) { + std::lock_guard lock(mutex); + if (auto it = dataChannelMap.find(id); it != dataChannelMap.end()) + return it->second; +#if RTC_ENABLE_WEBSOCKET + if (auto it = webSocketMap.find(id); it != webSocketMap.end()) + return it->second; +#endif + return nullptr; +} + } // namespace void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast(level)); } @@ -164,6 +206,29 @@ int rtcDeleteDataChannel(int dc) { return 0; } +#if RTC_ENABLE_WEBSOCKET +int rtcCreateWebSocket(const char *url) { + return emplaceWebSocket(std::make_shared(url)); +} + +int rtcDeleteWebsocket(int ws) { + auto webSocket = getWebSocket(ws); + if (!webSocket) + return -1; + + webSocket->onOpen(nullptr); + webSocket->onClosed(nullptr); + webSocket->onError(nullptr); + webSocket->onMessage(nullptr); + webSocket->onBufferedAmountLow(nullptr); + webSocket->onAvailable(nullptr); + + eraseWebSocket(ws); + return 0; +} + +#endif + int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb) { auto peerConnection = getPeerConnection(pc); if (!peerConnection) @@ -298,135 +363,135 @@ int rtcGetDataChannelLabel(int dc, char *buffer, int size) { return size + 1; } -int rtcSetOpenCallback(int dc, openCallbackFunc cb) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetOpenCallback(int id, openCallbackFunc cb) { + auto channel = getChannel(id); + if (!channel) return -1; if (cb) - dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); }); + channel->onOpen([id, cb]() { cb(getUserPointer(id)); }); else - dataChannel->onOpen(nullptr); + channel->onOpen(nullptr); return 0; } -int rtcSetClosedCallback(int dc, closedCallbackFunc cb) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetClosedCallback(int id, closedCallbackFunc cb) { + auto channel = getChannel(id); + if (!channel) return -1; if (cb) - dataChannel->onClosed([dc, cb]() { cb(getUserPointer(dc)); }); + channel->onClosed([id, cb]() { cb(getUserPointer(id)); }); else - dataChannel->onClosed(nullptr); + channel->onClosed(nullptr); return 0; } -int rtcSetErrorCallback(int dc, errorCallbackFunc cb) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetErrorCallback(int id, errorCallbackFunc cb) { + auto channel = getChannel(id); + if (!channel) return -1; if (cb) - dataChannel->onError( - [dc, cb](const string &error) { cb(error.c_str(), getUserPointer(dc)); }); + channel->onError([id, cb](const string &error) { cb(error.c_str(), getUserPointer(id)); }); else - dataChannel->onError(nullptr); + channel->onError(nullptr); return 0; } -int rtcSetMessageCallback(int dc, messageCallbackFunc cb) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetMessageCallback(int id, messageCallbackFunc cb) { + auto channel = getChannel(id); + if (!channel) return -1; if (cb) - dataChannel->onMessage( - [dc, cb](const binary &b) { - cb(reinterpret_cast(b.data()), b.size(), getUserPointer(dc)); + channel->onMessage( + [id, cb](const binary &b) { + cb(reinterpret_cast(b.data()), b.size(), getUserPointer(id)); }, - [dc, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(dc)); }); + [id, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(id)); }); else - dataChannel->onMessage(nullptr); + channel->onMessage(nullptr); return 0; } -int rtcSendMessage(int dc, const char *data, int size) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSendMessage(int id, const char *data, int size) { + auto channel = getChannel(id); + if (!channel) return -1; if (size >= 0) { auto b = reinterpret_cast(data); - CATCH(dataChannel->send(b, size)); + CATCH(channel->send(binary(b, b + size))); return size; } else { - string s(data); - CATCH(dataChannel->send(s)); - return s.size(); + string str(data); + int len = str.size(); + CATCH(channel->send(std::move(str))); + return len; } } -int rtcGetBufferedAmount(int dc) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcGetBufferedAmount(int id) { + auto channel = getChannel(id); + if (!channel) return -1; - CATCH(return int(dataChannel->bufferedAmount())); + CATCH(return int(channel->bufferedAmount())); } -int rtcSetBufferedAmountLowThreshold(int dc, int amount) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetBufferedAmountLowThreshold(int id, int amount) { + auto channel = getChannel(id); + if (!channel) return -1; - CATCH(dataChannel->setBufferedAmountLowThreshold(size_t(amount))); + CATCH(channel->setBufferedAmountLowThreshold(size_t(amount))); return 0; } -int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb) { + auto channel = getChannel(id); + if (!channel) return -1; if (cb) - dataChannel->onBufferedAmountLow([dc, cb]() { cb(getUserPointer(dc)); }); + channel->onBufferedAmountLow([id, cb]() { cb(getUserPointer(id)); }); else - dataChannel->onBufferedAmountLow(nullptr); + channel->onBufferedAmountLow(nullptr); return 0; } -int rtcGetAvailableAmount(int dc) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcGetAvailableAmount(int id) { + auto channel = getChannel(id); + if (!channel) return -1; - CATCH(return int(dataChannel->availableAmount())); + CATCH(return int(channel->availableAmount())); } -int rtcSetAvailableCallback(int dc, availableCallbackFunc cb) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcSetAvailableCallback(int id, availableCallbackFunc cb) { + auto channel = getChannel(id); + if (!channel) return -1; if (cb) - dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); }); + channel->onOpen([id, cb]() { cb(getUserPointer(id)); }); else - dataChannel->onOpen(nullptr); + channel->onOpen(nullptr); return 0; } -int rtcReceiveMessage(int dc, char *buffer, int *size) { - auto dataChannel = getDataChannel(dc); - if (!dataChannel) +int rtcReceiveMessage(int id, char *buffer, int *size) { + auto channel = getChannel(id); + if (!channel) return -1; if (!size) return -1; CATCH({ - auto message = dataChannel->receive(); + auto message = channel->receive(); if (!message) return 0; diff --git a/src/sctptransport.cpp b/src/sctptransport.cpp index af0b1c4..024d34d 100644 --- a/src/sctptransport.cpp +++ b/src/sctptransport.cpp @@ -71,9 +71,8 @@ void SctpTransport::Cleanup() { SctpTransport::SctpTransport(std::shared_ptr lower, uint16_t port, message_callback recvCallback, amount_callback bufferedAmountCallback, state_callback stateChangeCallback) - : Transport(lower), mPort(port), mSendQueue(0, message_size_func), - mBufferedAmountCallback(std::move(bufferedAmountCallback)), - mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) { + : Transport(lower, std::move(stateChangeCallback)), mPort(port), + mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) { onRecv(recvCallback); PLOG_DEBUG << "Initializing SCTP transport"; @@ -180,8 +179,6 @@ SctpTransport::~SctpTransport() { usrsctp_deregister_address(this); } -SctpTransport::State SctpTransport::state() const { return mState; } - bool SctpTransport::stop() { if (!Transport::stop()) return false; @@ -240,6 +237,7 @@ void SctpTransport::shutdown() { bool SctpTransport::send(message_ptr message) { std::lock_guard lock(mSendMutex); + if (!message) return mSendQueue.empty(); @@ -269,7 +267,7 @@ void SctpTransport::incoming(message_ptr message) { // to be sent on our side (i.e. the local INIT) before proceeding. { std::unique_lock lock(mWriteMutex); - mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; }); + mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; }); } if (!message) { @@ -283,11 +281,6 @@ void SctpTransport::incoming(message_ptr message) { usrsctp_conninput(this, message->data(), message->size(), 0); } -void SctpTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(state); -} - bool SctpTransport::trySendQueue() { // Requires mSendMutex to be locked while (auto next = mSendQueue.peek()) { @@ -302,7 +295,7 @@ bool SctpTransport::trySendQueue() { bool SctpTransport::trySendMessage(message_ptr message) { // Requires mSendMutex to be locked - if (!mSock || mState != State::Connected) + if (!mSock || state() != State::Connected) return false; uint32_t ppid; @@ -414,7 +407,7 @@ void SctpTransport::sendReset(uint16_t streamId) { if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) { std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp... mWrittenCondition.wait_for(lock, 1000ms, - [&]() { return mWritten || mState != State::Connected; }); + [&]() { return mWritten || state() != State::Connected; }); } else if (errno == EINVAL) { PLOG_VERBOSE << "SCTP stream " << streamId << " already reset"; } else { @@ -571,7 +564,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s PLOG_INFO << "SCTP connected"; changeState(State::Connected); } else { - if (mState == State::Connecting) { + if (state() == State::Connecting) { PLOG_ERROR << "SCTP connection failed"; changeState(State::Failed); } else { diff --git a/src/sctptransport.hpp b/src/sctptransport.hpp index e751277..c22e9be 100644 --- a/src/sctptransport.hpp +++ b/src/sctptransport.hpp @@ -38,17 +38,12 @@ public: static void Init(); static void Cleanup(); - enum class State { Disconnected, Connecting, Connected, Failed }; - using amount_callback = std::function; - using state_callback = std::function; SctpTransport(std::shared_ptr lower, uint16_t port, message_callback recvCallback, amount_callback bufferedAmountCallback, state_callback stateChangeCallback); ~SctpTransport(); - State state() const; - bool stop() override; bool send(message_ptr message) override; // false if buffered void close(unsigned int stream); @@ -76,7 +71,6 @@ private: void connect(); void shutdown(); void incoming(message_ptr message) override; - void changeState(State state); bool trySendQueue(); bool trySendMessage(message_ptr message); @@ -105,14 +99,11 @@ private: std::atomic mWritten = false; // written outside lock bool mWrittenOnce = false; - state_callback mStateChangeCallback; - std::atomic mState; + binary mPartialRecv, mPartialStringData, mPartialBinaryData; // Stats std::atomic mBytesSent = 0, mBytesReceived = 0; - binary mPartialRecv, mPartialStringData, mPartialBinaryData; - static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, struct sctp_rcvinfo recv_info, int flags, void *user_data); static int SendCallback(struct socket *sock, uint32_t sb_free); diff --git a/src/tcptransport.cpp b/src/tcptransport.cpp new file mode 100644 index 0000000..cc6defc --- /dev/null +++ b/src/tcptransport.cpp @@ -0,0 +1,320 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "tcptransport.hpp" + +#include +#ifndef _WIN32 +#include +#include +#endif + +namespace rtc { + +using std::to_string; + +SelectInterrupter::SelectInterrupter() { +#ifndef _WIN32 + int pipefd[2]; + if (::pipe(pipefd) != 0) + throw std::runtime_error("Failed to create pipe"); + ::fcntl(pipefd[0], F_SETFL, O_NONBLOCK); + ::fcntl(pipefd[1], F_SETFL, O_NONBLOCK); + mPipeOut = pipefd[0]; // read + mPipeIn = pipefd[1]; // write +#endif +} + +SelectInterrupter::~SelectInterrupter() { + std::lock_guard lock(mMutex); +#ifdef _WIN32 + if (mDummySock != INVALID_SOCKET) + ::closesocket(mDummySock); +#else + ::close(mPipeIn); + ::close(mPipeOut); +#endif +} + +int SelectInterrupter::prepare(fd_set &readfds, fd_set &writefds) { + std::lock_guard lock(mMutex); +#ifdef _WIN32 + if (mDummySock == INVALID_SOCKET) + mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0); + FD_SET(mDummySock, &readfds); + return SOCK_TO_INT(mDummySock) + 1; +#else + int ret; + do { + char dummy; + ret = ::read(mPipeIn, &dummy, 1); + } while (ret > 0); + FD_SET(mPipeIn, &readfds); + return mPipeIn + 1; +#endif +} + +void SelectInterrupter::interrupt() { + std::lock_guard lock(mMutex); +#ifdef _WIN32 + if (mDummySock != INVALID_SOCKET) { + ::closesocket(mDummySock); + mDummySock = INVALID_SOCKET; + } +#else + char dummy = 0; + ::write(mPipeOut, &dummy, 1); +#endif +} + +TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback) + : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) { + + PLOG_DEBUG << "Initializing TCP transport"; + mThread = std::thread(&TcpTransport::runLoop, this); +} + +TcpTransport::~TcpTransport() { + stop(); +} + +bool TcpTransport::stop() { + if (!Transport::stop()) + return false; + + PLOG_DEBUG << "Waiting TCP recv thread"; + close(); + mThread.join(); + return true; +} + +bool TcpTransport::send(message_ptr message) { + if (!message) + return mSendQueue.empty(); + + PLOG_VERBOSE << "Send size=" << (message ? message->size() : 0); + + return outgoing(message); +} + +void TcpTransport::incoming(message_ptr message) { recv(message); } + +bool TcpTransport::outgoing(message_ptr message) { + // If nothing is pending, try to send directly + // It's safe because if the queue is empty, the thread is not sending + if (mSendQueue.empty() && trySendMessage(message)) + return true; + + mSendQueue.push(message); + interruptSelect(); // so the thread waits for writability + return false; +} + +void TcpTransport::connect(const string &hostname, const string &service) { + PLOG_DEBUG << "Connecting to " << hostname << ":" << service; + + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG; + + struct addrinfo *result = nullptr; + if (getaddrinfo(hostname.c_str(), service.c_str(), &hints, &result)) + throw std::runtime_error("Resolution failed for \"" + hostname + ":" + service + "\""); + + for (auto p = result; p; p = p->ai_next) + try { + connect(p->ai_addr, p->ai_addrlen); + freeaddrinfo(result); + return; + } catch (const std::runtime_error &e) { + PLOG_WARNING << e.what(); + } + + freeaddrinfo(result); + throw std::runtime_error("Connection failed to \"" + hostname + ":" + service + "\""); +} + +void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) { + try { + PLOG_DEBUG << "Creating TCP socket"; + + // Create socket + mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP); + if (mSock == INVALID_SOCKET) + throw std::runtime_error("TCP socket creation failed"); + + ctl_t b = 1; + if (::ioctlsocket(mSock, FIONBIO, &b) < 0) + throw std::runtime_error("Failed to set socket non-blocking mode"); + + IF_PLOG(plog::debug) { + char node[MAX_NUMERICNODE_LEN]; + char serv[MAX_NUMERICSERV_LEN]; + if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN, + NI_NUMERICHOST | NI_NUMERICSERV) == 0) { + PLOG_DEBUG << "Trying address " << node << ":" << serv; + } + } + + // Initiate connection + ::connect(mSock, addr, addrlen); + + fd_set writefds; + FD_ZERO(&writefds); + FD_SET(mSock, &writefds); + struct timeval tv; + tv.tv_sec = 10; // TODO + tv.tv_usec = 0; + int ret = ::select(SOCKET_TO_INT(mSock) + 1, NULL, &writefds, NULL, &tv); + + if (ret < 0) + throw std::runtime_error("Failed to wait for socket connection"); + + if (ret == 0 || ::send(mSock, NULL, 0, MSG_NOSIGNAL) != 0) + throw std::runtime_error("Connection failed"); + + } catch (...) { + if (mSock != INVALID_SOCKET) { + ::closesocket(mSock); + mSock = INVALID_SOCKET; + } + throw; + } +} + +void TcpTransport::close() { + if (mSock != INVALID_SOCKET) { + PLOG_DEBUG << "Closing TCP socket"; + ::closesocket(mSock); + mSock = INVALID_SOCKET; + } + changeState(State::Disconnected); +} + +bool TcpTransport::trySendQueue() { + while (auto next = mSendQueue.peek()) { + auto message = *next; + if (!trySendMessage(message)) { + mSendQueue.exchange(message); + return false; + } + mSendQueue.pop(); + } + return true; +} + +bool TcpTransport::trySendMessage(message_ptr &message) { + auto data = reinterpret_cast(message->data()); + auto size = message->size(); + while (size) { + int len = ::send(mSock, data, size, MSG_NOSIGNAL); + if (len < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + message = make_message(message->end() - size, message->end()); + return false; + } else { + throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno)); + } + } + + data += len; + size -= len; + } + message = nullptr; + return true; +} + +void TcpTransport::runLoop() { + const size_t bufferSize = 4096; + + // Connect + try { + changeState(State::Connecting); + connect(mHostname, mService); + + } catch (const std::exception &e) { + PLOG_ERROR << "TCP connect: " << e.what(); + changeState(State::Failed); + return; + } + + + // Receive loop + try { + PLOG_INFO << "TCP connected"; + changeState(State::Connected); + + while (true) { + fd_set readfds, writefds; + int n = prepareSelect(readfds, writefds); + int ret = ::select(n, &readfds, &writefds, NULL, NULL); + if (ret < 0) + throw std::runtime_error("Failed to wait on socket"); + + if (FD_ISSET(mSock, &writefds)) + trySendQueue(); + + if (FD_ISSET(mSock, &readfds)) { + char buffer[bufferSize]; + int len = ::recv(mSock, buffer, bufferSize, 0); + if (len < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + continue; + } else { + throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno)); + } + } + + if (len == 0) + break; // clean close + + auto *b = reinterpret_cast(buffer); + incoming(make_message(b, b + len)); + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "TCP recv: " << e.what(); + } + + PLOG_INFO << "TCP disconnected"; + changeState(State::Disconnected); + recv(nullptr); +} + +int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) { + FD_ZERO(&readfds); + FD_ZERO(&writefds); + FD_SET(mSock, &readfds); + + if (!mSendQueue.empty()) + FD_SET(mSock, &writefds); + + int n = SOCKET_TO_INT(mSock) + 1; + int m = mInterrupter.prepare(readfds, writefds); + return std::max(n, m); +} + +void TcpTransport::interruptSelect() { mInterrupter.interrupt(); } + +} // namespace rtc + +#endif diff --git a/src/tcptransport.hpp b/src/tcptransport.hpp new file mode 100644 index 0000000..565916a --- /dev/null +++ b/src/tcptransport.hpp @@ -0,0 +1,90 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_TCP_TRANSPORT_H +#define RTC_TCP_TRANSPORT_H + +#if RTC_ENABLE_WEBSOCKET + +#include "include.hpp" +#include "queue.hpp" +#include "transport.hpp" + +#include +#include + +// Use the socket defines from libjuice +#include "../deps/libjuice/src/socket.h" + +namespace rtc { + +// Utility class to interrupt select() +class SelectInterrupter { +public: + SelectInterrupter(); + ~SelectInterrupter(); + + int prepare(fd_set &readfds, fd_set &writefds); + void interrupt(); + +private: + std::mutex mMutex; +#ifdef _WIN32 + socket_t mDummySock = INVALID_SOCKET; +#else // assume POSIX + int mPipeIn, mPipeOut; +#endif +}; + +class TcpTransport : public Transport { +public: + TcpTransport(const string &hostname, const string &service, state_callback callback); + ~TcpTransport(); + + bool stop() override; + bool send(message_ptr message) override; + + void incoming(message_ptr message) override; + bool outgoing(message_ptr message) override; + +private: + void connect(const string &hostname, const string &service); + void connect(const sockaddr *addr, socklen_t addrlen); + void close(); + + bool trySendQueue(); + bool trySendMessage(message_ptr &message); + + void runLoop(); + + int prepareSelect(fd_set &readfds, fd_set &writefds); + void interruptSelect(); + + string mHostname, mService; + + socket_t mSock = INVALID_SOCKET; + std::thread mThread; + SelectInterrupter mInterrupter; + Queue mSendQueue; +}; + +} // namespace rtc + +#endif + +#endif diff --git a/src/tlstransport.cpp b/src/tlstransport.cpp new file mode 100644 index 0000000..89ef21f --- /dev/null +++ b/src/tlstransport.cpp @@ -0,0 +1,432 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "tlstransport.hpp" +#include "tcptransport.hpp" + +#include +#include +#include +#include + +using namespace std::chrono; + +using std::shared_ptr; +using std::string; +using std::unique_ptr; +using std::weak_ptr; + +#if USE_GNUTLS + +namespace { + +static bool check_gnutls(int ret, const string &message = "GnuTLS error") { + if (ret < 0) { + if (!gnutls_error_is_fatal(ret)) { + PLOG_INFO << gnutls_strerror(ret); + return false; + } + PLOG_ERROR << message << ": " << gnutls_strerror(ret); + throw std::runtime_error(message + ": " + gnutls_strerror(ret)); + } + return true; +} + +} // namespace + +namespace rtc { + +void TlsTransport::Init() { + // Nothing to do +} + +void TlsTransport::Cleanup() { + // Nothing to do +} + +TlsTransport::TlsTransport(shared_ptr lower, string host, state_callback callback) + : Transport(lower, std::move(callback)) { + + PLOG_DEBUG << "Initializing TLS transport (GnuTLS)"; + + check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT)); + + try { + const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128"; + const char *err_pos = NULL; + check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos), + "Failed to set TLS priorities"); + + gnutls_session_set_ptr(mSession, this); + gnutls_transport_set_ptr(mSession, this); + gnutls_transport_set_push_function(mSession, WriteCallback); + gnutls_transport_set_pull_function(mSession, ReadCallback); + gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); + + gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size()); + + mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); + registerIncoming(); + + } catch (...) { + + gnutls_deinit(mSession); + throw; + } +} + +TlsTransport::~TlsTransport() { + stop(); + gnutls_deinit(mSession); +} + +bool TlsTransport::stop() { + if (!Transport::stop()) + return false; + + PLOG_DEBUG << "Stopping TLS recv thread"; + mIncomingQueue.stop(); + mRecvThread.join(); + return true; +} + +bool TlsTransport::send(message_ptr message) { + if (!message) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + + ssize_t ret; + do { + ret = gnutls_record_send(mSession, message->data(), message->size()); + } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN); + + return check_gnutls(ret); +} + +void TlsTransport::incoming(message_ptr message) { + if (message) + mIncomingQueue.push(message); + else + mIncomingQueue.stop(); +} + +void TlsTransport::runRecvLoop() { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + // Handshake loop + try { + changeState(State::Connecting); + + int ret; + do { + ret = gnutls_handshake(mSession); + } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN || + !check_gnutls(ret, "TLS handshake failed")); + + } catch (const std::exception &e) { + PLOG_ERROR << "TLS handshake: " << e.what(); + changeState(State::Failed); + return; + } + + // Receive loop + try { + PLOG_INFO << "TLS handshake finished"; + changeState(State::Connected); + + while (true) { + ssize_t ret; + do { + ret = gnutls_record_recv(mSession, buffer, bufferSize); + } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN); + + // Consider premature termination as remote closing + if (ret == GNUTLS_E_PREMATURE_TERMINATION) { + PLOG_DEBUG << "TLS connection terminated"; + break; + } + + if (check_gnutls(ret)) { + if (ret == 0) { + // Closed + PLOG_DEBUG << "TLS connection cleanly closed"; + break; + } + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "TLS recv: " << e.what(); + } + + gnutls_bye(mSession, GNUTLS_SHUT_RDWR); + + PLOG_INFO << "TLS closed"; + changeState(State::Disconnected); + recv(nullptr); +} + +ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) { + TlsTransport *t = static_cast(ptr); + if (len > 0) { + auto b = reinterpret_cast(data); + t->outgoing(make_message(b, b + len)); + } + gnutls_transport_set_errno(t->mSession, 0); + return ssize_t(len); +} + +ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) { + TlsTransport *t = static_cast(ptr); + if (auto next = t->mIncomingQueue.pop()) { + auto message = *next; + ssize_t len = std::min(maxlen, message->size()); + std::memcpy(data, message->data(), len); + gnutls_transport_set_errno(t->mSession, 0); + return len; + } + // Closed + gnutls_transport_set_errno(t->mSession, 0); + return 0; +} + +int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) { + TlsTransport *t = static_cast(ptr); + if (ms != GNUTLS_INDEFINITE_TIMEOUT) + t->mIncomingQueue.wait(milliseconds(ms)); + else + t->mIncomingQueue.wait(); + return !t->mIncomingQueue.empty() ? 1 : 0; +} + +} // namespace rtc + +#else // USE_GNUTLS==0 + +#include +#include +#include +#include + +namespace { + +const int BIO_EOF = -1; + +string openssl_error_string(unsigned long err) { + const size_t bufferSize = 256; + char buffer[bufferSize]; + ERR_error_string_n(err, buffer, bufferSize); + return string(buffer); +} + +bool check_openssl(int success, const string &message = "OpenSSL error") { + if (success) + return true; + + string str = openssl_error_string(ERR_get_error()); + PLOG_ERROR << message << ": " << str; + throw std::runtime_error(message + ": " + str); +} + +bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") { + if (ret == BIO_EOF) + return true; + + unsigned long err = SSL_get_error(ssl, ret); + if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { + return true; + } + if (err == SSL_ERROR_ZERO_RETURN) { + PLOG_DEBUG << "TLS connection cleanly closed"; + return false; + } + string str = openssl_error_string(err); + PLOG_ERROR << str; + throw std::runtime_error(message + ": " + str); +} + +} // namespace + +namespace rtc { + +int TlsTransport::TransportExIndex = -1; + +void TlsTransport::Init() { + if (TransportExIndex < 0) { + TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); + } +} + +void TlsTransport::Cleanup() { + // Nothing to do +} + +TlsTransport::TlsTransport(shared_ptr lower, string host, state_callback callback) + : Transport(lower, std::move(callback)) { + + PLOG_DEBUG << "Initializing TLS transport (OpenSSL)"; + + if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible + throw std::runtime_error("Failed to create SSL context"); + + check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), + "Failed to set SSL priorities"); + + SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3); + SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION); + SSL_CTX_set_read_ahead(mCtx, 1); + SSL_CTX_set_quiet_shutdown(mCtx, 1); + SSL_CTX_set_info_callback(mCtx, InfoCallback); + + SSL_CTX_set_default_verify_paths(mCtx); + SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL); + SSL_CTX_set_verify_depth(mCtx, 4); + + if (!(mSsl = SSL_new(mCtx))) + throw std::runtime_error("Failed to create SSL instance"); + + SSL_set_ex_data(mSsl, TransportExIndex, this); + SSL_set_tlsext_host_name(mSsl, host.c_str()); + + SSL_set_connect_state(mSsl); + + if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem()))) + throw std::runtime_error("Failed to create BIO"); + + BIO_set_mem_eof_return(mInBio, BIO_EOF); + BIO_set_mem_eof_return(mOutBio, BIO_EOF); + SSL_set_bio(mSsl, mInBio, mOutBio); + + auto ecdh = unique_ptr( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); + SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE); + SSL_set_tmp_ecdh(mSsl, ecdh.get()); + + mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); +} + +TlsTransport::~TlsTransport() { + stop(); + + SSL_free(mSsl); + SSL_CTX_free(mCtx); +} + +bool TlsTransport::stop() { + if (!Transport::stop()) + return false; + + PLOG_DEBUG << "Stopping TLS recv thread"; + mIncomingQueue.stop(); + mRecvThread.join(); + SSL_shutdown(mSsl); + return true; +} + +bool TlsTransport::send(message_ptr message) { + if (!message) + return false; + + int ret = SSL_write(mSsl, message->data(), message->size()); + if (!check_openssl_ret(mSsl, ret)) + return false; + + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + while (int len = BIO_read(mOutBio, buffer, bufferSize)) + outgoing(make_message(buffer, buffer + len)); + + return true; +} + +void TlsTransport::incoming(message_ptr message) { + if (message) + mIncomingQueue.push(message); + else + mIncomingQueue.stop(); +} + +void TlsTransport::runRecvLoop() { + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + + try { + changeState(State::Connecting); + + SSL_do_handshake(mSsl); + while (int len = BIO_read(mOutBio, buffer, bufferSize)) + outgoing(make_message(buffer, buffer + len)); + + while (auto next = mIncomingQueue.pop()) { + message_ptr message = *next; + message_ptr decrypted; + + BIO_write(mInBio, message->data(), message->size()); + + int ret = SSL_read(mSsl, buffer, bufferSize); + if (!check_openssl_ret(mSsl, ret)) + break; + + if (ret > 0) + decrypted = make_message(buffer, buffer + ret); + + while (int len = BIO_read(mOutBio, buffer, bufferSize)) + outgoing(make_message(buffer, buffer + len)); + + if (state() == State::Connecting && SSL_is_init_finished(mSsl)) { + PLOG_INFO << "TLS handshake finished"; + changeState(State::Connected); + } + + if (decrypted) + recv(decrypted); + } + } catch (const std::exception &e) { + PLOG_ERROR << "TLS recv: " << e.what(); + } + + if (state() == State::Connected) { + PLOG_INFO << "TLS closed"; + recv(nullptr); + } else { + PLOG_ERROR << "TLS handshake failed"; + } +} + +void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) { + TlsTransport *t = + static_cast(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex)); + + if (where & SSL_CB_ALERT) { + if (ret != 256) { // Close Notify + PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret); + } + t->mIncomingQueue.stop(); // Close the connection + } +} + +} // namespace rtc + +#endif + +#endif diff --git a/src/tlstransport.hpp b/src/tlstransport.hpp new file mode 100644 index 0000000..6f68b23 --- /dev/null +++ b/src/tlstransport.hpp @@ -0,0 +1,83 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_TLS_TRANSPORT_H +#define RTC_TLS_TRANSPORT_H + +#if RTC_ENABLE_WEBSOCKET + +#include "include.hpp" +#include "queue.hpp" +#include "transport.hpp" + +#include +#include +#include + +#if USE_GNUTLS +#include +#else +#include +#endif + +namespace rtc { + +class TcpTransport; + +class TlsTransport : public Transport { +public: + static void Init(); + static void Cleanup(); + + TlsTransport(std::shared_ptr lower, string host, state_callback callback); + ~TlsTransport(); + + bool stop() override; + bool send(message_ptr message) override; + + void incoming(message_ptr message) override; + +protected: + void runRecvLoop(); + + Queue mIncomingQueue; + std::thread mRecvThread; + +#if USE_GNUTLS + gnutls_session_t mSession; + + static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); + static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); + static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms); +#else + SSL_CTX *mCtx; + SSL *mSsl; + BIO *mInBio, *mOutBio; + + static int TransportExIndex; + + static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx); + static void InfoCallback(const SSL *ssl, int where, int ret); +#endif +}; + +} // namespace rtc + +#endif + +#endif diff --git a/src/transport.hpp b/src/transport.hpp index fd3acaa..a19e6a6 100644 --- a/src/transport.hpp +++ b/src/transport.hpp @@ -32,7 +32,13 @@ using namespace std::placeholders; class Transport { public: - Transport(std::shared_ptr lower = nullptr) : mLower(std::move(lower)) {} + enum class State { Disconnected, Connecting, Connected, Completed, Failed }; + using state_callback = std::function; + + Transport(std::shared_ptr lower = nullptr, state_callback callback = nullptr) + : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) { + } + virtual ~Transport() { stop(); if (mLower) @@ -49,11 +55,16 @@ public: } void onRecv(message_callback callback) { mRecvCallback = std::move(callback); } + State state() const { return mState; } virtual bool send(message_ptr message) { return outgoing(message); } protected: void recv(message_ptr message) { mRecvCallback(message); } + void changeState(State state) { + if (mState.exchange(state) != state) + mStateChangeCallback(state); + } virtual void incoming(message_ptr message) { recv(message); } virtual bool outgoing(message_ptr message) { @@ -65,7 +76,10 @@ protected: private: std::shared_ptr mLower; + synchronized_callback mStateChangeCallback; synchronized_callback mRecvCallback; + + std::atomic mState = State::Disconnected; std::atomic mShutdown = false; }; diff --git a/src/websocket.cpp b/src/websocket.cpp new file mode 100644 index 0000000..fc4ae70 --- /dev/null +++ b/src/websocket.cpp @@ -0,0 +1,311 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "include.hpp" +#include "websocket.hpp" + +#include "tcptransport.hpp" +#include "tlstransport.hpp" +#include "wstransport.hpp" + +#include + +#ifdef _WIN32 +#include +#endif + +namespace rtc { + +WebSocket::WebSocket() {} + +WebSocket::WebSocket(const string &url) : WebSocket() { open(url); } + +WebSocket::~WebSocket() { remoteClose(); } + +WebSocket::State WebSocket::readyState() const { return mState; } + +void WebSocket::open(const string &url) { + if (mState != State::Closed) + throw std::runtime_error("WebSocket must be closed before opening"); + + static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)"; + static std::regex regex(rs, std::regex::extended); + + std::smatch match; + if (!std::regex_match(url, match, regex)) + throw std::invalid_argument("Malformed WebSocket URL: " + url); + + mScheme = match[2]; + if (mScheme != "ws" && mScheme != "wss") + throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme); + + mHost = match[4]; + if (auto pos = mHost.find(':'); pos != string::npos) { + mHostname = mHost.substr(0, pos); + mService = mHost.substr(pos + 1); + } else { + mHostname = mHost; + mService = mScheme == "ws" ? "80" : "443"; + } + + mPath = match[5]; + if (string query = match[7]; !query.empty()) + mPath += "?" + query; + + changeState(State::Connecting); + initTcpTransport(); +} + +void WebSocket::close() { + auto state = mState.load(); + if (state == State::Connecting || state == State::Open) { + changeState(State::Closing); + if (auto transport = std::atomic_load(&mWsTransport)) + transport->close(); + else + changeState(State::Closed); + } +} + +void WebSocket::remoteClose() { + close(); + closeTransports(); +} + +bool WebSocket::send(const std::variant &data) { + return std::visit( + [&](const auto &d) { + using T = std::decay_t; + constexpr auto type = std::is_same_v ? Message::String : Message::Binary; + auto *b = reinterpret_cast(d.data()); + return outgoing(std::make_shared(b, b + d.size(), type)); + }, + data); +} + +bool WebSocket::isOpen() const { return mState == State::Open; } + +bool WebSocket::isClosed() const { return mState == State::Closed; } + +size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; } + +std::optional> WebSocket::receive() { + while (!mRecvQueue.empty()) { + auto message = *mRecvQueue.pop(); + switch (message->type) { + case Message::String: + return std::make_optional( + string(reinterpret_cast(message->data()), message->size())); + case Message::Binary: + return std::make_optional(std::move(*message)); + default: + // Ignore + break; + } + } + return nullopt; +} + +size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); } + +bool WebSocket::changeState(State state) { return mState.exchange(state) != state; } + +bool WebSocket::outgoing(mutable_message_ptr message) { + if (mState != State::Open || !mWsTransport) + throw std::runtime_error("WebSocket is not open"); + + if (message->size() > maxMessageSize()) + throw std::runtime_error("Message size exceeds limit"); + + return mWsTransport->send(message); +} + +void WebSocket::incoming(message_ptr message) { + if (message->type == Message::String || message->type == Message::Binary) { + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); + } +} + +std::shared_ptr WebSocket::initTcpTransport() { + using State = TcpTransport::State; + try { + std::lock_guard lock(mInitMutex); + if (auto transport = std::atomic_load(&mTcpTransport)) + return transport; + + auto transport = std::make_shared( + mHostname, mService, [this, weak_this = weak_from_this()](State state) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (state) { + case State::Connected: + if (mScheme == "ws") + initWsTransport(); + else + initTlsTransport(); + break; + case State::Failed: + triggerError("TCP connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mTcpTransport, transport); + if (mState == WebSocket::State::Closed) { + mTcpTransport.reset(); + transport->stop(); + throw std::runtime_error("Connection is closed"); + } + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("TCP transport initialization failed"); + } +} + +std::shared_ptr WebSocket::initTlsTransport() { + using State = TlsTransport::State; + try { + std::lock_guard lock(mInitMutex); + if (auto transport = std::atomic_load(&mTlsTransport)) + return transport; + + auto lower = std::atomic_load(&mTcpTransport); + auto transport = std::make_shared( + lower, mHost, [this, weak_this = weak_from_this()](State state) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (state) { + case State::Connected: + initWsTransport(); + break; + case State::Failed: + triggerError("TCP connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mTlsTransport, transport); + if (mState == WebSocket::State::Closed) { + mTlsTransport.reset(); + transport->stop(); + throw std::runtime_error("Connection is closed"); + } + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("TLS transport initialization failed"); + } +} + +std::shared_ptr WebSocket::initWsTransport() { + using State = WsTransport::State; + try { + std::lock_guard lock(mInitMutex); + if (auto transport = std::atomic_load(&mWsTransport)) + return transport; + + std::shared_ptr lower = std::atomic_load(&mTlsTransport); + if (!lower) + lower = std::atomic_load(&mTcpTransport); + auto transport = std::make_shared( + lower, mHost, mPath, weak_bind(&WebSocket::incoming, this, _1), + [this, weak_this = weak_from_this()](State state) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (state) { + case State::Connected: + if (mState == WebSocket::State::Connecting) { + PLOG_DEBUG << "WebSocket open"; + changeState(WebSocket::State::Open); + triggerOpen(); + } + break; + case State::Failed: + triggerError("WebSocket connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mWsTransport, transport); + if (mState == WebSocket::State::Closed) { + mWsTransport.reset(); + transport->stop(); + throw std::runtime_error("Connection is closed"); + } + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("WebSocket transport initialization failed"); + } +} + +void WebSocket::closeTransports() { + changeState(State::Closed); + + // Pass the references to a thread, allowing to terminate a transport from its own thread + auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr)); + auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr)); + auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr)); + if (ws || tls || tcp) { + std::thread t([ws, tls, tcp]() mutable { + if (ws) + ws->stop(); + if (tls) + tls->stop(); + if (tcp) + tcp->stop(); + + ws.reset(); + tls.reset(); + tcp.reset(); + }); + t.detach(); + } +} + +} // namespace rtc + +#endif diff --git a/src/wstransport.cpp b/src/wstransport.cpp new file mode 100644 index 0000000..0537916 --- /dev/null +++ b/src/wstransport.cpp @@ -0,0 +1,372 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "wstransport.hpp" +#include "tcptransport.hpp" +#include "tlstransport.hpp" + +#include "base64.hpp" + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifndef htonll +#define htonll(x) \ + ((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32)) +#endif +#ifndef ntohll +#define ntohll(x) htonll(x) +#endif + +namespace rtc { + +using namespace std::chrono; +using std::to_integer; +using std::to_string; + +using random_bytes_engine = + std::independent_bits_engine; + +WsTransport::WsTransport(std::shared_ptr lower, string host, string path, + message_callback recvCallback, state_callback stateCallback) + : Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) { + onRecv(recvCallback); + + PLOG_DEBUG << "Initializing WebSocket transport"; + + registerIncoming(); + sendHttpRequest(); +} + +WsTransport::~WsTransport() { stop(); } + +bool WsTransport::stop() { + if (!Transport::stop()) + return false; + + close(); + return true; +} + +bool WsTransport::send(message_ptr message) { + if (!message) + return false; + + // Call the mutable message overload with a copy + return send(std::make_shared(*message)); +} + +bool WsTransport::send(mutable_message_ptr message) { + if (!message || state() != State::Connected) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + + return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(), + message->size(), true, true}); +} + +void WsTransport::incoming(message_ptr message) { + try { + mBuffer.insert(mBuffer.end(), message->begin(), message->end()); + + if (state() == State::Connecting) { + if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) { + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + PLOG_INFO << "WebSocket open"; + changeState(State::Connected); + } + } + + if (state() == State::Connected) { + Frame frame = {}; + while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) { + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + recvFrame(frame); + } + } + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } + + if (state() == State::Connected) { + PLOG_INFO << "WebSocket disconnected"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "WebSocket handshake failed"; + changeState(State::Failed); + } +} + +void WsTransport::close() { + if (state() == State::Connected) { + sendFrame({CLOSE, NULL, 0, true, true}); + PLOG_INFO << "WebSocket closing"; + changeState(State::Completed); + } +} + +bool WsTransport::sendHttpRequest() { + changeState(State::Connecting); + + auto seed = system_clock::now().time_since_epoch().count(); + random_bytes_engine generator(seed); + + binary key(16); + std::generate(reinterpret_cast(key.data()), + reinterpret_cast(key.data() + key.size()), generator); + + const string request = "GET " + mPath + + " HTTP/1.1\r\n" + "Host: " + + mHost + + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Key: " + + to_base64(key) + + "\r\n" + "\r\n"; + + auto data = reinterpret_cast(request.data()); + auto size = request.size(); + return outgoing(make_message(data, data + size)); +} + +size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) { + std::list lines; + auto begin = reinterpret_cast(buffer); + auto end = begin + size; + auto cur = begin; + while (true) { + auto last = cur; + cur = std::find(cur, end, '\n'); + if (cur == end) + return 0; + string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++); + if (line.empty()) + break; + lines.emplace_back(std::move(line)); + } + size_t length = cur - begin; + + if (lines.empty()) + throw std::runtime_error("Invalid HTTP response for WebSocket"); + + string status = std::move(lines.front()); + lines.pop_front(); + + std::istringstream ss(status); + string protocol; + unsigned int code = 0; + ss >> protocol >> code; + PLOG_DEBUG << "WebSocket response code: " << code; + if (code != 101) + throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code)); + + std::multimap headers; + for (const auto &line : lines) { + if (size_t pos = line.find_first_of(':'); pos != string::npos) { + string key = line.substr(0, pos); + string value = line.substr(line.find_first_not_of(' ', pos + 1)); + std::transform(key.begin(), key.end(), key.begin(), + [](char c) { return std::tolower(c); }); + headers.emplace(std::move(key), std::move(value)); + } else { + headers.emplace(line, ""); + } + } + + auto h = headers.find("upgrade"); + if (h == headers.end() || h->second != "websocket") + throw std::runtime_error("WebSocket update header missing or mismatching"); + + h = headers.find("sec-websocket-accept"); + if (h == headers.end()) + throw std::runtime_error("WebSocket accept header missing"); + + // TODO: Verify Sec-WebSocket-Accept + + return length; +} + +// http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-------+-+-------------+-------------------------------+ +// |F|R|R|R| opcode|M| Payload len | Extended payload length | +// |I|S|S|S| (4) |A| (7) | (16/64) | +// |N|V|V|V| |S| | (if payload len==126/127) | +// | |1|2|3| |K| | | +// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +// | Extended payload length continued, if payload len == 127 | +// + - - - - - - - - - - - - - - - +-------------------------------+ +// | | Masking-key, if MASK set to 1 | +// +-------------------------------+-------------------------------+ +// | Masking-key (continued) | Payload Data | +// +-------------------------------+ - - - - - - - - - - - - - - - + +// : Payload Data continued ... : +// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// | Payload Data continued ... | +// +---------------------------------------------------------------+ + +size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) { + const byte *end = buffer + size; + if (end - buffer < 2) + return 0; + + byte *cur = buffer; + auto b1 = to_integer(*cur++); + auto b2 = to_integer(*cur++); + + frame.fin = (b1 & 0x80) != 0; + frame.mask = (b2 & 0x80) != 0; + frame.opcode = static_cast(b1 & 0x0F); + frame.length = b2 & 0x7F; + + if (frame.length == 0x7E) { + if (end - cur < 2) + return 0; + frame.length = ntohs(*reinterpret_cast(cur)); + cur += 2; + } else if (frame.length == 0x7F) { + if (end - cur < 8) + return false; + frame.length = ntohll(*reinterpret_cast(cur)); + cur += 8; + } + + const byte *maskingKey = nullptr; + if (frame.mask) { + if (end - cur < 4) + return 0; + maskingKey = cur; + cur += 4; + } + + if (end - cur < frame.length) + return false; + + frame.payload = cur; + if (maskingKey) + for (size_t i = 0; i < frame.length; ++i) + frame.payload[i] ^= maskingKey[i % 4]; + + return end - buffer; +} + +void WsTransport::recvFrame(const Frame &frame) { + switch (frame.opcode) { + case TEXT_FRAME: + case BINARY_FRAME: { + if (!mPartial.empty()) { + auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(mPartial.begin(), mPartial.end(), type)); + mPartial.clear(); + } + if (frame.fin) { + auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(frame.payload, frame.payload + frame.length)); + } else { + mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length); + mPartialOpcode = frame.opcode; + } + break; + } + case CONTINUATION: { + mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length); + if (frame.fin) { + auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(mPartial.begin(), mPartial.end())); + mPartial.clear(); + } + break; + } + case PING: { + sendFrame({PONG, frame.payload, frame.length, true, true}); + break; + } + case PONG: { + // TODO + break; + } + case CLOSE: { + close(); + PLOG_INFO << "WebSocket closed"; + changeState(State::Disconnected); + break; + } + default: { + close(); + throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode)); + } + } +} + +bool WsTransport::sendFrame(const Frame &frame) { + byte buffer[14]; + byte *cur = buffer; + + *cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0)); + + if (frame.length < 0x7E) { + *cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0)); + } else if (frame.length <= 0xFF) { + *cur++ = byte(0x7E | (frame.mask ? 0x80 : 0)); + *reinterpret_cast(cur) = uint16_t(frame.length); + cur += 2; + } else { + *cur++ = byte(0x7F | (frame.mask ? 0x80 : 0)); + *reinterpret_cast(cur) = uint64_t(frame.length); + cur += 8; + } + + if (frame.mask) { + auto seed = system_clock::now().time_since_epoch().count(); + random_bytes_engine generator(seed); + + auto *maskingKey = cur; + std::generate(reinterpret_cast(maskingKey), + reinterpret_cast(maskingKey + 4), generator); + cur += 4; + + for (size_t i = 0; i < frame.length; ++i) + frame.payload[i] ^= maskingKey[i % 4]; + } + + outgoing(make_message(buffer, cur)); // header + return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload +} + +} // namespace rtc + +#endif diff --git a/src/wstransport.hpp b/src/wstransport.hpp new file mode 100644 index 0000000..82903d4 --- /dev/null +++ b/src/wstransport.hpp @@ -0,0 +1,83 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_WS_TRANSPORT_H +#define RTC_WS_TRANSPORT_H + +#if RTC_ENABLE_WEBSOCKET + +#include "include.hpp" +#include "transport.hpp" + +namespace rtc { + +class TcpTransport; +class TlsTransport; + +class WsTransport : public Transport { +public: + WsTransport(std::shared_ptr lower, string host, string path, + message_callback recvCallback, state_callback stateCallback); + ~WsTransport(); + + bool stop() override; + bool send(message_ptr message) override; + bool send(mutable_message_ptr message); + + void incoming(message_ptr message) override; + + void close(); + +private: + enum Opcode : uint8_t { + CONTINUATION = 0, + TEXT_FRAME = 1, + BINARY_FRAME = 2, + CLOSE = 8, + PING = 9, + PONG = 10, + }; + + struct Frame { + Opcode opcode = BINARY_FRAME; + byte *payload = nullptr; + size_t length = 0; + bool fin = true; + bool mask = true; + }; + + bool sendHttpRequest(); + size_t readHttpResponse(const byte *buffer, size_t size); + + size_t readFrame(byte *buffer, size_t size, Frame &frame); + void recvFrame(const Frame &frame); + bool sendFrame(const Frame &frame); + + const string mHost; + const string mPath; + + binary mBuffer; + binary mPartial; + Opcode mPartialOpcode; +}; + +} // namespace rtc + +#endif + +#endif diff --git a/test/main.cpp b/test/main.cpp index 738bb2d..f92037a 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -25,19 +25,19 @@ void test_capi(); int main(int argc, char **argv) { try { - std::cout << "*** Running connectivity test..." << std::endl; + cout << endl << "*** Running connectivity test..." << endl; test_connectivity(); - std::cout << "*** Finished connectivity test" << std::endl; + cout << "*** Finished connectivity test" << endl; } catch (const exception &e) { - std::cerr << "Connectivity test failed: " << e.what() << endl; + cerr << "Connectivity test failed: " << e.what() << endl; return -1; } try { - std::cout << "*** Running C API test..." << std::endl; + cout << endl << "*** Running C API test..." << endl; test_capi(); - std::cout << "*** Finished C API test" << std::endl; + cout << "*** Finished C API test" << endl; } catch (const exception &e) { - std::cerr << "C API test failed: " << e.what() << endl; + cerr << "C API test failed: " << e.what() << endl; return -1; } return 0;