diff --git a/src/sctptransport.cpp b/src/sctptransport.cpp index 73b2aa1..0823c10 100644 --- a/src/sctptransport.cpp +++ b/src/sctptransport.cpp @@ -84,6 +84,9 @@ void SctpTransport::Init() { // Change congestion control from the default TCP Reno (RFC 2581) to H-TCP usrsctp_sysctl_set_sctp_default_cc_module(SCTP_CC_HTCP); + // Enable Partial Reliability Extension (RFC 3758) + usrsctp_sysctl_set_sctp_pr_enable(1); + // Enable Non-Renegable Selective Acknowledgments (NR-SACKs) usrsctp_sysctl_set_sctp_nrsack_enable(1); @@ -103,7 +106,7 @@ SctpTransport::SctpTransport(std::shared_ptr lower, uint16_t port, std::optional mtu, message_callback recvCallback, amount_callback bufferedAmountCallback, state_callback stateChangeCallback) - : Transport(lower, std::move(stateChangeCallback)), mPort(port), mPendingRecvCount(0), + : Transport(lower, std::move(stateChangeCallback)), mPort(port), mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) { onRecv(recvCallback); @@ -259,7 +262,7 @@ bool SctpTransport::stop() { return false; mSendQueue.stop(); - safeFlush(); + flush(); shutdown(); onRecv(nullptr); return true; @@ -333,13 +336,20 @@ bool SctpTransport::send(message_ptr message) { return false; } -void SctpTransport::closeStream(unsigned int stream) { - send(make_message(0, Message::Reset, uint16_t(stream))); +bool SctpTransport::flush() { + try { + std::lock_guard lock(mSendMutex); + trySendQueue(); + return true; + + } catch (const std::exception &e) { + PLOG_WARNING << "SCTP flush: " << e.what(); + return false; + } } -void SctpTransport::flush() { - std::lock_guard lock(mSendMutex); - trySendQueue(); +void SctpTransport::closeStream(unsigned int stream) { + send(make_message(0, Message::Reset, uint16_t(stream))); } void SctpTransport::incoming(message_ptr message) { @@ -427,6 +437,16 @@ void SctpTransport::doRecv() { } } +void SctpTransport::doFlush() { + std::lock_guard lock(mSendMutex); + --mPendingFlushCount; + try { + trySendQueue(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + bool SctpTransport::trySendQueue() { // Requires mSendMutex to be locked while (auto next = mSendQueue.peek()) { @@ -572,17 +592,6 @@ void SctpTransport::sendReset(uint16_t streamId) { } } -bool SctpTransport::safeFlush() { - try { - flush(); - return true; - - } catch (const std::exception &e) { - PLOG_WARNING << "SCTP flush: " << e.what(); - return false; - } -} - void SctpTransport::handleUpcall() { if (!mSock) return; @@ -596,8 +605,10 @@ void SctpTransport::handleUpcall() { mProcessor.enqueue(&SctpTransport::doRecv, this); } - if (events & SCTP_EVENT_WRITE) - mProcessor.enqueue(&SctpTransport::safeFlush, this); + if (events & SCTP_EVENT_WRITE && mPendingFlushCount == 0) { + ++mPendingFlushCount; + mProcessor.enqueue(&SctpTransport::doFlush, this); + } } int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, uint8_t /*set_df*/) { @@ -712,7 +723,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s PLOG_VERBOSE << "SCTP dry event"; // It should not be necessary since the send callback should have been called already, // but to be sure, let's try to send now. - safeFlush(); + flush(); break; } diff --git a/src/sctptransport.hpp b/src/sctptransport.hpp index a9a78a1..eb738d8 100644 --- a/src/sctptransport.hpp +++ b/src/sctptransport.hpp @@ -51,8 +51,8 @@ public: void start() override; bool stop() override; bool send(message_ptr message) override; // false if buffered + bool flush(); void closeStream(unsigned int stream); - void flush(); // Stats void clearStats(); @@ -80,12 +80,12 @@ private: bool outgoing(message_ptr message) override; void doRecv(); + void doFlush(); bool trySendQueue(); bool trySendMessage(message_ptr message); void updateBufferedAmount(uint16_t streamId, long delta); void triggerBufferedAmount(uint16_t streamId, size_t amount); void sendReset(uint16_t streamId); - bool safeFlush(); void handleUpcall(); int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df); @@ -97,7 +97,8 @@ private: struct socket *mSock; Processor mProcessor; - std::atomic mPendingRecvCount; + std::atomic mPendingRecvCount = 0; + std::atomic mPendingFlushCount = 0; std::mutex mRecvMutex; std::recursive_mutex mSendMutex; // buffered amount callback is synchronous Queue mSendQueue; diff --git a/test/benchmark.cpp b/test/benchmark.cpp index fca21f5..0a8e960 100644 --- a/test/benchmark.cpp +++ b/test/benchmark.cpp @@ -127,8 +127,12 @@ size_t benchmark(milliseconds duration) { openTime = steady_clock::now(); cout << "DataChannel open, sending data..." << endl; - while (dc1->bufferedAmount() == 0) { - dc1->send(messageData); + try { + while (dc1->bufferedAmount() == 0) { + dc1->send(messageData); + } + } catch (const std::exception &e) { + std::cout << "Send failed: " << e.what() << std::endl; } // When sent data is buffered in the DataChannel, @@ -141,8 +145,12 @@ size_t benchmark(milliseconds duration) { return; // Continue sending - while (dc1->bufferedAmount() == 0) { - dc1->send(messageData); + try { + while (dc1->isOpen() && dc1->bufferedAmount() == 0) { + dc1->send(messageData); + } + } catch (const std::exception &e) { + std::cout << "Send failed: " << e.what() << std::endl; } });