Skip to content

Commit 7d6ac19

Browse files
s-perronKeenuts
authored andcommitted
[opt] Fix uses of type manager in fix storage class (KhronosGroup#5740)
This removes some uses of the type manager. One use could not be removed. Instead I had to update GenCopy to not use the type manager, and be able to copy pointers. Part of KhronosGroup#5691
1 parent 9cfa559 commit 7d6ac19

File tree

3 files changed

+108
-62
lines changed

3 files changed

+108
-62
lines changed

source/opt/fix_storage_class.cpp

+30-14
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,26 @@ bool FixStorageClass::IsPointerResultType(Instruction* inst) {
141141
if (inst->type_id() == 0) {
142142
return false;
143143
}
144-
const analysis::Type* ret_type =
145-
context()->get_type_mgr()->GetType(inst->type_id());
146-
return ret_type->AsPointer() != nullptr;
144+
145+
Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
146+
return type_def->opcode() == spv::Op::OpTypePointer;
147147
}
148148

149149
bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
150150
spv::StorageClass storage_class) {
151-
analysis::TypeManager* type_mgr = context()->get_type_mgr();
152-
analysis::Type* pType = type_mgr->GetType(inst->type_id());
153-
const analysis::Pointer* result_type = pType->AsPointer();
151+
if (inst->type_id() == 0) {
152+
return false;
153+
}
154154

155-
if (result_type == nullptr) {
155+
Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
156+
if (type_def->opcode() != spv::Op::OpTypePointer) {
156157
return false;
157158
}
158159

159-
return (result_type->storage_class() == storage_class);
160+
const uint32_t kPointerTypeStorageClassIndex = 0;
161+
spv::StorageClass pointer_storage_class = static_cast<spv::StorageClass>(
162+
type_def->GetSingleWordInOperand(kPointerTypeStorageClassIndex));
163+
return pointer_storage_class == storage_class;
160164
}
161165

162166
bool FixStorageClass::ChangeResultType(Instruction* inst,
@@ -301,9 +305,11 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
301305
break;
302306
}
303307

304-
Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
305-
assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
306-
id = orig_type_inst->GetSingleWordInOperand(1);
308+
Instruction* id_type_inst = get_def_use_mgr()->GetDef(id);
309+
assert(id_type_inst->opcode() == spv::Op::OpTypePointer);
310+
id = id_type_inst->GetSingleWordInOperand(1);
311+
spv::StorageClass input_storage_class =
312+
static_cast<spv::StorageClass>(id_type_inst->GetSingleWordInOperand(0));
307313

308314
for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
309315
Instruction* type_inst = get_def_use_mgr()->GetDef(id);
@@ -336,9 +342,19 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
336342
"Tried to extract from an object where it cannot be done.");
337343
}
338344

339-
return context()->get_type_mgr()->FindPointerToType(
340-
id, static_cast<spv::StorageClass>(
341-
orig_type_inst->GetSingleWordInOperand(0)));
345+
Instruction* orig_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
346+
spv::StorageClass orig_storage_class =
347+
static_cast<spv::StorageClass>(orig_type_inst->GetSingleWordInOperand(0));
348+
assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
349+
if (orig_type_inst->GetSingleWordInOperand(1) == id &&
350+
input_storage_class == orig_storage_class) {
351+
// The existing type is correct. Avoid the search for the type. Note that if
352+
// there is a duplicate type, the search below could return a different type
353+
// forcing more changes to the code than necessary.
354+
return inst->type_id();
355+
}
356+
357+
return context()->get_type_mgr()->FindPointerToType(id, input_storage_class);
342358
}
343359

344360
// namespace opt

source/opt/pass.cpp

+44-48
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ uint32_t Pass::GetNullId(uint32_t type_id) {
8383

8484
uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id,
8585
Instruction* insertion_position) {
86-
analysis::TypeManager* type_mgr = context()->get_type_mgr();
8786
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
8887

8988
uint32_t original_type_id = object_to_copy->type_id();
@@ -95,55 +94,52 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id,
9594
context(), insertion_position,
9695
IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse);
9796

