Skip to content

Commit

Permalink
chore: Add provided buffers socket interface (#320)
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Gershman <[email protected]>
Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange authored Oct 12, 2024
1 parent 9dd5659 commit 6765a7a
Show file tree
Hide file tree
Showing 19 changed files with 316 additions and 52 deletions.
2 changes: 1 addition & 1 deletion examples/echo_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ size_t Driver::Run(base::Histogram* dest) {
break;
}

socket_->Shutdown(SHUT_RDWR);
std::ignore = socket_->Shutdown(SHUT_RDWR);
dest->Merge(hist);

return i;
Expand Down
2 changes: 1 addition & 1 deletion util/accept_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class AcceptServerTest : public testing::Test {
void SetUp() override;

void TearDown() override {
client_sock_->proactor()->Await([&] { client_sock_->Close(); });
client_sock_->proactor()->Await([&] { std::ignore = client_sock_->Close(); });
as_->Stop(true);
watchdog_done_.Notify();
watchdog_fiber_.Join();
Expand Down
3 changes: 2 additions & 1 deletion util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class DynamicBodyRequestImpl : public HttpRequestBase {
DynamicBodyRequest req_;

public:
DynamicBodyRequestImpl(DynamicBodyRequestImpl&&) = default;
DynamicBodyRequestImpl(DynamicBodyRequestImpl&& other) : req_(std::move(other.req_)) {
}

explicit DynamicBodyRequestImpl(std::string_view url)
: req_(boost::beast::http::verb::post, boost::string_view{url.data(), url.size()}, 11) {
Expand Down
13 changes: 13 additions & 0 deletions util/fiber_socket_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source

virtual ::io::Result<size_t> Recv(const io::MutableBytes& mb, int flags = 0) = 0;

struct ProvidedBuffer {
io::Bytes buffer;
uint32_t allocated;
uint8_t cookie; // Used by the socket to identify the buffer.
};

// Unlike Recv/ReadSome, this method returns a buffer managed by the socket.
// They should be returned back to the socket after the data is read.
// small is an optional buffer that can be used for small messages.
virtual ::io::Result<unsigned> RecvProvided(unsigned buf_len, ProvidedBuffer* dest) = 0;

virtual void ReturnProvided(const ProvidedBuffer& pbuf) = 0;

static bool IsConnClosed(const error_code& ec) {
return (ec == std::errc::connection_aborted) || (ec == std::errc::connection_reset) ||
(ec == std::errc::broken_pipe);
Expand Down
104 changes: 104 additions & 0 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,110 @@ auto EpollSocket::RecvMsg(const msghdr& msg, int flags) -> Result<size_t> {
return nonstd::make_unexpected(std::move(ec));
}

io::Result<unsigned> EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
DCHECK_GT(buf_len, 0u);

int fd = native_handle();
read_context_ = detail::FiberActive();
absl::Cleanup clean = [this]() { read_context_ = nullptr; };

ssize_t res;
error_code ec;
while (true) {
if (fd_ & IS_SHUTDOWN) {
res = EPIPE;
break;
}

io::MutableBytes buf = proactor()->AllocateBuffer(bufreq_sz_);
res = recv(fd, buf.data(), buf.size(), 0);
if (res > 0) { // if res is 0, that means a peer closed the socket.
size_t ures = res;
dest[0].cookie = 1;

// Handle buffer shrinkage.
if (bufreq_sz_ > kMinBufSize && ures < bufreq_sz_ / 2) {
bufreq_sz_ = absl::bit_ceil(ures);
io::MutableBytes buf2 = proactor()->AllocateBuffer(ures);
DCHECK_GE(buf2.size(), ures);

memcpy(buf2.data(), buf.data(), ures);
proactor()->ReturnBuffer(buf);
dest[0].buffer = {buf2.data(), ures};
dest[0].allocated = buf2.size();
return 1;
}

dest[0].buffer = {buf.data(), ures};
dest[0].allocated = buf.size();

// Handle buffer expansion.
unsigned num_bufs = 1;
while (buf.size() == bufreq_sz_) {
if (bufreq_sz_ < kMaxBufSize) {
bufreq_sz_ *= 2;
}

if (num_bufs == buf_len)
break;

buf = proactor()->AllocateBuffer(bufreq_sz_);
res = recv(fd, buf.data(), buf.size(), 0);
if (res <= 0) {
proactor()->ReturnBuffer(buf);
break;
}
ures = res;
dest[num_bufs].buffer = {buf.data(), ures};
dest[num_bufs].allocated = buf.size();
dest[num_bufs].cookie = 1;
++num_bufs;
}

return num_bufs;
} // res > 0

proactor()->ReturnBuffer(buf);

if (res == 0 || errno != EAGAIN) {
break;
}

if (SuspendMyself(read_context_, &ec) && ec) {
return nonstd::make_unexpected(std::move(ec));
}
}

// Error handling - finale part.
if (res == -1) {
res = errno;
} else if (res == 0) {
res = ECONNABORTED;
}

DVSOCK(1) << "Got " << res;

// ETIMEDOUT can happen if a socket does not have keepalive enabled or for some reason
// TCP connection did indeed stopped getting tcp keep alive packets.
if (!base::_in(res, {ECONNABORTED, EPIPE, ECONNRESET, ETIMEDOUT})) {
LOG(ERROR) << "sock[" << fd << "] Unexpected error " << res << "/" << strerror(res) << " "
<< RemoteEndpoint();
}

ec = std::error_code(res, std::system_category());
VSOCK(1) << "Error on " << RemoteEndpoint() << ": " << ec.message();

return nonstd::make_unexpected(std::move(ec));
}

void EpollSocket::ReturnProvided(const ProvidedBuffer& pbuf) {
DCHECK_EQ(pbuf.cookie, 1);
DCHECK(!pbuf.buffer.empty());

proactor()->ReturnBuffer(
io::MutableBytes{const_cast<uint8_t*>(pbuf.buffer.data()), pbuf.allocated});
}

io::Result<size_t> EpollSocket::Recv(const io::MutableBytes& mb, int flags) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
Expand Down
7 changes: 7 additions & 0 deletions util/fibers/epoll_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class EpollSocket : public LinuxSocketBase {

error_code Shutdown(int how) override;

io::Result<unsigned> RecvProvided(unsigned buf_len, ProvidedBuffer* dest) final;
void ReturnProvided(const ProvidedBuffer& pbuf) final;

void RegisterOnErrorCb(std::function<void(uint32_t)> cb) final;
void CancelOnErrorCb() final;

Expand All @@ -57,6 +60,10 @@ class EpollSocket : public LinuxSocketBase {
uint16_t epoll_mask_ = 0;
uint16_t kev_error_ = 0;

static constexpr uint32_t kMaxBufSize = 1 << 16;
static constexpr uint32_t kMinBufSize = 1 << 4;
uint32_t bufreq_sz_ = kMinBufSize;

std::function<void(uint32_t)> error_cb_;
};

Expand Down
66 changes: 62 additions & 4 deletions util/fibers/fiber_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ INSTANTIATE_TEST_SUITE_P(Engines, FiberSocketTest,
,
"uring"
#endif
));
),
[](const auto& info) { return string(info.param); });

void FiberSocketTest::SetUp() {
#if __linux__
Expand Down Expand Up @@ -231,9 +232,7 @@ TEST_P(FiberSocketTest, Poll) {
accept_fb_.Join();

LOG(INFO) << "Before close";
proactor_->Await([&] {
std::ignore = sock->Close();
});
proactor_->Await([&] { std::ignore = sock->Close(); });
usleep(1000);

// POLLRDHUP is linux specific
Expand Down Expand Up @@ -347,6 +346,65 @@ TEST_P(FiberSocketTest, UDS) {
LOG(INFO) << "Finished";
}

TEST_P(FiberSocketTest, RecvProvided) {
constexpr unsigned kBufLen = 40;
#ifdef __linux__
bool use_uring = GetParam() == "uring";

UringProactor* up = static_cast<UringProactor*>(proactor_.get());
if (use_uring) {
up->Await([up] { UringSocket::InitProvidedBuffers(4, kBufLen, up); });
}
#endif

unique_ptr<FiberSocketBase> sock;
error_code ec;
proactor_->Await([&] {
sock.reset(proactor_->CreateSocket());
ec = sock->Connect(listen_ep_);
});
ASSERT_FALSE(ec);

io::Result<unsigned> res;
FiberSocketBase::ProvidedBuffer pbuf[8];

auto recv_fb = proactor_->LaunchFiber([&] {
res = conn_socket_->RecvProvided(8, pbuf);
#ifdef __linux__
if (use_uring) {
bool has_more = static_cast<UringSocket*>(conn_socket_.get())->HasRecvData();
EXPECT_TRUE(has_more);
}
#endif
});

uint8_t buf[128];
memset(buf, 'x', sizeof(buf));

proactor_->Await([&] {
auto wrt_ec = sock->Write(io::Bytes(buf));
ASSERT_FALSE(wrt_ec);
});

recv_fb.Join();
proactor_->Await([&] { std::ignore = sock->Close(); });
ASSERT_TRUE(res);

ASSERT_TRUE(*res > 0 && *res < 8);
size_t total_size = 0;
for (unsigned i = 0; i < *res; ++i) {
total_size += pbuf[i].buffer.size();
}

ASSERT_LE(total_size, sizeof(buf));

proactor_->Await([&] {
for (unsigned i = 0; i < *res; ++i) {
conn_socket_->ReturnProvided(pbuf[i]);
}
});
}

#ifdef __linux__
TEST_P(FiberSocketTest, NotEmpty) {
bool use_uring = GetParam() == "uring";
Expand Down
2 changes: 1 addition & 1 deletion util/fibers/fibers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ TEST_P(ProactorTest, NotifyRemote2) {

for (unsigned i = 0; i < kNumThreads; ++i) {
for (unsigned j = 0; j < 20; ++j) {
fbs.push_back(ths[i]->proactor->LaunchFiber(StrCat("test", i, "/", j), [i, j, &ths] {
fbs.push_back(ths[i]->proactor->LaunchFiber(StrCat("test", i, "/", j), [i, &ths] {
for (unsigned iter = 0; iter < 1000; ++iter) {
unsigned idx = (i + iter) % kNumThreads;

Expand Down
9 changes: 9 additions & 0 deletions util/fibers/proactor_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ bool ProactorBase::InMyThread() const {
return pthread_self() == thread_id_;
}

io::MutableBytes ProactorBase::AllocateBuffer(size_t hint_sz) {
uint8_t* res = new uint8_t[hint_sz];
return io::MutableBytes{res, hint_sz};
}

void ProactorBase::ReturnBuffer(io::MutableBytes buf) {
operator delete[](buf.data());
}

uint32_t ProactorBase::AddOnIdleTask(OnIdleTask f) {
DCHECK(InMyThread());

Expand Down
4 changes: 4 additions & 0 deletions util/fibers/proactor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ class ProactorBase {
return fb;
}

// Returns a buffer of size at least min_size.
io::MutableBytes AllocateBuffer(size_t min_size);
void ReturnBuffer(io::MutableBytes buf);

using OnIdleTask = std::function<uint32_t()>;
using PeriodicTask = std::function<void()>;

Expand Down
2 changes: 1 addition & 1 deletion util/fibers/uring_file_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class UringFileTest : public testing::Test {

TEST_F(UringFileTest, Basic) {
string path = base::GetTestTempPath("1.log");
proactor_->Await([this, path] {
proactor_->Await([path] {
auto res = OpenLinux(path, O_RDWR | O_CREAT | O_TRUNC, 0666);
ASSERT_TRUE(res);
LinuxFile* wf = (*res).get();
Expand Down
Loading

0 comments on commit 6765a7a

Please sign in to comment.