Skip to content

Commit

Permalink
Add validation for SPV_NV_tensor_addressing and SPV_NV_cooperative_ma…
Browse files Browse the repository at this point in the history
…trix2
  • Loading branch information
jeffbolznv committed Oct 23, 2024
1 parent 895bb9f commit 6548915
Show file tree
Hide file tree
Showing 29 changed files with 2,242 additions and 37 deletions.
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ vars = {

're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca',

'spirv_headers_revision': '50bc4debdc3eec5045edbeb8ce164090e29b91f3',
'spirv_headers_revision': '22c4d1b1e9d1c7d9aa5086c93e6491f21080019b',
}

deps = {
Expand Down
6 changes: 6 additions & 0 deletions include/spirv-tools/libspirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ typedef enum spv_operand_type_t {
SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS,
// Optional enum type from SPV_NV_raw_access_chains
SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS,
// Enum type from SPV_NV_tensor_addressing
SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE,
// Enum type from SPV_NV_cooperative_matrix2
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE,
// Enum type from SPV_NV_cooperative_matrix2
SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS,

// This is a sentinel value, and does not represent an operand type.
// It should come last.
Expand Down
1 change: 1 addition & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_tracing_reorder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h
${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp
Expand Down
5 changes: 4 additions & 1 deletion source/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,13 +717,16 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: {
// This operand is a mask.

// Map an optional operand type to its corresponding concrete type.
Expand Down
2 changes: 2 additions & 0 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeRayQueryKHR:
case spv::Op::OpTypeHitObjectNV:
case spv::Op::OpTypeUntypedPointerKHR:
case spv::Op::OpTypeTensorLayoutNV:
case spv::Op::OpTypeTensorViewNV:
return true;
default:
// In particular, OpTypeForwardPointer does not generate a type,
Expand Down
18 changes: 18 additions & 0 deletions source/operand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "cooperative matrix layout";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
return "cooperative matrix use";
case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE:
return "tensor clamp mode";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
return "cooperative matrix reduce";
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
return "tensor addressing operands";
case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER:
return "initialization mode qualifier";
case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER:
Expand Down Expand Up @@ -409,6 +415,8 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
return true;
default:
break;
Expand Down Expand Up @@ -598,6 +606,16 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
case spv::Op::OpTypeArray:
out = [](unsigned index) { return index == 1; };
break;
case spv::Op::OpCooperativeMatrixPerElementOpNV:
out = [](unsigned index) { return index == 3; };
break;
case spv::Op::OpCooperativeMatrixReduceNV:
out = [](unsigned index) { return index == 4; };
break;
case spv::Op::OpCooperativeMatrixLoadTensorNV:
// approximate, due to variable operands
out = [](unsigned index) { return index > 6; };
break;
default:
out = [](unsigned) { return false; };
break;
Expand Down
6 changes: 6 additions & 0 deletions source/opt/aggressive_dead_code_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ constexpr uint32_t kGlobalVariableVariableIndex = 12;
constexpr uint32_t kExtInstSetInIdx = 0;
constexpr uint32_t kExtInstOpInIdx = 1;
constexpr uint32_t kInterpolantInIdx = 2;
constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0;

// Sorting functor to present annotation instructions in an easy-to-process
// order. The functor orders by opcode first and falls back on unique id
Expand Down Expand Up @@ -438,6 +439,11 @@ uint32_t AggressiveDCEPass::GetLoadedVariableFromNonFunctionCalls(
}
break;
}
case spv::Op::OpCooperativeMatrixLoadNV:
case spv::Op::OpCooperativeMatrixLoadKHR:
case spv::Op::OpCooperativeMatrixLoadTensorNV:
return GetVariableId(
inst->GetSingleWordInOperand(kCooperativeMatrixLoadSourceAddrInIdx));
default:
break;
}
Expand Down
15 changes: 15 additions & 0 deletions source/opt/eliminate_dead_members_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ void EliminateDeadMembersPass::FindLiveMembers() {
MarkPointeeTypeAsFullUsed(inst.type_id());
break;
}
} else if (inst.opcode() == spv::Op::OpTypePointer) {
uint32_t storage_class = inst.GetSingleWordInOperand(0);
if (storage_class == uint32_t(spv::StorageClass::PhysicalStorageBuffer)) {
MarkTypeAsFullyUsed(inst.GetSingleWordInOperand(1));
}
}
}

Expand Down Expand Up @@ -200,6 +205,8 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract(
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -246,6 +253,8 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain(
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -505,6 +514,8 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
new_operands.emplace_back(inst->GetInOperand(i));
type_id = type_inst->GetSingleWordInOperand(0);
break;
Expand Down Expand Up @@ -578,6 +589,8 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -639,6 +652,8 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down
28 changes: 27 additions & 1 deletion source/opt/ir_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,35 @@ uint32_t IRContext::GetBuiltinInputVarId(uint32_t builtin) {

void IRContext::AddCalls(const Function* func, std::queue<uint32_t>* todo) {
for (auto bi = func->begin(); bi != func->end(); ++bi)
for (auto ii = bi->begin(); ii != bi->end(); ++ii)
for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
if (ii->opcode() == spv::Op::OpFunctionCall)
todo->push(ii->GetSingleWordInOperand(0));
if (ii->opcode() == spv::Op::OpCooperativeMatrixPerElementOpNV)
todo->push(ii->GetSingleWordInOperand(1));
if (ii->opcode() == spv::Op::OpCooperativeMatrixReduceNV)
todo->push(ii->GetSingleWordInOperand(2));
if (ii->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
const auto memory_operands_index = 3;
auto mask = ii->GetSingleWordInOperand(memory_operands_index);

uint32_t count = 1;
if (mask & uint32_t(spv::MemoryAccessMask::Aligned)) ++count;
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR))
++count;
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR))
++count;

const auto tensor_operands_index = memory_operands_index + count;
mask = ii->GetSingleWordInOperand(tensor_operands_index);
count = 1;
if (mask & uint32_t(spv::TensorAddressingOperandsMask::TensorView))
++count;

if (mask & uint32_t(spv::TensorAddressingOperandsMask::DecodeFunc)) {
todo->push(ii->GetSingleWordInOperand(tensor_operands_index + count));
}
}
}
}

