diff --git a/include/rtc/websocket.hpp b/include/rtc/websocket.hpp index 7dd58a7..0e24deb 100644 --- a/include/rtc/websocket.hpp +++ b/include/rtc/websocket.hpp @@ -49,6 +49,7 @@ public: struct Configuration { bool disableTlsVerification = false; // if true, don't verify the TLS certificate + std::optional> protocols = std::nullopt; }; WebSocket(std::optional config = nullopt); diff --git a/src/websocket.cpp b/src/websocket.cpp index 75ae9e9..a12dd91 100644 --- a/src/websocket.cpp +++ b/src/websocket.cpp @@ -291,7 +291,14 @@ shared_ptr WebSocket::initWsTransport() { shared_ptr lower = std::atomic_load(&mTlsTransport); if (!lower) lower = std::atomic_load(&mTcpTransport); + + auto wsConfig = WsTransport::Configuration(); + if(mConfig.protocols) { + wsConfig.protocols = *mConfig.protocols; + } + auto transport = std::make_shared( + wsConfig, lower, mHost, mPath, weak_bind(&WebSocket::incoming, this, _1), [this, weak_this = weak_from_this()](State state) { auto shared_this = weak_this.lock(); diff --git a/src/wstransport.cpp b/src/wstransport.cpp index 8c6d9ff..5228f41 100644 --- a/src/wstransport.cpp +++ b/src/wstransport.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #ifdef _WIN32 #include @@ -53,9 +54,9 @@ using std::to_string; using random_bytes_engine = std::independent_bits_engine; -WsTransport::WsTransport(std::shared_ptr lower, string host, string path, +WsTransport::WsTransport(std::optional config, std::shared_ptr lower, string host, string path, message_callback recvCallback, state_callback stateCallback) - : Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) { + : Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)), mConfig(config ? std::move(*config) : Configuration()) { onRecv(recvCallback); PLOG_DEBUG << "Initializing WebSocket transport"; @@ -164,6 +165,15 @@ bool WsTransport::sendHttpRequest() { auto k = reinterpret_cast(key.data()); std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); }); + string appendHeader = ""; + if(mConfig.protocols.size() > 0) { + appendHeader += "Sec-WebSocket-Protocol: " + + std::accumulate(mConfig.protocols.begin(), mConfig.protocols.end(), string(), [](const string& a, const string& b) -> string { + return a + (a.length() > 0 ? "," : "") + b; + }) + + "\r\n"; + } + const string request = "GET " + mPath + " HTTP/1.1\r\n" "Host: " + @@ -174,8 +184,9 @@ bool WsTransport::sendHttpRequest() { "Sec-WebSocket-Version: 13\r\n" "Sec-WebSocket-Key: " + to_base64(key) + - "\r\n" - "\r\n"; + "\r\n" + + std::move(appendHeader) + + "\r\n"; auto data = reinterpret_cast(request.data()); auto size = request.size(); diff --git a/src/wstransport.hpp b/src/wstransport.hpp index 69ebed1..226df07 100644 --- a/src/wstransport.hpp +++ b/src/wstransport.hpp @@ -31,7 +31,11 @@ class TlsTransport; class WsTransport : public Transport { public: - WsTransport(std::shared_ptr lower, string host, string path, + struct Configuration { + std::vector protocols; + }; + + WsTransport(std::optional config, std::shared_ptr lower, string host, string path, message_callback recvCallback, state_callback stateCallback); ~WsTransport(); @@ -74,6 +78,8 @@ private: binary mBuffer; binary mPartial; Opcode mPartialOpcode; + + const Configuration mConfig; }; } // namespace rtc