diff --git a/src/impl/sctptransport.cpp b/src/impl/sctptransport.cpp index ae7feef..306b755 100644 --- a/src/impl/sctptransport.cpp +++ b/src/impl/sctptransport.cpp @@ -85,8 +85,8 @@ void SctpTransport::Init() { // Change congestion control from the default TCP Reno (RFC 2581) to H-TCP usrsctp_sysctl_set_sctp_default_cc_module(SCTP_CC_HTCP); - // Enable Non-Renegable Selective Acknowledgments (NR-SACKs) - usrsctp_sysctl_set_sctp_nrsack_enable(1); + // Enable Partial Reliability Extension (RFC 3758) + usrsctp_sysctl_set_sctp_pr_enable(1); // Increase the initial window size to 10 MTUs (RFC 6928) usrsctp_sysctl_set_sctp_initial_cwnd(10); @@ -104,7 +104,7 @@ SctpTransport::SctpTransport(shared_ptr lower, uint16_t port, optional mtu, message_callback recvCallback, amount_callback bufferedAmountCallback, state_callback stateChangeCallback) - : Transport(lower, std::move(stateChangeCallback)), mPort(port), mPendingRecvCount(0), + : Transport(lower, std::move(stateChangeCallback)), mPort(port), mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) { onRecv(recvCallback); @@ -260,7 +260,7 @@ bool SctpTransport::stop() { return false; mSendQueue.stop(); - safeFlush(); + flush(); shutdown(); onRecv(nullptr); return true; @@ -334,13 +334,20 @@ bool SctpTransport::send(message_ptr message) { return false; } -void SctpTransport::closeStream(unsigned int stream) { - send(make_message(0, Message::Reset, uint16_t(stream))); +bool SctpTransport::flush() { + try { + std::lock_guard lock(mSendMutex); + trySendQueue(); + return true; + + } catch (const std::exception &e) { + PLOG_WARNING << "SCTP flush: " << e.what(); + return false; + } } -void SctpTransport::flush() { - std::lock_guard lock(mSendMutex); - trySendQueue(); +void SctpTransport::closeStream(unsigned int stream) { + send(make_message(0, Message::Reset, uint16_t(stream))); } void SctpTransport::incoming(message_ptr message) { @@ -428,6 +435,16 @@ void SctpTransport::doRecv() { } } +void SctpTransport::doFlush() { + std::lock_guard lock(mSendMutex); + --mPendingFlushCount; + try { + trySendQueue(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + bool SctpTransport::trySendQueue() { // Requires mSendMutex to be locked while (auto next = mSendQueue.peek()) { @@ -573,17 +590,6 @@ void SctpTransport::sendReset(uint16_t streamId) { } } -bool SctpTransport::safeFlush() { - try { - flush(); - return true; - - } catch (const std::exception &e) { - PLOG_WARNING << "SCTP flush: " << e.what(); - return false; - } -} - void SctpTransport::handleUpcall() { if (!mSock) return; @@ -597,8 +603,10 @@ void SctpTransport::handleUpcall() { mProcessor.enqueue(&SctpTransport::doRecv, this); } - if (events & SCTP_EVENT_WRITE) - mProcessor.enqueue(&SctpTransport::safeFlush, this); + if (events & SCTP_EVENT_WRITE && mPendingFlushCount == 0) { + ++mPendingFlushCount; + mProcessor.enqueue(&SctpTransport::doFlush, this); + } } int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, uint8_t /*set_df*/) { @@ -713,7 +721,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s PLOG_VERBOSE << "SCTP dry event"; // It should not be necessary since the send callback should have been called already, // but to be sure, let's try to send now. - safeFlush(); + flush(); break; } diff --git a/src/impl/sctptransport.hpp b/src/impl/sctptransport.hpp index 249f835..dce3c17 100644 --- a/src/impl/sctptransport.hpp +++ b/src/impl/sctptransport.hpp @@ -50,8 +50,8 @@ public: void start() override; bool stop() override; bool send(message_ptr message) override; // false if buffered + bool flush(); void closeStream(unsigned int stream); - void flush(); // Stats void clearStats(); @@ -79,12 +79,12 @@ private: bool outgoing(message_ptr message) override; void doRecv(); + void doFlush(); bool trySendQueue(); bool trySendMessage(message_ptr message); void updateBufferedAmount(uint16_t streamId, long delta); void triggerBufferedAmount(uint16_t streamId, size_t amount); void sendReset(uint16_t streamId); - bool safeFlush(); void handleUpcall(); int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df); @@ -96,7 +96,8 @@ private: struct socket *mSock; Processor mProcessor; - std::atomic mPendingRecvCount; + std::atomic mPendingRecvCount = 0; + std::atomic mPendingFlushCount = 0; std::mutex mRecvMutex; std::recursive_mutex mSendMutex; // buffered amount callback is synchronous Queue mSendQueue; diff --git a/src/impl/threadpool.cpp b/src/impl/threadpool.cpp index 690d985..5072427 100644 --- a/src/impl/threadpool.cpp +++ b/src/impl/threadpool.cpp @@ -51,7 +51,7 @@ void ThreadPool::spawn(int count) { void ThreadPool::join() { { std::unique_lock lock(mMutex); - mWaitingCondition.wait(lock, [&]() { return mWaitingWorkers == int(mWorkers.size()); }); + mWaitingCondition.wait(lock, [&]() { return mBusyWorkers == 0; }); mJoining = true; mTasksCondition.notify_all(); } @@ -66,6 +66,8 @@ void ThreadPool::join() { } void ThreadPool::run() { + ++mBusyWorkers; + scope_guard([&]() { --mBusyWorkers; }); while (runOne()) { } } @@ -81,24 +83,23 @@ bool ThreadPool::runOne() { std::function ThreadPool::dequeue() { std::unique_lock lock(mMutex); while (!mJoining) { + std::optional time; if (!mTasks.empty()) { - if (mTasks.top().time <= clock::now()) { + time = mTasks.top().time; + if (*time <= clock::now()) { auto func = std::move(mTasks.top().func); mTasks.pop(); return func; } - - ++mWaitingWorkers; - mWaitingCondition.notify_all(); - mTasksCondition.wait_until(lock, mTasks.top().time); - - } else { - ++mWaitingWorkers; - mWaitingCondition.notify_all(); - mTasksCondition.wait(lock); } - --mWaitingWorkers; + --mBusyWorkers; + scope_guard([&]() { ++mBusyWorkers; }); + mWaitingCondition.notify_all(); + if(time) + mTasksCondition.wait_until(lock, *time); + else + mTasksCondition.wait(lock); } return nullptr; } diff --git a/src/impl/threadpool.hpp b/src/impl/threadpool.hpp index d312b1a..f9e0f09 100644 --- a/src/impl/threadpool.hpp +++ b/src/impl/threadpool.hpp @@ -72,7 +72,7 @@ protected: std::function dequeue(); // returns null function if joining std::vector mWorkers; - int mWaitingWorkers = 0; + int mBusyWorkers = 0; std::atomic mJoining = false; struct Task { diff --git a/test/benchmark.cpp b/test/benchmark.cpp index 0787489..2969549 100644 --- a/test/benchmark.cpp +++ b/test/benchmark.cpp @@ -115,8 +115,12 @@ size_t benchmark(milliseconds duration) { openTime = steady_clock::now(); cout << "DataChannel open, sending data..." << endl; - while (dc1->bufferedAmount() == 0) { - dc1->send(messageData); + try { + while (dc1->bufferedAmount() == 0) { + dc1->send(messageData); + } + } catch (const std::exception &e) { + std::cout << "Send failed: " << e.what() << std::endl; } // When sent data is buffered in the DataChannel, @@ -129,8 +133,12 @@ size_t benchmark(milliseconds duration) { return; // Continue sending - while (dc1->bufferedAmount() == 0) { - dc1->send(messageData); + try { + while (dc1->isOpen() && dc1->bufferedAmount() == 0) { + dc1->send(messageData); + } + } catch (const std::exception &e) { + std::cout << "Send failed: " << e.what() << std::endl; } });