Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/circt/Dialect/OM/OMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ class IntegerBinaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
let results = (outs OMIntegerType:$result);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
let hasFolder = 1;
}

def IntegerAddOp : IntegerBinaryArithmeticOp<"integer.add", [Commutative]> {
Expand Down
27 changes: 7 additions & 20 deletions lib/Dialect/OM/Evaluator/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,28 +491,15 @@ circt::om::Evaluator::evaluateIntegerBinaryArithmetic(
assert(lhs && rhs &&
"expected om::IntegerAttr for IntegerBinaryArithmeticOp operands");

// Extend values if necessary to match bitwidth. Most interesting arithmetic
// on APSInt asserts that both operands are the same bitwidth, but the
// IntegerAttrs we are working with may have used the smallest necessary
// bitwidth to represent the number they hold, and won't necessarily match.
APSInt lhsVal = lhs.getValue().getAPSInt();
APSInt rhsVal = rhs.getValue().getAPSInt();
if (lhsVal.getBitWidth() > rhsVal.getBitWidth())
rhsVal = rhsVal.extend(lhsVal.getBitWidth());
else if (rhsVal.getBitWidth() > lhsVal.getBitWidth())
lhsVal = lhsVal.extend(rhsVal.getBitWidth());

// Perform arbitrary precision signed integer binary arithmetic.
FailureOr<APSInt> result = op.evaluateIntegerOperation(lhsVal, rhsVal);

if (failed(result))
std::array<Attribute, 2> operandAttrs = {lhs, rhs};
SmallVector<mlir::OpFoldResult, 1> results;
om::IntegerAttr resultAttr;
auto foldResult = op->fold(operandAttrs, results);
if (failed(foldResult) || results.size() != 1 ||
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget how this works. If the folder doesn't apply, then that returns {}. That won't trigger the failure here, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If folder doesn't apply, it returns LogicalResult::failure so it returns an error below (results will be empty)

!(resultAttr = llvm::dyn_cast_or_null<om::IntegerAttr>(
results[0].dyn_cast<Attribute>())))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these checks necessary/reachable given the constraints of the fold function?

Copy link
Copy Markdown
Member Author

@uenoku uenoku Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the current implementation of folders that implement IntegerBinaryArithmetic, yes it's certainly not reachable.

But I don't think this check is unnecessary since folder could return a value (even when all operands are attributes), and without this check we would get assertion failure or nullptr deference (e.g. string_concat didn't return an attribute when it had single operand without this PR).

This part https://github.com/llvm/circt/pull/10265/changes#r3105827722 is more generic version of this with a slightly better error message.

return op->emitError("failed to evaluate integer operation");

// Package the result as a new om::IntegerAttr.
MLIRContext *ctx = op->getContext();
auto resultAttr =
om::IntegerAttr::get(ctx, mlir::IntegerAttr::get(ctx, result.value()));

// Finalize the op result value.
auto *handleValue = cast<evaluator::AttributeValue>(handle.value().get());
auto resultStatus = handleValue->setAttr(resultAttr);
Expand Down
56 changes: 55 additions & 1 deletion lib/Dialect/OM/OMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "circt/Dialect/OM/OMOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/OM/OMOpInterfaces.h"
#include "circt/Dialect/OM/OMUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
Expand Down Expand Up @@ -639,6 +640,39 @@ PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

//===----------------------------------------------------------------------===//
// IntegerBinaryArithmeticOp
//===----------------------------------------------------------------------===//

static OpFoldResult foldIntegerBinaryArithmetic(IntegerBinaryArithmeticOp op,
Attribute lhsAttr,
Attribute rhsAttr) {
auto lhs = dyn_cast_or_null<circt::om::IntegerAttr>(lhsAttr);
auto rhs = dyn_cast_or_null<circt::om::IntegerAttr>(rhsAttr);
if (!lhs || !rhs)
return {};
// Extend values if necessary to match bitwidth. Most interesting arithmetic
// on APSInt asserts that both operands are the same bitwidth, but the
// IntegerAttrs we are working with may have used the smallest necessary
// bitwidth to represent the number they hold, and won't necessarily match.
APSInt lhsVal = lhs.getValue().getAPSInt();
APSInt rhsVal = rhs.getValue().getAPSInt();
if (lhsVal.getBitWidth() > rhsVal.getBitWidth())
rhsVal = rhsVal.extend(lhsVal.getBitWidth());
else if (rhsVal.getBitWidth() > lhsVal.getBitWidth())
lhsVal = lhsVal.extend(rhsVal.getBitWidth());

// Perform arbitrary precision signed integer binary arithmetic.
auto result = op.evaluateIntegerOperation(lhsVal, rhsVal);
if (failed(result))
return {};

auto *ctx = op.getContext();
// Return the result as a new om::IntegerAttr.
return circt::om::IntegerAttr::get(
ctx, mlir::IntegerAttr::get(ctx, result.value()));
}

//===----------------------------------------------------------------------===//
// IntegerAddOp
//===----------------------------------------------------------------------===//
Expand All @@ -649,6 +683,10 @@ IntegerAddOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
return success(lhs + rhs);
}

OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
}

//===----------------------------------------------------------------------===//
// IntegerMulOp
//===----------------------------------------------------------------------===//
Expand All @@ -659,6 +697,10 @@ IntegerMulOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
return success(lhs * rhs);
}

OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
}

