Skip to content

Commit

Permalink
serialization now uses new fail functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
anarthal committed Jul 26, 2024
1 parent c2844bc commit 6eb384a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,13 @@ class serialization_context

void append_to_buffer(span<const std::uint8_t> contents)
{
// Do nothing if we previously encountered an error
if (err_)
return;

// Check if the buffer has space for the given contents
if (buffer_.size() + contents.size() > max_buffer_size_)
{
err_ = client_errc::max_buffer_size_exceeded;
return;
}
add_error(client_errc::max_buffer_size_exceeded);

// Copy
buffer_.insert(buffer_.end(), contents.begin(), contents.end());
// Copy if there was no error
if (!err_)
buffer_.insert(buffer_.end(), contents.begin(), contents.end());
}

void append_header() { append_to_buffer(std::array<std::uint8_t, frame_header_size>{}); }
Expand Down Expand Up @@ -138,6 +132,13 @@ class serialization_context
// To be called by serialize() functions. Appends bytes to the buffer.
void add(span<const std::uint8_t> content) { add_impl(content); }

// Sets the error state. TODO: unit test
void add_error(error_code ec)
{
if (!err_)
err_ = ec;
}

error_code error() const { return err_; }

// Write frame headers to an already serialized message with space for them
Expand Down
43 changes: 17 additions & 26 deletions include/boost/mysql/impl/internal/protocol/query_with_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <boost/mysql/impl/internal/protocol/impl/serialization_context.hpp>

#include <boost/core/span.hpp>
#include <boost/system/result.hpp>

#include <cstdint>

Expand All @@ -30,9 +29,7 @@ class external_format_context : public format_context_base
{
static void do_append(void* obj, const char* data, std::size_t size)
{
static_cast<serialization_context*>(obj)->add_checked(
{reinterpret_cast<const std::uint8_t*>(data), size}
);
static_cast<serialization_context*>(obj)->add({reinterpret_cast<const std::uint8_t*>(data), size});
}

public:
Expand All @@ -42,33 +39,27 @@ class external_format_context : public format_context_base
}
};

// TODO: can we make serialization able to fail in the general case?
inline system::result<std::uint8_t> serialize_query_with_params(
std::vector<std::uint8_t>& to,
constant_string_view query,
span<const format_arg> args,
format_options opts,
std::size_t frame_size = max_packet_size
)
struct query_with_params
{
// Create a serialization context
serialization_context ctx(to, frame_size);
constant_string_view query;
span<const format_arg> args;
format_options opts;

// Serialize the query header
ctx.add(0x03);
void serialize(serialization_context& ctx) const
{
// Create a format context
external_format_context fmt_ctx(ctx, opts);

// Create a format context and serialize the actual query
external_format_context fmt_ctx(ctx, opts);
vformat_sql_to(fmt_ctx, query, args);
// Serialize the query header
ctx.add(0x03);

// Check for errors
auto err = fmt_ctx.error_state();
if (err)
return err;
// Serialize the actual query
vformat_sql_to(fmt_ctx, query, args);

// Write frame headers
return ctx.write_frame_headers(0);
}
// Check for errors
ctx.add_error(fmt_ctx.error_state());
}
};

} // namespace detail
} // namespace mysql
Expand Down
23 changes: 4 additions & 19 deletions include/boost/mysql/impl/internal/sansio/start_execution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_START_EXECUTION_HPP
#define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_START_EXECUTION_HPP

#include <boost/mysql/character_set.hpp>
#include <boost/mysql/client_errc.hpp>
#include <boost/mysql/diagnostics.hpp>
#include <boost/mysql/error_code.hpp>
Expand Down Expand Up @@ -55,31 +56,15 @@ class start_execution_algo
any_execution_request::data_t::query_with_params_t data
)
{
// TODO: this should be expressible in terms of st.write()
// Determine format options
if (st.current_charset.name == nullptr)
{
return error_code(client_errc::unknown_character_set);
}
format_options opts{st.current_charset, st.backslash_escapes};

// Clear the write buffer
st.write_buffer.clear();

// Serialize
auto res = serialize_query_with_params(
st.write_buffer,
data.query,
data.args,
format_options{st.current_charset, st.backslash_escapes}
);

// Check for errors
if (res.has_error())
return res.error();

// Done
seqnum() = *res;
return next_action::write({st.write_buffer, false});
// Write the request
return st.write(query_with_params{data.query, data.args, opts}, seqnum());
}

next_action write_stmt(connection_state_data& st, any_execution_request::data_t::stmt_t data)
Expand Down

0 comments on commit 6eb384a

Please sign in to comment.