Compare commits

...

7 Commits

14 changed files with 465 additions and 132 deletions

View File

@ -46,12 +46,8 @@ pc->onLocalDescription([](const rtc::Description &sdp) {
MY_SEND_DESCRIPTION_TO_REMOTE(string(sdp)); MY_SEND_DESCRIPTION_TO_REMOTE(string(sdp));
}); });
pc->onLocalCandidate([](const optional<rtc::Candidate> &candidate) { pc->onLocalCandidate([](const rtc::Candidate &candidate) {
if (candidate) { MY_SEND_CANDIDATE_TO_REMOTE(candidate.candidate(), candidate.mid());
MY_SEND_CANDIDATE_TO_REMOTE(candidate->candidate(), candidate->mid());
} else {
// Gathering finished
}
}); });
MY_ON_RECV_DESCRIPTION_FROM_REMOTE([pc](string sdp) { MY_ON_RECV_DESCRIPTION_FROM_REMOTE([pc](string sdp) {
@ -63,6 +59,19 @@ MY_ON_RECV_CANDIDATE_FROM_REMOTE([pc](string candidate, string mid) {
}); });
``` ```
### Observe the PeerConnection state
```cpp
pc->onStateChanged([](PeerConnection::State state) {
cout << "State: " << state << endl;
});
pc->onGatheringStateChanged([](PeerConnection::GatheringState state) {
cout << "Gathering state: " << state << endl;
});
```
### Create a DataChannel ### Create a DataChannel
```cpp ```cpp

View File

@ -47,7 +47,8 @@ public:
void setFingerprint(string fingerprint); void setFingerprint(string fingerprint);
void setSctpPort(uint16_t port); void setSctpPort(uint16_t port);
void addCandidate(std::optional<Candidate> candidate); void addCandidate(Candidate candidate);
void endCandidates();
operator string() const; operator string() const;

View File

@ -20,12 +20,13 @@
#define RTC_PEER_CONNECTION_H #define RTC_PEER_CONNECTION_H
#include "candidate.hpp" #include "candidate.hpp"
#include "configuration.hpp"
#include "datachannel.hpp" #include "datachannel.hpp"
#include "description.hpp" #include "description.hpp"
#include "configuration.hpp"
#include "include.hpp" #include "include.hpp"
#include "message.hpp" #include "message.hpp"
#include "reliability.hpp" #include "reliability.hpp"
#include "rtc.hpp"
#include <atomic> #include <atomic>
#include <functional> #include <functional>
@ -40,12 +41,28 @@ class SctpTransport;
class PeerConnection { class PeerConnection {
public: public:
enum class State : int {
New = RTC_NEW,
Connecting = RTC_CONNECTING,
Connected = RTC_CONNECTED,
Disconnected = RTC_DISCONNECTED,
Failed = RTC_FAILED,
Closed = RTC_CLOSED
};
enum class GatheringState : int {
New = RTC_GATHERING_NEW,
InProgress = RTC_GATHERING_INPROGRESS,
Complete = RTC_GATHERING_COMPLETE,
};
PeerConnection(void); PeerConnection(void);
PeerConnection(const Configuration &config); PeerConnection(const Configuration &config);
~PeerConnection(); ~PeerConnection();
const Configuration *config() const; const Configuration *config() const;
State state() const;
GatheringState gatheringState() const;
std::optional<Description> localDescription() const; std::optional<Description> localDescription() const;
std::optional<Description> remoteDescription() const; std::optional<Description> remoteDescription() const;
@ -57,7 +74,9 @@ public:
void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback); void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
void onLocalDescription(std::function<void(const Description &description)> callback); void onLocalDescription(std::function<void(const Description &description)> callback);
void onLocalCandidate(std::function<void(const std::optional<Candidate> &candidate)> callback); void onLocalCandidate(std::function<void(const Candidate &candidate)> callback);
void onStateChange(std::function<void(State state)> callback);
void onGatheringStateChange(std::function<void(GatheringState state)> callback);
private: private:
void initIceTransport(Description::Role role); void initIceTransport(Description::Role role);
@ -66,11 +85,15 @@ private:
bool checkFingerprint(const std::string &fingerprint) const; bool checkFingerprint(const std::string &fingerprint) const;
void forwardMessage(message_ptr message); void forwardMessage(message_ptr message);
void openDataChannels(void); void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
void openDataChannels();
void closeDataChannels();
void processLocalDescription(Description description); void processLocalDescription(Description description);
void processLocalCandidate(std::optional<Candidate> candidate); void processLocalCandidate(Candidate candidate);
void triggerDataChannel(std::shared_ptr<DataChannel> dataChannel); void triggerDataChannel(std::shared_ptr<DataChannel> dataChannel);
void changeState(State state);
void changeGatheringState(GatheringState state);
const Configuration mConfig; const Configuration mConfig;
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;
@ -84,11 +107,19 @@ private:
std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels; std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels;
std::atomic<State> mState;
std::atomic<GatheringState> mGatheringState;
std::function<void(std::shared_ptr<DataChannel> dataChannel)> mDataChannelCallback; std::function<void(std::shared_ptr<DataChannel> dataChannel)> mDataChannelCallback;
std::function<void(const Description &description)> mLocalDescriptionCallback; std::function<void(const Description &description)> mLocalDescriptionCallback;
std::function<void(const std::optional<Candidate> &candidate)> mLocalCandidateCallback; std::function<void(const Candidate &candidate)> mLocalCandidateCallback;
std::function<void(State state)> mStateChangeCallback;
std::function<void(GatheringState state)> mGatheringStateChangeCallback;
}; };
} // namespace rtc } // namespace rtc
std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &state);
std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::GatheringState &state);
#endif #endif

