Skip to content

Commit 690ec7c

Browse files
authored
[SPIR-V] Fix mesh payload global variable for VK_EXT_mesh_shaders (#6526)
The existing logic from `VK_NV_mesh_shader` was incorrectly adapted for the `VK_EXT_mesh_shader` implementation when it comes to the handling of payloads as in/out variables. Because `TaskPayloadWorkgroupEXT` must be applied to a single global `OpVariable` for each task/mesh shader, the struct should not be flattened. Further, Location assignment is not necessary for these input and output variables, so the usual reason for flattening structs does not apply. This change now removes the inner struct member global variables and ensures the parent payload is decorated with `TaskPayloadWorkgroupEXT`. Note that for amplification/task shaders, the payload variable is created with the `groupshared` decl, and then its storage class needs to be updated when that variable is used as a parameter to the `DispatchMesh` call, as described in: https://docs.vulkan.org/spec/latest/proposals/proposals/VK_EXT_mesh_shader.html#_hlsl_changes Tested with new validation checks from: KhronosGroup/SPIRV-Tools#5640 Fixes #5981
1 parent 3546bde commit 690ec7c

File tree

3 files changed

+57
-35
lines changed

3 files changed

+57
-35
lines changed

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

+52-23
Original file line numberDiff line numberDiff line change
@@ -3582,32 +3582,61 @@ bool DeclResultIdMapper::createPayloadStageVars(
35823582
}
35833583

35843584
const auto loc = decl->getLocation();
3585-
if (!type->isStructureType()) {
3586-
StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
3587-
getLocationAndComponentCount(astContext, type));
3588-
const auto name = namePrefix.str() + "." + decl->getNameAsString();
3589-
SpirvVariable *varInstr = spvBuilder.addStageIOVar(
3590-
type, sc, name, /*isPrecise=*/false, /*isNointerp=*/false, loc);
3591-
3592-
if (!varInstr)
3593-
return false;
35943585

3595-
// Even though these as user defined IO stage variables, set them as SPIR-V
3596-
// builtins in order to bypass any semantic string checks and location
3597-
// assignment.
3598-
stageVar.setIsSpirvBuiltin();
3599-
stageVar.setSpirvInstr(varInstr);
3600-
if (stageVar.getStorageClass() == spv::StorageClass::Input ||
3601-
stageVar.getStorageClass() == spv::StorageClass::Output) {
3602-
stageVar.setEntryPoint(entryFunction);
3586+
// Most struct type stage vars must be flattened, but for EXT_mesh_shaders the
3587+
// mesh payload struct should be decorated with TaskPayloadWorkgroupEXT and
3588+
// used directly as the OpEntryPoint variable.
3589+
if (!type->isStructureType() ||
3590+
featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
3591+
3592+
SpirvVariable *varInstr = nullptr;
3593+
3594+
// Check whether a mesh payload module variable has already been added, as
3595+
// is the case for the groupshared payload variable parameter of
3596+
// DispatchMesh. In this case, change the storage class from Workgroup to
3597+
// TaskPayloadWorkgroupEXT.
3598+
if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
3599+
for (SpirvVariable *moduleVar : spvBuilder.getModule()->getVariables()) {
3600+
if (moduleVar->getAstResultType() == type) {
3601+
moduleVar->setStorageClass(
3602+
spv::StorageClass::TaskPayloadWorkgroupEXT);
3603+
varInstr = moduleVar;
3604+
}
3605+
}
36033606
}
3604-
stageVars.push_back(stageVar);
36053607

3606-
if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
3607-
// Decorate with PerTaskNV for mesh/amplification shader payload
3608-
// variables.
3609-
spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
3610-
varInstr->getSourceLocation());
3608+
// If necessary, create new stage variable for mesh payload.
3609+
if (!varInstr) {
3610+
LocationAndComponent locationAndComponentCount =
3611+
type->isStructureType()
3612+
? LocationAndComponent({0, 0, false})
3613+
: getLocationAndComponentCount(astContext, type);
3614+
StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr,
3615+
type, locationAndComponentCount);
3616+
const auto name = namePrefix.str() + "." + decl->getNameAsString();
3617+
varInstr = spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false,
3618+
/*isNointerp=*/false, loc);
3619+
3620+
if (!varInstr)
3621+
return false;
3622+
3623+
// Even though these as user defined IO stage variables, set them as
3624+
// SPIR-V builtins in order to bypass any semantic string checks and
3625+
// location assignment.
3626+
stageVar.setIsSpirvBuiltin();
3627+
stageVar.setSpirvInstr(varInstr);
3628+
if (stageVar.getStorageClass() == spv::StorageClass::Input ||
3629+
stageVar.getStorageClass() == spv::StorageClass::Output) {
3630+
stageVar.setEntryPoint(entryFunction);
3631+
}
3632+
stageVars.push_back(stageVar);
3633+
3634+
if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
3635+
// Decorate with PerTaskNV for mesh/amplification shader payload
3636+
// variables.
3637+
spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
3638+
varInstr->getSourceLocation());
3639+
}
36113640
}
36123641

