Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not review] feat: TlsSocket AsyncWriteSome and AsyncReadSome #376

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress
async_write_pending_ = 1;
}

// TODO implement async functionality
void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epoll AsyncRead is synchronous

auto res = ReadSome(v, len);
cb(res);
Expand Down
1 change: 1 addition & 0 deletions util/tls/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ add_library(tls_lib tls_engine.cc tls_socket.cc)

cxx_link(tls_lib fibers2 OpenSSL::SSL)
cxx_test(tls_engine_test tls_lib LABELS CI)
cxx_test(tls_socket_test tls_lib LABELS CI)
249 changes: 241 additions & 8 deletions util/tls/tls_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,23 +329,217 @@ io::Result<TlsSocket::PushResult> TlsSocket::PushToEngine(const iovec* ptr, uint
return res;
}

// TODO: to implement async functionality.
void TlsSocket::HandleOpAsync(int op_val) {
switch (op_val) {
case Engine::EOF_STREAM:
VLOG(1) << "EOF_STREAM received " << next_sock_->native_handle();
async_write_req_->caller_completion_cb(
make_unexpected(make_error_code(errc::connection_aborted)));
break;
case Engine::NEED_READ_AND_MAYBE_WRITE:
HandleUpstreamAsyncRead();
break;
case Engine::NEED_WRITE:
MaybeSendOutputAsync();
break;
default:
LOG(DFATAL) << "Unsupported " << op_val;
}
}

void TlsSocket::AsyncWriteReq::Run() {
if (state == AsyncWriteReq::PushToEngine) {
io::Result<PushResult> push_res = owner->PushToEngine(vec, len);
if (!push_res) {
caller_completion_cb(make_unexpected(push_res.error()));
return;
}
last_push = *push_res;
state = AsyncWriteReq::HandleOpAsync;
}

if (state == AsyncWriteReq::HandleOpAsync) {
state = AsyncWriteReq::MaybeSendOutputAsync;
if (last_push.engine_opcode < 0) {
owner->HandleOpAsync(last_push.engine_opcode);
}
}

if (state == AsyncWriteReq::MaybeSendOutputAsync) {
state = AsyncWriteReq::PushToEngine;
if (last_push.written > 0) {
DCHECK(!continuation);
continuation = [this]() {
state = AsyncWriteReq::Done;
caller_completion_cb(last_push.written);
};
owner->MaybeSendOutputAsync();
}
}
}

void TlsSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
io::Result<size_t> res = WriteSome(v, len);
cb(res);
CHECK(!async_write_req_.has_value());
async_write_req_.emplace(AsyncWriteReq(this, std::move(cb), v, len));
async_write_req_->Run();
}

void TlsSocket::AsyncReadReq::Run() {
DCHECK_GT(len, 0u);

while (true) {
DCHECK(!dest.empty());

size_t read_len = std::min(dest.size(), size_t(INT_MAX));

Engine::OpResult op_result = owner->engine_->Read(dest.data(), read_len);

int op_val = op_result;

DVLOG(2) << "Engine::Read " << dest.size() << " bytes, got " << op_val;

if (op_val > 0) {
read_total += op_val;

// I do not understand this code and what the hell I meant to do here. Seems to work
// though.
if (size_t(op_val) == read_len) {
if (size_t(op_val) < dest.size()) {
dest.remove_prefix(op_val);
} else {
++vec;
--len;
if (len == 0) {
// We are done. Call completion callback.
caller_completion_cb(read_total);
return;
}
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(vec->iov_base), vec->iov_len};
}
// We read everything we asked for but there are still buffers left to fill.
continue;
}
break;
}

// Will automatically call Run()
owner->HandleOpAsync(op_val);
}

// We are done. Call completion callback.
caller_completion_cb(read_total);

// clean up so we can queue more reads
}

