Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 85 additions & 41 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static bool isCtrlMov(Operation *op) { return dyn_cast<CtrlMovOp>(op) != nullptr
static bool isPhi(Operation *op) { return dyn_cast<PhiOp>(op) != nullptr; }
static bool isReserve(Operation *op) { return dyn_cast<ReserveOp>(op) != nullptr; }
static bool isConstant(Operation *op) { return dyn_cast<ConstantOp>(op) != nullptr; }
static bool isGrantOnce(Operation *op) { return dyn_cast<GrantOnceOp>(op) != nullptr; }

// ----- placement helpers -----.
static TileLocation getTileLocation(Operation *op) {
Expand Down Expand Up @@ -120,20 +121,36 @@ static std::string getOpcode(Operation *op) {
std::string opcode = op->getName().getStringRef().str();
if (opcode.rfind("neura.", 0) == 0) opcode = opcode.substr(6);
if (isConstant(op)) return "CONSTANT";
if (isGrantOnce(op)) return "GRANT_ONCE";
std::transform(opcode.begin(), opcode.end(), opcode.begin(), ::toupper);
return opcode;
}

// Literals for CONSTANT's sources, e.g. "#10" / "#0" / "#3.0".
// Literals for CONSTANT and GRANT_ONCE's constant values, e.g. "#10" / "#0" / "#3.0".
static std::string getConstantLiteral(Operation *op) {
if (!isConstant(op)) return "";
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
return "#0";
if (isConstant(op)) {
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
return "#0";
}

if (isGrantOnce(op)) {
if (auto grant_once_op = dyn_cast<GrantOnceOp>(op)) {
if (auto constant_value = grant_once_op.getConstantValue()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(*constant_value))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(*constant_value))
return "#" + std::to_string(float_attr.getValueAsDouble());
}
}
return "#0";
}

return "";
}

