Merge branch 'dev' of github.com:paullouisageneau/libdatachannel into dev

This commit is contained in:
Paul-Louis Ageneau
2020-05-25 14:03:11 +02:00
32 changed files with 2294 additions and 293 deletions

View File

@ -6,6 +6,7 @@ project (libdatachannel
option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
option(USE_JUICE "Use libjuice instead of libnice" OFF)
option(RTC_ENABLE_WEBSOCKET "Build WebSocket support" ON)
if(USE_GNUTLS)
option(USE_NETTLE "Use Nettle instead of OpenSSL in libjuice" ON)
@ -39,6 +40,14 @@ set(LIBDATACHANNEL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
)
set(LIBDATACHANNEL_WEBSOCKET_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
)
set(LIBDATACHANNEL_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/candidate.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp
@ -55,6 +64,7 @@ set(LIBDATACHANNEL_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/websocket.hpp
)
set(TESTS_SOURCES
@ -89,26 +99,42 @@ endif()
add_library(Usrsctp::Usrsctp ALIAS usrsctp)
add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static)
add_library(datachannel SHARED ${LIBDATACHANNEL_SOURCES})
if (RTC_ENABLE_WEBSOCKET)
add_library(datachannel SHARED
${LIBDATACHANNEL_SOURCES}
${LIBDATACHANNEL_WEBSOCKET_SOURCES})
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL
${LIBDATACHANNEL_SOURCES}
${LIBDATACHANNEL_WEBSOCKET_SOURCES})
target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=1)
target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=1)
else()
add_library(datachannel SHARED
${LIBDATACHANNEL_SOURCES})
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL
${LIBDATACHANNEL_SOURCES})
target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=0)
target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=0)
endif()
set_target_properties(datachannel PROPERTIES
VERSION ${PROJECT_VERSION}
CXX_STANDARD 17)
set_target_properties(datachannel-static PROPERTIES
VERSION ${PROJECT_VERSION}
CXX_STANDARD 17)
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES})
set_target_properties(datachannel-static PROPERTIES
VERSION ${PROJECT_VERSION}
CXX_STANDARD 17)
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic)
if(WIN32)

View File

@ -10,6 +10,7 @@ lib libdatachannel
<cxxstd>17
<include>./include/rtc
<define>USE_JUICE=1
<define>RTC_ENABLE_WEBSOCKET=0
<library>/libdatachannel//usrsctp
<library>/libdatachannel//juice
<library>/libdatachannel//plog

View File

@ -38,6 +38,14 @@ else
LIBS+=glib-2.0 gobject-2.0 nice
endif
RTC_ENABLE_WEBSOCKET ?= 1
ifneq ($(RTC_ENABLE_WEBSOCKET), 0)
CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=1
else
CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=0
endif
INCLUDES+=$(shell pkg-config --cflags $(LIBS))
LDLIBS+=$(LOCALLIBS) $(shell pkg-config --libs $(LIBS))

View File

@ -82,7 +82,6 @@ private:
std::atomic<bool> mIsClosed = false;
Queue<message_ptr> mRecvQueue;
std::atomic<size_t> mRecvAmount = 0;
friend class PeerConnection;
};

View File

