Skip to content

Commit

Permalink
[CIR][CUDA] Added NVPTX64 ABI info
Browse files Browse the repository at this point in the history
  • Loading branch information
AdUhTkJm committed Jan 29, 2025
1 parent d329c96 commit 8107643
Showing 1 changed file with 153 additions and 0 deletions.
153 changes: 153 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,36 @@ class SPIRVTargetCIRGenInfo : public CommonSPIRTargetCIRGenInfo {

} // namespace


//===----------------------------------------------------------------------===//
// NVPTX ABI Implementation
//===----------------------------------------------------------------------===//

namespace {

class NVPTXABIInfo : public ABIInfo {
mlir::MLIRContext &MLCtx;
public:
NVPTXABIInfo(CIRGenTypes &CGT, mlir::MLIRContext &MLCtx)
: ABIInfo(CGT), MLCtx(MLCtx) {}

bool isUnsupportedType(QualType T) const;
cir::ABIArgInfo coerceToIntArrayWithLimit(QualType Ty, unsigned MaxSize) const;

cir::ABIArgInfo classifyReturnType(QualType RetTy) const;
cir::ABIArgInfo classifyArgumentType(QualType Ty) const;

void computeInfo(CIRGenFunctionInfo &FI) const override;
};

class NVPTXTargetCIRGenInfo : public TargetCIRGenInfo {
public:
NVPTXTargetCIRGenInfo(CIRGenTypes &CGT, mlir::MLIRContext &MLCtx)
: TargetCIRGenInfo(std::make_unique<NVPTXABIInfo>(CGT, MLCtx)) {}
};

}

// TODO(cir): remove the attribute once this gets used.
LLVM_ATTRIBUTE_UNUSED
static bool classifyReturnType(const CIRGenCXXABI &CXXABI,
Expand Down Expand Up @@ -443,6 +473,125 @@ cir::ABIArgInfo X86_64ABIInfo::classifyArgumentType(QualType Ty,
return cir::ABIArgInfo::getDirect(ResType);
}

/// Checks if the type is unsupported directly by the current target.
bool NVPTXABIInfo::isUnsupportedType(QualType T) const {
ASTContext &Context = getContext();
if (!Context.getTargetInfo().hasFloat16Type() && T->isFloat16Type())
return true;

if (!Context.getTargetInfo().hasFloat128Type() &&
(T->isFloat128Type() ||
(T->isRealFloatingType() && Context.getTypeSize(T) == 128)))
return true;

if (const auto *EIT = T->getAs<BitIntType>())
return EIT->getNumBits() >
(Context.getTargetInfo().hasInt128Type() ? 128U : 64U);

if (!Context.getTargetInfo().hasInt128Type() && T->isIntegerType() &&
Context.getTypeSize(T) > 64U)
return true;

if (const auto *AT = T->getAsArrayTypeUnsafe())
return isUnsupportedType(AT->getElementType());

const auto *RT = T->getAs<RecordType>();
if (!RT)
return false;

const RecordDecl *RD = RT->getDecl();

// If this is a C++ record, check the bases first.
if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD))
for (const CXXBaseSpecifier &I : CXXRD->bases())
if (isUnsupportedType(I.getType()))
return true;

for (const FieldDecl *I : RD->fields())
if (isUnsupportedType(I->getType()))
return true;
return false;
}

/// Coerce the given type into an array with maximum allowed size of elements.
cir::ABIArgInfo NVPTXABIInfo::coerceToIntArrayWithLimit(QualType Ty,
unsigned MaxSize) const {
// Alignment and Size are measured in bits.
const uint64_t Size = getContext().getTypeSize(Ty);
const uint64_t Alignment = getContext().getTypeAlign(Ty);
const unsigned Div = std::min<unsigned>(MaxSize, Alignment);
cir::IntType IntType = cir::IntType::get(&MLCtx, Div, false);
const uint64_t NumElements = (Size + Div - 1) / Div;
return cir::ABIArgInfo::getDirect(cir::ArrayType::get(&MLCtx, IntType, NumElements));
}

cir::ABIArgInfo NVPTXABIInfo::classifyReturnType(QualType RetTy) const {
if (RetTy->isVoidType())
return cir::ABIArgInfo::getIgnore();

if (getContext().getLangOpts().OpenMP &&
getContext().getLangOpts().OpenMPIsTargetDevice &&
isUnsupportedType(RetTy))
return coerceToIntArrayWithLimit(RetTy, 64);

// note: this is different from default ABI
if (!RetTy->isScalarType())
return cir::ABIArgInfo::getDirect();

// Treat an enum type as its underlying type.
if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
RetTy = EnumTy->getDecl()->getIntegerType();

return (isPromotableIntegerTypeForABI(RetTy) ? cir::ABIArgInfo::getExtend(RetTy)
: cir::ABIArgInfo::getDirect());
}

cir::ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
// Treat an enum type as its underlying type.
if (const EnumType *EnumTy = Ty->getAs<EnumType>())
Ty = EnumTy->getDecl()->getIntegerType();

// Return aggregate type as indirect by value
if (isAggregateTypeForABI(Ty)) {
// Under CUDA device compilation, tex/surf builtin types are replaced with
// object types and passed directly.
if (getContext().getLangOpts().CUDAIsDevice) {
if (Ty->isCUDADeviceBuiltinSurfaceType() ||
Ty->isCUDADeviceBuiltinTextureType())

// On the device side, both surface and texture reference
// is represented as an object handle in 64-bit integer.
return cir::ABIArgInfo::getDirect(cir::IntType::get(&MLCtx, 64, false));

}

clang::CharUnits Alignment = getContext().getTypeAlignInChars(Ty);
return cir::ABIArgInfo::getIndirect(Alignment.getQuantity());
}

if (const auto *EIT = Ty->getAs<BitIntType>()) {
if ((EIT->getNumBits() > 128) ||
(!getContext().getTargetInfo().hasInt128Type() &&
EIT->getNumBits() > 64)) {
clang::CharUnits Alignment = getContext().getTypeAlignInChars(Ty);
return cir::ABIArgInfo::getIndirect(Alignment.getQuantity());
}
}

return (isPromotableIntegerTypeForABI(Ty) ? cir::ABIArgInfo::getExtend(Ty)
: cir::ABIArgInfo::getDirect());
}

void NVPTXABIInfo::computeInfo(CIRGenFunctionInfo &FI) const {
if (!getCXXABI().classifyReturnType(FI))
FI.getReturnInfo() = classifyReturnType(FI.getReturnType());

for (auto &&[ArgumentsCount, I] : llvm::enumerate(FI.arguments()))
I.info = ArgumentsCount < FI.getNumRequiredArgs()
? classifyArgumentType(I.type)
: cir::ABIArgInfo::getDirect();
}

ABIInfo::~ABIInfo() {}

bool ABIInfo::isPromotableIntegerTypeForABI(QualType Ty) const {
Expand Down Expand Up @@ -634,5 +783,9 @@ const TargetCIRGenInfo &CIRGenModule::getTargetCIRGenInfo() {
case llvm::Triple::spirv64: {
return SetCIRGenInfo(new SPIRVTargetCIRGenInfo(genTypes));
}

case llvm::Triple::nvptx64: {
return SetCIRGenInfo(new NVPTXTargetCIRGenInfo(genTypes, getMLIRContext()));
}
}
}

0 comments on commit 8107643

Please sign in to comment.