diff --git a/include/rtc/channel.hpp b/include/rtc/channel.hpp index d4810a9..f944316 100644 --- a/include/rtc/channel.hpp +++ b/include/rtc/channel.hpp @@ -21,8 +21,8 @@ #include "include.hpp" +#include #include -#include #include namespace rtc { @@ -30,7 +30,7 @@ namespace rtc { class Channel { public: virtual void close() = 0; - virtual void send(const std::variant &data) = 0; + virtual bool send(const std::variant &data) = 0; virtual std::optional> receive() = 0; virtual bool isOpen() const = 0; virtual bool isClosed() const = 0; @@ -66,8 +66,8 @@ private: synchronized_callback<> mAvailableCallback; synchronized_callback<> mBufferedAmountLowCallback; - size_t mBufferedAmount = 0; - size_t mBufferedAmountLowThreshold = 0; + std::atomic mBufferedAmount = 0; + std::atomic mBufferedAmountLowThreshold = 0; }; } // namespace rtc diff --git a/include/rtc/datachannel.hpp b/include/rtc/datachannel.hpp index 80672aa..be07b4c 100644 --- a/include/rtc/datachannel.hpp +++ b/include/rtc/datachannel.hpp @@ -45,13 +45,13 @@ public: ~DataChannel(); void close(void); - void send(const std::variant &data); - void send(const byte *data, size_t size); + bool send(const std::variant &data); + bool send(const byte *data, size_t size); std::optional> receive(); // Directly send a buffer to avoid a copy - template void sendBuffer(const Buffer &buf); - template void sendBuffer(Iterator first, Iterator last); + template bool sendBuffer(const Buffer &buf); + template bool sendBuffer(Iterator first, Iterator last); bool isOpen(void) const; bool isClosed(void) const; @@ -65,7 +65,7 @@ public: private: void open(std::shared_ptr sctpTransport); - void outgoing(mutable_message_ptr message); + bool outgoing(mutable_message_ptr message); void incoming(message_ptr message); void processOpenMessage(message_ptr message); @@ -93,14 +93,14 @@ template std::pair to_bytes(const Buffer buf.size() * sizeof(E)); } -template void DataChannel::sendBuffer(const Buffer &buf) { +template bool DataChannel::sendBuffer(const Buffer &buf) { auto [bytes, size] = to_bytes(buf); auto message = std::make_shared(size); std::copy(bytes, bytes + size, message->data()); - outgoing(message); + return outgoing(message); } -template void DataChannel::sendBuffer(Iterator first, Iterator last) { +template bool DataChannel::sendBuffer(Iterator first, Iterator last) { size_t size = 0; for (Iterator it = first; it != last; ++it) size += it->size(); @@ -111,7 +111,7 @@ template void DataChannel::sendBuffer(Iterator first, Iterat auto [bytes, size] = to_bytes(*it); pos = std::copy(bytes, bytes + size, pos); } - outgoing(message); + return outgoing(message); } } // namespace rtc diff --git a/src/channel.cpp b/src/channel.cpp index 61f0441..56a0dfd 100644 --- a/src/channel.cpp +++ b/src/channel.cpp @@ -80,10 +80,9 @@ void Channel::triggerAvailable(size_t count) { } void Channel::triggerBufferedAmount(size_t amount) { - bool lowThresholdCrossed = - mBufferedAmount > mBufferedAmountLowThreshold && amount <= mBufferedAmountLowThreshold; - mBufferedAmount = amount; - if (lowThresholdCrossed) + size_t previous = mBufferedAmount.exchange(amount); + size_t threshold = mBufferedAmountLowThreshold.load(); + if (previous > threshold && amount <= threshold) mBufferedAmountLowCallback(); } diff --git a/src/datachannel.cpp b/src/datachannel.cpp index c3324f0..ecdbb5a 100644 --- a/src/datachannel.cpp +++ b/src/datachannel.cpp @@ -83,19 +83,19 @@ void DataChannel::close() { } } -void DataChannel::send(const std::variant &data) { - std::visit( +bool DataChannel::send(const std::variant &data) { + return std::visit( [&](const auto &d) { using T = std::decay_t; constexpr auto type = std::is_same_v ? Message::String : Message::Binary; auto *b = reinterpret_cast(d.data()); - outgoing(std::make_shared(b, b + d.size(), type)); + return outgoing(std::make_shared(b, b + d.size(), type)); }, data); } -void DataChannel::send(const byte *data, size_t size) { - outgoing(std::make_shared(data, data + size, Message::Binary)); +bool DataChannel::send(const byte *data, size_t size) { + return outgoing(std::make_shared(data, data + size, Message::Binary)); } std::optional> DataChannel::receive() { @@ -177,7 +177,7 @@ void DataChannel::open(shared_ptr sctpTransport) { mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); } -void DataChannel::outgoing(mutable_message_ptr message) { +bool DataChannel::outgoing(mutable_message_ptr message) { if (mIsClosed || !mSctpTransport) throw std::runtime_error("DataChannel is closed"); @@ -187,7 +187,7 @@ void DataChannel::outgoing(mutable_message_ptr message) { // Before the ACK has been received on a DataChannel, all messages must be sent ordered message->reliability = mIsOpen ? mReliability : nullptr; message->stream = mStream; - mSctpTransport->send(message); + return mSctpTransport->send(message); } void DataChannel::incoming(message_ptr message) { diff --git a/src/sctptransport.cpp b/src/sctptransport.cpp index 75b2b53..cba5d51 100644 --- a/src/sctptransport.cpp +++ b/src/sctptransport.cpp @@ -188,13 +188,18 @@ void SctpTransport::connect() { SctpTransport::State SctpTransport::state() const { return mState; } bool SctpTransport::send(message_ptr message) { - if (!message) - return false; + std::lock_guard lock(mSendMutex); + + if (!message) + return mSendQueue.empty(); + + // If nothing is pending, try to send directly + if (mSendQueue.empty() && trySendMessage(message)) + return true; - updateBufferedAmount(message->stream, message->size()); mSendQueue.push(message); - trySendAll(); - return true; + updateBufferedAmount(message->stream, message_size_func(message)); + return false; } void SctpTransport::reset(unsigned int stream) { @@ -231,25 +236,21 @@ void SctpTransport::changeState(State state) { mStateChangeCallback(state); } -bool SctpTransport::trySendAll() { - std::unique_lock lock(mSendMutex, std::try_to_lock); - if (!lock.owns_lock()) - return false; - +bool SctpTransport::trySendQueue() { + // Requires mSendMutex to be locked while (auto next = mSendQueue.peek()) { auto message = *next; - if (!trySend(message)) + if (!trySendMessage(message)) return false; - updateBufferedAmount(message->stream, -message->size()); mSendQueue.pop(); + updateBufferedAmount(message->stream, -message_size_func(message)); } return true; } -bool SctpTransport::trySend(message_ptr message) { - if (!message) - return false; - +bool SctpTransport::trySendMessage(message_ptr message) { + // Requires mSendMutex to be locked + // // TODO: Implement SCTP ndata specification draft when supported everywhere // See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08 @@ -316,19 +317,13 @@ bool SctpTransport::trySend(message_ptr message) { } void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) { - if (delta == 0) - return; - std::lock_guard lock(mBufferedAmountMutex); + // Requires mSendMutex to be locked auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first; - if (delta > 0) - it->second += size_t(delta); - else if (it->second > size_t(-delta)) - it->second -= size_t(-delta); - else - it->second = 0; - mBufferedAmountCallback(streamId, it->second); - if (it->second == 0) + size_t amount = it->second; + amount = size_t(std::max(long(amount) + delta, long(0))); + if (amount == 0) mBufferedAmount.erase(it); + mBufferedAmountCallback(streamId, amount); } int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, @@ -364,7 +359,8 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co int SctpTransport::handleSend(size_t free) { try { - trySendAll(); + std::lock_guard lock(mSendMutex); + trySendQueue(); } catch (const std::exception &e) { std::cerr << "SCTP send: " << e.what() << std::endl; return -1; @@ -470,7 +466,8 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s case SCTP_SENDER_DRY_EVENT: { // It not should be necessary since the send callback should have been called already, // but to be sure, let's try to send now. - trySendAll(); + std::lock_guard lock(mSendMutex); + trySendQueue(); } case SCTP_STREAM_RESET_EVENT: { const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event; diff --git a/src/sctptransport.hpp b/src/sctptransport.hpp index 6cb132d..95e9d58 100644 --- a/src/sctptransport.hpp +++ b/src/sctptransport.hpp @@ -69,8 +69,9 @@ private: void connect(); void incoming(message_ptr message); void changeState(State state); - bool trySendAll(); - bool trySend(message_ptr message); + + bool trySendQueue(); + bool trySendMessage(message_ptr message); void updateBufferedAmount(uint16_t streamId, long delta); int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len, @@ -86,8 +87,6 @@ private: std::mutex mSendMutex; Queue mSendQueue; - - std::mutex mBufferedAmountMutex; std::map mBufferedAmount; amount_callback mBufferedAmountCallback;