Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 79 additions & 37 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,27 @@ static std::string getOpcode(Operation *op) {
return opcode;
}

// Literals for CONSTANT's sources, e.g. "#10" / "#0" / "#3.0".
// Literals for CONSTANT operations, e.g. "#10" / "#0" / "#3.0".
static std::string getConstantLiteral(Operation *op) {
if (!isConstant(op)) return "";
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
if (isConstant(op)) {
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
return "#0";
}

// Check for constant_value attribute in non-CONSTANT operations
if (auto constant_value_attr = op->getAttr("constant_value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
if (auto float_attr = dyn_cast<FloatAttr>(constant_value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
return "#0";

return "";
}

// ----- Topology from Architecture -----.
Expand Down Expand Up @@ -295,8 +306,8 @@ struct GenerateCodePass
// - collect DATA_MOV and CTRL_MOV ops.
// - collect reserve_to_phi_maps (PHI's operand#0 is the reserve).
void indexIR(func::FuncOp function,
SmallVector<Operation*> &dataMovs,
SmallVector<Operation*> &ctrlMovs,
SmallVector<Operation*> &data_movs,
SmallVector<Operation*> &ctrl_movs,
DenseMap<Value, Operation*> &reserve_to_phi_map) {
function.walk([&](Operation *op) {
// placement for every op (even for mov/reserve).
Expand All @@ -308,8 +319,8 @@ struct GenerateCodePass
}

// collect forwarders.
if (isDataMov(op)) { dataMovs.push_back(op); return; }
if (isCtrlMov(op)) { ctrlMovs.push_back(op); return; }
if (isDataMov(op)) { data_movs.push_back(op); return; }
if (isCtrlMov(op)) { ctrl_movs.push_back(op); return; }

// skip Reserve from materialization.
if (isReserve(op)) return;
Expand All @@ -324,7 +335,11 @@ struct GenerateCodePass

if (isConstant(op)) {
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else if (op->getAttr("constant_value")) {
// Check if operation has constant_value attribute (for non-CONSTANT operations)
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else {
// Handle normal operands
SmallVector<Value> operands; operands.reserve(op->getNumOperands());
for (Value v : op->getOperands()) {
operands.push_back(v);
Expand Down Expand Up @@ -571,21 +586,33 @@ struct GenerateCodePass
for (const Tile &core : config.cores) {
yaml_out << " - column: " << core.col_idx << "\n row: " << core.row_idx
<< "\n core_id: \"" << core.core_id << "\"\n entries:\n";
int entry_id = 0;

// Group instructions by timestep
std::map<int, std::vector<const Instruction*>> timestep_groups;
for (const Instruction &inst : core.entry.instructions) {
yaml_out << " - entry_id: \"entry" << entry_id++ << "\"\n instructions:\n"
<< " - opcode: \"" << inst.opcode << "\"\n timestep: " << inst.time_step << "\n";
// sources.
if (!inst.src_operands.empty()) {
yaml_out << " src_operands:\n";
for (const Operand &opnd : inst.src_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
// destinations.
if (!inst.dst_operands.empty()) {
yaml_out << " dst_operands:\n";
for (const Operand &opnd : inst.dst_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
timestep_groups[inst.time_step].push_back(&inst);
}

yaml_out << " - entry_id: \"entry0\"\n instructions:\n";
for (const auto &timestep_pair : timestep_groups) {
int timestep = timestep_pair.first;
const auto &operations = timestep_pair.second;

yaml_out << " - timestep: " << timestep << "\n operations:\n";
for (const Instruction *inst : operations) {
yaml_out << " - opcode: \"" << inst->opcode << "\"\n";
// sources.
if (!inst->src_operands.empty()) {
yaml_out << " src_operands:\n";
for (const Operand &opnd : inst->src_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
// destinations.
if (!inst->dst_operands.empty()) {
yaml_out << " dst_operands:\n";
for (const Operand &opnd : inst->dst_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
}
}
}
Expand Down Expand Up @@ -614,17 +641,32 @@ struct GenerateCodePass

for (const Tile &core : config.cores) {
asm_out << "PE(" << core.col_idx << "," << core.row_idx << "):\n";

// Group instructions by timestep
std::map<int, std::vector<const Instruction*>> timestep_groups;
for (const Instruction &inst : core.entry.instructions) {
asm_out << "{\n " << inst.opcode;
for (const Operand &operand : inst.src_operands) asm_out << ", " << formatOperand(operand);
if (!inst.dst_operands.empty()) {
asm_out << " -> ";
for (size_t i = 0; i < inst.dst_operands.size(); ++i) {
if (i > 0) asm_out << ", ";
asm_out << formatOperand(inst.dst_operands[i]);
timestep_groups[inst.time_step].push_back(&inst);
}

for (const auto &timestep_pair : timestep_groups) {
int timestep = timestep_pair.first;
const auto &instructions = timestep_pair.second;

asm_out << "{\n";
for (size_t i = 0; i < instructions.size(); ++i) {
const Instruction *inst = instructions[i];
asm_out << " " << inst->opcode;
for (const Operand &operand : inst->src_operands) asm_out << ", " << formatOperand(operand);
if (!inst->dst_operands.empty()) {
asm_out << " -> ";
for (size_t j = 0; j < inst->dst_operands.size(); ++j) {
if (j > 0) asm_out << ", ";
asm_out << formatOperand(inst->dst_operands[j]);
}
}
asm_out << "\n";
}
asm_out << " (t=" << inst.time_step << ")\n}\n";
asm_out << "} (t=" << timestep << ")\n";
}
asm_out << "\n";
}
Expand Down Expand Up @@ -693,15 +735,15 @@ struct GenerateCodePass
clearState();

// Single function-level walks: index + materialize + collect.
SmallVector<Operation*> dataMovs;
SmallVector<Operation*> ctrlMovs;
SmallVector<Operation*> data_movs;
SmallVector<Operation*> ctrl_movs;
DenseMap<Value, Operation*> reserve_to_phi_map;
indexIR(func, dataMovs, ctrlMovs, reserve_to_phi_map);
indexIR(func, data_movs, ctrl_movs, reserve_to_phi_map);

// Expand forwarders without re-walking IR.
for (Operation *op : dataMovs)
for (Operation *op : data_movs)
expandMovImpl<false>(op, topo, /*unused*/reserve_to_phi_map);
for (Operation *op : ctrlMovs)
for (Operation *op : ctrl_movs)
expandMovImpl<true>(op, topo, reserve_to_phi_map);

// Debug unresolveds, then dump outputs.
Expand Down
Loading