Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic pow -> log/exp transformation #454

Merged
merged 6 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Changelog
New
~~~

- Non-number exponents for the ``pow()`` function
are now supported in Taylor integrators
(`#454 <https://github.com/bluescarni/heyoka/pull/454>`__).
- It is now possible to initialise a Taylor integrator
with an empty initial state vector. This will result
in zero-initialization of the state vector
Expand Down
49 changes: 25 additions & 24 deletions src/math/pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
// Fetch the pow eval algo.
const auto pea = get_pow_eval_algo(f);

// Codegen the exponent.
auto *expo = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);

// Fetch the internal vector type.
auto *vec_t = make_vector_type(fp_t, batch_size);

if (order == 0u) {
return pea.eval_f(
s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)});
return pea.eval_f(s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), expo});
}

// Special case for sqrt().
Expand All @@ -514,7 +519,6 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
}

// The general case.
auto &builder = s.builder();

// NOTE: iteration in the [0, order) range
// (i.e., order *not* included).
Expand All @@ -524,27 +528,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
auto *v1 = taylor_fetch_diff(arr, idx, j, n_uvars);

// Compute the scalar factor: order * num - j * (num + 1).
auto scal_f = [&]() -> llvm::Value * {
if constexpr (std::is_same_v<U, number>) {
return vector_splat(
builder,
llvm_codegen(s, fp_t,
number_like(s, fp_t, static_cast<double>(order)) * num
- number_like(s, fp_t, static_cast<double>(j)) * (num + number_like(s, fp_t, 1.))),
batch_size);
} else {
auto pc = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);
auto *jvec = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(j))), batch_size);
auto *ordvec
= vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(order))), batch_size);
auto *onevec = vector_splat(builder, llvm_codegen(s, fp_t, number(1.)), batch_size);
auto *jvec = llvm_codegen(s, vec_t, number(static_cast<double>(j)));
auto *ordvec = llvm_codegen(s, vec_t, number(static_cast<double>(order)));
auto *onevec = llvm_codegen(s, vec_t, number(1.));

auto tmp1 = llvm_fmul(s, ordvec, pc);
auto tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, pc, onevec));
auto *tmp1 = llvm_fmul(s, ordvec, expo);
auto *tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, expo, onevec));

return llvm_fsub(s, tmp1, tmp2);
}
}();
auto *scal_f = llvm_fsub(s, tmp1, tmp2);

// Add scal_f*v0*v1 to the sum.
sum.push_back(llvm_fmul(s, scal_f, llvm_fmul(s, v0, v1)));
Expand All @@ -554,14 +545,16 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
auto *ret_acc = pairwise_sum(s, sum);

// Compute the final divisor: order * (zero-th derivative of u_idx).
auto *ord_f = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(order))), batch_size);
auto *ord_f = llvm_codegen(s, vec_t, number(static_cast<double>(order)));
auto *b0 = taylor_fetch_diff(arr, u_idx, 0, n_uvars);
auto *div = llvm_fmul(s, ord_f, b0);

// Compute and return the result: ret_acc / div.
return llvm_fdiv(s, ret_acc, div);
}

// LCOV_EXCL_START

// All the other cases.
template <typename U1, typename U2, std::enable_if_t<!std::conjunction_v<is_num_param<U1>, is_num_param<U2>>, int> = 0>
llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &,
Expand All @@ -572,19 +565,23 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &,
"An invalid argument type was encountered while trying to build the Taylor derivative of a pow()");
}

// LCOV_EXCL_STOP

llvm::Value *taylor_diff_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &f, const std::vector<std::uint32_t> &deps,
const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
{
assert(f.args().size() == 2u);

// LCOV_EXCL_START
if (!deps.empty()) {
throw std::invalid_argument(
fmt::format("An empty hidden dependency vector is expected in order to compute the Taylor "
"derivative of the exponentiation, but a vector of size {} was passed "
"instead",
deps.size()));
}
// LCOV_EXCL_STOP

return std::visit(
[&](const auto &v1, const auto &v2) {
Expand Down Expand Up @@ -898,7 +895,7 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &s, llvm::Type *fp_t, con
auto *ft = llvm::FunctionType::get(val_t, fargs, false);
// Create the function
f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md);
assert(f != nullptr);
assert(f != nullptr); // LCOV_EXCL_LINE

// Fetch the necessary function arguments.
auto ord = f->args().begin();
Expand Down Expand Up @@ -969,6 +966,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &s, llvm::Type *fp_t, con
return f;
}

// LCOV_EXCL_START

// All the other cases.
template <typename U1, typename U2, std::enable_if_t<!std::conjunction_v<is_num_param<U1>, is_num_param<U2>>, int> = 0>
llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &,
Expand All @@ -978,6 +977,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const po
"of a pow() in compact mode");
}

// LCOV_EXCL_STOP

