Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coro_rpc] add attachment/close/get_connection_id for coro_rpc::context #521

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions include/ylt/coro_rpc/impl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include <async_simple/coro/Lazy.h>

#include <any>
#include <atomic>
#include <cstdint>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -89,7 +89,8 @@ class context_base {
std::visit(
[&]<typename serialize_proto>(const serialize_proto &) {
self_->conn_->template response_msg<rpc_protocol>(
serialize_proto::serialize(), self_->req_head_,
serialize_proto::serialize(),
std::move(self_->resp_attachment_), self_->req_head_,
self_->is_delay_);
},
*rpc_protocol::get_serialize_protocol(self_->req_head_));
Expand All @@ -116,13 +117,17 @@ class context_base {
std::visit(
[&]<typename serialize_proto>(const serialize_proto &) {
self_->conn_->template response_msg<rpc_protocol>(
serialize_proto::serialize(ret), self_->req_head_,
serialize_proto::serialize(ret),
std::move(self_->resp_attachment_), self_->req_head_,
self_->is_delay_);
},
*rpc_protocol::get_serialize_protocol(self_->req_head_));

// response_handler_(std::move(conn_), std::move(ret));
}
self_->resp_attachment_ = [] {
return std::string_view{};
};
}

/*!
Expand All @@ -132,6 +137,51 @@ class context_base {
*/
bool has_closed() const { return self_->conn_->has_closed(); }

/*!
* Close connection
*/
void close() { return self_->conn_->async_close(); }

/*!
* Get the unique connection ID
* @return connection id
*/
uint64_t get_connection_id() { return self_->conn_->conn_id_; }

/*!
* Set the response_attachment
* @return a ref of response_attachment
*/
void set_response_attachment(std::string attachment) {
set_response_attachment([attachment = std::move(attachment)] {
return std::string_view{attachment};
});
}

/*!
* Set the response_attachment
* @return a ref of response_attachment
*/
void set_response_attachment(std::function<std::string_view()> attachment) {
self_->resp_attachment_ = std::move(attachment);
}

/*!
* Get the request attachment
* @return connection id
*/
std::string_view get_request_attachment() const {
return self_->req_attachment_;
}

/*!
* Release the attachment
* @return connection id
*/
std::string release_request_attachment() {
return std::move(self_->req_attachment_);
}

