Fixed state callback and revised synchronization and deletion

This commit is contained in:
Paul-Louis Ageneau
2020-03-31 14:59:50 +02:00
parent 577d048844
commit e04113f3f1
10 changed files with 147 additions and 65 deletions

View File

@ -60,6 +60,8 @@ protected:
virtual void triggerAvailable(size_t count); virtual void triggerAvailable(size_t count);
virtual void triggerBufferedAmount(size_t amount); virtual void triggerBufferedAmount(size_t amount);
void resetCallbacks();
private: private:
synchronized_callback<> mOpenCallback; synchronized_callback<> mOpenCallback;
synchronized_callback<> mClosedCallback; synchronized_callback<> mClosedCallback;

View File

@ -40,7 +40,7 @@ class DataChannel : public std::enable_shared_from_this<DataChannel>, public Cha
public: public:
DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label, DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
string protocol, Reliability reliability); string protocol, Reliability reliability);
DataChannel(std::weak_ptr<PeerConnection> pc, std::shared_ptr<SctpTransport> transport, DataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
unsigned int stream); unsigned int stream);
~DataChannel(); ~DataChannel();
@ -65,13 +65,13 @@ public:
private: private:
void remoteClose(); void remoteClose();
void open(std::shared_ptr<SctpTransport> sctpTransport); void open(std::shared_ptr<SctpTransport> transport);
bool outgoing(mutable_message_ptr message); bool outgoing(mutable_message_ptr message);
void incoming(message_ptr message); void incoming(message_ptr message);
void processOpenMessage(message_ptr message); void processOpenMessage(message_ptr message);
const std::weak_ptr<PeerConnection> mPeerConnection; const std::weak_ptr<PeerConnection> mPeerConnection;
std::shared_ptr<SctpTransport> mSctpTransport; std::weak_ptr<SctpTransport> mSctpTransport;
unsigned int mStream; unsigned int mStream;
string mLabel; string mLabel;

View File

@ -52,14 +52,13 @@ public:
Connected = RTC_CONNECTED, Connected = RTC_CONNECTED,
Disconnected = RTC_DISCONNECTED, Disconnected = RTC_DISCONNECTED,
Failed = RTC_FAILED, Failed = RTC_FAILED,
Closed = RTC_CLOSED, Closed = RTC_CLOSED
Destroying = RTC_DESTROYING
}; };
enum class GatheringState : int { enum class GatheringState : int {
New = RTC_GATHERING_NEW, New = RTC_GATHERING_NEW,
InProgress = RTC_GATHERING_INPROGRESS, InProgress = RTC_GATHERING_INPROGRESS,
Complete = RTC_GATHERING_COMPLETE, Complete = RTC_GATHERING_COMPLETE
}; };
PeerConnection(void); PeerConnection(void);
@ -94,6 +93,7 @@ private:
std::shared_ptr<IceTransport> initIceTransport(Description::Role role); std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
std::shared_ptr<DtlsTransport> initDtlsTransport(); std::shared_ptr<DtlsTransport> initDtlsTransport();
std::shared_ptr<SctpTransport> initSctpTransport(); std::shared_ptr<SctpTransport> initSctpTransport();
void closeTransports();
void endLocalCandidates(); void endLocalCandidates();
bool checkFingerprint(const std::string &fingerprint) const; bool checkFingerprint(const std::string &fingerprint) const;
@ -112,8 +112,10 @@ private:
void processLocalDescription(Description description); void processLocalDescription(Description description);
void processLocalCandidate(Candidate candidate); void processLocalCandidate(Candidate candidate);
void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel); void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel);
void changeState(State state); bool changeState(State state);
void changeGatheringState(GatheringState state); bool changeGatheringState(GatheringState state);
void resetCallbacks();
const Configuration mConfig; const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;

View File

@ -31,8 +31,7 @@ typedef enum {
RTC_CONNECTED = 2, RTC_CONNECTED = 2,
RTC_DISCONNECTED = 3, RTC_DISCONNECTED = 3,
RTC_FAILED = 4, RTC_FAILED = 4,
RTC_CLOSED = 5, RTC_CLOSED = 5
RTC_DESTROYING = 6 // internal
} rtcState; } rtcState;
typedef enum { typedef enum {

View File

@ -88,5 +88,14 @@ void Channel::triggerBufferedAmount(size_t amount) {
mBufferedAmountLowCallback(); mBufferedAmountLowCallback();
} }
void Channel::resetCallbacks() {
mOpenCallback = nullptr;
mClosedCallback = nullptr;
mErrorCallback = nullptr;
mMessageCallback = nullptr;
mAvailableCallback = nullptr;
mBufferedAmountLowCallback = nullptr;
}
} // namespace rtc } // namespace rtc