36133642
if (asInput) {

tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %dxc -T as_6_5 -fspv-target-env=vulkan1.1spirv1.4 -E main -fcgl %s -spirv | FileCheck %s
22
// CHECK: OpCapability MeshShadingEXT
33
// CHECK: OpExtension "SPV_EXT_mesh_shader"
4-
// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %out_var_dummy %out_var_pos
4+
// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %pld
55
// CHECK: OpExecutionMode %main LocalSize 128 1 1
66

77
// CHECK: OpDecorate [[drawid]] BuiltIn DrawIndex
@@ -11,14 +11,12 @@
1111
// CHECK: OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
1212

1313

14-
// CHECK: %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup
14+
// CHECK: %pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT
1515
// CHECK: [[drawid]] = OpVariable %_ptr_Input_int Input
1616
// CHECK: %gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
1717
// CHECK: %gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
1818
// CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
1919
// CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
20-
// CHECK: %out_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT
21-
// CHECK: %out_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT
2220
struct MeshPayload {
2321
float dummy[10];
2422
float4 pos;
@@ -47,16 +45,12 @@ void main(
4745
// CHECK: %tid = OpFunctionParameter %_ptr_Function_uint
4846
// CHECK: %tig = OpFunctionParameter %_ptr_Function_uint
4947
//
50-
// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1
48+
// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %pld %int_1
5149
// CHECK: OpStore [[a]] {{%[0-9]+}}
5250
pld.pos = float4(gtid.x, gid.y, tid, tig);
5351

5452
// CHECK: OpControlBarrier %uint_2 %uint_2 %uint_264
5553
// CHECK: [[e:%[0-9]+]] = OpLoad %MeshPayload %pld
56-
// CHECK: [[f:%[0-9]+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0
57-
// CHECK: OpStore %out_var_dummy [[f]]
58-
// CHECK: [[g:%[0-9]+]] = OpCompositeExtract %v4float [[e]] 1
59-
// CHECK: OpStore %out_var_pos [[g]]
6054
// CHECK: [[h:%[0-9]+]] = OpLoad %int %drawId
6155
// CHECK: [[i:%[0-9]+]] = OpBitcast %uint [[h]]
6256
// CHECK: [[j:%[0-9]+]] = OpLoad %int %drawId

tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %dxc -T ms_6_5 -fspv-target-env=universal1.5 -E main -fcgl %s -spirv | FileCheck %s
22
// CHECK: OpCapability MeshShadingEXT
33
// CHECK: OpExtension "SPV_EXT_mesh_shader"
4-
// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_dummy %in_var_pos [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR
4+
// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_pld [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR
55
// CHECK: OpExecutionMode %main LocalSize 128 1 1
66
// CHECK: OpExecutionMode %main OutputTrianglesNV
77
// CHECK: OpExecutionMode %main OutputVertices 64
@@ -37,8 +37,7 @@
3737

3838
// CHECK: %gl_ClipDistance = OpVariable %_ptr_Output__arr__arr_float_uint_5_uint_64 Output
3939
// CHECK: %gl_CullDistance = OpVariable %_ptr_Output__arr__arr_float_uint_3_uint_64 Output
40-
// CHECK: %in_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT
41-
// CHECK: %in_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT
40+
// CHECK: %in_var_pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT
4241
// CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
4342
// CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
4443
// CHECK: %gl_Position = OpVariable %_ptr_Output__arr_v4float_uint_64 Output

0 commit comments

Comments
 (0)