View File

@ -16,11 +16,30 @@
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/ */
#ifndef RTC_C_API
#define RTC_C_API
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
// libdatachannel rtc C API // libdatachannel rtc C API
typedef enum {
RTC_NEW = 0,
RTC_CONNECTING = 1,
RTC_CONNECTED = 2,
RTC_DISCONNECTED = 3,
RTC_FAILED = 4,
RTC_CLOSED = 5
} rtc_state_t;
typedef enum {
RTC_GATHERING_NEW = 0,
RTC_GATHERING_INPROGRESS = 1,
RTC_GATHERING_COMPLETE = 2
} rtc_gathering_state_t;
int rtcCreatePeerConnection(const char **iceServers, int iceServersCount); int rtcCreatePeerConnection(const char **iceServers, int iceServersCount);
void rtcDeletePeerConnection(int pc); void rtcDeletePeerConnection(int pc);
int rtcCreateDataChannel(int pc, const char *label); int rtcCreateDataChannel(int pc, const char *label);
@ -30,6 +49,10 @@ void rtcSetLocalDescriptionCallback(int pc, void (*descriptionCallback)(const ch
void *)); void *));
void rtcSetLocalCandidateCallback(int pc, void rtcSetLocalCandidateCallback(int pc,
void (*candidateCallback)(const char *, const char *, void *)); void (*candidateCallback)(const char *, const char *, void *));
void rtcSetStateChangeCallback(int pc, void (*stateCallback)(rtc_state_t state, void *));
void rtcSetGatheringStateChangeCallback(int pc,
void (*gatheringStateCallback)(rtc_gathering_state_t state,
void *));
void rtcSetRemoteDescription(int pc, const char *sdp, const char *type); void rtcSetRemoteDescription(int pc, const char *sdp, const char *type);
void rtcAddRemoteCandidate(int pc, const char *candidate, const char *mid); void rtcAddRemoteCandidate(int pc, const char *candidate, const char *mid);
int rtcGetDataChannelLabel(int dc, char *data, int size); int rtcGetDataChannelLabel(int dc, char *data, int size);
@ -43,3 +66,5 @@ void rtcSetUserPointer(int i, void *ptr);
} // extern "C" } // extern "C"
#endif #endif
#endif

View File

