From b3ad27684f21bfe3ff00c6b954ce79389a958123 Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Thu, 7 Dec 2023 14:39:42 +0000 Subject: [PATCH] Manual type legalization. LLVM's own type legalization promotes floating point operations without truncation of intermediate results. We rely on that truncation, so run our own pass to legalize manually before LLVM's legalization runs. --- doc/modules/compiler/utils.rst | 14 ++ .../source/refsi_pass_machinery.cpp | 4 + .../refsi_m1/source/refsi_pass_machinery.cpp | 4 + ...kiecutter.target_name}}_pass_machinery.cpp | 4 + .../riscv/source/riscv_pass_machinery.cpp | 4 + .../source/base_module_pass_machinery.cpp | 1 + .../base/source/base_module_pass_registry.def | 1 + .../targets/host/source/HostPassMachinery.cpp | 4 + .../lit/passes/manual_type_legalization.ll | 38 ++++++ modules/compiler/utils/CMakeLists.txt | 2 + .../utils/manual_type_legalization_pass.h | 56 ++++++++ .../source/manual_type_legalization_pass.cpp | 129 ++++++++++++++++++ .../cl/test/UnitCL/source/ktst_precision.cpp | 27 +--- 13 files changed, 263 insertions(+), 25 deletions(-) create mode 100644 modules/compiler/test/lit/passes/manual_type_legalization.ll create mode 100644 modules/compiler/utils/include/compiler/utils/manual_type_legalization_pass.h create mode 100644 modules/compiler/utils/source/manual_type_legalization_pass.cpp diff --git a/doc/modules/compiler/utils.rst b/doc/modules/compiler/utils.rst index dd28dfe10..465aa4219 100644 --- a/doc/modules/compiler/utils.rst +++ b/doc/modules/compiler/utils.rst @@ -1199,6 +1199,20 @@ about some of the above types, such as the type of images passed to any of the it may be required to skip other passes such as the ``compiler::ImageArgumentSubstitutionPass``. +ManualTypeLegalizationPass +-------------------------- + +The ``ManualTypeLegalizationPass`` pass replaces ``half`` operations with +``float`` operations, inserting conversions as needed. It does this to work +around LLVM issue 73805, where LLVM's own legalization replaces whole chains of +operations rather than each operation individually, thus leaving out rounding +operations implied by the LLVM IR. + +This replacement is only done on targets that promote ``half`` to ``float`` +during type legalization. On targets where ``half`` is a native type, or where +``half`` is known to be promoted using "soft-promotion" rules, LLVM is presumed +to translate ``half`` correctly. + Metadata Utilities ------------------ diff --git a/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_pass_machinery.cpp b/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_pass_machinery.cpp index 0f64c42f2..c58190f4f 100644 --- a/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_pass_machinery.cpp +++ b/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_pass_machinery.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -156,6 +157,9 @@ llvm::ModulePassManager RefSiG1PassMachinery::getLateTargetPasses() { addLLVMDefaultPerModulePipeline(PM, getPB(), options); + PM.addPass(llvm::createModuleToFunctionPassAdaptor( + compiler::utils::ManualTypeLegalizationPass())); + if (env_debug_prefix) { // With all passes scheduled, add a callback pass to view the // assembly/object file, if requested. diff --git a/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_pass_machinery.cpp b/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_pass_machinery.cpp index cadc63391..dd70156cb 100644 --- a/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_pass_machinery.cpp +++ b/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_pass_machinery.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -159,6 +160,9 @@ llvm::ModulePassManager RefSiM1PassMachinery::getLateTargetPasses() { addLLVMDefaultPerModulePipeline(PM, getPB(), options); + PM.addPass(llvm::createModuleToFunctionPassAdaptor( + compiler::utils::ManualTypeLegalizationPass())); + if (env_debug_prefix) { // With all passes scheduled, add a callback pass to view the // assembly/object file, if requested. diff --git a/modules/compiler/cookie/{{cookiecutter.target_name}}/source/{{cookiecutter.target_name}}_pass_machinery.cpp b/modules/compiler/cookie/{{cookiecutter.target_name}}/source/{{cookiecutter.target_name}}_pass_machinery.cpp index 0ea4f3960..5e8f565e3 100644 --- a/modules/compiler/cookie/{{cookiecutter.target_name}}/source/{{cookiecutter.target_name}}_pass_machinery.cpp +++ b/modules/compiler/cookie/{{cookiecutter.target_name}}/source/{{cookiecutter.target_name}}_pass_machinery.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -222,6 +223,9 @@ llvm::ModulePassManager {{cookiecutter.target_name.capitalize()}}PassMachinery:: addLLVMDefaultPerModulePipeline(PM, getPB(), options); + PM.addPass(llvm::createModuleToFunctionPassAdaptor( + compiler::utils::ManualTypeLegalizationPass())); + if (env_debug_prefix) { // With all passes scheduled, add a callback pass to view the // assembly/object file, if requested. diff --git a/modules/compiler/riscv/source/riscv_pass_machinery.cpp b/modules/compiler/riscv/source/riscv_pass_machinery.cpp index 9cfa3ecd0..73d0c55bb 100644 --- a/modules/compiler/riscv/source/riscv_pass_machinery.cpp +++ b/modules/compiler/riscv/source/riscv_pass_machinery.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -255,6 +256,9 @@ llvm::ModulePassManager RiscvPassMachinery::getLateTargetPasses() { addLLVMDefaultPerModulePipeline(PM, getPB(), options); + PM.addPass(llvm::createModuleToFunctionPassAdaptor( + compiler::utils::ManualTypeLegalizationPass())); + if (env_debug_prefix) { // With all passes scheduled, add a callback pass to view the // assembly/object file, if requested. diff --git a/modules/compiler/source/base/source/base_module_pass_machinery.cpp b/modules/compiler/source/base/source/base_module_pass_machinery.cpp index e0409a602..dede768e1 100644 --- a/modules/compiler/source/base/source/base_module_pass_machinery.cpp +++ b/modules/compiler/source/base/source/base_module_pass_machinery.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include diff --git a/modules/compiler/source/base/source/base_module_pass_registry.def b/modules/compiler/source/base/source/base_module_pass_registry.def index 605aee590..659f55c81 100644 --- a/modules/compiler/source/base/source/base_module_pass_registry.def +++ b/modules/compiler/source/base/source/base_module_pass_registry.def @@ -172,6 +172,7 @@ FUNCTION_PASS("bit-shift-fixup", compiler::BitShiftFixupPass()) FUNCTION_PASS("ca-mem2reg", compiler::MemToRegPass()) FUNCTION_PASS("check-unsupported-types", compiler::CheckForUnsupportedTypesPass()) FUNCTION_PASS("combine-fpext-fptrunc", compiler::CombineFPExtFPTruncPass()) +FUNCTION_PASS("manual-type-legalization", compiler::utils::ManualTypeLegalizationPass()) FUNCTION_PASS("software-div", compiler::SoftwareDivisionPass()) FUNCTION_PASS("replace-addrspace-fns", compiler::utils::ReplaceAddressSpaceQualifierFunctionsPass()) FUNCTION_PASS("remove-lifetime", compiler::utils::RemoveLifetimeIntrinsicsPass()) diff --git a/modules/compiler/targets/host/source/HostPassMachinery.cpp b/modules/compiler/targets/host/source/HostPassMachinery.cpp index f49107516..02eeaecb9 100644 --- a/modules/compiler/targets/host/source/HostPassMachinery.cpp +++ b/modules/compiler/targets/host/source/HostPassMachinery.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -330,6 +331,9 @@ llvm::ModulePassManager HostPassMachinery::getKernelFinalizationPasses( compiler::utils::VectorizeMetadataAnalysis, handler::VectorizeInfoMetadataHandler>()); + PM.addPass(llvm::createModuleToFunctionPassAdaptor( + compiler::utils::ManualTypeLegalizationPass())); + return PM; } diff --git a/modules/compiler/test/lit/passes/manual_type_legalization.ll b/modules/compiler/test/lit/passes/manual_type_legalization.ll new file mode 100644 index 000000000..e1cba700f --- /dev/null +++ b/modules/compiler/test/lit/passes/manual_type_legalization.ll @@ -0,0 +1,38 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +; RUN: muxc --passes manual-type-legalization,verify -S %s | FileCheck %s + +; Make sure we use a triple that does not have half as a legal type. +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "spir64-unknown-unknown" + +; CHECK-LABEL: define half @f +; CHECK-DAG: [[AEXT:%.*]] = fpext half %a to float +; CHECK-DAG: [[BEXT:%.*]] = fpext half %b to float +; CHECK-DAG: [[CEXT:%.*]] = fpext half %c to float +; CHECK-DAG: [[DADD:%.*]] = fadd float [[AEXT]], [[BEXT]] +; CHECK-DAG: [[DTRUNC:%.*]] = fptrunc float [[DADD]] to half +; CHECK-DAG: [[DEXT:%.*]] = fpext half [[DTRUNC]] to float +; CHECK-DAG: [[EADD:%.*]] = fadd float [[DEXT]], [[CEXT]] +; CHECK-DAG: [[ETRUNC:%.*]] = fptrunc float [[EADD]] to half +; CHECK: ret half [[ETRUNC]] +define half @f(half %a, half %b, half %c) { +entry: + %d = fadd half %a, %b + %e = fadd half %d, %c + ret half %e +} diff --git a/modules/compiler/utils/CMakeLists.txt b/modules/compiler/utils/CMakeLists.txt index 233d119a5..1666c5ec6 100644 --- a/modules/compiler/utils/CMakeLists.txt +++ b/modules/compiler/utils/CMakeLists.txt @@ -39,6 +39,7 @@ add_ca_library(compiler-utils STATIC ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/llvm_global_mutex.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/lower_to_mux_builtins_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/make_function_name_unique_pass.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/manual_type_legalization_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/mangling.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/memory_buffer.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/metadata.h @@ -92,6 +93,7 @@ add_ca_library(compiler-utils STATIC ${CMAKE_CURRENT_SOURCE_DIR}/source/lower_to_mux_builtins_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/make_function_name_unique_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/mangling.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/source/manual_type_legalization_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/metadata.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/metadata_analysis.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/metadata_hooks.cpp diff --git a/modules/compiler/utils/include/compiler/utils/manual_type_legalization_pass.h b/modules/compiler/utils/include/compiler/utils/manual_type_legalization_pass.h new file mode 100644 index 000000000..977c6be5e --- /dev/null +++ b/modules/compiler/utils/include/compiler/utils/manual_type_legalization_pass.h @@ -0,0 +1,56 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED +#define COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED + +#include + +namespace compiler { +namespace utils { + +/// Manual type legalization pass. +/// +/// On targets that do not natively support \c half, promote operations on \c +/// half to \c float instead. +/// +/// When LLVM encounters floating point operations in a type it does not support +/// natively, it extends its operands to an extended precision floating point +/// type, performs the operation in that extended type, and rounds the result +/// back to the original type. However, when it extends its operands to an +/// extended precision floating point type, if an operand itself was a floating +/// point operation that was also so extended, its rounding and re-extension are +/// skipped. This causes issues for code that relies on exact rounding of +/// intermediate results, which we avoid by manually doing this promition +/// ourselves. +/// +/// Simply performing operations in a wider floating point type and rounding +/// back to the narrow floating point type is not, in general, correct, due to +/// double rounding. For addition, subtraction, and multiplications, \c float +/// provides enough additional precision that double rounding is known not to be +/// an issue. For other operations, this pass may generate incorrect results, +/// but this should only happen in cases where letting the operation pass +/// through to LLVM would result in the same incorrect results. +struct ManualTypeLegalizationPass final + : llvm::PassInfoMixin { + llvm::PreservedAnalyses run(llvm::Function &F, + llvm::FunctionAnalysisManager &FAM); +}; + +} // namespace utils +} // namespace compiler + +#endif // COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED diff --git a/modules/compiler/utils/source/manual_type_legalization_pass.cpp b/modules/compiler/utils/source/manual_type_legalization_pass.cpp new file mode 100644 index 000000000..de1e1d244 --- /dev/null +++ b/modules/compiler/utils/source/manual_type_legalization_pass.cpp @@ -0,0 +1,129 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run( + Function &F, FunctionAnalysisManager &FAM) { + auto &TTI = FAM.getResult(F); + + auto *HalfT = Type::getHalfTy(F.getContext()); + auto *FloatT = Type::getFloatTy(F.getContext()); + + // Targets where half is a legal type do not need this pass. Targets where + // half is promoted using "soft promotion" rules also do not need this pass. + // We cannot reliably determine which targets these are, but that is okay, on + // targets where this pass is not needed it does no harm, it merely wastes + // time. + llvm::Triple TT(F.getParent()->getTargetTriple()); + if (TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV()) { + return PreservedAnalyses::all(); + } + + DenseMap FPExtVals; + IRBuilder<> B(F.getContext()); + + auto CreateFPExt = [&](Value *V, Type *ExtTy) { + auto *&FPExt = FPExtVals[V]; + if (!FPExt) { + if (auto *I = dyn_cast(V)) { +#if LLVM_VERSION_GREATER_EQUAL(18, 0) + std::optional IPAD; + IPAD = I->getInsertionPointAfterDef(); +#else + std::optional IPAD; + if (auto *IPADRaw = I->getInsertionPointAfterDef()) { + IPAD = IPADRaw; + } +#endif + assert(IPAD && + "getInsertionPointAfterDef() should return an insertion point " + "for all FP16 instructions"); + B.SetInsertPoint(*IPAD); + } else { + B.SetInsertPointPastAllocas(&F); + } + FPExt = B.CreateFPExt(V, ExtTy, V->getName() + ".fpext"); + } + return FPExt; + }; + + bool Changed = false; + + for (auto &BB : F) { + for (auto &I : make_early_inc_range(BB)) { + auto *BO = dyn_cast(&I); + if (!BO) continue; + + auto *T = BO->getType(); + auto *VecT = dyn_cast(T); + auto *ElT = VecT ? VecT->getElementType() : T; + + if (ElT != HalfT) continue; + + auto *LHS = BO->getOperand(0); + auto *RHS = BO->getOperand(1); + assert(LHS->getType() == T && + "Expected matching types for floating point operation"); + assert(RHS->getType() == T && + "Expected matching types for floating point operation"); + + auto *ExtElT = FloatT; + auto *ExtT = + VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT; + + auto *LHSExt = CreateFPExt(LHS, ExtT); + auto *RHSExt = CreateFPExt(RHS, ExtT); + + B.SetInsertPoint(BO); + + B.setFastMathFlags(BO->getFastMathFlags()); + auto *OpExt = B.CreateBinOp(BO->getOpcode(), LHSExt, RHSExt, + BO->getName() + ".fpext"); + B.clearFastMathFlags(); + + auto *Trunc = B.CreateFPTrunc(OpExt, T); + Trunc->takeName(BO); + + BO->replaceAllUsesWith(Trunc); + BO->eraseFromParent(); + + Changed = true; + } + } + + PreservedAnalyses PA; + if (Changed) { + PA = PreservedAnalyses::none(); + PA.preserveSet(); + } else { + PA = PreservedAnalyses::all(); + } + return PA; +} diff --git a/source/cl/test/UnitCL/source/ktst_precision.cpp b/source/cl/test/UnitCL/source/ktst_precision.cpp index d3e1f72ae..8eb339e90 100644 --- a/source/cl/test/UnitCL/source/ktst_precision.cpp +++ b/source/cl/test/UnitCL/source/ktst_precision.cpp @@ -507,12 +507,7 @@ TEST_P(HalfMathBuiltins, Precision_16_Half_Ceil) { TestAgainstRef<0_ULP>(ceil_ref); } -// TODO: CA-2731 -#ifdef __arm__ -TEST_P(HalfMathBuiltins, DISABLED_Precision_17_Half_sqrt) { -#else TEST_P(HalfMathBuiltins, Precision_17_Half_sqrt) { -#endif if (!UCL::hasHalfSupport(device)) { GTEST_SKIP(); } @@ -756,12 +751,7 @@ TEST_P(HalfMathBuiltins, Precision_31_Half_Nan) { this->RunGeneric1D(N / vec_width); } -// TODO: CA-2731 -#ifdef __arm__ -TEST_P(HalfMathBuiltins, DISABLED_Precision_32_Half_Mad) { -#else TEST_P(HalfMathBuiltins, Precision_32_Half_Mad) { -#endif if (!UCL::hasHalfSupport(device)) { GTEST_SKIP(); } @@ -1744,12 +1734,7 @@ TEST_P(HalfMathBuiltins, Precision_75_Half_lgammar_private) { TestAgainstIntReferenceArgRef(lgammar_ref); } -// TODO: CA-2731 -#ifdef __arm__ -TEST_P(HalfMathBuiltins, DISABLED_Precision_76_Half_tgamma) { -#else TEST_P(HalfMathBuiltins, Precision_76_Half_tgamma) { -#endif if (!UCL::hasHalfSupport(device)) { GTEST_SKIP(); } @@ -1789,9 +1774,7 @@ TEST_P(HalfMathBuiltins, Precision_79_Half_tanh) { TestAgainstRef<2_ULP>(tanh_ref); } -// TODO: CA-2731 -// TODO: CA-4735 -TEST_P(HalfMathBuiltins, DISABLED_Precision_80_Half_pow) { +TEST_P(HalfMathBuiltins, Precision_80_Half_pow) { if (!UCL::hasHalfSupport(device)) { GTEST_SKIP(); } @@ -1813,8 +1796,7 @@ TEST_P(HalfMathBuiltins, DISABLED_Precision_80_Half_pow) { TestAgainstRef<4_ULP>(pow_ref); } -// TODO: CA-2731 -// TODO: CA-4735 +// TODO: OCK-523 TEST_P(HalfMathBuiltins, DISABLED_Precision_81_Half_powr) { if (!UCL::hasHalfSupport(device)) { GTEST_SKIP(); @@ -1851,12 +1833,7 @@ TEST_P(HalfMathBuiltins, DISABLED_Precision_81_Half_powr) { TestAgainstRef<4_ULP>(powr_ref); } -#ifdef __arm__ -// TODO: CA-2731 -TEST_P(HalfMathBuiltins, DISABLED_Precision_82_Half_pown) { -#else TEST_P(HalfMathBuiltins, Precision_82_Half_pown) { -#endif if (!UCL::hasHalfSupport(device)) { GTEST_SKIP(); }