@ -19,6 +19,10 @@
#ifndef RTC_INCLUDE_H
#define RTC_INCLUDE_H
#ifndef RTC_ENABLE_WEBSOCKET
#define RTC_ENABLE_WEBSOCKET 1
#endif
#ifdef _WIN32
#ifndef _WIN32_WINNT
#define _WIN32_WINNT 0x0602
@ -56,10 +60,21 @@ const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
// overloaded helper
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
// weak_ptr bind helper
template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
using result_type = typename decltype(bound)::result_type;
if (auto shared_this = weak_this.lock())
return bound(args...);
else
return (result_type) false;
};
}
template <typename... P> class synchronized_callback {
public:
synchronized_callback() = default;

View File

@ -30,6 +30,7 @@ namespace rtc {
struct Message : binary {
enum Type { Binary, String, Control, Reset };
Message(const Message &message) = default;
Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {}
template <typename Iterator>

View File

@ -98,8 +98,6 @@ public:
std::string connectionInfo;
private:
init_token mInitToken = Init::Token();
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
std::shared_ptr<DtlsTransport> initDtlsTransport();
std::shared_ptr<SctpTransport> initSctpTransport();
@ -130,6 +128,8 @@ private:
const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate;
init_token mInitToken = Init::Token();
std::optional<Description> mLocalDescription, mRemoteDescription;
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;

View File

@ -44,6 +44,7 @@ public:
void push(T element);
std::optional<T> pop();
std::optional<T> peek();
std::optional<T> exchange(T element);
bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
private:
@ -118,6 +119,16 @@ template <typename T> std::optional<T> Queue<T>::peek() {
}
}
template <typename T> std::optional<T> Queue<T>::exchange(T element) {
std::unique_lock lock(mMutex);
if (!mQueue.empty()) {
std::swap(mQueue.front(), element);
return std::optional<T>{element};
} else {
return nullopt;
}
}
template <typename T>
bool Queue<T>::wait(const std::optional<std::chrono::milliseconds> &duration) {
std::unique_lock lock(mMutex);

View File

@ -27,6 +27,10 @@ extern "C" {
// libdatachannel C API
#ifndef RTC_ENABLE_WEBSOCKET
#define RTC_ENABLE_WEBSOCKET 1
#endif
typedef enum {
RTC_NEW = 0,
RTC_CONNECTING = 1,
@ -42,8 +46,7 @@ typedef enum {
RTC_GATHERING_COMPLETE = 2
} rtcGatheringState;
// Don't change, it must match plog severity
typedef enum {
typedef enum { // Don't change, it must match plog severity
RTC_LOG_NONE = 0,
RTC_LOG_FATAL = 1,
RTC_LOG_ERROR = 2,
@ -76,10 +79,10 @@ typedef void (*availableCallbackFunc)(void *ptr);
void rtcInitLogger(rtcLogLevel level);
// User pointer
void rtcSetUserPointer(int i, void *ptr);
void rtcSetUserPointer(int id, void *ptr);
// PeerConnection
int rtcCreatePeerConnection(const rtcConfiguration *config);
int rtcCreatePeerConnection(const rtcConfiguration *config); // returns pc id
int rtcDeletePeerConnection(int pc);
int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb);
@ -95,24 +98,32 @@ int rtcGetLocalAddress(int pc, char *buffer, int size);
int rtcGetRemoteAddress(int pc, char *buffer, int size);
// DataChannel
int rtcCreateDataChannel(int pc, const char *label);
int rtcCreateDataChannel(int pc, const char *label); // returns dc id
int rtcDeleteDataChannel(int dc);
int rtcGetDataChannelLabel(int dc, char *buffer, int size);
int rtcSetOpenCallback(int dc, openCallbackFunc cb);
int rtcSetClosedCallback(int dc, closedCallbackFunc cb);
int rtcSetErrorCallback(int dc, errorCallbackFunc cb);
int rtcSetMessageCallback(int dc, messageCallbackFunc cb);
int rtcSendMessage(int dc, const char *data, int size);
int rtcGetBufferedAmount(int dc); // total size buffered to send
int rtcSetBufferedAmountLowThreshold(int dc, int amount);
int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb);
// WebSocket
#if RTC_ENABLE_WEBSOCKET
int rtcCreateWebSocket(const char *url); // returns ws id
int rtcDeleteWebsocket(int ws);
#endif
// DataChannel extended API
int rtcGetAvailableAmount(int dc); // total size available to receive
int rtcSetAvailableCallback(int dc, availableCallbackFunc cb);
int rtcReceiveMessage(int dc, char *buffer, int *size);
// DataChannel and WebSocket common API
int rtcSetOpenCallback(int id, openCallbackFunc cb);
int rtcSetClosedCallback(int id, closedCallbackFunc cb);
int rtcSetErrorCallback(int id, errorCallbackFunc cb);
int rtcSetMessageCallback(int id, messageCallbackFunc cb);
int rtcSendMessage(int id, const char *data, int size);
int rtcGetBufferedAmount(int id); // total size buffered to send
int rtcSetBufferedAmountLowThreshold(int id, int amount);
int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb);
// DataChannel and WebSocket common extended API
int rtcGetAvailableAmount(int id); // total size available to receive
int rtcSetAvailableCallback(int id, availableCallbackFunc cb);
int rtcReceiveMessage(int id, char *buffer, int *size);
// Cleanup
void rtcCleanup();

View File

@ -23,6 +23,7 @@
//
#include "datachannel.hpp"
#include "peerconnection.hpp"
#include "websocket.hpp"
// C API
#include "rtc.h"

95
include/rtc/websocket.hpp Normal file
View File

@ -0,0 +1,95 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_WEBSOCKET_H
#define RTC_WEBSOCKET_H
#if RTC_ENABLE_WEBSOCKET
#include "channel.hpp"
#include "include.hpp"
#include "init.hpp"
#include "message.hpp"
#include "queue.hpp"
#include <atomic>
#include <optional>
#include <thread>
#include <variant>
namespace rtc {
class TcpTransport;
class TlsTransport;
class WsTransport;
class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
public:
enum class State : int {
Connecting = 0,
Open = 1,
Closing = 2,
Closed = 3,
};
WebSocket();
WebSocket(const string &url);
~WebSocket();
State readyState() const;
void open(const string &url);
void close() override;
bool send(const std::variant<binary, string> &data) override;
bool isOpen() const override;
bool isClosed() const override;
size_t maxMessageSize() const override;
// Extended API
std::optional<std::variant<binary, string>> receive() override;
size_t availableAmount() const override; // total size available to receive
private:
bool changeState(State state);
void remoteClose();
bool outgoing(mutable_message_ptr message);
void incoming(message_ptr message);
std::shared_ptr<TcpTransport> initTcpTransport();
std::shared_ptr<TlsTransport> initTlsTransport();
std::shared_ptr<WsTransport> initWsTransport();
void closeTransports();
init_token mInitToken = Init::Token();
std::shared_ptr<TcpTransport> mTcpTransport;
std::shared_ptr<TlsTransport> mTlsTransport;
std::shared_ptr<WsTransport> mWsTransport;
std::recursive_mutex mInitMutex;
string mScheme, mHost, mHostname, mService, mPath;
std::atomic<State> mState = State::Closed;
Queue<message_ptr> mRecvQueue;
};
} // namespace rtc
#endif
#endif // RTC_WEBSOCKET_H

65
src/base64.cpp Normal file
View File

@ -0,0 +1,65 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#if RTC_ENABLE_WEBSOCKET
#include "base64.hpp"
namespace rtc {
using std::to_integer;
string to_base64(const binary &data) {
static const char tab[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
string out;
out.reserve(3 * ((data.size() + 3) / 4));
int i = 0;
while (data.size() - i >= 3) {
auto d0 = to_integer<uint8_t>(data[i]);
auto d1 = to_integer<uint8_t>(data[i + 1]);
auto d2 = to_integer<uint8_t>(data[i + 2]);
out += tab[d0 >> 2];
out += tab[((d0 & 3) << 4) | (d1 >> 4)];
out += tab[((d1 & 0x0F) << 2) | (d2 >> 6)];
out += tab[d2 & 0x3F];
i += 3;
}
int left = data.size() - i;
if (left) {
auto d0 = to_integer<uint8_t>(data[i]);
out += tab[d0 >> 2];
if (left == 1) {
out += tab[(d0 & 3) << 4];
out += '=';
} else { // left == 2
auto d1 = to_integer<uint8_t>(data[i + 1]);
out += tab[((d0 & 3) << 4) | (d1 >> 4)];
out += tab[(d1 & 0x0F) << 2];
}
out += '=';
}
return out;
}
} // namespace rtc
#endif

34
src/base64.hpp Normal file
View File

@ -0,0 +1,34 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_BASE64_H
#define RTC_BASE64_H
#if RTC_ENABLE_WEBSOCKET
#include "include.hpp"
namespace rtc {
string to_base64(const binary &data);
}
#endif
#endif

View File

@ -214,6 +214,9 @@ bool DataChannel::outgoing(mutable_message_ptr message) {
}
void DataChannel::incoming(message_ptr message) {
if (!message)
return;
switch (message->type) {
case Message::Control: {
auto raw = reinterpret_cast<const uint8_t *>(message->data());

View File

@ -18,9 +18,7 @@
#include "dtlstransport.hpp"
#include "icetransport.hpp"
#include "message.hpp"
#include <cassert>
#include <chrono>
#include <cstring>
#include <exception>
@ -64,11 +62,9 @@ void DtlsTransport::Cleanup() {
}
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback,
state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
mVerifierCallback(std::move(verifierCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
@ -76,13 +72,14 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
check_gnutls(gnutls_init(&mSession, flags));
try {
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
// See https://tools.ietf.org/html/rfc8261#section-5
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
const char *err_pos = NULL;
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
"Unable to set TLS priorities");
"Failed to set TLS priorities");
gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
check_gnutls(
@ -101,6 +98,11 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
registerIncoming();
} catch (...) {
gnutls_deinit(mSession);
throw;
}
}
DtlsTransport::~DtlsTransport() {
@ -109,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
gnutls_deinit(mSession);
}
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::stop() {
if (!Transport::stop())
return false;
@ -122,7 +122,7 @@ bool DtlsTransport::stop() {
}
bool DtlsTransport::send(message_ptr message) {
if (!message || mState != State::Connected)
if (!message || state() != State::Connected)
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -148,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
mIncomingQueue.push(message);
}
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096;
@ -169,7 +164,7 @@ void DtlsTransport::runRecvLoop() {
throw std::runtime_error("MTU is too low");
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
!check_gnutls(ret, "TLS handshake failed"));
!check_gnutls(ret, "DTLS handshake failed"));
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
// See https://tools.ietf.org/html/rfc8261#section-5
@ -183,7 +178,7 @@ void DtlsTransport::runRecvLoop() {
// Receive loop
try {
PLOG_INFO << "DTLS handshake done";
PLOG_INFO << "DTLS handshake finished";
changeState(State::Connected);
const size_t bufferSize = maxMtu;
@ -218,7 +213,7 @@ void DtlsTransport::runRecvLoop() {
gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
PLOG_INFO << "DTLS disconnected";
PLOG_INFO << "DTLS closed";
changeState(State::Disconnected);
recv(nullptr);
}
@ -341,7 +336,7 @@ void DtlsTransport::Init() {
if (!BioMethods) {
BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
if (!BioMethods)
throw std::runtime_error("Unable to BIO methods for DTLS writer");
throw std::runtime_error("Failed to create BIO methods for DTLS writer");
BIO_meth_set_create(BioMethods, BioMethodNew);
BIO_meth_set_destroy(BioMethods, BioMethodFree);
BIO_meth_set_write(BioMethods, BioMethodWrite);
@ -358,17 +353,17 @@ void DtlsTransport::Cleanup() {
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
mVerifierCallback(std::move(verifierCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
try {
if (!(mCtx = SSL_CTX_new(DTLS_method())))
throw std::runtime_error("Unable to create SSL context");
throw std::runtime_error("Failed to create SSL context");
check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
"Unable to set SSL priorities");
"Failed to set SSL priorities");
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
@ -389,7 +384,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
if (!(mSsl = SSL_new(mCtx)))
throw std::runtime_error("Unable to create SSL instance");
throw std::runtime_error("Failed to create SSL instance");
SSL_set_ex_data(mSsl, TransportExIndex, this);
@ -399,7 +394,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
SSL_set_accept_state(mSsl);
if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods)))
throw std::runtime_error("Unable to create BIO");
throw std::runtime_error("Failed to create BIO");
BIO_set_mem_eof_return(mInBio, BIO_EOF);
BIO_set_data(mOutBio, this);
@ -412,6 +407,14 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
registerIncoming();
} catch (...) {
if (mSsl)
SSL_free(mSsl);
if (mCtx)
SSL_CTX_free(mCtx);
throw;
}
}
DtlsTransport::~DtlsTransport() {
@ -432,18 +435,14 @@ bool DtlsTransport::stop() {
return true;
}
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::send(message_ptr message) {
if (!message || mState != State::Connected)
if (!message || state() != State::Connected)
return false;
PLOG_VERBOSE << "Send size=" << message->size();
int ret = SSL_write(mSsl, message->data(), message->size());
if (!check_openssl_ret(mSsl, ret))
return false;
return true;
return check_openssl_ret(mSsl, ret);
}
void DtlsTransport::incoming(message_ptr message) {
@ -456,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
mIncomingQueue.push(message);
}
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096;
try {
@ -479,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
auto message = *mIncomingQueue.pop();
BIO_write(mInBio, message->data(), message->size());
if (mState == State::Connecting) {
if (state() == State::Connecting) {
// Continue the handshake
int ret = SSL_do_handshake(mSsl);
if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
@ -490,7 +484,7 @@ void DtlsTransport::runRecvLoop() {
// MTU See https://tools.ietf.org/html/rfc8261#section-5
SSL_set_mtu(mSsl, maxMtu + 1);
PLOG_INFO << "DTLS handshake done";
PLOG_INFO << "DTLS handshake finished";
changeState(State::Connected);
}
} else {
@ -504,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
// No more messages pending, retransmit and rearm timeout if connecting
std::optional<milliseconds> duration;
if (mState == State::Connecting) {
if (state() == State::Connecting) {
// Warning: This function breaks the usual return value convention
int ret = DTLSv1_handle_timeout(mSsl);
if (ret < 0) {
@ -514,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
}
struct timeval timeout = {};
if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
// Also handle handshake timeout manually because OpenSSL actually doesn't...
// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
@ -535,8 +529,8 @@ void DtlsTransport::runRecvLoop() {
PLOG_ERROR << "DTLS recv: " << e.what();
}
if (mState == State::Connected) {
PLOG_INFO << "DTLS disconnected";
if (state() == State::Connected) {
PLOG_INFO << "DTLS closed";
changeState(State::Disconnected);
recv(nullptr);
} else {

View File

@ -46,33 +46,25 @@ public:
static void Init();
static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using verifier_callback = std::function<bool(const std::string &fingerprint)>;
using state_callback = std::function<void(State state)>;
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback);
~DtlsTransport();
State state() const;
bool stop() override;
bool send(message_ptr message) override; // false if dropped
private:
void incoming(message_ptr message) override;
void changeState(State state);
void runRecvLoop();
const std::shared_ptr<Certificate> mCertificate;
Queue<message_ptr> mIncomingQueue;
std::atomic<State> mState;
std::thread mRecvThread;
verifier_callback mVerifierCallback;
state_callback mStateChangeCallback;
#if USE_GNUTLS
gnutls_session_t mSession;
@ -82,8 +74,8 @@ private:
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
#else
SSL_CTX *mCtx;
SSL *mSsl;
SSL_CTX *mCtx = NULL;
SSL *mSsl = NULL;
BIO *mInBio, *mOutBio;
static BIO_METHOD *BioMethods;

View File

@ -48,9 +48,8 @@ namespace rtc {
IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback)
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
mAgent(nullptr, nullptr) {
@ -84,6 +83,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
mStunService = server.service;
jconfig.stun_server_host = mStunHostname.c_str();
jconfig.stun_server_port = std::stoul(mStunService);
break;
}
}
@ -108,8 +108,6 @@ bool IceTransport::stop() {
Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const {
char sdp[JUICE_MAX_SDP_STRING_LEN];
if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
@ -161,7 +159,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
}
bool IceTransport::send(message_ptr message) {
if (!message || (mState != State::Connected && mState != State::Completed))
auto s = state();
if (!message || (s != State::Connected && s != State::Completed))
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -173,18 +172,29 @@ bool IceTransport::outgoing(message_ptr message) {
message->size()) >= 0;
}
void IceTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(mGatheringState);
}
void IceTransport::processStateChange(unsigned int state) {
changeState(static_cast<State>(state));
switch (state) {
case JUICE_STATE_DISCONNECTED:
changeState(State::Disconnected);
break;
case JUICE_STATE_CONNECTING:
changeState(State::Connecting);
break;
case JUICE_STATE_CONNECTED:
changeState(State::Connected);
break;
case JUICE_STATE_COMPLETED:
changeState(State::Completed);
break;
case JUICE_STATE_FAILED:
changeState(State::Failed);
break;
};
}
void IceTransport::processCandidate(const string &candidate) {
@ -263,9 +273,8 @@ namespace rtc {
IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback)
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
@ -457,8 +466,6 @@ bool IceTransport::stop() {
Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const {
// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
// role, and the other MUST take the controlled role.
@ -529,7 +536,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
}
bool IceTransport::send(message_ptr message) {
if (!message || (mState != State::Connected && mState != State::Completed))
auto s = state();
if (!message || (s != State::Connected && s != State::Completed))
return false;
PLOG_VERBOSE << "Send size=" << message->size();
@ -541,11 +549,6 @@ bool IceTransport::outgoing(message_ptr message) {
reinterpret_cast<const char *>(message->data())) >= 0;
}
void IceTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(mGatheringState);
@ -576,7 +579,23 @@ void IceTransport::processStateChange(unsigned int state) {
mTimeoutId = 0;
}
changeState(static_cast<State>(state));
switch (state) {
case NICE_COMPONENT_STATE_DISCONNECTED:
changeState(State::Disconnected);
break;
case NICE_COMPONENT_STATE_CONNECTING:
changeState(State::Connecting);
break;
case NICE_COMPONENT_STATE_CONNECTED:
changeState(State::Connected);
break;
case NICE_COMPONENT_STATE_READY:
changeState(State::Completed);
break;
case NICE_COMPONENT_STATE_FAILED:
changeState(State::Failed);
break;
};
}
string IceTransport::AddressToString(const NiceAddress &addr) {

View File

@ -40,29 +40,9 @@ namespace rtc {
class IceTransport : public Transport {
public:
#if USE_JUICE
enum class State : unsigned int{
Disconnected = JUICE_STATE_DISCONNECTED,
Connecting = JUICE_STATE_CONNECTING,
Connected = JUICE_STATE_CONNECTED,
Completed = JUICE_STATE_COMPLETED,
Failed = JUICE_STATE_FAILED,
};
#else
enum class State : unsigned int {
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
Connecting = NICE_COMPONENT_STATE_CONNECTING,
Connected = NICE_COMPONENT_STATE_CONNECTED,
Completed = NICE_COMPONENT_STATE_READY,
Failed = NICE_COMPONENT_STATE_FAILED,
};
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
#endif
enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
using candidate_callback = std::function<void(const Candidate &candidate)>;
using state_callback = std::function<void(State state)>;
using gathering_state_callback = std::function<void(GatheringState state)>;
IceTransport(const Configuration &config, Description::Role role,
@ -71,7 +51,6 @@ public:
~IceTransport();
Description::Role role() const;
State state() const;
GatheringState gatheringState() const;
Description getLocalDescription(Description::Type type) const;
void setRemoteDescription(const Description &description);
@ -84,10 +63,13 @@ public:
bool stop() override;
bool send(message_ptr message) override; // false if dropped
#if !USE_JUICE
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
#endif
private:
bool outgoing(message_ptr message) override;
void changeState(State state);
void changeGatheringState(GatheringState state);
void processStateChange(unsigned int state);
@ -98,11 +80,9 @@ private:
Description::Role mRole;
string mMid;
std::chrono::milliseconds mTrickleTimeout;
std::atomic<State> mState;
std::atomic<GatheringState> mGatheringState;
candidate_callback mCandidateCallback;
state_callback mStateChangeCallback;
gathering_state_callback mGatheringStateChangeCallback;
#if USE_JUICE

View File

@ -21,6 +21,10 @@
#include "dtlstransport.hpp"
#include "sctptransport.hpp"
#if RTC_ENABLE_WEBSOCKET
#include "tlstransport.hpp"
#endif
#ifdef _WIN32
#include <winsock2.h>
#endif
@ -69,13 +73,19 @@ Init::Init() {
ERR_load_crypto_strings();
#endif
DtlsTransport::Init();
SctpTransport::Init();
DtlsTransport::Init();
#if RTC_ENABLE_WEBSOCKET
TlsTransport::Init();
#endif
}
Init::~Init() {
DtlsTransport::Cleanup();
SctpTransport::Cleanup();
DtlsTransport::Cleanup();
#if RTC_ENABLE_WEBSOCKET
TlsTransport::Cleanup();
#endif
#ifdef _WIN32
WSACleanup();

View File

@ -23,7 +23,6 @@
#include "include.hpp"
#include "sctptransport.hpp"
#include <iostream>
#include <thread>
namespace rtc {
@ -33,23 +32,6 @@ using namespace std::placeholders;
using std::shared_ptr;
using std::weak_ptr;
template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
if (auto shared_this = weak_this.lock())
bound(args...);
};
}
template <typename F, typename T, typename... Args>
auto weak_bind_verifier(F &&f, T *t, Args &&... _args) {
return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
if (auto shared_this = weak_this.lock())
return bound(args...);
else
return false;
};
}
PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
PeerConnection::PeerConnection(const Configuration &config)
@ -271,7 +253,7 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
auto lower = std::atomic_load(&mIceTransport);
auto transport = std::make_shared<DtlsTransport>(
lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
lower, mCertificate, weak_bind(&PeerConnection::checkFingerprint, this, _1),
[this, weak_this = weak_from_this()](DtlsTransport::State state) {
auto shared_this = weak_this.lock();
if (!shared_this)

View File

@ -16,10 +16,15 @@
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#include "datachannel.hpp"
#include "include.hpp"
#include "datachannel.hpp"
#include "peerconnection.hpp"
#if RTC_ENABLE_WEBSOCKET
#include "websocket.hpp"
#endif
#include <rtc.h>
#include <exception>
@ -43,6 +48,9 @@ namespace {
std::unordered_map<int, shared_ptr<PeerConnection>> peerConnectionMap;
std::unordered_map<int, shared_ptr<DataChannel>> dataChannelMap;
#if RTC_ENABLE_WEBSOCKET
std::unordered_map<int, shared_ptr<WebSocket>> webSocketMap;
#endif
std::unordered_map<int, void *> userPointerMap;
std::mutex mutex;
int lastId = 0;
@ -103,6 +111,40 @@ bool eraseDataChannel(int dc) {
return true;
}
#if RTC_ENABLE_WEBSOCKET
shared_ptr<WebSocket> getWebSocket(int id) {
std::lock_guard lock(mutex);
auto it = webSocketMap.find(id);
return it != webSocketMap.end() ? it->second : nullptr;
}
int emplaceWebSocket(shared_ptr<WebSocket> ptr) {
std::lock_guard lock(mutex);
int ws = ++lastId;
webSocketMap.emplace(std::make_pair(ws, ptr));
return ws;
}
bool eraseWebSocket(int ws) {
std::lock_guard lock(mutex);
if (webSocketMap.erase(ws) == 0)
return false;
userPointerMap.erase(ws);
return true;
}
#endif
shared_ptr<Channel> getChannel(int id) {
std::lock_guard lock(mutex);
if (auto it = dataChannelMap.find(id); it != dataChannelMap.end())
return it->second;
#if RTC_ENABLE_WEBSOCKET
if (auto it = webSocketMap.find(id); it != webSocketMap.end())
return it->second;
#endif
return nullptr;
}
} // namespace
void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
@ -164,6 +206,29 @@ int rtcDeleteDataChannel(int dc) {
return 0;
}
#if RTC_ENABLE_WEBSOCKET
int rtcCreateWebSocket(const char *url) {
return emplaceWebSocket(std::make_shared<WebSocket>(url));
}
int rtcDeleteWebsocket(int ws) {
auto webSocket = getWebSocket(ws);
if (!webSocket)
return -1;
webSocket->onOpen(nullptr);
webSocket->onClosed(nullptr);
webSocket->onError(nullptr);
webSocket->onMessage(nullptr);
webSocket->onBufferedAmountLow(nullptr);
webSocket->onAvailable(nullptr);
eraseWebSocket(ws);
return 0;
}
#endif
int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb) {
auto peerConnection = getPeerConnection(pc);
if (!peerConnection)
@ -298,135 +363,135 @@ int rtcGetDataChannelLabel(int dc, char *buffer, int size) {
return size + 1;
}
int rtcSetOpenCallback(int dc, openCallbackFunc cb) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetOpenCallback(int id, openCallbackFunc cb) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (cb)
dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
else
dataChannel->onOpen(nullptr);
channel->onOpen(nullptr);
return 0;
}
int rtcSetClosedCallback(int dc, closedCallbackFunc cb) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetClosedCallback(int id, closedCallbackFunc cb) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (cb)
dataChannel->onClosed([dc, cb]() { cb(getUserPointer(dc)); });
channel->onClosed([id, cb]() { cb(getUserPointer(id)); });
else
dataChannel->onClosed(nullptr);
channel->onClosed(nullptr);
return 0;
}
int rtcSetErrorCallback(int dc, errorCallbackFunc cb) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetErrorCallback(int id, errorCallbackFunc cb) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (cb)
dataChannel->onError(
[dc, cb](const string &error) { cb(error.c_str(), getUserPointer(dc)); });
channel->onError([id, cb](const string &error) { cb(error.c_str(), getUserPointer(id)); });
else
dataChannel->onError(nullptr);
channel->onError(nullptr);
return 0;
}
int rtcSetMessageCallback(int dc, messageCallbackFunc cb) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetMessageCallback(int id, messageCallbackFunc cb) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (cb)
dataChannel->onMessage(
[dc, cb](const binary &b) {
cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(dc));
channel->onMessage(
[id, cb](const binary &b) {
cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(id));
},
[dc, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(dc)); });
[id, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(id)); });
else
dataChannel->onMessage(nullptr);
channel->onMessage(nullptr);
return 0;
}
int rtcSendMessage(int dc, const char *data, int size) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSendMessage(int id, const char *data, int size) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (size >= 0) {
auto b = reinterpret_cast<const byte *>(data);
CATCH(dataChannel->send(b, size));
CATCH(channel->send(binary(b, b + size)));
return size;
} else {
string s(data);
CATCH(dataChannel->send(s));
return s.size();
string str(data);
int len = str.size();
CATCH(channel->send(std::move(str)));
return len;
}
}
int rtcGetBufferedAmount(int dc) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcGetBufferedAmount(int id) {
auto channel = getChannel(id);
if (!channel)
return -1;
CATCH(return int(dataChannel->bufferedAmount()));
CATCH(return int(channel->bufferedAmount()));
}
int rtcSetBufferedAmountLowThreshold(int dc, int amount) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetBufferedAmountLowThreshold(int id, int amount) {
auto channel = getChannel(id);
if (!channel)
return -1;
CATCH(dataChannel->setBufferedAmountLowThreshold(size_t(amount)));
CATCH(channel->setBufferedAmountLowThreshold(size_t(amount)));
return 0;
}
int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (cb)
dataChannel->onBufferedAmountLow([dc, cb]() { cb(getUserPointer(dc)); });
channel->onBufferedAmountLow([id, cb]() { cb(getUserPointer(id)); });
else
dataChannel->onBufferedAmountLow(nullptr);
channel->onBufferedAmountLow(nullptr);
return 0;
}
int rtcGetAvailableAmount(int dc) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcGetAvailableAmount(int id) {
auto channel = getChannel(id);
if (!channel)
return -1;
CATCH(return int(dataChannel->availableAmount()));
CATCH(return int(channel->availableAmount()));
}
int rtcSetAvailableCallback(int dc, availableCallbackFunc cb) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcSetAvailableCallback(int id, availableCallbackFunc cb) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (cb)
dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
else
dataChannel->onOpen(nullptr);
channel->onOpen(nullptr);
return 0;
}
int rtcReceiveMessage(int dc, char *buffer, int *size) {
auto dataChannel = getDataChannel(dc);
if (!dataChannel)
int rtcReceiveMessage(int id, char *buffer, int *size) {
auto channel = getChannel(id);
if (!channel)
return -1;
if (!size)
return -1;
CATCH({
auto message = dataChannel->receive();
auto message = channel->receive();
if (!message)
return 0;

View File

@ -71,9 +71,8 @@ void SctpTransport::Cleanup() {
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
message_callback recvCallback, amount_callback bufferedAmountCallback,
state_callback stateChangeCallback)
: Transport(lower), mPort(port), mSendQueue(0, message_size_func),
mBufferedAmountCallback(std::move(bufferedAmountCallback)),
mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
: Transport(lower, std::move(stateChangeCallback)), mPort(port),
mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
onRecv(recvCallback);
PLOG_DEBUG << "Initializing SCTP transport";
@ -180,8 +179,6 @@ SctpTransport::~SctpTransport() {
usrsctp_deregister_address(this);
}
SctpTransport::State SctpTransport::state() const { return mState; }
bool SctpTransport::stop() {
if (!Transport::stop())
return false;
@ -240,6 +237,7 @@ void SctpTransport::shutdown() {
bool SctpTransport::send(message_ptr message) {
std::lock_guard lock(mSendMutex);
if (!message)
return mSendQueue.empty();
@ -269,7 +267,7 @@ void SctpTransport::incoming(message_ptr message) {
// to be sent on our side (i.e. the local INIT) before proceeding.
{
std::unique_lock lock(mWriteMutex);
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
}
if (!message) {
@ -283,11 +281,6 @@ void SctpTransport::incoming(message_ptr message) {
usrsctp_conninput(this, message->data(), message->size(), 0);
}
void SctpTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
bool SctpTransport::trySendQueue() {
// Requires mSendMutex to be locked
while (auto next = mSendQueue.peek()) {
@ -302,7 +295,7 @@ bool SctpTransport::trySendQueue() {
bool SctpTransport::trySendMessage(message_ptr message) {
// Requires mSendMutex to be locked
if (!mSock || mState != State::Connected)
if (!mSock || state() != State::Connected)
return false;
uint32_t ppid;
@ -414,7 +407,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
mWrittenCondition.wait_for(lock, 1000ms,
[&]() { return mWritten || mState != State::Connected; });
[&]() { return mWritten || state() != State::Connected; });
} else if (errno == EINVAL) {
PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
} else {
@ -571,7 +564,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
PLOG_INFO << "SCTP connected";
changeState(State::Connected);
} else {
if (mState == State::Connecting) {
if (state() == State::Connecting) {
PLOG_ERROR << "SCTP connection failed";
changeState(State::Failed);
} else {

View File

@ -38,17 +38,12 @@ public:
static void Init();
static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
using state_callback = std::function<void(State state)>;
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
~SctpTransport();
State state() const;
bool stop() override;
bool send(message_ptr message) override; // false if buffered
void close(unsigned int stream);
@ -76,7 +71,6 @@ private:
void connect();
void shutdown();
void incoming(message_ptr message) override;
void changeState(State state);
bool trySendQueue();
bool trySendMessage(message_ptr message);
@ -105,14 +99,11 @@ private:
std::atomic<bool> mWritten = false; // written outside lock
bool mWrittenOnce = false;
state_callback mStateChangeCallback;
std::atomic<State> mState;
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
// Stats
std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
struct sctp_rcvinfo recv_info, int flags, void *user_data);
static int SendCallback(struct socket *sock, uint32_t sb_free);

320
src/tcptransport.cpp Normal file
View File

@ -0,0 +1,320 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#if RTC_ENABLE_WEBSOCKET
#include "tcptransport.hpp"
#include <exception>
#ifndef _WIN32
#include <fcntl.h>
#include <unistd.h>
#endif
namespace rtc {
using std::to_string;
SelectInterrupter::SelectInterrupter() {
#ifndef _WIN32
int pipefd[2];
if (::pipe(pipefd) != 0)
throw std::runtime_error("Failed to create pipe");
::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
mPipeOut = pipefd[0]; // read
mPipeIn = pipefd[1]; // write
#endif
}
SelectInterrupter::~SelectInterrupter() {
std::lock_guard lock(mMutex);
#ifdef _WIN32
if (mDummySock != INVALID_SOCKET)
::closesocket(mDummySock);
#else
::close(mPipeIn);
::close(mPipeOut);
#endif
}
int SelectInterrupter::prepare(fd_set &readfds, fd_set &writefds) {
std::lock_guard lock(mMutex);
#ifdef _WIN32
if (mDummySock == INVALID_SOCKET)
mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
FD_SET(mDummySock, &readfds);
return SOCK_TO_INT(mDummySock) + 1;
#else
int ret;
do {
char dummy;
ret = ::read(mPipeIn, &dummy, 1);
} while (ret > 0);
FD_SET(mPipeIn, &readfds);
return mPipeIn + 1;
#endif
}
void SelectInterrupter::interrupt() {
std::lock_guard lock(mMutex);
#ifdef _WIN32
if (mDummySock != INVALID_SOCKET) {
::closesocket(mDummySock);
mDummySock = INVALID_SOCKET;
}
#else
char dummy = 0;
::write(mPipeOut, &dummy, 1);
#endif
}
TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
: Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
PLOG_DEBUG << "Initializing TCP transport";
mThread = std::thread(&TcpTransport::runLoop, this);
}
TcpTransport::~TcpTransport() {
stop();
}
bool TcpTransport::stop() {
if (!Transport::stop())
return false;
PLOG_DEBUG << "Waiting TCP recv thread";
close();
mThread.join();
return true;
}
bool TcpTransport::send(message_ptr message) {
if (!message)
return mSendQueue.empty();
PLOG_VERBOSE << "Send size=" << (message ? message->size() : 0);
return outgoing(message);
}
void TcpTransport::incoming(message_ptr message) { recv(message); }
bool TcpTransport::outgoing(message_ptr message) {
// If nothing is pending, try to send directly
// It's safe because if the queue is empty, the thread is not sending
if (mSendQueue.empty() && trySendMessage(message))
return true;
mSendQueue.push(message);
interruptSelect(); // so the thread waits for writability
return false;
}
void TcpTransport::connect(const string &hostname, const string &service) {
PLOG_DEBUG << "Connecting to " << hostname << ":" << service;
struct addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_ADDRCONFIG;
struct addrinfo *result = nullptr;
if (getaddrinfo(hostname.c_str(), service.c_str(), &hints, &result))
throw std::runtime_error("Resolution failed for \"" + hostname + ":" + service + "\"");
for (auto p = result; p; p = p->ai_next)
try {
connect(p->ai_addr, p->ai_addrlen);
freeaddrinfo(result);
return;
} catch (const std::runtime_error &e) {
PLOG_WARNING << e.what();
}
freeaddrinfo(result);
throw std::runtime_error("Connection failed to \"" + hostname + ":" + service + "\"");
}
void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
try {
PLOG_DEBUG << "Creating TCP socket";
// Create socket
mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
if (mSock == INVALID_SOCKET)
throw std::runtime_error("TCP socket creation failed");
ctl_t b = 1;
if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
throw std::runtime_error("Failed to set socket non-blocking mode");
IF_PLOG(plog::debug) {
char node[MAX_NUMERICNODE_LEN];
char serv[MAX_NUMERICSERV_LEN];
if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
PLOG_DEBUG << "Trying address " << node << ":" << serv;
}
}
// Initiate connection
::connect(mSock, addr, addrlen);
fd_set writefds;
FD_ZERO(&writefds);
FD_SET(mSock, &writefds);
struct timeval tv;
tv.tv_sec = 10; // TODO
tv.tv_usec = 0;
int ret = ::select(SOCKET_TO_INT(mSock) + 1, NULL, &writefds, NULL, &tv);
if (ret < 0)
throw std::runtime_error("Failed to wait for socket connection");
if (ret == 0 || ::send(mSock, NULL, 0, MSG_NOSIGNAL) != 0)
throw std::runtime_error("Connection failed");
} catch (...) {
if (mSock != INVALID_SOCKET) {
::closesocket(mSock);
mSock = INVALID_SOCKET;
}
throw;
}
}
void TcpTransport::close() {
if (mSock != INVALID_SOCKET) {
PLOG_DEBUG << "Closing TCP socket";
::closesocket(mSock);
mSock = INVALID_SOCKET;
}
changeState(State::Disconnected);
}
bool TcpTransport::trySendQueue() {
while (auto next = mSendQueue.peek()) {
auto message = *next;
if (!trySendMessage(message)) {
mSendQueue.exchange(message);
return false;
}
mSendQueue.pop();
}
return true;
}
bool TcpTransport::trySendMessage(message_ptr &message) {
auto data = reinterpret_cast<const char *>(message->data());
auto size = message->size();
while (size) {
int len = ::send(mSock, data, size, MSG_NOSIGNAL);
if (len < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
message = make_message(message->end() - size, message->end());
return false;
} else {
throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
}
}
data += len;
size -= len;
}
message = nullptr;
return true;
}
void TcpTransport::runLoop() {
const size_t bufferSize = 4096;
// Connect
try {
changeState(State::Connecting);
connect(mHostname, mService);
} catch (const std::exception &e) {
PLOG_ERROR << "TCP connect: " << e.what();
changeState(State::Failed);
return;
}
// Receive loop
try {
PLOG_INFO << "TCP connected";
changeState(State::Connected);
while (true) {
fd_set readfds, writefds;
int n = prepareSelect(readfds, writefds);
int ret = ::select(n, &readfds, &writefds, NULL, NULL);
if (ret < 0)
throw std::runtime_error("Failed to wait on socket");
if (FD_ISSET(mSock, &writefds))
trySendQueue();
if (FD_ISSET(mSock, &readfds)) {
char buffer[bufferSize];
int len = ::recv(mSock, buffer, bufferSize, 0);
if (len < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
continue;
} else {
throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
}
}
if (len == 0)
break; // clean close
auto *b = reinterpret_cast<byte *>(buffer);
incoming(make_message(b, b + len));
}
}
} catch (const std::exception &e) {
PLOG_ERROR << "TCP recv: " << e.what();
}
PLOG_INFO << "TCP disconnected";
changeState(State::Disconnected);
recv(nullptr);
}
int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) {
FD_ZERO(&readfds);
FD_ZERO(&writefds);
FD_SET(mSock, &readfds);
if (!mSendQueue.empty())
FD_SET(mSock, &writefds);
int n = SOCKET_TO_INT(mSock) + 1;
int m = mInterrupter.prepare(readfds, writefds);
return std::max(n, m);
}
void TcpTransport::interruptSelect() { mInterrupter.interrupt(); }
} // namespace rtc
#endif

90
src/tcptransport.hpp Normal file
View File

@ -0,0 +1,90 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_TCP_TRANSPORT_H
#define RTC_TCP_TRANSPORT_H
#if RTC_ENABLE_WEBSOCKET
#include "include.hpp"
#include "queue.hpp"
#include "transport.hpp"
#include <mutex>
#include <thread>
// Use the socket defines from libjuice
#include "../deps/libjuice/src/socket.h"
namespace rtc {
// Utility class to interrupt select()
class SelectInterrupter {
public:
SelectInterrupter();
~SelectInterrupter();
int prepare(fd_set &readfds, fd_set &writefds);
void interrupt();
private:
std::mutex mMutex;
#ifdef _WIN32
socket_t mDummySock = INVALID_SOCKET;
#else // assume POSIX
int mPipeIn, mPipeOut;
#endif
};
class TcpTransport : public Transport {
public:
TcpTransport(const string &hostname, const string &service, state_callback callback);
~TcpTransport();
bool stop() override;
bool send(message_ptr message) override;
void incoming(message_ptr message) override;
bool outgoing(message_ptr message) override;
private:
void connect(const string &hostname, const string &service);
void connect(const sockaddr *addr, socklen_t addrlen);
void close();
bool trySendQueue();
bool trySendMessage(message_ptr &message);
void runLoop();
int prepareSelect(fd_set &readfds, fd_set &writefds);
void interruptSelect();
string mHostname, mService;
socket_t mSock = INVALID_SOCKET;
std::thread mThread;
SelectInterrupter mInterrupter;
Queue<message_ptr> mSendQueue;
};
} // namespace rtc
#endif
#endif

432
src/tlstransport.cpp Normal file
View File

@ -0,0 +1,432 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#if RTC_ENABLE_WEBSOCKET
#include "tlstransport.hpp"
#include "tcptransport.hpp"
#include <chrono>
#include <cstring>
#include <exception>
#include <iostream>
using namespace std::chrono;
using std::shared_ptr;
using std::string;
using std::unique_ptr;
using std::weak_ptr;
#if USE_GNUTLS
namespace {
static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
if (ret < 0) {
if (!gnutls_error_is_fatal(ret)) {
PLOG_INFO << gnutls_strerror(ret);
return false;
}
PLOG_ERROR << message << ": " << gnutls_strerror(ret);
throw std::runtime_error(message + ": " + gnutls_strerror(ret));
}
return true;
}
} // namespace
namespace rtc {
void TlsTransport::Init() {
// Nothing to do
}
void TlsTransport::Cleanup() {
// Nothing to do
}
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) {
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
try {
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
const char *err_pos = NULL;
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
"Failed to set TLS priorities");
gnutls_session_set_ptr(mSession, this);
gnutls_transport_set_ptr(mSession, this);
gnutls_transport_set_push_function(mSession, WriteCallback);
gnutls_transport_set_pull_function(mSession, ReadCallback);
gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
registerIncoming();
} catch (...) {
gnutls_deinit(mSession);
throw;
}
}
TlsTransport::~TlsTransport() {
stop();
gnutls_deinit(mSession);
}
bool TlsTransport::stop() {
if (!Transport::stop())
return false;
PLOG_DEBUG << "Stopping TLS recv thread";
mIncomingQueue.stop();
mRecvThread.join();
return true;
}
bool TlsTransport::send(message_ptr message) {
if (!message)
return false;
PLOG_VERBOSE << "Send size=" << message->size();
ssize_t ret;
do {
ret = gnutls_record_send(mSession, message->data(), message->size());
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
return check_gnutls(ret);
}
void TlsTransport::incoming(message_ptr message) {
if (message)
mIncomingQueue.push(message);
else
mIncomingQueue.stop();
}
void TlsTransport::runRecvLoop() {
const size_t bufferSize = 4096;
char buffer[bufferSize];
// Handshake loop
try {
changeState(State::Connecting);
int ret;
do {
ret = gnutls_handshake(mSession);
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
!check_gnutls(ret, "TLS handshake failed"));
} catch (const std::exception &e) {
PLOG_ERROR << "TLS handshake: " << e.what();
changeState(State::Failed);
return;
}
// Receive loop
try {
PLOG_INFO << "TLS handshake finished";
changeState(State::Connected);
while (true) {
ssize_t ret;
do {
ret = gnutls_record_recv(mSession, buffer, bufferSize);
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
// Consider premature termination as remote closing
if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
PLOG_DEBUG << "TLS connection terminated";
break;
}
if (check_gnutls(ret)) {
if (ret == 0) {
// Closed
PLOG_DEBUG << "TLS connection cleanly closed";
break;
}
auto *b = reinterpret_cast<byte *>(buffer);
recv(make_message(b, b + ret));
}
}
} catch (const std::exception &e) {
PLOG_ERROR << "TLS recv: " << e.what();
}
gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
PLOG_INFO << "TLS closed";
changeState(State::Disconnected);
recv(nullptr);
}
ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
TlsTransport *t = static_cast<TlsTransport *>(ptr);
if (len > 0) {
auto b = reinterpret_cast<const byte *>(data);
t->outgoing(make_message(b, b + len));
}
gnutls_transport_set_errno(t->mSession, 0);
return ssize_t(len);
}
ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
TlsTransport *t = static_cast<TlsTransport *>(ptr);
if (auto next = t->mIncomingQueue.pop()) {
auto message = *next;
ssize_t len = std::min(maxlen, message->size());
std::memcpy(data, message->data(), len);
gnutls_transport_set_errno(t->mSession, 0);
return len;
}
// Closed
gnutls_transport_set_errno(t->mSession, 0);
return 0;
}
int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
TlsTransport *t = static_cast<TlsTransport *>(ptr);
if (ms != GNUTLS_INDEFINITE_TIMEOUT)
t->mIncomingQueue.wait(milliseconds(ms));
else
t->mIncomingQueue.wait();
return !t->mIncomingQueue.empty() ? 1 : 0;
}
} // namespace rtc
#else // USE_GNUTLS==0
#include <openssl/bio.h>
#include <openssl/ec.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
namespace {
const int BIO_EOF = -1;
string openssl_error_string(unsigned long err) {
const size_t bufferSize = 256;
char buffer[bufferSize];
ERR_error_string_n(err, buffer, bufferSize);
return string(buffer);
}
bool check_openssl(int success, const string &message = "OpenSSL error") {
if (success)
return true;
string str = openssl_error_string(ERR_get_error());
PLOG_ERROR << message << ": " << str;
throw std::runtime_error(message + ": " + str);
}
bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
if (ret == BIO_EOF)
return true;
unsigned long err = SSL_get_error(ssl, ret);
if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
return true;
}
if (err == SSL_ERROR_ZERO_RETURN) {
PLOG_DEBUG << "TLS connection cleanly closed";
return false;
}
string str = openssl_error_string(err);
PLOG_ERROR << str;
throw std::runtime_error(message + ": " + str);
}
} // namespace
namespace rtc {
int TlsTransport::TransportExIndex = -1;
void TlsTransport::Init() {
if (TransportExIndex < 0) {
TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
}
}
void TlsTransport::Cleanup() {
// Nothing to do
}
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) {
PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
throw std::runtime_error("Failed to create SSL context");
check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
"Failed to set SSL priorities");
SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
SSL_CTX_set_read_ahead(mCtx, 1);
SSL_CTX_set_quiet_shutdown(mCtx, 1);
SSL_CTX_set_info_callback(mCtx, InfoCallback);
SSL_CTX_set_default_verify_paths(mCtx);
SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
SSL_CTX_set_verify_depth(mCtx, 4);
if (!(mSsl = SSL_new(mCtx)))
throw std::runtime_error("Failed to create SSL instance");
SSL_set_ex_data(mSsl, TransportExIndex, this);
SSL_set_tlsext_host_name(mSsl, host.c_str());
SSL_set_connect_state(mSsl);
if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
throw std::runtime_error("Failed to create BIO");
BIO_set_mem_eof_return(mInBio, BIO_EOF);
BIO_set_mem_eof_return(mOutBio, BIO_EOF);
SSL_set_bio(mSsl, mInBio, mOutBio);
auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
SSL_set_tmp_ecdh(mSsl, ecdh.get());
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
}
TlsTransport::~TlsTransport() {
stop();
SSL_free(mSsl);
SSL_CTX_free(mCtx);
}
bool TlsTransport::stop() {
if (!Transport::stop())
return false;
PLOG_DEBUG << "Stopping TLS recv thread";
mIncomingQueue.stop();
mRecvThread.join();
SSL_shutdown(mSsl);
return true;
}
bool TlsTransport::send(message_ptr message) {
if (!message)
return false;
int ret = SSL_write(mSsl, message->data(), message->size());
if (!check_openssl_ret(mSsl, ret))
return false;
const size_t bufferSize = 4096;
byte buffer[bufferSize];
while (int len = BIO_read(mOutBio, buffer, bufferSize))
outgoing(make_message(buffer, buffer + len));
return true;
}
void TlsTransport::incoming(message_ptr message) {
if (message)
mIncomingQueue.push(message);
else
mIncomingQueue.stop();
}
void TlsTransport::runRecvLoop() {
const size_t bufferSize = 4096;
byte buffer[bufferSize];
try {
changeState(State::Connecting);
SSL_do_handshake(mSsl);
while (int len = BIO_read(mOutBio, buffer, bufferSize))
outgoing(make_message(buffer, buffer + len));
while (auto next = mIncomingQueue.pop()) {
message_ptr message = *next;
message_ptr decrypted;
BIO_write(mInBio, message->data(), message->size());
int ret = SSL_read(mSsl, buffer, bufferSize);
if (!check_openssl_ret(mSsl, ret))
break;
if (ret > 0)
decrypted = make_message(buffer, buffer + ret);
while (int len = BIO_read(mOutBio, buffer, bufferSize))
outgoing(make_message(buffer, buffer + len));
if (state() == State::Connecting && SSL_is_init_finished(mSsl)) {
PLOG_INFO << "TLS handshake finished";
changeState(State::Connected);
}
if (decrypted)
recv(decrypted);
}
} catch (const std::exception &e) {
PLOG_ERROR << "TLS recv: " << e.what();
}
if (state() == State::Connected) {
PLOG_INFO << "TLS closed";
recv(nullptr);
} else {
PLOG_ERROR << "TLS handshake failed";
}
}
void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
TlsTransport *t =
static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
if (where & SSL_CB_ALERT) {
if (ret != 256) { // Close Notify
PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
}
t->mIncomingQueue.stop(); // Close the connection
}
}
} // namespace rtc
#endif
#endif

83
src/tlstransport.hpp Normal file
View File

@ -0,0 +1,83 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_TLS_TRANSPORT_H
#define RTC_TLS_TRANSPORT_H
#if RTC_ENABLE_WEBSOCKET
#include "include.hpp"
#include "queue.hpp"
#include "transport.hpp"
#include <memory>
#include <mutex>
#include <thread>
#if USE_GNUTLS
#include <gnutls/gnutls.h>
#else
#include <openssl/ssl.h>
#endif
namespace rtc {
class TcpTransport;
class TlsTransport : public Transport {
public:
static void Init();
static void Cleanup();
TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
~TlsTransport();
bool stop() override;
bool send(message_ptr message) override;
void incoming(message_ptr message) override;
protected:
void runRecvLoop();
Queue<message_ptr> mIncomingQueue;
std::thread mRecvThread;
#if USE_GNUTLS
gnutls_session_t mSession;
static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
#else
SSL_CTX *mCtx;
SSL *mSsl;
BIO *mInBio, *mOutBio;
static int TransportExIndex;
static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
static void InfoCallback(const SSL *ssl, int where, int ret);
#endif
};
} // namespace rtc
#endif
#endif

View File

@ -32,7 +32,13 @@ using namespace std::placeholders;
class Transport {
public:
Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
enum class State { Disconnected, Connecting, Connected, Completed, Failed };
using state_callback = std::function<void(State state)>;
Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
: mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
}
virtual ~Transport() {
stop();
if (mLower)
@ -49,11 +55,16 @@ public:
}
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
State state() const { return mState; }
virtual bool send(message_ptr message) { return outgoing(message); }
protected:
void recv(message_ptr message) { mRecvCallback(message); }
void changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
virtual void incoming(message_ptr message) { recv(message); }
virtual bool outgoing(message_ptr message) {
@ -65,7 +76,10 @@ protected:
private:
std::shared_ptr<Transport> mLower;
synchronized_callback<State> mStateChangeCallback;
synchronized_callback<message_ptr> mRecvCallback;
std::atomic<State> mState = State::Disconnected;
std::atomic<bool> mShutdown = false;
};

311
src/websocket.cpp Normal file
View File

@ -0,0 +1,311 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#if RTC_ENABLE_WEBSOCKET
#include "include.hpp"
#include "websocket.hpp"
#include "tcptransport.hpp"
#include "tlstransport.hpp"
#include "wstransport.hpp"
#include <regex>
#ifdef _WIN32
#include <winsock2.h>
#endif
namespace rtc {
WebSocket::WebSocket() {}
WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
WebSocket::~WebSocket() { remoteClose(); }
WebSocket::State WebSocket::readyState() const { return mState; }
void WebSocket::open(const string &url) {
if (mState != State::Closed)
throw std::runtime_error("WebSocket must be closed before opening");
static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
static std::regex regex(rs, std::regex::extended);
std::smatch match;
if (!std::regex_match(url, match, regex))
throw std::invalid_argument("Malformed WebSocket URL: " + url);
mScheme = match[2];
if (mScheme != "ws" && mScheme != "wss")
throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
mHost = match[4];
if (auto pos = mHost.find(':'); pos != string::npos) {
mHostname = mHost.substr(0, pos);
mService = mHost.substr(pos + 1);
} else {
mHostname = mHost;
mService = mScheme == "ws" ? "80" : "443";
}
mPath = match[5];
if (string query = match[7]; !query.empty())
mPath += "?" + query;
changeState(State::Connecting);
initTcpTransport();
}
void WebSocket::close() {
auto state = mState.load();
if (state == State::Connecting || state == State::Open) {
changeState(State::Closing);
if (auto transport = std::atomic_load(&mWsTransport))
transport->close();
else
changeState(State::Closed);
}
}
void WebSocket::remoteClose() {
close();
closeTransports();
}
bool WebSocket::send(const std::variant<binary, string> &data) {
return std::visit(
[&](const auto &d) {
using T = std::decay_t<decltype(d)>;
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
auto *b = reinterpret_cast<const byte *>(d.data());
return outgoing(std::make_shared<Message>(b, b + d.size(), type));
},
data);
}
bool WebSocket::isOpen() const { return mState == State::Open; }
bool WebSocket::isClosed() const { return mState == State::Closed; }
size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
std::optional<std::variant<binary, string>> WebSocket::receive() {
while (!mRecvQueue.empty()) {
auto message = *mRecvQueue.pop();
switch (message->type) {
case Message::String:
return std::make_optional(
string(reinterpret_cast<const char *>(message->data()), message->size()));
case Message::Binary:
return std::make_optional(std::move(*message));
default:
// Ignore
break;
}
}
return nullopt;
}
size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
bool WebSocket::changeState(State state) { return mState.exchange(state) != state; }
bool WebSocket::outgoing(mutable_message_ptr message) {
if (mState != State::Open || !mWsTransport)
throw std::runtime_error("WebSocket is not open");
if (message->size() > maxMessageSize())
throw std::runtime_error("Message size exceeds limit");
return mWsTransport->send(message);
}
void WebSocket::incoming(message_ptr message) {
if (message->type == Message::String || message->type == Message::Binary) {
mRecvQueue.push(message);
triggerAvailable(mRecvQueue.size());
}
}
std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
using State = TcpTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mTcpTransport))
return transport;
auto transport = std::make_shared<TcpTransport>(
mHostname, mService, [this, weak_this = weak_from_this()](State state) {
auto shared_this = weak_this.lock();
if (!shared_this)
return;
switch (state) {
case State::Connected:
if (mScheme == "ws")
initWsTransport();
else
initTlsTransport();
break;
case State::Failed:
triggerError("TCP connection failed");
remoteClose();
break;
case State::Disconnected:
remoteClose();
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mTcpTransport, transport);
if (mState == WebSocket::State::Closed) {
mTcpTransport.reset();
transport->stop();
throw std::runtime_error("Connection is closed");
}
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
remoteClose();
throw std::runtime_error("TCP transport initialization failed");
}
}
std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
using State = TlsTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mTlsTransport))
return transport;
auto lower = std::atomic_load(&mTcpTransport);
auto transport = std::make_shared<TlsTransport>(
lower, mHost, [this, weak_this = weak_from_this()](State state) {
auto shared_this = weak_this.lock();
if (!shared_this)
return;
switch (state) {
case State::Connected:
initWsTransport();
break;
case State::Failed:
triggerError("TCP connection failed");
remoteClose();
break;
case State::Disconnected:
remoteClose();
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mTlsTransport, transport);
if (mState == WebSocket::State::Closed) {
mTlsTransport.reset();
transport->stop();
throw std::runtime_error("Connection is closed");
}
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
remoteClose();
throw std::runtime_error("TLS transport initialization failed");
}
}
std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
using State = WsTransport::State;
try {
std::lock_guard lock(mInitMutex);
if (auto transport = std::atomic_load(&mWsTransport))
return transport;
std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
if (!lower)
lower = std::atomic_load(&mTcpTransport);
auto transport = std::make_shared<WsTransport>(
lower, mHost, mPath, weak_bind(&WebSocket::incoming, this, _1),
[this, weak_this = weak_from_this()](State state) {
auto shared_this = weak_this.lock();
if (!shared_this)
return;
switch (state) {
case State::Connected:
if (mState == WebSocket::State::Connecting) {
PLOG_DEBUG << "WebSocket open";
changeState(WebSocket::State::Open);
triggerOpen();
}
break;
case State::Failed:
triggerError("WebSocket connection failed");
remoteClose();
break;
case State::Disconnected:
remoteClose();
break;
default:
// Ignore
break;
}
});
std::atomic_store(&mWsTransport, transport);
if (mState == WebSocket::State::Closed) {
mWsTransport.reset();
transport->stop();
throw std::runtime_error("Connection is closed");
}
return transport;
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
remoteClose();
throw std::runtime_error("WebSocket transport initialization failed");
}
}
void WebSocket::closeTransports() {
changeState(State::Closed);
// Pass the references to a thread, allowing to terminate a transport from its own thread
auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
if (ws || tls || tcp) {
std::thread t([ws, tls, tcp]() mutable {
if (ws)
ws->stop();
if (tls)
tls->stop();
if (tcp)
tcp->stop();
ws.reset();
tls.reset();
tcp.reset();
});
t.detach();
}
}
} // namespace rtc
#endif

372
src/wstransport.cpp Normal file
View File

@ -0,0 +1,372 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#if RTC_ENABLE_WEBSOCKET
#include "wstransport.hpp"
#include "tcptransport.hpp"
#include "tlstransport.hpp"
#include "base64.hpp"
#include <chrono>
#include <list>
#include <map>
#include <random>
#include <regex>
#ifdef _WIN32
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif
#ifndef htonll
#define htonll(x) \
((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
#endif
#ifndef ntohll
#define ntohll(x) htonll(x)
#endif
namespace rtc {
using namespace std::chrono;
using std::to_integer;
using std::to_string;
using random_bytes_engine =
std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
message_callback recvCallback, state_callback stateCallback)
: Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) {
onRecv(recvCallback);
PLOG_DEBUG << "Initializing WebSocket transport";
registerIncoming();
sendHttpRequest();
}
WsTransport::~WsTransport() { stop(); }
bool WsTransport::stop() {
if (!Transport::stop())
return false;
close();
return true;
}
bool WsTransport::send(message_ptr message) {
if (!message)
return false;
// Call the mutable message overload with a copy
return send(std::make_shared<Message>(*message));
}
bool WsTransport::send(mutable_message_ptr message) {
if (!message || state() != State::Connected)
return false;
PLOG_VERBOSE << "Send size=" << message->size();
return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
message->size(), true, true});
}
void WsTransport::incoming(message_ptr message) {
try {
mBuffer.insert(mBuffer.end(), message->begin(), message->end());
if (state() == State::Connecting) {
if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
PLOG_INFO << "WebSocket open";
changeState(State::Connected);
}
}
if (state() == State::Connected) {
Frame frame = {};
while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
recvFrame(frame);
}
}
} catch (const std::exception &e) {
PLOG_ERROR << e.what();
}
if (state() == State::Connected) {
PLOG_INFO << "WebSocket disconnected";
changeState(State::Disconnected);
recv(nullptr);
} else {
PLOG_ERROR << "WebSocket handshake failed";
changeState(State::Failed);
}
}
void WsTransport::close() {
if (state() == State::Connected) {
sendFrame({CLOSE, NULL, 0, true, true});
PLOG_INFO << "WebSocket closing";
changeState(State::Completed);
}
}
bool WsTransport::sendHttpRequest() {
changeState(State::Connecting);
auto seed = system_clock::now().time_since_epoch().count();
random_bytes_engine generator(seed);
binary key(16);
std::generate(reinterpret_cast<uint8_t *>(key.data()),
reinterpret_cast<uint8_t *>(key.data() + key.size()), generator);
const string request = "GET " + mPath +
" HTTP/1.1\r\n"
"Host: " +
mHost +
"\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Key: " +
to_base64(key) +
"\r\n"
"\r\n";
auto data = reinterpret_cast<const byte *>(request.data());
auto size = request.size();
return outgoing(make_message(data, data + size));
}
size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
std::list<string> lines;
auto begin = reinterpret_cast<const char *>(buffer);
auto end = begin + size;
auto cur = begin;
while (true) {
auto last = cur;
cur = std::find(cur, end, '\n');
if (cur == end)
return 0;
string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
if (line.empty())
break;
lines.emplace_back(std::move(line));
}
size_t length = cur - begin;
if (lines.empty())
throw std::runtime_error("Invalid HTTP response for WebSocket");
string status = std::move(lines.front());
lines.pop_front();
std::istringstream ss(status);
string protocol;
unsigned int code = 0;
ss >> protocol >> code;
PLOG_DEBUG << "WebSocket response code: " << code;
if (code != 101)
throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
std::multimap<string, string> headers;
for (const auto &line : lines) {
if (size_t pos = line.find_first_of(':'); pos != string::npos) {
string key = line.substr(0, pos);
string value = line.substr(line.find_first_not_of(' ', pos + 1));
std::transform(key.begin(), key.end(), key.begin(),
[](char c) { return std::tolower(c); });
headers.emplace(std::move(key), std::move(value));
} else {
headers.emplace(line, "");
}
}
auto h = headers.find("upgrade");
if (h == headers.end() || h->second != "websocket")
throw std::runtime_error("WebSocket update header missing or mismatching");
h = headers.find("sec-websocket-accept");
if (h == headers.end())
throw std::runtime_error("WebSocket accept header missing");
// TODO: Verify Sec-WebSocket-Accept
return length;
}
// http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | | Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------+ - - - - - - - - - - - - - - - +
// : Payload Data continued ... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | Payload Data continued ... |
// +---------------------------------------------------------------+
size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
const byte *end = buffer + size;
if (end - buffer < 2)
return 0;
byte *cur = buffer;
auto b1 = to_integer<uint8_t>(*cur++);
auto b2 = to_integer<uint8_t>(*cur++);
frame.fin = (b1 & 0x80) != 0;
frame.mask = (b2 & 0x80) != 0;
frame.opcode = static_cast<Opcode>(b1 & 0x0F);
frame.length = b2 & 0x7F;
if (frame.length == 0x7E) {
if (end - cur < 2)
return 0;
frame.length = ntohs(*reinterpret_cast<const uint16_t *>(cur));
cur += 2;
} else if (frame.length == 0x7F) {
if (end - cur < 8)
return false;
frame.length = ntohll(*reinterpret_cast<const uint64_t *>(cur));
cur += 8;
}
const byte *maskingKey = nullptr;
if (frame.mask) {
if (end - cur < 4)
return 0;
maskingKey = cur;
cur += 4;
}
if (end - cur < frame.length)
return false;
frame.payload = cur;
if (maskingKey)
for (size_t i = 0; i < frame.length; ++i)
frame.payload[i] ^= maskingKey[i % 4];
return end - buffer;
}
void WsTransport::recvFrame(const Frame &frame) {
switch (frame.opcode) {
case TEXT_FRAME:
case BINARY_FRAME: {
if (!mPartial.empty()) {
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
recv(make_message(mPartial.begin(), mPartial.end(), type));
mPartial.clear();
}
if (frame.fin) {
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
recv(make_message(frame.payload, frame.payload + frame.length));
} else {
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
mPartialOpcode = frame.opcode;
}
break;
}
case CONTINUATION: {
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
if (frame.fin) {
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
recv(make_message(mPartial.begin(), mPartial.end()));
mPartial.clear();
}
break;
}
case PING: {
sendFrame({PONG, frame.payload, frame.length, true, true});
break;
}
case PONG: {
// TODO
break;
}
case CLOSE: {
close();
PLOG_INFO << "WebSocket closed";
changeState(State::Disconnected);
break;
}
default: {
close();
throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode));
}
}
}
bool WsTransport::sendFrame(const Frame &frame) {
byte buffer[14];
byte *cur = buffer;
*cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0));
if (frame.length < 0x7E) {
*cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0));
} else if (frame.length <= 0xFF) {
*cur++ = byte(0x7E | (frame.mask ? 0x80 : 0));
*reinterpret_cast<uint16_t *>(cur) = uint16_t(frame.length);
cur += 2;
} else {
*cur++ = byte(0x7F | (frame.mask ? 0x80 : 0));
*reinterpret_cast<uint64_t *>(cur) = uint64_t(frame.length);
cur += 8;
}
if (frame.mask) {
auto seed = system_clock::now().time_since_epoch().count();
random_bytes_engine generator(seed);
auto *maskingKey = cur;
std::generate(reinterpret_cast<uint8_t *>(maskingKey),
reinterpret_cast<uint8_t *>(maskingKey + 4), generator);
cur += 4;
for (size_t i = 0; i < frame.length; ++i)
frame.payload[i] ^= maskingKey[i % 4];
}
outgoing(make_message(buffer, cur)); // header
return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
}
} // namespace rtc
#endif