@ -109,13 +109,12 @@ void Description::setFingerprint(string fingerprint) {
void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); } void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
void Description::addCandidate(std::optional<Candidate> candidate) { void Description::addCandidate(Candidate candidate) {
if (candidate) mCandidates.emplace_back(std::move(candidate));
mCandidates.emplace_back(std::move(*candidate));
else
mTrickle = false;
} }
void Description::endCandidates() { mTrickle = false; }
Description::operator string() const { Description::operator string() const {
if (!mFingerprint) if (!mFingerprint)
throw std::logic_error("Fingerprint must be set to generate a SDP"); throw std::logic_error("Fingerprint must be set to generate a SDP");

View File

@ -22,6 +22,7 @@
#include <cassert> #include <cassert>
#include <cstring> #include <cstring>
#include <exception> #include <exception>
#include <iostream>
#include <gnutls/dtls.h> #include <gnutls/dtls.h>
@ -46,9 +47,11 @@ namespace rtc {
using std::shared_ptr; using std::shared_ptr;
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate, DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifier, ready_callback ready) verifier_callback verifierCallback,
: Transport(lower), mCertificate(certificate), mVerifierCallback(std::move(verifier)), state_callback stateChangeCallback)
mReadyCallback(std::move(ready)) { : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback); gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
bool active = lower->role() == Description::Role::Active; bool active = lower->role() == Description::Role::Active;
@ -81,7 +84,12 @@ DtlsTransport::~DtlsTransport() {
gnutls_deinit(mSession); gnutls_deinit(mSession);
} }
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::send(message_ptr message) { bool DtlsTransport::send(message_ptr message) {
if (!message)
return false;
while (true) { while (true) {
ssize_t ret = gnutls_record_send(mSession, message->data(), message->size()); ssize_t ret = gnutls_record_send(mSession, message->data(), message->size());
if (check_gnutls(ret)) { if (check_gnutls(ret)) {
@ -92,10 +100,25 @@ bool DtlsTransport::send(message_ptr message) {
void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); }
void DtlsTransport::runRecvLoop() { void DtlsTransport::changeState(State state) {
while (!check_gnutls(gnutls_handshake(mSession), "TLS handshake failed")) {} mState = state;
mStateChangeCallback(state);
}
mReadyCallback(); void DtlsTransport::runRecvLoop() {
try {
changeState(State::Connecting);
while (!check_gnutls(gnutls_handshake(mSession), "TLS handshake failed")) {
}
} catch (const std::exception &e) {
std::cerr << "DTLS handshake: " << e.what() << std::endl;
changeState(State::Failed);
return;
}
try {
changeState(State::Connected);
const size_t bufferSize = 2048; const size_t bufferSize = 2048;
char buffer[bufferSize]; char buffer[bufferSize];
@ -111,6 +134,13 @@ void DtlsTransport::runRecvLoop() {
recv(make_message(b, b + ret)); recv(make_message(b, b + ret));
} }
} }
} catch (const std::exception &e) {
std::cerr << "DTLS recv: " << e.what() << std::endl;
}
changeState(State::Disconnected);
recv(nullptr);
} }
int DtlsTransport::CertificateCallback(gnutls_session_t session) { int DtlsTransport::CertificateCallback(gnutls_session_t session) {
@ -120,7 +150,6 @@ int DtlsTransport::CertificateCallback(gnutls_session_t session) {
return GNUTLS_E_CERTIFICATE_ERROR; return GNUTLS_E_CERTIFICATE_ERROR;
} }
// Get peer's certificate
unsigned int count = 0; unsigned int count = 0;
const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count); const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count);
if (!array || count == 0) { if (!array || count == 0) {
@ -155,13 +184,13 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) { ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
DtlsTransport *t = static_cast<DtlsTransport *>(ptr); DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
auto next = t->mIncomingQueue.pop(); auto next = t->mIncomingQueue.pop();
if (!next) { auto message = next ? *next : nullptr;
if (!message) {
// Closed // Closed
gnutls_transport_set_errno(t->mSession, 0); gnutls_transport_set_errno(t->mSession, 0);
return 0; return 0;
} }
auto message = *next;
ssize_t len = std::min(maxlen, message->size()); ssize_t len = std::min(maxlen, message->size());
std::memcpy(data, message->data(), len); std::memcpy(data, message->data(), len);
gnutls_transport_set_errno(t->mSession, 0); gnutls_transport_set_errno(t->mSession, 0);

View File

@ -25,6 +25,7 @@
#include "queue.hpp" #include "queue.hpp"
#include "transport.hpp" #include "transport.hpp"
#include <atomic>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <thread> #include <thread>
@ -37,26 +38,33 @@ class IceTransport;
class DtlsTransport : public Transport { class DtlsTransport : public Transport {
public: public:
enum class State { Disconnected, Connecting, Connected, Failed };
using verifier_callback = std::function<bool(const std::string &fingerprint)>; using verifier_callback = std::function<bool(const std::string &fingerprint)>;
using ready_callback = std::function<void(void)>; using state_callback = std::function<void(State state)>;
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate, DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
verifier_callback verifier, ready_callback ready); verifier_callback verifierCallback, state_callback stateChangeCallback);
~DtlsTransport(); ~DtlsTransport();
State state() const;
bool send(message_ptr message); bool send(message_ptr message);
private: private:
void incoming(message_ptr message); void incoming(message_ptr message);
void changeState(State state);
void runRecvLoop(); void runRecvLoop();
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;
gnutls_session_t mSession; gnutls_session_t mSession;
Queue<message_ptr> mIncomingQueue; Queue<message_ptr> mIncomingQueue;
std::atomic<State> mState;
std::thread mRecvThread; std::thread mRecvThread;
verifier_callback mVerifierCallback; verifier_callback mVerifierCallback;
ready_callback mReadyCallback; state_callback mStateChangeCallback;
static int CertificateCallback(gnutls_session_t session); static int CertificateCallback(gnutls_session_t session);
static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);

View File

@ -34,10 +34,14 @@ using std::shared_ptr;
using std::weak_ptr; using std::weak_ptr;
IceTransport::IceTransport(const Configuration &config, Description::Role role, IceTransport::IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, ready_callback ready) candidate_callback candidateCallback,
: mRole(role), mMid("0"), mState(State::Disconnected), mNiceAgent(nullptr, nullptr), state_callback stateChangeCallback,
mMainLoop(nullptr, nullptr), mCandidateCallback(std::move(candidateCallback)), gathering_state_callback gatheringStateChangeCallback)
mReadyCallback(ready) { : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr),
mCandidateCallback(std::move(candidateCallback)),
mStateChangeCallback(std::move(stateChangeCallback)),
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)) {
auto logLevelFlags = GLogLevelFlags(G_LOG_LEVEL_MASK | G_LOG_FLAG_FATAL | G_LOG_FLAG_RECURSION); auto logLevelFlags = GLogLevelFlags(G_LOG_LEVEL_MASK | G_LOG_FLAG_FATAL | G_LOG_FLAG_RECURSION);
g_log_set_handler(nullptr, logLevelFlags, LogCallback, this); g_log_set_handler(nullptr, logLevelFlags, LogCallback, this);
@ -102,7 +106,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
} }
g_signal_connect(G_OBJECT(mNiceAgent.get()), "component-state-changed", g_signal_connect(G_OBJECT(mNiceAgent.get()), "component-state-changed",
G_CALLBACK(StateChangedCallback), this); G_CALLBACK(StateChangeCallback), this);
g_signal_connect(G_OBJECT(mNiceAgent.get()), "new-candidate-full", g_signal_connect(G_OBJECT(mNiceAgent.get()), "new-candidate-full",
G_CALLBACK(CandidateCallback), this); G_CALLBACK(CandidateCallback), this);
g_signal_connect(G_OBJECT(mNiceAgent.get()), "candidate-gathering-done", g_signal_connect(G_OBJECT(mNiceAgent.get()), "candidate-gathering-done",
@ -146,9 +150,13 @@ void IceTransport::setRemoteDescription(const Description &description) {
} }
void IceTransport::gatherLocalCandidates() { void IceTransport::gatherLocalCandidates() {
if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId)) // Change state now as candidates calls can be synchronous
changeGatheringState(GatheringState::InProgress);
if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId)) {
throw std::runtime_error("Failed to gather local ICE candidates"); throw std::runtime_error("Failed to gather local ICE candidates");
} }
}
bool IceTransport::addRemoteCandidate(const Candidate &candidate) { bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
// Warning: the candidate string must start with "a=candidate:" and it must not end with a // Warning: the candidate string must start with "a=candidate:" and it must not end with a
@ -167,7 +175,7 @@ bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
} }
bool IceTransport::send(message_ptr message) { bool IceTransport::send(message_ptr message) {
if (!mStreamId) if (!message || !mStreamId)
return false; return false;
outgoing(message); outgoing(message);
@ -185,42 +193,66 @@ void IceTransport::outgoing(message_ptr message) {
reinterpret_cast<const char *>(message->data())); reinterpret_cast<const char *>(message->data()));
} }
void IceTransport::changeState(State state) {
mState = state;
mStateChangeCallback(mState);
}
void IceTransport::changeGatheringState(GatheringState state) {
mGatheringState = state;
mGatheringStateChangeCallback(mGatheringState);
}
void IceTransport::processCandidate(const string &candidate) { void IceTransport::processCandidate(const string &candidate) {
mCandidateCallback(Candidate(candidate, mMid)); mCandidateCallback(Candidate(candidate, mMid));
} }
void IceTransport::processGatheringDone() { mCandidateCallback(nullopt); } void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); }
void IceTransport::changeState(uint32_t state) { void IceTransport::processStateChange(uint32_t state) {
mState = static_cast<State>(state); if (state != NICE_COMPONENT_STATE_GATHERING)
if (mState == State::Ready) { changeState(static_cast<State>(state));
mReadyCallback();
}
} }
void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate,
gpointer userData) { gpointer userData) {
auto iceTransport = static_cast<rtc::IceTransport *>(userData); auto iceTransport = static_cast<rtc::IceTransport *>(userData);
gchar *cand = nice_agent_generate_local_candidate_sdp(agent, candidate); gchar *cand = nice_agent_generate_local_candidate_sdp(agent, candidate);
try {
iceTransport->processCandidate(cand); iceTransport->processCandidate(cand);
} catch (const std::exception &e) {
std::cerr << "ICE candidate: " << e.what() << std::endl;
}
g_free(cand); g_free(cand);
} }
void IceTransport::GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData) { void IceTransport::GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData) {
auto iceTransport = static_cast<rtc::IceTransport *>(userData); auto iceTransport = static_cast<rtc::IceTransport *>(userData);
try {
iceTransport->processGatheringDone(); iceTransport->processGatheringDone();
} catch (const std::exception &e) {
std::cerr << "ICE gathering done: " << e.what() << std::endl;
}
} }
void IceTransport::StateChangedCallback(NiceAgent *agent, guint streamId, guint componentId, void IceTransport::StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId,
guint state, gpointer userData) { guint state, gpointer userData) {
auto iceTransport = static_cast<rtc::IceTransport *>(userData); auto iceTransport = static_cast<rtc::IceTransport *>(userData);
iceTransport->changeState(state); try {
iceTransport->processStateChange(state);
} catch (const std::exception &e) {
std::cerr << "ICE change state: " << e.what() << std::endl;
}
} }
void IceTransport::RecvCallback(NiceAgent *agent, guint streamId, guint componentId, guint len, void IceTransport::RecvCallback(NiceAgent *agent, guint streamId, guint componentId, guint len,
gchar *buf, gpointer userData) { gchar *buf, gpointer userData) {
auto iceTransport = static_cast<rtc::IceTransport *>(userData); auto iceTransport = static_cast<rtc::IceTransport *>(userData);
try {
iceTransport->incoming(reinterpret_cast<byte *>(buf), len); iceTransport->incoming(reinterpret_cast<byte *>(buf), len);
} catch (const std::exception &e) {
std::cerr << "ICE incoming: " << e.what() << std::endl;
}
} }
void IceTransport::LogCallback(const gchar *logDomain, GLogLevelFlags logLevel, void IceTransport::LogCallback(const gchar *logDomain, GLogLevelFlags logLevel,

View File

@ -31,7 +31,6 @@ extern "C" {
} }
#include <atomic> #include <atomic>
#include <optional>
#include <thread> #include <thread>
namespace rtc { namespace rtc {
@ -40,22 +39,26 @@ class IceTransport : public Transport {
public: public:
enum class State : uint32_t { enum class State : uint32_t {
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED, Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
Gathering = NICE_COMPONENT_STATE_GATHERING,
Connecting = NICE_COMPONENT_STATE_CONNECTING, Connecting = NICE_COMPONENT_STATE_CONNECTING,
Connected = NICE_COMPONENT_STATE_CONNECTED, Connected = NICE_COMPONENT_STATE_CONNECTED,
Ready = NICE_COMPONENT_STATE_READY, Ready = NICE_COMPONENT_STATE_READY,
Failed = NICE_COMPONENT_STATE_FAILED Failed = NICE_COMPONENT_STATE_FAILED
}; };
using candidate_callback = std::function<void(const std::optional<Candidate> &candidate)>; enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
using ready_callback = std::function<void(void)>;
using candidate_callback = std::function<void(const Candidate &candidate)>;
using state_callback = std::function<void(State state)>;
using gathering_state_callback = std::function<void(GatheringState state)>;
IceTransport(const Configuration &config, Description::Role role, IceTransport(const Configuration &config, Description::Role role,
candidate_callback candidateCallback, ready_callback ready); candidate_callback candidateCallback, state_callback stateChangeCallback,
gathering_state_callback gatheringStateChangeCallback);
~IceTransport(); ~IceTransport();
Description::Role role() const; Description::Role role() const;
State state() const; State state() const;
GatheringState gyyatheringState() const;
Description getLocalDescription(Description::Type type) const; Description getLocalDescription(Description::Type type) const;
void setRemoteDescription(const Description &description); void setRemoteDescription(const Description &description);
void gatherLocalCandidates(); void gatherLocalCandidates();
@ -68,13 +71,17 @@ private:
void incoming(const byte *data, int size); void incoming(const byte *data, int size);
void outgoing(message_ptr message); void outgoing(message_ptr message);
void changeState(uint32_t state); void changeState(State state);
void changeGatheringState(GatheringState state);
void processCandidate(const string &candidate); void processCandidate(const string &candidate);
void processGatheringDone(); void processGatheringDone();
void processStateChange(uint32_t state);
Description::Role mRole; Description::Role mRole;
string mMid; string mMid;
State mState; std::atomic<State> mState;
std::atomic<GatheringState> mGatheringState;
uint32_t mStreamId = 0; uint32_t mStreamId = 0;
std::unique_ptr<NiceAgent, void (*)(gpointer)> mNiceAgent; std::unique_ptr<NiceAgent, void (*)(gpointer)> mNiceAgent;
@ -82,11 +89,12 @@ private:
std::thread mMainLoopThread; std::thread mMainLoopThread;
candidate_callback mCandidateCallback; candidate_callback mCandidateCallback;
ready_callback mReadyCallback; state_callback mStateChangeCallback;
gathering_state_callback mGatheringStateChangeCallback;
static void CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData); static void CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData);
static void GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData); static void GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData);
static void StateChangedCallback(NiceAgent *agent, guint streamId, guint componentId, static void StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId,
guint state, gpointer userData); guint state, gpointer userData);
static void RecvCallback(NiceAgent *agent, guint stream_id, guint component_id, guint len, static void RecvCallback(NiceAgent *agent, guint stream_id, guint component_id, guint len,
gchar *buf, gpointer userData); gchar *buf, gpointer userData);

