From 81157ba33279cb4b091cbb92a937f63fd07d301f Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Mon, 18 Dec 2023 17:31:54 +0000 Subject: [PATCH] Use @llvm.sqrt.* intrincs rather than rolling our own. All targets that we currently support have hardware floating point square root instructions that are more accurate than our own implementation. --- .../internal/check_surrounding_values.h | 71 ------ .../abacus/include/abacus/internal/sqrt.h | 29 +++ .../include/abacus/internal/sqrt_unsafe.h | 156 ------------- .../builtins/abacus/source/CMakeLists.txt | 87 +++++-- .../abacus/source/abacus_math/CMakeLists.txt | 7 +- .../abacus/source/abacus_math/acos.cpp | 8 +- .../abacus/source/abacus_math/acosh.cpp | 15 +- .../abacus/source/abacus_math/acospi.cpp | 8 +- .../abacus/source/abacus_math/asin.cpp | 12 +- .../abacus/source/abacus_math/asinpi.cpp | 7 +- .../abacus/source/abacus_math/half_sqrt.cpp | 67 +----- .../abacus/source/abacus_math/hypot.cpp | 4 +- .../source/abacus_math/inplace_sqrt.cpp | 58 +++++ .../source/abacus_math/inplace_sqrt.ll.in | 182 +++++++++++++++ .../abacus/source/abacus_math/sqrt.cpp | 213 +++++++----------- 15 files changed, 455 insertions(+), 469 deletions(-) delete mode 100644 modules/compiler/builtins/abacus/include/abacus/internal/check_surrounding_values.h create mode 100644 modules/compiler/builtins/abacus/include/abacus/internal/sqrt.h delete mode 100644 modules/compiler/builtins/abacus/include/abacus/internal/sqrt_unsafe.h create mode 100644 modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.cpp create mode 100644 modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.ll.in diff --git a/modules/compiler/builtins/abacus/include/abacus/internal/check_surrounding_values.h b/modules/compiler/builtins/abacus/include/abacus/internal/check_surrounding_values.h deleted file mode 100644 index 4aa0bdda2..000000000 --- a/modules/compiler/builtins/abacus/include/abacus/internal/check_surrounding_values.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) Codeplay Software Limited -// -// Licensed under the Apache License, Version 2.0 (the "License") with LLVM -// Exceptions; you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. -// -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef __ABACUS_CHECK_SURROUNDING_VALUES_H__ -#define __ABACUS_CHECK_SURROUNDING_VALUES_H__ - -#include -#include -#include -#include -#include - -namespace abacus { -namespace internal { -// Checks the 3 values surrounding sqrt_estimate and returns the one -// mathematically closest to sqrt(input) -template -inline T check_surrounding_values(const T &input, const T &sqrt_estimate) { - typedef typename TypeTraits::SignedType SignedType; - typedef typename TypeTraits::UnsignedType UnsignedType; - // We now need to check if the values on either side of sqrt_value are better - // approximations of sqrt(input) seems to be up to one bit off - const UnsignedType sqrtAs = - abacus::detail::cast::as(sqrt_estimate); - const T sqrt_value_lo = abacus::detail::cast::as(sqrtAs - 1); - const T sqrt_value_hi = abacus::detail::cast::as(sqrtAs + 1); - - // Calculate 0.5*(sqrt_estimate^2 - input) very accurately for each of the 3 - // possibilities: - T can1sq_lo; - const T can1sq_hi = abacus::internal::multiply_exact_unsafe( - sqrt_value_lo, sqrt_value_lo, &can1sq_lo); - const T term1 = ((can1sq_hi - input) + can1sq_lo) * 0.5; - - T can2sq_lo; - const T can2sq_hi = abacus::internal::multiply_exact_unsafe( - sqrt_estimate, sqrt_estimate, &can2sq_lo); - const T term2 = ((can2sq_hi - input) + can2sq_lo) * 0.5; - - T can3sq_lo; - const T can3sq_hi = abacus::internal::multiply_exact_unsafe( - sqrt_value_hi, sqrt_value_hi, &can3sq_lo); - const T term3 = ((can3sq_hi - input) + can3sq_lo) * 0.5; - - T result = sqrt_estimate; - - const SignedType c1 = term2 + term3 < input - sqrt_value_hi * sqrt_estimate; - result = __abacus_select(result, sqrt_value_hi, c1); - - const SignedType c2 = term1 + term2 >= input - sqrt_value_lo * sqrt_estimate; - result = __abacus_select(result, sqrt_value_lo, c2); - - return result; -} -} // namespace internal -} // namespace abacus - -#endif //__ABACUS_CHECK_SURROUNDING_VALUES_H__ diff --git a/modules/compiler/builtins/abacus/include/abacus/internal/sqrt.h b/modules/compiler/builtins/abacus/include/abacus/internal/sqrt.h new file mode 100644 index 000000000..e4865ec27 --- /dev/null +++ b/modules/compiler/builtins/abacus/include/abacus/internal/sqrt.h @@ -0,0 +1,29 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef __ABACUS_INTERNAL_SQRT_H__ +#define __ABACUS_INTERNAL_SQRT_H__ + +#include + +namespace abacus { +namespace internal { +template +T sqrt(T); +} // namespace internal +} // namespace abacus + +#endif //__ABACUS_INTERNAL_SQRT_H__ diff --git a/modules/compiler/builtins/abacus/include/abacus/internal/sqrt_unsafe.h b/modules/compiler/builtins/abacus/include/abacus/internal/sqrt_unsafe.h deleted file mode 100644 index ef2fa466a..000000000 --- a/modules/compiler/builtins/abacus/include/abacus/internal/sqrt_unsafe.h +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (C) Codeplay Software Limited -// -// Licensed under the Apache License, Version 2.0 (the "License") with LLVM -// Exceptions; you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. -// -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef __ABACUS_INTERNAL_SQRT_UNSAFE_H__ -#define __ABACUS_INTERNAL_SQRT_UNSAFE_H__ - -#include -#include -#include -#include -#include -#include - -namespace abacus { -namespace internal { -template -struct sqrt_unsafe_helper; - -#ifdef __CA_BUILTINS_HALF_SUPPORT -template <> -struct sqrt_unsafe_helper { - template - static T _(const T &x) { - typedef typename TypeTraits::LargerType LargerType; - typedef typename TypeTraits::SignedType SignedType; - - // The following algorithm is more complex than it may appear on the - // surface. - // In essence there is a rather famous floating point bithack to get a - // surprisingly good approximation to 1/sqrt(x), and due to the nature of - // 1/sqrt(x) there also exists a very computer friendly way of computing - // extra bits of precision off that initial approximation, via - // Newton-Rhapson iteration. (a way of computing roots of polynomials) - // So we use this floating point hack to get a good initial guess, and then - // do several Newton-Rhapson iterations. Finally because we want sqrt(x) - // as opposed to 1/sqrt(x) at the end we just multiply by x. - // - // Because sqrt(x) is expected to be corectly rounded (aka return the exact - // answer as closly as possible), for the last iteration and subsequent - // multiplication by x we do it in 32 bit precision. You can do it in 16 bit - // but you're just faking 32 bit simulation, which though untested can only - // be slower - - // See more information on this algorithm at - // https://en.wikipedia.org/wiki/Fast_inverse_square_root, - // and a rather excellent derivation/discussion at - // http://h14s.p5r.org/2012/09/0x5f3759df.html?mwh=1 - - T result = rsqrt_initial_guess(x); - - // Newton-Raphson Method times 2 - // Approximate 1/sqrt(x) - result = 0.5f16 * result * (3.0f16 - result * (result * x)); - result = 0.5f16 * result * (3.0f16 - (result * result) * x); - - // Do one iteration in 32-bit precision as we need 0-ulp for this function: - LargerType result_f = abacus::detail::cast::convert(result); - LargerType x_f = abacus::detail::cast::convert(x); - - result_f = LargerType(0.5f) * result_f * - (LargerType(3.0f) - (result_f * result_f) * x_f); - - // 1/sqrt(x) -> sqrt(x) and convert back - result = abacus::detail::cast::convert(result_f * x_f); - - result = __abacus_select( - result, x, abacus::detail::cast::convert(x == 0.0f16)); - - // NAN returns: - result = - __abacus_select(result, FPShape::NaN(), - abacus::detail::cast::convert(x < 0.0f16)); - return result; - } -}; -#endif //__CA_BUILTINS_HALF_SUPPORT - -template <> -struct sqrt_unsafe_helper { - template - static T _(const T &x) { - T result = rsqrt_initial_guess(x); - - // Newton-Raphson Method times 3 - // Approximate 1/sqrt(x) - result = (T)0.5f * result * ((T)3.0f - result * (result * x)); - result = (T)0.5f * result * ((T)3.0f - result * (result * x)); - - // 1/sqrt(x) -> sqrt(x) - // This is rolled into the final Newton-Raphson iteration so it saves us a - // multiplication, and improves accuracy. - T rx = result * x; - result = (T)0.5f * rx * ((T)3.0f - result * rx); - - result = __abacus_select(result, x, x == 0.0f); - - // This selects NAN when x is negative (and non-zero) or NAN. - return __abacus_select(ABACUS_NAN, result, x >= 0.0f); - } -}; - -#ifdef __CA_BUILTINS_DOUBLE_SUPPORT -template <> -struct sqrt_unsafe_helper { - template - static T _(const T &x) { - T rsqrt_value = rsqrt_initial_guess(x); - - // Newton Rhapson iterations: - rsqrt_value = - (T)0.5 * rsqrt_value * ((T)3.0 - rsqrt_value * (rsqrt_value * x)); - rsqrt_value = - (T)0.5 * rsqrt_value * ((T)3.0 - rsqrt_value * (rsqrt_value * x)); - rsqrt_value = - (T)0.5 * rsqrt_value * ((T)3.0 - rsqrt_value * (rsqrt_value * x)); - rsqrt_value = - (T)0.5 * rsqrt_value * ((T)3.0 - rsqrt_value * (rsqrt_value * x)); - - // Todo. Maybe just do this exactly? - // We calculate the square root from the inverse square root by multiplying - // by `x` so we might as well absorb that into the final Newton Raphson - // iteration and pull it out as a common subexpression. Not only does this - // save us a multiplication, it gives us a little more numerical accuracy. - // Sadly it is still not quite accurate enough for the 0 ULPs we require, - // but it does get us to within 1 ULP. - T rx = rsqrt_value * x; - T sqrt_value = (T)0.5 * rx * ((T)3.0 - rsqrt_value * rx); - - T best_value = check_surrounding_values(x, sqrt_value); - - return best_value; - } -}; -#endif // __CA_BUILTINS_DOUBLE_SUPPORT - -template -inline T sqrt_unsafe(const T &x) { - typedef typename TypeTraits::ElementType ElementType; - return sqrt_unsafe_helper::_(x); -} -} // namespace internal -} // namespace abacus -#endif //__ABACUS_INTERNAL_SQRT_UNSAFE_H__ diff --git a/modules/compiler/builtins/abacus/source/CMakeLists.txt b/modules/compiler/builtins/abacus/source/CMakeLists.txt index 907b71f5f..56546f4a5 100644 --- a/modules/compiler/builtins/abacus/source/CMakeLists.txt +++ b/modules/compiler/builtins/abacus/source/CMakeLists.txt @@ -33,12 +33,16 @@ add_subdirectory(abacus_memory) add_subdirectory(abacus_misc) add_subdirectory(abacus_relational) -set(abacus_sources +set(abacus_sources_host ${abacus_cast_sources} ${abacus_common_sources} ${abacus_extra_sources} - ${abacus_geometric_sources} ${abacus_integer_sources} ${abacus_math_sources} + ${abacus_geometric_sources} ${abacus_integer_sources} ${abacus_math_sources_host} + ${abacus_memory_sources} ${abacus_misc_sources} ${abacus_relational_sources}) +set(abacus_sources_device + ${abacus_cast_sources} ${abacus_common_sources} ${abacus_extra_sources} + ${abacus_geometric_sources} ${abacus_integer_sources} ${abacus_math_sources_device} ${abacus_memory_sources} ${abacus_misc_sources} ${abacus_relational_sources}) -add_library(abacus_static STATIC ${abacus_sources}) +add_library(abacus_static STATIC ${abacus_sources_host}) add_dependencies(abacus_static abacus_generate) target_compile_definitions(abacus_static PRIVATE "ABACUS_ENABLE_OPENCL_1_2_BUILTINS" @@ -46,7 +50,7 @@ target_compile_definitions(abacus_static PRIVATE # If extra ComputeAorta commands exist, use them. if(COMMAND add_ca_tidy) - add_ca_tidy(abacus_static ${abacus_sources}) + add_ca_tidy(abacus_static ${abacus_sources_host}) if(TARGET tidy-abacus_static) add_dependencies(tidy-abacus_static abacus_generate) endif() @@ -157,15 +161,32 @@ if(${ABACUS_BUILD_WITH_RUNTIME_TOOLS}) set(cap_suf "${cap_suf}_fp16") endif() - foreach(SOURCE IN LISTS abacus_sources) - string(LENGTH "${CMAKE_CURRENT_SOURCE_DIR}/" LENGTH) - string(SUBSTRING "${SOURCE}" ${LENGTH} -1 SUB_SOURCE) - string(REGEX REPLACE "^.*\\.\(c[lp]*\)$" "\\1" SOURCE_TYPE "${SOURCE}") + foreach(SOURCE IN LISTS abacus_sources_device) + file(RELATIVE_PATH RELSOURCE "${CMAKE_CURRENT_SOURCE_DIR}" "${SOURCE}") + get_filename_component(SOURCE_DIR "${RELSOURCE}" DIRECTORY) + get_filename_component(SOURCE_NAME "${SOURCE}" NAME) + get_filename_component(SOURCE_TYPE "${SOURCE}" LAST_EXT) + if(SOURCE_TYPE STREQUAL ".in") + if("${triple}" STREQUAL spir64-unknown-unknown) + set(target "\ +target triple = \"spir64-unknown-unknown\"\n\ +target datalayout = \"e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024\"") + elseif("${triple}" STREQUAL spir-unknown-unknown) + set(target "\ +target triple = \"spir-unknown-unknown\"\n\ +target datalayout = \"e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024\"") + else() + message(FATAL_ERROR "Missing target definition for target '${triple}'") + endif() + get_filename_component(SOURCE_NAME "${SOURCE}" NAME_WLE) + configure_file(${SOURCE} "${CMAKE_CURRENT_BINARY_DIR}/bc/${triple}${cap_suf}/${SOURCE_DIR}/${SOURCE_NAME}" @ONLY) + set(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/bc/${triple}${cap_suf}/${SOURCE_DIR}/${SOURCE_NAME}") + get_filename_component(SOURCE_TYPE "${SOURCE}" LAST_EXT) + endif() - set(XTYPE "${SOURCE_TYPE}") # Language for -x option. + set(XTYPE "") # Language for -x option. set(XOPTS "") # Extra language specific options. - set(OUTPUT - "${CMAKE_CURRENT_BINARY_DIR}/bc/${triple}${cap_suf}/${SUB_SOURCE}.bc") + set(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/bc/${triple}${cap_suf}/${SOURCE_DIR}/${SOURCE_NAME}.bc") # The 'Unix Makefiles' and 'MinGW Makefiles' generators do not # automatically create the output directories the bc files are output to, @@ -176,28 +197,46 @@ if(${ABACUS_BUILD_WITH_RUNTIME_TOOLS}) file(MAKE_DIRECTORY ${OUTPUT_DIR}) endif() - if("${SOURCE_TYPE}" STREQUAL "cpp") + set(DEPFILE_ARGS -dependency-file "${OUTPUT}.d" -MT "${OUTPUT}" -sys-header-deps) + if(SOURCE_TYPE STREQUAL ".cpp") set(XTYPE "c++") set(XOPTS "-std=c++11") - elseif("${SOURCE_TYPE}" STREQUAL "cl") + elseif(SOURCE_TYPE STREQUAL ".cl") + set(XTYPE "cl") set(XOPTS "-cl-std=CL1.2") + elseif(SOURCE_TYPE STREQUAL ".ll") + set(XTYPE "ir") + set(DEPFILE_ARGS) + else() + message(FATAL_ERROR "Missing handling for source type '${SOURCE_TYPE}'") endif() # This is required to correctly expose all the builtins since some of their # defintions come from this file. set(ENABLE_OPENCL_BUILTINS "-DABACUS_ENABLE_OPENCL_1_2_BUILTINS -DABACUS_ENABLE_OPENCL_3_0_BUILTINS") - add_custom_command( - OUTPUT "${OUTPUT}" - COMMAND ${RUNTIME_COMPILER} -cc1 -x ${XTYPE} ${XOPTS} -triple ${triple} - ${ABACUS_RUNTIME_OPTIONS} ${BUILTINS_EXTRA_OPTIONS} ${ENABLE_OPENCL_BUILTINS} - -include "${RUNTIME_CLHEADER}" - -emit-llvm-bc -o "${OUTPUT}" - -dependency-file "${OUTPUT}.d" -MT "${OUTPUT}" -sys-header-deps - "${SOURCE}" - DEPENDS "${SOURCE}" "${RUNTIME_COMPILER}" "${RUNTIME_CLHEADER}" - abacus_generate ${ABACUS_GENERATED_FILES} - DEPFILE "${OUTPUT}.d") + if(DEFINED DEPFILE_ARGS) + add_custom_command( + OUTPUT "${OUTPUT}" + COMMAND ${RUNTIME_COMPILER} -cc1 -x ${XTYPE} ${XOPTS} -triple ${triple} + ${ABACUS_RUNTIME_OPTIONS} ${BUILTINS_EXTRA_OPTIONS} ${ENABLE_OPENCL_BUILTINS} + -include "${RUNTIME_CLHEADER}" + -emit-llvm-bc -o "${OUTPUT}" ${DEPFILE_ARGS} + "${SOURCE}" + DEPENDS "${SOURCE}" "${RUNTIME_COMPILER}" "${RUNTIME_CLHEADER}" + abacus_generate ${ABACUS_GENERATED_FILES} + DEPFILE "${OUTPUT}.d") + else() + add_custom_command( + OUTPUT "${OUTPUT}" + COMMAND ${RUNTIME_COMPILER} -cc1 -x ${XTYPE} ${XOPTS} -triple ${triple} + ${ABACUS_RUNTIME_OPTIONS} ${BUILTINS_EXTRA_OPTIONS} ${ENABLE_OPENCL_BUILTINS} + -include "${RUNTIME_CLHEADER}" + -emit-llvm-bc -o "${OUTPUT}" + "${SOURCE}" + DEPENDS "${SOURCE}" "${RUNTIME_COMPILER}" "${RUNTIME_CLHEADER}" + abacus_generate ${ABACUS_GENERATED_FILES}) + endif() set(ALL_BCS "${ALL_BCS};${OUTPUT}") endforeach() diff --git a/modules/compiler/builtins/abacus/source/abacus_math/CMakeLists.txt b/modules/compiler/builtins/abacus/source/abacus_math/CMakeLists.txt index 5289ce3db..b990bdf77 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/CMakeLists.txt +++ b/modules/compiler/builtins/abacus/source/abacus_math/CMakeLists.txt @@ -115,5 +115,10 @@ set(abacus_math_sources ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tanpi.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tgamma.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/trunc.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/trunc.cpp) +set(abacus_math_sources_host ${abacus_math_sources} + ${CMAKE_CURRENT_SOURCE_DIR}/inplace_sqrt.cpp + PARENT_SCOPE) +set(abacus_math_sources_device ${abacus_math_sources} + ${CMAKE_CURRENT_SOURCE_DIR}/inplace_sqrt.ll.in PARENT_SCOPE) diff --git a/modules/compiler/builtins/abacus/source/abacus_math/acos.cpp b/modules/compiler/builtins/abacus/source/abacus_math/acos.cpp index ae72e834d..58840983f 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/acos.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/acos.cpp @@ -24,7 +24,7 @@ #include #endif // __CA_BUILTINS_DOUBLE_SUPPORT #include -#include +#include #ifdef __CA_BUILTINS_HALF_SUPPORT #include #include @@ -134,7 +134,7 @@ abacus_float ABACUS_API __abacus_acos(abacus_float x) { // get acos: abacus_float result = ans; if (interval < 12 && interval != 8 && interval != 7) { - result = abacus::internal::sqrt_unsafe(ans); + result = abacus::internal::sqrt(ans); } return (x > 0.0f) ? result : ABACUS_PI_F - result; @@ -163,7 +163,7 @@ T acos(const T x) { } T result = - __abacus_select(ans, abacus::internal::sqrt_unsafe(ans), + __abacus_select(ans, abacus::internal::sqrt(ans), (interval < 12) & (interval != 8) & (interval != 7)); result = __abacus_select((T)ABACUS_PI_F - result, result, x > 0); @@ -219,7 +219,7 @@ T ABACUS_API acos_half(T x) { T ansBig = xAbs * abacus::internal::horner_polynomial(xAbs, __codeplay_acos_1); - ansBig = abacus::internal::sqrt_unsafe(ansBig); + ansBig = abacus::internal::sqrt(ansBig); ans = __abacus_select(ans, ansBig, xBig); ans = __abacus_select(ans, ABACUS_PI_H - ans, SignedType(x < 0.0f16)); diff --git a/modules/compiler/builtins/abacus/source/abacus_math/acosh.cpp b/modules/compiler/builtins/abacus/source/abacus_math/acosh.cpp index 07c4e60bb..94bda7748 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/acosh.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/acosh.cpp @@ -19,7 +19,7 @@ #include #include #include -#include +#include namespace { template @@ -28,7 +28,7 @@ T acosh(const T x) { const T y = x - 1.0f; - T ex = y + abacus::internal::sqrt_unsafe(y * (y + 2.0f)); + T ex = y + abacus::internal::sqrt(y * (y + 2.0f)); // This can overflow so we check for large values const SignedType cond1 = y > 2.0e16f; @@ -61,9 +61,8 @@ T acosh_half(const T x) { // A small optimization for vectorized versions. Rather than call log multiple // times and select the right answer, we instead do a smaller branch to pick // the input value for log: - T log_input = - __abacus_select(x + abacus::internal::sqrt_unsafe(x * x - 1.0f16), x, - SignedType(x >= xBigBound)); + T log_input = __abacus_select(x + abacus::internal::sqrt(x * x - 1.0f16), x, + SignedType(x >= xBigBound)); log_input = __abacus_select(log_input, 2.0f16 * x, SignedType(x >= xBigBound && x < xOverflowBound)); @@ -75,7 +74,7 @@ T acosh_half(const T x) { __abacus_select(ans, ABACUS_LN2_H + ans, SignedType(x >= xOverflowBound)); const T small_return = - abacus::internal::sqrt_unsafe(x - 1.0f16) * + abacus::internal::sqrt(x - 1.0f16) * abacus::internal::horner_polynomial(x - T(1.0f16), _acoshH); ans = __abacus_select(ans, small_return, SignedType(x < T(2.0f16))); @@ -86,7 +85,7 @@ T acosh_half(const T x) { template <> abacus_half acosh_half(const abacus_half x) { if (x < 2.0f16) { - return abacus::internal::sqrt_unsafe(x - 1.0f16) * + return abacus::internal::sqrt(x - 1.0f16) * abacus::internal::horner_polynomial(x - 1.0f16, _acoshH); } @@ -104,7 +103,7 @@ abacus_half acosh_half(const abacus_half x) { return __abacus_log(2.0f16 * x); } - return __abacus_log(x + abacus::internal::sqrt_unsafe(x * x - 1.0f16)); + return __abacus_log(x + abacus::internal::sqrt(x * x - 1.0f16)); } abacus_half ABACUS_API __abacus_acosh(abacus_half x) { return acosh_half<>(x); } diff --git a/modules/compiler/builtins/abacus/source/abacus_math/acospi.cpp b/modules/compiler/builtins/abacus/source/abacus_math/acospi.cpp index 315fc9902..317daf41f 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/acospi.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/acospi.cpp @@ -16,8 +16,9 @@ #include #include +#include #include -#include +#include namespace { template @@ -56,8 +57,7 @@ T acospi_half(const T x) { ans = xAbs * __abacus_select(poly1, poly2, cond1); - ans = - __abacus_select(ans + 0.5f16, abacus::internal::sqrt_unsafe(ans), cond1); + ans = __abacus_select(ans + 0.5f16, abacus::internal::sqrt(ans), cond1); ans = __abacus_select(ans, 1.0f16 - ans, x < 0.0f16); @@ -74,7 +74,7 @@ abacus_half acospi_half(const abacus_half x) { ans = xAbs * abacus::internal::horner_polynomial( xAbs, __codeplay_acospi_coeff_halfH1); - ans = abacus::internal::sqrt_unsafe(ans); + ans = abacus::internal::sqrt(ans); } else { const abacus_half x2 = x * x; ans = xAbs * abacus::internal::horner_polynomial( diff --git a/modules/compiler/builtins/abacus/source/abacus_math/asin.cpp b/modules/compiler/builtins/abacus/source/abacus_math/asin.cpp index 09cf1da45..a5fc8c6f0 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/asin.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/asin.cpp @@ -24,7 +24,7 @@ #ifdef __CA_BUILTINS_DOUBLE_SUPPORT #include #endif // __CA_BUILTINS_DOUBLE_SUPPORT -#include +#include // see maple worksheet for how coefficients were derived. static ABACUS_CONSTANT abacus_float __codeplay_asin_coeff[80] = { @@ -136,7 +136,7 @@ abacus_float __abacus_asin(abacus_float x) { #endif if (interval < 9) { - ans = -abacus::internal::sqrt_unsafe(ans); + ans = -abacus::internal::sqrt(ans); ans += ABACUS_PI_2_F; } @@ -168,7 +168,7 @@ T asin_half(const T x) { T ansBig = xAbs * abacus::internal::horner_polynomial(xAbs, __codeplay_asin_1); - ansBig = -abacus::internal::sqrt_unsafe(ansBig) + ABACUS_PI_2_H; + ansBig = -abacus::internal::sqrt(ansBig) + ABACUS_PI_2_H; ansBig = __abacus_copysign(ansBig, x); ans = __abacus_select(ans, ansBig, xBig); @@ -200,7 +200,7 @@ abacus_half asin_half(const abacus_half x) { xAbs * abacus::internal::horner_polynomial( xAbs, __codeplay_asin_1); - ans = -abacus::internal::sqrt_unsafe(ans) + ABACUS_PI_2_H; + ans = -abacus::internal::sqrt(ans) + ABACUS_PI_2_H; return __abacus_copysign(ans, x); } @@ -235,8 +235,8 @@ T asin(const T x) { ans = __abacus_select(ans, poly, cond); } - T result = __abacus_select( - ans, -abacus::internal::sqrt_unsafe(ans) + ABACUS_PI_2_F, (interval < 9)); + T result = __abacus_select(ans, -abacus::internal::sqrt(ans) + ABACUS_PI_2_F, + (interval < 9)); result = __abacus_select(-result, result, x > 0); diff --git a/modules/compiler/builtins/abacus/source/abacus_math/asinpi.cpp b/modules/compiler/builtins/abacus/source/abacus_math/asinpi.cpp index 692681a1c..54ee13f9f 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/asinpi.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/asinpi.cpp @@ -16,8 +16,9 @@ #include #include +#include #include -#include +#include namespace { template @@ -53,7 +54,7 @@ T asinpi_half(const T x) { T ansCond = xAbs * abacus::internal::horner_polynomial( xAbs, __codeplay_asinpi_coeff_halfH1); - ansCond = -abacus::internal::sqrt_unsafe(ansCond) + 0.5f16; + ansCond = -abacus::internal::sqrt(ansCond) + 0.5f16; ansCond = __abacus_copysign(ansCond, x); ans = __abacus_select(ans, ansCond, cond1); @@ -75,7 +76,7 @@ abacus_half asinpi_half(const abacus_half x) { xAbs * abacus::internal::horner_polynomial( xAbs, __codeplay_asinpi_coeff_halfH1); - ans = -abacus::internal::sqrt_unsafe(ans) + 0.5f16; + ans = -abacus::internal::sqrt(ans) + 0.5f16; return __abacus_copysign(ans, x); } diff --git a/modules/compiler/builtins/abacus/source/abacus_math/half_sqrt.cpp b/modules/compiler/builtins/abacus/source/abacus_math/half_sqrt.cpp index 3ee0d4599..5051b613d 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/half_sqrt.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/half_sqrt.cpp @@ -14,77 +14,24 @@ // // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include #include -#include -#include -#include -#include -#include -#include -#include - -namespace { -template -T half_sqrt(const T x) { - typedef typename TypeTraits::SignedType SignedType; - typedef typename TypeTraits::UnsignedType UnsignedType; - - const UnsignedType xUint = abacus::detail::cast::as(x); - - // We use the exact bounds for rtz as it also works with ftz and the - // other rounding modes. - const SignedType xBig = (xUint >= 0x7e6eb3c0); - - const SignedType xSmall = abacus::internal::is_denorm(x); - - // xUint | F_HIDDEN_BIT sets the exponent to -126 - // 16777216 2^24 - // Multiplication exponent = -126 + 24 = -102 - // - // 0x0C800000 2^(-102) - // - // processedX (x * 2^24) + 2^-102 - 2^-102 - // (x * 2^24) - T processedX = __abacus_select( - x, - abacus::detail::cast::as(xUint | F_HIDDEN_BIT) * 16777216.0f - - __abacus_as_float(0x0C800000), - xSmall); - - processedX = __abacus_select(processedX, processedX * 0.0625f, xBig); - - // 1/sqrt(x) -> sqrt(x) - T ans = abacus::internal::rsqrt_unsafe(processedX, 2) * processedX; - - // note 0.000244140625 == (1 / 4096) - ans = __abacus_select(ans, ans * 0.000244140625f, xSmall); - - ans = __abacus_select(ans, ans * 4.0f, xBig); - - ans = __abacus_select(ans, ABACUS_INFINITY, __abacus_isinf(x)); - ans = __abacus_select(ans, ABACUS_NAN, __abacus_signbit(x)); - ans = __abacus_select(ans, x, (__abacus_fabs(x) == 0.0f) & ~xSmall); - - return ans; -} -} // namespace +#include abacus_float ABACUS_API __abacus_half_sqrt(abacus_float x) { - return half_sqrt<>(x); + return abacus::internal::sqrt(x); } abacus_float2 ABACUS_API __abacus_half_sqrt(abacus_float2 x) { - return half_sqrt<>(x); + return abacus::internal::sqrt(x); } abacus_float3 ABACUS_API __abacus_half_sqrt(abacus_float3 x) { - return half_sqrt<>(x); + return abacus::internal::sqrt(x); } abacus_float4 ABACUS_API __abacus_half_sqrt(abacus_float4 x) { - return half_sqrt<>(x); + return abacus::internal::sqrt(x); } abacus_float8 ABACUS_API __abacus_half_sqrt(abacus_float8 x) { - return half_sqrt<>(x); + return abacus::internal::sqrt(x); } abacus_float16 ABACUS_API __abacus_half_sqrt(abacus_float16 x) { - return half_sqrt<>(x); + return abacus::internal::sqrt(x); } diff --git a/modules/compiler/builtins/abacus/source/abacus_math/hypot.cpp b/modules/compiler/builtins/abacus/source/abacus_math/hypot.cpp index d34ea7cdb..6268450da 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/hypot.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/hypot.cpp @@ -98,7 +98,7 @@ struct hypot_helper { ans = fast(xAbs, yAbs); } else { // Uses slower __abacus_sqrt() which is currently implemented with - // some 32-bit float operations in abacus::internal::sqrt_unsafe + // some 32-bit float operations in abacus::internal::sqrt ans = accurate(xAbs, yAbs); } @@ -197,7 +197,7 @@ struct hypot_helper { const T yReduced = yAbs * inverse_pow; // NOTE: This call uses 32-bit float instruction as part of - // abacus::internal::sqrt_unsafe + // abacus::internal::sqrt T ans = __abacus_sqrt(xReduced * xReduced + yReduced * yReduced); ans *= similar_pow; return ans; diff --git a/modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.cpp b/modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.cpp new file mode 100644 index 000000000..94beade03 --- /dev/null +++ b/modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.cpp @@ -0,0 +1,58 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include + +namespace abacus { +namespace detail { +template +void inplace_sqrt(T &t) { + using ET = typename TypeTraits::ElementType; + for (ET *p = reinterpret_cast(&t), *e = reinterpret_cast(&t + 1); + p != e; ++p) { + *p = std::sqrt(*p); + } +} +} // namespace detail +} // namespace abacus + +#ifdef __CA_BUILTINS_HALF_SUPPORT +template void abacus::detail::inplace_sqrt(abacus_half &); +template void abacus::detail::inplace_sqrt(abacus_half2 &); +template void abacus::detail::inplace_sqrt(abacus_half3 &); +template void abacus::detail::inplace_sqrt(abacus_half4 &); +template void abacus::detail::inplace_sqrt(abacus_half8 &); +template void abacus::detail::inplace_sqrt(abacus_half16 &); +#endif // __CA_BUILTINS_HALF_SUPPORT + +template void abacus::detail::inplace_sqrt(abacus_float &); +template void abacus::detail::inplace_sqrt(abacus_float2 &); +template void abacus::detail::inplace_sqrt(abacus_float3 &); +template void abacus::detail::inplace_sqrt(abacus_float4 &); +template void abacus::detail::inplace_sqrt(abacus_float8 &); +template void abacus::detail::inplace_sqrt(abacus_float16 &); + +#ifdef __CA_BUILTINS_DOUBLE_SUPPORT +template void abacus::detail::inplace_sqrt(abacus_double &); +template void abacus::detail::inplace_sqrt(abacus_double2 &); +template void abacus::detail::inplace_sqrt(abacus_double3 &); +template void abacus::detail::inplace_sqrt(abacus_double4 &); +template void abacus::detail::inplace_sqrt(abacus_double8 &); +template void abacus::detail::inplace_sqrt(abacus_double16 &); +#endif // __CA_BUILTINS_DOUBLE_SUPPORT diff --git a/modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.ll.in b/modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.ll.in new file mode 100644 index 000000000..caae121a3 --- /dev/null +++ b/modules/compiler/builtins/abacus/source/abacus_math/inplace_sqrt.ll.in @@ -0,0 +1,182 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +@target@ + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDF16_EEvRT_(ptr dereferenceable(2) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load half, ptr %0 + %2 = call half @llvm.sqrt.f16(half %1) + store half %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv2_DF16_EEvRT_(ptr dereferenceable(4) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <2 x half>, ptr %0 + %2 = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %1) + store <2 x half> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv3_DF16_EEvRT_(ptr dereferenceable(6) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <3 x half>, ptr %0 + %2 = call <3 x half> @llvm.sqrt.v3f16(<3 x half> %1) + store <3 x half> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv4_DF16_EEvRT_(ptr dereferenceable(8) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <4 x half>, ptr %0 + %2 = call <4 x half> @llvm.sqrt.v4f16(<4 x half> %1) + store <4 x half> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv8_DF16_EEvRT_(ptr dereferenceable(16) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <8 x half>, ptr %0 + %2 = call <8 x half> @llvm.sqrt.v8f16(<8 x half> %1) + store <8 x half> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv16_DF16_EEvRT_(ptr dereferenceable(32) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <16 x half>, ptr %0 + %2 = call <16 x half> @llvm.sqrt.v16f16(<16 x half> %1) + store <16 x half> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIfEEvRT_(ptr dereferenceable(4) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load float, ptr %0 + %2 = call float @llvm.sqrt.f32(float %1) + store float %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv2_fEEvRT_(ptr dereferenceable(8) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <2 x float>, ptr %0 + %2 = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %1) + store <2 x float> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv3_fEEvRT_(ptr dereferenceable(12) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <3 x float>, ptr %0 + %2 = call <3 x float> @llvm.sqrt.v3f32(<3 x float> %1) + store <3 x float> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv4_fEEvRT_(ptr dereferenceable(16) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <4 x float>, ptr %0 + %2 = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %1) + store <4 x float> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv8_fEEvRT_(ptr dereferenceable(32) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <8 x float>, ptr %0 + %2 = call <8 x float> @llvm.sqrt.v8f32(<8 x float> %1) + store <8 x float> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv16_fEEvRT_(ptr dereferenceable(64) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <16 x float>, ptr %0 + %2 = call <16 x float> @llvm.sqrt.v16f32(<16 x float> %1) + store <16 x float> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIdEEvRT_(ptr dereferenceable(8) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load double, ptr %0 + %2 = call double @llvm.sqrt.f64(double %1) + store double %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv2_dEEvRT_(ptr dereferenceable(16) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <2 x double>, ptr %0 + %2 = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %1) + store <2 x double> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv3_dEEvRT_(ptr dereferenceable(24) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <3 x double>, ptr %0 + %2 = call <3 x double> @llvm.sqrt.v3f64(<3 x double> %1) + store <3 x double> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv4_dEEvRT_(ptr dereferenceable(32) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <4 x double>, ptr %0 + %2 = call <4 x double> @llvm.sqrt.v4f64(<4 x double> %1) + store <4 x double> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv8_dEEvRT_(ptr dereferenceable(64) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <8 x double>, ptr %0 + %2 = call <8 x double> @llvm.sqrt.v8f64(<8 x double> %1) + store <8 x double> %2, ptr %0 + ret void +} + +define hidden spir_func void @_ZN6abacus6detail12inplace_sqrtIDv16_dEEvRT_(ptr dereferenceable(128) %0) local_unnamed_addr mustprogress nounwind alwaysinline { +entry: + %1 = load <16 x double>, ptr %0 + %2 = call <16 x double> @llvm.sqrt.v16f64(<16 x double> %1) + store <16 x double> %2, ptr %0 + ret void +} + +declare half @llvm.sqrt.f16(half) +declare <2 x half> @llvm.sqrt.v2f16(<2 x half>) +declare <3 x half> @llvm.sqrt.v3f16(<3 x half>) +declare <4 x half> @llvm.sqrt.v4f16(<4 x half>) +declare <8 x half> @llvm.sqrt.v8f16(<8 x half>) +declare <16 x half> @llvm.sqrt.v16f16(<16 x half>) + +declare float @llvm.sqrt.f32(float) +declare <2 x float> @llvm.sqrt.v2f32(<2 x float>) +declare <3 x float> @llvm.sqrt.v3f32(<3 x float>) +declare <4 x float> @llvm.sqrt.v4f32(<4 x float>) +declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) +declare <16 x float> @llvm.sqrt.v16f32(<16 x float>) + +declare double @llvm.sqrt.f64(double) +declare <2 x double> @llvm.sqrt.v2f64(<2 x double>) +declare <3 x double> @llvm.sqrt.v3f64(<3 x double>) +declare <4 x double> @llvm.sqrt.v4f64(<4 x double>) +declare <8 x double> @llvm.sqrt.v8f64(<8 x double>) +declare <16 x double> @llvm.sqrt.v16f64(<16 x double>) diff --git a/modules/compiler/builtins/abacus/source/abacus_math/sqrt.cpp b/modules/compiler/builtins/abacus/source/abacus_math/sqrt.cpp index 9116fdec3..4ebc95747 100644 --- a/modules/compiler/builtins/abacus/source/abacus_math/sqrt.cpp +++ b/modules/compiler/builtins/abacus/source/abacus_math/sqrt.cpp @@ -14,148 +14,101 @@ // // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include #include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace { -template ::ElementType> -struct helper; - -#ifdef __CA_BUILTINS_HALF_SUPPORT +namespace abacus { +namespace detail { template -struct helper { - static T _(const T x) { - typedef typename TypeTraits::UnsignedType UnsignedType; - typedef typename TypeTraits::SignedType SignedType; - - T processedX = x; - - // Un-denormalise any small numbers so it passes through the algorithm - // correctly - // Do this by multiply by 2^10 if less than a certain threshold (0x0800), - // and multiply by 2^-5 at the end - // This insures intermediate results in sqrt_unsafe are non-denormal - SignedType xSmall = - abacus::detail::cast::as(x) < UnsignedType(0x0800); - - // TODO might need a bit more on hardware that doesn't support denormals - processedX = __abacus_select(x, x * 1024.0f16, xSmall); - - /* -------------- For devices without denormalsupport--------------------- - //To multiply a denorm by 2^10 the bithack way: - //OR in an exponent (In this case 0x2C00), for a denormal x this gives us -x*2^10 + 2^4 - UnsignedType denorm_hack = abacus::detail::cast::as(x) | -(UnsignedType) 0x2C00; +void inplace_sqrt(T &); +} // namespace detail - T denorm_scaled = abacus::detail::cast::as(denorm_hack) - (T)0.0625f16; -//2^-4 - - processedX = __abacus_select(processedX, denorm_scaled, -abacus::detail::cast::convert(abacus::detail::cast::as(x) -< 0x0400)); --------------------------------------------------------------------------------*/ - - T ans = abacus::internal::sqrt_unsafe(processedX); - - ans = __abacus_select(ans, ans * 0.03125f16, xSmall); - - // This fabs is used to prevent an earlier branch check for x == INFINITY, - // as it happens the sqrt_unsafe returns -INFINTY in this case. - return __abacus_fabs(ans); - } -}; - -#endif // __CA_BUILTINS_HALF_SUPPORT - -template -struct helper { - static T _(const T x) { - typedef typename TypeTraits::SignedType SignedType; - - // We pre-condition the input by scaling it up or down, to avoid overflows - // and underflows/subnormals. - const SignedType xSmall = x < 1.0f; - - T processedX = x * __abacus_select((T)0.0625f, (T)16777216.0f, xSmall); - T ans = abacus::internal::sqrt_unsafe(processedX); - ans = ans * __abacus_select((T)4.0f, (T)0.000244140625f, xSmall); - - // sqrt_unsafe already correctly deals with zeroes, negatives and NANs. - // Only infinity is left to worry about. - ans = __abacus_select(ans, x, (x == ABACUS_INFINITY)); - - return ans; - } -}; - -#ifdef __CA_BUILTINS_DOUBLE_SUPPORT +namespace internal { template -struct helper { - static T _(const T x) { - typedef typename TypeTraits::SignedType SignedType; - - const SignedType xSmall = - x < __abacus_as_double((abacus_long)0x3CD0000000000000); - - const T inter = - __abacus_select(x, x * 1267650600228229401496703205376.0, xSmall); - - T result = abacus::internal::sqrt_unsafe(inter); - - result = __abacus_select( - result, result * __abacus_as_double((abacus_long)0x3CD0000000000000), - xSmall); - - const SignedType cond1 = (x == 0.0) | (x == ABACUS_INFINITY); - result = __abacus_select(result, x, cond1); - - const SignedType cond2 = (x < 0.0) | __abacus_isnan(x); - result = __abacus_select(result, (T)ABACUS_NAN, cond2); - - return result; - } -}; -#endif // __CA_BUILTINS_DOUBLE_SUPPORT - -template -T sqrt(const T x) { - return helper::_(x); +T sqrt(T x) { + abacus::detail::inplace_sqrt(x); + return x; } -} // namespace +} // namespace internal +} // namespace abacus #ifdef __CA_BUILTINS_HALF_SUPPORT -abacus_half ABACUS_API __abacus_sqrt(abacus_half x) { return sqrt<>(x); } -abacus_half2 ABACUS_API __abacus_sqrt(abacus_half2 x) { return sqrt<>(x); } -abacus_half3 ABACUS_API __abacus_sqrt(abacus_half3 x) { return sqrt<>(x); } -abacus_half4 ABACUS_API __abacus_sqrt(abacus_half4 x) { return sqrt<>(x); } -abacus_half8 ABACUS_API __abacus_sqrt(abacus_half8 x) { return sqrt<>(x); } -abacus_half16 ABACUS_API __abacus_sqrt(abacus_half16 x) { return sqrt<>(x); } +template abacus_half abacus::internal::sqrt(abacus_half); +template abacus_half2 abacus::internal::sqrt(abacus_half2); +template abacus_half3 abacus::internal::sqrt(abacus_half3); +template abacus_half4 abacus::internal::sqrt(abacus_half4); +template abacus_half8 abacus::internal::sqrt(abacus_half8); +template abacus_half16 abacus::internal::sqrt(abacus_half16); + +abacus_half ABACUS_API __abacus_sqrt(abacus_half x) { + return abacus::internal::sqrt<>(x); +} +abacus_half2 ABACUS_API __abacus_sqrt(abacus_half2 x) { + return abacus::internal::sqrt<>(x); +} +abacus_half3 ABACUS_API __abacus_sqrt(abacus_half3 x) { + return abacus::internal::sqrt<>(x); +} +abacus_half4 ABACUS_API __abacus_sqrt(abacus_half4 x) { + return abacus::internal::sqrt<>(x); +} +abacus_half8 ABACUS_API __abacus_sqrt(abacus_half8 x) { + return abacus::internal::sqrt<>(x); +} +abacus_half16 ABACUS_API __abacus_sqrt(abacus_half16 x) { + return abacus::internal::sqrt<>(x); +} #endif // __CA_BUILTINS_HALF_SUPPORT -abacus_float ABACUS_API __abacus_sqrt(abacus_float x) { return sqrt<>(x); } -abacus_float2 ABACUS_API __abacus_sqrt(abacus_float2 x) { return sqrt<>(x); } -abacus_float3 ABACUS_API __abacus_sqrt(abacus_float3 x) { return sqrt<>(x); } -abacus_float4 ABACUS_API __abacus_sqrt(abacus_float4 x) { return sqrt<>(x); } -abacus_float8 ABACUS_API __abacus_sqrt(abacus_float8 x) { return sqrt<>(x); } -abacus_float16 ABACUS_API __abacus_sqrt(abacus_float16 x) { return sqrt<>(x); } +template abacus_float abacus::internal::sqrt(abacus_float); +template abacus_float2 abacus::internal::sqrt(abacus_float2); +template abacus_float3 abacus::internal::sqrt(abacus_float3); +template abacus_float4 abacus::internal::sqrt(abacus_float4); +template abacus_float8 abacus::internal::sqrt(abacus_float8); +template abacus_float16 abacus::internal::sqrt(abacus_float16); + +abacus_float ABACUS_API __abacus_sqrt(abacus_float x) { + return abacus::internal::sqrt<>(x); +} +abacus_float2 ABACUS_API __abacus_sqrt(abacus_float2 x) { + return abacus::internal::sqrt<>(x); +} +abacus_float3 ABACUS_API __abacus_sqrt(abacus_float3 x) { + return abacus::internal::sqrt<>(x); +} +abacus_float4 ABACUS_API __abacus_sqrt(abacus_float4 x) { + return abacus::internal::sqrt<>(x); +} +abacus_float8 ABACUS_API __abacus_sqrt(abacus_float8 x) { + return abacus::internal::sqrt<>(x); +} +abacus_float16 ABACUS_API __abacus_sqrt(abacus_float16 x) { + return abacus::internal::sqrt<>(x); +} #ifdef __CA_BUILTINS_DOUBLE_SUPPORT -abacus_double ABACUS_API __abacus_sqrt(abacus_double x) { return sqrt<>(x); } -abacus_double2 ABACUS_API __abacus_sqrt(abacus_double2 x) { return sqrt<>(x); } -abacus_double3 ABACUS_API __abacus_sqrt(abacus_double3 x) { return sqrt<>(x); } -abacus_double4 ABACUS_API __abacus_sqrt(abacus_double4 x) { return sqrt<>(x); } -abacus_double8 ABACUS_API __abacus_sqrt(abacus_double8 x) { return sqrt<>(x); } +template abacus_double abacus::internal::sqrt(abacus_double); +template abacus_double2 abacus::internal::sqrt(abacus_double2); +template abacus_double3 abacus::internal::sqrt(abacus_double3); +template abacus_double4 abacus::internal::sqrt(abacus_double4); +template abacus_double8 abacus::internal::sqrt(abacus_double8); +template abacus_double16 abacus::internal::sqrt(abacus_double16); + +abacus_double ABACUS_API __abacus_sqrt(abacus_double x) { + return abacus::internal::sqrt<>(x); +} +abacus_double2 ABACUS_API __abacus_sqrt(abacus_double2 x) { + return abacus::internal::sqrt<>(x); +} +abacus_double3 ABACUS_API __abacus_sqrt(abacus_double3 x) { + return abacus::internal::sqrt<>(x); +} +abacus_double4 ABACUS_API __abacus_sqrt(abacus_double4 x) { + return abacus::internal::sqrt<>(x); +} +abacus_double8 ABACUS_API __abacus_sqrt(abacus_double8 x) { + return abacus::internal::sqrt<>(x); +} abacus_double16 ABACUS_API __abacus_sqrt(abacus_double16 x) { - return sqrt<>(x); + return abacus::internal::sqrt<>(x); } #endif // __CA_BUILTINS_DOUBLE_SUPPORT