From 33886884a1e7827a3a3ec0f404b3e91ecf203798 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Mon, 13 Jan 2025 12:49:50 +0200 Subject: [PATCH] fix: echo_server works with multishot Fix echo_server with newer api --- examples/echo_server.cc | 206 ++++++++++++++++++++++-------------- util/fiber_socket_base.h | 4 +- util/fibers/epoll_socket.cc | 6 +- util/fibers/uring_socket.cc | 34 +++--- 4 files changed, 147 insertions(+), 103 deletions(-) diff --git a/examples/echo_server.cc b/examples/echo_server.cc index 893b84f5..aa6fe3c0 100644 --- a/examples/echo_server.cc +++ b/examples/echo_server.cc @@ -25,8 +25,8 @@ #include "util/varz.h" #ifdef __linux__ -#include "util/fibers/uring_socket.h" #include "util/fibers/uring_proactor.h" +#include "util/fibers/uring_socket.h" #endif using namespace util; @@ -56,7 +56,7 @@ ABSL_FLAG(uint32, max_clients, 1 << 16, ""); ABSL_FLAG(bool, raw, true, "If true, does not send/receive size parameter during " "the connection handshake"); -ABSL_FLAG(bool, tcp_nodelay, false, "if true - use tcp_nodelay option for server sockets"); +ABSL_FLAG(bool, tcp_nodelay, true, "use tcp_nodelay option for server sockets"); ABSL_FLAG(bool, multishot, false, "If true, iouring sockets use multishot receives"); ABSL_FLAG(uint16_t, bufring_size, 256, "Size of the buffer ring for iouring sockets"); ABSL_FLAG(bool, use_incoming_cpu, false, @@ -66,6 +66,7 @@ VarzQps ping_qps("ping-qps"); VarzCount connections("connections"); const char kMaxConnectionsError[] = "max connections reached\r\n"; +constexpr size_t kBufLen = 64; class EchoConnection : public Connection { public: @@ -73,35 +74,45 @@ class EchoConnection : public Connection { } private: + struct SendMsgState { + bool is_raw = false; + error_code ec; + size_t cur_sendmsg_len = 0; + vector vec; + vector kept_buffers; + FiberSocketBase::ProvidedBuffer pbuf; + unsigned buf_len; + }; + void HandleRequests() final; - std::error_code ReadMsg(); + error_code ReadMsg(); - std::queue prov_buffers_; - size_t pending_read_bytes_ = 0, first_buf_offset_ = 0; + // Returns true if we still need the buffer because it's referenced by iovec. + bool ProcessSingleBuffer(SendMsgState* state); + void ProcessFully(SendMsgState* state); + + error_code Send(bool is_raw, const vector& vec); + + queue prov_buffers_; + uint64_t replies_ = 0; size_t req_len_ = 0; }; std::error_code EchoConnection::ReadMsg() { FiberSocketBase::ProvidedBuffer pb[8]; + VLOG(2) << "Waiting for socket read"; + unsigned num_bufs = socket_->RecvProvided(8, pb); - while (pending_read_bytes_ < req_len_) { - unsigned num_bufs = socket_->RecvProvided(8, pb); - - for (unsigned i = 0; i < num_bufs; ++i) { - if (pb[i].res_len > 0) { - prov_buffers_.push(pb[i]); - pending_read_bytes_ += pb[i].res_len; - } else { - DCHECK_EQ(i, 0u); - return error_code(-pb[i].res_len, system_category()); - } - } - if (pending_read_bytes_ > req_len_) { - DVLOG(1) << "Waited for " << req_len_ << " but got " << pending_read_bytes_; + for (unsigned i = 0; i < num_bufs; ++i) { + if (pb[i].res_len > 0) { + prov_buffers_.push(pb[i]); + } else { + DCHECK_EQ(i, 0u); + return error_code(-pb[i].res_len, system_category()); } } - + VLOG(2) << "Received " << num_bufs << " buffers"; return {}; } @@ -110,10 +121,7 @@ static thread_local base::Histogram send_hist; void EchoConnection::HandleRequests() { ThisFiber::SetName("HandleRequests"); - std::error_code ec; - vector vec; uint8_t buf[8]; - vec.resize(2); int yes = 1; if (GetFlag(FLAGS_tcp_nodelay)) { @@ -129,16 +137,18 @@ void EchoConnection::HandleRequests() { #endif connections.IncBy(1); - vec[0].iov_base = buf; - vec[0].iov_len = 8; - auto ep = socket_->RemoteEndpoint(); VLOG(1) << "New connection from " << ep; - bool is_raw = GetFlag(FLAGS_raw); + SendMsgState state; + state.vec.resize(1); + state.vec[0].iov_base = buf; + state.vec[0].iov_len = 8; - if (is_raw) { + state.is_raw = GetFlag(FLAGS_raw); + + if (state.is_raw) { req_len_ = GetFlag(FLAGS_size); } else { VLOG(1) << "Waiting for size header from " << ep; @@ -159,66 +169,30 @@ void EchoConnection::HandleRequests() { } CHECK_LE(req_len_, 1UL << 26); - vector returned_buffers; // after the handshake. - uint64_t replies = 0; while (true) { - ec = ReadMsg(); - if (FiberSocketBase::IsConnClosed(ec)) { - VLOG(1) << "Closing " << ep << " after " << replies << " replies"; + state.ec = ReadMsg(); + if (FiberSocketBase::IsConnClosed(state.ec)) { + VLOG(1) << "Closing " << ep << " after " << replies_ << " replies"; break; } - CHECK(!ec) << ec; + CHECK(!state.ec) << state.ec; ping_qps.Inc(); - vec[0].iov_base = buf; - vec[0].iov_len = 4; + state.vec[0].iov_base = buf; + state.vec[0].iov_len = 4; absl::little_endian::Store32(buf, req_len_); - vec.resize(1); - - size_t prepare_len = 0; - DCHECK(returned_buffers.empty()); - - while (prepare_len < req_len_) { - DCHECK(!prov_buffers_.empty()); - size_t needed = req_len_ - prepare_len; - const auto& pbuf = prov_buffers_.front(); - size_t bytes_count = pbuf.res_len - first_buf_offset_; - DCHECK_GT(!pbuf.res_len, 0); - - if (bytes_count <= needed) { - vec.push_back({const_cast(pbuf.start) + first_buf_offset_, bytes_count}); - prepare_len += bytes_count; - DCHECK_GE(pending_read_bytes_, bytes_count); - pending_read_bytes_ -= bytes_count; - returned_buffers.push_back(pbuf); - prov_buffers_.pop(); - first_buf_offset_ = 0; - } else { - vec.push_back({const_cast(pbuf.start) + first_buf_offset_, needed}); - first_buf_offset_ += needed; - prepare_len += needed; - DCHECK_GE(pending_read_bytes_, needed); - pending_read_bytes_ -= needed; - } - } - if (is_raw) { - auto prev = absl::GetCurrentTimeNanos(); - // send(sock->native_handle(), work_buf_.get(), sz, 0); - ec = socket_->Write(vec.data() + 1, vec.size() - 1); - auto now = absl::GetCurrentTimeNanos(); - send_hist.Add((now - prev) / 1000); - } else { - ec = socket_->Write(vec.data(), vec.size()); - } - for (const auto& pb : returned_buffers) { - socket_->ReturnProvided(pb); + while (!prov_buffers_.empty()) { + state.pbuf = prov_buffers_.front(); + state.buf_len = state.pbuf.res_len; + prov_buffers_.pop(); + ProcessFully(&state); + if (state.ec) + break; } - returned_buffers.clear(); - ++replies; - if (ec) + if (state.ec) break; } @@ -226,6 +200,80 @@ void EchoConnection::HandleRequests() { connections.IncBy(-1); } +bool EchoConnection::ProcessSingleBuffer(SendMsgState* state) { + auto GetBufferStart = [](const FiberSocketBase::ProvidedBuffer& pb) -> uint8_t* { + if (pb.type == FiberSocketBase::kHeapType) + return pb.start; +#ifdef __linux__ + return static_cast(ProactorBase::me())->GetBufRingPtr(0, pb.bid); +#endif + return nullptr; + }; + + size_t len = std::min(state->buf_len, kBufLen); + size_t buf_offset = 0; + uint8_t* start = GetBufferStart(state->pbuf); + while (state->cur_sendmsg_len + len >= req_len_) { + unsigned needed = req_len_ - state->cur_sendmsg_len; + state->vec.push_back({start + buf_offset, needed}); + state->ec = Send(state->is_raw, state->vec); + state->vec.resize(1); + for (auto& pb : state->kept_buffers) { + VLOG(2) << "Return buffer id " << pb.bid << " " << pb.res_len; + socket_->ReturnProvided(pb); + } + state->kept_buffers.clear(); + state->cur_sendmsg_len = 0; + buf_offset += needed; + len -= needed; + if (state->ec) + return false; + } + + if (len) { // consume the whole buffer and can not send yet. + state->vec.push_back({start + buf_offset, len}); + state->cur_sendmsg_len += len; + return true; + } + return true; +} + +void EchoConnection::ProcessFully(SendMsgState* state) { +#ifdef __linux__ + fb2::UringProactor* up = static_cast(socket_->proactor()); + while (state->buf_len > kBufLen) { // process bundle + ProcessSingleBuffer(state); + state->buf_len -= kBufLen; + state->pbuf.bid = up->GetNextBufRingBid(0, state->pbuf.bid); + if (state->ec) + return; + } +#endif + if (ProcessSingleBuffer(state)) { + state->kept_buffers.push_back(state->pbuf); + } else { + VLOG(1) << "Return buffer id " << state->pbuf.bid << " " << state->pbuf.res_len; + socket_->ReturnProvided(state->pbuf); + } +} + +error_code EchoConnection::Send(bool is_raw, const vector& vec) { + VLOG(1) << "Send response " << replies_; + + error_code ec; + if (is_raw) { + auto prev = absl::GetCurrentTimeNanos(); + ec = socket_->Write(vec.data() + 1, vec.size() - 1); + auto now = absl::GetCurrentTimeNanos(); + send_hist.Add((now - prev) / 1000); + } else { + ec = socket_->Write(vec.data(), vec.size()); + } + + ++replies_; + return ec; +} + class EchoListener : public ListenerInterface { public: EchoListener() { @@ -548,7 +596,7 @@ int main(int argc, char* argv[]) { if (!absl::GetFlag(FLAGS_epoll)) { pp->AwaitBrief([](unsigned, auto* pb) { fb2::UringProactor* up = static_cast(pb); - up->RegisterBufferRing(0, absl::GetFlag(FLAGS_bufring_size), 64); + up->RegisterBufferRing(0, absl::GetFlag(FLAGS_bufring_size), kBufLen); }); } #endif diff --git a/util/fiber_socket_base.h b/util/fiber_socket_base.h index 3ae5e048..d5a84ce4 100644 --- a/util/fiber_socket_base.h +++ b/util/fiber_socket_base.h @@ -60,6 +60,8 @@ class FiberSocketBase : public io::Sink, virtual ::io::Result Recv(const io::MutableBytes& mb, int flags = 0) = 0; + enum ProvidedType : uint8_t { kHeapType = 1, kBufRingType = 2}; + struct ProvidedBuffer { union { uint8_t* start; @@ -68,7 +70,7 @@ class FiberSocketBase : public io::Sink, int res_len; // positive len, negative errno. uint32_t allocated; - uint8_t cookie; // Used by the socket to identify the buffer source. + ProvidedType type; // Buffer type. void SetError(uint16_t err) { res_len = -int(err); diff --git a/util/fibers/epoll_socket.cc b/util/fibers/epoll_socket.cc index 07aab2b0..e9536fc1 100644 --- a/util/fibers/epoll_socket.cc +++ b/util/fibers/epoll_socket.cc @@ -455,7 +455,7 @@ unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { ssize_t res = recv(fd, buf.data(), buf.size(), 0); if (res > 0) { // if res is 0, that means a peer closed the socket. size_t ures = res; - dest[0].cookie = 1; + dest[0].type = kHeapType; dest[0].start = nullptr; // Handle buffer shrinkage. @@ -497,7 +497,7 @@ unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { dest[num_bufs].start = buf.data(); dest[num_bufs].res_len = ures; dest[num_bufs].allocated = buf.size(); - dest[num_bufs].cookie = 1; + dest[num_bufs].type = kHeapType; ++num_bufs; } @@ -527,7 +527,7 @@ unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { } void EpollSocket::ReturnProvided(const ProvidedBuffer& pbuf) { - DCHECK_EQ(pbuf.cookie, 1); + DCHECK_EQ(pbuf.type, kHeapType); DCHECK_GT(pbuf.res_len, 0); proactor()->DeallocateBuffer( diff --git a/util/fibers/uring_socket.cc b/util/fibers/uring_socket.cc index 688d8d64..7663e75d 100644 --- a/util/fibers/uring_socket.cc +++ b/util/fibers/uring_socket.cc @@ -38,9 +38,6 @@ auto Unexpected(std::errc e) { return make_unexpected(make_error_code(e)); } -constexpr uint8_t kHeapType = 1; -constexpr uint8_t kBufRingType = 2; - } // namespace bool UringSocket::MultiShot::DecRef() { @@ -65,26 +62,23 @@ void UringSocket::MultiShot::Activate(int fd, uint16_t bufring_id, uint8_t flags DVLOG(2) << "Multishot completion " << res << " flags: " << flags; UringProactor* proactor = static_cast(ProactorBase::me()); - if ((flags & IORING_CQE_F_MORE) == 0) { - if (DecRef()) // Last reference. - return; - - // Assumption. - CHECK_EQ(flags & IORING_CQE_F_BUFFER, 0u); - CHECK_LE(res, 0); - err_no = -res; - error_raised = 1; - } else { - CHECK(flags & IORING_CQE_F_BUFFER); + if (flags & IORING_CQE_F_BUFFER) { CHECK_GT(res, 0); this->tail = proactor->EnqueueMultishotCompletion(bufring_id, res, flags, this->tail); if (this->head == UringProactor::kMultiShotUndef) this->head = this->tail; - - DVLOG(1) << "Multishot tail " << tail << " " << flags; - DCHECK_NE(tail, UringProactor::kMultiShotUndef); + } else { + CHECK_LE(res, 0); + DCHECK_EQ(0u, flags & IORING_CQE_F_MORE); + err_no = -res; + error_raised = 1; + } + + if ((flags & IORING_CQE_F_MORE) == 0) { + if (DecRef()) // Last reference. + return; } if (poll_pending) { @@ -550,7 +544,7 @@ unsigned UringSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { pbuf.bid = result.bid; pbuf.allocated = 0; pbuf.res_len = result.res; - pbuf.cookie = kBufRingType; + pbuf.type = kBufRingType; if (res == buf_len) { return res; } @@ -578,7 +572,7 @@ unsigned UringSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { } ssize_t res = fc.Get(); - dest[0].cookie = kBufRingType; + dest[0].type = kBufRingType; dest[0].allocated = 0; if (res > 0) { @@ -608,7 +602,7 @@ unsigned UringSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) { void UringSocket::ReturnProvided(const ProvidedBuffer& pbuf) { CHECK_GT(pbuf.res_len, 0); - CHECK_EQ(pbuf.cookie, kBufRingType); // kHeapType is not supported yet. + CHECK_EQ(pbuf.type, kBufRingType); // kHeapType is not supported yet. Proactor* p = GetProactor();