98-
analysis::Type* original_type = type_mgr->GetType(original_type_id);
99-
analysis::Type* new_type = type_mgr->GetType(new_type_id);
100-
101-
if (const analysis::Array* original_array_type = original_type->AsArray()) {
102-
uint32_t original_element_type_id =
103-
type_mgr->GetId(original_array_type->element_type());
104-
105-
analysis::Array* new_array_type = new_type->AsArray();
106-
assert(new_array_type != nullptr && "Can't copy an array to a non-array.");
107-
uint32_t new_element_type_id =
108-
type_mgr->GetId(new_array_type->element_type());
109-
110-
std::vector<uint32_t> element_ids;
111-
const analysis::Constant* length_const =
112-
const_mgr->FindDeclaredConstant(original_array_type->LengthId());
113-
assert(length_const->AsIntConstant());
114-
uint32_t array_length = length_const->AsIntConstant()->GetU32();
115-
for (uint32_t i = 0; i < array_length; i++) {
116-
Instruction* extract = ir_builder.AddCompositeExtract(
117-
original_element_type_id, object_to_copy->result_id(), {i});
118-
element_ids.push_back(
119-
GenerateCopy(extract, new_element_type_id, insertion_position));
97+
Instruction* original_type = get_def_use_mgr()->GetDef(original_type_id);
98+
Instruction* new_type = get_def_use_mgr()->GetDef(new_type_id);
99+
assert(new_type->opcode() == original_type->opcode() &&
100+
"Can't copy an aggragate type unless the type correspond.");
101+
102+
switch (original_type->opcode()) {
103+
case spv::Op::OpTypeArray: {
104+
uint32_t original_element_type_id =
105+
original_type->GetSingleWordInOperand(0);
106+
uint32_t new_element_type_id = new_type->GetSingleWordInOperand(0);
107+
108+
std::vector<uint32_t> element_ids;
109+
uint32_t length_id = original_type->GetSingleWordInOperand(1);
110+
const analysis::Constant* length_const =
111+
const_mgr->FindDeclaredConstant(length_id);
112+
assert(length_const->AsIntConstant());
113+
uint32_t array_length = length_const->AsIntConstant()->GetU32();
114+
for (uint32_t i = 0; i < array_length; i++) {
115+
Instruction* extract = ir_builder.AddCompositeExtract(
116+
original_element_type_id, object_to_copy->result_id(), {i});
117+
element_ids.push_back(
118+
GenerateCopy(extract, new_element_type_id, insertion_position));
119+
}
120+
121+
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
122+
->result_id();
120123
}
121-
122-
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
123-
->result_id();
124-
} else if (const analysis::Struct* original_struct_type =
125-
original_type->AsStruct()) {
126-
analysis::Struct* new_struct_type = new_type->AsStruct();
127-
128-
const std::vector<const analysis::Type*>& original_types =
129-
original_struct_type->element_types();
130-
const std::vector<const analysis::Type*>& new_types =
131-
new_struct_type->element_types();
132-
std::vector<uint32_t> element_ids;
133-
for (uint32_t i = 0; i < original_types.size(); i++) {
134-
Instruction* extract = ir_builder.AddCompositeExtract(
135-
type_mgr->GetId(original_types[i]), object_to_copy->result_id(), {i});
136-
element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]),
137-
insertion_position));
124+
case spv::Op::OpTypeStruct: {
125+
std::vector<uint32_t> element_ids;
126+
for (uint32_t i = 0; i < original_type->NumInOperands(); i++) {
127+
uint32_t orig_member_type_id = original_type->GetSingleWordInOperand(i);
128+
uint32_t new_member_type_id = new_type->GetSingleWordInOperand(i);
129+
Instruction* extract = ir_builder.AddCompositeExtract(
130+
orig_member_type_id, object_to_copy->result_id(), {i});
131+
element_ids.push_back(
132+
GenerateCopy(extract, new_member_type_id, insertion_position));
133+
}
134+
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
135+
->result_id();
138136
}
139-
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
140-
->result_id();
141-
} else {
142-
// If we do not have an aggregate type, then we have a problem. Either we
143-
// found multiple instances of the same type, or we are copying to an
144-
// incompatible type. Either way the code is illegal.
145-
assert(false &&
146-
"Don't know how to copy this type. Code is likely illegal.");
137+
default:
138+
// If we do not have an aggregate type, then we have a problem. Either we
139+
// found multiple instances of the same type, or we are copying to an
140+
// incompatible type. Either way the code is illegal.
141+
assert(false &&
142+
"Don't know how to copy this type. Code is likely illegal.");
147143
}
148144
return 0;
149145
}

test/opt/fix_storage_class_test.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,40 @@ OpFunctionEnd
953953
SinglePassRunAndCheck<FixStorageClass>(text, text, false, false);
954954
}
955955

956+
// Tests that the pass is not confused when there are multiple definitions
957+
// of a pointer type to the same type with the same storage class.
958+
TEST_F(FixStorageClassTest, DuplicatePointerType) {
959+
const std::string text = R"(OpCapability Shader
960+
OpMemoryModel Logical GLSL450
961+
OpEntryPoint GLCompute %1 "main"
962+
OpExecutionMode %1 LocalSize 64 1 1
963+
OpSource HLSL 600
964+
%uint = OpTypeInt 32 0
965+
%uint_0 = OpConstant %uint 0
966+
%uint_3 = OpConstant %uint 3
967+
%_arr_uint_uint_3 = OpTypeArray %uint %uint_3
968+
%void = OpTypeVoid
969+
%7 = OpTypeFunction %void
970+
%_struct_8 = OpTypeStruct %_arr_uint_uint_3
971+
%_ptr_Function__struct_8 = OpTypePointer Function %_struct_8
972+
%_ptr_Function_uint = OpTypePointer Function %uint
973+
%_ptr_Function__arr_uint_uint_3 = OpTypePointer Function %_arr_uint_uint_3
974+
%_ptr_Function_uint_0 = OpTypePointer Function %uint
975+
%_ptr_Function__ptr_Function_uint_0 = OpTypePointer Function %_ptr_Function_uint_0
976+
%1 = OpFunction %void None %7
977+
%14 = OpLabel
978+
%15 = OpVariable %_ptr_Function__ptr_Function_uint_0 Function
979+
%16 = OpVariable %_ptr_Function__struct_8 Function
980+
%17 = OpAccessChain %_ptr_Function__arr_uint_uint_3 %16 %uint_0
981+
%18 = OpAccessChain %_ptr_Function_uint_0 %17 %uint_0
982+
OpStore %15 %18
983+
OpReturn
984+
OpFunctionEnd
985+
)";
986+
987+
SinglePassRunAndCheck<FixStorageClass>(text, text, false);
988+
}
989+
956990
} // namespace
957991
} // namespace opt
958992
} // namespace spvtools

0 commit comments

Comments
 (0)