View File

@ -34,12 +34,16 @@ using std::shared_ptr;
PeerConnection::PeerConnection() : PeerConnection(Configuration()) {} PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
PeerConnection::PeerConnection(const Configuration &config) PeerConnection::PeerConnection(const Configuration &config)
: mConfig(config), mCertificate(make_certificate("libdatachannel")) {} : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
PeerConnection::~PeerConnection() {} PeerConnection::~PeerConnection() {}
const Configuration *PeerConnection::config() const { return &mConfig; } const Configuration *PeerConnection::config() const { return &mConfig; }
PeerConnection::State PeerConnection::state() const { return mState; }
PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; }
std::optional<Description> PeerConnection::localDescription() const { return mLocalDescription; } std::optional<Description> PeerConnection::localDescription() const { return mLocalDescription; }
std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; } std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; }
@ -62,7 +66,7 @@ void PeerConnection::addRemoteCandidate(Candidate candidate) {
throw std::logic_error("Remote candidate set without remote description"); throw std::logic_error("Remote candidate set without remote description");
if (mIceTransport->addRemoteCandidate(candidate)) if (mIceTransport->addRemoteCandidate(candidate))
mRemoteDescription->addCandidate(std::make_optional(std::move(candidate))); mRemoteDescription->addCandidate(std::move(candidate));
} }
shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label, shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
@ -86,7 +90,7 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
initIceTransport(Description::Role::Active); initIceTransport(Description::Role::Active);
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer)); processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer));
mIceTransport->gatherLocalCandidates(); mIceTransport->gatherLocalCandidates();
} else if (mSctpTransport && mSctpTransport->isReady()) { } else if (mSctpTransport && mSctpTransport->state() == SctpTransport::State::Connected) {
channel->open(mSctpTransport); channel->open(mSctpTransport);
} }
return channel; return channel;
@ -102,28 +106,93 @@ void PeerConnection::onLocalDescription(
mLocalDescriptionCallback = callback; mLocalDescriptionCallback = callback;
} }
void PeerConnection::onLocalCandidate( void PeerConnection::onLocalCandidate(std::function<void(const Candidate &candidate)> callback) {
std::function<void(const std::optional<Candidate> &candidate)> callback) {
mLocalCandidateCallback = callback; mLocalCandidateCallback = callback;
} }
void PeerConnection::onStateChange(std::function<void(State state)> callback) {
mStateChangeCallback = callback;
}
void PeerConnection::onGatheringStateChange(std::function<void(GatheringState state)> callback) {
mGatheringStateChangeCallback = callback;
}
void PeerConnection::initIceTransport(Description::Role role) { void PeerConnection::initIceTransport(Description::Role role) {
mIceTransport = std::make_shared<IceTransport>( mIceTransport = std::make_shared<IceTransport>(
mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1), mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
std::bind(&PeerConnection::initDtlsTransport, this)); [this](IceTransport::State state) {
switch (state) {
case IceTransport::State::Connecting:
changeState(State::Connecting);
break;
case IceTransport::State::Failed:
changeState(State::Failed);
break;
case IceTransport::State::Ready:
initDtlsTransport();
break;
default:
// Ignore
break;
}
},
[this](IceTransport::GatheringState state) {
switch (state) {
case IceTransport::GatheringState::InProgress:
changeGatheringState(GatheringState::InProgress);
break;
case IceTransport::GatheringState::Complete:
if (mLocalDescription)
mLocalDescription->endCandidates();
changeGatheringState(GatheringState::Complete);
break;
default:
// Ignore
break;
}
});
} }
void PeerConnection::initDtlsTransport() { void PeerConnection::initDtlsTransport() {
mDtlsTransport = std::make_shared<DtlsTransport>( mDtlsTransport = std::make_shared<DtlsTransport>(
mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1), mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
std::bind(&PeerConnection::initSctpTransport, this)); [this](DtlsTransport::State state) {
switch (state) {
case DtlsTransport::State::Connected:
initSctpTransport();
break;
case DtlsTransport::State::Failed:
changeState(State::Failed);
break;
default:
// Ignore
break;
}
});
} }
void PeerConnection::initSctpTransport() { void PeerConnection::initSctpTransport() {
uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT); uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
mSctpTransport = std::make_shared<SctpTransport>( mSctpTransport = std::make_shared<SctpTransport>(
mDtlsTransport, sctpPort, std::bind(&PeerConnection::openDataChannels, this), mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
std::bind(&PeerConnection::forwardMessage, this, _1)); [this](SctpTransport::State state) {
switch (state) {
case SctpTransport::State::Connected:
changeState(State::Connected);
openDataChannels();
break;
case SctpTransport::State::Failed:
changeState(State::Failed);
break;
case SctpTransport::State::Disconnected:
changeState(State::Disconnected);
break;
default:
// Ignore
break;
}
});
} }
bool PeerConnection::checkFingerprint(const std::string &fingerprint) const { bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
@ -138,6 +207,11 @@ void PeerConnection::forwardMessage(message_ptr message) {
if (!mIceTransport || !mSctpTransport) if (!mIceTransport || !mSctpTransport)
throw std::logic_error("Got a DataChannel message without transport"); throw std::logic_error("Got a DataChannel message without transport");
if (!message) {
closeDataChannels();
return;
}
shared_ptr<DataChannel> channel; shared_ptr<DataChannel> channel;
if (auto it = mDataChannels.find(message->stream); it != mDataChannels.end()) { if (auto it = mDataChannels.find(message->stream); it != mDataChannels.end()) {
channel = it->second.lock(); channel = it->second.lock();
@ -165,7 +239,8 @@ void PeerConnection::forwardMessage(message_ptr message) {
channel->incoming(message); channel->incoming(message);
} }
void PeerConnection::openDataChannels(void) { void PeerConnection::iterateDataChannels(
std::function<void(shared_ptr<DataChannel> channel)> func) {
auto it = mDataChannels.begin(); auto it = mDataChannels.begin();
while (it != mDataChannels.end()) { while (it != mDataChannels.end()) {
auto channel = it->second.lock(); auto channel = it->second.lock();
@ -173,11 +248,19 @@ void PeerConnection::openDataChannels(void) {
it = mDataChannels.erase(it); it = mDataChannels.erase(it);
continue; continue;
} }
channel->open(mSctpTransport); func(channel);
++it; ++it;
} }
} }
void PeerConnection::openDataChannels() {
iterateDataChannels([this](shared_ptr<DataChannel> channel) { channel->open(mSctpTransport); });
}
void PeerConnection::closeDataChannels() {
iterateDataChannels([](shared_ptr<DataChannel> channel) { channel->close(); });
}
void PeerConnection::processLocalDescription(Description description) { void PeerConnection::processLocalDescription(Description description) {
auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt; auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt;
@ -189,7 +272,7 @@ void PeerConnection::processLocalDescription(Description description) {
mLocalDescriptionCallback(*mLocalDescription); mLocalDescriptionCallback(*mLocalDescription);
} }
void PeerConnection::processLocalCandidate(std::optional<Candidate> candidate) { void PeerConnection::processLocalCandidate(Candidate candidate) {
if (!mLocalDescription) if (!mLocalDescription)
throw std::logic_error("Got a local candidate without local description"); throw std::logic_error("Got a local candidate without local description");
@ -204,4 +287,63 @@ void PeerConnection::triggerDataChannel(std::shared_ptr<DataChannel> dataChannel
mDataChannelCallback(dataChannel); mDataChannelCallback(dataChannel);
} }
void PeerConnection::changeState(State state) {
mState = state;
if (mStateChangeCallback)
mStateChangeCallback(state);
}
void PeerConnection::changeGatheringState(GatheringState state) {
mGatheringState = state;
if (mGatheringStateChangeCallback)
mGatheringStateChangeCallback(state);
}
} // namespace rtc } // namespace rtc
std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &state) {
using State = rtc::PeerConnection::State;
std::string str;
switch (state) {
case State::New:
str = "new";
break;
case State::Connecting:
str = "connecting";
break;
case State::Connected:
str = "connected";
break;
case State::Disconnected:
str = "disconnected";
break;
case State::Failed:
str = "failed";
break;
default:
str = "unknown";
break;
}
return out << str;
}
std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::GatheringState &state) {
using GatheringState = rtc::PeerConnection::GatheringState;
std::string str;
switch (state) {
case GatheringState::New:
str = "new";
break;
case GatheringState::InProgress:
str = "in_progress";
break;
case GatheringState::Complete:
str = "complete";
break;
default:
str = "unknown";
break;
}
return out << str;
}

View File

@ -95,14 +95,32 @@ void rtcSetLocalCandidateCallback(int pc,
if (it == peerConnectionMap.end()) if (it == peerConnectionMap.end())
return; return;
it->second->onLocalCandidate( it->second->onLocalCandidate([pc, candidateCallback](const Candidate &candidate) {
[pc, candidateCallback](const std::optional<Candidate> &candidate) { candidateCallback(candidate.candidate().c_str(), candidate.mid().c_str(),
if (candidate) {
candidateCallback(string(*candidate).c_str(), candidate->mid().c_str(),
getUserPointer(pc)); getUserPointer(pc));
} else { });
candidateCallback(nullptr, nullptr, getUserPointer(pc));
} }
void rtcSetStateChangeCallback(int pc, void (*stateCallback)(rtc_state_t state, void *)) {
auto it = peerConnectionMap.find(pc);
if (it == peerConnectionMap.end())
return;
it->second->onStateChange([pc, stateCallback](PeerConnection::State state) {
stateCallback(static_cast<rtc_state_t>(state), getUserPointer(pc));
});
}
void rtcSetGatheringStateChangeCallback(int pc,
void (*gatheringStateCallback)(rtc_gathering_state_t state,
void *)) {
auto it = peerConnectionMap.find(pc);
if (it == peerConnectionMap.end())
return;
it->second->onGatheringStateChange(
[pc, gatheringStateCallback](PeerConnection::GatheringState state) {
gatheringStateCallback(static_cast<rtc_gathering_state_t>(state), getUserPointer(pc));
}); });
} }

