Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 88 additions & 48 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,27 @@ static std::string getOpcode(Operation *op) {
return opcode;
}

// Literals for CONSTANT's sources, e.g. "#10" / "#0" / "#3.0".
// Literals for CONSTANT operations, e.g. "#10" / "#0" / "#3.0".
static std::string getConstantLiteral(Operation *op) {
if (!isConstant(op)) return "";
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
if (isConstant(op)) {
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
return "#0";
}

// Checks for constant_value attribute in non-CONSTANT operations.
if (auto constant_value_attr = op->getAttr("constant_value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
if (auto float_attr = dyn_cast<FloatAttr>(constant_value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
return "#0";

return "";
}

// ----- Topology from Architecture -----.
Expand Down Expand Up @@ -295,8 +306,8 @@ struct GenerateCodePass
// - collect DATA_MOV and CTRL_MOV ops.
// - collect reserve_to_phi_maps (PHI's operand#0 is the reserve).
void indexIR(func::FuncOp function,
SmallVector<Operation*> &dataMovs,
SmallVector<Operation*> &ctrlMovs,
SmallVector<Operation*> &data_movs,
SmallVector<Operation*> &ctrl_movs,
DenseMap<Value, Operation*> &reserve_to_phi_map) {
function.walk([&](Operation *op) {
// placement for every op (even for mov/reserve).
Expand All @@ -308,8 +319,8 @@ struct GenerateCodePass
}

// collect forwarders.
if (isDataMov(op)) { dataMovs.push_back(op); return; }
if (isCtrlMov(op)) { ctrlMovs.push_back(op); return; }
if (isDataMov(op)) { data_movs.push_back(op); return; }
if (isCtrlMov(op)) { ctrl_movs.push_back(op); return; }

// skip Reserve from materialization.
if (isReserve(op)) return;
Expand All @@ -324,7 +335,11 @@ struct GenerateCodePass

if (isConstant(op)) {
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else if (op->getAttr("constant_value")) {
// Checks if operation has constant_value attribute (for non-CONSTANT operations).
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else {
// Handles normal operands.
SmallVector<Value> operands; operands.reserve(op->getNumOperands());
for (Value v : op->getOperands()) {
operands.push_back(v);
Expand All @@ -348,7 +363,7 @@ struct GenerateCodePass
static SmallVector<LinkStep, 8> getLinkChain(Operation *forwarder) { return collectLinkSteps(forwarder); }
static SmallVector<RegStep, 4> getRegisterSteps(Operation *forwarder) { return collectRegSteps(forwarder); }

// Validate forwarder op arities: DATA_MOV: at least 1 in/1 out; CTRL_MOV: at least 2 inputs (src,reserve).
// Validates forwarder op arities: DATA_MOV: at least 1 in/1 out; CTRL_MOV: at least 2 inputs (src,reserve).
template<bool IsCtrl>
bool validateForwarderShape(Operation *forwarder) {
if constexpr (!IsCtrl) {
Expand All @@ -358,7 +373,7 @@ struct GenerateCodePass
}
}

// Compute producer first-hop directions and consumer last-hop directions (or LOCAL if link-less).
// Computes producer first-hop directions and consumer last-hop directions (or LOCAL if link-less).
std::pair<StringRef, StringRef> computeDirections(const SmallVector<LinkStep, 8> &links, const Topology &topo) {
StringRef producer_direction("LOCAL");
StringRef consumer_direction("LOCAL");
Expand All @@ -369,7 +384,7 @@ struct GenerateCodePass
return {producer_direction, consumer_direction};
}

// Add producer endpoints (first-hop directions or local $reg when using same-tile register paths).
// Adds producer endpoints (first-hop directions or local $reg when using same-tile register paths).
void setProducerDestination(Operation *producer, StringRef producer_direction, const SmallVector<RegStep, 4> &regs) {
if (auto *pi = getInstructionPointer(producer)) {
if (!producer_direction.empty() && producer_direction != "LOCAL") {
Expand All @@ -380,7 +395,7 @@ struct GenerateCodePass
}
}

// Emit router hops for multi-hop paths (from the second hop onwards). CTRL_MOV emits CTRL_MOV hops.
// Emits router hops for multi-hop paths (from the second hop onwards). CTRL_MOV emits CTRL_MOV hops.
template<bool IsCtrl>
void generateIntermediateHops(const SmallVector<LinkStep, 8> &links, const Topology &topo) {
for (size_t i = 1; i < links.size(); ++i) {
Expand Down Expand Up @@ -487,7 +502,7 @@ struct GenerateCodePass
consumers = collectDataMovConsumers(forwarder);
}

// Wire each consumer: prefer register rewiring; fallback to direction rewiring.
// Wires each consumer: prefer register rewiring; fallback to direction rewiring.
for (auto &[consOp, atVal] : consumers) {
if (!handleRegisterRewiring<IsCtrl>(consOp, atVal, regs, links, topo))
handleDirectionRewiring<IsCtrl>(consOp, atVal, consumer_direction, links, forwarder);
Expand Down Expand Up @@ -542,7 +557,7 @@ struct GenerateCodePass
ArrayConfig config{columns, rows, {}};
std::map<std::pair<int,int>, std::vector<Instruction>> tile_insts;

// Flatten and sort by timesteps.
// Flattens and sorts by timesteps.
for (auto &[tile_key, timestep_map] : tile_time_instructions) {
auto &flat = tile_insts[tile_key];
for (auto &[timestep, instruction_vec] : timestep_map) for (Instruction &inst : instruction_vec) flat.push_back(inst);
Expand Down Expand Up @@ -571,21 +586,33 @@ struct GenerateCodePass
for (const Tile &core : config.cores) {
yaml_out << " - column: " << core.col_idx << "\n row: " << core.row_idx
<< "\n core_id: \"" << core.core_id << "\"\n entries:\n";
int entry_id = 0;

// Groups instructions by timestep.
std::map<int, std::vector<const Instruction*>> timestep_groups;
for (const Instruction &inst : core.entry.instructions) {
yaml_out << " - entry_id: \"entry" << entry_id++ << "\"\n instructions:\n"
<< " - opcode: \"" << inst.opcode << "\"\n timestep: " << inst.time_step << "\n";
// sources.
if (!inst.src_operands.empty()) {
yaml_out << " src_operands:\n";
for (const Operand &opnd : inst.src_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
// destinations.
if (!inst.dst_operands.empty()) {
yaml_out << " dst_operands:\n";
for (const Operand &opnd : inst.dst_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
timestep_groups[inst.time_step].push_back(&inst);
}

yaml_out << " - entry_id: \"entry0\"\n instructions:\n";
for (const auto &timestep_pair : timestep_groups) {
int timestep = timestep_pair.first;
const auto &operations = timestep_pair.second;

yaml_out << " - timestep: " << timestep << "\n operations:\n";
for (const Instruction *inst : operations) {
yaml_out << " - opcode: \"" << inst->opcode << "\"\n";
// sources.
if (!inst->src_operands.empty()) {
yaml_out << " src_operands:\n";
for (const Operand &opnd : inst->src_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
// destinations.
if (!inst->dst_operands.empty()) {
yaml_out << " dst_operands:\n";
for (const Operand &opnd : inst->dst_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
}
}
}
Expand Down Expand Up @@ -614,17 +641,32 @@ struct GenerateCodePass

for (const Tile &core : config.cores) {
asm_out << "PE(" << core.col_idx << "," << core.row_idx << "):\n";

// Groups instructions by timestep.
std::map<int, std::vector<const Instruction*>> timestep_groups;
for (const Instruction &inst : core.entry.instructions) {
asm_out << "{\n " << inst.opcode;
for (const Operand &operand : inst.src_operands) asm_out << ", " << formatOperand(operand);
if (!inst.dst_operands.empty()) {
asm_out << " -> ";
for (size_t i = 0; i < inst.dst_operands.size(); ++i) {
if (i > 0) asm_out << ", ";
asm_out << formatOperand(inst.dst_operands[i]);
timestep_groups[inst.time_step].push_back(&inst);
}

for (const auto &timestep_pair : timestep_groups) {
int timestep = timestep_pair.first;
const auto &instructions = timestep_pair.second;

asm_out << "{\n";
for (size_t i = 0; i < instructions.size(); ++i) {
const Instruction *inst = instructions[i];
asm_out << " " << inst->opcode;
for (const Operand &operand : inst->src_operands) asm_out << ", " << formatOperand(operand);
if (!inst->dst_operands.empty()) {
asm_out << " -> ";
for (size_t j = 0; j < inst->dst_operands.size(); ++j) {
if (j > 0) asm_out << ", ";
asm_out << formatOperand(inst->dst_operands[j]);
}
}
asm_out << "\n";
}
asm_out << " (t=" << inst.time_step << ")\n}\n";
asm_out << "} (t=" << timestep << ")\n";
}
asm_out << "\n";
}
Expand Down Expand Up @@ -658,8 +700,8 @@ struct GenerateCodePass
return &vec[idx];
}

// Replace the exact source slots in consumers that correspond to `value_at_consumer`,
// or fill the first UNRESOLVED placeholder if a 1:1 match wasn't found.
// Replaces the exact source slots in consumers that correspond to `value_at_consumer`,
// or fills the first UNRESOLVED placeholder if a 1:1 match wasn't found.
void setConsumerSourceExact(Operation *consumer, Value value_at_consumer, const std::string &text) {
Instruction *ci = getInstructionPointer(consumer);
if (!ci) return;
Expand Down Expand Up @@ -693,18 +735,16 @@ struct GenerateCodePass
clearState();

// Single function-level walks: index + materialize + collect.
SmallVector<Operation*> dataMovs;
SmallVector<Operation*> ctrlMovs;
SmallVector<Operation*> data_movs;
SmallVector<Operation*> ctrl_movs;
DenseMap<Value, Operation*> reserve_to_phi_map;
indexIR(func, dataMovs, ctrlMovs, reserve_to_phi_map);
indexIR(func, data_movs, ctrl_movs, reserve_to_phi_map);

// Expand forwarders without re-walking IR.
for (Operation *op : dataMovs)
// Expands forwarders without re-walking IR.
for (Operation *op : data_movs)
expandMovImpl<false>(op, topo, /*unused*/reserve_to_phi_map);
for (Operation *op : ctrlMovs)
for (Operation *op : ctrl_movs)
expandMovImpl<true>(op, topo, reserve_to_phi_map);

// Debug unresolveds, then dump outputs.
logUnresolvedOperands();

ArrayConfig config = buildArrayConfig(columns, rows);
Expand Down
Loading