Skip to content

Commit

Permalink
fix: wait ws coroutine quit (qicosmos#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
qicosmos authored Jul 31, 2023
1 parent 29769e5 commit a4a2e16
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 38 deletions.
22 changes: 11 additions & 11 deletions include/cinatra/coro_http_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class coro_http_client {

coro_http_client(asio::io_context::executor_type executor)
: socket_(std::make_shared<socket_t>(executor)),
read_buf_(socket_->read_buf_),
executor_wrapper_(executor),
timer_(&executor_wrapper_) {}

Expand Down Expand Up @@ -1021,6 +1022,7 @@ class coro_http_client {
struct socket_t {
asio::ip::tcp::socket impl_;
std::atomic<bool> has_closed_ = true;
asio::streambuf read_buf_;
template <typename ioc_t>
socket_t(ioc_t &&ioc) : impl_(std::forward<ioc_t>(ioc)) {}
};
Expand Down Expand Up @@ -1514,24 +1516,22 @@ class coro_http_client {

read_buf_.consume(read_buf_.size());
size_t header_size = 2;

std::shared_ptr sock = socket_;
auto on_ws_msg = std::move(on_ws_msg_);
websocket ws{};
while (true) {
std::weak_ptr socket = socket_;
if (auto [ec, _] = co_await async_read(read_buf_, header_size); ec) {
data.net_err = ec;
data.status = 404;
auto sock = socket.lock();
if (!sock) {

if (sock->has_closed_) {
co_return;
}
if (!sock->has_closed_) {
close_socket(*sock);
}

if (on_ws_msg_)
on_ws_msg_(data);
close_socket(*sock);

if (on_ws_msg)
on_ws_msg(data);
co_return;
}

Expand Down Expand Up @@ -1673,10 +1673,9 @@ class coro_http_client {
}

coro_io::ExecutorWrapper<> executor_wrapper_;
std::unique_ptr<asio::io_context::work> work_;
coro_io::period_timer timer_;
std::shared_ptr<socket_t> socket_;
asio::streambuf read_buf_;
asio::streambuf &read_buf_;
simple_buffer body_{};

std::unordered_map<std::string, std::string> req_headers_;
Expand Down Expand Up @@ -1714,6 +1713,7 @@ class coro_http_client {
std::chrono::steady_clock::duration req_timeout_duration_ =
std::chrono::seconds(60);
std::string resp_chunk_str_;

#ifdef BENCHMARK_TEST
std::string req_str_;
bool stop_bench_ = false;
Expand Down
33 changes: 6 additions & 27 deletions tests/test_cinatra_websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,14 @@ TEST_CASE("test wss client") {
}
#endif

async_simple::coro::Lazy<void> test_websocket(coro_http_client &client,
std::promise<void> &promise) {
async_simple::coro::Lazy<void> test_websocket(coro_http_client &client) {
client.on_ws_close([](std::string_view reason) {
std::cout << "web socket close " << reason << std::endl;
CHECK(reason == "ws close");
});
client.on_ws_msg([&](resp_data data) {
if (data.net_err) {
std::cout << data.net_err.message() << "\n";
promise.set_value();
return;
}

Expand Down Expand Up @@ -137,11 +135,9 @@ TEST_CASE("test websocket") {
coro_http_client client;
client.set_ws_sec_key("s//GYHa/XO7Hd2F2eOGfyA==");

std::promise<void> promise;
async_simple::coro::syncAwait(test_websocket(client, promise));
async_simple::coro::syncAwait(test_websocket(client));

client.async_close();
promise.get_future().wait();

std::this_thread::sleep_for(std::chrono::milliseconds(300));

Expand Down Expand Up @@ -176,39 +172,27 @@ void test_websocket_content(size_t len) {
REQUIRE(async_simple::coro::syncAwait(
client.async_ws_connect("ws://localhost:8090")));

std::pair<std::promise<void>, bool> msg_pair_promise{};

std::string send_str(len, 'a');

std::promise<void> quit_promise{};

client.on_ws_msg([&, send_str](resp_data data) {
if (data.net_err) {
std::cout << "ws_msg net error " << data.net_err.message() << "\n";
quit_promise.set_value();
if (!msg_pair_promise.second) {
msg_pair_promise.first.set_value();
}

return;
}

std::cout << "ws msg len: " << data.resp_body.size() << std::endl;
REQUIRE(data.resp_body.size() == send_str.size());
CHECK(data.resp_body == send_str);
msg_pair_promise.first.set_value();
msg_pair_promise.second = true;
});

async_simple::coro::syncAwait(client.async_send_ws(send_str));
msg_pair_promise.first.get_future().wait();

std::this_thread::sleep_for(std::chrono::milliseconds(300));

server.stop();
server_thread.join();

client.async_close();

quit_promise.get_future().wait();
}

TEST_CASE("test websocket content lt 126") {
Expand Down Expand Up @@ -243,12 +227,8 @@ TEST_CASE("test send after server stop") {
REQUIRE(async_simple::coro::syncAwait(
client->async_ws_connect("ws://localhost:8090")));

std::promise<void> promise;
client->on_ws_msg([&client, &promise](resp_data data) {
if (data.net_err) {
client->async_close();
}
promise.set_value();
client->on_ws_msg([](resp_data data) {
std::cout << data.net_err.message() << "\n";
});

server.stop();
Expand All @@ -259,5 +239,4 @@ TEST_CASE("test send after server stop") {
CHECK(result.net_err);

server_thread.join();
promise.get_future().wait();
}

0 comments on commit a4a2e16

Please sign in to comment.