From c1f91b2fff48cad10ca82deaba0d25daa0524bab Mon Sep 17 00:00:00 2001 From: Paul-Louis Ageneau Date: Sun, 15 Mar 2020 19:30:46 +0100 Subject: [PATCH] Added WebSocket --- CMakeLists.txt | 3 +- include/rtc/peerconnection.hpp | 4 +- include/rtc/websocket.hpp | 87 ++++++++++ src/dtlstransport.cpp | 39 ++--- src/dtlstransport.hpp | 8 - src/icetransport.cpp | 66 ++++--- src/icetransport.hpp | 28 +-- src/sctptransport.cpp | 20 +-- src/sctptransport.hpp | 11 +- src/tcptransport.cpp | 4 +- src/tcptransport.hpp | 2 +- src/tlstransport.cpp | 8 +- src/tlstransport.hpp | 2 +- src/transport.hpp | 16 +- src/websocket.cpp | 302 ++++++++++++++++++++++++--------- src/wstransport.cpp | 9 +- src/wstransport.hpp | 4 +- 17 files changed, 406 insertions(+), 207 deletions(-) create mode 100644 include/rtc/websocket.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0358942..1209958 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,10 +36,11 @@ set(LIBDATACHANNEL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp ) set(LIBDATACHANNEL_HEADERS diff --git a/include/rtc/peerconnection.hpp b/include/rtc/peerconnection.hpp index 334d63a..552cfb1 100644 --- a/include/rtc/peerconnection.hpp +++ b/include/rtc/peerconnection.hpp @@ -96,8 +96,6 @@ public: std::optional rtt(); private: - init_token mInitToken = Init::Token(); - std::shared_ptr initIceTransport(Description::Role role); std::shared_ptr initDtlsTransport(); std::shared_ptr initSctpTransport(); @@ -128,6 +126,8 @@ private: const Configuration mConfig; const std::shared_ptr mCertificate; + init_token mInitToken = Init::Token(); + std::optional mLocalDescription, mRemoteDescription; mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; diff --git a/include/rtc/websocket.hpp b/include/rtc/websocket.hpp new file mode 100644 index 0000000..7c9fb00 --- /dev/null +++ b/include/rtc/websocket.hpp @@ -0,0 +1,87 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_WEBSOCKET_H +#define RTC_WEBSOCKET_H + +#if ENABLE_WEBSOCKET + +#include "channel.hpp" +#include "include.hpp" +#include "init.hpp" +#include "message.hpp" +#include "queue.hpp" + +#include +#include +#include +#include + +namespace rtc { + +class TcpTransport; +class TlsTransport; +class WsTransport; + +class WebSocket final : public Channel, public std::enable_shared_from_this { +public: + WebSocket(); + WebSocket(const string &url); + ~WebSocket(); + + void open(const string &url); + void close() override; + bool send(const std::variant &data) override; + + bool isOpen() const override; + bool isClosed() const override; + size_t maxMessageSize() const override; + + // Extended API + std::optional> receive() override; + size_t availableAmount() const override; // total size available to receive + +private: + void remoteClose(); + bool outgoing(mutable_message_ptr message); + + std::shared_ptr initTcpTransport(); + std::shared_ptr initTlsTransport(); + std::shared_ptr initWsTransport(); + void closeTransports(); + + init_token mInitToken = Init::Token(); + + std::shared_ptr mTcpTransport; + std::shared_ptr mTlsTransport; + std::shared_ptr mWsTransport; + std::recursive_mutex mInitMutex; + + string mScheme, mHost, mHostname, mService, mPath; + + std::atomic mIsOpen = false; + std::atomic mIsClosed = false; + + Queue mRecvQueue; + std::atomic mRecvAmount = 0; +}; +} // namespace rtc + +#endif + +#endif // NET_WEBSOCKET_H diff --git a/src/dtlstransport.cpp b/src/dtlstransport.cpp index cb57d51..8a22e84 100644 --- a/src/dtlstransport.cpp +++ b/src/dtlstransport.cpp @@ -62,11 +62,9 @@ void DtlsTransport::Cleanup() { } DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, - verifier_callback verifierCallback, - state_callback stateChangeCallback) - : Transport(lower), mCertificate(certificate), mState(State::Disconnected), - mVerifierCallback(std::move(verifierCallback)), - mStateChangeCallback(std::move(stateChangeCallback)) { + verifier_callback verifierCallback, state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate), + mVerifierCallback(std::move(verifierCallback)) { PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)"; @@ -113,8 +111,6 @@ DtlsTransport::~DtlsTransport() { gnutls_deinit(mSession); } -DtlsTransport::State DtlsTransport::state() const { return mState; } - bool DtlsTransport::stop() { if (!Transport::stop()) return false; @@ -126,7 +122,7 @@ bool DtlsTransport::stop() { } bool DtlsTransport::send(message_ptr message) { - if (!message || mState != State::Connected) + if (!message || state() != State::Connected) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -152,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } -void DtlsTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(state); -} - void DtlsTransport::runRecvLoop() { const size_t maxMtu = 4096; @@ -362,9 +353,8 @@ void DtlsTransport::Cleanup() { DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, verifier_callback verifierCallback, state_callback stateChangeCallback) - : Transport(lower), mCertificate(certificate), mState(State::Disconnected), - mVerifierCallback(std::move(verifierCallback)), - mStateChangeCallback(std::move(stateChangeCallback)) { + : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate), + mVerifierCallback(std::move(verifierCallback)) { PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)"; @@ -445,10 +435,8 @@ bool DtlsTransport::stop() { return true; } -DtlsTransport::State DtlsTransport::state() const { return mState; } - bool DtlsTransport::send(message_ptr message) { - if (!message || mState != State::Connected) + if (!message || state() != State::Connected) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -467,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } -void DtlsTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(state); -} - void DtlsTransport::runRecvLoop() { const size_t maxMtu = 4096; try { @@ -490,7 +473,7 @@ void DtlsTransport::runRecvLoop() { auto message = *mIncomingQueue.pop(); BIO_write(mInBio, message->data(), message->size()); - if (mState == State::Connecting) { + if (state() == State::Connecting) { // Continue the handshake int ret = SSL_do_handshake(mSsl); if (!check_openssl_ret(mSsl, ret, "Handshake failed")) @@ -515,7 +498,7 @@ void DtlsTransport::runRecvLoop() { // No more messages pending, retransmit and rearm timeout if connecting std::optional duration; - if (mState == State::Connecting) { + if (state() == State::Connecting) { // Warning: This function breaks the usual return value convention int ret = DTLSv1_handle_timeout(mSsl); if (ret < 0) { @@ -525,7 +508,7 @@ void DtlsTransport::runRecvLoop() { } struct timeval timeout = {}; - if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) { + if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) { duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000); // Also handle handshake timeout manually because OpenSSL actually doesn't... // OpenSSL backs off exponentially in base 2 starting from the recommended 1s @@ -546,7 +529,7 @@ void DtlsTransport::runRecvLoop() { PLOG_ERROR << "DTLS recv: " << e.what(); } - if (mState == State::Connected) { + if (state() == State::Connected) { PLOG_INFO << "DTLS disconnected"; changeState(State::Disconnected); recv(nullptr); diff --git a/src/dtlstransport.hpp b/src/dtlstransport.hpp index 54a1310..9389069 100644 --- a/src/dtlstransport.hpp +++ b/src/dtlstransport.hpp @@ -46,33 +46,25 @@ public: static void Init(); static void Cleanup(); - enum class State { Disconnected, Connecting, Connected, Failed }; - using verifier_callback = std::function; - using state_callback = std::function; DtlsTransport(std::shared_ptr lower, std::shared_ptr certificate, verifier_callback verifierCallback, state_callback stateChangeCallback); ~DtlsTransport(); - State state() const; - bool stop() override; bool send(message_ptr message) override; // false if dropped private: void incoming(message_ptr message) override; - void changeState(State state); void runRecvLoop(); const std::shared_ptr mCertificate; Queue mIncomingQueue; - std::atomic mState; std::thread mRecvThread; verifier_callback mVerifierCallback; - state_callback mStateChangeCallback; #if USE_GNUTLS gnutls_session_t mSession; diff --git a/src/icetransport.cpp b/src/icetransport.cpp index 22ef719..b55b51d 100644 --- a/src/icetransport.cpp +++ b/src/icetransport.cpp @@ -48,9 +48,8 @@ namespace rtc { IceTransport::IceTransport(const Configuration &config, Description::Role role, candidate_callback candidateCallback, state_callback stateChangeCallback, gathering_state_callback gatheringStateChangeCallback) - : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New), - mCandidateCallback(std::move(candidateCallback)), - mStateChangeCallback(std::move(stateChangeCallback)), + : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"), + mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)), mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), mAgent(nullptr, nullptr) { @@ -108,8 +107,6 @@ bool IceTransport::stop() { Description::Role IceTransport::role() const { return mRole; } -IceTransport::State IceTransport::state() const { return mState; } - Description IceTransport::getLocalDescription(Description::Type type) const { char sdp[JUICE_MAX_SDP_STRING_LEN]; if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0) @@ -161,7 +158,8 @@ std::optional IceTransport::getRemoteAddress() const { } bool IceTransport::send(message_ptr message) { - if (!message || (mState != State::Connected && mState != State::Completed)) + auto s = state(); + if (!message || (s != State::Connected && s != State::Completed)) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -173,18 +171,29 @@ bool IceTransport::outgoing(message_ptr message) { message->size()) >= 0; } -void IceTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(mState); -} - void IceTransport::changeGatheringState(GatheringState state) { if (mGatheringState.exchange(state) != state) mGatheringStateChangeCallback(mGatheringState); } void IceTransport::processStateChange(unsigned int state) { - changeState(static_cast(state)); + switch (state) { + case JUICE_STATE_DISCONNECTED: + changeState(State::Disconnected); + break; + case JUICE_STATE_CONNECTING: + changeState(State::Connecting); + break; + case JUICE_STATE_CONNECTED: + changeState(State::Connected); + break; + case JUICE_STATE_COMPLETED: + changeState(State::Completed); + break; + case JUICE_STATE_FAILED: + changeState(State::Failed); + break; + }; } void IceTransport::processCandidate(const string &candidate) { @@ -263,9 +272,8 @@ namespace rtc { IceTransport::IceTransport(const Configuration &config, Description::Role role, candidate_callback candidateCallback, state_callback stateChangeCallback, gathering_state_callback gatheringStateChangeCallback) - : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New), - mCandidateCallback(std::move(candidateCallback)), - mStateChangeCallback(std::move(stateChangeCallback)), + : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"), + mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)), mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) { @@ -457,8 +465,6 @@ bool IceTransport::stop() { Description::Role IceTransport::role() const { return mRole; } -IceTransport::State IceTransport::state() const { return mState; } - Description IceTransport::getLocalDescription(Description::Type type) const { // RFC 8445: The initiating agent that started the ICE processing MUST take the controlling // role, and the other MUST take the controlled role. @@ -529,7 +535,8 @@ std::optional IceTransport::getRemoteAddress() const { } bool IceTransport::send(message_ptr message) { - if (!message || (mState != State::Connected && mState != State::Completed)) + auto s = state(); + if (!message || (s != State::Connected && s != State::Completed)) return false; PLOG_VERBOSE << "Send size=" << message->size(); @@ -541,11 +548,6 @@ bool IceTransport::outgoing(message_ptr message) { reinterpret_cast(message->data())) >= 0; } -void IceTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(mState); -} - void IceTransport::changeGatheringState(GatheringState state) { if (mGatheringState.exchange(state) != state) mGatheringStateChangeCallback(mGatheringState); @@ -576,7 +578,23 @@ void IceTransport::processStateChange(unsigned int state) { mTimeoutId = 0; } - changeState(static_cast(state)); + switch (state) { + case NICE_COMPONENT_STATE_DISCONNECTED: + changeState(State::Disconnected); + break; + case NICE_COMPONENT_STATE_CONNECTING: + changeState(State::Connecting); + break; + case NICE_COMPONENT_STATE_CONNECTED: + changeState(State::Connected); + break; + case NICE_COMPONENT_STATE_READY: + changeState(State::Completed); + break; + case NICE_COMPONENT_STATE_FAILED: + changeState(State::Failed); + break; + }; } string IceTransport::AddressToString(const NiceAddress &addr) { diff --git a/src/icetransport.hpp b/src/icetransport.hpp index 46375be..ce0d3a4 100644 --- a/src/icetransport.hpp +++ b/src/icetransport.hpp @@ -40,29 +40,9 @@ namespace rtc { class IceTransport : public Transport { public: -#if USE_JUICE - enum class State : unsigned int{ - Disconnected = JUICE_STATE_DISCONNECTED, - Connecting = JUICE_STATE_CONNECTING, - Connected = JUICE_STATE_CONNECTED, - Completed = JUICE_STATE_COMPLETED, - Failed = JUICE_STATE_FAILED, - }; -#else - enum class State : unsigned int { - Disconnected = NICE_COMPONENT_STATE_DISCONNECTED, - Connecting = NICE_COMPONENT_STATE_CONNECTING, - Connected = NICE_COMPONENT_STATE_CONNECTED, - Completed = NICE_COMPONENT_STATE_READY, - Failed = NICE_COMPONENT_STATE_FAILED, - }; - - bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote); -#endif enum class GatheringState { New = 0, InProgress = 1, Complete = 2 }; using candidate_callback = std::function; - using state_callback = std::function; using gathering_state_callback = std::function; IceTransport(const Configuration &config, Description::Role role, @@ -71,7 +51,6 @@ public: ~IceTransport(); Description::Role role() const; - State state() const; GatheringState gatheringState() const; Description getLocalDescription(Description::Type type) const; void setRemoteDescription(const Description &description); @@ -84,10 +63,13 @@ public: bool stop() override; bool send(message_ptr message) override; // false if dropped +#if !USE_JUICE + bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote); +#endif + private: bool outgoing(message_ptr message) override; - void changeState(State state); void changeGatheringState(GatheringState state); void processStateChange(unsigned int state); @@ -98,11 +80,9 @@ private: Description::Role mRole; string mMid; std::chrono::milliseconds mTrickleTimeout; - std::atomic mState; std::atomic mGatheringState; candidate_callback mCandidateCallback; - state_callback mStateChangeCallback; gathering_state_callback mGatheringStateChangeCallback; #if USE_JUICE diff --git a/src/sctptransport.cpp b/src/sctptransport.cpp index a19d5b1..fd95386 100644 --- a/src/sctptransport.cpp +++ b/src/sctptransport.cpp @@ -67,9 +67,8 @@ void SctpTransport::Cleanup() { usrsctp_finish(); } SctpTransport::SctpTransport(std::shared_ptr lower, uint16_t port, message_callback recvCallback, amount_callback bufferedAmountCallback, state_callback stateChangeCallback) - : Transport(lower), mPort(port), mSendQueue(0, message_size_func), - mBufferedAmountCallback(std::move(bufferedAmountCallback)), - mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) { + : Transport(lower, std::move(stateChangeCallback)), mPort(port), + mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) { onRecv(recvCallback); PLOG_DEBUG << "Initializing SCTP transport"; @@ -176,8 +175,6 @@ SctpTransport::~SctpTransport() { usrsctp_deregister_address(this); } -SctpTransport::State SctpTransport::state() const { return mState; } - bool SctpTransport::stop() { if (!Transport::stop()) return false; @@ -265,7 +262,7 @@ void SctpTransport::incoming(message_ptr message) { // to be sent on our side (i.e. the local INIT) before proceeding. { std::unique_lock lock(mWriteMutex); - mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; }); + mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; }); } if (!message) { @@ -279,11 +276,6 @@ void SctpTransport::incoming(message_ptr message) { usrsctp_conninput(this, message->data(), message->size(), 0); } -void SctpTransport::changeState(State state) { - if (mState.exchange(state) != state) - mStateChangeCallback(state); -} - bool SctpTransport::trySendQueue() { // Requires mSendMutex to be locked while (auto next = mSendQueue.peek()) { @@ -298,7 +290,7 @@ bool SctpTransport::trySendQueue() { bool SctpTransport::trySendMessage(message_ptr message) { // Requires mSendMutex to be locked - if (!mSock || mState != State::Connected) + if (!mSock || state() != State::Connected) return false; uint32_t ppid; @@ -410,7 +402,7 @@ void SctpTransport::sendReset(uint16_t streamId) { if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) { std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp... mWrittenCondition.wait_for(lock, 1000ms, - [&]() { return mWritten || mState != State::Connected; }); + [&]() { return mWritten || state() != State::Connected; }); } else if (errno == EINVAL) { PLOG_VERBOSE << "SCTP stream " << streamId << " already reset"; } else { @@ -567,7 +559,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s PLOG_INFO << "SCTP connected"; changeState(State::Connected); } else { - if (mState == State::Connecting) { + if (state() == State::Connecting) { PLOG_ERROR << "SCTP connection failed"; changeState(State::Failed); } else { diff --git a/src/sctptransport.hpp b/src/sctptransport.hpp index e751277..c22e9be 100644 --- a/src/sctptransport.hpp +++ b/src/sctptransport.hpp @@ -38,17 +38,12 @@ public: static void Init(); static void Cleanup(); - enum class State { Disconnected, Connecting, Connected, Failed }; - using amount_callback = std::function; - using state_callback = std::function; SctpTransport(std::shared_ptr lower, uint16_t port, message_callback recvCallback, amount_callback bufferedAmountCallback, state_callback stateChangeCallback); ~SctpTransport(); - State state() const; - bool stop() override; bool send(message_ptr message) override; // false if buffered void close(unsigned int stream); @@ -76,7 +71,6 @@ private: void connect(); void shutdown(); void incoming(message_ptr message) override; - void changeState(State state); bool trySendQueue(); bool trySendMessage(message_ptr message); @@ -105,14 +99,11 @@ private: std::atomic mWritten = false; // written outside lock bool mWrittenOnce = false; - state_callback mStateChangeCallback; - std::atomic mState; + binary mPartialRecv, mPartialStringData, mPartialBinaryData; // Stats std::atomic mBytesSent = 0, mBytesReceived = 0; - binary mPartialRecv, mPartialStringData, mPartialBinaryData; - static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, struct sctp_rcvinfo recv_info, int flags, void *user_data); static int SendCallback(struct socket *sock, uint32_t sb_free); diff --git a/src/tcptransport.cpp b/src/tcptransport.cpp index f33df3d..2dd8877 100644 --- a/src/tcptransport.cpp +++ b/src/tcptransport.cpp @@ -24,8 +24,8 @@ namespace rtc { using std::to_string; -TcpTransport::TcpTransport(const string &hostname, const string &service) - : mHostname(hostname), mService(service) { +TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback) + : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) { mThread = std::thread(&TcpTransport::runLoop, this); } diff --git a/src/tcptransport.hpp b/src/tcptransport.hpp index 7a2cecf..fdb9afd 100644 --- a/src/tcptransport.hpp +++ b/src/tcptransport.hpp @@ -34,7 +34,7 @@ namespace rtc { class TcpTransport : public Transport { public: - TcpTransport(const string &hostname, const string &service); + TcpTransport(const string &hostname, const string &service, state_callback callback); ~TcpTransport(); bool stop() override; diff --git a/src/tlstransport.cpp b/src/tlstransport.cpp index 4237534..8d73f54 100644 --- a/src/tlstransport.cpp +++ b/src/tlstransport.cpp @@ -61,7 +61,8 @@ void TlsTransport::Cleanup() { // Nothing to do } -TlsTransport::TlsTransport(shared_ptr lower, string host) : Transport(lower) { +TlsTransport::TlsTransport(shared_ptr lower, string host, state_callback callback) + : Transport(lower, std::move(callback)) { PLOG_DEBUG << "Initializing TLS transport (GnuTLS)"; @@ -82,6 +83,7 @@ TlsTransport::TlsTransport(shared_ptr lower, string host) : Transp gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size()); mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); + registerIncoming(); } catch (...) { @@ -271,10 +273,10 @@ void TlsTransport::Cleanup() { // Nothing to do } -TlsTransport::TlsTransport(shared_ptr lower, string host) : Transport(lower) { +TlsTransport::TlsTransport(shared_ptr lower, string host, state_callback callback) + : Transport(lower, std::move(callback)) { PLOG_DEBUG << "Initializing TLS transport (OpenSSL)"; - GlobalInit(); if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible throw std::runtime_error("Failed to create SSL context"); diff --git a/src/tlstransport.hpp b/src/tlstransport.hpp index 6652953..75b889d 100644 --- a/src/tlstransport.hpp +++ b/src/tlstransport.hpp @@ -41,7 +41,7 @@ class TcpTransport; class TlsTransport : public Transport { public: - TlsTransport(std::shared_ptr lower, string host); + TlsTransport(std::shared_ptr lower, string host, state_callback callback); ~TlsTransport(); bool stop() override; diff --git a/src/transport.hpp b/src/transport.hpp index fd3acaa..a19e6a6 100644 --- a/src/transport.hpp +++ b/src/transport.hpp @@ -32,7 +32,13 @@ using namespace std::placeholders; class Transport { public: - Transport(std::shared_ptr lower = nullptr) : mLower(std::move(lower)) {} + enum class State { Disconnected, Connecting, Connected, Completed, Failed }; + using state_callback = std::function; + + Transport(std::shared_ptr lower = nullptr, state_callback callback = nullptr) + : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) { + } + virtual ~Transport() { stop(); if (mLower) @@ -49,11 +55,16 @@ public: } void onRecv(message_callback callback) { mRecvCallback = std::move(callback); } + State state() const { return mState; } virtual bool send(message_ptr message) { return outgoing(message); } protected: void recv(message_ptr message) { mRecvCallback(message); } + void changeState(State state) { + if (mState.exchange(state) != state) + mStateChangeCallback(state); + } virtual void incoming(message_ptr message) { recv(message); } virtual bool outgoing(message_ptr message) { @@ -65,7 +76,10 @@ protected: private: std::shared_ptr mLower; + synchronized_callback mStateChangeCallback; synchronized_callback mRecvCallback; + + std::atomic mState = State::Disconnected; std::atomic mShutdown = false; }; diff --git a/src/websocket.cpp b/src/websocket.cpp index fc3106b..94f309e 100644 --- a/src/websocket.cpp +++ b/src/websocket.cpp @@ -1,100 +1,238 @@ -/************************************************************************* - * Copyright (C) 2017-2018 by Paul-Louis Ageneau * - * paul-louis (at) ageneau (dot) org * - * * - * This file is part of Plateform. * - * * - * Plateform is free software: you can redistribute it and/or modify * - * it under the terms of the GNU Affero General Public License as * - * published by the Free Software Foundation, either version 3 of * - * the License, or (at your option) any later version. * - * * - * Plateform is distributed in the hope that it will be useful, but * - * WITHOUT ANY WARRANTY; without even the implied warranty of * - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * - * GNU Affero General Public License for more details. * - * * - * You should have received a copy of the GNU Affero General Public * - * License along with Plateform. * - * If not, see . * - *************************************************************************/ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ -#include "net/websocket.hpp" +#if ENABLE_WEBSOCKET -#include -#include +#include "include.hpp" +#include "websocket.hpp" -const size_t DEFAULT_MAX_PAYLOAD_SIZE = 16384; // 16 KB +#include "tcptransport.hpp" +#include "tlstransport.hpp" +#include "wstransport.hpp" -namespace net { +#include -WebSocket::WebSocket(void) : mMaxPayloadSize(DEFAULT_MAX_PAYLOAD_SIZE) {} +namespace rtc { + +WebSocket::WebSocket() {} WebSocket::WebSocket(const string &url) : WebSocket() { open(url); } -WebSocket::~WebSocket(void) {} +WebSocket::~WebSocket() { close(); } void WebSocket::open(const string &url) { - close(); + static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)"; + static std::regex regex(rs, std::regex::extended); - mUrl = url; - mThread = std::thread(&WebSocket::run, this); -} + std::smatch match; + if (!std::regex_match(url, match, regex)) + throw std::invalid_argument("Malformed WebSocket URL: " + url); -void WebSocket::close(void) { - mWebSocket.close(); - if (mThread.joinable()) - mThread.join(); - mConnected = false; -} + mScheme = match[2]; + if (mScheme != "ws" && mScheme != "wss") + throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme); -bool WebSocket::isOpen(void) const { return mConnected; } - -bool WebSocket::isClosed(void) const { return !mThread.joinable(); } - -void WebSocket::setMaxPayloadSize(size_t size) { mMaxPayloadSize = size; } - -bool WebSocket::send(const std::variant &data) { - if (!std::holds_alternative(data)) - throw std::runtime_error("WebSocket string messages are not supported"); - - mWebSocket.write(std::get(data)); - return true; -} - -std::optional> WebSocket::receive() { - if (!mQueue.empty()) - return mQueue.pop(); - else - return std::nullopt; -} - -void WebSocket::run(void) { - if (mUrl.empty()) - return; - - try { - mWebSocket.connect(mUrl); - - mConnected = true; - triggerOpen(); - - while (true) { - binary payload; - if (!mWebSocket.read(payload, mMaxPayloadSize)) - break; - mQueue.push(std::move(payload)); - triggerAvailable(mQueue.size()); - } - } catch (const std::exception &e) { - triggerError(e.what()); + mHost = match[4]; + if (auto pos = mHost.find(':'); pos != string::npos) { + mHostname = mHost.substr(0, pos); + mService = mHost.substr(pos + 1); + } else { + mHostname = mHost; + mService = mScheme == "ws" ? "80" : "443"; } - mWebSocket.close(); + mPath = match[5]; + if (string query = match[7]; !query.empty()) + mPath += "?" + query; - if (mConnected) - triggerClosed(); - mConnected = false; + initTcpTransport(); } -} // namespace net +void WebSocket::close() { + resetCallbacks(); + closeTransports(); +} + +void WebSocket::remoteClose() { + mIsOpen = false; + if (!mIsClosed.exchange(true)) + triggerClosed(); +} + +bool WebSocket::send(const std::variant &data) { + return std::visit( + [&](const auto &d) { + using T = std::decay_t; + constexpr auto type = std::is_same_v ? Message::String : Message::Binary; + auto *b = reinterpret_cast(d.data()); + return outgoing(std::make_shared(b, b + d.size(), type)); + }, + data); +} + +bool WebSocket::isOpen() const { return mIsOpen; } + +bool WebSocket::isClosed() const { return mIsClosed; } + +size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; } + +std::optional> WebSocket::receive() { return nullopt; } + +size_t WebSocket::availableAmount() const { return 0; } + +bool WebSocket::outgoing(mutable_message_ptr message) { + if (mIsClosed || !mWsTransport) + throw std::runtime_error("WebSocket is closed"); + + if (message->size() > maxMessageSize()) + throw std::runtime_error("Message size exceeds limit"); + + return mWsTransport->send(message); +} + +std::shared_ptr WebSocket::initTcpTransport() { + using State = TcpTransport::State; + try { + std::lock_guard lock(mInitMutex); + if (auto transport = std::atomic_load(&mTcpTransport)) + return transport; + + auto transport = std::make_shared(mHostname, mService, [this](State state) { + switch (state) { + case State::Connected: + if (mScheme == "ws") + initWsTransport(); + else + initTlsTransport(); + break; + case State::Failed: + // TODO + break; + case State::Disconnected: + // TODO + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mTcpTransport, transport); + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + // TODO + throw std::runtime_error("TCP transport initialization failed"); + } +} + +std::shared_ptr WebSocket::initTlsTransport() { + using State = TlsTransport::State; + try { + std::lock_guard lock(mInitMutex); + if (auto transport = std::atomic_load(&mTlsTransport)) + return transport; + + auto lower = std::atomic_load(&mTcpTransport); + auto transport = std::make_shared(lower, mHost, [this](State state) { + switch (state) { + case State::Connected: + initWsTransport(); + break; + case State::Failed: + // TODO + break; + case State::Disconnected: + // TODO + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mTlsTransport, transport); + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + // TODO + throw std::runtime_error("TLS transport initialization failed"); + } +} + +std::shared_ptr WebSocket::initWsTransport() { + using State = WsTransport::State; + try { + std::lock_guard lock(mInitMutex); + if (auto transport = std::atomic_load(&mWsTransport)) + return transport; + + std::shared_ptr lower = std::atomic_load(&mTlsTransport); + if (!lower) + lower = std::atomic_load(&mTcpTransport); + auto transport = std::make_shared(lower, mHost, mPath, [this](State state) { + switch (state) { + case State::Connected: + triggerOpen(); + break; + case State::Failed: + // TODO + break; + case State::Disconnected: + // TODO + break; + default: + // Ignore + break; + } + }); + std::atomic_store(&mWsTransport, transport); + return transport; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + // TODO + throw std::runtime_error("WebSocket transport initialization failed"); + } +} + +void closeTransports() { + mIsOpen = false; + mIsClosed = true; + + // Pass the references to a thread, allowing to terminate a transport from its own thread + auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr)); + auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr)); + auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr)); + if (ws || dtls || tcp) { + std::thread t([ws, dtls, tcp]() mutable { + if (ws) + ws->stop(); + if (dtls) + dtls->stop(); + if (tcp) + tcp->stop(); + + ws.reset(); + dtls.reset(); + tcp.reset(); + }); + t.detach(); + } +} + +} // namespace rtc + +#endif diff --git a/src/wstransport.cpp b/src/wstransport.cpp index e57013f..1962968 100644 --- a/src/wstransport.cpp +++ b/src/wstransport.cpp @@ -53,11 +53,12 @@ using std::to_string; using random_bytes_engine = std::independent_bits_engine; -WsTransport::WsTransport(std::shared_ptr lower, string host, string path) - : Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {} +WsTransport::WsTransport(std::shared_ptr lower, string host, string path, + state_callback callback) + : Transport(lower, std::move(callback)), mHost(std::move(host)), mPath(std::move(path)) { -WsTransport::WsTransport(std::shared_ptr lower, string host, string path) - : Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {} + registerIncoming(); +} WsTransport::~WsTransport() {} diff --git a/src/wstransport.hpp b/src/wstransport.hpp index fd5ff01..21b48a0 100644 --- a/src/wstransport.hpp +++ b/src/wstransport.hpp @@ -31,8 +31,8 @@ class TlsTransport; class WsTransport : public Transport { public: - WsTransport(std::shared_ptr lower, string host, string path); - WsTransport(std::shared_ptr lower, string host, string path); + WsTransport(std::shared_ptr lower, string host, string path, + state_callback callback); ~WsTransport(); void stop() override;