// TODO: to implement async functionality.
void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
io::Result<size_t> res = ReadSome(v, len);
cb(res);
CHECK(!async_read_req_.has_value());
auto req = AsyncReadReq(this, std::move(cb), v, len);
req.dest = {reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len};
async_read_req_.emplace(std::move(req));
async_read_req_->Run();
}

SSL* TlsSocket::ssl_handle() {
return engine_ ? engine_->native_handle() : nullptr;
}

void TlsSocket::HandleUpstreamAsyncWrite(io::Result<size_t> write_result, Engine::Buffer buffer) {
if (!write_result) {
state_ &= ~WRITE_IN_PROGRESS;

// broken_pipe - happens when the other side closes the connection. do not log this.
if (write_result.error() != errc::broken_pipe) {
VSOCK(1) << "HandleUpstreamWrite failed " << write_result.error();
}

// We are done. Errornous exit.
async_write_req_->caller_completion_cb(write_result);
return;
}

CHECK_GT(*write_result, 0u);
upstream_write_ += *write_result;
engine_->ConsumeOutputBuf(*write_result);
buffer.remove_prefix(*write_result);

// We are not done. Re-arm the async write until we drive it to completion or error.
if (!buffer.empty()) {
auto& scratch = async_write_req_->scratch_iovec;
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
scratch.iov_len = buffer.size();
next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) {
HandleUpstreamAsyncWrite(write_result, buffer);
});
}

if (engine_->OutputPending() > 0) {
LOG(INFO) << "ssl buffer is not empty with " << engine_->OutputPending()
<< " bytes. short write detected";
}

state_ &= ~WRITE_IN_PROGRESS;

// If there is a continuation run it and let it yield back to the main loop
if (async_write_req_->continuation) {
auto cont = std::exchange(async_write_req_->continuation, std::function<void()>{});
cont();
return;
}

// Yield back to main loop
return async_write_req_->Run();
}

void TlsSocket::StartUpstreamWrite() {
Engine::Buffer buffer = engine_->PeekOutputBuf();
DCHECK(!buffer.empty());
DCHECK((state_ & WRITE_IN_PROGRESS) == 0);

if (buffer.empty()) {
// We are done
return;
}

DVLOG(2) << "HandleUpstreamWrite " << buffer.size();
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;

auto& scratch = async_write_req_->scratch_iovec;
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
scratch.iov_len = buffer.size();

next_sock_->AsyncWriteSome(&scratch, 1, [this, buffer](auto write_result) {
HandleUpstreamAsyncWrite(write_result, buffer);
});
}

void TlsSocket::MaybeSendOutputAsync() {
if (engine_->OutputPending() == 0) {
if (async_write_req_->continuation) {
auto cont = std::exchange(async_write_req_->continuation, std::function<void()>{});
cont();
return;
}
async_write_req_->Run();
}

// This function is present in both read and write paths.
// meaning that both of them can be called concurrently from differrent fibers and then
// race over flushing the output buffer. We use state_ to prevent that.
if (state_ & WRITE_IN_PROGRESS) {
if (async_write_req_->continuation) {
// TODO we must "yield" -> subscribe as a continuation to the write request cause otherwise
// we might deadlock. See the sync version of HandleOp for more info
auto cont = std::exchange(async_write_req_->continuation, std::function<void()>{});
cont();
return;
}
}

StartUpstreamWrite();
}

auto TlsSocket::MaybeSendOutput() -> error_code {
if (engine_->OutputPending() == 0)
return {};
Expand Down Expand Up @@ -373,6 +567,46 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
return HandleUpstreamWrite();
}

void TlsSocket::StartUpstreamRead() {
auto buffer = engine_->PeekInputBuf();
state_ |= READ_IN_PROGRESS;

auto& scratch = async_write_req_->scratch_iovec;
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
scratch.iov_len = buffer.size();

next_sock_->AsyncReadSome(&scratch, 1, [this](auto read_result) {
state_ &= ~READ_IN_PROGRESS;
if (!read_result) {
// log any errors as well as situations where we have unflushed output.
if (read_result.error() != errc::connection_aborted || engine_->OutputPending() > 0) {
VSOCK(1) << "HandleUpstreamRead failed " << read_result.error();
}
// Erronous path. Apply the completion callback and exit.
async_write_req_->caller_completion_cb(read_result);
return;
}

DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes";
engine_->CommitInput(*read_result);
// We are not done. Give back control to the main loop.
async_write_req_->Run();
});
}

