Skip to content

Commit

Permalink
Merge pull request #236 from hvdijk/manual-type-legalization
Browse files Browse the repository at this point in the history
Manual type legalization.
  • Loading branch information
hvdijk authored Dec 7, 2023
2 parents d0f09c8 + b3ad276 commit 379f43d
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 25 deletions.
14 changes: 14 additions & 0 deletions doc/modules/compiler/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,20 @@ about some of the above types, such as the type of images passed to any of the
it may be required to skip other passes such as the
``compiler::ImageArgumentSubstitutionPass``.

ManualTypeLegalizationPass
--------------------------

The ``ManualTypeLegalizationPass`` pass replaces ``half`` operations with
``float`` operations, inserting conversions as needed. It does this to work
around LLVM issue 73805, where LLVM's own legalization replaces whole chains of
operations rather than each operation individually, thus leaving out rounding
operations implied by the LLVM IR.

This replacement is only done on targets that promote ``half`` to ``float``
during type legalization. On targets where ``half`` is a native type, or where
``half`` is known to be promoted using "soft-promotion" rules, LLVM is presumed
to translate ``half`` correctly.

Metadata Utilities
------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <compiler/utils/define_mux_dma_pass.h>
#include <compiler/utils/encode_kernel_metadata_pass.h>
#include <compiler/utils/link_builtins_pass.h>
#include <compiler/utils/manual_type_legalization_pass.h>
#include <compiler/utils/metadata_analysis.h>
#include <compiler/utils/replace_address_space_qualifier_functions_pass.h>
#include <compiler/utils/replace_mem_intrinsics_pass.h>
Expand Down Expand Up @@ -156,6 +157,9 @@ llvm::ModulePassManager RefSiG1PassMachinery::getLateTargetPasses() {

addLLVMDefaultPerModulePipeline(PM, getPB(), options);

PM.addPass(llvm::createModuleToFunctionPassAdaptor(
compiler::utils::ManualTypeLegalizationPass()));

if (env_debug_prefix) {
// With all passes scheduled, add a callback pass to view the
// assembly/object file, if requested.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <compiler/utils/cl_builtin_info.h>
#include <compiler/utils/encode_kernel_metadata_pass.h>
#include <compiler/utils/link_builtins_pass.h>
#include <compiler/utils/manual_type_legalization_pass.h>
#include <compiler/utils/metadata_analysis.h>
#include <compiler/utils/replace_address_space_qualifier_functions_pass.h>
#include <compiler/utils/replace_local_module_scope_variables_pass.h>
Expand Down Expand Up @@ -159,6 +160,9 @@ llvm::ModulePassManager RefSiM1PassMachinery::getLateTargetPasses() {

addLLVMDefaultPerModulePipeline(PM, getPB(), options);

PM.addPass(llvm::createModuleToFunctionPassAdaptor(
compiler::utils::ManualTypeLegalizationPass()));

if (env_debug_prefix) {
// With all passes scheduled, add a callback pass to view the
// assembly/object file, if requested.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <compiler/utils/attributes.h>
#include <compiler/utils/encode_kernel_metadata_pass.h>
#include <compiler/utils/link_builtins_pass.h>
#include <compiler/utils/manual_type_legalization_pass.h>
#include <compiler/utils/metadata.h>
#include <compiler/utils/metadata_analysis.h>
#include <compiler/utils/replace_local_module_scope_variables_pass.h>
Expand Down Expand Up @@ -222,6 +223,9 @@ llvm::ModulePassManager {{cookiecutter.target_name.capitalize()}}PassMachinery::

addLLVMDefaultPerModulePipeline(PM, getPB(), options);

PM.addPass(llvm::createModuleToFunctionPassAdaptor(
compiler::utils::ManualTypeLegalizationPass()));

if (env_debug_prefix) {
// With all passes scheduled, add a callback pass to view the
// assembly/object file, if requested.
Expand Down
4 changes: 4 additions & 0 deletions modules/compiler/riscv/source/riscv_pass_machinery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <compiler/utils/attributes.h>
#include <compiler/utils/encode_kernel_metadata_pass.h>
#include <compiler/utils/link_builtins_pass.h>
#include <compiler/utils/manual_type_legalization_pass.h>
#include <compiler/utils/metadata.h>
#include <compiler/utils/metadata_analysis.h>
#include <compiler/utils/replace_address_space_qualifier_functions_pass.h>
Expand Down Expand Up @@ -255,6 +256,9 @@ llvm::ModulePassManager RiscvPassMachinery::getLateTargetPasses() {

addLLVMDefaultPerModulePipeline(PM, getPB(), options);

PM.addPass(llvm::createModuleToFunctionPassAdaptor(
compiler::utils::ManualTypeLegalizationPass()));

if (env_debug_prefix) {
// With all passes scheduled, add a callback pass to view the
// assembly/object file, if requested.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <compiler/utils/link_builtins_pass.h>
#include <compiler/utils/lower_to_mux_builtins_pass.h>
#include <compiler/utils/make_function_name_unique_pass.h>
#include <compiler/utils/manual_type_legalization_pass.h>
#include <compiler/utils/metadata_analysis.h>
#include <compiler/utils/optimal_builtin_replacement_pass.h>
#include <compiler/utils/pipeline_parse_helpers.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ FUNCTION_PASS("bit-shift-fixup", compiler::BitShiftFixupPass())
FUNCTION_PASS("ca-mem2reg", compiler::MemToRegPass())
FUNCTION_PASS("check-unsupported-types", compiler::CheckForUnsupportedTypesPass())
FUNCTION_PASS("combine-fpext-fptrunc", compiler::CombineFPExtFPTruncPass())
FUNCTION_PASS("manual-type-legalization", compiler::utils::ManualTypeLegalizationPass())
FUNCTION_PASS("software-div", compiler::SoftwareDivisionPass())
FUNCTION_PASS("replace-addrspace-fns", compiler::utils::ReplaceAddressSpaceQualifierFunctionsPass())
FUNCTION_PASS("remove-lifetime", compiler::utils::RemoveLifetimeIntrinsicsPass())
Expand Down
4 changes: 4 additions & 0 deletions modules/compiler/targets/host/source/HostPassMachinery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <compiler/utils/compute_local_memory_usage_pass.h>
#include <compiler/utils/define_mux_builtins_pass.h>
#include <compiler/utils/make_function_name_unique_pass.h>
#include <compiler/utils/manual_type_legalization_pass.h>
#include <compiler/utils/metadata.h>
#include <compiler/utils/metadata_analysis.h>
#include <compiler/utils/pipeline_parse_helpers.h>
Expand Down Expand Up @@ -330,6 +331,9 @@ llvm::ModulePassManager HostPassMachinery::getKernelFinalizationPasses(
compiler::utils::VectorizeMetadataAnalysis,
handler::VectorizeInfoMetadataHandler>());

PM.addPass(llvm::createModuleToFunctionPassAdaptor(
compiler::utils::ManualTypeLegalizationPass()));

return PM;
}

Expand Down
38 changes: 38 additions & 0 deletions modules/compiler/test/lit/passes/manual_type_legalization.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
; Copyright (C) Codeplay Software Limited
;
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
; Exceptions; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
; License for the specific language governing permissions and limitations
; under the License.
;
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

; RUN: muxc --passes manual-type-legalization,verify -S %s | FileCheck %s

; Make sure we use a triple that does not have half as a legal type.
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "spir64-unknown-unknown"

; CHECK-LABEL: define half @f
; CHECK-DAG: [[AEXT:%.*]] = fpext half %a to float
; CHECK-DAG: [[BEXT:%.*]] = fpext half %b to float
; CHECK-DAG: [[CEXT:%.*]] = fpext half %c to float
; CHECK-DAG: [[DADD:%.*]] = fadd float [[AEXT]], [[BEXT]]
; CHECK-DAG: [[DTRUNC:%.*]] = fptrunc float [[DADD]] to half
; CHECK-DAG: [[DEXT:%.*]] = fpext half [[DTRUNC]] to float
; CHECK-DAG: [[EADD:%.*]] = fadd float [[DEXT]], [[CEXT]]
; CHECK-DAG: [[ETRUNC:%.*]] = fptrunc float [[EADD]] to half
; CHECK: ret half [[ETRUNC]]
define half @f(half %a, half %b, half %c) {
entry:
%d = fadd half %a, %b
%e = fadd half %d, %c
ret half %e
}
2 changes: 2 additions & 0 deletions modules/compiler/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/llvm_global_mutex.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/lower_to_mux_builtins_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/make_function_name_unique_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/manual_type_legalization_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/mangling.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/memory_buffer.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/metadata.h
Expand Down Expand Up @@ -92,6 +93,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/source/lower_to_mux_builtins_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/make_function_name_unique_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/mangling.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/manual_type_legalization_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/metadata.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/metadata_analysis.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/metadata_hooks.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) Codeplay Software Limited
//
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM
// Exceptions; you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED
#define COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED

#include <llvm/IR/PassManager.h>

namespace compiler {
namespace utils {

/// Manual type legalization pass.
///
/// On targets that do not natively support \c half, promote operations on \c
/// half to \c float instead.
///
/// When LLVM encounters floating point operations in a type it does not support
/// natively, it extends its operands to an extended precision floating point
/// type, performs the operation in that extended type, and rounds the result
/// back to the original type. However, when it extends its operands to an
/// extended precision floating point type, if an operand itself was a floating
/// point operation that was also so extended, its rounding and re-extension are
/// skipped. This causes issues for code that relies on exact rounding of
/// intermediate results, which we avoid by manually doing this promition
/// ourselves.
///
/// Simply performing operations in a wider floating point type and rounding
/// back to the narrow floating point type is not, in general, correct, due to
/// double rounding. For addition, subtraction, and multiplications, \c float
/// provides enough additional precision that double rounding is known not to be
/// an issue. For other operations, this pass may generate incorrect results,
/// but this should only happen in cases where letting the operation pass
/// through to LLVM would result in the same incorrect results.
struct ManualTypeLegalizationPass final
: llvm::PassInfoMixin<ManualTypeLegalizationPass> {
llvm::PreservedAnalyses run(llvm::Function &F,
llvm::FunctionAnalysisManager &FAM);
};

} // namespace utils
} // namespace compiler

#endif // COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED
129 changes: 129 additions & 0 deletions modules/compiler/utils/source/manual_type_legalization_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (C) Codeplay Software Limited
//
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM
// Exceptions; you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <compiler/utils/manual_type_legalization_pass.h>
#include <llvm/ADT/DenseMap.h>
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/InstrTypes.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Type.h>
#include <llvm/Support/Casting.h>
#include <llvm/TargetParser/Triple.h>
#include <multi_llvm/llvm_version.h>

using namespace llvm;

PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run(
Function &F, FunctionAnalysisManager &FAM) {
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);

auto *HalfT = Type::getHalfTy(F.getContext());
auto *FloatT = Type::getFloatTy(F.getContext());

// Targets where half is a legal type do not need this pass. Targets where
// half is promoted using "soft promotion" rules also do not need this pass.
// We cannot reliably determine which targets these are, but that is okay, on
// targets where this pass is not needed it does no harm, it merely wastes
// time.
llvm::Triple TT(F.getParent()->getTargetTriple());
if (TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV()) {
return PreservedAnalyses::all();
}

DenseMap<Value *, Value *> FPExtVals;
IRBuilder<> B(F.getContext());

auto CreateFPExt = [&](Value *V, Type *ExtTy) {
auto *&FPExt = FPExtVals[V];
if (!FPExt) {
if (auto *I = dyn_cast<Instruction>(V)) {
#if LLVM_VERSION_GREATER_EQUAL(18, 0)
std::optional<BasicBlock::iterator> IPAD;
IPAD = I->getInsertionPointAfterDef();
#else
std::optional<Instruction *> IPAD;
if (auto *IPADRaw = I->getInsertionPointAfterDef()) {
IPAD = IPADRaw;
}
#endif
assert(IPAD &&
"getInsertionPointAfterDef() should return an insertion point "
"for all FP16 instructions");
B.SetInsertPoint(*IPAD);
} else {
B.SetInsertPointPastAllocas(&F);
}
FPExt = B.CreateFPExt(V, ExtTy, V->getName() + ".fpext");
}
return FPExt;
};

bool Changed = false;

for (auto &BB : F) {
for (auto &I : make_early_inc_range(BB)) {
auto *BO = dyn_cast<BinaryOperator>(&I);
if (!BO) continue;

auto *T = BO->getType();
auto *VecT = dyn_cast<VectorType>(T);
auto *ElT = VecT ? VecT->getElementType() : T;

if (ElT != HalfT) continue;

auto *LHS = BO->getOperand(0);
auto *RHS = BO->getOperand(1);
assert(LHS->getType() == T &&
"Expected matching types for floating point operation");
assert(RHS->getType() == T &&
"Expected matching types for floating point operation");

auto *ExtElT = FloatT;
auto *ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT;

auto *LHSExt = CreateFPExt(LHS, ExtT);
auto *RHSExt = CreateFPExt(RHS, ExtT);

B.SetInsertPoint(BO);

B.setFastMathFlags(BO->getFastMathFlags());
auto *OpExt = B.CreateBinOp(BO->getOpcode(), LHSExt, RHSExt,
BO->getName() + ".fpext");
B.clearFastMathFlags();

auto *Trunc = B.CreateFPTrunc(OpExt, T);
Trunc->takeName(BO);

BO->replaceAllUsesWith(Trunc);
BO->eraseFromParent();

Changed = true;
}
}

PreservedAnalyses PA;
if (Changed) {
PA = PreservedAnalyses::none();
PA.preserveSet<CFGAnalyses>();
} else {
PA = PreservedAnalyses::all();
}
return PA;
}
Loading

0 comments on commit 379f43d

Please sign in to comment.