Added WebSocket

This commit is contained in:
Paul-Louis Ageneau
2020-03-15 19:30:46 +01:00
parent b06b33234b
commit c1f91b2fff
17 changed files with 406 additions and 207 deletions

View File

@ -36,10 +36,11 @@ set(LIBDATACHANNEL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
)
set(LIBDATACHANNEL_HEADERS

View File

@ -96,8 +96,6 @@ public:
std::optional<std::chrono::milliseconds> rtt();
private:
init_token mInitToken = Init::Token();
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
std::shared_ptr<DtlsTransport> initDtlsTransport();
std::shared_ptr<SctpTransport> initSctpTransport();
@ -128,6 +126,8 @@ private:
const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate;
init_token mInitToken = Init::Token();
std::optional<Description> mLocalDescription, mRemoteDescription;
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;

87
include/rtc/websocket.hpp Normal file
View File

@ -0,0 +1,87 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_WEBSOCKET_H
#define RTC_WEBSOCKET_H
#if ENABLE_WEBSOCKET
#include "channel.hpp"
#include "include.hpp"
#include "init.hpp"
#include "message.hpp"
#include "queue.hpp"
#include <atomic>
#include <optional>
#include <thread>
#include <variant>
namespace rtc {
class TcpTransport;
class TlsTransport;
class WsTransport;
class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
public:
WebSocket();
WebSocket(const string &url);
~WebSocket();
void open(const string &url);
void close() override;
bool send(const std::variant<binary, string> &data) override;
bool isOpen() const override;
bool isClosed() const override;
size_t maxMessageSize() const override;
// Extended API
std::optional<std::variant<binary, string>> receive() override;
size_t availableAmount() const override; // total size available to receive
private:
void remoteClose();
bool outgoing(mutable_message_ptr message);
std::shared_ptr<TcpTransport> initTcpTransport();
std::shared_ptr<TlsTransport> initTlsTransport();
std::shared_ptr<WsTransport> initWsTransport();
void closeTransports();
init_token mInitToken = Init::Token();
std::shared_ptr<TcpTransport> mTcpTransport;
std::shared_ptr<TlsTransport> mTlsTransport;
std::shared_ptr<WsTransport> mWsTransport;
std::recursive_mutex mInitMutex;
string mScheme, mHost, mHostname, mService, mPath;
std::atomic<bool> mIsOpen = false;
std::atomic<bool> mIsClosed = false;
Queue<message_ptr> mRecvQueue;
std::atomic<size_t> mRecvAmount = 0;
};
} // namespace rtc
#endif
#endif // NET_WEBSOCKET_H

View File