View File

@ -47,9 +47,10 @@ void SctpTransport::GlobalCleanup() {
} }
} }
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, ready_callback ready, SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
message_callback recv) state_callback stateChangeCallback)
: Transport(lower), mReadyCallback(std::move(ready)), mPort(port) { : Transport(lower), mPort(port), mState(State::Disconnected),
mStateChangeCallback(std::move(stateChangeCallback)) {
onRecv(recv); onRecv(recv);
@ -120,6 +121,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, re
SctpTransport::~SctpTransport() { SctpTransport::~SctpTransport() {
mStopping = true; mStopping = true;
mConnectCondition.notify_all();
if (mConnectThread.joinable()) if (mConnectThread.joinable())
mConnectThread.join(); mConnectThread.join();
@ -132,9 +134,12 @@ SctpTransport::~SctpTransport() {
GlobalCleanup(); GlobalCleanup();
} }
bool SctpTransport::isReady() const { return mIsReady; } SctpTransport::State SctpTransport::state() const { return mState; }
bool SctpTransport::send(message_ptr message) { bool SctpTransport::send(message_ptr message) {
if (!message)
return false;
const Reliability reliability = message->reliability ? *message->reliability : Reliability(); const Reliability reliability = message->reliability ? *message->reliability : Reliability();
struct sctp_sendv_spa spa = {}; struct sctp_sendv_spa spa = {};
@ -201,6 +206,12 @@ void SctpTransport::reset(unsigned int stream) {
} }
void SctpTransport::incoming(message_ptr message) { void SctpTransport::incoming(message_ptr message) {
if (!message) {
changeState(State::Disconnected);
recv(nullptr);
return;
}
// There could be a race condition here where we receive the remote INIT before the thread in // There could be a race condition here where we receive the remote INIT before the thread in
// usrsctp_connect sends the local one, which would result in the connection being aborted. // usrsctp_connect sends the local one, which would result in the connection being aborted.
// Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before // Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before
@ -214,7 +225,15 @@ void SctpTransport::incoming(message_ptr message) {
usrsctp_conninput(this, message->data(), message->size(), 0); usrsctp_conninput(this, message->data(), message->size(), 0);
} }
void SctpTransport::changeState(State state) {
mState = state;
mStateChangeCallback(state);
}
void SctpTransport::runConnect() { void SctpTransport::runConnect() {
try {
changeState(State::Connecting);
struct sockaddr_conn sconn = {}; struct sockaddr_conn sconn = {};
sconn.sconn_family = AF_CONN; sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(mPort); sconn.sconn_port = htons(mPort);
@ -226,15 +245,19 @@ void SctpTransport::runConnect() {
// According to the IETF draft, both endpoints must initiate the SCTP association, in a // According to the IETF draft, both endpoints must initiate the SCTP association, in a
// simultaneous-open manner, irrelevent to the SDP setup role. // simultaneous-open manner, irrelevent to the SDP setup role.
// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3 // See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) != 0) { if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) !=
0) {
std::cerr << "SCTP connection failed, errno=" << errno << std::endl; std::cerr << "SCTP connection failed, errno=" << errno << std::endl;
changeState(State::Failed);
mStopping = true; mStopping = true;
return; return;
} }
if (!mStopping) { if (!mStopping)
mIsReady = true; changeState(State::Connected);
mReadyCallback();
} catch (const std::exception &e) {
std::cerr << "SCTP connect: " << e.what() << std::endl;
} }
} }
@ -251,12 +274,11 @@ int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_
} }
int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
struct sctp_rcvinfo recv_info, int flags) { struct sctp_rcvinfo info, int flags) {
if (flags & MSG_NOTIFICATION) { if (flags & MSG_NOTIFICATION) {
processNotification((union sctp_notification *)data, len); processNotification((union sctp_notification *)data, len);
} else { } else {
processData((const byte *)data, len, recv_info.rcv_sid, processData((const byte *)data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
PayloadId(htonl(recv_info.rcv_ppid)));
} }
free(data); free(data);
return 0; return 0;

View File

@ -36,13 +36,15 @@ namespace rtc {
class SctpTransport : public Transport { class SctpTransport : public Transport {
public: public:
using ready_callback = std::function<void(void)>; enum class State { Disconnected, Connecting, Connected, Failed };
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, ready_callback ready, using state_callback = std::function<void(State state)>;
message_callback recv);
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
state_callback stateChangeCallback);
~SctpTransport(); ~SctpTransport();
bool isReady() const; State state() const;
bool send(message_ptr message); bool send(message_ptr message);
void reset(unsigned int stream); void reset(unsigned int stream);
@ -57,6 +59,7 @@ private:
}; };
void incoming(message_ptr message); void incoming(message_ptr message);
void changeState(State state);
void runConnect(); void runConnect();
int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df); int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
@ -67,18 +70,18 @@ private:
void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid); void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
void processNotification(const union sctp_notification *notify, size_t len); void processNotification(const union sctp_notification *notify, size_t len);
ready_callback mReadyCallback;
struct socket *mSock; struct socket *mSock;
uint16_t mPort; uint16_t mPort;
std::thread mConnectThread; std::thread mConnectThread;
std::atomic<bool> mStopping = false;
std::atomic<bool> mIsReady = false;
std::mutex mConnectMutex; std::mutex mConnectMutex;
std::condition_variable mConnectCondition; std::condition_variable mConnectCondition;
std::atomic<bool> mConnectDataSent = false; std::atomic<bool> mConnectDataSent = false;
std::atomic<bool> mStopping = false;
std::atomic<State> mState;
state_callback mStateChangeCallback;
static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df); static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,

