Skip to content

Commit

Permalink
Fix rebuilding types with circular references (#5623). (#5637)
Browse files Browse the repository at this point in the history
This fixes the problem reported in #5623 using the observation that if
we are re-building a type that already exists in the type pool, we
should just return that type.

This makes type rebuilding more efficient, and it also prevents the
type builder from getting itself into infinite recursion (as reported in
this issue).

In fixing this, I found a couple of other bugs in the type builder:

- When rebuilding an Array type, we were not re-building the element
  type. This caused stale type references in the rebuilt type.

- This bug had not been caught by the test, because the test itself had
  a bug in it: the test was rebuilding types on top of the same ID (the
  ID counter was never incremented).

Initially, the bug in the test caused a failure with the new logic in
the builder because we now return types from the pool directly, which
causes a failure when two incompatible types are registered under the
same ID.

Fixing that issue in the test exposed another bug in the rebuilder: we
were not re-building the element type for Array types. This was causing
a stale type reference inside Array types which was later caught by the
type removal logic in the test.
  • Loading branch information
dnovillo authored Apr 9, 2024
1 parent ade1f7c commit 3983d15
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 27 deletions.
64 changes: 40 additions & 24 deletions source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,13 +517,24 @@ void TypeManager::CreateDecoration(uint32_t target,
context()->get_def_use_mgr()->AnalyzeInstUse(inst);
}

Type* TypeManager::RebuildType(const Type& type) {
Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
assert(type_id != 0);

// The comparison and hash on the type pool will avoid inserting the rebuilt
// type if an equivalent type already exists. The rebuilt type will be deleted
// when it goes out of scope at the end of the function in that case. Repeated
// insertions of the same Type will, at most, keep one corresponding object in
// the type pool.
std::unique_ptr<Type> rebuilt_ty;

// If |type_id| is already present in the type pool, return the existing type.
// This saves extra work in the type builder and prevents running into
// circular issues (https://github.com/KhronosGroup/SPIRV-Tools/issues/5623).
Type* pool_ty = GetType(type_id);
if (pool_ty != nullptr) {
return pool_ty;
}

switch (type.kind()) {
#define DefineNoSubtypeCase(kind) \
case Type::k##kind: \
Expand All @@ -550,51 +561,54 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kVector: {
const Vector* vec_ty = type.AsVector();
const Type* ele_ty = vec_ty->element_type();
rebuilt_ty =
MakeUnique<Vector>(RebuildType(*ele_ty), vec_ty->element_count());
rebuilt_ty = MakeUnique<Vector>(RebuildType(GetId(ele_ty), *ele_ty),
vec_ty->element_count());
break;
}
case Type::kMatrix: {
const Matrix* mat_ty = type.AsMatrix();
const Type* ele_ty = mat_ty->element_type();
rebuilt_ty =
MakeUnique<Matrix>(RebuildType(*ele_ty), mat_ty->element_count());
rebuilt_ty = MakeUnique<Matrix>(RebuildType(GetId(ele_ty), *ele_ty),
mat_ty->element_count());
break;
}
case Type::kImage: {
const Image* image_ty = type.AsImage();
const Type* ele_ty = image_ty->sampled_type();
rebuilt_ty =
MakeUnique<Image>(RebuildType(*ele_ty), image_ty->dim(),
image_ty->depth(), image_ty->is_arrayed(),
image_ty->is_multisampled(), image_ty->sampled(),
image_ty->format(), image_ty->access_qualifier());
rebuilt_ty = MakeUnique<Image>(
RebuildType(GetId(ele_ty), *ele_ty), image_ty->dim(),
image_ty->depth(), image_ty->is_arrayed(),
image_ty->is_multisampled(), image_ty->sampled(), image_ty->format(),
image_ty->access_qualifier());
break;
}
case Type::kSampledImage: {
const SampledImage* image_ty = type.AsSampledImage();
const Type* ele_ty = image_ty->image_type();
rebuilt_ty = MakeUnique<SampledImage>(RebuildType(*ele_ty));
rebuilt_ty =
MakeUnique<SampledImage>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kArray: {
const Array* array_ty = type.AsArray();
rebuilt_ty =
MakeUnique<Array>(array_ty->element_type(), array_ty->length_info());
const Type* ele_ty = array_ty->element_type();
rebuilt_ty = MakeUnique<Array>(RebuildType(GetId(ele_ty), *ele_ty),
array_ty->length_info());
break;
}
case Type::kRuntimeArray: {
const RuntimeArray* array_ty = type.AsRuntimeArray();
const Type* ele_ty = array_ty->element_type();
rebuilt_ty = MakeUnique<RuntimeArray>(RebuildType(*ele_ty));
rebuilt_ty =
MakeUnique<RuntimeArray>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kStruct: {
const Struct* struct_ty = type.AsStruct();
std::vector<const Type*> subtypes;
subtypes.reserve(struct_ty->element_types().size());
for (const auto* ele_ty : struct_ty->element_types()) {
subtypes.push_back(RebuildType(*ele_ty));
subtypes.push_back(RebuildType(GetId(ele_ty), *ele_ty));
}
rebuilt_ty = MakeUnique<Struct>(subtypes);
Struct* rebuilt_struct = rebuilt_ty->AsStruct();
Expand All @@ -611,7 +625,7 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kPointer: {
const Pointer* pointer_ty = type.AsPointer();
const Type* ele_ty = pointer_ty->pointee_type();
rebuilt_ty = MakeUnique<Pointer>(RebuildType(*ele_ty),
rebuilt_ty = MakeUnique<Pointer>(RebuildType(GetId(ele_ty), *ele_ty),
pointer_ty->storage_class());
break;
}
Expand All @@ -621,9 +635,10 @@ Type* TypeManager::RebuildType(const Type& type) {
std::vector<const Type*> param_types;
param_types.reserve(function_ty->param_types().size());
for (const auto* param_ty : function_ty->param_types()) {
param_types.push_back(RebuildType(*param_ty));
param_types.push_back(RebuildType(GetId(param_ty), *param_ty));
}
rebuilt_ty = MakeUnique<Function>(RebuildType(*ret_ty), param_types);
rebuilt_ty = MakeUnique<Function>(RebuildType(GetId(ret_ty), *ret_ty),
param_types);
break;
}
case Type::kForwardPointer: {
Expand All @@ -633,24 +648,25 @@ Type* TypeManager::RebuildType(const Type& type) {
const Pointer* target_ptr = forward_ptr_ty->target_pointer();
if (target_ptr) {
rebuilt_ty->AsForwardPointer()->SetTargetPointer(
RebuildType(*target_ptr)->AsPointer());
RebuildType(GetId(target_ptr), *target_ptr)->AsPointer());
}
break;
}
case Type::kCooperativeMatrixNV: {
const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixNV>(
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
cm_type->columns_id());
RebuildType(GetId(component_type), *component_type),
cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id());
break;
}
case Type::kCooperativeMatrixKHR: {
const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
cm_type->columns_id(), cm_type->use_id());
RebuildType(GetId(component_type), *component_type),
cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id(),
cm_type->use_id());
break;
}
default:
Expand All @@ -669,7 +685,7 @@ Type* TypeManager::RebuildType(const Type& type) {
void TypeManager::RegisterType(uint32_t id, const Type& type) {
// Rebuild |type| so it and all its constituent types are owned by the type
// pool.
Type* rebuilt = RebuildType(type);
Type* rebuilt = RebuildType(id, type);
assert(rebuilt->IsSame(&type));
id_to_type_[id] = rebuilt;
if (GetId(rebuilt) == 0) {
Expand Down
4 changes: 3 additions & 1 deletion source/opt/type_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ class TypeManager {
// Returns an equivalent pointer to |type| built in terms of pointers owned by
// |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt
// replacing the bool subtype with one owned by |type_pool_|.
Type* RebuildType(const Type& type);
//
// The re-built type will have ID |type_id|.
Type* RebuildType(uint32_t type_id, const Type& type);

// Completes the incomplete type |type|, by replaces all references to
// ForwardPointer by the defining Pointer.
Expand Down
38 changes: 36 additions & 2 deletions test/opt/type_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,10 +942,11 @@ OpMemoryModel Logical GLSL450
EXPECT_NE(context, nullptr);

std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
uint32_t id = 1u;
uint32_t id = 0u;
for (auto& t : types) {
context->get_type_mgr()->RegisterType(id, *t);
context->get_type_mgr()->RegisterType(++id, *t);
EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
EXPECT_EQ(id, context->get_type_mgr()->GetId(t.get()));
}
types.clear();

Expand Down Expand Up @@ -1199,6 +1200,39 @@ OpMemoryModel Logical GLSL450
Match(text, context.get());
}

// Structures containing circular type references
// (from https://github.com/KhronosGroup/SPIRV-Tools/issues/5623).
TEST(TypeManager, CircularPointerToStruct) {
const std::string text = R"(
OpCapability VariablePointers
OpCapability PhysicalStorageBufferAddresses
OpCapability Int64
OpCapability Shader
OpExtension "SPV_KHR_variable_pointers"
OpExtension "SPV_KHR_physical_storage_buffer"
OpMemoryModel PhysicalStorageBuffer64 GLSL450
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
OpExecutionMode %1 DepthReplacing
OpDecorate %1200 ArrayStride 24
OpMemberDecorate %600 0 Offset 0
OpMemberDecorate %800 0 Offset 0
OpMemberDecorate %120 0 Offset 16
OpTypeForwardPointer %1200 PhysicalStorageBuffer
%600 = OpTypeStruct %1200
%800 = OpTypeStruct %1200
%120 = OpTypeStruct %800
%1200 = OpTypePointer PhysicalStorageBuffer %120
)";

std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
TypeManager manager(nullptr, context.get());
uint32_t id = manager.FindPointerToType(600, spv::StorageClass::Function);
EXPECT_EQ(id, 1201);
}

} // namespace
} // namespace analysis
} // namespace opt
Expand Down

0 comments on commit 3983d15

Please sign in to comment.