Skip to content

Commit

Permalink
fix: echo_server works with multishot
Browse files Browse the repository at this point in the history
Fix echo_server with newer api
  • Loading branch information
romange committed Jan 13, 2025
1 parent e84cb42 commit 9170c42
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 89 deletions.
206 changes: 127 additions & 79 deletions examples/echo_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#include "util/varz.h"

#ifdef __linux__
#include "util/fibers/uring_socket.h"
#include "util/fibers/uring_proactor.h"
#include "util/fibers/uring_socket.h"
#endif

using namespace util;
Expand Down Expand Up @@ -56,7 +56,7 @@ ABSL_FLAG(uint32, max_clients, 1 << 16, "");
ABSL_FLAG(bool, raw, true,
"If true, does not send/receive size parameter during "
"the connection handshake");
ABSL_FLAG(bool, tcp_nodelay, false, "if true - use tcp_nodelay option for server sockets");
ABSL_FLAG(bool, tcp_nodelay, true, "use tcp_nodelay option for server sockets");
ABSL_FLAG(bool, multishot, false, "If true, iouring sockets use multishot receives");
ABSL_FLAG(uint16_t, bufring_size, 256, "Size of the buffer ring for iouring sockets");
ABSL_FLAG(bool, use_incoming_cpu, false,
Expand All @@ -66,42 +66,53 @@ VarzQps ping_qps("ping-qps");
VarzCount connections("connections");

const char kMaxConnectionsError[] = "max connections reached\r\n";
constexpr size_t kBufLen = 64;

class EchoConnection : public Connection {
public:
EchoConnection() {
}

private:
struct SendMsgState {
bool is_raw = false;
error_code ec;
size_t cur_sendmsg_len = 0;
vector<iovec> vec;
vector<FiberSocketBase::ProvidedBuffer> kept_buffers;
FiberSocketBase::ProvidedBuffer pbuf;
unsigned buf_len;
};

void HandleRequests() final;

std::error_code ReadMsg();
error_code ReadMsg();

std::queue<FiberSocketBase::ProvidedBuffer> prov_buffers_;
size_t pending_read_bytes_ = 0, first_buf_offset_ = 0;
// Returns true if we still need the buffer because it's referenced by iovec.
bool ProcessSingleBuffer(SendMsgState* state);
void ProcessFully(SendMsgState* state);

error_code Send(bool is_raw, const vector<iovec>& vec);

queue<FiberSocketBase::ProvidedBuffer> prov_buffers_;
uint64_t replies_ = 0;
size_t req_len_ = 0;
};

std::error_code EchoConnection::ReadMsg() {
FiberSocketBase::ProvidedBuffer pb[8];
VLOG(2) << "Waiting for socket read";
unsigned num_bufs = socket_->RecvProvided(8, pb);

while (pending_read_bytes_ < req_len_) {
unsigned num_bufs = socket_->RecvProvided(8, pb);

for (unsigned i = 0; i < num_bufs; ++i) {
if (pb[i].res_len > 0) {
prov_buffers_.push(pb[i]);
pending_read_bytes_ += pb[i].res_len;
} else {
DCHECK_EQ(i, 0u);
return error_code(-pb[i].res_len, system_category());
}
}
if (pending_read_bytes_ > req_len_) {
DVLOG(1) << "Waited for " << req_len_ << " but got " << pending_read_bytes_;
for (unsigned i = 0; i < num_bufs; ++i) {
if (pb[i].res_len > 0) {
prov_buffers_.push(pb[i]);
} else {
DCHECK_EQ(i, 0u);
return error_code(-pb[i].res_len, system_category());
}
}

VLOG(2) << "Received " << num_bufs << " buffers";
return {};
}

Expand All @@ -110,10 +121,7 @@ static thread_local base::Histogram send_hist;
void EchoConnection::HandleRequests() {
ThisFiber::SetName("HandleRequests");

std::error_code ec;
vector<iovec> vec;
uint8_t buf[8];
vec.resize(2);

int yes = 1;
if (GetFlag(FLAGS_tcp_nodelay)) {
Expand All @@ -129,16 +137,18 @@ void EchoConnection::HandleRequests() {
#endif
connections.IncBy(1);

vec[0].iov_base = buf;
vec[0].iov_len = 8;

auto ep = socket_->RemoteEndpoint();

VLOG(1) << "New connection from " << ep;

bool is_raw = GetFlag(FLAGS_raw);
SendMsgState state;
state.vec.resize(1);
state.vec[0].iov_base = buf;
state.vec[0].iov_len = 8;

if (is_raw) {
state.is_raw = GetFlag(FLAGS_raw);

if (state.is_raw) {
req_len_ = GetFlag(FLAGS_size);
} else {
VLOG(1) << "Waiting for size header from " << ep;
Expand All @@ -159,73 +169,111 @@ void EchoConnection::HandleRequests() {
}

CHECK_LE(req_len_, 1UL << 26);
vector<FiberSocketBase::ProvidedBuffer> returned_buffers;

// after the handshake.
uint64_t replies = 0;
while (true) {
ec = ReadMsg();
if (FiberSocketBase::IsConnClosed(ec)) {
VLOG(1) << "Closing " << ep << " after " << replies << " replies";
state.ec = ReadMsg();
if (FiberSocketBase::IsConnClosed(state.ec)) {
VLOG(1) << "Closing " << ep << " after " << replies_ << " replies";
break;
}
CHECK(!ec) << ec;
CHECK(!state.ec) << state.ec;
ping_qps.Inc();

vec[0].iov_base = buf;
vec[0].iov_len = 4;
state.vec[0].iov_base = buf;
state.vec[0].iov_len = 4;
absl::little_endian::Store32(buf, req_len_);
vec.resize(1);

size_t prepare_len = 0;
DCHECK(returned_buffers.empty());

while (prepare_len < req_len_) {
DCHECK(!prov_buffers_.empty());
size_t needed = req_len_ - prepare_len;
const auto& pbuf = prov_buffers_.front();
size_t bytes_count = pbuf.res_len - first_buf_offset_;
DCHECK_GT(!pbuf.res_len, 0);

if (bytes_count <= needed) {
vec.push_back({const_cast<uint8_t*>(pbuf.start) + first_buf_offset_, bytes_count});
prepare_len += bytes_count;
DCHECK_GE(pending_read_bytes_, bytes_count);
pending_read_bytes_ -= bytes_count;
returned_buffers.push_back(pbuf);
prov_buffers_.pop();
first_buf_offset_ = 0;
} else {
vec.push_back({const_cast<uint8_t*>(pbuf.start) + first_buf_offset_, needed});
first_buf_offset_ += needed;
prepare_len += needed;
DCHECK_GE(pending_read_bytes_, needed);
pending_read_bytes_ -= needed;
}
}

if (is_raw) {
auto prev = absl::GetCurrentTimeNanos();
// send(sock->native_handle(), work_buf_.get(), sz, 0);
ec = socket_->Write(vec.data() + 1, vec.size() - 1);
auto now = absl::GetCurrentTimeNanos();
send_hist.Add((now - prev) / 1000);
} else {
ec = socket_->Write(vec.data(), vec.size());
}
for (const auto& pb : returned_buffers) {
socket_->ReturnProvided(pb);
while (!prov_buffers_.empty()) {
state.pbuf = prov_buffers_.front();
state.buf_len = state.pbuf.res_len;
prov_buffers_.pop();
ProcessFully(&state);
if (state.ec)
break;
}
returned_buffers.clear();
++replies;
if (ec)
if (state.ec)
break;
}

VLOG(1) << "Connection ended " << ep;
connections.IncBy(-1);
}

bool EchoConnection::ProcessSingleBuffer(SendMsgState* state) {
auto GetBufferStart = [](const FiberSocketBase::ProvidedBuffer& pb) -> uint8_t* {
if (pb.type == FiberSocketBase::kHeapType)
return pb.start;
#ifdef __linux__
return static_cast<fb2::UringProactor*>(ProactorBase::me())->GetBufRingPtr(0, pb.bid);
#endif
return nullptr;
};

size_t len = std::min<unsigned>(state->buf_len, kBufLen);
size_t buf_offset = 0;
uint8_t* start = GetBufferStart(state->pbuf);
while (state->cur_sendmsg_len + len >= req_len_) {
unsigned needed = req_len_ - state->cur_sendmsg_len;
state->vec.push_back({start + buf_offset, needed});
state->ec = Send(state->is_raw, state->vec);
state->vec.resize(1);
for (auto& pb : state->kept_buffers) {
VLOG(2) << "Return buffer id " << pb.bid << " " << pb.res_len;
socket_->ReturnProvided(pb);
}
state->kept_buffers.clear();
state->cur_sendmsg_len = 0;
buf_offset += needed;
len -= needed;
if (state->ec)
return false;
}

if (len) { // consume the whole buffer and can not send yet.
state->vec.push_back({start + buf_offset, len});
state->cur_sendmsg_len += len;
return true;
}
return true;
}

void EchoConnection::ProcessFully(SendMsgState* state) {
#ifdef __linux__
fb2::UringProactor* up = static_cast<fb2::UringProactor*>(socket_->proactor());
while (state->buf_len > kBufLen) { // process bundle
ProcessSingleBuffer(state);
state->buf_len -= kBufLen;
state->pbuf.bid = up->GetNextBufRingBid(0, state->pbuf.bid);
if (state->ec)
return;
}
#endif
if (ProcessSingleBuffer(state)) {
state->kept_buffers.push_back(state->pbuf);
} else {
VLOG(1) << "Return buffer id " << state->pbuf.bid << " " << state->pbuf.res_len;
socket_->ReturnProvided(state->pbuf);
}
}

error_code EchoConnection::Send(bool is_raw, const vector<iovec>& vec) {
VLOG(1) << "Send response " << replies_;

error_code ec;
if (is_raw) {
auto prev = absl::GetCurrentTimeNanos();
ec = socket_->Write(vec.data() + 1, vec.size() - 1);
auto now = absl::GetCurrentTimeNanos();
send_hist.Add((now - prev) / 1000);
} else {
ec = socket_->Write(vec.data(), vec.size());
}

++replies_;
return ec;
}

class EchoListener : public ListenerInterface {
public:
EchoListener() {
Expand Down Expand Up @@ -548,7 +596,7 @@ int main(int argc, char* argv[]) {
if (!absl::GetFlag(FLAGS_epoll)) {
pp->AwaitBrief([](unsigned, auto* pb) {
fb2::UringProactor* up = static_cast<fb2::UringProactor*>(pb);
up->RegisterBufferRing(0, absl::GetFlag(FLAGS_bufring_size), 64);
up->RegisterBufferRing(0, absl::GetFlag(FLAGS_bufring_size), kBufLen);
});
}
#endif
Expand Down
4 changes: 3 additions & 1 deletion util/fiber_socket_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class FiberSocketBase : public io::Sink,

virtual ::io::Result<size_t> Recv(const io::MutableBytes& mb, int flags = 0) = 0;

enum ProvidedType : uint8_t { kHeapType = 1, kBufRingType = 2};

struct ProvidedBuffer {
union {
uint8_t* start;
Expand All @@ -68,7 +70,7 @@ class FiberSocketBase : public io::Sink,

int res_len; // positive len, negative errno.
uint32_t allocated;
uint8_t cookie; // Used by the socket to identify the buffer source.
ProvidedType type; // Buffer type.

void SetError(uint16_t err) {
res_len = -int(err);
Expand Down
6 changes: 3 additions & 3 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
ssize_t res = recv(fd, buf.data(), buf.size(), 0);
if (res > 0) { // if res is 0, that means a peer closed the socket.
size_t ures = res;
dest[0].cookie = 1;
dest[0].type = kHeapType;
dest[0].start = nullptr;

// Handle buffer shrinkage.
Expand Down Expand Up @@ -497,7 +497,7 @@ unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
dest[num_bufs].start = buf.data();
dest[num_bufs].res_len = ures;
dest[num_bufs].allocated = buf.size();
dest[num_bufs].cookie = 1;
dest[num_bufs].type = kHeapType;
++num_bufs;
}

Expand Down Expand Up @@ -527,7 +527,7 @@ unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
}

void EpollSocket::ReturnProvided(const ProvidedBuffer& pbuf) {
DCHECK_EQ(pbuf.cookie, 1);
DCHECK_EQ(pbuf.type, kHeapType);
DCHECK_GT(pbuf.res_len, 0);

proactor()->DeallocateBuffer(
Expand Down
9 changes: 3 additions & 6 deletions util/fibers/uring_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ auto Unexpected(std::errc e) {
return make_unexpected(make_error_code(e));
}

constexpr uint8_t kHeapType = 1;
constexpr uint8_t kBufRingType = 2;

} // namespace

bool UringSocket::MultiShot::DecRef() {
Expand Down Expand Up @@ -550,7 +547,7 @@ unsigned UringSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
pbuf.bid = result.bid;
pbuf.allocated = 0;
pbuf.res_len = result.res;
pbuf.cookie = kBufRingType;
pbuf.type = kBufRingType;
if (res == buf_len) {
return res;
}
Expand Down Expand Up @@ -578,7 +575,7 @@ unsigned UringSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
}
ssize_t res = fc.Get();

dest[0].cookie = kBufRingType;
dest[0].type = kBufRingType;
dest[0].allocated = 0;

if (res > 0) {
Expand Down Expand Up @@ -608,7 +605,7 @@ unsigned UringSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {

void UringSocket::ReturnProvided(const ProvidedBuffer& pbuf) {
CHECK_GT(pbuf.res_len, 0);
CHECK_EQ(pbuf.cookie, kBufRingType); // kHeapType is not supported yet.
CHECK_EQ(pbuf.type, kBufRingType); // kHeapType is not supported yet.

Proactor* p = GetProactor();

Expand Down

0 comments on commit 9170c42

Please sign in to comment.