From b0c2639dbb8ea49b651c96ee18c04f5e3b047844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anarthal=20=28Rub=C3=A9n=20P=C3=A9rez=29?= <34971811+anarthal@users.noreply.github.com> Date: Fri, 26 Jul 2024 11:46:32 +0200 Subject: [PATCH] Writes exceeding max buffer size no longer resize the buffer before failing close #297 --- .../protocol/impl/binary_protocol.hpp | 10 +- .../internal/protocol/impl/protocol_types.hpp | 54 +-- .../protocol/impl/serialization_context.hpp | 143 ++++---- .../impl/internal/protocol/serialization.hpp | 63 +++- .../impl/internal/sansio/close_statement.hpp | 13 +- .../internal/sansio/connection_state_data.hpp | 5 +- .../boost/mysql/impl/internal/sansio/ping.hpp | 15 +- .../impl/internal/sansio/reset_connection.hpp | 6 +- .../impl/internal/sansio/top_level_algo.hpp | 11 - include/boost/mysql/impl/pipeline.ipp | 12 +- test/unit/include/test_unit/algo_test.hpp | 33 +- .../include/test_unit/create_coldef_frame.hpp | 36 +- test/unit/include/test_unit/create_err.hpp | 22 +- .../include/test_unit/create_ok_frame.hpp | 31 +- .../create_prepare_statement_response.hpp | 23 +- .../include/test_unit/create_query_frame.hpp | 10 +- .../include/test_unit/create_row_message.hpp | 34 +- .../include/test_unit/serialize_to_vector.hpp | 37 ++ test/unit/test/protocol/binary_protocol.cpp | 38 -- test/unit/test/protocol/serialization.cpp | 34 +- .../test/protocol/serialization_context.cpp | 341 ++++++++++++------ .../unit/test/protocol/serialization_test.hpp | 15 +- test/unit/test/sansio/prepare_statement.cpp | 36 +- test/unit/test/sansio/set_character_set.cpp | 13 + test/unit/test/sansio/start_execution.cpp | 28 +- test/unit/test/sansio/top_level_algo.cpp | 31 -- 26 files changed, 642 insertions(+), 452 deletions(-) create mode 100644 test/unit/include/test_unit/serialize_to_vector.hpp diff --git a/include/boost/mysql/impl/internal/protocol/impl/binary_protocol.hpp b/include/boost/mysql/impl/internal/protocol/impl/binary_protocol.hpp index aee115135..bcd7f050d 100644 --- a/include/boost/mysql/impl/internal/protocol/impl/binary_protocol.hpp +++ b/include/boost/mysql/impl/internal/protocol/impl/binary_protocol.hpp @@ -326,7 +326,7 @@ inline void serialize_binary_float(serialization_context& ctx, T input) inline void serialize_binary_date(serialization_context& ctx, const date& input) { - ctx.serialize( + ctx.serialize_fixed( int1{static_cast(binc::date_sz)}, int2{input.year()}, int1{input.month()}, @@ -336,7 +336,7 @@ inline void serialize_binary_date(serialization_context& ctx, const date& input) inline void serialize_binary_datetime(serialization_context& ctx, const datetime& input) { - ctx.serialize( + ctx.serialize_fixed( int1{static_cast(binc::datetime_dhmsu_sz)}, int2{input.year()}, int1{input.month()}, @@ -367,7 +367,7 @@ inline void serialize_binary_time(serialization_context& ctx, const boost::mysql auto is_negative = (input.count() < 0) ? 1 : 0; // Serialize - ctx.serialize( + ctx.serialize_fixed( int1{static_cast(time_dhmsu_sz)}, int1{static_cast(is_negative)}, int4{static_cast(std::abs(num_days.count()))}, @@ -431,8 +431,8 @@ void boost::mysql::detail::serialize_binary_field(serialization_context& ctx, fi case field_kind::null: break; case field_kind::int64: sint8{input.get_int64()}.serialize(ctx); break; case field_kind::uint64: int8{input.get_uint64()}.serialize(ctx); break; - case field_kind::string: string_lenenc{input.get_string()}.serialize_checked(ctx); break; - case field_kind::blob: string_lenenc{to_string(input.get_blob())}.serialize_checked(ctx); break; + case field_kind::string: string_lenenc{input.get_string()}.serialize(ctx); break; + case field_kind::blob: string_lenenc{to_string(input.get_blob())}.serialize(ctx); break; case field_kind::float_: serialize_binary_float(ctx, input.get_float()); break; case field_kind::double_: serialize_binary_float(ctx, input.get_double()); break; case field_kind::date: serialize_binary_date(ctx, input.get_date()); break; diff --git a/include/boost/mysql/impl/internal/protocol/impl/protocol_types.hpp b/include/boost/mysql/impl/internal/protocol/impl/protocol_types.hpp index 5e403688b..9e6213b92 100644 --- a/include/boost/mysql/impl/internal/protocol/impl/protocol_types.hpp +++ b/include/boost/mysql/impl/internal/protocol/impl/protocol_types.hpp @@ -31,13 +31,16 @@ struct int_holder { IntType value; - void serialize(serialization_context& ctx) const + // This is a fixed-size type + static constexpr std::size_t size = sizeof(IntType); + + void serialize_fixed(std::uint8_t* to) const { - std::array buffer{}; - endian::endian_store(buffer.data(), value); - ctx.add(buffer); + endian::endian_store(to, value); } + void serialize(serialization_context& ctx) const { ctx.serialize_fixed(*this); } + deserialize_errc deserialize(deserialization_context& ctx) { constexpr std::size_t sz = sizeof(IntType); @@ -61,12 +64,12 @@ struct int3 { std::uint32_t value; - void serialize(serialization_context& ctx) const - { - std::array buffer; - endian::store_little_u24(buffer.data(), value); - ctx.add(buffer); - } + // This is a fixed-size type + static constexpr std::size_t size = 3u; + + void serialize_fixed(std::uint8_t* to) const { endian::store_little_u24(to, value); } + + void serialize(serialization_context& ctx) const { ctx.serialize_fixed(*this); } deserialize_errc deserialize(deserialization_context& ctx) { @@ -90,18 +93,21 @@ struct int_lenenc } else if (value < 0x10000) { - ctx.add(static_cast(0xfc)); - int2{static_cast(value)}.serialize(ctx); + ctx.serialize_fixed( + int1{static_cast(0xfc)}, + int2{static_cast(value)} + ); } else if (value < 0x1000000) { - ctx.add(static_cast(0xfd)); - int3{static_cast(value)}.serialize(ctx); + ctx.serialize_fixed( + int1{static_cast(0xfd)}, + int3{static_cast(value)} + ); } else { - ctx.add(static_cast(0xfe)); - int8{value}.serialize(ctx); + ctx.serialize_fixed(int1{static_cast(0xfe)}, int8{value}); } } @@ -178,7 +184,6 @@ struct string_eof } void serialize(serialization_context& ctx) const { ctx.add(to_span(value)); } - void serialize_checked(serialization_context& ctx) const { ctx.add_checked(to_span(value)); } }; struct string_lenenc @@ -213,11 +218,6 @@ struct string_lenenc ctx.serialize(int_lenenc{value.size()}); ctx.add(to_span(value)); } - void serialize_checked(serialization_context& ctx) const - { - ctx.serialize(int_lenenc{value.size()}); - ctx.add_checked(to_span(value)); - } }; template @@ -225,10 +225,12 @@ struct string_fixed { std::array value; - void serialize(serialization_context& ctx) const - { - ctx.add({reinterpret_cast(value.data()), N}); - } + // This is a fixed size type + static constexpr std::size_t size = N; + + void serialize_fixed(std::uint8_t* to) const { std::memcpy(to, value.data(), N); } + + void serialize(serialization_context& ctx) const { ctx.serialize_fixed(*this); } deserialize_errc deserialize(deserialization_context& ctx) { diff --git a/include/boost/mysql/impl/internal/protocol/impl/serialization_context.hpp b/include/boost/mysql/impl/internal/protocol/impl/serialization_context.hpp index 1d1bf6fcd..edf583a03 100644 --- a/include/boost/mysql/impl/internal/protocol/impl/serialization_context.hpp +++ b/include/boost/mysql/impl/internal/protocol/impl/serialization_context.hpp @@ -8,10 +8,14 @@ #ifndef BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_IMPL_SERIALIZATION_CONTEXT_HPP #define BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_IMPL_SERIALIZATION_CONTEXT_HPP +#include +#include + #include #include #include +#include #include #include @@ -31,25 +35,44 @@ BOOST_INLINE_CONSTEXPR std::size_t disable_framing = static_cast(-1 // of frame headers in serialization functions creates messages ready to send. // We require the entire message to be created before it's sent, so we don't lose any functionality. // -// This class knows the offset of the next frame header. Adding data with add_checked will correctly -// insert space for headers as required while copying the data. Other functions may result in -// "overruns" (writing past the offset of the next header). Overruns are fixed by add_frame_headers, -// which will memmove data as required. The distinction is made for efficiency. +// This class knows the offset of the next frame header. Adding data will correctly +// insert space for headers as required while copying the data. +// +// Like format_context_base, contains an error that can be set if a serialization +// function helps (e.g. because it would overrun the buffer size limit). +// Once set, serializing is a no-op. This pattern allows us to check for errors just once. class serialization_context { std::vector& buffer_; - std::size_t initial_offset_; + std::size_t max_buffer_size_; std::size_t max_frame_size_; - std::size_t next_header_offset_{}; + std::size_t next_header_offset_; + error_code err_; // max_frame_size_ == -1 can be used to disable framing. Used for testing bool framing_enabled() const { return max_frame_size_ != disable_framing; } - void add_checked_impl(span content) + void append_to_buffer(span contents) { - // Add any required frame headers we didn't add until now - add_frame_headers(); + // Do nothing if we previously encountered an error + if (err_) + return; + + // Check if the buffer has space for the given contents + if (buffer_.size() + contents.size() > max_buffer_size_) + { + err_ = client_errc::max_buffer_size_exceeded; + return; + } + // Copy + buffer_.insert(buffer_.end(), contents.begin(), contents.end()); + } + + void append_header() { append_to_buffer(std::array{}); } + + void add_impl(span content) + { // Add the content in chunks, inserting space for headers where required std::size_t content_offset = 0; while (content_offset < content.size()) @@ -59,88 +82,73 @@ class serialization_context auto remaining_content = static_cast(content.size() - content_offset); auto remaining_frame = static_cast(next_header_offset_ - buffer_.size()); auto size_to_write = (std::min)(remaining_content, remaining_frame); - buffer_.insert( - buffer_.end(), - content.data() + content_offset, - content.data() + content_offset + size_to_write - ); + append_to_buffer(content.subspan(content_offset, size_to_write)); content_offset += size_to_write; // Insert space for a frame header if required if (buffer_.size() == next_header_offset_) { - buffer_.resize(buffer_.size() + 4); + append_header(); next_header_offset_ += (max_frame_size_ + frame_header_size); } } } -public: - serialization_context(std::vector& buff, std::size_t max_frame_size = max_packet_size) - : buffer_(buff), initial_offset_(buffer_.size()), max_frame_size_(max_frame_size) + template + static constexpr std::size_t fixed_total_size(Serializable, Rest... rest) { - // Add space for the initial header - if (framing_enabled()) - { - buffer_.resize(buffer_.size() + frame_header_size); - next_header_offset_ = initial_offset_ + max_frame_size_ + frame_header_size; - } + return Serializable::size + fixed_total_size(rest...); } - // Exposed for testing - std::size_t next_header_offset() const { return next_header_offset_; } - - // To be called by serialize() functions. - // Appends a single byte to the buffer. Doesn't take framing into account. - void add(std::uint8_t value) { buffer_.push_back(value); } + static constexpr std::size_t fixed_total_size() { return 0u; } - // To be called by serialize() functions. Appends bytes to the buffer. - // Doesn't take framing into account - use for payloads with bound size. - void add(span content) + template + static void serialize_fixed_impl(std::uint8_t* it, Serializable serializable, Rest... rest) { - buffer_.insert(buffer_.end(), content.begin(), content.end()); + serializable.serialize_fixed(it); + serialize_fixed_impl(it + Serializable::size, rest...); } - // Like add, but takes framing into account. Use for potentially long payloads. - // If the payload is very long, space for frame headers will be added as required, - // avoiding expensive memmove's when calling add_frame_headers - void add_checked(span content) + static void serialize_fixed_impl(std::uint8_t*) {} + +public: + serialization_context( + std::vector& buff, + std::size_t max_buffer_size = static_cast(-1), + std::size_t max_frame_size = max_packet_size + ) + : buffer_(buff), + max_buffer_size_(max_buffer_size), + max_frame_size_(max_frame_size), + next_header_offset_( + framing_enabled() ? buffer_.size() + max_frame_size_ + frame_header_size + : static_cast(-1) + ) { + // Add space for the initial header if (framing_enabled()) - { - add_checked_impl(content); - } - else - { - add(content); - } + append_header(); } - // Inserts any missing space for frame headers, moving data as required. // Exposed for testing - void add_frame_headers() - { - while (next_header_offset_ <= buffer_.size()) - { - // Insert space for the frame header where needed - const std::array placeholder{}; - buffer_.insert(buffer_.begin() + next_header_offset_, placeholder.begin(), placeholder.end()); + std::size_t next_header_offset() const { return next_header_offset_; } - // Update the next frame header offset - next_header_offset_ += (max_frame_size_ + frame_header_size); - } - } + void add(std::uint8_t value) { add_impl({&value, 1}); } + + // To be called by serialize() functions. Appends bytes to the buffer. + void add(span content) { add_impl(content); } + + error_code error() const { return err_; } // Write frame headers to an already serialized message with space for them - std::uint8_t write_frame_headers(std::uint8_t seqnum) + std::uint8_t write_frame_headers(std::uint8_t seqnum, std::size_t initial_offset) { BOOST_ASSERT(framing_enabled()); - - // Add any missing space for headers - add_frame_headers(); + BOOST_ASSERT(!err_); + BOOST_ASSERT(initial_offset < buffer_.size()); // Actually write the headers - std::size_t offset = initial_offset_; + std::size_t offset = initial_offset; while (offset < buffer_.size()) { // Calculate the current frame size @@ -165,6 +173,17 @@ class serialization_context return seqnum; } + // Optimization for fixed size types. We serialize them to an + // intermediate, stack-based buffer, then copy them to the actual buffer. + // This saves reallocations and space checks + template + void serialize_fixed(Serializable... s) + { + std::array buff; + serialize_fixed_impl(buff.data(), s...); + add(buff); + } + // Allow chaining template void serialize(Serializable... s) diff --git a/include/boost/mysql/impl/internal/protocol/serialization.hpp b/include/boost/mysql/impl/internal/protocol/serialization.hpp index afab0e2d5..1fc5734e9 100644 --- a/include/boost/mysql/impl/internal/protocol/serialization.hpp +++ b/include/boost/mysql/impl/internal/protocol/serialization.hpp @@ -8,6 +8,7 @@ #ifndef BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_SERIALIZATION_HPP #define BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_SERIALIZATION_HPP +#include #include #include @@ -21,6 +22,7 @@ #include +#include #include namespace boost { @@ -53,7 +55,7 @@ struct query_command void serialize(serialization_context& ctx) const { ctx.add(0x03); - string_eof{query}.serialize_checked(ctx); + string_eof{query}.serialize(ctx); } }; @@ -82,11 +84,7 @@ struct execute_stmt_command struct close_stmt_command { std::uint32_t statement_id; - void serialize(serialization_context& ctx) const - { - ctx.add(0x19); - int4{statement_id}.serialize(ctx); - } + void serialize(serialization_context& ctx) const { ctx.serialize_fixed(int1{0x19}, int4{statement_id}); } }; // Login request @@ -121,18 +119,48 @@ struct auth_switch_response void serialize(serialization_context& ctx) const { ctx.add(auth_plugin_data); } }; -// Serialize a complete message +// The result of serialize_top_level (similar to system::result, +// doesn't track source locations) +struct serialize_top_level_result +{ + error_code err; + std::uint8_t seqnum{}; + + constexpr serialize_top_level_result(error_code ec) noexcept : err(ec) {} + constexpr serialize_top_level_result(std::uint8_t seqnum) noexcept : seqnum(seqnum) {} +}; + +// Serialize a complete message. May fail template -inline std::uint8_t serialize_top_level( +inline serialize_top_level_result serialize_top_level( const Serializable& input, std::vector& to, std::uint8_t seqnum = 0, - std::size_t frame_size = max_packet_size + std::size_t max_buffer_size = static_cast(-1), + std::size_t max_frame_size = max_packet_size ) { - serialization_context ctx(to, frame_size); + std::size_t initial_offset = to.size(); + serialization_context ctx(to, max_buffer_size, max_frame_size); input.serialize(ctx); - return ctx.write_frame_headers(seqnum); + auto err = ctx.error(); + if (err) + return err; + return ctx.write_frame_headers(seqnum, initial_offset); +} + +// Same, but for cases that can't fail. Does not enforce any limit on buffer size +template +inline std::uint8_t serialize_top_level_checked( + const Serializable& input, + std::vector& to, + std::uint8_t seqnum = 0, + std::size_t max_frame_size = max_packet_size +) +{ + auto res = serialize_top_level(input, to, seqnum, static_cast(-1), max_frame_size); + BOOST_ASSERT(res.err == error_code()); + return res.seqnum; } } // namespace detail @@ -197,7 +225,7 @@ void boost::mysql::detail::execute_stmt_command::serialize(serialization_context constexpr int1 new_params_bind_flag{1}; // header - ctx.serialize(command_id, int4{statement_id}, flags, iteration_count); + ctx.serialize_fixed(command_id, int4{statement_id}, flags, iteration_count); // Number of parameters auto num_params = params.size(); @@ -218,8 +246,7 @@ void boost::mysql::detail::execute_stmt_command::serialize(serialization_context field_kind kind = param.kind(); protocol_field_type type = to_protocol_field_type(kind); std::uint8_t unsigned_flag = kind == field_kind::uint64 ? std::uint8_t(0x80) : std::uint8_t(0); - ctx.add(static_cast(type)); - ctx.add(unsigned_flag); + ctx.serialize_fixed(int1{static_cast(type)}, int1{unsigned_flag}); } // actual values @@ -232,11 +259,13 @@ void boost::mysql::detail::execute_stmt_command::serialize(serialization_context void boost::mysql::detail::login_request::serialize(serialization_context& ctx) const { - ctx.serialize( + ctx.serialize_fixed( int4{negotiated_capabilities.get()}, // client_flag int4{max_packet_size}, // max_packet_size int1{get_collation_first_byte(collation_id)}, // character_set - string_fixed<23>{}, // filler (all zeros) + string_fixed<23>{} // filler (all zeros) + ); + ctx.serialize( string_null{username}, string_lenenc{to_string(auth_response)} // we require CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA ); @@ -249,7 +278,7 @@ void boost::mysql::detail::login_request::serialize(serialization_context& ctx) void boost::mysql::detail::ssl_request::serialize(serialization_context& ctx) const { - ctx.serialize( + ctx.serialize_fixed( int4{negotiated_capabilities.get()}, // client_flag int4{max_packet_size}, // max_packet_size int1{get_collation_first_byte(collation_id)}, // character_set, diff --git a/include/boost/mysql/impl/internal/sansio/close_statement.hpp b/include/boost/mysql/impl/internal/sansio/close_statement.hpp index 2ed026bab..4300f2d91 100644 --- a/include/boost/mysql/impl/internal/sansio/close_statement.hpp +++ b/include/boost/mysql/impl/internal/sansio/close_statement.hpp @@ -9,6 +9,7 @@ #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_CLOSE_STATEMENT_HPP #include +#include #include #include @@ -22,11 +23,17 @@ inline run_pipeline_algo_params setup_close_statement_pipeline( close_statement_algo_params params ) { + // Pipeline a ping with the close statement, to avoid delays on old connections + // that don't set tcp_nodelay. Both requests are small and fixed size, so + // we don't enforce any buffer limits here. st.write_buffer.clear(); - auto seqnum1 = serialize_top_level(close_stmt_command{params.stmt_id}, st.write_buffer); - auto seqnum2 = serialize_top_level(ping_command{}, st.write_buffer); + auto seqnum1 = serialize_top_level_checked(close_stmt_command{params.stmt_id}, st.write_buffer); + auto seqnum2 = serialize_top_level_checked(ping_command{}, st.write_buffer); st.shared_pipeline_stages = { - {{pipeline_stage_kind::close_statement, seqnum1, {}}, {pipeline_stage_kind::ping, seqnum2, {}}} + { + {pipeline_stage_kind::close_statement, seqnum1, {}}, + {pipeline_stage_kind::ping, seqnum2, {}}, + } }; return {st.write_buffer, st.shared_pipeline_stages, nullptr}; } diff --git a/include/boost/mysql/impl/internal/sansio/connection_state_data.hpp b/include/boost/mysql/impl/internal/sansio/connection_state_data.hpp index 2aec241bb..b1aaa065a 100644 --- a/include/boost/mysql/impl/internal/sansio/connection_state_data.hpp +++ b/include/boost/mysql/impl/internal/sansio/connection_state_data.hpp @@ -124,7 +124,10 @@ struct connection_state_data { // use_ssl is attached by top_level_algo write_buffer.clear(); - seqnum = serialize_top_level(msg, write_buffer, seqnum); + auto res = serialize_top_level(msg, write_buffer, seqnum, max_buffer_size()); + if (res.err) + return res.err; + seqnum = res.seqnum; return next_action::write({write_buffer, false}); } }; diff --git a/include/boost/mysql/impl/internal/sansio/ping.hpp b/include/boost/mysql/impl/internal/sansio/ping.hpp index 1d0cfb735..bdb43239f 100644 --- a/include/boost/mysql/impl/internal/sansio/ping.hpp +++ b/include/boost/mysql/impl/internal/sansio/ping.hpp @@ -40,18 +40,23 @@ class read_ping_response_algo if (ec) return ec; - // Process the OK packet - return st.deserialize_ok(*diag_); + // Process the OK packet and done + ec = st.deserialize_ok(*diag_); } - return next_action(); + + return ec; } }; inline run_pipeline_algo_params setup_ping_pipeline(connection_state_data& st) { + // The ping request is fixed size and small. No buffer limit is enforced on it. st.write_buffer.clear(); - auto seqnum = serialize_top_level(ping_command{}, st.write_buffer); - st.shared_pipeline_stages[0] = {pipeline_stage_kind::ping, seqnum, {}}; + st.shared_pipeline_stages[0] = { + pipeline_stage_kind::ping, + serialize_top_level_checked(ping_command{}, st.write_buffer), + {} + }; return { st.write_buffer, {st.shared_pipeline_stages.data(), 1}, diff --git a/include/boost/mysql/impl/internal/sansio/reset_connection.hpp b/include/boost/mysql/impl/internal/sansio/reset_connection.hpp index 1627ce6de..41be65054 100644 --- a/include/boost/mysql/impl/internal/sansio/reset_connection.hpp +++ b/include/boost/mysql/impl/internal/sansio/reset_connection.hpp @@ -56,19 +56,19 @@ class read_reset_connection_response_algo } // Done - return ec; } - return next_action(); + return ec; } }; inline run_pipeline_algo_params setup_reset_connection_pipeline(connection_state_data& st) { + // reset_connection request is fixed size and small, so we don't enforce any buffer limit st.write_buffer.clear(); st.shared_pipeline_stages[0] = { pipeline_stage_kind::reset_connection, - serialize_top_level(reset_connection_command{}, st.write_buffer), + serialize_top_level_checked(reset_connection_command{}, st.write_buffer), {} }; return { diff --git a/include/boost/mysql/impl/internal/sansio/top_level_algo.hpp b/include/boost/mysql/impl/internal/sansio/top_level_algo.hpp index 4ab19f3ff..5a0d6194a 100644 --- a/include/boost/mysql/impl/internal/sansio/top_level_algo.hpp +++ b/include/boost/mysql/impl/internal/sansio/top_level_algo.hpp @@ -107,17 +107,6 @@ class top_level_algo // Write until a complete message was written bytes_to_write_ = act.write_args().buffer; - // Check buffer size. We should check this before - // resizing the buffer, but requires non-trivial changes. - // For now, this yields the right user-facing behavior. - // https://github.com/boostorg/mysql/issues/297 - // https://github.com/boostorg/mysql/issues/279 - if (bytes_to_write_.size() > st_->max_buffer_size()) - { - ec = client_errc::max_buffer_size_exceeded; - continue; - } - while (!bytes_to_write_.empty() && !ec) { BOOST_MYSQL_YIELD( diff --git a/include/boost/mysql/impl/pipeline.ipp b/include/boost/mysql/impl/pipeline.ipp index f831c1f9b..2b7fb6230 100644 --- a/include/boost/mysql/impl/pipeline.ipp +++ b/include/boost/mysql/impl/pipeline.ipp @@ -31,7 +31,7 @@ boost::mysql::pipeline_request& boost::mysql::pipeline_request::add_execute(stri impl_.stages_.reserve(impl_.stages_.size() + 1); // strong guarantee impl_.stages_.push_back({ detail::pipeline_stage_kind::execute, - detail::serialize_top_level(detail::query_command{query}, impl_.buffer_), + detail::serialize_top_level_checked(detail::query_command{query}, impl_.buffer_), detail::resultset_encoding::text, }); return *this; @@ -51,7 +51,7 @@ boost::mysql::pipeline_request& boost::mysql::pipeline_request::add_execute_rang impl_.stages_.reserve(impl_.stages_.size() + 1); // strong guarantee impl_.stages_.push_back({ detail::pipeline_stage_kind::execute, - detail::serialize_top_level(detail::execute_stmt_command{stmt.id(), params}, impl_.buffer_), + detail::serialize_top_level_checked(detail::execute_stmt_command{stmt.id(), params}, impl_.buffer_), detail::resultset_encoding::binary, }); return *this; @@ -62,7 +62,7 @@ boost::mysql::pipeline_request& boost::mysql::pipeline_request::add_prepare_stat impl_.stages_.reserve(impl_.stages_.size() + 1); // strong guarantee impl_.stages_.push_back({ detail::pipeline_stage_kind::prepare_statement, - detail::serialize_top_level(detail::prepare_stmt_command{stmt_sql}, impl_.buffer_), + detail::serialize_top_level_checked(detail::prepare_stmt_command{stmt_sql}, impl_.buffer_), {}, }); return *this; @@ -73,7 +73,7 @@ boost::mysql::pipeline_request& boost::mysql::pipeline_request::add_close_statem impl_.stages_.reserve(impl_.stages_.size() + 1); // strong guarantee impl_.stages_.push_back({ detail::pipeline_stage_kind::close_statement, - detail::serialize_top_level(detail::close_stmt_command{stmt.id()}, impl_.buffer_), + detail::serialize_top_level_checked(detail::close_stmt_command{stmt.id()}, impl_.buffer_), {}, }); return *this; @@ -84,7 +84,7 @@ boost::mysql::pipeline_request& boost::mysql::pipeline_request::add_reset_connec impl_.stages_.reserve(impl_.stages_.size() + 1); // strong guarantee impl_.stages_.push_back({ detail::pipeline_stage_kind::reset_connection, - detail::serialize_top_level(detail::reset_connection_command{}, impl_.buffer_), + detail::serialize_top_level_checked(detail::reset_connection_command{}, impl_.buffer_), {}, }); return *this; @@ -100,7 +100,7 @@ boost::mysql::pipeline_request& boost::mysql::pipeline_request::add_set_characte impl_.stages_.reserve(impl_.stages_.size() + 1); // strong guarantee impl_.stages_.push_back({ detail::pipeline_stage_kind::set_character_set, - detail::serialize_top_level(detail::query_command{*q}, impl_.buffer_), + detail::serialize_top_level_checked(detail::query_command{*q}, impl_.buffer_), charset, }); return *this; diff --git a/test/unit/include/test_unit/algo_test.hpp b/test/unit/include/test_unit/algo_test.hpp index 2c4ba3b5a..40ff9831a 100644 --- a/test/unit/include/test_unit/algo_test.hpp +++ b/test/unit/include/test_unit/algo_test.hpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -28,6 +29,7 @@ #include "test_common/assert_buffer_equals.hpp" #include "test_common/create_diagnostics.hpp" #include "test_common/printing.hpp" +#include "test_common/source_location.hpp" #include "test_unit/printing.hpp" namespace boost { @@ -202,18 +204,26 @@ class BOOST_ATTRIBUTE_NODISCARD algo_test } template - void check(AlgoFixture& fix, error_code expected_ec = {}, const diagnostics& expected_diag = {}) const + void check( + AlgoFixture& fix, + error_code expected_ec = {}, + const diagnostics& expected_diag = {}, + source_location loc = BOOST_MYSQL_CURRENT_LOCATION + ) const { - check_impl(fix.st, fix.algo, expected_ec); - BOOST_TEST(fix.diag == expected_diag); + BOOST_TEST_CONTEXT("Called from " << loc) + { + check_impl(fix.st, fix.algo, expected_ec); + BOOST_TEST(fix.diag == expected_diag); + } } template - void check_network_errors() const + void check_network_errors(source_location loc = BOOST_MYSQL_CURRENT_LOCATION) const { for (std::size_t i = 0; i < num_steps(); ++i) { - BOOST_TEST_CONTEXT("check_network_errors erroring at step " << i) + BOOST_TEST_CONTEXT("Called from " << loc << " at step " << i) { AlgoFixture fix; check_network_errors_impl(fix.st, fix.algo, i); @@ -225,14 +235,21 @@ class BOOST_ATTRIBUTE_NODISCARD algo_test struct algo_fixture_base { - detail::connection_state_data st{512}; + static constexpr std::size_t default_max_buffsize = 1024u; + + detail::connection_state_data st; diagnostics diag; - algo_fixture_base(diagnostics initial_diag = create_server_diag("Diagnostics not cleared")) - : diag(std::move(initial_diag)) + algo_fixture_base( + diagnostics initial_diag = create_server_diag("Diagnostics not cleared"), + std::size_t max_buffer_size = default_max_buffsize + ) + : st(max_buffer_size, max_buffer_size), diag(std::move(initial_diag)) { st.write_buffer.push_back(0xff); // Check that we clear the write buffer at each step } + + algo_fixture_base(std::size_t max_buffer_size) : algo_fixture_base(diagnostics(), max_buffer_size) {} }; } // namespace test diff --git a/test/unit/include/test_unit/create_coldef_frame.hpp b/test/unit/include/test_unit/create_coldef_frame.hpp index df368fe5f..2184c3a2f 100644 --- a/test/unit/include/test_unit/create_coldef_frame.hpp +++ b/test/unit/include/test_unit/create_coldef_frame.hpp @@ -17,6 +17,7 @@ #include #include "test_unit/create_frame.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -58,24 +59,23 @@ inline std::vector create_coldef_body(const detail::coldef_view& p } }; - std::vector buff; - detail::serialization_context ctx(buff, detail::disable_framing); - ctx.serialize( - detail::string_lenenc{"def"}, - detail::string_lenenc{pack.database}, - detail::string_lenenc{pack.table}, - detail::string_lenenc{pack.org_table}, - detail::string_lenenc{pack.name}, - detail::string_lenenc{pack.org_name}, - detail::int_lenenc{0x0c}, // length of fixed fields - detail::int2{pack.collation_id}, - detail::int4{pack.column_length}, - detail::int1{static_cast(to_protocol_type(pack.type))}, - detail::int2{pack.flags}, - detail::int1{pack.decimals}, - detail::int2{0} // padding - ); - return buff; + return serialize_to_vector([=](detail::serialization_context& ctx) { + ctx.serialize( + detail::string_lenenc{"def"}, + detail::string_lenenc{pack.database}, + detail::string_lenenc{pack.table}, + detail::string_lenenc{pack.org_table}, + detail::string_lenenc{pack.name}, + detail::string_lenenc{pack.org_name}, + detail::int_lenenc{0x0c}, // length of fixed fields + detail::int2{pack.collation_id}, + detail::int4{pack.column_length}, + detail::int1{static_cast(to_protocol_type(pack.type))}, + detail::int2{pack.flags}, + detail::int1{pack.decimals}, + detail::int2{0} // padding + ); + }); } inline std::vector create_coldef_frame(std::uint8_t seqnum, const detail::coldef_view& coldef) diff --git a/test/unit/include/test_unit/create_err.hpp b/test/unit/include/test_unit/create_err.hpp index 437d31ae2..c4ef7c07f 100644 --- a/test/unit/include/test_unit/create_err.hpp +++ b/test/unit/include/test_unit/create_err.hpp @@ -16,6 +16,7 @@ #include #include "test_unit/create_frame.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -23,17 +24,16 @@ namespace test { inline std::vector serialize_err_impl(detail::err_view pack, bool with_header) { - std::vector buff; - detail::serialization_context ctx(buff, detail::disable_framing); - if (with_header) - ctx.add(0xff); // header - ctx.serialize( - detail::int2{pack.error_code}, - detail::string_fixed<1>{}, // SQL state marker - detail::string_fixed<5>{}, // SQL state - detail::string_eof{pack.error_message} - ); - return buff; + return serialize_to_vector([=](detail::serialization_context& ctx) { + if (with_header) + ctx.add(0xff); // header + ctx.serialize( + detail::int2{pack.error_code}, + detail::string_fixed<1>{}, // SQL state marker + detail::string_fixed<5>{}, // SQL state + detail::string_eof{pack.error_message} + ); + }); } class err_builder diff --git a/test/unit/include/test_unit/create_ok_frame.hpp b/test/unit/include/test_unit/create_ok_frame.hpp index 37ca1e7e8..c4a937db3 100644 --- a/test/unit/include/test_unit/create_ok_frame.hpp +++ b/test/unit/include/test_unit/create_ok_frame.hpp @@ -14,6 +14,7 @@ #include #include "test_unit/create_frame.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -21,22 +22,20 @@ namespace test { inline std::vector serialize_ok_impl(const detail::ok_view& pack, std::uint8_t header) { - std::vector buff; - detail::serialization_context ctx(buff, detail::disable_framing); - - ctx.serialize( - detail::int1{header}, - detail::int_lenenc{pack.affected_rows}, - detail::int_lenenc{pack.last_insert_id}, - detail::int2{pack.status_flags}, - detail::int2{pack.warnings} - ); - // When info is empty, it's actually omitted in the ok_packet - if (!pack.info.empty()) - { - detail::string_lenenc{pack.info}.serialize(ctx); - } - return buff; + return serialize_to_vector([=](detail::serialization_context& ctx) { + ctx.serialize( + detail::int1{header}, + detail::int_lenenc{pack.affected_rows}, + detail::int_lenenc{pack.last_insert_id}, + detail::int2{pack.status_flags}, + detail::int2{pack.warnings} + ); + // When info is empty, it's actually omitted in the ok_packet + if (!pack.info.empty()) + { + detail::string_lenenc{pack.info}.serialize(ctx); + } + }); } inline std::vector create_ok_body(const detail::ok_view& ok) diff --git a/test/unit/include/test_unit/create_prepare_statement_response.hpp b/test/unit/include/test_unit/create_prepare_statement_response.hpp index 11d5c8572..c9cd1222b 100644 --- a/test/unit/include/test_unit/create_prepare_statement_response.hpp +++ b/test/unit/include/test_unit/create_prepare_statement_response.hpp @@ -15,6 +15,7 @@ #include #include "test_unit/create_frame.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -56,17 +57,17 @@ class prepare_stmt_response_builder std::vector build() const { - std::vector res; - detail::serialization_context ctx(res, detail::disable_framing); - ctx.serialize( - detail::int1{0u}, // OK header - detail::int4{statement_id_}, // statement_id - detail::int2{num_columns_}, // num columns - detail::int2{num_params_}, // num_params - detail::int1{0u}, // reserved - detail::int2{90u} // warning_count - ); - return create_frame(seqnum_, res); + auto body = serialize_to_vector([this](detail::serialization_context& ctx) { + ctx.serialize( + detail::int1{0u}, // OK header + detail::int4{statement_id_}, // statement_id + detail::int2{num_columns_}, // num columns + detail::int2{num_params_}, // num_params + detail::int1{0u}, // reserved + detail::int2{90u} // warning_count + ); + }); + return create_frame(seqnum_, body); } }; diff --git a/test/unit/include/test_unit/create_query_frame.hpp b/test/unit/include/test_unit/create_query_frame.hpp index 23cf488d0..312909bd4 100644 --- a/test/unit/include/test_unit/create_query_frame.hpp +++ b/test/unit/include/test_unit/create_query_frame.hpp @@ -17,6 +17,7 @@ #include #include "test_unit/create_frame.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -24,11 +25,10 @@ namespace test { inline std::vector create_query_body_impl(std::uint8_t command_id, string_view sql) { - std::vector buff; - detail::serialization_context ctx(buff, detail::disable_framing); - ctx.add(command_id); - ctx.add(detail::to_span(sql)); - return buff; + return serialize_to_vector([=](detail::serialization_context& ctx) { + ctx.add(command_id); + ctx.add(detail::to_span(sql)); + }); } inline std::vector create_query_frame(std::uint8_t seqnum, string_view sql) diff --git a/test/unit/include/test_unit/create_row_message.hpp b/test/unit/include/test_unit/create_row_message.hpp index 953dac595..6dea92865 100644 --- a/test/unit/include/test_unit/create_row_message.hpp +++ b/test/unit/include/test_unit/create_row_message.hpp @@ -13,6 +13,7 @@ #include "test_common/create_basic.hpp" #include "test_unit/create_frame.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -20,25 +21,24 @@ namespace test { inline std::vector serialize_text_row_impl(span fields) { - std::vector buff; - detail::serialization_context ctx(buff, detail::disable_framing); - for (field_view f : fields) - { - std::string s; - switch (f.kind()) + return serialize_to_vector([=](detail::serialization_context& ctx) { + for (field_view f : fields) { - case field_kind::int64: s = std::to_string(f.get_int64()); break; - case field_kind::uint64: s = std::to_string(f.get_uint64()); break; - case field_kind::float_: s = std::to_string(f.get_float()); break; - case field_kind::double_: s = std::to_string(f.get_double()); break; - case field_kind::string: s = f.get_string(); break; - case field_kind::blob: s.assign(f.get_blob().begin(), f.get_blob().end()); break; - case field_kind::null: ctx.add(std::uint8_t(0xfb)); continue; - default: throw std::runtime_error("create_text_row_message: type not implemented"); + std::string s; + switch (f.kind()) + { + case field_kind::int64: s = std::to_string(f.get_int64()); break; + case field_kind::uint64: s = std::to_string(f.get_uint64()); break; + case field_kind::float_: s = std::to_string(f.get_float()); break; + case field_kind::double_: s = std::to_string(f.get_double()); break; + case field_kind::string: s = f.get_string(); break; + case field_kind::blob: s.assign(f.get_blob().begin(), f.get_blob().end()); break; + case field_kind::null: ctx.add(std::uint8_t(0xfb)); continue; + default: throw std::runtime_error("create_text_row_message: type not implemented"); + } + detail::string_lenenc{s}.serialize(ctx); } - detail::string_lenenc{s}.serialize(ctx); - } - return buff; + }); } template diff --git a/test/unit/include/test_unit/serialize_to_vector.hpp b/test/unit/include/test_unit/serialize_to_vector.hpp new file mode 100644 index 000000000..b69c70ab9 --- /dev/null +++ b/test/unit/include/test_unit/serialize_to_vector.hpp @@ -0,0 +1,37 @@ +// +// Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// + +#ifndef BOOST_MYSQL_TEST_UNIT_INCLUDE_TEST_UNIT_SERIALIZE_TO_VECTOR_HPP +#define BOOST_MYSQL_TEST_UNIT_INCLUDE_TEST_UNIT_SERIALIZE_TO_VECTOR_HPP + +#include + +#include +#include + +#include +#include + +namespace boost { +namespace mysql { +namespace test { + +template +std::vector serialize_to_vector(const Fn& serialize_fn) +{ + std::vector buff; + detail::serialization_context ctx(buff, static_cast(-1), detail::disable_framing); + serialize_fn(ctx); + BOOST_TEST(ctx.error() == error_code()); + return buff; +} + +} // namespace test +} // namespace mysql +} // namespace boost + +#endif diff --git a/test/unit/test/protocol/binary_protocol.cpp b/test/unit/test/protocol/binary_protocol.cpp index 3885bd6dc..8f3308061 100644 --- a/test/unit/test/protocol/binary_protocol.cpp +++ b/test/unit/test/protocol/binary_protocol.cpp @@ -117,44 +117,6 @@ BOOST_AUTO_TEST_CASE(serialize) } } -// String and blob parameters may be large, so we take framing -// into account when serializing them -BOOST_AUTO_TEST_CASE(serialize_framing) -{ - constexpr std::uint8_t blob_buffer[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - - struct - { - const char* name; - field_view param; - std::vector serialized; - } test_cases[] = { - // clang-format off - { - "string", - field_view("abcdefghijk"), - { - 0, 0, 0, 0, 11, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, - 0, 0, 0, 0, 0x6a, 0x6b - } - }, - { - "blob", - field_view(blob_buffer), - { - 0, 0, 0, 0, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, - 0, 0, 0, 0, 10 - } - }, - // clang-format on - }; - - for (const auto& tc : test_cases) - { - BOOST_TEST_CONTEXT(tc.name) { do_serialize_test(field_view_adaptor{tc.param}, tc.serialized, 10u); } - } -} - BOOST_AUTO_TEST_SUITE(deserialize_success) struct success_sample diff --git a/test/unit/test/protocol/serialization.cpp b/test/unit/test/protocol/serialization.cpp index 2cca2d867..b200b9918 100644 --- a/test/unit/test/protocol/serialization.cpp +++ b/test/unit/test/protocol/serialization.cpp @@ -5,7 +5,9 @@ // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // +#include #include +#include #include #include #include @@ -22,6 +24,7 @@ #include "serialization_test.hpp" #include "test_common/assert_buffer_equals.hpp" #include "test_common/create_basic.hpp" +#include "test_common/printing.hpp" #include "test_unit/mock_message.hpp" using namespace boost::mysql::detail; @@ -30,13 +33,14 @@ namespace collations = boost::mysql::mysql_collations; using boost::span; using boost::mysql::date; using boost::mysql::datetime; +using boost::mysql::error_code; using boost::mysql::field_view; using boost::mysql::string_view; BOOST_AUTO_TEST_SUITE(test_serialization) // spotcheck: multi-frame messages handled correctly by serialize_top_level -BOOST_AUTO_TEST_CASE(serialize_top_level_) +BOOST_AUTO_TEST_CASE(serialize_top_level_multiframe) { constexpr std::size_t frame_size = 8u; const std::array payload{ @@ -46,11 +50,24 @@ BOOST_AUTO_TEST_CASE(serialize_top_level_) 4, 5, 6, 7, 8, 3, 0, 0, 43, 9, 10, 11}; std::vector buff{80, 81, 82, 83, 85}; - std::uint8_t seqnum = serialize_top_level(mock_message{payload}, buff, 42, frame_size); - BOOST_TEST(seqnum == 44u); + auto result = serialize_top_level(mock_message{payload}, buff, 42, 0xffff, frame_size); + BOOST_TEST(result.err == error_code()); + BOOST_TEST(result.seqnum == 44u); BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); } +// spotcheck: max size correctly propagated +BOOST_AUTO_TEST_CASE(serialize_top_level_error_max_size) +{ + const std::array payload{ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} + }; + std::vector buff; + auto result = serialize_top_level(mock_message{payload}, buff, 42, 8u); + BOOST_TEST(result.err == boost::mysql::client_errc::max_buffer_size_exceeded); + BOOST_TEST(result.seqnum == 0u); +} + BOOST_AUTO_TEST_CASE(quit) { quit_command cmd; @@ -80,17 +97,6 @@ BOOST_AUTO_TEST_CASE(query) do_serialize_test(cmd, serialized); } -// Query strings may be large. We consider framing when serializing them -BOOST_AUTO_TEST_CASE(query_framing) -{ - query_command cmd{"show databases"}; - const std::uint8_t serialized[] = { - 0, 0, 0, 0, 0x03, 0x73, 0x68, 0x6f, 0x77, 0x20, 0x64, 0x61, // frame 1 - 0, 0, 0, 0, 0x74, 0x61, 0x62, 0x61, 0x73, 0x65, 0x73 // frame 2 - }; - do_serialize_test(cmd, serialized, 8u); -} - BOOST_AUTO_TEST_CASE(prepare_statement) { prepare_stmt_command cmd{"SELECT * from three_rows_table WHERE id = ?"}; diff --git a/test/unit/test/protocol/serialization_context.cpp b/test/unit/test/protocol/serialization_context.cpp index e0707525d..0983d63b4 100644 --- a/test/unit/test/protocol/serialization_context.cpp +++ b/test/unit/test/protocol/serialization_context.cpp @@ -5,6 +5,8 @@ // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // +#include +#include #include #include @@ -18,9 +20,12 @@ #include "test_common/assert_buffer_equals.hpp" #include "test_common/buffer_concat.hpp" +#include "test_common/printing.hpp" using namespace boost::mysql; +namespace { + BOOST_AUTO_TEST_SUITE(test_serialization_context) struct framing_test_case @@ -57,7 +62,7 @@ std::vector make_test_cases() }; } -BOOST_AUTO_TEST_CASE(add_frame_headers) +BOOST_AUTO_TEST_CASE(add) { constexpr std::size_t fs = 8u; // frame size const std::vector initial_buffer{0xaa, 0xbb, 0xcc, 0xdd, 0xee}; @@ -68,164 +73,107 @@ BOOST_AUTO_TEST_CASE(add_frame_headers) { // Setup std::vector buff{initial_buffer}; - detail::serialization_context ctx(buff, fs); + detail::serialization_context ctx(buff, 0xffff, fs); - // Add payload and set headers + // Add the payload ctx.add(tc.payload); - ctx.add_frame_headers(); // Check auto expected = test::concat_copy(initial_buffer, tc.expected_buffer); BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); BOOST_TEST(ctx.next_header_offset() == tc.expected_next_frame_offset + initial_buffer.size()); + BOOST_TEST(ctx.error() == error_code()); } } } // Spotcheck: if the initial buffer is empty, everything works fine -BOOST_AUTO_TEST_CASE(add_frame_headers_initial_buffer_empty) +BOOST_AUTO_TEST_CASE(add_initial_buffer_empty) { // Setup std::vector buff; - detail::serialization_context ctx(buff, 8); + detail::serialization_context ctx(buff, 0xffff, 8); // Add data const std::array payload{ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10} }; ctx.add(payload); - ctx.add_frame_headers(); // Check const std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 9, 10}; BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); BOOST_TEST(ctx.next_header_offset() == 24u); + BOOST_TEST(ctx.error() == error_code()); } // Spotcheck: adding single bytes or in chunks also works fine -BOOST_AUTO_TEST_CASE(add_frame_headers_chunks) +BOOST_AUTO_TEST_CASE(chunks) { // Setup std::vector buff; - detail::serialization_context ctx(buff, 8); - - // Add data + detail::serialization_context ctx(buff, 0xffff, 8); const std::array payload1{ {1, 2, 3, 4} }; const std::array payload2{ {5, 6, 7, 8, 9} }; - ctx.add(0xff); - ctx.add(payload1); - ctx.add(0xfe); - ctx.add(payload2); - ctx.add(0xfc); - ctx.add_frame_headers(); - // Check - const std::vector expected{0, 0, 0, 0, 0xff, 1, 2, 3, 4, 0xfe, - 5, 6, 0, 0, 0, 0, 7, 8, 9, 0xfc}; + // Add byte + ctx.add(0xff); + std::vector expected{0, 0, 0, 0, 0xff}; BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); - BOOST_TEST(ctx.next_header_offset() == 24u); -} - -// add_checked should behave like add + write_frame_headers -BOOST_AUTO_TEST_CASE(add_checked) -{ - constexpr std::size_t fs = 8u; // frame size - const std::vector initial_buffer{0xaa, 0xbb, 0xcc, 0xdd, 0xee}; - - for (const auto& tc : make_test_cases()) - { - BOOST_TEST_CONTEXT(tc.name) - { - // Setup - std::vector buff{initial_buffer}; - detail::serialization_context ctx(buff, fs); - - // Add payload and set headers - ctx.add_checked(tc.payload); + BOOST_TEST(ctx.error() == error_code()); - // Check - auto expected = test::concat_copy(initial_buffer, tc.expected_buffer); - BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); - BOOST_TEST(ctx.next_header_offset() == tc.expected_next_frame_offset + initial_buffer.size()); - } - } -} - -// Spotcheck: add_checked should work fine if the initial buffer is empty -BOOST_AUTO_TEST_CASE(add_checked_initial_buffer_empty) -{ - // Setup - std::vector buff; - detail::serialization_context ctx(buff, 8); + // Add buffer + ctx.add(payload1); + expected = {0, 0, 0, 0, 0xff, 1, 2, 3, 4}; + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + BOOST_TEST(ctx.error() == error_code()); - // Add payload and set headers - const std::array payload{ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - }; - ctx.add_checked(payload); + // Add byte + ctx.add(0xfe); + expected = {0, 0, 0, 0, 0xff, 1, 2, 3, 4, 0xfe}; + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + BOOST_TEST(ctx.error() == error_code()); - // Check - const std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 9, 10}; + // Add buffer + ctx.add(payload2); + expected = {0, 0, 0, 0, 0xff, 1, 2, 3, 4, 0xfe, 5, 6, 0, 0, 0, 0, 7, 8, 9}; BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + BOOST_TEST(ctx.error() == error_code()); + + // Add byte + ctx.add(0xfc); + expected = {0, 0, 0, 0, 0xff, 1, 2, 3, 4, 0xfe, 5, 6, 0, 0, 0, 0, 7, 8, 9, 0xfc}; BOOST_TEST(ctx.next_header_offset() == 24u); + BOOST_TEST(ctx.error() == error_code()); } -// If there are any missing frame headers when add_checked is called, -// they are inserted -BOOST_AUTO_TEST_CASE(add_checked_missing_frame_headers) +// Spotcheck: adding a single byte that causes a frame header to be written works +BOOST_AUTO_TEST_CASE(add_byte_fills_frame) { // Setup std::vector buff; - detail::serialization_context ctx(buff, 8); - - // Create some missing frame headers by using unchecked add - const std::array payload1{ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} + detail::serialization_context ctx(buff, 0xffff, 8); + const std::array payload{ + {1, 2, 3, 4, 5, 6, 7} }; - ctx.add(payload1); - - // Add (checked) some data - const std::array payload2{ - {21, 22, 23, 24, 25, 26, 27, 28, 29, 30} - }; - ctx.add_checked(payload2); - // Check - const std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, - 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 0, 0, 17, 18, 19, 20, - 21, 22, 23, 24, 0, 0, 0, 0, 25, 26, 27, 28, 29, 30}; + // Add payload + ctx.add(payload); + std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7}; BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); - BOOST_TEST(ctx.next_header_offset() == 48u); -} - -// Same as above, but what we insert via add_checked is not enough to fill a frame -BOOST_AUTO_TEST_CASE(add_checked_missing_frame_headers_small_payload) -{ - // Setup - std::vector buff; - detail::serialization_context ctx(buff, 8); - - // Create some missing frame headers by using unchecked add - const std::array payload1{ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} - }; - ctx.add(payload1); - - // Add (checked) some data - const std::array payload2{ - {21, 22} - }; - ctx.add_checked(payload2); + BOOST_TEST(ctx.next_header_offset() == 12u); + BOOST_TEST(ctx.error() == error_code()); - // Check - const std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 9, - 10, 11, 12, 13, 14, 15, 16, 0, 0, 0, 0, 17, 18, 19, 20, 21, 22}; + // Add byte + ctx.add(0xab); + expected = {0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0xab, 0, 0, 0, 0}; BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); - BOOST_TEST(ctx.next_header_offset() == 36u); + BOOST_TEST(ctx.next_header_offset() == 24u); + BOOST_TEST(ctx.error() == error_code()); } BOOST_AUTO_TEST_CASE(write_frame_headers) @@ -269,14 +217,15 @@ BOOST_AUTO_TEST_CASE(write_frame_headers) { // Setup std::vector buff{initial_buffer}; - detail::serialization_context ctx(buff, 8); - ctx.add_checked(tc.payload); + detail::serialization_context ctx(buff, 0xffff, 8); + ctx.add(tc.payload); // Call and check - auto seqnum = ctx.write_frame_headers(42); + auto seqnum = ctx.write_frame_headers(42, initial_buffer.size()); const auto expected = test::concat_copy(initial_buffer, tc.expected); BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); BOOST_TEST(seqnum == tc.expected_seqnum); + BOOST_TEST(ctx.error() == error_code()); } } } @@ -286,7 +235,7 @@ BOOST_AUTO_TEST_CASE(write_frame_headers_seqnum_wrap) { // Setup std::vector buff; - detail::serialization_context ctx(buff, 8); + detail::serialization_context ctx(buff, 0xffff, 8); for (std::uint8_t i = 1; i <= 20; ++i) ctx.add(i); @@ -296,9 +245,10 @@ BOOST_AUTO_TEST_CASE(write_frame_headers_seqnum_wrap) 8, 0, 0, 0xff, 9, 10, 11, 12, 13, 14, 15, 16, // frame 2 4, 0, 0, 0, 17, 18, 19, 20 // frame 3 }; - auto seqnum = ctx.write_frame_headers(0xfe); + auto seqnum = ctx.write_frame_headers(0xfe, 0); BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); BOOST_TEST(seqnum == 1u); + BOOST_TEST(ctx.error() == error_code()); } // Spotcheck: disable framing works @@ -306,7 +256,7 @@ BOOST_AUTO_TEST_CASE(disable_framing) { // Setup std::vector buff; - detail::serialization_context ctx(buff, detail::disable_framing); + detail::serialization_context ctx(buff, 0xffff, detail::disable_framing); // Add data using the several functions available const std::array payload1{ @@ -317,11 +267,180 @@ BOOST_AUTO_TEST_CASE(disable_framing) }; ctx.add(42); ctx.add(payload1); - ctx.add_checked(payload2); + ctx.add(payload2); // We didn't add any framing const std::vector expected{42, 1, 2, 3, 4, 5, 6, 7, 8, 9}; BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + BOOST_TEST(ctx.error() == error_code()); +} + +BOOST_AUTO_TEST_SUITE(max_buffer_size_error) + +BOOST_AUTO_TEST_CASE(header_exceeds_maxsize) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 3u); + + // Buffer can't hold the header + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); +} + +BOOST_AUTO_TEST_CASE(contents_exceed_maxsize) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 8u); + + // Header plus content would exceed max size + ctx.add(std::vector{1, 2, 3, 4, 5, 6}); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + + // Only header written + std::array expected{}; + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); +} + +BOOST_AUTO_TEST_CASE(subsequent_header_exceeds_maxsize) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 13u, 8u); + + // Successfully add some data + ctx.add(std::vector{1, 2, 3, 4, 5, 6}); + std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6}; + BOOST_TEST(ctx.error() == error_code()); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Add data triggering a header that can't fit + ctx.add(std::vector{7, 8}); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); +} + +BOOST_AUTO_TEST_CASE(maxsize_zero) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 0u); + + // Buffer can't hold the header + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); +} + +BOOST_AUTO_TEST_CASE(error_by_one_byte) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 12u); + + // Successfully add data until max size + ctx.add(std::vector{1, 2, 3, 4, 5, 6, 7, 8}); + std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + BOOST_TEST(ctx.error() == error_code()); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Adding more data fails. No data is written to the buffer + ctx.add(std::vector{1}); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); +} + +// Spotcheck: adding a single byte triggers the same behavior +BOOST_AUTO_TEST_CASE(error_add_u8) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 12u); + + // Successfully add data until max size + ctx.add(std::vector{1, 2, 3, 4, 5, 6, 7, 8}); + std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + BOOST_TEST(ctx.error() == error_code()); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Adding more data fails. No data is written to the buffer + ctx.add(42); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); +} + +// Edge case: if the input buffer exceeded the max size, we fail +BOOST_AUTO_TEST_CASE(buffer_exceeds_max_size) +{ + std::vector buff(48u, 0); + detail::serialization_context ctx(buff, 12u); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); +} + +// Previous contents are taken into account for size checks +BOOST_AUTO_TEST_CASE(buffer_with_previous_contents) +{ + // Setup + std::vector buff{1, 2, 3}; + detail::serialization_context ctx(buff, 8u); + + // Just max size + ctx.add(42); + std::vector expected{1, 2, 3, 0, 0, 0, 0, 42}; + BOOST_TEST(ctx.error() == error_code()); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Past max size + ctx.add(80); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); +} + +BOOST_AUTO_TEST_CASE(several_errors) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 12u); + + // Successfully add some data + ctx.add(std::vector{1, 2, 3, 4, 5}); + std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5}; + BOOST_TEST(ctx.error() == error_code()); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Adding more data fails. No data is written to the buffer + ctx.add(std::vector{6, 7, 8, 9}); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Adding more data again does nothing + ctx.add(std::vector{10, 11, 12}); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); } +BOOST_AUTO_TEST_CASE(success_after_error) +{ + // Setup + std::vector buff; + detail::serialization_context ctx(buff, 12u); + + // Successfully add some data + ctx.add(std::vector{1, 2, 3, 4, 5}); + std::vector expected{0, 0, 0, 0, 1, 2, 3, 4, 5}; + BOOST_TEST(ctx.error() == error_code()); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Adding more data fails. No data is written to the buffer + ctx.add(std::vector{6, 7, 8, 9}); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); + + // Adding more data again does nothing, even if the data would fit + ctx.add(1); + BOOST_TEST(ctx.error() == client_errc::max_buffer_size_exceeded); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(buff, expected); +} + +BOOST_AUTO_TEST_SUITE_END() + BOOST_AUTO_TEST_SUITE_END() + +} // namespace \ No newline at end of file diff --git a/test/unit/test/protocol/serialization_test.hpp b/test/unit/test/protocol/serialization_test.hpp index 6c0b3ef5e..105f1aa78 100644 --- a/test/unit/test/protocol/serialization_test.hpp +++ b/test/unit/test/protocol/serialization_test.hpp @@ -20,6 +20,7 @@ #include #include "test_common/assert_buffer_equals.hpp" +#include "test_unit/serialize_to_vector.hpp" namespace boost { namespace mysql { @@ -63,21 +64,13 @@ class deserialization_buffer }; template -void do_serialize_test( - T value, - span expected, - std::size_t frame_size = detail::disable_framing -) +void do_serialize_test(T value, span expected) { - // Setup - std::vector buffer; - detail::serialization_context ctx(buffer, frame_size); - // Serialize - value.serialize(ctx); + auto actual = serialize_to_vector([&](detail::serialization_context& ctx) { value.serialize(ctx); }); // Check - BOOST_MYSQL_ASSERT_BUFFER_EQUALS(expected, buffer); + BOOST_MYSQL_ASSERT_BUFFER_EQUALS(expected, actual); } template diff --git a/test/unit/test/sansio/prepare_statement.cpp b/test/unit/test/sansio/prepare_statement.cpp index 5cb2b81ee..32744bb1a 100644 --- a/test/unit/test/sansio/prepare_statement.cpp +++ b/test/unit/test/sansio/prepare_statement.cpp @@ -13,13 +13,9 @@ #include -#include -#include - #include "test_unit/algo_test.hpp" #include "test_unit/create_coldef_frame.hpp" #include "test_unit/create_err.hpp" -#include "test_unit/create_frame.hpp" #include "test_unit/create_meta.hpp" #include "test_unit/create_prepare_statement_response.hpp" #include "test_unit/create_query_frame.hpp" @@ -162,6 +158,9 @@ struct prepare_fixture : algo_fixture_base { detail::prepare_statement_algo algo{diag, {"SELECT 1"}}; + prepare_fixture() = default; + prepare_fixture(std::size_t max_bufsize) : algo_fixture_base(max_bufsize) {} + statement result() const { return algo.result(st); } }; @@ -184,16 +183,6 @@ BOOST_AUTO_TEST_CASE(prepare_success) BOOST_TEST(stmt.num_params() == 2u); } -BOOST_AUTO_TEST_CASE(prepare_network_error) -{ - // This covers errors in the request and the response - algo_test() - .expect_write(create_prepare_statement_frame(0, "SELECT 1")) - .expect_read(prepare_stmt_response_builder().seqnum(1).id(29).num_columns(0).num_params(1).build()) - .expect_read(create_coldef_frame(2, meta_builder().name("abc").build_coldef())) - .check_network_errors(); -} - // Spotcheck: an error while reading the response is propagated correctly BOOST_AUTO_TEST_CASE(prepare_error_packet) { @@ -211,4 +200,23 @@ BOOST_AUTO_TEST_CASE(prepare_error_packet) .check(fix, common_server_errc::er_bad_db_error, create_server_diag("my_message")); } +BOOST_AUTO_TEST_CASE(prepare_network_error) +{ + // This covers errors in the request and the response + algo_test() + .expect_write(create_prepare_statement_frame(0, "SELECT 1")) + .expect_read(prepare_stmt_response_builder().seqnum(1).id(29).num_columns(0).num_params(1).build()) + .expect_read(create_coldef_frame(2, meta_builder().name("abc").build_coldef())) + .check_network_errors(); +} + +BOOST_AUTO_TEST_CASE(prepare_error_max_buffer_size) +{ + // Setup + prepare_fixture fix(10u); + + // Run the algo + algo_test().check(fix, client_errc::max_buffer_size_exceeded); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/unit/test/sansio/set_character_set.cpp b/test/unit/test/sansio/set_character_set.cpp index b49349091..7a4a5586a 100644 --- a/test/unit/test/sansio/set_character_set.cpp +++ b/test/unit/test/sansio/set_character_set.cpp @@ -140,6 +140,10 @@ struct set_charset_fixture : algo_fixture_base detail::set_character_set_algo algo; set_charset_fixture(character_set charset = utf8mb4_charset) : algo(diag, {charset}) {} + set_charset_fixture(std::size_t max_bufsize) + : algo_fixture_base(max_bufsize), algo(diag, {utf8mb4_charset}) + { + } }; BOOST_AUTO_TEST_CASE(set_charset_success) @@ -192,4 +196,13 @@ BOOST_AUTO_TEST_CASE(set_charset_error_network) .check_network_errors(); } +BOOST_AUTO_TEST_CASE(set_charset_error_max_buffer_size) +{ + // Setup + set_charset_fixture fix(16u); + + // Run the algo + algo_test().check(fix, client_errc::max_buffer_size_exceeded); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/unit/test/sansio/start_execution.cpp b/test/unit/test/sansio/start_execution.cpp index c22a651bf..c4d2d0612 100644 --- a/test/unit/test/sansio/start_execution.cpp +++ b/test/unit/test/sansio/start_execution.cpp @@ -19,6 +19,8 @@ #include +#include + #include "test_common/check_meta.hpp" #include "test_common/create_basic.hpp" #include "test_unit/algo_test.hpp" @@ -41,7 +43,13 @@ struct fixture : algo_fixture_base mock_execution_processor proc; detail::start_execution_algo algo; - fixture(any_execution_request req) : algo(diag, {req, &proc}) {} + fixture( + any_execution_request req = any_execution_request("SELECT 1"), + std::size_t max_bufsize = default_max_buffsize + ) + : algo_fixture_base(max_bufsize), algo(diag, {req, &proc}) + { + } }; BOOST_AUTO_TEST_CASE(text_query) @@ -109,19 +117,23 @@ BOOST_AUTO_TEST_CASE(error_num_params) // This covers errors in both writing the request and calling read_resultset_head BOOST_AUTO_TEST_CASE(error_network_error) { - // check_network_errors() requires F to be default-constructible - struct query_fixture : fixture - { - query_fixture() : fixture(any_execution_request("SELECT 1")) {} - }; - // This will test for errors writing the execution request // and reading the response and metadata (thus, calling read_resultset_head) algo_test() .expect_write(create_frame(0, {0x03, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31})) .expect_read(create_frame(1, {0x01})) .expect_read(create_coldef_frame(2, meta_builder().type(column_type::varchar).build_coldef())) - .check_network_errors(); + .check_network_errors(); +} + +BOOST_AUTO_TEST_CASE(error_max_buffer_size) +{ + // Setup + std::string query(512, 'a'); + fixture fix(any_execution_request(query), 512u); + + // Run the algo + algo_test().check(fix, client_errc::max_buffer_size_exceeded); } BOOST_AUTO_TEST_SUITE_END() diff --git a/test/unit/test/sansio/top_level_algo.cpp b/test/unit/test/sansio/top_level_algo.cpp index 512c326db..1dcdc3a50 100644 --- a/test/unit/test/sansio/top_level_algo.cpp +++ b/test/unit/test/sansio/top_level_algo.cpp @@ -379,37 +379,6 @@ BOOST_AUTO_TEST_CASE(write_max_buffer_size_exact) BOOST_TEST(act.success()); } -BOOST_AUTO_TEST_CASE(write_max_buffer_size_exceeded) -{ - struct mock_algo - { - coroutine coro; - std::uint8_t seqnum{}; - const std::array long_msg{}; - - next_action resume(connection_state_data& st, error_code ec) - { - BOOST_ASIO_CORO_REENTER(coro) - { - BOOST_TEST(ec == error_code()); - BOOST_ASIO_CORO_YIELD return st.write(mock_message{long_msg}, seqnum); - BOOST_TEST(ec == client_errc::max_buffer_size_exceeded); - } - return next_action(); - } - }; - - connection_state_data st(32, 64); - top_level_algo algo(st); - - // Initial run yields a write request that exceeds the max buffer size. - // We never get to see such request, it generates an immediate failure. - auto act = algo.resume(error_code(), 0); - - // Done - BOOST_TEST(act.success()); -} - BOOST_AUTO_TEST_CASE(write_ssl_active) { struct mock_algo