diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index bb3baecf19..5db014ca95 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -880,9 +880,11 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl, const bool noWriteBack = storedValue == nullptr || spvContext.isGS() || spvContext.isMS(); - return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize, - "out.var", llvm::None, &storedValue, noWriteBack, - &inheritSemantic); + StageVarDataBundle stageVarData = { + decl, &inheritSemantic, false, sigPoint, + type, arraySize, "out.var", llvm::None}; + return createStageVars(stageVarData, /*asInput=*/false, &storedValue, + noWriteBack); } bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl, @@ -898,9 +900,11 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl, SemanticInfo inheritSemantic = {}; - return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize, - "out.var", invocationId, &storedValue, - /*noWriteBack=*/false, &inheritSemantic); + StageVarDataBundle stageVarData = { + decl, &inheritSemantic, false, sigPoint, + type, arraySize, "out.var", invocationId}; + return createStageVars(stageVarData, /*asInput=*/false, &storedValue, + /*noWriteBack=*/false); } bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl, @@ -938,9 +942,11 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl, return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true, type, "in.var", loadedValue); } else { - return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, - arraySize, "in.var", llvm::None, loadedValue, - /*noWriteBack=*/false, &inheritSemantic); + StageVarDataBundle stageVarData = { + paramDecl, &inheritSemantic, false, sigPoint, + type, arraySize, "in.var", llvm::None}; + return createStageVars(stageVarData, /*asInput=*/true, loadedValue, + /*noWriteBack=*/false); } } @@ -2602,22 +2608,22 @@ bool DeclResultIdMapper::decorateResourceCoherent() { } bool DeclResultIdMapper::createStructOutputVar( - QualType type, const hlsl::SigPoint *sigPoint, uint32_t arraySize, - llvm::Optional invocationId, - const llvm::StringRef namePrefix, SemanticInfo *semantic, bool noWriteBack, - bool asNoInterp, SpirvInstruction *value, SourceLocation loc) { + const StageVarDataBundle &stageVarData, SpirvInstruction *value, + bool noWriteBack) { // If we have base classes, we need to handle them first. - if (const auto *cxxDecl = type->getAsCXXRecordDecl()) { + if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) { uint32_t baseIndex = 0; for (auto base : cxxDecl->bases()) { SpirvInstruction *subValue = nullptr; if (!noWriteBack) - subValue = spvBuilder.createCompositeExtract(base.getType(), value, - {baseIndex++}, loc); - - if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(), - false, base.getType(), arraySize, namePrefix, - invocationId, &subValue, noWriteBack, semantic)) + subValue = spvBuilder.createCompositeExtract( + base.getType(), value, {baseIndex++}, + stageVarData.decl->getLocation()); + + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = base.getType()->getAsCXXRecordDecl(); + memberVarData.type = base.getType(); + if (!createStageVars(memberVarData, false, &subValue, noWriteBack)) return false; } } @@ -2636,64 +2642,65 @@ bool DeclResultIdMapper::createStructOutputVar( // // The interesting shader stage is HS. We need the InvocationID to write // out the value to the correct array element. - const auto *structDecl = type->getAs()->getDecl(); + const auto *structDecl = stageVarData.type->getAs()->getDecl(); for (const auto *field : structDecl->fields()) { const auto fieldType = field->getType(); SpirvInstruction *subValue = nullptr; if (!noWriteBack) { subValue = spvBuilder.createCompositeExtract( - fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()}, - loc); + fieldType, value, + {getNumBaseClasses(stageVarData.type) + field->getFieldIndex()}, + stageVarData.decl->getLocation()); if (field->hasAttr() || structDecl->hasAttr()) subValue->setNoninterpolated(); } - if (!createStageVars( - sigPoint, field, false, field->getType(), arraySize, namePrefix, - invocationId, &subValue, noWriteBack, semantic, - asNoInterp || field->hasAttr())) + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = field; + memberVarData.type = field->getType(); + memberVarData.asNoInterp |= field->hasAttr(); + if (!createStageVars(memberVarData, false, &subValue, noWriteBack)) return false; } return true; } -SpirvInstruction *DeclResultIdMapper::createStructInputVar( - QualType type, const hlsl::SigPoint *sigPoint, uint32_t arraySize, - const llvm::StringRef namePrefix, bool asNoInterp, bool noWriteBack, - SemanticInfo *semanticToUse, SourceLocation loc) { +SpirvInstruction * +DeclResultIdMapper::createStructInputVar(const StageVarDataBundle &stageVarData, + bool noWriteBack) { // If this decl translates into multiple stage input variables, we need to // load their values into a composite. llvm::SmallVector subValues; // If we have base classes, we need to handle them first. - if (const auto *cxxDecl = type->getAsCXXRecordDecl()) { + if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) { for (auto base : cxxDecl->bases()) { SpirvInstruction *subValue = nullptr; - if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(), true, - base.getType(), arraySize, namePrefix, - llvm::Optional(), &subValue, - noWriteBack, semanticToUse)) + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = base.getType()->getAsCXXRecordDecl(); + memberVarData.type = base.getType(); + if (!createStageVars(memberVarData, true, &subValue, noWriteBack)) return nullptr; subValues.push_back(subValue); } } - const auto *structDecl = type->getAs()->getDecl(); + const auto *structDecl = stageVarData.type->getAs()->getDecl(); for (const auto *field : structDecl->fields()) { SpirvInstruction *subValue = nullptr; - if (!createStageVars(sigPoint, field, true, field->getType(), arraySize, - namePrefix, llvm::Optional(), - &subValue, noWriteBack, semanticToUse, - asNoInterp || - field->hasAttr())) + StageVarDataBundle memberVarData = stageVarData; + memberVarData.decl = field; + memberVarData.type = field->getType(); + memberVarData.asNoInterp |= field->hasAttr(); + if (!createStageVars(memberVarData, true, &subValue, noWriteBack)) return nullptr; subValues.push_back(subValue); } - if (arraySize == 0) { - SpirvInstruction *value = - spvBuilder.createCompositeConstruct(type, subValues, loc); + if (stageVarData.arraySize == 0) { + SpirvInstruction *value = spvBuilder.createCompositeConstruct( + stageVarData.type, subValues, stageVarData.decl->getLocation()); for (auto *subInstr : subValues) spvBuilder.addPerVertexStgInputFuncVarEntry(subInstr, value); return value; @@ -2705,22 +2712,25 @@ SpirvInstruction *DeclResultIdMapper::createStructInputVar( // from visiting all fields. So now we need to extract all the elements // at the same index of each field arrays and compose a new struct out // of them. - const auto structType = type; + const auto structType = stageVarData.type; const auto arrayType = astContext.getConstantArrayType( - structType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0); + structType, llvm::APInt(32, stageVarData.arraySize), + clang::ArrayType::Normal, 0); llvm::SmallVector arrayElements; - for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) { + for (uint32_t arrayIndex = 0; arrayIndex < stageVarData.arraySize; + ++arrayIndex) { llvm::SmallVector fields; // If we have base classes, we need to handle them first. - if (const auto *cxxDecl = type->getAsCXXRecordDecl()) { + if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) { uint32_t baseIndex = 0; for (auto base : cxxDecl->bases()) { const auto baseType = base.getType(); fields.push_back(spvBuilder.createCompositeExtract( - baseType, subValues[baseIndex++], {arrayIndex}, loc)); + baseType, subValues[baseIndex++], {arrayIndex}, + stageVarData.decl->getLocation())); } } @@ -2729,41 +2739,42 @@ SpirvInstruction *DeclResultIdMapper::createStructInputVar( const auto fieldType = field->getType(); fields.push_back(spvBuilder.createCompositeExtract( fieldType, - subValues[getNumBaseClasses(type) + field->getFieldIndex()], - {arrayIndex}, loc)); + subValues[getNumBaseClasses(stageVarData.type) + + field->getFieldIndex()], + {arrayIndex}, stageVarData.decl->getLocation())); } // Compose a new struct out of them - arrayElements.push_back( - spvBuilder.createCompositeConstruct(structType, fields, loc)); + arrayElements.push_back(spvBuilder.createCompositeConstruct( + structType, fields, stageVarData.decl->getLocation())); } - return spvBuilder.createCompositeConstruct(arrayType, arrayElements, loc); + return spvBuilder.createCompositeConstruct(arrayType, arrayElements, + stageVarData.decl->getLocation()); } void DeclResultIdMapper::storeToShaderOutputVariable( - SpirvVariable *varInstr, hlsl::Semantic::Kind semanticKind, QualType type, - SpirvInstruction *value, llvm::Optional invocationId, - hlsl::SigPoint::Kind sigPointKind, const NamedDecl *decl, - SourceLocation loc) { + SpirvVariable *varInstr, SpirvInstruction *value, + const StageVarDataBundle &stageVarData) { SpirvInstruction *ptr = varInstr; // Special handling of SV_TessFactor HS patch constant output. // TessLevelOuter is always an array of size 4 in SPIR-V, but // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the // relevant indexes must be written to. - if (semanticKind == hlsl::Semantic::Kind::TessFactor && - hlsl::GetArraySize(type) != 4) { - const auto tessFactorSize = hlsl::GetArraySize(type); + if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::TessFactor && + hlsl::GetArraySize(stageVarData.type) != 4) { + const auto tessFactorSize = hlsl::GetArraySize(stageVarData.type); for (uint32_t i = 0; i < tessFactorSize; ++i) { ptr = spvBuilder.createAccessChain( astContext.FloatTy, varInstr, {spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, i))}, - loc); - spvBuilder.createStore(ptr, - spvBuilder.createCompositeExtract( - astContext.FloatTy, value, {i}, loc), - loc); + stageVarData.decl->getLocation()); + spvBuilder.createStore( + ptr, + spvBuilder.createCompositeExtract(astContext.FloatTy, value, {i}, + stageVarData.decl->getLocation()), + stageVarData.decl->getLocation()); } } // Special handling of SV_InsideTessFactor HS patch constant output. @@ -2771,110 +2782,121 @@ void DeclResultIdMapper::storeToShaderOutputVariable( // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of // TessLevelInner. - else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor && + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::InsideTessFactor && // Some developers use float[1] instead of a scalar float. - (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) { + (!stageVarData.type->isArrayType() || + hlsl::GetArraySize(stageVarData.type) == 1)) { ptr = spvBuilder.createAccessChain( astContext.FloatTy, varInstr, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)), - loc); - if (type->isArrayType()) // float[1] - value = spvBuilder.createCompositeExtract(astContext.FloatTy, value, {0}, - loc); - spvBuilder.createStore(ptr, value, loc); + stageVarData.decl->getLocation()); + if (stageVarData.type->isArrayType()) // float[1] + value = spvBuilder.createCompositeExtract( + astContext.FloatTy, value, {0}, stageVarData.decl->getLocation()); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // Special handling of SV_Coverage, which is an unit value. We need to // write it to the first element in the SampleMask builtin. - else if (semanticKind == hlsl::Semantic::Kind::Coverage) { + else if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Coverage) { ptr = spvBuilder.createAccessChain( - type, varInstr, + stageVarData.type, varInstr, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)), - loc); + stageVarData.decl->getLocation()); ptr->setStorageClass(spv::StorageClass::Output); - spvBuilder.createStore(ptr, value, loc); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // Special handling of HS ouput, for which we write to only one // element in the per-vertex data array: the one indexed by // SV_ControlPointID. - else if (invocationId.hasValue() && invocationId.getValue() != nullptr) { + else if (stageVarData.invocationId.hasValue() && + stageVarData.invocationId.getValue() != nullptr) { // Remove the arrayness to get the element type. assert(isa(varInstr->getAstResultType())); const auto elementType = astContext.getAsArrayType(varInstr->getAstResultType()) ->getElementType(); - auto index = invocationId.getValue(); - ptr = spvBuilder.createAccessChain(elementType, varInstr, index, loc); + auto index = stageVarData.invocationId.getValue(); + ptr = spvBuilder.createAccessChain(elementType, varInstr, index, + stageVarData.decl->getLocation()); ptr->setStorageClass(spv::StorageClass::Output); - spvBuilder.createStore(ptr, value, loc); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // Since boolean output stage variables are represented as unsigned // integers, we must cast the value to uint before storing. - else if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) { - value = - theEmitter.castToType(value, type, varInstr->getAstResultType(), loc); - spvBuilder.createStore(ptr, value, loc); + else if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + value = theEmitter.castToType(value, stageVarData.type, + varInstr->getAstResultType(), + stageVarData.decl->getLocation()); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } // For all normal cases else { - spvBuilder.createStore(ptr, value, loc); + spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } } SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable( - SpirvVariable *varInstr, hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, QualType type, const NamedDecl *decl, - SourceLocation loc) { - SpirvInstruction *load = - spvBuilder.createLoad(varInstr->getAstResultType(), varInstr, loc); + SpirvVariable *varInstr, const StageVarDataBundle &stageVarData) { + SpirvInstruction *load = spvBuilder.createLoad( + varInstr->getAstResultType(), varInstr, stageVarData.decl->getLocation()); // Fix ups for corner cases // Special handling of SV_TessFactor DS patch constant input. // TessLevelOuter is always an array of size 4 in SPIR-V, but // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the // relevant indexes must be loaded. - if (semanticKind == hlsl::Semantic::Kind::TessFactor && - hlsl::GetArraySize(type) != 4) { + if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::TessFactor && + hlsl::GetArraySize(stageVarData.type) != 4) { llvm::SmallVector components; - const auto tessFactorSize = hlsl::GetArraySize(type); + const auto tessFactorSize = hlsl::GetArraySize(stageVarData.type); const auto arrType = astContext.getConstantArrayType( astContext.FloatTy, llvm::APInt(32, tessFactorSize), clang::ArrayType::Normal, 0); for (uint32_t i = 0; i < tessFactorSize; ++i) - components.push_back(spvBuilder.createCompositeExtract(astContext.FloatTy, - load, {i}, loc)); - load = spvBuilder.createCompositeConstruct(arrType, components, loc); + components.push_back(spvBuilder.createCompositeExtract( + astContext.FloatTy, load, {i}, stageVarData.decl->getLocation())); + load = spvBuilder.createCompositeConstruct( + arrType, components, stageVarData.decl->getLocation()); } // Special handling of SV_InsideTessFactor DS patch constant input. // TessLevelInner is always an array of size 2 in SPIR-V, but // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of // TessLevelInner. - else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor && + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::InsideTessFactor && // Some developers use float[1] instead of a scalar float. - (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) { - load = - spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0}, loc); - if (type->isArrayType()) { // float[1] + (!stageVarData.type->isArrayType() || + hlsl::GetArraySize(stageVarData.type) == 1)) { + load = spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0}, + stageVarData.decl->getLocation()); + if (stageVarData.type->isArrayType()) { // float[1] const auto arrType = astContext.getConstantArrayType( astContext.FloatTy, llvm::APInt(32, 1), clang::ArrayType::Normal, 0); - load = spvBuilder.createCompositeConstruct(arrType, {load}, loc); + load = spvBuilder.createCompositeConstruct( + arrType, {load}, stageVarData.decl->getLocation()); } } // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord // is always a float3. To ensure SPIR-V validity, a float3 stage variable // is created, and we must extract a float2 from it before passing it to // the main function. - else if (semanticKind == hlsl::Semantic::Kind::DomainLocation && - hlsl::GetHLSLVecSize(type) != 3) { - const auto domainLocSize = hlsl::GetHLSLVecSize(type); + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::DomainLocation && + hlsl::GetHLSLVecSize(stageVarData.type) != 3) { + const auto domainLocSize = hlsl::GetHLSLVecSize(stageVarData.type); load = spvBuilder.createVectorShuffle( astContext.getExtVectorType(astContext.FloatTy, domainLocSize), load, - load, {0, 1}, loc); + load, {0, 1}, stageVarData.decl->getLocation()); } // Special handling of SV_Coverage, which is an uint value. We need to // read SampleMask and extract its first element. - else if (semanticKind == hlsl::Semantic::Kind::Coverage) { - load = spvBuilder.createCompositeExtract(type, load, {0}, loc); + else if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Coverage) { + load = spvBuilder.createCompositeExtract(stageVarData.type, load, {0}, + stageVarData.decl->getLocation()); } // Special handling of SV_InnerCoverage, which is an uint value. We need // to read FullyCoveredEXT, which is a boolean value, and convert it to an @@ -2887,85 +2909,96 @@ SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable( // but are undefined when bit 0 is set to 1 (essentially, this bit-field // represents a Boolean value where false must be exactly 0, but true can // be any odd (i.e. bit 0 set) non-zero value)." - else if (semanticKind == hlsl::Semantic::Kind::InnerCoverage) { + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::InnerCoverage) { const auto constOne = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1)); const auto constZero = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); load = spvBuilder.createSelect(astContext.UnsignedIntTy, load, constOne, - constZero, loc); + constZero, stageVarData.decl->getLocation()); } // Special handling of SV_Barycentrics, which is a float3, but the // The 3 values are NOT guaranteed to add up to floating-point 1.0 // exactly. Calculate the third element here. - else if (semanticKind == hlsl::Semantic::Kind::Barycentrics) { - const auto x = - spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0}, loc); - const auto y = - spvBuilder.createCompositeExtract(astContext.FloatTy, load, {1}, loc); - const auto xy = spvBuilder.createBinaryOp(spv::Op::OpFAdd, - astContext.FloatTy, x, y, loc); + else if (stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::Barycentrics) { + const auto x = spvBuilder.createCompositeExtract( + astContext.FloatTy, load, {0}, stageVarData.decl->getLocation()); + const auto y = spvBuilder.createCompositeExtract( + astContext.FloatTy, load, {1}, stageVarData.decl->getLocation()); + const auto xy = + spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, x, y, + stageVarData.decl->getLocation()); const auto z = spvBuilder.createBinaryOp( spv::Op::OpFSub, astContext.FloatTy, spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)), - xy, loc); + xy, stageVarData.decl->getLocation()); load = spvBuilder.createCompositeConstruct( - astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z}, loc); + astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z}, + stageVarData.decl->getLocation()); } // Special handling of SV_DispatchThreadID and SV_GroupThreadID, which may // be a uint or uint2, but the underlying stage input variable is a uint3. // The last component(s) should be discarded in needed. - else if ((semanticKind == hlsl::Semantic::Kind::DispatchThreadID || - semanticKind == hlsl::Semantic::Kind::GroupThreadID || - semanticKind == hlsl::Semantic::Kind::GroupID) && - (!hlsl::IsHLSLVecType(type) || hlsl::GetHLSLVecSize(type) != 3)) { + else if ((stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::DispatchThreadID || + stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::GroupThreadID || + stageVarData.semantic->getKind() == + hlsl::Semantic::Kind::GroupID) && + (!hlsl::IsHLSLVecType(stageVarData.type) || + hlsl::GetHLSLVecSize(stageVarData.type) != 3)) { const auto srcVecElemType = - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type; - const auto vecSize = - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1; + hlsl::IsHLSLVecType(stageVarData.type) + ? hlsl::GetHLSLVecElementType(stageVarData.type) + : stageVarData.type; + const auto vecSize = hlsl::IsHLSLVecType(stageVarData.type) + ? hlsl::GetHLSLVecSize(stageVarData.type) + : 1; if (vecSize == 1) - load = spvBuilder.createCompositeExtract(srcVecElemType, load, {0}, loc); + load = spvBuilder.createCompositeExtract( + srcVecElemType, load, {0}, stageVarData.decl->getLocation()); else if (vecSize == 2) load = spvBuilder.createVectorShuffle( astContext.getExtVectorType(srcVecElemType, 2), load, load, {0, 1}, - loc); + stageVarData.decl->getLocation()); } // Reciprocate SV_Position.w if requested - if (semanticKind == hlsl::Semantic::Kind::Position) - load = invertWIfRequested(load, loc); + if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Position) + load = invertWIfRequested(load, stageVarData.decl->getLocation()); // Since boolean stage input variables are represented as unsigned // integers, after loading them, we should cast them to boolean. - if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) { - load = theEmitter.castToType(load, varInstr->getAstResultType(), type, loc); + if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + load = theEmitter.castToType(load, varInstr->getAstResultType(), + stageVarData.type, + stageVarData.decl->getLocation()); } return load; } -bool DeclResultIdMapper::validateShaderStageVar(SemanticInfo *semantic, - const hlsl::SigPoint *sigPoint, - const NamedDecl *decl, - QualType type) { - const auto semanticKind = semantic->getKind(); - const auto sigPointKind = sigPoint->GetKind(); - - if (!validateVKAttributes(decl)) +bool DeclResultIdMapper::validateShaderStageVar( + const StageVarDataBundle &stageVarData) { + if (!validateVKAttributes(stageVarData.decl)) return false; - if (!isValidSemanticInShaderModel(semanticKind, sigPointKind, decl)) { + if (!isValidSemanticInShaderModel(stageVarData)) { emitError("invalid usage of semantic '%0' in shader profile %1", - decl->getLocation()) - << semantic->str + stageVarData.decl->getLocation()) + << stageVarData.semantic->str << hlsl::ShaderModel::GetKindName( spvContext.getCurrentShaderModelKind()); return false; } - if (!validateVKBuiltins(decl, sigPoint)) + if (!validateVKBuiltins(stageVarData)) return false; - if (!validateShaderStageVarType(semanticKind, type, decl->getLocation())) + if (!validateShaderStageVarType(stageVarData)) return false; return true; } @@ -3006,17 +3039,18 @@ bool DeclResultIdMapper::validateVKAttributes(const NamedDecl *decl) { return success; } -bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, - const hlsl::SigPoint *sigPoint) { +bool DeclResultIdMapper::validateVKBuiltins( + const StageVarDataBundle &stageVarData) { bool success = true; - if (const auto *builtinAttr = decl->getAttr()) { + if (const auto *builtinAttr = stageVarData.decl->getAttr()) { // The front end parsing only allows vk::builtin to be attached to a // function/parameter/variable; all of them are DeclaratorDecls. - const auto declType = getTypeOrFnRetType(cast(decl)); + const auto declType = + getTypeOrFnRetType(cast(stageVarData.decl)); const auto loc = builtinAttr->getLocation(); - if (decl->hasAttr()) { + if (stageVarData.decl->hasAttr()) { emitError("cannot use vk::builtin and vk::location together", loc); success = false; } @@ -3029,7 +3063,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } - if (sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) { + if (stageVarData.sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) { emitError( "HelperInvocation builtin can only be used as pixel shader input", loc); @@ -3041,7 +3075,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } - switch (sigPoint->GetKind()) { + switch (stageVarData.sigPoint->GetKind()) { case hlsl::SigPoint::Kind::VSOut: case hlsl::SigPoint::Kind::HSCPIn: case hlsl::SigPoint::Kind::HSCPOut: @@ -3054,7 +3088,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, break; default: emitError("PointSize builtin cannot be used as %0", loc) - << sigPoint->GetName(); + << stageVarData.sigPoint->GetName(); success = false; } } else if (builtin == "BaseVertex" || builtin == "BaseInstance" || @@ -3066,24 +3100,25 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } - switch (sigPoint->GetKind()) { + switch (stageVarData.sigPoint->GetKind()) { case hlsl::SigPoint::Kind::VSIn: break; case hlsl::SigPoint::Kind::MSIn: case hlsl::SigPoint::Kind::ASIn: if (builtin != "DrawIndex") { emitError("%0 builtin cannot be used as %1", loc) - << builtin << sigPoint->GetName(); + << builtin << stageVarData.sigPoint->GetName(); success = false; } break; default: emitError("%0 builtin cannot be used as %1", loc) - << builtin << sigPoint->GetName(); + << builtin << stageVarData.sigPoint->GetName(); success = false; } } else if (builtin == "DeviceIndex") { - if (getStorageClassForSigPoint(sigPoint) != spv::StorageClass::Input) { + if (getStorageClassForSigPoint(stageVarData.sigPoint) != + spv::StorageClass::Input) { emitError("%0 builtin can only be used as shader input", loc) << builtin; success = false; @@ -3095,7 +3130,7 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, success = false; } } else if (builtin == "ViewportMaskNV") { - if (sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) { + if (stageVarData.sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) { emitError("%0 builtin can only be used as 'primitives' output in MS", loc) << builtin; @@ -3115,13 +3150,13 @@ bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl, } bool DeclResultIdMapper::validateShaderStageVarType( - hlsl::Semantic::Kind semanticKind, QualType type, - clang::SourceLocation loc) { + const StageVarDataBundle &stageVarData) { - switch (semanticKind) { + switch (stageVarData.semantic->getKind()) { case hlsl::Semantic::Kind::InnerCoverage: - if (!type->isSpecificBuiltinType(BuiltinType::UInt)) { - emitError("SV_InnerCoverage must be of uint type.", loc); + if (!stageVarData.type->isSpecificBuiltinType(BuiltinType::UInt)) { + emitError("SV_InnerCoverage must be of uint type.", + stageVarData.decl->getLocation()); return false; } break; @@ -3132,18 +3167,18 @@ bool DeclResultIdMapper::validateShaderStageVarType( } bool DeclResultIdMapper::isValidSemanticInShaderModel( - hlsl::Semantic::Kind semanticKind, hlsl::SigPoint::Kind sigPointKind, - const NamedDecl *decl) { + const StageVarDataBundle &stageVarData) { // Error out when the given semantic is invalid in this shader model - if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPointKind, - spvContext.getMajorVersion(), - spvContext.getMinorVersion()) == + if (hlsl::SigPoint::GetInterpretation( + stageVarData.semantic->getKind(), stageVarData.sigPoint->GetKind(), + spvContext.getMajorVersion(), spvContext.getMinorVersion()) == hlsl::DXIL::SemanticInterpretationKind::NA) { // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex". - switch (sigPointKind) { + switch (stageVarData.sigPoint->GetKind()) { case hlsl::SigPoint::Kind::MSIn: case hlsl::SigPoint::Kind::ASIn: - if (const auto *builtinAttr = decl->getAttr()) { + if (const auto *builtinAttr = + stageVarData.decl->getAttr()) { const llvm::StringRef builtin = builtinAttr->getBuiltIn(); if (builtin == "DrawIndex") { break; @@ -3191,38 +3226,36 @@ DeclResultIdMapper::getBaseInstanceVariable(SemanticInfo *semantic, } SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( - const hlsl::SigPoint *sigPoint, SemanticInfo *semanticToUse, - const NamedDecl *decl, QualType type, uint32_t arraySize, bool asNoInterp, - const llvm::StringRef namePrefix) { + const StageVarDataBundle &stageVarData) { // The evalType will be the type of the interface variable in SPIR-V. // The type of the variable used in the body of the function will still be // `type`. - QualType evalType = getTypeForSpirvStageVariable( - type, semanticToUse->getKind(), sigPoint->GetKind(), decl, arraySize); + QualType evalType = getTypeForSpirvStageVariable(stageVarData); - const auto *builtinAttr = decl->getAttr(); + const auto *builtinAttr = stageVarData.decl->getAttr(); StageVar stageVar( - sigPoint, *semanticToUse, builtinAttr, evalType, + stageVarData.sigPoint, *stageVarData.semantic, builtinAttr, evalType, // For HS/DS/GS, we have already stripped the outmost arrayness on type. - getLocationAndComponentCount(astContext, type)); - const auto name = namePrefix.str() + "." + stageVar.getSemanticStr(); - SpirvVariable *varInstr = - createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc); + getLocationAndComponentCount(astContext, stageVarData.type)); + const auto name = + stageVarData.namePrefix.str() + "." + stageVar.getSemanticStr(); + SpirvVariable *varInstr = createSpirvStageVar( + &stageVar, stageVarData.decl, name, stageVarData.semantic->loc); if (!varInstr) return nullptr; - if (asNoInterp) + if (stageVarData.asNoInterp) varInstr->setNoninterpolated(); stageVar.setSpirvInstr(varInstr); - stageVar.setLocationAttr(decl->getAttr()); - stageVar.setIndexAttr(decl->getAttr()); + stageVar.setLocationAttr(stageVarData.decl->getAttr()); + stageVar.setIndexAttr(stageVarData.decl->getAttr()); if (stageVar.getStorageClass() == spv::StorageClass::Input || stageVar.getStorageClass() == spv::StorageClass::Output) { stageVar.setEntryPoint(entryFunction); } - decorateStageVarWithIntrinsicAttrs(decl, &stageVar, varInstr); + decorateStageVarWithIntrinsicAttrs(stageVarData.decl, &stageVar, varInstr); stageVars.push_back(stageVar); // Emit OpDecorate* instructions to link this stage variable with the HLSL @@ -3230,9 +3263,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr()); // TODO: the following may not be correct? - if (sigPoint->GetSignatureKind() == + if (stageVarData.sigPoint->GetSignatureKind() == hlsl::DXIL::SignatureKind::PatchConstOrPrim) { - if (sigPoint->GetKind() == hlsl::SigPoint::Kind::MSPOut) { + if (stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::MSPOut) { // Decorate with PerPrimitiveNV for per-primitive out variables. spvBuilder.decoratePerPrimitiveNV(varInstr, varInstr->getSourceLocation()); @@ -3243,9 +3276,10 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( // Decorate with interpolation modes for pixel shader input variables // or vertex shader output variables. - if ((spvContext.isPS() && sigPoint->IsInput()) || - (spvContext.isVS() && sigPoint->IsOutput())) - decorateInterpolationMode(decl, type, varInstr, *semanticToUse); + if ((spvContext.isPS() && stageVarData.sigPoint->IsInput()) || + (spvContext.isVS() && stageVarData.sigPoint->IsOutput())) + decorateInterpolationMode(stageVarData.decl, stageVarData.type, varInstr, + *stageVarData.semantic); // Special case: The DX12 SV_InstanceID always counts from 0, even if the // StartInstanceLocation parameter is non-zero. gl_InstanceIndex, however, @@ -3271,12 +3305,13 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( // drawing command or the firstInstance member of a structure consumed by // an indirect drawing command. if (spirvOptions.supportNonzeroBaseInstance && - semanticToUse->getKind() == hlsl::Semantic::Kind::InstanceID && - sigPoint->GetKind() == hlsl::SigPoint::Kind::VSIn) { + stageVarData.semantic->getKind() == hlsl::Semantic::Kind::InstanceID && + stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::VSIn) { // The above call to createSpirvStageVar creates the gl_InstanceIndex. // We should now manually create the gl_BaseInstance variable and do the // subtraction. - auto *baseInstanceVar = getBaseInstanceVariable(semanticToUse, sigPoint); + auto *baseInstanceVar = + getBaseInstanceVariable(stageVarData.semantic, stageVarData.sigPoint); // SPIR-V code for 'SV_InstanceID = gl_InstanceIndex - gl_BaseInstance' varInstr = getInstanceIdFromIndexAndBase(varInstr, baseInstanceVar); @@ -3284,17 +3319,15 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( // We have semantics attached to this decl, which means it must be a // function/parameter/variable. All are DeclaratorDecls. - stageVarInstructions[cast(decl)] = varInstr; + stageVarInstructions[cast(stageVarData.decl)] = varInstr; return varInstr; } QualType DeclResultIdMapper::getTypeForSpirvStageVariable( - QualType type, hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, const NamedDecl *decl, - uint32_t arraySize) { - QualType evalType = type; - switch (semanticKind) { + const StageVarDataBundle &stageVarData) { + QualType evalType = stageVarData.type; + switch (stageVarData.semantic->getKind()) { case hlsl::Semantic::Kind::DomainLocation: // SV_DomainLocation can refer to a float2, whereas TessCoord is a float3. // To ensure SPIR-V validity, we must create a float3 and extract a @@ -3338,7 +3371,9 @@ QualType DeclResultIdMapper::getTypeForSpirvStageVariable( // (GlobalInvocationId, LocalInvocationId, WorkgroupId) must be a uint3. // Keep the original integer signedness evalType = astContext.getExtVectorType( - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type, + hlsl::IsHLSLVecType(stageVarData.type) + ? hlsl::GetHLSLVecElementType(stageVarData.type) + : stageVarData.type, 3); break; default: @@ -3348,33 +3383,34 @@ QualType DeclResultIdMapper::getTypeForSpirvStageVariable( // Boolean stage I/O variables must be represented as unsigned integers. // Boolean built-in variables are represented as bool. - if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) { - evalType = getUintTypeWithSourceComponents(astContext, type); + if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + evalType = getUintTypeWithSourceComponents(astContext, stageVarData.type); } // Handle the extra arrayness - if (arraySize != 0) { + if (stageVarData.arraySize != 0) { evalType = astContext.getConstantArrayType( - evalType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0); + evalType, llvm::APInt(32, stageVarData.arraySize), + clang::ArrayType::Normal, 0); } return evalType; } -bool DeclResultIdMapper::createStageVars( - const hlsl::SigPoint *sigPoint, const NamedDecl *decl, bool asInput, - QualType type, uint32_t arraySize, const llvm::StringRef namePrefix, - llvm::Optional invocationId, SpirvInstruction **value, - bool noWriteBack, SemanticInfo *inheritSemantic, bool asNoInterp) { +bool DeclResultIdMapper::createStageVars(StageVarDataBundle &stageVarData, + bool asInput, SpirvInstruction **value, + bool noWriteBack) { assert(value); // invocationId should only be used for handling HS per-vertex output. - if (invocationId.hasValue()) { - assert(spvContext.isHS() && arraySize != 0 && !asInput); + if (stageVarData.invocationId.hasValue()) { + assert(spvContext.isHS() && stageVarData.arraySize != 0 && !asInput); } - assert(inheritSemantic); + assert(stageVarData.semantic); - if (type->isVoidType()) { + if (stageVarData.type->isVoidType()) { // No stage variables will be created for void type. return true; } @@ -3391,57 +3427,55 @@ bool DeclResultIdMapper::createStageVars( // should have semantics attached; // * If the current decl is not a struct, it should have semantic attached. - auto thisSemantic = getStageVarSemantic(decl); + auto thisSemantic = getStageVarSemantic(stageVarData.decl); // Which semantic we should use for this decl - auto *semanticToUse = &thisSemantic; - // Enclosing semantics override internal ones - if (inheritSemantic->isValid()) { + if (stageVarData.semantic->isValid()) { if (thisSemantic.isValid()) { emitWarning( "internal semantic '%0' overridden by enclosing semantic '%1'", thisSemantic.loc) - << thisSemantic.str << inheritSemantic->str; + << thisSemantic.str << stageVarData.semantic->str; } - semanticToUse = inheritSemantic; + } else { + stageVarData.semantic = &thisSemantic; } - const auto loc = decl->getLocation(); - if (semanticToUse->isValid() && + if (stageVarData.semantic->isValid() && // Structs with attached semantics will be handled later. - !type->isStructureType()) { + !stageVarData.type->isStructureType()) { // Found semantic attached directly to this Decl. This means we need to // map this decl to a single stage variable. - const auto semanticKind = semanticToUse->getKind(); - const auto sigPointKind = sigPoint->GetKind(); + const auto semanticKind = stageVarData.semantic->getKind(); + const auto sigPointKind = stageVarData.sigPoint->GetKind(); - if (!validateShaderStageVar(semanticToUse, sigPoint, decl, type)) { + if (!validateShaderStageVar(stageVarData)) { return false; } // Special handling of certain mappings between HLSL semantics and // SPIR-V builtins: // * SV_CullDistance/SV_ClipDistance are outsourced to GlPerVertex. - if (glPerVertex.tryToAccess(sigPointKind, semanticKind, - semanticToUse->index, invocationId, value, - noWriteBack, /*vecComponent=*/nullptr, loc)) + if (glPerVertex.tryToAccess( + sigPointKind, semanticKind, stageVarData.semantic->index, + stageVarData.invocationId, value, noWriteBack, + /*vecComponent=*/nullptr, stageVarData.decl->getLocation())) return true; - SpirvVariable *varInstr = createSpirvInterfaceVariable( - sigPoint, semanticToUse, decl, type, arraySize, asNoInterp, namePrefix); + SpirvVariable *varInstr = createSpirvInterfaceVariable(stageVarData); if (!varInstr) { return false; } // Mark that we have used one index for this semantic - ++semanticToUse->index; + ++stageVarData.semantic->index; if (asInput) { - *value = loadShaderInputVariable(varInstr, semanticKind, sigPointKind, - type, decl, loc); - if ((decl->hasAttr() || asNoInterp) && + *value = loadShaderInputVariable(varInstr, stageVarData); + if ((stageVarData.decl->hasAttr() || + stageVarData.asNoInterp) && sigPointKind == hlsl::SigPoint::Kind::PSIn) spvBuilder.addPerVertexStgInputFuncVarEntry(varInstr, *value); @@ -3451,8 +3485,7 @@ bool DeclResultIdMapper::createStageVars( // Negate SV_Position.y if requested if (semanticKind == hlsl::Semantic::Kind::Position) *value = invertYIfRequested(*value, thisSemantic.loc); - storeToShaderOutputVariable(varInstr, semanticKind, type, *value, - invocationId, sigPointKind, decl, loc); + storeToShaderOutputVariable(varInstr, *value, stageVarData); } return true; @@ -3461,22 +3494,20 @@ bool DeclResultIdMapper::createStageVars( // If the decl itself doesn't have semantic string attached and there is no // one to inherit, it should be a struct having all its fields with semantic // strings. - if (!semanticToUse->isValid() && !type->isStructureType()) { + if (!stageVarData.semantic->isValid() && + !stageVarData.type->isStructureType()) { emitError("semantic string missing for shader %select{output|input}0 " "variable '%1'", - loc) - << asInput << decl->getName(); + stageVarData.decl->getLocation()) + << asInput << stageVarData.decl->getName(); return false; } if (asInput) { - *value = createStructInputVar(type, sigPoint, arraySize, namePrefix, - asNoInterp, noWriteBack, semanticToUse, loc); + *value = createStructInputVar(stageVarData, noWriteBack); return (*value) != nullptr; } else { - return createStructOutputVar(type, sigPoint, arraySize, invocationId, - namePrefix, semanticToUse, noWriteBack, - asNoInterp, *value, loc); + return createStructOutputVar(stageVarData, *value, noWriteBack); } } diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.h b/tools/clang/lib/SPIRV/DeclResultIdMapper.h index fc239b84fa..45c54eeb41 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.h +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.h @@ -381,6 +381,46 @@ class DeclResultIdMapper { int indexInCTBuffer; }; + /// The struct containing the data needed to create the input and output + /// variables for the decl. + struct StageVarDataBundle { + public: + StageVarDataBundle() = default; + // The declaration of the variable for which we need to create the stage + // variables. + const NamedDecl *decl; + + // The HLSL semantic to apply to the variable. Note that this could be + // different than the semantic attached to decl because it could inherit + // the semantic from the parent declaration if this declaration is a member. + SemanticInfo *semantic; + + // True if the variable is not suppose to be interpolated. Note that we + // cannot just look at decl to determine this because the attribute might + // have been applied to a parent declaration. + bool asNoInterp; + + // The sigPoint is the shader stage that this variable should be added to, + // and whether it is an input or output. + const hlsl::SigPoint *sigPoint; + + // The type to use for the new variable. There are cases where the type + // might be different. See the call sites for createStageVars. + QualType type; + + // If the shader stage for the variable is HS, DS, or GS, the SPIR-V + // requires that the stage variable is an array of type. The arraySize gives + // the size for that array. + uint32_t arraySize; + + // A prefix to use for the name of the variable. + llvm::StringRef namePrefix; + + // If arraySize is not zero, invocationId gives the index to used when + // generating a write to the stage variable. + llvm::Optional invocationId; + }; + /// \brief Returns the SPIR-V information for the given decl. /// Returns nullptr if no such decl was previously registered. const DeclSpirvInfo *getDeclSpirvInfo(const ValueDecl *decl) const; @@ -621,113 +661,75 @@ class DeclResultIdMapper { const DeclContext *decl, int arraySize, ContextUsageKind usageKind, llvm::StringRef typeName, llvm::StringRef varName); - /// Creates all the stage variables mapped from semantics on the given decl. - /// Returns true on success. + /// Creates all of the stage variables that must be generated for the given + /// stage variable data. Returns true on success. /// - /// If decl is of struct type, this means flattening it and create stand- - /// alone variables for each field. If arraySize is not zero, the created - /// stage variables will be arrays of the original type and the given size. - /// This is for supporting HS/DS/GS, which takes in primitives containing - /// multiple vertices. asType should be the type we are treating decl as; - /// For HS/DS/GS, the outermost arrayness should be discarded and use - /// arraySize instead. + /// stageVarData: See the definition of StageVarDataBundle to see how that + /// data is used. /// - /// Also performs reading the stage variables and compose a temporary value - /// of the given type and writing into *value, if asInput is true. Otherwise, - /// Decomposes the *value according to type and writes back into the stage - /// output variables, unless noWriteBack is set to true. noWriteBack is used - /// by GS since in GS we manually control write back using .Append() method. + /// asInput: True if the stage variable is an input. /// - /// invocationId is only used for HS to indicate the index of the output - /// array element to write to. + /// TODO(s-perron): a variable that is an input or an output depending on + /// value of a flag is very hard to read. This function should be split up + /// and flag variables removed. /// - /// Assumes the decl has semantic attached to itself or to its fields. - /// If inheritSemantic is valid, it will override all semantics attached to - /// the children of this decl, and the children of this decl will be using - /// the semantic in inheritSemantic, with index increasing sequentially. - bool createStageVars(const hlsl::SigPoint *sigPoint, const NamedDecl *decl, - bool asInput, QualType asType, uint32_t arraySize, - const llvm::StringRef namePrefix, - llvm::Optional invocationId, - SpirvInstruction **value, bool noWriteBack, - SemanticInfo *inheritSemantic, bool asNoInterp = false); + /// [in/out] value: If `asInput` is true, this is an + /// output, and will be an instruction that loads the stage variable. If + /// `asInput` is false, then it is an input to createStageVars, and contains + /// the value to be stored in the new stage variable. + /// + /// noWriteBack: If true, the newly created stage variable will not be written + /// to. + bool createStageVars(StageVarDataBundle &stageVarData, bool asInput, + SpirvInstruction **value, bool noWriteBack); // Creates a variable to represent the output variable, which must be a // structure. If `noWriteBack` is false, then `value` will be written to the // new variable. Returns true if successful. // - // TODO(s-perron): All of the parameters associated with the HLSL - // representation of the variable should be grouped in a struct. This struct - // should be used in all of the function below to reduce the number of - // parameters. + // stageVarData: The data needed to create the stage variable. + // + // noWriteBack: A flag to indicate if the variable should be written or not. // - // type: The type of the variable in the HLSL. - // sigPoint: The signature point for the variable. - // arraySize: The size of the array when the SPIR-V variable must become an - // array of type, 0 otherwise. invocationId: when arraySize is not 0, - // invocationId is the index at which `value` should be written. namePrefix: a - // prefix to apply to the name of the variable. semantic: the HLSL semantic - // for this variable. noWriteBack: A flag to indicate if the variable should - // be written or not. asNoInterp: True if the variable is not to be - // interpolated. value: The value to be written to the newly create variable. - // loc: The location of the variable in the source code. - bool createStructOutputVar(QualType type, const hlsl::SigPoint *sigPoint, - uint32_t arraySize, - llvm::Optional invocationId, - const llvm::StringRef namePrefix, - SemanticInfo *semantic, bool noWriteBack, - bool asNoInterp, SpirvInstruction *value, - SourceLocation loc); + // value: The value to be written to the newly create variable. + bool createStructOutputVar(const StageVarDataBundle &stageVarData, + SpirvInstruction *value, bool noWriteBack); // Creates a variable to represent the input variable, which must be a // structure. The value is loaded and the instruction with the final value is // return. // - // type: The type of the variable in the HLSL. - // sigPoint: The signature point for the variable. - // arraySize: The size of the array when the SPIR-V variable must become an - // array of type, 0 otherwise. namePrefix: a prefix to apply to the name of - // the variable. asNoInterp: True if the variable is not to be interpolated. + // stageVarData: The data needed to create the stage variable. + // // noWriteBack: A flag to indicate if the variable should be written or not. - // semantic: the HLSL semantic for this variable. - // loc: The location of the variable in the source code. - SpirvInstruction * - createStructInputVar(QualType type, const hlsl::SigPoint *sigPoint, - uint32_t arraySize, const llvm::StringRef namePrefix, - bool asNoInterp, bool noWriteBack, - SemanticInfo *semantic, SourceLocation loc); + SpirvInstruction *createStructInputVar(const StageVarDataBundle &stageVarData, + bool noWriteBack); // Store `value` to the shader output variable `varInstr`. Since the type - // could be different, other data is used to know how to convert `value` into - // the correct type for `varInstr`. + // could be different, stageVarData is used to know how to convert `value` + // into the correct type for `varInstr`. + // + // varInstr: the output variable that corresponds to `stageVarData`. It must + // not be a struct. // - // varInstr: the output variable that corresponds to `decl`. It must not be a - // struct. semanticKind: the HLSL semantic for this variable. type: The type - // of the variable in the HLSL. value: The value to be written to the create - // variable. invocationId: when arraySize is not 0, invocationId is the index - // at which `value` should be written. sigPointKind: The signature point for - // the variable. decl: The original declaration of the variable. loc: The - // location of the variable in the source code. - void storeToShaderOutputVariable( - SpirvVariable *varInstr, hlsl::Semantic::Kind semanticKind, QualType type, - SpirvInstruction *value, llvm::Optional invocationId, - hlsl::SigPoint::Kind sigPointKind, const NamedDecl *decl, - SourceLocation loc); - - // Loads shader input variable `varInstr`, and modifies it to match `type`. - // Other data is used to know how to convert `varInstr` into the correct type. + // value: The value to be written to the create variable. // - // varInstr: the input variable that corresponds to `decl`. It must not be a - // struct. semanticKind: the HLSL semantic for this variable. sigPointKind: - // The signature point for the variable. type: The type of the variable in the - // HLSL. decl: The original declaration of the variable. loc: The location of - // the variable in the source code. - SpirvInstruction *loadShaderInputVariable(SpirvVariable *varInstr, - hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, - QualType type, - const NamedDecl *decl, - SourceLocation loc); + // stageVarData: The data that was used to create `varInstr`. + void storeToShaderOutputVariable(SpirvVariable *varInstr, + SpirvInstruction *value, + const StageVarDataBundle &stageVarData); + + // Loads shader input variable `varInstr`, and modifies the value to match the + // type in stageVarData. The struct stageVarData is used to know how to + // convert the value loaded from `varInstr` into the correct type. + // + // varInstr: the input variable that corresponds to `stageVarData`. It must + // not be a struct. + // + // stageVarData: The data that was used to create `varInstr`. + SpirvInstruction * + loadShaderInputVariable(SpirvVariable *varInstr, + const StageVarDataBundle &stageVarData); // Creates a function scope variable to represent the "SV_InstanceID" // semantic, which it not immediately available in SPIR-V. Its value will be @@ -735,8 +737,10 @@ class DeclResultIdMapper { // variables. // // instanceIndexVar: The SPIR-V input variable that decorated with - // InstanceIndex baseInstanceVar: The SPIR-V input variable that is decorated - // with BaseInstance. + // InstanceIndex. + // + // baseInstanceVar: The SPIR-V input variable that is decorated with + // BaseInstance. SpirvVariable *getInstanceIdFromIndexAndBase(SpirvVariable *instanceIndexVar, SpirvVariable *baseInstanceVar); @@ -750,6 +754,7 @@ class DeclResultIdMapper { // with reflection. // // semantic: the semantic to attach to this variable + // // sigPoint: the signature point identifying which shader stage the variable // will be used in. SpirvVariable *getBaseInstanceVariable(SemanticInfo *semantic, @@ -759,57 +764,34 @@ class DeclResultIdMapper { // The new variable with be add to `this->StageVars`. // // - // sigPoint: the signature point identifying which shader stage the variable - // will be used in. semantic: the HLSL semantic for the variable. decl: the - // declaration of the variable. type: the HLSL type of the variable arraySize: - // the size of the array if `type` should be turned into an array. 0 if `type` - // should not be turned into and array. - SpirvVariable *createSpirvInterfaceVariable(const hlsl::SigPoint *sigPoint, - SemanticInfo *semantic, - const NamedDecl *decl, - QualType type, uint32_t arraySize, - bool asNoInterp, - const llvm::StringRef namePrefix); + // stageVarData: the data needed to create the interface variable. See the + // declaration of StageVarDataBundle for the details. + SpirvVariable * + createSpirvInterfaceVariable(const StageVarDataBundle &stageVarData); // Returns the type that the SPIR-V input or output variable must have to // correspond to a variable with the given information. // - // type: The type of the variable in HLSL. - // semanticKind: The semantic attached to the variable. - // sigPointKind: The signature point (shader, input or output) of the - // variable. decl: The declaration of the variable. arraySize: The size of the - // array the variable should be converted to. 0 if it should not be turned - // into an array. - QualType getTypeForSpirvStageVariable(QualType type, - hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, - const NamedDecl *decl, - uint32_t arraySize); - - // Returns true if all of the information is consistent with a valid shader - // stage variable. Issues an error and returns false otherwise. - bool validateShaderStageVar(SemanticInfo *semantic, - const hlsl::SigPoint *sigPoint, - const NamedDecl *decl, QualType type); + // stageVarData: the data needed to create the interface variable. See the + // declaration of StageVarDataBundle for the details. + QualType getTypeForSpirvStageVariable(const StageVarDataBundle &stageVarData); + + // Returns true if all of the stage variable data is consistent with a valid + // shader stage variable. Issues an error and returns false otherwise. + bool validateShaderStageVar(const StageVarDataBundle &stageVarData); /// Returns true if all vk:: attributes usages are valid. bool validateVKAttributes(const NamedDecl *decl); /// Returns true if all vk::builtin usages are valid. - bool validateVKBuiltins(const NamedDecl *decl, - const hlsl::SigPoint *sigPoint); + bool validateVKBuiltins(const StageVarDataBundle &stageVarData); - // Returns true if the type and semantic kind are compatible. Issues an error - // and returns false otherwise. - bool validateShaderStageVarType(hlsl::Semantic::Kind semanticKind, - QualType type, clang::SourceLocation loc); + // Returns true if the type in stageVarData is compatible with the rest of the + // data. Issues an error and returns false otherwise. + bool validateShaderStageVarType(const StageVarDataBundle &stageVarData); - // Returns true if the semantic can be used with the given signature point. - // The declaration is used to get the builtin attribute to augment the - // semantic. - bool isValidSemanticInShaderModel(hlsl::Semantic::Kind semanticKind, - hlsl::SigPoint::Kind sigPointKind, - const NamedDecl *decl); + // Returns true if the semantic is consistent wit the rest of the given data. + bool isValidSemanticInShaderModel(const StageVarDataBundle &stageVarData); /// Creates the SPIR-V variable instruction for the given StageVar and returns /// the instruction. Also sets whether the StageVar is a SPIR-V builtin and @@ -824,6 +806,7 @@ class DeclResultIdMapper { /// Creates assoicated counter variables for all AssocCounter cases (see the /// comment of CounterVarFields). void createCounterVarForDecl(const DeclaratorDecl *decl); + /// Creates the associated counter variable for final RW/Append/Consume /// structured buffer. Handles AssocCounter#1 and AssocCounter#2 (see the /// comment of CounterVarFields).