diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index c29c10872cae..593655bae5bc 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -61,6 +61,14 @@ class EvaluatorValue : public std::enable_shared_from_this { void markFullyEvaluated() { assert(!fullyEvaluated && "should not mark twice"); fullyEvaluated = true; + // Increment the counter if one is set. + if (fullyEvaluatedCounter) + ++(*fullyEvaluatedCounter); + } + + /// Set a counter to increment when this value becomes fully evaluated. + void setFullyEvaluatedCounter(uint64_t *counter) { + fullyEvaluatedCounter = counter; } /// Return true if the value is unknown (has unknown in its fan-in). @@ -104,6 +112,7 @@ class EvaluatorValue : public std::enable_shared_from_this { bool fullyEvaluated = false; bool finalized = false; bool unknown = false; + uint64_t *fullyEvaluatedCounter = nullptr; }; /// Values which can be used as pointers to different values. @@ -393,6 +402,9 @@ class Evaluator { using ObjectKey = std::pair; + /// Get the number of fully evaluated nodes tracked by this evaluator. + uint64_t getFullyEvaluatedCount() const { return fullyEvaluatedCount; } + private: bool isFullyEvaluated(Value value, ActualParameters key) { return isFullyEvaluated({value, key}); @@ -403,6 +415,12 @@ class Evaluator { return val && val->isFullyEvaluated(); } + /// Attach the evaluation counter to a newly created value. + void attachCounter(evaluator::EvaluatorValuePtr &value) { + if (value && !value->isFullyEvaluated()) + value->setFullyEvaluatedCounter(&fullyEvaluatedCount); + } + FailureOr getOrCreateValue(Value value, ActualParameters actualParams, Location loc); FailureOr @@ -478,8 +496,11 @@ class Evaluator { std::unique_ptr>>> actualParametersBuffers; - /// A worklist that tracks values which needs to be fully evaluated. - std::queue worklist; + /// Worklists that track values which need to be fully evaluated. + /// We use two worklists to detect cycles: process all items from one, + /// and if any become fully evaluated, swap and continue. + std::vector worklist; + std::vector nextWorklist; /// A queue of pending property assertions to be evaluated after the worklist /// is fully drained. Each entry is a (PropertyAssertOp, ActualParameters) @@ -492,6 +513,9 @@ class Evaluator { /// instantiation context (a pair of Value and parameters). DenseMap> objects; + /// Counter for fully evaluated nodes. + uint64_t fullyEvaluatedCount = 0; + #ifndef NDEBUG /// Current nesting depth for debug output indentation. unsigned debugNesting = 0; diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 18ff9f98e5d2..1f722b78c3ad 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -70,32 +70,38 @@ FailureOr circt::om::Evaluator::getPartiallyEvaluatedValue(Type type, Location loc) { using namespace circt::om::evaluator; - return TypeSwitch>(type) - .Case([&](circt::om::ListType type) { - evaluator::EvaluatorValuePtr result = - std::make_shared(type, loc); - return success(result); - }) - .Case([&](circt::om::ClassType type) - -> FailureOr { - auto classDef = - symbolTable.lookup(type.getClassName().getValue()); - if (!classDef) - return symbolTable.getOp()->emitError("unknown class name ") - << type.getClassName(); - - // Create an ObjectValue for both ClassOp and ClassExternOp - evaluator::EvaluatorValuePtr result = - std::make_shared(classDef, loc); - - return success(result); - }) - .Case([&](circt::om::StringType type) { - evaluator::EvaluatorValuePtr result = - evaluator::AttributeValue::get(type, loc); - return success(result); - }) - .Default([&](auto type) { return failure(); }); + auto result = + TypeSwitch>(type) + .Case([&](circt::om::ListType type) { + evaluator::EvaluatorValuePtr result = + std::make_shared(type, loc); + return success(result); + }) + .Case([&](circt::om::ClassType type) + -> FailureOr { + auto classDef = + symbolTable.lookup(type.getClassName().getValue()); + if (!classDef) + return symbolTable.getOp()->emitError("unknown class name ") + << type.getClassName(); + + // Create an ObjectValue for both ClassOp and ClassExternOp + evaluator::EvaluatorValuePtr result = + std::make_shared(classDef, loc); + + return success(result); + }) + .Case([&](circt::om::StringType type) { + evaluator::EvaluatorValuePtr result = + evaluator::AttributeValue::get(type, loc); + return success(result); + }) + .Default([&](auto type) { return failure(); }); + + if (succeeded(result)) + attachCounter(result.value()); + + return result; } FailureOr circt::om::Evaluator::getOrCreateValue( @@ -186,6 +192,8 @@ FailureOr circt::om::Evaluator::getOrCreateValue( if (failed(result)) return result; + // Attach listener to newly created values + attachCounter(result.value()); objects[{value, actualParams}] = result.value(); return result; } @@ -212,6 +220,7 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className, if (isa(classDef)) { evaluator::EvaluatorValuePtr result = std::make_shared(classDef, loc); + attachCounter(result); result->markUnknown(); LLVM_DEBUG(dbgs(1) << "extern: \n"); return result; @@ -283,7 +292,7 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className, UnknownLoc::get(context)))) return failure(); // Add to the worklist. - worklist.push({result, actualParams}); + worklist.push_back({result, actualParams}); } } @@ -329,6 +338,9 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className, // If it's external call, just allocate new ObjectValue. evaluator::EvaluatorValuePtr result = std::make_shared(cls, fields, loc); + // Object is already fully evaluated when created with fields. + assert(result->isFullyEvaluated() && + "object with fields should be fully evaluated"); return result; } @@ -355,6 +367,7 @@ circt::om::Evaluator::instantiate( evaluator::EvaluatorValuePtr result = std::make_shared( classDef, UnknownLoc::get(classDef.getContext())); + attachCounter(result); result->markUnknown(); LLVM_DEBUG(dbgs(1) << "result: \n"); return result; @@ -380,18 +393,46 @@ circt::om::Evaluator::instantiate( // `evaluateObjectInstance` has populated the worklist. Continue evaluations // unless there is a partially evaluated value. LLVM_DEBUG(dbgs() << "worklist:\n"); - while (!worklist.empty()) { - auto [value, args] = worklist.front(); - worklist.pop(); - - auto result = evaluateValue(value, args, loc); - if (failed(result)) - return failure(); + // Use two-worklist approach: process all items from current worklist, and if + // at least one becomes fully evaluated, swap and continue. If a full pass + // completes with no progress, we have a cycle. + while (!worklist.empty()) { + uint64_t countBeforePass = fullyEvaluatedCount; + LLVM_DEBUG(dbgs() << "- processing " << worklist.size() + << " items (fully evaluated count: " + << fullyEvaluatedCount << ")\n"); + + // Process all items in the current worklist. + while (!worklist.empty()) { + auto [value, args] = worklist.back(); + worklist.pop_back(); + auto result = evaluateValue(value, args, loc); + + if (failed(result)) + return failure(); + + // If not fully evaluated, add to next worklist for retry. + if (!result.value()->isFullyEvaluated()) + nextWorklist.push_back({value, args}); + } - // It's possible that the value is not fully evaluated. - if (!result.value()->isFullyEvaluated()) - worklist.push({value, args}); + // Check if we made progress. + uint64_t evaluatedThisPass = fullyEvaluatedCount - countBeforePass; + LLVM_DEBUG(dbgs() << "- evaluated " << evaluatedThisPass + << " nodes this pass\n"); + + // If nothing became fully evaluated in this pass, we have a cycle. + if (evaluatedThisPass == 0 && !nextWorklist.empty()) + return cls.emitError() + << "cycle detected: " << nextWorklist.size() + << " values remain partially evaluated after full pass with no " + "progress (total fully evaluated: " + << fullyEvaluatedCount << ")"; + + // Swap worklists for next iteration. + worklist = std::move(nextWorklist); + nextWorklist.clear(); } // Now that all values are fully resolved, evaluate the deferred property @@ -732,6 +773,9 @@ circt::om::Evaluator::evaluateObjectField(ObjectFieldOp op, currentObject = nextObject; } + if (!finalField->isFullyEvaluated()) + return objectFieldValue; + // Update the reference. llvm::cast(objectFieldValue.get()) ->setValue(finalField); @@ -1044,7 +1088,7 @@ circt::om::Evaluator::createUnknownValue(Type type, Location loc) { return success(AttributeValue::get(type, LocationAttr(loc))); }); - // Mark the result as unknown if successful + // Mark the result as unknown if successful. if (succeeded(result)) result->get()->markUnknown(); diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index 398f9fd3b675..e9a02a5df2c8 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -654,8 +654,9 @@ om.class @ReferenceEachOther() -> (field: !ty){ context.getOrLoadDialect(); context.getDiagEngine().registerHandler([&](Diagnostic &diag) { - ASSERT_EQ(diag.str(), "failed to finalize evaluation. Probably the class " - "contains a dataflow cycle"); + ASSERT_EQ(diag.str(), + "cycle detected: 1 values remain partially evaluated after full " + "pass with no progress (total fully evaluated: 1)"); }); OwningOpRef owning = @@ -669,6 +670,50 @@ om.class @ReferenceEachOther() -> (field: !ty){ ASSERT_TRUE(failed(result)); } +// Test nested object field references. +// https://github.com/llvm/circt/issues/10264 +TEST(EvaluatorTests, Issue10264NestedFieldReferences) { + StringRef mod = R"MLIR( +om.class @Domain(%in: !om.string) -> (out: !om.string) { + om.class.fields %in : !om.string +} + +om.class @Top() -> (test: i1) { + %0 = om.constant "A" : !om.string + %1 = om.object @Domain(%0) : (!om.string) -> !om.class.type<@Domain> + %2 = om.object.field %1, [@out] : (!om.class.type<@Domain>) -> !om.string + %3 = om.object @Domain(%2) : (!om.string) -> !om.class.type<@Domain> + %4 = om.object.field %3, [@out] : (!om.class.type<@Domain>) -> !om.string + %5 = om.constant "B" : !om.string + %6 = om.prop.eq %4, %5 : !om.string + om.class.fields %6 : i1 +} +)MLIR"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate(StringAttr::get(&context, "Top"), {}); + + ASSERT_TRUE(succeeded(result)); + + // Verify the result is correct (false since "A" != "B") + auto fieldValue = llvm::cast(result.value().get()) + ->getField("test") + .value(); + auto boolValue = llvm::cast(fieldValue.get()) + ->getAs(); + ASSERT_FALSE(boolValue.getValue()); +} + TEST(EvaluatorTests, IntegerBinaryArithmeticAdd) { StringRef mod = R"MLIR( om.class @IntegerBinaryArithmeticAdd() -> (result: !om.integer) {