Skip to content

Commit

Permalink
feat: add websocket compression for coro_client
Browse files Browse the repository at this point in the history
  • Loading branch information
helintongh committed Apr 17, 2024
1 parent f2c5363 commit 8cc8cab
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 50 deletions.
109 changes: 108 additions & 1 deletion include/cinatra/coro_http_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "async_simple/Unit.h"
#include "async_simple/coro/FutureAwaiter.h"
#include "async_simple/coro/Lazy.h"
#ifdef CINATRA_ENABLE_GZIP
#include "gzip.hpp"
#endif
#include "cinatra_log_wrapper.hpp"
#include "http_parser.hpp"
#include "multipart.hpp"
Expand Down Expand Up @@ -319,14 +322,16 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {

void set_ws_sec_key(std::string sec_key) { ws_sec_key_ = std::move(sec_key); }

async_simple::coro::Lazy<bool> async_ws_connect(std::string uri) {
async_simple::coro::Lazy<bool> async_ws_connect(std::string uri, bool enable_ws_deflate = false) {
resp_data data{};
auto [r, u] = handle_uri(data, uri);
if (!r) {
CINATRA_LOG_WARNING << "url error:";
co_return false;
}

enable_ws_deflate_ = enable_ws_deflate;

req_context<> ctx{};
if (u.is_websocket()) {
// build websocket http header
Expand All @@ -337,10 +342,33 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
}
add_header("Sec-WebSocket-Key", ws_sec_key_);
add_header("Sec-WebSocket-Version", "13");
#ifdef CINATRA_ENABLE_GZIP
add_header("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits");
#endif
}

data = co_await async_request(std::move(uri), http_method::GET,
std::move(ctx));
#ifdef CINATRA_ENABLE_GZIP
if (enable_ws_deflate_) {
for (auto c : data.resp_headers) {
std::cout << c.name << " value is: " << c.value << std::endl;
if (c.name == "Sec-WebSocket-Extensions") {
std::cout << "have extensions\n";
if (c.value.find("permessage-deflate;") != std::string::npos) {
std::cout << "support deflate extensions\n";
is_server_support_ws_deflate_ = true;
}
else {
std::cout << "not support deflate extensions\n";
is_server_support_ws_deflate_ = false;
}
break;
}
}
}
#endif

async_read_ws().start([](auto &&) {
});
co_return !data.net_err;
Expand Down Expand Up @@ -376,6 +404,30 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
}

if constexpr (is_span_v<Source>) {
#ifdef CINATRA_ENABLE_GZIP
if (enable_ws_deflate_ && is_server_support_ws_deflate_) {
std::string dest_buf;
if (cinatra::gzip_codec::deflate(std::string(source.begin(), source.end()), dest_buf)) {
std::span<char> msg(dest_buf.data(), dest_buf.size());
auto header = ws.encode_frame(msg, op, need_mask, true, true);
std::vector<asio::const_buffer> buffers;
buffers.push_back(asio::buffer(header));
buffers.push_back(asio::buffer(dest_buf));

auto [ec, sz] = co_await async_write(buffers);
if (ec) {
data.net_err = ec;
data.status = 404;
}
}
else {
CINATRA_LOG_ERROR << "compuress data error, data: " << std::string(source.begin(), source.end());
data.net_err = std::make_error_code(std::errc::protocol_error);
data.status = 404;
}
}
else {
#endif
std::string encode_header = ws.encode_frame(source, op, need_mask);
std::vector<asio::const_buffer> buffers{
asio::buffer(encode_header.data(), encode_header.size()),
Expand All @@ -386,11 +438,40 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
data.net_err = ec;
data.status = 404;
}
#ifdef CINATRA_ENABLE_GZIP
}
#endif
}
else {
while (true) {
auto result = co_await source();

#ifdef CINATRA_ENABLE_GZIP
if (enable_ws_deflate_ && is_server_support_ws_deflate_) {
std::string dest_buf;
if (cinatra::gzip_codec::deflate(std::string(result), dest_buf)) {
std::span<char> msg(dest_buf.data(), dest_buf.size());
std::string header =
ws.encode_frame(msg, op, need_mask, result.eof, true);
std::vector<asio::const_buffer> buffers;
buffers.push_back(asio::buffer(header));
buffers.push_back(asio::buffer(dest_buf));

auto [ec, sz] = co_await async_write(buffers);
if (ec) {
data.net_err = ec;
data.status = 404;
}
}
else {
CINATRA_LOG_ERROR << "compuress data error, data: " << std::string(source.begin(), source.end());
data.net_err = std::make_error_code(std::errc::protocol_error);
data.status = 404;
}
}
else {
#endif

std::span<char> msg(result.buf.data(), result.buf.size());
std::string encode_header =
ws.encode_frame(msg, op, need_mask, result.eof);
Expand All @@ -409,6 +490,9 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
break;
}
}
#ifdef CINATRA_ENABLE_GZIP
}
#endif
}

