Fixed TLS layer with GnuTLS

This commit is contained in:
Paul-Louis Ageneau
2020-06-01 00:43:41 +02:00
parent a47ddc5838
commit 755b3e9dac
2 changed files with 51 additions and 17 deletions

View File

@ -62,31 +62,38 @@ void TlsTransport::Cleanup() {
} }
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback) TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
: Transport(lower, std::move(callback)) { : Transport(lower, std::move(callback)), mHost(std::move(host)) {
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)"; PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
check_gnutls(gnutls_certificate_allocate_credentials(&mCreds));
check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT)); check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
try { try {
check_gnutls(gnutls_certificate_set_x509_system_trust(mCreds));
check_gnutls(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128"; const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
const char *err_pos = NULL; const char *err_pos = NULL;
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos), check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
"Failed to set TLS priorities"); "Failed to set TLS priorities");
gnutls_session_set_ptr(mSession, this); PLOG_VERBOSE << "Server Name Indication: " << mHost;
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size());
gnutls_session_set_ptr(mSession, this);
gnutls_transport_set_ptr(mSession, this); gnutls_transport_set_ptr(mSession, this);
gnutls_transport_set_push_function(mSession, WriteCallback); gnutls_transport_set_push_function(mSession, WriteCallback);
gnutls_transport_set_pull_function(mSession, ReadCallback); gnutls_transport_set_pull_function(mSession, ReadCallback);
gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size()); mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
registerIncoming(); registerIncoming();
} catch (...) { } catch (...) {
gnutls_deinit(mSession); gnutls_deinit(mSession);
gnutls_certificate_free_credentials(mCreds);
throw; throw;
} }
} }
@ -94,6 +101,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
TlsTransport::~TlsTransport() { TlsTransport::~TlsTransport() {
stop(); stop();
gnutls_deinit(mSession); gnutls_deinit(mSession);
gnutls_certificate_free_credentials(mCreds);
} }
bool TlsTransport::stop() { bool TlsTransport::stop() {
@ -111,6 +119,9 @@ bool TlsTransport::send(message_ptr message) {
return false; return false;
PLOG_VERBOSE << "Send size=" << message->size(); PLOG_VERBOSE << "Send size=" << message->size();
if(message->size() == 0)
return true;
ssize_t ret; ssize_t ret;
do { do {
ret = gnutls_record_send(mSession, message->data(), message->size()); ret = gnutls_record_send(mSession, message->data(), message->size());
@ -196,20 +207,37 @@ ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data
ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) { ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
TlsTransport *t = static_cast<TlsTransport *>(ptr); TlsTransport *t = static_cast<TlsTransport *>(ptr);
while (auto next = t->mIncomingQueue.pop()) {
auto message = *next;
if (message->size() > 0) {
ssize_t len = std::min(maxlen, message->size());
std::memcpy(data, message->data(), len);
gnutls_transport_set_errno(t->mSession, 0);
return len;
}
t->recv(message); // Pass zero-sized messages through message_ptr &message = t->mIncomingMessage;
size_t &position = t->mIncomingMessagePosition;
if(message && position >= message->size())
message.reset();
if(!message) {
position = 0;
while (auto next = t->mIncomingQueue.pop()) {
message = *next;
if (message->size() > 0)
break;
t->recv(message); // Pass zero-sized messages through
}
}
if(message) {
size_t available = message->size() - position;
ssize_t len = std::min(maxlen, available);
std::memcpy(data, message->data() + position, len);
position+= len;
gnutls_transport_set_errno(t->mSession, 0);
return len;
}
else {
// Closed
gnutls_transport_set_errno(t->mSession, 0);
return 0;
} }
// Closed
gnutls_transport_set_errno(t->mSession, 0);
return 0;
} }
int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) { int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
@ -308,6 +336,8 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
throw std::runtime_error("Failed to create SSL instance"); throw std::runtime_error("Failed to create SSL instance");
SSL_set_ex_data(mSsl, TransportExIndex, this); SSL_set_ex_data(mSsl, TransportExIndex, this);
PLOG_VERBOSE << "Server Name Indication: " << host;
SSL_set_tlsext_host_name(mSsl, host.c_str()); SSL_set_tlsext_host_name(mSsl, host.c_str());
SSL_set_connect_state(mSsl); SSL_set_connect_state(mSsl);

View File

@ -56,10 +56,14 @@ protected:
void runRecvLoop(); void runRecvLoop();
Queue<message_ptr> mIncomingQueue; Queue<message_ptr> mIncomingQueue;
message_ptr mIncomingMessage;
size_t mIncomingMessagePosition = 0;
std::thread mRecvThread; std::thread mRecvThread;
#if USE_GNUTLS #if USE_GNUTLS
gnutls_session_t mSession; gnutls_session_t mSession;
gnutls_certificate_credentials_t mCreds;
string mHost;
static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);