mirror of
https://github.com/mii443/libdatachannel.git
synced 2025-12-04 19:48:25 +00:00
Changed buffer amount low behavior to prevent deadlock situations
This commit is contained in:
@@ -21,8 +21,8 @@
|
||||
|
||||
#include "include.hpp"
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <variant>
|
||||
|
||||
namespace rtc {
|
||||
@@ -30,7 +30,7 @@ namespace rtc {
|
||||
class Channel {
|
||||
public:
|
||||
virtual void close() = 0;
|
||||
virtual void send(const std::variant<binary, string> &data) = 0;
|
||||
virtual bool send(const std::variant<binary, string> &data) = 0;
|
||||
virtual std::optional<std::variant<binary, string>> receive() = 0;
|
||||
virtual bool isOpen() const = 0;
|
||||
virtual bool isClosed() const = 0;
|
||||
@@ -66,8 +66,8 @@ private:
|
||||
synchronized_callback<> mAvailableCallback;
|
||||
synchronized_callback<> mBufferedAmountLowCallback;
|
||||
|
||||
size_t mBufferedAmount = 0;
|
||||
size_t mBufferedAmountLowThreshold = 0;
|
||||
std::atomic<size_t> mBufferedAmount = 0;
|
||||
std::atomic<size_t> mBufferedAmountLowThreshold = 0;
|
||||
};
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
@@ -45,13 +45,13 @@ public:
|
||||
~DataChannel();
|
||||
|
||||
void close(void);
|
||||
void send(const std::variant<binary, string> &data);
|
||||
void send(const byte *data, size_t size);
|
||||
bool send(const std::variant<binary, string> &data);
|
||||
bool send(const byte *data, size_t size);
|
||||
std::optional<std::variant<binary, string>> receive();
|
||||
|
||||
// Directly send a buffer to avoid a copy
|
||||
template <typename Buffer> void sendBuffer(const Buffer &buf);
|
||||
template <typename Iterator> void sendBuffer(Iterator first, Iterator last);
|
||||
template <typename Buffer> bool sendBuffer(const Buffer &buf);
|
||||
template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
|
||||
|
||||
bool isOpen(void) const;
|
||||
bool isClosed(void) const;
|
||||
@@ -65,7 +65,7 @@ public:
|
||||
|
||||
private:
|
||||
void open(std::shared_ptr<SctpTransport> sctpTransport);
|
||||
void outgoing(mutable_message_ptr message);
|
||||
bool outgoing(mutable_message_ptr message);
|
||||
void incoming(message_ptr message);
|
||||
void processOpenMessage(message_ptr message);
|
||||
|
||||
@@ -93,14 +93,14 @@ template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer
|
||||
buf.size() * sizeof(E));
|
||||
}
|
||||
|
||||
template <typename Buffer> void DataChannel::sendBuffer(const Buffer &buf) {
|
||||
template <typename Buffer> bool DataChannel::sendBuffer(const Buffer &buf) {
|
||||
auto [bytes, size] = to_bytes(buf);
|
||||
auto message = std::make_shared<Message>(size);
|
||||
std::copy(bytes, bytes + size, message->data());
|
||||
outgoing(message);
|
||||
return outgoing(message);
|
||||
}
|
||||
|
||||
template <typename Iterator> void DataChannel::sendBuffer(Iterator first, Iterator last) {
|
||||
template <typename Iterator> bool DataChannel::sendBuffer(Iterator first, Iterator last) {
|
||||
size_t size = 0;
|
||||
for (Iterator it = first; it != last; ++it)
|
||||
size += it->size();
|
||||
@@ -111,7 +111,7 @@ template <typename Iterator> void DataChannel::sendBuffer(Iterator first, Iterat
|
||||
auto [bytes, size] = to_bytes(*it);
|
||||
pos = std::copy(bytes, bytes + size, pos);
|
||||
}
|
||||
outgoing(message);
|
||||
return outgoing(message);
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
@@ -80,10 +80,9 @@ void Channel::triggerAvailable(size_t count) {
|
||||
}
|
||||
|
||||
void Channel::triggerBufferedAmount(size_t amount) {
|
||||
bool lowThresholdCrossed =
|
||||
mBufferedAmount > mBufferedAmountLowThreshold && amount <= mBufferedAmountLowThreshold;
|
||||
mBufferedAmount = amount;
|
||||
if (lowThresholdCrossed)
|
||||
size_t previous = mBufferedAmount.exchange(amount);
|
||||
size_t threshold = mBufferedAmountLowThreshold.load();
|
||||
if (previous > threshold && amount <= threshold)
|
||||
mBufferedAmountLowCallback();
|
||||
}
|
||||
|
||||
|
||||
@@ -83,19 +83,19 @@ void DataChannel::close() {
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannel::send(const std::variant<binary, string> &data) {
|
||||
std::visit(
|
||||
bool DataChannel::send(const std::variant<binary, string> &data) {
|
||||
return std::visit(
|
||||
[&](const auto &d) {
|
||||
using T = std::decay_t<decltype(d)>;
|
||||
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
|
||||
auto *b = reinterpret_cast<const byte *>(d.data());
|
||||
outgoing(std::make_shared<Message>(b, b + d.size(), type));
|
||||
return outgoing(std::make_shared<Message>(b, b + d.size(), type));
|
||||
},
|
||||
data);
|
||||
}
|
||||
|
||||
void DataChannel::send(const byte *data, size_t size) {
|
||||
outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
|
||||
bool DataChannel::send(const byte *data, size_t size) {
|
||||
return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
|
||||
}
|
||||
|
||||
std::optional<std::variant<binary, string>> DataChannel::receive() {
|
||||
@@ -177,7 +177,7 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
|
||||
mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
|
||||
}
|
||||
|
||||
void DataChannel::outgoing(mutable_message_ptr message) {
|
||||
bool DataChannel::outgoing(mutable_message_ptr message) {
|
||||
if (mIsClosed || !mSctpTransport)
|
||||
throw std::runtime_error("DataChannel is closed");
|
||||
|
||||
@@ -187,7 +187,7 @@ void DataChannel::outgoing(mutable_message_ptr message) {
|
||||
// Before the ACK has been received on a DataChannel, all messages must be sent ordered
|
||||
message->reliability = mIsOpen ? mReliability : nullptr;
|
||||
message->stream = mStream;
|
||||
mSctpTransport->send(message);
|
||||
return mSctpTransport->send(message);
|
||||
}
|
||||
|
||||
void DataChannel::incoming(message_ptr message) {
|
||||
|
||||
@@ -188,13 +188,18 @@ void SctpTransport::connect() {
|
||||
SctpTransport::State SctpTransport::state() const { return mState; }
|
||||
|
||||
bool SctpTransport::send(message_ptr message) {
|
||||
if (!message)
|
||||
return false;
|
||||
std::lock_guard<std::mutex> lock(mSendMutex);
|
||||
|
||||
updateBufferedAmount(message->stream, message->size());
|
||||
mSendQueue.push(message);
|
||||
trySendAll();
|
||||
if (!message)
|
||||
return mSendQueue.empty();
|
||||
|
||||
// If nothing is pending, try to send directly
|
||||
if (mSendQueue.empty() && trySendMessage(message))
|
||||
return true;
|
||||
|
||||
mSendQueue.push(message);
|
||||
updateBufferedAmount(message->stream, message_size_func(message));
|
||||
return false;
|
||||
}
|
||||
|
||||
void SctpTransport::reset(unsigned int stream) {
|
||||
@@ -231,25 +236,21 @@ void SctpTransport::changeState(State state) {
|
||||
mStateChangeCallback(state);
|
||||
}
|
||||
|
||||
bool SctpTransport::trySendAll() {
|
||||
std::unique_lock<std::mutex> lock(mSendMutex, std::try_to_lock);
|
||||
if (!lock.owns_lock())
|
||||
return false;
|
||||
|
||||
bool SctpTransport::trySendQueue() {
|
||||
// Requires mSendMutex to be locked
|
||||
while (auto next = mSendQueue.peek()) {
|
||||
auto message = *next;
|
||||
if (!trySend(message))
|
||||
if (!trySendMessage(message))
|
||||
return false;
|
||||
updateBufferedAmount(message->stream, -message->size());
|
||||
mSendQueue.pop();
|
||||
updateBufferedAmount(message->stream, -message_size_func(message));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SctpTransport::trySend(message_ptr message) {
|
||||
if (!message)
|
||||
return false;
|
||||
|
||||
bool SctpTransport::trySendMessage(message_ptr message) {
|
||||
// Requires mSendMutex to be locked
|
||||
//
|
||||
// TODO: Implement SCTP ndata specification draft when supported everywhere
|
||||
// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
|
||||
|
||||
@@ -316,19 +317,13 @@ bool SctpTransport::trySend(message_ptr message) {
|
||||
}
|
||||
|
||||
void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
|
||||
if (delta == 0)
|
||||
return;
|
||||
std::lock_guard<std::mutex> lock(mBufferedAmountMutex);
|
||||
// Requires mSendMutex to be locked
|
||||
auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
|
||||
if (delta > 0)
|
||||
it->second += size_t(delta);
|
||||
else if (it->second > size_t(-delta))
|
||||
it->second -= size_t(-delta);
|
||||
else
|
||||
it->second = 0;
|
||||
mBufferedAmountCallback(streamId, it->second);
|
||||
if (it->second == 0)
|
||||
size_t amount = it->second;
|
||||
amount = size_t(std::max(long(amount) + delta, long(0)));
|
||||
if (amount == 0)
|
||||
mBufferedAmount.erase(it);
|
||||
mBufferedAmountCallback(streamId, amount);
|
||||
}
|
||||
|
||||
int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
|
||||
@@ -364,7 +359,8 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co
|
||||
|
||||
int SctpTransport::handleSend(size_t free) {
|
||||
try {
|
||||
trySendAll();
|
||||
std::lock_guard<std::mutex> lock(mSendMutex);
|
||||
trySendQueue();
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << "SCTP send: " << e.what() << std::endl;
|
||||
return -1;
|
||||
@@ -470,7 +466,8 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
|
||||
case SCTP_SENDER_DRY_EVENT: {
|
||||
// It not should be necessary since the send callback should have been called already,
|
||||
// but to be sure, let's try to send now.
|
||||
trySendAll();
|
||||
std::lock_guard<std::mutex> lock(mSendMutex);
|
||||
trySendQueue();
|
||||
}
|
||||
case SCTP_STREAM_RESET_EVENT: {
|
||||
const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
|
||||
|
||||
@@ -69,8 +69,9 @@ private:
|
||||
void connect();
|
||||
void incoming(message_ptr message);
|
||||
void changeState(State state);
|
||||
bool trySendAll();
|
||||
bool trySend(message_ptr message);
|
||||
|
||||
bool trySendQueue();
|
||||
bool trySendMessage(message_ptr message);
|
||||
void updateBufferedAmount(uint16_t streamId, long delta);
|
||||
|
||||
int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
|
||||
@@ -86,8 +87,6 @@ private:
|
||||
|
||||
std::mutex mSendMutex;
|
||||
Queue<message_ptr> mSendQueue;
|
||||
|
||||
std::mutex mBufferedAmountMutex;
|
||||
std::map<uint16_t, size_t> mBufferedAmount;
|
||||
amount_callback mBufferedAmountCallback;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user