diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a3d81f..6022b94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,8 @@ set(TESTS_ANSWERER_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test/p2p/answerer.cpp ) -set(THREADS_PREFER_PTHREAD_FLAG ON) +set(CMAKE_THREAD_PREFER_PTHREAD TRUE) +set(THREADS_PREFER_PTHREAD_FLAG TRUE) find_package(Threads REQUIRED) add_subdirectory(deps/usrsctp EXCLUDE_FROM_ALL) @@ -92,10 +93,11 @@ set_target_properties(datachannel PROPERTIES CXX_STANDARD 17) target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include) target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc) target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) -target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include) -target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic) +target_link_libraries(datachannel PUBLIC Threads::Threads) +target_link_libraries(datachannel PRIVATE Usrsctp::UsrsctpStatic) add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES}) set_target_properties(datachannel-static PROPERTIES @@ -103,14 +105,15 @@ set_target_properties(datachannel-static PROPERTIES CXX_STANDARD 17) target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include) target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc) target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) -target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include) -target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic) +target_link_libraries(datachannel-static PUBLIC Threads::Threads) +target_link_libraries(datachannel-static PRIVATE Usrsctp::UsrsctpStatic) if(WIN32) - target_link_libraries(datachannel "wsock32" "ws2_32") # winsock2 - target_link_libraries(datachannel-static "wsock32" "ws2_32") # winsock2 + target_link_libraries(datachannel PRIVATE wsock32 ws2_32) # winsock2 + target_link_libraries(datachannel-static PRIVATE wsock32 ws2_32) # winsock2 endif() if (USE_GNUTLS) @@ -124,29 +127,29 @@ if (USE_GNUTLS) IMPORTED_LOCATION "${GNUTLS_LIBRARIES}") endif() target_compile_definitions(datachannel PRIVATE USE_GNUTLS=1) - target_link_libraries(datachannel GnuTLS::GnuTLS) + target_link_libraries(datachannel PRIVATE GnuTLS::GnuTLS) target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=1) - target_link_libraries(datachannel-static GnuTLS::GnuTLS) + target_link_libraries(datachannel-static PRIVATE GnuTLS::GnuTLS) else() find_package(OpenSSL REQUIRED) target_compile_definitions(datachannel PRIVATE USE_GNUTLS=0) - target_link_libraries(datachannel OpenSSL::SSL) + target_link_libraries(datachannel PRIVATE OpenSSL::SSL) target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=0) - target_link_libraries(datachannel-static OpenSSL::SSL) + target_link_libraries(datachannel-static PRIVATE OpenSSL::SSL) endif() if (USE_JUICE) add_subdirectory(deps/libjuice EXCLUDE_FROM_ALL) target_compile_definitions(datachannel PRIVATE USE_JUICE=1) - target_link_libraries(datachannel LibJuice::LibJuiceStatic) + target_link_libraries(datachannel PRIVATE LibJuice::LibJuiceStatic) target_compile_definitions(datachannel-static PRIVATE USE_JUICE=1) - target_link_libraries(datachannel-static LibJuice::LibJuiceStatic) + target_link_libraries(datachannel-static PRIVATE LibJuice::LibJuiceStatic) else() find_package(LibNice REQUIRED) target_compile_definitions(datachannel PRIVATE USE_JUICE=0) - target_link_libraries(datachannel LibNice::LibNice) + target_link_libraries(datachannel PRIVATE LibNice::LibNice) target_compile_definitions(datachannel-static PRIVATE USE_JUICE=0) - target_link_libraries(datachannel-static LibNice::LibNice) + target_link_libraries(datachannel-static PRIVATE LibNice::LibNice) endif() add_library(LibDataChannel::LibDataChannel ALIAS datachannel) diff --git a/include/rtc/peerconnection.hpp b/include/rtc/peerconnection.hpp index 334d63a..eea7930 100644 --- a/include/rtc/peerconnection.hpp +++ b/include/rtc/peerconnection.hpp @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -44,6 +45,9 @@ class IceTransport; class DtlsTransport; class SctpTransport; +using certificate_ptr = std::shared_ptr; +using future_certificate_ptr = std::shared_future; + class PeerConnection : public std::enable_shared_from_this { public: enum class State : int { @@ -126,7 +130,7 @@ private: void resetCallbacks(); const Configuration mConfig; - const std::shared_ptr mCertificate; + const future_certificate_ptr mCertificate; std::optional mLocalDescription, mRemoteDescription; mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; diff --git a/src/certificate.cpp b/src/certificate.cpp index 52dcf76..ca35734 100644 --- a/src/certificate.cpp +++ b/src/certificate.cpp @@ -141,14 +141,9 @@ string make_fingerprint(gnutls_x509_crt_t crt) { return oss.str(); } -shared_ptr make_certificate(const string &commonName) { - static std::unordered_map> cache; - static std::mutex cacheMutex; - - std::lock_guard lock(cacheMutex); - if (auto it = cache.find(commonName); it != cache.end()) - return it->second; +namespace { +certificate_ptr make_certificate_impl(string commonName) { std::unique_ptr crt(create_crt(), delete_crt); std::unique_ptr privkey(create_privkey(), delete_privkey); @@ -174,11 +169,11 @@ shared_ptr make_certificate(const string &commonName) { check_gnutls(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0), "Unable to auto-sign certificate"); - auto certificate = std::make_shared(*crt, *privkey); - cache.emplace(std::make_pair(commonName, certificate)); - return certificate; + return std::make_shared(*crt, *privkey); } +} // namespace + } // namespace rtc #else @@ -236,15 +231,9 @@ string make_fingerprint(X509 *x509) { return oss.str(); } +namespace { -shared_ptr make_certificate(const string &commonName) { - static std::unordered_map> cache; - static std::mutex cacheMutex; - - std::lock_guard lock(cacheMutex); - if (auto it = cache.find(commonName); it != cache.end()) - return it->second; - +certificate_ptr make_certificate_impl(string commonName) { shared_ptr x509(X509_new(), X509_free); shared_ptr pkey(EVP_PKEY_new(), EVP_PKEY_free); @@ -281,12 +270,54 @@ shared_ptr make_certificate(const string &commonName) { if (!X509_sign(x509.get(), pkey.get(), EVP_sha256())) throw std::runtime_error("Unable to auto-sign certificate"); - auto certificate = std::make_shared(x509, pkey); - cache.emplace(std::make_pair(commonName, certificate)); - return certificate; + return std::make_shared(x509, pkey); } +} // namespace + } // namespace rtc #endif +// Common for GnuTLS and OpenSSL + +namespace rtc { + +namespace { + +// Helper function roughly equivalent to std::async with policy std::launch::async +// since std::async might be unreliable on some platforms (e.g. Mingw32 on Windows) +template +std::future(std::decay_t...)>> thread_call(F &&f, + Args &&... args) { + using R = std::result_of_t(std::decay_t...)>; + std::packaged_task task(std::bind(f, std::forward(args)...)); + std::future future = task.get_future(); + std::thread t(std::move(task)); + t.detach(); + return future; +} + +static std::unordered_map CertificateCache; +static std::mutex CertificateCacheMutex; + +} // namespace + +future_certificate_ptr make_certificate(string commonName) { + std::lock_guard lock(CertificateCacheMutex); + + if (auto it = CertificateCache.find(commonName); it != CertificateCache.end()) + return it->second; + + auto future = thread_call(make_certificate_impl, commonName); + auto shared = future.share(); + CertificateCache.emplace(std::move(commonName), shared); + return shared; +} + +void CleanupCertificateCache() { + std::lock_guard lock(CertificateCacheMutex); + CertificateCache.clear(); +} + +} // namespace rtc diff --git a/src/certificate.hpp b/src/certificate.hpp index cda2fb5..fd5affb 100644 --- a/src/certificate.hpp +++ b/src/certificate.hpp @@ -21,6 +21,7 @@ #include "include.hpp" +#include #include #if USE_GNUTLS @@ -62,7 +63,12 @@ string make_fingerprint(gnutls_x509_crt_t crt); string make_fingerprint(X509 *x509); #endif -std::shared_ptr make_certificate(const string &commonName); +using certificate_ptr = std::shared_ptr; +using future_certificate_ptr = std::shared_future; + +future_certificate_ptr make_certificate(string commonName); // cached + +void CleanupCertificateCache(); } // namespace rtc diff --git a/src/dtlstransport.cpp b/src/dtlstransport.cpp index 7723130..8d7d842 100644 --- a/src/dtlstransport.cpp +++ b/src/dtlstransport.cpp @@ -63,9 +63,8 @@ void DtlsTransport::Cleanup() { // Nothing to do } -DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, - verifier_callback verifierCallback, - state_callback stateChangeCallback) +DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower), mCertificate(certificate), mState(State::Disconnected), mVerifierCallback(std::move(verifierCallback)), mStateChangeCallback(std::move(stateChangeCallback)) { diff --git a/src/dtlstransport.hpp b/src/dtlstransport.hpp index d0e6030..c8200b9 100644 --- a/src/dtlstransport.hpp +++ b/src/dtlstransport.hpp @@ -51,7 +51,7 @@ public: using verifier_callback = std::function; using state_callback = std::function; - DtlsTransport(std::shared_ptr lower, std::shared_ptr certificate, + DtlsTransport(std::shared_ptr lower, certificate_ptr certificate, verifier_callback verifierCallback, state_callback stateChangeCallback); ~DtlsTransport(); @@ -65,7 +65,7 @@ private: void changeState(State state); void runRecvLoop(); - const std::shared_ptr mCertificate; + const certificate_ptr mCertificate; Queue mIncomingQueue; std::atomic mState; diff --git a/src/init.cpp b/src/init.cpp index 38afc1a..2fbfcea 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -18,6 +18,7 @@ #include "init.hpp" +#include "certificate.hpp" #include "dtlstransport.hpp" #include "sctptransport.hpp" @@ -74,6 +75,7 @@ Init::Init() { } Init::~Init() { + CleanupCertificateCache(); DtlsTransport::Cleanup(); SctpTransport::Cleanup(); diff --git a/src/peerconnection.cpp b/src/peerconnection.cpp index 234b3eb..82794b6 100644 --- a/src/peerconnection.cpp +++ b/src/peerconnection.cpp @@ -269,9 +269,10 @@ shared_ptr PeerConnection::initDtlsTransport() { if (auto transport = std::atomic_load(&mDtlsTransport)) return transport; + auto certificate = mCertificate.get(); auto lower = std::atomic_load(&mIceTransport); auto transport = std::make_shared( - lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1), + lower, certificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1), [this, weak_this = weak_from_this()](DtlsTransport::State state) { auto shared_this = weak_this.lock(); if (!shared_this) @@ -513,9 +514,11 @@ void PeerConnection::processLocalDescription(Description description) { if (auto remote = remoteDescription()) remoteSctpPort = remote->sctpPort(); + auto certificate = mCertificate.get(); // wait for certificate if not ready + std::lock_guard lock(mLocalDescriptionMutex); mLocalDescription.emplace(std::move(description)); - mLocalDescription->setFingerprint(mCertificate->fingerprint()); + mLocalDescription->setFingerprint(certificate->fingerprint()); mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT)); mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);