//===----------------------------------------------------------------------===//
// IntegerShrOp
//===----------------------------------------------------------------------===//
Expand All @@ -675,6 +717,10 @@ IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
return success(lhs >> rhs.getExtValue());
}

OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
}

//===----------------------------------------------------------------------===//
// IntegerShlOp
//===----------------------------------------------------------------------===//
Expand All @@ -691,14 +737,22 @@ IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
return success(lhs << rhs.getExtValue());
}

OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
}

//===----------------------------------------------------------------------===//
// StringConcatOp
//===----------------------------------------------------------------------===//

OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
// Fold single-operand concat to just the operand.
if (getStrings().size() == 1)
if (getStrings().size() == 1) {
if (auto strAttr = adaptor.getStrings()[0])
return strAttr;
Comment thread
seldridge marked this conversation as resolved.

return getStrings()[0];
}

// Check if all operands are constant strings before accumulating.
if (!llvm::all_of(adaptor.getStrings(), [](Attribute operand) {
Expand Down
41 changes: 38 additions & 3 deletions test/Dialect/OM/canonicalizers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func.func @ObjectsMustDCE() {
return
}

om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) -> (out1: !om.string, out2: !om.string, out3: !om.string, out4: !om.string, out5: !om.string, out6: !om.string, out7: !om.string) {
om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) -> (out1: !om.string, out2: !om.string, out3: !om.string, out4: !om.string, out5: !om.string, out6: !om.string, out7: !om.string, out8: !om.string) {
Comment thread
uenoku marked this conversation as resolved.
Outdated
%s1 = om.constant "Hello" : !om.string
%s2 = om.constant "World" : !om.string
%s3 = om.constant "!" : !om.string
Expand All @@ -43,6 +43,9 @@ om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) ->
// Single operand replaced with operand
%2 = om.string.concat %s1 : !om.string

// Single constant operand folds to the attribute.
%singleConst = om.string.concat %s3 : !om.string

// Empty concat
%3 = om.string.concat %empty, %empty : !om.string

Expand All @@ -57,8 +60,40 @@ om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) ->
%nested = om.string.concat %str1, %str2 : !om.string
%concat1 = om.string.concat %nested, %s3 : !om.string

// CHECK: om.class.fields [[HELLOWORLD]], [[HELLO]], [[HELLO]], [[EMPTY]], [[HELLOWORLD]], [[CONCAT1]], [[NESTED]]
om.class.fields %0, %1, %2, %3, %5, %concat1, %nested : !om.string, !om.string, !om.string, !om.string, !om.string, !om.string, !om.string
// CHECK: om.class.fields [[HELLOWORLD]], [[HELLO]], [[HELLO]], [[CONST]], [[EMPTY]], [[HELLOWORLD]], [[CONCAT1]], [[NESTED]]
om.class.fields %0, %1, %2, %singleConst, %3, %5, %concat1, %nested : !om.string, !om.string, !om.string, !om.string, !om.string, !om.string, !om.string, !om.string
}

// CHECK-LABEL: @IntegerBinaryArithmeticFold
om.class @IntegerBinaryArithmeticFold(%x: !om.integer) -> (out1: !om.integer, out2: !om.integer,
out3: !om.integer, out4: !om.integer,
out5: !om.integer, out6: !om.integer) {
%i3 = om.constant #om.integer<3 : si4> : !om.integer
%i4 = om.constant #om.integer<4 : si4> : !om.integer
%i2 = om.constant #om.integer<2 : si4> : !om.integer
%neg1 = om.constant #om.integer<-1 : si4> : !om.integer
%i1 = om.constant #om.integer<1 : si4> : !om.integer
%wide = om.constant #om.integer<7 : si6> : !om.integer

// CHECK-DAG: [[ADD:%.+]] = om.constant #om.integer<7 : si4> : !om.integer
// CHECK-DAG: [[MUL:%.+]] = om.constant #om.integer<-4 : si4> : !om.integer
// CHECK-DAG: [[SHR:%.+]] = om.constant #om.integer<1 : si4> : !om.integer
// CHECK-DAG: [[SHL:%.+]] = om.constant #om.integer<-2 : si4> : !om.integer
// CHECK-DAG: [[WIDEADD:%.+]] = om.constant #om.integer<9 : si6> : !om.integer
// CHECK: [[DYN:%.+]] = om.integer.add %x, %{{.+}} : !om.integer
%0 = om.integer.add %i3, %i4 : !om.integer
%1 = om.integer.mul %i3, %i4 : !om.integer
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This folding to -4 is right in twos-complement, but may produce unexpected behavior if we're not careful as I think we view OM integers as arbitrary precision. I think the answer to this is just that this is how this works.

%2 = om.integer.shr %i4, %i2 : !om.integer
%3 = om.integer.shl %neg1, %i1 : !om.integer

// Mixed bit widths should still fold after extending operands.
%4 = om.integer.add %i2, %wide : !om.integer

// Non-constant operands should remain.
%5 = om.integer.add %x, %i1 : !om.integer

// CHECK: om.class.fields [[ADD]], [[MUL]], [[SHR]], [[SHL]], [[WIDEADD]], [[DYN]]
om.class.fields %0, %1, %2, %3, %4, %5 : !om.integer, !om.integer, !om.integer, !om.integer, !om.integer, !om.integer
}

// CHECK-LABEL: @PropEqFold
Expand Down
Loading