From ad3958c2d8eac8f00ba498b1c629150fb0da3429 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Fri, 8 Dec 2023 10:34:31 -0800 Subject: [PATCH] Refactor stage var function (#6099) The next step in the refactoring is to reduce the number of variables passed to various functions. This should make the logic easier to follow. This is done by grouping variables into a struct, and we pass the struct along. --------- Co-authored-by: Cassandra Beckley --- tools/clang/lib/SPIRV/DeclResultIdMapper.cpp | 523 ++++++++++--------- tools/clang/lib/SPIRV/DeclResultIdMapper.h | 241 ++++----- 2 files changed, 389 insertions(+), 375 deletions(-) diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index bb3baecf19..5db014ca95 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -880,9 +880,11 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl, const bool noWriteBack = storedValue == nullptr || spvContext.isGS() || spvContext.isMS(); - return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize, - "out.var", llvm::None, &storedValue, noWriteBack, - &inheritSemantic); + StageVarDataBundle stageVarData = { + decl, &inheritSemantic, false, sigPoint, + type, arraySize, "out.var", llvm::None}; + return createStageVars(stageVarData, /*asInput=*/false, &storedValue, + noWriteBack); } bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl, @@ -898,9 +900,11 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl, SemanticInfo inheritSemantic = {}; - return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize, - "out.var", invocationId, &storedValue, - /*noWriteBack=*/false, &inheritSemantic); + StageVarDataBundle stageVarData = { + decl, &inheritSemantic, false, sigPoint, + type, arraySize, "out.var", invocationId}; + return createStageVars(stageVarData, /*asInput=*/false, &storedValue, + /*noWriteBack=*/false); } bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl, @@ -938,9 +942,11 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl, return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true, type, "in.var", loadedValue); } else { - return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, - arraySize, "in.var", llvm::None, loadedValue, - /*noWriteBack=*/false, &inheritSemantic); + StageVarDataBundle stageVarData = { + paramDecl, &inheritSemantic, false, sigPoint, + type, arraySize, "in.var", llvm::None}; + return createStageVars(stageVarData, /*asInput=*/true, loadedValue, + /*noWriteBack=*/false); } } @@ -2602,22 +2608,22 @@ bool DeclResultIdMapper::decorateResourceCoherent() { } bool DeclResultIdMapper::createStructOutputVar( - QualType type, const hlsl::SigPoint *sigPoint, uint32_t arraySize, - llvm::Optional invocationId, - const llvm::StringRef namePrefix, SemanticInfo *semantic, bool noWriteBack, - bool asNoInterp, SpirvInstruction *value, SourceLocation loc) { + const StageVarDataBundle &stageVarData, SpirvInstruction *value, + bool noWriteBack) { // If we have base classes, we need to handle them first. - if (const auto *cxxDecl = type->getAsCXXRecordDecl()) { + if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) { uint32_t baseIndex = 0; for (auto base : cxxDecl->bases()) { SpirvInstruction *subValue = nullptr; if (!noWriteBack) - subValue = spvBuilder.createCompositeExtract(base.getType(), value, - {baseIndex++}, loc); - - if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(), - false, base.getType(), arraySize, namePrefix, - invocationId, &subValue, noWriteBack, semantic)) + subValue = spvBuilder.createCompositeExtract( + base.getType(), value, {baseIndex++}, + stageVarData.decl->getLocation()); + + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = base.getType()->getAsCXXRecordDecl(); + memberVarData.type = base.getType(); + if (!createStageVars(memberVarData, false, &subValue, noWriteBack)) return false; } } @@ -2636,64 +2642,65 @@ bool DeclResultIdMapper::createStructOutputVar( // // The interesting shader stage is HS. We need the InvocationID to write // out the value to the correct array element. - const auto *structDecl = type->getAs()->getDecl(); + const auto *structDecl = stageVarData.type->getAs()->getDecl(); for (const auto *field : structDecl->fields()) { const auto fieldType = field->getType(); SpirvInstruction *subValue = nullptr; if (!noWriteBack) { subValue = spvBuilder.createCompositeExtract( - fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()}, - loc); + fieldType, value, + {getNumBaseClasses(stageVarData.type) + field->getFieldIndex()}, + stageVarData.decl->getLocation()); if (field->hasAttr() || structDecl->hasAttr()) subValue->setNoninterpolated(); } - if (!createStageVars( - sigPoint, field, false, field->getType(), arraySize, namePrefix, - invocationId, &subValue, noWriteBack, semantic, - asNoInterp || field->hasAttr())) + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = field; + memberVarData.type = field->getType(); + memberVarData.asNoInterp |= field->hasAttr(); + if (!createStageVars(memberVarData, false, &subValue, noWriteBack)) return false; } return true; } -SpirvInstruction *DeclResultIdMapper::createStructInputVar( - QualType type, const hlsl::SigPoint *sigPoint, uint32_t arraySize, - const llvm::StringRef namePrefix, bool asNoInterp, bool noWriteBack, - SemanticInfo *semanticToUse, SourceLocation loc) { +SpirvInstruction * +DeclResultIdMapper::createStructInputVar(const StageVarDataBundle &stageVarData, + bool noWriteBack) { // If this decl translates into multiple stage input variables, we need to // load their values into a composite. llvm::SmallVector subValues; // If we have base classes, we need to handle them first. - if (const auto *cxxDecl = type->getAsCXXRecordDecl()) { + if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) { for (auto base : cxxDecl->bases()) { SpirvInstruction *subValue = nullptr; - if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(), true, - base.getType(), arraySize, namePrefix, - llvm::Optional(), &subValue, - noWriteBack, semanticToUse)) + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = base.getType()->getAsCXXRecordDecl(); + memberVarData.type = base.getType(); + if (!createStageVars(memberVarData, true, &subValue, noWriteBack)) return nullptr; subValues.push_back(subValue); } } - const auto *structDecl = type->getAs()->getDecl(); + const auto *structDecl = stageVarData.type->getAs()->getDecl(); for (const auto *field : structDecl->fields()) { SpirvInstruction *subValue = nullptr; - if (!createStageVars(sigPoint, field, true, field->getType(), arraySize, - namePrefix, llvm::Optional(), - &subValue, noWriteBack, semanticToUse, - asNoInterp || - field->hasAttr())) + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = field; + memberVarData.type = field->getType(); + memberVarData.asNoInterp |= field->hasAttr(); + if (!createStageVars(memberVarData, true, &subValue, noWriteBack)) return nullptr; subValues.push_back(subValue); } - if (arraySize == 0) { - SpirvInstruction *value = - spvBuilder.createCompositeConstruct(type, subValues, loc); + if (stageVarData.arraySize == 0) { + SpirvInstruction *value = spvBuilder.createCompositeConstruct( + stageVarData.type, subValues, stageVarData.decl->getLocation()); for (auto *subInstr : subValues) spvBuilder.addPerVertexStgInputFuncVarEntry(subInstr, value); return value; @@ -2705,22 +2712,25 @@ SpirvInstruction *DeclResultIdMapper::createStructInputVar( // from visiting all fields. So now we need to extract all the elements // at the same index of each field arrays and compose a new struct out // of them. - const auto structType = type; + const auto structType = stageVarData.type; const auto arrayType = astContext.getConstantArrayType( - structType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0); + structType, llvm::APInt(32, stageVarData.arraySize), + clang::ArrayType::Normal, 0); llvm::SmallVector arrayElements; - for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) { + for (uint32_t arrayIndex = 0; arrayIndex < stageVarData.arraySize; + ++arrayIndex) { llvm::SmallVector fields; // If we have base classes, we need to handle them first. - if (const auto *cxxDecl = type->getAsCXXRecordDecl()) { + if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) { uint32_t baseIndex = 0; for (auto base : cxxDecl->bases()) { const auto baseType = base.getType(); fields.push_back(spvBuilder.createCompositeExtract( - baseType, subValues[baseIndex++], {arrayIndex}, loc)); + baseType, subValues[baseIndex++], {arrayIndex}, + stageVarData.decl->getLocation())); } } @@ -2729,41 +2739,42 @@ SpirvInstruction *DeclResultIdMapper::createStructInputVar( const auto fieldType = field->getType(); fields.push_back(spvBuilder.createCompositeExtract( fieldType, - subValues[getNumBaseClasses(type) + field->getFieldIndex()], - {arrayIndex}, loc)); + subValues[getNumBaseClasses(stageVarData.type) + + field->getFieldIndex()], + {arrayIndex}, stageVarData.decl->getLocation())); } // Compose a new struct out of them - arrayElements.push_back( - spvBuilder.createCompositeConstruct(structType, fields, loc)); + arrayElements.push_back(spvBuilder.createCompositeConstruct( + structType, fields, stageVarData.decl->getLocation())); } - return spvBuilder.createCompositeConstruct(arrayType, arrayElements, loc); + return spvBuilder.createCompositeConstruct(arrayType, arrayElements, + stageVarData.decl->getLocation()); } void DeclResultIdMapper::storeToShaderOutputVariable( - SpirvVariable *varInstr, hlsl::Semantic::Kind semanticKind, QualType type, - SpirvInstruction *value, llvm::Optional invocationId, - hlsl::SigPoint::Kind sigPointKind, const NamedDecl *decl, - SourceLocation loc) { + SpirvVariable *varInstr, SpirvInstruction *value, + const StageVarDataBundle &stageVarData) { SpirvInstruction *ptr = varInstr; // Special handling of SV_TessFactor HS patch constant output. // TessLevelOuter is always an array of size 4 in SPIR-V, but // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the // relevant indexes must be written to. - if (semanticKind == hlsl::Semantic::Kind::TessFactor && - hlsl::GetArraySize(type) != 4) { - const auto tessFactorSize = hlsl::GetArraySize(type); + if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::TessFactor && + hlsl::GetArraySize(stageVarData.type) != 4) { + const auto tessFactorSize = hlsl::GetArraySize(stageVarData.type); for (uint32_t i = 0; i < tessFactorSize; ++i) { ptr = spvBuilder.createAccessChain( astContext.FloatTy, varInstr, {spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, i))}, - loc); - spvBuilder.createStore(ptr, - spvBuilder.createCompositeExtract( - astContext.FloatTy, value, {i}, loc), - loc); + stageVarData.decl->getLocation()); + spvBuilder.createStore( + ptr, + spvBuilder.createCompositeExtract(astContext.FloatTy, value, {i}, + stageVarData.decl->getLocation()), + stageVarData.decl->getLocation()); } } // Special handling of SV_InsideTessFactor HS patch constant output. @@ -2771,110 +2782,121 @@ void DeclResultIdMapper::storeToShaderOutputVariable( // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of // TessLevelInner. - else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor && + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::InsideTessFactor && // Some developers use float[1] instead of a scalar float. - (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) { + (!stageVarData.type->isArrayType() || + hlsl::GetArraySize(stageVarData.type) == 1)) { ptr = spvBuilder.createAccessChain( astContext.FloatTy, varInstr, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)), - loc); - if (type->isArrayType()) // float[1] - value = spvBuilder.createCompositeExtract(astContext.FloatTy, value, {0}, - loc); - spvBuilder.createStore(ptr, value, loc); + stageVarData.decl->getLocation()); + if (stageVarData.type->isArrayType()) // float[1] + value = spvBuilder.createCompositeExtract( + astContext.FloatTy, value, {0}, stageVarData.decl->getLocation()); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // Special handling of SV_Coverage, which is an unit value. We need to // write it to the first element in the SampleMask builtin. - else if (semanticKind == hlsl::Semantic::Kind::Coverage) { + else if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Coverage) { ptr = spvBuilder.createAccessChain( - type, varInstr, + stageVarData.type, varInstr, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)), - loc); + stageVarData.decl->getLocation()); ptr->setStorageClass(spv::StorageClass::Output); - spvBuilder.createStore(ptr, value, loc); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // Special handling of HS ouput, for which we write to only one // element in the per-vertex data array: the one indexed by // SV_ControlPointID. - else if (invocationId.hasValue() && invocationId.getValue() != nullptr) { + else if (stageVarData.invocationId.hasValue() && + stageVarData.invocationId.getValue() != nullptr) { // Remove the arrayness to get the element type. assert(isa(varInstr->getAstResultType())); const auto elementType = astContext.getAsArrayType(varInstr->getAstResultType()) ->getElementType(); - auto index = invocationId.getValue(); - ptr = spvBuilder.createAccessChain(elementType, varInstr, index, loc); + auto index = stageVarData.invocationId.getValue(); + ptr = spvBuilder.createAccessChain(elementType, varInstr, index, + stageVarData.decl->getLocation()); ptr->setStorageClass(spv::StorageClass::Output); - spvBuilder.createStore(ptr, value, loc); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // Since boolean output stage variables are represented as unsigned // integers, we must cast the value to uint before storing. - else if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) { - value = - theEmitter.castToType(value, type, varInstr->getAstResultType(), loc); - spvBuilder.createStore(ptr, value, loc); + else if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + value = theEmitter.castToType(value, stageVarData.type, + varInstr->getAstResultType(), + stageVarData.decl->getLocation()); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // For all normal cases else { - spvBuilder.createStore(ptr, value, loc); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } } SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable( - SpirvVariable *varInstr, hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, QualType type, const NamedDecl *decl, - SourceLocation loc) { - SpirvInstruction *load = - spvBuilder.createLoad(varInstr->getAstResultType(), varInstr, loc); + SpirvVariable *varInstr, const StageVarDataBundle &stageVarData) { + SpirvInstruction *load = spvBuilder.createLoad( + varInstr->getAstResultType(), varInstr, stageVarData.decl->getLocation()); // Fix ups for corner cases // Special handling of SV_TessFactor DS patch constant input. // TessLevelOuter is always an array of size 4 in SPIR-V, but // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the // relevant indexes must be loaded. - if (semanticKind == hlsl::Semantic::Kind::TessFactor && - hlsl::GetArraySize(type) != 4) { + if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::TessFactor && + hlsl::GetArraySize(stageVarData.type) != 4) { llvm::SmallVector components; - const auto tessFactorSize = hlsl::GetArraySize(type); + const auto tessFactorSize = hlsl::GetArraySize(stageVarData.type); const auto arrType = astContext.getConstantArrayType( astContext.FloatTy, llvm::APInt(32, tessFactorSize), clang::ArrayType::Normal, 0); for (uint32_t i = 0; i < tessFactorSize; ++i) - components.push_back(spvBuilder.createCompositeExtract(astContext.FloatTy, - load, {i}, loc)); - load = spvBuilder.createCompositeConstruct(arrType, components, loc); + components.push_back(spvBuilder.createCompositeExtract( + astContext.FloatTy, load, {i}, stageVarData.decl->getLocation())); + load = spvBuilder.createCompositeConstruct( + arrType, components, stageVarData.decl->getLocation()); } // Special handling of SV_InsideTessFactor DS patch constant input. // TessLevelInner is always an array of size 2 in SPIR-V, but // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of // TessLevelInner. - else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor && + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::InsideTessFactor && // Some developers use float[1] instead of a scalar float. - (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) { - load = - spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0}, loc); - if (type->isArrayType()) { // float[1] + (!stageVarData.type->isArrayType() || + hlsl::GetArraySize(stageVarData.type) == 1)) { + load = spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0}, + stageVarData.decl->getLocation()); + if (stageVarData.type->isArrayType()) { // float[1] const auto arrType = astContext.getConstantArrayType( astContext.FloatTy, llvm::APInt(32, 1), clang::ArrayType::Normal, 0); - load = spvBuilder.createCompositeConstruct(arrType, {load}, loc); + load = spvBuilder.createCompositeConstruct( + arrType, {load}, stageVarData.decl->getLocation()); } } // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord // is always a float3. To ensure SPIR-V validity, a float3 stage variable // is created, and we must extract a float2 from it before passing it to // the main function. - else if (semanticKind == hlsl::Semantic::Kind::DomainLocation && - hlsl::GetHLSLVecSize(type) != 3) { - const auto domainLocSize = hlsl::GetHLSLVecSize(type); + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::DomainLocation && + hlsl::GetHLSLVecSize(stageVarData.type) != 3) { + const auto domainLocSize = hlsl::GetHLSLVecSize(stageVarData.type); load = spvBuilder.createVectorShuffle( astContext.getExtVectorType(astContext.FloatTy, domainLocSize), load, - load, {0, 1}, loc); + load, {0, 1}, stageVarData.decl->getLocation()); } // Special handling of SV_Coverage, which is an uint value. We need to // read SampleMask and extract its first element. - else if (semanticKind == hlsl::Semantic::Kind::Coverage) { - load = spvBuilder.createCompositeExtract(type, load, {0}, loc); + else if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Coverage) { + load = spvBuilder.createCompositeExtract(stageVarData.type, load, {0}, + stageVarData.decl->getLocation()); } // Special handling of SV_InnerCoverage, which is an uint value. We need // to read FullyCoveredEXT, which is a boolean value, and convert it to an @@ -2887,85 +2909,96 @@ SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable( // but are undefined when bit 0 is set to 1 (essentially, this bit-field // represents a Boolean value where false must be exactly 0, but true can // be any odd (i.e. bit 0 set) non-zero value)." - else if (semanticKind == hlsl::Semantic::Kind::InnerCoverage) { + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::InnerCoverage) { const auto constOne = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1)); const auto constZero = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); load = spvBuilder.createSelect(astContext.UnsignedIntTy, load, constOne, - constZero, loc); + constZero, stageVarData.decl->getLocation()); } // Special handling of SV_Barycentrics, which is a float3, but the // The 3 values are NOT guaranteed to add up to floating-point 1.0 // exactly. Calculate the third element here. - else if (semanticKind == hlsl::Semantic::Kind::Barycentrics) { - const auto x = - spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0}, loc); - const auto y = - spvBuilder.createCompositeExtract(astContext.FloatTy, load, {1}, loc); - const auto xy = spvBuilder.createBinaryOp(spv::Op::OpFAdd, - astContext.FloatTy, x, y, loc); + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::Barycentrics) { + const auto x = spvBuilder.createCompositeExtract( + astContext.FloatTy, load, {0}, stageVarData.decl->getLocation()); + const auto y = spvBuilder.createCompositeExtract( + astContext.FloatTy, load, {1}, stageVarData.decl->getLocation()); + const auto xy = + spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, x, y, + stageVarData.decl->getLocation()); const auto z = spvBuilder.createBinaryOp( spv::Op::OpFSub, astContext.FloatTy, spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)), - xy, loc); + xy, stageVarData.decl->getLocation()); load = spvBuilder.createCompositeConstruct( - astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z}, loc); + astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z}, + stageVarData.decl->getLocation()); } // Special handling of SV_DispatchThreadID and SV_GroupThreadID, which may // be a uint or uint2, but the underlying stage input variable is a uint3. // The last component(s) should be discarded in needed. - else if ((semanticKind == hlsl::Semantic::Kind::DispatchThreadID || - semanticKind == hlsl::Semantic::Kind::GroupThreadID || - semanticKind == hlsl::Semantic::Kind::GroupID) && - (!hlsl::IsHLSLVecType(type) || hlsl::GetHLSLVecSize(type) != 3)) { + else if ((stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::DispatchThreadID || + stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::GroupThreadID || + stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::GroupID) && + (!hlsl::IsHLSLVecType(stageVarData.type) || + hlsl::GetHLSLVecSize(stageVarData.type) != 3)) { const auto srcVecElemType = - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type; - const auto vecSize = - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1; + hlsl::IsHLSLVecType(stageVarData.type) + ? hlsl::GetHLSLVecElementType(stageVarData.type) + : stageVarData.type; + const auto vecSize = hlsl::IsHLSLVecType(stageVarData.type) + ? hlsl::GetHLSLVecSize(stageVarData.type) + : 1; if (vecSize == 1) - load = spvBuilder.createCompositeExtract(srcVecElemType, load, {0}, loc); + load = spvBuilder.createCompositeExtract( + srcVecElemType, load, {0}, stageVarData.decl->getLocation()); else if (vecSize == 2) load = spvBuilder.createVectorShuffle( astContext.getExtVectorType(srcVecElemType, 2), load, load, {0, 1}, - loc); + stageVarData.decl->getLocation()); } // Reciprocate SV_Position.w if requested - if (semanticKind == hlsl::Semantic::Kind::Position) - load = invertWIfRequested(load, loc); + if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Position) + load = invertWIfRequested(load, stageVarData.decl->getLocation()); // Since boolean stage input variables are represented as unsigned // integers, after loading them, we should cast them to boolean. - if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) { - load = theEmitter.castToType(load, varInstr->getAstResultType(), type, loc); + if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + load = theEmitter.castToType(load, varInstr->getAstResultType(), + stageVarData.type, + stageVarData.decl->getLocation()); } return load; } -bool DeclResultIdMapper::validateShaderStageVar(SemanticInfo *semantic, - const hlsl::SigPoint *sigPoint, - const NamedDecl *decl, - QualType type) { - const auto semanticKind = semantic->getKind(); - const auto sigPointKind = sigPoint->GetKind(); - - if (!validateVKAttributes(decl)) +bool DeclResultIdMapper::validateShaderStageVar( + const StageVarDataBundle &stageVarData) { + if (!validateVKAttributes(stageVarData.decl)) return false; - if (!isValidSemanticInShaderModel(semanticKind, sigPointKind, decl)) { + if (!isValidSemanticInShaderModel(stageVarData)) { emitError("invalid usage of semantic '%0' in shader profile %1", - decl->getLocation()) - << semantic->str + stageVarData.decl->getLocation()) + << stageVarData.semantic->str << hlsl::ShaderModel::GetKindName( spvContext.getCurrentShaderModelKind()); return false; } - if (!validateVKBuiltins(decl, sigPoint)) + if (!validateVKBuiltins(stageVarData)) return false; - if (!validateShaderStageVarType(semanticKind, type, decl->getLocation())) + if (!validateShaderStageVarType(stageVarData)) return false; return true; } @@ -3006,17 +3039,18 @@ bool DeclResultIdMapper::validateVKAttributes(const NamedDecl *decl) { return success; } -bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, - const hlsl::SigPoint *sigPoint) { +bool DeclResultIdMapper::validateVKBuiltins( + const StageVarDataBundle &stageVarData) { bool success = true; - if (const auto *builtinAttr = decl->getAttr()) { + if (const auto *builtinAttr = stageVarData.decl->getAttr()) { // The front end parsing only allows vk::builtin to be attached to a // function/parameter/variable; all of them are DeclaratorDecls. - const auto declType = getTypeOrFnRetType(cast(decl)); + const auto declType = + getTypeOrFnRetType(cast(stageVarData.decl)); const auto loc = builtinAttr->getLocation(); - if (decl->hasAttr()) { + if (stageVarData.decl->hasAttr()) { emitError("cannot use vk::builtin and vk::location together", loc); success = false; } @@ -3029,7 +3063,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } - if (sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) { + if (stageVarData.sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) { emitError( "HelperInvocation builtin can only be used as pixel shader input", loc); @@ -3041,7 +3075,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } - switch (sigPoint->GetKind()) { + switch (stageVarData.sigPoint->GetKind()) { case hlsl::SigPoint::Kind::VSOut: case hlsl::SigPoint::Kind::HSCPIn: case hlsl::SigPoint::Kind::HSCPOut: @@ -3054,7 +3088,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, break; default: emitError("PointSize builtin cannot be used as %0", loc) - << sigPoint->GetName(); + << stageVarData.sigPoint->GetName(); success = false; } } else if (builtin == "BaseVertex" || builtin == "BaseInstance" || @@ -3066,24 +3100,25 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } - switch (sigPoint->GetKind()) { + switch (stageVarData.sigPoint->GetKind()) { case hlsl::SigPoint::Kind::VSIn: break; case hlsl::SigPoint::Kind::MSIn: case hlsl::SigPoint::Kind::ASIn: if (builtin != "DrawIndex") { emitError("%0 builtin cannot be used as %1", loc) - << builtin << sigPoint->GetName(); + << builtin << stageVarData.sigPoint->GetName(); success = false; } break; default: emitError("%0 builtin cannot be used as %1", loc) - << builtin << sigPoint->GetName(); + << builtin << stageVarData.sigPoint->GetName(); success = false; } } else if (builtin == "DeviceIndex") { - if (getStorageClassForSigPoint(sigPoint) != spv::StorageClass::Input) { + if (getStorageClassForSigPoint(stageVarData.sigPoint) != + spv::StorageClass::Input) { emitError("%0 builtin can only be used as shader input", loc) << builtin; success = false; @@ -3095,7 +3130,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } } else if (builtin == "ViewportMaskNV") { - if (sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) { + if (stageVarData.sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) { emitError("%0 builtin can only be used as 'primitives' output in MS", loc) << builtin; @@ -3115,13 +3150,13 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, } bool DeclResultIdMapper::validateShaderStageVarType( - hlsl::Semantic::Kind semanticKind, QualType type, - clang::SourceLocation loc) { + const StageVarDataBundle &stageVarData) { - switch (semanticKind) { + switch (stageVarData.semantic->getKind()) { case hlsl::Semantic::Kind::InnerCoverage: - if (!type->isSpecificBuiltinType(BuiltinType::UInt)) { - emitError("SV_InnerCoverage must be of uint type.", loc); + if (!stageVarData.type->isSpecificBuiltinType(BuiltinType::UInt)) { + emitError("SV_InnerCoverage must be of uint type.", + stageVarData.decl->getLocation()); return false; } break; @@ -3132,18 +3167,18 @@ bool DeclResultIdMapper::validateShaderStageVarType( } bool DeclResultIdMapper::isValidSemanticInShaderModel( - hlsl::Semantic::Kind semanticKind, hlsl::SigPoint::Kind sigPointKind, - const NamedDecl *decl) { + const StageVarDataBundle &stageVarData) { // Error out when the given semantic is invalid in this shader model - if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPointKind, - spvContext.getMajorVersion(), - spvContext.getMinorVersion()) == + if (hlsl::SigPoint::GetInterpretation( + stageVarData.semantic->getKind(), stageVarData.sigPoint->GetKind(), + spvContext.getMajorVersion(), spvContext.getMinorVersion()) == hlsl::DXIL::SemanticInterpretationKind::NA) { // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex". - switch (sigPointKind) { + switch (stageVarData.sigPoint->GetKind()) { case hlsl::SigPoint::Kind::MSIn: case hlsl::SigPoint::Kind::ASIn: - if (const auto *builtinAttr = decl->getAttr()) { + if (const auto *builtinAttr = + stageVarData.decl->getAttr()) { const llvm::StringRef builtin = builtinAttr->getBuiltIn(); if (builtin == "DrawIndex") { break; @@ -3191,38 +3226,36 @@ DeclResultIdMapper::getBaseInstanceVariable(SemanticInfo *semantic, } SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( - const hlsl::SigPoint *sigPoint, SemanticInfo *semanticToUse, - const NamedDecl *decl, QualType type, uint32_t arraySize, bool asNoInterp, - const llvm::StringRef namePrefix) { + const StageVarDataBundle &stageVarData) { // The evalType will be the type of the interface variable in SPIR-V. // The type of the variable used in the body of the function will still be // `type`. - QualType evalType = getTypeForSpirvStageVariable( - type, semanticToUse->getKind(), sigPoint->GetKind(), decl, arraySize); + QualType evalType = getTypeForSpirvStageVariable(stageVarData); - const auto *builtinAttr = decl->getAttr(); + const auto *builtinAttr = stageVarData.decl->getAttr(); StageVar stageVar( - sigPoint, *semanticToUse, builtinAttr, evalType, + stageVarData.sigPoint, *stageVarData.semantic, builtinAttr, evalType, // For HS/DS/GS, we have already stripped the outmost arrayness on type. - getLocationAndComponentCount(astContext, type)); - const auto name = namePrefix.str() + "." + stageVar.getSemanticStr(); - SpirvVariable *varInstr = - createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc); + getLocationAndComponentCount(astContext, stageVarData.type)); + const auto name = + stageVarData.namePrefix.str() + "." + stageVar.getSemanticStr(); + SpirvVariable *varInstr = createSpirvStageVar( + &stageVar, stageVarData.decl, name, stageVarData.semantic->loc); if (!varInstr) return nullptr; - if (asNoInterp) + if (stageVarData.asNoInterp) varInstr->setNoninterpolated(); stageVar.setSpirvInstr(varInstr); - stageVar.setLocationAttr(decl->getAttr()); - stageVar.setIndexAttr(decl->getAttr()); + stageVar.setLocationAttr(stageVarData.decl->getAttr()); + stageVar.setIndexAttr(stageVarData.decl->getAttr()); if (stageVar.getStorageClass() == spv::StorageClass::Input || stageVar.getStorageClass() == spv::StorageClass::Output) { stageVar.setEntryPoint(entryFunction); } - decorateStageVarWithIntrinsicAttrs(decl, &stageVar, varInstr); + decorateStageVarWithIntrinsicAttrs(stageVarData.decl, &stageVar, varInstr); stageVars.push_back(stageVar); // Emit OpDecorate* instructions to link this stage variable with the HLSL @@ -3230,9 +3263,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr()); // TODO: the following may not be correct? - if (sigPoint->GetSignatureKind() == + if (stageVarData.sigPoint->GetSignatureKind() == hlsl::DXIL::SignatureKind::PatchConstOrPrim) { - if (sigPoint->GetKind() == hlsl::SigPoint::Kind::MSPOut) { + if (stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::MSPOut) { // Decorate with PerPrimitiveNV for per-primitive out variables. spvBuilder.decoratePerPrimitiveNV(varInstr, varInstr->getSourceLocation()); @@ -3243,9 +3276,10 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( // Decorate with interpolation modes for pixel shader input variables // or vertex shader output variables. - if ((spvContext.isPS() && sigPoint->IsInput()) || - (spvContext.isVS() && sigPoint->IsOutput())) - decorateInterpolationMode(decl, type, varInstr, *semanticToUse); + if ((spvContext.isPS() && stageVarData.sigPoint->IsInput()) || + (spvContext.isVS() && stageVarData.sigPoint->IsOutput())) + decorateInterpolationMode(stageVarData.decl, stageVarData.type, varInstr, + *stageVarData.semantic); // Special case: The DX12 SV_InstanceID always counts from 0, even if the // StartInstanceLocation parameter is non-zero. gl_InstanceIndex, however, @@ -3271,12 +3305,13 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( // drawing command or the firstInstance member of a structure consumed by // an indirect drawing command. if (spirvOptions.supportNonzeroBaseInstance && - semanticToUse->getKind() == hlsl::Semantic::Kind::InstanceID && - sigPoint->GetKind() == hlsl::SigPoint::Kind::VSIn) { + stageVarData.semantic->getKind() == hlsl::Semantic::Kind::InstanceID && + stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::VSIn) { // The above call to createSpirvStageVar creates the gl_InstanceIndex. // We should now manually create the gl_BaseInstance variable and do the // subtraction. - auto *baseInstanceVar = getBaseInstanceVariable(semanticToUse, sigPoint); + auto *baseInstanceVar = + getBaseInstanceVariable(stageVarData.semantic, stageVarData.sigPoint); // SPIR-V code for 'SV_InstanceID = gl_InstanceIndex - gl_BaseInstance' varInstr = getInstanceIdFromIndexAndBase(varInstr, baseInstanceVar); @@ -3284,17 +3319,15 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( // We have semantics attached to this decl, which means it must be a // function/parameter/variable. All are DeclaratorDecls. - stageVarInstructions[cast(decl)] = varInstr; + stageVarInstructions[cast(stageVarData.decl)] = varInstr; return varInstr; } QualType DeclResultIdMapper::getTypeForSpirvStageVariable( - QualType type, hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, const NamedDecl *decl, - uint32_t arraySize) { - QualType evalType = type; - switch (semanticKind) { + const StageVarDataBundle &stageVarData) { + QualType evalType = stageVarData.type; + switch (stageVarData.semantic->getKind()) { case hlsl::Semantic::Kind::DomainLocation: // SV_DomainLocation can refer to a float2, whereas TessCoord is a float3. // To ensure SPIR-V validity, we must create a float3 and extract a @@ -3338,7 +3371,9 @@ QualType DeclResultIdMapper::getTypeForSpirvStageVariable( // (GlobalInvocationId, LocalInvocationId, WorkgroupId) must be a uint3. // Keep the original integer signedness evalType = astContext.getExtVectorType( - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type, + hlsl::IsHLSLVecType(stageVarData.type) + ? hlsl::GetHLSLVecElementType(stageVarData.type) + : stageVarData.type, 3); break; default: @@ -3348,33 +3383,34 @@ QualType DeclResultIdMapper::getTypeForSpirvStageVariable( // Boolean stage I/O variables must be represented as unsigned integers. // Boolean built-in variables are represented as bool. - if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) { - evalType = getUintTypeWithSourceComponents(astContext, type); + if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + evalType = getUintTypeWithSourceComponents(astContext, stageVarData.type); } // Handle the extra arrayness - if (arraySize != 0) { + if (stageVarData.arraySize != 0) { evalType = astContext.getConstantArrayType( - evalType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0); + evalType, llvm::APInt(32, stageVarData.arraySize), + clang::ArrayType::Normal, 0); } return evalType; } -bool DeclResultIdMapper::createStageVars( - const hlsl::SigPoint *sigPoint, const NamedDecl *decl, bool asInput, - QualType type, uint32_t arraySize, const llvm::StringRef namePrefix, - llvm::Optional invocationId, SpirvInstruction **value, - bool noWriteBack, SemanticInfo *inheritSemantic, bool asNoInterp) { +bool DeclResultIdMapper::createStageVars(StageVarDataBundle &stageVarData, + bool asInput, SpirvInstruction **value, + bool noWriteBack) { assert(value); // invocationId should only be used for handling HS per-vertex output. - if (invocationId.hasValue()) { - assert(spvContext.isHS() && arraySize != 0 && !asInput); + if (stageVarData.invocationId.hasValue()) { + assert(spvContext.isHS() && stageVarData.arraySize != 0 && !asInput); } - assert(inheritSemantic); + assert(stageVarData.semantic); - if (type->isVoidType()) { + if (stageVarData.type->isVoidType()) { // No stage variables will be created for void type. return true; } @@ -3391,57 +3427,55 @@ bool DeclResultIdMapper::createStageVars( // should have semantics attached; // * If the current decl is not a struct, it should have semantic attached. - auto thisSemantic = getStageVarSemantic(decl); + auto thisSemantic = getStageVarSemantic(stageVarData.decl); // Which semantic we should use for this decl - auto *semanticToUse = &thisSemantic; - // Enclosing semantics override internal ones - if (inheritSemantic->isValid()) { + if (stageVarData.semantic->isValid()) { if (thisSemantic.isValid()) { emitWarning( "internal semantic '%0' overridden by enclosing semantic '%1'", thisSemantic.loc) - << thisSemantic.str << inheritSemantic->str; + << thisSemantic.str << stageVarData.semantic->str; } - semanticToUse = inheritSemantic; + } else { + stageVarData.semantic = &thisSemantic; } - const auto loc = decl->getLocation(); - if (semanticToUse->isValid() && + if (stageVarData.semantic->isValid() && // Structs with attached semantics will be handled later. - !type->isStructureType()) { + !stageVarData.type->isStructureType()) { // Found semantic attached directly to this Decl. This means we need to // map this decl to a single stage variable. - const auto semanticKind = semanticToUse->getKind(); - const auto sigPointKind = sigPoint->GetKind(); + const auto semanticKind = stageVarData.semantic->getKind(); + const auto sigPointKind = stageVarData.sigPoint->GetKind(); - if (!validateShaderStageVar(semanticToUse, sigPoint, decl, type)) { + if (!validateShaderStageVar(stageVarData)) { return false; } // Special handling of certain mappings between HLSL semantics and // SPIR-V builtins: // * SV_CullDistance/SV_ClipDistance are outsourced to GlPerVertex. - if (glPerVertex.tryToAccess(sigPointKind, semanticKind, - semanticToUse->index, invocationId, value, - noWriteBack, /*vecComponent=*/nullptr, loc)) + if (glPerVertex.tryToAccess( + sigPointKind, semanticKind, stageVarData.semantic->index, + stageVarData.invocationId, value, noWriteBack, + /*vecComponent=*/nullptr, stageVarData.decl->getLocation())) return true; - SpirvVariable *varInstr = createSpirvInterfaceVariable( - sigPoint, semanticToUse, decl, type, arraySize, asNoInterp, namePrefix); + SpirvVariable *varInstr = createSpirvInterfaceVariable(stageVarData); if (!varInstr) { return false; } // Mark that we have used one index for this semantic - ++semanticToUse->index; + ++stageVarData.semantic->index; if (asInput) { - *value = loadShaderInputVariable(varInstr, semanticKind, sigPointKind, - type, decl, loc); - if ((decl->hasAttr() || asNoInterp) && + *value = loadShaderInputVariable(varInstr, stageVarData); + if ((stageVarData.decl->hasAttr() || + stageVarData.asNoInterp) && sigPointKind == hlsl::SigPoint::Kind::PSIn) spvBuilder.addPerVertexStgInputFuncVarEntry(varInstr, *value); @@ -3451,8 +3485,7 @@ bool DeclResultIdMapper::createStageVars( // Negate SV_Position.y if requested if (semanticKind == hlsl::Semantic::Kind::Position) *value = invertYIfRequested(*value, thisSemantic.loc); - storeToShaderOutputVariable(varInstr, semanticKind, type, *value, - invocationId, sigPointKind, decl, loc); + storeToShaderOutputVariable(varInstr, *value, stageVarData); } return true; @@ -3461,22 +3494,20 @@ bool DeclResultIdMapper::createStageVars( // If the decl itself doesn't have semantic string attached and there is no // one to inherit, it should be a struct having all its fields with semantic // strings. - if (!semanticToUse->isValid() && !type->isStructureType()) { + if (!stageVarData.semantic->isValid() && + !stageVarData.type->isStructureType()) { emitError("semantic string missing for shader %select{output|input}0 " "variable '%1'", - loc) - << asInput << decl->getName(); + stageVarData.decl->getLocation()) + << asInput << stageVarData.decl->getName(); return false; } if (asInput) { - *value = createStructInputVar(type, sigPoint, arraySize, namePrefix, - asNoInterp, noWriteBack, semanticToUse, loc); + *value = createStructInputVar(stageVarData, noWriteBack); return (*value) != nullptr; } else { - return createStructOutputVar(type, sigPoint, arraySize, invocationId, - namePrefix, semanticToUse, noWriteBack, - asNoInterp, *value, loc); + return createStructOutputVar(stageVarData, *value, noWriteBack); } } diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.h b/tools/clang/lib/SPIRV/DeclResultIdMapper.h index fc239b84fa..45c54eeb41 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.h +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.h @@ -381,6 +381,46 @@ class DeclResultIdMapper { int indexInCTBuffer; }; + /// The struct containing the data needed to create the input and output + /// variables for the decl. + struct StageVarDataBundle { + public: + StageVarDataBundle() = default; + // The declaration of the variable for which we need to create the stage + // variables. + const NamedDecl *decl; + + // The HLSL semantic to apply to the variable. Note that this could be + // different than the semantic attached to decl because it could inherit + // the semantic from the parent declaration if this declaration is a member. + SemanticInfo *semantic; + + // True if the variable is not suppose to be interpolated. Note that we + // cannot just look at decl to determine this because the attribute might + // have been applied to a parent declaration. + bool asNoInterp; + + // The sigPoint is the shader stage that this variable should be added to, + // and whether it is an input or output. + const hlsl::SigPoint *sigPoint; + + // The type to use for the new variable. There are cases where the type + // might be different. See the call sites for createStageVars. + QualType type; + + // If the shader stage for the variable is HS, DS, or GS, the SPIR-V + // requires that the stage variable is an array of type. The arraySize gives + // the size for that array. + uint32_t arraySize; + + // A prefix to use for the name of the variable. + llvm::StringRef namePrefix; + + // If arraySize is not zero, invocationId gives the index to used when + // generating a write to the stage variable. + llvm::Optional invocationId; + }; + /// \brief Returns the SPIR-V information for the given decl. /// Returns nullptr if no such decl was previously registered. const DeclSpirvInfo *getDeclSpirvInfo(const ValueDecl *decl) const; @@ -621,113 +661,75 @@ class DeclResultIdMapper { const DeclContext *decl, int arraySize, ContextUsageKind usageKind, llvm::StringRef typeName, llvm::StringRef varName); - /// Creates all the stage variables mapped from semantics on the given decl. - /// Returns true on success. + /// Creates all of the stage variables that must be generated for the given + /// stage variable data. Returns true on success. /// - /// If decl is of struct type, this means flattening it and create stand- - /// alone variables for each field. If arraySize is not zero, the created - /// stage variables will be arrays of the original type and the given size. - /// This is for supporting HS/DS/GS, which takes in primitives containing - /// multiple vertices. asType should be the type we are treating decl as; - /// For HS/DS/GS, the outermost arrayness should be discarded and use - /// arraySize instead. + /// stageVarData: See the definition of StageVarDataBundle to see how that + /// data is used. /// - /// Also performs reading the stage variables and compose a temporary value - /// of the given type and writing into *value, if asInput is true. Otherwise, - /// Decomposes the *value according to type and writes back into the stage - /// output variables, unless noWriteBack is set to true. noWriteBack is used - /// by GS since in GS we manually control write back using .Append() method. + /// asInput: True if the stage variable is an input. /// - /// invocationId is only used for HS to indicate the index of the output - /// array element to write to. + /// TODO(s-perron): a variable that is an input or an output depending on + /// value of a flag is very hard to read. This function should be split up + /// and flag variables removed. /// - /// Assumes the decl has semantic attached to itself or to its fields. - /// If inheritSemantic is valid, it will override all semantics attached to - /// the children of this decl, and the children of this decl will be using - /// the semantic in inheritSemantic, with index increasing sequentially. - bool createStageVars(const hlsl::SigPoint *sigPoint, const NamedDecl *decl, - bool asInput, QualType asType, uint32_t arraySize, - const llvm::StringRef namePrefix, - llvm::Optional invocationId, - SpirvInstruction **value, bool noWriteBack, - SemanticInfo *inheritSemantic, bool asNoInterp = false); + /// [in/out] value: If `asInput` is true, this is an + /// output, and will be an instruction that loads the stage variable. If + /// `asInput` is false, then it is an input to createStageVars, and contains + /// the value to be stored in the new stage variable. + /// + /// noWriteBack: If true, the newly created stage variable will not be written + /// to. + bool createStageVars(StageVarDataBundle &stageVarData, bool asInput, + SpirvInstruction **value, bool noWriteBack); // Creates a variable to represent the output variable, which must be a // structure. If `noWriteBack` is false, then `value` will be written to the // new variable. Returns true if successful. // - // TODO(s-perron): All of the parameters associated with the HLSL - // representation of the variable should be grouped in a struct. This struct - // should be used in all of the function below to reduce the number of - // parameters. + // stageVarData: The data needed to create the stage variable. + // + // noWriteBack: A flag to indicate if the variable should be written or not. // - // type: The type of the variable in the HLSL. - // sigPoint: The signature point for the variable. - // arraySize: The size of the array when the SPIR-V variable must become an - // array of type, 0 otherwise. invocationId: when arraySize is not 0, - // invocationId is the index at which `value` should be written. namePrefix: a - // prefix to apply to the name of the variable. semantic: the HLSL semantic - // for this variable. noWriteBack: A flag to indicate if the variable should - // be written or not. asNoInterp: True if the variable is not to be - // interpolated. value: The value to be written to the newly create variable. - // loc: The location of the variable in the source code. - bool createStructOutputVar(QualType type, const hlsl::SigPoint *sigPoint, - uint32_t arraySize, - llvm::Optional invocationId, - const llvm::StringRef namePrefix, - SemanticInfo *semantic, bool noWriteBack, - bool asNoInterp, SpirvInstruction *value, - SourceLocation loc); + // value: The value to be written to the newly create variable. + bool createStructOutputVar(const StageVarDataBundle &stageVarData, + SpirvInstruction *value, bool noWriteBack); // Creates a variable to represent the input variable, which must be a // structure. The value is loaded and the instruction with the final value is // return. // - // type: The type of the variable in the HLSL. - // sigPoint: The signature point for the variable. - // arraySize: The size of the array when the SPIR-V variable must become an - // array of type, 0 otherwise. namePrefix: a prefix to apply to the name of - // the variable. asNoInterp: True if the variable is not to be interpolated. + // stageVarData: The data needed to create the stage variable. + // // noWriteBack: A flag to indicate if the variable should be written or not. - // semantic: the HLSL semantic for this variable. - // loc: The location of the variable in the source code. - SpirvInstruction * - createStructInputVar(QualType type, const hlsl::SigPoint *sigPoint, - uint32_t arraySize, const llvm::StringRef namePrefix, - bool asNoInterp, bool noWriteBack, - SemanticInfo *semantic, SourceLocation loc); + SpirvInstruction *createStructInputVar(const StageVarDataBundle &stageVarData, + bool noWriteBack); // Store `value` to the shader output variable `varInstr`. Since the type - // could be different, other data is used to know how to convert `value` into - // the correct type for `varInstr`. + // could be different, stageVarData is used to know how to convert `value` + // into the correct type for `varInstr`. + // + // varInstr: the output variable that corresponds to `stageVarData`. It must + // not be a struct. // - // varInstr: the output variable that corresponds to `decl`. It must not be a - // struct. semanticKind: the HLSL semantic for this variable. type: The type - // of the variable in the HLSL. value: The value to be written to the create - // variable. invocationId: when arraySize is not 0, invocationId is the index - // at which `value` should be written. sigPointKind: The signature point for - // the variable. decl: The original declaration of the variable. loc: The - // location of the variable in the source code. - void storeToShaderOutputVariable( - SpirvVariable *varInstr, hlsl::Semantic::Kind semanticKind, QualType type, - SpirvInstruction *value, llvm::Optional invocationId, - hlsl::SigPoint::Kind sigPointKind, const NamedDecl *decl, - SourceLocation loc); - - // Loads shader input variable `varInstr`, and modifies it to match `type`. - // Other data is used to know how to convert `varInstr` into the correct type. + // value: The value to be written to the create variable. // - // varInstr: the input variable that corresponds to `decl`. It must not be a - // struct. semanticKind: the HLSL semantic for this variable. sigPointKind: - // The signature point for the variable. type: The type of the variable in the - // HLSL. decl: The original declaration of the variable. loc: The location of - // the variable in the source code. - SpirvInstruction *loadShaderInputVariable(SpirvVariable *varInstr, - hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, - QualType type, - const NamedDecl *decl, - SourceLocation loc); + // stageVarData: The data that was used to create `varInstr`. + void storeToShaderOutputVariable(SpirvVariable *varInstr, + SpirvInstruction *value, + const StageVarDataBundle &stageVarData); + + // Loads shader input variable `varInstr`, and modifies the value to match the + // type in stageVarData. The struct stageVarData is used to know how to + // convert the value loaded from `varInstr` into the correct type. + // + // varInstr: the input variable that corresponds to `stageVarData`. It must + // not be a struct. + // + // stageVarData: The data that was used to create `varInstr`. + SpirvInstruction * + loadShaderInputVariable(SpirvVariable *varInstr, + const StageVarDataBundle &stageVarData); // Creates a function scope variable to represent the "SV_InstanceID" // semantic, which it not immediately available in SPIR-V. Its value will be @@ -735,8 +737,10 @@ class DeclResultIdMapper { // variables. // // instanceIndexVar: The SPIR-V input variable that decorated with - // InstanceIndex baseInstanceVar: The SPIR-V input variable that is decorated - // with BaseInstance. + // InstanceIndex. + // + // baseInstanceVar: The SPIR-V input variable that is decorated with + // BaseInstance. SpirvVariable *getInstanceIdFromIndexAndBase(SpirvVariable *instanceIndexVar, SpirvVariable *baseInstanceVar); @@ -750,6 +754,7 @@ class DeclResultIdMapper { // with reflection. // // semantic: the semantic to attach to this variable + // // sigPoint: the signature point identifying which shader stage the variable // will be used in. SpirvVariable *getBaseInstanceVariable(SemanticInfo *semantic, @@ -759,57 +764,34 @@ class DeclResultIdMapper { // The new variable with be add to `this->StageVars`. // // - // sigPoint: the signature point identifying which shader stage the variable - // will be used in. semantic: the HLSL semantic for the variable. decl: the - // declaration of the variable. type: the HLSL type of the variable arraySize: - // the size of the array if `type` should be turned into an array. 0 if `type` - // should not be turned into and array. - SpirvVariable *createSpirvInterfaceVariable(const hlsl::SigPoint *sigPoint, - SemanticInfo *semantic, - const NamedDecl *decl, - QualType type, uint32_t arraySize, - bool asNoInterp, - const llvm::StringRef namePrefix); + // stageVarData: the data needed to create the interface variable. See the + // declaration of StageVarDataBundle for the details. + SpirvVariable * + createSpirvInterfaceVariable(const StageVarDataBundle &stageVarData); // Returns the type that the SPIR-V input or output variable must have to // correspond to a variable with the given information. // - // type: The type of the variable in HLSL. - // semanticKind: The semantic attached to the variable. - // sigPointKind: The signature point (shader, input or output) of the - // variable. decl: The declaration of the variable. arraySize: The size of the - // array the variable should be converted to. 0 if it should not be turned - // into an array. - QualType getTypeForSpirvStageVariable(QualType type, - hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, - const NamedDecl *decl, - uint32_t arraySize); - - // Returns true if all of the information is consistent with a valid shader - // stage variable. Issues an error and returns false otherwise. - bool validateShaderStageVar(SemanticInfo *semantic, - const hlsl::SigPoint *sigPoint, - const NamedDecl *decl, QualType type); + // stageVarData: the data needed to create the interface variable. See the + // declaration of StageVarDataBundle for the details. + QualType getTypeForSpirvStageVariable(const StageVarDataBundle &stageVarData); + + // Returns true if all of the stage variable data is consistent with a valid + // shader stage variable. Issues an error and returns false otherwise. + bool validateShaderStageVar(const StageVarDataBundle &stageVarData); /// Returns true if all vk:: attributes usages are valid. bool validateVKAttributes(const NamedDecl *decl); /// Returns true if all vk::builtin usages are valid. - bool validateVKBuiltins(const NamedDecl *decl, - const hlsl::SigPoint *sigPoint); + bool validateVKBuiltins(const StageVarDataBundle &stageVarData); - // Returns true if the type and semantic kind are compatible. Issues an error - // and returns false otherwise. - bool validateShaderStageVarType(hlsl::Semantic::Kind semanticKind, - QualType type, clang::SourceLocation loc); + // Returns true if the type in stageVarData is compatible with the rest of the + // data. Issues an error and returns false otherwise. + bool validateShaderStageVarType(const StageVarDataBundle &stageVarData); - // Returns true if the semantic can be used with the given signature point. - // The declaration is used to get the builtin attribute to augment the - // semantic. - bool isValidSemanticInShaderModel(hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, - const NamedDecl *decl); + // Returns true if the semantic is consistent wit the rest of the given data. + bool isValidSemanticInShaderModel(const StageVarDataBundle &stageVarData); /// Creates the SPIR-V variable instruction for the given StageVar and returns /// the instruction. Also sets whether the StageVar is a SPIR-V builtin and @@ -824,6 +806,7 @@ class DeclResultIdMapper { /// Creates assoicated counter variables for all AssocCounter cases (see the /// comment of CounterVarFields). void createCounterVarForDecl(const DeclaratorDecl *decl); + /// Creates the associated counter variable for final RW/Append/Consume /// structured buffer. Handles AssocCounter#1 and AssocCounter#2 (see the /// comment of CounterVarFields).