diff --git a/CMakeLists.txt b/CMakeLists.txt index 803b4dd..0358942 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ set(LIBDATACHANNEL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp ) set(LIBDATACHANNEL_HEADERS diff --git a/include/rtc/message.hpp b/include/rtc/message.hpp index 465b68a..e1396d1 100644 --- a/include/rtc/message.hpp +++ b/include/rtc/message.hpp @@ -30,6 +30,7 @@ namespace rtc { struct Message : binary { enum Type { Binary, String, Control, Reset }; + Message(const Message &message) = default; Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {} template diff --git a/src/base64.cpp b/src/base64.cpp new file mode 100644 index 0000000..1779346 --- /dev/null +++ b/src/base64.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if ENABLE_WEBSOCKET + +#include "base64.hpp" + +namespace rtc { + +using std::to_integer; + +string to_base64(const binary &data) { + static const char tab[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + string out; + out.reserve(3 * ((data.size() + 3) / 4)); + int i = 0; + while (data.size() - i >= 3) { + auto d0 = to_integer(data[i]); + auto d1 = to_integer(data[i + 1]); + auto d2 = to_integer(data[i + 2]); + out += tab[d0 >> 2]; + out += tab[((d0 & 3) << 4) | (d1 >> 4)]; + out += tab[((d1 & 0x0F) << 2) | (d2 >> 6)]; + out += tab[d2 & 0x3F]; + i += 3; + } + + int left = data.size() - i; + if (left) { + auto d0 = to_integer(data[i]); + out += tab[d0 >> 2]; + if (left == 1) { + out += tab[(d0 & 3) << 4]; + out += '='; + } else { // left == 2 + auto d1 = to_integer(data[i + 1]); + out += tab[((d0 & 3) << 4) | (d1 >> 4)]; + out += tab[(d1 & 0x0F) << 2]; + } + out += '='; + } + + return out; +} + +} // namespace rtc + +#endif diff --git a/src/base64.hpp b/src/base64.hpp new file mode 100644 index 0000000..41a06ed --- /dev/null +++ b/src/base64.hpp @@ -0,0 +1,34 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_BASE64_H +#define RTC_BASE64_H + +#if ENABLE_WEBSOCKET + +#include "include.hpp" + +namespace rtc { + +string to_base64(const binary &data); + +} + +#endif + +#endif diff --git a/src/tcptransport.cpp b/src/tcptransport.cpp index 5ebe6d3..f33df3d 100644 --- a/src/tcptransport.cpp +++ b/src/tcptransport.cpp @@ -22,6 +22,8 @@ namespace rtc { +using std::to_string; + TcpTransport::TcpTransport(const string &hostname, const string &service) : mHostname(hostname), mService(service) { mThread = std::thread(&TcpTransport::runLoop, this); @@ -146,7 +148,7 @@ bool TcpTransport::trySendMessage(message_ptr &message) { message = make_message(message->data() + len, message->data() + size); return false; } else { - throw std::runtime_error("Connection lost, errno=" + std::to_string(sockerrno)); + throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno)); } } message = nullptr; @@ -172,6 +174,7 @@ void TcpTransport::runLoop() { FD_ZERO(&readfds); FD_ZERO(&writefds); FD_SET(mSock, &readfds); + // TODO if (!mSendQueue.empty()) FD_SET(mSock, &writefds); int ret = ::select(SOCKET_TO_INT(mSock) + 1, &readfds, &writefds, NULL, NULL); @@ -182,7 +185,7 @@ void TcpTransport::runLoop() { char buffer[bufferSize]; int len = ::recv(mSock, buffer, bufferSize, 0); if (len < 0) - throw std::runtime_error("Connection lost, errno=" + std::to_string(sockerrno)); + throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno)); if (len == 0) break; // clean close diff --git a/src/websocket.cpp b/src/websocket.cpp new file mode 100644 index 0000000..fc3106b --- /dev/null +++ b/src/websocket.cpp @@ -0,0 +1,100 @@ +/************************************************************************* + * Copyright (C) 2017-2018 by Paul-Louis Ageneau * + * paul-louis (at) ageneau (dot) org * + * * + * This file is part of Plateform. * + * * + * Plateform is free software: you can redistribute it and/or modify * + * it under the terms of the GNU Affero General Public License as * + * published by the Free Software Foundation, either version 3 of * + * the License, or (at your option) any later version. * + * * + * Plateform is distributed in the hope that it will be useful, but * + * WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU Affero General Public License for more details. * + * * + * You should have received a copy of the GNU Affero General Public * + * License along with Plateform. * + * If not, see . * + *************************************************************************/ + +#include "net/websocket.hpp" + +#include +#include + +const size_t DEFAULT_MAX_PAYLOAD_SIZE = 16384; // 16 KB + +namespace net { + +WebSocket::WebSocket(void) : mMaxPayloadSize(DEFAULT_MAX_PAYLOAD_SIZE) {} + +WebSocket::WebSocket(const string &url) : WebSocket() { open(url); } + +WebSocket::~WebSocket(void) {} + +void WebSocket::open(const string &url) { + close(); + + mUrl = url; + mThread = std::thread(&WebSocket::run, this); +} + +void WebSocket::close(void) { + mWebSocket.close(); + if (mThread.joinable()) + mThread.join(); + mConnected = false; +} + +bool WebSocket::isOpen(void) const { return mConnected; } + +bool WebSocket::isClosed(void) const { return !mThread.joinable(); } + +void WebSocket::setMaxPayloadSize(size_t size) { mMaxPayloadSize = size; } + +bool WebSocket::send(const std::variant &data) { + if (!std::holds_alternative(data)) + throw std::runtime_error("WebSocket string messages are not supported"); + + mWebSocket.write(std::get(data)); + return true; +} + +std::optional> WebSocket::receive() { + if (!mQueue.empty()) + return mQueue.pop(); + else + return std::nullopt; +} + +void WebSocket::run(void) { + if (mUrl.empty()) + return; + + try { + mWebSocket.connect(mUrl); + + mConnected = true; + triggerOpen(); + + while (true) { + binary payload; + if (!mWebSocket.read(payload, mMaxPayloadSize)) + break; + mQueue.push(std::move(payload)); + triggerAvailable(mQueue.size()); + } + } catch (const std::exception &e) { + triggerError(e.what()); + } + + mWebSocket.close(); + + if (mConnected) + triggerClosed(); + mConnected = false; +} + +} // namespace net diff --git a/src/wstransport.cpp b/src/wstransport.cpp new file mode 100644 index 0000000..e57013f --- /dev/null +++ b/src/wstransport.cpp @@ -0,0 +1,332 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#if ENABLE_WEBSOCKET + +#include "wstransport.hpp" +#include "tcptransport.hpp" +#include "tlstransport.hpp" + +#include "base64.hpp" + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifndef htonll +#define htonll(x) \ + ((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32)) +#endif +#ifndef ntohll +#define ntohll(x) htonll(x) +#endif + +namespace rtc { + +using namespace std::chrono; +using std::to_integer; +using std::to_string; + +using random_bytes_engine = + std::independent_bits_engine; + +WsTransport::WsTransport(std::shared_ptr lower, string host, string path) + : Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {} + +WsTransport::WsTransport(std::shared_ptr lower, string host, string path) + : Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {} + +WsTransport::~WsTransport() {} + +void WsTransport::stop() {} + +bool WsTransport::send(message_ptr message) { + if (!message) + return false; + + // Call the mutable message overload with a copy + return send(std::make_shared(*message)); +} + +bool WsTransport::send(mutable_message_ptr message) { + if (!message) + return false; + + return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(), + message->size(), true, true}); +} + +void WsTransport::incoming(message_ptr message) { + mBuffer.insert(mBuffer.end(), message->begin(), message->end()); + + if (!mHandshakeDone) { + if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) { + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + mHandshakeDone = true; + } + } + + if (mHandshakeDone) { + Frame frame = {}; + while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) { + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + recvFrame(frame); + } + } +} + +void WsTransport::connect() { sendHttpRequest(); } + +void WsTransport::close() { + if (mHandshakeDone) + sendFrame({CLOSE, NULL, 0, true, true}); +} + +bool WsTransport::sendHttpRequest() { + auto seed = system_clock::now().time_since_epoch().count(); + random_bytes_engine generator(seed); + + binary key(16); + std::generate(reinterpret_cast(key.data()), + reinterpret_cast(key.data() + key.size()), generator); + + const string request = "GET " + mPath + + " HTTP/1.1\r\n" + "Host: " + + mHost + + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Key: " + + to_base64(key) + + "\r\n" + "\r\n"; + + auto data = reinterpret_cast(request.data()); + auto size = request.size(); + return outgoing(make_message(data, data + size)); +} + +size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) { + + std::list lines; + auto begin = reinterpret_cast(buffer); + auto end = begin + size; + auto cur = begin; + while ((cur = std::find(cur, end, '\n')) != end) { + string line(begin, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++); + if (line.empty()) + break; + lines.emplace_back(std::move(line)); + } + size_t length = cur - begin; + + string status = std::move(lines.front()); + lines.pop_front(); + + std::istringstream ss(status); + string protocol; + unsigned int code = 0; + ss >> protocol >> code; + if (code != 101) + throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code)); + + std::multimap headers; + for (const auto &line : lines) { + if (size_t pos = line.find_first_of(':'); pos != string::npos) { + string key = line.substr(0, pos); + string value = line.substr(line.find_first_not_of(' ', pos + 1)); + std::transform(key.begin(), key.end(), key.begin(), + [](char c) { return std::tolower(c); }); + headers.emplace(std::move(key), std::move(value)); + } else { + headers.emplace(line, ""); + } + } + + auto h = headers.find("upgrade"); + if (h == headers.end() || h->second != "websocket") + throw std::runtime_error("WebSocket update header missing or mismatching"); + + h = headers.find("sec-websocket-accept"); + throw std::runtime_error("WebSocket accept header missing"); + + // TODO: Verify Sec-WebSocket-Accept + + return length; +} + +// http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-------+-+-------------+-------------------------------+ +// |F|R|R|R| opcode|M| Payload len | Extended payload length | +// |I|S|S|S| (4) |A| (7) | (16/64) | +// |N|V|V|V| |S| | (if payload len==126/127) | +// | |1|2|3| |K| | | +// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +// | Extended payload length continued, if payload len == 127 | +// + - - - - - - - - - - - - - - - +-------------------------------+ +// | | Masking-key, if MASK set to 1 | +// +-------------------------------+-------------------------------+ +// | Masking-key (continued) | Payload Data | +// +-------------------------------+ - - - - - - - - - - - - - - - + +// : Payload Data continued ... : +// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// | Payload Data continued ... | +// +---------------------------------------------------------------+ + +size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) { + const byte *end = buffer + size; + if (end - buffer < 2) + return 0; + + byte *cur = buffer; + auto b1 = to_integer(*cur++); + auto b2 = to_integer(*cur++); + + frame.fin = (b1 & 0x80) != 0; + frame.mask = (b2 & 0x80) != 0; + frame.opcode = static_cast(b1 & 0x0F); + frame.length = b2 & 0x7F; + + if (frame.length == 0x7E) { + if (end - cur < 2) + return 0; + frame.length = ntohs(*reinterpret_cast(cur)); + cur += 2; + } else if (frame.length == 0x7F) { + if (end - cur < 8) + return false; + frame.length = ntohll(*reinterpret_cast(cur)); + cur += 8; + } + + const byte *maskingKey = nullptr; + if (frame.mask) { + if (end - cur < 4) + return 0; + maskingKey = cur; + cur += 4; + } + + if (end - cur < frame.length) + return false; + + frame.payload = cur; + if (maskingKey) + for (size_t i = 0; i < frame.length; ++i) + frame.payload[i] ^= maskingKey[i % 4]; + + return end - buffer; +} + +void WsTransport::recvFrame(const Frame &frame) { + switch (frame.opcode) { + case TEXT_FRAME: + case BINARY_FRAME: { + if (!mPartial.empty()) { + auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(mPartial.begin(), mPartial.end(), type)); + mPartial.clear(); + } + if (frame.fin) { + auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(frame.payload, frame.payload + frame.length)); + } else { + mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length); + mPartialOpcode = frame.opcode; + } + break; + } + case CONTINUATION: { + mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length); + if (frame.fin) { + auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(mPartial.begin(), mPartial.end())); + mPartial.clear(); + } + break; + } + case PING: { + sendFrame({PONG, frame.payload, frame.length, true, true}); + break; + } + case PONG: { + // TODO + break; + } + case CLOSE: { + close(); + break; + } + default: { + close(); + throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode)); + } + } +} + +bool WsTransport::sendFrame(const Frame &frame) { + byte buffer[14]; + byte *cur = buffer; + + *cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0)); + + if (frame.length < 0x7E) { + *cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0)); + } else if (frame.length <= 0xFF) { + *cur++ = byte(0x7E | (frame.mask ? 0x80 : 0)); + *reinterpret_cast(cur) = uint16_t(frame.length); + cur += 2; + } else { + *cur++ = byte(0x7F | (frame.mask ? 0x80 : 0)); + *reinterpret_cast(cur) = uint64_t(frame.length); + cur += 8; + } + + if (frame.mask) { + auto seed = system_clock::now().time_since_epoch().count(); + random_bytes_engine generator(seed); + + auto *maskingKey = cur; + std::generate(reinterpret_cast(maskingKey), + reinterpret_cast(maskingKey + 4), generator); + cur += 4; + + for (size_t i = 0; i < frame.length; ++i) + frame.payload[i] ^= maskingKey[i % 4]; + } + + outgoing(make_message(buffer, cur)); // header + return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload +} + +} // namespace rtc + +#endif diff --git a/src/wstransport.hpp b/src/wstransport.hpp new file mode 100644 index 0000000..fd5ff01 --- /dev/null +++ b/src/wstransport.hpp @@ -0,0 +1,85 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_WS_TRANSPORT_H +#define RTC_WS_TRANSPORT_H + +#if ENABLE_WEBSOCKET + +#include "include.hpp" +#include "transport.hpp" + +namespace rtc { + +class TcpTransport; +class TlsTransport; + +class WsTransport : public Transport { +public: + WsTransport(std::shared_ptr lower, string host, string path); + WsTransport(std::shared_ptr lower, string host, string path); + ~WsTransport(); + + void stop() override; + bool send(message_ptr message) override; + bool send(mutable_message_ptr message); + + void incoming(message_ptr message) override; + +private: + enum Opcode : uint8_t { + CONTINUATION = 0, + TEXT_FRAME = 1, + BINARY_FRAME = 2, + CLOSE = 8, + PING = 9, + PONG = 10, + }; + + struct Frame { + Opcode opcode = BINARY_FRAME; + byte *payload = nullptr; + size_t length = 0; + bool fin = true; + bool mask = true; + }; + + void connect(); + void close(); + + bool sendHttpRequest(); + size_t readHttpResponse(const byte *buffer, size_t size); + + size_t readFrame(byte *buffer, size_t size, Frame &frame); + void recvFrame(const Frame &frame); + bool sendFrame(const Frame &frame); + + const string mHost; + const string mPath; + + bool mHandshakeDone = false; + binary mBuffer; + binary mPartial; + Opcode mPartialOpcode; + }; + +} // namespace rtc + +#endif + +#endif