Revised synchronization

This commit is contained in:
Paul-Louis Ageneau
2019-12-16 10:45:00 +01:00
parent 5a8725dac1
commit e5a19f85ed
11 changed files with 138 additions and 95 deletions

View File

@ -57,13 +57,13 @@ public:
~synchronized_callback() { *this = nullptr; } ~synchronized_callback() { *this = nullptr; }
synchronized_callback &operator=(std::function<void(P...)> func) { synchronized_callback &operator=(std::function<void(P...)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard lock(mutex);
callback = func; callback = func;
return *this; return *this;
} }
void operator()(P... args) const { void operator()(P... args) const {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard lock(mutex);
if (callback) if (callback)
callback(args...); callback(args...);
} }

View File

@ -31,6 +31,7 @@
#include <atomic> #include <atomic>
#include <functional> #include <functional>
#include <list> #include <list>
#include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
@ -83,9 +84,9 @@ public:
void onGatheringStateChange(std::function<void(GatheringState state)> callback); void onGatheringStateChange(std::function<void(GatheringState state)> callback);
private: private:
void initIceTransport(Description::Role role); std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
void initDtlsTransport(); std::shared_ptr<DtlsTransport> initDtlsTransport();
void initSctpTransport(); std::shared_ptr<SctpTransport> initSctpTransport();
bool checkFingerprint(const std::string &fingerprint) const; bool checkFingerprint(const std::string &fingerprint) const;
void forwardMessage(message_ptr message); void forwardMessage(message_ptr message);
@ -103,8 +104,8 @@ private:
const Configuration mConfig; const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;
std::optional<Description> mLocalDescription; std::optional<Description> mLocalDescription, mRemoteDescription;
std::optional<Description> mRemoteDescription; mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
std::shared_ptr<IceTransport> mIceTransport; std::shared_ptr<IceTransport> mIceTransport;
std::shared_ptr<DtlsTransport> mDtlsTransport; std::shared_ptr<DtlsTransport> mDtlsTransport;

View File

@ -67,31 +67,31 @@ Queue<T>::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0)
template <typename T> Queue<T>::~Queue() { stop(); } template <typename T> Queue<T>::~Queue() { stop(); }
template <typename T> void Queue<T>::stop() { template <typename T> void Queue<T>::stop() {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard lock(mMutex);
mStopping = true; mStopping = true;
mPopCondition.notify_all(); mPopCondition.notify_all();
mPushCondition.notify_all(); mPushCondition.notify_all();
} }
template <typename T> bool Queue<T>::empty() const { template <typename T> bool Queue<T>::empty() const {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard lock(mMutex);
return mQueue.empty(); return mQueue.empty();
} }
template <typename T> size_t Queue<T>::size() const { template <typename T> size_t Queue<T>::size() const {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard lock(mMutex);
return mQueue.size(); return mQueue.size();
} }
template <typename T> size_t Queue<T>::amount() const { template <typename T> size_t Queue<T>::amount() const {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard lock(mMutex);
return mAmount; return mAmount;
} }
template <typename T> void Queue<T>::push(const T &element) { push(T{element}); } template <typename T> void Queue<T>::push(const T &element) { push(T{element}); }
template <typename T> void Queue<T>::push(T &&element) { template <typename T> void Queue<T>::push(T &&element) {
std::unique_lock<std::mutex> lock(mMutex); std::unique_lock lock(mMutex);
mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; }); mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
if (!mStopping) { if (!mStopping) {
mAmount += mAmountFunction(element); mAmount += mAmountFunction(element);
@ -101,7 +101,7 @@ template <typename T> void Queue<T>::push(T &&element) {
} }
template <typename T> std::optional<T> Queue<T>::pop() { template <typename T> std::optional<T> Queue<T>::pop() {
std::unique_lock<std::mutex> lock(mMutex); std::unique_lock lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; }); mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
if (!mQueue.empty()) { if (!mQueue.empty()) {
mAmount -= mAmountFunction(mQueue.front()); mAmount -= mAmountFunction(mQueue.front());
@ -114,7 +114,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
} }
template <typename T> std::optional<T> Queue<T>::peek() { template <typename T> std::optional<T> Queue<T>::peek() {
std::unique_lock<std::mutex> lock(mMutex); std::unique_lock lock(mMutex);
if (!mQueue.empty()) { if (!mQueue.empty()) {
return std::optional<T>{mQueue.front()}; return std::optional<T>{mQueue.front()};
} else { } else {
@ -123,12 +123,12 @@ template <typename T> std::optional<T> Queue<T>::peek() {
} }
template <typename T> void Queue<T>::wait() { template <typename T> void Queue<T>::wait() {
std::unique_lock<std::mutex> lock(mMutex); std::unique_lock lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; }); mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
} }
template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) { template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
std::unique_lock<std::mutex> lock(mMutex); std::unique_lock lock(mMutex);
mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; }); mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
} }

