Skip to content

Commit

Permalink
Improve internal implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
fantasy-peak committed Apr 17, 2024
1 parent 397e233 commit 065c6ff
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 155 deletions.
97 changes: 45 additions & 52 deletions out/bi_web/include/frpc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,23 +748,9 @@ class HelloWorldClient final {
m_cb.emplace(req_id, std::move(cb));
m_timeout_cb.emplace(req_id, std::move(timeout_cb));
}
m_channel->send(std::move(snd_bufs),
timeout,
[this, req_id]() mutable {
std::unique_lock lk(m_mtx);
#if __cplusplus >= 202302L
if (!m_cb.contains(req_id) || !m_timeout_cb.contains(req_id))
return;
#else
if (m_cb.find(req_id) == m_cb.end() || m_timeout_cb.find(req_id) == m_timeout_cb.end())
return;
#endif
auto cb = std::move(m_timeout_cb[req_id]);
m_timeout_cb.erase(req_id);
m_cb.erase(req_id);
lk.unlock();
cb();
});
m_channel->send(std::move(snd_bufs), timeout, [this, req_id] {
callTimeoutCallback(req_id);
});
}
#ifdef __cpp_impl_coroutine
template <asio::completion_token_for<void(std::string, Info, uint64_t, std::optional<std::string>)> CompletionToken>
Expand Down Expand Up @@ -817,28 +803,34 @@ class HelloWorldClient final {
}

private:
void callTimeoutCallback(uint64_t req_id) {
std::unique_lock lk(m_mtx);
if (m_timeout_cb.find(req_id) == m_timeout_cb.end())
return;
auto cb = std::move(m_timeout_cb[req_id]);
m_timeout_cb.erase(req_id);
m_cb.erase(req_id);
lk.unlock();
cb();
}

