Added WebSocket

This commit is contained in:
Paul-Louis Ageneau
2020-03-15 19:30:46 +01:00
parent b06b33234b
commit c1f91b2fff
17 changed files with 406 additions and 207 deletions

View File

@ -36,10 +36,11 @@ set(LIBDATACHANNEL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
) )
set(LIBDATACHANNEL_HEADERS set(LIBDATACHANNEL_HEADERS

View File

@ -96,8 +96,6 @@ public:
std::optional<std::chrono::milliseconds> rtt(); std::optional<std::chrono::milliseconds> rtt();
private: private:
init_token mInitToken = Init::Token();
std::shared_ptr<IceTransport> initIceTransport(Description::Role role); std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
std::shared_ptr<DtlsTransport> initDtlsTransport(); std::shared_ptr<DtlsTransport> initDtlsTransport();
std::shared_ptr<SctpTransport> initSctpTransport(); std::shared_ptr<SctpTransport> initSctpTransport();
@ -128,6 +126,8 @@ private:
const Configuration mConfig; const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;
init_token mInitToken = Init::Token();
std::optional<Description> mLocalDescription, mRemoteDescription; std::optional<Description> mLocalDescription, mRemoteDescription;
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;

87
include/rtc/websocket.hpp Normal file
View File

@ -0,0 +1,87 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_WEBSOCKET_H
#define RTC_WEBSOCKET_H
#if ENABLE_WEBSOCKET
#include "channel.hpp"
#include "include.hpp"
#include "init.hpp"
#include "message.hpp"
#include "queue.hpp"
#include <atomic>
#include <optional>
#include <thread>
#include <variant>
namespace rtc {
class TcpTransport;
class TlsTransport;
class WsTransport;
class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
public:
WebSocket();
WebSocket(const string &url);
~WebSocket();
void open(const string &url);
void close() override;
bool send(const std::variant<binary, string> &data) override;
bool isOpen() const override;
bool isClosed() const override;
size_t maxMessageSize() const override;
// Extended API
std::optional<std::variant<binary, string>> receive() override;
size_t availableAmount() const override; // total size available to receive
private:
void remoteClose();
bool outgoing(mutable_message_ptr message);
std::shared_ptr<TcpTransport> initTcpTransport();
std::shared_ptr<TlsTransport> initTlsTransport();
std::shared_ptr<WsTransport> initWsTransport();
void closeTransports();
init_token mInitToken = Init::Token();
std::shared_ptr<TcpTransport> mTcpTransport;
std::shared_ptr<TlsTransport> mTlsTransport;
std::shared_ptr<WsTransport> mWsTransport;
std::recursive_mutex mInitMutex;
string mScheme, mHost, mHostname, mService, mPath;
std::atomic<bool> mIsOpen = false;
std::atomic<bool> mIsClosed = false;
Queue<message_ptr> mRecvQueue;
std::atomic<size_t> mRecvAmount = 0;
};
} // namespace rtc
#endif
#endif // NET_WEBSOCKET_H

View File

@ -62,11 +62,9 @@ void DtlsTransport::Cleanup() {
} }
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate, DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, verifier_callback verifierCallback, state_callback stateChangeCallback)
state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
: Transport(lower), mCertificate(certificate), mState(State::Disconnected), mVerifierCallback(std::move(verifierCallback)) {
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)"; PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
@ -113,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
gnutls_deinit(mSession); gnutls_deinit(mSession);
} }
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::stop() { bool DtlsTransport::stop() {
if (!Transport::stop()) if (!Transport::stop())
return false; return false;
@ -126,7 +122,7 @@ bool DtlsTransport::stop() {
} }
bool DtlsTransport::send(message_ptr message) { bool DtlsTransport::send(message_ptr message) {
if (!message || mState != State::Connected) if (!message || state() != State::Connected)
return false; return false;
PLOG_VERBOSE << "Send size=" << message->size(); PLOG_VERBOSE << "Send size=" << message->size();
@ -152,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
mIncomingQueue.push(message); mIncomingQueue.push(message);
} }
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() { void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096; const size_t maxMtu = 4096;
@ -362,9 +353,8 @@ void DtlsTransport::Cleanup() {
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate, DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback) verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected), : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
mVerifierCallback(std::move(verifierCallback)), mVerifierCallback(std::move(verifierCallback)) {
mStateChangeCallback(std::move(stateChangeCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)"; PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
@ -445,10 +435,8 @@ bool DtlsTransport::stop() {
return true; return true;
} }
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::send(message_ptr message) { bool DtlsTransport::send(message_ptr message) {
if (!message || mState != State::Connected) if (!message || state() != State::Connected)
return false; return false;
PLOG_VERBOSE << "Send size=" << message->size(); PLOG_VERBOSE << "Send size=" << message->size();
@ -467,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
mIncomingQueue.push(message); mIncomingQueue.push(message);
} }
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() { void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096; const size_t maxMtu = 4096;
try { try {
@ -490,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
auto message = *mIncomingQueue.pop(); auto message = *mIncomingQueue.pop();
BIO_write(mInBio, message->data(), message->size()); BIO_write(mInBio, message->data(), message->size());
if (mState == State::Connecting) { if (state() == State::Connecting) {
// Continue the handshake // Continue the handshake
int ret = SSL_do_handshake(mSsl); int ret = SSL_do_handshake(mSsl);
if (!check_openssl_ret(mSsl, ret, "Handshake failed")) if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
@ -515,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
// No more messages pending, retransmit and rearm timeout if connecting // No more messages pending, retransmit and rearm timeout if connecting
std::optional<milliseconds> duration; std::optional<milliseconds> duration;
if (mState == State::Connecting) { if (state() == State::Connecting) {
// Warning: This function breaks the usual return value convention // Warning: This function breaks the usual return value convention
int ret = DTLSv1_handle_timeout(mSsl); int ret = DTLSv1_handle_timeout(mSsl);
if (ret < 0) { if (ret < 0) {
@ -525,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
} }
struct timeval timeout = {}; struct timeval timeout = {};
if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) { if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000); duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
// Also handle handshake timeout manually because OpenSSL actually doesn't... // Also handle handshake timeout manually because OpenSSL actually doesn't...
// OpenSSL backs off exponentially in base 2 starting from the recommended 1s // OpenSSL backs off exponentially in base 2 starting from the recommended 1s
@ -546,7 +529,7 @@ void DtlsTransport::runRecvLoop() {
PLOG_ERROR << "DTLS recv: " << e.what(); PLOG_ERROR << "DTLS recv: " << e.what();
} }
if (mState == State::Connected) { if (state() == State::Connected) {
PLOG_INFO << "DTLS disconnected"; PLOG_INFO << "DTLS disconnected";
changeState(State::Disconnected); changeState(State::Disconnected);
recv(nullptr); recv(nullptr);

View File

@ -46,33 +46,25 @@ public:
static void Init(); static void Init();
static void Cleanup(); static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using verifier_callback = std::function<bool(const std::string &fingerprint)>; using verifier_callback = std::function<bool(const std::string &fingerprint)>;
using state_callback = std::function<void(State state)>;
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate, DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback); verifier_callback verifierCallback, state_callback stateChangeCallback);
~DtlsTransport(); ~DtlsTransport();
State state() const;
bool stop() override; bool stop() override;
bool send(message_ptr message) override; // false if dropped bool send(message_ptr message) override; // false if dropped
private: private:
void incoming(message_ptr message) override; void incoming(message_ptr message) override;
void changeState(State state);
void runRecvLoop(); void runRecvLoop();
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;
Queue<message_ptr> mIncomingQueue; Queue<message_ptr> mIncomingQueue;
std::atomic<State> mState;
std::thread mRecvThread; std::thread mRecvThread;
verifier_callback mVerifierCallback; verifier_callback mVerifierCallback;
state_callback mStateChangeCallback;
#if USE_GNUTLS #if USE_GNUTLS
gnutls_session_t mSession; gnutls_session_t mSession;

View File

@ -48,9 +48,8 @@ namespace rtc {
IceTransport::IceTransport(const Configuration &config, Description::Role role, IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, state_callback stateChangeCallback, candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback) gathering_state_callback gatheringStateChangeCallback)
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New), : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
mCandidateCallback(std::move(candidateCallback)), mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
mAgent(nullptr, nullptr) { mAgent(nullptr, nullptr) {
@ -108,8 +107,6 @@ bool IceTransport::stop() {
Description::Role IceTransport::role() const { return mRole; } Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const { Description IceTransport::getLocalDescription(Description::Type type) const {
char sdp[JUICE_MAX_SDP_STRING_LEN]; char sdp[JUICE_MAX_SDP_STRING_LEN];
if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0) if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
@ -161,7 +158,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
} }
bool IceTransport::send(message_ptr message) { bool IceTransport::send(message_ptr message) {
if (!message || (mState != State::Connected && mState != State::Completed)) auto s = state();
if (!message || (s != State::Connected && s != State::Completed))
return false; return false;
PLOG_VERBOSE << "Send size=" << message->size(); PLOG_VERBOSE << "Send size=" << message->size();
@ -173,18 +171,29 @@ bool IceTransport::outgoing(message_ptr message) {
message->size()) >= 0; message->size()) >= 0;
} }
void IceTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) { void IceTransport::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state) if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(mGatheringState); mGatheringStateChangeCallback(mGatheringState);
} }
void IceTransport::processStateChange(unsigned int state) { void IceTransport::processStateChange(unsigned int state) {
changeState(static_cast<State>(state)); switch (state) {
case JUICE_STATE_DISCONNECTED:
changeState(State::Disconnected);
break;
case JUICE_STATE_CONNECTING:
changeState(State::Connecting);
break;
case JUICE_STATE_CONNECTED:
changeState(State::Connected);
break;
case JUICE_STATE_COMPLETED:
changeState(State::Completed);
break;
case JUICE_STATE_FAILED:
changeState(State::Failed);
break;
};
} }
void IceTransport::processCandidate(const string &candidate) { void IceTransport::processCandidate(const string &candidate) {
@ -263,9 +272,8 @@ namespace rtc {
IceTransport::IceTransport(const Configuration &config, Description::Role role, IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, state_callback stateChangeCallback, candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback) gathering_state_callback gatheringStateChangeCallback)
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New), : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
mCandidateCallback(std::move(candidateCallback)), mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) { mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
@ -457,8 +465,6 @@ bool IceTransport::stop() {
Description::Role IceTransport::role() const { return mRole; } Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const { Description IceTransport::getLocalDescription(Description::Type type) const {
// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling // RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
// role, and the other MUST take the controlled role. // role, and the other MUST take the controlled role.
@ -529,7 +535,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
} }
bool IceTransport::send(message_ptr message) { bool IceTransport::send(message_ptr message) {
if (!message || (mState != State::Connected && mState != State::Completed)) auto s = state();
if (!message || (s != State::Connected && s != State::Completed))
return false; return false;
PLOG_VERBOSE << "Send size=" << message->size(); PLOG_VERBOSE << "Send size=" << message->size();
@ -541,11 +548,6 @@ bool IceTransport::outgoing(message_ptr message) {
reinterpret_cast<const char *>(message->data())) >= 0; reinterpret_cast<const char *>(message->data())) >= 0;
} }
void IceTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) { void IceTransport::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state) if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(mGatheringState); mGatheringStateChangeCallback(mGatheringState);
@ -576,7 +578,23 @@ void IceTransport::processStateChange(unsigned int state) {
mTimeoutId = 0; mTimeoutId = 0;
} }
changeState(static_cast<State>(state)); switch (state) {
case NICE_COMPONENT_STATE_DISCONNECTED:
changeState(State::Disconnected);
break;
case NICE_COMPONENT_STATE_CONNECTING:
changeState(State::Connecting);
break;
case NICE_COMPONENT_STATE_CONNECTED:
changeState(State::Connected);
break;
case NICE_COMPONENT_STATE_READY:
changeState(State::Completed);
break;
case NICE_COMPONENT_STATE_FAILED:
changeState(State::Failed);
break;
};
} }
string IceTransport::AddressToString(const NiceAddress &addr) { string IceTransport::AddressToString(const NiceAddress &addr) {

View File

@ -40,29 +40,9 @@ namespace rtc {
class IceTransport : public Transport { class IceTransport : public Transport {
public: public:
#if USE_JUICE
enum class State : unsigned int{
Disconnected = JUICE_STATE_DISCONNECTED,
Connecting = JUICE_STATE_CONNECTING,
Connected = JUICE_STATE_CONNECTED,
Completed = JUICE_STATE_COMPLETED,
Failed = JUICE_STATE_FAILED,
};
#else
enum class State : unsigned int {
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
Connecting = NICE_COMPONENT_STATE_CONNECTING,
Connected = NICE_COMPONENT_STATE_CONNECTED,
Completed = NICE_COMPONENT_STATE_READY,
Failed = NICE_COMPONENT_STATE_FAILED,
};
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
#endif
enum class GatheringState { New = 0, InProgress = 1, Complete = 2 }; enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
using candidate_callback = std::function<void(const Candidate &candidate)>; using candidate_callback = std::function<void(const Candidate &candidate)>;
using state_callback = std::function<void(State state)>;
using gathering_state_callback = std::function<void(GatheringState state)>; using gathering_state_callback = std::function<void(GatheringState state)>;
IceTransport(const Configuration &config, Description::Role role, IceTransport(const Configuration &config, Description::Role role,
@ -71,7 +51,6 @@ public:
~IceTransport(); ~IceTransport();
Description::Role role() const; Description::Role role() const;
State state() const;
GatheringState gatheringState() const; GatheringState gatheringState() const;
Description getLocalDescription(Description::Type type) const; Description getLocalDescription(Description::Type type) const;
void setRemoteDescription(const Description &description); void setRemoteDescription(const Description &description);
@ -84,10 +63,13 @@ public:
bool stop() override; bool stop() override;
bool send(message_ptr message) override; // false if dropped bool send(message_ptr message) override; // false if dropped
#if !USE_JUICE
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
#endif
private: private:
bool outgoing(message_ptr message) override; bool outgoing(message_ptr message) override;
void changeState(State state);
void changeGatheringState(GatheringState state); void changeGatheringState(GatheringState state);
void processStateChange(unsigned int state); void processStateChange(unsigned int state);
@ -98,11 +80,9 @@ private:
Description::Role mRole; Description::Role mRole;
string mMid; string mMid;
std::chrono::milliseconds mTrickleTimeout; std::chrono::milliseconds mTrickleTimeout;
std::atomic<State> mState;
std::atomic<GatheringState> mGatheringState; std::atomic<GatheringState> mGatheringState;
candidate_callback mCandidateCallback; candidate_callback mCandidateCallback;
state_callback mStateChangeCallback;
gathering_state_callback mGatheringStateChangeCallback; gathering_state_callback mGatheringStateChangeCallback;
#if USE_JUICE #if USE_JUICE

View File

@ -67,9 +67,8 @@ void SctpTransport::Cleanup() { usrsctp_finish(); }
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
message_callback recvCallback, amount_callback bufferedAmountCallback, message_callback recvCallback, amount_callback bufferedAmountCallback,
state_callback stateChangeCallback) state_callback stateChangeCallback)
: Transport(lower), mPort(port), mSendQueue(0, message_size_func), : Transport(lower, std::move(stateChangeCallback)), mPort(port),
mBufferedAmountCallback(std::move(bufferedAmountCallback)), mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
onRecv(recvCallback); onRecv(recvCallback);
PLOG_DEBUG << "Initializing SCTP transport"; PLOG_DEBUG << "Initializing SCTP transport";
@ -176,8 +175,6 @@ SctpTransport::~SctpTransport() {
usrsctp_deregister_address(this); usrsctp_deregister_address(this);
} }
SctpTransport::State SctpTransport::state() const { return mState; }
bool SctpTransport::stop() { bool SctpTransport::stop() {
if (!Transport::stop()) if (!Transport::stop())
return false; return false;
@ -265,7 +262,7 @@ void SctpTransport::incoming(message_ptr message) {
// to be sent on our side (i.e. the local INIT) before proceeding. // to be sent on our side (i.e. the local INIT) before proceeding.
{ {
std::unique_lock lock(mWriteMutex); std::unique_lock lock(mWriteMutex);
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; }); mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
} }
if (!message) { if (!message) {
@ -279,11 +276,6 @@ void SctpTransport::incoming(message_ptr message) {
usrsctp_conninput(this, message->data(), message->size(), 0); usrsctp_conninput(this, message->data(), message->size(), 0);
} }
void SctpTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
bool SctpTransport::trySendQueue() { bool SctpTransport::trySendQueue() {
// Requires mSendMutex to be locked // Requires mSendMutex to be locked
while (auto next = mSendQueue.peek()) { while (auto next = mSendQueue.peek()) {
@ -298,7 +290,7 @@ bool SctpTransport::trySendQueue() {
bool SctpTransport::trySendMessage(message_ptr message) { bool SctpTransport::trySendMessage(message_ptr message) {
// Requires mSendMutex to be locked // Requires mSendMutex to be locked
if (!mSock || mState != State::Connected) if (!mSock || state() != State::Connected)
return false; return false;
uint32_t ppid; uint32_t ppid;
@ -410,7 +402,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) { if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp... std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
mWrittenCondition.wait_for(lock, 1000ms, mWrittenCondition.wait_for(lock, 1000ms,
[&]() { return mWritten || mState != State::Connected; }); [&]() { return mWritten || state() != State::Connected; });
} else if (errno == EINVAL) { } else if (errno == EINVAL) {
PLOG_VERBOSE << "SCTP stream " << streamId << " already reset"; PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
} else { } else {
@ -567,7 +559,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
PLOG_INFO << "SCTP connected"; PLOG_INFO << "SCTP connected";
changeState(State::Connected); changeState(State::Connected);
} else { } else {
if (mState == State::Connecting) { if (state() == State::Connecting) {
PLOG_ERROR << "SCTP connection failed"; PLOG_ERROR << "SCTP connection failed";
changeState(State::Failed); changeState(State::Failed);
} else { } else {

View File

@ -38,17 +38,12 @@ public:
static void Init(); static void Init();
static void Cleanup(); static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>; using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
using state_callback = std::function<void(State state)>;
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback, SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
amount_callback bufferedAmountCallback, state_callback stateChangeCallback); amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
~SctpTransport(); ~SctpTransport();
State state() const;
bool stop() override; bool stop() override;
bool send(message_ptr message) override; // false if buffered bool send(message_ptr message) override; // false if buffered
void close(unsigned int stream); void close(unsigned int stream);
@ -76,7 +71,6 @@ private:
void connect(); void connect();
void shutdown(); void shutdown();
void incoming(message_ptr message) override; void incoming(message_ptr message) override;
void changeState(State state);
bool trySendQueue(); bool trySendQueue();
bool trySendMessage(message_ptr message); bool trySendMessage(message_ptr message);
@ -105,14 +99,11 @@ private:
std::atomic<bool> mWritten = false; // written outside lock std::atomic<bool> mWritten = false; // written outside lock
bool mWrittenOnce = false; bool mWrittenOnce = false;
state_callback mStateChangeCallback; binary mPartialRecv, mPartialStringData, mPartialBinaryData;
std::atomic<State> mState;
// Stats // Stats
std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0; std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
struct sctp_rcvinfo recv_info, int flags, void *user_data); struct sctp_rcvinfo recv_info, int flags, void *user_data);
static int SendCallback(struct socket *sock, uint32_t sb_free); static int SendCallback(struct socket *sock, uint32_t sb_free);

View File

@ -24,8 +24,8 @@ namespace rtc {
using std::to_string; using std::to_string;
TcpTransport::TcpTransport(const string &hostname, const string &service) TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
: mHostname(hostname), mService(service) { : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
mThread = std::thread(&TcpTransport::runLoop, this); mThread = std::thread(&TcpTransport::runLoop, this);
} }

View File

@ -34,7 +34,7 @@ namespace rtc {
class TcpTransport : public Transport { class TcpTransport : public Transport {
public: public:
TcpTransport(const string &hostname, const string &service); TcpTransport(const string &hostname, const string &service, state_callback callback);
~TcpTransport(); ~TcpTransport();
bool stop() override; bool stop() override;

View File

@ -61,7 +61,8 @@ void TlsTransport::Cleanup() {
// Nothing to do // Nothing to do
} }
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) { TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) {
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)"; PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
@ -82,6 +83,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transp
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size()); gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
registerIncoming();
} catch (...) { } catch (...) {
@ -271,10 +273,10 @@ void TlsTransport::Cleanup() {
// Nothing to do // Nothing to do
} }
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) { TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) {
PLOG_DEBUG << "Initializing TLS transport (OpenSSL)"; PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
GlobalInit();
if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
throw std::runtime_error("Failed to create SSL context"); throw std::runtime_error("Failed to create SSL context");

View File

@ -41,7 +41,7 @@ class TcpTransport;
class TlsTransport : public Transport { class TlsTransport : public Transport {
public: public:
TlsTransport(std::shared_ptr<TcpTransport> lower, string host); TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
~TlsTransport(); ~TlsTransport();
bool stop() override; bool stop() override;

View File

@ -32,7 +32,13 @@ using namespace std::placeholders;
class Transport { class Transport {
public: public:
Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {} enum class State { Disconnected, Connecting, Connected, Completed, Failed };
using state_callback = std::function<void(State state)>;
Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
: mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
}
virtual ~Transport() { virtual ~Transport() {
stop(); stop();
if (mLower) if (mLower)
@ -49,11 +55,16 @@ public:
} }
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); } void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
State state() const { return mState; }
virtual bool send(message_ptr message) { return outgoing(message); } virtual bool send(message_ptr message) { return outgoing(message); }
protected: protected:
void recv(message_ptr message) { mRecvCallback(message); } void recv(message_ptr message) { mRecvCallback(message); }
void changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
virtual void incoming(message_ptr message) { recv(message); } virtual void incoming(message_ptr message) { recv(message); }
virtual bool outgoing(message_ptr message) { virtual bool outgoing(message_ptr message) {
@ -65,7 +76,10 @@ protected:
private: private:
std::shared_ptr<Transport> mLower; std::shared_ptr<Transport> mLower;
synchronized_callback<State> mStateChangeCallback;
synchronized_callback<message_ptr> mRecvCallback; synchronized_callback<message_ptr> mRecvCallback;
std::atomic<State> mState = State::Disconnected;
std::atomic<bool> mShutdown = false; std::atomic<bool> mShutdown = false;
}; };

View File

@ -1,100 +1,238 @@
/************************************************************************* /**
* Copyright (C) 2017-2018 by Paul-Louis Ageneau * * Copyright (c) 2020 Paul-Louis Ageneau
* paul-louis (at) ageneau (dot) org * *
* * * This library is free software; you can redistribute it and/or
* This file is part of Plateform. * * modify it under the terms of the GNU Lesser General Public
* * * License as published by the Free Software Foundation; either
* Plateform is free software: you can redistribute it and/or modify * * version 2.1 of the License, or (at your option) any later version.
* it under the terms of the GNU Affero General Public License as * *
* published by the Free Software Foundation, either version 3 of * * This library is distributed in the hope that it will be useful,
* the License, or (at your option) any later version. * * but WITHOUT ANY WARRANTY; without even the implied warranty of
* * * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Plateform is distributed in the hope that it will be useful, but * * Lesser General Public License for more details.
* WITHOUT ANY WARRANTY; without even the implied warranty of * *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * * You should have received a copy of the GNU Lesser General Public
* GNU Affero General Public License for more details. * * License along with this library; if not, write to the Free Software
* * * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
* You should have received a copy of the GNU Affero General Public * */
* License along with Plateform. *
* If not, see <http://www.gnu.org/licenses/>. *
*************************************************************************/
#include "net/websocket.hpp" #if ENABLE_WEBSOCKET
#include <exception> #include "include.hpp"
#include <iostream> #include "websocket.hpp"
const size_t DEFAULT_MAX_PAYLOAD_SIZE = 16384; // 16 KB #include "tcptransport.hpp"
#include "tlstransport.hpp"
#include "wstransport.hpp"
namespace net { #include <regex>
WebSocket::WebSocket(void) : mMaxPayloadSize(DEFAULT_MAX_PAYLOAD_SIZE) {} namespace rtc {
WebSocket::WebSocket() {}
WebSocket::WebSocket(const string &url) : WebSocket() { open(url); } WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
WebSocket::~WebSocket(void) {} WebSocket::~WebSocket() { close(); }
void WebSocket::open(const string &url) { void WebSocket::open(const string &url) {
close(); static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
static std::regex regex(rs, std::regex::extended);
mUrl = url; std::smatch match;
mThread = std::thread(&WebSocket::run, this); if (!std::regex_match(url, match, regex))
} throw std::invalid_argument("Malformed WebSocket URL: " + url);
void WebSocket::close(void) { mScheme = match[2];
mWebSocket.close(); if (mScheme != "ws" && mScheme != "wss")
if (mThread.joinable()) throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
mThread.join();
mConnected = false;
}
bool WebSocket::isOpen(void) const { return mConnected; } mHost = match[4];
if (auto pos = mHost.find(':'); pos != string::npos) {
bool WebSocket::isClosed(void) const { return !mThread.joinable(); } mHostname = mHost.substr(0, pos);
mService = mHost.substr(pos + 1);
void WebSocket::setMaxPayloadSize(size_t size) { mMaxPayloadSize = size; } } else {
mHostname = mHost;
bool WebSocket::send(const std::variant<binary, string> &data) { mService = mScheme == "ws" ? "80" : "443";
if (!std::holds_alternative<binary>(data))
throw std::runtime_error("WebSocket string messages are not supported");
mWebSocket.write(std::get<binary>(data));
return true;
}
std::optional<std::variant<binary, string>> WebSocket::receive() {
if (!mQueue.empty())
return mQueue.pop();
else
return std::nullopt;
}
void WebSocket::run(void) {
if (mUrl.empty())
return;
try {
mWebSocket.connect(mUrl);
mConnected = true;
triggerOpen();
while (true) {
binary payload;
if (!mWebSocket.read(payload, mMaxPayloadSize))
break;
mQueue.push(std::move(payload));
triggerAvailable(mQueue.size());
}
} catch (const std::exception &e) {
triggerError(e.what());
} }
mWebSocket.close(); mPath = match[5];
if (string query = match[7]; !query.empty())
mPath += "?" + query;
if (mConnected) initTcpTransport();
triggerClosed();
mConnected = false;
} }
} // namespace net void WebSocket::close() {
resetCallbacks();
closeTransports();
}
void WebSocket::remoteClose() {
mIsOpen = false;
if (!mIsClosed.exchange(true))
triggerClosed();
}
bool WebSocket::send(const std::variant<binary, string> &data) {
return std::visit(
[&](const auto &d) {
using T = std::decay_t<decltype(d)>;
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
auto *b = reinterpret_cast<const byte *>(d.data());
return outgoing(std::make_shared<Message>(b, b + d.size(), type));
},
data);
}
bool WebSocket::isOpen() const { return mIsOpen; }
bool WebSocket::isClosed() const { return mIsClosed; }
size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
std::optional<std::variant<binary, string>> WebSocket::receive() { return nullopt; }
size_t WebSocket::availableAmount() const { return 0; }
bool WebSocket::outgoing(mutable_message_ptr message) {
if (mIsClosed || !mWsTransport)
throw std::runtime_error("WebSocket is closed");
if (message->size() > maxMessageSize())
throw std::runtime_error("Message size exceeds limit");
return mWsTransport->send(message);
}
std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
using State = TcpTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mTcpTransport))
return transport;
auto transport = std::make_shared<TcpTransport>(mHostname, mService, [this](State state) {
switch (state) {
case State::Connected:
if (mScheme == "ws")
initWsTransport();
else
initTlsTransport();
break;
case State::Failed:
// TODO
break;
case State::Disconnected:
// TODO
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mTcpTransport, transport);
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
// TODO
throw std::runtime_error("TCP transport initialization failed");
}
}
std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
using State = TlsTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mTlsTransport))
return transport;
auto lower = std::atomic_load(&mTcpTransport);
auto transport = std::make_shared<TlsTransport>(lower, mHost, [this](State state) {
switch (state) {
case State::Connected:
initWsTransport();
break;
case State::Failed:
// TODO
break;
case State::Disconnected:
// TODO
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mTlsTransport, transport);
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
// TODO
throw std::runtime_error("TLS transport initialization failed");
}
}
std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
using State = WsTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mWsTransport))
return transport;
std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
if (!lower)
lower = std::atomic_load(&mTcpTransport);
auto transport = std::make_shared<WsTransport>(lower, mHost, mPath, [this](State state) {
switch (state) {
case State::Connected:
triggerOpen();
break;
case State::Failed:
// TODO
break;
case State::Disconnected:
// TODO
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mWsTransport, transport);
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
// TODO
throw std::runtime_error("WebSocket transport initialization failed");
}
}
void closeTransports() {
mIsOpen = false;
mIsClosed = true;
// Pass the references to a thread, allowing to terminate a transport from its own thread
auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
if (ws || dtls || tcp) {
std::thread t([ws, dtls, tcp]() mutable {
if (ws)
ws->stop();
if (dtls)
dtls->stop();
if (tcp)
tcp->stop();
ws.reset();
dtls.reset();
tcp.reset();
});
t.detach();
}
}
} // namespace rtc
#endif

View File

@ -53,11 +53,12 @@ using std::to_string;
using random_bytes_engine = using random_bytes_engine =
std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>; std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
WsTransport::WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path) WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
: Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {} state_callback callback)
: Transport(lower, std::move(callback)), mHost(std::move(host)), mPath(std::move(path)) {
WsTransport::WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path) registerIncoming();
: Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {} }
WsTransport::~WsTransport() {} WsTransport::~WsTransport() {}

View File

@ -31,8 +31,8 @@ class TlsTransport;
class WsTransport : public Transport { class WsTransport : public Transport {
public: public:
WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path); WsTransport(std::shared_ptr<Transport> lower, string host, string path,
WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path); state_callback callback);
~WsTransport(); ~WsTransport();
void stop() override; void stop() override;