View File

@ -28,7 +28,7 @@ using namespace std;
int main(int argc, char **argv) { int main(int argc, char **argv) {
rtc::Configuration config; rtc::Configuration config;
config.iceServers.emplace_back("stun.l.google.com:19302"); // config.iceServers.emplace_back("stun.l.google.com:19302");
auto pc1 = std::make_shared<PeerConnection>(config); auto pc1 = std::make_shared<PeerConnection>(config);
auto pc2 = std::make_shared<PeerConnection>(config); auto pc2 = std::make_shared<PeerConnection>(config);
@ -38,11 +38,14 @@ int main(int argc, char **argv) {
pc2->setRemoteDescription(sdp); pc2->setRemoteDescription(sdp);
}); });
pc1->onLocalCandidate([pc2](const optional<Candidate> &candidate) { pc1->onLocalCandidate([pc2](const Candidate &candidate) {
if (candidate) { cout << "Candidate 1: " << candidate << endl;
cout << "Candidate 1: " << *candidate << endl; pc2->addRemoteCandidate(candidate);
pc2->addRemoteCandidate(*candidate); });
}
pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; });
pc1->onGatheringStateChange([](PeerConnection::GatheringState state) {
cout << "Gathering state 1: " << state << endl;
}); });
pc2->onLocalDescription([pc1](const Description &sdp) { pc2->onLocalDescription([pc1](const Description &sdp) {
@ -50,11 +53,14 @@ int main(int argc, char **argv) {
pc1->setRemoteDescription(sdp); pc1->setRemoteDescription(sdp);
}); });
pc2->onLocalCandidate([pc1](const optional<Candidate> &candidate) { pc2->onLocalCandidate([pc1](const Candidate &candidate) {
if (candidate) { cout << "Candidate 2: " << candidate << endl;
cout << "Candidate 2: " << *candidate << endl; pc1->addRemoteCandidate(candidate);
pc1->addRemoteCandidate(*candidate); });
}
pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; });
pc2->onGatheringStateChange([](PeerConnection::GatheringState state) {
cout << "Gathering state 2: " << state << endl;
}); });
shared_ptr<DataChannel> dc2; shared_ptr<DataChannel> dc2;