diff --git a/src/tlstransport.cpp b/src/tlstransport.cpp index 188187c..913fd72 100644 --- a/src/tlstransport.cpp +++ b/src/tlstransport.cpp @@ -62,31 +62,38 @@ void TlsTransport::Cleanup() { } TlsTransport::TlsTransport(shared_ptr lower, string host, state_callback callback) - : Transport(lower, std::move(callback)) { + : Transport(lower, std::move(callback)), mHost(std::move(host)) { PLOG_DEBUG << "Initializing TLS transport (GnuTLS)"; + check_gnutls(gnutls_certificate_allocate_credentials(&mCreds)); check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT)); try { + check_gnutls(gnutls_certificate_set_x509_system_trust(mCreds)); + check_gnutls(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds)); + gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0); + const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128"; const char *err_pos = NULL; check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos), "Failed to set TLS priorities"); - gnutls_session_set_ptr(mSession, this); + PLOG_VERBOSE << "Server Name Indication: " << mHost; + gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size()); + + gnutls_session_set_ptr(mSession, this); gnutls_transport_set_ptr(mSession, this); gnutls_transport_set_push_function(mSession, WriteCallback); gnutls_transport_set_pull_function(mSession, ReadCallback); gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); - gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size()); - - mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); + mRecvThread = std::thread(&TlsTransport::runRecvLoop, this); registerIncoming(); } catch (...) { gnutls_deinit(mSession); + gnutls_certificate_free_credentials(mCreds); throw; } } @@ -94,6 +101,7 @@ TlsTransport::TlsTransport(shared_ptr lower, string host, state_ca TlsTransport::~TlsTransport() { stop(); gnutls_deinit(mSession); + gnutls_certificate_free_credentials(mCreds); } bool TlsTransport::stop() { @@ -111,6 +119,9 @@ bool TlsTransport::send(message_ptr message) { return false; PLOG_VERBOSE << "Send size=" << message->size(); + if(message->size() == 0) + return true; + ssize_t ret; do { ret = gnutls_record_send(mSession, message->data(), message->size()); @@ -196,20 +207,37 @@ ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) { TlsTransport *t = static_cast(ptr); - while (auto next = t->mIncomingQueue.pop()) { - auto message = *next; - if (message->size() > 0) { - ssize_t len = std::min(maxlen, message->size()); - std::memcpy(data, message->data(), len); - gnutls_transport_set_errno(t->mSession, 0); - return len; - } - t->recv(message); // Pass zero-sized messages through + message_ptr &message = t->mIncomingMessage; + size_t &position = t->mIncomingMessagePosition; + + if(message && position >= message->size()) + message.reset(); + + if(!message) { + position = 0; + while (auto next = t->mIncomingQueue.pop()) { + message = *next; + if (message->size() > 0) + break; + + t->recv(message); // Pass zero-sized messages through + } + } + + if(message) { + size_t available = message->size() - position; + ssize_t len = std::min(maxlen, available); + std::memcpy(data, message->data() + position, len); + position+= len; + gnutls_transport_set_errno(t->mSession, 0); + return len; + } + else { + // Closed + gnutls_transport_set_errno(t->mSession, 0); + return 0; } - // Closed - gnutls_transport_set_errno(t->mSession, 0); - return 0; } int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) { @@ -308,6 +336,8 @@ TlsTransport::TlsTransport(shared_ptr lower, string host, state_ca throw std::runtime_error("Failed to create SSL instance"); SSL_set_ex_data(mSsl, TransportExIndex, this); + + PLOG_VERBOSE << "Server Name Indication: " << host; SSL_set_tlsext_host_name(mSsl, host.c_str()); SSL_set_connect_state(mSsl); diff --git a/src/tlstransport.hpp b/src/tlstransport.hpp index 6f68b23..4d193db 100644 --- a/src/tlstransport.hpp +++ b/src/tlstransport.hpp @@ -56,10 +56,14 @@ protected: void runRecvLoop(); Queue mIncomingQueue; + message_ptr mIncomingMessage; + size_t mIncomingMessagePosition = 0; std::thread mRecvThread; #if USE_GNUTLS gnutls_session_t mSession; + gnutls_certificate_credentials_t mCreds; + string mHost; static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);