-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LLVM's own type legalization promotes floating point operations without truncation of intermediate results. We rely on that truncation, so run our own pass to legalize manually before LLVM's legalization runs.
- Loading branch information
Showing
13 changed files
with
263 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
modules/compiler/test/lit/passes/manual_type_legalization.ll
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
; Copyright (C) Codeplay Software Limited | ||
; | ||
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM | ||
; Exceptions; you may not use this file except in compliance with the License. | ||
; You may obtain a copy of the License at | ||
; | ||
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt | ||
; | ||
; Unless required by applicable law or agreed to in writing, software | ||
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
; License for the specific language governing permissions and limitations | ||
; under the License. | ||
; | ||
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
; RUN: muxc --passes manual-type-legalization,verify -S %s | FileCheck %s | ||
|
||
; Make sure we use a triple that does not have half as a legal type. | ||
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" | ||
target triple = "spir64-unknown-unknown" | ||
|
||
; CHECK-LABEL: define half @f | ||
; CHECK-DAG: [[AEXT:%.*]] = fpext half %a to float | ||
; CHECK-DAG: [[BEXT:%.*]] = fpext half %b to float | ||
; CHECK-DAG: [[CEXT:%.*]] = fpext half %c to float | ||
; CHECK-DAG: [[DADD:%.*]] = fadd float [[AEXT]], [[BEXT]] | ||
; CHECK-DAG: [[DTRUNC:%.*]] = fptrunc float [[DADD]] to half | ||
; CHECK-DAG: [[DEXT:%.*]] = fpext half [[DTRUNC]] to float | ||
; CHECK-DAG: [[EADD:%.*]] = fadd float [[DEXT]], [[CEXT]] | ||
; CHECK-DAG: [[ETRUNC:%.*]] = fptrunc float [[EADD]] to half | ||
; CHECK: ret half [[ETRUNC]] | ||
define half @f(half %a, half %b, half %c) { | ||
entry: | ||
%d = fadd half %a, %b | ||
%e = fadd half %d, %c | ||
ret half %e | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
modules/compiler/utils/include/compiler/utils/manual_type_legalization_pass.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright (C) Codeplay Software Limited | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM | ||
// Exceptions; you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
// License for the specific language governing permissions and limitations | ||
// under the License. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED | ||
#define COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED | ||
|
||
#include <llvm/IR/PassManager.h> | ||
|
||
namespace compiler { | ||
namespace utils { | ||
|
||
/// Manual type legalization pass. | ||
/// | ||
/// On targets that do not natively support \c half, promote operations on \c | ||
/// half to \c float instead. | ||
/// | ||
/// When LLVM encounters floating point operations in a type it does not support | ||
/// natively, it extends its operands to an extended precision floating point | ||
/// type, performs the operation in that extended type, and rounds the result | ||
/// back to the original type. However, when it extends its operands to an | ||
/// extended precision floating point type, if an operand itself was a floating | ||
/// point operation that was also so extended, its rounding and re-extension are | ||
/// skipped. This causes issues for code that relies on exact rounding of | ||
/// intermediate results, which we avoid by manually doing this promition | ||
/// ourselves. | ||
/// | ||
/// Simply performing operations in a wider floating point type and rounding | ||
/// back to the narrow floating point type is not, in general, correct, due to | ||
/// double rounding. For addition, subtraction, and multiplications, \c float | ||
/// provides enough additional precision that double rounding is known not to be | ||
/// an issue. For other operations, this pass may generate incorrect results, | ||
/// but this should only happen in cases where letting the operation pass | ||
/// through to LLVM would result in the same incorrect results. | ||
struct ManualTypeLegalizationPass final | ||
: llvm::PassInfoMixin<ManualTypeLegalizationPass> { | ||
llvm::PreservedAnalyses run(llvm::Function &F, | ||
llvm::FunctionAnalysisManager &FAM); | ||
}; | ||
|
||
} // namespace utils | ||
} // namespace compiler | ||
|
||
#endif // COMPILER_UTILS_MANUAL_TYPE_LEGALIZATION_PASS_H_INCLUDED |
129 changes: 129 additions & 0 deletions
129
modules/compiler/utils/source/manual_type_legalization_pass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
// Copyright (C) Codeplay Software Limited | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM | ||
// Exceptions; you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
// License for the specific language governing permissions and limitations | ||
// under the License. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include <compiler/utils/manual_type_legalization_pass.h> | ||
#include <llvm/ADT/DenseMap.h> | ||
#include <llvm/Analysis/TargetTransformInfo.h> | ||
#include <llvm/IR/IRBuilder.h> | ||
#include <llvm/IR/InstrTypes.h> | ||
#include <llvm/IR/Instructions.h> | ||
#include <llvm/IR/Module.h> | ||
#include <llvm/IR/PassManager.h> | ||
#include <llvm/IR/Type.h> | ||
#include <llvm/Support/Casting.h> | ||
#include <llvm/TargetParser/Triple.h> | ||
#include <multi_llvm/llvm_version.h> | ||
|
||
using namespace llvm; | ||
|
||
PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run( | ||
Function &F, FunctionAnalysisManager &FAM) { | ||
auto &TTI = FAM.getResult<TargetIRAnalysis>(F); | ||
|
||
auto *HalfT = Type::getHalfTy(F.getContext()); | ||
auto *FloatT = Type::getFloatTy(F.getContext()); | ||
|
||
// Targets where half is a legal type do not need this pass. Targets where | ||
// half is promoted using "soft promotion" rules also do not need this pass. | ||
// We cannot reliably determine which targets these are, but that is okay, on | ||
// targets where this pass is not needed it does no harm, it merely wastes | ||
// time. | ||
llvm::Triple TT(F.getParent()->getTargetTriple()); | ||
if (TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV()) { | ||
return PreservedAnalyses::all(); | ||
} | ||
|
||
DenseMap<Value *, Value *> FPExtVals; | ||
IRBuilder<> B(F.getContext()); | ||
|
||
auto CreateFPExt = [&](Value *V, Type *ExtTy) { | ||
auto *&FPExt = FPExtVals[V]; | ||
if (!FPExt) { | ||
if (auto *I = dyn_cast<Instruction>(V)) { | ||
#if LLVM_VERSION_GREATER_EQUAL(18, 0) | ||
std::optional<BasicBlock::iterator> IPAD; | ||
IPAD = I->getInsertionPointAfterDef(); | ||
#else | ||
std::optional<Instruction *> IPAD; | ||
if (auto *IPADRaw = I->getInsertionPointAfterDef()) { | ||
IPAD = IPADRaw; | ||
} | ||
#endif | ||
assert(IPAD && | ||
"getInsertionPointAfterDef() should return an insertion point " | ||
"for all FP16 instructions"); | ||
B.SetInsertPoint(*IPAD); | ||
} else { | ||
B.SetInsertPointPastAllocas(&F); | ||
} | ||
FPExt = B.CreateFPExt(V, ExtTy, V->getName() + ".fpext"); | ||
} | ||
return FPExt; | ||
}; | ||
|
||
bool Changed = false; | ||
|
||
for (auto &BB : F) { | ||
for (auto &I : make_early_inc_range(BB)) { | ||
auto *BO = dyn_cast<BinaryOperator>(&I); | ||
if (!BO) continue; | ||
|
||
auto *T = BO->getType(); | ||
auto *VecT = dyn_cast<VectorType>(T); | ||
auto *ElT = VecT ? VecT->getElementType() : T; | ||
|
||
if (ElT != HalfT) continue; | ||
|
||
auto *LHS = BO->getOperand(0); | ||
auto *RHS = BO->getOperand(1); | ||
assert(LHS->getType() == T && | ||
"Expected matching types for floating point operation"); | ||
assert(RHS->getType() == T && | ||
"Expected matching types for floating point operation"); | ||
|
||
auto *ExtElT = FloatT; | ||
auto *ExtT = | ||
VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT; | ||
|
||
auto *LHSExt = CreateFPExt(LHS, ExtT); | ||
auto *RHSExt = CreateFPExt(RHS, ExtT); | ||
|
||
B.SetInsertPoint(BO); | ||
|
||
B.setFastMathFlags(BO->getFastMathFlags()); | ||
auto *OpExt = B.CreateBinOp(BO->getOpcode(), LHSExt, RHSExt, | ||
BO->getName() + ".fpext"); | ||
B.clearFastMathFlags(); | ||
|
||
auto *Trunc = B.CreateFPTrunc(OpExt, T); | ||
Trunc->takeName(BO); | ||
|
||
BO->replaceAllUsesWith(Trunc); | ||
BO->eraseFromParent(); | ||
|
||
Changed = true; | ||
} | ||
} | ||
|
||
PreservedAnalyses PA; | ||
if (Changed) { | ||
PA = PreservedAnalyses::none(); | ||
PA.preserveSet<CFGAnalyses>(); | ||
} else { | ||
PA = PreservedAnalyses::all(); | ||
} | ||
return PA; | ||
} |
Oops, something went wrong.