void dispatch(std::vector<zmq::message_t>& recv_bufs) {
if (recv_bufs.size() != 2) {
m_error(FRPC_ERROR_FORMAT("Illegal response packet"));
return;
}
try {
auto [req_id, req_type] = unpack<std::tuple<uint64_t, HelloWorldClientHelloWorldServer>>(recv_bufs[0].data(), recv_bufs[0].size());
std::unique_lock lk(m_mtx);
if (m_cb.find(req_id) == m_cb.end())
return;
auto cb = std::move(m_cb[req_id]);
m_cb.erase(req_id);
m_timeout_cb.erase(req_id);
lk.unlock();
switch (req_type) {
case HelloWorldClientHelloWorldServer::hello_world: {
auto [reply, info, count, date] = unpack<std::tuple<std::string, Info, uint64_t, std::optional<std::string>>>(recv_bufs[1].data(), recv_bufs[1].size());
std::unique_lock lk(m_mtx);
#if __cplusplus >= 202302L
if (!m_cb.contains(req_id))
break;
#else
if (m_cb.find(req_id) == m_cb.end())
break;
#endif
auto cb = std::move(m_cb[req_id]);
m_cb.erase(req_id);
m_timeout_cb.erase(req_id);
lk.unlock();
auto callback = std::any_cast<std::function<void(std::string, Info, uint64_t, std::optional<std::string>)>>(cb);
callback(std::move(reply), std::move(info), count, std::move(date));
break;
Expand Down Expand Up @@ -1457,29 +1449,30 @@ class StreamServer final {
recv_bufs[2] = zmq::message_t(is_close_buffer.data(), is_close_buffer.size());
auto ptr = std::make_shared<std::vector<zmq::message_t>>(std::move(recv_bufs));

auto out = std::make_shared<Stream<void(std::string)>>([ptr, this](std::string reply) mutable {
auto& recv_bufs = *ptr;
auto packet = pack<std::tuple<std::string>>(std::make_tuple(std::move(reply)));
auto close = pack<bool>(false);
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[1].data(), recv_bufs[1].size()));
snd_bufs.emplace_back(zmq::message_t(packet.data(), packet.size()));
m_channel->send(snd_bufs);
},
[ptr, this, req_id, channel_ptr]() mutable {
auto& recv_bufs = *ptr;
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[2].data(), recv_bufs[2].size()));
snd_bufs.emplace_back(zmq::message_t("C", 1));
m_channel->send(snd_bufs);
channel_ptr->close();
{
std::lock_guard lk(m_mtx);
m_channel_mapping.erase(req_id);
}
});
auto out = std::make_shared<Stream<void(std::string)>>(
[ptr, this](std::string reply) mutable {
auto& recv_bufs = *ptr;
auto packet = pack<std::tuple<std::string>>(std::make_tuple(std::move(reply)));
auto close = pack<bool>(false);
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[1].data(), recv_bufs[1].size()));
snd_bufs.emplace_back(zmq::message_t(packet.data(), packet.size()));
m_channel->send(std::move(snd_bufs));
},
[ptr, this, req_id, channel_ptr]() mutable {
auto& recv_bufs = *ptr;
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[2].data(), recv_bufs[2].size()));
snd_bufs.emplace_back(zmq::message_t("C", 1));
m_channel->send(std::move(snd_bufs));
channel_ptr->close();
{
std::lock_guard lk(m_mtx);
m_channel_mapping.erase(req_id);
}
});
std::visit([&](auto&& arg) mutable {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::shared_ptr<StreamServerHandler>>) {
Expand Down
97 changes: 45 additions & 52 deletions out/frpc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,23 +748,9 @@ class HelloWorldClient final {
m_cb.emplace(req_id, std::move(cb));
m_timeout_cb.emplace(req_id, std::move(timeout_cb));
}
m_channel->send(std::move(snd_bufs),
timeout,
[this, req_id]() mutable {
std::unique_lock lk(m_mtx);
#if __cplusplus >= 202302L
if (!m_cb.contains(req_id) || !m_timeout_cb.contains(req_id))
return;
#else
if (m_cb.find(req_id) == m_cb.end() || m_timeout_cb.find(req_id) == m_timeout_cb.end())
return;
#endif
auto cb = std::move(m_timeout_cb[req_id]);
m_timeout_cb.erase(req_id);
m_cb.erase(req_id);
lk.unlock();
cb();
});
m_channel->send(std::move(snd_bufs), timeout, [this, req_id] {
callTimeoutCallback(req_id);
});
}
#ifdef __cpp_impl_coroutine
template <asio::completion_token_for<void(std::string, Info, uint64_t, std::optional<std::string>)> CompletionToken>
Expand Down Expand Up @@ -817,28 +803,34 @@ class HelloWorldClient final {
}

private:
void callTimeoutCallback(uint64_t req_id) {
std::unique_lock lk(m_mtx);
if (m_timeout_cb.find(req_id) == m_timeout_cb.end())
return;
auto cb = std::move(m_timeout_cb[req_id]);
m_timeout_cb.erase(req_id);
m_cb.erase(req_id);
lk.unlock();
cb();
}

void dispatch(std::vector<zmq::message_t>& recv_bufs) {
if (recv_bufs.size() != 2) {
m_error(FRPC_ERROR_FORMAT("Illegal response packet"));
return;
}
try {
auto [req_id, req_type] = unpack<std::tuple<uint64_t, HelloWorldClientHelloWorldServer>>(recv_bufs[0].data(), recv_bufs[0].size());
std::unique_lock lk(m_mtx);
if (m_cb.find(req_id) == m_cb.end())
return;
auto cb = std::move(m_cb[req_id]);
m_cb.erase(req_id);
m_timeout_cb.erase(req_id);
lk.unlock();
switch (req_type) {
case HelloWorldClientHelloWorldServer::hello_world: {
auto [reply, info, count, date] = unpack<std::tuple<std::string, Info, uint64_t, std::optional<std::string>>>(recv_bufs[1].data(), recv_bufs[1].size());
std::unique_lock lk(m_mtx);
#if __cplusplus >= 202302L
if (!m_cb.contains(req_id))
break;
#else
if (m_cb.find(req_id) == m_cb.end())
break;
#endif
auto cb = std::move(m_cb[req_id]);
m_cb.erase(req_id);
m_timeout_cb.erase(req_id);
lk.unlock();
auto callback = std::any_cast<std::function<void(std::string, Info, uint64_t, std::optional<std::string>)>>(cb);
callback(std::move(reply), std::move(info), count, std::move(date));
break;
Expand Down Expand Up @@ -1457,29 +1449,30 @@ class StreamServer final {
recv_bufs[2] = zmq::message_t(is_close_buffer.data(), is_close_buffer.size());
auto ptr = std::make_shared<std::vector<zmq::message_t>>(std::move(recv_bufs));

auto out = std::make_shared<Stream<void(std::string)>>([ptr, this](std::string reply) mutable {
auto& recv_bufs = *ptr;
auto packet = pack<std::tuple<std::string>>(std::make_tuple(std::move(reply)));
auto close = pack<bool>(false);
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[1].data(), recv_bufs[1].size()));
snd_bufs.emplace_back(zmq::message_t(packet.data(), packet.size()));
m_channel->send(snd_bufs);
},
[ptr, this, req_id, channel_ptr]() mutable {
auto& recv_bufs = *ptr;
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[2].data(), recv_bufs[2].size()));
snd_bufs.emplace_back(zmq::message_t("C", 1));
m_channel->send(snd_bufs);
channel_ptr->close();
{
std::lock_guard lk(m_mtx);
m_channel_mapping.erase(req_id);
}
});
auto out = std::make_shared<Stream<void(std::string)>>(
[ptr, this](std::string reply) mutable {
auto& recv_bufs = *ptr;
auto packet = pack<std::tuple<std::string>>(std::make_tuple(std::move(reply)));
auto close = pack<bool>(false);
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[1].data(), recv_bufs[1].size()));
snd_bufs.emplace_back(zmq::message_t(packet.data(), packet.size()));
m_channel->send(std::move(snd_bufs));
},
[ptr, this, req_id, channel_ptr]() mutable {
auto& recv_bufs = *ptr;
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[2].data(), recv_bufs[2].size()));
snd_bufs.emplace_back(zmq::message_t("C", 1));
m_channel->send(std::move(snd_bufs));
channel_ptr->close();
{
std::lock_guard lk(m_mtx);
m_channel_mapping.erase(req_id);
}
});
std::visit([&](auto&& arg) mutable {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::shared_ptr<StreamServerHandler>>) {
Expand Down
48 changes: 19 additions & 29 deletions template/cpp/bi.inja
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,7 @@ public:
m_cb.emplace(req_id, std::move(cb));
m_timeout_cb.emplace(req_id, std::move(timeout_cb));
}
m_channel->send(std::move(snd_bufs),
timeout,
[this, req_id]() mutable {
std::unique_lock lk(m_mtx);
#if __cplusplus >= 202302L
if (!m_cb.contains(req_id) || !m_timeout_cb.contains(req_id))
return;
#else
if (m_cb.find(req_id) == m_cb.end() || m_timeout_cb.find(req_id) == m_timeout_cb.end())
return;
#endif
auto cb = std::move(m_timeout_cb[req_id]);
m_timeout_cb.erase(req_id);
m_cb.erase(req_id);
lk.unlock();
cb();
});
m_channel->send(std::move(snd_bufs), timeout, [this, req_id] { callTimeoutCallback(req_id); });
}
#ifdef __cpp_impl_coroutine
template <asio::completion_token_for<void({{_format_args_type(func.outputs)}})> CompletionToken>
Expand Down Expand Up @@ -140,29 +124,35 @@ public:
}

