Compare commits

...

13 Commits

19 changed files with 371 additions and 148 deletions

View File

@ -1,7 +1,7 @@
cmake_minimum_required (VERSION 3.7)
project (libdatachannel
DESCRIPTION "WebRTC DataChannels Library"
VERSION 0.4.8
VERSION 0.4.9
LANGUAGES CXX)
option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
@ -31,11 +31,31 @@ set(LIBDATACHANNEL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/description.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/dtlstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/icetransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/init.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/log.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/peerconnection.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
)
set(LIBDATACHANNEL_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/candidate.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/configuration.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/configuration.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/datachannel.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/description.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/include.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/init.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/log.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/message.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/peerconnection.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/queue.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp
)
set(TESTS_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/test/main.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/connectivity.cpp
@ -54,6 +74,10 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
add_subdirectory(deps/usrsctp EXCLUDE_FROM_ALL)
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU")
target_compile_options(usrsctp PRIVATE -Wno-error=format-truncation)
target_compile_options(usrsctp-static PRIVATE -Wno-error=format-truncation)
endif()
add_library(Usrsctp::Usrsctp ALIAS usrsctp)
add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static)
@ -123,6 +147,9 @@ endif()
add_library(LibDataChannel::LibDataChannel ALIAS datachannel)
add_library(LibDataChannel::LibDataChannelStatic ALIAS datachannel-static)
install(TARGETS datachannel LIBRARY DESTINATION lib)
install(FILES ${LIBDATACHANNEL_HEADERS} DESTINATION include/rtc)
# Main Test
add_executable(datachannel-tests ${TESTS_SOURCES})
set_target_properties(datachannel-tests PROPERTIES

View File

@ -5,7 +5,7 @@ CXX=$(CROSS)g++
AR=$(CROSS)ar
RM=rm -f
CXXFLAGS=-std=c++17
CPPFLAGS=-O2 -pthread -fPIC -Wall -Wno-address-of-packed-member
CPPFLAGS=-O2 -pthread -fPIC -Wall
LDFLAGS=-pthread
LIBS=
LOCALLIBS=libusrsctp.a
@ -86,7 +86,7 @@ dist-clean: clean
libusrsctp.a:
cd $(USRSCTP_DIR) && \
./bootstrap && \
./configure --enable-static --disable-debug CFLAGS="$(CPPFLAGS)" && \
./configure --enable-static --disable-debug CFLAGS="$(CPPFLAGS) -Wno-error=format-truncation" && \
make
cp $(USRSCTP_DIR)/usrsctplib/.libs/libusrsctp.a .

2
deps/libjuice vendored

View File

@ -38,9 +38,9 @@ class PeerConnection;
class DataChannel : public std::enable_shared_from_this<DataChannel>, public Channel {
public:
DataChannel(std::shared_ptr<PeerConnection> pc, unsigned int stream, string label,
DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
string protocol, Reliability reliability);
DataChannel(std::shared_ptr<PeerConnection> pc, std::shared_ptr<SctpTransport> transport,
DataChannel(std::weak_ptr<PeerConnection> pc, std::shared_ptr<SctpTransport> transport,
unsigned int stream);
~DataChannel();
@ -70,7 +70,7 @@ private:
void incoming(message_ptr message);
void processOpenMessage(message_ptr message);
const std::shared_ptr<PeerConnection> mPeerConnection;
const std::weak_ptr<PeerConnection> mPeerConnection;
std::shared_ptr<SctpTransport> mSctpTransport;
unsigned int mStream;

50
include/rtc/init.hpp Normal file
View File

@ -0,0 +1,50 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_INIT_H
#define RTC_INIT_H
#include "include.hpp"
#include <mutex>
namespace rtc {
class Init;
using init_token = std::shared_ptr<Init>;
class Init {
public:
static init_token Token();
static void Cleanup();
~Init();
private:
Init();
static std::weak_ptr<Init> Weak;
static init_token Global;
static std::mutex Mutex;
};
inline void Cleanup() { Init::Cleanup(); }
} // namespace rtc
#endif

View File

@ -19,9 +19,7 @@
#ifndef RTC_LOG_H
#define RTC_LOG_H
#include "plog/Appenders/ColorConsoleAppender.h"
#include "plog/Log.h"
#include "plog/Logger.h"
namespace rtc {
@ -35,21 +33,8 @@ enum class LogLevel { // Don't change, it must match plog severity
Verbose = 6
};
inline void InitLogger(plog::Severity severity, plog::IAppender *appender = nullptr) {
static plog::ColorConsoleAppender<plog::TxtFormatter> consoleAppender;
static plog::Logger<0> *logger = nullptr;
if (!logger) {
logger = &plog::init(severity, appender ? appender : &consoleAppender);
PLOG_DEBUG << "Logger initialized";
} else {
logger->setMaxSeverity(severity);
if (appender)
logger->addAppender(appender);
}
}
inline void InitLogger(LogLevel level) { InitLogger(static_cast<plog::Severity>(level)); }
void InitLogger(LogLevel level);
void InitLogger(plog::Severity severity, plog::IAppender *appender = nullptr);
}
#endif

View File

@ -24,6 +24,7 @@
#include "datachannel.hpp"
#include "description.hpp"
#include "include.hpp"
#include "init.hpp"
#include "message.hpp"
#include "reliability.hpp"
#include "rtc.hpp"
@ -88,6 +89,8 @@ public:
void onGatheringStateChange(std::function<void(GatheringState state)> callback);
private:
init_token mInitToken = Init::Token();
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
std::shared_ptr<DtlsTransport> initDtlsTransport();
std::shared_ptr<SctpTransport> initSctpTransport();

View File

@ -41,12 +41,10 @@ public:
bool empty() const;
size_t size() const; // elements
size_t amount() const; // amount
void push(const T &element);
void push(T &&element);
void push(T element);
std::optional<T> pop();
std::optional<T> peek();
void wait();
void wait(const std::chrono::milliseconds &duration);
bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
private:
const size_t mLimit;
@ -88,9 +86,7 @@ template <typename T> size_t Queue<T>::amount() const {
return mAmount;
}
template <typename T> void Queue<T>::push(const T &element) { push(T{element}); }
template <typename T> void Queue<T>::push(T &&element) {
template <typename T> void Queue<T>::push(T element) {
std::unique_lock lock(mMutex);
mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
if (!mStopping) {
@ -122,14 +118,14 @@ template <typename T> std::optional<T> Queue<T>::peek() {
}
}
template <typename T> void Queue<T>::wait() {
template <typename T>
bool Queue<T>::wait(const std::optional<std::chrono::milliseconds> &duration) {
std::unique_lock lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
}
template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
std::unique_lock lock(mMutex);
mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
if (duration)
mPopCondition.wait_for(lock, *duration, [this]() { return !mQueue.empty() || mStopping; });
else
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
return !mStopping;
}
} // namespace rtc

View File

@ -17,8 +17,11 @@
*/
// C++ API
#include "datachannel.hpp"
#include "include.hpp"
#include "init.hpp" // for rtc::Cleanup()
#include "log.hpp"
//
#include "datachannel.hpp"
#include "peerconnection.hpp"
// C API

View File

@ -245,13 +245,6 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
if (auto it = cache.find(commonName); it != cache.end())
return it->second;
if (cache.empty()) {
// This is the first call to OpenSSL
OPENSSL_init_ssl(0, NULL);
SSL_load_error_strings();
ERR_load_crypto_strings();
}
shared_ptr<X509> x509(X509_new(), X509_free);
shared_ptr<EVP_PKEY> pkey(EVP_PKEY_new(), EVP_PKEY_free);

View File

@ -30,6 +30,7 @@
namespace rtc {
using std::shared_ptr;
using std::weak_ptr;
// Messages for the DataChannel establishment protocol
// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09
@ -66,16 +67,16 @@ struct CloseMessage {
const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // 1 MiB
DataChannel::DataChannel(shared_ptr<PeerConnection> pc, unsigned int stream, string label,
DataChannel::DataChannel(weak_ptr<PeerConnection> pc, unsigned int stream, string label,
string protocol, Reliability reliability)
: mPeerConnection(std::move(pc)), mStream(stream), mLabel(std::move(label)),
: mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
mProtocol(std::move(protocol)),
mReliability(std::make_shared<Reliability>(std::move(reliability))),
mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
DataChannel::DataChannel(shared_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport,
DataChannel::DataChannel(weak_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport,
unsigned int stream)
: mPeerConnection(std::move(pc)), mSctpTransport(transport), mStream(stream),
: mPeerConnection(pc), mSctpTransport(transport), mStream(stream),
mReliability(std::make_shared<Reliability>()),
mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
@ -147,9 +148,10 @@ bool DataChannel::isClosed(void) const { return mIsClosed; }
size_t DataChannel::maxMessageSize() const {
size_t max = DEFAULT_MAX_MESSAGE_SIZE;
if (auto description = mPeerConnection->remoteDescription())
if (auto maxMessageSize = description->maxMessageSize())
return *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
if (auto pc = mPeerConnection.lock())
if (auto description = pc->remoteDescription())
if (auto maxMessageSize = description->maxMessageSize())
return *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
return std::min(max, LOCAL_MAX_MESSAGE_SIZE);
}

View File

@ -55,6 +55,14 @@ static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
namespace rtc {
void DtlsTransport::Init() {
// Nothing to do
}
void DtlsTransport::Cleanup() {
// Nothing to do
}
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback,
state_callback stateChangeCallback)
@ -131,10 +139,13 @@ bool DtlsTransport::send(message_ptr message) {
}
void DtlsTransport::incoming(message_ptr message) {
if (message)
mIncomingQueue.push(message);
else
if (!message) {
mIncomingQueue.stop();
return;
}
PLOG_VERBOSE << "Incoming size=" << message->size();
mIncomingQueue.push(message);
}
void DtlsTransport::changeState(State state) {
@ -262,10 +273,8 @@ ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size
int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
if (ms != GNUTLS_INDEFINITE_TIMEOUT)
t->mIncomingQueue.wait(milliseconds(ms));
else
t->mIncomingQueue.wait();
t->mIncomingQueue.wait(ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms))
: nullopt);
return !t->mIncomingQueue.empty() ? 1 : 0;
}
@ -323,7 +332,7 @@ BIO_METHOD *DtlsTransport::BioMethods = NULL;
int DtlsTransport::TransportExIndex = -1;
std::mutex DtlsTransport::GlobalMutex;
void DtlsTransport::GlobalInit() {
void DtlsTransport::Init() {
std::lock_guard lock(GlobalMutex);
if (!BioMethods) {
BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
@ -339,6 +348,10 @@ void DtlsTransport::GlobalInit() {
}
}
void DtlsTransport::Cleanup() {
// Nothing to do
}
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
@ -346,7 +359,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
mStateChangeCallback(std::move(stateChangeCallback)) {
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
GlobalInit();
if (!(mCtx = SSL_CTX_new(DTLS_method())))
throw std::runtime_error("Unable to create SSL context");
@ -432,10 +444,13 @@ bool DtlsTransport::send(message_ptr message) {
}
void DtlsTransport::incoming(message_ptr message) {
if (message)
mIncomingQueue.push(message);
else
if (!message) {
mIncomingQueue.stop();
return;
}
PLOG_VERBOSE << "Incoming size=" << message->size();
mIncomingQueue.push(message);
}
void DtlsTransport::changeState(State state) {
@ -448,29 +463,47 @@ void DtlsTransport::runRecvLoop() {
try {
changeState(State::Connecting);
SSL_do_handshake(mSsl);
int ret = SSL_do_handshake(mSsl);
check_openssl_ret(mSsl, ret, "Handshake failed");
const size_t bufferSize = maxMtu;
byte buffer[bufferSize];
while (auto next = mIncomingQueue.pop()) {
auto message = *next;
BIO_write(mInBio, message->data(), message->size());
int ret = SSL_read(mSsl, buffer, bufferSize);
if (!check_openssl_ret(mSsl, ret))
break;
while (true) {
std::optional<milliseconds> duration;
struct timeval timeout = {};
if (DTLSv1_get_timeout(mSsl, &timeout))
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
auto decrypted = ret > 0 ? make_message(buffer, buffer + ret) : nullptr;
if (!mIncomingQueue.wait(duration))
break; // queue is stopped
message_ptr decrypted;
if (!mIncomingQueue.empty()) {
auto message = *mIncomingQueue.pop();
BIO_write(mInBio, message->data(), message->size());
int ret = SSL_read(mSsl, buffer, bufferSize);
if (!check_openssl_ret(mSsl, ret))
break;
if (ret > 0)
decrypted = make_message(buffer, buffer + ret);
}
if (mState == State::Connecting) {
if (unsigned long err = ERR_get_error())
throw std::runtime_error("handshake failed: " + openssl_error_string(err));
if (SSL_is_init_finished(mSsl)) {
changeState(State::Connected);
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
// See https://tools.ietf.org/html/rfc8261#section-5
SSL_set_mtu(mSsl, maxMtu + 1);
} else {
// Continue the handshake
int ret = SSL_do_handshake(mSsl);
if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
break;
DTLSv1_handle_timeout(mSsl);
}
}
@ -486,7 +519,7 @@ void DtlsTransport::runRecvLoop() {
changeState(State::Disconnected);
recv(nullptr);
} else {
PLOG_INFO << "DTLS handshake failed";
PLOG_ERROR << "DTLS handshake failed";
changeState(State::Failed);
}
}

View File

@ -43,6 +43,9 @@ class IceTransport;
class DtlsTransport : public Transport {
public:
static void Init();
static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using verifier_callback = std::function<bool(const std::string &fingerprint)>;
@ -87,7 +90,6 @@ private:
static int TransportExIndex;
static std::mutex GlobalMutex;
static void GlobalInit();
static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
static void InfoCallback(const SSL *ssl, int where, int ret);

View File

@ -162,7 +162,10 @@ bool IceTransport::send(message_ptr message) {
return outgoing(message);
}
void IceTransport::incoming(message_ptr message) { recv(message); }
void IceTransport::incoming(message_ptr message) {
PLOG_VERBOSE << "Incoming size=" << message->size();
recv(message);
}
void IceTransport::incoming(const byte *data, int size) {
incoming(make_message(data, data + size));
@ -516,7 +519,10 @@ bool IceTransport::send(message_ptr message) {
return outgoing(message);
}
void IceTransport::incoming(message_ptr message) { recv(message); }
void IceTransport::incoming(message_ptr message) {
PLOG_VERBOSE << "Incoming size=" << message->size();
recv(message);
}
void IceTransport::incoming(const byte *data, int size) {
incoming(make_message(data, data + size));

86
src/init.cpp Normal file
View File

@ -0,0 +1,86 @@
/**
* Copyright (c) 2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#include "init.hpp"
#include "dtlstransport.hpp"
#include "sctptransport.hpp"
#ifdef _WIN32
#include <winsock2.h>
#endif
#if USE_GNUTLS
// Nothing to do
#else
#include <openssl/err.h>
#include <openssl/ssl.h>
#endif
using std::shared_ptr;
namespace rtc {
std::weak_ptr<Init> Init::Weak;
init_token Init::Global;
std::mutex Init::Mutex;
init_token Init::Token() {
std::lock_guard lock(Mutex);
if (!Global) {
if (auto token = Weak.lock())
Global = token;
else
Global = shared_ptr<Init>(new Init());
}
return Global;
}
void Init::Cleanup() { Global.reset(); }
Init::Init() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData))
throw std::runtime_error("WSAStartup failed, error=" + std::to_string(WSAGetLastError()));
#endif
#if USE_GNUTLS
// Nothing to do
#else
OPENSSL_init_ssl(0, NULL);
SSL_load_error_strings();
ERR_load_crypto_strings();
#endif
DtlsTransport::Init();
SctpTransport::Init();
}
Init::~Init() {
DtlsTransport::Cleanup();
SctpTransport::Cleanup();
#ifdef _WIN32
WSACleanup();
#endif
}
} // namespace rtc

42
src/log.cpp Normal file
View File

@ -0,0 +1,42 @@
/**
* Copyright (c) 2019-2020 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#include "log.hpp"
#include "plog/Appenders/ColorConsoleAppender.h"
#include "plog/Log.h"
#include "plog/Logger.h"
namespace rtc {
void InitLogger(LogLevel level) { InitLogger(static_cast<plog::Severity>(level)); }
void InitLogger(plog::Severity severity, plog::IAppender *appender) {
static plog::ColorConsoleAppender<plog::TxtFormatter> consoleAppender;
static plog::Logger<0> *logger = nullptr;
if (!logger) {
logger = &plog::init(severity, appender ? appender : &consoleAppender);
PLOG_DEBUG << "Logger initialized";
} else {
logger->setMaxSeverity(severity);
if (appender)
logger->addAppender(appender);
}
}
}

View File

@ -25,10 +25,6 @@
#include <iostream>
#ifdef _WIN32
#include <winsock2.h>
#endif
namespace rtc {
using namespace std::placeholders;
@ -37,11 +33,6 @@ using std::shared_ptr;
using std::weak_ptr;
PeerConnection::PeerConnection() : PeerConnection(Configuration()) {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData))
throw std::runtime_error("WSAStartup failed, error=" + std::to_string(WSAGetLastError()));
#endif
}
PeerConnection::PeerConnection(const Configuration &config)
@ -53,10 +44,6 @@ PeerConnection::~PeerConnection() {
mSctpTransport.reset();
mDtlsTransport.reset();
mIceTransport.reset();
#ifdef _WIN32
WSACleanup();
#endif
}
void PeerConnection::close() {
@ -415,10 +402,8 @@ shared_ptr<DataChannel> PeerConnection::findDataChannel(uint16_t stream) {
shared_ptr<DataChannel> channel;
if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
channel = it->second.lock();
if (!channel || channel->isClosed()) {
if (!channel)
mDataChannels.erase(it);
channel.reset();
}
}
return channel;
}
@ -429,11 +414,13 @@ void PeerConnection::iterateDataChannels(
auto it = mDataChannels.begin();
while (it != mDataChannels.end()) {
auto channel = it->second.lock();
if (!channel || channel->isClosed()) {
if (!channel) {
it = mDataChannels.erase(it);
continue;
}
func(channel);
if (!channel->isClosed()) {
func(channel);
}
++it;
}
}

View File

@ -49,31 +49,20 @@ using std::shared_ptr;
namespace rtc {
std::mutex SctpTransport::GlobalMutex;
int SctpTransport::InstancesCount = 0;
void SctpTransport::GlobalInit() {
std::lock_guard lock(GlobalMutex);
if (InstancesCount++ == 0) {
usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
usrsctp_sysctl_set_sctp_ecn_enable(0);
usrsctp_sysctl_set_sctp_init_rtx_max_default(5);
usrsctp_sysctl_set_sctp_path_rtx_max_default(5);
usrsctp_sysctl_set_sctp_assoc_rtx_max_default(5); // single path
usrsctp_sysctl_set_sctp_rto_min_default(1 * 1000); // ms
usrsctp_sysctl_set_sctp_rto_max_default(10 * 1000); // ms
usrsctp_sysctl_set_sctp_rto_initial_default(1 * 1000); // ms
usrsctp_sysctl_set_sctp_init_rto_max_default(10 * 1000); // ms
usrsctp_sysctl_set_sctp_heartbeat_interval_default(10 * 1000); // ms
}
void SctpTransport::Init() {
usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
usrsctp_sysctl_set_sctp_ecn_enable(0);
usrsctp_sysctl_set_sctp_init_rtx_max_default(5);
usrsctp_sysctl_set_sctp_path_rtx_max_default(5);
usrsctp_sysctl_set_sctp_assoc_rtx_max_default(5); // single path
usrsctp_sysctl_set_sctp_rto_min_default(1 * 1000); // ms
usrsctp_sysctl_set_sctp_rto_max_default(10 * 1000); // ms
usrsctp_sysctl_set_sctp_rto_initial_default(1 * 1000); // ms
usrsctp_sysctl_set_sctp_init_rto_max_default(10 * 1000); // ms
usrsctp_sysctl_set_sctp_heartbeat_interval_default(10 * 1000); // ms
}
void SctpTransport::GlobalCleanup() {
std::lock_guard lock(GlobalMutex);
if (--InstancesCount == 0) {
usrsctp_finish();
}
}
void SctpTransport::Cleanup() { usrsctp_finish(); }
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
message_callback recvCallback, amount_callback bufferedAmountCallback,
@ -84,7 +73,6 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
onRecv(recvCallback);
PLOG_DEBUG << "Initializing SCTP transport";
GlobalInit();
usrsctp_register_address(this);
mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::RecvCallback,
@ -175,8 +163,6 @@ SctpTransport::~SctpTransport() {
usrsctp_close(mSock);
usrsctp_deregister_address(this);
GlobalCleanup();
}
SctpTransport::State SctpTransport::state() const { return mState; }
@ -187,7 +173,7 @@ void SctpTransport::stop() {
if (!mShutdown.exchange(true)) {
mSendQueue.stop();
flush();
safeFlush();
shutdown();
}
}
@ -277,13 +263,15 @@ void SctpTransport::incoming(message_ptr message) {
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
}
if (message) {
usrsctp_conninput(this, message->data(), message->size(), 0);
} else {
if (!message) {
PLOG_INFO << "SCTP disconnected";
changeState(State::Disconnected);
recv(nullptr);
return;
}
PLOG_VERBOSE << "Incoming size=" << message->size();
usrsctp_conninput(this, message->data(), message->size(), 0);
}
void SctpTransport::changeState(State state) {
@ -381,18 +369,33 @@ bool SctpTransport::trySendMessage(message_ptr message) {
void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
// Requires mSendMutex to be locked
auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
size_t amount = it->second;
amount = size_t(std::max(long(amount) + delta, long(0)));
size_t amount = size_t(std::max(long(it->second) + delta, long(0)));
if (amount == 0)
mBufferedAmount.erase(it);
else
it->second = amount;
mBufferedAmountCallback(streamId, amount);
}
bool SctpTransport::safeFlush() {
try {
flush();
return true;
} catch (const std::exception &e) {
PLOG_ERROR << "SCTP flush: " << e.what();
return false;
}
}
int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
size_t len, struct sctp_rcvinfo info, int flags) {
try {
PLOG_VERBOSE << "Handle recv, len=" << len;
if (!len)
return -1;
if (flags & MSG_EOR) {
if (!mPartialRecv.empty()) {
mPartialRecv.insert(mPartialRecv.end(), data, data + len);
@ -418,24 +421,21 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co
}
int SctpTransport::handleSend(size_t free) {
try {
std::lock_guard lock(mSendMutex);
trySendQueue();
} catch (const std::exception &e) {
PLOG_ERROR << "SCTP send: " << e.what();
return -1;
}
return 0; // success
PLOG_VERBOSE << "Handle send, free=" << free;
return safeFlush() ? 0 : -1;
}
int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) {
try {
PLOG_VERBOSE << "Handle write, len=" << len;
std::unique_lock lock(mWriteMutex);
if (!outgoing(make_message(data, data + len)))
return -1;
mWritten = true;
mWrittenOnce = true;
mWrittenCondition.notify_all();
} catch (const std::exception &e) {
PLOG_ERROR << "SCTP write: " << e.what();
return -1;
@ -444,6 +444,8 @@ int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_
}
void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, PayloadId ppid) {
PLOG_VERBOSE << "Process data, len=" << len;
// The usage of the PPIDs "WebRTC String Partial" and "WebRTC Binary Partial" is deprecated.
// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.6
// We handle them at reception for compatibility reasons but should never send them.
@ -504,10 +506,15 @@ void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, Payl
}
void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
if (len != size_t(notify->sn_header.sn_length))
if (len != size_t(notify->sn_header.sn_length)) {
PLOG_WARNING << "Invalid notification length";
return;
}
switch (notify->sn_header.sn_type) {
auto type = notify->sn_header.sn_type;
PLOG_VERBOSE << "Process notification, type=" << type;
switch (type) {
case SCTP_ASSOC_CHANGE: {
const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
if (assoc_change.sac_state == SCTP_COMM_UP) {
@ -523,13 +530,16 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
}
mWrittenCondition.notify_all();
}
break;
}
case SCTP_SENDER_DRY_EVENT: {
// It not should be necessary since the send callback should have been called already,
// but to be sure, let's try to send now.
std::lock_guard lock(mSendMutex);
trySendQueue();
safeFlush();
break;
}
case SCTP_STREAM_RESET_EVENT: {
const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t);

View File

@ -35,6 +35,9 @@ namespace rtc {
class SctpTransport : public Transport {
public:
static void Init();
static void Cleanup();
enum class State { Disconnected, Connecting, Connected, Failed };
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
@ -72,6 +75,7 @@ private:
bool trySendQueue();
bool trySendMessage(message_ptr message);
void updateBufferedAmount(uint16_t streamId, long delta);
bool safeFlush();
int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
struct sctp_rcvinfo recv_info, int flags);
@ -105,12 +109,6 @@ private:
struct sctp_rcvinfo recv_info, int flags, void *user_data);
static int SendCallback(struct socket *sock, uint32_t sb_free);
static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
void GlobalInit();
void GlobalCleanup();
static std::mutex GlobalMutex;
static int InstancesCount;
};
} // namespace rtc