From b76ef702f475ac193780e90a9ffbf7a1044fcebf Mon Sep 17 00:00:00 2001 From: "Zezheng.Li" Date: Mon, 18 Sep 2023 17:04:27 +0800 Subject: [PATCH] fix --- include/ylt/struct_pack.hpp | 65 +++++++------ include/ylt/struct_pack/reflection.hpp | 61 +++++++----- include/ylt/struct_pack/struct_pack_impl.hpp | 76 +++++++++------ .../tests/test_compile_time_calculate.cpp | 14 ++- src/struct_pack/tests/test_derived.cpp | 94 ++++++++++++++++--- 5 files changed, 210 insertions(+), 100 deletions(-) diff --git a/include/ylt/struct_pack.hpp b/include/ylt/struct_pack.hpp index 76aa58eec1..0fc7eba105 100644 --- a/include/ylt/struct_pack.hpp +++ b/include/ylt/struct_pack.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include @@ -512,48 +513,50 @@ template +[[nodiscard]] STRUCT_PACK_INLINE + struct_pack::expected, struct_pack::errc> + deserialize_derived_class(Reader &reader) { + static_assert(sizeof...(DerivedClasses) > 0, + "There must have a least one derived class"); + static_assert( + struct_pack::detail::public_base_class_checker< + BaseClass, std::tuple>::value, + "the First type should be the base class of all derived class "); + constexpr auto has_hash_collision = + struct_pack::detail::MD5_set::has_hash_collision; + if constexpr (has_hash_collision != 0) { + static_assert(!sizeof(std::tuple_element_t>), + "hash collision happened, consider add member `static " + "constexpr uint64_t struct_pack_id` for collision type. "); + } + else { + struct_pack::expected, struct_pack::errc> ret; + auto ec = struct_pack::detail::deserialize_derived_class( + ret.value(), reader); + if SP_UNLIKELY (ec != struct_pack::errc{}) { + ret = unexpected{ec}; + } + return ret; + } +} template [[nodiscard]] STRUCT_PACK_INLINE - struct_pack::expected, struct_pack::errc> + struct_pack::expected, struct_pack::errc> deserialize_derived_class(const View &v) { - static_assert(sizeof...(DerivedClasses) > 0); - static_assert(struct_pack::detail::public_base_class_checker< - BaseClass, std::tuple>::value, - "The First template argument is not the public"); - static_assert(!struct_pack::detail::MD5_set< - std::tuple>::has_hash_collision); - detail::memory_reader reader{v.data(), v.size()}; + detail::memory_reader reader{v.data(), v.data() + v.size()}; return deserialize_derived_class(reader); } template [[nodiscard]] STRUCT_PACK_INLINE - struct_pack::expected, struct_pack::errc> + struct_pack::expected, struct_pack::errc> deserialize_derived_class(const char *data, size_t size) { - static_assert(sizeof...(DerivedClasses) > 0); - static_assert(struct_pack::detail::public_base_class_checker< - BaseClass, std::tuple>::value); - static_assert(!struct_pack::detail::MD5_set< - std::tuple>::has_hash_collision); detail::memory_reader reader{data, data + size}; return deserialize_derived_class(reader); } -template -[[nodiscard]] STRUCT_PACK_INLINE - struct_pack::expected, struct_pack::errc> - deserialize_derived_class(Reader &reader) { - static_assert(sizeof...(DerivedClasses) > 0); - static_assert(struct_pack::detail::public_base_class_checker< - BaseClass, std::tuple>::value); - static_assert(!struct_pack::detail::MD5_set< - std::tuple>::has_hash_collision); - struct_pack::expected, struct_pack::errc> ret; - auto ec = struct_pack::detail::deserialize_derived_class(ret.value(), reader); - if SP_UNLIKELY (ec != struct_pack::errc{}) { - ret = unexpected{ec}; - } -} } // namespace struct_pack \ No newline at end of file diff --git a/include/ylt/struct_pack/reflection.hpp b/include/ylt/struct_pack/reflection.hpp index ff70612aab..c7b054c3da 100644 --- a/include/ylt/struct_pack/reflection.hpp +++ b/include/ylt/struct_pack/reflection.hpp @@ -89,28 +89,6 @@ constexpr std::size_t alignment_v = 0; #if __cpp_concepts >= 201907L -template -concept has_user_defined_id = requires { - typename std::integral_constant; -}; - -#else - -template -struct has_user_defined_id_impl : std::false_type {}; - -template -struct has_user_defined_id_impl< - T, std::void_t>> - : std::true_type {}; - -template -constexpr bool has_user_defined_id = has_user_defined_id_impl::value; - -#endif - -#if __cpp_concepts >= 201907L - template concept writer_t = requires(T t) { t.write((const char *)nullptr, std::size_t{}); @@ -189,6 +167,45 @@ struct compatible; // clang-format off namespace detail { +#if __cpp_concepts >= 201907L + +template +concept has_user_defined_id = requires { + typename std::integral_constant; +}; + +template +concept has_user_defined_id_ADL = requires { + typename std::integral_constant; +}; + +#else + +template +struct has_user_defined_id_impl : std::false_type {}; + +template +struct has_user_defined_id_impl< + T, std::void_t>> + : std::true_type {}; + +template +constexpr bool has_user_defined_id = has_user_defined_id_impl::value; + +template +struct has_user_defined_id_ADL_impl : std::false_type {}; + +template +struct has_user_defined_id_ADL_impl< + T, std::void_t>> + : std::true_type {}; + +template +constexpr bool has_user_defined_id_ADL = has_user_defined_id_ADL_impl::value; + +#endif + #if __cpp_concepts >= 201907L template concept deserialize_view = requires(Type container) { diff --git a/include/ylt/struct_pack/struct_pack_impl.hpp b/include/ylt/struct_pack/struct_pack_impl.hpp index 2272a5b22d..8c3ab3aba3 100644 --- a/include/ylt/struct_pack/struct_pack_impl.hpp +++ b/include/ylt/struct_pack/struct_pack_impl.hpp @@ -920,6 +920,11 @@ constexpr decltype(auto) get_type_end_flag() { {static_cast(type_id::type_end_flag_with_id)}} + get_size_literal(); } + else if constexpr (has_user_defined_id_ADL) { + return string_literal{ + {static_cast(type_id::type_end_flag_with_id)}} + + get_size_literal(); + } else { return string_literal{{static_cast(type_id::type_end_flag)}}; } @@ -1181,7 +1186,7 @@ constexpr bool check_if_compatible_element_exist_impl_helper() { template constexpr uint32_t get_types_code_impl() { constexpr auto str = get_types_literal...>(); - return MD5::MD5Hash32Constexpr(str.data(), str.size()); + return MD5::MD5Hash32Constexpr(str.data(), str.size()) & 0xFFFFFFFE; } template @@ -1492,9 +1497,8 @@ constexpr void get_compatible_version_numbers(Buffer &buffer, std::size_t &sz) { } } -template -constexpr void STRUCT_PACK_INLINE -compile_time_sort(std::array &array) { +template +constexpr void STRUCT_PACK_INLINE compile_time_sort(std::array &array) { // FIXME: use faster compile-time sort for (std::size_t i = 0; i < array.size(); ++i) { for (std::size_t j = i + 1; j < array.size(); ++j) { @@ -1701,10 +1705,10 @@ class packer { constexpr uint32_t raw_types_code = calculate_raw_hash(); if constexpr (serialize_static_config::has_compatible || check_if_add_type_literal()) { - return raw_types_code - raw_types_code % 2 + 1; + return raw_types_code + 1; } else { // default case, only has hash_code - return raw_types_code - raw_types_code % 2; + return raw_types_code; } } template = std::variant_size_v) { + unreachable(); return; } else { @@ -2445,7 +2450,7 @@ class unpacker { STRUCT_PACK_INLINE std::pair deserialize_metainfo() { uint32_t current_types_code; - if constexpr (is_MD5_reader_wrapper) { + if constexpr (is_MD5_reader_wrapper) { reader_.read_head((char *)¤t_types_code); } else { @@ -2996,6 +3001,9 @@ struct MD5_pair { constexpr friend bool operator<(const MD5_pair &l, const MD5_pair &r) { return l.md5 < r.md5; } + constexpr friend bool operator>(const MD5_pair &l, const MD5_pair &r) { + return l.md5 > r.md5; + } constexpr friend bool operator==(const MD5_pair &l, const MD5_pair &r) { return l.md5 == r.md5; } @@ -3005,36 +3013,39 @@ template struct MD5_set { static constexpr int size = sizeof...(DerivedClasses); static_assert(size <= 256); - static constexpr std::array value = - calculate_md5(std::make_index_sequence()); private: template static constexpr std::array calculate_md5( std::index_sequence) { - std::array md5{}; - ((md5[Index] = {get_type_code(), Index}), ...); + std::array md5{}; + ((md5[Index] = + MD5_pair{get_types_code())>() & + 0xFFFFFFFE, + Index}), + ...); compile_time_sort(md5); - return std::move(md5); + return md5; } - static constexpr bool has_hash_collision_impl() { - for (int i = 1; i < size; ++i) { + static constexpr std::size_t has_hash_collision_impl() { + for (std::size_t i = 1; i < size; ++i) { if (value[i - 1] == value[i]) { - return true; + return value[i].index; } } - return false; + return 0; } public: - static constexpr bool has_hash_collision = has_hash_collision_impl(); + static constexpr std::array value = + calculate_md5(std::make_index_sequence()); + static constexpr std::size_t has_hash_collision = has_hash_collision_impl(); }; template struct public_base_class_checker { static_assert(std::tuple_size_v <= 256); - static constexpr bool value = - calculate_md5(std::tuple_size_v); private: template @@ -3043,53 +3054,58 @@ struct public_base_class_checker { std::tuple_element_t> && ...); } + + public: + static constexpr bool value = public_base_class_checker::calculate_md5( + std::make_index_sequence>()); }; template struct deserialize_derived_class_helper { template - static STRUCT_PACK_INLINE constexpr std::errc run( - std::shared_ptr &base, unpack &unpacker) { + static STRUCT_PACK_INLINE constexpr struct_pack::errc run( + std::unique_ptr &base, unpack &unpacker) { if constexpr (index >= std::tuple_size_v) { - return std::errc{}; + unreachable(); + return struct_pack::errc{}; } else { using derived_class = std::tuple_element_t; - base = std::make_shared(); - return unpacker.template deserialize(base.get()); + base = std::make_unique(); + return unpacker.template deserialize(*(derived_class *)base.get()); } } }; template struct MD5_reader_wrapper : public Reader { - MD5_reader_wrapper(Reader &&reader) : reader(std::move(reader)) { - is_failed = reader.read(&head_chars, sizeof(head_chars)); + MD5_reader_wrapper(Reader &&reader) : Reader(std::move(reader)) { + is_failed = !Reader::read((char *)&head_chars, sizeof(head_chars)); } bool read_head(char *target) { memcpy(target, &head_chars, sizeof(head_chars)); return true; } - Reader &&release_reader() { return std::move(reader); } + Reader &&release_reader() { return std::move(*(Reader *)this); } bool is_failed; - uint32_t get_md5() { return head_chars; } + uint32_t get_md5() { return head_chars & 0xFFFFFFFE; } private: std::uint32_t head_chars; std::size_t read_pos; - Reader &reader; }; template [[nodiscard]] STRUCT_PACK_INLINE struct_pack::errc deserialize_derived_class( - std::shared_ptr &base, Reader &reader) { + std::unique_ptr &base, Reader &reader) { MD5_reader_wrapper wrapper{std::move(reader)}; if (wrapper.is_failed) { return struct_pack::errc::no_buffer_space; } unpacker> unpack{wrapper}; constexpr auto &MD5s = MD5_set::value; + static_assert(MD5s.size() == sizeof...(DerivedClasses)); MD5_pair md5_pair{wrapper.get_md5(), 0}; auto result = std::lower_bound(MD5s.begin(), MD5s.end(), md5_pair); if (result == MD5s.end() || result->md5 != md5_pair.md5) { diff --git a/src/struct_pack/tests/test_compile_time_calculate.cpp b/src/struct_pack/tests/test_compile_time_calculate.cpp index d908647734..6d8de1737a 100644 --- a/src/struct_pack/tests/test_compile_time_calculate.cpp +++ b/src/struct_pack/tests/test_compile_time_calculate.cpp @@ -405,9 +405,17 @@ struct bar_with_ID1 { constexpr static std::size_t struct_pack_id = 1; }; +struct bar_with_ID2 { + std::vector a; + std::vector b; +}; +constexpr int struct_pack_id(bar_with_ID2*) { return 11; } + TEST_CASE("test user defined ID") { + static_assert(struct_pack::detail::has_user_defined_id_ADL); { - static_assert(has_user_defined_id); + static_assert( + struct_pack::detail::has_user_defined_id); static_assert(struct_pack::get_type_literal() != struct_pack::get_type_literal()); static_assert(struct_pack::get_type_literal() != @@ -416,7 +424,7 @@ TEST_CASE("test user defined ID") { struct_pack::get_type_literal()); } { - static_assert(has_user_defined_id); + static_assert(struct_pack::detail::has_user_defined_id); static_assert(struct_pack::get_type_literal() != struct_pack::get_type_literal()); static_assert(struct_pack::get_type_literal() != @@ -427,5 +435,7 @@ TEST_CASE("test user defined ID") { struct_pack::get_type_literal()); static_assert(struct_pack::get_type_literal() != struct_pack::get_type_literal()); + static_assert(struct_pack::get_type_literal() != + struct_pack::get_type_literal()); } } diff --git a/src/struct_pack/tests/test_derived.cpp b/src/struct_pack/tests/test_derived.cpp index d5ffc98f2c..89c4aa7581 100644 --- a/src/struct_pack/tests/test_derived.cpp +++ b/src/struct_pack/tests/test_derived.cpp @@ -5,47 +5,111 @@ #include "doctest.h" #include "ylt/struct_pack.hpp" +#include "ylt/struct_pack/struct_pack_impl.hpp" struct Base { - static struct_pack::expected, struct_pack::errc> + Base(){}; + static struct_pack::expected, struct_pack::errc> deserialize(std::string_view sv); + virtual std::string get_name() const = 0; + virtual ~Base(){}; }; struct foo : public Base { - std::string hello; - std::string hi; + std::string hello = "1"; + std::string hi = "2"; + virtual std::string get_name() const override { return "foo"; } friend bool operator==(const foo& a, const foo& b) { return a.hello == b.hello && a.hi == b.hi; } }; +STRUCT_PACK_REFL(foo, hello, hi); +struct foo2 : public Base { + std::string hello = "1"; + std::string hi = "2"; + virtual std::string get_name() const override { return "foo2"; } + friend bool operator==(const foo2& a, const foo2& b) { + return a.hello == b.hello && a.hi == b.hi; + } + static constexpr uint64_t struct_pack_id = 114514; +}; +STRUCT_PACK_REFL(foo2, hello, hi); +struct foo3 : public Base { + std::string hello = "1"; + std::string hi = "2"; + virtual std::string get_name() const override { return "foo3"; } + friend bool operator==(const foo3& a, const foo3& b) { + return a.hello == b.hello && a.hi == b.hi; + } +}; +STRUCT_PACK_REFL(foo3, hello, hi); +struct foo4 : public Base { + std::string hello = "1"; + std::string hi = "2"; + virtual std::string get_name() const override { return "foo4"; } + friend bool operator==(const foo4& a, const foo4& b) { + return a.hello == b.hello && a.hi == b.hi; + } +}; +constexpr int struct_pack_id(foo4*) { return 112233211; } +STRUCT_PACK_REFL(foo4, hello, hi); struct bar : public Base { - int oh; - int no; + int oh = 1; + int no = 2; + virtual std::string get_name() const override { return "bar"; } friend bool operator==(const bar& a, const bar& b) { return a.oh == b.oh && a.no == b.no; } }; +STRUCT_PACK_REFL(bar, oh, no); struct gua : Base { - std::vector foos; - std::vector bars; + std::string a = "Hello"; + int b = 1; + virtual std::string get_name() const override { return "gua"; } friend bool operator==(const gua& a, const gua& b) { - return a.foos == b.foos && a.bars == b.bars; + return a.a == b.a && a.b == b.b; } }; -struct_pack::expected, struct_pack::errc> +STRUCT_PACK_REFL(gua, a, b); + +struct_pack::expected, struct_pack::errc> Base::deserialize(std::string_view sv) { - return struct_pack::deserialize_derived_class(sv); + return struct_pack::deserialize_derived_class(sv); } + TEST_CASE("testing derived") { - foo f{.hello = "1", .hi = "2"}; - bar b{.oh = 1, .no = 2}; - gua g{.foos = {{.hello = "23", .hi = "34"}, {.hello = "45", .hi = "67"}}, - .bars = {{.oh = 1, .no = 2}, {.oh = 3, .no = 4}}}; + using namespace std::string_literals; + static_assert(struct_pack::detail::has_user_defined_id_ADL); + foo f; + foo2 f2; + foo4 f4; + bar b; + gua g; std::vector vecs{struct_pack::serialize(f), struct_pack::serialize(b), - struct_pack::serialize(g)}; + struct_pack::serialize(g), + struct_pack::serialize(f2), + struct_pack::serialize(f4)}; auto f1 = Base::deserialize(vecs[0]); auto b1 = Base::deserialize(vecs[1]); auto g1 = Base::deserialize(vecs[2]); + auto f21 = Base::deserialize(vecs[3]); + auto f41 = Base::deserialize(vecs[4]); CHECK(*(foo*)(f1->get()) == f); + CHECK(*(foo2*)(f21->get()) == f2); CHECK(*(bar*)(b1->get()) == b); CHECK(*(gua*)(g1->get()) == g); + std::vector, std::string>> vec; + vec.emplace_back(std::move(f1.value()), "foo"); + vec.emplace_back(std::move(f21.value()), "foo2"); + vec.emplace_back(std::move(b1.value()), "bar"); + vec.emplace_back(std::move(g1.value()), "gua"); + vec.emplace_back(std::move(f41.value()), "foo4"); + for (const auto& e : vec) { + CHECK(e.first->get_name() == e.second); + } +} + +TEST_CASE("test hash collision") { + static_assert(struct_pack::detail::MD5_set::has_hash_collision != 0); } \ No newline at end of file