// ----- Topology from Architecture -----.
Expand Down Expand Up @@ -295,8 +312,8 @@ struct GenerateCodePass
// - collect DATA_MOV and CTRL_MOV ops.
// - collect reserve_to_phi_maps (PHI's operand#0 is the reserve).
void indexIR(func::FuncOp function,
SmallVector<Operation*> &dataMovs,
SmallVector<Operation*> &ctrlMovs,
SmallVector<Operation*> &data_movs,
SmallVector<Operation*> &ctrl_movs,
DenseMap<Value, Operation*> &reserve_to_phi_map) {
function.walk([&](Operation *op) {
// placement for every op (even for mov/reserve).
Expand All @@ -308,8 +325,8 @@ struct GenerateCodePass
}

// collect forwarders.
if (isDataMov(op)) { dataMovs.push_back(op); return; }
if (isCtrlMov(op)) { ctrlMovs.push_back(op); return; }
if (isDataMov(op)) { data_movs.push_back(op); return; }
if (isCtrlMov(op)) { ctrl_movs.push_back(op); return; }

// skip Reserve from materialization.
if (isReserve(op)) return;
Expand All @@ -322,7 +339,7 @@ struct GenerateCodePass
Instruction inst(opcode);
inst.time_step = placement.time_step;

if (isConstant(op)) {
if (isConstant(op) || isGrantOnce(op)) {
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else {
SmallVector<Value> operands; operands.reserve(op->getNumOperands());
Expand Down Expand Up @@ -571,21 +588,33 @@ struct GenerateCodePass
for (const Tile &core : config.cores) {
yaml_out << " - column: " << core.col_idx << "\n row: " << core.row_idx
<< "\n core_id: \"" << core.core_id << "\"\n entries:\n";
int entry_id = 0;

// Group instructions by timestep
std::map<int, std::vector<const Instruction*>> timestep_groups;
for (const Instruction &inst : core.entry.instructions) {
yaml_out << " - entry_id: \"entry" << entry_id++ << "\"\n instructions:\n"
<< " - opcode: \"" << inst.opcode << "\"\n timestep: " << inst.time_step << "\n";
// sources.
if (!inst.src_operands.empty()) {
yaml_out << " src_operands:\n";
for (const Operand &opnd : inst.src_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
// destinations.
if (!inst.dst_operands.empty()) {
yaml_out << " dst_operands:\n";
for (const Operand &opnd : inst.dst_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
timestep_groups[inst.time_step].push_back(&inst);
}

yaml_out << " - entry_id: \"entry0\"\n instructions:\n";
for (const auto &timestep_pair : timestep_groups) {
int timestep = timestep_pair.first;
const auto &operations = timestep_pair.second;

yaml_out << " - timestep: " << timestep << "\n operations:\n";
for (const Instruction *inst : operations) {
yaml_out << " - opcode: \"" << inst->opcode << "\"\n";
// sources.
if (!inst->src_operands.empty()) {
yaml_out << " src_operands:\n";
for (const Operand &opnd : inst->src_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
// destinations.
if (!inst->dst_operands.empty()) {
yaml_out << " dst_operands:\n";
for (const Operand &opnd : inst->dst_operands)
yaml_out << " - operand: \"" << opnd.operand << "\"\n color: \"" << opnd.color << "\"\n";
}
}
}
}
Expand Down Expand Up @@ -614,17 +643,32 @@ struct GenerateCodePass

for (const Tile &core : config.cores) {
asm_out << "PE(" << core.col_idx << "," << core.row_idx << "):\n";

// Group instructions by timestep
std::map<int, std::vector<const Instruction*>> timestep_groups;
for (const Instruction &inst : core.entry.instructions) {
asm_out << "{\n " << inst.opcode;
for (const Operand &operand : inst.src_operands) asm_out << ", " << formatOperand(operand);
if (!inst.dst_operands.empty()) {
asm_out << " -> ";
for (size_t i = 0; i < inst.dst_operands.size(); ++i) {
if (i > 0) asm_out << ", ";
asm_out << formatOperand(inst.dst_operands[i]);
timestep_groups[inst.time_step].push_back(&inst);
}

for (const auto &timestep_pair : timestep_groups) {
int timestep = timestep_pair.first;
const auto &instructions = timestep_pair.second;

asm_out << "{\n";
for (size_t i = 0; i < instructions.size(); ++i) {
const Instruction *inst = instructions[i];
asm_out << " " << inst->opcode;
for (const Operand &operand : inst->src_operands) asm_out << ", " << formatOperand(operand);
if (!inst->dst_operands.empty()) {
asm_out << " -> ";
for (size_t j = 0; j < inst->dst_operands.size(); ++j) {
if (j > 0) asm_out << ", ";
asm_out << formatOperand(inst->dst_operands[j]);
}
}
asm_out << "\n";
}
asm_out << " (t=" << inst.time_step << ")\n}\n";
asm_out << "} (t=" << timestep << ")\n";
}
asm_out << "\n";
}
Expand Down Expand Up @@ -693,15 +737,15 @@ struct GenerateCodePass
clearState();

// Single function-level walks: index + materialize + collect.
SmallVector<Operation*> dataMovs;
SmallVector<Operation*> ctrlMovs;
SmallVector<Operation*> data_movs;
SmallVector<Operation*> ctrl_movs;
DenseMap<Value, Operation*> reserve_to_phi_map;
indexIR(func, dataMovs, ctrlMovs, reserve_to_phi_map);
indexIR(func, data_movs, ctrl_movs, reserve_to_phi_map);

// Expand forwarders without re-walking IR.
for (Operation *op : dataMovs)
for (Operation *op : data_movs)
expandMovImpl<false>(op, topo, /*unused*/reserve_to_phi_map);
for (Operation *op : ctrlMovs)
for (Operation *op : ctrl_movs)
expandMovImpl<true>(op, topo, reserve_to_phi_map);

// Debug unresolveds, then dump outputs.
Expand Down
Loading