From 6123679448ef6f227df9f9e166a7aef51c0f7703 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Fri, 17 Apr 2026 18:05:18 -0700 Subject: [PATCH 1/3] [OM] Introduce explicit evaluator state and op patterns Refactor the OM evaluator around explicit Ready/Pending/Failure state tracking and a shared operation pattern framework. Centralize reference resolution, operand readiness/unknown propagation, and placeholder creation, while moving op-specific logic into typed patterns. Also add the supporting evaluator regressions for unknown nested fields and reference/value propagation. --- .../circt/Dialect/OM/Evaluator/Evaluator.h | 98 +- lib/Dialect/OM/Evaluator/Evaluator.cpp | 1266 ++++++++++------- .../Dialect/OM/Evaluator/EvaluatorTests.cpp | 108 ++ 3 files changed, 881 insertions(+), 591 deletions(-) diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index 06ee3057400f..b8f5d2a034a3 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -37,6 +37,36 @@ class EvaluatorValue; /// primitive Attribute. Further refinement is expected. using EvaluatorValuePtr = std::shared_ptr; +/// The evaluation state of a value handle. +enum class ResolutionState { + /// The handle can be used now. For references, this means the whole + /// reference chain leads to a fully evaluated value. + Ready, + /// Evaluation is not done yet. The handle itself may still be partial, or a + /// reference in the chain may still be missing. + Pending, + /// Evaluation hit a hard error, such as a reference cycle. + Failure +}; + +struct ResolvedValue { + /// `state` says whether `value` can be used. `value` keeps the original + /// handle so callers can keep passing it around even when it is still pending + /// or has already failed. + ResolutionState state; + EvaluatorValuePtr value; + + static ResolvedValue ready(EvaluatorValuePtr value) { + return {ResolutionState::Ready, std::move(value)}; + } + static ResolvedValue pending(EvaluatorValuePtr value = nullptr) { + return {ResolutionState::Pending, std::move(value)}; + } + static ResolvedValue failure(EvaluatorValuePtr value = nullptr) { + return {ResolutionState::Failure, std::move(value)}; + } +}; + /// The fields of a composite Object, currently represented as a map. Further /// refinement is expected. using ObjectFields = SmallDenseMap; @@ -54,7 +84,9 @@ class EvaluatorValue : public std::enable_shared_from_this { Kind getKind() const { return kind; } MLIRContext *getContext() const { return ctx; } - // Return true the value is fully evaluated. + // Return true if this value object has finished its own local work. + // This is not the same as semantic Ready/Pending state: for example, a + // ReferenceValue can be fully evaluated but still point to a pending value. // Unknown values are considered fully evaluated. bool isFullyEvaluated() const { return fullyEvaluated; } void markFullyEvaluated() { @@ -130,17 +162,7 @@ class ReferenceValue : public EvaluatorValue { LogicalResult finalizeImpl(); // Return the first non-reference value that is reachable from the reference. - FailureOr getStrippedValue() const { - llvm::SmallPtrSet visited; - auto currentValue = value; - while (auto *v = dyn_cast(currentValue.get())) { - // Detect a cycle. - if (!visited.insert(v).second) - return failure(); - currentValue = v->getValue(); - } - return success(currentValue); - } + FailureOr getStrippedValue() const; private: EvaluatorValuePtr value; @@ -410,53 +432,25 @@ class Evaluator { /// Evaluate a Value in a Class body according to the small expression grammar /// described in the rationale document. The actual parameters are the values /// supplied at the current instantiation of the Class being evaluated. - FailureOr + evaluator::ResolvedValue evaluateValue(Value value, ActualParameters actualParams, Location loc); /// Evaluator dispatch functions for the small expression grammar. - FailureOr evaluateParameter(BlockArgument formalParam, - ActualParameters actualParams, - Location loc); - - FailureOr - evaluateConstant(ConstantOp op, ActualParameters actualParams, Location loc); - - FailureOr - evaluateIntegerBinaryArithmetic(IntegerBinaryArithmeticOp op, - ActualParameters actualParams, Location loc); + evaluator::ResolvedValue evaluateParameter(BlockArgument formalParam, + ActualParameters actualParams, + Location loc); /// Instantiate an Object with its class name and actual parameters. FailureOr evaluateObjectInstance(StringAttr className, ActualParameters actualParams, Location loc, ObjectKey instanceKey = {}); - FailureOr + evaluator::ResolvedValue evaluateObjectInstance(ObjectOp op, ActualParameters actualParams); - FailureOr - evaluateObjectField(ObjectFieldOp op, ActualParameters actualParams, - Location loc); - FailureOr evaluateListCreate(ListCreateOp op, - ActualParameters actualParams, - Location loc); - FailureOr evaluateListConcat(ListConcatOp op, - ActualParameters actualParams, - Location loc); - FailureOr - evaluateStringConcat(StringConcatOp op, ActualParameters actualParams, - Location loc); - FailureOr - evaluateBinaryEquality(BinaryEqualityOp op, ActualParameters actualParams, - Location loc); - FailureOr - evaluateBasePathCreate(FrozenBasePathCreateOp op, - ActualParameters actualParams, Location loc); - FailureOr - evaluatePathCreate(FrozenPathCreateOp op, ActualParameters actualParams, - Location loc); - FailureOr - evaluateEmptyPath(FrozenEmptyPathOp op, ActualParameters actualParams, - Location loc); - FailureOr - evaluateUnknownValue(UnknownValueOp op, Location loc); + evaluator::ResolvedValue evaluateObjectField(ObjectFieldOp op, + ActualParameters actualParams, + Location loc); + evaluator::ResolvedValue evaluateUnknownValue(UnknownValueOp op, + Location loc); LogicalResult evaluatePropertyAssert(PropertyAssertOp op, ActualParameters actualParams); @@ -483,6 +477,10 @@ class Evaluator { /// Evaluator value storage. Return an evaluator value for the given /// instantiation context (a pair of Value and parameters). DenseMap> objects; + + /// Tracks object instantiations currently being evaluated so recursive + /// object graphs reuse the existing placeholder instead of recursing. + llvm::SmallDenseSet activeObjectInstances; }; /// Helper to enable printing objects in Diagnostics. diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 916f968779da..a7a5b691a11b 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -14,15 +14,576 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Location.h" #include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include +#include #define DEBUG_TYPE "om-evaluator" using namespace mlir; using namespace circt::om; +namespace { + +using ResolutionState = evaluator::ResolutionState; +using ResolvedValue = evaluator::ResolvedValue; + +//===----------------------------------------------------------------------===// +// Resolved value helpers +//===----------------------------------------------------------------------===// + +/// Walk through reference values until we reach a non-reference value. +/// Return Pending if the chain ends at null. Return Failure if the chain loops. +static ResolvedValue +resolveReferenceValue(evaluator::EvaluatorValuePtr currentValue) { + llvm::SmallPtrSet visited; + if (!currentValue) + return ResolvedValue::pending(); + + while (auto *ref = + llvm::dyn_cast(currentValue.get())) { + if (!visited.insert(ref).second) + return ResolvedValue::failure(currentValue); + currentValue = ref->getValue(); + if (!currentValue) + return ResolvedValue::pending(); + } + + return ResolvedValue::ready(std::move(currentValue)); +} + +/// Say whether this handle is ready to use, still pending, or failed. +/// Keep the original handle in the result either way. +static ResolvedValue +resolveValueState(evaluator::EvaluatorValuePtr currentValue) { + if (!currentValue || !currentValue->isFullyEvaluated()) + return ResolvedValue::pending(std::move(currentValue)); + + auto resolved = resolveReferenceValue(currentValue); + if (resolved.state != ResolutionState::Ready) + return {resolved.state, std::move(currentValue)}; + if (!resolved.value->isFullyEvaluated()) + return ResolvedValue::pending(std::move(currentValue)); + + return ResolvedValue::ready(std::move(currentValue)); +} + +//===----------------------------------------------------------------------===// +// Ready value helpers +//===----------------------------------------------------------------------===// + +static evaluator::EvaluatorValue * +resolveReadyValue(evaluator::EvaluatorValuePtr value) { + assert(value); + auto resolved = resolveReferenceValue(value); + assert(resolved.state == ResolutionState::Ready); + assert(resolved.value && resolved.value->isFullyEvaluated()); + return resolved.value.get(); +} + +template +static ValueT *getReadyAs(evaluator::EvaluatorValuePtr value) { + auto *typedValue = + llvm::dyn_cast(resolveReadyValue(std::move(value))); + assert(typedValue); + return typedValue; +} + +template +static AttrT getAsAttr(evaluator::EvaluatorValuePtr value) { + return llvm::dyn_cast( + getReadyAs(std::move(value))->getAttr()); +} + +static bool isUnknownReadyValue(evaluator::EvaluatorValuePtr value) { + return (value && value->isUnknown()) || resolveReadyValue(value)->isUnknown(); +} + +//===----------------------------------------------------------------------===// +// Operand resolution helpers +//===----------------------------------------------------------------------===// + +/// If the operand is ready, write it to `readyValue`. Otherwise return the +/// pending or failure state right away. +static std::optional +requireReady(const ResolvedValue &resolved, + evaluator::EvaluatorValuePtr pendingValue, + llvm::function_ref emitFailure, + evaluator::EvaluatorValuePtr &readyValue) { + switch (resolved.state) { + case ResolutionState::Pending: + return ResolvedValue::pending(std::move(pendingValue)); + case ResolutionState::Failure: + emitFailure(); + return ResolvedValue::failure(); + case ResolutionState::Ready: + readyValue = resolved.value; + return std::nullopt; + } + llvm_unreachable("unknown resolution state"); +} + +/// Resolve every operand before running the pattern. Stop right away if any +/// operand is pending or failed. Also record whether any operand is unknown so +/// the shared ready-operands path can mark the result unknown without running +/// the op-specific code. +static std::optional requireAllOperandsReady( + ValueRange operands, evaluator::EvaluatorValuePtr pendingValue, + llvm::function_ref evaluateOperand, + llvm::function_ref emitFailure, + SmallVectorImpl &readyOperands, + bool &existsUnknown) { + readyOperands.clear(); + readyOperands.reserve(operands.size()); + existsUnknown = false; + + for (auto operand : operands) { + evaluator::EvaluatorValuePtr readyOperand; + if (auto early = requireReady(evaluateOperand(operand), pendingValue, + emitFailure, readyOperand)) + return *early; + existsUnknown |= isUnknownReadyValue(readyOperand); + readyOperands.push_back(std::move(readyOperand)); + } + + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// Result mutation helpers +//===----------------------------------------------------------------------===// + +static ResolvedValue markUnknownAndReturn(evaluator::EvaluatorValuePtr value) { + value->markUnknown(); + return resolveValueState(std::move(value)); +} + +static LogicalResult setAttrResult(evaluator::EvaluatorValuePtr resultValue, + Attribute attr) { + auto *attrValue = cast(resultValue.get()); + if (failed(attrValue->setAttr(attr)) || failed(attrValue->finalize())) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Operation pattern infrastructure +//===----------------------------------------------------------------------===// + +/// Base class for one OM operation in the evaluator. +/// A pattern creates a placeholder for the result and later fills it in. +class OperationPattern { +public: + using CreatePartialValueFn = + llvm::function_ref(Type, + Location)>; + using GetValueHandleFn = + llvm::function_ref(Value, + Location)>; + + explicit OperationPattern(StringRef operationName) + : operationName(operationName) {} + virtual ~OperationPattern() = default; + + StringRef getOperationName() const { return operationName; } + virtual FailureOr + createPlaceholder(Operation *op, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const { + return createPartialValue(value.getType(), loc); + } + + virtual ResolvedValue + evaluate(Operation *op, evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const = 0; + +private: + StringRef operationName; +}; + +template +class OpPattern : public OperationPattern { +public: + using OperationPattern::OperationPattern; + + ResolvedValue evaluate(Operation *op, + evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const final { + return evaluateTyped(cast(op), std::move(resultValue), evaluateValue, + loc); + } + + FailureOr createPlaceholder( + Operation *op, Value value, CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const override { + return createTypedPlaceholder(cast(op), value, createPartialValue, + getValueHandle, loc); + } + +protected: + virtual FailureOr + createTypedPlaceholder(OpT op, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const { + return OperationPattern::createPlaceholder(op, value, createPartialValue, + getValueHandle, loc); + } + + virtual ResolvedValue + evaluateTyped(OpT op, evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const { + return resolveValueState(std::move(resultValue)); + } +}; + +/// Base class for operations that only run once all operands are ready. +/// This handles the shared ready/pending/failure/unknown logic so concrete +/// patterns only implement the successful case. +template +class OpReadyOperandsPattern : public OperationPattern { +public: + using OperationPattern::OperationPattern; + +private: + ResolvedValue evaluate(Operation *op, + evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const final { + if (resultValue && resultValue->isFullyEvaluated()) + return resolveValueState(std::move(resultValue)); + + SmallVector readyOperands; + bool existsUnknown = false; + if (auto early = requireAllOperandsReady( + op->getOperands(), resultValue, evaluateValue, + [&] { + op->emitError() + << "failed to resolve " << getOperationName() << " operand"; + }, + readyOperands, existsUnknown)) + return *early; + // If any operand is unknown, the result is unknown too. + if (existsUnknown) + return markUnknownAndReturn(std::move(resultValue)); + + if (failed(evaluateTyped(cast(op), readyOperands, resultValue, loc))) + return ResolvedValue::failure(); + return resolveValueState(std::move(resultValue)); + } + + FailureOr + createPlaceholder(Operation *op, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const final { + return createTypedPlaceholder(cast(op), value, createPartialValue, + getValueHandle, loc); + } + +protected: + virtual FailureOr + createTypedPlaceholder(OpT op, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const { + return OperationPattern::createPlaceholder(op, value, createPartialValue, + getValueHandle, loc); + } + + virtual LogicalResult + evaluateTyped(OpT op, ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const = 0; +}; + +//===----------------------------------------------------------------------===// +// Operation patterns +//===----------------------------------------------------------------------===// + +class IntegerBinaryArithmeticPattern final + : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + LogicalResult evaluateTyped(IntegerBinaryArithmeticOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + assert(operands.size() == 2 && "expected binary arithmetic operands"); + + circt::om::IntegerAttr lhs = getAsAttr(operands[0]); + circt::om::IntegerAttr rhs = getAsAttr(operands[1]); + assert(lhs && "expected om::IntegerAttr for IntegerBinaryArithmeticOp lhs"); + assert(rhs && "expected om::IntegerAttr for IntegerBinaryArithmeticOp rhs"); + + APSInt lhsVal = lhs.getValue().getAPSInt(); + APSInt rhsVal = rhs.getValue().getAPSInt(); + if (lhsVal.getBitWidth() > rhsVal.getBitWidth()) + rhsVal = rhsVal.extend(lhsVal.getBitWidth()); + else if (rhsVal.getBitWidth() > lhsVal.getBitWidth()) + lhsVal = lhsVal.extend(rhsVal.getBitWidth()); + + FailureOr result = op.evaluateIntegerOperation(lhsVal, rhsVal); + if (failed(result)) + return op->emitError("failed to evaluate integer operation"); + + MLIRContext *ctx = op.getContext(); + auto resultAttr = circt::om::IntegerAttr::get( + ctx, mlir::IntegerAttr::get(ctx, result.value())); + return setAttrResult(std::move(resultValue), resultAttr); + } +}; + +class ListCreatePattern final : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + LogicalResult evaluateTyped(ListCreateOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + SmallVector values; + values.reserve(operands.size()); + for (auto operand : operands) + values.push_back(operand); + + cast(resultValue.get()) + ->setElements(std::move(values)); + return success(); + } +}; + +class ListConcatPattern final : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + LogicalResult evaluateTyped(ListConcatOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + SmallVector values; + for (auto operand : operands) { + auto *subListValue = getReadyAs(operand); + llvm::append_range(values, subListValue->getElements()); + } + + cast(resultValue.get()) + ->setElements(std::move(values)); + return success(); + } +}; + +class StringConcatPattern final + : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + LogicalResult evaluateTyped(StringConcatOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + std::string result; + for (auto operand : operands) { + auto attr = getAsAttr(operand); + assert(attr && "expected StringAttr for StringConcatOp operand"); + result += attr.getValue().str(); + } + + auto resultStr = StringAttr::get(result, op.getResult().getType()); + return setAttrResult(std::move(resultValue), resultStr); + } +}; + +class BinaryEqualityPattern final + : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + LogicalResult evaluateTyped(BinaryEqualityOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + assert(operands.size() == 2 && "expected binary equality operands"); + + mlir::Attribute lhs = getAsAttr(operands[0]); + mlir::Attribute rhs = getAsAttr(operands[1]); + + FailureOr result = op.evaluateBinaryEquality(lhs, rhs); + if (failed(result)) + return op->emitError("failed to evaluate binary equality operation"); + return setAttrResult(std::move(resultValue), *result); + } +}; + +class FrozenBasePathCreatePattern final + : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + FailureOr + createTypedPlaceholder(FrozenBasePathCreateOp op, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, + Location loc) const override { + return success( + std::make_shared(op.getPathAttr(), loc)); + } + + LogicalResult evaluateTyped(FrozenBasePathCreateOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + assert(operands.size() == 1 && + "expected one operand for frozenbasepath_create"); + + auto *basePathValue = + getReadyAs(operands.front()); + cast(resultValue.get()) + ->setBasepath(*basePathValue); + return success(); + } +}; + +class FrozenPathCreatePattern final + : public OpReadyOperandsPattern { +public: + using OpReadyOperandsPattern::OpReadyOperandsPattern; + +protected: + FailureOr + createTypedPlaceholder(FrozenPathCreateOp pathOp, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, + Location loc) const override { + return success(std::make_shared( + pathOp.getTargetKindAttr(), pathOp.getPathAttr(), + pathOp.getModuleAttr(), pathOp.getRefAttr(), pathOp.getFieldAttr(), + loc)); + } + + LogicalResult evaluateTyped(FrozenPathCreateOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + assert(operands.size() == 1 && + "expected one operand for frozenpath_create"); + + auto *basePathValue = + getReadyAs(operands.front()); + cast(resultValue.get())->setBasepath(*basePathValue); + return success(); + } +}; + +class ConstantPattern final : public OpPattern { +public: + using OpPattern::OpPattern; + +protected: + FailureOr createTypedPlaceholder( + ConstantOp op, Value value, CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const override { + return success( + circt::om::evaluator::AttributeValue::get(op.getValue(), loc)); + } +}; + +class AnyCastPattern final : public OpPattern { +public: + using OpPattern::OpPattern; + +protected: + FailureOr createTypedPlaceholder( + AnyCastOp op, Value value, CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, Location loc) const override { + return getValueHandle(op.getInput(), loc); + } + + ResolvedValue + evaluateTyped(AnyCastOp op, evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const override { + if (resultValue && resultValue->isFullyEvaluated()) + return resolveValueState(std::move(resultValue)); + return evaluateValue(op.getInput()); + } +}; + +class FrozenEmptyPathPattern final : public OpPattern { +public: + using OpPattern::OpPattern; + +protected: + FailureOr + createTypedPlaceholder(FrozenEmptyPathOp op, Value value, + CreatePartialValueFn createPartialValue, + GetValueHandleFn getValueHandle, + Location loc) const override { + return success(std::make_shared( + evaluator::PathValue::getEmptyPath(loc))); + } +}; + +//===----------------------------------------------------------------------===// +// Operaton Pattern Registery +//===----------------------------------------------------------------------===// + +class OperationPatternRegistry { +public: + OperationPatternRegistry() { + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + addPattern(); + } + + const OperationPattern *lookup(Operation *op) const { + auto it = patternsByOpName.find(op->getName().getStringRef()); + return it == patternsByOpName.end() ? nullptr : it->second; + } + +private: + template + void addPattern() { + auto pattern = std::make_unique(OpT::getOperationName()); + const OperationPattern *patternPtr = pattern.get(); + patterns.push_back(std::move(pattern)); + patternsByOpName[OpT::getOperationName()] = patternPtr; + } + + SmallVector> patterns; + llvm::StringMap patternsByOpName; +}; + +static const OperationPatternRegistry &getOperationPatternRegistry() { + static const OperationPatternRegistry registry; + return registry; +} + +} // namespace + /// Construct an Evaluator with an IR module. circt::om::Evaluator::Evaluator(ModuleOp mod) : symbolTable(mod) {} @@ -95,7 +656,22 @@ circt::om::Evaluator::getPartiallyEvaluatedValue(Type type, Location loc) { evaluator::AttributeValue::get(type, loc); return success(result); }) - .Default([&](auto type) { return failure(); }); + .Case([&](FrozenBasePathType type) { + evaluator::EvaluatorValuePtr result = + std::make_shared(type.getContext()); + return success(result); + }) + .Case([&](FrozenPathType type) { + evaluator::EvaluatorValuePtr result = + std::make_shared( + evaluator::PathValue::getEmptyPath(loc)); + return success(result); + }) + .Default([&](auto type) { + evaluator::EvaluatorValuePtr result = + evaluator::AttributeValue::get(type, loc); + return success(result); + }); } FailureOr circt::om::Evaluator::getOrCreateValue( @@ -115,66 +691,33 @@ FailureOr circt::om::Evaluator::getOrCreateValue( return val; }) .Case([&](OpResult result) { + using namespace circt::om::evaluator; + Operation *op = result.getDefiningOp(); + + if (auto *pattern = getOperationPatternRegistry().lookup(op)) + return pattern->createPlaceholder( + op, value, + [&](Type type, Location placeholderLoc) { + return getPartiallyEvaluatedValue(type, placeholderLoc); + }, + [&](Value aliasedValue, Location placeholderLoc) { + return getOrCreateValue(aliasedValue, actualParams, + placeholderLoc); + }, + loc); + return TypeSwitch>( - result.getDefiningOp()) - .Case([&](ConstantOp op) { - return evaluateConstant(op, actualParams, loc); - }) - .Case([&](IntegerBinaryArithmeticOp op) { - // Create a partially evaluated AttributeValue of - // om::IntegerType in case we need to delay evaluation. - evaluator::EvaluatorValuePtr result = - evaluator::AttributeValue::get(op.getResult().getType(), - loc); - return success(result); - }) + FailureOr>(op) .Case([&](auto op) { - // Create a reference value since the value pointed by object - // field op is not created yet. - evaluator::EvaluatorValuePtr result = - std::make_shared( - value.getType(), loc); - return success(result); - }) - .Case([&](AnyCastOp op) { - return getOrCreateValue(op.getInput(), actualParams, loc); - }) - .Case([&](FrozenBasePathCreateOp op) { - evaluator::EvaluatorValuePtr result = - std::make_shared( - op.getPathAttr(), loc); - return success(result); - }) - .Case([&](FrozenPathCreateOp op) { - evaluator::EvaluatorValuePtr result = - std::make_shared( - op.getTargetKindAttr(), op.getPathAttr(), - op.getModuleAttr(), op.getRefAttr(), - op.getFieldAttr(), loc); - return success(result); - }) - .Case([&](FrozenEmptyPathOp op) { - evaluator::EvaluatorValuePtr result = - std::make_shared( - evaluator::PathValue::getEmptyPath(loc)); - return success(result); - }) - .Case([&](BinaryEqualityOp op) { - evaluator::EvaluatorValuePtr result = - evaluator::AttributeValue::get(op.getResult().getType(), - loc); - return success(result); - }) - .Case([&](auto op) { - return getPartiallyEvaluatedValue(op.getType(), loc); + return success( + std::make_shared(value.getType(), loc)); }) .Case([&](auto op) { return getPartiallyEvaluatedValue(op.getType(), op.getLoc()); }) - .Case( - [&](auto op) { return evaluateUnknownValue(op, loc); }) + .Case([&](auto op) { + return createUnknownValue(op.getType(), loc); + }) .Default([&](Operation *op) { auto error = op->emitError("unable to evaluate value"); error.attachNote() << "value: " << value; @@ -275,12 +818,11 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className, auto name = fieldNames[i]; auto value = operands[i]; auto fieldLoc = cls.getFieldLocByIndex(i); - FailureOr result = - evaluateValue(value, actualParams, fieldLoc); - if (failed(result)) - return result; + auto result = evaluateValue(value, actualParams, fieldLoc); + if (result.state == ResolutionState::Failure) + return failure(); - fields[cast(name)] = result.value(); + fields[cast(name)] = result.value; } // Evaluate property assertions. @@ -344,11 +886,11 @@ circt::om::Evaluator::instantiate( auto result = evaluateValue(value, args, loc); - if (failed(result)) + if (result.state == ResolutionState::Failure) return failure(); // It's possible that the value is not fully evaluated. - if (!result.value()->isFullyEvaluated()) + if (result.state == ResolutionState::Pending) worklist.push({value, args}); } @@ -361,169 +903,54 @@ circt::om::Evaluator::instantiate( return object; } -FailureOr -circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams, - Location loc) { - auto evaluatorValue = getOrCreateValue(value, actualParams, loc).value(); - - // Return if the value is already evaluated. - if (evaluatorValue->isFullyEvaluated()) - return evaluatorValue; +ResolvedValue circt::om::Evaluator::evaluateValue(Value value, + ActualParameters actualParams, + Location loc) { + auto evaluatorValue = getOrCreateValue(value, actualParams, loc); + if (failed(evaluatorValue)) + return ResolvedValue::failure(); - return llvm::TypeSwitch>(value) + return llvm::TypeSwitch(value) .Case([&](BlockArgument arg) { return evaluateParameter(arg, actualParams, loc); }) - .Case([&](OpResult result) { - return TypeSwitch>( - result.getDefiningOp()) - .Case([&](ConstantOp op) { - return evaluateConstant(op, actualParams, loc); - }) - .Case([&](IntegerBinaryArithmeticOp op) { - return evaluateIntegerBinaryArithmetic(op, actualParams, loc); - }) + .Case([&](OpResult result) { + if (auto *pattern = + getOperationPatternRegistry().lookup(result.getDefiningOp())) + return pattern->evaluate( + result.getDefiningOp(), evaluatorValue.value(), + [&](Value nestedValue) { + return evaluateValue(nestedValue, actualParams, loc); + }, + loc); + + if (evaluatorValue.value()->isFullyEvaluated()) + return resolveValueState(evaluatorValue.value()); + + return TypeSwitch(result.getDefiningOp()) .Case([&](ObjectOp op) { return evaluateObjectInstance(op, actualParams); }) .Case([&](ObjectFieldOp op) { return evaluateObjectField(op, actualParams, loc); }) - .Case([&](ListCreateOp op) { - return evaluateListCreate(op, actualParams, loc); - }) - .Case([&](ListConcatOp op) { - return evaluateListConcat(op, actualParams, loc); - }) - .Case([&](StringConcatOp op) { - return evaluateStringConcat(op, actualParams, loc); - }) - .Case([&](BinaryEqualityOp op) { - return evaluateBinaryEquality(op, actualParams, loc); - }) - .Case([&](AnyCastOp op) { - return evaluateValue(op.getInput(), actualParams, loc); - }) - .Case([&](FrozenBasePathCreateOp op) { - return evaluateBasePathCreate(op, actualParams, loc); - }) - .Case([&](FrozenPathCreateOp op) { - return evaluatePathCreate(op, actualParams, loc); - }) - .Case([&](FrozenEmptyPathOp op) { - return evaluateEmptyPath(op, actualParams, loc); - }) .Case([&](UnknownValueOp op) { return evaluateUnknownValue(op, loc); }) .Default([&](Operation *op) { auto error = op->emitError("unable to evaluate value"); error.attachNote() << "value: " << value; - return error; + return ResolvedValue::failure(); }); }); } /// Evaluator dispatch function for parameters. -FailureOr circt::om::Evaluator::evaluateParameter( +ResolvedValue circt::om::Evaluator::evaluateParameter( BlockArgument formalParam, ActualParameters actualParams, Location loc) { auto val = (*actualParams)[formalParam.getArgNumber()]; val->setLoc(loc); - return success(val); -} - -/// Evaluator dispatch function for constants. -FailureOr -circt::om::Evaluator::evaluateConstant(ConstantOp op, - ActualParameters actualParams, - Location loc) { - // For list constants, create ListValue. - return success(om::evaluator::AttributeValue::get(op.getValue(), loc)); -} - -// Evaluator dispatch function for integer binary arithmetic. -FailureOr -circt::om::Evaluator::evaluateIntegerBinaryArithmetic( - IntegerBinaryArithmeticOp op, ActualParameters actualParams, Location loc) { - // Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet. - auto handle = getOrCreateValue(op.getResult(), actualParams, loc); - - // If it's fully evaluated, we can return it. - if (handle.value()->isFullyEvaluated()) - return handle; - - // Evaluate operands if necessary, and return the partially evaluated value if - // they aren't ready. - auto lhsResult = evaluateValue(op.getLhs(), actualParams, loc); - if (failed(lhsResult)) - return lhsResult; - if (!lhsResult.value()->isFullyEvaluated()) - return handle; - - auto rhsResult = evaluateValue(op.getRhs(), actualParams, loc); - if (failed(rhsResult)) - return rhsResult; - if (!rhsResult.value()->isFullyEvaluated()) - return handle; - - // Check if any operand is unknown and propagate the unknown flag. - if (lhsResult.value()->isUnknown() || rhsResult.value()->isUnknown()) { - handle.value()->markUnknown(); - return handle; - } - - // Extract the integer attributes. - auto extractAttr = [](evaluator::EvaluatorValue *value) { - return std::move( - llvm::TypeSwitch(value) - .Case([](evaluator::AttributeValue *val) { - return val->getAs(); - }) - .Case([](evaluator::ReferenceValue *val) { - return cast( - val->getStrippedValue()->get()) - ->getAs(); - })); - }; - - om::IntegerAttr lhs = extractAttr(lhsResult.value().get()); - om::IntegerAttr rhs = extractAttr(rhsResult.value().get()); - assert(lhs && rhs && - "expected om::IntegerAttr for IntegerBinaryArithmeticOp operands"); - - // Extend values if necessary to match bitwidth. Most interesting arithmetic - // on APSInt asserts that both operands are the same bitwidth, but the - // IntegerAttrs we are working with may have used the smallest necessary - // bitwidth to represent the number they hold, and won't necessarily match. - APSInt lhsVal = lhs.getValue().getAPSInt(); - APSInt rhsVal = rhs.getValue().getAPSInt(); - if (lhsVal.getBitWidth() > rhsVal.getBitWidth()) - rhsVal = rhsVal.extend(lhsVal.getBitWidth()); - else if (rhsVal.getBitWidth() > lhsVal.getBitWidth()) - lhsVal = lhsVal.extend(rhsVal.getBitWidth()); - - // Perform arbitrary precision signed integer binary arithmetic. - FailureOr result = op.evaluateIntegerOperation(lhsVal, rhsVal); - - if (failed(result)) - return op->emitError("failed to evaluate integer operation"); - - // Package the result as a new om::IntegerAttr. - MLIRContext *ctx = op->getContext(); - auto resultAttr = - om::IntegerAttr::get(ctx, mlir::IntegerAttr::get(ctx, result.value())); - - // Finalize the op result value. - auto *handleValue = cast(handle.value().get()); - auto resultStatus = handleValue->setAttr(resultAttr); - if (failed(resultStatus)) - return resultStatus; - - auto finalizeStatus = handleValue->finalize(); - if (failed(finalizeStatus)) - return finalizeStatus; - - return handle; + return resolveValueState(val); } /// Evaluator dispatch function for property assertions. @@ -531,39 +958,19 @@ LogicalResult circt::om::Evaluator::evaluatePropertyAssert(PropertyAssertOp op, ActualParameters actualParams) { auto loc = op.getLoc(); - - // Evaluate the condition, returning early if it isn't ready yet. - auto condResult = evaluateValue(op.getCondition(), actualParams, loc); - if (failed(condResult)) - return failure(); - if (!condResult.value()->isFullyEvaluated()) - return success(); - - // If the condition is unknown, skip silently (best-effort). - if (condResult.value()->isUnknown()) + evaluator::EvaluatorValuePtr readyCond; + if (auto early = requireReady( + evaluateValue(op.getCondition(), actualParams, loc), nullptr, + [&] { + op.emitError("failed to resolve property assertion condition"); + }, + readyCond)) + return early->state == ResolutionState::Pending ? success() : failure(); + + if (isUnknownReadyValue(readyCond)) return success(); - // Extract the attribute from the condition value, handling the case where - // the condition resolves through a ReferenceValue (e.g. an ObjectFieldOp or - // a parameter that participates in cycle resolution). - auto extractAttr = [](evaluator::EvaluatorValue *value) -> mlir::Attribute { - return llvm::TypeSwitch(value) - .Case([](evaluator::AttributeValue *val) { return val->getAttr(); }) - .Case([](evaluator::ReferenceValue *val) -> mlir::Attribute { - auto stripped = val->getStrippedValue(); - if (failed(stripped)) - return {}; - if (auto *attr = - dyn_cast(stripped.value().get())) - return attr->getAttr(); - return {}; - }) - .Default([](auto *) -> mlir::Attribute { return {}; }); - }; - - auto condAttr = extractAttr(condResult.value().get()); - if (!condAttr) - return success(); + auto condAttr = getAsAttr(readyCond); bool isFalse = false; if (auto boolAttr = dyn_cast(condAttr)) @@ -592,6 +999,7 @@ circt::om::Evaluator::createParametersFromOperands( auto inputResult = getOrCreateValue(input, actualParams, loc); if (failed(inputResult)) return failure(); + parameters->push_back(inputResult.value()); } @@ -600,66 +1008,91 @@ circt::om::Evaluator::createParametersFromOperands( } /// Evaluator dispatch function for Object instances. -FailureOr +ResolvedValue circt::om::Evaluator::evaluateObjectInstance(ObjectOp op, ActualParameters actualParams) { auto loc = op.getLoc(); - if (isFullyEvaluated({op, actualParams})) - return getOrCreateValue(op, actualParams, loc); + auto key = ObjectKey{op, actualParams}; + // Check if the instance is already fully evaluated or being evaluated. This + // can happen when there is a cycle in the object graph. In this case we + // should not attempt to evaluate the instance again, but just return the + // current state of the value, which might be pending or unknown. + if (isFullyEvaluated(key) || !activeObjectInstances.insert(key).second) + return resolveValueState(getOrCreateValue(op, actualParams, loc).value()); + auto clearActiveObject = + llvm::scope_exit([&] { activeObjectInstances.erase(key); }); auto params = createParametersFromOperands(op.getOperands(), actualParams, loc); if (failed(params)) - return failure(); - return evaluateObjectInstance(op.getClassNameAttr(), params.value(), loc, - {op, actualParams}); + return ResolvedValue::failure(); + auto result = evaluateObjectInstance(op.getClassNameAttr(), params.value(), + loc, {op, actualParams}); + if (failed(result)) + return ResolvedValue::failure(); + return resolveValueState(result.value()); } /// Evaluator dispatch function for Object fields. -FailureOr -circt::om::Evaluator::evaluateObjectField(ObjectFieldOp op, - ActualParameters actualParams, - Location loc) { - // Evaluate the Object itself, in case it hasn't been evaluated yet. - FailureOr currentObjectResult = - evaluateValue(op.getObject(), actualParams, loc); - if (failed(currentObjectResult)) - return currentObjectResult; - - auto result = currentObjectResult.value(); - +ResolvedValue circt::om::Evaluator::evaluateObjectField( + ObjectFieldOp op, ActualParameters actualParams, Location loc) { auto objectFieldValue = getOrCreateValue(op, actualParams, loc).value(); - if (result->isUnknown()) { - // If objectFieldValue is a ReferenceValue, set its value to a unknown value - // of the proper type + auto setUnknownFieldValue = [&]() -> ResolvedValue { + auto unknownField = createUnknownValue(op.getResult().getType(), loc); + if (failed(unknownField)) + return ResolvedValue::failure(); + if (auto *ref = - llvm::dyn_cast(objectFieldValue.get())) { - auto unknownField = createUnknownValue(op.getResult().getType(), loc); - if (failed(unknownField)) - return unknownField; + llvm::dyn_cast(objectFieldValue.get())) ref->setValue(unknownField.value()); - } - // markUnknown() also marks the value as fully evaluated + objectFieldValue->markUnknown(); - return objectFieldValue; - } + return ResolvedValue::ready(objectFieldValue); + }; + + evaluator::EvaluatorValuePtr readyObject; + if (auto early = requireReady( + evaluateValue(op.getObject(), actualParams, loc), objectFieldValue, + [&] { op.emitError("failed to resolve object field base"); }, + readyObject)) + return *early; - auto *currentObject = llvm::cast(result.get()); + if (isUnknownReadyValue(readyObject)) + return setUnknownFieldValue(); + + auto *currentObject = getReadyAs(readyObject); // Iteratively access nested fields through the path until we reach the final // field in the path. evaluator::EvaluatorValuePtr finalField; - for (auto field : op.getFieldPath().getAsRange()) { + auto fieldPath = op.getFieldPath().getAsRange(); + for (auto it = fieldPath.begin(), end = fieldPath.end(); it != end; ++it) { + auto field = *it; // `currentObject` might no be fully evaluated. if (!currentObject->getFields().contains(field.getAttr())) - return objectFieldValue; + return ResolvedValue::pending(objectFieldValue); auto currentField = currentObject->getField(field.getAttr()); finalField = currentField.value(); - if (auto *nextObject = - llvm::dyn_cast(finalField.get())) - currentObject = nextObject; + // Only the middle path elements need to be ready objects. The last element + // is the value we are returning, so it may have any type. + if (std::next(it) == end) + continue; + + evaluator::EvaluatorValuePtr nextObject; + if (auto early = requireReady( + resolveValueState(finalField), objectFieldValue, + [&] { + op.emitError("failed to resolve nested object field " + "path"); + }, + nextObject)) + return *early; + if (isUnknownReadyValue(nextObject)) + return setUnknownFieldValue(); + + currentObject = getReadyAs(nextObject); } // Update the reference. @@ -667,273 +1100,7 @@ circt::om::Evaluator::evaluateObjectField(ObjectFieldOp op, ->setValue(finalField); // Return the field being accessed. - return objectFieldValue; -} - -/// Evaluator dispatch function for List creation. -FailureOr -circt::om::Evaluator::evaluateListCreate(ListCreateOp op, - ActualParameters actualParams, - Location loc) { - // Evaluate the Object itself, in case it hasn't been evaluated yet. - SmallVector values; - auto list = getOrCreateValue(op, actualParams, loc); - bool hasUnknown = false; - for (auto operand : op.getOperands()) { - auto result = evaluateValue(operand, actualParams, loc); - if (failed(result)) - return result; - if (!result.value()->isFullyEvaluated()) - return list; - // Check if any operand is unknown. - if (result.value()->isUnknown()) - hasUnknown = true; - values.push_back(result.value()); - } - - // Set the list elements (this also marks the list as fully evaluated). - llvm::cast(list.value().get()) - ->setElements(std::move(values)); - - // If any operand is unknown, mark the list as unknown. - // markUnknown() checks if already fully evaluated before calling - // markFullyEvaluated(). - if (hasUnknown) - list.value()->markUnknown(); - - return list; -} - -/// Evaluator dispatch function for List concatenation. -FailureOr -circt::om::Evaluator::evaluateListConcat(ListConcatOp op, - ActualParameters actualParams, - Location loc) { - // Evaluate the List concat op itself, in case it hasn't been evaluated yet. - SmallVector values; - auto list = getOrCreateValue(op, actualParams, loc); - - // Extract the ListValue, either directly or through an object reference. - auto extractList = [](evaluator::EvaluatorValue *value) { - return std::move( - llvm::TypeSwitch( - value) - .Case([](evaluator::ListValue *val) { return val; }) - .Case([](evaluator::ReferenceValue *val) { - return cast(val->getStrippedValue()->get()); - })); - }; - - bool hasUnknown = false; - for (auto operand : op.getOperands()) { - auto result = evaluateValue(operand, actualParams, loc); - if (failed(result)) - return result; - if (!result.value()->isFullyEvaluated()) - return list; - // Check if any operand is unknown. - if (result.value()->isUnknown()) - hasUnknown = true; - - // Extract this sublist and ensure it's done evaluating. - evaluator::ListValue *subList = extractList(result.value().get()); - if (!subList->isFullyEvaluated()) - return list; - - // Append each EvaluatorValue from the sublist. - for (const auto &subValue : subList->getElements()) - values.push_back(subValue); - } - - // Return the concatenated list. - llvm::cast(list.value().get()) - ->setElements(std::move(values)); - - // If any operand is unknown, mark the result as unknown. - // markUnknown() checks if already fully evaluated before calling - // markFullyEvaluated(). - if (hasUnknown) - list.value()->markUnknown(); - - return list; -} - -/// Evaluator dispatch function for String concatenation. -FailureOr -circt::om::Evaluator::evaluateStringConcat(StringConcatOp op, - ActualParameters actualParams, - Location loc) { - // Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet. - auto handle = getOrCreateValue(op.getResult(), actualParams, loc); - if (failed(handle)) - return handle; - - // If it's fully evaluated, we can return it. - if (handle.value()->isFullyEvaluated()) - return handle; - - // Extract the string attributes, handling both AttributeValue and - // ReferenceValue cases. - auto extractAttr = [](evaluator::EvaluatorValue *value) -> StringAttr { - return llvm::TypeSwitch(value) - .Case([](evaluator::AttributeValue *val) { - return val->getAs(); - }) - .Case([](evaluator::ReferenceValue *val) { - return cast(val->getStrippedValue()->get()) - ->getAs(); - }); - }; - - // Evaluate all operands and concatenate them. - std::string result; - for (auto operand : op.getOperands()) { - auto operandResult = evaluateValue(operand, actualParams, loc); - if (failed(operandResult)) - return operandResult; - if (!operandResult.value()->isFullyEvaluated()) - return handle; - - StringAttr str = extractAttr(operandResult.value().get()); - assert(str && "expected StringAttr for StringConcatOp operand"); - result += str.getValue().str(); - } - - // Create the concatenated string attribute. - auto resultStr = StringAttr::get(result, op.getResult().getType()); - - // Finalize the op result value. - auto *handleValue = cast(handle.value().get()); - auto resultStatus = handleValue->setAttr(resultStr); - if (failed(resultStatus)) - return resultStatus; - - auto finalizeStatus = handleValue->finalize(); - if (failed(finalizeStatus)) - return finalizeStatus; - - return handle; -} - -// Evaluator dispatch function for binary property equality operations. -FailureOr -circt::om::Evaluator::evaluateBinaryEquality(BinaryEqualityOp op, - ActualParameters actualParams, - Location loc) { - // Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet. - auto handle = getOrCreateValue(op.getResult(), actualParams, loc); - if (failed(handle)) - return handle; - - // If it's fully evaluated, we can return it. - if (handle.value()->isFullyEvaluated()) - return handle; - - // Evaluate both operands, returning the partially evaluated handle if either - // isn't ready yet. - auto lhsResult = evaluateValue(op.getLhs(), actualParams, loc); - if (failed(lhsResult)) - return lhsResult; - if (!lhsResult.value()->isFullyEvaluated()) - return handle; - - auto rhsResult = evaluateValue(op.getRhs(), actualParams, loc); - if (failed(rhsResult)) - return rhsResult; - if (!rhsResult.value()->isFullyEvaluated()) - return handle; - - // Check if any operand is unknown and propagate the unknown flag. - if (lhsResult.value()->isUnknown() || rhsResult.value()->isUnknown()) { - handle.value()->markUnknown(); - return handle; - } - - // Extract the underlying attribute, handling both AttributeValue and - // ReferenceValue cases. - auto extractAttr = [](evaluator::EvaluatorValue *value) -> mlir::Attribute { - return llvm::TypeSwitch(value) - .Case([](evaluator::AttributeValue *val) { return val->getAttr(); }) - .Case([](evaluator::ReferenceValue *val) -> mlir::Attribute { - return cast(val->getStrippedValue()->get()) - ->getAttr(); - }); - }; - - mlir::Attribute lhs = extractAttr(lhsResult.value().get()); - mlir::Attribute rhs = extractAttr(rhsResult.value().get()); - assert(lhs && rhs && "expected attribute for BinaryEqualityOp operands"); - - // Perform the binary equality operation. - FailureOr result = op.evaluateBinaryEquality(lhs, rhs); - if (failed(result)) - return op->emitError("failed to evaluate binary equality operation"); - - // Finalize the op result value. - auto *handleValue = cast(handle.value().get()); - auto resultStatus = handleValue->setAttr(*result); - if (failed(resultStatus)) - return resultStatus; - - auto finalizeStatus = handleValue->finalize(); - if (failed(finalizeStatus)) - return finalizeStatus; - - return handle; -} - -FailureOr -circt::om::Evaluator::evaluateBasePathCreate(FrozenBasePathCreateOp op, - ActualParameters actualParams, - Location loc) { - // Evaluate the Object itself, in case it hasn't been evaluated yet. - auto valueResult = getOrCreateValue(op, actualParams, loc).value(); - auto *path = llvm::cast(valueResult.get()); - auto result = evaluateValue(op.getBasePath(), actualParams, loc); - if (failed(result)) - return result; - auto &value = result.value(); - if (!value->isFullyEvaluated()) - return valueResult; - - // If the base path is unknown, mark the result as unknown. - if (result.value()->isUnknown()) { - valueResult->markUnknown(); - return valueResult; - } - - path->setBasepath(*llvm::cast(value.get())); - return valueResult; -} - -FailureOr -circt::om::Evaluator::evaluatePathCreate(FrozenPathCreateOp op, - ActualParameters actualParams, - Location loc) { - // Evaluate the Object itself, in case it hasn't been evaluated yet. - auto valueResult = getOrCreateValue(op, actualParams, loc).value(); - auto *path = llvm::cast(valueResult.get()); - auto result = evaluateValue(op.getBasePath(), actualParams, loc); - if (failed(result)) - return result; - auto &value = result.value(); - if (!value->isFullyEvaluated()) - return valueResult; - - // If the base path is unknown, mark the result as unknown. - if (result.value()->isUnknown()) { - valueResult->markUnknown(); - return valueResult; - } - - path->setBasepath(*llvm::cast(value.get())); - return valueResult; -} - -FailureOr circt::om::Evaluator::evaluateEmptyPath( - FrozenEmptyPathOp op, ActualParameters actualParams, Location loc) { - auto valueResult = getOrCreateValue(op, actualParams, loc).value(); - return valueResult; + return resolveValueState(objectFieldValue); } /// Create an unknown value of the specified type @@ -982,9 +1149,12 @@ circt::om::Evaluator::createUnknownValue(Type type, Location loc) { } /// Evaluate an unknown value -FailureOr -circt::om::Evaluator::evaluateUnknownValue(UnknownValueOp op, Location loc) { - return createUnknownValue(op.getType(), loc); +ResolvedValue circt::om::Evaluator::evaluateUnknownValue(UnknownValueOp op, + Location loc) { + auto result = createUnknownValue(op.getType(), loc); + if (failed(result)) + return ResolvedValue::failure(); + return resolveValueState(result.value()); } //===----------------------------------------------------------------------===// @@ -1026,11 +1196,25 @@ LogicalResult circt::om::evaluator::ObjectValue::finalizeImpl() { // ReferenceValue //===----------------------------------------------------------------------===// +FailureOr +circt::om::evaluator::ReferenceValue::getStrippedValue() const { + auto resolved = resolveReferenceValue(value); + switch (resolved.state) { + case ResolutionState::Ready: + return success(resolved.value); + case ResolutionState::Pending: + return mlir::emitError(getLoc(), "reference value is not resolved"); + case ResolutionState::Failure: + return mlir::emitError(getLoc(), "reference value contains a cycle"); + } + llvm_unreachable("unknown resolution state"); +} + LogicalResult circt::om::evaluator::ReferenceValue::finalizeImpl() { - auto result = getStrippedValue(); - if (failed(result)) - return result; - value = std::move(result.value()); + auto resolved = resolveReferenceValue(value); + if (resolved.state != ResolutionState::Ready) + return failure(); + value = std::move(resolved.value); // the stripped value also needs to be finalized if (failed(finalizeEvaluatorValue(value))) return failure(); diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index d798aed654d2..4dd77d45e097 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -1661,6 +1661,65 @@ om.class @Foo( ASSERT_TRUE(object->getField("b").value()->isUnknown()); } +TEST(EvaluatorTests, UnknownValuesNestedObjectFieldPath) { + StringRef mod = R"MLIR( +om.class @Leaf( + %value: !om.integer +) -> ( + value: !om.integer +) { + om.class.fields %value : !om.integer +} + +om.class @Outer( + %leaf: !om.class.type<@Leaf> +) -> ( + leaf: !om.class.type<@Leaf> +) { + om.class.fields %leaf : !om.class.type<@Leaf> +} + +om.class @Foo( + %unknown_outer: !om.class.type<@Outer> +) -> ( + value: !om.integer +) { + %0 = om.object.field %unknown_outer, [@leaf, @value] : (!om.class.type<@Outer>) -> !om.integer + om.class.fields %0 : !om.integer +} +)MLIR"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto unknownLoc = UnknownLoc::get(&context); + auto outerClassType = circt::om::ClassType::get( + &context, mlir::FlatSymbolRefAttr::get(&context, "Outer")); + auto unknownOuter = circt::om::evaluator::AttributeValue::get( + outerClassType, LocationAttr(unknownLoc)); + unknownOuter->markUnknown(); + + auto result = + evaluator.instantiate(StringAttr::get(&context, "Foo"), {unknownOuter}); + + ASSERT_TRUE(succeeded(result)); + + auto *object = llvm::cast(result.value().get()); + auto value = object->getField("value"); + ASSERT_TRUE(succeeded(value)); + ASSERT_TRUE(value->get()->isUnknown()); + ASSERT_EQ(value->get()->getType(), circt::om::OMIntegerType::get(&context)); + ASSERT_EQ(value->get()->getKind(), evaluator::EvaluatorValue::Kind::Attr); +} + TEST(EvaluatorTests, StringConcat) { const char *mod = R"MLIR( module { @@ -1955,4 +2014,53 @@ om.class @PropEqInteger(%n: !om.integer) -> (equal: i1, not_equal: i1, unknown: } } +TEST(EvaluatorTests, ReferenceValueBounceThroughObject) { + StringRef mod = R"MLIR( +om.class @Domain(%in: !om.string) -> (out: !om.string) { + om.class.fields %in : !om.string +} +om.class @Foo_Class(%basepath: !om.frozenbasepath) -> (test1: i1, test2: i1) { + %0 = om.constant "A" : !om.string + %1 = om.object @Domain(%0) : (!om.string) -> !om.class.type<@Domain> + %2 = om.object.field %1, [@out] : (!om.class.type<@Domain>) -> !om.string + %3 = om.object @Domain(%2) : (!om.string) -> !om.class.type<@Domain> + %4 = om.object.field %3, [@out] : (!om.class.type<@Domain>) -> !om.string + %5 = om.constant "A" : !om.string + %6 = om.constant "B" : !om.string + %7 = om.prop.eq %4, %5 : !om.string + %8 = om.prop.eq %4, %6 : !om.string + om.class.fields %7, %8 : i1, i1 +} +)MLIR"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "Foo_Class"), + {std::make_shared(&context)}); + ASSERT_TRUE(succeeded(result)); + + auto *obj = llvm::cast(result.value().get()); + auto check = [&](StringRef fieldName, bool expected) { + auto *field = + obj->getField(StringAttr::get(&context, fieldName)).value().get(); + ASSERT_TRUE(llvm::isa(field)); + auto intAttr = llvm::cast(field) + ->getAs(); + ASSERT_TRUE(intAttr); + ASSERT_EQ(intAttr.getValue().getZExtValue(), expected ? 1u : 0u); + }; + check("test1", true); + check("test2", false); +} + } // namespace From 92256fe9d240a4433df29453f7a11136f1eec730 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Fri, 17 Apr 2026 23:56:01 -0700 Subject: [PATCH 2/3] Rename terminology, add a doc --- .../circt/Dialect/OM/Evaluator/Evaluator.h | 72 ++++++++++++------- lib/Dialect/OM/Evaluator/Evaluator.cpp | 49 +++++++------ 2 files changed, 72 insertions(+), 49 deletions(-) diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index b8f5d2a034a3..4846cfeec297 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -37,10 +37,34 @@ class EvaluatorValue; /// primitive Attribute. Further refinement is expected. using EvaluatorValuePtr = std::shared_ptr; +/// The evaluator tracks two different things: +/// +/// 1. Local state of one value object: `isSettled()` +/// +/// false -> this value object may still change +/// true -> this value object has finished its own local work +/// +/// 2. State of using a value handle: `ResolutionState` +/// +/// Pending -> the handle still cannot be used +/// Ready -> the handle can be used now +/// Failure -> evaluation hit a hard error +/// +/// These are not the same thing. A reference may itself be settled, but using +/// the handle may still be pending: +/// +/// ReferenceValue (settled = true) +/// | +/// v +/// pending target +/// +/// So `isSettled()` is about one value object, while `ResolvedValue` is about +/// whether the whole handle is usable. +/// /// The evaluation state of a value handle. enum class ResolutionState { /// The handle can be used now. For references, this means the whole - /// reference chain leads to a fully evaluated value. + /// reference chain leads to a settled value. Ready, /// Evaluation is not done yet. The handle itself may still be partial, or a /// reference in the chain may still be missing. @@ -86,12 +110,12 @@ class EvaluatorValue : public std::enable_shared_from_this { // Return true if this value object has finished its own local work. // This is not the same as semantic Ready/Pending state: for example, a - // ReferenceValue can be fully evaluated but still point to a pending value. - // Unknown values are considered fully evaluated. - bool isFullyEvaluated() const { return fullyEvaluated; } - void markFullyEvaluated() { - assert(!fullyEvaluated && "should not mark twice"); - fullyEvaluated = true; + // ReferenceValue can be settled but still point to a pending value. + // Unknown values are considered settled. + bool isSettled() const { return settled; } + void markSettled() { + assert(!settled && "should not mark twice"); + settled = true; } /// Return true if the value is unknown (has unknown in its fan-in). @@ -100,13 +124,13 @@ class EvaluatorValue : public std::enable_shared_from_this { bool isUnknown() const { return unknown; } /// Mark this value as unknown. - /// This also marks the value as fully evaluated if it isn't already, since - /// unknown values are considered fully evaluated. This maintains the - /// invariant that unknown implies fullyEvaluated. + /// This also marks the value as settled if it isn't already, since unknown + /// values are considered settled. This maintains the invariant that unknown + /// implies settled. void markUnknown() { unknown = true; - if (!fullyEvaluated) - markFullyEvaluated(); + if (!settled) + markSettled(); } /// Return the associated MLIR context. @@ -132,7 +156,7 @@ class EvaluatorValue : public std::enable_shared_from_this { const Kind kind; MLIRContext *ctx; Location loc; - bool fullyEvaluated = false; + bool settled = false; bool finalized = false; bool unknown = false; }; @@ -155,7 +179,7 @@ class ReferenceValue : public EvaluatorValue { EvaluatorValuePtr getValue() const { return value; } void setValue(EvaluatorValuePtr newValue) { value = std::move(newValue); - markFullyEvaluated(); + markSettled(); } // Finalize the value. @@ -202,7 +226,7 @@ class AttributeValue : public EvaluatorValue { AttributeValue(PrivateTag, Attribute attr, Location loc) : EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr), type(cast(attr).getType()) { - markFullyEvaluated(); + markSettled(); } // Constructor for partially evaluated AttributeValue @@ -237,12 +261,12 @@ class ListValue : public EvaluatorValue { Location loc) : EvaluatorValue(type.getContext(), Kind::List, loc), type(type), elements(std::move(elements)) { - markFullyEvaluated(); + markSettled(); } void setElements(SmallVector newElements) { elements = std::move(newElements); - markFullyEvaluated(); + markSettled(); } // Finalize the value. @@ -273,7 +297,7 @@ class ObjectValue : public EvaluatorValue { ObjectValue(om::ClassLike cls, ObjectFields fields, Location loc) : EvaluatorValue(cls.getContext(), Kind::Object, loc), cls(cls), fields(std::move(fields)) { - markFullyEvaluated(); + markSettled(); } // Partially evaluated value. @@ -285,7 +309,7 @@ class ObjectValue : public EvaluatorValue { void setFields(llvm::SmallDenseMap newFields) { fields = std::move(newFields); - markFullyEvaluated(); + markSettled(); } /// Return the type of the value, which is a ClassType. @@ -415,13 +439,13 @@ class Evaluator { using ObjectKey = std::pair; private: - bool isFullyEvaluated(Value value, ActualParameters key) { - return isFullyEvaluated({value, key}); + bool isSettled(Value value, ActualParameters key) { + return isSettled({value, key}); } - bool isFullyEvaluated(ObjectKey key) { + bool isSettled(ObjectKey key) { auto val = objects.lookup(key); - return val && val->isFullyEvaluated(); + return val && val->isSettled(); } FailureOr @@ -471,7 +495,7 @@ class Evaluator { std::unique_ptr>>> actualParametersBuffers; - /// A worklist that tracks values which needs to be fully evaluated. + /// A worklist that tracks values that still need more evaluation work. std::queue worklist; /// Evaluator value storage. Return an evaluator value for the given diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index a7a5b691a11b..094a00ef0ea7 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -62,13 +62,13 @@ resolveReferenceValue(evaluator::EvaluatorValuePtr currentValue) { /// Keep the original handle in the result either way. static ResolvedValue resolveValueState(evaluator::EvaluatorValuePtr currentValue) { - if (!currentValue || !currentValue->isFullyEvaluated()) + if (!currentValue || !currentValue->isSettled()) return ResolvedValue::pending(std::move(currentValue)); auto resolved = resolveReferenceValue(currentValue); if (resolved.state != ResolutionState::Ready) return {resolved.state, std::move(currentValue)}; - if (!resolved.value->isFullyEvaluated()) + if (!resolved.value->isSettled()) return ResolvedValue::pending(std::move(currentValue)); return ResolvedValue::ready(std::move(currentValue)); @@ -83,7 +83,7 @@ resolveReadyValue(evaluator::EvaluatorValuePtr value) { assert(value); auto resolved = resolveReferenceValue(value); assert(resolved.state == ResolutionState::Ready); - assert(resolved.value && resolved.value->isFullyEvaluated()); + assert(resolved.value && resolved.value->isSettled()); return resolved.value.get(); } @@ -258,7 +258,7 @@ class OpReadyOperandsPattern : public OperationPattern { evaluator::EvaluatorValuePtr resultValue, llvm::function_ref evaluateValue, Location loc) const final { - if (resultValue && resultValue->isFullyEvaluated()) + if (resultValue && resultValue->isSettled()) return resolveValueState(std::move(resultValue)); SmallVector readyOperands; @@ -516,7 +516,7 @@ class AnyCastPattern final : public OpPattern { evaluateTyped(AnyCastOp op, evaluator::EvaluatorValuePtr resultValue, llvm::function_ref evaluateValue, Location loc) const override { - if (resultValue && resultValue->isFullyEvaluated()) + if (resultValue && resultValue->isSettled()) return resolveValueState(std::move(resultValue)); return evaluateValue(op.getInput()); } @@ -609,7 +609,7 @@ LogicalResult circt::om::evaluator::EvaluatorValue::finalize() { return success(); // Enable the flag to avoid infinite recursions. finalized = true; - assert(isFullyEvaluated()); + assert(isSettled()); return llvm::TypeSwitch(this) .Case([](auto v) { return v->finalizeImpl(); }); @@ -889,7 +889,7 @@ circt::om::Evaluator::instantiate( if (result.state == ResolutionState::Failure) return failure(); - // It's possible that the value is not fully evaluated. + // The value may still be unsettled, so keep it on the worklist. if (result.state == ResolutionState::Pending) worklist.push({value, args}); } @@ -924,7 +924,7 @@ ResolvedValue circt::om::Evaluator::evaluateValue(Value value, }, loc); - if (evaluatorValue.value()->isFullyEvaluated()) + if (evaluatorValue.value()->isSettled()) return resolveValueState(evaluatorValue.value()); return TypeSwitch(result.getDefiningOp()) @@ -1013,11 +1013,11 @@ circt::om::Evaluator::evaluateObjectInstance(ObjectOp op, ActualParameters actualParams) { auto loc = op.getLoc(); auto key = ObjectKey{op, actualParams}; - // Check if the instance is already fully evaluated or being evaluated. This + // Check if the instance is already settled or being evaluated. This // can happen when there is a cycle in the object graph. In this case we // should not attempt to evaluate the instance again, but just return the // current state of the value, which might be pending or unknown. - if (isFullyEvaluated(key) || !activeObjectInstances.insert(key).second) + if (isSettled(key) || !activeObjectInstances.insert(key).second) return resolveValueState(getOrCreateValue(op, actualParams, loc).value()); auto clearActiveObject = llvm::scope_exit([&] { activeObjectInstances.erase(key); }); @@ -1069,7 +1069,7 @@ ResolvedValue circt::om::Evaluator::evaluateObjectField( auto fieldPath = op.getFieldPath().getAsRange(); for (auto it = fieldPath.begin(), end = fieldPath.end(); it != end; ++it) { auto field = *it; - // `currentObject` might no be fully evaluated. + // `currentObject` might not be settled yet. if (!currentObject->getFields().contains(field.getAttr())) return ResolvedValue::pending(objectFieldValue); @@ -1241,24 +1241,24 @@ LogicalResult circt::om::evaluator::ListValue::finalizeImpl() { evaluator::BasePathValue::BasePathValue(MLIRContext *context) : EvaluatorValue(context, Kind::BasePath, UnknownLoc::get(context)), path(PathAttr::get(context, {})) { - markFullyEvaluated(); + markSettled(); } evaluator::BasePathValue::BasePathValue(PathAttr path, Location loc) : EvaluatorValue(path.getContext(), Kind::BasePath, loc), path(path) {} PathAttr evaluator::BasePathValue::getPath() const { - assert(isFullyEvaluated()); + assert(isSettled()); return path; } void evaluator::BasePathValue::setBasepath(const BasePathValue &basepath) { - assert(!isFullyEvaluated()); + assert(!isSettled()); auto newPath = llvm::to_vector(basepath.path.getPath()); auto oldPath = path.getPath(); newPath.append(oldPath.begin(), oldPath.end()); path = PathAttr::get(path.getContext(), newPath); - markFullyEvaluated(); + markSettled(); } //===----------------------------------------------------------------------===// @@ -1273,7 +1273,7 @@ evaluator::PathValue::PathValue(TargetKindAttr targetKind, PathAttr path, evaluator::PathValue evaluator::PathValue::getEmptyPath(Location loc) { PathValue path(nullptr, nullptr, nullptr, nullptr, nullptr, loc); - path.markFullyEvaluated(); + path.markSettled(); return path; } @@ -1323,12 +1323,12 @@ StringAttr evaluator::PathValue::getAsString() const { } void evaluator::PathValue::setBasepath(const BasePathValue &basepath) { - assert(!isFullyEvaluated()); + assert(!isSettled()); auto newPath = llvm::to_vector(basepath.getPath().getPath()); auto oldPath = path.getPath(); newPath.append(oldPath.begin(), oldPath.end()); path = PathAttr::get(path.getContext(), newPath); - markFullyEvaluated(); + markSettled(); } //===----------------------------------------------------------------------===// @@ -1339,19 +1339,18 @@ LogicalResult circt::om::evaluator::AttributeValue::setAttr(Attribute attr) { if (cast(attr).getType() != this->type) return mlir::emitError(getLoc(), "cannot set AttributeValue of type ") << this->type << " to Attribute " << attr; - if (isFullyEvaluated()) - return mlir::emitError( - getLoc(), - "cannot set AttributeValue that has already been fully evaluated"); + if (isSettled()) + return mlir::emitError(getLoc(), + "cannot set AttributeValue that is already settled"); this->attr = attr; - markFullyEvaluated(); + markSettled(); return success(); } LogicalResult circt::om::evaluator::AttributeValue::finalizeImpl() { - if (!isFullyEvaluated()) + if (!isSettled()) return mlir::emitError( - getLoc(), "cannot finalize AttributeValue that is not fully evaluated"); + getLoc(), "cannot finalize AttributeValue that is not settled"); return success(); } From a82bb7b7fd4195c808a6eea19d2484dbf99c4045 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Sat, 18 Apr 2026 00:18:39 -0700 Subject: [PATCH 3/3] Use a folder directly --- include/circt/Dialect/OM/OMOps.td | 1 + lib/Dialect/OM/Evaluator/CMakeLists.txt | 1 + lib/Dialect/OM/Evaluator/Evaluator.cpp | 509 +++--------------- .../OM/Evaluator/EvaluatorPatterns.cpp | 174 ++++++ lib/Dialect/OM/Evaluator/EvaluatorPatterns.h | 273 ++++++++++ lib/Dialect/OM/OMOps.cpp | 46 +- .../Dialect/OM/Evaluator/EvaluatorTests.cpp | 43 +- 7 files changed, 593 insertions(+), 454 deletions(-) create mode 100644 lib/Dialect/OM/Evaluator/EvaluatorPatterns.cpp create mode 100644 lib/Dialect/OM/Evaluator/EvaluatorPatterns.h diff --git a/include/circt/Dialect/OM/OMOps.td b/include/circt/Dialect/OM/OMOps.td index 826f922e8020..01e4dd315edd 100644 --- a/include/circt/Dialect/OM/OMOps.td +++ b/include/circt/Dialect/OM/OMOps.td @@ -460,6 +460,7 @@ class IntegerBinaryArithmeticOp traits = []> : let results = (outs OMIntegerType:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + let hasFolder = 1; } def IntegerAddOp : IntegerBinaryArithmeticOp<"integer.add", [Commutative]> { diff --git a/lib/Dialect/OM/Evaluator/CMakeLists.txt b/lib/Dialect/OM/Evaluator/CMakeLists.txt index d0ae4b70c8ab..883c31aa862c 100644 --- a/lib/Dialect/OM/Evaluator/CMakeLists.txt +++ b/lib/Dialect/OM/Evaluator/CMakeLists.txt @@ -1,5 +1,6 @@ add_circt_library(CIRCTOMEvaluator Evaluator.cpp + EvaluatorPatterns.cpp DEPENDS MLIROMIncGen diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 094a00ef0ea7..26e47dc8794f 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "circt/Dialect/OM/Evaluator/Evaluator.h" +#include "EvaluatorPatterns.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Location.h" #include "mlir/IR/SymbolTable.h" @@ -33,15 +34,14 @@ namespace { using ResolutionState = evaluator::ResolutionState; using ResolvedValue = evaluator::ResolvedValue; +} // namespace -//===----------------------------------------------------------------------===// -// Resolved value helpers -//===----------------------------------------------------------------------===// +namespace circt::om::detail { +namespace { /// Walk through reference values until we reach a non-reference value. /// Return Pending if the chain ends at null. Return Failure if the chain loops. -static ResolvedValue -resolveReferenceValue(evaluator::EvaluatorValuePtr currentValue) { +ResolvedValue resolveReferenceValue(evaluator::EvaluatorValuePtr currentValue) { llvm::SmallPtrSet visited; if (!currentValue) return ResolvedValue::pending(); @@ -58,10 +58,9 @@ resolveReferenceValue(evaluator::EvaluatorValuePtr currentValue) { return ResolvedValue::ready(std::move(currentValue)); } -/// Say whether this handle is ready to use, still pending, or failed. -/// Keep the original handle in the result either way. -static ResolvedValue -resolveValueState(evaluator::EvaluatorValuePtr currentValue) { +} // namespace + +ResolvedValue resolveValueState(evaluator::EvaluatorValuePtr currentValue) { if (!currentValue || !currentValue->isSettled()) return ResolvedValue::pending(std::move(currentValue)); @@ -74,11 +73,7 @@ resolveValueState(evaluator::EvaluatorValuePtr currentValue) { return ResolvedValue::ready(std::move(currentValue)); } -//===----------------------------------------------------------------------===// -// Ready value helpers -//===----------------------------------------------------------------------===// - -static evaluator::EvaluatorValue * +evaluator::EvaluatorValue * resolveReadyValue(evaluator::EvaluatorValuePtr value) { assert(value); auto resolved = resolveReferenceValue(value); @@ -87,31 +82,11 @@ resolveReadyValue(evaluator::EvaluatorValuePtr value) { return resolved.value.get(); } -template -static ValueT *getReadyAs(evaluator::EvaluatorValuePtr value) { - auto *typedValue = - llvm::dyn_cast(resolveReadyValue(std::move(value))); - assert(typedValue); - return typedValue; -} - -template -static AttrT getAsAttr(evaluator::EvaluatorValuePtr value) { - return llvm::dyn_cast( - getReadyAs(std::move(value))->getAttr()); -} - -static bool isUnknownReadyValue(evaluator::EvaluatorValuePtr value) { +bool isUnknownReadyValue(evaluator::EvaluatorValuePtr value) { return (value && value->isUnknown()) || resolveReadyValue(value)->isUnknown(); } -//===----------------------------------------------------------------------===// -// Operand resolution helpers -//===----------------------------------------------------------------------===// - -/// If the operand is ready, write it to `readyValue`. Otherwise return the -/// pending or failure state right away. -static std::optional +std::optional requireReady(const ResolvedValue &resolved, evaluator::EvaluatorValuePtr pendingValue, llvm::function_ref emitFailure, @@ -129,11 +104,7 @@ requireReady(const ResolvedValue &resolved, llvm_unreachable("unknown resolution state"); } -/// Resolve every operand before running the pattern. Stop right away if any -/// operand is pending or failed. Also record whether any operand is unknown so -/// the shared ready-operands path can mark the result unknown without running -/// the op-specific code. -static std::optional requireAllOperandsReady( +std::optional requireAllOperandsReady( ValueRange operands, evaluator::EvaluatorValuePtr pendingValue, llvm::function_ref evaluateOperand, llvm::function_ref emitFailure, @@ -155,430 +126,70 @@ static std::optional requireAllOperandsReady( return std::nullopt; } -//===----------------------------------------------------------------------===// -// Result mutation helpers -//===----------------------------------------------------------------------===// - -static ResolvedValue markUnknownAndReturn(evaluator::EvaluatorValuePtr value) { +ResolvedValue markUnknownAndReturn(evaluator::EvaluatorValuePtr value) { value->markUnknown(); return resolveValueState(std::move(value)); } -static LogicalResult setAttrResult(evaluator::EvaluatorValuePtr resultValue, - Attribute attr) { +LogicalResult setAttrResult(evaluator::EvaluatorValuePtr resultValue, + Attribute attr) { auto *attrValue = cast(resultValue.get()); if (failed(attrValue->setAttr(attr)) || failed(attrValue->finalize())) return failure(); return success(); } -//===----------------------------------------------------------------------===// -// Operation pattern infrastructure -//===----------------------------------------------------------------------===// - -/// Base class for one OM operation in the evaluator. -/// A pattern creates a placeholder for the result and later fills it in. -class OperationPattern { -public: - using CreatePartialValueFn = - llvm::function_ref(Type, - Location)>; - using GetValueHandleFn = - llvm::function_ref(Value, - Location)>; - - explicit OperationPattern(StringRef operationName) - : operationName(operationName) {} - virtual ~OperationPattern() = default; - - StringRef getOperationName() const { return operationName; } - virtual FailureOr - createPlaceholder(Operation *op, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const { - return createPartialValue(value.getType(), loc); - } - - virtual ResolvedValue - evaluate(Operation *op, evaluator::EvaluatorValuePtr resultValue, - llvm::function_ref evaluateValue, - Location loc) const = 0; - -private: - StringRef operationName; -}; - -template -class OpPattern : public OperationPattern { -public: - using OperationPattern::OperationPattern; - - ResolvedValue evaluate(Operation *op, - evaluator::EvaluatorValuePtr resultValue, - llvm::function_ref evaluateValue, - Location loc) const final { - return evaluateTyped(cast(op), std::move(resultValue), evaluateValue, - loc); - } +LogicalResult foldSingleResultOperation( + Operation *op, ArrayRef readyOperands, + evaluator::EvaluatorValuePtr resultValue, StringRef failureMessage) { + assert(op->getNumResults() == 1 && "expected one-result op"); - FailureOr createPlaceholder( - Operation *op, Value value, CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const override { - return createTypedPlaceholder(cast(op), value, createPartialValue, - getValueHandle, loc); - } + SmallVector operandAttrs; + operandAttrs.reserve(readyOperands.size()); + for (auto &operand : readyOperands) + operandAttrs.push_back(getAsAttr(operand)); -protected: - virtual FailureOr - createTypedPlaceholder(OpT op, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const { - return OperationPattern::createPlaceholder(op, value, createPartialValue, - getValueHandle, loc); - } + SmallVector foldResults; + if (failed(op->fold(operandAttrs, foldResults))) + return op->emitError(failureMessage); + if (foldResults.size() != 1) + return op->emitError("expected folder to produce one result"); - virtual ResolvedValue - evaluateTyped(OpT op, evaluator::EvaluatorValuePtr resultValue, - llvm::function_ref evaluateValue, - Location loc) const { - return resolveValueState(std::move(resultValue)); - } -}; - -/// Base class for operations that only run once all operands are ready. -/// This handles the shared ready/pending/failure/unknown logic so concrete -/// patterns only implement the successful case. -template -class OpReadyOperandsPattern : public OperationPattern { -public: - using OperationPattern::OperationPattern; - -private: - ResolvedValue evaluate(Operation *op, - evaluator::EvaluatorValuePtr resultValue, - llvm::function_ref evaluateValue, - Location loc) const final { - if (resultValue && resultValue->isSettled()) - return resolveValueState(std::move(resultValue)); - - SmallVector readyOperands; - bool existsUnknown = false; - if (auto early = requireAllOperandsReady( - op->getOperands(), resultValue, evaluateValue, - [&] { - op->emitError() - << "failed to resolve " << getOperationName() << " operand"; - }, - readyOperands, existsUnknown)) - return *early; - // If any operand is unknown, the result is unknown too. - if (existsUnknown) - return markUnknownAndReturn(std::move(resultValue)); - - if (failed(evaluateTyped(cast(op), readyOperands, resultValue, loc))) - return ResolvedValue::failure(); - return resolveValueState(std::move(resultValue)); - } - - FailureOr - createPlaceholder(Operation *op, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const final { - return createTypedPlaceholder(cast(op), value, createPartialValue, - getValueHandle, loc); - } - -protected: - virtual FailureOr - createTypedPlaceholder(OpT op, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const { - return OperationPattern::createPlaceholder(op, value, createPartialValue, - getValueHandle, loc); - } + Attribute foldedAttr; + auto foldedResult = foldResults.front(); + if (auto attr = llvm::dyn_cast(foldedResult)) { + foldedAttr = attr; + } else + return op->emitError( + "folder returned operands even though all operands are constant, " + "consider enhance the folder or avoid using the folder for this op in " + "the evaluator"); - virtual LogicalResult - evaluateTyped(OpT op, ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const = 0; -}; - -//===----------------------------------------------------------------------===// -// Operation patterns -//===----------------------------------------------------------------------===// - -class IntegerBinaryArithmeticPattern final - : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - LogicalResult evaluateTyped(IntegerBinaryArithmeticOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - assert(operands.size() == 2 && "expected binary arithmetic operands"); - - circt::om::IntegerAttr lhs = getAsAttr(operands[0]); - circt::om::IntegerAttr rhs = getAsAttr(operands[1]); - assert(lhs && "expected om::IntegerAttr for IntegerBinaryArithmeticOp lhs"); - assert(rhs && "expected om::IntegerAttr for IntegerBinaryArithmeticOp rhs"); - - APSInt lhsVal = lhs.getValue().getAPSInt(); - APSInt rhsVal = rhs.getValue().getAPSInt(); - if (lhsVal.getBitWidth() > rhsVal.getBitWidth()) - rhsVal = rhsVal.extend(lhsVal.getBitWidth()); - else if (rhsVal.getBitWidth() > lhsVal.getBitWidth()) - lhsVal = lhsVal.extend(rhsVal.getBitWidth()); - - FailureOr result = op.evaluateIntegerOperation(lhsVal, rhsVal); - if (failed(result)) - return op->emitError("failed to evaluate integer operation"); - - MLIRContext *ctx = op.getContext(); - auto resultAttr = circt::om::IntegerAttr::get( - ctx, mlir::IntegerAttr::get(ctx, result.value())); - return setAttrResult(std::move(resultValue), resultAttr); - } -}; - -class ListCreatePattern final : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - LogicalResult evaluateTyped(ListCreateOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - SmallVector values; - values.reserve(operands.size()); - for (auto operand : operands) - values.push_back(operand); - - cast(resultValue.get()) - ->setElements(std::move(values)); - return success(); - } -}; - -class ListConcatPattern final : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - LogicalResult evaluateTyped(ListConcatOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - SmallVector values; - for (auto operand : operands) { - auto *subListValue = getReadyAs(operand); - llvm::append_range(values, subListValue->getElements()); - } - - cast(resultValue.get()) - ->setElements(std::move(values)); - return success(); - } -}; - -class StringConcatPattern final - : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - LogicalResult evaluateTyped(StringConcatOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - std::string result; - for (auto operand : operands) { - auto attr = getAsAttr(operand); - assert(attr && "expected StringAttr for StringConcatOp operand"); - result += attr.getValue().str(); - } - - auto resultStr = StringAttr::get(result, op.getResult().getType()); - return setAttrResult(std::move(resultValue), resultStr); - } -}; - -class BinaryEqualityPattern final - : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - LogicalResult evaluateTyped(BinaryEqualityOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - assert(operands.size() == 2 && "expected binary equality operands"); - - mlir::Attribute lhs = getAsAttr(operands[0]); - mlir::Attribute rhs = getAsAttr(operands[1]); - - FailureOr result = op.evaluateBinaryEquality(lhs, rhs); - if (failed(result)) - return op->emitError("failed to evaluate binary equality operation"); - return setAttrResult(std::move(resultValue), *result); - } -}; - -class FrozenBasePathCreatePattern final - : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - FailureOr - createTypedPlaceholder(FrozenBasePathCreateOp op, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, - Location loc) const override { - return success( - std::make_shared(op.getPathAttr(), loc)); - } - - LogicalResult evaluateTyped(FrozenBasePathCreateOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - assert(operands.size() == 1 && - "expected one operand for frozenbasepath_create"); - - auto *basePathValue = - getReadyAs(operands.front()); - cast(resultValue.get()) - ->setBasepath(*basePathValue); - return success(); - } -}; - -class FrozenPathCreatePattern final - : public OpReadyOperandsPattern { -public: - using OpReadyOperandsPattern::OpReadyOperandsPattern; - -protected: - FailureOr - createTypedPlaceholder(FrozenPathCreateOp pathOp, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, - Location loc) const override { - return success(std::make_shared( - pathOp.getTargetKindAttr(), pathOp.getPathAttr(), - pathOp.getModuleAttr(), pathOp.getRefAttr(), pathOp.getFieldAttr(), - loc)); - } - - LogicalResult evaluateTyped(FrozenPathCreateOp op, - ArrayRef operands, - evaluator::EvaluatorValuePtr resultValue, - Location loc) const override { - assert(operands.size() == 1 && - "expected one operand for frozenpath_create"); - - auto *basePathValue = - getReadyAs(operands.front()); - cast(resultValue.get())->setBasepath(*basePathValue); - return success(); - } -}; - -class ConstantPattern final : public OpPattern { -public: - using OpPattern::OpPattern; - -protected: - FailureOr createTypedPlaceholder( - ConstantOp op, Value value, CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const override { - return success( - circt::om::evaluator::AttributeValue::get(op.getValue(), loc)); - } -}; + return setAttrResult(std::move(resultValue), foldedAttr); +} -class AnyCastPattern final : public OpPattern { -public: - using OpPattern::OpPattern; +} // namespace circt::om::detail -protected: - FailureOr createTypedPlaceholder( - AnyCastOp op, Value value, CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, Location loc) const override { - return getValueHandle(op.getInput(), loc); - } +using circt::om::detail::getAsAttr; +using circt::om::detail::getReadyAs; +using circt::om::detail::isUnknownReadyValue; +using circt::om::detail::requireReady; +using circt::om::detail::resolveReferenceValue; +using circt::om::detail::resolveValueState; - ResolvedValue - evaluateTyped(AnyCastOp op, evaluator::EvaluatorValuePtr resultValue, - llvm::function_ref evaluateValue, - Location loc) const override { - if (resultValue && resultValue->isSettled()) - return resolveValueState(std::move(resultValue)); - return evaluateValue(op.getInput()); - } -}; - -class FrozenEmptyPathPattern final : public OpPattern { -public: - using OpPattern::OpPattern; - -protected: - FailureOr - createTypedPlaceholder(FrozenEmptyPathOp op, Value value, - CreatePartialValueFn createPartialValue, - GetValueHandleFn getValueHandle, - Location loc) const override { - return success(std::make_shared( - evaluator::PathValue::getEmptyPath(loc))); - } -}; +namespace { //===----------------------------------------------------------------------===// // Operaton Pattern Registery //===----------------------------------------------------------------------===// -class OperationPatternRegistry { -public: - OperationPatternRegistry() { - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - addPattern(); - } - - const OperationPattern *lookup(Operation *op) const { - auto it = patternsByOpName.find(op->getName().getStringRef()); - return it == patternsByOpName.end() ? nullptr : it->second; - } - -private: - template - void addPattern() { - auto pattern = std::make_unique(OpT::getOperationName()); - const OperationPattern *patternPtr = pattern.get(); - patterns.push_back(std::move(pattern)); - patternsByOpName[OpT::getOperationName()] = patternPtr; - } - - SmallVector> patterns; - llvm::StringMap patternsByOpName; -}; - -static const OperationPatternRegistry &getOperationPatternRegistry() { - static const OperationPatternRegistry registry; +static const circt::om::detail::OperationPatternRegistry & +getOperationPatternRegistry() { + static const circt::om::detail::OperationPatternRegistry registry = [] { + circt::om::detail::OperationPatternRegistry registry; + circt::om::detail::registerOperationPatterns(registry); + return registry; + }(); return registry; } @@ -695,14 +306,14 @@ FailureOr circt::om::Evaluator::getOrCreateValue( Operation *op = result.getDefiningOp(); if (auto *pattern = getOperationPatternRegistry().lookup(op)) - return pattern->createPlaceholder( + return pattern->createInitialValue( op, value, - [&](Type type, Location placeholderLoc) { - return getPartiallyEvaluatedValue(type, placeholderLoc); + [&](Type type, Location initialValueLoc) { + return getPartiallyEvaluatedValue(type, initialValueLoc); }, - [&](Value aliasedValue, Location placeholderLoc) { + [&](Value aliasedValue, Location initialValueLoc) { return getOrCreateValue(aliasedValue, actualParams, - placeholderLoc); + initialValueLoc); }, loc); @@ -896,7 +507,7 @@ circt::om::Evaluator::instantiate( auto &object = result.value(); // Finalize the value. This will eliminate intermidiate ReferenceValue used as - // a placeholder in the initialization. + // an initial value during initialization. if (failed(object->finalize())) return cls.emitError() << "failed to finalize evaluation. Probably the " "class contains a dataflow cycle"; diff --git a/lib/Dialect/OM/Evaluator/EvaluatorPatterns.cpp b/lib/Dialect/OM/Evaluator/EvaluatorPatterns.cpp new file mode 100644 index 000000000000..4f78a6766868 --- /dev/null +++ b/lib/Dialect/OM/Evaluator/EvaluatorPatterns.cpp @@ -0,0 +1,174 @@ +//===- EvaluatorPatterns.cpp - OM evaluator concrete patterns -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the concrete OM evaluator operation patterns. +// +//===----------------------------------------------------------------------===// + +#include "EvaluatorPatterns.h" +#include + +using namespace mlir; +using namespace circt::om; + +namespace circt { +namespace om { +namespace detail { + +class ListCreatePattern final + : public OpWhenOperandsReadyPattern { +public: + using OpWhenOperandsReadyPattern::OpWhenOperandsReadyPattern; + +protected: + LogicalResult evaluateTyped(ListCreateOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + SmallVector values; + values.reserve(operands.size()); + for (auto operand : operands) + values.push_back(std::move(operand)); + + cast(resultValue.get()) + ->setElements(std::move(values)); + return success(); + } +}; + +class ListConcatPattern final + : public OpWhenOperandsReadyPattern { +public: + using OpWhenOperandsReadyPattern::OpWhenOperandsReadyPattern; + +protected: + LogicalResult evaluateTyped(ListConcatOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + SmallVector values; + for (const auto &operand : operands) { + auto *subListValue = getReadyAs(operand); + llvm::append_range(values, subListValue->getElements()); + } + + cast(resultValue.get()) + ->setElements(std::move(values)); + return success(); + } +}; + +class FrozenBasePathCreatePattern final + : public OpWhenOperandsReadyPattern { +public: + using OpWhenOperandsReadyPattern::OpWhenOperandsReadyPattern; + +protected: + FailureOr + createInitialValueFor(FrozenBasePathCreateOp op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, + Location loc) const override { + return success( + std::make_shared(op.getPathAttr(), loc)); + } + + LogicalResult evaluateTyped(FrozenBasePathCreateOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + assert(operands.size() == 1 && + "expected one operand for frozenbasepath_create"); + + auto *basePathValue = + getReadyAs(operands.front()); + cast(resultValue.get()) + ->setBasepath(*basePathValue); + return success(); + } +}; + +class FrozenPathCreatePattern final + : public OpWhenOperandsReadyPattern { +public: + using OpWhenOperandsReadyPattern::OpWhenOperandsReadyPattern; + +protected: + FailureOr + createInitialValueFor(FrozenPathCreateOp pathOp, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, + Location loc) const override { + return success(std::make_shared( + pathOp.getTargetKindAttr(), pathOp.getPathAttr(), + pathOp.getModuleAttr(), pathOp.getRefAttr(), pathOp.getFieldAttr(), + loc)); + } + + LogicalResult evaluateTyped(FrozenPathCreateOp op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + assert(operands.size() == 1 && + "expected one operand for frozenpath_create"); + + auto *basePathValue = + getReadyAs(operands.front()); + cast(resultValue.get())->setBasepath(*basePathValue); + return success(); + } +}; + +class AnyCastPattern final : public OpPattern { +public: + using OpPattern::OpPattern; + +protected: + FailureOr + createInitialValueFor(AnyCastOp op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, + Location loc) const override { + return getValueFor(op.getInput(), loc); + } +}; + +class FrozenEmptyPathPattern final : public OpPattern { +public: + using OpPattern::OpPattern; + +protected: + FailureOr + createInitialValueFor(FrozenEmptyPathOp op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, + Location loc) const override { + return success(std::make_shared( + evaluator::PathValue::getEmptyPath(loc))); + } +}; + +void registerOperationPatterns(OperationPatternRegistry ®istry) { + registry.addPattern>(); + registry.addPattern(); + registry.addPattern(); + registry.addPattern>(); + registry.addPattern>(); + registry.addPattern>(); + registry.addPattern>(); + registry.addPattern(); + registry.addPattern(); + registry.addPattern>(); + registry.addPattern>(); + registry.addPattern(); + registry.addPattern(); +} + +} // namespace detail +} // namespace om +} // namespace circt diff --git a/lib/Dialect/OM/Evaluator/EvaluatorPatterns.h b/lib/Dialect/OM/Evaluator/EvaluatorPatterns.h new file mode 100644 index 000000000000..e7063a71db47 --- /dev/null +++ b/lib/Dialect/OM/Evaluator/EvaluatorPatterns.h @@ -0,0 +1,273 @@ +//===- EvaluatorPatterns.h - OM evaluator pattern details -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Stuff shared between Evaluator.cpp and EvaluatorPatterns.cpp. +// +//===----------------------------------------------------------------------===// + +// clang-tidy seems to expect the absolute path in the header guard on some +// systems, so just disable it. +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef DIALECT_OM_EVALUATOR_EVALUATORPATTERNS_H +#define DIALECT_OM_EVALUATOR_EVALUATORPATTERNS_H + +#include "circt/Dialect/OM/Evaluator/Evaluator.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include + +namespace circt { +namespace om { +namespace detail { + +using ResolutionState = evaluator::ResolutionState; +using ResolvedValue = evaluator::ResolvedValue; + +ResolvedValue resolveValueState(evaluator::EvaluatorValuePtr currentValue); +evaluator::EvaluatorValue * +resolveReadyValue(evaluator::EvaluatorValuePtr value); +bool isUnknownReadyValue(evaluator::EvaluatorValuePtr value); + +std::optional +requireReady(const ResolvedValue &resolved, + evaluator::EvaluatorValuePtr pendingValue, + llvm::function_ref emitFailure, + evaluator::EvaluatorValuePtr &readyValue); + +std::optional requireAllOperandsReady( + ValueRange operands, evaluator::EvaluatorValuePtr pendingValue, + llvm::function_ref evaluateOperand, + llvm::function_ref emitFailure, + SmallVectorImpl &readyOperands, + bool &existsUnknown); + +ResolvedValue markUnknownAndReturn(evaluator::EvaluatorValuePtr value); +LogicalResult setAttrResult(evaluator::EvaluatorValuePtr resultValue, + Attribute attr); +LogicalResult foldSingleResultOperation( + Operation *op, ArrayRef readyOperands, + evaluator::EvaluatorValuePtr resultValue, + StringRef failureMessage = "failed to evaluate operation"); + +template +ValueT *getReadyAs(evaluator::EvaluatorValuePtr value) { + auto *typedValue = + llvm::dyn_cast(resolveReadyValue(std::move(value))); + assert(typedValue); + return typedValue; +} + +template +AttrT getAsAttr(evaluator::EvaluatorValuePtr value) { + return llvm::dyn_cast( + getReadyAs(std::move(value))->getAttr()); +} + +/// Base class for one OM operation in the evaluator. +/// A pattern picks an initial value for the result and later fills it in. +class OperationPattern { +public: + using GetPartialValueForTypeFn = + llvm::function_ref(Type, + Location)>; + using GetValueForFn = + llvm::function_ref(Value, + Location)>; + + explicit OperationPattern(StringRef operationName) + : operationName(operationName) {} + virtual ~OperationPattern() = default; + + StringRef getOperationName() const { return operationName; } + FailureOr + createInitialValue(Operation *op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, Location loc) const { + return createInitialValueImpl(op, value, getPartialValueForType, getValueFor, + loc); + } + + virtual ResolvedValue + evaluate(Operation *op, evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const = 0; + +protected: + static FailureOr + ccreateDefaultInitialValue(Value value, + GetPartialValueForTypeFn getPartialValueForType, + Location loc) { + return getPartialValueForType(value.getType(), loc); + } + +private: + virtual FailureOr + createInitialValueImpl(Operation *op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, Location loc) const { + return ccreateDefaultInitialValue(value, getPartialValueForType, loc); + } + + StringRef operationName; +}; + +template +class OpPattern : public OperationPattern { +public: + using OperationPattern::OperationPattern; + +private: + FailureOr + createInitialValueImpl(Operation *op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, Location loc) const final { + return createInitialValueFor(cast(op), value, getPartialValueForType, + getValueFor, loc); + } + + ResolvedValue evaluate(Operation *op, + evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const override { + return evaluateTyped(cast(op), std::move(resultValue), evaluateValue, + loc); + } + +protected: + virtual FailureOr + createInitialValueFor(OpT op, Value value, + GetPartialValueForTypeFn getPartialValueForType, + GetValueForFn getValueFor, Location loc) const { + return ccreateDefaultInitialValue(value, getPartialValueForType, loc); + } + + virtual ResolvedValue + evaluateTyped(OpT op, evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const { + return resolveValueState(std::move(resultValue)); + } +}; + +/// Base class for operations that only run once all operands are ready. +/// This handles the shared ready/pending/failure/unknown logic so concrete +/// patterns only implement the successful case. +template +class OpWhenOperandsReadyPattern : public OpPattern { +public: + using OpPattern::OpPattern; + using OpPattern::evaluateTyped; + +private: + ResolvedValue evaluate(Operation *op, + evaluator::EvaluatorValuePtr resultValue, + llvm::function_ref evaluateValue, + Location loc) const final { + if (resultValue && resultValue->isSettled()) + return resolveValueState(std::move(resultValue)); + + SmallVector readyOperands; + bool existsUnknown = false; + if (auto early = requireAllOperandsReady( + op->getOperands(), resultValue, evaluateValue, + [&] { + op->emitError() << "failed to resolve " + << this->getOperationName() << " operand"; + }, + readyOperands, existsUnknown)) + return *early; + // If any operand is unknown, the result is unknown too. + if (existsUnknown) + return markUnknownAndReturn(std::move(resultValue)); + + if (failed(evaluateTyped(cast(op), readyOperands, resultValue, loc))) + return ResolvedValue::failure(); + return resolveValueState(std::move(resultValue)); + } + +protected: + virtual LogicalResult + evaluateTyped(OpT op, ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const = 0; +}; + +/// Base class for single-result ops that can reuse the op's regular folder once +/// all operands are ready. +template +class OpFolderPattern : public OpWhenOperandsReadyPattern { +public: + using OpWhenOperandsReadyPattern::OpWhenOperandsReadyPattern; + +protected: + FailureOr + createInitialValueFor(OpT op, Value value, + typename OperationPattern::GetPartialValueForTypeFn + getPartialValueForType, + typename OperationPattern::GetValueForFn getValueFor, + Location loc) const override { + SmallVector operandAttrs(op->getNumOperands()); + SmallVector foldResults; + if (succeeded(op->fold(operandAttrs, foldResults)) && + foldResults.size() == 1) { + if (auto foldedAttr = llvm::dyn_cast(foldResults.front())) + return success( + circt::om::evaluator::AttributeValue::get(foldedAttr, loc)); + if (auto foldedValue = llvm::dyn_cast(foldResults.front())) + return getValueFor(foldedValue, loc); + } + return OpPattern::createInitialValueFor(op, value, getPartialValueForType, + getValueFor, loc); + } + + LogicalResult evaluateTyped(OpT op, + ArrayRef operands, + evaluator::EvaluatorValuePtr resultValue, + Location loc) const override { + return foldSingleResultOperation(op.getOperation(), operands, + std::move(resultValue)); + } +}; + +class OperationPatternRegistry { +public: + const OperationPattern *lookup(Operation *op) const { + auto it = patternsByOpName.find(op->getName().getStringRef()); + return it == patternsByOpName.end() ? nullptr : it->second; + } + + template + void addPattern() { + auto pattern = std::make_unique(OpT::getOperationName()); + const OperationPattern *patternPtr = pattern.get(); + patterns.push_back(std::move(pattern)); + patternsByOpName[OpT::getOperationName()] = patternPtr; + } + +private: + SmallVector> patterns; + llvm::StringMap patternsByOpName; +}; + +void registerOperationPatterns(OperationPatternRegistry ®istry); + +} // namespace detail +} // namespace om +} // namespace circt + +#endif // DIALECT_OM_EVALUATOR_EVALUATORPATTERNS_H diff --git a/lib/Dialect/OM/OMOps.cpp b/lib/Dialect/OM/OMOps.cpp index d92b34fbb2dc..148bc7332839 100644 --- a/lib/Dialect/OM/OMOps.cpp +++ b/lib/Dialect/OM/OMOps.cpp @@ -643,12 +643,40 @@ PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // IntegerAddOp //===----------------------------------------------------------------------===// +template +static OpFoldResult +foldIntegerBinaryArithmetic(OpT op, typename OpT::FoldAdaptor adaptor) { + auto lhs = dyn_cast_or_null(adaptor.getLhs()); + auto rhs = dyn_cast_or_null(adaptor.getRhs()); + if (!lhs || !rhs) + return {}; + + APSInt lhsVal = lhs.getValue().getAPSInt(); + APSInt rhsVal = rhs.getValue().getAPSInt(); + if (lhsVal.getBitWidth() > rhsVal.getBitWidth()) + rhsVal = rhsVal.extend(lhsVal.getBitWidth()); + else if (rhsVal.getBitWidth() > lhsVal.getBitWidth()) + lhsVal = lhsVal.extend(rhsVal.getBitWidth()); + + auto result = op.evaluateIntegerOperation(lhsVal, rhsVal); + if (failed(result)) + return {}; + + auto *ctx = op.getContext(); + return circt::om::IntegerAttr::get( + ctx, mlir::IntegerAttr::get(ctx, result.value())); +} + FailureOr IntegerAddOp::evaluateIntegerOperation(const llvm::APSInt &lhs, const llvm::APSInt &rhs) { return success(lhs + rhs); } +OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) { + return foldIntegerBinaryArithmetic(*this, adaptor); +} + //===----------------------------------------------------------------------===// // IntegerMulOp //===----------------------------------------------------------------------===// @@ -659,6 +687,10 @@ IntegerMulOp::evaluateIntegerOperation(const llvm::APSInt &lhs, return success(lhs * rhs); } +OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) { + return foldIntegerBinaryArithmetic(*this, adaptor); +} + //===----------------------------------------------------------------------===// // IntegerShrOp //===----------------------------------------------------------------------===// @@ -675,6 +707,10 @@ IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs, return success(lhs >> rhs.getExtValue()); } +OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) { + return foldIntegerBinaryArithmetic(*this, adaptor); +} + //===----------------------------------------------------------------------===// // IntegerShlOp //===----------------------------------------------------------------------===// @@ -691,14 +727,22 @@ IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs, return success(lhs << rhs.getExtValue()); } +OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) { + return foldIntegerBinaryArithmetic(*this, adaptor); +} + //===----------------------------------------------------------------------===// // StringConcatOp //===----------------------------------------------------------------------===// OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) { // Fold single-operand concat to just the operand. - if (getStrings().size() == 1) + if (getStrings().size() == 1) { + if (auto strAttr = adaptor.getStrings()[0]) + return strAttr; + return getStrings()[0]; + } // Check if all operands are constant strings before accumulating. if (!llvm::all_of(adaptor.getStrings(), [](Attribute operand) { diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index 4dd77d45e097..79243aeffa82 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -798,7 +798,7 @@ om.class @IntegerBinaryArithmeticShrNegative() -> (result: !om.integer){ ASSERT_EQ(diag.str(), "'om.integer.shr' op shift amount must be non-negative"); if (StringRef(diag.str()).starts_with("failed")) - ASSERT_EQ(diag.str(), "failed to evaluate integer operation"); + ASSERT_EQ(diag.str(), "failed to evaluate operation"); }); OwningOpRef owning = @@ -834,7 +834,7 @@ om.class @IntegerBinaryArithmeticShrTooLarge() -> (result: !om.integer){ diag.str(), "'om.integer.shr' op shift amount must be representable in 64 bits"); if (StringRef(diag.str()).starts_with("failed")) - ASSERT_EQ(diag.str(), "failed to evaluate integer operation"); + ASSERT_EQ(diag.str(), "failed to evaluate operation"); }); OwningOpRef owning = @@ -905,7 +905,7 @@ om.class @IntegerBinaryArithmeticShlNegative() -> (result: !om.integer) { ASSERT_EQ(diag.str(), "'om.integer.shl' op shift amount must be non-negative"); if (StringRef(diag.str()).starts_with("failed")) - ASSERT_EQ(diag.str(), "failed to evaluate integer operation"); + ASSERT_EQ(diag.str(), "failed to evaluate operation"); }); OwningOpRef owning = @@ -941,7 +941,7 @@ om.class @IntegerBinaryArithmeticShlTooLarge() -> (result: !om.integer) { diag.str(), "'om.integer.shl' op shift amount must be representable in 64 bits"); if (StringRef(diag.str()).starts_with("failed")) - ASSERT_EQ(diag.str(), "failed to evaluate integer operation"); + ASSERT_EQ(diag.str(), "failed to evaluate operation"); }); OwningOpRef owning = @@ -1761,6 +1761,41 @@ module { .getValue()); } +TEST(EvaluatorTests, StringConcatSingleOperand) { + const char *mod = R"MLIR( +module { + om.class @Test() -> (result: !om.string) { + %0 = om.constant "Hello" : !om.string + %1 = om.string.concat %0 : !om.string + om.class.fields %1 : !om.string + } +} +)MLIR"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate(StringAttr::get(&context, "Test"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ("Hello", llvm::cast(fieldValue.get()) + ->getAs() + .getValue()); +} + TEST(EvaluatorTests, UnknownObjectFieldTest) { StringRef mod = R"MLIR( om.class.extern @Dut_Class(%basepath: !om.frozenbasepath) -> (omirOut: !om.list) {