diff --git a/util/fibers/epoll_socket.cc b/util/fibers/epoll_socket.cc index e9536fc1..1aaca377 100644 --- a/util/fibers/epoll_socket.cc +++ b/util/fibers/epoll_socket.cc @@ -388,6 +388,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress async_write_pending_ = 1; } +// TODO implement async functionality void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { auto res = ReadSome(v, len); cb(res); diff --git a/util/tls/CMakeLists.txt b/util/tls/CMakeLists.txt index 21fcba63..fb2fd27f 100644 --- a/util/tls/CMakeLists.txt +++ b/util/tls/CMakeLists.txt @@ -4,3 +4,4 @@ add_library(tls_lib tls_engine.cc tls_socket.cc) cxx_link(tls_lib fibers2 OpenSSL::SSL) cxx_test(tls_engine_test tls_lib LABELS CI) +cxx_test(tls_socket_test tls_lib LABELS CI) diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index 67229d57..5d2bb2e4 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -329,23 +329,217 @@ io::Result TlsSocket::PushToEngine(const iovec* ptr, uint return res; } -// TODO: to implement async functionality. +void TlsSocket::HandleOpAsync(int op_val) { + switch (op_val) { + case Engine::EOF_STREAM: + VLOG(1) << "EOF_STREAM received " << next_sock_->native_handle(); + async_write_req_->caller_completion_cb( + make_unexpected(make_error_code(errc::connection_aborted))); + break; + case Engine::NEED_READ_AND_MAYBE_WRITE: + HandleUpstreamAsyncRead(); + break; + case Engine::NEED_WRITE: + MaybeSendOutputAsync(); + break; + default: + LOG(DFATAL) << "Unsupported " << op_val; + } +} + +void TlsSocket::AsyncWriteReq::Run() { + if (state == AsyncWriteReq::PushToEngine) { + io::Result push_res = owner->PushToEngine(vec, len); + if (!push_res) { + caller_completion_cb(make_unexpected(push_res.error())); + return; + } + last_push = *push_res; + state = AsyncWriteReq::HandleOpAsync; + } + + if (state == AsyncWriteReq::HandleOpAsync) { + state = AsyncWriteReq::MaybeSendOutputAsync; + if (last_push.engine_opcode < 0) { + owner->HandleOpAsync(last_push.engine_opcode); + } + } + + if (state == AsyncWriteReq::MaybeSendOutputAsync) { + state = AsyncWriteReq::PushToEngine; + if (last_push.written > 0) { + DCHECK(!continuation); + continuation = [this]() { + state = AsyncWriteReq::Done; + caller_completion_cb(last_push.written); + }; + owner->MaybeSendOutputAsync(); + } + } +} + void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { - io::Result res = WriteSome(v, len); - cb(res); + CHECK(!async_write_req_.has_value()); + async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len)); + async_write_req_->Run(); } +void TlsSocket::AsyncReadReq::Run() { + DCHECK_GT(len, 0u); + + while (true) { + DCHECK(!dest.empty()); + + size_t read_len = std::min(dest.size(), size_t(INT_MAX)); + + Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len); + + int op_val = op_result; + + DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val; + + if (op_val > 0) { + read_total += op_val; + + // I do not understand this code and what the hell I meant to do here. Seems to work + // though. + if (size_t(op_val) == read_len) { + if (size_t(op_val) < dest.size()) { + dest.remove_prefix(op_val); + } else { + ++vec; + --len; + if (len == 0) { + // We are done. Call completion callback. + caller_completion_cb(read_total); + return; + } + dest = Engine::MutableBuffer{reinterpret_cast(vec->iov_base), vec->iov_len}; + } + // We read everything we asked for but there are still buffers left to fill. + continue; + } + break; + } + + // Will automatically call Run() + owner->HandleOpAsync(op_val); + } + + // We are done. Call completion callback. + caller_completion_cb(read_total); + + // clean up so we can queue more reads +} -// TODO: to implement async functionality. void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) { - io::Result res = ReadSome(v, len); - cb(res); + CHECK(!async_read_req_.has_value()); + auto req = AsyncReadReq(this, std::move(cb), v, len); + req.dest = {reinterpret_cast(v->iov_base), v->iov_len}; + async_read_req_.emplace(std::move(req)); + async_read_req_->Run(); } SSL* TlsSocket::ssl_handle() { return engine_ ? engine_->native_handle() : nullptr; } +void TlsSocket::HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer) { + if (!write_result) { + state_ &= ~WRITE_IN_PROGRESS; + + // broken_pipe - happens when the other side closes the connection. do not log this. + if (write_result.error() != errc::broken_pipe) { + VSOCK(1) << "HandleUpstreamWrite failed " << write_result.error(); + } + + // We are done. Errornous exit. + async_write_req_->caller_completion_cb(write_result); + return; + } + + CHECK_GT(*write_result, 0u); + upstream_write_ += *write_result; + engine_->ConsumeOutputBuf(*write_result); + buffer.remove_prefix(*write_result); + + // We are not done. Re-arm the async write until we drive it to completion or error. + if (!buffer.empty()) { + auto& scratch = async_write_req_->scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + HandleUpstreamAsyncWrite(write_result, buffer); + }); + } + + if (engine_->OutputPending() > 0) { + LOG(INFO) << "ssl buffer is not empty with " << engine_->OutputPending() + << " bytes. short write detected"; + } + + state_ &= ~WRITE_IN_PROGRESS; + + // If there is a continuation run it and let it yield back to the main loop + if (async_write_req_->continuation) { + auto cont = std::exchange(async_write_req_->continuation, std::function{}); + cont(); + return; + } + + // Yield back to main loop + return async_write_req_->Run(); +} + +void TlsSocket::StartUpstreamWrite() { + Engine::Buffer buffer = engine_->PeekOutputBuf(); + DCHECK(!buffer.empty()); + DCHECK((state_ & WRITE_IN_PROGRESS) == 0); + + if (buffer.empty()) { + // We are done + return; + } + + DVLOG(2) << "HandleUpstreamWrite " << buffer.size(); + // we do not allow concurrent writes from multiple fibers. + state_ |= WRITE_IN_PROGRESS; + + auto& scratch = async_write_req_->scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + + next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) { + HandleUpstreamAsyncWrite(write_result, buffer); + }); +} + +void TlsSocket::MaybeSendOutputAsync() { + if (engine_->OutputPending() == 0) { + if (async_write_req_->continuation) { + auto cont = std::exchange(async_write_req_->continuation, std::function{}); + cont(); + return; + } + async_write_req_->Run(); + } + + // This function is present in both read and write paths. + // meaning that both of them can be called concurrently from differrent fibers and then + // race over flushing the output buffer. We use state_ to prevent that. + if (state_ & WRITE_IN_PROGRESS) { + if (async_write_req_->continuation) { + // TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise + // we might deadlock. See the sync version of HandleOp for more info + auto cont = std::exchange(async_write_req_->continuation, std::function{}); + cont(); + return; + } + } + + StartUpstreamWrite(); +} + auto TlsSocket::MaybeSendOutput() -> error_code { if (engine_->OutputPending() == 0) return {}; @@ -373,6 +567,46 @@ auto TlsSocket::MaybeSendOutput() -> error_code { return HandleUpstreamWrite(); } +void TlsSocket::StartUpstreamRead() { + auto buffer = engine_->PeekInputBuf(); + state_ |= READ_IN_PROGRESS; + + auto& scratch = async_write_req_->scratch_iovec; + scratch.iov_base = const_cast(buffer.data()); + scratch.iov_len = buffer.size(); + + next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) { + state_ &= ~READ_IN_PROGRESS; + if (!read_result) { + // log any errors as well as situations where we have unflushed output. + if (read_result.error() != errc::connection_aborted || engine_->OutputPending() > 0) { + VSOCK(1) << "HandleUpstreamRead failed " << read_result.error(); + } + // Erronous path. Apply the completion callback and exit. + async_write_req_->caller_completion_cb(read_result); + return; + } + + DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes"; + engine_->CommitInput(*read_result); + // We are not done. Give back control to the main loop. + async_write_req_->Run(); + }); +} + +void TlsSocket::HandleUpstreamAsyncRead() { + auto on_success = [this]() { + if (state_ & READ_IN_PROGRESS) { + async_write_req_->Run(); + } + + StartUpstreamRead(); + }; + + async_write_req_->continuation = on_success; + MaybeSendOutputAsync(); +} + auto TlsSocket::HandleUpstreamRead() -> error_code { RETURN_ON_ERROR(MaybeSendOutput()); @@ -481,8 +715,7 @@ unsigned TlsSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { } void TlsSocket::ReturnProvided(const ProvidedBuffer& pbuf) { - proactor()->DeallocateBuffer( - io::MutableBytes{const_cast(pbuf.start), pbuf.allocated}); + proactor()->DeallocateBuffer(io::MutableBytes{const_cast(pbuf.start), pbuf.allocated}); } bool TlsSocket::IsUDS() const { diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index 83a470b0..4fbeddbe 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -7,6 +7,7 @@ #include #include +#include #include "util/fiber_socket_base.h" #include "util/tls/tls_engine.h" @@ -90,7 +91,6 @@ class TlsSocket final : public FiberSocketBase { virtual void SetProactor(ProactorBase* p) override; private: - struct PushResult { size_t written = 0; int engine_opcode = 0; // Engine::OpCode @@ -114,6 +114,65 @@ class TlsSocket final : public FiberSocketBase { std::unique_ptr engine_; size_t upstream_write_ = 0; + struct AsyncReqBase { + AsyncReqBase(TlsSocket* owner, io::AsyncProgressCb caller_cb, const iovec* vec, uint32_t len) + : owner(owner), caller_completion_cb(std::move(caller_cb)), vec(vec), len(len) { + } + + TlsSocket* owner; + // Callback passed from the user. + io::AsyncProgressCb caller_completion_cb; + + const iovec* vec; + uint32_t len; + + std::function continuation; + }; + + struct AsyncWriteReq : AsyncReqBase { + using AsyncReqBase::AsyncReqBase; + + iovec scratch_iovec; + // TODO simplify state transitions + // TODO handle async yields to avoid deadlocks (see HandleOp) + enum State { PushToEngine, HandleOpAsync, MaybeSendOutputAsync, Done }; + State state = PushToEngine; + PushResult last_push; + + // Main loop + void Run(); + }; + + friend AsyncWriteReq; + + struct AsyncReadReq : AsyncReqBase { + using AsyncReqBase::AsyncReqBase; + + Engine::MutableBuffer dest; + size_t read_total = 0; + + // Main loop + void Run(); + }; + + friend AsyncReadReq; + + // Asynchronous helpers + void MaybeSendOutputAsync(); + + void HandleUpstreamAsyncWrite(io::Result write_result, Engine::Buffer buffer); + void HandleUpstreamAsyncRead(); + + void HandleOpAsync(int op_val); + + void StartUpstreamWrite(); + void StartUpstreamRead(); + + // TODO clean up the optional before we yield such that progress callback can dispatch another + // async operation + std::optional async_write_req_; + std::optional async_read_req_; + enum { WRITE_IN_PROGRESS = 1, READ_IN_PROGRESS = 2, SHUTDOWN_IN_PROGRESS = 4, SHUTDOWN_DONE = 8 }; uint8_t state_{0}; }; diff --git a/util/tls/tls_socket_test.cc b/util/tls/tls_socket_test.cc new file mode 100644 index 00000000..eae010dc --- /dev/null +++ b/util/tls/tls_socket_test.cc @@ -0,0 +1,246 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#include "util/tls/tls_socket.h" + +#include + +#include + +#include "base/gtest.h" +#include "base/logging.h" +#include "util/fiber_socket_base.h" +#include "util/fibers/fibers.h" +#include "util/fibers/synchronization.h" + +#ifdef __linux__ +#include "util/fibers/uring_proactor.h" +#include "util/fibers/uring_socket.h" +#endif +#include "util/fibers/epoll_proactor.h" + +namespace util { +namespace fb2 { + +constexpr uint32_t kRingDepth = 8; +using namespace testing; + +#ifdef __linux__ +void InitProactor(ProactorBase* p) { + if (p->GetKind() == ProactorBase::IOURING) { + static_cast(p)->Init(0, kRingDepth); + } else { + static_cast(p)->Init(0); + } +} +#else +void InitProactor(ProactorBase* p) { + static_cast(p)->Init(0); +} +#endif + +using namespace std; + +enum TlsContextRole { SERVER, CLIENT }; + +SSL_CTX* CreateSslCntx(TlsContextRole role) { + std::string tls_key_file; + std::string tls_key_cert; + std::string tls_ca_cert_file; + SSL_CTX* ctx; + + if (role == TlsContextRole::SERVER) { + ctx = SSL_CTX_new(TLS_server_method()); + // TODO init those to build on ci + } else { + ctx = SSL_CTX_new(TLS_client_method()); + } + unsigned mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + + bool res = SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM) != 1; + EXPECT_FALSE(res); + res = SSL_CTX_use_certificate_chain_file(ctx, tls_key_cert.c_str()) != 1; + EXPECT_FALSE(res); + res = SSL_CTX_load_verify_locations(ctx, tls_ca_cert_file.data(), nullptr) != 1; + EXPECT_FALSE(res); + res = 1 == SSL_CTX_set_cipher_list(ctx, "DEFAULT"); + EXPECT_TRUE(res); + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); + SSL_CTX_set_verify(ctx, mask, NULL); + SSL_CTX_set_dh_auto(ctx, 1); + return ctx; +} + +class TlsFiberSocketTest : public testing::TestWithParam { + protected: + void SetUp() final; + void TearDown() final; + + static void SetUpTestCase() { + testing::FLAGS_gtest_death_test_style = "threadsafe"; + } + + using IoResult = int; + + // TODO clean up + virtual void HandleRequest() { + tls_socket_ = std::make_unique(conn_socket_.release()); + tls_socket_->InitSSL(CreateSslCntx(SERVER)); + tls_socket_->Accept(); + + uint8_t buf[16]; + auto res = tls_socket_->Recv(buf); + EXPECT_TRUE(res.has_value()); + EXPECT_TRUE(res.value() == 16); + + auto write_res = tls_socket_->Write(buf); + EXPECT_FALSE(write_res); + } + + unique_ptr proactor_; + thread proactor_thread_; + unique_ptr listen_socket_; + unique_ptr conn_socket_; + unique_ptr tls_socket_; + + uint16_t listen_port_ = 0; + Fiber accept_fb_; + Fiber conn_fb_; + std::error_code accept_ec_; + FiberSocketBase::endpoint_type listen_ep_; + uint32_t conn_sock_err_mask_ = 0; +}; + +INSTANTIATE_TEST_SUITE_P(Engines, TlsFiberSocketTest, + testing::Values("epoll" +#ifdef __linux__ + , + "uring" +#endif + ), + [](const auto& info) { return string(info.param); }); + +void TlsFiberSocketTest::SetUp() { +#if __linux__ + bool use_uring = GetParam() == "uring"; + ProactorBase* proactor = nullptr; + if (use_uring) + proactor = new UringProactor; + else + proactor = new EpollProactor; +#else + ProactorBase* proactor = new EpollProactor; +#endif + + atomic_bool init_done{false}; + + proactor_thread_ = thread{[proactor, &init_done] { + InitProactor(proactor); + init_done.store(true, memory_order_release); + proactor->Run(); + }}; + + proactor_.reset(proactor); + + error_code ec = proactor_->AwaitBrief([&] { + listen_socket_.reset(proactor_->CreateSocket()); + return listen_socket_->Listen(0, 0); + }); + + CHECK(!ec); + listen_port_ = listen_socket_->LocalEndpoint().port(); + DCHECK_GT(listen_port_, 0); + + auto address = boost::asio::ip::make_address("127.0.0.1"); + listen_ep_ = FiberSocketBase::endpoint_type{address, listen_port_}; + + accept_fb_ = proactor_->LaunchFiber("AcceptFb", [this] { + auto accept_res = listen_socket_->Accept(); + VLOG_IF(1, !accept_res) << "Accept res: " << accept_res.error(); + + if (accept_res) { + VLOG(1) << "Accepted connection " << *accept_res; + FiberSocketBase* sock = *accept_res; + conn_socket_.reset(sock); + conn_socket_->SetProactor(proactor_.get()); + conn_socket_->RegisterOnErrorCb([this](uint32_t mask) { + LOG(INFO) << "Error mask: " << mask; + conn_sock_err_mask_ = mask; + }); + conn_fb_ = proactor_->LaunchFiber([this]() { HandleRequest(); }); + } else { + accept_ec_ = accept_res.error(); + } + }); +} + +void TlsFiberSocketTest::TearDown() { + VLOG(1) << "TearDown"; + + proactor_->Await([&] { + std::ignore = listen_socket_->Shutdown(SHUT_RDWR); + if (conn_socket_) { + std::ignore = conn_socket_->Close(); + } else { + std::ignore = tls_socket_->Close(); + } + }); + + conn_fb_.JoinIfNeeded(); + accept_fb_.JoinIfNeeded(); + + // We close here because we need to wake up listening socket. + proactor_->Await([&] { std::ignore = listen_socket_->Close(); }); + + proactor_->Stop(); + proactor_thread_.join(); + proactor_.reset(); +} + +TEST_P(TlsFiberSocketTest, Basic) { + unique_ptr tls_sock = std::make_unique(proactor_->CreateSocket()); + tls_sock->InitSSL(CreateSslCntx(CLIENT)); + + LOG(INFO) << "before wait "; + proactor_->Await([&] { + ThisFiber::SetName("ConnectFb"); + + LOG(INFO) << "Connecting to " << listen_ep_; + error_code ec = tls_sock->Connect(listen_ep_); + uint8_t buf[16] = {120}; + VLOG(1) << "Before writesome"; + + Done done; + iovec v{.iov_base = &buf, .iov_len = 16}; + + tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + + // TODO with iouring this max outs the memory and crashes + // TODO investigate why + tls_sock->AsyncReadSome(&v, 1, [done](auto result) mutable { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 16); + done.Notify(); + }); + + done.Wait(); + + VLOG(1) << "closing client sock " << tls_sock->native_handle(); + std::ignore = tls_sock->Close(); + accept_fb_.Join(); + VLOG(1) << "After join"; + ASSERT_FALSE(ec) << ec.message(); + ASSERT_FALSE(accept_ec_); + }); +} + +} // namespace fb2 +} // namespace util