mirror of
https://github.com/mii443/libdatachannel.git
synced 2025-08-31 03:19:29 +00:00
Added WebSocket
This commit is contained in:
@ -36,10 +36,11 @@ set(LIBDATACHANNEL_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
|
||||
)
|
||||
|
||||
set(LIBDATACHANNEL_HEADERS
|
||||
|
@ -96,8 +96,6 @@ public:
|
||||
std::optional<std::chrono::milliseconds> rtt();
|
||||
|
||||
private:
|
||||
init_token mInitToken = Init::Token();
|
||||
|
||||
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
|
||||
std::shared_ptr<DtlsTransport> initDtlsTransport();
|
||||
std::shared_ptr<SctpTransport> initSctpTransport();
|
||||
@ -128,6 +126,8 @@ private:
|
||||
const Configuration mConfig;
|
||||
const std::shared_ptr<Certificate> mCertificate;
|
||||
|
||||
init_token mInitToken = Init::Token();
|
||||
|
||||
std::optional<Description> mLocalDescription, mRemoteDescription;
|
||||
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
|
||||
|
||||
|
87
include/rtc/websocket.hpp
Normal file
87
include/rtc/websocket.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
/**
|
||||
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Lesser General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2.1 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Lesser General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Lesser General Public
|
||||
* License along with this library; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
*/
|
||||
|
||||
#ifndef RTC_WEBSOCKET_H
|
||||
#define RTC_WEBSOCKET_H
|
||||
|
||||
#if ENABLE_WEBSOCKET
|
||||
|
||||
#include "channel.hpp"
|
||||
#include "include.hpp"
|
||||
#include "init.hpp"
|
||||
#include "message.hpp"
|
||||
#include "queue.hpp"
|
||||
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <thread>
|
||||
#include <variant>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
class TcpTransport;
|
||||
class TlsTransport;
|
||||
class WsTransport;
|
||||
|
||||
class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
|
||||
public:
|
||||
WebSocket();
|
||||
WebSocket(const string &url);
|
||||
~WebSocket();
|
||||
|
||||
void open(const string &url);
|
||||
void close() override;
|
||||
bool send(const std::variant<binary, string> &data) override;
|
||||
|
||||
bool isOpen() const override;
|
||||
bool isClosed() const override;
|
||||
size_t maxMessageSize() const override;
|
||||
|
||||
// Extended API
|
||||
std::optional<std::variant<binary, string>> receive() override;
|
||||
size_t availableAmount() const override; // total size available to receive
|
||||
|
||||
private:
|
||||
void remoteClose();
|
||||
bool outgoing(mutable_message_ptr message);
|
||||
|
||||
std::shared_ptr<TcpTransport> initTcpTransport();
|
||||
std::shared_ptr<TlsTransport> initTlsTransport();
|
||||
std::shared_ptr<WsTransport> initWsTransport();
|
||||
void closeTransports();
|
||||
|
||||
init_token mInitToken = Init::Token();
|
||||
|
||||
std::shared_ptr<TcpTransport> mTcpTransport;
|
||||
std::shared_ptr<TlsTransport> mTlsTransport;
|
||||
std::shared_ptr<WsTransport> mWsTransport;
|
||||
std::recursive_mutex mInitMutex;
|
||||
|
||||
string mScheme, mHost, mHostname, mService, mPath;
|
||||
|
||||
std::atomic<bool> mIsOpen = false;
|
||||
std::atomic<bool> mIsClosed = false;
|
||||
|
||||
Queue<message_ptr> mRecvQueue;
|
||||
std::atomic<size_t> mRecvAmount = 0;
|
||||
};
|
||||
} // namespace rtc
|
||||
|
||||
#endif
|
||||
|
||||
#endif // NET_WEBSOCKET_H
|
@ -62,11 +62,9 @@ void DtlsTransport::Cleanup() {
|
||||
}
|
||||
|
||||
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
|
||||
verifier_callback verifierCallback,
|
||||
state_callback stateChangeCallback)
|
||||
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
|
||||
mVerifierCallback(std::move(verifierCallback)),
|
||||
mStateChangeCallback(std::move(stateChangeCallback)) {
|
||||
verifier_callback verifierCallback, state_callback stateChangeCallback)
|
||||
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
|
||||
mVerifierCallback(std::move(verifierCallback)) {
|
||||
|
||||
PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
|
||||
|
||||
@ -113,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
|
||||
gnutls_deinit(mSession);
|
||||
}
|
||||
|
||||
DtlsTransport::State DtlsTransport::state() const { return mState; }
|
||||
|
||||
bool DtlsTransport::stop() {
|
||||
if (!Transport::stop())
|
||||
return false;
|
||||
@ -126,7 +122,7 @@ bool DtlsTransport::stop() {
|
||||
}
|
||||
|
||||
bool DtlsTransport::send(message_ptr message) {
|
||||
if (!message || mState != State::Connected)
|
||||
if (!message || state() != State::Connected)
|
||||
return false;
|
||||
|
||||
PLOG_VERBOSE << "Send size=" << message->size();
|
||||
@ -152,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
|
||||
mIncomingQueue.push(message);
|
||||
}
|
||||
|
||||
void DtlsTransport::changeState(State state) {
|
||||
if (mState.exchange(state) != state)
|
||||
mStateChangeCallback(state);
|
||||
}
|
||||
|
||||
void DtlsTransport::runRecvLoop() {
|
||||
const size_t maxMtu = 4096;
|
||||
|
||||
@ -362,9 +353,8 @@ void DtlsTransport::Cleanup() {
|
||||
|
||||
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
|
||||
verifier_callback verifierCallback, state_callback stateChangeCallback)
|
||||
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
|
||||
mVerifierCallback(std::move(verifierCallback)),
|
||||
mStateChangeCallback(std::move(stateChangeCallback)) {
|
||||
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
|
||||
mVerifierCallback(std::move(verifierCallback)) {
|
||||
|
||||
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
|
||||
|
||||
@ -445,10 +435,8 @@ bool DtlsTransport::stop() {
|
||||
return true;
|
||||
}
|
||||
|
||||
DtlsTransport::State DtlsTransport::state() const { return mState; }
|
||||
|
||||
bool DtlsTransport::send(message_ptr message) {
|
||||
if (!message || mState != State::Connected)
|
||||
if (!message || state() != State::Connected)
|
||||
return false;
|
||||
|
||||
PLOG_VERBOSE << "Send size=" << message->size();
|
||||
@ -467,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
|
||||
mIncomingQueue.push(message);
|
||||
}
|
||||
|
||||
void DtlsTransport::changeState(State state) {
|
||||
if (mState.exchange(state) != state)
|
||||
mStateChangeCallback(state);
|
||||
}
|
||||
|
||||
void DtlsTransport::runRecvLoop() {
|
||||
const size_t maxMtu = 4096;
|
||||
try {
|
||||
@ -490,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
|
||||
auto message = *mIncomingQueue.pop();
|
||||
BIO_write(mInBio, message->data(), message->size());
|
||||
|
||||
if (mState == State::Connecting) {
|
||||
if (state() == State::Connecting) {
|
||||
// Continue the handshake
|
||||
int ret = SSL_do_handshake(mSsl);
|
||||
if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
|
||||
@ -515,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
|
||||
|
||||
// No more messages pending, retransmit and rearm timeout if connecting
|
||||
std::optional<milliseconds> duration;
|
||||
if (mState == State::Connecting) {
|
||||
if (state() == State::Connecting) {
|
||||
// Warning: This function breaks the usual return value convention
|
||||
int ret = DTLSv1_handle_timeout(mSsl);
|
||||
if (ret < 0) {
|
||||
@ -525,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
|
||||
}
|
||||
|
||||
struct timeval timeout = {};
|
||||
if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
|
||||
if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
|
||||
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
|
||||
// Also handle handshake timeout manually because OpenSSL actually doesn't...
|
||||
// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
|
||||
@ -546,7 +529,7 @@ void DtlsTransport::runRecvLoop() {
|
||||
PLOG_ERROR << "DTLS recv: " << e.what();
|
||||
}
|
||||
|
||||
if (mState == State::Connected) {
|
||||
if (state() == State::Connected) {
|
||||
PLOG_INFO << "DTLS disconnected";
|
||||
changeState(State::Disconnected);
|
||||
recv(nullptr);
|
||||
|
@ -46,33 +46,25 @@ public:
|
||||
static void Init();
|
||||
static void Cleanup();
|
||||
|
||||
enum class State { Disconnected, Connecting, Connected, Failed };
|
||||
|
||||
using verifier_callback = std::function<bool(const std::string &fingerprint)>;
|
||||
using state_callback = std::function<void(State state)>;
|
||||
|
||||
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
|
||||
verifier_callback verifierCallback, state_callback stateChangeCallback);
|
||||
~DtlsTransport();
|
||||
|
||||
State state() const;
|
||||
|
||||
bool stop() override;
|
||||
bool send(message_ptr message) override; // false if dropped
|
||||
|
||||
private:
|
||||
void incoming(message_ptr message) override;
|
||||
void changeState(State state);
|
||||
void runRecvLoop();
|
||||
|
||||
const std::shared_ptr<Certificate> mCertificate;
|
||||
|
||||
Queue<message_ptr> mIncomingQueue;
|
||||
std::atomic<State> mState;
|
||||
std::thread mRecvThread;
|
||||
|
||||
verifier_callback mVerifierCallback;
|
||||
state_callback mStateChangeCallback;
|
||||
|
||||
#if USE_GNUTLS
|
||||
gnutls_session_t mSession;
|
||||
|
@ -48,9 +48,8 @@ namespace rtc {
|
||||
IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
||||
candidate_callback candidateCallback, state_callback stateChangeCallback,
|
||||
gathering_state_callback gatheringStateChangeCallback)
|
||||
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
|
||||
mCandidateCallback(std::move(candidateCallback)),
|
||||
mStateChangeCallback(std::move(stateChangeCallback)),
|
||||
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
|
||||
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
|
||||
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
|
||||
mAgent(nullptr, nullptr) {
|
||||
|
||||
@ -108,8 +107,6 @@ bool IceTransport::stop() {
|
||||
|
||||
Description::Role IceTransport::role() const { return mRole; }
|
||||
|
||||
IceTransport::State IceTransport::state() const { return mState; }
|
||||
|
||||
Description IceTransport::getLocalDescription(Description::Type type) const {
|
||||
char sdp[JUICE_MAX_SDP_STRING_LEN];
|
||||
if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
|
||||
@ -161,7 +158,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
|
||||
}
|
||||
|
||||
bool IceTransport::send(message_ptr message) {
|
||||
if (!message || (mState != State::Connected && mState != State::Completed))
|
||||
auto s = state();
|
||||
if (!message || (s != State::Connected && s != State::Completed))
|
||||
return false;
|
||||
|
||||
PLOG_VERBOSE << "Send size=" << message->size();
|
||||
@ -173,18 +171,29 @@ bool IceTransport::outgoing(message_ptr message) {
|
||||
message->size()) >= 0;
|
||||
}
|
||||
|
||||
void IceTransport::changeState(State state) {
|
||||
if (mState.exchange(state) != state)
|
||||
mStateChangeCallback(mState);
|
||||
}
|
||||
|
||||
void IceTransport::changeGatheringState(GatheringState state) {
|
||||
if (mGatheringState.exchange(state) != state)
|
||||
mGatheringStateChangeCallback(mGatheringState);
|
||||
}
|
||||
|
||||
void IceTransport::processStateChange(unsigned int state) {
|
||||
changeState(static_cast<State>(state));
|
||||
switch (state) {
|
||||
case JUICE_STATE_DISCONNECTED:
|
||||
changeState(State::Disconnected);
|
||||
break;
|
||||
case JUICE_STATE_CONNECTING:
|
||||
changeState(State::Connecting);
|
||||
break;
|
||||
case JUICE_STATE_CONNECTED:
|
||||
changeState(State::Connected);
|
||||
break;
|
||||
case JUICE_STATE_COMPLETED:
|
||||
changeState(State::Completed);
|
||||
break;
|
||||
case JUICE_STATE_FAILED:
|
||||
changeState(State::Failed);
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
void IceTransport::processCandidate(const string &candidate) {
|
||||
@ -263,9 +272,8 @@ namespace rtc {
|
||||
IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
||||
candidate_callback candidateCallback, state_callback stateChangeCallback,
|
||||
gathering_state_callback gatheringStateChangeCallback)
|
||||
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
|
||||
mCandidateCallback(std::move(candidateCallback)),
|
||||
mStateChangeCallback(std::move(stateChangeCallback)),
|
||||
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
|
||||
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
|
||||
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
|
||||
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
|
||||
|
||||
@ -457,8 +465,6 @@ bool IceTransport::stop() {
|
||||
|
||||
Description::Role IceTransport::role() const { return mRole; }
|
||||
|
||||
IceTransport::State IceTransport::state() const { return mState; }
|
||||
|
||||
Description IceTransport::getLocalDescription(Description::Type type) const {
|
||||
// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
|
||||
// role, and the other MUST take the controlled role.
|
||||
@ -529,7 +535,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
|
||||
}
|
||||
|
||||
bool IceTransport::send(message_ptr message) {
|
||||
if (!message || (mState != State::Connected && mState != State::Completed))
|
||||
auto s = state();
|
||||
if (!message || (s != State::Connected && s != State::Completed))
|
||||
return false;
|
||||
|
||||
PLOG_VERBOSE << "Send size=" << message->size();
|
||||
@ -541,11 +548,6 @@ bool IceTransport::outgoing(message_ptr message) {
|
||||
reinterpret_cast<const char *>(message->data())) >= 0;
|
||||
}
|
||||
|
||||
void IceTransport::changeState(State state) {
|
||||
if (mState.exchange(state) != state)
|
||||
mStateChangeCallback(mState);
|
||||
}
|
||||
|
||||
void IceTransport::changeGatheringState(GatheringState state) {
|
||||
if (mGatheringState.exchange(state) != state)
|
||||
mGatheringStateChangeCallback(mGatheringState);
|
||||
@ -576,7 +578,23 @@ void IceTransport::processStateChange(unsigned int state) {
|
||||
mTimeoutId = 0;
|
||||
}
|
||||
|
||||
changeState(static_cast<State>(state));
|
||||
switch (state) {
|
||||
case NICE_COMPONENT_STATE_DISCONNECTED:
|
||||
changeState(State::Disconnected);
|
||||
break;
|
||||
case NICE_COMPONENT_STATE_CONNECTING:
|
||||
changeState(State::Connecting);
|
||||
break;
|
||||
case NICE_COMPONENT_STATE_CONNECTED:
|
||||
changeState(State::Connected);
|
||||
break;
|
||||
case NICE_COMPONENT_STATE_READY:
|
||||
changeState(State::Completed);
|
||||
break;
|
||||
case NICE_COMPONENT_STATE_FAILED:
|
||||
changeState(State::Failed);
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
string IceTransport::AddressToString(const NiceAddress &addr) {
|
||||
|
@ -40,29 +40,9 @@ namespace rtc {
|
||||
|
||||
class IceTransport : public Transport {
|
||||
public:
|
||||
#if USE_JUICE
|
||||
enum class State : unsigned int{
|
||||
Disconnected = JUICE_STATE_DISCONNECTED,
|
||||
Connecting = JUICE_STATE_CONNECTING,
|
||||
Connected = JUICE_STATE_CONNECTED,
|
||||
Completed = JUICE_STATE_COMPLETED,
|
||||
Failed = JUICE_STATE_FAILED,
|
||||
};
|
||||
#else
|
||||
enum class State : unsigned int {
|
||||
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
|
||||
Connecting = NICE_COMPONENT_STATE_CONNECTING,
|
||||
Connected = NICE_COMPONENT_STATE_CONNECTED,
|
||||
Completed = NICE_COMPONENT_STATE_READY,
|
||||
Failed = NICE_COMPONENT_STATE_FAILED,
|
||||
};
|
||||
|
||||
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
|
||||
#endif
|
||||
enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
|
||||
|
||||
using candidate_callback = std::function<void(const Candidate &candidate)>;
|
||||
using state_callback = std::function<void(State state)>;
|
||||
using gathering_state_callback = std::function<void(GatheringState state)>;
|
||||
|
||||
IceTransport(const Configuration &config, Description::Role role,
|
||||
@ -71,7 +51,6 @@ public:
|
||||
~IceTransport();
|
||||
|
||||
Description::Role role() const;
|
||||
State state() const;
|
||||
GatheringState gatheringState() const;
|
||||
Description getLocalDescription(Description::Type type) const;
|
||||
void setRemoteDescription(const Description &description);
|
||||
@ -84,10 +63,13 @@ public:
|
||||
bool stop() override;
|
||||
bool send(message_ptr message) override; // false if dropped
|
||||
|
||||
#if !USE_JUICE
|
||||
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
|
||||
#endif
|
||||
|
||||
private:
|
||||
bool outgoing(message_ptr message) override;
|
||||
|
||||
void changeState(State state);
|
||||
void changeGatheringState(GatheringState state);
|
||||
|
||||
void processStateChange(unsigned int state);
|
||||
@ -98,11 +80,9 @@ private:
|
||||
Description::Role mRole;
|
||||
string mMid;
|
||||
std::chrono::milliseconds mTrickleTimeout;
|
||||
std::atomic<State> mState;
|
||||
std::atomic<GatheringState> mGatheringState;
|
||||
|
||||
candidate_callback mCandidateCallback;
|
||||
state_callback mStateChangeCallback;
|
||||
gathering_state_callback mGatheringStateChangeCallback;
|
||||
|
||||
#if USE_JUICE
|
||||
|
@ -67,9 +67,8 @@ void SctpTransport::Cleanup() { usrsctp_finish(); }
|
||||
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
|
||||
message_callback recvCallback, amount_callback bufferedAmountCallback,
|
||||
state_callback stateChangeCallback)
|
||||
: Transport(lower), mPort(port), mSendQueue(0, message_size_func),
|
||||
mBufferedAmountCallback(std::move(bufferedAmountCallback)),
|
||||
mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
|
||||
: Transport(lower, std::move(stateChangeCallback)), mPort(port),
|
||||
mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
|
||||
onRecv(recvCallback);
|
||||
|
||||
PLOG_DEBUG << "Initializing SCTP transport";
|
||||
@ -176,8 +175,6 @@ SctpTransport::~SctpTransport() {
|
||||
usrsctp_deregister_address(this);
|
||||
}
|
||||
|
||||
SctpTransport::State SctpTransport::state() const { return mState; }
|
||||
|
||||
bool SctpTransport::stop() {
|
||||
if (!Transport::stop())
|
||||
return false;
|
||||
@ -265,7 +262,7 @@ void SctpTransport::incoming(message_ptr message) {
|
||||
// to be sent on our side (i.e. the local INIT) before proceeding.
|
||||
{
|
||||
std::unique_lock lock(mWriteMutex);
|
||||
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
|
||||
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
|
||||
}
|
||||
|
||||
if (!message) {
|
||||
@ -279,11 +276,6 @@ void SctpTransport::incoming(message_ptr message) {
|
||||
usrsctp_conninput(this, message->data(), message->size(), 0);
|
||||
}
|
||||
|
||||
void SctpTransport::changeState(State state) {
|
||||
if (mState.exchange(state) != state)
|
||||
mStateChangeCallback(state);
|
||||
}
|
||||
|
||||
bool SctpTransport::trySendQueue() {
|
||||
// Requires mSendMutex to be locked
|
||||
while (auto next = mSendQueue.peek()) {
|
||||
@ -298,7 +290,7 @@ bool SctpTransport::trySendQueue() {
|
||||
|
||||
bool SctpTransport::trySendMessage(message_ptr message) {
|
||||
// Requires mSendMutex to be locked
|
||||
if (!mSock || mState != State::Connected)
|
||||
if (!mSock || state() != State::Connected)
|
||||
return false;
|
||||
|
||||
uint32_t ppid;
|
||||
@ -410,7 +402,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
|
||||
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
|
||||
std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
|
||||
mWrittenCondition.wait_for(lock, 1000ms,
|
||||
[&]() { return mWritten || mState != State::Connected; });
|
||||
[&]() { return mWritten || state() != State::Connected; });
|
||||
} else if (errno == EINVAL) {
|
||||
PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
|
||||
} else {
|
||||
@ -567,7 +559,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
|
||||
PLOG_INFO << "SCTP connected";
|
||||
changeState(State::Connected);
|
||||
} else {
|
||||
if (mState == State::Connecting) {
|
||||
if (state() == State::Connecting) {
|
||||
PLOG_ERROR << "SCTP connection failed";
|
||||
changeState(State::Failed);
|
||||
} else {
|
||||
|
@ -38,17 +38,12 @@ public:
|
||||
static void Init();
|
||||
static void Cleanup();
|
||||
|
||||
enum class State { Disconnected, Connecting, Connected, Failed };
|
||||
|
||||
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
|
||||
using state_callback = std::function<void(State state)>;
|
||||
|
||||
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
|
||||
amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
|
||||
~SctpTransport();
|
||||
|
||||
State state() const;
|
||||
|
||||
bool stop() override;
|
||||
bool send(message_ptr message) override; // false if buffered
|
||||
void close(unsigned int stream);
|
||||
@ -76,7 +71,6 @@ private:
|
||||
void connect();
|
||||
void shutdown();
|
||||
void incoming(message_ptr message) override;
|
||||
void changeState(State state);
|
||||
|
||||
bool trySendQueue();
|
||||
bool trySendMessage(message_ptr message);
|
||||
@ -105,14 +99,11 @@ private:
|
||||
std::atomic<bool> mWritten = false; // written outside lock
|
||||
bool mWrittenOnce = false;
|
||||
|
||||
state_callback mStateChangeCallback;
|
||||
std::atomic<State> mState;
|
||||
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
|
||||
|
||||
// Stats
|
||||
std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
|
||||
|
||||
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
|
||||
|
||||
static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
|
||||
struct sctp_rcvinfo recv_info, int flags, void *user_data);
|
||||
static int SendCallback(struct socket *sock, uint32_t sb_free);
|
||||
|
@ -24,8 +24,8 @@ namespace rtc {
|
||||
|
||||
using std::to_string;
|
||||
|
||||
TcpTransport::TcpTransport(const string &hostname, const string &service)
|
||||
: mHostname(hostname), mService(service) {
|
||||
TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
|
||||
: Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
|
||||
mThread = std::thread(&TcpTransport::runLoop, this);
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ namespace rtc {
|
||||
|
||||
class TcpTransport : public Transport {
|
||||
public:
|
||||
TcpTransport(const string &hostname, const string &service);
|
||||
TcpTransport(const string &hostname, const string &service, state_callback callback);
|
||||
~TcpTransport();
|
||||
|
||||
bool stop() override;
|
||||
|
@ -61,7 +61,8 @@ void TlsTransport::Cleanup() {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
|
||||
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
|
||||
: Transport(lower, std::move(callback)) {
|
||||
|
||||
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
|
||||
|
||||
@ -82,6 +83,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transp
|
||||
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
|
||||
|
||||
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
||||
registerIncoming();
|
||||
|
||||
} catch (...) {
|
||||
|
||||
@ -271,10 +273,10 @@ void TlsTransport::Cleanup() {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
|
||||
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
|
||||
: Transport(lower, std::move(callback)) {
|
||||
|
||||
PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
|
||||
GlobalInit();
|
||||
|
||||
if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
|
||||
throw std::runtime_error("Failed to create SSL context");
|
||||
|
@ -41,7 +41,7 @@ class TcpTransport;
|
||||
|
||||
class TlsTransport : public Transport {
|
||||
public:
|
||||
TlsTransport(std::shared_ptr<TcpTransport> lower, string host);
|
||||
TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
|
||||
~TlsTransport();
|
||||
|
||||
bool stop() override;
|
||||
|
@ -32,7 +32,13 @@ using namespace std::placeholders;
|
||||
|
||||
class Transport {
|
||||
public:
|
||||
Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
|
||||
enum class State { Disconnected, Connecting, Connected, Completed, Failed };
|
||||
using state_callback = std::function<void(State state)>;
|
||||
|
||||
Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
|
||||
: mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
|
||||
}
|
||||
|
||||
virtual ~Transport() {
|
||||
stop();
|
||||
if (mLower)
|
||||
@ -49,11 +55,16 @@ public:
|
||||
}
|
||||
|
||||
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
|
||||
State state() const { return mState; }
|
||||
|
||||
virtual bool send(message_ptr message) { return outgoing(message); }
|
||||
|
||||
protected:
|
||||
void recv(message_ptr message) { mRecvCallback(message); }
|
||||
void changeState(State state) {
|
||||
if (mState.exchange(state) != state)
|
||||
mStateChangeCallback(state);
|
||||
}
|
||||
|
||||
virtual void incoming(message_ptr message) { recv(message); }
|
||||
virtual bool outgoing(message_ptr message) {
|
||||
@ -65,7 +76,10 @@ protected:
|
||||
|
||||
private:
|
||||
std::shared_ptr<Transport> mLower;
|
||||
synchronized_callback<State> mStateChangeCallback;
|
||||
synchronized_callback<message_ptr> mRecvCallback;
|
||||
|
||||
std::atomic<State> mState = State::Disconnected;
|
||||
std::atomic<bool> mShutdown = false;
|
||||
};
|
||||
|
||||
|
@ -1,100 +1,238 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) 2017-2018 by Paul-Louis Ageneau *
|
||||
* paul-louis (at) ageneau (dot) org *
|
||||
* *
|
||||
* This file is part of Plateform. *
|
||||
* *
|
||||
* Plateform is free software: you can redistribute it and/or modify *
|
||||
* it under the terms of the GNU Affero General Public License as *
|
||||
* published by the Free Software Foundation, either version 3 of *
|
||||
* the License, or (at your option) any later version. *
|
||||
* *
|
||||
* Plateform is distributed in the hope that it will be useful, but *
|
||||
* WITHOUT ANY WARRANTY; without even the implied warranty of *
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
|
||||
* GNU Affero General Public License for more details. *
|
||||
* *
|
||||
* You should have received a copy of the GNU Affero General Public *
|
||||
* License along with Plateform. *
|
||||
* If not, see <http://www.gnu.org/licenses/>. *
|
||||
*************************************************************************/
|
||||
/**
|
||||
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Lesser General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2.1 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Lesser General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Lesser General Public
|
||||
* License along with this library; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
*/
|
||||
|
||||
#include "net/websocket.hpp"
|
||||
#if ENABLE_WEBSOCKET
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include "include.hpp"
|
||||
#include "websocket.hpp"
|
||||
|
||||
const size_t DEFAULT_MAX_PAYLOAD_SIZE = 16384; // 16 KB
|
||||
#include "tcptransport.hpp"
|
||||
#include "tlstransport.hpp"
|
||||
#include "wstransport.hpp"
|
||||
|
||||
namespace net {
|
||||
#include <regex>
|
||||
|
||||
WebSocket::WebSocket(void) : mMaxPayloadSize(DEFAULT_MAX_PAYLOAD_SIZE) {}
|
||||
namespace rtc {
|
||||
|
||||
WebSocket::WebSocket() {}
|
||||
|
||||
WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
|
||||
|
||||
WebSocket::~WebSocket(void) {}
|
||||
WebSocket::~WebSocket() { close(); }
|
||||
|
||||
void WebSocket::open(const string &url) {
|
||||
close();
|
||||
static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
|
||||
static std::regex regex(rs, std::regex::extended);
|
||||
|
||||
mUrl = url;
|
||||
mThread = std::thread(&WebSocket::run, this);
|
||||
std::smatch match;
|
||||
if (!std::regex_match(url, match, regex))
|
||||
throw std::invalid_argument("Malformed WebSocket URL: " + url);
|
||||
|
||||
mScheme = match[2];
|
||||
if (mScheme != "ws" && mScheme != "wss")
|
||||
throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
|
||||
|
||||
mHost = match[4];
|
||||
if (auto pos = mHost.find(':'); pos != string::npos) {
|
||||
mHostname = mHost.substr(0, pos);
|
||||
mService = mHost.substr(pos + 1);
|
||||
} else {
|
||||
mHostname = mHost;
|
||||
mService = mScheme == "ws" ? "80" : "443";
|
||||
}
|
||||
|
||||
void WebSocket::close(void) {
|
||||
mWebSocket.close();
|
||||
if (mThread.joinable())
|
||||
mThread.join();
|
||||
mConnected = false;
|
||||
mPath = match[5];
|
||||
if (string query = match[7]; !query.empty())
|
||||
mPath += "?" + query;
|
||||
|
||||
initTcpTransport();
|
||||
}
|
||||
|
||||
bool WebSocket::isOpen(void) const { return mConnected; }
|
||||
void WebSocket::close() {
|
||||
resetCallbacks();
|
||||
closeTransports();
|
||||
}
|
||||
|
||||
bool WebSocket::isClosed(void) const { return !mThread.joinable(); }
|
||||
|
||||
void WebSocket::setMaxPayloadSize(size_t size) { mMaxPayloadSize = size; }
|
||||
void WebSocket::remoteClose() {
|
||||
mIsOpen = false;
|
||||
if (!mIsClosed.exchange(true))
|
||||
triggerClosed();
|
||||
}
|
||||
|
||||
bool WebSocket::send(const std::variant<binary, string> &data) {
|
||||
if (!std::holds_alternative<binary>(data))
|
||||
throw std::runtime_error("WebSocket string messages are not supported");
|
||||
|
||||
mWebSocket.write(std::get<binary>(data));
|
||||
return true;
|
||||
return std::visit(
|
||||
[&](const auto &d) {
|
||||
using T = std::decay_t<decltype(d)>;
|
||||
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
|
||||
auto *b = reinterpret_cast<const byte *>(d.data());
|
||||
return outgoing(std::make_shared<Message>(b, b + d.size(), type));
|
||||
},
|
||||
data);
|
||||
}
|
||||
|
||||
std::optional<std::variant<binary, string>> WebSocket::receive() {
|
||||
if (!mQueue.empty())
|
||||
return mQueue.pop();
|
||||
else
|
||||
return std::nullopt;
|
||||
bool WebSocket::isOpen() const { return mIsOpen; }
|
||||
|
||||
bool WebSocket::isClosed() const { return mIsClosed; }
|
||||
|
||||
size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
|
||||
|
||||
std::optional<std::variant<binary, string>> WebSocket::receive() { return nullopt; }
|
||||
|
||||
size_t WebSocket::availableAmount() const { return 0; }
|
||||
|
||||
bool WebSocket::outgoing(mutable_message_ptr message) {
|
||||
if (mIsClosed || !mWsTransport)
|
||||
throw std::runtime_error("WebSocket is closed");
|
||||
|
||||
if (message->size() > maxMessageSize())
|
||||
throw std::runtime_error("Message size exceeds limit");
|
||||
|
||||
return mWsTransport->send(message);
|
||||
}
|
||||
|
||||
void WebSocket::run(void) {
|
||||
if (mUrl.empty())
|
||||
return;
|
||||
|
||||
std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
|
||||
using State = TcpTransport::State;
|
||||
try {
|
||||
mWebSocket.connect(mUrl);
|
||||
std::lock_guard lock(mInitMutex);
|
||||
if (auto transport = std::atomic_load(&mTcpTransport))
|
||||
return transport;
|
||||
|
||||
mConnected = true;
|
||||
triggerOpen();
|
||||
|
||||
while (true) {
|
||||
binary payload;
|
||||
if (!mWebSocket.read(payload, mMaxPayloadSize))
|
||||
auto transport = std::make_shared<TcpTransport>(mHostname, mService, [this](State state) {
|
||||
switch (state) {
|
||||
case State::Connected:
|
||||
if (mScheme == "ws")
|
||||
initWsTransport();
|
||||
else
|
||||
initTlsTransport();
|
||||
break;
|
||||
case State::Failed:
|
||||
// TODO
|
||||
break;
|
||||
case State::Disconnected:
|
||||
// TODO
|
||||
break;
|
||||
default:
|
||||
// Ignore
|
||||
break;
|
||||
mQueue.push(std::move(payload));
|
||||
triggerAvailable(mQueue.size());
|
||||
}
|
||||
});
|
||||
std::atomic_store(&mTcpTransport, transport);
|
||||
return transport;
|
||||
} catch (const std::exception &e) {
|
||||
triggerError(e.what());
|
||||
PLOG_ERROR << e.what();
|
||||
// TODO
|
||||
throw std::runtime_error("TCP transport initialization failed");
|
||||
}
|
||||
}
|
||||
|
||||
mWebSocket.close();
|
||||
std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
|
||||
using State = TlsTransport::State;
|
||||
try {
|
||||
std::lock_guard lock(mInitMutex);
|
||||
if (auto transport = std::atomic_load(&mTlsTransport))
|
||||
return transport;
|
||||
|
||||
if (mConnected)
|
||||
triggerClosed();
|
||||
mConnected = false;
|
||||
auto lower = std::atomic_load(&mTcpTransport);
|
||||
auto transport = std::make_shared<TlsTransport>(lower, mHost, [this](State state) {
|
||||
switch (state) {
|
||||
case State::Connected:
|
||||
initWsTransport();
|
||||
break;
|
||||
case State::Failed:
|
||||
// TODO
|
||||
break;
|
||||
case State::Disconnected:
|
||||
// TODO
|
||||
break;
|
||||
default:
|
||||
// Ignore
|
||||
break;
|
||||
}
|
||||
});
|
||||
std::atomic_store(&mTlsTransport, transport);
|
||||
return transport;
|
||||
} catch (const std::exception &e) {
|
||||
PLOG_ERROR << e.what();
|
||||
// TODO
|
||||
throw std::runtime_error("TLS transport initialization failed");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace net
|
||||
std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
|
||||
using State = WsTransport::State;
|
||||
try {
|
||||
std::lock_guard lock(mInitMutex);
|
||||
if (auto transport = std::atomic_load(&mWsTransport))
|
||||
return transport;
|
||||
|
||||
std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
|
||||
if (!lower)
|
||||
lower = std::atomic_load(&mTcpTransport);
|
||||
auto transport = std::make_shared<WsTransport>(lower, mHost, mPath, [this](State state) {
|
||||
switch (state) {
|
||||
case State::Connected:
|
||||
triggerOpen();
|
||||
break;
|
||||
case State::Failed:
|
||||
// TODO
|
||||
break;
|
||||
case State::Disconnected:
|
||||
// TODO
|
||||
break;
|
||||
default:
|
||||
// Ignore
|
||||
break;
|
||||
}
|
||||
});
|
||||
std::atomic_store(&mWsTransport, transport);
|
||||
return transport;
|
||||
} catch (const std::exception &e) {
|
||||
PLOG_ERROR << e.what();
|
||||
// TODO
|
||||
throw std::runtime_error("WebSocket transport initialization failed");
|
||||
}
|
||||
}
|
||||
|
||||
void closeTransports() {
|
||||
mIsOpen = false;
|
||||
mIsClosed = true;
|
||||
|
||||
// Pass the references to a thread, allowing to terminate a transport from its own thread
|
||||
auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
|
||||
auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
|
||||
auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
|
||||
if (ws || dtls || tcp) {
|
||||
std::thread t([ws, dtls, tcp]() mutable {
|
||||
if (ws)
|
||||
ws->stop();
|
||||
if (dtls)
|
||||
dtls->stop();
|
||||
if (tcp)
|
||||
tcp->stop();
|
||||
|
||||
ws.reset();
|
||||
dtls.reset();
|
||||
tcp.reset();
|
||||
});
|
||||
t.detach();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif
|
||||
|
@ -53,11 +53,12 @@ using std::to_string;
|
||||
using random_bytes_engine =
|
||||
std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
|
||||
|
||||
WsTransport::WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path)
|
||||
: Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {}
|
||||
WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
|
||||
state_callback callback)
|
||||
: Transport(lower, std::move(callback)), mHost(std::move(host)), mPath(std::move(path)) {
|
||||
|
||||
WsTransport::WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path)
|
||||
: Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {}
|
||||
registerIncoming();
|
||||
}
|
||||
|
||||
WsTransport::~WsTransport() {}
|
||||
|
||||
|
@ -31,8 +31,8 @@ class TlsTransport;
|
||||
|
||||
class WsTransport : public Transport {
|
||||
public:
|
||||
WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path);
|
||||
WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path);
|
||||
WsTransport(std::shared_ptr<Transport> lower, string host, string path,
|
||||
state_callback callback);
|
||||
~WsTransport();
|
||||
|
||||
void stop() override;
|
||||
|
Reference in New Issue
Block a user