83
src/wstransport.hpp Normal file
View File

@ -0,0 +1,83 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_WS_TRANSPORT_H
#define RTC_WS_TRANSPORT_H
#if RTC_ENABLE_WEBSOCKET
#include "include.hpp"
#include "transport.hpp"
namespace rtc {
class TcpTransport;
class TlsTransport;
class WsTransport : public Transport {
public:
WsTransport(std::shared_ptr<Transport> lower, string host, string path,
message_callback recvCallback, state_callback stateCallback);
~WsTransport();
bool stop() override;
bool send(message_ptr message) override;
bool send(mutable_message_ptr message);
void incoming(message_ptr message) override;
void close();
private:
enum Opcode : uint8_t {
CONTINUATION = 0,
TEXT_FRAME = 1,
BINARY_FRAME = 2,
CLOSE = 8,
PING = 9,
PONG = 10,
};
struct Frame {
Opcode opcode = BINARY_FRAME;
byte *payload = nullptr;
size_t length = 0;
bool fin = true;
bool mask = true;
};
bool sendHttpRequest();
size_t readHttpResponse(const byte *buffer, size_t size);
size_t readFrame(byte *buffer, size_t size, Frame &frame);
void recvFrame(const Frame &frame);
bool sendFrame(const Frame &frame);
const string mHost;
const string mPath;
binary mBuffer;
binary mPartial;
Opcode mPartialOpcode;
};
} // namespace rtc
#endif
#endif

View File

@ -25,19 +25,19 @@ void test_capi();
int main(int argc, char **argv) {
try {
std::cout << "*** Running connectivity test..." << std::endl;
cout << endl << "*** Running connectivity test..." << endl;
test_connectivity();
std::cout << "*** Finished connectivity test" << std::endl;
cout << "*** Finished connectivity test" << endl;
} catch (const exception &e) {
std::cerr << "Connectivity test failed: " << e.what() << endl;
cerr << "Connectivity test failed: " << e.what() << endl;
return -1;
}
try {
std::cout << "*** Running C API test..." << std::endl;
cout << endl << "*** Running C API test..." << endl;
test_capi();
std::cout << "*** Finished C API test" << std::endl;
cout << "*** Finished C API test" << endl;
} catch (const exception &e) {
std::cerr << "C API test failed: " << e.what() << endl;
cerr << "C API test failed: " << e.what() << endl;
return -1;
}
return 0;