Revised synchronization

This commit is contained in:
Paul-Louis Ageneau
2019-12-16 10:45:00 +01:00
parent 5a8725dac1
commit e5a19f85ed
11 changed files with 138 additions and 95 deletions

View File

@ -57,13 +57,13 @@ public:
~synchronized_callback() { *this = nullptr; }
synchronized_callback &operator=(std::function<void(P...)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::lock_guard lock(mutex);
callback = func;
return *this;
}
void operator()(P... args) const {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::lock_guard lock(mutex);
if (callback)
callback(args...);
}

View File

@ -31,6 +31,7 @@
#include <atomic>
#include <functional>
#include <list>
#include <mutex>
#include <thread>
#include <unordered_map>
@ -83,9 +84,9 @@ public:
void onGatheringStateChange(std::function<void(GatheringState state)> callback);
private:
void initIceTransport(Description::Role role);
void initDtlsTransport();
void initSctpTransport();
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
std::shared_ptr<DtlsTransport> initDtlsTransport();
std::shared_ptr<SctpTransport> initSctpTransport();
bool checkFingerprint(const std::string &fingerprint) const;
void forwardMessage(message_ptr message);
@ -103,8 +104,8 @@ private:
const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate;
std::optional<Description> mLocalDescription;
std::optional<Description> mRemoteDescription;
std::optional<Description> mLocalDescription, mRemoteDescription;
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
std::shared_ptr<IceTransport> mIceTransport;
std::shared_ptr<DtlsTransport> mDtlsTransport;

View File