View File

@ -74,7 +74,7 @@ DataChannel::DataChannel(weak_ptr<PeerConnection> pc, unsigned int stream, strin
mReliability(std::make_shared<Reliability>(std::move(reliability))), mReliability(std::make_shared<Reliability>(std::move(reliability))),
mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {} mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
DataChannel::DataChannel(weak_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport, DataChannel::DataChannel(weak_ptr<PeerConnection> pc, weak_ptr<SctpTransport> transport,
unsigned int stream) unsigned int stream)
: mPeerConnection(pc), mSctpTransport(transport), mStream(stream), : mPeerConnection(pc), mSctpTransport(transport), mStream(stream),
mReliability(std::make_shared<Reliability>()), mReliability(std::make_shared<Reliability>()),
@ -93,10 +93,13 @@ string DataChannel::protocol() const { return mProtocol; }
Reliability DataChannel::reliability() const { return *mReliability; } Reliability DataChannel::reliability() const { return *mReliability; }
void DataChannel::close() { void DataChannel::close() {
if (mIsOpen.exchange(false) && mSctpTransport) if (mIsOpen.exchange(false))
mSctpTransport->reset(mStream); if (auto transport = mSctpTransport.lock())
transport->reset(mStream);
mIsClosed = true; mIsClosed = true;
mSctpTransport.reset(); mSctpTransport.reset();
resetCallbacks();
} }
void DataChannel::remoteClose() { void DataChannel::remoteClose() {
@ -158,8 +161,8 @@ size_t DataChannel::maxMessageSize() const {
size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); } size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) { void DataChannel::open(shared_ptr<SctpTransport> transport) {
mSctpTransport = sctpTransport; mSctpTransport = transport;
uint8_t channelType = static_cast<uint8_t>(mReliability->type); uint8_t channelType = static_cast<uint8_t>(mReliability->type);
if (mReliability->unordered) if (mReliability->unordered)
@ -186,20 +189,24 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
std::copy(mLabel.begin(), mLabel.end(), end); std::copy(mLabel.begin(), mLabel.end(), end);
std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size()); std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
} }
bool DataChannel::outgoing(mutable_message_ptr message) { bool DataChannel::outgoing(mutable_message_ptr message) {
if (mIsClosed || !mSctpTransport) if (mIsClosed)
throw std::runtime_error("DataChannel is closed"); throw std::runtime_error("DataChannel is closed");
if (message->size() > maxMessageSize()) if (message->size() > maxMessageSize())
throw std::runtime_error("Message size exceeds limit"); throw std::runtime_error("Message size exceeds limit");
auto transport = mSctpTransport.lock();
if (!transport)
throw std::runtime_error("DataChannel has no transport");
// Before the ACK has been received on a DataChannel, all messages must be sent ordered // Before the ACK has been received on a DataChannel, all messages must be sent ordered
message->reliability = mIsOpen ? mReliability : nullptr; message->reliability = mIsOpen ? mReliability : nullptr;
message->stream = mStream; message->stream = mStream;
return mSctpTransport->send(message); return transport->send(message);
} }
void DataChannel::incoming(message_ptr message) { void DataChannel::incoming(message_ptr message) {
@ -238,6 +245,10 @@ void DataChannel::incoming(message_ptr message) {
} }
void DataChannel::processOpenMessage(message_ptr message) { void DataChannel::processOpenMessage(message_ptr message) {
auto transport = mSctpTransport.lock();
if (!transport)
throw std::runtime_error("DataChannel has no transport");
if (message->size() < sizeof(OpenMessage)) if (message->size() < sizeof(OpenMessage))
throw std::invalid_argument("DataChannel open message too small"); throw std::invalid_argument("DataChannel open message too small");
@ -274,7 +285,7 @@ void DataChannel::processOpenMessage(message_ptr message) {
auto &ack = *reinterpret_cast<AckMessage *>(buffer.data()); auto &ack = *reinterpret_cast<AckMessage *>(buffer.data());
ack.type = MESSAGE_ACK; ack.type = MESSAGE_ACK;
mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
mIsOpen = true; mIsOpen = true;
triggerOpen(); triggerOpen();

View File

@ -24,6 +24,7 @@
#include "sctptransport.hpp" #include "sctptransport.hpp"
#include <iostream> #include <iostream>
#include <thread>
namespace rtc { namespace rtc {
@ -38,28 +39,11 @@ PeerConnection::PeerConnection() : PeerConnection(Configuration()) {
PeerConnection::PeerConnection(const Configuration &config) PeerConnection::PeerConnection(const Configuration &config)
: mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {} : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
PeerConnection::~PeerConnection() { PeerConnection::~PeerConnection() { close(); }
changeState(State::Destroying);
close();
mSctpTransport.reset();
mDtlsTransport.reset();
mIceTransport.reset();
}
void PeerConnection::close() { void PeerConnection::close() {
// Close DataChannels
closeDataChannels(); closeDataChannels();
closeTransports();
// Close Transports
for (int i = 0; i < 2; ++i) { // Make sure a transport wasn't spawn behind our back
if (auto transport = std::atomic_load(&mSctpTransport))
transport->stop();
if (auto transport = std::atomic_load(&mDtlsTransport))
transport->stop();
if (auto transport = std::atomic_load(&mIceTransport))
transport->stop();
}
changeState(State::Closed);
} }
const Configuration *PeerConnection::config() const { return &mConfig; } const Configuration *PeerConnection::config() const { return &mConfig; }
@ -241,8 +225,15 @@ shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role
break; break;
} }
}); });
std::atomic_store(&mIceTransport, transport); std::atomic_store(&mIceTransport, transport);
if (mState == State::Closed) {
mIceTransport.reset();
transport->stop();
throw std::runtime_error("Connection is closed");
}
return transport; return transport;
} catch (const std::exception &e) { } catch (const std::exception &e) {
PLOG_ERROR << e.what(); PLOG_ERROR << e.what();
changeState(State::Failed); changeState(State::Failed);
@ -274,8 +265,15 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
break; break;
} }
}); });
std::atomic_store(&mDtlsTransport, transport); std::atomic_store(&mDtlsTransport, transport);
if (mState == State::Closed) {
mDtlsTransport.reset();
transport->stop();
throw std::runtime_error("Connection is closed");
}
return transport; return transport;
} catch (const std::exception &e) { } catch (const std::exception &e) {
PLOG_ERROR << e.what(); PLOG_ERROR << e.what();
changeState(State::Failed); changeState(State::Failed);
@ -312,8 +310,15 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
break; break;
} }
}); });
std::atomic_store(&mSctpTransport, transport); std::atomic_store(&mSctpTransport, transport);
if (mState == State::Closed) {
mSctpTransport.reset();
transport->stop();
throw std::runtime_error("Connection is closed");
}
return transport; return transport;
} catch (const std::exception &e) { } catch (const std::exception &e) {
PLOG_ERROR << e.what(); PLOG_ERROR << e.what();
changeState(State::Failed); changeState(State::Failed);
@ -321,6 +326,34 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
} }
} }
void PeerConnection::closeTransports() {
// Change state to sink state Closed to block init methods
changeState(State::Closed);
// Reset callbacks now that state is changed
resetCallbacks();
// Pass the references to a thread, allowing to terminate a transport from its own thread
auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
if (sctp || dtls || ice) {
std::thread t([sctp, dtls, ice]() mutable {
if (sctp)
sctp->stop();
if (dtls)
dtls->stop();
if (ice)
ice->stop();
sctp.reset();
dtls.reset();
ice.reset();
});
t.detach();
}
}
void PeerConnection::endLocalCandidates() { void PeerConnection::endLocalCandidates() {
std::lock_guard lock(mLocalDescriptionMutex); std::lock_guard lock(mLocalDescriptionMutex);
if (mLocalDescription) if (mLocalDescription)
@ -467,21 +500,34 @@ void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
mDataChannelCallback(dataChannel); mDataChannelCallback(dataChannel);
} }
void PeerConnection::changeState(State state) { bool PeerConnection::changeState(State state) {
State current; State current;
do { do {
current = mState.load(); current = mState.load();
if (current == state || current == State::Destroying) if (current == state)
return; return true;
if (current == State::Closed)
return false;
} while (!mState.compare_exchange_weak(current, state)); } while (!mState.compare_exchange_weak(current, state));
if (state != State::Destroying)
mStateChangeCallback(state); mStateChangeCallback(state);
return true;
} }
void PeerConnection::changeGatheringState(GatheringState state) { bool PeerConnection::changeGatheringState(GatheringState state) {
if (mGatheringState.exchange(state) != state) if (mGatheringState.exchange(state) != state)
mGatheringStateChangeCallback(state); mGatheringStateChangeCallback(state);
return true;
}
void PeerConnection::resetCallbacks() {
// Unregister all callbacks
mDataChannelCallback = nullptr;
mLocalDescriptionCallback = nullptr;
mLocalCandidateCallback = nullptr;
mStateChangeCallback = nullptr;
mGatheringStateChangeCallback = nullptr;
} }
} // namespace rtc } // namespace rtc
@ -508,9 +554,6 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &st
case State::Closed: case State::Closed:
str = "closed"; str = "closed";
break; break;
case State::Destroying:
str = "destroying";
break;
default: default:
str = "unknown"; str = "unknown";
break; break;

