Skip to content

Commit

Permalink
Add multi state ctor from input range.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Aug 5, 2024
1 parent f078fc3 commit b349f64
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
23 changes: 23 additions & 0 deletions include/heyoka/llvm_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdint>
#include <memory>
#include <ostream>
#include <ranges>
#include <string>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -368,9 +369,31 @@ class HEYOKA_DLL_PUBLIC llvm_multi_state
void load(boost::archive::binary_iarchive &, unsigned);
BOOST_SERIALIZATION_SPLIT_MEMBER()

// Helper to turn an input range into a vector of llvm_state objects.
template <typename R>
static std::vector<llvm_state> rng_to_vector(R &&r)
{
std::vector<llvm_state> retval;
if constexpr (std::ranges::sized_range<R>) {
retval.reserve(static_cast<decltype(retval.size())>(std::ranges::size(r)));
}

for (auto &&s : r) {
retval.push_back(std::forward<decltype(s)>(s));
}

return retval;
}

public:
llvm_multi_state();
explicit llvm_multi_state(std::vector<llvm_state>);
template <typename R>
requires std::ranges::input_range<R>
&& std::same_as<llvm_state, std::remove_cvref_t<std::ranges::range_reference_t<R>>>
explicit llvm_multi_state(R &&rng) : llvm_multi_state(rng_to_vector(std::forward<R>(rng)))
{
}
llvm_multi_state(const llvm_multi_state &);
llvm_multi_state(llvm_multi_state &&) noexcept;
llvm_multi_state &operator=(const llvm_multi_state &);
Expand Down
88 changes: 88 additions & 0 deletions test/llvm_multi_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include <array>
#include <cmath>
#include <list>
#include <ranges>
#include <sstream>
#include <stdexcept>

Expand Down Expand Up @@ -576,3 +579,88 @@ TEST_CASE("vfabi double")
#endif
}
}

// Test for the range constructor.
TEST_CASE("range ctor")
{
auto [x, y] = make_vars("x", "y");

{
std::list<llvm_state> slist;

slist.emplace_back();
add_cfunc<double>(slist.back(), "f1", {x * y}, {x, y});

slist.emplace_back();
add_cfunc<double>(slist.back(), "f2", {x / y}, {x, y});

slist.emplace_back();
add_cfunc<double>(slist.back(), "f3", {x + y}, {x, y});

slist.emplace_back();
add_cfunc<double>(slist.back(), "f4", {x - y}, {x, y});

llvm_state::clear_memcache();

llvm_multi_state ms{slist | std::views::transform([](auto &s) -> auto && { return std::move(s); })};
ms.compile();

auto *cf1_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f1"));
auto *cf2_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f2"));
auto *cf3_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f3"));
auto *cf4_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f4"));

const double ins[] = {2., 3.};
double outs[4] = {};

cf1_ptr(outs, ins, nullptr, nullptr);
cf2_ptr(outs + 1, ins, nullptr, nullptr);
cf3_ptr(outs + 2, ins, nullptr, nullptr);
cf4_ptr(outs + 3, ins, nullptr, nullptr);

REQUIRE(outs[0] == 6);
REQUIRE(outs[1] == 2. / 3.);
REQUIRE(outs[2] == 5);
REQUIRE(outs[3] == -1);
}

{
std::array<llvm_state, 4> slist;

add_cfunc<double>(slist[0], "f1", {x * y}, {x, y});
add_cfunc<double>(slist[1], "f2", {x / y}, {x, y});
add_cfunc<double>(slist[2], "f3", {x + y}, {x, y});
add_cfunc<double>(slist[3], "f4", {x - y}, {x, y});

llvm_state::clear_memcache();

llvm_multi_state ms{slist | std::views::transform([](auto &s) -> auto && { return std::move(s); })};
ms.compile();

auto *cf1_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f1"));
auto *cf2_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f2"));
auto *cf3_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f3"));
auto *cf4_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(ms.jit_lookup("f4"));

const double ins[] = {2., 3.};
double outs[4] = {};

cf1_ptr(outs, ins, nullptr, nullptr);
cf2_ptr(outs + 1, ins, nullptr, nullptr);
cf3_ptr(outs + 2, ins, nullptr, nullptr);
cf4_ptr(outs + 3, ins, nullptr, nullptr);

REQUIRE(outs[0] == 6);
REQUIRE(outs[1] == 2. / 3.);
REQUIRE(outs[2] == 5);
REQUIRE(outs[3] == -1);
}
}

0 comments on commit b349f64

Please sign in to comment.