SRTP has to be aware of every RTP stream. This commit makes that work.

This commit is contained in:
Staz M
2020-10-15 23:31:49 -04:00
parent e4057c48f6
commit 0a46aa2c6d
6 changed files with 84 additions and 31 deletions

View File

@ -131,6 +131,7 @@ public:
void addSSRC(uint32_t ssrc, std::string name);
bool hasSSRC(uint32_t ssrc);
std::vector<uint32_t> getSSRCs();
void setBitrate(int bitrate);
int getBitrate() const;

View File

@ -706,6 +706,16 @@ void Description::Media::addRTPMap(const Description::Media::RTPMap& map) {
mRtpMap.emplace(map.pt, map);
}
std::vector<uint32_t> Description::Media::getSSRCs() {
std::vector<uint32_t> vec;
for (auto &val : mAttributes) {
PLOG_DEBUG << val;
if (val.find("ssrc:") == 0) {
vec.emplace_back(std::stoul((std::string)val.substr(5, val.find(" "))));
}
}
return vec;
}
Description::Media::RTPMap::RTPMap(string_view mline) {

View File

@ -262,41 +262,67 @@ void DtlsSrtpTransport::postHandshake() {
serverSalt = clientSalt + SRTP_SALT_LEN;
#endif
unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
std::memcpy(clientSessionKey, clientKey, SRTP_AES_128_KEY_LEN);
std::memcpy(clientSessionKey + SRTP_AES_128_KEY_LEN, clientSalt, SRTP_SALT_LEN);
unsigned char serverSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
std::memcpy(serverSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
std::memcpy(serverSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
srtp_policy_t inbound = {};
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
inbound.ssrc.type = ssrc_any_inbound;
inbound.ssrc.value = 0;
inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
inbound.next = nullptr;
if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
throw std::runtime_error("SRTP add inbound stream failed, status=" +
to_string(static_cast<int>(err)));
srtp_policy_t outbound = {};
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
outbound.ssrc.type = ssrc_any_outbound;
outbound.ssrc.value = 0;
outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
outbound.next = nullptr;
if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
throw std::runtime_error("SRTP add outbound stream failed, status=" +
to_string(static_cast<int>(err)));
// srtp_policy_t inbound = {};
// srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
// srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
// inbound.ssrc.type = ssrc_any_inbound;
// inbound.ssrc.value = 0;
// inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
// inbound.next = nullptr;
//
// if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
// throw std::runtime_error("SRTP add inbound stream failed, status=" +
// to_string(static_cast<int>(err)));
//
// srtp_policy_t outbound = {};
// srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
// srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
// outbound.ssrc.type = ssrc_any_outbound;
// outbound.ssrc.value = 0;
// outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
// outbound.next = nullptr;
//
// if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
// throw std::runtime_error("SRTP add outbound stream failed, status=" +
// to_string(static_cast<int>(err)));
mInitDone = true;
}
void DtlsSrtpTransport::addSSRC(uint32_t ssrc) {
srtp_policy_t inbound = {};
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
inbound.ssrc.type = ssrc_specific;
inbound.ssrc.value = ssrc;
inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
inbound.next = nullptr;
if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
throw std::runtime_error("SRTP add inbound stream failed, status=" +
to_string(static_cast<int>(err)));
srtp_policy_t outbound = {};
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
outbound.ssrc.type = ssrc_specific;
outbound.ssrc.value = ssrc;
outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
outbound.next = nullptr;
if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
throw std::runtime_error("SRTP add outbound stream failed, status=" +
to_string(static_cast<int>(err)));
}
} // namespace rtc
#endif

View File

@ -39,6 +39,7 @@ public:
~DtlsSrtpTransport();
bool sendMedia(message_ptr message);
void addSSRC(uint32_t ssrc);
private:
void incoming(message_ptr message) override;
@ -48,6 +49,9 @@ private:
srtp_t mSrtpIn, mSrtpOut;
bool mInitDone = false;
unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
unsigned char serverSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
};
} // namespace rtc

View File

@ -177,8 +177,8 @@ void DtlsTransport::runRecvLoop() {
// Receive loop
try {
PLOG_INFO << "DTLS handshake finished";
postHandshake();
changeState(State::Connected);
postHandshake();
const size_t bufferSize = maxMtu;
char buffer[bufferSize];

View File

@ -162,8 +162,8 @@ void PeerConnection::setRemoteDescription(Description description) {
for (const auto &candidate : remoteCandidates)
addRemoteCandidate(candidate);
if (std::atomic_load(&mIceTransport)) {
openTracks();
if (auto transport = std::atomic_load(&mDtlsTransport); transport && transport->state() == rtc::DtlsTransport::State::Connected) {
openTracks();
}
}
@ -694,11 +694,23 @@ void PeerConnection::openTracks() {
if (auto transport = std::atomic_load(&mDtlsTransport)) {
auto srtpTransport = std::reinterpret_pointer_cast<DtlsSrtpTransport>(transport);
std::shared_lock lock(mTracksMutex); // read-only
for (auto it = mTracks.begin(); it != mTracks.end(); ++it)
if (auto track = it->second.lock()) {
if (!track->isOpen())
// for (auto it = mTracks.begin(); it != mTracks.end(); ++it)
for (unsigned int i = 0; i < mTrackLines.size(); i++) {
if (auto track = mTrackLines[i].lock()) {
if (!track->isOpen()) {
// if (track->description().direction() == rtc::Description::Direction::RecvOnly || track->description().direction() == rtc::Description::Direction::SendRecv)
// srtpTransport->addInboundSSRC(0);
// if (track->description().direction() == rtc::Description::Direction::SendOnly || track->description().direction() == rtc::Description::Direction::SendRecv)
for (auto ssrc : track->description().getSSRCs())
srtpTransport->addSSRC(ssrc);
for (auto ssrc : std::get<rtc::Description::Media *>(remoteDescription()->media(i))->getSSRCs())
srtpTransport->addSSRC(ssrc);
track->open(srtpTransport);
}
}
}
}
#endif
}