@ -67,31 +67,31 @@ Queue<T>::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0)
template <typename T> Queue<T>::~Queue() { stop(); }
template <typename T> void Queue<T>::stop() {
std::lock_guard<std::mutex> lock(mMutex);
std::lock_guard lock(mMutex);
mStopping = true;
mPopCondition.notify_all();
mPushCondition.notify_all();
}
template <typename T> bool Queue<T>::empty() const {
std::lock_guard<std::mutex> lock(mMutex);
std::lock_guard lock(mMutex);
return mQueue.empty();
}
template <typename T> size_t Queue<T>::size() const {
std::lock_guard<std::mutex> lock(mMutex);
std::lock_guard lock(mMutex);
return mQueue.size();
}
template <typename T> size_t Queue<T>::amount() const {
std::lock_guard<std::mutex> lock(mMutex);
std::lock_guard lock(mMutex);
return mAmount;
}
template <typename T> void Queue<T>::push(const T &element) { push(T{element}); }
template <typename T> void Queue<T>::push(T &&element) {
std::unique_lock<std::mutex> lock(mMutex);
std::unique_lock lock(mMutex);
mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
if (!mStopping) {
mAmount += mAmountFunction(element);
@ -101,7 +101,7 @@ template <typename T> void Queue<T>::push(T &&element) {
}
template <typename T> std::optional<T> Queue<T>::pop() {
std::unique_lock<std::mutex> lock(mMutex);
std::unique_lock lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
if (!mQueue.empty()) {
mAmount -= mAmountFunction(mQueue.front());
@ -114,7 +114,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
}
template <typename T> std::optional<T> Queue<T>::peek() {
std::unique_lock<std::mutex> lock(mMutex);
std::unique_lock lock(mMutex);
if (!mQueue.empty()) {
return std::optional<T>{mQueue.front()};
} else {
@ -123,12 +123,12 @@ template <typename T> std::optional<T> Queue<T>::peek() {
}
template <typename T> void Queue<T>::wait() {
std::unique_lock<std::mutex> lock(mMutex);
std::unique_lock lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
}
template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
std::unique_lock<std::mutex> lock(mMutex);
std::unique_lock lock(mMutex);
mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
}

View File

@ -145,7 +145,7 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
static std::unordered_map<string, shared_ptr<Certificate>> cache;
static std::mutex cacheMutex;
std::lock_guard<std::mutex> lock(cacheMutex);
std::lock_guard lock(cacheMutex);
if (auto it = cache.find(commonName); it != cache.end())
return it->second;
@ -241,7 +241,7 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
static std::unordered_map<string, shared_ptr<Certificate>> cache;
static std::mutex cacheMutex;
std::lock_guard<std::mutex> lock(cacheMutex);
std::lock_guard lock(cacheMutex);
if (auto it = cache.find(commonName); it != cache.end())
return it->second;

View File

@ -85,6 +85,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
}
DtlsTransport::~DtlsTransport() {
stop();
gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
gnutls_deinit(mSession);
}
@ -94,8 +96,10 @@ DtlsTransport::State DtlsTransport::state() const { return mState; }
void DtlsTransport::stop() {
Transport::stop();
if (mRecvThread.joinable()) {
mIncomingQueue.stop();
mRecvThread.join();
}
}
bool DtlsTransport::send(message_ptr message) {
@ -293,7 +297,7 @@ int DtlsTransport::TransportExIndex = -1;
std::mutex DtlsTransport::GlobalMutex;
void DtlsTransport::GlobalInit() {
std::lock_guard<std::mutex> lock(GlobalMutex);
std::lock_guard lock(GlobalMutex);
if (TransportExIndex < 0) {
TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
}
@ -358,6 +362,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
}
DtlsTransport::~DtlsTransport() {
stop();
SSL_shutdown(mSsl);
SSL_free(mSsl);
SSL_CTX_free(mCtx);
@ -366,8 +372,10 @@ DtlsTransport::~DtlsTransport() {
void DtlsTransport::stop() {
Transport::stop();
if (mRecvThread.joinable()) {
mIncomingQueue.stop();
mRecvThread.join();
}
}
DtlsTransport::State DtlsTransport::state() const { return mState; }

View File

@ -55,10 +55,10 @@ public:
State state() const;
void stop() override;
bool send(message_ptr message); // false if dropped
bool send(message_ptr message) override; // false if dropped
private:
void incoming(message_ptr message);
void incoming(message_ptr message) override;
void changeState(State state);
void runRecvLoop();

View File

@ -130,11 +130,13 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
RecvCallback, this);
}
IceTransport::~IceTransport() {}
IceTransport::~IceTransport() { stop(); }
void IceTransport::stop() {
if (mMainLoopThread.joinable()) {
g_main_loop_quit(mMainLoop.get());
mMainLoopThread.join();
}
}
Description::Role IceTransport::role() const { return mRole; }

View File

@ -71,9 +71,9 @@ public:
bool send(message_ptr message) override; // false if dropped
private:
void incoming(message_ptr message);
void incoming(message_ptr message) override;
void incoming(const byte *data, int size);
void outgoing(message_ptr message);
void outgoing(message_ptr message) override;
void changeState(State state);
void changeGatheringState(GatheringState state);

View File

@ -20,6 +20,7 @@
#include "certificate.hpp"
#include "dtlstransport.hpp"
#include "icetransport.hpp"
#include "include.hpp"
#include "sctptransport.hpp"
#include <iostream>
@ -37,12 +38,12 @@ PeerConnection::PeerConnection(const Configuration &config)
: mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
PeerConnection::~PeerConnection() {
if (mIceTransport)
mIceTransport->stop();
if (mDtlsTransport)
mDtlsTransport->stop();
if (mSctpTransport)
mSctpTransport->stop();
if (auto transport = std::atomic_load(&mIceTransport))
transport->stop();
if (auto transport = std::atomic_load(&mDtlsTransport))
transport->stop();
if (auto transport = std::atomic_load(&mSctpTransport))
transport->stop();
mSctpTransport.reset();
mDtlsTransport.reset();
@ -55,26 +56,36 @@ PeerConnection::State PeerConnection::state() const { return mState; }
PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; }
std::optional<Description> PeerConnection::localDescription() const { return mLocalDescription; }
std::optional<Description> PeerConnection::localDescription() const {
std::lock_guard lock(mLocalDescriptionMutex);
return mLocalDescription;
}
std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; }
std::optional<Description> PeerConnection::remoteDescription() const {
std::lock_guard lock(mRemoteDescriptionMutex);
return mRemoteDescription;
}
void PeerConnection::setRemoteDescription(Description description) {
std::lock_guard lock(mRemoteDescriptionMutex);
auto remoteCandidates = description.extractCandidates();
mRemoteDescription.emplace(std::move(description));
if (!mIceTransport)
initIceTransport(Description::Role::ActPass);
auto iceTransport = std::atomic_load(&mIceTransport);
if (!iceTransport)
iceTransport = initIceTransport(Description::Role::ActPass);
mIceTransport->setRemoteDescription(*mRemoteDescription);
iceTransport->setRemoteDescription(*mRemoteDescription);
if (mRemoteDescription->type() == Description::Type::Offer) {
// This is an offer and we are the answerer.
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Answer));
mIceTransport->gatherLocalCandidates();
processLocalDescription(iceTransport->getLocalDescription(Description::Type::Answer));
iceTransport->gatherLocalCandidates();
} else {
// This is an answer and we are the offerer.
if (!mSctpTransport && mIceTransport->role() == Description::Role::Active) {
auto sctpTransport = std::atomic_load(&mSctpTransport);
if (!sctpTransport && iceTransport->role() == Description::Role::Active) {
// Since we assumed passive role during DataChannel creation, we need to shift the
// stream numbers by one to shift them from odd to even.
decltype(mDataChannels) newDataChannels;
@ -92,16 +103,19 @@ void PeerConnection::setRemoteDescription(Description description) {
}
void PeerConnection::addRemoteCandidate(Candidate candidate) {
if (!mRemoteDescription || !mIceTransport)
std::lock_guard lock(mRemoteDescriptionMutex);
auto iceTransport = std::atomic_load(&mIceTransport);
if (!mRemoteDescription || !iceTransport)
throw std::logic_error("Remote candidate set without remote description");
mRemoteDescription->addCandidate(candidate);
if (candidate.resolve(Candidate::ResolveMode::Simple)) {
mIceTransport->addRemoteCandidate(candidate);
iceTransport->addRemoteCandidate(candidate);
} else {
// OK, we might need a lookup, do it asynchronously
weak_ptr<IceTransport> weakIceTransport{mIceTransport};
weak_ptr<IceTransport> weakIceTransport{iceTransport};
std::thread t([weakIceTransport, candidate]() mutable {
if (candidate.resolve(Candidate::ResolveMode::Lookup))
if (auto iceTransport = weakIceTransport.lock())
@ -112,11 +126,13 @@ void PeerConnection::addRemoteCandidate(Candidate candidate) {
}
std::optional<string> PeerConnection::localAddress() const {
return mIceTransport ? mIceTransport->getLocalAddress() : nullopt;
auto iceTransport = std::atomic_load(&mIceTransport);
return iceTransport ? iceTransport->getLocalAddress() : nullopt;
}
std::optional<string> PeerConnection::remoteAddress() const {
return mIceTransport ? mIceTransport->getRemoteAddress() : nullopt;
auto iceTransport = std::atomic_load(&mIceTransport);
return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
}
shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
@ -126,7 +142,8 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
// setup:passive. [...] Thus, setup:active is RECOMMENDED.
// See https://tools.ietf.org/html/rfc5763#section-5
// Therefore, we assume passive role when we are the offerer.
auto role = mIceTransport ? mIceTransport->role() : Description::Role::Passive;
auto iceTransport = std::atomic_load(&mIceTransport);
auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
// The active side must use streams with even identifiers, whereas the passive side must use
// streams with odd identifiers.
@ -142,15 +159,17 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
mDataChannels.insert(std::make_pair(stream, channel));
if (!mIceTransport) {
if (!iceTransport) {
// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
// setup:actpass.
// See https://tools.ietf.org/html/rfc5763#section-5
initIceTransport(Description::Role::ActPass);
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer));
mIceTransport->gatherLocalCandidates();
} else if (mSctpTransport && mSctpTransport->state() == SctpTransport::State::Connected) {
channel->open(mSctpTransport);
iceTransport = initIceTransport(Description::Role::ActPass);
processLocalDescription(iceTransport->getLocalDescription(Description::Type::Offer));
iceTransport->gatherLocalCandidates();
} else {
if (auto transport = std::atomic_load(&mSctpTransport))
if (transport->state() == SctpTransport::State::Connected)
channel->open(transport);
}
return channel;
}
@ -177,8 +196,8 @@ void PeerConnection::onGatheringStateChange(std::function<void(GatheringState st
mGatheringStateChangeCallback = callback;
}
void PeerConnection::initIceTransport(Description::Role role) {
mIceTransport = std::make_shared<IceTransport>(
shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role) {
auto transport = std::make_shared<IceTransport>(
mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
[this](IceTransport::State state) {
switch (state) {
@ -211,11 +230,14 @@ void PeerConnection::initIceTransport(Description::Role role) {
break;
}
});
std::atomic_store(&mIceTransport, transport);
return transport;
}
void PeerConnection::initDtlsTransport() {
mDtlsTransport = std::make_shared<DtlsTransport>(
mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
auto lower = std::atomic_load(&mIceTransport);
auto transport = std::make_shared<DtlsTransport>(
lower, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
[this](DtlsTransport::State state) {
switch (state) {
case DtlsTransport::State::Connected:
@ -229,12 +251,15 @@ void PeerConnection::initDtlsTransport() {
break;
}
});
std::atomic_store(&mDtlsTransport, transport);
return transport;
}
void PeerConnection::initSctpTransport() {
uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
mSctpTransport = std::make_shared<SctpTransport>(
mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
uint16_t sctpPort = remoteDescription()->sctpPort().value_or(DEFAULT_SCTP_PORT);
auto lower = std::atomic_load(&mDtlsTransport);
auto transport = std::make_shared<SctpTransport>(
lower, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
std::bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
[this](SctpTransport::State state) {
switch (state) {
@ -253,9 +278,12 @@ void PeerConnection::initSctpTransport() {
break;
}
});
std::atomic_store(&mSctpTransport, transport);
return transport;
}
bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
std::lock_guard lock(mRemoteDescriptionMutex);
if (auto expectedFingerprint =
mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) {
return *expectedFingerprint == fingerprint;
@ -264,9 +292,6 @@ bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
}
void PeerConnection::forwardMessage(message_ptr message) {
if (!mIceTransport || !mSctpTransport)
throw std::logic_error("Got a DataChannel message without transport");
if (!message) {
closeDataChannels();
return;
@ -281,19 +306,24 @@ void PeerConnection::forwardMessage(message_ptr message) {
}
}
auto iceTransport = std::atomic_load(&mIceTransport);
auto sctpTransport = std::atomic_load(&mSctpTransport);
if (!iceTransport || !sctpTransport)
return;
if (!channel) {
const byte dataChannelOpenMessage{0x03};
unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0;
unsigned int remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
message->stream % 2 == remoteParity) {
channel =
std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream);
std::make_shared<DataChannel>(shared_from_this(), sctpTransport, message->stream);
channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this,
weak_ptr<DataChannel>{channel}));
mDataChannels.insert(std::make_pair(message->stream, channel));
} else {
// Invalid, close the DataChannel by resetting the stream
mSctpTransport->reset(message->stream);
sctpTransport->reset(message->stream);
return;
}
}
@ -330,16 +360,20 @@ void PeerConnection::iterateDataChannels(
}
void PeerConnection::openDataChannels() {
iterateDataChannels([this](shared_ptr<DataChannel> channel) { channel->open(mSctpTransport); });
if (auto transport = std::atomic_load(&mSctpTransport))
iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->open(transport); });
}
void PeerConnection::closeDataChannels() {
iterateDataChannels([](shared_ptr<DataChannel> channel) { channel->close(); });
iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->close(); });
}
void PeerConnection::processLocalDescription(Description description) {
auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt;
std::optional<uint16_t> remoteSctpPort;
if (auto remote = remoteDescription())
remoteSctpPort = remote->sctpPort();
std::lock_guard lock(mLocalDescriptionMutex);
mLocalDescription.emplace(std::move(description));
mLocalDescription->setFingerprint(mCertificate->fingerprint());
mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
@ -349,6 +383,7 @@ void PeerConnection::processLocalDescription(Description description) {
}
void PeerConnection::processLocalCandidate(Candidate candidate) {
std::lock_guard lock(mLocalDescriptionMutex);
if (!mLocalDescription)
throw std::logic_error("Got a local candidate without local description");

View File

@ -33,7 +33,7 @@ std::mutex SctpTransport::GlobalMutex;
int SctpTransport::InstancesCount = 0;
void SctpTransport::GlobalInit() {
std::unique_lock<std::mutex> lock(GlobalMutex);
std::lock_guard lock(GlobalMutex);
if (InstancesCount++ == 0) {
usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
usrsctp_sysctl_set_sctp_ecn_enable(0);
@ -41,7 +41,7 @@ void SctpTransport::GlobalInit() {
}
void SctpTransport::GlobalCleanup() {
std::unique_lock<std::mutex> lock(GlobalMutex);
std::lock_guard lock(GlobalMutex);
if (--InstancesCount == 0) {
usrsctp_finish();
}
@ -143,6 +143,8 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
}
SctpTransport::~SctpTransport() {
stop();
if (mSock) {
usrsctp_shutdown(mSock, SHUT_RDWR);
usrsctp_close(mSock);
@ -156,15 +158,14 @@ SctpTransport::State SctpTransport::state() const { return mState; }
void SctpTransport::stop() {
Transport::stop();
onRecv(nullptr);
mSendQueue.stop();
// Unblock incoming
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
mConnectDataSent = true;
mConnectCondition.notify_all();
}
}
void SctpTransport::connect() {
@ -190,7 +191,7 @@ void SctpTransport::connect() {
}
bool SctpTransport::send(message_ptr message) {
std::lock_guard<std::mutex> lock(mSendMutex);
std::lock_guard lock(mSendMutex);
if (!message)
return mSendQueue.empty();
@ -225,8 +226,8 @@ void SctpTransport::incoming(message_ptr message) {
// There could be a race condition here where we receive the remote INIT before the local one is
// sent, which would result in the connection being aborted. Therefore, we need to wait for data
// to be sent on our side (i.e. the local INIT) before proceeding.
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
{
std::unique_lock lock(mConnectMutex);
mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; });
}
@ -361,7 +362,7 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co
int SctpTransport::handleSend(size_t free) {
try {
std::lock_guard<std::mutex> lock(mSendMutex);
std::lock_guard lock(mSendMutex);
trySendQueue();
} catch (const std::exception &e) {
std::cerr << "SCTP send: " << e.what() << std::endl;
@ -374,11 +375,9 @@ int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_
try {
outgoing(make_message(data, data + len));
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
std::unique_lock lock(mConnectMutex);
mConnectDataSent = true;
mConnectCondition.notify_all();
}
} catch (const std::exception &e) {
std::cerr << "SCTP write: " << e.what() << std::endl;
return -1;
@ -453,7 +452,6 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
switch (notify->sn_header.sn_type) {
case SCTP_ASSOC_CHANGE: {
const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
std::unique_lock<std::mutex> lock(mConnectMutex);
if (assoc_change.sac_state == SCTP_COMM_UP) {
changeState(State::Connected);
} else {
@ -468,7 +466,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
case SCTP_SENDER_DRY_EVENT: {
// It not should be necessary since the send callback should have been called already,
// but to be sure, let's try to send now.
std::lock_guard<std::mutex> lock(mSendMutex);
std::lock_guard lock(mSendMutex);
trySendQueue();
}
case SCTP_STREAM_RESET_EVENT: {

View File

@ -68,7 +68,7 @@ private:
};
void connect();
void incoming(message_ptr message);
void incoming(message_ptr message) override;
void changeState(State state);
bool trySendQueue();
@ -93,8 +93,7 @@ private:
std::mutex mConnectMutex;
std::condition_variable mConnectCondition;
std::atomic<bool> mConnectDataSent = false;
std::atomic<bool> mStopping = false;
bool mConnectDataSent = false;
state_callback mStateChangeCallback;
std::atomic<State> mState;