View File

@ -145,7 +145,7 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
static std::unordered_map<string, shared_ptr<Certificate>> cache; static std::unordered_map<string, shared_ptr<Certificate>> cache;
static std::mutex cacheMutex; static std::mutex cacheMutex;
std::lock_guard<std::mutex> lock(cacheMutex); std::lock_guard lock(cacheMutex);
if (auto it = cache.find(commonName); it != cache.end()) if (auto it = cache.find(commonName); it != cache.end())
return it->second; return it->second;
@ -241,7 +241,7 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
static std::unordered_map<string, shared_ptr<Certificate>> cache; static std::unordered_map<string, shared_ptr<Certificate>> cache;
static std::mutex cacheMutex; static std::mutex cacheMutex;
std::lock_guard<std::mutex> lock(cacheMutex); std::lock_guard lock(cacheMutex);
if (auto it = cache.find(commonName); it != cache.end()) if (auto it = cache.find(commonName); it != cache.end())
return it->second; return it->second;

View File

@ -85,6 +85,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
} }
DtlsTransport::~DtlsTransport() { DtlsTransport::~DtlsTransport() {
stop();
gnutls_bye(mSession, GNUTLS_SHUT_RDWR); gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
gnutls_deinit(mSession); gnutls_deinit(mSession);
} }
@ -94,8 +96,10 @@ DtlsTransport::State DtlsTransport::state() const { return mState; }
void DtlsTransport::stop() { void DtlsTransport::stop() {
Transport::stop(); Transport::stop();
mIncomingQueue.stop(); if (mRecvThread.joinable()) {
mRecvThread.join(); mIncomingQueue.stop();
mRecvThread.join();
}
} }
bool DtlsTransport::send(message_ptr message) { bool DtlsTransport::send(message_ptr message) {
@ -293,7 +297,7 @@ int DtlsTransport::TransportExIndex = -1;
std::mutex DtlsTransport::GlobalMutex; std::mutex DtlsTransport::GlobalMutex;
void DtlsTransport::GlobalInit() { void DtlsTransport::GlobalInit() {
std::lock_guard<std::mutex> lock(GlobalMutex); std::lock_guard lock(GlobalMutex);
if (TransportExIndex < 0) { if (TransportExIndex < 0) {
TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
} }
@ -358,6 +362,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
} }
DtlsTransport::~DtlsTransport() { DtlsTransport::~DtlsTransport() {
stop();
SSL_shutdown(mSsl); SSL_shutdown(mSsl);
SSL_free(mSsl); SSL_free(mSsl);
SSL_CTX_free(mCtx); SSL_CTX_free(mCtx);
@ -366,8 +372,10 @@ DtlsTransport::~DtlsTransport() {
void DtlsTransport::stop() { void DtlsTransport::stop() {
Transport::stop(); Transport::stop();
mIncomingQueue.stop(); if (mRecvThread.joinable()) {
mRecvThread.join(); mIncomingQueue.stop();
mRecvThread.join();
}
} }
DtlsTransport::State DtlsTransport::state() const { return mState; } DtlsTransport::State DtlsTransport::state() const { return mState; }

View File

@ -55,10 +55,10 @@ public:
State state() const; State state() const;
void stop() override; void stop() override;
bool send(message_ptr message); // false if dropped bool send(message_ptr message) override; // false if dropped
private: private:
void incoming(message_ptr message); void incoming(message_ptr message) override;
void changeState(State state); void changeState(State state);
void runRecvLoop(); void runRecvLoop();

View File

@ -130,11 +130,13 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
RecvCallback, this); RecvCallback, this);
} }
IceTransport::~IceTransport() {} IceTransport::~IceTransport() { stop(); }
void IceTransport::stop() { void IceTransport::stop() {
g_main_loop_quit(mMainLoop.get()); if (mMainLoopThread.joinable()) {
mMainLoopThread.join(); g_main_loop_quit(mMainLoop.get());
mMainLoopThread.join();
}
} }
Description::Role IceTransport::role() const { return mRole; } Description::Role IceTransport::role() const { return mRole; }

View File

@ -71,9 +71,9 @@ public:
bool send(message_ptr message) override; // false if dropped bool send(message_ptr message) override; // false if dropped
private: private:
void incoming(message_ptr message); void incoming(message_ptr message) override;
void incoming(const byte *data, int size); void incoming(const byte *data, int size);
void outgoing(message_ptr message); void outgoing(message_ptr message) override;
void changeState(State state); void changeState(State state);
void changeGatheringState(GatheringState state); void changeGatheringState(GatheringState state);

View File

@ -20,6 +20,7 @@
#include "certificate.hpp" #include "certificate.hpp"
#include "dtlstransport.hpp" #include "dtlstransport.hpp"
#include "icetransport.hpp" #include "icetransport.hpp"
#include "include.hpp"
#include "sctptransport.hpp" #include "sctptransport.hpp"
#include <iostream> #include <iostream>
@ -37,12 +38,12 @@ PeerConnection::PeerConnection(const Configuration &config)
: mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {} : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
PeerConnection::~PeerConnection() { PeerConnection::~PeerConnection() {
if (mIceTransport) if (auto transport = std::atomic_load(&mIceTransport))
mIceTransport->stop(); transport->stop();
if (mDtlsTransport) if (auto transport = std::atomic_load(&mDtlsTransport))
mDtlsTransport->stop(); transport->stop();
if (mSctpTransport) if (auto transport = std::atomic_load(&mSctpTransport))
mSctpTransport->stop(); transport->stop();
mSctpTransport.reset(); mSctpTransport.reset();
mDtlsTransport.reset(); mDtlsTransport.reset();
@ -55,26 +56,36 @@ PeerConnection::State PeerConnection::state() const { return mState; }
PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; } PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; }
std::optional<Description> PeerConnection::localDescription() const { return mLocalDescription; } std::optional<Description> PeerConnection::localDescription() const {
std::lock_guard lock(mLocalDescriptionMutex);
return mLocalDescription;
}
std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; } std::optional<Description> PeerConnection::remoteDescription() const {
std::lock_guard lock(mRemoteDescriptionMutex);
return mRemoteDescription;
}
void PeerConnection::setRemoteDescription(Description description) { void PeerConnection::setRemoteDescription(Description description) {
std::lock_guard lock(mRemoteDescriptionMutex);
auto remoteCandidates = description.extractCandidates(); auto remoteCandidates = description.extractCandidates();
mRemoteDescription.emplace(std::move(description)); mRemoteDescription.emplace(std::move(description));
if (!mIceTransport) auto iceTransport = std::atomic_load(&mIceTransport);
initIceTransport(Description::Role::ActPass); if (!iceTransport)
iceTransport = initIceTransport(Description::Role::ActPass);
mIceTransport->setRemoteDescription(*mRemoteDescription); iceTransport->setRemoteDescription(*mRemoteDescription);
if (mRemoteDescription->type() == Description::Type::Offer) { if (mRemoteDescription->type() == Description::Type::Offer) {
// This is an offer and we are the answerer. // This is an offer and we are the answerer.
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Answer)); processLocalDescription(iceTransport->getLocalDescription(Description::Type::Answer));
mIceTransport->gatherLocalCandidates(); iceTransport->gatherLocalCandidates();
} else { } else {
// This is an answer and we are the offerer. // This is an answer and we are the offerer.
if (!mSctpTransport && mIceTransport->role() == Description::Role::Active) { auto sctpTransport = std::atomic_load(&mSctpTransport);
if (!sctpTransport && iceTransport->role() == Description::Role::Active) {
// Since we assumed passive role during DataChannel creation, we need to shift the // Since we assumed passive role during DataChannel creation, we need to shift the
// stream numbers by one to shift them from odd to even. // stream numbers by one to shift them from odd to even.
decltype(mDataChannels) newDataChannels; decltype(mDataChannels) newDataChannels;
@ -92,16 +103,19 @@ void PeerConnection::setRemoteDescription(Description description) {
} }
void PeerConnection::addRemoteCandidate(Candidate candidate) { void PeerConnection::addRemoteCandidate(Candidate candidate) {
if (!mRemoteDescription || !mIceTransport) std::lock_guard lock(mRemoteDescriptionMutex);
auto iceTransport = std::atomic_load(&mIceTransport);
if (!mRemoteDescription || !iceTransport)
throw std::logic_error("Remote candidate set without remote description"); throw std::logic_error("Remote candidate set without remote description");
mRemoteDescription->addCandidate(candidate); mRemoteDescription->addCandidate(candidate);
if (candidate.resolve(Candidate::ResolveMode::Simple)) { if (candidate.resolve(Candidate::ResolveMode::Simple)) {
mIceTransport->addRemoteCandidate(candidate); iceTransport->addRemoteCandidate(candidate);
} else { } else {
// OK, we might need a lookup, do it asynchronously // OK, we might need a lookup, do it asynchronously
weak_ptr<IceTransport> weakIceTransport{mIceTransport}; weak_ptr<IceTransport> weakIceTransport{iceTransport};
std::thread t([weakIceTransport, candidate]() mutable { std::thread t([weakIceTransport, candidate]() mutable {
if (candidate.resolve(Candidate::ResolveMode::Lookup)) if (candidate.resolve(Candidate::ResolveMode::Lookup))
if (auto iceTransport = weakIceTransport.lock()) if (auto iceTransport = weakIceTransport.lock())
@ -112,11 +126,13 @@ void PeerConnection::addRemoteCandidate(Candidate candidate) {
} }
std::optional<string> PeerConnection::localAddress() const { std::optional<string> PeerConnection::localAddress() const {
return mIceTransport ? mIceTransport->getLocalAddress() : nullopt; auto iceTransport = std::atomic_load(&mIceTransport);
return iceTransport ? iceTransport->getLocalAddress() : nullopt;
} }
std::optional<string> PeerConnection::remoteAddress() const { std::optional<string> PeerConnection::remoteAddress() const {
return mIceTransport ? mIceTransport->getRemoteAddress() : nullopt; auto iceTransport = std::atomic_load(&mIceTransport);
return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
} }
shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label, shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
@ -126,7 +142,8 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
// setup:passive. [...] Thus, setup:active is RECOMMENDED. // setup:passive. [...] Thus, setup:active is RECOMMENDED.
// See https://tools.ietf.org/html/rfc5763#section-5 // See https://tools.ietf.org/html/rfc5763#section-5
// Therefore, we assume passive role when we are the offerer. // Therefore, we assume passive role when we are the offerer.
auto role = mIceTransport ? mIceTransport->role() : Description::Role::Passive; auto iceTransport = std::atomic_load(&mIceTransport);
auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
// The active side must use streams with even identifiers, whereas the passive side must use // The active side must use streams with even identifiers, whereas the passive side must use
// streams with odd identifiers. // streams with odd identifiers.
@ -142,15 +159,17 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability); std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
mDataChannels.insert(std::make_pair(stream, channel)); mDataChannels.insert(std::make_pair(stream, channel));
if (!mIceTransport) { if (!iceTransport) {
// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of // RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
// setup:actpass. // setup:actpass.
// See https://tools.ietf.org/html/rfc5763#section-5 // See https://tools.ietf.org/html/rfc5763#section-5
initIceTransport(Description::Role::ActPass); iceTransport = initIceTransport(Description::Role::ActPass);
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer)); processLocalDescription(iceTransport->getLocalDescription(Description::Type::Offer));
mIceTransport->gatherLocalCandidates(); iceTransport->gatherLocalCandidates();
} else if (mSctpTransport && mSctpTransport->state() == SctpTransport::State::Connected) { } else {
channel->open(mSctpTransport); if (auto transport = std::atomic_load(&mSctpTransport))
if (transport->state() == SctpTransport::State::Connected)
channel->open(transport);
} }
return channel; return channel;
} }
@ -177,8 +196,8 @@ void PeerConnection::onGatheringStateChange(std::function<void(GatheringState st
mGatheringStateChangeCallback = callback; mGatheringStateChangeCallback = callback;
} }
void PeerConnection::initIceTransport(Description::Role role) { shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role) {
mIceTransport = std::make_shared<IceTransport>( auto transport = std::make_shared<IceTransport>(
mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1), mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
[this](IceTransport::State state) { [this](IceTransport::State state) {
switch (state) { switch (state) {
@ -211,11 +230,14 @@ void PeerConnection::initIceTransport(Description::Role role) {
break; break;
} }
}); });
std::atomic_store(&mIceTransport, transport);
return transport;
} }
void PeerConnection::initDtlsTransport() { shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
mDtlsTransport = std::make_shared<DtlsTransport>( auto lower = std::atomic_load(&mIceTransport);
mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1), auto transport = std::make_shared<DtlsTransport>(
lower, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
[this](DtlsTransport::State state) { [this](DtlsTransport::State state) {
switch (state) { switch (state) {
case DtlsTransport::State::Connected: case DtlsTransport::State::Connected:
@ -229,12 +251,15 @@ void PeerConnection::initDtlsTransport() {
break; break;
} }
}); });
std::atomic_store(&mDtlsTransport, transport);
return transport;
} }
void PeerConnection::initSctpTransport() { shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT); uint16_t sctpPort = remoteDescription()->sctpPort().value_or(DEFAULT_SCTP_PORT);
mSctpTransport = std::make_shared<SctpTransport>( auto lower = std::atomic_load(&mDtlsTransport);
mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1), auto transport = std::make_shared<SctpTransport>(
lower, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
std::bind(&PeerConnection::forwardBufferedAmount, this, _1, _2), std::bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
[this](SctpTransport::State state) { [this](SctpTransport::State state) {
switch (state) { switch (state) {
@ -253,9 +278,12 @@ void PeerConnection::initSctpTransport() {
break; break;
} }
}); });
std::atomic_store(&mSctpTransport, transport);
return transport;
} }
bool PeerConnection::checkFingerprint(const std::string &fingerprint) const { bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
std::lock_guard lock(mRemoteDescriptionMutex);
if (auto expectedFingerprint = if (auto expectedFingerprint =
mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) { mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) {
return *expectedFingerprint == fingerprint; return *expectedFingerprint == fingerprint;
@ -264,9 +292,6 @@ bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
} }
void PeerConnection::forwardMessage(message_ptr message) { void PeerConnection::forwardMessage(message_ptr message) {
if (!mIceTransport || !mSctpTransport)
throw std::logic_error("Got a DataChannel message without transport");
if (!message) { if (!message) {
closeDataChannels(); closeDataChannels();
return; return;
@ -281,19 +306,24 @@ void PeerConnection::forwardMessage(message_ptr message) {
} }
} }
auto iceTransport = std::atomic_load(&mIceTransport);
auto sctpTransport = std::atomic_load(&mSctpTransport);
if (!iceTransport || !sctpTransport)
return;
if (!channel) { if (!channel) {
const byte dataChannelOpenMessage{0x03}; const byte dataChannelOpenMessage{0x03};
unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0; unsigned int remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
if (message->type == Message::Control && *message->data() == dataChannelOpenMessage && if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
message->stream % 2 == remoteParity) { message->stream % 2 == remoteParity) {
channel = channel =
std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream); std::make_shared<DataChannel>(shared_from_this(), sctpTransport, message->stream);
channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this,
weak_ptr<DataChannel>{channel})); weak_ptr<DataChannel>{channel}));
mDataChannels.insert(std::make_pair(message->stream, channel)); mDataChannels.insert(std::make_pair(message->stream, channel));
} else { } else {
// Invalid, close the DataChannel by resetting the stream // Invalid, close the DataChannel by resetting the stream
mSctpTransport->reset(message->stream); sctpTransport->reset(message->stream);
return; return;
} }
} }
@ -330,16 +360,20 @@ void PeerConnection::iterateDataChannels(
} }
void PeerConnection::openDataChannels() { void PeerConnection::openDataChannels() {
iterateDataChannels([this](shared_ptr<DataChannel> channel) { channel->open(mSctpTransport); }); if (auto transport = std::atomic_load(&mSctpTransport))
iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->open(transport); });
} }
void PeerConnection::closeDataChannels() { void PeerConnection::closeDataChannels() {
iterateDataChannels([](shared_ptr<DataChannel> channel) { channel->close(); }); iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->close(); });
} }
void PeerConnection::processLocalDescription(Description description) { void PeerConnection::processLocalDescription(Description description) {
auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt; std::optional<uint16_t> remoteSctpPort;
if (auto remote = remoteDescription())
remoteSctpPort = remote->sctpPort();
std::lock_guard lock(mLocalDescriptionMutex);
mLocalDescription.emplace(std::move(description)); mLocalDescription.emplace(std::move(description));
mLocalDescription->setFingerprint(mCertificate->fingerprint()); mLocalDescription->setFingerprint(mCertificate->fingerprint());
mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT)); mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
@ -349,6 +383,7 @@ void PeerConnection::processLocalDescription(Description description) {
} }
void PeerConnection::processLocalCandidate(Candidate candidate) { void PeerConnection::processLocalCandidate(Candidate candidate) {
std::lock_guard lock(mLocalDescriptionMutex);
if (!mLocalDescription) if (!mLocalDescription)
throw std::logic_error("Got a local candidate without local description"); throw std::logic_error("Got a local candidate without local description");

View File

@ -33,7 +33,7 @@ std::mutex SctpTransport::GlobalMutex;
int SctpTransport::InstancesCount = 0; int SctpTransport::InstancesCount = 0;
void SctpTransport::GlobalInit() { void SctpTransport::GlobalInit() {
std::unique_lock<std::mutex> lock(GlobalMutex); std::lock_guard lock(GlobalMutex);
if (InstancesCount++ == 0) { if (InstancesCount++ == 0) {
usrsctp_init(0, &SctpTransport::WriteCallback, nullptr); usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
usrsctp_sysctl_set_sctp_ecn_enable(0); usrsctp_sysctl_set_sctp_ecn_enable(0);
@ -41,7 +41,7 @@ void SctpTransport::GlobalInit() {
} }
void SctpTransport::GlobalCleanup() { void SctpTransport::GlobalCleanup() {
std::unique_lock<std::mutex> lock(GlobalMutex); std::lock_guard lock(GlobalMutex);
if (--InstancesCount == 0) { if (--InstancesCount == 0) {
usrsctp_finish(); usrsctp_finish();
} }
@ -143,6 +143,8 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
} }
SctpTransport::~SctpTransport() { SctpTransport::~SctpTransport() {
stop();
if (mSock) { if (mSock) {
usrsctp_shutdown(mSock, SHUT_RDWR); usrsctp_shutdown(mSock, SHUT_RDWR);
usrsctp_close(mSock); usrsctp_close(mSock);
@ -156,15 +158,14 @@ SctpTransport::State SctpTransport::state() const { return mState; }
void SctpTransport::stop() { void SctpTransport::stop() {
Transport::stop(); Transport::stop();
onRecv(nullptr);
mSendQueue.stop(); mSendQueue.stop();
// Unblock incoming // Unblock incoming
if (!mConnectDataSent) { std::unique_lock<std::mutex> lock(mConnectMutex);
std::unique_lock<std::mutex> lock(mConnectMutex); mConnectDataSent = true;
mConnectDataSent = true; mConnectCondition.notify_all();
mConnectCondition.notify_all();
}
} }
void SctpTransport::connect() { void SctpTransport::connect() {
@ -190,7 +191,7 @@ void SctpTransport::connect() {
} }
bool SctpTransport::send(message_ptr message) { bool SctpTransport::send(message_ptr message) {
std::lock_guard<std::mutex> lock(mSendMutex); std::lock_guard lock(mSendMutex);
if (!message) if (!message)
return mSendQueue.empty(); return mSendQueue.empty();
@ -225,8 +226,8 @@ void SctpTransport::incoming(message_ptr message) {
// There could be a race condition here where we receive the remote INIT before the local one is // There could be a race condition here where we receive the remote INIT before the local one is
// sent, which would result in the connection being aborted. Therefore, we need to wait for data // sent, which would result in the connection being aborted. Therefore, we need to wait for data
// to be sent on our side (i.e. the local INIT) before proceeding. // to be sent on our side (i.e. the local INIT) before proceeding.
if (!mConnectDataSent) { {
std::unique_lock<std::mutex> lock(mConnectMutex); std::unique_lock lock(mConnectMutex);
mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; }); mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; });
} }
@ -361,7 +362,7 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co
int SctpTransport::handleSend(size_t free) { int SctpTransport::handleSend(size_t free) {
try { try {
std::lock_guard<std::mutex> lock(mSendMutex); std::lock_guard lock(mSendMutex);
trySendQueue(); trySendQueue();
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << "SCTP send: " << e.what() << std::endl; std::cerr << "SCTP send: " << e.what() << std::endl;
@ -374,11 +375,9 @@ int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_
try { try {
outgoing(make_message(data, data + len)); outgoing(make_message(data, data + len));
if (!mConnectDataSent) { std::unique_lock lock(mConnectMutex);
std::unique_lock<std::mutex> lock(mConnectMutex); mConnectDataSent = true;
mConnectDataSent = true; mConnectCondition.notify_all();
mConnectCondition.notify_all();
}
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << "SCTP write: " << e.what() << std::endl; std::cerr << "SCTP write: " << e.what() << std::endl;
return -1; return -1;
@ -453,7 +452,6 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
switch (notify->sn_header.sn_type) { switch (notify->sn_header.sn_type) {
case SCTP_ASSOC_CHANGE: { case SCTP_ASSOC_CHANGE: {
const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change; const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
std::unique_lock<std::mutex> lock(mConnectMutex);
if (assoc_change.sac_state == SCTP_COMM_UP) { if (assoc_change.sac_state == SCTP_COMM_UP) {
changeState(State::Connected); changeState(State::Connected);
} else { } else {
@ -468,7 +466,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
case SCTP_SENDER_DRY_EVENT: { case SCTP_SENDER_DRY_EVENT: {
// It not should be necessary since the send callback should have been called already, // It not should be necessary since the send callback should have been called already,
// but to be sure, let's try to send now. // but to be sure, let's try to send now.
std::lock_guard<std::mutex> lock(mSendMutex); std::lock_guard lock(mSendMutex);
trySendQueue(); trySendQueue();
} }
case SCTP_STREAM_RESET_EVENT: { case SCTP_STREAM_RESET_EVENT: {

View File

@ -68,7 +68,7 @@ private:
}; };
void connect(); void connect();
void incoming(message_ptr message); void incoming(message_ptr message) override;
void changeState(State state); void changeState(State state);
bool trySendQueue(); bool trySendQueue();
@ -93,8 +93,7 @@ private:
std::mutex mConnectMutex; std::mutex mConnectMutex;
std::condition_variable mConnectCondition; std::condition_variable mConnectCondition;
std::atomic<bool> mConnectDataSent = false; bool mConnectDataSent = false;
std::atomic<bool> mStopping = false;
state_callback mStateChangeCallback; state_callback mStateChangeCallback;
std::atomic<State> mState; std::atomic<State> mState;