mirror of
https://github.com/mii443/libdatachannel.git
synced 2025-08-22 23:25:33 +00:00
Merge pull request #64 from paullouisageneau/websocket
Add optional WebSocket with the same API
This commit is contained in:
@ -6,6 +6,7 @@ project (libdatachannel
|
|||||||
|
|
||||||
option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
|
option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
|
||||||
option(USE_JUICE "Use libjuice instead of libnice" OFF)
|
option(USE_JUICE "Use libjuice instead of libnice" OFF)
|
||||||
|
option(RTC_ENABLE_WEBSOCKET "Build WebSocket support" ON)
|
||||||
|
|
||||||
if(USE_GNUTLS)
|
if(USE_GNUTLS)
|
||||||
option(USE_NETTLE "Use Nettle instead of OpenSSL in libjuice" ON)
|
option(USE_NETTLE "Use Nettle instead of OpenSSL in libjuice" ON)
|
||||||
@ -39,6 +40,14 @@ set(LIBDATACHANNEL_SOURCES
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
set(LIBDATACHANNEL_WEBSOCKET_SOURCES
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
|
||||||
|
)
|
||||||
|
|
||||||
set(LIBDATACHANNEL_HEADERS
|
set(LIBDATACHANNEL_HEADERS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/candidate.hpp
|
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/candidate.hpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp
|
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp
|
||||||
@ -55,6 +64,7 @@ set(LIBDATACHANNEL_HEADERS
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp
|
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h
|
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp
|
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/websocket.hpp
|
||||||
)
|
)
|
||||||
|
|
||||||
set(TESTS_SOURCES
|
set(TESTS_SOURCES
|
||||||
@ -89,26 +99,42 @@ endif()
|
|||||||
add_library(Usrsctp::Usrsctp ALIAS usrsctp)
|
add_library(Usrsctp::Usrsctp ALIAS usrsctp)
|
||||||
add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static)
|
add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static)
|
||||||
|
|
||||||
add_library(datachannel SHARED ${LIBDATACHANNEL_SOURCES})
|
if (RTC_ENABLE_WEBSOCKET)
|
||||||
|
add_library(datachannel SHARED
|
||||||
|
${LIBDATACHANNEL_SOURCES}
|
||||||
|
${LIBDATACHANNEL_WEBSOCKET_SOURCES})
|
||||||
|
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL
|
||||||
|
${LIBDATACHANNEL_SOURCES}
|
||||||
|
${LIBDATACHANNEL_WEBSOCKET_SOURCES})
|
||||||
|
target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=1)
|
||||||
|
target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=1)
|
||||||
|
else()
|
||||||
|
add_library(datachannel SHARED
|
||||||
|
${LIBDATACHANNEL_SOURCES})
|
||||||
|
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL
|
||||||
|
${LIBDATACHANNEL_SOURCES})
|
||||||
|
target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=0)
|
||||||
|
target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=0)
|
||||||
|
endif()
|
||||||
|
|
||||||
set_target_properties(datachannel PROPERTIES
|
set_target_properties(datachannel PROPERTIES
|
||||||
VERSION ${PROJECT_VERSION}
|
VERSION ${PROJECT_VERSION}
|
||||||
CXX_STANDARD 17)
|
CXX_STANDARD 17)
|
||||||
|
set_target_properties(datachannel-static PROPERTIES
|
||||||
|
VERSION ${PROJECT_VERSION}
|
||||||
|
CXX_STANDARD 17)
|
||||||
|
|
||||||
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
|
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
|
||||||
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
|
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||||
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
|
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
|
||||||
target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
|
|
||||||
|
|
||||||
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES})
|
|
||||||
set_target_properties(datachannel-static PROPERTIES
|
|
||||||
VERSION ${PROJECT_VERSION}
|
|
||||||
CXX_STANDARD 17)
|
|
||||||
|
|
||||||
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
|
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
|
||||||
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
|
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||||
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
|
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
|
||||||
|
|
||||||
|
target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
|
||||||
target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic)
|
target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic)
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
|
1
Jamfile
1
Jamfile
@ -10,6 +10,7 @@ lib libdatachannel
|
|||||||
<cxxstd>17
|
<cxxstd>17
|
||||||
<include>./include/rtc
|
<include>./include/rtc
|
||||||
<define>USE_JUICE=1
|
<define>USE_JUICE=1
|
||||||
|
<define>RTC_ENABLE_WEBSOCKET=0
|
||||||
<library>/libdatachannel//usrsctp
|
<library>/libdatachannel//usrsctp
|
||||||
<library>/libdatachannel//juice
|
<library>/libdatachannel//juice
|
||||||
<library>/libdatachannel//plog
|
<library>/libdatachannel//plog
|
||||||
|
8
Makefile
8
Makefile
@ -38,6 +38,14 @@ else
|
|||||||
LIBS+=glib-2.0 gobject-2.0 nice
|
LIBS+=glib-2.0 gobject-2.0 nice
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
RTC_ENABLE_WEBSOCKET ?= 1
|
||||||
|
ifneq ($(RTC_ENABLE_WEBSOCKET), 0)
|
||||||
|
CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=1
|
||||||
|
else
|
||||||
|
CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=0
|
||||||
|
endif
|
||||||
|
|
||||||
|
|
||||||
INCLUDES+=$(shell pkg-config --cflags $(LIBS))
|
INCLUDES+=$(shell pkg-config --cflags $(LIBS))
|
||||||
LDLIBS+=$(LOCALLIBS) $(shell pkg-config --libs $(LIBS))
|
LDLIBS+=$(LOCALLIBS) $(shell pkg-config --libs $(LIBS))
|
||||||
|
|
||||||
|
@ -82,7 +82,6 @@ private:
|
|||||||
std::atomic<bool> mIsClosed = false;
|
std::atomic<bool> mIsClosed = false;
|
||||||
|
|
||||||
Queue<message_ptr> mRecvQueue;
|
Queue<message_ptr> mRecvQueue;
|
||||||
std::atomic<size_t> mRecvAmount = 0;
|
|
||||||
|
|
||||||
friend class PeerConnection;
|
friend class PeerConnection;
|
||||||
};
|
};
|
||||||
|
@ -19,6 +19,10 @@
|
|||||||
#ifndef RTC_INCLUDE_H
|
#ifndef RTC_INCLUDE_H
|
||||||
#define RTC_INCLUDE_H
|
#define RTC_INCLUDE_H
|
||||||
|
|
||||||
|
#ifndef RTC_ENABLE_WEBSOCKET
|
||||||
|
#define RTC_ENABLE_WEBSOCKET 1
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#ifndef _WIN32_WINNT
|
#ifndef _WIN32_WINNT
|
||||||
#define _WIN32_WINNT 0x0602
|
#define _WIN32_WINNT 0x0602
|
||||||
@ -56,10 +60,21 @@ const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
|
|||||||
const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
|
const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
|
||||||
const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
|
const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
|
||||||
|
|
||||||
|
// overloaded helper
|
||||||
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
|
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
|
||||||
template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
|
template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
|
||||||
|
|
||||||
|
// weak_ptr bind helper
|
||||||
|
template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
|
||||||
|
return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
|
||||||
|
using result_type = typename decltype(bound)::result_type;
|
||||||
|
if (auto shared_this = weak_this.lock())
|
||||||
|
return bound(args...);
|
||||||
|
else
|
||||||
|
return (result_type) false;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename... P> class synchronized_callback {
|
template <typename... P> class synchronized_callback {
|
||||||
public:
|
public:
|
||||||
synchronized_callback() = default;
|
synchronized_callback() = default;
|
||||||
|
@ -30,6 +30,7 @@ namespace rtc {
|
|||||||
struct Message : binary {
|
struct Message : binary {
|
||||||
enum Type { Binary, String, Control, Reset };
|
enum Type { Binary, String, Control, Reset };
|
||||||
|
|
||||||
|
Message(const Message &message) = default;
|
||||||
Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {}
|
Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {}
|
||||||
|
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
|
@ -98,8 +98,6 @@ public:
|
|||||||
std::string connectionInfo;
|
std::string connectionInfo;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
init_token mInitToken = Init::Token();
|
|
||||||
|
|
||||||
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
|
std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
|
||||||
std::shared_ptr<DtlsTransport> initDtlsTransport();
|
std::shared_ptr<DtlsTransport> initDtlsTransport();
|
||||||
std::shared_ptr<SctpTransport> initSctpTransport();
|
std::shared_ptr<SctpTransport> initSctpTransport();
|
||||||
@ -130,6 +128,8 @@ private:
|
|||||||
const Configuration mConfig;
|
const Configuration mConfig;
|
||||||
const std::shared_ptr<Certificate> mCertificate;
|
const std::shared_ptr<Certificate> mCertificate;
|
||||||
|
|
||||||
|
init_token mInitToken = Init::Token();
|
||||||
|
|
||||||
std::optional<Description> mLocalDescription, mRemoteDescription;
|
std::optional<Description> mLocalDescription, mRemoteDescription;
|
||||||
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
|
mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ public:
|
|||||||
void push(T element);
|
void push(T element);
|
||||||
std::optional<T> pop();
|
std::optional<T> pop();
|
||||||
std::optional<T> peek();
|
std::optional<T> peek();
|
||||||
|
std::optional<T> exchange(T element);
|
||||||
bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
|
bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -118,6 +119,16 @@ template <typename T> std::optional<T> Queue<T>::peek() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T> std::optional<T> Queue<T>::exchange(T element) {
|
||||||
|
std::unique_lock lock(mMutex);
|
||||||
|
if (!mQueue.empty()) {
|
||||||
|
std::swap(mQueue.front(), element);
|
||||||
|
return std::optional<T>{element};
|
||||||
|
} else {
|
||||||
|
return nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool Queue<T>::wait(const std::optional<std::chrono::milliseconds> &duration) {
|
bool Queue<T>::wait(const std::optional<std::chrono::milliseconds> &duration) {
|
||||||
std::unique_lock lock(mMutex);
|
std::unique_lock lock(mMutex);
|
||||||
|
@ -27,6 +27,10 @@ extern "C" {
|
|||||||
|
|
||||||
// libdatachannel C API
|
// libdatachannel C API
|
||||||
|
|
||||||
|
#ifndef RTC_ENABLE_WEBSOCKET
|
||||||
|
#define RTC_ENABLE_WEBSOCKET 1
|
||||||
|
#endif
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
RTC_NEW = 0,
|
RTC_NEW = 0,
|
||||||
RTC_CONNECTING = 1,
|
RTC_CONNECTING = 1,
|
||||||
@ -42,8 +46,7 @@ typedef enum {
|
|||||||
RTC_GATHERING_COMPLETE = 2
|
RTC_GATHERING_COMPLETE = 2
|
||||||
} rtcGatheringState;
|
} rtcGatheringState;
|
||||||
|
|
||||||
// Don't change, it must match plog severity
|
typedef enum { // Don't change, it must match plog severity
|
||||||
typedef enum {
|
|
||||||
RTC_LOG_NONE = 0,
|
RTC_LOG_NONE = 0,
|
||||||
RTC_LOG_FATAL = 1,
|
RTC_LOG_FATAL = 1,
|
||||||
RTC_LOG_ERROR = 2,
|
RTC_LOG_ERROR = 2,
|
||||||
@ -76,10 +79,10 @@ typedef void (*availableCallbackFunc)(void *ptr);
|
|||||||
void rtcInitLogger(rtcLogLevel level);
|
void rtcInitLogger(rtcLogLevel level);
|
||||||
|
|
||||||
// User pointer
|
// User pointer
|
||||||
void rtcSetUserPointer(int i, void *ptr);
|
void rtcSetUserPointer(int id, void *ptr);
|
||||||
|
|
||||||
// PeerConnection
|
// PeerConnection
|
||||||
int rtcCreatePeerConnection(const rtcConfiguration *config);
|
int rtcCreatePeerConnection(const rtcConfiguration *config); // returns pc id
|
||||||
int rtcDeletePeerConnection(int pc);
|
int rtcDeletePeerConnection(int pc);
|
||||||
|
|
||||||
int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb);
|
int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb);
|
||||||
@ -95,24 +98,32 @@ int rtcGetLocalAddress(int pc, char *buffer, int size);
|
|||||||
int rtcGetRemoteAddress(int pc, char *buffer, int size);
|
int rtcGetRemoteAddress(int pc, char *buffer, int size);
|
||||||
|
|
||||||
// DataChannel
|
// DataChannel
|
||||||
int rtcCreateDataChannel(int pc, const char *label);
|
int rtcCreateDataChannel(int pc, const char *label); // returns dc id
|
||||||
int rtcDeleteDataChannel(int dc);
|
int rtcDeleteDataChannel(int dc);
|
||||||
|
|
||||||
int rtcGetDataChannelLabel(int dc, char *buffer, int size);
|
int rtcGetDataChannelLabel(int dc, char *buffer, int size);
|
||||||
int rtcSetOpenCallback(int dc, openCallbackFunc cb);
|
|
||||||
int rtcSetClosedCallback(int dc, closedCallbackFunc cb);
|
|
||||||
int rtcSetErrorCallback(int dc, errorCallbackFunc cb);
|
|
||||||
int rtcSetMessageCallback(int dc, messageCallbackFunc cb);
|
|
||||||
int rtcSendMessage(int dc, const char *data, int size);
|
|
||||||
|
|
||||||
int rtcGetBufferedAmount(int dc); // total size buffered to send
|
// WebSocket
|
||||||
int rtcSetBufferedAmountLowThreshold(int dc, int amount);
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb);
|
int rtcCreateWebSocket(const char *url); // returns ws id
|
||||||
|
int rtcDeleteWebsocket(int ws);
|
||||||
|
#endif
|
||||||
|
|
||||||
// DataChannel extended API
|
// DataChannel and WebSocket common API
|
||||||
int rtcGetAvailableAmount(int dc); // total size available to receive
|
int rtcSetOpenCallback(int id, openCallbackFunc cb);
|
||||||
int rtcSetAvailableCallback(int dc, availableCallbackFunc cb);
|
int rtcSetClosedCallback(int id, closedCallbackFunc cb);
|
||||||
int rtcReceiveMessage(int dc, char *buffer, int *size);
|
int rtcSetErrorCallback(int id, errorCallbackFunc cb);
|
||||||
|
int rtcSetMessageCallback(int id, messageCallbackFunc cb);
|
||||||
|
int rtcSendMessage(int id, const char *data, int size);
|
||||||
|
|
||||||
|
int rtcGetBufferedAmount(int id); // total size buffered to send
|
||||||
|
int rtcSetBufferedAmountLowThreshold(int id, int amount);
|
||||||
|
int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb);
|
||||||
|
|
||||||
|
// DataChannel and WebSocket common extended API
|
||||||
|
int rtcGetAvailableAmount(int id); // total size available to receive
|
||||||
|
int rtcSetAvailableCallback(int id, availableCallbackFunc cb);
|
||||||
|
int rtcReceiveMessage(int id, char *buffer, int *size);
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
void rtcCleanup();
|
void rtcCleanup();
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
//
|
//
|
||||||
#include "datachannel.hpp"
|
#include "datachannel.hpp"
|
||||||
#include "peerconnection.hpp"
|
#include "peerconnection.hpp"
|
||||||
|
#include "websocket.hpp"
|
||||||
|
|
||||||
// C API
|
// C API
|
||||||
#include "rtc.h"
|
#include "rtc.h"
|
||||||
|
95
include/rtc/websocket.hpp
Normal file
95
include/rtc/websocket.hpp
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef RTC_WEBSOCKET_H
|
||||||
|
#define RTC_WEBSOCKET_H
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "channel.hpp"
|
||||||
|
#include "include.hpp"
|
||||||
|
#include "init.hpp"
|
||||||
|
#include "message.hpp"
|
||||||
|
#include "queue.hpp"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <optional>
|
||||||
|
#include <thread>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
class TcpTransport;
|
||||||
|
class TlsTransport;
|
||||||
|
class WsTransport;
|
||||||
|
|
||||||
|
class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
|
||||||
|
public:
|
||||||
|
enum class State : int {
|
||||||
|
Connecting = 0,
|
||||||
|
Open = 1,
|
||||||
|
Closing = 2,
|
||||||
|
Closed = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
WebSocket();
|
||||||
|
WebSocket(const string &url);
|
||||||
|
~WebSocket();
|
||||||
|
|
||||||
|
State readyState() const;
|
||||||
|
|
||||||
|
void open(const string &url);
|
||||||
|
void close() override;
|
||||||
|
bool send(const std::variant<binary, string> &data) override;
|
||||||
|
|
||||||
|
bool isOpen() const override;
|
||||||
|
bool isClosed() const override;
|
||||||
|
size_t maxMessageSize() const override;
|
||||||
|
|
||||||
|
// Extended API
|
||||||
|
std::optional<std::variant<binary, string>> receive() override;
|
||||||
|
size_t availableAmount() const override; // total size available to receive
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool changeState(State state);
|
||||||
|
void remoteClose();
|
||||||
|
bool outgoing(mutable_message_ptr message);
|
||||||
|
void incoming(message_ptr message);
|
||||||
|
|
||||||
|
std::shared_ptr<TcpTransport> initTcpTransport();
|
||||||
|
std::shared_ptr<TlsTransport> initTlsTransport();
|
||||||
|
std::shared_ptr<WsTransport> initWsTransport();
|
||||||
|
void closeTransports();
|
||||||
|
|
||||||
|
init_token mInitToken = Init::Token();
|
||||||
|
|
||||||
|
std::shared_ptr<TcpTransport> mTcpTransport;
|
||||||
|
std::shared_ptr<TlsTransport> mTlsTransport;
|
||||||
|
std::shared_ptr<WsTransport> mWsTransport;
|
||||||
|
std::recursive_mutex mInitMutex;
|
||||||
|
|
||||||
|
string mScheme, mHost, mHostname, mService, mPath;
|
||||||
|
std::atomic<State> mState = State::Closed;
|
||||||
|
|
||||||
|
Queue<message_ptr> mRecvQueue;
|
||||||
|
};
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // RTC_WEBSOCKET_H
|
65
src/base64.cpp
Normal file
65
src/base64.cpp
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "base64.hpp"
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
using std::to_integer;
|
||||||
|
|
||||||
|
string to_base64(const binary &data) {
|
||||||
|
static const char tab[] =
|
||||||
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||||
|
|
||||||
|
string out;
|
||||||
|
out.reserve(3 * ((data.size() + 3) / 4));
|
||||||
|
int i = 0;
|
||||||
|
while (data.size() - i >= 3) {
|
||||||
|
auto d0 = to_integer<uint8_t>(data[i]);
|
||||||
|
auto d1 = to_integer<uint8_t>(data[i + 1]);
|
||||||
|
auto d2 = to_integer<uint8_t>(data[i + 2]);
|
||||||
|
out += tab[d0 >> 2];
|
||||||
|
out += tab[((d0 & 3) << 4) | (d1 >> 4)];
|
||||||
|
out += tab[((d1 & 0x0F) << 2) | (d2 >> 6)];
|
||||||
|
out += tab[d2 & 0x3F];
|
||||||
|
i += 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
int left = data.size() - i;
|
||||||
|
if (left) {
|
||||||
|
auto d0 = to_integer<uint8_t>(data[i]);
|
||||||
|
out += tab[d0 >> 2];
|
||||||
|
if (left == 1) {
|
||||||
|
out += tab[(d0 & 3) << 4];
|
||||||
|
out += '=';
|
||||||
|
} else { // left == 2
|
||||||
|
auto d1 = to_integer<uint8_t>(data[i + 1]);
|
||||||
|
out += tab[((d0 & 3) << 4) | (d1 >> 4)];
|
||||||
|
out += tab[(d1 & 0x0F) << 2];
|
||||||
|
}
|
||||||
|
out += '=';
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
34
src/base64.hpp
Normal file
34
src/base64.hpp
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef RTC_BASE64_H
|
||||||
|
#define RTC_BASE64_H
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "include.hpp"
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
string to_base64(const binary &data);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
@ -214,6 +214,9 @@ bool DataChannel::outgoing(mutable_message_ptr message) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DataChannel::incoming(message_ptr message) {
|
void DataChannel::incoming(message_ptr message) {
|
||||||
|
if (!message)
|
||||||
|
return;
|
||||||
|
|
||||||
switch (message->type) {
|
switch (message->type) {
|
||||||
case Message::Control: {
|
case Message::Control: {
|
||||||
auto raw = reinterpret_cast<const uint8_t *>(message->data());
|
auto raw = reinterpret_cast<const uint8_t *>(message->data());
|
||||||
|
@ -18,9 +18,7 @@
|
|||||||
|
|
||||||
#include "dtlstransport.hpp"
|
#include "dtlstransport.hpp"
|
||||||
#include "icetransport.hpp"
|
#include "icetransport.hpp"
|
||||||
#include "message.hpp"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
@ -64,11 +62,9 @@ void DtlsTransport::Cleanup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
|
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
|
||||||
verifier_callback verifierCallback,
|
verifier_callback verifierCallback, state_callback stateChangeCallback)
|
||||||
state_callback stateChangeCallback)
|
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
|
||||||
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
|
mVerifierCallback(std::move(verifierCallback)) {
|
||||||
mVerifierCallback(std::move(verifierCallback)),
|
|
||||||
mStateChangeCallback(std::move(stateChangeCallback)) {
|
|
||||||
|
|
||||||
PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
|
PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
|
||||||
|
|
||||||
@ -76,13 +72,14 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
|
|||||||
unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
|
unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
|
||||||
check_gnutls(gnutls_init(&mSession, flags));
|
check_gnutls(gnutls_init(&mSession, flags));
|
||||||
|
|
||||||
|
try {
|
||||||
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
|
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
|
||||||
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
|
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
|
||||||
// See https://tools.ietf.org/html/rfc8261#section-5
|
// See https://tools.ietf.org/html/rfc8261#section-5
|
||||||
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
|
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
|
||||||
const char *err_pos = NULL;
|
const char *err_pos = NULL;
|
||||||
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
|
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
|
||||||
"Unable to set TLS priorities");
|
"Failed to set TLS priorities");
|
||||||
|
|
||||||
gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
|
gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
|
||||||
check_gnutls(
|
check_gnutls(
|
||||||
@ -101,6 +98,11 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
|
|||||||
|
|
||||||
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
|
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
|
||||||
registerIncoming();
|
registerIncoming();
|
||||||
|
|
||||||
|
} catch (...) {
|
||||||
|
gnutls_deinit(mSession);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DtlsTransport::~DtlsTransport() {
|
DtlsTransport::~DtlsTransport() {
|
||||||
@ -109,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
|
|||||||
gnutls_deinit(mSession);
|
gnutls_deinit(mSession);
|
||||||
}
|
}
|
||||||
|
|
||||||
DtlsTransport::State DtlsTransport::state() const { return mState; }
|
|
||||||
|
|
||||||
bool DtlsTransport::stop() {
|
bool DtlsTransport::stop() {
|
||||||
if (!Transport::stop())
|
if (!Transport::stop())
|
||||||
return false;
|
return false;
|
||||||
@ -122,7 +122,7 @@ bool DtlsTransport::stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool DtlsTransport::send(message_ptr message) {
|
bool DtlsTransport::send(message_ptr message) {
|
||||||
if (!message || mState != State::Connected)
|
if (!message || state() != State::Connected)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
PLOG_VERBOSE << "Send size=" << message->size();
|
PLOG_VERBOSE << "Send size=" << message->size();
|
||||||
@ -148,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
|
|||||||
mIncomingQueue.push(message);
|
mIncomingQueue.push(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DtlsTransport::changeState(State state) {
|
|
||||||
if (mState.exchange(state) != state)
|
|
||||||
mStateChangeCallback(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DtlsTransport::runRecvLoop() {
|
void DtlsTransport::runRecvLoop() {
|
||||||
const size_t maxMtu = 4096;
|
const size_t maxMtu = 4096;
|
||||||
|
|
||||||
@ -169,7 +164,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
throw std::runtime_error("MTU is too low");
|
throw std::runtime_error("MTU is too low");
|
||||||
|
|
||||||
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
|
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
|
||||||
!check_gnutls(ret, "TLS handshake failed"));
|
!check_gnutls(ret, "DTLS handshake failed"));
|
||||||
|
|
||||||
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
|
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
|
||||||
// See https://tools.ietf.org/html/rfc8261#section-5
|
// See https://tools.ietf.org/html/rfc8261#section-5
|
||||||
@ -183,7 +178,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
|
|
||||||
// Receive loop
|
// Receive loop
|
||||||
try {
|
try {
|
||||||
PLOG_INFO << "DTLS handshake done";
|
PLOG_INFO << "DTLS handshake finished";
|
||||||
changeState(State::Connected);
|
changeState(State::Connected);
|
||||||
|
|
||||||
const size_t bufferSize = maxMtu;
|
const size_t bufferSize = maxMtu;
|
||||||
@ -218,7 +213,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
|
|
||||||
gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
|
gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
|
||||||
|
|
||||||
PLOG_INFO << "DTLS disconnected";
|
PLOG_INFO << "DTLS closed";
|
||||||
changeState(State::Disconnected);
|
changeState(State::Disconnected);
|
||||||
recv(nullptr);
|
recv(nullptr);
|
||||||
}
|
}
|
||||||
@ -341,7 +336,7 @@ void DtlsTransport::Init() {
|
|||||||
if (!BioMethods) {
|
if (!BioMethods) {
|
||||||
BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
|
BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
|
||||||
if (!BioMethods)
|
if (!BioMethods)
|
||||||
throw std::runtime_error("Unable to BIO methods for DTLS writer");
|
throw std::runtime_error("Failed to create BIO methods for DTLS writer");
|
||||||
BIO_meth_set_create(BioMethods, BioMethodNew);
|
BIO_meth_set_create(BioMethods, BioMethodNew);
|
||||||
BIO_meth_set_destroy(BioMethods, BioMethodFree);
|
BIO_meth_set_destroy(BioMethods, BioMethodFree);
|
||||||
BIO_meth_set_write(BioMethods, BioMethodWrite);
|
BIO_meth_set_write(BioMethods, BioMethodWrite);
|
||||||
@ -358,17 +353,17 @@ void DtlsTransport::Cleanup() {
|
|||||||
|
|
||||||
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
|
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
|
||||||
verifier_callback verifierCallback, state_callback stateChangeCallback)
|
verifier_callback verifierCallback, state_callback stateChangeCallback)
|
||||||
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
|
: Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
|
||||||
mVerifierCallback(std::move(verifierCallback)),
|
mVerifierCallback(std::move(verifierCallback)) {
|
||||||
mStateChangeCallback(std::move(stateChangeCallback)) {
|
|
||||||
|
|
||||||
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
|
PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
|
||||||
|
|
||||||
|
try {
|
||||||
if (!(mCtx = SSL_CTX_new(DTLS_method())))
|
if (!(mCtx = SSL_CTX_new(DTLS_method())))
|
||||||
throw std::runtime_error("Unable to create SSL context");
|
throw std::runtime_error("Failed to create SSL context");
|
||||||
|
|
||||||
check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
|
check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
|
||||||
"Unable to set SSL priorities");
|
"Failed to set SSL priorities");
|
||||||
|
|
||||||
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
|
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
|
||||||
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
|
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
|
||||||
@ -389,7 +384,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
|
|||||||
check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
|
check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
|
||||||
|
|
||||||
if (!(mSsl = SSL_new(mCtx)))
|
if (!(mSsl = SSL_new(mCtx)))
|
||||||
throw std::runtime_error("Unable to create SSL instance");
|
throw std::runtime_error("Failed to create SSL instance");
|
||||||
|
|
||||||
SSL_set_ex_data(mSsl, TransportExIndex, this);
|
SSL_set_ex_data(mSsl, TransportExIndex, this);
|
||||||
|
|
||||||
@ -399,7 +394,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
|
|||||||
SSL_set_accept_state(mSsl);
|
SSL_set_accept_state(mSsl);
|
||||||
|
|
||||||
if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods)))
|
if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods)))
|
||||||
throw std::runtime_error("Unable to create BIO");
|
throw std::runtime_error("Failed to create BIO");
|
||||||
|
|
||||||
BIO_set_mem_eof_return(mInBio, BIO_EOF);
|
BIO_set_mem_eof_return(mInBio, BIO_EOF);
|
||||||
BIO_set_data(mOutBio, this);
|
BIO_set_data(mOutBio, this);
|
||||||
@ -412,6 +407,14 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
|
|||||||
|
|
||||||
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
|
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
|
||||||
registerIncoming();
|
registerIncoming();
|
||||||
|
|
||||||
|
} catch (...) {
|
||||||
|
if (mSsl)
|
||||||
|
SSL_free(mSsl);
|
||||||
|
if (mCtx)
|
||||||
|
SSL_CTX_free(mCtx);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DtlsTransport::~DtlsTransport() {
|
DtlsTransport::~DtlsTransport() {
|
||||||
@ -432,18 +435,14 @@ bool DtlsTransport::stop() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
DtlsTransport::State DtlsTransport::state() const { return mState; }
|
|
||||||
|
|
||||||
bool DtlsTransport::send(message_ptr message) {
|
bool DtlsTransport::send(message_ptr message) {
|
||||||
if (!message || mState != State::Connected)
|
if (!message || state() != State::Connected)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
PLOG_VERBOSE << "Send size=" << message->size();
|
PLOG_VERBOSE << "Send size=" << message->size();
|
||||||
|
|
||||||
int ret = SSL_write(mSsl, message->data(), message->size());
|
int ret = SSL_write(mSsl, message->data(), message->size());
|
||||||
if (!check_openssl_ret(mSsl, ret))
|
return check_openssl_ret(mSsl, ret);
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DtlsTransport::incoming(message_ptr message) {
|
void DtlsTransport::incoming(message_ptr message) {
|
||||||
@ -456,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
|
|||||||
mIncomingQueue.push(message);
|
mIncomingQueue.push(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DtlsTransport::changeState(State state) {
|
|
||||||
if (mState.exchange(state) != state)
|
|
||||||
mStateChangeCallback(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DtlsTransport::runRecvLoop() {
|
void DtlsTransport::runRecvLoop() {
|
||||||
const size_t maxMtu = 4096;
|
const size_t maxMtu = 4096;
|
||||||
try {
|
try {
|
||||||
@ -479,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
auto message = *mIncomingQueue.pop();
|
auto message = *mIncomingQueue.pop();
|
||||||
BIO_write(mInBio, message->data(), message->size());
|
BIO_write(mInBio, message->data(), message->size());
|
||||||
|
|
||||||
if (mState == State::Connecting) {
|
if (state() == State::Connecting) {
|
||||||
// Continue the handshake
|
// Continue the handshake
|
||||||
int ret = SSL_do_handshake(mSsl);
|
int ret = SSL_do_handshake(mSsl);
|
||||||
if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
|
if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
|
||||||
@ -490,7 +484,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
// MTU See https://tools.ietf.org/html/rfc8261#section-5
|
// MTU See https://tools.ietf.org/html/rfc8261#section-5
|
||||||
SSL_set_mtu(mSsl, maxMtu + 1);
|
SSL_set_mtu(mSsl, maxMtu + 1);
|
||||||
|
|
||||||
PLOG_INFO << "DTLS handshake done";
|
PLOG_INFO << "DTLS handshake finished";
|
||||||
changeState(State::Connected);
|
changeState(State::Connected);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -504,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
|
|
||||||
// No more messages pending, retransmit and rearm timeout if connecting
|
// No more messages pending, retransmit and rearm timeout if connecting
|
||||||
std::optional<milliseconds> duration;
|
std::optional<milliseconds> duration;
|
||||||
if (mState == State::Connecting) {
|
if (state() == State::Connecting) {
|
||||||
// Warning: This function breaks the usual return value convention
|
// Warning: This function breaks the usual return value convention
|
||||||
int ret = DTLSv1_handle_timeout(mSsl);
|
int ret = DTLSv1_handle_timeout(mSsl);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
@ -514,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct timeval timeout = {};
|
struct timeval timeout = {};
|
||||||
if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
|
if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
|
||||||
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
|
duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
|
||||||
// Also handle handshake timeout manually because OpenSSL actually doesn't...
|
// Also handle handshake timeout manually because OpenSSL actually doesn't...
|
||||||
// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
|
// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
|
||||||
@ -535,8 +529,8 @@ void DtlsTransport::runRecvLoop() {
|
|||||||
PLOG_ERROR << "DTLS recv: " << e.what();
|
PLOG_ERROR << "DTLS recv: " << e.what();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mState == State::Connected) {
|
if (state() == State::Connected) {
|
||||||
PLOG_INFO << "DTLS disconnected";
|
PLOG_INFO << "DTLS closed";
|
||||||
changeState(State::Disconnected);
|
changeState(State::Disconnected);
|
||||||
recv(nullptr);
|
recv(nullptr);
|
||||||
} else {
|
} else {
|
||||||
|
@ -46,33 +46,25 @@ public:
|
|||||||
static void Init();
|
static void Init();
|
||||||
static void Cleanup();
|
static void Cleanup();
|
||||||
|
|
||||||
enum class State { Disconnected, Connecting, Connected, Failed };
|
|
||||||
|
|
||||||
using verifier_callback = std::function<bool(const std::string &fingerprint)>;
|
using verifier_callback = std::function<bool(const std::string &fingerprint)>;
|
||||||
using state_callback = std::function<void(State state)>;
|
|
||||||
|
|
||||||
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
|
DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
|
||||||
verifier_callback verifierCallback, state_callback stateChangeCallback);
|
verifier_callback verifierCallback, state_callback stateChangeCallback);
|
||||||
~DtlsTransport();
|
~DtlsTransport();
|
||||||
|
|
||||||
State state() const;
|
|
||||||
|
|
||||||
bool stop() override;
|
bool stop() override;
|
||||||
bool send(message_ptr message) override; // false if dropped
|
bool send(message_ptr message) override; // false if dropped
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void incoming(message_ptr message) override;
|
void incoming(message_ptr message) override;
|
||||||
void changeState(State state);
|
|
||||||
void runRecvLoop();
|
void runRecvLoop();
|
||||||
|
|
||||||
const std::shared_ptr<Certificate> mCertificate;
|
const std::shared_ptr<Certificate> mCertificate;
|
||||||
|
|
||||||
Queue<message_ptr> mIncomingQueue;
|
Queue<message_ptr> mIncomingQueue;
|
||||||
std::atomic<State> mState;
|
|
||||||
std::thread mRecvThread;
|
std::thread mRecvThread;
|
||||||
|
|
||||||
verifier_callback mVerifierCallback;
|
verifier_callback mVerifierCallback;
|
||||||
state_callback mStateChangeCallback;
|
|
||||||
|
|
||||||
#if USE_GNUTLS
|
#if USE_GNUTLS
|
||||||
gnutls_session_t mSession;
|
gnutls_session_t mSession;
|
||||||
@ -82,8 +74,8 @@ private:
|
|||||||
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
|
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
|
||||||
static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
|
static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
|
||||||
#else
|
#else
|
||||||
SSL_CTX *mCtx;
|
SSL_CTX *mCtx = NULL;
|
||||||
SSL *mSsl;
|
SSL *mSsl = NULL;
|
||||||
BIO *mInBio, *mOutBio;
|
BIO *mInBio, *mOutBio;
|
||||||
|
|
||||||
static BIO_METHOD *BioMethods;
|
static BIO_METHOD *BioMethods;
|
||||||
|
@ -48,9 +48,8 @@ namespace rtc {
|
|||||||
IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
||||||
candidate_callback candidateCallback, state_callback stateChangeCallback,
|
candidate_callback candidateCallback, state_callback stateChangeCallback,
|
||||||
gathering_state_callback gatheringStateChangeCallback)
|
gathering_state_callback gatheringStateChangeCallback)
|
||||||
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
|
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
|
||||||
mCandidateCallback(std::move(candidateCallback)),
|
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
|
||||||
mStateChangeCallback(std::move(stateChangeCallback)),
|
|
||||||
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
|
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
|
||||||
mAgent(nullptr, nullptr) {
|
mAgent(nullptr, nullptr) {
|
||||||
|
|
||||||
@ -84,6 +83,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
|||||||
mStunService = server.service;
|
mStunService = server.service;
|
||||||
jconfig.stun_server_host = mStunHostname.c_str();
|
jconfig.stun_server_host = mStunHostname.c_str();
|
||||||
jconfig.stun_server_port = std::stoul(mStunService);
|
jconfig.stun_server_port = std::stoul(mStunService);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,8 +108,6 @@ bool IceTransport::stop() {
|
|||||||
|
|
||||||
Description::Role IceTransport::role() const { return mRole; }
|
Description::Role IceTransport::role() const { return mRole; }
|
||||||
|
|
||||||
IceTransport::State IceTransport::state() const { return mState; }
|
|
||||||
|
|
||||||
Description IceTransport::getLocalDescription(Description::Type type) const {
|
Description IceTransport::getLocalDescription(Description::Type type) const {
|
||||||
char sdp[JUICE_MAX_SDP_STRING_LEN];
|
char sdp[JUICE_MAX_SDP_STRING_LEN];
|
||||||
if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
|
if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
|
||||||
@ -161,7 +159,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool IceTransport::send(message_ptr message) {
|
bool IceTransport::send(message_ptr message) {
|
||||||
if (!message || (mState != State::Connected && mState != State::Completed))
|
auto s = state();
|
||||||
|
if (!message || (s != State::Connected && s != State::Completed))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
PLOG_VERBOSE << "Send size=" << message->size();
|
PLOG_VERBOSE << "Send size=" << message->size();
|
||||||
@ -173,18 +172,29 @@ bool IceTransport::outgoing(message_ptr message) {
|
|||||||
message->size()) >= 0;
|
message->size()) >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void IceTransport::changeState(State state) {
|
|
||||||
if (mState.exchange(state) != state)
|
|
||||||
mStateChangeCallback(mState);
|
|
||||||
}
|
|
||||||
|
|
||||||
void IceTransport::changeGatheringState(GatheringState state) {
|
void IceTransport::changeGatheringState(GatheringState state) {
|
||||||
if (mGatheringState.exchange(state) != state)
|
if (mGatheringState.exchange(state) != state)
|
||||||
mGatheringStateChangeCallback(mGatheringState);
|
mGatheringStateChangeCallback(mGatheringState);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IceTransport::processStateChange(unsigned int state) {
|
void IceTransport::processStateChange(unsigned int state) {
|
||||||
changeState(static_cast<State>(state));
|
switch (state) {
|
||||||
|
case JUICE_STATE_DISCONNECTED:
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
break;
|
||||||
|
case JUICE_STATE_CONNECTING:
|
||||||
|
changeState(State::Connecting);
|
||||||
|
break;
|
||||||
|
case JUICE_STATE_CONNECTED:
|
||||||
|
changeState(State::Connected);
|
||||||
|
break;
|
||||||
|
case JUICE_STATE_COMPLETED:
|
||||||
|
changeState(State::Completed);
|
||||||
|
break;
|
||||||
|
case JUICE_STATE_FAILED:
|
||||||
|
changeState(State::Failed);
|
||||||
|
break;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void IceTransport::processCandidate(const string &candidate) {
|
void IceTransport::processCandidate(const string &candidate) {
|
||||||
@ -263,9 +273,8 @@ namespace rtc {
|
|||||||
IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
IceTransport::IceTransport(const Configuration &config, Description::Role role,
|
||||||
candidate_callback candidateCallback, state_callback stateChangeCallback,
|
candidate_callback candidateCallback, state_callback stateChangeCallback,
|
||||||
gathering_state_callback gatheringStateChangeCallback)
|
gathering_state_callback gatheringStateChangeCallback)
|
||||||
: mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
|
: Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
|
||||||
mCandidateCallback(std::move(candidateCallback)),
|
mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
|
||||||
mStateChangeCallback(std::move(stateChangeCallback)),
|
|
||||||
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
|
mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
|
||||||
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
|
mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
|
||||||
|
|
||||||
@ -457,8 +466,6 @@ bool IceTransport::stop() {
|
|||||||
|
|
||||||
Description::Role IceTransport::role() const { return mRole; }
|
Description::Role IceTransport::role() const { return mRole; }
|
||||||
|
|
||||||
IceTransport::State IceTransport::state() const { return mState; }
|
|
||||||
|
|
||||||
Description IceTransport::getLocalDescription(Description::Type type) const {
|
Description IceTransport::getLocalDescription(Description::Type type) const {
|
||||||
// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
|
// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
|
||||||
// role, and the other MUST take the controlled role.
|
// role, and the other MUST take the controlled role.
|
||||||
@ -529,7 +536,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool IceTransport::send(message_ptr message) {
|
bool IceTransport::send(message_ptr message) {
|
||||||
if (!message || (mState != State::Connected && mState != State::Completed))
|
auto s = state();
|
||||||
|
if (!message || (s != State::Connected && s != State::Completed))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
PLOG_VERBOSE << "Send size=" << message->size();
|
PLOG_VERBOSE << "Send size=" << message->size();
|
||||||
@ -541,11 +549,6 @@ bool IceTransport::outgoing(message_ptr message) {
|
|||||||
reinterpret_cast<const char *>(message->data())) >= 0;
|
reinterpret_cast<const char *>(message->data())) >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void IceTransport::changeState(State state) {
|
|
||||||
if (mState.exchange(state) != state)
|
|
||||||
mStateChangeCallback(mState);
|
|
||||||
}
|
|
||||||
|
|
||||||
void IceTransport::changeGatheringState(GatheringState state) {
|
void IceTransport::changeGatheringState(GatheringState state) {
|
||||||
if (mGatheringState.exchange(state) != state)
|
if (mGatheringState.exchange(state) != state)
|
||||||
mGatheringStateChangeCallback(mGatheringState);
|
mGatheringStateChangeCallback(mGatheringState);
|
||||||
@ -576,7 +579,23 @@ void IceTransport::processStateChange(unsigned int state) {
|
|||||||
mTimeoutId = 0;
|
mTimeoutId = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
changeState(static_cast<State>(state));
|
switch (state) {
|
||||||
|
case NICE_COMPONENT_STATE_DISCONNECTED:
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
break;
|
||||||
|
case NICE_COMPONENT_STATE_CONNECTING:
|
||||||
|
changeState(State::Connecting);
|
||||||
|
break;
|
||||||
|
case NICE_COMPONENT_STATE_CONNECTED:
|
||||||
|
changeState(State::Connected);
|
||||||
|
break;
|
||||||
|
case NICE_COMPONENT_STATE_READY:
|
||||||
|
changeState(State::Completed);
|
||||||
|
break;
|
||||||
|
case NICE_COMPONENT_STATE_FAILED:
|
||||||
|
changeState(State::Failed);
|
||||||
|
break;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
string IceTransport::AddressToString(const NiceAddress &addr) {
|
string IceTransport::AddressToString(const NiceAddress &addr) {
|
||||||
|
@ -40,29 +40,9 @@ namespace rtc {
|
|||||||
|
|
||||||
class IceTransport : public Transport {
|
class IceTransport : public Transport {
|
||||||
public:
|
public:
|
||||||
#if USE_JUICE
|
|
||||||
enum class State : unsigned int{
|
|
||||||
Disconnected = JUICE_STATE_DISCONNECTED,
|
|
||||||
Connecting = JUICE_STATE_CONNECTING,
|
|
||||||
Connected = JUICE_STATE_CONNECTED,
|
|
||||||
Completed = JUICE_STATE_COMPLETED,
|
|
||||||
Failed = JUICE_STATE_FAILED,
|
|
||||||
};
|
|
||||||
#else
|
|
||||||
enum class State : unsigned int {
|
|
||||||
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
|
|
||||||
Connecting = NICE_COMPONENT_STATE_CONNECTING,
|
|
||||||
Connected = NICE_COMPONENT_STATE_CONNECTED,
|
|
||||||
Completed = NICE_COMPONENT_STATE_READY,
|
|
||||||
Failed = NICE_COMPONENT_STATE_FAILED,
|
|
||||||
};
|
|
||||||
|
|
||||||
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
|
|
||||||
#endif
|
|
||||||
enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
|
enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
|
||||||
|
|
||||||
using candidate_callback = std::function<void(const Candidate &candidate)>;
|
using candidate_callback = std::function<void(const Candidate &candidate)>;
|
||||||
using state_callback = std::function<void(State state)>;
|
|
||||||
using gathering_state_callback = std::function<void(GatheringState state)>;
|
using gathering_state_callback = std::function<void(GatheringState state)>;
|
||||||
|
|
||||||
IceTransport(const Configuration &config, Description::Role role,
|
IceTransport(const Configuration &config, Description::Role role,
|
||||||
@ -71,7 +51,6 @@ public:
|
|||||||
~IceTransport();
|
~IceTransport();
|
||||||
|
|
||||||
Description::Role role() const;
|
Description::Role role() const;
|
||||||
State state() const;
|
|
||||||
GatheringState gatheringState() const;
|
GatheringState gatheringState() const;
|
||||||
Description getLocalDescription(Description::Type type) const;
|
Description getLocalDescription(Description::Type type) const;
|
||||||
void setRemoteDescription(const Description &description);
|
void setRemoteDescription(const Description &description);
|
||||||
@ -84,10 +63,13 @@ public:
|
|||||||
bool stop() override;
|
bool stop() override;
|
||||||
bool send(message_ptr message) override; // false if dropped
|
bool send(message_ptr message) override; // false if dropped
|
||||||
|
|
||||||
|
#if !USE_JUICE
|
||||||
|
bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool outgoing(message_ptr message) override;
|
bool outgoing(message_ptr message) override;
|
||||||
|
|
||||||
void changeState(State state);
|
|
||||||
void changeGatheringState(GatheringState state);
|
void changeGatheringState(GatheringState state);
|
||||||
|
|
||||||
void processStateChange(unsigned int state);
|
void processStateChange(unsigned int state);
|
||||||
@ -98,11 +80,9 @@ private:
|
|||||||
Description::Role mRole;
|
Description::Role mRole;
|
||||||
string mMid;
|
string mMid;
|
||||||
std::chrono::milliseconds mTrickleTimeout;
|
std::chrono::milliseconds mTrickleTimeout;
|
||||||
std::atomic<State> mState;
|
|
||||||
std::atomic<GatheringState> mGatheringState;
|
std::atomic<GatheringState> mGatheringState;
|
||||||
|
|
||||||
candidate_callback mCandidateCallback;
|
candidate_callback mCandidateCallback;
|
||||||
state_callback mStateChangeCallback;
|
|
||||||
gathering_state_callback mGatheringStateChangeCallback;
|
gathering_state_callback mGatheringStateChangeCallback;
|
||||||
|
|
||||||
#if USE_JUICE
|
#if USE_JUICE
|
||||||
|
14
src/init.cpp
14
src/init.cpp
@ -21,6 +21,10 @@
|
|||||||
#include "dtlstransport.hpp"
|
#include "dtlstransport.hpp"
|
||||||
#include "sctptransport.hpp"
|
#include "sctptransport.hpp"
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
#include "tlstransport.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <winsock2.h>
|
#include <winsock2.h>
|
||||||
#endif
|
#endif
|
||||||
@ -69,13 +73,19 @@ Init::Init() {
|
|||||||
ERR_load_crypto_strings();
|
ERR_load_crypto_strings();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DtlsTransport::Init();
|
|
||||||
SctpTransport::Init();
|
SctpTransport::Init();
|
||||||
|
DtlsTransport::Init();
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
TlsTransport::Init();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Init::~Init() {
|
Init::~Init() {
|
||||||
DtlsTransport::Cleanup();
|
|
||||||
SctpTransport::Cleanup();
|
SctpTransport::Cleanup();
|
||||||
|
DtlsTransport::Cleanup();
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
TlsTransport::Cleanup();
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
|
@ -23,7 +23,6 @@
|
|||||||
#include "include.hpp"
|
#include "include.hpp"
|
||||||
#include "sctptransport.hpp"
|
#include "sctptransport.hpp"
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
namespace rtc {
|
namespace rtc {
|
||||||
@ -33,23 +32,6 @@ using namespace std::placeholders;
|
|||||||
using std::shared_ptr;
|
using std::shared_ptr;
|
||||||
using std::weak_ptr;
|
using std::weak_ptr;
|
||||||
|
|
||||||
template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
|
|
||||||
return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
|
|
||||||
if (auto shared_this = weak_this.lock())
|
|
||||||
bound(args...);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F, typename T, typename... Args>
|
|
||||||
auto weak_bind_verifier(F &&f, T *t, Args &&... _args) {
|
|
||||||
return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
|
|
||||||
if (auto shared_this = weak_this.lock())
|
|
||||||
return bound(args...);
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
|
PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
|
||||||
|
|
||||||
PeerConnection::PeerConnection(const Configuration &config)
|
PeerConnection::PeerConnection(const Configuration &config)
|
||||||
@ -271,7 +253,7 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
|
|||||||
|
|
||||||
auto lower = std::atomic_load(&mIceTransport);
|
auto lower = std::atomic_load(&mIceTransport);
|
||||||
auto transport = std::make_shared<DtlsTransport>(
|
auto transport = std::make_shared<DtlsTransport>(
|
||||||
lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
|
lower, mCertificate, weak_bind(&PeerConnection::checkFingerprint, this, _1),
|
||||||
[this, weak_this = weak_from_this()](DtlsTransport::State state) {
|
[this, weak_this = weak_from_this()](DtlsTransport::State state) {
|
||||||
auto shared_this = weak_this.lock();
|
auto shared_this = weak_this.lock();
|
||||||
if (!shared_this)
|
if (!shared_this)
|
||||||
|
181
src/rtc.cpp
181
src/rtc.cpp
@ -16,10 +16,15 @@
|
|||||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "datachannel.hpp"
|
|
||||||
#include "include.hpp"
|
#include "include.hpp"
|
||||||
|
|
||||||
|
#include "datachannel.hpp"
|
||||||
#include "peerconnection.hpp"
|
#include "peerconnection.hpp"
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
#include "websocket.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <rtc.h>
|
#include <rtc.h>
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
@ -43,6 +48,9 @@ namespace {
|
|||||||
|
|
||||||
std::unordered_map<int, shared_ptr<PeerConnection>> peerConnectionMap;
|
std::unordered_map<int, shared_ptr<PeerConnection>> peerConnectionMap;
|
||||||
std::unordered_map<int, shared_ptr<DataChannel>> dataChannelMap;
|
std::unordered_map<int, shared_ptr<DataChannel>> dataChannelMap;
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
std::unordered_map<int, shared_ptr<WebSocket>> webSocketMap;
|
||||||
|
#endif
|
||||||
std::unordered_map<int, void *> userPointerMap;
|
std::unordered_map<int, void *> userPointerMap;
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
int lastId = 0;
|
int lastId = 0;
|
||||||
@ -103,6 +111,40 @@ bool eraseDataChannel(int dc) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
shared_ptr<WebSocket> getWebSocket(int id) {
|
||||||
|
std::lock_guard lock(mutex);
|
||||||
|
auto it = webSocketMap.find(id);
|
||||||
|
return it != webSocketMap.end() ? it->second : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int emplaceWebSocket(shared_ptr<WebSocket> ptr) {
|
||||||
|
std::lock_guard lock(mutex);
|
||||||
|
int ws = ++lastId;
|
||||||
|
webSocketMap.emplace(std::make_pair(ws, ptr));
|
||||||
|
return ws;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool eraseWebSocket(int ws) {
|
||||||
|
std::lock_guard lock(mutex);
|
||||||
|
if (webSocketMap.erase(ws) == 0)
|
||||||
|
return false;
|
||||||
|
userPointerMap.erase(ws);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
shared_ptr<Channel> getChannel(int id) {
|
||||||
|
std::lock_guard lock(mutex);
|
||||||
|
if (auto it = dataChannelMap.find(id); it != dataChannelMap.end())
|
||||||
|
return it->second;
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
if (auto it = webSocketMap.find(id); it != webSocketMap.end())
|
||||||
|
return it->second;
|
||||||
|
#endif
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
|
void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
|
||||||
@ -164,6 +206,29 @@ int rtcDeleteDataChannel(int dc) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
int rtcCreateWebSocket(const char *url) {
|
||||||
|
return emplaceWebSocket(std::make_shared<WebSocket>(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
int rtcDeleteWebsocket(int ws) {
|
||||||
|
auto webSocket = getWebSocket(ws);
|
||||||
|
if (!webSocket)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
webSocket->onOpen(nullptr);
|
||||||
|
webSocket->onClosed(nullptr);
|
||||||
|
webSocket->onError(nullptr);
|
||||||
|
webSocket->onMessage(nullptr);
|
||||||
|
webSocket->onBufferedAmountLow(nullptr);
|
||||||
|
webSocket->onAvailable(nullptr);
|
||||||
|
|
||||||
|
eraseWebSocket(ws);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb) {
|
int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb) {
|
||||||
auto peerConnection = getPeerConnection(pc);
|
auto peerConnection = getPeerConnection(pc);
|
||||||
if (!peerConnection)
|
if (!peerConnection)
|
||||||
@ -298,135 +363,135 @@ int rtcGetDataChannelLabel(int dc, char *buffer, int size) {
|
|||||||
return size + 1;
|
return size + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetOpenCallback(int dc, openCallbackFunc cb) {
|
int rtcSetOpenCallback(int id, openCallbackFunc cb) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (cb)
|
if (cb)
|
||||||
dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
|
channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
|
||||||
else
|
else
|
||||||
dataChannel->onOpen(nullptr);
|
channel->onOpen(nullptr);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetClosedCallback(int dc, closedCallbackFunc cb) {
|
int rtcSetClosedCallback(int id, closedCallbackFunc cb) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (cb)
|
if (cb)
|
||||||
dataChannel->onClosed([dc, cb]() { cb(getUserPointer(dc)); });
|
channel->onClosed([id, cb]() { cb(getUserPointer(id)); });
|
||||||
else
|
else
|
||||||
dataChannel->onClosed(nullptr);
|
channel->onClosed(nullptr);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetErrorCallback(int dc, errorCallbackFunc cb) {
|
int rtcSetErrorCallback(int id, errorCallbackFunc cb) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (cb)
|
if (cb)
|
||||||
dataChannel->onError(
|
channel->onError([id, cb](const string &error) { cb(error.c_str(), getUserPointer(id)); });
|
||||||
[dc, cb](const string &error) { cb(error.c_str(), getUserPointer(dc)); });
|
|
||||||
else
|
else
|
||||||
dataChannel->onError(nullptr);
|
channel->onError(nullptr);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetMessageCallback(int dc, messageCallbackFunc cb) {
|
int rtcSetMessageCallback(int id, messageCallbackFunc cb) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (cb)
|
if (cb)
|
||||||
dataChannel->onMessage(
|
channel->onMessage(
|
||||||
[dc, cb](const binary &b) {
|
[id, cb](const binary &b) {
|
||||||
cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(dc));
|
cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(id));
|
||||||
},
|
},
|
||||||
[dc, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(dc)); });
|
[id, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(id)); });
|
||||||
else
|
else
|
||||||
dataChannel->onMessage(nullptr);
|
channel->onMessage(nullptr);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSendMessage(int dc, const char *data, int size) {
|
int rtcSendMessage(int id, const char *data, int size) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (size >= 0) {
|
if (size >= 0) {
|
||||||
auto b = reinterpret_cast<const byte *>(data);
|
auto b = reinterpret_cast<const byte *>(data);
|
||||||
CATCH(dataChannel->send(b, size));
|
CATCH(channel->send(binary(b, b + size)));
|
||||||
return size;
|
return size;
|
||||||
} else {
|
} else {
|
||||||
string s(data);
|
string str(data);
|
||||||
CATCH(dataChannel->send(s));
|
int len = str.size();
|
||||||
return s.size();
|
CATCH(channel->send(std::move(str)));
|
||||||
|
return len;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcGetBufferedAmount(int dc) {
|
int rtcGetBufferedAmount(int id) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
CATCH(return int(dataChannel->bufferedAmount()));
|
CATCH(return int(channel->bufferedAmount()));
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetBufferedAmountLowThreshold(int dc, int amount) {
|
int rtcSetBufferedAmountLowThreshold(int id, int amount) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
CATCH(dataChannel->setBufferedAmountLowThreshold(size_t(amount)));
|
CATCH(channel->setBufferedAmountLowThreshold(size_t(amount)));
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb) {
|
int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (cb)
|
if (cb)
|
||||||
dataChannel->onBufferedAmountLow([dc, cb]() { cb(getUserPointer(dc)); });
|
channel->onBufferedAmountLow([id, cb]() { cb(getUserPointer(id)); });
|
||||||
else
|
else
|
||||||
dataChannel->onBufferedAmountLow(nullptr);
|
channel->onBufferedAmountLow(nullptr);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcGetAvailableAmount(int dc) {
|
int rtcGetAvailableAmount(int id) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
CATCH(return int(dataChannel->availableAmount()));
|
CATCH(return int(channel->availableAmount()));
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcSetAvailableCallback(int dc, availableCallbackFunc cb) {
|
int rtcSetAvailableCallback(int id, availableCallbackFunc cb) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (cb)
|
if (cb)
|
||||||
dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
|
channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
|
||||||
else
|
else
|
||||||
dataChannel->onOpen(nullptr);
|
channel->onOpen(nullptr);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int rtcReceiveMessage(int dc, char *buffer, int *size) {
|
int rtcReceiveMessage(int id, char *buffer, int *size) {
|
||||||
auto dataChannel = getDataChannel(dc);
|
auto channel = getChannel(id);
|
||||||
if (!dataChannel)
|
if (!channel)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
if (!size)
|
if (!size)
|
||||||
return -1;
|
return -1;
|
||||||
|
|
||||||
CATCH({
|
CATCH({
|
||||||
auto message = dataChannel->receive();
|
auto message = channel->receive();
|
||||||
if (!message)
|
if (!message)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
|
@ -71,9 +71,8 @@ void SctpTransport::Cleanup() {
|
|||||||
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
|
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
|
||||||
message_callback recvCallback, amount_callback bufferedAmountCallback,
|
message_callback recvCallback, amount_callback bufferedAmountCallback,
|
||||||
state_callback stateChangeCallback)
|
state_callback stateChangeCallback)
|
||||||
: Transport(lower), mPort(port), mSendQueue(0, message_size_func),
|
: Transport(lower, std::move(stateChangeCallback)), mPort(port),
|
||||||
mBufferedAmountCallback(std::move(bufferedAmountCallback)),
|
mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
|
||||||
mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
|
|
||||||
onRecv(recvCallback);
|
onRecv(recvCallback);
|
||||||
|
|
||||||
PLOG_DEBUG << "Initializing SCTP transport";
|
PLOG_DEBUG << "Initializing SCTP transport";
|
||||||
@ -180,8 +179,6 @@ SctpTransport::~SctpTransport() {
|
|||||||
usrsctp_deregister_address(this);
|
usrsctp_deregister_address(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
SctpTransport::State SctpTransport::state() const { return mState; }
|
|
||||||
|
|
||||||
bool SctpTransport::stop() {
|
bool SctpTransport::stop() {
|
||||||
if (!Transport::stop())
|
if (!Transport::stop())
|
||||||
return false;
|
return false;
|
||||||
@ -240,6 +237,7 @@ void SctpTransport::shutdown() {
|
|||||||
|
|
||||||
bool SctpTransport::send(message_ptr message) {
|
bool SctpTransport::send(message_ptr message) {
|
||||||
std::lock_guard lock(mSendMutex);
|
std::lock_guard lock(mSendMutex);
|
||||||
|
|
||||||
if (!message)
|
if (!message)
|
||||||
return mSendQueue.empty();
|
return mSendQueue.empty();
|
||||||
|
|
||||||
@ -269,7 +267,7 @@ void SctpTransport::incoming(message_ptr message) {
|
|||||||
// to be sent on our side (i.e. the local INIT) before proceeding.
|
// to be sent on our side (i.e. the local INIT) before proceeding.
|
||||||
{
|
{
|
||||||
std::unique_lock lock(mWriteMutex);
|
std::unique_lock lock(mWriteMutex);
|
||||||
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
|
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!message) {
|
if (!message) {
|
||||||
@ -283,11 +281,6 @@ void SctpTransport::incoming(message_ptr message) {
|
|||||||
usrsctp_conninput(this, message->data(), message->size(), 0);
|
usrsctp_conninput(this, message->data(), message->size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SctpTransport::changeState(State state) {
|
|
||||||
if (mState.exchange(state) != state)
|
|
||||||
mStateChangeCallback(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool SctpTransport::trySendQueue() {
|
bool SctpTransport::trySendQueue() {
|
||||||
// Requires mSendMutex to be locked
|
// Requires mSendMutex to be locked
|
||||||
while (auto next = mSendQueue.peek()) {
|
while (auto next = mSendQueue.peek()) {
|
||||||
@ -302,7 +295,7 @@ bool SctpTransport::trySendQueue() {
|
|||||||
|
|
||||||
bool SctpTransport::trySendMessage(message_ptr message) {
|
bool SctpTransport::trySendMessage(message_ptr message) {
|
||||||
// Requires mSendMutex to be locked
|
// Requires mSendMutex to be locked
|
||||||
if (!mSock || mState != State::Connected)
|
if (!mSock || state() != State::Connected)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
uint32_t ppid;
|
uint32_t ppid;
|
||||||
@ -414,7 +407,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
|
|||||||
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
|
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
|
||||||
std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
|
std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
|
||||||
mWrittenCondition.wait_for(lock, 1000ms,
|
mWrittenCondition.wait_for(lock, 1000ms,
|
||||||
[&]() { return mWritten || mState != State::Connected; });
|
[&]() { return mWritten || state() != State::Connected; });
|
||||||
} else if (errno == EINVAL) {
|
} else if (errno == EINVAL) {
|
||||||
PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
|
PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
|
||||||
} else {
|
} else {
|
||||||
@ -571,7 +564,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
|
|||||||
PLOG_INFO << "SCTP connected";
|
PLOG_INFO << "SCTP connected";
|
||||||
changeState(State::Connected);
|
changeState(State::Connected);
|
||||||
} else {
|
} else {
|
||||||
if (mState == State::Connecting) {
|
if (state() == State::Connecting) {
|
||||||
PLOG_ERROR << "SCTP connection failed";
|
PLOG_ERROR << "SCTP connection failed";
|
||||||
changeState(State::Failed);
|
changeState(State::Failed);
|
||||||
} else {
|
} else {
|
||||||
|
@ -38,17 +38,12 @@ public:
|
|||||||
static void Init();
|
static void Init();
|
||||||
static void Cleanup();
|
static void Cleanup();
|
||||||
|
|
||||||
enum class State { Disconnected, Connecting, Connected, Failed };
|
|
||||||
|
|
||||||
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
|
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
|
||||||
using state_callback = std::function<void(State state)>;
|
|
||||||
|
|
||||||
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
|
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
|
||||||
amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
|
amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
|
||||||
~SctpTransport();
|
~SctpTransport();
|
||||||
|
|
||||||
State state() const;
|
|
||||||
|
|
||||||
bool stop() override;
|
bool stop() override;
|
||||||
bool send(message_ptr message) override; // false if buffered
|
bool send(message_ptr message) override; // false if buffered
|
||||||
void close(unsigned int stream);
|
void close(unsigned int stream);
|
||||||
@ -76,7 +71,6 @@ private:
|
|||||||
void connect();
|
void connect();
|
||||||
void shutdown();
|
void shutdown();
|
||||||
void incoming(message_ptr message) override;
|
void incoming(message_ptr message) override;
|
||||||
void changeState(State state);
|
|
||||||
|
|
||||||
bool trySendQueue();
|
bool trySendQueue();
|
||||||
bool trySendMessage(message_ptr message);
|
bool trySendMessage(message_ptr message);
|
||||||
@ -105,14 +99,11 @@ private:
|
|||||||
std::atomic<bool> mWritten = false; // written outside lock
|
std::atomic<bool> mWritten = false; // written outside lock
|
||||||
bool mWrittenOnce = false;
|
bool mWrittenOnce = false;
|
||||||
|
|
||||||
state_callback mStateChangeCallback;
|
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
|
||||||
std::atomic<State> mState;
|
|
||||||
|
|
||||||
// Stats
|
// Stats
|
||||||
std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
|
std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
|
||||||
|
|
||||||
binary mPartialRecv, mPartialStringData, mPartialBinaryData;
|
|
||||||
|
|
||||||
static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
|
static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
|
||||||
struct sctp_rcvinfo recv_info, int flags, void *user_data);
|
struct sctp_rcvinfo recv_info, int flags, void *user_data);
|
||||||
static int SendCallback(struct socket *sock, uint32_t sb_free);
|
static int SendCallback(struct socket *sock, uint32_t sb_free);
|
||||||
|
320
src/tcptransport.cpp
Normal file
320
src/tcptransport.cpp
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "tcptransport.hpp"
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#ifndef _WIN32
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
using std::to_string;
|
||||||
|
|
||||||
|
SelectInterrupter::SelectInterrupter() {
|
||||||
|
#ifndef _WIN32
|
||||||
|
int pipefd[2];
|
||||||
|
if (::pipe(pipefd) != 0)
|
||||||
|
throw std::runtime_error("Failed to create pipe");
|
||||||
|
::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
|
||||||
|
::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
|
||||||
|
mPipeOut = pipefd[0]; // read
|
||||||
|
mPipeIn = pipefd[1]; // write
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
SelectInterrupter::~SelectInterrupter() {
|
||||||
|
std::lock_guard lock(mMutex);
|
||||||
|
#ifdef _WIN32
|
||||||
|
if (mDummySock != INVALID_SOCKET)
|
||||||
|
::closesocket(mDummySock);
|
||||||
|
#else
|
||||||
|
::close(mPipeIn);
|
||||||
|
::close(mPipeOut);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int SelectInterrupter::prepare(fd_set &readfds, fd_set &writefds) {
|
||||||
|
std::lock_guard lock(mMutex);
|
||||||
|
#ifdef _WIN32
|
||||||
|
if (mDummySock == INVALID_SOCKET)
|
||||||
|
mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
|
||||||
|
FD_SET(mDummySock, &readfds);
|
||||||
|
return SOCK_TO_INT(mDummySock) + 1;
|
||||||
|
#else
|
||||||
|
int ret;
|
||||||
|
do {
|
||||||
|
char dummy;
|
||||||
|
ret = ::read(mPipeIn, &dummy, 1);
|
||||||
|
} while (ret > 0);
|
||||||
|
FD_SET(mPipeIn, &readfds);
|
||||||
|
return mPipeIn + 1;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void SelectInterrupter::interrupt() {
|
||||||
|
std::lock_guard lock(mMutex);
|
||||||
|
#ifdef _WIN32
|
||||||
|
if (mDummySock != INVALID_SOCKET) {
|
||||||
|
::closesocket(mDummySock);
|
||||||
|
mDummySock = INVALID_SOCKET;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
char dummy = 0;
|
||||||
|
::write(mPipeOut, &dummy, 1);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
|
||||||
|
: Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Initializing TCP transport";
|
||||||
|
mThread = std::thread(&TcpTransport::runLoop, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
TcpTransport::~TcpTransport() {
|
||||||
|
stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TcpTransport::stop() {
|
||||||
|
if (!Transport::stop())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Waiting TCP recv thread";
|
||||||
|
close();
|
||||||
|
mThread.join();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TcpTransport::send(message_ptr message) {
|
||||||
|
if (!message)
|
||||||
|
return mSendQueue.empty();
|
||||||
|
|
||||||
|
PLOG_VERBOSE << "Send size=" << (message ? message->size() : 0);
|
||||||
|
|
||||||
|
return outgoing(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpTransport::incoming(message_ptr message) { recv(message); }
|
||||||
|
|
||||||
|
bool TcpTransport::outgoing(message_ptr message) {
|
||||||
|
// If nothing is pending, try to send directly
|
||||||
|
// It's safe because if the queue is empty, the thread is not sending
|
||||||
|
if (mSendQueue.empty() && trySendMessage(message))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
mSendQueue.push(message);
|
||||||
|
interruptSelect(); // so the thread waits for writability
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpTransport::connect(const string &hostname, const string &service) {
|
||||||
|
PLOG_DEBUG << "Connecting to " << hostname << ":" << service;
|
||||||
|
|
||||||
|
struct addrinfo hints = {};
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
hints.ai_protocol = IPPROTO_TCP;
|
||||||
|
hints.ai_flags = AI_ADDRCONFIG;
|
||||||
|
|
||||||
|
struct addrinfo *result = nullptr;
|
||||||
|
if (getaddrinfo(hostname.c_str(), service.c_str(), &hints, &result))
|
||||||
|
throw std::runtime_error("Resolution failed for \"" + hostname + ":" + service + "\"");
|
||||||
|
|
||||||
|
for (auto p = result; p; p = p->ai_next)
|
||||||
|
try {
|
||||||
|
connect(p->ai_addr, p->ai_addrlen);
|
||||||
|
freeaddrinfo(result);
|
||||||
|
return;
|
||||||
|
} catch (const std::runtime_error &e) {
|
||||||
|
PLOG_WARNING << e.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
freeaddrinfo(result);
|
||||||
|
throw std::runtime_error("Connection failed to \"" + hostname + ":" + service + "\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
|
||||||
|
try {
|
||||||
|
PLOG_DEBUG << "Creating TCP socket";
|
||||||
|
|
||||||
|
// Create socket
|
||||||
|
mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
|
||||||
|
if (mSock == INVALID_SOCKET)
|
||||||
|
throw std::runtime_error("TCP socket creation failed");
|
||||||
|
|
||||||
|
ctl_t b = 1;
|
||||||
|
if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
|
||||||
|
throw std::runtime_error("Failed to set socket non-blocking mode");
|
||||||
|
|
||||||
|
IF_PLOG(plog::debug) {
|
||||||
|
char node[MAX_NUMERICNODE_LEN];
|
||||||
|
char serv[MAX_NUMERICSERV_LEN];
|
||||||
|
if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
|
||||||
|
NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
|
||||||
|
PLOG_DEBUG << "Trying address " << node << ":" << serv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initiate connection
|
||||||
|
::connect(mSock, addr, addrlen);
|
||||||
|
|
||||||
|
fd_set writefds;
|
||||||
|
FD_ZERO(&writefds);
|
||||||
|
FD_SET(mSock, &writefds);
|
||||||
|
struct timeval tv;
|
||||||
|
tv.tv_sec = 10; // TODO
|
||||||
|
tv.tv_usec = 0;
|
||||||
|
int ret = ::select(SOCKET_TO_INT(mSock) + 1, NULL, &writefds, NULL, &tv);
|
||||||
|
|
||||||
|
if (ret < 0)
|
||||||
|
throw std::runtime_error("Failed to wait for socket connection");
|
||||||
|
|
||||||
|
if (ret == 0 || ::send(mSock, NULL, 0, MSG_NOSIGNAL) != 0)
|
||||||
|
throw std::runtime_error("Connection failed");
|
||||||
|
|
||||||
|
} catch (...) {
|
||||||
|
if (mSock != INVALID_SOCKET) {
|
||||||
|
::closesocket(mSock);
|
||||||
|
mSock = INVALID_SOCKET;
|
||||||
|
}
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpTransport::close() {
|
||||||
|
if (mSock != INVALID_SOCKET) {
|
||||||
|
PLOG_DEBUG << "Closing TCP socket";
|
||||||
|
::closesocket(mSock);
|
||||||
|
mSock = INVALID_SOCKET;
|
||||||
|
}
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TcpTransport::trySendQueue() {
|
||||||
|
while (auto next = mSendQueue.peek()) {
|
||||||
|
auto message = *next;
|
||||||
|
if (!trySendMessage(message)) {
|
||||||
|
mSendQueue.exchange(message);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mSendQueue.pop();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TcpTransport::trySendMessage(message_ptr &message) {
|
||||||
|
auto data = reinterpret_cast<const char *>(message->data());
|
||||||
|
auto size = message->size();
|
||||||
|
while (size) {
|
||||||
|
int len = ::send(mSock, data, size, MSG_NOSIGNAL);
|
||||||
|
if (len < 0) {
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||||
|
message = make_message(message->end() - size, message->end());
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data += len;
|
||||||
|
size -= len;
|
||||||
|
}
|
||||||
|
message = nullptr;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpTransport::runLoop() {
|
||||||
|
const size_t bufferSize = 4096;
|
||||||
|
|
||||||
|
// Connect
|
||||||
|
try {
|
||||||
|
changeState(State::Connecting);
|
||||||
|
connect(mHostname, mService);
|
||||||
|
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << "TCP connect: " << e.what();
|
||||||
|
changeState(State::Failed);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Receive loop
|
||||||
|
try {
|
||||||
|
PLOG_INFO << "TCP connected";
|
||||||
|
changeState(State::Connected);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
fd_set readfds, writefds;
|
||||||
|
int n = prepareSelect(readfds, writefds);
|
||||||
|
int ret = ::select(n, &readfds, &writefds, NULL, NULL);
|
||||||
|
if (ret < 0)
|
||||||
|
throw std::runtime_error("Failed to wait on socket");
|
||||||
|
|
||||||
|
if (FD_ISSET(mSock, &writefds))
|
||||||
|
trySendQueue();
|
||||||
|
|
||||||
|
if (FD_ISSET(mSock, &readfds)) {
|
||||||
|
char buffer[bufferSize];
|
||||||
|
int len = ::recv(mSock, buffer, bufferSize, 0);
|
||||||
|
if (len < 0) {
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len == 0)
|
||||||
|
break; // clean close
|
||||||
|
|
||||||
|
auto *b = reinterpret_cast<byte *>(buffer);
|
||||||
|
incoming(make_message(b, b + len));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << "TCP recv: " << e.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLOG_INFO << "TCP disconnected";
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
recv(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) {
|
||||||
|
FD_ZERO(&readfds);
|
||||||
|
FD_ZERO(&writefds);
|
||||||
|
FD_SET(mSock, &readfds);
|
||||||
|
|
||||||
|
if (!mSendQueue.empty())
|
||||||
|
FD_SET(mSock, &writefds);
|
||||||
|
|
||||||
|
int n = SOCKET_TO_INT(mSock) + 1;
|
||||||
|
int m = mInterrupter.prepare(readfds, writefds);
|
||||||
|
return std::max(n, m);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TcpTransport::interruptSelect() { mInterrupter.interrupt(); }
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
90
src/tcptransport.hpp
Normal file
90
src/tcptransport.hpp
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef RTC_TCP_TRANSPORT_H
|
||||||
|
#define RTC_TCP_TRANSPORT_H
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "include.hpp"
|
||||||
|
#include "queue.hpp"
|
||||||
|
#include "transport.hpp"
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
// Use the socket defines from libjuice
|
||||||
|
#include "../deps/libjuice/src/socket.h"
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
// Utility class to interrupt select()
|
||||||
|
class SelectInterrupter {
|
||||||
|
public:
|
||||||
|
SelectInterrupter();
|
||||||
|
~SelectInterrupter();
|
||||||
|
|
||||||
|
int prepare(fd_set &readfds, fd_set &writefds);
|
||||||
|
void interrupt();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mMutex;
|
||||||
|
#ifdef _WIN32
|
||||||
|
socket_t mDummySock = INVALID_SOCKET;
|
||||||
|
#else // assume POSIX
|
||||||
|
int mPipeIn, mPipeOut;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
class TcpTransport : public Transport {
|
||||||
|
public:
|
||||||
|
TcpTransport(const string &hostname, const string &service, state_callback callback);
|
||||||
|
~TcpTransport();
|
||||||
|
|
||||||
|
bool stop() override;
|
||||||
|
bool send(message_ptr message) override;
|
||||||
|
|
||||||
|
void incoming(message_ptr message) override;
|
||||||
|
bool outgoing(message_ptr message) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void connect(const string &hostname, const string &service);
|
||||||
|
void connect(const sockaddr *addr, socklen_t addrlen);
|
||||||
|
void close();
|
||||||
|
|
||||||
|
bool trySendQueue();
|
||||||
|
bool trySendMessage(message_ptr &message);
|
||||||
|
|
||||||
|
void runLoop();
|
||||||
|
|
||||||
|
int prepareSelect(fd_set &readfds, fd_set &writefds);
|
||||||
|
void interruptSelect();
|
||||||
|
|
||||||
|
string mHostname, mService;
|
||||||
|
|
||||||
|
socket_t mSock = INVALID_SOCKET;
|
||||||
|
std::thread mThread;
|
||||||
|
SelectInterrupter mInterrupter;
|
||||||
|
Queue<message_ptr> mSendQueue;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
432
src/tlstransport.cpp
Normal file
432
src/tlstransport.cpp
Normal file
@ -0,0 +1,432 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "tlstransport.hpp"
|
||||||
|
#include "tcptransport.hpp"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstring>
|
||||||
|
#include <exception>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace std::chrono;
|
||||||
|
|
||||||
|
using std::shared_ptr;
|
||||||
|
using std::string;
|
||||||
|
using std::unique_ptr;
|
||||||
|
using std::weak_ptr;
|
||||||
|
|
||||||
|
#if USE_GNUTLS
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
|
||||||
|
if (ret < 0) {
|
||||||
|
if (!gnutls_error_is_fatal(ret)) {
|
||||||
|
PLOG_INFO << gnutls_strerror(ret);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
PLOG_ERROR << message << ": " << gnutls_strerror(ret);
|
||||||
|
throw std::runtime_error(message + ": " + gnutls_strerror(ret));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
void TlsTransport::Init() {
|
||||||
|
// Nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::Cleanup() {
|
||||||
|
// Nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
|
||||||
|
: Transport(lower, std::move(callback)) {
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
|
||||||
|
|
||||||
|
check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
|
||||||
|
|
||||||
|
try {
|
||||||
|
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
|
||||||
|
const char *err_pos = NULL;
|
||||||
|
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
|
||||||
|
"Failed to set TLS priorities");
|
||||||
|
|
||||||
|
gnutls_session_set_ptr(mSession, this);
|
||||||
|
gnutls_transport_set_ptr(mSession, this);
|
||||||
|
gnutls_transport_set_push_function(mSession, WriteCallback);
|
||||||
|
gnutls_transport_set_pull_function(mSession, ReadCallback);
|
||||||
|
gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
|
||||||
|
|
||||||
|
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
|
||||||
|
|
||||||
|
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
||||||
|
registerIncoming();
|
||||||
|
|
||||||
|
} catch (...) {
|
||||||
|
|
||||||
|
gnutls_deinit(mSession);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TlsTransport::~TlsTransport() {
|
||||||
|
stop();
|
||||||
|
gnutls_deinit(mSession);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TlsTransport::stop() {
|
||||||
|
if (!Transport::stop())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Stopping TLS recv thread";
|
||||||
|
mIncomingQueue.stop();
|
||||||
|
mRecvThread.join();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TlsTransport::send(message_ptr message) {
|
||||||
|
if (!message)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
PLOG_VERBOSE << "Send size=" << message->size();
|
||||||
|
|
||||||
|
ssize_t ret;
|
||||||
|
do {
|
||||||
|
ret = gnutls_record_send(mSession, message->data(), message->size());
|
||||||
|
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
|
||||||
|
|
||||||
|
return check_gnutls(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::incoming(message_ptr message) {
|
||||||
|
if (message)
|
||||||
|
mIncomingQueue.push(message);
|
||||||
|
else
|
||||||
|
mIncomingQueue.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::runRecvLoop() {
|
||||||
|
const size_t bufferSize = 4096;
|
||||||
|
char buffer[bufferSize];
|
||||||
|
|
||||||
|
// Handshake loop
|
||||||
|
try {
|
||||||
|
changeState(State::Connecting);
|
||||||
|
|
||||||
|
int ret;
|
||||||
|
do {
|
||||||
|
ret = gnutls_handshake(mSession);
|
||||||
|
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
|
||||||
|
!check_gnutls(ret, "TLS handshake failed"));
|
||||||
|
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << "TLS handshake: " << e.what();
|
||||||
|
changeState(State::Failed);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive loop
|
||||||
|
try {
|
||||||
|
PLOG_INFO << "TLS handshake finished";
|
||||||
|
changeState(State::Connected);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
ssize_t ret;
|
||||||
|
do {
|
||||||
|
ret = gnutls_record_recv(mSession, buffer, bufferSize);
|
||||||
|
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
|
||||||
|
|
||||||
|
// Consider premature termination as remote closing
|
||||||
|
if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
|
||||||
|
PLOG_DEBUG << "TLS connection terminated";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (check_gnutls(ret)) {
|
||||||
|
if (ret == 0) {
|
||||||
|
// Closed
|
||||||
|
PLOG_DEBUG << "TLS connection cleanly closed";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto *b = reinterpret_cast<byte *>(buffer);
|
||||||
|
recv(make_message(b, b + ret));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << "TLS recv: " << e.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
|
||||||
|
|
||||||
|
PLOG_INFO << "TLS closed";
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
recv(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
|
||||||
|
TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
||||||
|
if (len > 0) {
|
||||||
|
auto b = reinterpret_cast<const byte *>(data);
|
||||||
|
t->outgoing(make_message(b, b + len));
|
||||||
|
}
|
||||||
|
gnutls_transport_set_errno(t->mSession, 0);
|
||||||
|
return ssize_t(len);
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
|
||||||
|
TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
||||||
|
if (auto next = t->mIncomingQueue.pop()) {
|
||||||
|
auto message = *next;
|
||||||
|
ssize_t len = std::min(maxlen, message->size());
|
||||||
|
std::memcpy(data, message->data(), len);
|
||||||
|
gnutls_transport_set_errno(t->mSession, 0);
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
// Closed
|
||||||
|
gnutls_transport_set_errno(t->mSession, 0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
|
||||||
|
TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
||||||
|
if (ms != GNUTLS_INDEFINITE_TIMEOUT)
|
||||||
|
t->mIncomingQueue.wait(milliseconds(ms));
|
||||||
|
else
|
||||||
|
t->mIncomingQueue.wait();
|
||||||
|
return !t->mIncomingQueue.empty() ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#else // USE_GNUTLS==0
|
||||||
|
|
||||||
|
#include <openssl/bio.h>
|
||||||
|
#include <openssl/ec.h>
|
||||||
|
#include <openssl/err.h>
|
||||||
|
#include <openssl/ssl.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const int BIO_EOF = -1;
|
||||||
|
|
||||||
|
string openssl_error_string(unsigned long err) {
|
||||||
|
const size_t bufferSize = 256;
|
||||||
|
char buffer[bufferSize];
|
||||||
|
ERR_error_string_n(err, buffer, bufferSize);
|
||||||
|
return string(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool check_openssl(int success, const string &message = "OpenSSL error") {
|
||||||
|
if (success)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
string str = openssl_error_string(ERR_get_error());
|
||||||
|
PLOG_ERROR << message << ": " << str;
|
||||||
|
throw std::runtime_error(message + ": " + str);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
|
||||||
|
if (ret == BIO_EOF)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
unsigned long err = SSL_get_error(ssl, ret);
|
||||||
|
if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (err == SSL_ERROR_ZERO_RETURN) {
|
||||||
|
PLOG_DEBUG << "TLS connection cleanly closed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
string str = openssl_error_string(err);
|
||||||
|
PLOG_ERROR << str;
|
||||||
|
throw std::runtime_error(message + ": " + str);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
int TlsTransport::TransportExIndex = -1;
|
||||||
|
|
||||||
|
void TlsTransport::Init() {
|
||||||
|
if (TransportExIndex < 0) {
|
||||||
|
TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::Cleanup() {
|
||||||
|
// Nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
|
||||||
|
: Transport(lower, std::move(callback)) {
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
|
||||||
|
|
||||||
|
if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
|
||||||
|
throw std::runtime_error("Failed to create SSL context");
|
||||||
|
|
||||||
|
check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
|
||||||
|
"Failed to set SSL priorities");
|
||||||
|
|
||||||
|
SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
|
||||||
|
SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
|
||||||
|
SSL_CTX_set_read_ahead(mCtx, 1);
|
||||||
|
SSL_CTX_set_quiet_shutdown(mCtx, 1);
|
||||||
|
SSL_CTX_set_info_callback(mCtx, InfoCallback);
|
||||||
|
|
||||||
|
SSL_CTX_set_default_verify_paths(mCtx);
|
||||||
|
SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
|
||||||
|
SSL_CTX_set_verify_depth(mCtx, 4);
|
||||||
|
|
||||||
|
if (!(mSsl = SSL_new(mCtx)))
|
||||||
|
throw std::runtime_error("Failed to create SSL instance");
|
||||||
|
|
||||||
|
SSL_set_ex_data(mSsl, TransportExIndex, this);
|
||||||
|
SSL_set_tlsext_host_name(mSsl, host.c_str());
|
||||||
|
|
||||||
|
SSL_set_connect_state(mSsl);
|
||||||
|
|
||||||
|
if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
|
||||||
|
throw std::runtime_error("Failed to create BIO");
|
||||||
|
|
||||||
|
BIO_set_mem_eof_return(mInBio, BIO_EOF);
|
||||||
|
BIO_set_mem_eof_return(mOutBio, BIO_EOF);
|
||||||
|
SSL_set_bio(mSsl, mInBio, mOutBio);
|
||||||
|
|
||||||
|
auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
|
||||||
|
EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
|
||||||
|
SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
|
||||||
|
SSL_set_tmp_ecdh(mSsl, ecdh.get());
|
||||||
|
|
||||||
|
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
TlsTransport::~TlsTransport() {
|
||||||
|
stop();
|
||||||
|
|
||||||
|
SSL_free(mSsl);
|
||||||
|
SSL_CTX_free(mCtx);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TlsTransport::stop() {
|
||||||
|
if (!Transport::stop())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Stopping TLS recv thread";
|
||||||
|
mIncomingQueue.stop();
|
||||||
|
mRecvThread.join();
|
||||||
|
SSL_shutdown(mSsl);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TlsTransport::send(message_ptr message) {
|
||||||
|
if (!message)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
int ret = SSL_write(mSsl, message->data(), message->size());
|
||||||
|
if (!check_openssl_ret(mSsl, ret))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const size_t bufferSize = 4096;
|
||||||
|
byte buffer[bufferSize];
|
||||||
|
while (int len = BIO_read(mOutBio, buffer, bufferSize))
|
||||||
|
outgoing(make_message(buffer, buffer + len));
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::incoming(message_ptr message) {
|
||||||
|
if (message)
|
||||||
|
mIncomingQueue.push(message);
|
||||||
|
else
|
||||||
|
mIncomingQueue.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::runRecvLoop() {
|
||||||
|
const size_t bufferSize = 4096;
|
||||||
|
byte buffer[bufferSize];
|
||||||
|
|
||||||
|
try {
|
||||||
|
changeState(State::Connecting);
|
||||||
|
|
||||||
|
SSL_do_handshake(mSsl);
|
||||||
|
while (int len = BIO_read(mOutBio, buffer, bufferSize))
|
||||||
|
outgoing(make_message(buffer, buffer + len));
|
||||||
|
|
||||||
|
while (auto next = mIncomingQueue.pop()) {
|
||||||
|
message_ptr message = *next;
|
||||||
|
message_ptr decrypted;
|
||||||
|
|
||||||
|
BIO_write(mInBio, message->data(), message->size());
|
||||||
|
|
||||||
|
int ret = SSL_read(mSsl, buffer, bufferSize);
|
||||||
|
if (!check_openssl_ret(mSsl, ret))
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (ret > 0)
|
||||||
|
decrypted = make_message(buffer, buffer + ret);
|
||||||
|
|
||||||
|
while (int len = BIO_read(mOutBio, buffer, bufferSize))
|
||||||
|
outgoing(make_message(buffer, buffer + len));
|
||||||
|
|
||||||
|
if (state() == State::Connecting && SSL_is_init_finished(mSsl)) {
|
||||||
|
PLOG_INFO << "TLS handshake finished";
|
||||||
|
changeState(State::Connected);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (decrypted)
|
||||||
|
recv(decrypted);
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << "TLS recv: " << e.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state() == State::Connected) {
|
||||||
|
PLOG_INFO << "TLS closed";
|
||||||
|
recv(nullptr);
|
||||||
|
} else {
|
||||||
|
PLOG_ERROR << "TLS handshake failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
|
||||||
|
TlsTransport *t =
|
||||||
|
static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
|
||||||
|
|
||||||
|
if (where & SSL_CB_ALERT) {
|
||||||
|
if (ret != 256) { // Close Notify
|
||||||
|
PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
|
||||||
|
}
|
||||||
|
t->mIncomingQueue.stop(); // Close the connection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
83
src/tlstransport.hpp
Normal file
83
src/tlstransport.hpp
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef RTC_TLS_TRANSPORT_H
|
||||||
|
#define RTC_TLS_TRANSPORT_H
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "include.hpp"
|
||||||
|
#include "queue.hpp"
|
||||||
|
#include "transport.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#if USE_GNUTLS
|
||||||
|
#include <gnutls/gnutls.h>
|
||||||
|
#else
|
||||||
|
#include <openssl/ssl.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
class TcpTransport;
|
||||||
|
|
||||||
|
class TlsTransport : public Transport {
|
||||||
|
public:
|
||||||
|
static void Init();
|
||||||
|
static void Cleanup();
|
||||||
|
|
||||||
|
TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
|
||||||
|
~TlsTransport();
|
||||||
|
|
||||||
|
bool stop() override;
|
||||||
|
bool send(message_ptr message) override;
|
||||||
|
|
||||||
|
void incoming(message_ptr message) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void runRecvLoop();
|
||||||
|
|
||||||
|
Queue<message_ptr> mIncomingQueue;
|
||||||
|
std::thread mRecvThread;
|
||||||
|
|
||||||
|
#if USE_GNUTLS
|
||||||
|
gnutls_session_t mSession;
|
||||||
|
|
||||||
|
static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
|
||||||
|
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
|
||||||
|
static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
|
||||||
|
#else
|
||||||
|
SSL_CTX *mCtx;
|
||||||
|
SSL *mSsl;
|
||||||
|
BIO *mInBio, *mOutBio;
|
||||||
|
|
||||||
|
static int TransportExIndex;
|
||||||
|
|
||||||
|
static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
|
||||||
|
static void InfoCallback(const SSL *ssl, int where, int ret);
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
@ -32,7 +32,13 @@ using namespace std::placeholders;
|
|||||||
|
|
||||||
class Transport {
|
class Transport {
|
||||||
public:
|
public:
|
||||||
Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
|
enum class State { Disconnected, Connecting, Connected, Completed, Failed };
|
||||||
|
using state_callback = std::function<void(State state)>;
|
||||||
|
|
||||||
|
Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
|
||||||
|
: mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
|
||||||
|
}
|
||||||
|
|
||||||
virtual ~Transport() {
|
virtual ~Transport() {
|
||||||
stop();
|
stop();
|
||||||
if (mLower)
|
if (mLower)
|
||||||
@ -49,11 +55,16 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
|
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
|
||||||
|
State state() const { return mState; }
|
||||||
|
|
||||||
virtual bool send(message_ptr message) { return outgoing(message); }
|
virtual bool send(message_ptr message) { return outgoing(message); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void recv(message_ptr message) { mRecvCallback(message); }
|
void recv(message_ptr message) { mRecvCallback(message); }
|
||||||
|
void changeState(State state) {
|
||||||
|
if (mState.exchange(state) != state)
|
||||||
|
mStateChangeCallback(state);
|
||||||
|
}
|
||||||
|
|
||||||
virtual void incoming(message_ptr message) { recv(message); }
|
virtual void incoming(message_ptr message) { recv(message); }
|
||||||
virtual bool outgoing(message_ptr message) {
|
virtual bool outgoing(message_ptr message) {
|
||||||
@ -65,7 +76,10 @@ protected:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Transport> mLower;
|
std::shared_ptr<Transport> mLower;
|
||||||
|
synchronized_callback<State> mStateChangeCallback;
|
||||||
synchronized_callback<message_ptr> mRecvCallback;
|
synchronized_callback<message_ptr> mRecvCallback;
|
||||||
|
|
||||||
|
std::atomic<State> mState = State::Disconnected;
|
||||||
std::atomic<bool> mShutdown = false;
|
std::atomic<bool> mShutdown = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
311
src/websocket.cpp
Normal file
311
src/websocket.cpp
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "include.hpp"
|
||||||
|
#include "websocket.hpp"
|
||||||
|
|
||||||
|
#include "tcptransport.hpp"
|
||||||
|
#include "tlstransport.hpp"
|
||||||
|
#include "wstransport.hpp"
|
||||||
|
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <winsock2.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
WebSocket::WebSocket() {}
|
||||||
|
|
||||||
|
WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
|
||||||
|
|
||||||
|
WebSocket::~WebSocket() { remoteClose(); }
|
||||||
|
|
||||||
|
WebSocket::State WebSocket::readyState() const { return mState; }
|
||||||
|
|
||||||
|
void WebSocket::open(const string &url) {
|
||||||
|
if (mState != State::Closed)
|
||||||
|
throw std::runtime_error("WebSocket must be closed before opening");
|
||||||
|
|
||||||
|
static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
|
||||||
|
static std::regex regex(rs, std::regex::extended);
|
||||||
|
|
||||||
|
std::smatch match;
|
||||||
|
if (!std::regex_match(url, match, regex))
|
||||||
|
throw std::invalid_argument("Malformed WebSocket URL: " + url);
|
||||||
|
|
||||||
|
mScheme = match[2];
|
||||||
|
if (mScheme != "ws" && mScheme != "wss")
|
||||||
|
throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
|
||||||
|
|
||||||
|
mHost = match[4];
|
||||||
|
if (auto pos = mHost.find(':'); pos != string::npos) {
|
||||||
|
mHostname = mHost.substr(0, pos);
|
||||||
|
mService = mHost.substr(pos + 1);
|
||||||
|
} else {
|
||||||
|
mHostname = mHost;
|
||||||
|
mService = mScheme == "ws" ? "80" : "443";
|
||||||
|
}
|
||||||
|
|
||||||
|
mPath = match[5];
|
||||||
|
if (string query = match[7]; !query.empty())
|
||||||
|
mPath += "?" + query;
|
||||||
|
|
||||||
|
changeState(State::Connecting);
|
||||||
|
initTcpTransport();
|
||||||
|
}
|
||||||
|
|
||||||
|
void WebSocket::close() {
|
||||||
|
auto state = mState.load();
|
||||||
|
if (state == State::Connecting || state == State::Open) {
|
||||||
|
changeState(State::Closing);
|
||||||
|
if (auto transport = std::atomic_load(&mWsTransport))
|
||||||
|
transport->close();
|
||||||
|
else
|
||||||
|
changeState(State::Closed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void WebSocket::remoteClose() {
|
||||||
|
close();
|
||||||
|
closeTransports();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WebSocket::send(const std::variant<binary, string> &data) {
|
||||||
|
return std::visit(
|
||||||
|
[&](const auto &d) {
|
||||||
|
using T = std::decay_t<decltype(d)>;
|
||||||
|
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
|
||||||
|
auto *b = reinterpret_cast<const byte *>(d.data());
|
||||||
|
return outgoing(std::make_shared<Message>(b, b + d.size(), type));
|
||||||
|
},
|
||||||
|
data);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WebSocket::isOpen() const { return mState == State::Open; }
|
||||||
|
|
||||||
|
bool WebSocket::isClosed() const { return mState == State::Closed; }
|
||||||
|
|
||||||
|
size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
|
||||||
|
|
||||||
|
std::optional<std::variant<binary, string>> WebSocket::receive() {
|
||||||
|
while (!mRecvQueue.empty()) {
|
||||||
|
auto message = *mRecvQueue.pop();
|
||||||
|
switch (message->type) {
|
||||||
|
case Message::String:
|
||||||
|
return std::make_optional(
|
||||||
|
string(reinterpret_cast<const char *>(message->data()), message->size()));
|
||||||
|
case Message::Binary:
|
||||||
|
return std::make_optional(std::move(*message));
|
||||||
|
default:
|
||||||
|
// Ignore
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
|
||||||
|
|
||||||
|
bool WebSocket::changeState(State state) { return mState.exchange(state) != state; }
|
||||||
|
|
||||||
|
bool WebSocket::outgoing(mutable_message_ptr message) {
|
||||||
|
if (mState != State::Open || !mWsTransport)
|
||||||
|
throw std::runtime_error("WebSocket is not open");
|
||||||
|
|
||||||
|
if (message->size() > maxMessageSize())
|
||||||
|
throw std::runtime_error("Message size exceeds limit");
|
||||||
|
|
||||||
|
return mWsTransport->send(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
void WebSocket::incoming(message_ptr message) {
|
||||||
|
if (message->type == Message::String || message->type == Message::Binary) {
|
||||||
|
mRecvQueue.push(message);
|
||||||
|
triggerAvailable(mRecvQueue.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
|
||||||
|
using State = TcpTransport::State;
|
||||||
|
try {
|
||||||
|
std::lock_guard lock(mInitMutex);
|
||||||
|
if (auto transport = std::atomic_load(&mTcpTransport))
|
||||||
|
return transport;
|
||||||
|
|
||||||
|
auto transport = std::make_shared<TcpTransport>(
|
||||||
|
mHostname, mService, [this, weak_this = weak_from_this()](State state) {
|
||||||
|
auto shared_this = weak_this.lock();
|
||||||
|
if (!shared_this)
|
||||||
|
return;
|
||||||
|
switch (state) {
|
||||||
|
case State::Connected:
|
||||||
|
if (mScheme == "ws")
|
||||||
|
initWsTransport();
|
||||||
|
else
|
||||||
|
initTlsTransport();
|
||||||
|
break;
|
||||||
|
case State::Failed:
|
||||||
|
triggerError("TCP connection failed");
|
||||||
|
remoteClose();
|
||||||
|
break;
|
||||||
|
case State::Disconnected:
|
||||||
|
remoteClose();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Ignore
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
std::atomic_store(&mTcpTransport, transport);
|
||||||
|
if (mState == WebSocket::State::Closed) {
|
||||||
|
mTcpTransport.reset();
|
||||||
|
transport->stop();
|
||||||
|
throw std::runtime_error("Connection is closed");
|
||||||
|
}
|
||||||
|
return transport;
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << e.what();
|
||||||
|
remoteClose();
|
||||||
|
throw std::runtime_error("TCP transport initialization failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
|
||||||
|
using State = TlsTransport::State;
|
||||||
|
try {
|
||||||
|
std::lock_guard lock(mInitMutex);
|
||||||
|
if (auto transport = std::atomic_load(&mTlsTransport))
|
||||||
|
return transport;
|
||||||
|
|
||||||
|
auto lower = std::atomic_load(&mTcpTransport);
|
||||||
|
auto transport = std::make_shared<TlsTransport>(
|
||||||
|
lower, mHost, [this, weak_this = weak_from_this()](State state) {
|
||||||
|
auto shared_this = weak_this.lock();
|
||||||
|
if (!shared_this)
|
||||||
|
return;
|
||||||
|
switch (state) {
|
||||||
|
case State::Connected:
|
||||||
|
initWsTransport();
|
||||||
|
break;
|
||||||
|
case State::Failed:
|
||||||
|
triggerError("TCP connection failed");
|
||||||
|
remoteClose();
|
||||||
|
break;
|
||||||
|
case State::Disconnected:
|
||||||
|
remoteClose();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Ignore
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
std::atomic_store(&mTlsTransport, transport);
|
||||||
|
if (mState == WebSocket::State::Closed) {
|
||||||
|
mTlsTransport.reset();
|
||||||
|
transport->stop();
|
||||||
|
throw std::runtime_error("Connection is closed");
|
||||||
|
}
|
||||||
|
return transport;
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << e.what();
|
||||||
|
remoteClose();
|
||||||
|
throw std::runtime_error("TLS transport initialization failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
|
||||||
|
using State = WsTransport::State;
|
||||||
|
try {
|
||||||
|
std::lock_guard lock(mInitMutex);
|
||||||
|
if (auto transport = std::atomic_load(&mWsTransport))
|
||||||
|
return transport;
|
||||||
|
|
||||||
|
std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
|
||||||
|
if (!lower)
|
||||||
|
lower = std::atomic_load(&mTcpTransport);
|
||||||
|
auto transport = std::make_shared<WsTransport>(
|
||||||
|
lower, mHost, mPath, weak_bind(&WebSocket::incoming, this, _1),
|
||||||
|
[this, weak_this = weak_from_this()](State state) {
|
||||||
|
auto shared_this = weak_this.lock();
|
||||||
|
if (!shared_this)
|
||||||
|
return;
|
||||||
|
switch (state) {
|
||||||
|
case State::Connected:
|
||||||
|
if (mState == WebSocket::State::Connecting) {
|
||||||
|
PLOG_DEBUG << "WebSocket open";
|
||||||
|
changeState(WebSocket::State::Open);
|
||||||
|
triggerOpen();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case State::Failed:
|
||||||
|
triggerError("WebSocket connection failed");
|
||||||
|
remoteClose();
|
||||||
|
break;
|
||||||
|
case State::Disconnected:
|
||||||
|
remoteClose();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Ignore
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
std::atomic_store(&mWsTransport, transport);
|
||||||
|
if (mState == WebSocket::State::Closed) {
|
||||||
|
mWsTransport.reset();
|
||||||
|
transport->stop();
|
||||||
|
throw std::runtime_error("Connection is closed");
|
||||||
|
}
|
||||||
|
return transport;
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << e.what();
|
||||||
|
remoteClose();
|
||||||
|
throw std::runtime_error("WebSocket transport initialization failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void WebSocket::closeTransports() {
|
||||||
|
changeState(State::Closed);
|
||||||
|
|
||||||
|
// Pass the references to a thread, allowing to terminate a transport from its own thread
|
||||||
|
auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
|
||||||
|
auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
|
||||||
|
auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
|
||||||
|
if (ws || tls || tcp) {
|
||||||
|
std::thread t([ws, tls, tcp]() mutable {
|
||||||
|
if (ws)
|
||||||
|
ws->stop();
|
||||||
|
if (tls)
|
||||||
|
tls->stop();
|
||||||
|
if (tcp)
|
||||||
|
tcp->stop();
|
||||||
|
|
||||||
|
ws.reset();
|
||||||
|
tls.reset();
|
||||||
|
tcp.reset();
|
||||||
|
});
|
||||||
|
t.detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
372
src/wstransport.cpp
Normal file
372
src/wstransport.cpp
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "wstransport.hpp"
|
||||||
|
#include "tcptransport.hpp"
|
||||||
|
#include "tlstransport.hpp"
|
||||||
|
|
||||||
|
#include "base64.hpp"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <random>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <winsock2.h>
|
||||||
|
#else
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef htonll
|
||||||
|
#define htonll(x) \
|
||||||
|
((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
|
||||||
|
#endif
|
||||||
|
#ifndef ntohll
|
||||||
|
#define ntohll(x) htonll(x)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
using namespace std::chrono;
|
||||||
|
using std::to_integer;
|
||||||
|
using std::to_string;
|
||||||
|
|
||||||
|
using random_bytes_engine =
|
||||||
|
std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
|
||||||
|
|
||||||
|
WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
|
||||||
|
message_callback recvCallback, state_callback stateCallback)
|
||||||
|
: Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) {
|
||||||
|
onRecv(recvCallback);
|
||||||
|
|
||||||
|
PLOG_DEBUG << "Initializing WebSocket transport";
|
||||||
|
|
||||||
|
registerIncoming();
|
||||||
|
sendHttpRequest();
|
||||||
|
}
|
||||||
|
|
||||||
|
WsTransport::~WsTransport() { stop(); }
|
||||||
|
|
||||||
|
bool WsTransport::stop() {
|
||||||
|
if (!Transport::stop())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WsTransport::send(message_ptr message) {
|
||||||
|
if (!message)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Call the mutable message overload with a copy
|
||||||
|
return send(std::make_shared<Message>(*message));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WsTransport::send(mutable_message_ptr message) {
|
||||||
|
if (!message || state() != State::Connected)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
PLOG_VERBOSE << "Send size=" << message->size();
|
||||||
|
|
||||||
|
return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
|
||||||
|
message->size(), true, true});
|
||||||
|
}
|
||||||
|
|
||||||
|
void WsTransport::incoming(message_ptr message) {
|
||||||
|
try {
|
||||||
|
mBuffer.insert(mBuffer.end(), message->begin(), message->end());
|
||||||
|
|
||||||
|
if (state() == State::Connecting) {
|
||||||
|
if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
|
||||||
|
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
|
||||||
|
PLOG_INFO << "WebSocket open";
|
||||||
|
changeState(State::Connected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state() == State::Connected) {
|
||||||
|
Frame frame = {};
|
||||||
|
while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
|
||||||
|
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
|
||||||
|
recvFrame(frame);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
PLOG_ERROR << e.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state() == State::Connected) {
|
||||||
|
PLOG_INFO << "WebSocket disconnected";
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
recv(nullptr);
|
||||||
|
} else {
|
||||||
|
PLOG_ERROR << "WebSocket handshake failed";
|
||||||
|
changeState(State::Failed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void WsTransport::close() {
|
||||||
|
if (state() == State::Connected) {
|
||||||
|
sendFrame({CLOSE, NULL, 0, true, true});
|
||||||
|
PLOG_INFO << "WebSocket closing";
|
||||||
|
changeState(State::Completed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WsTransport::sendHttpRequest() {
|
||||||
|
changeState(State::Connecting);
|
||||||
|
|
||||||
|
auto seed = system_clock::now().time_since_epoch().count();
|
||||||
|
random_bytes_engine generator(seed);
|
||||||
|
|
||||||
|
binary key(16);
|
||||||
|
std::generate(reinterpret_cast<uint8_t *>(key.data()),
|
||||||
|
reinterpret_cast<uint8_t *>(key.data() + key.size()), generator);
|
||||||
|
|
||||||
|
const string request = "GET " + mPath +
|
||||||
|
" HTTP/1.1\r\n"
|
||||||
|
"Host: " +
|
||||||
|
mHost +
|
||||||
|
"\r\n"
|
||||||
|
"Connection: Upgrade\r\n"
|
||||||
|
"Upgrade: websocket\r\n"
|
||||||
|
"Sec-WebSocket-Version: 13\r\n"
|
||||||
|
"Sec-WebSocket-Key: " +
|
||||||
|
to_base64(key) +
|
||||||
|
"\r\n"
|
||||||
|
"\r\n";
|
||||||
|
|
||||||
|
auto data = reinterpret_cast<const byte *>(request.data());
|
||||||
|
auto size = request.size();
|
||||||
|
return outgoing(make_message(data, data + size));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
|
||||||
|
std::list<string> lines;
|
||||||
|
auto begin = reinterpret_cast<const char *>(buffer);
|
||||||
|
auto end = begin + size;
|
||||||
|
auto cur = begin;
|
||||||
|
while (true) {
|
||||||
|
auto last = cur;
|
||||||
|
cur = std::find(cur, end, '\n');
|
||||||
|
if (cur == end)
|
||||||
|
return 0;
|
||||||
|
string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
|
||||||
|
if (line.empty())
|
||||||
|
break;
|
||||||
|
lines.emplace_back(std::move(line));
|
||||||
|
}
|
||||||
|
size_t length = cur - begin;
|
||||||
|
|
||||||
|
if (lines.empty())
|
||||||
|
throw std::runtime_error("Invalid HTTP response for WebSocket");
|
||||||
|
|
||||||
|
string status = std::move(lines.front());
|
||||||
|
lines.pop_front();
|
||||||
|
|
||||||
|
std::istringstream ss(status);
|
||||||
|
string protocol;
|
||||||
|
unsigned int code = 0;
|
||||||
|
ss >> protocol >> code;
|
||||||
|
PLOG_DEBUG << "WebSocket response code: " << code;
|
||||||
|
if (code != 101)
|
||||||
|
throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
|
||||||
|
|
||||||
|
std::multimap<string, string> headers;
|
||||||
|
for (const auto &line : lines) {
|
||||||
|
if (size_t pos = line.find_first_of(':'); pos != string::npos) {
|
||||||
|
string key = line.substr(0, pos);
|
||||||
|
string value = line.substr(line.find_first_not_of(' ', pos + 1));
|
||||||
|
std::transform(key.begin(), key.end(), key.begin(),
|
||||||
|
[](char c) { return std::tolower(c); });
|
||||||
|
headers.emplace(std::move(key), std::move(value));
|
||||||
|
} else {
|
||||||
|
headers.emplace(line, "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto h = headers.find("upgrade");
|
||||||
|
if (h == headers.end() || h->second != "websocket")
|
||||||
|
throw std::runtime_error("WebSocket update header missing or mismatching");
|
||||||
|
|
||||||
|
h = headers.find("sec-websocket-accept");
|
||||||
|
if (h == headers.end())
|
||||||
|
throw std::runtime_error("WebSocket accept header missing");
|
||||||
|
|
||||||
|
// TODO: Verify Sec-WebSocket-Accept
|
||||||
|
|
||||||
|
return length;
|
||||||
|
}
|
||||||
|
|
||||||
|
// http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol
|
||||||
|
//
|
||||||
|
// 0 1 2 3
|
||||||
|
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||||
|
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||||
|
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||||
|
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||||
|
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||||
|
// | |1|2|3| |K| | |
|
||||||
|
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||||
|
// | Extended payload length continued, if payload len == 127 |
|
||||||
|
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||||
|
// | | Masking-key, if MASK set to 1 |
|
||||||
|
// +-------------------------------+-------------------------------+
|
||||||
|
// | Masking-key (continued) | Payload Data |
|
||||||
|
// +-------------------------------+ - - - - - - - - - - - - - - - +
|
||||||
|
// : Payload Data continued ... :
|
||||||
|
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||||
|
// | Payload Data continued ... |
|
||||||
|
// +---------------------------------------------------------------+
|
||||||
|
|
||||||
|
size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
|
||||||
|
const byte *end = buffer + size;
|
||||||
|
if (end - buffer < 2)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
byte *cur = buffer;
|
||||||
|
auto b1 = to_integer<uint8_t>(*cur++);
|
||||||
|
auto b2 = to_integer<uint8_t>(*cur++);
|
||||||
|
|
||||||
|
frame.fin = (b1 & 0x80) != 0;
|
||||||
|
frame.mask = (b2 & 0x80) != 0;
|
||||||
|
frame.opcode = static_cast<Opcode>(b1 & 0x0F);
|
||||||
|
frame.length = b2 & 0x7F;
|
||||||
|
|
||||||
|
if (frame.length == 0x7E) {
|
||||||
|
if (end - cur < 2)
|
||||||
|
return 0;
|
||||||
|
frame.length = ntohs(*reinterpret_cast<const uint16_t *>(cur));
|
||||||
|
cur += 2;
|
||||||
|
} else if (frame.length == 0x7F) {
|
||||||
|
if (end - cur < 8)
|
||||||
|
return false;
|
||||||
|
frame.length = ntohll(*reinterpret_cast<const uint64_t *>(cur));
|
||||||
|
cur += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
const byte *maskingKey = nullptr;
|
||||||
|
if (frame.mask) {
|
||||||
|
if (end - cur < 4)
|
||||||
|
return 0;
|
||||||
|
maskingKey = cur;
|
||||||
|
cur += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (end - cur < frame.length)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
frame.payload = cur;
|
||||||
|
if (maskingKey)
|
||||||
|
for (size_t i = 0; i < frame.length; ++i)
|
||||||
|
frame.payload[i] ^= maskingKey[i % 4];
|
||||||
|
|
||||||
|
return end - buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void WsTransport::recvFrame(const Frame &frame) {
|
||||||
|
switch (frame.opcode) {
|
||||||
|
case TEXT_FRAME:
|
||||||
|
case BINARY_FRAME: {
|
||||||
|
if (!mPartial.empty()) {
|
||||||
|
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
|
||||||
|
recv(make_message(mPartial.begin(), mPartial.end(), type));
|
||||||
|
mPartial.clear();
|
||||||
|
}
|
||||||
|
if (frame.fin) {
|
||||||
|
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
|
||||||
|
recv(make_message(frame.payload, frame.payload + frame.length));
|
||||||
|
} else {
|
||||||
|
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
|
||||||
|
mPartialOpcode = frame.opcode;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CONTINUATION: {
|
||||||
|
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
|
||||||
|
if (frame.fin) {
|
||||||
|
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
|
||||||
|
recv(make_message(mPartial.begin(), mPartial.end()));
|
||||||
|
mPartial.clear();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case PING: {
|
||||||
|
sendFrame({PONG, frame.payload, frame.length, true, true});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case PONG: {
|
||||||
|
// TODO
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CLOSE: {
|
||||||
|
close();
|
||||||
|
PLOG_INFO << "WebSocket closed";
|
||||||
|
changeState(State::Disconnected);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
close();
|
||||||
|
throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WsTransport::sendFrame(const Frame &frame) {
|
||||||
|
byte buffer[14];
|
||||||
|
byte *cur = buffer;
|
||||||
|
|
||||||
|
*cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0));
|
||||||
|
|
||||||
|
if (frame.length < 0x7E) {
|
||||||
|
*cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0));
|
||||||
|
} else if (frame.length <= 0xFF) {
|
||||||
|
*cur++ = byte(0x7E | (frame.mask ? 0x80 : 0));
|
||||||
|
*reinterpret_cast<uint16_t *>(cur) = uint16_t(frame.length);
|
||||||
|
cur += 2;
|
||||||
|
} else {
|
||||||
|
*cur++ = byte(0x7F | (frame.mask ? 0x80 : 0));
|
||||||
|
*reinterpret_cast<uint64_t *>(cur) = uint64_t(frame.length);
|
||||||
|
cur += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (frame.mask) {
|
||||||
|
auto seed = system_clock::now().time_since_epoch().count();
|
||||||
|
random_bytes_engine generator(seed);
|
||||||
|
|
||||||
|
auto *maskingKey = cur;
|
||||||
|
std::generate(reinterpret_cast<uint8_t *>(maskingKey),
|
||||||
|
reinterpret_cast<uint8_t *>(maskingKey + 4), generator);
|
||||||
|
cur += 4;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < frame.length; ++i)
|
||||||
|
frame.payload[i] ^= maskingKey[i % 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
outgoing(make_message(buffer, cur)); // header
|
||||||
|
return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
83
src/wstransport.hpp
Normal file
83
src/wstransport.hpp
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2020 Paul-Louis Ageneau
|
||||||
|
*
|
||||||
|
* This library is free software; you can redistribute it and/or
|
||||||
|
* modify it under the terms of the GNU Lesser General Public
|
||||||
|
* License as published by the Free Software Foundation; either
|
||||||
|
* version 2.1 of the License, or (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This library is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||||
|
* Lesser General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Lesser General Public
|
||||||
|
* License along with this library; if not, write to the Free Software
|
||||||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef RTC_WS_TRANSPORT_H
|
||||||
|
#define RTC_WS_TRANSPORT_H
|
||||||
|
|
||||||
|
#if RTC_ENABLE_WEBSOCKET
|
||||||
|
|
||||||
|
#include "include.hpp"
|
||||||
|
#include "transport.hpp"
|
||||||
|
|
||||||
|
namespace rtc {
|
||||||
|
|
||||||
|
class TcpTransport;
|
||||||
|
class TlsTransport;
|
||||||
|
|
||||||
|
class WsTransport : public Transport {
|
||||||
|
public:
|
||||||
|
WsTransport(std::shared_ptr<Transport> lower, string host, string path,
|
||||||
|
message_callback recvCallback, state_callback stateCallback);
|
||||||
|
~WsTransport();
|
||||||
|
|
||||||
|
bool stop() override;
|
||||||
|
bool send(message_ptr message) override;
|
||||||
|
bool send(mutable_message_ptr message);
|
||||||
|
|
||||||
|
void incoming(message_ptr message) override;
|
||||||
|
|
||||||
|
void close();
|
||||||
|
|
||||||
|
private:
|
||||||
|
enum Opcode : uint8_t {
|
||||||
|
CONTINUATION = 0,
|
||||||
|
TEXT_FRAME = 1,
|
||||||
|
BINARY_FRAME = 2,
|
||||||
|
CLOSE = 8,
|
||||||
|
PING = 9,
|
||||||
|
PONG = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Frame {
|
||||||
|
Opcode opcode = BINARY_FRAME;
|
||||||
|
byte *payload = nullptr;
|
||||||
|
size_t length = 0;
|
||||||
|
bool fin = true;
|
||||||
|
bool mask = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool sendHttpRequest();
|
||||||
|
size_t readHttpResponse(const byte *buffer, size_t size);
|
||||||
|
|
||||||
|
size_t readFrame(byte *buffer, size_t size, Frame &frame);
|
||||||
|
void recvFrame(const Frame &frame);
|
||||||
|
bool sendFrame(const Frame &frame);
|
||||||
|
|
||||||
|
const string mHost;
|
||||||
|
const string mPath;
|
||||||
|
|
||||||
|
binary mBuffer;
|
||||||
|
binary mPartial;
|
||||||
|
Opcode mPartialOpcode;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace rtc
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
@ -25,19 +25,19 @@ void test_capi();
|
|||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
try {
|
try {
|
||||||
std::cout << "*** Running connectivity test..." << std::endl;
|
cout << endl << "*** Running connectivity test..." << endl;
|
||||||
test_connectivity();
|
test_connectivity();
|
||||||
std::cout << "*** Finished connectivity test" << std::endl;
|
cout << "*** Finished connectivity test" << endl;
|
||||||
} catch (const exception &e) {
|
} catch (const exception &e) {
|
||||||
std::cerr << "Connectivity test failed: " << e.what() << endl;
|
cerr << "Connectivity test failed: " << e.what() << endl;
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
std::cout << "*** Running C API test..." << std::endl;
|
cout << endl << "*** Running C API test..." << endl;
|
||||||
test_capi();
|
test_capi();
|
||||||
std::cout << "*** Finished C API test" << std::endl;
|
cout << "*** Finished C API test" << endl;
|
||||||
} catch (const exception &e) {
|
} catch (const exception &e) {
|
||||||
std::cerr << "C API test failed: " << e.what() << endl;
|
cerr << "C API test failed: " << e.what() << endl;
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
Reference in New Issue
Block a user