void TlsSocket::HandleUpstreamAsyncRead() {
auto on_success = [this]() {
if (state_ & READ_IN_PROGRESS) {
async_write_req_->Run();
}

StartUpstreamRead();
};

async_write_req_->continuation = on_success;
MaybeSendOutputAsync();
}

auto TlsSocket::HandleUpstreamRead() -> error_code {
RETURN_ON_ERROR(MaybeSendOutput());

Expand Down Expand Up @@ -481,8 +715,7 @@ unsigned TlsSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
}

void TlsSocket::ReturnProvided(const ProvidedBuffer& pbuf) {
proactor()->DeallocateBuffer(
io::MutableBytes{const_cast<uint8_t*>(pbuf.start), pbuf.allocated});
proactor()->DeallocateBuffer(io::MutableBytes{const_cast<uint8_t*>(pbuf.start), pbuf.allocated});
}

bool TlsSocket::IsUDS() const {
Expand Down
61 changes: 60 additions & 1 deletion util/tls/tls_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <openssl/ssl.h>

#include <memory>
#include <optional>

#include "util/fiber_socket_base.h"
#include "util/tls/tls_engine.h"
Expand Down Expand Up @@ -90,7 +91,6 @@ class TlsSocket final : public FiberSocketBase {
virtual void SetProactor(ProactorBase* p) override;

private:

struct PushResult {
size_t written = 0;
int engine_opcode = 0; // Engine::OpCode
Expand All @@ -114,6 +114,65 @@ class TlsSocket final : public FiberSocketBase {
std::unique_ptr<Engine> engine_;
size_t upstream_write_ = 0;

struct AsyncReqBase {
AsyncReqBase(TlsSocket* owner, io::AsyncProgressCb caller_cb, const iovec* vec, uint32_t len)
: owner(owner), caller_completion_cb(std::move(caller_cb)), vec(vec), len(len) {
}

TlsSocket* owner;
// Callback passed from the user.
io::AsyncProgressCb caller_completion_cb;

const iovec* vec;
uint32_t len;

std::function<void()> continuation;
};

struct AsyncWriteReq : AsyncReqBase {
using AsyncReqBase::AsyncReqBase;

iovec scratch_iovec;
// TODO simplify state transitions
// TODO handle async yields to avoid deadlocks (see HandleOp)
enum State { PushToEngine, HandleOpAsync, MaybeSendOutputAsync, Done };
State state = PushToEngine;
PushResult last_push;

// Main loop
void Run();
};

friend AsyncWriteReq;

struct AsyncReadReq : AsyncReqBase {
using AsyncReqBase::AsyncReqBase;

Engine::MutableBuffer dest;
size_t read_total = 0;

// Main loop
void Run();
};

friend AsyncReadReq;

// Asynchronous helpers
void MaybeSendOutputAsync();

void HandleUpstreamAsyncWrite(io::Result<size_t> write_result, Engine::Buffer buffer);
void HandleUpstreamAsyncRead();

void HandleOpAsync(int op_val);

void StartUpstreamWrite();
void StartUpstreamRead();

// TODO clean up the optional before we yield such that progress callback can dispatch another
// async operation
std::optional<AsyncWriteReq> async_write_req_;
std::optional<AsyncReadReq> async_read_req_;

enum { WRITE_IN_PROGRESS = 1, READ_IN_PROGRESS = 2, SHUTDOWN_IN_PROGRESS = 4, SHUTDOWN_DONE = 8 };
uint8_t state_{0};
};
Expand Down
Loading
Loading