bool IRContext::ProcessEntryPointCallTree(ProcessFunction& pfn) {
Expand Down
48 changes: 48 additions & 0 deletions source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,28 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
{SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
break;
}
case Type::kTensorLayoutNV: {
auto tensor_layout = type->AsTensorLayoutNV();
typeInst = MakeUnique<Instruction>(
context(), spv::Op::OpTypeTensorLayoutNV, 0, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {tensor_layout->dim_id()}},
{SPV_OPERAND_TYPE_ID, {tensor_layout->clamp_mode_id()}}});
break;
}
case Type::kTensorViewNV: {
auto tensor_view = type->AsTensorViewNV();
std::vector<Operand> operands;
operands.push_back(Operand{SPV_OPERAND_TYPE_ID, {tensor_view->dim_id()}});
operands.push_back(
Operand{SPV_OPERAND_TYPE_ID, {tensor_view->has_dimensions_id()}});
for (auto p : tensor_view->perm()) {
operands.push_back(Operand{SPV_OPERAND_TYPE_ID, {p}});
}
typeInst = MakeUnique<Instruction>(context(), spv::Op::OpTypeTensorViewNV,
0, id, operands);
break;
}
default:
assert(false && "Unexpected type");
break;
Expand Down Expand Up @@ -667,6 +689,18 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
cm_type->use_id());
break;
}
case Type::kTensorLayoutNV: {
const TensorLayoutNV* tl_type = type.AsTensorLayoutNV();
rebuilt_ty = MakeUnique<TensorLayoutNV>(tl_type->dim_id(),
tl_type->clamp_mode_id());
break;
}
case Type::kTensorViewNV: {
const TensorViewNV* tv_type = type.AsTensorViewNV();
rebuilt_ty = MakeUnique<TensorViewNV>(
tv_type->dim_id(), tv_type->has_dimensions_id(), tv_type->perm());
break;
}
default:
assert(false && "Unhandled type");
return nullptr;
Expand Down Expand Up @@ -914,6 +948,20 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
case spv::Op::OpTypeHitObjectNV:
type = new HitObjectNV();
break;
case spv::Op::OpTypeTensorLayoutNV:
type = new TensorLayoutNV(inst.GetSingleWordInOperand(0),
inst.GetSingleWordInOperand(1));
break;
case spv::Op::OpTypeTensorViewNV: {
const auto count = inst.NumOperands();
std::vector<uint32_t> perm;
for (uint32_t i = 2; i < count; ++i) {
perm.push_back(inst.GetSingleWordOperand(i));
}
type = new TensorViewNV(inst.GetSingleWordInOperand(0),
inst.GetSingleWordInOperand(1), perm);
break;
}
default:
assert(false && "Type not handled by the type manager.");
break;
Expand Down
54 changes: 53 additions & 1 deletion source/opt/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
DeclareKindCase(TensorViewNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
Expand Down Expand Up @@ -235,6 +237,8 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
DeclareKindCase(TensorViewNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
Expand Down Expand Up @@ -747,7 +751,55 @@ bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
if (!mt) return false;
return component_type_->IsSameImpl(mt->component_type_, seen) &&
scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
columns_id_ == mt->columns_id_ && use_id_ == mt->use_id_ &&
HasSameDecorations(that);
}

TensorLayoutNV::TensorLayoutNV(const uint32_t dim, const uint32_t clamp_mode)
: Type(kTensorLayoutNV), dim_id_(dim), clamp_mode_id_(clamp_mode) {}

std::string TensorLayoutNV::str() const {
std::ostringstream oss;
oss << "<" << dim_id_ << ", " << clamp_mode_id_ << ">";
return oss.str();
}

size_t TensorLayoutNV::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
return hash_combine(hash, dim_id_, clamp_mode_id_);
}

bool TensorLayoutNV::IsSameImpl(const Type* that, IsSameCache*) const {
const TensorLayoutNV* tl = that->AsTensorLayoutNV();
if (!tl) return false;
return dim_id_ == tl->dim_id_ && clamp_mode_id_ == tl->clamp_mode_id_;
}

TensorViewNV::TensorViewNV(const uint32_t dim, const uint32_t clamp_mode,
const std::vector<uint32_t>& perm)
: Type(kTensorViewNV),
dim_id_(dim),
has_dimensions_id_(clamp_mode),
perm_(perm) {}

std::string TensorViewNV::str() const {
std::ostringstream oss;
oss << "<" << dim_id_ << ", " << has_dimensions_id_;
for (auto p : perm_) {
oss << ", " << p;
}
oss << ">";
return oss.str();
}

size_t TensorViewNV::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
return hash_combine(hash, dim_id_, has_dimensions_id_, perm_);
}

bool TensorViewNV::IsSameImpl(const Type* that, IsSameCache*) const {
const TensorViewNV* tv = that->AsTensorViewNV();
if (!tv) return false;
return dim_id_ == tv->dim_id_ &&
has_dimensions_id_ == tv->has_dimensions_id_ && perm_ == tv->perm_;
}

} // namespace analysis
Expand Down
Loading

0 comments on commit 6548915

Please sign in to comment.