mirror of
https://github.com/mii443/libdatachannel.git
synced 2025-08-22 23:25:33 +00:00
373 lines
11 KiB
C++
373 lines
11 KiB
C++
/**
|
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
|
*
|
|
* This library is free software; you can redistribute it and/or
|
|
* modify it under the terms of the GNU Lesser General Public
|
|
* License as published by the Free Software Foundation; either
|
|
* version 2.1 of the License, or (at your option) any later version.
|
|
*
|
|
* This library is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
* Lesser General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Lesser General Public
|
|
* License along with this library; if not, write to the Free Software
|
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|
*/
|
|
|
|
#if RTC_ENABLE_WEBSOCKET
|
|
|
|
#include "wstransport.hpp"
|
|
#include "tcptransport.hpp"
|
|
#include "tlstransport.hpp"
|
|
|
|
#include "base64.hpp"
|
|
|
|
#include <chrono>
|
|
#include <list>
|
|
#include <map>
|
|
#include <random>
|
|
#include <regex>
|
|
|
|
#ifdef _WIN32
|
|
#include <winsock2.h>
|
|
#else
|
|
#include <arpa/inet.h>
|
|
#endif
|
|
|
|
#ifndef htonll
|
|
#define htonll(x) \
|
|
((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
|
|
#endif
|
|
#ifndef ntohll
|
|
#define ntohll(x) htonll(x)
|
|
#endif
|
|
|
|
namespace rtc {
|
|
|
|
using namespace std::chrono;
|
|
using std::to_integer;
|
|
using std::to_string;
|
|
|
|
using random_bytes_engine =
|
|
std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
|
|
|
|
WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
|
|
message_callback recvCallback, state_callback stateCallback)
|
|
: Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) {
|
|
onRecv(recvCallback);
|
|
|
|
PLOG_DEBUG << "Initializing WebSocket transport";
|
|
|
|
registerIncoming();
|
|
sendHttpRequest();
|
|
}
|
|
|
|
WsTransport::~WsTransport() { stop(); }
|
|
|
|
bool WsTransport::stop() {
|
|
if (!Transport::stop())
|
|
return false;
|
|
|
|
close();
|
|
return true;
|
|
}
|
|
|
|
bool WsTransport::send(message_ptr message) {
|
|
if (!message)
|
|
return false;
|
|
|
|
// Call the mutable message overload with a copy
|
|
return send(std::make_shared<Message>(*message));
|
|
}
|
|
|
|
bool WsTransport::send(mutable_message_ptr message) {
|
|
if (!message || state() != State::Connected)
|
|
return false;
|
|
|
|
PLOG_VERBOSE << "Send size=" << message->size();
|
|
|
|
return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
|
|
message->size(), true, true});
|
|
}
|
|
|
|
void WsTransport::incoming(message_ptr message) {
|
|
try {
|
|
mBuffer.insert(mBuffer.end(), message->begin(), message->end());
|
|
|
|
if (state() == State::Connecting) {
|
|
if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
|
|
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
|
|
PLOG_INFO << "WebSocket open";
|
|
changeState(State::Connected);
|
|
}
|
|
}
|
|
|
|
if (state() == State::Connected) {
|
|
Frame frame = {};
|
|
while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
|
|
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
|
|
recvFrame(frame);
|
|
}
|
|
}
|
|
} catch (const std::exception &e) {
|
|
PLOG_ERROR << e.what();
|
|
}
|
|
|
|
if (state() == State::Connected) {
|
|
PLOG_INFO << "WebSocket disconnected";
|
|
changeState(State::Disconnected);
|
|
recv(nullptr);
|
|
} else {
|
|
PLOG_ERROR << "WebSocket handshake failed";
|
|
changeState(State::Failed);
|
|
}
|
|
}
|
|
|
|
void WsTransport::close() {
|
|
if (state() == State::Connected) {
|
|
sendFrame({CLOSE, NULL, 0, true, true});
|
|
PLOG_INFO << "WebSocket closing";
|
|
changeState(State::Completed);
|
|
}
|
|
}
|
|
|
|
bool WsTransport::sendHttpRequest() {
|
|
changeState(State::Connecting);
|
|
|
|
auto seed = system_clock::now().time_since_epoch().count();
|
|
random_bytes_engine generator(seed);
|
|
|
|
binary key(16);
|
|
std::generate(reinterpret_cast<uint8_t *>(key.data()),
|
|
reinterpret_cast<uint8_t *>(key.data() + key.size()), generator);
|
|
|
|
const string request = "GET " + mPath +
|
|
" HTTP/1.1\r\n"
|
|
"Host: " +
|
|
mHost +
|
|
"\r\n"
|
|
"Connection: Upgrade\r\n"
|
|
"Upgrade: websocket\r\n"
|
|
"Sec-WebSocket-Version: 13\r\n"
|
|
"Sec-WebSocket-Key: " +
|
|
to_base64(key) +
|
|
"\r\n"
|
|
"\r\n";
|
|
|
|
auto data = reinterpret_cast<const byte *>(request.data());
|
|
auto size = request.size();
|
|
return outgoing(make_message(data, data + size));
|
|
}
|
|
|
|
size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
|
|
std::list<string> lines;
|
|
auto begin = reinterpret_cast<const char *>(buffer);
|
|
auto end = begin + size;
|
|
auto cur = begin;
|
|
while (true) {
|
|
auto last = cur;
|
|
cur = std::find(cur, end, '\n');
|
|
if (cur == end)
|
|
return 0;
|
|
string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
|
|
if (line.empty())
|
|
break;
|
|
lines.emplace_back(std::move(line));
|
|
}
|
|
size_t length = cur - begin;
|
|
|
|
if (lines.empty())
|
|
throw std::runtime_error("Invalid HTTP response for WebSocket");
|
|
|
|
string status = std::move(lines.front());
|
|
lines.pop_front();
|
|
|
|
std::istringstream ss(status);
|
|
string protocol;
|
|
unsigned int code = 0;
|
|
ss >> protocol >> code;
|
|
PLOG_DEBUG << "WebSocket response code: " << code;
|
|
if (code != 101)
|
|
throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
|
|
|
|
std::multimap<string, string> headers;
|
|
for (const auto &line : lines) {
|
|
if (size_t pos = line.find_first_of(':'); pos != string::npos) {
|
|
string key = line.substr(0, pos);
|
|
string value = line.substr(line.find_first_not_of(' ', pos + 1));
|
|
std::transform(key.begin(), key.end(), key.begin(),
|
|
[](char c) { return std::tolower(c); });
|
|
headers.emplace(std::move(key), std::move(value));
|
|
} else {
|
|
headers.emplace(line, "");
|
|
}
|
|
}
|
|
|
|
auto h = headers.find("upgrade");
|
|
if (h == headers.end() || h->second != "websocket")
|
|
throw std::runtime_error("WebSocket update header missing or mismatching");
|
|
|
|
h = headers.find("sec-websocket-accept");
|
|
if (h == headers.end())
|
|
throw std::runtime_error("WebSocket accept header missing");
|
|
|
|
// TODO: Verify Sec-WebSocket-Accept
|
|
|
|
return length;
|
|
}
|
|
|
|
// http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol
|
|
//
|
|
// 0 1 2 3
|
|
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
|
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
|
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
|
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
|
// |N|V|V|V| |S| | (if payload len==126/127) |
|
|
// | |1|2|3| |K| | |
|
|
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
|
// | Extended payload length continued, if payload len == 127 |
|
|
// + - - - - - - - - - - - - - - - +-------------------------------+
|
|
// | | Masking-key, if MASK set to 1 |
|
|
// +-------------------------------+-------------------------------+
|
|
// | Masking-key (continued) | Payload Data |
|
|
// +-------------------------------+ - - - - - - - - - - - - - - - +
|
|
// : Payload Data continued ... :
|
|
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
|
// | Payload Data continued ... |
|
|
// +---------------------------------------------------------------+
|
|
|
|
size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
|
|
const byte *end = buffer + size;
|
|
if (end - buffer < 2)
|
|
return 0;
|
|
|
|
byte *cur = buffer;
|
|
auto b1 = to_integer<uint8_t>(*cur++);
|
|
auto b2 = to_integer<uint8_t>(*cur++);
|
|
|
|
frame.fin = (b1 & 0x80) != 0;
|
|
frame.mask = (b2 & 0x80) != 0;
|
|
frame.opcode = static_cast<Opcode>(b1 & 0x0F);
|
|
frame.length = b2 & 0x7F;
|
|
|
|
if (frame.length == 0x7E) {
|
|
if (end - cur < 2)
|
|
return 0;
|
|
frame.length = ntohs(*reinterpret_cast<const uint16_t *>(cur));
|
|
cur += 2;
|
|
} else if (frame.length == 0x7F) {
|
|
if (end - cur < 8)
|
|
return false;
|
|
frame.length = ntohll(*reinterpret_cast<const uint64_t *>(cur));
|
|
cur += 8;
|
|
}
|
|
|
|
const byte *maskingKey = nullptr;
|
|
if (frame.mask) {
|
|
if (end - cur < 4)
|
|
return 0;
|
|
maskingKey = cur;
|
|
cur += 4;
|
|
}
|
|
|
|
if (end - cur < frame.length)
|
|
return false;
|
|
|
|
frame.payload = cur;
|
|
if (maskingKey)
|
|
for (size_t i = 0; i < frame.length; ++i)
|
|
frame.payload[i] ^= maskingKey[i % 4];
|
|
|
|
return end - buffer;
|
|
}
|
|
|
|
void WsTransport::recvFrame(const Frame &frame) {
|
|
switch (frame.opcode) {
|
|
case TEXT_FRAME:
|
|
case BINARY_FRAME: {
|
|
if (!mPartial.empty()) {
|
|
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
|
|
recv(make_message(mPartial.begin(), mPartial.end(), type));
|
|
mPartial.clear();
|
|
}
|
|
if (frame.fin) {
|
|
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
|
|
recv(make_message(frame.payload, frame.payload + frame.length));
|
|
} else {
|
|
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
|
|
mPartialOpcode = frame.opcode;
|
|
}
|
|
break;
|
|
}
|
|
case CONTINUATION: {
|
|
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
|
|
if (frame.fin) {
|
|
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
|
|
recv(make_message(mPartial.begin(), mPartial.end()));
|
|
mPartial.clear();
|
|
}
|
|
break;
|
|
}
|
|
case PING: {
|
|
sendFrame({PONG, frame.payload, frame.length, true, true});
|
|
break;
|
|
}
|
|
case PONG: {
|
|
// TODO
|
|
break;
|
|
}
|
|
case CLOSE: {
|
|
close();
|
|
PLOG_INFO << "WebSocket closed";
|
|
changeState(State::Disconnected);
|
|
break;
|
|
}
|
|
default: {
|
|
close();
|
|
throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode));
|
|
}
|
|
}
|
|
}
|
|
|
|
bool WsTransport::sendFrame(const Frame &frame) {
|
|
byte buffer[14];
|
|
byte *cur = buffer;
|
|
|
|
*cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0));
|
|
|
|
if (frame.length < 0x7E) {
|
|
*cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0));
|
|
} else if (frame.length <= 0xFF) {
|
|
*cur++ = byte(0x7E | (frame.mask ? 0x80 : 0));
|
|
*reinterpret_cast<uint16_t *>(cur) = uint16_t(frame.length);
|
|
cur += 2;
|
|
} else {
|
|
*cur++ = byte(0x7F | (frame.mask ? 0x80 : 0));
|
|
*reinterpret_cast<uint64_t *>(cur) = uint64_t(frame.length);
|
|
cur += 8;
|
|
}
|
|
|
|
if (frame.mask) {
|
|
auto seed = system_clock::now().time_since_epoch().count();
|
|
random_bytes_engine generator(seed);
|
|
|
|
auto *maskingKey = cur;
|
|
std::generate(reinterpret_cast<uint8_t *>(maskingKey),
|
|
reinterpret_cast<uint8_t *>(maskingKey + 4), generator);
|
|
cur += 4;
|
|
|
|
for (size_t i = 0; i < frame.length; ++i)
|
|
frame.payload[i] ^= maskingKey[i % 4];
|
|
}
|
|
|
|
outgoing(make_message(buffer, cur)); // header
|
|
return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
|
|
}
|
|
|
|
} // namespace rtc
|
|
|
|
#endif
|