private:
void callTimeoutCallback(uint64_t req_id) {
std::unique_lock lk(m_mtx);
if (m_timeout_cb.find(req_id) == m_timeout_cb.end())
return;
auto cb = std::move(m_timeout_cb[req_id]);
m_timeout_cb.erase(req_id);
m_cb.erase(req_id);
lk.unlock();
cb();
}

void dispatch(std::vector<zmq::message_t>& recv_bufs) {
if (recv_bufs.size() != 2) {
m_error(FRPC_ERROR_FORMAT("Illegal response packet"));
return;
}
try {
auto [req_id, req_type] = unpack<std::tuple<uint64_t, {{value.caller}}{{value.callee}}>>(recv_bufs[0].data(), recv_bufs[0].size());
std::unique_lock lk(m_mtx);
if (m_cb.find(req_id) == m_cb.end())
return;
auto cb = std::move(m_cb[req_id]);
m_cb.erase(req_id);
m_timeout_cb.erase(req_id);
lk.unlock();
switch(req_type) {
{% for func in value.definitions %}
case {{value.caller}}{{value.callee}}::{{func.func_name}}: {
auto [{{_format_args_name(func.outputs)}}] = unpack<std::tuple<{{_format_args_type(func.outputs)}}>>(recv_bufs[1].data(), recv_bufs[1].size());
std::unique_lock lk(m_mtx);
#if __cplusplus >= 202302L
if (!m_cb.contains(req_id))
break;
#else
if (m_cb.find(req_id) == m_cb.end())
break;
#endif
auto cb = std::move(m_cb[req_id]);
m_cb.erase(req_id);
m_timeout_cb.erase(req_id);
lk.unlock();
auto callback = std::any_cast<std::function<void({{_format_args_type(func.outputs)}})>>(cb);
callback({{_format_args_name_and_move(func.outputs)}});
break;
Expand Down
Loading

0 comments on commit 065c6ff

Please sign in to comment.