Skip to content

Commit

Permalink
[coro_rpc] allow delay rpc call, context now share the ownership of b…
Browse files Browse the repository at this point in the history
…ody's buffer with connection (#244)

* [coro_rpc] allow delay response

* [coro_rpc] let callback own correct context
  • Loading branch information
poor-circle authored Mar 21, 2023
1 parent 8579a9c commit 884acd3
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 254 deletions.
14 changes: 9 additions & 5 deletions include/asio_util/asio_coro_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class callback_awaitor_base {
constexpr bool await_ready() const noexcept { return false; }
void await_suspend(std::coroutine_handle<> handle) noexcept {
awaitor.coro_ = handle;
op(awaitor_handler{awaitor});
op(awaitor_handler{&awaitor});
}
auto coAwait(async_simple::Executor *executor) const noexcept {
return *this;
Expand All @@ -104,7 +104,11 @@ class callback_awaitor_base {
public:
class awaitor_handler {
public:
awaitor_handler(Derived &obj) : obj(obj) {}
awaitor_handler(Derived *obj) : obj(obj) {}
awaitor_handler(awaitor_handler &&) = default;
awaitor_handler(const awaitor_handler &) = default;
awaitor_handler &operator=(const awaitor_handler &) = default;
awaitor_handler &operator=(awaitor_handler &&) = default;
template <typename... Args>
void set_value_then_resume(Args &&...args) const {
set_value(std::forward<Args>(args)...);
Expand All @@ -113,13 +117,13 @@ class callback_awaitor_base {
template <typename... Args>
void set_value(Args &&...args) const {
if constexpr (!std::is_void_v<Arg>) {
obj.arg_ = {std::forward<Args>(args)...};
obj->arg_ = {std::forward<Args>(args)...};
}
}
void resume() const { obj.coro_.resume(); }
void resume() const { obj->coro_.resume(); }

private:
Derived &obj;
Derived *obj;
};
template <typename Op>
callback_awaitor_impl<Op> await_resume(const Op &op) noexcept {
Expand Down
101 changes: 41 additions & 60 deletions include/coro_rpc/coro_rpc/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
#include "util/type_traits.h"

namespace coro_rpc {
template <typename T, typename Conn>
concept has_get_reserve_size = requires(Conn &&conn) {
T::get_reserve_size(conn);
};
/*!
*
* @tparam return_msg_type
Expand All @@ -41,37 +37,23 @@ concept has_get_reserve_size = requires(Conn &&conn) {
template <typename return_msg_type, typename rpc_protocol>
class context_base {
protected:
std::shared_ptr<coro_connection> conn_;
std::unique_ptr<std::atomic<bool>> has_response_;
typename rpc_protocol::req_header req_head_;
std::shared_ptr<context_info_t<rpc_protocol>> self_;
typename rpc_protocol::req_header &get_req_head() { return self_->req_head_; }

public:
/*!
* Construct a context by a share pointer of context Concept
* instance
* @param a share pointer for coro_connection
*/
context_base(std::shared_ptr<coro_connection> &&conn,
typename rpc_protocol::req_header &&req_head)
: conn_(std::move(conn)),
has_response_(std::make_unique<std::atomic<bool>>(false)),
req_head_(std::move(req_head)) {
if (conn_) {
conn_->set_delay(true);
context_base(std::shared_ptr<context_info_t<rpc_protocol>> context_info)
: self_(std::move(context_info)) {
if (self_->conn_) {
self_->conn_->set_rpc_call_type(
coro_connection::rpc_call_type::callback_started);
}
};
context_base() = delete;
context_base(const context_base &conn) = delete;
context_base(context_base &&conn) = default;
~context_base() {
if (has_response_ && conn_ && !*has_response_)
AS_UNLIKELY {
ELOGV(ERROR,
"We must send reply to client by call response_msg method"
"before coro_rpc::context<T> destruction!");
conn_->async_close();
}
}
context_base() = default;

using return_type = return_msg_type;

Expand All @@ -91,11 +73,12 @@ class context_base {
if constexpr (std::is_same_v<return_msg_type, void>) {
static_assert(sizeof...(args) == 0, "illegal args");

auto old_flag = has_response_->exchange(true);
if (old_flag != false) {
ELOGV(ERROR, "response message more than one time");
return;
}
auto old_flag = self_->has_response_.exchange(true);
if (old_flag != false)
AS_UNLIKELY {
ELOGV(ERROR, "response message more than one time");
return;
}

if (has_closed())
AS_UNLIKELY {
Expand All @@ -104,47 +87,39 @@ class context_base {
}
std::visit(
[&]<typename serialize_proto>(const serialize_proto &) {
conn_->response_msg<rpc_protocol>(serialize_proto::serialize(),
req_head_);
self_->conn_->template response_msg<rpc_protocol>(
serialize_proto::serialize(), self_->req_head_,
self_->is_delay_);
},
*rpc_protocol::get_serialize_protocol(req_head_));
*rpc_protocol::get_serialize_protocol(self_->req_head_));
}
else {
static_assert(
requires { return_msg_type{std::forward<Args>(args)...}; },
"constructed return_msg_type failed by illegal args");
return_msg_type ret{std::forward<Args>(args)...};

auto old_flag = has_response_->exchange(true);
if (old_flag != false) {
ELOGV(ERROR, "response message more than one time");
return;
}
auto old_flag = self_->has_response_.exchange(true);
if (old_flag != false)
AS_UNLIKELY {
ELOGV(ERROR, "response message more than one time");
return;
}

if (has_closed())
AS_UNLIKELY {
ELOGV(DEBUG, "response_msg failed: connection has been closed");
return;
}

if constexpr (has_get_reserve_size<rpc_protocol, rpc_conn>) {
std::visit(
[&]<typename serialize_proto>(const serialize_proto &) {
conn_->response_msg<rpc_protocol>(
serialize_proto::serialize(
ret, rpc_protocol::get_reserve_size(conn_)),
req_head_);
},
*rpc_protocol::get_serialize_protocol(req_head_));
}
else {
std::visit(
[&]<typename serialize_proto>(const serialize_proto &) {
conn_->response_msg<rpc_protocol>(serialize_proto::serialize(ret),
req_head_);
},
*rpc_protocol::get_serialize_protocol(req_head_));
}
std::visit(
[&]<typename serialize_proto>(const serialize_proto &) {
self_->conn_->template response_msg<rpc_protocol>(
serialize_proto::serialize(ret), self_->req_head_,
self_->is_delay_);
},
*rpc_protocol::get_serialize_protocol(self_->req_head_));

// response_handler_(std::move(conn_), std::move(ret));
}
}
Expand All @@ -154,14 +129,20 @@ class context_base {
*
* @return true if closed, otherwise false
*/
bool has_closed() const { return conn_->has_closed(); }
bool has_closed() const { return self_->conn_->has_closed(); }

void set_delay() {
self_->is_delay_ = true;
self_->conn_->set_rpc_call_type(
coro_connection::rpc_call_type::callback_with_delay);
}

template <typename T>
void set_tag(T &&tag) {
conn_->set_tag(std::forward<T>(tag));
self_->conn_->set_tag(std::forward<T>(tag));
}

std::any get_tag() { return conn_->get_tag(); }
std::any get_tag() { return self_->conn_->get_tag(); }
};

template <typename T>
Expand Down
Loading

0 comments on commit 884acd3

Please sign in to comment.