diff --git a/include/cinatra/coro_http_client.hpp b/include/cinatra/coro_http_client.hpp index cd43ca03..7e4ae26a 100644 --- a/include/cinatra/coro_http_client.hpp +++ b/include/cinatra/coro_http_client.hpp @@ -23,6 +23,7 @@ #include "async_simple/coro/Lazy.h" #include "cinatra_log_wrapper.hpp" #include "http_parser.hpp" +#include "multipart.hpp" #include "picohttpparser.h" #include "response_cv.hpp" #include "string_resize.hpp" @@ -147,7 +148,7 @@ class coro_http_client : public std::enable_shared_from_this { : executor_wrapper_(executor), timer_(&executor_wrapper_), socket_(std::make_shared(executor)), - read_buf_(socket_->read_buf_), + head_buf_(socket_->head_buf_), chunked_buf_(socket_->chunked_buf_) {} coro_http_client( @@ -185,9 +186,9 @@ class coro_http_client : public std::enable_shared_from_this { return true; } - ~coro_http_client() { async_close(); } + ~coro_http_client() { close(); } - void async_close() { + void close() { if (socket_ == nullptr || socket_->has_closed_) return; @@ -505,7 +506,7 @@ class coro_http_client : public std::enable_shared_from_this { co_return data; } - std::tie(ec, size) = co_await async_read(read_buf_, total_len_); + std::tie(ec, size) = co_await async_read(head_buf_, total_len_); if (ec) { if (!stop_bench_) @@ -517,8 +518,8 @@ class coro_http_client : public std::enable_shared_from_this { } else { const char *data_ptr = - asio::buffer_cast(read_buf_.data()); - read_buf_.consume(total_len_); + asio::buffer_cast(head_buf_.data()); + head_buf_.consume(total_len_); // check status if (data_ptr[9] > '3') { data.status = 404; @@ -526,7 +527,7 @@ class coro_http_client : public std::enable_shared_from_this { } } - read_buf_.consume(total_len_); + head_buf_.consume(total_len_); data.status = 200; data.total = total_len_; @@ -1168,7 +1169,7 @@ class coro_http_client : public std::enable_shared_from_this { struct socket_t { asio::ip::tcp::socket impl_; std::atomic has_closed_ = true; - asio::streambuf read_buf_; + asio::streambuf head_buf_; asio::streambuf chunked_buf_; #ifdef CINATRA_ENABLE_SSL std::unique_ptr> ssl_stream_; @@ -1320,7 +1321,7 @@ class coro_http_client : public std::enable_shared_from_this { std::error_code handle_header(resp_data &data, http_parser &parser, size_t header_size) { // parse header - const char *data_ptr = asio::buffer_cast(read_buf_.data()); + const char *data_ptr = asio::buffer_cast(head_buf_.data()); int parse_ret = parser.parse_response(data_ptr, header_size, 0); #ifdef INJECT_FOR_HTTP_CLIENT_TEST @@ -1334,7 +1335,7 @@ class coro_http_client : public std::enable_shared_from_this { #endif return std::make_error_code(std::errc::protocol_error); } - read_buf_.consume(header_size); // header size + head_buf_.consume(header_size); // header size data.resp_headers = parser.get_headers(); data.status = parser.status(); return {}; @@ -1348,7 +1349,7 @@ class coro_http_client : public std::enable_shared_from_this { http_method method) { resp_data data{}; do { - if (std::tie(ec, size) = co_await async_read_until(read_buf_, TWO_CRCF); + if (std::tie(ec, size) = co_await async_read_until(head_buf_, TWO_CRCF); ec) { break; } @@ -1377,16 +1378,28 @@ class coro_http_client : public std::enable_shared_from_this { } if (parser_.is_chunked()) { is_keep_alive = true; - if (read_buf_.size() > 0) { + if (head_buf_.size() > 0) { const char *data_ptr = - asio::buffer_cast(read_buf_.data()); - chunked_buf_.sputn(data_ptr, read_buf_.size()); - read_buf_.consume(read_buf_.size()); + asio::buffer_cast(head_buf_.data()); + chunked_buf_.sputn(data_ptr, head_buf_.size()); + head_buf_.consume(head_buf_.size()); } ec = co_await handle_chunked(data, std::move(ctx)); break; } + if (parser_.is_multipart()) { + is_keep_alive = true; + if (head_buf_.size() > 0) { + const char *data_ptr = + asio::buffer_cast(head_buf_.data()); + chunked_buf_.sputn(data_ptr, head_buf_.size()); + head_buf_.consume(head_buf_.size()); + } + ec = co_await handle_multipart(data, std::move(ctx)); + break; + } + redirect_uri_.clear(); bool is_redirect = parser_.is_location(); if (is_redirect) @@ -1406,11 +1419,11 @@ class coro_http_client : public std::enable_shared_from_this { } } - if (content_len <= read_buf_.size()) { + if (content_len <= head_buf_.size()) { // Now get entire content, additional data will discard. // copy body. if (content_len > 0) { - auto data_ptr = asio::buffer_cast(read_buf_.data()); + auto data_ptr = asio::buffer_cast(head_buf_.data()); if (is_out_buf) { memcpy(out_buf_.data(), data_ptr, content_len); } @@ -1418,17 +1431,17 @@ class coro_http_client : public std::enable_shared_from_this { detail::resize(body_, content_len); memcpy(body_.data(), data_ptr, content_len); } - read_buf_.consume(read_buf_.size()); + head_buf_.consume(head_buf_.size()); } co_await handle_entire_content(data, content_len, is_ranges, ctx); break; } // read left part of content. - size_t part_size = read_buf_.size(); + size_t part_size = head_buf_.size(); size_t size_to_read = content_len - part_size; - auto data_ptr = asio::buffer_cast(read_buf_.data()); + auto data_ptr = asio::buffer_cast(head_buf_.data()); if (is_out_buf) { memcpy(out_buf_.data(), data_ptr, part_size); } @@ -1437,7 +1450,7 @@ class coro_http_client : public std::enable_shared_from_this { memcpy(body_.data(), data_ptr, part_size); } - read_buf_.consume(part_size); + head_buf_.consume(part_size); if (is_out_buf) { if (std::tie(ec, size) = co_await async_read( @@ -1474,7 +1487,7 @@ class coro_http_client : public std::enable_shared_from_this { auto &ctx) { if (content_len > 0) { const char *data_ptr; - if (read_buf_.size() == 0) { + if (head_buf_.size() == 0) { if (out_buf_.empty()) { data_ptr = body_.data(); } @@ -1483,7 +1496,7 @@ class coro_http_client : public std::enable_shared_from_this { } } else { - data_ptr = asio::buffer_cast(read_buf_.data()); + data_ptr = asio::buffer_cast(head_buf_.data()); } if (is_ranges) { @@ -1499,9 +1512,9 @@ class coro_http_client : public std::enable_shared_from_this { std::string_view reply(data_ptr, content_len); data.resp_body = reply; - read_buf_.consume(content_len); + head_buf_.consume(content_len); } - data.eof = (read_buf_.size() == 0); + data.eof = (head_buf_.size() == 0); } void handle_result(resp_data &data, std::error_code ec, bool is_keep_alive) { @@ -1522,6 +1535,39 @@ class coro_http_client : public std::enable_shared_from_this { } } + template + async_simple::coro::Lazy handle_multipart( + resp_data &data, req_context ctx) { + std::error_code ec{}; + std::string boundary = std::string{parser_.get_boundary()}; + multipart_reader_t multipart(this); + while (true) { + auto part_head = co_await multipart.read_part_head(); + if (part_head.ec) { + co_return part_head.ec; + } + + auto part_body = co_await multipart.read_part_body(boundary); + + if (ctx.stream) { + ec = co_await ctx.stream->async_write(part_body.data.data(), + part_body.data.size()); + } + else { + resp_chunk_str_.append(part_body.data.data(), part_body.data.size()); + } + + if (part_body.ec) { + co_return part_body.ec; + } + + if (part_body.eof) { + break; + } + } + co_return ec; + } + template async_simple::coro::Lazy handle_chunked( resp_data &data, req_context ctx) { @@ -1721,12 +1767,12 @@ class coro_http_client : public std::enable_shared_from_this { async_simple::coro::Lazy async_read_ws() { resp_data data{}; - read_buf_.consume(read_buf_.size()); + head_buf_.consume(head_buf_.size()); size_t header_size = 2; std::shared_ptr sock = socket_; auto on_ws_msg = std::move(on_ws_msg_); auto on_ws_close = std::move(on_ws_close_); - asio::streambuf &read_buf = sock->read_buf_; + asio::streambuf &read_buf = sock->head_buf_; bool has_init_ssl = false; #ifdef CINATRA_ENABLE_SSL has_init_ssl = has_init_ssl_; @@ -1927,11 +1973,12 @@ class coro_http_client : public std::enable_shared_from_this { return has_http_scheme; } + friend class multipart_reader_t; http_parser parser_; coro_io::ExecutorWrapper<> executor_wrapper_; coro_io::period_timer timer_; std::shared_ptr socket_; - asio::streambuf &read_buf_; + asio::streambuf &head_buf_; asio::streambuf &chunked_buf_; std::string body_; diff --git a/include/cinatra/coro_http_connection.hpp b/include/cinatra/coro_http_connection.hpp index 8f87ba89..688261a6 100644 --- a/include/cinatra/coro_http_connection.hpp +++ b/include/cinatra/coro_http_connection.hpp @@ -15,6 +15,7 @@ #include "coro_http_router.hpp" #include "define.h" #include "http_parser.hpp" +#include "multipart.hpp" #include "sha1.hpp" #include "string_resize.hpp" #include "websocket.hpp" @@ -22,18 +23,6 @@ #include "ylt/coro_io/coro_io.hpp" namespace cinatra { -struct chunked_result { - std::error_code ec; - bool eof = false; - std::string_view data; -}; - -struct part_head_t { - std::error_code ec; - std::string name; - std::string filename; -}; - struct websocket_result { std::error_code ec; ws_frame_type type; @@ -289,6 +278,52 @@ class coro_http_connection co_return co_await write_chunked("", true); } + async_simple::coro::Lazy begin_multipart( + std::string_view boundary = "", std::string_view content_type = "") { + response_.set_delay(true); + response_.set_status(status_type::ok); + if (boundary.empty()) { + boundary = BOUNDARY; + } + if (content_type.empty()) { + content_type = "multipart/form-data"; + } + + std::string str{content_type}; + str.append("; ").append("boundary=").append(boundary); + response_.add_header("Content-Type", str); + response_.set_boundary(boundary); + co_return co_await reply(); + } + + async_simple::coro::Lazy write_multipart( + std::string_view part_data, std::string_view content_type) { + response_.set_delay(true); + buffers_.clear(); + std::string part_head = "--"; + part_head.append(response_.get_boundary()).append(CRCF); + part_head.append("Content-Type: ").append(content_type).append(CRCF); + part_head.append("Content-Length: ") + .append(std::to_string(part_data.size())) + .append(TWO_CRCF); + + buffers_.push_back(asio::buffer(part_head)); + buffers_.push_back(asio::buffer(part_data)); + buffers_.push_back(asio::buffer(CRCF)); + + auto [ec, _] = co_await async_write(buffers_); + co_return !ec; + } + + async_simple::coro::Lazy end_multipart() { + response_.set_delay(true); + buffers_.clear(); + std::string multipart_end = "--"; + multipart_end.append(response_.get_boundary()).append("--").append(CRCF); + auto [ec, _] = co_await async_write(asio::buffer(multipart_end)); + co_return !ec; + } + async_simple::coro::Lazy read_chunked() { if (head_buf_.size() > 0) { const char *data_ptr = asio::buffer_cast(head_buf_.data()); @@ -347,99 +382,6 @@ class coro_http_connection co_return result; } - async_simple::coro::Lazy read_part_head() { - if (head_buf_.size() > 0) { - const char *data_ptr = asio::buffer_cast(head_buf_.data()); - chunked_buf_.sputn(data_ptr, head_buf_.size()); - head_buf_.consume(head_buf_.size()); - } - - part_head_t result{}; - std::error_code ec{}; - size_t last_size = chunked_buf_.size(); - size_t size; - - auto get_part_name = [](std::string_view data, std::string_view name, - size_t start) { - start += name.length(); - size_t end = data.find("\"", start); - return data.substr(start, end - start); - }; - - constexpr std::string_view name = "name=\""; - constexpr std::string_view filename = "filename=\""; - - while (true) { - if (std::tie(ec, size) = co_await async_read_until(chunked_buf_, CRCF); - ec) { - result.ec = ec; - close(); - co_return result; - } - - const char *data_ptr = - asio::buffer_cast(chunked_buf_.data()); - chunked_buf_.consume(size); - if (*data_ptr == '-') { - continue; - } - std::string_view data{data_ptr, size}; - if (size == 2) { // got the head end: \r\n\r\n - break; - } - - if (size_t pos = data.find("name"); pos != std::string_view::npos) { - result.name = get_part_name(data, name, pos); - - if (size_t pos = data.find("filename"); pos != std::string_view::npos) { - result.filename = get_part_name(data, filename, pos); - } - continue; - } - } - - co_return result; - } - - async_simple::coro::Lazy read_part_body( - std::string_view boundary) { - chunked_result result{}; - std::error_code ec{}; - size_t size = 0; - - if (std::tie(ec, size) = co_await async_read_until(chunked_buf_, boundary); - ec) { - result.ec = ec; - close(); - co_return result; - } - - const char *data_ptr = asio::buffer_cast(chunked_buf_.data()); - chunked_buf_.consume(size); - result.data = std::string_view{ - data_ptr, size - boundary.size() - 4}; //-- boundary \r\n - - if (std::tie(ec, size) = co_await async_read_until(chunked_buf_, CRCF); - ec) { - result = {}; - result.ec = ec; - close(); - co_return result; - } - - data_ptr = asio::buffer_cast(chunked_buf_.data()); - std::string data{data_ptr, size}; - if (size > 2) { - constexpr std::string_view complete_flag = "--\r\n"; - if (data == complete_flag) { - result.eof = true; - } - } - - chunked_buf_.consume(size); - co_return result; - } - async_simple::coro::Lazy write_websocket( std::string_view msg, opcode op = opcode::text) { auto header = ws_.format_header(msg.length(), op); @@ -683,6 +625,7 @@ class coro_http_connection } private: + friend class multipart_reader_t; async_simple::Executor *executor_; asio::ip::tcp::socket socket_; coro_http_router &router_; diff --git a/include/cinatra/coro_http_response.hpp b/include/cinatra/coro_http_response.hpp index eb720765..f3d0a1e9 100644 --- a/include/cinatra/coro_http_response.hpp +++ b/include/cinatra/coro_http_response.hpp @@ -57,6 +57,10 @@ class coro_http_response { void set_keepalive(bool r) { keepalive_ = r; } + void set_boundary(std::string_view boundary) { boundary_ = boundary; } + + std::string_view get_boundary() { return boundary_; } + void to_buffers(std::vector& buffers) { build_resp_head(); @@ -99,10 +103,18 @@ class coro_http_response { } void build_resp_head() { - if (std::find_if(resp_headers_.begin(), resp_headers_.end(), - [](resp_header& header) { - return header.key == "Host"; - }) == resp_headers_.end()) { + bool has_len = false; + bool has_host = false; + for (auto& [k, v] : resp_headers_) { + if (k == "Host") { + has_host = true; + } + if (k == "Content-Length") { + has_len = true; + } + } + + if (!has_host) { resp_headers_sv_.emplace_back(resp_header_sv{"Host", "cinatra"}); } @@ -122,7 +134,8 @@ class coro_http_response { std::string_view(buf_, std::distance(buf_, ptr))}); } else { - resp_headers_sv_.emplace_back(resp_header_sv{"Content-Length", "0"}); + if (!has_len && boundary_.empty()) + resp_headers_sv_.emplace_back(resp_header_sv{"Content-Length", "0"}); } } @@ -151,6 +164,7 @@ class coro_http_response { delay_ = false; status_ = status_type::init; fmt_type_ = format_type::normal; + boundary_.clear(); } void append_head(auto& headers) { @@ -173,5 +187,6 @@ class coro_http_response { std::vector resp_headers_; std::vector resp_headers_sv_; coro_http_connection* conn_; + std::string boundary_; }; } // namespace cinatra \ No newline at end of file diff --git a/include/cinatra/define.h b/include/cinatra/define.h index 011a3870..eadbf662 100644 --- a/include/cinatra/define.h +++ b/include/cinatra/define.h @@ -105,6 +105,18 @@ const static inline std::string TWO_CRCF = "\r\n\r\n"; const static inline std::string BOUNDARY = "--CinatraBoundary2B8FAF4A80EDB307"; const static inline std::string MULTIPART_END = CRCF + "--" + BOUNDARY + "--"; +struct chunked_result { + std::error_code ec; + bool eof = false; + std::string_view data; +}; + +struct part_head_t { + std::error_code ec; + std::string name; + std::string filename; +}; + inline std::unordered_map g_content_type_map = { {".css", "text/css"}, {".csv", "text/csv"}, diff --git a/include/cinatra/http_parser.hpp b/include/cinatra/http_parser.hpp index 55559a43..6c2fa898 100644 --- a/include/cinatra/http_parser.hpp +++ b/include/cinatra/http_parser.hpp @@ -126,6 +126,29 @@ class http_parser { return false; } + bool is_multipart() { + auto content_type = get_header_value("Content-Type"); + if (content_type.empty()) { + return false; + } + + if (content_type.find("multipart") == std::string_view::npos) { + return false; + } + + return true; + } + + std::string_view get_boundary() { + auto content_type = get_header_value("Content-Type"); + size_t pos = content_type.find("=--"); + if (pos == std::string_view::npos) { + return ""; + } + + return content_type.substr(pos + 1); + } + bool is_req_ranges() const { auto value = this->get_header_value("Range"sv); return !value.empty(); diff --git a/include/cinatra/multipart.hpp b/include/cinatra/multipart.hpp new file mode 100644 index 00000000..d40aaa71 --- /dev/null +++ b/include/cinatra/multipart.hpp @@ -0,0 +1,118 @@ +#pragma once +#include "define.h" + +namespace cinatra { + +template +class multipart_reader_t { + public: + multipart_reader_t(T *conn) + : conn_(conn), + head_buf_(conn_->head_buf_), + chunked_buf_(conn_->chunked_buf_) {} + + async_simple::coro::Lazy read_part_head() { + if (head_buf_.size() > 0) { + const char *data_ptr = asio::buffer_cast(head_buf_.data()); + chunked_buf_.sputn(data_ptr, head_buf_.size()); + head_buf_.consume(head_buf_.size()); + } + + part_head_t result{}; + std::error_code ec{}; + size_t last_size = chunked_buf_.size(); + size_t size; + + auto get_part_name = [](std::string_view data, std::string_view name, + size_t start) { + start += name.length(); + size_t end = data.find("\"", start); + return data.substr(start, end - start); + }; + + constexpr std::string_view name = "name=\""; + constexpr std::string_view filename = "filename=\""; + + while (true) { + if (std::tie(ec, size) = + co_await conn_->async_read_until(chunked_buf_, CRCF); + ec) { + result.ec = ec; + conn_->close(); + co_return result; + } + + const char *data_ptr = + asio::buffer_cast(chunked_buf_.data()); + chunked_buf_.consume(size); + if (*data_ptr == '-') { + continue; + } + std::string_view data{data_ptr, size}; + if (size == 2) { // got the head end: \r\n\r\n + break; + } + + if (size_t pos = data.find("name"); pos != std::string_view::npos) { + result.name = get_part_name(data, name, pos); + + if (size_t pos = data.find("filename"); pos != std::string_view::npos) { + result.filename = get_part_name(data, filename, pos); + } + continue; + } + } + + co_return result; + } + + async_simple::coro::Lazy read_part_body( + std::string_view boundary) { + chunked_result result{}; + std::error_code ec{}; + size_t size = 0; + + if (std::tie(ec, size) = + co_await conn_->async_read_until(chunked_buf_, boundary); + ec) { + result.ec = ec; + conn_->close(); + co_return result; + } + + const char *data_ptr = asio::buffer_cast(chunked_buf_.data()); + chunked_buf_.consume(size); + result.data = std::string_view{ + data_ptr, size - boundary.size() - 4}; //-- boundary \r\n + + if (std::tie(ec, size) = + co_await conn_->async_read_until(chunked_buf_, CRCF); + ec) { + result = {}; + result.ec = ec; + conn_->close(); + co_return result; + } + + data_ptr = asio::buffer_cast(chunked_buf_.data()); + std::string data{data_ptr, size}; + if (size > 2) { + constexpr std::string_view complete_flag = "--\r\n"; + if (data == complete_flag) { + result.eof = true; + } + } + + chunked_buf_.consume(size); + co_return result; + } + + private: + T *conn_; + asio::streambuf &head_buf_; + asio::streambuf &chunked_buf_; +}; + +template +multipart_reader_t(T *con) -> multipart_reader_t; +} // namespace cinatra \ No newline at end of file diff --git a/press_tool/main.cpp b/press_tool/main.cpp index 5adfb6b9..2dc7cfb7 100644 --- a/press_tool/main.cpp +++ b/press_tool/main.cpp @@ -245,7 +245,7 @@ int main(int argc, char* argv[]) { for (auto& counter : v) { for (auto& conn : counter.conns) { conn->set_bench_stop(); - conn->async_close(); + conn->close(); } } }); diff --git a/tests/test_cinatra.cpp b/tests/test_cinatra.cpp index 563ce38b..ecc78566 100644 --- a/tests/test_cinatra.cpp +++ b/tests/test_cinatra.cpp @@ -15,6 +15,7 @@ #include "cinatra/coro_http_client.hpp" #include "cinatra/coro_http_server.hpp" #include "cinatra/define.h" +#include "cinatra/multipart.hpp" #include "cinatra/string_resize.hpp" #include "cinatra/time_util.hpp" #include "doctest/doctest.h" @@ -447,7 +448,7 @@ TEST_CASE("test upload file") { "http//badurl.com", "test_not_exist_file", not_exist_file)); CHECK(result.status == 404); - client.async_close(); + client.close(); server.stop(); server_thread.join(); @@ -567,9 +568,9 @@ TEST_CASE("test coro_http_client multipart upload") { coro_http_response &resp) -> async_simple::coro::Lazy { assert(req.get_content_type() == content_type::multipart); auto boundary = req.get_boundary(); - + multipart_reader_t multipart(req.get_conn()); while (true) { - auto part_head = co_await req.get_conn()->read_part_head(); + auto part_head = co_await multipart.read_part_head(); if (part_head.ec) { co_return; } @@ -599,7 +600,7 @@ TEST_CASE("test coro_http_client multipart upload") { } } - auto part_body = co_await req.get_conn()->read_part_body(boundary); + auto part_body = co_await multipart.read_part_body(boundary); if (part_body.ec) { co_return; } diff --git a/tests/test_cinatra_websocket.cpp b/tests/test_cinatra_websocket.cpp index 6650206b..729117a9 100644 --- a/tests/test_cinatra_websocket.cpp +++ b/tests/test_cinatra_websocket.cpp @@ -59,7 +59,7 @@ TEST_CASE("test wss client") { promise.get_future().wait(); - client.async_close(); + client.close(); server.stop(); server_thread.join(); @@ -180,7 +180,7 @@ void test_websocket_content(size_t len) { server.stop(); server_thread.join(); - client.async_close(); + client.close(); } TEST_CASE("test websocket content lt 126") { diff --git a/tests/test_coro_http_server.cpp b/tests/test_coro_http_server.cpp index 943648b5..7a23ee72 100644 --- a/tests/test_coro_http_server.cpp +++ b/tests/test_coro_http_server.cpp @@ -141,6 +141,38 @@ bool create_file(std::string_view filename, size_t file_size = 1024) { return true; } +TEST_CASE("test multiple download") { + coro_http_server server(1, 9001); + server.set_http_handler( + "/", + [](coro_http_request &req, + coro_http_response &resp) -> async_simple::coro::Lazy { + multipart_reader_t multipart(resp.get_conn()); + bool ok; + if (ok = co_await resp.get_conn()->begin_multipart(); !ok) { + co_return; + } + + std::vector vec{"hello", " world", " ok"}; + + for (auto &str : vec) { + if (ok = co_await resp.get_conn()->write_multipart(str, "text/plain"); + !ok) { + co_return; + } + } + + ok = co_await resp.get_conn()->end_multipart(); + }); + + server.async_start(); + + coro_http_client client{}; + auto result = client.get("http://127.0.0.1:9001/"); + CHECK(result.status == 200); + CHECK(result.resp_body == "hello world ok"); +} + TEST_CASE("test range download") { create_file("range_test.txt", 64); std::cout << fs::current_path() << "\n";