@ -62,11 +62,9 @@ void DtlsTransport::Cleanup() {
}
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback,
state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
mVerifierCallback(std::move(verifierCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
@ -113,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
gnutls_deinit(mSession);
}
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::stop() {
if (!Transport::stop())
return false;
@ -126,7 +122,7 @@ bool DtlsTransport::stop() {
}
bool DtlsTransport::send(message_ptr message) {
if (!message || mState != State::Connected)
if (!message || state() != State::Connected)
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -152,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
mIncomingQueue.push(message);
}
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096;
@ -362,9 +353,8 @@ void DtlsTransport::Cleanup() {
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
mVerifierCallback(std::move(verifierCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
@ -445,10 +435,8 @@ bool DtlsTransport::stop() {
return true;
}
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::send(message_ptr message) {
if (!message || mState != State::Connected)
if (!message || state() != State::Connected)
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -467,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
mIncomingQueue.push(message);
}
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096;
try {
@ -490,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
auto message = *mIncomingQueue.pop();
BIO_write(mInBio, message->data(), message->size());
if (mState == State::Connecting) {
if (state() == State::Connecting) {
// Continue the handshake
int ret = SSL_do_handshake(mSsl);
if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
@ -515,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
// No more messages pending, retransmit and rearm timeout if connecting
std::optional<milliseconds> duration;
if (mState == State::Connecting) {
if (state() == State::Connecting) {
// Warning: This function breaks the usual return value convention
int ret = DTLSv1_handle_timeout(mSsl);
if (ret < 0) {
@ -525,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
}
struct timeval timeout = {};
if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
// Also handle handshake timeout manually because OpenSSL actually doesn't...
// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
@ -546,7 +529,7 @@ void DtlsTransport::runRecvLoop() {
PLOG_ERROR << "DTLS recv: " << e.what();
}
if (mState == State::Connected) {
if (state() == State::Connected) {
PLOG_INFO << "DTLS disconnected";
changeState(State::Disconnected);
recv(nullptr);

View File

@ -46,33 +46,25 @@ public:
static void Init();
static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using verifier_callback = std::function<bool(const std::string &fingerprint)>;
using state_callback = std::function<void(State state)>;
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback);
~DtlsTransport();
State state() const;
bool stop() override;
bool send(message_ptr message) override; // false if dropped
private:
void incoming(message_ptr message) override;
void changeState(State state);
void runRecvLoop();
const std::shared_ptr<Certificate> mCertificate;
Queue<message_ptr> mIncomingQueue;
std::atomic<State> mState;
std::thread mRecvThread;
verifier_callback mVerifierCallback;
state_callback mStateChangeCallback;
#if USE_GNUTLS
gnutls_session_t mSession;

View File

@ -48,9 +48,8 @@ namespace rtc {
IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback)
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
mAgent(nullptr, nullptr) {
@ -108,8 +107,6 @@ bool IceTransport::stop() {
Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const {
char sdp[JUICE_MAX_SDP_STRING_LEN];
if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
@ -161,7 +158,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
}
bool IceTransport::send(message_ptr message) {
if (!message || (mState != State::Connected && mState != State::Completed))
auto s = state();
if (!message || (s != State::Connected && s != State::Completed))
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -173,18 +171,29 @@ bool IceTransport::outgoing(message_ptr message) {
message->size()) >= 0;
}
void IceTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(mGatheringState);
}
void IceTransport::processStateChange(unsigned int state) {
changeState(static_cast<State>(state));
switch (state) {
case JUICE_STATE_DISCONNECTED:
changeState(State::Disconnected);
break;
case JUICE_STATE_CONNECTING:
changeState(State::Connecting);
break;
case JUICE_STATE_CONNECTED:
changeState(State::Connected);
break;
case JUICE_STATE_COMPLETED:
changeState(State::Completed);
break;
case JUICE_STATE_FAILED:
changeState(State::Failed);
break;
};
}
void IceTransport::processCandidate(const string &candidate) {
@ -263,9 +272,8 @@ namespace rtc {
IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback)
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
@ -457,8 +465,6 @@ bool IceTransport::stop() {
Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const {
// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
// role, and the other MUST take the controlled role.
@ -529,7 +535,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
}
bool IceTransport::send(message_ptr message) {
if (!message || (mState != State::Connected && mState != State::Completed))
auto s = state();
if (!message || (s != State::Connected && s != State::Completed))
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -541,11 +548,6 @@ bool IceTransport::outgoing(message_ptr message) {
reinterpret_cast<const char *>(message->data())) >= 0;
}
void IceTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(mGatheringState);
@ -576,7 +578,23 @@ void IceTransport::processStateChange(unsigned int state) {
mTimeoutId = 0;
}
changeState(static_cast<State>(state));
switch (state) {
case NICE_COMPONENT_STATE_DISCONNECTED:
changeState(State::Disconnected);
break;
case NICE_COMPONENT_STATE_CONNECTING:
changeState(State::Connecting);
break;
case NICE_COMPONENT_STATE_CONNECTED:
changeState(State::Connected);
break;
case NICE_COMPONENT_STATE_READY:
changeState(State::Completed);
break;
case NICE_COMPONENT_STATE_FAILED:
changeState(State::Failed);
break;
};
}
string IceTransport::AddressToString(const NiceAddress &addr) {

View File

@ -40,29 +40,9 @@ namespace rtc {
class IceTransport : public Transport {
public:
#if USE_JUICE
enum class State : unsigned int{
Disconnected = JUICE_STATE_DISCONNECTED,
Connecting = JUICE_STATE_CONNECTING,
Connected = JUICE_STATE_CONNECTED,
Completed = JUICE_STATE_COMPLETED,
Failed = JUICE_STATE_FAILED,
};
#else
enum class State : unsigned int {
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
Connecting = NICE_COMPONENT_STATE_CONNECTING,
Connected = NICE_COMPONENT_STATE_CONNECTED,
Completed = NICE_COMPONENT_STATE_READY,
Failed = NICE_COMPONENT_STATE_FAILED,
};
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
#endif
enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
using candidate_callback = std::function<void(const Candidate &candidate)>;
using state_callback = std::function<void(State state)>;
using gathering_state_callback = std::function<void(GatheringState state)>;
IceTransport(const Configuration &config, Description::Role role,
@ -71,7 +51,6 @@ public:
~IceTransport();
Description::Role role() const;
State state() const;
GatheringState gatheringState() const;
Description getLocalDescription(Description::Type type) const;
void setRemoteDescription(const Description &description);
@ -84,10 +63,13 @@ public:
bool stop() override;
bool send(message_ptr message) override; // false if dropped
#if !USE_JUICE
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
#endif
private:
bool outgoing(message_ptr message) override;
void changeState(State state);
void changeGatheringState(GatheringState state);
void processStateChange(unsigned int state);
@ -98,11 +80,9 @@ private:
Description::Role mRole;
string mMid;
std::chrono::milliseconds mTrickleTimeout;
std::atomic<State> mState;
std::atomic<GatheringState> mGatheringState;
candidate_callback mCandidateCallback;
state_callback mStateChangeCallback;
gathering_state_callback mGatheringStateChangeCallback;
#if USE_JUICE

View File

@ -67,9 +67,8 @@ void SctpTransport::Cleanup() { usrsctp_finish(); }
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
message_callback recvCallback, amount_callback bufferedAmountCallback,
state_callback stateChangeCallback)
: Transport(lower), mPort(port), mSendQueue(0, message_size_func),
mBufferedAmountCallback(std::move(bufferedAmountCallback)),
mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
: Transport(lower, std::move(stateChangeCallback)), mPort(port),
mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
onRecv(recvCallback);
PLOG_DEBUG << "Initializing SCTP transport";
@ -176,8 +175,6 @@ SctpTransport::~SctpTransport() {
usrsctp_deregister_address(this);
}
SctpTransport::State SctpTransport::state() const { return mState; }
bool SctpTransport::stop() {
if (!Transport::stop())
return false;
@ -265,7 +262,7 @@ void SctpTransport::incoming(message_ptr message) {
// to be sent on our side (i.e. the local INIT) before proceeding.
{
std::unique_lock lock(mWriteMutex);
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
}
if (!message) {
@ -279,11 +276,6 @@ void SctpTransport::incoming(message_ptr message) {
usrsctp_conninput(this, message->data(), message->size(), 0);
}
void SctpTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
bool SctpTransport::trySendQueue() {
// Requires mSendMutex to be locked
while (auto next = mSendQueue.peek()) {
@ -298,7 +290,7 @@ bool SctpTransport::trySendQueue() {
bool SctpTransport::trySendMessage(message_ptr message) {
// Requires mSendMutex to be locked
if (!mSock || mState != State::Connected)
if (!mSock || state() != State::Connected)
return false;
uint32_t ppid;
@ -410,7 +402,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
mWrittenCondition.wait_for(lock, 1000ms,
[&]() { return mWritten || mState != State::Connected; });
[&]() { return mWritten || state() != State::Connected; });
} else if (errno == EINVAL) {
PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
} else {
@ -567,7 +559,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
PLOG_INFO << "SCTP connected";
changeState(State::Connected);
} else {
if (mState == State::Connecting) {
if (state() == State::Connecting) {
PLOG_ERROR << "SCTP connection failed";
changeState(State::Failed);
} else {

View File

@ -38,17 +38,12 @@ public:
static void Init();
static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
using state_callback = std::function<void(State state)>;
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
~SctpTransport();
State state() const;
bool stop() override;
bool send(message_ptr message) override; // false if buffered
void close(unsigned int stream);
@ -76,7 +71,6 @@ private:
void connect();
void shutdown();
void incoming(message_ptr message) override;
void changeState(State state);
bool trySendQueue();
bool trySendMessage(message_ptr message);
@ -105,14 +99,11 @@ private:
std::atomic<bool> mWritten = false; // written outside lock
bool mWrittenOnce = false;
state_callback mStateChangeCallback;
std::atomic<State> mState;
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
// Stats
std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
struct sctp_rcvinfo recv_info, int flags, void *user_data);
static int SendCallback(struct socket *sock, uint32_t sb_free);

View File

@ -24,8 +24,8 @@ namespace rtc {
using std::to_string;
TcpTransport::TcpTransport(const string &hostname, const string &service)
: mHostname(hostname), mService(service) {
TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
: Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
mThread = std::thread(&TcpTransport::runLoop, this);
}

View File

@ -34,7 +34,7 @@ namespace rtc {
class TcpTransport : public Transport {
public:
TcpTransport(const string &hostname, const string &service);
TcpTransport(const string &hostname, const string &service, state_callback callback);
~TcpTransport();
bool stop() override;

View File

@ -61,7 +61,8 @@ void TlsTransport::Cleanup() {
// Nothing to do
}
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) {
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
@ -82,6 +83,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transp
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
registerIncoming();
} catch (...) {
@ -271,10 +273,10 @@ void TlsTransport::Cleanup() {
// Nothing to do
}
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) {
PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
GlobalInit();
if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
throw std::runtime_error("Failed to create SSL context");

View File

@ -41,7 +41,7 @@ class TcpTransport;
class TlsTransport : public Transport {
public:
TlsTransport(std::shared_ptr<TcpTransport> lower, string host);
TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
~TlsTransport();
bool stop() override;

View File

@ -32,7 +32,13 @@ using namespace std::placeholders;
class Transport {
public:
Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
enum class State { Disconnected, Connecting, Connected, Completed, Failed };
using state_callback = std::function<void(State state)>;
Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
: mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
}
virtual ~Transport() {
stop();
if (mLower)
@ -49,11 +55,16 @@ public:
}
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
State state() const { return mState; }
virtual bool send(message_ptr message) { return outgoing(message); }
protected:
void recv(message_ptr message) { mRecvCallback(message); }
void changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
virtual void incoming(message_ptr message) { recv(message); }
virtual bool outgoing(message_ptr message) {
@ -65,7 +76,10 @@ protected:
private:
std::shared_ptr<Transport> mLower;
synchronized_callback<State> mStateChangeCallback;
synchronized_callback<message_ptr> mRecvCallback;
std::atomic<State> mState = State::Disconnected;
std::atomic<bool> mShutdown = false;
};

View File

@ -1,100 +1,238 @@
/*************************************************************************
* Copyright (C) 2017-2018 by Paul-Louis Ageneau *
* paul-louis (at) ageneau (dot) org *
* *
* This file is part of Plateform. *
* *
* Plateform is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Affero General Public License as *
* published by the Free Software Foundation, either version 3 of *
* the License, or (at your option) any later version. *
* *
* Plateform is distributed in the hope that it will be useful, but *
* WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Affero General Public License for more details. *
* *
* You should have received a copy of the GNU Affero General Public *
* License along with Plateform. *
* If not, see <http://www.gnu.org/licenses/>. *
*************************************************************************/
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#include "net/websocket.hpp"
#if ENABLE_WEBSOCKET
#include <exception>
#include <iostream>
#include "include.hpp"
#include "websocket.hpp"
const size_t DEFAULT_MAX_PAYLOAD_SIZE = 16384; // 16 KB
#include "tcptransport.hpp"
#include "tlstransport.hpp"
#include "wstransport.hpp"
namespace net {
#include <regex>
WebSocket::WebSocket(void) : mMaxPayloadSize(DEFAULT_MAX_PAYLOAD_SIZE) {}
namespace rtc {
WebSocket::WebSocket() {}
WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
WebSocket::~WebSocket(void) {}
WebSocket::~WebSocket() { close(); }
void WebSocket::open(const string &url) {
close();
static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
static std::regex regex(rs, std::regex::extended);
mUrl = url;
mThread = std::thread(&WebSocket::run, this);
std::smatch match;
if (!std::regex_match(url, match, regex))
throw std::invalid_argument("Malformed WebSocket URL: " + url);
mScheme = match[2];
if (mScheme != "ws" && mScheme != "wss")
throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
mHost = match[4];
if (auto pos = mHost.find(':'); pos != string::npos) {
mHostname = mHost.substr(0, pos);
mService = mHost.substr(pos + 1);
} else {
mHostname = mHost;
mService = mScheme == "ws" ? "80" : "443";
}
mPath = match[5];
if (string query = match[7]; !query.empty())
mPath += "?" + query;
initTcpTransport();
}
void WebSocket::close(void) {
mWebSocket.close();
if (mThread.joinable())
mThread.join();
mConnected = false;
void WebSocket::close() {
resetCallbacks();
closeTransports();
}
bool WebSocket::isOpen(void) const { return mConnected; }
bool WebSocket::isClosed(void) const { return !mThread.joinable(); }
void WebSocket::setMaxPayloadSize(size_t size) { mMaxPayloadSize = size; }
void WebSocket::remoteClose() {
mIsOpen = false;
if (!mIsClosed.exchange(true))
triggerClosed();
}
bool WebSocket::send(const std::variant<binary, string> &data) {
if (!std::holds_alternative<binary>(data))
throw std::runtime_error("WebSocket string messages are not supported");
mWebSocket.write(std::get<binary>(data));
return true;
return std::visit(
[&](const auto &d) {
using T = std::decay_t<decltype(d)>;
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
auto *b = reinterpret_cast<const byte *>(d.data());
return outgoing(std::make_shared<Message>(b, b + d.size(), type));
},
data);
}
std::optional<std::variant<binary, string>> WebSocket::receive() {
if (!mQueue.empty())
return mQueue.pop();
else
return std::nullopt;
bool WebSocket::isOpen() const { return mIsOpen; }
bool WebSocket::isClosed() const { return mIsClosed; }
size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
std::optional<std::variant<binary, string>> WebSocket::receive() { return nullopt; }
size_t WebSocket::availableAmount() const { return 0; }
bool WebSocket::outgoing(mutable_message_ptr message) {
if (mIsClosed || !mWsTransport)
throw std::runtime_error("WebSocket is closed");
if (message->size() > maxMessageSize())
throw std::runtime_error("Message size exceeds limit");
return mWsTransport->send(message);
}
void WebSocket::run(void) {
if (mUrl.empty())
return;
std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
using State = TcpTransport::State;
try {
mWebSocket.connect(mUrl);
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mTcpTransport))
return transport;
mConnected = true;
triggerOpen();
while (true) {
binary payload;
if (!mWebSocket.read(payload, mMaxPayloadSize))
auto transport = std::make_shared<TcpTransport>(mHostname, mService, [this](State state) {
switch (state) {
case State::Connected:
if (mScheme == "ws")
initWsTransport();
else
initTlsTransport();
break;
case State::Failed:
// TODO
break;
case State::Disconnected:
// TODO
break;
default:
// Ignore
break;
mQueue.push(std::move(payload));
triggerAvailable(mQueue.size());
}
});
std::atomic_store(&mTcpTransport, transport);
return transport;
} catch (const std::exception &e) {
triggerError(e.what());
PLOG_ERROR << e.what();
// TODO
throw std::runtime_error("TCP transport initialization failed");
}
mWebSocket.close();
if (mConnected)
triggerClosed();
mConnected = false;
}
} // namespace net
std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
using State = TlsTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mTlsTransport))
return transport;
auto lower = std::atomic_load(&mTcpTransport);
auto transport = std::make_shared<TlsTransport>(lower, mHost, [this](State state) {
switch (state) {
case State::Connected:
initWsTransport();
break;
case State::Failed:
// TODO
break;
case State::Disconnected:
// TODO
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mTlsTransport, transport);
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
// TODO
throw std::runtime_error("TLS transport initialization failed");
}
}
std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
using State = WsTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mWsTransport))
return transport;
std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
if (!lower)
lower = std::atomic_load(&mTcpTransport);
auto transport = std::make_shared<WsTransport>(lower, mHost, mPath, [this](State state) {
switch (state) {
case State::Connected:
triggerOpen();
break;
case State::Failed:
// TODO
break;
case State::Disconnected:
// TODO
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mWsTransport, transport);
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
// TODO
throw std::runtime_error("WebSocket transport initialization failed");
}
}
void closeTransports() {
mIsOpen = false;
mIsClosed = true;
// Pass the references to a thread, allowing to terminate a transport from its own thread
auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
if (ws || dtls || tcp) {
std::thread t([ws, dtls, tcp]() mutable {
if (ws)
ws->stop();
if (dtls)
dtls->stop();
if (tcp)
tcp->stop();
ws.reset();
dtls.reset();
tcp.reset();
});
t.detach();
}
}
} // namespace rtc
#endif

View File

@ -53,11 +53,12 @@ using std::to_string;
using random_bytes_engine =
std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
WsTransport::WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path)
: Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {}
WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
state_callback callback)
: Transport(lower, std::move(callback)), mHost(std::move(host)), mPath(std::move(path)) {
WsTransport::WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path)
: Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {}
registerIncoming();
}
WsTransport::~WsTransport() {}

View File

@ -31,8 +31,8 @@ class TlsTransport;
class WsTransport : public Transport {
public:
WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path);
WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path);
WsTransport(std::shared_ptr<Transport> lower, string host, string path,
state_callback callback);
~WsTransport();
void stop() override;