From 0457b679410b036e9616a20412c2127245d4ed10 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Fri, 12 Jul 2024 22:49:14 +0300 Subject: [PATCH] chore: revisit tls_socket code (#296) Signed-off-by: Roman Gershman --- util/tls/tls_engine.cc | 48 ++++++-------- util/tls/tls_engine.h | 9 ++- util/tls/tls_socket.cc | 139 +++++++++++++++++++---------------------- util/tls/tls_socket.h | 1 + 4 files changed, 94 insertions(+), 103 deletions(-) diff --git a/util/tls/tls_engine.cc b/util/tls/tls_engine.cc index c4d9653e..890e0c36 100644 --- a/util/tls/tls_engine.cc +++ b/util/tls/tls_engine.cc @@ -41,36 +41,28 @@ static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* locat return nonstd::make_unexpected(error); } - int want = SSL_want(ssl); - - if (want == SSL_NOTHING) { - int ssl_error = SSL_get_error(ssl, result); - int io_err = errno; - - switch (ssl_error) { - case SSL_ERROR_ZERO_RETURN: - break; - case SSL_ERROR_SYSCALL: - LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location; - break; - case SSL_ERROR_SSL: - LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location; - break; - default: - LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location; - break; - } - - return Engine::EOF_STREAM; + int ssl_error = SSL_get_error(ssl, result); + int io_err = errno; + + switch (ssl_error) { + case SSL_ERROR_ZERO_RETURN: + break; + case SSL_ERROR_WANT_READ: + return Engine::NEED_READ_AND_MAYBE_WRITE; + case SSL_ERROR_WANT_WRITE: + VLOG(1) << "SSL_ERROR_WANT_WRITE " << location; + return Engine::NEED_WRITE; + case SSL_ERROR_SYSCALL: + LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location; + break; + case SSL_ERROR_SSL: + LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location; + break; + default: + LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location; + break; } - if (SSL_WRITING == want) - return Engine::NEED_WRITE; - if (SSL_READING == want) - return Engine::NEED_READ_AND_MAYBE_WRITE; - - LOG(ERROR) << "Unsupported want value " << want << ", ssl_error: " << SSL_get_error(ssl, result); - return Engine::EOF_STREAM; } diff --git a/util/tls/tls_engine.h b/util/tls/tls_engine.h index c45a85f9..b5af8697 100644 --- a/util/tls/tls_engine.h +++ b/util/tls/tls_engine.h @@ -17,6 +17,13 @@ class Engine { enum HandshakeType { CLIENT = 1, SERVER = 2 }; enum OpCode { EOF_STREAM = -1, + + // We use BIO buffers, therefore any SSL operation can end up writing to the internal BIO + // and result in success, even though the data has not been flushed to the underlying socket. + // See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html + // As a result, we must flush output buffer (if OutputPending() > 0)if before we do any + // Socket reads. We could flush after each SSL operation but that would result in fragmented + // Socket writes which we want to avoid. NEED_READ_AND_MAYBE_WRITE = -2, NEED_WRITE = -3, }; @@ -89,7 +96,7 @@ class Engine { void CommitInput(unsigned sz); // Returns size of pending data that needs to be flushed out from SSL to I/O. - // See https://www.openssl.org/docs/man1.1.0/man3/BIO_new_bio_pair.html + // See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html // Specifically, warning that says: "An application must not rely on the error value of // SSL_operation() but must assure that the write buffer is always flushed first". size_t OutputPending() const { diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index df5cb39b..d2c81b4d 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -91,7 +91,7 @@ auto TlsSocket::Shutdown(int how) -> error_code { Engine::OpResult op_result = engine_->Shutdown(); if (op_result) { // engine_ could send notification messages to the peer. - MaybeSendOutput(); + std::ignore = MaybeSendOutput(); } // In any case we should also shutdown the underlying TCP socket without relying on the @@ -132,14 +132,10 @@ auto TlsSocket::Accept() -> AcceptResult { if (op_val >= 0) { // Shutdown or empty read/write may return 0. break; } - if (op_val == Engine::EOF_STREAM) { - return make_unexpected(make_error_code(errc::connection_reset)); - } - if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleSocketRead(); - if (ec) - return make_unexpected(ec); - } + + ec = HandleOp(op_val); + if (ec) + return make_unexpected(ec); } return nullptr; @@ -162,28 +158,14 @@ error_code TlsSocket::Connect(const endpoint_type& endpoint, // Flush the ssl data to the socket and run the loop that ensures handshaking converges. int op_val = *op_result; - error_code ec; // it should guide us to write and then read. DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE); while (op_val < 0) { - if (op_val == Engine::EOF_STREAM) { - return make_error_code(errc::connection_reset); - } + error_code ec = HandleOp(op_val); + if (ec) + return ec; - if (op_val == Engine::NEED_WRITE) { - ec = HandleSocketWrite(); - if (ec) - return ec; - } else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleSocketWrite(); - if (ec) - return ec; - - ec = HandleSocketRead(); - if (ec) - return ec; - } op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); if (!op_result) { return std::error_code(op_result.error(), std::system_category()); @@ -191,7 +173,11 @@ error_code TlsSocket::Connect(const endpoint_type& endpoint, op_val = *op_result; } - return ec; + const auto* cipher = SSL_get_current_cipher(engine_->native_handle()); + VLOG(1) << "SSL handshake success, chosen " << SSL_CIPHER_get_name(cipher) << "/" + << SSL_CIPHER_get_version(cipher); + + return {}; } auto TlsSocket::Close() -> error_code { @@ -245,11 +231,6 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { return make_unexpected(SSL2Error(op_result.error())); } - error_code ec = MaybeSendOutput(); - if (ec) { - return make_unexpected(ec); - } - int op_val = *op_result; if (spin_count.Check(op_val <= 0)) { // Once every 30 seconds. @@ -267,26 +248,18 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { ++io; --io_len; if (io_len == 0) - break; + break; // Finished reading everything. dest = Engine::MutableBuffer{reinterpret_cast(io->iov_base), io->iov_len}; } - continue; // We read everything we asked for - lets retry. + // We read everything we asked for but there are still buffers left to fill. + continue; } break; } - if (read_total) // if we read something lets return it before we handle other states. - break; - - if (op_val == Engine::EOF_STREAM) { - return make_unexpected(make_error_code(errc::connection_reset)); - } - - if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleSocketRead(); - if (ec) - return make_unexpected(ec); - } + error_code ec = HandleOp(op_val); + if (ec) + return make_unexpected(ec); } return read_total; } @@ -307,12 +280,12 @@ io::Result TlsSocket::WriteSome(const iovec* ptr, uint32_t len) { // Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16. // IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes. constexpr size_t kBufferSize = 1392; - io::Result ec; + io::Result res; size_t total_sent = 0; while (len) { if (ptr->iov_len > kBufferSize || len == 1) { - ec = SendBuffer(Engine::Buffer{reinterpret_cast(ptr->iov_base), ptr->iov_len}); + res = SendBuffer(Engine::Buffer{reinterpret_cast(ptr->iov_base), ptr->iov_len}); ptr++; len--; } else { @@ -324,18 +297,18 @@ io::Result TlsSocket::WriteSome(const iovec* ptr, uint32_t len) { ptr++; len--; } - ec = SendBuffer({scratch, buffered_size}); + res = SendBuffer({scratch, buffered_size}); } - if (!ec.has_value()) { - return ec; - } else { - total_sent += ec.value(); + if (!res) { + return res; } + total_sent += *res; } return total_sent; } io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { + // Sending buffer into ssl. DCHECK(engine_); DCHECK_GT(buf.size(), 0u); @@ -348,17 +321,7 @@ io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { return make_unexpected(SSL2Error(op_result.error())); } - error_code ec = MaybeSendOutput(); - if (ec) { - return make_unexpected(ec); - } - int op_val = *op_result; - if (spin_count.Check(op_val <= 0)) { - // Once every 30 seconds. - LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit() - << " Spins: " << spin_count.Spins(); - } if (op_val > 0) { send_total += op_val; @@ -370,15 +333,15 @@ io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { } } - if (op_val == Engine::EOF_STREAM) { - return make_unexpected(make_error_code(errc::connection_reset)); + if (spin_count.Check(op_val <= 0)) { + // Once every 30 seconds. + LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit() + << " Spins: " << spin_count.Spins(); } - if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleSocketRead(); - if (ec) - return make_unexpected(ec); - } + error_code ec = HandleOp(op_val); + if (ec) + return make_unexpected(ec); } return send_total; @@ -395,6 +358,9 @@ SSL* TlsSocket::ssl_handle() { } auto TlsSocket::MaybeSendOutput() -> error_code { + if (engine_->OutputPending() == 0) + return {}; + // This function is present in both read and write paths. // meaning that both of them can be called concurrently from differrent fibers and then // race over flushing the output buffer. We use state_ to prevent that. @@ -419,6 +385,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code { } auto TlsSocket::HandleSocketRead() -> error_code { + error_code ec = MaybeSendOutput(); + if (ec) + return ec; + if (state_ & READ_IN_PROGRESS) { // We need to Yield because otherwise we might end up in an infinite loop. // See also comments in MaybeSendOutput. @@ -434,6 +404,8 @@ auto TlsSocket::HandleSocketRead() -> error_code { return esz.error(); } + DVLOG(1) << "TlsSocket:Read " << *esz << " bytes"; + engine_->CommitInput(*esz); return error_code{}; @@ -441,26 +413,45 @@ auto TlsSocket::HandleSocketRead() -> error_code { error_code TlsSocket::HandleSocketWrite() { Engine::Buffer buffer = engine_->PeekOutputBuf(); + DCHECK(!buffer.empty()); + + if (buffer.empty()) + return {}; + // we do not allow concurrent writes from multiple fibers. + state_ |= WRITE_IN_PROGRESS; while (!buffer.empty()) { - // we do not allow concurrent writes from multiple fibers. - state_ |= WRITE_IN_PROGRESS; io::Result write_result = next_sock_->WriteSome(buffer); - // Safe to clear here since the code below is atomic fiber-wise. - state_ &= ~WRITE_IN_PROGRESS; DCHECK(engine_); if (!write_result) { + state_ &= ~WRITE_IN_PROGRESS; + return write_result.error(); } CHECK_GT(*write_result, 0u); engine_->ConsumeOutputBuf(*write_result); buffer.remove_prefix(*write_result); } + DCHECK_EQ(engine_->OutputPending(), 0u); + + state_ &= ~WRITE_IN_PROGRESS; return error_code{}; } +error_code TlsSocket::HandleOp(int op_val) { + switch (op_val) { + case Engine::EOF_STREAM: + return make_error_code(errc::connection_reset); + case Engine::NEED_READ_AND_MAYBE_WRITE: + return HandleSocketRead(); + default: + LOG(DFATAL) << "Unsupported " << op_val; + } + return {}; +} + TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const { return next_sock_->LocalEndpoint(); } diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 57d7b087..7479ae4f 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -95,6 +95,7 @@ class TlsSocket final : public FiberSocketBase { error_code HandleSocketRead(); error_code HandleSocketWrite(); + error_code HandleOp(int op); std::unique_ptr next_sock_; std::unique_ptr engine_;