Skip to content

Commit 64bc42d

Browse files
[SYCL][SPIRVBE] Reland support for SPV_KHR_float_controls2 (#20612)
When it originally landed through pulldown, it caused some problems with SYCL E2E tests. It should be passing now. --------- Co-authored-by: Dmitry Sidorov <[email protected]>
1 parent ca1636c commit 64bc42d

23 files changed

+1327
-55
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,3 +593,31 @@ Group and Subgroup Operations
593593
For workgroup and subgroup operations, LLVM uses function calls to represent SPIR-V's
594594
group-based instructions. These builtins facilitate group synchronization, data sharing,
595595
and collective operations essential for efficient parallel computation.
596+
597+
SPIR-V Instructions Mapped to LLVM Metadata
598+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
599+
Some SPIR-V instructions don't have a direct equivalent in the LLVM IR language. To
600+
address this, the SPIR-V Target uses different specific LLVM named metadata to convey
601+
the necessary information. The SPIR-V specification allows multiple module-scope
602+
instructions, where as LLVM named metadata must be unique. Therefore, the encoding of
603+
such instructions has the following format:
604+
605+
.. code-block:: llvm
606+
607+
!spirv.<OpCodeName> = !{!<InstructionMetadata1>, !<InstructionMetadata2>, ..}
608+
!<InstructionMetadata1> = !{<Operand1>, <Operand2>, ..}
609+
!<InstructionMetadata2> = !{<Operand1>, <Operand2>, ..}
610+
611+
Below, you will find the mappings between SPIR-V instruction and their corresponding
612+
LLVM IR representations.
613+
614+
+--------------------+---------------------------------------------------------+
615+
| SPIR-V instruction | LLVM IR |
616+
+====================+=========================================================+
617+
| OpExecutionMode | .. code-block:: llvm |
618+
| | |
619+
| | !spirv.ExecutionMode = !{!0} |
620+
| | !0 = !{void @worker, i32 30, i32 262149} |
621+
| | ; Set execution mode with id 30 (VecTypeHint) and |
622+
| | ; literal `262149` operand. |
623+
+--------------------+---------------------------------------------------------+

llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
8383
void outputExecutionMode(const Module &M);
8484
void outputAnnotations(const Module &M);
8585
void outputModuleSections();
86+
void outputFPFastMathDefaultInfo();
8687
bool isHidden() {
8788
return MF->getFunction()
8889
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
@@ -514,11 +515,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
514515
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
515516
if (Node) {
516517
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
518+
// If SPV_KHR_float_controls2 is enabled and we find any of
519+
// FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
520+
// modes, skip it, it'll be done somewhere else.
521+
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
522+
const auto EM =
523+
cast<ConstantInt>(
524+
cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
525+
->getValue())
526+
->getZExtValue();
527+
if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
528+
EM == SPIRV::ExecutionMode::ContractionOff ||
529+
EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
530+
continue;
531+
}
532+
517533
MCInst Inst;
518534
Inst.setOpcode(SPIRV::OpExecutionMode);
519535
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
520536
outputMCInst(Inst);
521537
}
538+
outputFPFastMathDefaultInfo();
522539
}
523540
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
524541
const Function &F = *FI;
@@ -572,12 +589,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
572589
}
573590
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
574591
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
575-
MCInst Inst;
576-
Inst.setOpcode(SPIRV::OpExecutionMode);
577-
Inst.addOperand(MCOperand::createReg(FReg));
578-
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
579-
Inst.addOperand(MCOperand::createImm(EM));
580-
outputMCInst(Inst);
592+
if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
593+
// When SPV_KHR_float_controls2 is enabled, ContractionOff is
594+
// deprecated. We need to use FPFastMathDefault with the appropriate
595+
// flags instead. Since FPFastMathDefault takes a target type, we need
596+
// to emit it for each floating-point type that exists in the module
597+
// to match the effect of ContractionOff. As of now, there are 3 FP
598+
// types: fp16, fp32 and fp64.
599+
600+
// We only end up here because there is no "spirv.ExecutionMode"
601+
// metadata, so that means no FPFastMathDefault. Therefore, we only
602+
// need to make sure AllowContract is set to 0, as the rest of flags.
603+
// We still need to emit the OpExecutionMode instruction, otherwise
604+
// it's up to the client API to define the flags. Therefore, we need
605+
// to find the constant with 0 value.
606+
607+
// Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
608+
// type int32 with 0 value to represent the FP Fast Math Mode.
609+
std::vector<const MachineInstr *> SPIRVFloatTypes;
610+
const MachineInstr *ConstZero = nullptr;
611+
for (const MachineInstr *MI :
612+
MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
613+
// Skip if the instruction is not OpTypeFloat or OpConstant.
614+
unsigned OpCode = MI->getOpcode();
615+
if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull)
616+
continue;
617+
618+
// Collect the SPIRV type if it's a float.
619+
if (OpCode == SPIRV::OpTypeFloat) {
620+
// Skip if the target type is not fp16, fp32, fp64.
621+
const unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
622+
if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
623+
OpTypeFloatSize != 64) {
624+
continue;
625+
}
626+
SPIRVFloatTypes.push_back(MI);
627+
} else {
628+
// Check if the constant is int32, if not skip it.
629+
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
630+
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
631+
if (!TypeMI || TypeMI->getOperand(1).getImm() != 32)
632+
continue;
633+
634+
ConstZero = MI;
635+
}
636+
}
637+
638+
// When SPV_KHR_float_controls2 is enabled, ContractionOff is
639+
// deprecated. We need to use FPFastMathDefault with the appropriate
640+
// flags instead. Since FPFastMathDefault takes a target type, we need
641+
// to emit it for each floating-point type that exists in the module
642+
// to match the effect of ContractionOff. As of now, there are 3 FP
643+
// types: fp16, fp32 and fp64.
644+
for (const MachineInstr *MI : SPIRVFloatTypes) {
645+
MCInst Inst;
646+
Inst.setOpcode(SPIRV::OpExecutionModeId);
647+
Inst.addOperand(MCOperand::createReg(FReg));
648+
unsigned EM =
649+
static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
650+
Inst.addOperand(MCOperand::createImm(EM));
651+
const MachineFunction *MF = MI->getMF();
652+
MCRegister TypeReg =
653+
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
654+
Inst.addOperand(MCOperand::createReg(TypeReg));
655+
assert(ConstZero && "There should be a constant zero.");
656+
MCRegister ConstReg = MAI->getRegisterAlias(
657+
ConstZero->getMF(), ConstZero->getOperand(0).getReg());
658+
Inst.addOperand(MCOperand::createReg(ConstReg));
659+
outputMCInst(Inst);
660+
}
661+
} else {
662+
MCInst Inst;
663+
Inst.setOpcode(SPIRV::OpExecutionMode);
664+
Inst.addOperand(MCOperand::createReg(FReg));
665+
unsigned EM =
666+
static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
667+
Inst.addOperand(MCOperand::createImm(EM));
668+
outputMCInst(Inst);
669+
}
581670
}
582671
}
583672
}
@@ -626,6 +715,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
626715
}
627716
}
628717