void set_delay() {
self_->is_delay_ = true;
self_->conn_->set_rpc_call_type(
Expand Down
109 changes: 62 additions & 47 deletions include/ylt/coro_rpc/impl/coro_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <asio/buffer.hpp>
#include <atomic>
#include <cstdint>
#include <functional>
#include <future>
#include <memory>
#include <string_view>
Expand All @@ -37,12 +38,6 @@
#endif
namespace coro_rpc {

template <typename T>
concept apply_user_buf = requires() {
requires std::is_same_v<std::string_view,
typename std::remove_cvref_t<T>::buffer_type>;
};

class coro_connection;
using rpc_conn = std::shared_ptr<coro_connection>;

Expand All @@ -51,7 +46,11 @@ struct context_info_t {
constexpr static size_t body_default_size_ = 256;
std::shared_ptr<coro_connection> conn_;
typename rpc_protocol::req_header req_head_;
std::vector<char> body_;
std::string body_;
std::string req_attachment_;
std::function<std::string_view()> resp_attachment_ = [] {
return std::string_view{};
};
std::atomic<bool> has_response_ = false;
bool is_delay_ = false;
context_info_t(std::shared_ptr<coro_connection> &&conn)
Expand Down Expand Up @@ -158,6 +157,7 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
while (true) {
auto &req_head = context_info->req_head_;
auto &body = context_info->body_;
auto &req_attachment = context_info->req_attachment_;
reset_timer();
auto ec = co_await rpc_protocol::read_head(socket, req_head);
cancel_timer();
Expand Down Expand Up @@ -195,16 +195,12 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
break;
}

std::string_view payload{};
std::string_view payload;
// rpc_protocol::buffer_type maybe from user, default from framework.
constexpr bool apply_user_buf_v = apply_user_buf<rpc_protocol>;
if constexpr (apply_user_buf_v) {
ec = co_await rpc_protocol::read_payload(socket, req_head, payload);
}
else {
ec = co_await rpc_protocol::read_payload(socket, req_head, body);
payload = std::string_view{body.data(), body.size()};
}

ec = co_await rpc_protocol::read_payload(socket, req_head, body,
req_attachment);
payload = std::string_view{body};

if (ec)
AS_UNLIKELY {
Expand Down Expand Up @@ -256,7 +252,7 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
if (resp_err != std::errc{})
AS_UNLIKELY { std::swap(resp_buf, resp_error_msg); }
std::string header_buf = rpc_protocol::prepare_response(
resp_buf, req_head, resp_err, resp_error_msg);
resp_buf, req_head, 0, resp_err, resp_error_msg);

#ifdef UNIT_TEST_INJECT
if (g_action == inject_action::close_socket_after_send_length) {
Expand All @@ -280,7 +276,10 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
AS_LIKELY {
if (resp_err != std::errc{})
AS_UNLIKELY { resp_err_ = resp_err; }
write_queue_.emplace_back(std::move(header_buf), std::move(resp_buf));
write_queue_.emplace_back(std::move(header_buf), std::move(resp_buf),
[] {
return std::string_view{};
});
if (write_queue_.size() == 1) {
send_data().start([self = shared_from_this()](auto &&) {
});
Expand All @@ -299,11 +298,13 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {

template <typename rpc_protocol>
void response_msg(std::string &&body_buf,
std::function<std::string_view()> &&resp_attachment,
const typename rpc_protocol::req_header &req_head,
bool is_delay) {
std::string header_buf = rpc_protocol::prepare_response(body_buf, req_head);
response(std::move(header_buf), std::move(body_buf), shared_from_this(),
is_delay)
std::string header_buf = rpc_protocol::prepare_response(
body_buf, req_head, resp_attachment().size());
response(std::move(header_buf), std::move(body_buf),
std::move(resp_attachment), shared_from_this(), is_delay)
.via(executor_)
.detach();
}
Expand Down Expand Up @@ -351,9 +352,10 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
auto &get_executor() { return *executor_; }

private:
async_simple::coro::Lazy<void> response(std::string header_buf,
std::string body_buf, rpc_conn self,
bool is_delay) noexcept {
async_simple::coro::Lazy<void> response(
std::string header_buf, std::string body_buf,
std::function<std::string_view()> resp_attachment, rpc_conn self,
bool is_delay) noexcept {
if (has_closed())
AS_UNLIKELY {
ELOGV(DEBUG, "response_msg failed: connection has been closed");
Expand All @@ -365,7 +367,8 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
body_buf.clear();
}
#endif
write_queue_.emplace_back(std::move(header_buf), std::move(body_buf));
write_queue_.emplace_back(std::move(header_buf), std::move(body_buf),
std::move(resp_attachment));
if (is_delay) {
--delay_resp_cnt;
assert(delay_resp_cnt >= 0);
Expand Down Expand Up @@ -402,19 +405,38 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
co_return;
}
#endif
std::array<asio::const_buffer, 2> buffers{asio::buffer(msg.first),
asio::buffer(msg.second)};
auto attachment = std::get<2>(msg)();
if (attachment.empty()) {
std::array<asio::const_buffer, 2> buffers{
asio::buffer(std::get<0>(msg)), asio::buffer(std::get<1>(msg))};
#ifdef YLT_ENABLE_SSL
if (use_ssl_) {
assert(ssl_stream_);
ret = co_await coro_io::async_write(*ssl_stream_, buffers);
}
else {
#endif
ret = co_await coro_io::async_write(socket_, buffers);
#ifdef YLT_ENABLE_SSL
if (use_ssl_) {
assert(ssl_stream_);
ret = co_await coro_io::async_write(*ssl_stream_, buffers);
}
#endif
}
else {
std::array<asio::const_buffer, 3> buffers{
asio::buffer(std::get<0>(msg)), asio::buffer(std::get<1>(msg)),
asio::buffer(attachment)};
#ifdef YLT_ENABLE_SSL
if (use_ssl_) {
assert(ssl_stream_);
ret = co_await coro_io::async_write(*ssl_stream_, buffers);
}
else {
#endif
ret = co_await coro_io::async_write(socket_, buffers);
ret = co_await coro_io::async_write(socket_, buffers);
#ifdef YLT_ENABLE_SSL
}
}
#endif
}
if (ret.first)
AS_UNLIKELY {
ELOGV(ERROR, "%s, %s", ret.first.message().data(),
Expand Down Expand Up @@ -449,11 +471,13 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
if (has_closed_) {
return;
}
close_socket();
has_closed_ = true;
asio::error_code ignored_ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec);
socket_.close(ignored_ec);
if (quit_callback_) {
quit_callback_(conn_id_);
}
has_closed_ = true;
}

void reset_timer() {
Expand All @@ -472,22 +496,11 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
ELOGV(INFO, "close timeout client conn_id %d", conn_id_);
#endif

close_socket();
close();
}
});
}

void close_socket() {
if (has_closed_) {
return;
}

asio::error_code ignored_ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ignored_ec);
socket_.close(ignored_ec);
has_closed_ = true;
}

void cancel_timer() {
if (!enable_check_timeout_) {
return;
Expand All @@ -502,7 +515,9 @@ class coro_connection : public std::enable_shared_from_this<coro_connection> {
async_simple::Executor *executor_;
asio::ip::tcp::socket socket_;
// FIXME: queue's performance can be imporved.
std::deque<std::pair<std::string, std::string>> write_queue_;
std::deque<
std::tuple<std::string, std::string, std::function<std::string_view()>>>
write_queue_;
std::errc resp_err_;
rpc_call_type rpc_call_type_{non_callback};

Expand Down
Loading
Loading