@@ -83,6 +83,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
8383 void outputExecutionMode (const Module &M);
8484 void outputAnnotations (const Module &M);
8585 void outputModuleSections ();
86+ void outputFPFastMathDefaultInfo ();
8687 bool isHidden () {
8788 return MF->getFunction ()
8889 .getFnAttribute (SPIRV_BACKEND_SERVICE_FUN_NAME)
@@ -514,11 +515,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
514515 NamedMDNode *Node = M.getNamedMetadata (" spirv.ExecutionMode" );
515516 if (Node) {
516517 for (unsigned i = 0 ; i < Node->getNumOperands (); i++) {
518+ // If SPV_KHR_float_controls2 is enabled and we find any of
519+ // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
520+ // modes, skip it, it'll be done somewhere else.
521+ if (ST->canUseExtension (SPIRV::Extension::SPV_KHR_float_controls2)) {
522+ const auto EM =
523+ cast<ConstantInt>(
524+ cast<ConstantAsMetadata>((Node->getOperand (i))->getOperand (1 ))
525+ ->getValue ())
526+ ->getZExtValue ();
527+ if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
528+ EM == SPIRV::ExecutionMode::ContractionOff ||
529+ EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
530+ continue ;
531+ }
532+
517533 MCInst Inst;
518534 Inst.setOpcode (SPIRV::OpExecutionMode);
519535 addOpsFromMDNode (cast<MDNode>(Node->getOperand (i)), Inst, MAI);
520536 outputMCInst (Inst);
521537 }
538+ outputFPFastMathDefaultInfo ();
522539 }
523540 for (auto FI = M.begin (), E = M.end (); FI != E; ++FI) {
524541 const Function &F = *FI;
@@ -572,12 +589,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
572589 }
573590 if (ST->isKernel () && !M.getNamedMetadata (" spirv.ExecutionMode" ) &&
574591 !M.getNamedMetadata (" opencl.enable.FP_CONTRACT" )) {
575- MCInst Inst;
576- Inst.setOpcode (SPIRV::OpExecutionMode);
577- Inst.addOperand (MCOperand::createReg (FReg));
578- unsigned EM = static_cast <unsigned >(SPIRV::ExecutionMode::ContractionOff);
579- Inst.addOperand (MCOperand::createImm (EM));
580- outputMCInst (Inst);
592+ if (ST->canUseExtension (SPIRV::Extension::SPV_KHR_float_controls2)) {
593+ // When SPV_KHR_float_controls2 is enabled, ContractionOff is
594+ // deprecated. We need to use FPFastMathDefault with the appropriate
595+ // flags instead. Since FPFastMathDefault takes a target type, we need
596+ // to emit it for each floating-point type that exists in the module
597+ // to match the effect of ContractionOff. As of now, there are 3 FP
598+ // types: fp16, fp32 and fp64.
599+
600+ // We only end up here because there is no "spirv.ExecutionMode"
601+ // metadata, so that means no FPFastMathDefault. Therefore, we only
602+ // need to make sure AllowContract is set to 0, as the rest of flags.
603+ // We still need to emit the OpExecutionMode instruction, otherwise
604+ // it's up to the client API to define the flags. Therefore, we need
605+ // to find the constant with 0 value.
606+
607+ // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
608+ // type int32 with 0 value to represent the FP Fast Math Mode.
609+ std::vector<const MachineInstr *> SPIRVFloatTypes;
610+ const MachineInstr *ConstZero = nullptr ;
611+ for (const MachineInstr *MI :
612+ MAI->getMSInstrs (SPIRV::MB_TypeConstVars)) {
613+ // Skip if the instruction is not OpTypeFloat or OpConstant.
614+ unsigned OpCode = MI->getOpcode ();
615+ if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull)
616+ continue ;
617+
618+ // Collect the SPIRV type if it's a float.
619+ if (OpCode == SPIRV::OpTypeFloat) {
620+ // Skip if the target type is not fp16, fp32, fp64.
621+ const unsigned OpTypeFloatSize = MI->getOperand (1 ).getImm ();
622+ if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
623+ OpTypeFloatSize != 64 ) {
624+ continue ;
625+ }
626+ SPIRVFloatTypes.push_back (MI);
627+ } else {
628+ // Check if the constant is int32, if not skip it.
629+ const MachineRegisterInfo &MRI = MI->getMF ()->getRegInfo ();
630+ MachineInstr *TypeMI = MRI.getVRegDef (MI->getOperand (1 ).getReg ());
631+ if (!TypeMI || TypeMI->getOperand (1 ).getImm () != 32 )
632+ continue ;
633+
634+ ConstZero = MI;
635+ }
636+ }
637+
638+ // When SPV_KHR_float_controls2 is enabled, ContractionOff is
639+ // deprecated. We need to use FPFastMathDefault with the appropriate
640+ // flags instead. Since FPFastMathDefault takes a target type, we need
641+ // to emit it for each floating-point type that exists in the module
642+ // to match the effect of ContractionOff. As of now, there are 3 FP
643+ // types: fp16, fp32 and fp64.
644+ for (const MachineInstr *MI : SPIRVFloatTypes) {
645+ MCInst Inst;
646+ Inst.setOpcode (SPIRV::OpExecutionModeId);
647+ Inst.addOperand (MCOperand::createReg (FReg));
648+ unsigned EM =
649+ static_cast <unsigned >(SPIRV::ExecutionMode::FPFastMathDefault);
650+ Inst.addOperand (MCOperand::createImm (EM));
651+ const MachineFunction *MF = MI->getMF ();
652+ MCRegister TypeReg =
653+ MAI->getRegisterAlias (MF, MI->getOperand (0 ).getReg ());
654+ Inst.addOperand (MCOperand::createReg (TypeReg));
655+ assert (ConstZero && " There should be a constant zero." );
656+ MCRegister ConstReg = MAI->getRegisterAlias (
657+ ConstZero->getMF (), ConstZero->getOperand (0 ).getReg ());
658+ Inst.addOperand (MCOperand::createReg (ConstReg));
659+ outputMCInst (Inst);
660+ }
661+ } else {
662+ MCInst Inst;
663+ Inst.setOpcode (SPIRV::OpExecutionMode);
664+ Inst.addOperand (MCOperand::createReg (FReg));
665+ unsigned EM =
666+ static_cast <unsigned >(SPIRV::ExecutionMode::ContractionOff);
667+ Inst.addOperand (MCOperand::createImm (EM));
668+ outputMCInst (Inst);
669+ }
581670 }
582671 }
583672}
@@ -626,6 +715,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
626715 }
627716}
628717
718+ void SPIRVAsmPrinter::outputFPFastMathDefaultInfo () {
719+ // Collect the SPIRVTypes that are OpTypeFloat and the constants of type
720+ // int32, that might be used as FP Fast Math Mode.
721+ std::vector<const MachineInstr *> SPIRVFloatTypes;
722+ // Hashtable to associate immediate values with the constant holding them.
723+ std::unordered_map<int , const MachineInstr *> ConstMap;
724+ for (const MachineInstr *MI : MAI->getMSInstrs (SPIRV::MB_TypeConstVars)) {
725+ // Skip if the instruction is not OpTypeFloat or OpConstant.
726+ unsigned OpCode = MI->getOpcode ();
727+ if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
728+ OpCode != SPIRV::OpConstantNull)
729+ continue ;
730+
731+ // Collect the SPIRV type if it's a float.
732+ if (OpCode == SPIRV::OpTypeFloat) {
733+ SPIRVFloatTypes.push_back (MI);
734+ } else {
735+ // Check if the constant is int32, if not skip it.
736+ const MachineRegisterInfo &MRI = MI->getMF ()->getRegInfo ();
737+ MachineInstr *TypeMI = MRI.getVRegDef (MI->getOperand (1 ).getReg ());
738+ if (!TypeMI || TypeMI->getOpcode () != SPIRV::OpTypeInt ||
739+ TypeMI->getOperand (1 ).getImm () != 32 )
740+ continue ;
741+
742+ if (OpCode == SPIRV::OpConstantI)
743+ ConstMap[MI->getOperand (2 ).getImm ()] = MI;
744+ else
745+ ConstMap[0 ] = MI;
746+ }
747+ }
748+
749+ for (const auto &[Func, FPFastMathDefaultInfoVec] :
750+ MAI->FPFastMathDefaultInfoMap ) {
751+ if (FPFastMathDefaultInfoVec.empty ())
752+ continue ;
753+
754+ for (const MachineInstr *MI : SPIRVFloatTypes) {
755+ unsigned OpTypeFloatSize = MI->getOperand (1 ).getImm ();
756+ unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
757+ computeFPFastMathDefaultInfoVecIndex (OpTypeFloatSize);
758+ assert (Index < FPFastMathDefaultInfoVec.size () &&
759+ " Index out of bounds for FPFastMathDefaultInfoVec" );
760+ const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
761+ assert (FPFastMathDefaultInfo.Ty &&
762+ " Expected target type for FPFastMathDefaultInfo" );
763+ assert (FPFastMathDefaultInfo.Ty ->getScalarSizeInBits () ==
764+ OpTypeFloatSize &&
765+ " Mismatched float type size" );
766+ MCInst Inst;
767+ Inst.setOpcode (SPIRV::OpExecutionModeId);
768+ MCRegister FuncReg = MAI->getFuncReg (Func);
769+ assert (FuncReg.isValid ());
770+ Inst.addOperand (MCOperand::createReg (FuncReg));
771+ Inst.addOperand (
772+ MCOperand::createImm (SPIRV::ExecutionMode::FPFastMathDefault));
773+ MCRegister TypeReg =
774+ MAI->getRegisterAlias (MI->getMF (), MI->getOperand (0 ).getReg ());
775+ Inst.addOperand (MCOperand::createReg (TypeReg));
776+ unsigned Flags = FPFastMathDefaultInfo.FastMathFlags ;
777+ if (FPFastMathDefaultInfo.ContractionOff &&
778+ (Flags & SPIRV::FPFastMathMode::AllowContract))
779+ report_fatal_error (
780+ " Conflicting FPFastMathFlags: ContractionOff and AllowContract" );
781+
782+ if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
783+ !(Flags &
784+ (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
785+ SPIRV::FPFastMathMode::NSZ))) {
786+ if (FPFastMathDefaultInfo.FPFastMathDefault )
787+ report_fatal_error (" Conflicting FPFastMathFlags: "
788+ " SignedZeroInfNanPreserve but at least one of "
789+ " NotNaN/NotInf/NSZ is enabled." );
790+ }
791+
792+ // Don't emit if none of the execution modes was used.
793+ if (Flags == SPIRV::FPFastMathMode::None &&
794+ !FPFastMathDefaultInfo.ContractionOff &&
795+ !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
796+ !FPFastMathDefaultInfo.FPFastMathDefault )
797+ continue ;
798+
799+ // Retrieve the constant instruction for the immediate value.
800+ auto It = ConstMap.find (Flags);
801+ if (It == ConstMap.end ())
802+ report_fatal_error (" Expected constant instruction for FP Fast Math "
803+ " Mode operand of FPFastMathDefault execution mode." );
804+ const MachineInstr *ConstMI = It->second ;
805+ MCRegister ConstReg = MAI->getRegisterAlias (
806+ ConstMI->getMF (), ConstMI->getOperand (0 ).getReg ());
807+ Inst.addOperand (MCOperand::createReg (ConstReg));
808+ outputMCInst (Inst);
809+ }
810+ }
811+ }
812+
629813void SPIRVAsmPrinter::outputModuleSections () {
630814 const Module *M = MMI->getModule ();
631815 // Get the global subtarget to output module-level info.
@@ -634,15 +818,17 @@ void SPIRVAsmPrinter::outputModuleSections() {
634818 MAI = &SPIRVModuleAnalysis::MAI;
635819 assert (ST && TII && MAI && M && " Module analysis is required" );
636820 // Output instructions according to the Logical Layout of a Module:
637- // 1,2. All OpCapability instructions, then optional OpExtension instructions.
821+ // 1,2. All OpCapability instructions, then optional OpExtension
822+ // instructions.
638823 outputGlobalRequirements ();
639824 // 3. Optional OpExtInstImport instructions.
640825 outputOpExtInstImports (*M);
641826 // 4. The single required OpMemoryModel instruction.
642827 outputOpMemoryModel ();
643828 // 5. All entry point declarations, using OpEntryPoint.
644829 outputEntryPoints ();
645- // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
830+ // 6. Execution-mode declarations, using OpExecutionMode or
831+ // OpExecutionModeId.
646832 outputExecutionMode (*M);
647833 // 7a. Debug: all OpString, OpSourceExtension, OpSource, and
648834 // OpSourceContinued, without forward references.
0 commit comments