Skip to content

Commit

Permalink
[struct_pack][feat] support deserialize derived class (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
poor-circle authored Sep 18, 2023
1 parent 9e8d841 commit ad3abef
Show file tree
Hide file tree
Showing 6 changed files with 394 additions and 37 deletions.
58 changes: 58 additions & 0 deletions include/ylt/struct_pack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include <cstdint>
#include <memory>
#include <type_traits>
#include <utility>

Expand Down Expand Up @@ -512,4 +513,61 @@ template <typename T, size_t I, typename Reader,
}
return ret;
}
#if __cpp_concepts >= 201907L
template <typename BaseClass, typename... DerivedClasses,
struct_pack::reader_t Reader>
#else
template <typename BaseClass, typename... DerivedClasses, typename Reader,
typename = std::enable_if_t<struct_pack::reader_t<Reader>>>
#endif
[[nodiscard]] STRUCT_PACK_INLINE
struct_pack::expected<std::unique_ptr<BaseClass>, struct_pack::errc>
deserialize_derived_class(Reader &reader) {
static_assert(sizeof...(DerivedClasses) > 0,
"There must have a least one derived class");
static_assert(
struct_pack::detail::public_base_class_checker<
BaseClass, std::tuple<DerivedClasses...>>::value,
"the First type should be the base class of all derived class ");
constexpr auto has_hash_collision =
struct_pack::detail::MD5_set<DerivedClasses...>::has_hash_collision;
if constexpr (has_hash_collision != 0) {
static_assert(!sizeof(std::tuple_element_t<has_hash_collision,
std::tuple<DerivedClasses...>>),
"hash collision happened, consider add member `static "
"constexpr uint64_t struct_pack_id` for collision type. ");
}
else {
struct_pack::expected<std::unique_ptr<BaseClass>, struct_pack::errc> ret;
auto ec = struct_pack::detail::deserialize_derived_class<BaseClass,
DerivedClasses...>(
ret.value(), reader);
if SP_UNLIKELY (ec != struct_pack::errc{}) {
ret = unexpected<struct_pack::errc>{ec};
}
return ret;
}
}
#if __cpp_concepts >= 201907L
template <typename BaseClass, typename... DerivedClasses,
detail::deserialize_view View>
#else
template <
typename BaseClass, typename... DerivedClasses, typename View,
typename = std::enable_if_t<struct_pack::detail::deserialize_view<View>>>
#endif
[[nodiscard]] STRUCT_PACK_INLINE
struct_pack::expected<std::unique_ptr<BaseClass>, struct_pack::errc>
deserialize_derived_class(const View &v) {
detail::memory_reader reader{v.data(), v.data() + v.size()};
return deserialize_derived_class<BaseClass, DerivedClasses...>(reader);
}
template <typename BaseClass, typename... DerivedClasses>
[[nodiscard]] STRUCT_PACK_INLINE
struct_pack::expected<std::unique_ptr<BaseClass>, struct_pack::errc>
deserialize_derived_class(const char *data, size_t size) {
detail::memory_reader reader{data, data + size};
return deserialize_derived_class<BaseClass, DerivedClasses...>(reader);
}

} // namespace struct_pack
74 changes: 52 additions & 22 deletions include/ylt/struct_pack/reflection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,28 +89,6 @@ constexpr std::size_t alignment_v = 0;

#if __cpp_concepts >= 201907L

template <typename T>
concept has_user_defined_id = requires {
typename std::integral_constant<std::size_t, T::struct_pack_id>;
};

#else

template <typename T, typename = void>
struct has_user_defined_id_impl : std::false_type {};

template <typename T>
struct has_user_defined_id_impl<
T, std::void_t<std::integral_constant<std::size_t, T::struct_pack_id>>>
: std::true_type {};

template <typename T>
constexpr bool has_user_defined_id = has_user_defined_id_impl<T>::value;

#endif

#if __cpp_concepts >= 201907L