718+
void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
719+
// Collect the SPIRVTypes that are OpTypeFloat and the constants of type
720+
// int32, that might be used as FP Fast Math Mode.
721+
std::vector<const MachineInstr *> SPIRVFloatTypes;
722+
// Hashtable to associate immediate values with the constant holding them.
723+
std::unordered_map<int, const MachineInstr *> ConstMap;
724+
for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
725+
// Skip if the instruction is not OpTypeFloat or OpConstant.
726+
unsigned OpCode = MI->getOpcode();
727+
if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
728+
OpCode != SPIRV::OpConstantNull)
729+
continue;
730+
731+
// Collect the SPIRV type if it's a float.
732+
if (OpCode == SPIRV::OpTypeFloat) {
733+
SPIRVFloatTypes.push_back(MI);
734+
} else {
735+
// Check if the constant is int32, if not skip it.
736+
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
737+
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
738+
if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
739+
TypeMI->getOperand(1).getImm() != 32)
740+
continue;
741+
742+
if (OpCode == SPIRV::OpConstantI)
743+
ConstMap[MI->getOperand(2).getImm()] = MI;
744+
else
745+
ConstMap[0] = MI;
746+
}
747+
}
748+
749+
for (const auto &[Func, FPFastMathDefaultInfoVec] :
750+
MAI->FPFastMathDefaultInfoMap) {
751+
if (FPFastMathDefaultInfoVec.empty())
752+
continue;
753+
754+
for (const MachineInstr *MI : SPIRVFloatTypes) {
755+
unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
756+
unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
757+
computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize);
758+
assert(Index < FPFastMathDefaultInfoVec.size() &&
759+
"Index out of bounds for FPFastMathDefaultInfoVec");
760+
const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
761+
assert(FPFastMathDefaultInfo.Ty &&
762+
"Expected target type for FPFastMathDefaultInfo");
763+
assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
764+
OpTypeFloatSize &&
765+
"Mismatched float type size");
766+
MCInst Inst;
767+
Inst.setOpcode(SPIRV::OpExecutionModeId);
768+
MCRegister FuncReg = MAI->getFuncReg(Func);
769+
assert(FuncReg.isValid());
770+
Inst.addOperand(MCOperand::createReg(FuncReg));
771+
Inst.addOperand(
772+
MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
773+
MCRegister TypeReg =
774+
MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg());
775+
Inst.addOperand(MCOperand::createReg(TypeReg));
776+
unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
777+
if (FPFastMathDefaultInfo.ContractionOff &&
778+
(Flags & SPIRV::FPFastMathMode::AllowContract))
779+
report_fatal_error(
780+
"Conflicting FPFastMathFlags: ContractionOff and AllowContract");
781+
782+
if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
783+
!(Flags &
784+
(SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
785+
SPIRV::FPFastMathMode::NSZ))) {
786+
if (FPFastMathDefaultInfo.FPFastMathDefault)
787+
report_fatal_error("Conflicting FPFastMathFlags: "
788+
"SignedZeroInfNanPreserve but at least one of "
789+
"NotNaN/NotInf/NSZ is enabled.");
790+
}
791+
792+
// Don't emit if none of the execution modes was used.
793+
if (Flags == SPIRV::FPFastMathMode::None &&
794+
!FPFastMathDefaultInfo.ContractionOff &&
795+
!FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
796+
!FPFastMathDefaultInfo.FPFastMathDefault)
797+
continue;
798+
799+
// Retrieve the constant instruction for the immediate value.
800+
auto It = ConstMap.find(Flags);
801+
if (It == ConstMap.end())
802+
report_fatal_error("Expected constant instruction for FP Fast Math "
803+
"Mode operand of FPFastMathDefault execution mode.");
804+
const MachineInstr *ConstMI = It->second;
805+
MCRegister ConstReg = MAI->getRegisterAlias(
806+
ConstMI->getMF(), ConstMI->getOperand(0).getReg());
807+
Inst.addOperand(MCOperand::createReg(ConstReg));
808+
outputMCInst(Inst);
809+
}
810+
}
811+
}
812+
629813
void SPIRVAsmPrinter::outputModuleSections() {
630814
const Module *M = MMI->getModule();
631815
// Get the global subtarget to output module-level info.
@@ -634,15 +818,17 @@ void SPIRVAsmPrinter::outputModuleSections() {
634818
MAI = &SPIRVModuleAnalysis::MAI;
635819
assert(ST && TII && MAI && M && "Module analysis is required");
636820
// Output instructions according to the Logical Layout of a Module:
637-
// 1,2. All OpCapability instructions, then optional OpExtension instructions.
821+
// 1,2. All OpCapability instructions, then optional OpExtension
822+
// instructions.
638823
outputGlobalRequirements();
639824
// 3. Optional OpExtInstImport instructions.
640825
outputOpExtInstImports(*M);
641826
// 4. The single required OpMemoryModel instruction.
642827
outputOpMemoryModel();
643828
// 5. All entry point declarations, using OpEntryPoint.
644829
outputEntryPoints();
645-
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
830+
// 6. Execution-mode declarations, using OpExecutionMode or
831+
// OpExecutionModeId.
646832
outputExecutionMode(*M);
647833
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
648834
// OpSourceContinued, without forward references.

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,11 +1163,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {
11631163

11641164
static bool generateExtInst(const SPIRV::IncomingCall *Call,
11651165
MachineIRBuilder &MIRBuilder,
1166-
SPIRVGlobalRegistry *GR) {
1166+
SPIRVGlobalRegistry *GR, const CallBase &CB) {
11671167
// Lookup the extended instruction number in the TableGen records.
11681168
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
11691169
uint32_t Number =
11701170
SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
1171+
// fmin_common and fmax_common are now deprecated, and we should use fmin and
1172+
// fmax with NotInf and NotNaN flags instead. Keep original number to add
1173+
// later the NoNans and NoInfs flags.
1174+
uint32_t OrigNumber = Number;
1175+
const SPIRVSubtarget &ST =
1176+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
1177+
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) &&
1178+
(Number == SPIRV::OpenCLExtInst::fmin_common ||
1179+
Number == SPIRV::OpenCLExtInst::fmax_common)) {
1180+
Number = (Number == SPIRV::OpenCLExtInst::fmin_common)
1181+
? SPIRV::OpenCLExtInst::fmin
1182+
: SPIRV::OpenCLExtInst::fmax;
1183+
}
11711184

11721185
// Build extended instruction.
11731186
auto MIB =
@@ -1179,6 +1192,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
11791192

11801193
for (auto Argument : Call->Arguments)
11811194
MIB.addUse(Argument);
1195+
MIB.getInstr()->copyIRFlags(CB);
1196+
if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
1197+
OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
1198+
// Add NoNans and NoInfs flags to fmin/fmax instruction.
1199+
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans);
1200+
MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs);
1201+
}
11821202
return true;
11831203
}
11841204

