Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions include/circt/Dialect/OM/Evaluator/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class EvaluatorValue : public std::enable_shared_from_this<EvaluatorValue> {
void markFullyEvaluated() {
assert(!fullyEvaluated && "should not mark twice");
fullyEvaluated = true;
// Increment the counter if one is set.
if (fullyEvaluatedCounter)
++(*fullyEvaluatedCounter);
}

/// Set a counter to increment when this value becomes fully evaluated.
void setFullyEvaluatedCounter(uint64_t *counter) {
fullyEvaluatedCounter = counter;
}

/// Return true if the value is unknown (has unknown in its fan-in).
Expand Down Expand Up @@ -104,6 +112,7 @@ class EvaluatorValue : public std::enable_shared_from_this<EvaluatorValue> {
bool fullyEvaluated = false;
bool finalized = false;
bool unknown = false;
uint64_t *fullyEvaluatedCounter = nullptr;
};

/// Values which can be used as pointers to different values.
Expand Down Expand Up @@ -393,6 +402,9 @@ class Evaluator {

using ObjectKey = std::pair<Value, ActualParameters>;

/// Get the number of fully evaluated nodes tracked by this evaluator.
uint64_t getFullyEvaluatedCount() const { return fullyEvaluatedCount; }

private:
bool isFullyEvaluated(Value value, ActualParameters key) {
return isFullyEvaluated({value, key});
Expand All @@ -403,6 +415,12 @@ class Evaluator {
return val && val->isFullyEvaluated();
}

/// Attach the evaluation counter to a newly created value.
void attachCounter(evaluator::EvaluatorValuePtr &value) {
if (value && !value->isFullyEvaluated())
value->setFullyEvaluatedCounter(&fullyEvaluatedCount);
}

FailureOr<EvaluatorValuePtr>
getOrCreateValue(Value value, ActualParameters actualParams, Location loc);
FailureOr<EvaluatorValuePtr>
Expand Down Expand Up @@ -478,8 +496,11 @@ class Evaluator {
std::unique_ptr<SmallVector<std::shared_ptr<evaluator::EvaluatorValue>>>>
actualParametersBuffers;

/// A worklist that tracks values which needs to be fully evaluated.
std::queue<ObjectKey> worklist;
/// Worklists that track values which need to be fully evaluated.
/// We use two worklists to detect cycles: process all items from one,
/// and if any become fully evaluated, swap and continue.
std::vector<ObjectKey> worklist;
std::vector<ObjectKey> nextWorklist;

/// A queue of pending property assertions to be evaluated after the worklist
/// is fully drained. Each entry is a (PropertyAssertOp, ActualParameters)
Expand All @@ -492,6 +513,9 @@ class Evaluator {
/// instantiation context (a pair of Value and parameters).
DenseMap<ObjectKey, std::shared_ptr<evaluator::EvaluatorValue>> objects;

/// Counter for fully evaluated nodes.
uint64_t fullyEvaluatedCount = 0;

#ifndef NDEBUG
/// Current nesting depth for debug output indentation.
unsigned debugNesting = 0;
Expand Down
120 changes: 82 additions & 38 deletions lib/Dialect/OM/Evaluator/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,38 @@ FailureOr<evaluator::EvaluatorValuePtr>
circt::om::Evaluator::getPartiallyEvaluatedValue(Type type, Location loc) {
using namespace circt::om::evaluator;

return TypeSwitch<mlir::Type, FailureOr<evaluator::EvaluatorValuePtr>>(type)
.Case([&](circt::om::ListType type) {
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ListValue>(type, loc);
return success(result);
})
.Case([&](circt::om::ClassType type)
-> FailureOr<evaluator::EvaluatorValuePtr> {
auto classDef =
symbolTable.lookup<ClassLike>(type.getClassName().getValue());
if (!classDef)
return symbolTable.getOp()->emitError("unknown class name ")
<< type.getClassName();

// Create an ObjectValue for both ClassOp and ClassExternOp
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ObjectValue>(classDef, loc);

return success(result);
})
.Case([&](circt::om::StringType type) {
evaluator::EvaluatorValuePtr result =
evaluator::AttributeValue::get(type, loc);
return success(result);
})
.Default([&](auto type) { return failure(); });
auto result =
TypeSwitch<mlir::Type, FailureOr<evaluator::EvaluatorValuePtr>>(type)
.Case([&](circt::om::ListType type) {
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ListValue>(type, loc);
return success(result);
})
.Case([&](circt::om::ClassType type)
-> FailureOr<evaluator::EvaluatorValuePtr> {
auto classDef =
symbolTable.lookup<ClassLike>(type.getClassName().getValue());
if (!classDef)
return symbolTable.getOp()->emitError("unknown class name ")
<< type.getClassName();

// Create an ObjectValue for both ClassOp and ClassExternOp
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ObjectValue>(classDef, loc);

return success(result);
})
.Case([&](circt::om::StringType type) {
evaluator::EvaluatorValuePtr result =
evaluator::AttributeValue::get(type, loc);
return success(result);
})
.Default([&](auto type) { return failure(); });

if (succeeded(result))
attachCounter(result.value());
Comment on lines +101 to +102
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the checking in attachCounter, can this blindly call attachCounter(result.value())? I guess not given that the return type is FailureOr<...>? result.value() must then auto-unpack the FailureOr for you?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think result.value() causes assertion failure for FailureOr::failuare(= std::nullopt), so we need to add a guard for this.


return result;
}

FailureOr<evaluator::EvaluatorValuePtr> circt::om::Evaluator::getOrCreateValue(
Expand Down Expand Up @@ -186,6 +192,8 @@ FailureOr<evaluator::EvaluatorValuePtr> circt::om::Evaluator::getOrCreateValue(
if (failed(result))
return result;

// Attach listener to newly created values
attachCounter(result.value());
objects[{value, actualParams}] = result.value();
return result;
}
Expand All @@ -212,6 +220,7 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className,
if (isa<ClassExternOp>(classDef)) {
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ObjectValue>(classDef, loc);
attachCounter(result);
result->markUnknown();
LLVM_DEBUG(dbgs(1) << "extern: <unknown-value>\n");
return result;
Expand Down Expand Up @@ -283,7 +292,7 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className,
UnknownLoc::get(context))))
return failure();
// Add to the worklist.
worklist.push({result, actualParams});
worklist.push_back({result, actualParams});
}
}

Expand Down Expand Up @@ -329,6 +338,9 @@ circt::om::Evaluator::evaluateObjectInstance(StringAttr className,
// If it's external call, just allocate new ObjectValue.
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ObjectValue>(cls, fields, loc);
// Object is already fully evaluated when created with fields.
assert(result->isFullyEvaluated() &&
"object with fields should be fully evaluated");
return result;
}

Expand All @@ -355,6 +367,7 @@ circt::om::Evaluator::instantiate(
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::ObjectValue>(
classDef, UnknownLoc::get(classDef.getContext()));
attachCounter(result);
result->markUnknown();
LLVM_DEBUG(dbgs(1) << "result: <unknown extern>\n");
return result;
Expand All @@ -380,18 +393,46 @@ circt::om::Evaluator::instantiate(
// `evaluateObjectInstance` has populated the worklist. Continue evaluations
// unless there is a partially evaluated value.
LLVM_DEBUG(dbgs() << "worklist:\n");
while (!worklist.empty()) {
auto [value, args] = worklist.front();
worklist.pop();

auto result = evaluateValue(value, args, loc);

if (failed(result))
return failure();
// Use two-worklist approach: process all items from current worklist, and if
// at least one becomes fully evaluated, swap and continue. If a full pass
// completes with no progress, we have a cycle.
while (!worklist.empty()) {
uint64_t countBeforePass = fullyEvaluatedCount;
LLVM_DEBUG(dbgs() << "- processing " << worklist.size()
<< " items (fully evaluated count: "
<< fullyEvaluatedCount << ")\n");

// Process all items in the current worklist.
while (!worklist.empty()) {
auto [value, args] = worklist.back();
worklist.pop_back();
auto result = evaluateValue(value, args, loc);

if (failed(result))
return failure();

// If not fully evaluated, add to next worklist for retry.
if (!result.value()->isFullyEvaluated())
nextWorklist.push_back({value, args});
}

// It's possible that the value is not fully evaluated.
if (!result.value()->isFullyEvaluated())
worklist.push({value, args});
// Check if we made progress.
uint64_t evaluatedThisPass = fullyEvaluatedCount - countBeforePass;
LLVM_DEBUG(dbgs() << "- evaluated " << evaluatedThisPass
<< " nodes this pass\n");

// If nothing became fully evaluated in this pass, we have a cycle.
if (evaluatedThisPass == 0 && !nextWorklist.empty())
return cls.emitError()
<< "cycle detected: " << nextWorklist.size()
<< " values remain partially evaluated after full pass with no "
"progress (total fully evaluated: "
<< fullyEvaluatedCount << ")";

// Swap worklists for next iteration.
worklist = std::move(nextWorklist);
nextWorklist.clear();
}

// Now that all values are fully resolved, evaluate the deferred property
Expand Down Expand Up @@ -732,6 +773,9 @@ circt::om::Evaluator::evaluateObjectField(ObjectFieldOp op,
currentObject = nextObject;
}

if (!finalField->isFullyEvaluated())
return objectFieldValue;

// Update the reference.
llvm::cast<evaluator::ReferenceValue>(objectFieldValue.get())
->setValue(finalField);
Expand Down Expand Up @@ -1044,7 +1088,7 @@ circt::om::Evaluator::createUnknownValue(Type type, Location loc) {
return success(AttributeValue::get(type, LocationAttr(loc)));
});

// Mark the result as unknown if successful
// Mark the result as unknown if successful.
if (succeeded(result))
result->get()->markUnknown();

Expand Down
49 changes: 47 additions & 2 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,9 @@ om.class @ReferenceEachOther() -> (field: !ty){
context.getOrLoadDialect<OMDialect>();

context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(), "failed to finalize evaluation. Probably the class "
"contains a dataflow cycle");
ASSERT_EQ(diag.str(),
"cycle detected: 1 values remain partially evaluated after full "
"pass with no progress (total fully evaluated: 1)");
});

OwningOpRef<ModuleOp> owning =
Expand All @@ -669,6 +670,50 @@ om.class @ReferenceEachOther() -> (field: !ty){
ASSERT_TRUE(failed(result));
}

// Test nested object field references.
// https://github.com/llvm/circt/issues/10264
TEST(EvaluatorTests, Issue10264NestedFieldReferences) {
StringRef mod = R"MLIR(
om.class @Domain(%in: !om.string) -> (out: !om.string) {
om.class.fields %in : !om.string
}

om.class @Top() -> (test: i1) {
%0 = om.constant "A" : !om.string
%1 = om.object @Domain(%0) : (!om.string) -> !om.class.type<@Domain>
%2 = om.object.field %1, [@out] : (!om.class.type<@Domain>) -> !om.string
%3 = om.object @Domain(%2) : (!om.string) -> !om.class.type<@Domain>
%4 = om.object.field %3, [@out] : (!om.class.type<@Domain>) -> !om.string
%5 = om.constant "B" : !om.string
%6 = om.prop.eq %4, %5 : !om.string
om.class.fields %6 : i1
}
)MLIR";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(StringAttr::get(&context, "Top"), {});

ASSERT_TRUE(succeeded(result));

// Verify the result is correct (false since "A" != "B")
auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("test")
.value();
auto boolValue = llvm::cast<evaluator::AttributeValue>(fieldValue.get())
->getAs<BoolAttr>();
ASSERT_FALSE(boolValue.getValue());
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very happy to see this working. Nice!


TEST(EvaluatorTests, IntegerBinaryArithmeticAdd) {
StringRef mod = R"MLIR(
om.class @IntegerBinaryArithmeticAdd() -> (result: !om.integer) {
Expand Down
Loading