diff --git a/CMakeLists.txt b/CMakeLists.txt index 118c0c0..3f9505b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,7 @@ set(LIBDATACHANNEL_WEBSOCKET_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/verifiedtlstransport.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp ) diff --git a/include/rtc/websocket.hpp b/include/rtc/websocket.hpp index 37aef7c..10ba564 100644 --- a/include/rtc/websocket.hpp +++ b/include/rtc/websocket.hpp @@ -47,7 +47,11 @@ public: Closed = 3, }; - WebSocket(); + struct Configuration { + bool disableTlsVerification = false; + }; + + WebSocket(std::optional config = nullopt); ~WebSocket(); State readyState() const; @@ -82,6 +86,7 @@ private: std::shared_ptr mWsTransport; std::recursive_mutex mInitMutex; + const Configuration mConfig; string mScheme, mHost, mHostname, mService, mPath; std::atomic mState = State::Closed; diff --git a/src/tlstransport.cpp b/src/tlstransport.cpp index 6820118..36b29c7 100644 --- a/src/tlstransport.cpp +++ b/src/tlstransport.cpp @@ -56,7 +56,6 @@ TlsTransport::TlsTransport(shared_ptr lower, string host, state_ca try { gnutls::check(gnutls_certificate_set_x509_system_trust(mCreds)); gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds)); - gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0); const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128"; const char *err_pos = NULL; @@ -72,7 +71,9 @@ TlsTransport::TlsTransport(shared_ptr lower, string host, state_ca gnutls_transport_set_pull_function(mSession, ReadCallback); gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); - mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); + postCreation(); + + mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); registerIncoming(); } catch (...) { @@ -123,6 +124,14 @@ void TlsTransport::incoming(message_ptr message) { mIncomingQueue.stop(); } +void TlsTransport::postCreation() { + // Dummy +} + +void TlsTransport::postHandshake() { + // Dummy +} + void TlsTransport::runRecvLoop() { const size_t bufferSize = 4096; char buffer[bufferSize]; @@ -147,6 +156,7 @@ void TlsTransport::runRecvLoop() { try { PLOG_INFO << "TLS handshake finished"; changeState(State::Connected); + postHandshake(); while (true) { ssize_t ret; @@ -263,25 +273,16 @@ TlsTransport::TlsTransport(shared_ptr lower, string host, state_ca openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), "Failed to set SSL priorities"); + if (!SSL_CTX_set_default_verify_paths(mCtx)) { + PLOG_WARNING << "SSL root CA certificates unavailable"; + } + SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3); SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION); SSL_CTX_set_read_ahead(mCtx, 1); SSL_CTX_set_quiet_shutdown(mCtx, 1); SSL_CTX_set_info_callback(mCtx, InfoCallback); - - // SSL_CTX_set_default_verify_paths() does nothing on Windows -#ifndef _WIN32 - if (SSL_CTX_set_default_verify_paths(mCtx)) { -#else - if (false) { -#endif - PLOG_INFO << "SSL root CA certificates available, server verification enabled"; - SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL); - SSL_CTX_set_verify_depth(mCtx, 4); - } else { - PLOG_WARNING << "SSL root CA certificates unavailable, server verification disabled"; - SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL); - } + SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL); if (!(mSsl = SSL_new(mCtx))) throw std::runtime_error("Failed to create SSL instance"); @@ -308,6 +309,8 @@ TlsTransport::TlsTransport(shared_ptr lower, string host, state_ca SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE); SSL_set_tmp_ecdh(mSsl, ecdh.get()); + postCreation(); + mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); registerIncoming(); @@ -366,6 +369,14 @@ void TlsTransport::incoming(message_ptr message) { mIncomingQueue.stop(); } +void TlsTransport::postCreation() { + // Dummy +} + +void TlsTransport::postHandshake() { + // Dummy +} + void TlsTransport::runRecvLoop() { const size_t bufferSize = 4096; byte buffer[bufferSize]; @@ -387,6 +398,7 @@ void TlsTransport::runRecvLoop() { if (SSL_is_init_finished(mSsl)) { PLOG_INFO << "TLS handshake finished"; changeState(State::Connected); + postHandshake(); } } else { int ret = SSL_read(mSsl, buffer, bufferSize); diff --git a/src/tlstransport.hpp b/src/tlstransport.hpp index d5192c7..8899ca2 100644 --- a/src/tlstransport.hpp +++ b/src/tlstransport.hpp @@ -38,14 +38,15 @@ public: static void Cleanup(); TlsTransport(std::shared_ptr lower, string host, state_callback callback); - ~TlsTransport(); + virtual ~TlsTransport(); bool stop() override; bool send(message_ptr message) override; - void incoming(message_ptr message) override; - protected: + virtual void incoming(message_ptr message) override; + virtual void postCreation(); + virtual void postHandshake(); void runRecvLoop(); string mHost; diff --git a/src/verifiedtlstransport.cpp b/src/verifiedtlstransport.cpp new file mode 100644 index 0000000..b367372 --- /dev/null +++ b/src/verifiedtlstransport.cpp @@ -0,0 +1,68 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "verifiedtlstransport.hpp" +#include "include.hpp" + +#if RTC_ENABLE_WEBSOCKET + +using std::shared_ptr; +using std::string; +using std::unique_ptr; +using std::weak_ptr; + +namespace rtc { + +#if USE_GNUTLS + +VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr lower, string host, + state_callback callback) + : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {} + +VerifiedTlsTransport::~VerifiedTlsTransport() {} + +void VerifiedTlsTransport::postCreation() { + gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0); +} + +void VerifiedTlsTransport::postHandshake() { + // Nothing to do +} + +#else // USE_GNUTLS==0 + +VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr lower, string host, + state_callback callback) + : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {} + +VerifiedTlsTransport::~VerifiedTlsTransport() {} + +void VerifiedTlsTransport::postCreation() { + SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL); + SSL_set_verify_depth(mSsl, 4); +} + +void VerifiedTlsTransport::postHandshake() { + // Nothing to do +} + +#endif + +} // namespace rtc + +#endif diff --git a/src/verifiedtlstransport.hpp b/src/verifiedtlstransport.hpp new file mode 100644 index 0000000..af9d712 --- /dev/null +++ b/src/verifiedtlstransport.hpp @@ -0,0 +1,42 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#ifndef RTC_VERIFIED_TLS_TRANSPORT_H +#define RTC_VERIFIED_TLS_TRANSPORT_H + +#include "tlstransport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc { + +class VerifiedTlsTransport final : public TlsTransport { +public: + VerifiedTlsTransport(std::shared_ptr lower, string host, state_callback callback); + ~VerifiedTlsTransport(); + +protected: + void postCreation() override; + void postHandshake() override; +}; + +} // namespace rtc + +#endif + +#endif diff --git a/src/websocket.cpp b/src/websocket.cpp index 2b6171e..711032c 100644 --- a/src/websocket.cpp +++ b/src/websocket.cpp @@ -24,6 +24,7 @@ #include "tcptransport.hpp" #include "tlstransport.hpp" +#include "verifiedtlstransport.hpp" #include "wstransport.hpp" #include @@ -34,7 +35,12 @@ namespace rtc { -WebSocket::WebSocket() { PLOG_VERBOSE << "Creating WebSocket"; } +using std::shared_ptr; + +WebSocket::WebSocket(std::optional config) + : mConfig(config ? std::move(*config) : Configuration()) { + PLOG_VERBOSE << "Creating WebSocket"; +} WebSocket::~WebSocket() { PLOG_VERBOSE << "Destroying WebSocket"; @@ -149,7 +155,7 @@ void WebSocket::incoming(message_ptr message) { } } -std::shared_ptr WebSocket::initTcpTransport() { +shared_ptr WebSocket::initTcpTransport() { using State = TcpTransport::State; try { std::lock_guard lock(mInitMutex); @@ -194,7 +200,7 @@ std::shared_ptr WebSocket::initTcpTransport() { } } -std::shared_ptr WebSocket::initTlsTransport() { +shared_ptr WebSocket::initTlsTransport() { using State = TlsTransport::State; try { std::lock_guard lock(mInitMutex); @@ -202,27 +208,40 @@ std::shared_ptr WebSocket::initTlsTransport() { return transport; auto lower = std::atomic_load(&mTcpTransport); - auto transport = std::make_shared( - lower, mHost, [this, weak_this = weak_from_this()](State state) { - auto shared_this = weak_this.lock(); - if (!shared_this) - return; - switch (state) { - case State::Connected: - initWsTransport(); - break; - case State::Failed: - triggerError("TCP connection failed"); - remoteClose(); - break; - case State::Disconnected: - remoteClose(); - break; - default: - // Ignore - break; - } - }); + auto stateChangeCallback = [this, weak_this = weak_from_this()](State state) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (state) { + case State::Connected: + initWsTransport(); + break; + case State::Failed: + triggerError("TCP connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }; + + shared_ptr transport; +#ifdef _WIN32 + if (!mConfig.disableTlsVerification) { + PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows"; + } + transport = std::make_shared(lower, mHost, stateChangeCallback); +#else + if (mConfig.disableTlsVerification) + transport = std::make_shared(lower, mHost, stateChangeCallback); + else + transport = std::make_shared(lower, mHost, stateChangeCallback); +#endif + std::atomic_store(&mTlsTransport, transport); if (mState == WebSocket::State::Closed) { mTlsTransport.reset(); @@ -237,14 +256,14 @@ std::shared_ptr WebSocket::initTlsTransport() { } } -std::shared_ptr WebSocket::initWsTransport() { +shared_ptr WebSocket::initWsTransport() { using State = WsTransport::State; try { std::lock_guard lock(mInitMutex); if (auto transport = std::atomic_load(&mWsTransport)) return transport; - std::shared_ptr lower = std::atomic_load(&mTlsTransport); + shared_ptr lower = std::atomic_load(&mTlsTransport); if (!lower) lower = std::atomic_load(&mTcpTransport); auto transport = std::make_shared(