co_return data;
Expand Down Expand Up @@ -1848,9 +1932,27 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
data_ptr += sizeof(uint16_t);
}
}
#ifdef CINATRA_ENABLE_GZIP
if (!is_close_frame && is_server_support_ws_deflate_ && enable_ws_deflate_) {
std::string out;
if (!cinatra::gzip_codec::inflate(std::string(data_ptr), out))
{
CINATRA_LOG_ERROR << "uncompuress data error";
data.status = 404;
data.net_err = std::make_error_code(std::errc::protocol_error);
break;
}

data.status = 200;
data.resp_body = {out.data(), out.size()};
}
else {
#endif
data.status = 200;
data.resp_body = {data_ptr, payload_len};
#ifdef CINATRA_ENABLE_GZIP
}
#endif

read_buf.consume(read_buf.size());
header_size = 2;
Expand Down Expand Up @@ -2042,6 +2144,11 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
std::string resp_chunk_str_;
std::span<char> out_buf_;

bool enable_ws_deflate_ = false;
#ifdef CINATRA_ENABLE_GZIP
bool is_server_support_ws_deflate_ = false;
#endif

#ifdef BENCHMARK_TEST
std::string req_str_;
bool stop_bench_ = false;
Expand Down
4 changes: 2 additions & 2 deletions include/cinatra/coro_http_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,13 @@ class coro_http_connection
std::string_view msg, opcode op = opcode::text) {
#ifdef CINATRA_ENABLE_GZIP
std::string dest_buf;
if (is_client_ws_compressed_ && data_length > 0) {
if (is_client_ws_compressed_ && msg.size() > 0) {
if (!cinatra::gzip_codec::deflate(std::string(msg), dest_buf)) {
CINATRA_LOG_ERROR << "compuress data error, data: " << msg;
co_return std::make_error_code(std::errc::protocol_error);
}

auto header = ws_.format_compressed_header(dest_buf.length(), op);
auto header = ws_.format_header(dest_buf.length(), op, true);
std::vector<asio::const_buffer> buffers;
buffers.push_back(asio::buffer(header));
buffers.push_back(asio::buffer(dest_buf));
Expand Down
58 changes: 11 additions & 47 deletions include/cinatra/websocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,61 +121,22 @@ class websocket {
return ws_frame_type::WS_BINARY_FRAME;
}

std::string format_header(size_t length, opcode code) {
size_t header_length = encode_header(length, code);
std::string format_header(size_t length, opcode code, bool is_compressed = false) {
size_t header_length = encode_header(length, code, is_compressed);
return {msg_header_, header_length};
}

std::string format_compressed_header(size_t data_length, opcode code) {

std::string destbuf;
char first_two_bytes[2] = { 0 };
//FIN
first_two_bytes[0] |= 0x80;

first_two_bytes[0] |= code;

const char compress_flag = 0x40;
first_two_bytes[0] |= compress_flag;

//mask = 0;
std::string send_data;

if (data_length < 126)
{
first_two_bytes[1] = data_length;
send_data.append(first_two_bytes, 2);
}
else if (data_length <= UINT16_MAX)
{
first_two_bytes[1] = 126;
char extended_playload_length[2] = { 0 };
uint16_t tmp = htons(data_length);
memcpy(&extended_playload_length, &tmp, 2);
send_data.append(first_two_bytes, 2);
send_data.append(extended_playload_length, 2);
}
else
{
first_two_bytes[1] = 127;
char extended_playload_length[8] = {0};
uint64_t tmp = htobe64((uint64_t)data_length);
memcpy(&extended_playload_length, &tmp, 8);
send_data.append(first_two_bytes, 2);
send_data.append(extended_playload_length, 8);
}

return send_data;
}

std::string encode_frame(std::span<char> &data, opcode op, bool need_mask,
bool eof = true) {
bool eof = true, bool need_compression = false) {
std::string header;
/// Base header.
frame_header hdr{};
hdr.fin = eof;
hdr.rsv1 = 0;
hdr.rsv2 = 0;
if (need_compression)
hdr.rsv2 = 1;
else
hdr.rsv2 = 0;
hdr.rsv3 = 0;
hdr.opcode = static_cast<uint8_t>(op);
hdr.mask = 1;
Expand Down Expand Up @@ -269,7 +230,7 @@ class websocket {
opcode get_opcode() { return (opcode)msg_opcode_; }

private:
size_t encode_header(size_t length, opcode code) {
size_t encode_header(size_t length, opcode code, bool is_compressed = false) {
size_t header_length;

if (length < 126) {
Expand All @@ -293,6 +254,9 @@ class websocket {
msg_header_[0] |= code;
}

if (is_compressed)
msg_header_[0] |= 0x40;

return header_length;
}

Expand Down

0 comments on commit 8cc8cab

Please sign in to comment.