@@ -2918,7 +2938,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
29182938
MachineIRBuilder &MIRBuilder,
29192939
const Register OrigRet, const Type *OrigRetTy,
29202940
const SmallVectorImpl<Register> &Args,
2921-
SPIRVGlobalRegistry *GR) {
2941+
SPIRVGlobalRegistry *GR, const CallBase &CB) {
29222942
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
29232943

29242944
// Lookup the builtin in the TableGen records.
@@ -2941,7 +2961,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
29412961
// Match the builtin with implementation based on the grouping.
29422962
switch (Call->Builtin->Group) {
29432963
case SPIRV::Extended:
2944-
return generateExtInst(Call.get(), MIRBuilder, GR);
2964+
return generateExtInst(Call.get(), MIRBuilder, GR, CB);
29452965
case SPIRV::Relational:
29462966
return generateRelationalInst(Call.get(), MIRBuilder, GR);
29472967
case SPIRV::Group:

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
3939
MachineIRBuilder &MIRBuilder,
4040
const Register OrigRet, const Type *OrigRetTy,
4141
const SmallVectorImpl<Register> &Args,
42-
SPIRVGlobalRegistry *GR);
42+
SPIRVGlobalRegistry *GR, const CallBase &CB);
4343

4444
/// Helper function for finding a builtin function attributes
4545
/// by a demangled function name. Defined in SPIRVBuiltins.cpp.

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -631,9 +631,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
631631
GR->getPointerSize()));
632632
}
633633
}
634-
if (auto Res =
635-
SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
636-
MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
634+
if (auto Res = SPIRV::lowerBuiltin(
635+
DemangledName, ST->getPreferredInstructionSet(), MIRBuilder,
636+
ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB))
637637
return *Res;
638638
}
639639

0 commit comments

Comments
 (0)