llvm::Function *taylor_c_diff_func_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &fn, std::uint32_t n_uvars,
std::uint32_t batch_size)
{
Expand Down
78 changes: 78 additions & 0 deletions src/taylor_01.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstdint>
#include <deque>
#include <exception>
#include <iterator>
#include <limits>
#include <numeric>
#include <optional>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -46,6 +48,7 @@

#include <heyoka/config.hpp>
#include <heyoka/detail/cm_utils.hpp>
#include <heyoka/detail/func_cache.hpp>
#include <heyoka/detail/llvm_func_create.hpp>
#include <heyoka/detail/llvm_helpers.hpp>
#include <heyoka/detail/logging_impl.hpp>
Expand All @@ -54,7 +57,11 @@
#include <heyoka/detail/type_traits.hpp>
#include <heyoka/detail/visibility.hpp>
#include <heyoka/expression.hpp>
#include <heyoka/func.hpp>
#include <heyoka/llvm_state.hpp>
#include <heyoka/math/exp.hpp>
#include <heyoka/math/log.hpp>
#include <heyoka/math/pow.hpp>
#include <heyoka/math/prod.hpp>
#include <heyoka/math/sum.hpp>
#include <heyoka/number.hpp>
Expand Down Expand Up @@ -769,6 +776,74 @@ void taylor_decompose_replace_numbers(taylor_dc_t &dc, std::vector<expression>::
}
}

// NOLINTNEXTLINE(misc-no-recursion)
expression pow_to_explog(funcptr_map<expression> &func_map, const expression &ex)
{
return std::visit(
// NOLINTNEXTLINE(misc-no-recursion)
[&]<typename T>(const T &v) {
if constexpr (std::same_as<T, func>) {
const auto *f_id = v.get_ptr();

// Check if we already performed the transformation on ex.
if (const auto it = func_map.find(f_id); it != func_map.end()) {
return it->second;
}

// Perform the transformation on the function arguments.
std::vector<expression> new_args;
new_args.reserve(v.args().size());
for (const auto &orig_arg : v.args()) {
new_args.push_back(pow_to_explog(func_map, orig_arg));
}

// Prepare the return value.
std::optional<expression> retval;

if (v.template extract<detail::pow_impl>() != nullptr
&& !std::holds_alternative<number>(new_args[1].value())) {
// The function is a pow() and the exponent is not a number: transform x**y -> exp(y*log(x)).
//
// NOTE: do not call directly log(new_args[0]) in order to avoid constant folding when the base
// is a number. For instance, if we have pow(2_dbl, par[0]), then we would end up computing
// log(2) in double precision. This would result in an inaccurate result if the fp type
// or precision in use during integration is higher than double.
// NOTE: because the exponent is not a number, no other constant folding should take
// place here.
retval.emplace(exp(new_args[1] * expression{func{detail::log_impl(new_args[0])}}));
} else {
// Create a copy of v with the new arguments.
retval.emplace(v.copy(std::move(new_args)));
}

// Put the return value into the cache.
[[maybe_unused]] const auto [_, flag] = func_map.emplace(f_id, *retval);
// NOTE: an expression cannot contain itself.
assert(flag); // LCOV_EXCL_LINE

return std::move(*retval);
} else {
return ex;
}
},
ex.value());
}

// Helper to transform x**y -> exp(y*log(x)), if y is not a number.
std::vector<expression> pow_to_explog(const std::vector<expression> &v_ex)
{
funcptr_map<expression> func_map;

std::vector<expression> retval;
retval.reserve(v_ex.size());

for (const auto &e : v_ex) {
retval.push_back(pow_to_explog(func_map, e));
}

return retval;
}

} // namespace

} // namespace detail
Expand Down Expand Up @@ -798,6 +873,9 @@ taylor_decompose_sys(const std::vector<std::pair<expression, expression>> &sys_,
std::ranges::transform(sys_, std::back_inserter(all_ex), &std::pair<expression, expression>::second);
all_ex.insert(all_ex.end(), sv_funcs_.begin(), sv_funcs_.end());

// Transform x**y -> exp(y*log(x)), if y is not a number.
all_ex = detail::pow_to_explog(all_ex);

// Transform sums into subs.
all_ex = detail::sum_to_sub(all_ex);

Expand Down
16 changes: 16 additions & 0 deletions test/llvm_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3397,6 +3397,20 @@ TEST_CASE("is_finite scalar mp")

#endif

// NOTE: is_natural() appears not to be working on ppc, giving the following error:
//
// LLVM ERROR: Cannot select: 0x63b95812620: v1i128 = bitcast 0x63b957f9700
// 0x63b957f9700: f128,ch = load<(load (s128) from %ir.2 + 16, align 1)> 0x63b95acca30, 0x63b957f9690, undef:i64
// 0x63b957f9690: i64 = add nuw 0x63b957f9380, Constant:i64<16>
// 0x63b957f9380: i64,ch = CopyFromReg 0x63b95acca30, Register:i64 %1
// 0x63b957f9310: i64 = Register %1
// 0x63b957f9620: i64 = Constant<16>
// 0x63b957f9460: i64 = undef
// In function: hey_is_natural
//
// This seems like an instruction selection problem specific to the ppc backend.
#if !defined(HEYOKA_ARCH_PPC)

TEST_CASE("is_natural scalar")
{
using detail::llvm_is_natural;
Expand Down Expand Up @@ -3578,3 +3592,5 @@ TEST_CASE("is_natural scalar mp")
}

#endif

#endif
Loading
Loading