template <typename T>
concept writer_t = requires(T t) {
t.write((const char *)nullptr, std::size_t{});
Expand Down Expand Up @@ -189,6 +167,58 @@ struct compatible;
// clang-format off
namespace detail {

#if __cpp_concepts >= 201907L

template <typename T>
concept has_user_defined_id = requires {
typename std::integral_constant<std::size_t, T::struct_pack_id>;
};

template <typename T>
concept has_user_defined_id_ADL = requires {
typename std::integral_constant<std::size_t,
struct_pack_id((T*)nullptr)>;
};

#else

template <typename T, typename = void>
struct has_user_defined_id_impl : std::false_type {};

template <typename T>
struct has_user_defined_id_impl<
T, std::void_t<std::integral_constant<std::size_t, T::struct_pack_id>>>
: std::true_type {};

template <typename T>
constexpr bool has_user_defined_id = has_user_defined_id_impl<T>::value;

template <std::size_t sz>
struct constant_checker{};

template <typename T, typename = void>
struct has_user_defined_id_ADL_impl : std::false_type {};

#ifdef _MSC_VER
// FIXME: we can't check if it's compile-time calculated in msvc with C++17
template <typename T>
struct has_user_defined_id_ADL_impl<
T, std::void_t<decltype(struct_pack_id((T*)nullptr))>>
: std::true_type {};
#else

template <typename T>
struct has_user_defined_id_ADL_impl<
T, std::void_t<constant_checker<struct_pack_id((T*)nullptr)>>>
: std::true_type {};

#endif

template <typename T>
constexpr bool has_user_defined_id_ADL = has_user_defined_id_ADL_impl<T>::value;

#endif

#if __cpp_concepts >= 201907L
template <typename Type>
concept deserialize_view = requires(Type container) {
Expand Down
169 changes: 156 additions & 13 deletions include/ylt/struct_pack/struct_pack_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "reflection.hpp"
#include "trivial_view.hpp"
#include "varint.hpp"
#include "ylt/struct_pack.hpp"

#if __cplusplus >= 202002L
#include "tuple.hpp"
Expand Down Expand Up @@ -919,6 +920,11 @@ constexpr decltype(auto) get_type_end_flag() {
{static_cast<char>(type_id::type_end_flag_with_id)}} +
get_size_literal<Arg::struct_pack_id>();
}
else if constexpr (has_user_defined_id_ADL<Arg>) {
return string_literal<char, 1>{
{static_cast<char>(type_id::type_end_flag_with_id)}} +
get_size_literal<struct_pack_id((Arg *)nullptr)>();
}
else {
return string_literal<char, 1>{{static_cast<char>(type_id::type_end_flag)}};
}
Expand Down Expand Up @@ -1180,7 +1186,7 @@ constexpr bool check_if_compatible_element_exist_impl_helper() {
template <typename T, typename... Args>
constexpr uint32_t get_types_code_impl() {
constexpr auto str = get_types_literal<T, remove_cvref_t<Args>...>();
return MD5::MD5Hash32Constexpr(str.data(), str.size());
return MD5::MD5Hash32Constexpr(str.data(), str.size()) & 0xFFFFFFFE;
}

template <typename T, typename Tuple, size_t... I>
Expand Down Expand Up @@ -1491,9 +1497,8 @@ constexpr void get_compatible_version_numbers(Buffer &buffer, std::size_t &sz) {
}
}

template <std::size_t sz>
constexpr void STRUCT_PACK_INLINE
compile_time_sort(std::array<uint64_t, sz> &array) {
template <typename T, std::size_t sz>
constexpr void STRUCT_PACK_INLINE compile_time_sort(std::array<T, sz> &array) {
// FIXME: use faster compile-time sort
for (std::size_t i = 0; i < array.size(); ++i) {
for (std::size_t j = i + 1; j < array.size(); ++j) {
Expand Down Expand Up @@ -1700,10 +1705,10 @@ class packer {
constexpr uint32_t raw_types_code = calculate_raw_hash<T, Args...>();
if constexpr (serialize_static_config<serialize_type>::has_compatible ||
check_if_add_type_literal<conf, serialize_type>()) {
return raw_types_code - raw_types_code % 2 + 1;
return raw_types_code + 1;
}
else { // default case, only has hash_code
return raw_types_code - raw_types_code % 2;
return raw_types_code;
}
}
template <uint64_t conf, bool is_default_size_type, typename T,
Expand Down Expand Up @@ -2036,6 +2041,16 @@ struct memory_reader {
}
std::size_t tellg() { return (std::size_t)now; }
};

template <typename Reader>
struct MD5_reader_wrapper;

template <typename T>
constexpr bool is_MD5_reader_wrapper = false;

template <typename T>
constexpr bool is_MD5_reader_wrapper<MD5_reader_wrapper<T>> = true;

#if __cpp_concepts >= 201907L
template <reader_t Reader>
#else
Expand Down Expand Up @@ -2384,6 +2399,7 @@ class unpacker {
static STRUCT_PACK_INLINE constexpr void run(unpack &unpacker,
variant_t &v) {
if constexpr (index >= std::variant_size_v<variant_t>) {
unreachable();
return;
}
else {
Expand Down Expand Up @@ -2434,14 +2450,19 @@ class unpacker {
STRUCT_PACK_INLINE std::pair<struct_pack::errc, std::uint64_t>
deserialize_metainfo() {
uint32_t current_types_code;
if SP_UNLIKELY (!reader_.read((char *)&current_types_code,
sizeof(uint32_t))) {
return {struct_pack::errc::no_buffer_space, 0};
if constexpr (is_MD5_reader_wrapper<Reader>) {
reader_.read_head((char *)&current_types_code);
}
constexpr uint32_t types_code =
get_types_code<T, decltype(get_types<T>())>();
if SP_UNLIKELY ((current_types_code / 2) != (types_code / 2)) {
return {struct_pack::errc::invalid_buffer, 0};
else {
if SP_UNLIKELY (!reader_.read((char *)&current_types_code,
sizeof(uint32_t))) {
return {struct_pack::errc::no_buffer_space, 0};
}
constexpr uint32_t types_code =
get_types_code<T, decltype(get_types<T>())>();
if SP_UNLIKELY ((current_types_code / 2) != (types_code / 2)) {
return {struct_pack::errc::invalid_buffer, 0};
}
}
if SP_LIKELY (current_types_code % 2 == 0) // unexist extended metainfo
{
Expand Down Expand Up @@ -2974,5 +2995,127 @@ class unpacker {
unsigned char size_type_;
};

struct MD5_pair {
uint32_t md5;
uint32_t index;
constexpr friend bool operator<(const MD5_pair &l, const MD5_pair &r) {
return l.md5 < r.md5;
}
constexpr friend bool operator>(const MD5_pair &l, const MD5_pair &r) {
return l.md5 > r.md5;
}
constexpr friend bool operator==(const MD5_pair &l, const MD5_pair &r) {
return l.md5 == r.md5;
}
};

template <typename... DerivedClasses>
struct MD5_set {
static constexpr int size = sizeof...(DerivedClasses);
static_assert(size <= 256);

private:
template <std::size_t... Index>
static constexpr std::array<MD5_pair, size> calculate_md5(
std::index_sequence<Index...>) {
std::array<MD5_pair, size> md5{};
((md5[Index] =
MD5_pair{get_types_code<DerivedClasses,
decltype(get_types<DerivedClasses>())>() &
0xFFFFFFFE,
Index}),
...);
compile_time_sort(md5);
return md5;
}
static constexpr std::size_t has_hash_collision_impl() {
for (std::size_t i = 1; i < size; ++i) {
if (value[i - 1] == value[i]) {
return value[i].index;
}
}
return 0;
}

public:
static constexpr std::array<MD5_pair, size> value =
calculate_md5(std::make_index_sequence<size>());
static constexpr std::size_t has_hash_collision = has_hash_collision_impl();
};

template <typename BaseClass, typename DerivedClasses>
struct public_base_class_checker {
static_assert(std::tuple_size_v<DerivedClasses> <= 256);

private:
template <std::size_t... Index>
static constexpr bool calculate_md5(std::index_sequence<Index...>) {
return (std::is_base_of_v<BaseClass,
std::tuple_element_t<Index, DerivedClasses>> &&
...);
}

public:
static constexpr bool value = public_base_class_checker::calculate_md5(
std::make_index_sequence<std::tuple_size_v<DerivedClasses>>());
};

template <typename DerivedClasses>
struct deserialize_derived_class_helper {
template <size_t index, typename BaseClass, typename unpack>
static STRUCT_PACK_INLINE constexpr struct_pack::errc run(
std::unique_ptr<BaseClass> &base, unpack &unpacker) {
if constexpr (index >= std::tuple_size_v<DerivedClasses>) {
unreachable();
return struct_pack::errc{};
}
else {
using derived_class = std::tuple_element_t<index, DerivedClasses>;
base = std::make_unique<derived_class>();
return unpacker.deserialize(*(derived_class *)base.get());
}
}
};

template <typename Reader>
struct MD5_reader_wrapper : public Reader {
MD5_reader_wrapper(Reader &&reader) : Reader(std::move(reader)) {
is_failed = !Reader::read((char *)&head_chars, sizeof(head_chars));
}
bool read_head(char *target) {
memcpy(target, &head_chars, sizeof(head_chars));
return true;
}
Reader &&release_reader() { return std::move(*(Reader *)this); }
bool is_failed;
uint32_t get_md5() { return head_chars & 0xFFFFFFFE; }

private:
std::uint32_t head_chars;
std::size_t read_pos;
};

template <typename BaseClass, typename... DerivedClasses, typename Reader>
[[nodiscard]] STRUCT_PACK_INLINE struct_pack::errc deserialize_derived_class(
std::unique_ptr<BaseClass> &base, Reader &reader) {
MD5_reader_wrapper wrapper{std::move(reader)};
if (wrapper.is_failed) {
return struct_pack::errc::no_buffer_space;
}
unpacker<MD5_reader_wrapper<Reader>> unpack{wrapper};
constexpr auto &MD5s = MD5_set<DerivedClasses...>::value;
static_assert(MD5s.size() == sizeof...(DerivedClasses));
MD5_pair md5_pair{wrapper.get_md5(), 0};
auto result = std::lower_bound(MD5s.begin(), MD5s.end(), md5_pair);
if (result == MD5s.end() || result->md5 != md5_pair.md5) {
return struct_pack::errc::invalid_buffer;
}
auto ret = template_switch<
deserialize_derived_class_helper<std::tuple<DerivedClasses...>>>(
result->index, base, unpack);
reader = std::move(wrapper.release_reader());
return ret;
}
} // namespace detail

} // namespace struct_pack
Loading

0 comments on commit ad3abef

Please sign in to comment.