Skip to content

Commit

Permalink
Merge branch 'llvm:main' into features/hip-device
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy authored Feb 7, 2025
2 parents 396a9f7 + a7383c9 commit ae72923
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 98 deletions.
47 changes: 47 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the CirAttrVisitor interface.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H

#include "clang/CIR/Dialect/IR/CIRAttrs.h"

namespace cir {

#define DISPATCH(NAME) return getImpl()->visitCir##NAME(cirAttr);

template <typename ImplClass, typename RetTy> class CirAttrVisitor {
public:
RetTy visit(mlir::Attribute attr) {
#define ATTRDEF(NAME) \
if (const auto cirAttr = mlir::dyn_cast<cir::NAME>(attr)) \
DISPATCH(NAME);
#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc"
llvm_unreachable("unhandled attribute type");
}

// If the implementation chooses not to implement a certain visit
// method, fall back to the parent.
#define ATTRDEF(NAME) \
RetTy visitCir##NAME(NAME cirAttr) { DISPATCH(Attr); }
#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc"

RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }

ImplClass *getImpl() { return static_cast<ImplClass *>(this); }
};

#undef DISPATCH

} // namespace cir

#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mlir_tablegen(CIROpsStructs.h.inc -gen-attrdef-decls)
mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
mlir_tablegen(CIRAttrDefsList.inc -gen-attrdef-list)
add_public_tablegen_target(MLIRCIREnumsGen)

clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering
Expand Down
156 changes: 59 additions & 97 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/LoweringHelpers.h"
#include "clang/CIR/MissingFeatures.h"
Expand Down Expand Up @@ -425,32 +426,52 @@ emitCirAttrToMemory(mlir::Operation *parentOp, mlir::Attribute attr,
}

/// Switches on the type of attribute and calls the appropriate conversion.
class CirAttrToValue : public CirAttrVisitor<CirAttrToValue, mlir::Value> {
public:
CirAttrToValue(mlir::Operation *parentOp,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout)
: parentOp(parentOp), rewriter(rewriter), converter(converter),
dataLayout(dataLayout) {}

mlir::Value visitCirIntAttr(cir::IntAttr attr);
mlir::Value visitCirFPAttr(cir::FPAttr attr);
mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr attr);
mlir::Value visitCirConstStructAttr(cir::ConstStructAttr attr);
mlir::Value visitCirConstArrayAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirConstVectorAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirBoolAttr(cir::BoolAttr attr);
mlir::Value visitCirZeroAttr(cir::ZeroAttr attr);
mlir::Value visitCirUndefAttr(cir::UndefAttr attr);
mlir::Value visitCirPoisonAttr(cir::PoisonAttr attr);
mlir::Value visitCirGlobalViewAttr(cir::GlobalViewAttr attr);
mlir::Value visitCirVTableAttr(cir::VTableAttr attr);
mlir::Value visitCirTypeInfoAttr(cir::TypeInfoAttr attr);

private:
mlir::Operation *parentOp;
mlir::ConversionPatternRewriter &rewriter;
const mlir::TypeConverter *converter;
mlir::DataLayout const &dataLayout;
};

/// IntAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::IntAttr intAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirIntAttr(cir::IntAttr intAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
}

/// BoolAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::BoolAttr boolAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirBoolAttr(cir::BoolAttr boolAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(boolAttr.getType()), boolAttr.getValue());
}

/// ConstPtrAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
auto loc = parentOp->getLoc();
if (ptrAttr.isNullValue()) {
return rewriter.create<mlir::LLVM::ZeroOp>(
Expand All @@ -465,51 +486,36 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr,
}

/// FPAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::FPAttr fltAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirFPAttr(cir::FPAttr fltAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
}

/// ZeroAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ZeroAttr zeroAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirZeroAttr(cir::ZeroAttr zeroAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ZeroOp>(
loc, converter->convertType(zeroAttr.getType()));
}

/// UndefAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirUndefAttr(cir::UndefAttr undefAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::UndefOp>(
loc, converter->convertType(undefAttr.getType()));
}

/// PoisonAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirPoisonAttr(cir::PoisonAttr poisonAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::PoisonOp>(
loc, converter->convertType(poisonAttr.getType()));
}

/// ConstStruct visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirConstStructAttr(cir::ConstStructAttr constStruct) {
auto llvmTy = converter->convertType(constStruct.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
Expand All @@ -525,49 +531,37 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
}

// VTableAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value CirAttrToValue::visitCirVTableAttr(cir::VTableAttr vtableArr) {
auto llvmTy = converter->convertType(vtableArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) {
mlir::Value init =
lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout);
mlir::Value init = visit(elt);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

// TypeInfoAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirTypeInfoAttr(cir::TypeInfoAttr typeinfoArr) {
auto llvmTy = converter->convertType(typeinfoArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) {
mlir::Value init =
lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout);
mlir::Value init = visit(elt);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

// ConstArrayAttr visitor
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirConstArrayAttr(cir::ConstArrayAttr constArr) {
auto llvmTy = converter->convertType(constArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result;
Expand Down Expand Up @@ -610,10 +604,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr,
}

// ConstVectorAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value
CirAttrToValue::visitCirConstVectorAttr(cir::ConstVectorAttr constVec) {
auto llvmTy = converter->convertType(constVec.getType());
auto loc = parentOp->getLoc();
SmallVector<mlir::Attribute> mlirValues;
Expand All @@ -638,11 +630,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
}

// GlobalViewAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirGlobalViewAttr(cir::GlobalViewAttr globalAttr) {
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
mlir::Type sourceType;
unsigned sourceAddrSpace = 0;
Expand Down Expand Up @@ -716,43 +705,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
}

/// Switches on the type of attribute and calls the appropriate conversion.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter);
if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter);
if (const auto constStruct = mlir::dyn_cast<cir::ConstStructAttr>(attr))
return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter,
dataLayout);
if (const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(attr))
return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter,
dataLayout);
if (const auto constVec = mlir::dyn_cast<cir::ConstVectorAttr>(attr))
return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter);
if (const auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter);
if (const auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(attr))
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter,
dataLayout);
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter,
dataLayout);
if (const auto typeinfoAttr = mlir::dyn_cast<cir::TypeInfoAttr>(attr))
return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter,
dataLayout);

llvm_unreachable("unhandled attribute type");
CirAttrToValue valueConverter(parentOp, rewriter, converter, dataLayout);
auto value = valueConverter.visit(attr);
if (!value)
llvm_unreachable("unhandled attribute type");
return value;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1734,8 +1696,8 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
// Regardless of the type, we should lower the constant of poison value
// into PoisonOp.
if (auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr)) {
rewriter.replaceOp(
op, lowerCirAttrAsValue(op, poisonAttr, rewriter, getTypeConverter()));
rewriter.replaceOp(op, lowerCirAttrAsValue(op, poisonAttr, rewriter,
getTypeConverter(), dataLayout));
return mlir::success();
}

Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace direct {

/// Convert a CIR attribute to an LLVM attribute. May use the datalayout for
/// lowering attributes to-be-stored in memory.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout);
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/mlir-tblgen/attrdefs.td
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -gen-attrdef-list -I %S/../../include %s | FileCheck %s --check-prefix=LIST

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
Expand All @@ -19,6 +20,13 @@ include "mlir/IR/OpBase.td"
// DEF: ::test::CompoundAAttr,
// DEF: ::test::SingleParameterAttr

// LIST: ATTRDEF(IndexAttr)
// LIST: ATTRDEF(SimpleAAttr)
// LIST: ATTRDEF(CompoundAAttr)
// LIST: ATTRDEF(SingleParameterAttr)

// LIST: #undef ATTRDEF

// DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
// DEF-SAME: ::mlir::AsmParser &parser,
// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type,
Expand Down
Loading

0 comments on commit ae72923

Please sign in to comment.