From b8a022fd5a0041be13d8e6bd7bc7195873d77c48 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 14 Sep 2024 11:06:57 +0200 Subject: [PATCH] Add new llvm primitives for trunc() and for testing finiteness/naturalness of FP values. --- include/heyoka/detail/llvm_helpers.hpp | 4 + include/heyoka/detail/real_helpers.hpp | 1 + src/detail/llvm_helpers.cpp | 63 +++ src/detail/real_helpers.cpp | 15 + src/detail/vector_math.cpp | 2 +- test/llvm_helpers.cpp | 510 +++++++++++++++++++++++++ 6 files changed, 594 insertions(+), 1 deletion(-) diff --git a/include/heyoka/detail/llvm_helpers.hpp b/include/heyoka/detail/llvm_helpers.hpp index cbc4e0395..8078d4d64 100644 --- a/include/heyoka/detail/llvm_helpers.hpp +++ b/include/heyoka/detail/llvm_helpers.hpp @@ -157,6 +157,10 @@ HEYOKA_DLL_PUBLIC llvm::Value *llvm_ui_to_fp(llvm_state &, llvm::Value *, llvm:: HEYOKA_DLL_PUBLIC llvm::Value *llvm_abs(llvm_state &, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_floor(llvm_state &, llvm::Value *); +HEYOKA_DLL_PUBLIC llvm::Value *llvm_trunc(llvm_state &, llvm::Value *); + +HEYOKA_DLL_PUBLIC llvm::Value *llvm_is_finite(llvm_state &, llvm::Value *); +HEYOKA_DLL_PUBLIC llvm::Value *llvm_is_natural(llvm_state &, llvm::Value *); HEYOKA_DLL_PUBLIC std::pair llvm_sincos(llvm_state &, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_atan2(llvm_state &, llvm::Value *, llvm::Value *); diff --git a/include/heyoka/detail/real_helpers.hpp b/include/heyoka/detail/real_helpers.hpp index c64a239de..990543578 100644 --- a/include/heyoka/detail/real_helpers.hpp +++ b/include/heyoka/detail/real_helpers.hpp @@ -44,6 +44,7 @@ llvm::Value *llvm_real_fcmp_one(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fnz(llvm_state &, llvm::Value *); llvm::Value *llvm_real_ui_to_fp(llvm_state &, llvm::Value *, llvm::Type *); llvm::Value *llvm_real_sgn(llvm_state &, llvm::Value *); +llvm::Value *llvm_real_isfinite(llvm_state &, llvm::Value *); HEYOKA_DLL_PUBLIC mppp::real eps_from_prec(mpfr_prec_t); diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index bf81c3e93..f808f1e2e 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -2347,6 +2347,69 @@ llvm::Value *llvm_floor(llvm_state &s, llvm::Value *x) x); } +// Trunc. +llvm::Value *llvm_trunc(llvm_state &s, llvm::Value *x) +{ + return llvm_math_intr(s, "llvm.trunc", +#if defined(HEYOKA_HAVE_REAL128) + "truncq", +#endif +#if defined(HEYOKA_HAVE_REAL) + "mpfr_trunc", +#endif + x); +} + +// is_finite(). +llvm::Value *llvm_is_finite(llvm_state &s, llvm::Value *x) +{ + assert(x != nullptr); + + auto *x_t = x->getType(); + + if (x_t->getScalarType()->isFloatingPointTy()) { + // Codegen +- inf. + auto *pinf = llvm_codegen(s, x_t, number{std::numeric_limits::infinity()}); + auto *minf = llvm_codegen(s, x_t, number{-std::numeric_limits::infinity()}); + + // Check that if x is not +- inf or NaN. + auto *x_not_pinf = llvm_fcmp_one(s, x, pinf); + auto *x_not_minf = llvm_fcmp_one(s, x, minf); + + // Put the conditions together and return. + return s.builder().CreateLogicalAnd(x_not_pinf, x_not_minf); +#if defined(HEYOKA_HAVE_REAL) + } else if (llvm_is_real(x_t) != 0) { + return llvm_real_isfinite(s, x); +#endif + // LCOV_EXCL_START + } else [[unlikely]] { + throw std::invalid_argument(fmt::format( + "Invalid type '{}' encountered in the LLVM implementation of is_finite()", llvm_type_name(x_t))); + } + // LCOV_EXCL_STOP +} + +// Check if the input floating-point value is a natural number. +llvm::Value *llvm_is_natural(llvm_state &s, llvm::Value *x) +{ + // Is x finite? + auto *x_finite = llvm_is_finite(s, x); + + // Is x>=0? + auto *x_ge_0 = llvm_fcmp_oge(s, x, llvm_codegen(s, x->getType(), number{0.})); + + // Is x an integral value? + auto *x_int = llvm_fcmp_oeq(s, x, llvm_trunc(s, x)); + + // Put the conditions together and return. + auto &bld = s.builder(); + auto *ret = bld.CreateLogicalAnd(x_finite, x_ge_0); + ret = bld.CreateLogicalAnd(ret, x_int); + + return ret; +} + // Add a function to count the number of sign changes in the coefficients // of a polynomial of degree n. The coefficients are SIMD vectors of size batch_size // and scalar type scal_t. diff --git a/src/detail/real_helpers.cpp b/src/detail/real_helpers.cpp index 100684190..e290de060 100644 --- a/src/detail/real_helpers.cpp +++ b/src/detail/real_helpers.cpp @@ -611,6 +611,21 @@ llvm::Value *llvm_real_fnz(llvm_state &s, llvm::Value *x) return builder.CreateICmpEQ(ret, llvm::ConstantInt::getNullValue(ret->getType())); } +llvm::Value *llvm_real_isfinite(llvm_state &s, llvm::Value *x) +{ + // LCOV_EXCL_START + assert(x != nullptr); + // LCOV_EXCL_STOP + + auto &bld = s.builder(); + + // Check if x is an ordinary number. + auto *f = real_nary_cmp(s, x->getType(), "mpfr_number_p", 1u); + auto *ret = bld.CreateCall(f, x); + + return ret; +} + // Convert the input unsigned integral value n to the real type fp_t. llvm::Value *llvm_real_ui_to_fp(llvm_state &s, llvm::Value *n, llvm::Type *fp_t) { diff --git a/src/detail/vector_math.cpp b/src/detail/vector_math.cpp index 12405fab1..308968e62 100644 --- a/src/detail/vector_math.cpp +++ b/src/detail/vector_math.cpp @@ -126,7 +126,7 @@ auto make_vf_map() // by sleef, on the assumption that usually sqrt() is implemented directly in hardware // and thus there's no need to go through sleef. This is certainly true for x86, // but I am not 100% sure for the other archs. Let's keep this in mind. - // NOTE: the same holds for things like abs() and floor(). + // NOTE: the same holds for things like abs(), floor(), trunc(), etc. // Single-precision. add_vfinfo_sleef(retval, "llvm.sin.f32", "sin", "f"); diff --git a/test/llvm_helpers.cpp b/test/llvm_helpers.cpp index b96833ac9..ee7f06904 100644 --- a/test/llvm_helpers.cpp +++ b/test/llvm_helpers.cpp @@ -3068,3 +3068,513 @@ TEST_CASE("clone type") #endif } + +TEST_CASE("trunc scalar") +{ + using detail::llvm_trunc; + using detail::to_external_llvm_type; + + auto tester = [](auto fp_x) { + using fp_t = decltype(fp_x); + + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto val_t = to_external_llvm_type(context); + + std::vector fargs{val_t}; + auto *ft = llvm::FunctionType::get(val_t, fargs, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_trunc", &md); + + auto x = f->args().begin(); + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + builder.CreateRet(llvm_trunc(s, x)); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("hey_trunc")); + + REQUIRE(f_ptr(fp_t(2) / 7) == 0); + REQUIRE(f_ptr(fp_t(8) / 7) == 1); + REQUIRE(f_ptr(fp_t(-2) / 7) == 0); + REQUIRE(f_ptr(fp_t(-8) / 7) == -1); + } + }; + + tuple_for_each(fp_types, tester); +} + +TEST_CASE("trunc batch") +{ + using detail::llvm_trunc; + using detail::to_external_llvm_type; + using std::trunc; + + auto tester = [](auto fp_x) { + using fp_t = decltype(fp_x); + + for (auto batch_size : {1u, 2u, 4u, 13u}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto val_t = to_external_llvm_type(context); + + std::vector fargs(2u, llvm::PointerType::getUnqual(val_t)); + auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_trunc", &md); + + auto ret_ptr = f->args().begin(); + auto x_ptr = f->args().begin() + 1; + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + auto ret = llvm_trunc(s, detail::load_vector_from_memory(builder, val_t, x_ptr, batch_size)); + + detail::store_vector_to_memory(builder, ret_ptr, ret); + + builder.CreateRetVoid(); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("hey_trunc")); + + // Setup the argument and the output value. + std::vector ret_vec(batch_size), a_vec(ret_vec); + for (auto i = 0u; i < batch_size; ++i) { + a_vec[i] = fp_t(i + 1u) / 3 * (i % 2u == 0 ? 1 : -1); + } + + f_ptr(ret_vec.data(), a_vec.data()); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(ret_vec[i] == trunc(a_vec[i])); + } + } + } + }; + + tuple_for_each(fp_types, tester); +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("trunc scalar mp") +{ + using detail::llvm_trunc; + using detail::to_external_llvm_type; + + using fp_t = mppp::real; + + const auto prec = 237u; + + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto *val_t = to_external_llvm_type(context); + auto *real_t = detail::internal_llvm_type_like(s, mppp::real{0, prec}); + + auto *ft = llvm::FunctionType::get( + builder.getVoidTy(), {llvm::PointerType::getUnqual(val_t), llvm::PointerType::getUnqual(val_t)}, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "trunc", &md); + + auto *ret_ptr = f->args().begin(); + auto *in_ptr = f->args().begin() + 1; + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + auto *x = detail::ext_load_vector_from_memory(s, real_t, in_ptr, 1u); + + // Compute the result. + auto *ret = llvm_trunc(s, x); + + // Store it. + detail::ext_store_vector_to_memory(s, ret_ptr, ret); + + // Return. + builder.CreateRetVoid(); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("trunc")); + + mppp::real arg{1.5, prec}, out = arg; + f_ptr(&out, &arg); + REQUIRE(out == 1); + arg.set(-1.1); + f_ptr(&out, &arg); + REQUIRE(out == -1); + } +} + +#endif + +TEST_CASE("is_finite scalar") +{ + using detail::llvm_is_finite; + using detail::to_external_llvm_type; + + auto tester = [](auto fp_x) { + using fp_t = decltype(fp_x); + + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto val_t = to_external_llvm_type(context); + + std::vector fargs{val_t}; + auto *ft = llvm::FunctionType::get(builder.getInt32Ty(), fargs, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_is_finite", &md); + + auto x = f->args().begin(); + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + builder.CreateRet(builder.CreateZExt(llvm_is_finite(s, x), builder.getInt32Ty())); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("hey_is_finite")); + + REQUIRE(f_ptr(fp_t(123)) == 1); + REQUIRE(f_ptr(fp_t(-123)) == 1); + REQUIRE(f_ptr(std::numeric_limits::infinity()) == 0); + REQUIRE(f_ptr(-std::numeric_limits::infinity()) == 0); + REQUIRE(f_ptr(-std::numeric_limits::quiet_NaN()) == 0); + } + }; + + tuple_for_each(fp_types, tester); +} + +TEST_CASE("is_finite batch") +{ + using detail::llvm_is_finite; + using detail::to_external_llvm_type; + using std::isfinite; + + auto tester = [](auto fp_x) { + using fp_t = decltype(fp_x); + + for (auto batch_size : {1u, 2u, 4u, 13u}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto val_t = to_external_llvm_type(context); + + std::vector fargs(2u, llvm::PointerType::getUnqual(context)); + auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_is_finite", &md); + + auto ret_ptr = f->args().begin(); + auto x_ptr = f->args().begin() + 1; + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + auto *ret = llvm_is_finite(s, detail::load_vector_from_memory(builder, val_t, x_ptr, batch_size)); + ret = builder.CreateZExt(ret, detail::make_vector_type(builder.getInt32Ty(), batch_size)); + + detail::store_vector_to_memory(builder, ret_ptr, ret); + + builder.CreateRetVoid(); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("hey_is_finite")); + + // Setup the argument and the output value. + std::vector ret_vec(batch_size); + std::vector a_vec(batch_size); + for (auto i = 0u; i < batch_size; ++i) { + if (i == 0u) { + a_vec[i] = std::numeric_limits::quiet_NaN(); + } else if (i == 1u) { + a_vec[i] = -std::numeric_limits::infinity(); + } else { + a_vec[i] = static_cast(i); + } + } + + f_ptr(ret_vec.data(), a_vec.data()); + + for (auto i = 0u; i < batch_size; ++i) { + if (i == 0u || i == 1u) { + REQUIRE(ret_vec[i] == 0u); + } else { + REQUIRE(ret_vec[i] == 1u); + } + } + } + } + }; + + tuple_for_each(fp_types, tester); +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("is_finite scalar mp") +{ + using detail::llvm_is_finite; + using detail::to_external_llvm_type; + + using fp_t = mppp::real; + + const auto prec = 237u; + + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto *val_t = to_external_llvm_type(context); + auto *real_t = detail::internal_llvm_type_like(s, mppp::real{0, prec}); + + auto *ft = llvm::FunctionType::get(builder.getInt32Ty(), {llvm::PointerType::getUnqual(val_t)}, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "is_finite", &md); + + auto *in_ptr = f->args().begin(); + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + auto *x = detail::ext_load_vector_from_memory(s, real_t, in_ptr, 1u); + + // Compute the result. + auto *ret = llvm_is_finite(s, x); + + // Return it. + builder.CreateRet(builder.CreateZExt(ret, builder.getInt32Ty())); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("is_finite")); + + mppp::real arg{1.5, prec}; + REQUIRE(f_ptr(&arg) == 1u); + arg.set(std::numeric_limits::infinity()); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(std::numeric_limits::quiet_NaN()); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(-123); + REQUIRE(f_ptr(&arg) == 1u); + } +} + +#endif + +TEST_CASE("is_natural scalar") +{ + using detail::llvm_is_natural; + using detail::to_external_llvm_type; + + auto tester = [](auto fp_x) { + using fp_t = decltype(fp_x); + + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto val_t = to_external_llvm_type(context); + + std::vector fargs{val_t}; + auto *ft = llvm::FunctionType::get(builder.getInt32Ty(), fargs, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_is_natural", &md); + + auto x = f->args().begin(); + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + builder.CreateRet(builder.CreateZExt(llvm_is_natural(s, x), builder.getInt32Ty())); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("hey_is_natural")); + + REQUIRE(f_ptr(fp_t(123.1)) == 0); + REQUIRE(f_ptr(fp_t(123)) == 1); + REQUIRE(f_ptr(fp_t(-123)) == 0); + REQUIRE(f_ptr(std::numeric_limits::infinity()) == 0); + REQUIRE(f_ptr(-std::numeric_limits::infinity()) == 0); + REQUIRE(f_ptr(-std::numeric_limits::quiet_NaN()) == 0); + } + }; + + tuple_for_each(fp_types, tester); +} + +TEST_CASE("is_natural batch") +{ + using detail::llvm_is_natural; + using detail::to_external_llvm_type; + using std::isfinite; + + auto tester = [](auto fp_x) { + using fp_t = decltype(fp_x); + + for (auto batch_size : {1u, 2u, 4u, 13u}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto val_t = to_external_llvm_type(context); + + std::vector fargs(2u, llvm::PointerType::getUnqual(context)); + auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_is_natural", &md); + + auto ret_ptr = f->args().begin(); + auto x_ptr = f->args().begin() + 1; + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + auto *ret = llvm_is_natural(s, detail::load_vector_from_memory(builder, val_t, x_ptr, batch_size)); + ret = builder.CreateZExt(ret, detail::make_vector_type(builder.getInt32Ty(), batch_size)); + + detail::store_vector_to_memory(builder, ret_ptr, ret); + + builder.CreateRetVoid(); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("hey_is_natural")); + + // Setup the argument and the output value. + std::vector ret_vec(batch_size); + std::vector a_vec(batch_size); + for (auto i = 0u; i < batch_size; ++i) { + if (i == 0u) { + a_vec[i] = std::numeric_limits::quiet_NaN(); + } else if (i == 1u) { + a_vec[i] = -std::numeric_limits::infinity(); + } else if (i == 2u) { + a_vec[i] = static_cast(i) + static_cast(.1); + } else if (i == 3u) { + a_vec[i] = -static_cast(i); + } else { + a_vec[i] = static_cast(i); + } + } + + f_ptr(ret_vec.data(), a_vec.data()); + + for (auto i = 0u; i < batch_size; ++i) { + if (i <= 3u) { + REQUIRE(ret_vec[i] == 0u); + } else { + REQUIRE(ret_vec[i] == 1u); + } + } + } + } + }; + + tuple_for_each(fp_types, tester); +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("is_natural scalar mp") +{ + using detail::llvm_is_natural; + using detail::to_external_llvm_type; + + using fp_t = mppp::real; + + const auto prec = 237u; + + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + auto *val_t = to_external_llvm_type(context); + auto *real_t = detail::internal_llvm_type_like(s, mppp::real{0, prec}); + + auto *ft = llvm::FunctionType::get(builder.getInt32Ty(), {llvm::PointerType::getUnqual(val_t)}, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "is_natural", &md); + + auto *in_ptr = f->args().begin(); + + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + auto *x = detail::ext_load_vector_from_memory(s, real_t, in_ptr, 1u); + + // Compute the result. + auto *ret = llvm_is_natural(s, x); + + // Return it. + builder.CreateRet(builder.CreateZExt(ret, builder.getInt32Ty())); + + // Compile. + s.compile(); + + // Fetch the function pointer. + auto f_ptr = reinterpret_cast(s.jit_lookup("is_natural")); + + mppp::real arg{1.5, prec}; + REQUIRE(f_ptr(&arg) == 0u); + arg.set(-1); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(123); + REQUIRE(f_ptr(&arg) == 1u); + arg.set(123.1); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(-123.1); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(std::numeric_limits::infinity()); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(std::numeric_limits::quiet_NaN()); + REQUIRE(f_ptr(&arg) == 0u); + arg.set(-123); + REQUIRE(f_ptr(&arg) == 0u); + } +} + +#endif