Skip to content

Commit

Permalink
Writes exceeding max buffer size no longer resize the buffer before f…
Browse files Browse the repository at this point in the history
…ailing

close #297
  • Loading branch information
anarthal authored Jul 26, 2024
1 parent 820e10e commit b0c2639
Show file tree
Hide file tree
Showing 26 changed files with 642 additions and 452 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ inline void serialize_binary_float(serialization_context& ctx, T input)

inline void serialize_binary_date(serialization_context& ctx, const date& input)
{
ctx.serialize(
ctx.serialize_fixed(
int1{static_cast<std::uint8_t>(binc::date_sz)},
int2{input.year()},
int1{input.month()},
Expand All @@ -336,7 +336,7 @@ inline void serialize_binary_date(serialization_context& ctx, const date& input)

inline void serialize_binary_datetime(serialization_context& ctx, const datetime& input)
{
ctx.serialize(
ctx.serialize_fixed(
int1{static_cast<std::uint8_t>(binc::datetime_dhmsu_sz)},
int2{input.year()},
int1{input.month()},
Expand Down Expand Up @@ -367,7 +367,7 @@ inline void serialize_binary_time(serialization_context& ctx, const boost::mysql
auto is_negative = (input.count() < 0) ? 1 : 0;

// Serialize
ctx.serialize(
ctx.serialize_fixed(
int1{static_cast<std::uint8_t>(time_dhmsu_sz)},
int1{static_cast<std::uint8_t>(is_negative)},
int4{static_cast<std::uint32_t>(std::abs(num_days.count()))},
Expand Down Expand Up @@ -431,8 +431,8 @@ void boost::mysql::detail::serialize_binary_field(serialization_context& ctx, fi
case field_kind::null: break;
case field_kind::int64: sint8{input.get_int64()}.serialize(ctx); break;
case field_kind::uint64: int8{input.get_uint64()}.serialize(ctx); break;
case field_kind::string: string_lenenc{input.get_string()}.serialize_checked(ctx); break;
case field_kind::blob: string_lenenc{to_string(input.get_blob())}.serialize_checked(ctx); break;
case field_kind::string: string_lenenc{input.get_string()}.serialize(ctx); break;
case field_kind::blob: string_lenenc{to_string(input.get_blob())}.serialize(ctx); break;
case field_kind::float_: serialize_binary_float(ctx, input.get_float()); break;
case field_kind::double_: serialize_binary_float(ctx, input.get_double()); break;
case field_kind::date: serialize_binary_date(ctx, input.get_date()); break;
Expand Down
54 changes: 28 additions & 26 deletions include/boost/mysql/impl/internal/protocol/impl/protocol_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ struct int_holder
{
IntType value;

void serialize(serialization_context& ctx) const
// This is a fixed-size type
static constexpr std::size_t size = sizeof(IntType);

void serialize_fixed(std::uint8_t* to) const
{
std::array<std::uint8_t, sizeof(IntType)> buffer{};
endian::endian_store<IntType, sizeof(IntType), endian::order::little>(buffer.data(), value);
ctx.add(buffer);
endian::endian_store<IntType, sizeof(IntType), endian::order::little>(to, value);
}

void serialize(serialization_context& ctx) const { ctx.serialize_fixed(*this); }

deserialize_errc deserialize(deserialization_context& ctx)
{
constexpr std::size_t sz = sizeof(IntType);
Expand All @@ -61,12 +64,12 @@ struct int3
{
std::uint32_t value;

void serialize(serialization_context& ctx) const
{
std::array<std::uint8_t, 3> buffer;
endian::store_little_u24(buffer.data(), value);
ctx.add(buffer);
}
// This is a fixed-size type
static constexpr std::size_t size = 3u;

void serialize_fixed(std::uint8_t* to) const { endian::store_little_u24(to, value); }

void serialize(serialization_context& ctx) const { ctx.serialize_fixed(*this); }

deserialize_errc deserialize(deserialization_context& ctx)
{
Expand All @@ -90,18 +93,21 @@ struct int_lenenc
}
else if (value < 0x10000)
{
ctx.add(static_cast<std::uint8_t>(0xfc));
int2{static_cast<std::uint16_t>(value)}.serialize(ctx);
ctx.serialize_fixed(
int1{static_cast<std::uint8_t>(0xfc)},
int2{static_cast<std::uint16_t>(value)}
);
}
else if (value < 0x1000000)
{
ctx.add(static_cast<std::uint8_t>(0xfd));
int3{static_cast<std::uint32_t>(value)}.serialize(ctx);
ctx.serialize_fixed(
int1{static_cast<std::uint8_t>(0xfd)},
int3{static_cast<std::uint32_t>(value)}
);
}
else
{
ctx.add(static_cast<std::uint8_t>(0xfe));
int8{value}.serialize(ctx);
ctx.serialize_fixed(int1{static_cast<std::uint8_t>(0xfe)}, int8{value});
}
}

Expand Down Expand Up @@ -178,7 +184,6 @@ struct string_eof
}

void serialize(serialization_context& ctx) const { ctx.add(to_span(value)); }
void serialize_checked(serialization_context& ctx) const { ctx.add_checked(to_span(value)); }
};

struct string_lenenc
Expand Down Expand Up @@ -213,22 +218,19 @@ struct string_lenenc
ctx.serialize(int_lenenc{value.size()});
ctx.add(to_span(value));
}
void serialize_checked(serialization_context& ctx) const
{
ctx.serialize(int_lenenc{value.size()});
ctx.add_checked(to_span(value));
}
};

template <std::size_t N>
struct string_fixed
{
std::array<char, N> value;

void serialize(serialization_context& ctx) const
{
ctx.add({reinterpret_cast<const std::uint8_t*>(value.data()), N});
}
// This is a fixed size type
static constexpr std::size_t size = N;

void serialize_fixed(std::uint8_t* to) const { std::memcpy(to, value.data(), N); }

void serialize(serialization_context& ctx) const { ctx.serialize_fixed(*this); }

deserialize_errc deserialize(deserialization_context& ctx)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
#ifndef BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_IMPL_SERIALIZATION_CONTEXT_HPP
#define BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_IMPL_SERIALIZATION_CONTEXT_HPP

#include <boost/mysql/client_errc.hpp>
#include <boost/mysql/error_code.hpp>

#include <boost/mysql/impl/internal/protocol/frame_header.hpp>

#include <boost/assert.hpp>
#include <boost/core/ignore_unused.hpp>
#include <boost/core/span.hpp>
#include <boost/endian/conversion.hpp>

#include <algorithm>
Expand All @@ -31,25 +35,44 @@ BOOST_INLINE_CONSTEXPR std::size_t disable_framing = static_cast<std::size_t>(-1
// of frame headers in serialization functions creates messages ready to send.
// We require the entire message to be created before it's sent, so we don't lose any functionality.
//
// This class knows the offset of the next frame header. Adding data with add_checked will correctly
// insert space for headers as required while copying the data. Other functions may result in
// "overruns" (writing past the offset of the next header). Overruns are fixed by add_frame_headers,
// which will memmove data as required. The distinction is made for efficiency.
// This class knows the offset of the next frame header. Adding data will correctly
// insert space for headers as required while copying the data.
//
// Like format_context_base, contains an error that can be set if a serialization
// function helps (e.g. because it would overrun the buffer size limit).
// Once set, serializing is a no-op. This pattern allows us to check for errors just once.
class serialization_context
{
std::vector<std::uint8_t>& buffer_;
std::size_t initial_offset_;
std::size_t max_buffer_size_;
std::size_t max_frame_size_;
std::size_t next_header_offset_{};
std::size_t next_header_offset_;
error_code err_;

// max_frame_size_ == -1 can be used to disable framing. Used for testing
bool framing_enabled() const { return max_frame_size_ != disable_framing; }

void add_checked_impl(span<const std::uint8_t> content)
void append_to_buffer(span<const std::uint8_t> contents)
{
// Add any required frame headers we didn't add until now
add_frame_headers();
// Do nothing if we previously encountered an error
if (err_)
return;

// Check if the buffer has space for the given contents
if (buffer_.size() + contents.size() > max_buffer_size_)
{
err_ = client_errc::max_buffer_size_exceeded;
return;
}

// Copy
buffer_.insert(buffer_.end(), contents.begin(), contents.end());
}

void append_header() { append_to_buffer(std::array<std::uint8_t, frame_header_size>{}); }

void add_impl(span<const std::uint8_t> content)
{
// Add the content in chunks, inserting space for headers where required
std::size_t content_offset = 0;
while (content_offset < content.size())
Expand All @@ -59,88 +82,73 @@ class serialization_context
auto remaining_content = static_cast<std::size_t>(content.size() - content_offset);
auto remaining_frame = static_cast<std::size_t>(next_header_offset_ - buffer_.size());
auto size_to_write = (std::min)(remaining_content, remaining_frame);
buffer_.insert(
buffer_.end(),
content.data() + content_offset,
content.data() + content_offset + size_to_write
);
append_to_buffer(content.subspan(content_offset, size_to_write));
content_offset += size_to_write;

// Insert space for a frame header if required
if (buffer_.size() == next_header_offset_)
{
buffer_.resize(buffer_.size() + 4);
append_header();
next_header_offset_ += (max_frame_size_ + frame_header_size);
}
}
}

public:
serialization_context(std::vector<std::uint8_t>& buff, std::size_t max_frame_size = max_packet_size)
: buffer_(buff), initial_offset_(buffer_.size()), max_frame_size_(max_frame_size)
template <class Serializable, class... Rest>
static constexpr std::size_t fixed_total_size(Serializable, Rest... rest)
{
// Add space for the initial header
if (framing_enabled())
{
buffer_.resize(buffer_.size() + frame_header_size);
next_header_offset_ = initial_offset_ + max_frame_size_ + frame_header_size;
}
return Serializable::size + fixed_total_size(rest...);
}

// Exposed for testing
std::size_t next_header_offset() const { return next_header_offset_; }

// To be called by serialize() functions.
// Appends a single byte to the buffer. Doesn't take framing into account.
void add(std::uint8_t value) { buffer_.push_back(value); }
static constexpr std::size_t fixed_total_size() { return 0u; }

// To be called by serialize() functions. Appends bytes to the buffer.
// Doesn't take framing into account - use for payloads with bound size.
void add(span<const std::uint8_t> content)
template <class Serializable, class... Rest>
static void serialize_fixed_impl(std::uint8_t* it, Serializable serializable, Rest... rest)
{
buffer_.insert(buffer_.end(), content.begin(), content.end());
serializable.serialize_fixed(it);
serialize_fixed_impl(it + Serializable::size, rest...);
}

// Like add, but takes framing into account. Use for potentially long payloads.
// If the payload is very long, space for frame headers will be added as required,
// avoiding expensive memmove's when calling add_frame_headers
void add_checked(span<const std::uint8_t> content)
static void serialize_fixed_impl(std::uint8_t*) {}

public:
serialization_context(
std::vector<std::uint8_t>& buff,
std::size_t max_buffer_size = static_cast<std::size_t>(-1),
std::size_t max_frame_size = max_packet_size
)
: buffer_(buff),
max_buffer_size_(max_buffer_size),
max_frame_size_(max_frame_size),
next_header_offset_(
framing_enabled() ? buffer_.size() + max_frame_size_ + frame_header_size
: static_cast<std::size_t>(-1)
)
{
// Add space for the initial header
if (framing_enabled())
{
add_checked_impl(content);
}
else
{
add(content);
}
append_header();
}

// Inserts any missing space for frame headers, moving data as required.
// Exposed for testing
void add_frame_headers()
{
while (next_header_offset_ <= buffer_.size())
{
// Insert space for the frame header where needed
const std::array<std::uint8_t, frame_header_size> placeholder{};
buffer_.insert(buffer_.begin() + next_header_offset_, placeholder.begin(), placeholder.end());
std::size_t next_header_offset() const { return next_header_offset_; }

// Update the next frame header offset
next_header_offset_ += (max_frame_size_ + frame_header_size);
}
}
void add(std::uint8_t value) { add_impl({&value, 1}); }

// To be called by serialize() functions. Appends bytes to the buffer.
void add(span<const std::uint8_t> content) { add_impl(content); }

error_code error() const { return err_; }

// Write frame headers to an already serialized message with space for them
std::uint8_t write_frame_headers(std::uint8_t seqnum)
std::uint8_t write_frame_headers(std::uint8_t seqnum, std::size_t initial_offset)
{
BOOST_ASSERT(framing_enabled());

// Add any missing space for headers
add_frame_headers();
BOOST_ASSERT(!err_);
BOOST_ASSERT(initial_offset < buffer_.size());

// Actually write the headers
std::size_t offset = initial_offset_;
std::size_t offset = initial_offset;
while (offset < buffer_.size())
{
// Calculate the current frame size
Expand All @@ -165,6 +173,17 @@ class serialization_context
return seqnum;
}

// Optimization for fixed size types. We serialize them to an
// intermediate, stack-based buffer, then copy them to the actual buffer.
// This saves reallocations and space checks
template <class... Serializable>
void serialize_fixed(Serializable... s)
{
std::array<std::uint8_t, fixed_total_size(Serializable{}...)> buff;
serialize_fixed_impl(buff.data(), s...);
add(buff);
}

// Allow chaining
template <class... Serializable>
void serialize(Serializable... s)
Expand Down
Loading

0 comments on commit b0c2639

Please sign in to comment.