View File

@ -53,6 +53,14 @@ void *getUserPointer(int id) {
return it != userPointerMap.end() ? it->second : nullptr; return it != userPointerMap.end() ? it->second : nullptr;
} }
void setUserPointer(int i, void *ptr) {
std::lock_guard lock(mutex);
if (ptr)
userPointerMap.insert(std::make_pair(i, ptr));
else
userPointerMap.erase(i);
}
shared_ptr<PeerConnection> getPeerConnection(int id) { shared_ptr<PeerConnection> getPeerConnection(int id) {
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
auto it = peerConnectionMap.find(id); auto it = peerConnectionMap.find(id);
@ -99,12 +107,7 @@ bool eraseDataChannel(int dc) {
void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); } void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
void rtcSetUserPointer(int i, void *ptr) { void rtcSetUserPointer(int i, void *ptr) { setUserPointer(i, ptr); }
if (ptr)
userPointerMap.insert(std::make_pair(i, ptr));
else
userPointerMap.erase(i);
}
int rtcCreatePeerConnection(const rtcConfiguration *config) { int rtcCreatePeerConnection(const rtcConfiguration *config) {
Configuration c; Configuration c;

View File

@ -157,6 +157,16 @@ int test_capi_main() {
sleep(3); sleep(3);
if (peer1->state != RTC_CONNECTED || peer2->state != RTC_CONNECTED) {
fprintf(stderr, "PeerConnection is not connected\n");
goto error;
}
if (!peer1->connected || !peer2->connected) {
fprintf(stderr, "DataChannel is not connected\n");
goto error;
}
char buffer[256]; char buffer[256];
if (rtcGetLocalAddress(peer1->pc, buffer, 256) >= 0) if (rtcGetLocalAddress(peer1->pc, buffer, 256) >= 0)
printf("Local address 1: %s\n", buffer); printf("Local address 1: %s\n", buffer);
@ -167,13 +177,12 @@ int test_capi_main() {
if (rtcGetRemoteAddress(peer2->pc, buffer, 256) >= 0) if (rtcGetRemoteAddress(peer2->pc, buffer, 256) >= 0)
printf("Remote address 2: %s\n", buffer); printf("Remote address 2: %s\n", buffer);
if (peer1->connected && peer2->connected) {
deletePeer(peer1); deletePeer(peer1);
deletePeer(peer2);
sleep(1); sleep(1);
deletePeer(peer2);
printf("Success\n"); printf("Success\n");
return 0; return 0;
}
error: error:
deletePeer(peer1); deletePeer(peer1);

View File

@ -108,6 +108,13 @@ void test_connectivity() {
this_thread::sleep_for(3s); this_thread::sleep_for(3s);
if (pc1->state() != PeerConnection::State::Connected &&
pc2->state() != PeerConnection::State::Connected)
throw runtime_error("PeerConnection is not connected");
if (!dc1->isOpen() || !dc2->isOpen())
throw runtime_error("DataChannel is not open");
if (auto addr = pc1->localAddress()) if (auto addr = pc1->localAddress())
cout << "Local address 1: " << *addr << endl; cout << "Local address 1: " << *addr << endl;
if (auto addr = pc1->remoteAddress()) if (auto addr = pc1->remoteAddress())
@ -117,13 +124,10 @@ void test_connectivity() {
if (auto addr = pc2->remoteAddress()) if (auto addr = pc2->remoteAddress())
cout << "Remote address 2: " << *addr << endl; cout << "Remote address 2: " << *addr << endl;
if (!dc1->isOpen() || !dc2->isOpen()) // Delay close of peer 2 to check closing works properly
throw runtime_error("DataChannel is not open");
pc1->close(); pc1->close();
pc2->close();
this_thread::sleep_for(1s); this_thread::sleep_for(1s);
pc2->close();
cout << "Success" << endl; cout << "Success" << endl;
} }