Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Add attribute visitor for lowering globals #1318

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the CirAttrVisitor interface.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H

#include "clang/CIR/Dialect/IR/CIRAttrs.h"

namespace cir {

#define DISPATCH(NAME) return getImpl()->visitCir##NAME(cirAttr);

template <typename ImplClass, typename RetTy> class CirAttrVisitor {
public:
RetTy visit(mlir::Attribute attr) {
#define ATTRDEF(NAME) \
if (const auto cirAttr = mlir::dyn_cast<cir::NAME>(attr)) \
DISPATCH(NAME);
#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc"
llvm_unreachable("unhandled attribute type");
}

// If the implementation chooses not to implement a certain visit
// method, fall back to the parent.
#define ATTRDEF(NAME) \
RetTy visitCir##NAME(NAME cirAttr) { DISPATCH(Attr); }
#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc"

RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }

ImplClass *getImpl() { return static_cast<ImplClass *>(this); }
};

#undef DISPATCH

} // namespace cir

#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mlir_tablegen(CIROpsStructs.h.inc -gen-attrdef-decls)
mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
mlir_tablegen(CIRAttrDefsList.inc -gen-attrdef-list)
add_public_tablegen_target(MLIRCIREnumsGen)

clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering
Expand Down
156 changes: 59 additions & 97 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/LoweringHelpers.h"
#include "clang/CIR/MissingFeatures.h"
Expand Down Expand Up @@ -425,32 +426,52 @@ emitCirAttrToMemory(mlir::Operation *parentOp, mlir::Attribute attr,
}

/// Switches on the type of attribute and calls the appropriate conversion.
class CirAttrToValue : public CirAttrVisitor<CirAttrToValue, mlir::Value> {
public:
CirAttrToValue(mlir::Operation *parentOp,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout)
: parentOp(parentOp), rewriter(rewriter), converter(converter),
dataLayout(dataLayout) {}

mlir::Value visitCirIntAttr(cir::IntAttr attr);
mlir::Value visitCirFPAttr(cir::FPAttr attr);
mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr attr);
mlir::Value visitCirConstStructAttr(cir::ConstStructAttr attr);
mlir::Value visitCirConstArrayAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirConstVectorAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirBoolAttr(cir::BoolAttr attr);
mlir::Value visitCirZeroAttr(cir::ZeroAttr attr);
mlir::Value visitCirUndefAttr(cir::UndefAttr attr);
mlir::Value visitCirPoisonAttr(cir::PoisonAttr attr);
mlir::Value visitCirGlobalViewAttr(cir::GlobalViewAttr attr);
mlir::Value visitCirVTableAttr(cir::VTableAttr attr);
mlir::Value visitCirTypeInfoAttr(cir::TypeInfoAttr attr);

private:
mlir::Operation *parentOp;
mlir::ConversionPatternRewriter &rewriter;
const mlir::TypeConverter *converter;
mlir::DataLayout const &dataLayout;
};

/// IntAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::IntAttr intAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirIntAttr(cir::IntAttr intAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
}

/// BoolAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::BoolAttr boolAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirBoolAttr(cir::BoolAttr boolAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(boolAttr.getType()), boolAttr.getValue());
}

/// ConstPtrAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
auto loc = parentOp->getLoc();
if (ptrAttr.isNullValue()) {
return rewriter.create<mlir::LLVM::ZeroOp>(
Expand All @@ -465,51 +486,36 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr,
}

/// FPAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::FPAttr fltAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirFPAttr(cir::FPAttr fltAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
}

/// ZeroAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ZeroAttr zeroAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirZeroAttr(cir::ZeroAttr zeroAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ZeroOp>(
loc, converter->convertType(zeroAttr.getType()));
}

/// UndefAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirUndefAttr(cir::UndefAttr undefAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::UndefOp>(
loc, converter->convertType(undefAttr.getType()));
}

/// PoisonAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value CirAttrToValue::visitCirPoisonAttr(cir::PoisonAttr poisonAttr) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::PoisonOp>(
loc, converter->convertType(poisonAttr.getType()));
}

/// ConstStruct visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirConstStructAttr(cir::ConstStructAttr constStruct) {
auto llvmTy = converter->convertType(constStruct.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
Expand All @@ -525,49 +531,37 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
}

// VTableAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value CirAttrToValue::visitCirVTableAttr(cir::VTableAttr vtableArr) {
auto llvmTy = converter->convertType(vtableArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) {
mlir::Value init =
lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout);
mlir::Value init = visit(elt);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

// TypeInfoAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirTypeInfoAttr(cir::TypeInfoAttr typeinfoArr) {
auto llvmTy = converter->convertType(typeinfoArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) {
mlir::Value init =
lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout);
mlir::Value init = visit(elt);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

// ConstArrayAttr visitor
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirConstArrayAttr(cir::ConstArrayAttr constArr) {
auto llvmTy = converter->convertType(constArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result;
Expand Down Expand Up @@ -610,10 +604,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr,
}

// ConstVectorAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
mlir::Value
CirAttrToValue::visitCirConstVectorAttr(cir::ConstVectorAttr constVec) {
auto llvmTy = converter->convertType(constVec.getType());
auto loc = parentOp->getLoc();
SmallVector<mlir::Attribute> mlirValues;
Expand All @@ -638,11 +630,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
}

// GlobalViewAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
mlir::Value
CirAttrToValue::visitCirGlobalViewAttr(cir::GlobalViewAttr globalAttr) {
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
mlir::Type sourceType;
unsigned sourceAddrSpace = 0;
Expand Down Expand Up @@ -716,43 +705,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
}

/// Switches on the type of attribute and calls the appropriate conversion.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout) {
if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter);
if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter);
if (const auto constStruct = mlir::dyn_cast<cir::ConstStructAttr>(attr))
return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter,
dataLayout);
if (const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(attr))
return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter,
dataLayout);
if (const auto constVec = mlir::dyn_cast<cir::ConstVectorAttr>(attr))
return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter);
if (const auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter);
if (const auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(attr))
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter,
dataLayout);
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter,
dataLayout);
if (const auto typeinfoAttr = mlir::dyn_cast<cir::TypeInfoAttr>(attr))
return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter,
dataLayout);

llvm_unreachable("unhandled attribute type");
CirAttrToValue valueConverter(parentOp, rewriter, converter, dataLayout);
auto value = valueConverter.visit(attr);
if (!value)
llvm_unreachable("unhandled attribute type");
return value;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1734,8 +1696,8 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
// Regardless of the type, we should lower the constant of poison value
// into PoisonOp.
if (auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr)) {
rewriter.replaceOp(
op, lowerCirAttrAsValue(op, poisonAttr, rewriter, getTypeConverter()));
rewriter.replaceOp(op, lowerCirAttrAsValue(op, poisonAttr, rewriter,
getTypeConverter(), dataLayout));
return mlir::success();
}

Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace direct {

/// Convert a CIR attribute to an LLVM attribute. May use the datalayout for
/// lowering attributes to-be-stored in memory.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout);
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/mlir-tblgen/attrdefs.td
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -gen-attrdef-list -I %S/../../include %s | FileCheck %s --check-prefix=LIST

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
Expand All @@ -19,6 +20,13 @@ include "mlir/IR/OpBase.td"
// DEF: ::test::CompoundAAttr,
// DEF: ::test::SingleParameterAttr

// LIST: ATTRDEF(IndexAttr)
// LIST: ATTRDEF(SimpleAAttr)
// LIST: ATTRDEF(CompoundAAttr)
// LIST: ATTRDEF(SingleParameterAttr)

// LIST: #undef ATTRDEF

// DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
// DEF-SAME: ::mlir::AsmParser &parser,
// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type,
Expand Down
Loading