From f7b915147d50fc70a4d1e689a7c648a1a6354003 Mon Sep 17 00:00:00 2001 From: Andy Kaylor Date: Fri, 7 Feb 2025 13:15:21 -0800 Subject: [PATCH 1/2] [CIR] Add attribute visitor for lowering globals (#1318) This adds a new mlir-tablegen option to generate a .inc file with the complete set of attrdefs defined in a .td file and uses the file generated for CIR attrdefs to create an attr visitor. This visitor is used in the lowering of global variables directly to LLVM IR. The purpose of this change is to align the incubator lowering implementation with the recent upstream changes to make future upstreaming easier, while also fulfilling the upstream request to have the visitor be based on a tablegen created file. The new mlir-tablegen feature will be upstreamed after it is established here. No observable change is intended in the CIR code. --- .../clang/CIR/Dialect/IR/CIRAttrVisitor.h | 47 ++++++ .../clang/CIR/Dialect/IR/CMakeLists.txt | 1 + .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 156 +++++++----------- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 3 +- mlir/test/mlir-tblgen/attrdefs.td | 8 + mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 24 +++ 6 files changed, 141 insertions(+), 98 deletions(-) create mode 100644 clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h new file mode 100644 index 000000000000..106fb3d0ed17 --- /dev/null +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h @@ -0,0 +1,47 @@ +//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the CirAttrVisitor interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H +#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H + +#include "clang/CIR/Dialect/IR/CIRAttrs.h" + +namespace cir { + +#define DISPATCH(NAME) return getImpl()->visitCir##NAME(cirAttr); + +template class CirAttrVisitor { +public: + RetTy visit(mlir::Attribute attr) { +#define ATTRDEF(NAME) \ + if (const auto cirAttr = mlir::dyn_cast(attr)) \ + DISPATCH(NAME); +#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc" + llvm_unreachable("unhandled attribute type"); + } + + // If the implementation chooses not to implement a certain visit + // method, fall back to the parent. +#define ATTRDEF(NAME) \ + RetTy visitCir##NAME(NAME cirAttr) { DISPATCH(Attr); } +#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc" + + RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); } + + ImplClass *getImpl() { return static_cast(this); } +}; + +#undef DISPATCH + +} // namespace cir + +#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H diff --git a/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt b/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt index 3d43b06c6217..014bb3d9b03c 100644 --- a/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt +++ b/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt @@ -26,6 +26,7 @@ mlir_tablegen(CIROpsStructs.h.inc -gen-attrdef-decls) mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs) mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs) +mlir_tablegen(CIRAttrDefsList.inc -gen-attrdef-list) add_public_tablegen_target(MLIRCIREnumsGen) clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 3817eb3960a5..833d256d0404 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -41,6 +41,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" +#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/LoweringHelpers.h" #include "clang/CIR/MissingFeatures.h" @@ -425,32 +426,52 @@ emitCirAttrToMemory(mlir::Operation *parentOp, mlir::Attribute attr, } /// Switches on the type of attribute and calls the appropriate conversion. +class CirAttrToValue : public CirAttrVisitor { +public: + CirAttrToValue(mlir::Operation *parentOp, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) + : parentOp(parentOp), rewriter(rewriter), converter(converter), + dataLayout(dataLayout) {} + + mlir::Value visitCirIntAttr(cir::IntAttr attr); + mlir::Value visitCirFPAttr(cir::FPAttr attr); + mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr attr); + mlir::Value visitCirConstStructAttr(cir::ConstStructAttr attr); + mlir::Value visitCirConstArrayAttr(cir::ConstArrayAttr attr); + mlir::Value visitCirConstVectorAttr(cir::ConstVectorAttr attr); + mlir::Value visitCirBoolAttr(cir::BoolAttr attr); + mlir::Value visitCirZeroAttr(cir::ZeroAttr attr); + mlir::Value visitCirUndefAttr(cir::UndefAttr attr); + mlir::Value visitCirPoisonAttr(cir::PoisonAttr attr); + mlir::Value visitCirGlobalViewAttr(cir::GlobalViewAttr attr); + mlir::Value visitCirVTableAttr(cir::VTableAttr attr); + mlir::Value visitCirTypeInfoAttr(cir::TypeInfoAttr attr); + +private: + mlir::Operation *parentOp; + mlir::ConversionPatternRewriter &rewriter; + const mlir::TypeConverter *converter; + mlir::DataLayout const &dataLayout; +}; /// IntAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::IntAttr intAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirIntAttr(cir::IntAttr intAttr) { auto loc = parentOp->getLoc(); return rewriter.create( loc, converter->convertType(intAttr.getType()), intAttr.getValue()); } /// BoolAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::BoolAttr boolAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirBoolAttr(cir::BoolAttr boolAttr) { auto loc = parentOp->getLoc(); return rewriter.create( loc, converter->convertType(boolAttr.getType()), boolAttr.getValue()); } /// ConstPtrAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) { auto loc = parentOp->getLoc(); if (ptrAttr.isNullValue()) { return rewriter.create( @@ -465,51 +486,36 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr, } /// FPAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::FPAttr fltAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirFPAttr(cir::FPAttr fltAttr) { auto loc = parentOp->getLoc(); return rewriter.create( loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); } /// ZeroAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ZeroAttr zeroAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirZeroAttr(cir::ZeroAttr zeroAttr) { auto loc = parentOp->getLoc(); return rewriter.create( loc, converter->convertType(zeroAttr.getType())); } /// UndefAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirUndefAttr(cir::UndefAttr undefAttr) { auto loc = parentOp->getLoc(); return rewriter.create( loc, converter->convertType(undefAttr.getType())); } /// PoisonAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value CirAttrToValue::visitCirPoisonAttr(cir::PoisonAttr poisonAttr) { auto loc = parentOp->getLoc(); return rewriter.create( loc, converter->convertType(poisonAttr.getType())); } /// ConstStruct visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter, - mlir::DataLayout const &dataLayout) { +mlir::Value +CirAttrToValue::visitCirConstStructAttr(cir::ConstStructAttr constStruct) { auto llvmTy = converter->convertType(constStruct.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); @@ -525,18 +531,13 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct, } // VTableAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter, - mlir::DataLayout const &dataLayout) { +mlir::Value CirAttrToValue::visitCirVTableAttr(cir::VTableAttr vtableArr) { auto llvmTy = converter->convertType(vtableArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) { - mlir::Value init = - lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout); + mlir::Value init = visit(elt); result = rewriter.create(loc, result, init, idx); } @@ -544,18 +545,14 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr, } // TypeInfoAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter, - mlir::DataLayout const &dataLayout) { +mlir::Value +CirAttrToValue::visitCirTypeInfoAttr(cir::TypeInfoAttr typeinfoArr) { auto llvmTy = converter->convertType(typeinfoArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) { - mlir::Value init = - lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout); + mlir::Value init = visit(elt); result = rewriter.create(loc, result, init, idx); } @@ -563,11 +560,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr, } // ConstArrayAttr visitor -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter, - mlir::DataLayout const &dataLayout) { +mlir::Value +CirAttrToValue::visitCirConstArrayAttr(cir::ConstArrayAttr constArr) { auto llvmTy = converter->convertType(constArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result; @@ -610,10 +604,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr, } // ConstVectorAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value +CirAttrToValue::visitCirConstVectorAttr(cir::ConstVectorAttr constVec) { auto llvmTy = converter->convertType(constVec.getType()); auto loc = parentOp->getLoc(); SmallVector mlirValues; @@ -638,11 +630,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec, } // GlobalViewAttr visitor. -static mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter, - mlir::DataLayout const &dataLayout) { +mlir::Value +CirAttrToValue::visitCirGlobalViewAttr(cir::GlobalViewAttr globalAttr) { auto module = parentOp->getParentOfType(); mlir::Type sourceType; unsigned sourceAddrSpace = 0; @@ -716,43 +705,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, } /// Switches on the type of attribute and calls the appropriate conversion. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, +mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, + const mlir::Attribute attr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter, mlir::DataLayout const &dataLayout) { - if (const auto intAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter); - if (const auto fltAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter); - if (const auto ptrAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter); - if (const auto constStruct = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter, - dataLayout); - if (const auto constArr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter, - dataLayout); - if (const auto constVec = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter); - if (const auto boolAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter); - if (const auto zeroAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter); - if (const auto undefAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter); - if (const auto poisonAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter); - if (const auto globalAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter, - dataLayout); - if (const auto vtableAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter, - dataLayout); - if (const auto typeinfoAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter, - dataLayout); - - llvm_unreachable("unhandled attribute type"); + CirAttrToValue valueConverter(parentOp, rewriter, converter, dataLayout); + auto value = valueConverter.visit(attr); + if (!value) + llvm_unreachable("unhandled attribute type"); + return value; } //===----------------------------------------------------------------------===// @@ -1734,8 +1696,8 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( // Regardless of the type, we should lower the constant of poison value // into PoisonOp. if (auto poisonAttr = mlir::dyn_cast(attr)) { - rewriter.replaceOp( - op, lowerCirAttrAsValue(op, poisonAttr, rewriter, getTypeConverter())); + rewriter.replaceOp(op, lowerCirAttrAsValue(op, poisonAttr, rewriter, + getTypeConverter(), dataLayout)); return mlir::success(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 104ce3a0b105..bb0dcaf87efe 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -22,7 +22,8 @@ namespace direct { /// Convert a CIR attribute to an LLVM attribute. May use the datalayout for /// lowering attributes to-be-stored in memory. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, +mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, + const mlir::Attribute attr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter, mlir::DataLayout const &dataLayout); diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index 35d2c49619ee..e911f70e4358 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -1,5 +1,6 @@ // RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL // RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF +// RUN: mlir-tblgen -gen-attrdef-list -I %S/../../include %s | FileCheck %s --check-prefix=LIST include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" @@ -19,6 +20,13 @@ include "mlir/IR/OpBase.td" // DEF: ::test::CompoundAAttr, // DEF: ::test::SingleParameterAttr +// LIST: ATTRDEF(IndexAttr) +// LIST: ATTRDEF(SimpleAAttr) +// LIST: ATTRDEF(CompoundAAttr) +// LIST: ATTRDEF(SingleParameterAttr) + +// LIST: #undef ATTRDEF + // DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser( // DEF-SAME: ::mlir::AsmParser &parser, // DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type, diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 6a39424bd463..4f0100fa67cd 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -690,6 +690,7 @@ class DefGenerator { public: bool emitDecls(StringRef selectedDialect); bool emitDefs(StringRef selectedDialect); + bool emitList(StringRef selectedDialect); protected: DefGenerator(ArrayRef defs, raw_ostream &os, @@ -1025,6 +1026,23 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { return false; } +bool DefGenerator::emitList(StringRef selectedDialect) { + emitSourceFileHeader(("List of " + defType + "Def Definitions").str(), os); + + SmallVector defs; + collectAllDefs(selectedDialect, defRecords, defs); + if (defs.empty()) + return false; + + auto interleaveFn = [&](const AttrOrTypeDef &def) { + os << defType.upper() << "DEF(" << def.getCppClassName() << ")"; + }; + llvm::interleave(defs, os, interleaveFn, "\n"); + os << "\n\n"; + os << "#undef " << defType.upper() << "DEF" << "\n"; + return false; +} + //===----------------------------------------------------------------------===// // Type Constraints //===----------------------------------------------------------------------===// @@ -1099,6 +1117,12 @@ static mlir::GenRegistration AttrDefGenerator generator(records, os); return generator.emitDecls(attrDialect); }); +static mlir::GenRegistration + genAttrList("gen-attrdef-list", "Generate an AttrDef list", + [](const RecordKeeper &records, raw_ostream &os) { + AttrDefGenerator generator(records, os); + return generator.emitList(attrDialect); + }); //===----------------------------------------------------------------------===// // TypeDef From a7383c9d05165d16edba857ddc86e5d29d94d2cc Mon Sep 17 00:00:00 2001 From: Konstantinos Parasyris Date: Fri, 7 Feb 2025 13:33:49 -0800 Subject: [PATCH 2/2] [CIR][HIP] Compile host code (#1319) Adds support for `__host__` and `__device__` functions when compiling for CUDA host. The PR follows the structure of #1309 --- clang/lib/CIR/CodeGen/CIRGenModule.cpp | 10 ++--- clang/lib/CIR/CodeGen/TargetInfo.cpp | 56 ++++++++++++++++++++++++++ clang/test/CIR/CodeGen/HIP/simple.cpp | 16 ++++++++ 3 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 clang/test/CIR/CodeGen/HIP/simple.cpp diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index c58d260e166a..412369ed07ef 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -515,7 +515,8 @@ void CIRGenModule::emitGlobal(GlobalDecl GD) { assert(!Global->hasAttr() && "NYI"); assert(!Global->hasAttr() && "NYI"); - if (langOpts.CUDA) { + if (langOpts.CUDA || langOpts.HIP) { + // clang uses the same flag when building HIP code if (langOpts.CUDAIsDevice) { // This will implicitly mark templates and their // specializations as __host__ __device__. @@ -3217,8 +3218,7 @@ void CIRGenModule::Release() { if (astContext.getTargetInfo().getTriple().isWasm()) llvm_unreachable("NYI"); - if (getTriple().isAMDGPU() || - (getTriple().isSPIRV() && getTriple().getVendor() == llvm::Triple::AMD)) { + if (getTriple().isSPIRV() && getTriple().getVendor() == llvm::Triple::AMD) { llvm_unreachable("NYI"); } @@ -3229,9 +3229,7 @@ void CIRGenModule::Release() { if (!astContext.CUDAExternalDeviceDeclODRUsedByHost.empty()) { llvm_unreachable("NYI"); } - if (langOpts.HIP && !getLangOpts().OffloadingNewDriver) { - llvm_unreachable("NYI"); - } + assert(!MissingFeatures::emitLLVMUsed()); assert(!MissingFeatures::sanStats()); diff --git a/clang/lib/CIR/CodeGen/TargetInfo.cpp b/clang/lib/CIR/CodeGen/TargetInfo.cpp index 7669dad59eb8..07dca811985e 100644 --- a/clang/lib/CIR/CodeGen/TargetInfo.cpp +++ b/clang/lib/CIR/CodeGen/TargetInfo.cpp @@ -329,6 +329,30 @@ class NVPTXTargetCIRGenInfo : public TargetCIRGenInfo { } // namespace +//===----------------------------------------------------------------------===// +// AMDGPU ABI Implementation +//===----------------------------------------------------------------------===// + +namespace { + +class AMDGPUABIInfo : public ABIInfo { +public: + AMDGPUABIInfo(CIRGenTypes &cgt) : ABIInfo(cgt) {} + + cir::ABIArgInfo classifyReturnType(QualType retTy) const; + cir::ABIArgInfo classifyArgumentType(QualType ty) const; + + void computeInfo(CIRGenFunctionInfo &fnInfo) const override; +}; + +class AMDGPUTargetCIRGenInfo : public TargetCIRGenInfo { +public: + AMDGPUTargetCIRGenInfo(CIRGenTypes &cgt) + : TargetCIRGenInfo(std::make_unique(cgt)) {} +}; + +} // namespace + // TODO(cir): remove the attribute once this gets used. LLVM_ATTRIBUTE_UNUSED static bool classifyReturnType(const CIRGenCXXABI &CXXABI, @@ -495,6 +519,34 @@ void NVPTXABIInfo::computeInfo(CIRGenFunctionInfo &fnInfo) const { fnInfo.getReturnInfo() = cir::ABIArgInfo::getDirect(CGT.convertType(retTy)); } +// Skeleton only. Implement when used in TargetLower stage. +cir::ABIArgInfo AMDGPUABIInfo::classifyReturnType(QualType retTy) const { + llvm_unreachable("not yet implemented"); +} + +cir::ABIArgInfo AMDGPUABIInfo::classifyArgumentType(QualType ty) const { + llvm_unreachable("not yet implemented"); +} + +void AMDGPUABIInfo::computeInfo(CIRGenFunctionInfo &fnInfo) const { + // Top level CIR has unlimited arguments and return types. Lowering for ABI + // specific concerns should happen during a lowering phase. Assume everything + // is direct for now. + for (CIRGenFunctionInfo::arg_iterator it = fnInfo.arg_begin(), + ie = fnInfo.arg_end(); + it != ie; ++it) { + if (testIfIsVoidTy(it->type)) + it->info = cir::ABIArgInfo::getIgnore(); + else + it->info = cir::ABIArgInfo::getDirect(CGT.convertType(it->type)); + } + auto retTy = fnInfo.getReturnType(); + if (testIfIsVoidTy(retTy)) + fnInfo.getReturnInfo() = cir::ABIArgInfo::getIgnore(); + else + fnInfo.getReturnInfo() = cir::ABIArgInfo::getDirect(CGT.convertType(retTy)); +} + ABIInfo::~ABIInfo() {} bool ABIInfo::isPromotableIntegerTypeForABI(QualType Ty) const { @@ -690,5 +742,9 @@ const TargetCIRGenInfo &CIRGenModule::getTargetCIRGenInfo() { case llvm::Triple::nvptx64: { return SetCIRGenInfo(new NVPTXTargetCIRGenInfo(genTypes)); } + + case llvm::Triple::amdgcn: { + return SetCIRGenInfo(new AMDGPUTargetCIRGenInfo(genTypes)); + } } } diff --git a/clang/test/CIR/CodeGen/HIP/simple.cpp b/clang/test/CIR/CodeGen/HIP/simple.cpp new file mode 100644 index 000000000000..ec4110da10d7 --- /dev/null +++ b/clang/test/CIR/CodeGen/HIP/simple.cpp @@ -0,0 +1,16 @@ +#include "../Inputs/cuda.h" + +// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \ +// RUN: -emit-cir %s -o %t.cir +// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s + + +// This should emit as a normal C++ function. +__host__ void host_fn(int *a, int *b, int *c) {} + +// CIR: cir.func @_Z7host_fnPiS_S_ + +// This shouldn't emit. +__device__ void device_fn(int* a, double b, float c) {} + +// CHECK-NOT: cir.func @_Z9device_fnPidf