From b349f643339dea492e814beaef1f9610922a1e63 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Mon, 5 Aug 2024 11:01:06 +0200 Subject: [PATCH] Add multi state ctor from input range. --- include/heyoka/llvm_state.hpp | 23 +++++++++ test/llvm_multi_state.cpp | 88 +++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/include/heyoka/llvm_state.hpp b/include/heyoka/llvm_state.hpp index b14e90c02..b1f74f966 100644 --- a/include/heyoka/llvm_state.hpp +++ b/include/heyoka/llvm_state.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -368,9 +369,31 @@ class HEYOKA_DLL_PUBLIC llvm_multi_state void load(boost::archive::binary_iarchive &, unsigned); BOOST_SERIALIZATION_SPLIT_MEMBER() + // Helper to turn an input range into a vector of llvm_state objects. + template + static std::vector rng_to_vector(R &&r) + { + std::vector retval; + if constexpr (std::ranges::sized_range) { + retval.reserve(static_cast(std::ranges::size(r))); + } + + for (auto &&s : r) { + retval.push_back(std::forward(s)); + } + + return retval; + } + public: llvm_multi_state(); explicit llvm_multi_state(std::vector); + template + requires std::ranges::input_range + && std::same_as>> + explicit llvm_multi_state(R &&rng) : llvm_multi_state(rng_to_vector(std::forward(rng))) + { + } llvm_multi_state(const llvm_multi_state &); llvm_multi_state(llvm_multi_state &&) noexcept; llvm_multi_state &operator=(const llvm_multi_state &); diff --git a/test/llvm_multi_state.cpp b/test/llvm_multi_state.cpp index 652ab5b56..40fd1578f 100644 --- a/test/llvm_multi_state.cpp +++ b/test/llvm_multi_state.cpp @@ -6,7 +6,10 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include #include +#include +#include #include #include @@ -576,3 +579,88 @@ TEST_CASE("vfabi double") #endif } } + +// Test for the range constructor. +TEST_CASE("range ctor") +{ + auto [x, y] = make_vars("x", "y"); + + { + std::list slist; + + slist.emplace_back(); + add_cfunc(slist.back(), "f1", {x * y}, {x, y}); + + slist.emplace_back(); + add_cfunc(slist.back(), "f2", {x / y}, {x, y}); + + slist.emplace_back(); + add_cfunc(slist.back(), "f3", {x + y}, {x, y}); + + slist.emplace_back(); + add_cfunc(slist.back(), "f4", {x - y}, {x, y}); + + llvm_state::clear_memcache(); + + llvm_multi_state ms{slist | std::views::transform([](auto &s) -> auto && { return std::move(s); })}; + ms.compile(); + + auto *cf1_ptr + = reinterpret_cast(ms.jit_lookup("f1")); + auto *cf2_ptr + = reinterpret_cast(ms.jit_lookup("f2")); + auto *cf3_ptr + = reinterpret_cast(ms.jit_lookup("f3")); + auto *cf4_ptr + = reinterpret_cast(ms.jit_lookup("f4")); + + const double ins[] = {2., 3.}; + double outs[4] = {}; + + cf1_ptr(outs, ins, nullptr, nullptr); + cf2_ptr(outs + 1, ins, nullptr, nullptr); + cf3_ptr(outs + 2, ins, nullptr, nullptr); + cf4_ptr(outs + 3, ins, nullptr, nullptr); + + REQUIRE(outs[0] == 6); + REQUIRE(outs[1] == 2. / 3.); + REQUIRE(outs[2] == 5); + REQUIRE(outs[3] == -1); + } + + { + std::array slist; + + add_cfunc(slist[0], "f1", {x * y}, {x, y}); + add_cfunc(slist[1], "f2", {x / y}, {x, y}); + add_cfunc(slist[2], "f3", {x + y}, {x, y}); + add_cfunc(slist[3], "f4", {x - y}, {x, y}); + + llvm_state::clear_memcache(); + + llvm_multi_state ms{slist | std::views::transform([](auto &s) -> auto && { return std::move(s); })}; + ms.compile(); + + auto *cf1_ptr + = reinterpret_cast(ms.jit_lookup("f1")); + auto *cf2_ptr + = reinterpret_cast(ms.jit_lookup("f2")); + auto *cf3_ptr + = reinterpret_cast(ms.jit_lookup("f3")); + auto *cf4_ptr + = reinterpret_cast(ms.jit_lookup("f4")); + + const double ins[] = {2., 3.}; + double outs[4] = {}; + + cf1_ptr(outs, ins, nullptr, nullptr); + cf2_ptr(outs + 1, ins, nullptr, nullptr); + cf3_ptr(outs + 2, ins, nullptr, nullptr); + cf4_ptr(outs + 3, ins, nullptr, nullptr); + + REQUIRE(outs[0] == 6); + REQUIRE(outs[1] == 2. / 3.); + REQUIRE(outs[2] == 5); + REQUIRE(outs[3] == -1); + } +}