Skip to content

Commit 298055b

Browse files
authored
opt: add StorageImageWriteWithoutFormat to trimm pass (#5860)
* opt: add StorageImageWriteWithoutFormat to trimm pass --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent 895bb9f commit 298055b

File tree

3 files changed

+114
-1
lines changed

3 files changed

+114
-1
lines changed

source/opt/trim_capabilities_pass.cpp

+25-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
4949
constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
5050
constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
5151
constexpr uint32_t kOpImageReadImageIndex = 0;
52+
constexpr uint32_t kOpImageWriteImageIndex = 0;
5253
constexpr uint32_t kOpImageSparseReadImageIndex = 0;
5354
constexpr uint32_t kOpExtInstSetInIndex = 0;
5455
constexpr uint32_t kOpExtInstInstructionInIndex = 1;
@@ -338,6 +339,8 @@ Handler_OpImageRead_StorageImageReadWithoutFormat(
338339
const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
339340
const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
340341

342+
// If the Image Format is Unknown and Dim is SubpassData,
343+
// StorageImageReadWithoutFormat is required.
341344
const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
342345
const bool requires_capability_for_unknown =
343346
spv::Dim(dim) != spv::Dim::SubpassData;
@@ -346,6 +349,26 @@ Handler_OpImageRead_StorageImageReadWithoutFormat(
346349
: std::nullopt;
347350
}
348351

352+
static std::optional<spv::Capability>
353+
Handler_OpImageWrite_StorageImageWriteWithoutFormat(
354+
const Instruction* instruction) {
355+
assert(instruction->opcode() == spv::Op::OpImageWrite &&
356+
"This handler only support OpImageWrite opcodes.");
357+
const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
358+
359+
const uint32_t image_index =
360+
instruction->GetSingleWordInOperand(kOpImageWriteImageIndex);
361+
const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
362+
363+
// If the Image Format is Unknown, StorageImageWriteWithoutFormat is required.
364+
const Instruction* type = def_use_mgr->GetDef(type_index);
365+
const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
366+
const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
367+
return is_unknown
368+
? std::optional(spv::Capability::StorageImageWriteWithoutFormat)
369+
: std::nullopt;
370+
}
371+
349372
static std::optional<spv::Capability>
350373
Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
351374
const Instruction* instruction) {
@@ -365,9 +388,10 @@ Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
365388
}
366389

367390
// Opcode of interest to determine capabilities requirements.
368-
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
391+
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 13> kOpcodeHandlers{{
369392
// clang-format off
370393
{spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
394+
{spv::Op::OpImageWrite, Handler_OpImageWrite_StorageImageWriteWithoutFormat},
371395
{spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
372396
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
373397
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },

source/opt/trim_capabilities_pass.h

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class TrimCapabilitiesPass : public Pass {
100100
spv::Capability::Shader,
101101
spv::Capability::ShaderClockKHR,
102102
spv::Capability::StorageImageReadWithoutFormat,
103+
spv::Capability::StorageImageWriteWithoutFormat,
103104
spv::Capability::StorageInputOutput16,
104105
spv::Capability::StoragePushConstant16,
105106
spv::Capability::StorageUniform16,

test/opt/trim_capabilities_pass_test.cpp

+88
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,94 @@ TEST_F(TrimCapabilitiesPassTest,
23662366
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
23672367
}
23682368

2369+
TEST_F(TrimCapabilitiesPassTest,
2370+
StorageImageWriteWithoutFormat_RemainsWhenRequiredWithWrite) {
2371+
const std::string kTest = R"(
2372+
OpCapability StorageImageWriteWithoutFormat
2373+
; CHECK: OpCapability StorageImageWriteWithoutFormat
2374+
OpCapability Shader
2375+
OpCapability StorageImageExtendedFormats
2376+
OpMemoryModel Logical GLSL450
2377+
OpEntryPoint GLCompute %main "main" %id %img
2378+
OpExecutionMode %main LocalSize 8 8 8
2379+
OpSource HLSL 670
2380+
OpName %type_image "type.3d.image"
2381+
OpName %img "img"
2382+
OpName %main "main"
2383+
OpDecorate %id BuiltIn GlobalInvocationId
2384+
OpDecorate %img DescriptorSet 0
2385+
OpDecorate %img Binding 0
2386+
%float = OpTypeFloat 32
2387+
%float_4 = OpConstant %float 4
2388+
%float_5 = OpConstant %float 5
2389+
%v2float = OpTypeVector %float 2
2390+
%9 = OpConstantComposite %v2float %float_4 %float_5
2391+
%type_image = OpTypeImage %float 3D 2 0 0 2 Unknown
2392+
%ptr_img = OpTypePointer UniformConstant %type_image
2393+
%uint = OpTypeInt 32 0
2394+
%v3uint = OpTypeVector %uint 3
2395+
%ptr_input = OpTypePointer Input %v3uint
2396+
%void = OpTypeVoid
2397+
%15 = OpTypeFunction %void
2398+
%img = OpVariable %ptr_img UniformConstant
2399+
%id = OpVariable %ptr_input Input
2400+
%main = OpFunction %void None %15
2401+
%16 = OpLabel
2402+
%17 = OpLoad %v3uint %id
2403+
%18 = OpLoad %type_image %img
2404+
OpImageWrite %18 %17 %9 None
2405+
OpReturn
2406+
OpFunctionEnd
2407+
)";
2408+
const auto result =
2409+
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
2410+
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
2411+
}
2412+
2413+
TEST_F(TrimCapabilitiesPassTest,
2414+
StorageImageWriteWithoutFormat_RemovedWithWriteOnKnownFormat) {
2415+
const std::string kTest = R"(
2416+
OpCapability StorageImageWriteWithoutFormat
2417+
; CHECK-NOT: OpCapability StorageImageWriteWithoutFormat
2418+
OpCapability Shader
2419+
OpCapability StorageImageExtendedFormats
2420+
OpMemoryModel Logical GLSL450
2421+
OpEntryPoint GLCompute %main "main" %id %img
2422+
OpExecutionMode %main LocalSize 8 8 8
2423+
OpSource HLSL 670
2424+
OpName %type_image "type.3d.image"
2425+
OpName %img "img"
2426+
OpName %main "main"
2427+
OpDecorate %id BuiltIn GlobalInvocationId
2428+
OpDecorate %img DescriptorSet 0
2429+
OpDecorate %img Binding 0
2430+
%float = OpTypeFloat 32
2431+
%float_4 = OpConstant %float 4
2432+
%float_5 = OpConstant %float 5
2433+
%v2float = OpTypeVector %float 2
2434+
%9 = OpConstantComposite %v2float %float_4 %float_5
2435+
%type_image = OpTypeImage %float 3D 2 0 0 2 Rg32f
2436+
%ptr_img = OpTypePointer UniformConstant %type_image
2437+
%uint = OpTypeInt 32 0
2438+
%v3uint = OpTypeVector %uint 3
2439+
%ptr_input = OpTypePointer Input %v3uint
2440+
%void = OpTypeVoid
2441+
%15 = OpTypeFunction %void
2442+
%img = OpVariable %ptr_img UniformConstant
2443+
%id = OpVariable %ptr_input Input
2444+
%main = OpFunction %void None %15
2445+
%16 = OpLabel
2446+
%17 = OpLoad %v3uint %id
2447+
%18 = OpLoad %type_image %img
2448+
OpImageWrite %18 %17 %9 None
2449+
OpReturn
2450+
OpFunctionEnd
2451+
)";
2452+
const auto result =
2453+
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
2454+
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
2455+
}
2456+
23692457
TEST_F(TrimCapabilitiesPassTest, PhysicalStorageBuffer_RemovedWhenUnused) {
23702458
const std::string kTest = R"(
23712459
OpCapability PhysicalStorageBufferAddresses

0 commit comments

Comments
 (0)