Skip to content

Commit f91e0bf

Browse files
[SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 (llvm#155645)
This PR introduces the support for the SPIR-V extension `SPV_KHR_bfloat16`. This extension extends the `OpTypeFloat` instruction to enable the use of bfloat16 types with cooperative matrices and dot products. TODO: Per the `SPV_KHR_bfloat16` extension, there are a limited number of instructions that can use the bfloat16 type. For example, arithmetic instructions like `FAdd` or `FMul` can't operate on `bfloat16` values. Therefore, a future patch should be added to either emit an error or fall back to FP32 for arithmetic in cases where bfloat16 must not be used. Reference Specification: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_bfloat16.asciidoc
1 parent ac69f9d commit f91e0bf

File tree

8 files changed

+112
-11
lines changed

8 files changed

+112
-11
lines changed

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2765,7 +2765,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
27652765
}
27662766

27672767
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
2768-
if (containsBF16Type(U))
2768+
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
27692769
return false;
27702770

27712771
const CallInst &CI = cast<CallInst>(U);

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
147147
{"SPV_KHR_float_controls2",
148148
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
149149
{"SPV_INTEL_tensor_float32_conversion",
150-
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
150+
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
151+
{"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
151152

152153
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
153154
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
12221222
}
12231223
}
12241224

1225+
static bool isBFloat16Type(const SPIRVType *TypeDef) {
1226+
return TypeDef && TypeDef->getNumOperands() == 3 &&
1227+
TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1228+
TypeDef->getOperand(1).getImm() == 16 &&
1229+
TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1230+
}
1231+
12251232
void addInstrRequirements(const MachineInstr &MI,
12261233
SPIRV::RequirementHandler &Reqs,
12271234
const SPIRVSubtarget &ST) {
@@ -1261,12 +1268,29 @@ void addInstrRequirements(const MachineInstr &MI,
12611268
Reqs.addCapability(SPIRV::Capability::Int8);
12621269
break;
12631270
}
1271+
case SPIRV::OpDot: {
1272+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1273+
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1274+
if (isBFloat16Type(TypeDef))
1275+
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1276+
break;
1277+
}
12641278
case SPIRV::OpTypeFloat: {
12651279
unsigned BitWidth = MI.getOperand(1).getImm();
12661280
if (BitWidth == 64)
12671281
Reqs.addCapability(SPIRV::Capability::Float64);
1268-
else if (BitWidth == 16)
1269-
Reqs.addCapability(SPIRV::Capability::Float16);
1282+
else if (BitWidth == 16) {
1283+
if (isBFloat16Type(&MI)) {
1284+
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1285+
report_fatal_error("OpTypeFloat type with bfloat requires the "
1286+
"following SPIR-V extension: SPV_KHR_bfloat16",
1287+
false);
1288+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1289+
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1290+
} else {
1291+
Reqs.addCapability(SPIRV::Capability::Float16);
1292+
}
1293+
}
12701294
break;
12711295
}
12721296
case SPIRV::OpTypeVector: {
@@ -1286,8 +1310,9 @@ void addInstrRequirements(const MachineInstr &MI,
12861310
assert(MI.getOperand(2).isReg());
12871311
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
12881312
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1289-
if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1290-
TypeDef->getOperand(1).getImm() == 16)
1313+
if ((TypeDef->getNumOperands() == 2) &&
1314+
(TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1315+
(TypeDef->getOperand(1).getImm() == 16))
12911316
Reqs.addCapability(SPIRV::Capability::Float16Buffer);
12921317
break;
12931318
}
@@ -1593,15 +1618,20 @@ void addInstrRequirements(const MachineInstr &MI,
15931618
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
15941619
}
15951620
break;
1596-
case SPIRV::OpTypeCooperativeMatrixKHR:
1621+
case SPIRV::OpTypeCooperativeMatrixKHR: {
15971622
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
15981623
report_fatal_error(
15991624
"OpTypeCooperativeMatrixKHR type requires the "
16001625
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
16011626
false);
16021627
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
16031628
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1629+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1630+
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1631+
if (isBFloat16Type(TypeDef))
1632+
Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
16041633
break;
1634+
}
16051635
case SPIRV::OpArithmeticFenceEXT:
16061636
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
16071637
report_fatal_error("OpArithmeticFenceEXT requires the "

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
383383
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
384384
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
385385
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
386+
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
386387

387388
//===----------------------------------------------------------------------===//
388389
// Multiclass used to define Capabilities enum values and at the same time
@@ -595,6 +596,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
595596
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
596597
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
597598
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
599+
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
600+
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
601+
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
598602

599603
//===----------------------------------------------------------------------===//
600604
// Multiclass used to define SourceLanguage enum values and at the same time
@@ -2021,4 +2025,4 @@ multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
20212025
reqExtensions, [], []>;
20222026
}
20232027

2024-
defm BFloat16KHR : FPEncodingOperand<0, []>;
2028+
defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;

llvm/test/CodeGen/SPIRV/basic_float_types.ll

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
2-
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
3-
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
1+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
2+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
44

55
define void @main() {
66
entry:
77

88
; CHECK-DAG: OpCapability Float16
99
; CHECK-DAG: OpCapability Float64
10+
; CHECK-DAG: OpCapability BFloat16TypeKHR
1011

1112
; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}}
1213
; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
4+
5+
; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
6+
7+
; CHECK-DAG: OpCapability BFloat16TypeKHR
8+
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
9+
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
10+
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
11+
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
13+
target triple = "spir64-unknown-unknown"
14+
15+
define spir_kernel void @test() {
16+
entry:
17+
%addr1 = alloca bfloat
18+
%addr2 = alloca <2 x bfloat>
19+
%data1 = load bfloat, ptr %addr1
20+
%data2 = load <2 x bfloat>, ptr %addr2
21+
ret void
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpCapability BFloat16TypeKHR
5+
; CHECK-DAG: OpCapability CooperativeMatrixKHR
6+
; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
7+
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
8+
; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix"
9+
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
10+
; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
11+
; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]
12+
13+
define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
14+
entry:
15+
%addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
16+
%res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
17+
%m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
18+
store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
19+
ret void
20+
}
21+
22+
declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpCapability BFloat16TypeKHR
5+
; CHECK-DAG: OpCapability BFloat16DotProductKHR
6+
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
7+
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
8+
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
9+
; CHECK: OpDot
10+
11+
declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
12+
13+
define spir_kernel void @test() {
14+
entry:
15+
%addrA = alloca <2 x bfloat>
16+
%addrB = alloca <2 x bfloat>
17+
%dataA = load <2 x bfloat>, ptr %addrA
18+
%dataB = load <2 x bfloat>, ptr %addrB
19+
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
20+
ret void
21+
}

0 commit comments

Comments
 (0)