Skip to content
Merged
46 changes: 35 additions & 11 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,21 +523,45 @@ struct GenerateCodePass
// Checks if operation has constant_value attribute (for non-CONSTANT operations).
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else {
// Handles normal operands, including operations with rhs_value attribute.
// Handles normal operands and folded constants (lhs_value/rhs_value).
SmallVector<Value> operands; operands.reserve(op->getNumOperands());

// Processes actual Value operands (if any).
for (Value v : op->getOperands()) {

auto appendLiteralSlot = [&](Attribute attr) -> bool {
std::string literal = extractConstantLiteralFromAttr(attr);
if (literal.empty()) return false;
inst.src_operands.emplace_back(literal, "RED");
// Keeps index alignment with operation_to_operands for rewiring.
operands.push_back(Value());
return true;
};
auto appendValueSlot = [&](Value v) {
operands.push_back(v);
inst.src_operands.emplace_back("UNRESOLVED", "RED");
}

// Handles cases where binary operations have the RHS constant stored as an attribute.
if (auto rhs_value_attr = op->getAttr(attr::kRhsValue)) {
std::string rhs_literal = extractConstantLiteralFromAttr(rhs_value_attr);
if (!rhs_literal.empty()) {
inst.src_operands.emplace_back(rhs_literal, "RED");
};

// StoreIndexed has operand order: value(lhs) -> base(rhs) -> indices.
// rhs_value must be inserted before indices (not appended at tail).
if (auto store_indexed_op = dyn_cast<StoreIndexedOp>(op)) {
bool lhs_folded = appendLiteralSlot(op->getAttr(attr::kLhsValue));
if (!lhs_folded) appendValueSlot(store_indexed_op.getValue());

bool rhs_folded = appendLiteralSlot(op->getAttr(attr::kRhsValue));
if (!rhs_folded) {
Value base = store_indexed_op.getBase();
if (base) appendValueSlot(base);
}

for (Value index : store_indexed_op.getIndices()) {
appendValueSlot(index);
}
} else {
// Generic handling:
// - lhs_value is the leading source slot.
// - remaining Value operands keep original order.
// - rhs_value is the trailing source slot.
appendLiteralSlot(op->getAttr(attr::kLhsValue));
for (Value v : op->getOperands()) appendValueSlot(v);
appendLiteralSlot(op->getAttr(attr::kRhsValue));
}

operation_to_operands[op] = std::move(operands);
Expand Down
Loading