Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,11 +572,15 @@ struct GenerateCodePass
// Emits router hops for multi-hop paths (from the second hop onwards). CTRL_MOV emits CTRL_MOV hops.
template<bool IsCtrl>
void generateIntermediateHops(const SmallVector<LinkStep, 8> &links, const Topology &topo,
int base_mov_id, size_t &hop_counter) {
for (size_t i = 1; i < links.size(); ++i) {
int base_mov_id, size_t &hop_counter,
bool starts_with_register = false) {
// Hops start from links[1]; when the path begins with a register, timestamp hops with the previous link.
size_t begin = 1;
for (size_t i = begin; i < links.size(); ++i) {
int prev_link = links[i - 1].link_id;
int cur_link = links[i].link_id;
int ts = links[i].ts;
// If path starts with register, align hop ts to the previous link (value arrival).
int ts = starts_with_register ? links[i - 1].ts : links[i].ts;

int mid_tile = topo.srcTileOfLink(cur_link);
StringRef in = topo.invertDir(topo.dirFromLink(prev_link));
Expand Down Expand Up @@ -690,10 +694,13 @@ struct GenerateCodePass
StringRef producer_direction = directions.first;
StringRef consumer_direction = directions.second;

// Detects if the path starts with a register (same-tile register staging).
bool starts_with_register = !regs.empty();

// Producer endpoints & intermediate hops.
setProducerDestination(producer, producer_direction, regs);
size_t hop_counter = 1;
generateIntermediateHops<IsCtrl>(links, topo, mov_dfg_id, hop_counter);
generateIntermediateHops<IsCtrl>(links, topo, mov_dfg_id, hop_counter, starts_with_register);

// Gather consumers.
SmallVector<std::pair<Operation*, Value>, 2> consumers;
Expand Down Expand Up @@ -790,6 +797,18 @@ struct GenerateCodePass
return it->second;
}

// Looks up time_step for a materialized instruction by its id.
std::optional<int> getInstructionTimeById(int id) const {
for (const auto &tile_entry : tile_time_instructions) {
for (const auto &ts_entry : tile_entry.second) {
for (const Instruction &inst : ts_entry.second) {
if (inst.id == id) return inst.time_step;
}
}
}
return std::nullopt;
}

// Gets instruction ID for a materialized operation.
int getInstructionId(Operation *op) const {
auto it = operation_to_instruction_reference.find(op);
Expand Down Expand Up @@ -881,6 +900,21 @@ struct GenerateCodePass
return info;
}

// Detects whether mapping_locs starts with a register/reg resource.
static bool pathStartsWithRegister(Operation *op) {
if (auto arr = op->getAttrOfType<ArrayAttr>("mapping_locs")) {
if (!arr.empty()) {
if (auto dict = dyn_cast<DictionaryAttr>(arr[0])) {
if (auto res = dyn_cast_or_null<StringAttr>(dict.get("resource"))) {
return res.getValue() == "register" || res.getValue() == "reg";
}
}
}
}
return false;
}


struct DfgNodeInfo {
std::string opcode;
int tile_x = -1;
Expand Down Expand Up @@ -966,7 +1000,8 @@ struct GenerateCodePass
TileLocation producer_loc,
const Topology &topology,
SmallVector<std::pair<int, int>, 8> &out_tiles,
SmallVector<int, 8> &out_time_steps) const {
SmallVector<int, 8> &out_time_steps,
bool starts_with_register = false) const {
out_tiles.clear();
out_time_steps.clear();
if (link_steps.empty()) return;
Expand All @@ -975,7 +1010,8 @@ struct GenerateCodePass
? topology.tileIdAt(producer_loc.col_idx, producer_loc.row_idx) : -1;
int consumer_tile_id = topology.dstTileOfLink(link_steps.back().link_id);

for (size_t i = 0; i < link_steps.size(); ++i) {
size_t begin = starts_with_register ? 1 : 0;
for (size_t i = begin; i < link_steps.size(); ++i) {
int middle_tile_id = topology.srcTileOfLink(link_steps[i].link_id);
if (middle_tile_id == producer_tile_id || middle_tile_id == consumer_tile_id) continue;
auto coord = topology.tile_location.lookup(middle_tile_id);
Expand All @@ -984,6 +1020,7 @@ struct GenerateCodePass
continue; // Skips duplicates.
}
out_tiles.push_back(coord);
// Uses the hop's own outgoing link timestep.
out_time_steps.push_back(link_steps[i].ts);
}
}
Expand Down Expand Up @@ -1032,13 +1069,17 @@ struct GenerateCodePass
SmallVector<LinkStep, 8> link_steps = collectLinkSteps(operation);
SmallVector<std::pair<int,int>, 8> hop_tiles;
SmallVector<int, 8> hop_time_steps;
bool starts_with_register = pathStartsWithRegister(operation);
// Build hop tiles directly from link steps (mirrors router hop emission).
if (link_steps.size() > 1) {
for (size_t i = 1; i < link_steps.size(); ++i) {
size_t begin = starts_with_register ? 1 : 1; // hops are from link[1] onward
for (size_t i = begin; i < link_steps.size(); ++i) {
int middle_tile_id = topology.srcTileOfLink(link_steps[i].link_id);
auto coord = topology.tile_location.lookup(middle_tile_id);
hop_tiles.push_back(coord);
hop_time_steps.push_back(link_steps[i].ts);
// Aligns hop ts with the link's own timestep (or previous if starts with register).
int hop_ts = starts_with_register ? link_steps[i - 1].ts : link_steps[i].ts;
hop_time_steps.push_back(hop_ts);
}
}

Expand All @@ -1056,7 +1097,14 @@ struct GenerateCodePass
hop_node.opcode = isCtrlMov(operation) ? "CTRL_MOV" : "DATA_MOV";
hop_node.tile_x = hop_tiles[i].first;
hop_node.tile_y = hop_tiles[i].second;
hop_node.time_step = (i < hop_time_steps.size()) ? hop_time_steps[i] : -1;
// Prefers the materialized instruction time if available; fallback to link-based,
// but ensure it is not earlier than the computed hop_ts when present.
if (auto ts = getInstructionTimeById(node_id)) {
int link_ts = (i < hop_time_steps.size()) ? hop_time_steps[i] : *ts;
hop_node.time_step = std::max(*ts, link_ts);
} else {
hop_node.time_step = (i < hop_time_steps.size()) ? hop_time_steps[i] : -1;
}
nodes[node_id] = hop_node;
}

Expand Down
58 changes: 54 additions & 4 deletions test/code_gen/test_code_generate.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,20 @@ func.func @loop_test() -> f32 {
// ASM: PE(1,1):
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [WEST, RED], [EAST, RED] -> [SOUTH, RED] (t=7, inv_iters=1)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [WEST, RED] (t=7, inv_iters=1)
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [$1] (t=7, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [$0] (t=8, inv_iters=1)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [WEST, RED] (t=8, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: PHI_START, [WEST, RED], [$0] -> [WEST, RED], [$0] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(2,1):
// ASM-NEXT: {
// ASM-NEXT: NOT, [NORTH, RED] -> [WEST, RED] (t=6, inv_iters=1)
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [WEST, RED] (t=6, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [WEST, RED] (t=7, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=2)
// ASM: PE(0,2):
// ASM-NEXT: {
// ASM-NEXT: PHI_START, [$0], [SOUTH, RED] -> [SOUTH, RED] (t=5, inv_iters=1)
Expand All @@ -176,3 +174,55 @@ func.func @loop_test() -> f32 {
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [$1] -> [$0] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(1,2):
// ASM-NEXT: {
// ASM-NEXT: CONSTANT, [#1] -> [$0] (t=0, inv_iters=0)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [SOUTH, RED] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [$0] -> [$0] (t=1, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
// ASM-NEXT: PHI_START, [$0], [EAST, RED] -> [$0], [EAST, RED] (t=2, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: ADD, [EAST, RED], [$0] -> [EAST, RED], [NORTH, RED] (t=3, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: PHI_START, [WEST, RED], [EAST, RED] -> [EAST, RED] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(2,2):
// ASM-NEXT: {
// ASM-NEXT: ICMP_SLT, [$0], [WEST, RED] -> [NORTH, RED], [SOUTH, RED], [WEST, RED], [$0], [$1] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$2], [$0] -> [WEST, RED] (t=6, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
// ASM-NEXT: PHI_START, [EAST, RED], [NORTH, RED] -> [WEST, RED] (t=2, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$2] (t=3, inv_iters=0)
// ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [WEST, RED] (t=8, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(3,2):
// ASM-NEXT: {
// ASM-NEXT: CONSTANT, [#0] -> [$0] (t=0, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [$0] -> [WEST, RED] (t=1, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=1)
// ASM: PE(1,3):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [SOUTH, RED] -> [EAST, RED] (t=3, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=3)
// ASM: PE(2,3):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$0], [SOUTH, RED] -> [SOUTH, RED] (t=6, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=1)
10 changes: 4 additions & 6 deletions test/e2e/bicg/bicg_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,9 @@
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [arg1] -> [NORTH, RED], [$0] (t=2, inv_iters=0)
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [EAST, RED] (t=2, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [EAST, RED] (t=3, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$0], [NORTH, RED] -> [$0], [$1] (t=10, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=10)
// ASM-NEXT: {
Expand All @@ -238,14 +236,14 @@
// ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [EAST, RED] (t=15, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [EAST, RED] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [EAST, RED] (t=3, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$1] (t=11, inv_iters=0)
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [EAST, RED] (t=11, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=11)
// ASM-NEXT: {
// ASM-NEXT: NOT, [WEST, RED] -> [$0], [$1] (t=12, inv_iters=0)
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [EAST, RED] (t=12, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=12)

// RUN: mlir-neura-opt %t-kernel.mlir --view-op-graph 2>&1 | sed -n '/^digraph G {/,/^}$/p' > bicg_kernel_original.dot
Expand Down
36 changes: 32 additions & 4 deletions test/e2e/fir/fir_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@
// ASM-NEXT: } (idx_per_ii=2)
// ASM: PE(1,2):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [SOUTH, RED] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: ADD, [NORTH, RED], [SOUTH, RED] -> [SOUTH, RED], [$0] (t=6, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
Expand All @@ -155,10 +152,12 @@
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [$1] (t=3, inv_iters=0)
// ASM-NEXT: RETURN_VALUE, [$0] (t=8, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [SOUTH, RED] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(2,2):
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [EAST, RED] (t=5, inv_iters=1)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [WEST, RED] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: GEP, [EAST, RED] -> [$0] (t=2, inv_iters=0)
Expand All @@ -169,7 +168,36 @@
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: NOT, [EAST, RED] -> [$1], [WEST, RED] (t=4, inv_iters=0)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [WEST, RED] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(3,2):
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [#0] -> [$0] (t=0, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: PHI_START, [$0], [WEST, RED] -> [NORTH, RED], [WEST, RED], [$0] (t=1, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
// ASM-NEXT: ADD, [$0], [#1] -> [$0], [WEST, RED] (t=2, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: ICMP_EQ, [$0], [#32] -> [WEST, RED] (t=3, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=3)
// ASM: PE(1,3):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [SOUTH, RED] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM: PE(2,3):
// ASM-NEXT: {
// ASM-NEXT: MUL, [SOUTH, RED], [EAST, RED] -> [WEST, RED] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(3,3):
// ASM-NEXT: {
// ASM-NEXT: GEP, [SOUTH, RED] -> [$0] (t=2, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=2)
// ASM-NEXT: {
// ASM-NEXT: LOAD, [$0] -> [WEST, RED] (t=3, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=3)

// RUN: mlir-neura-opt %t-kernel.mlir --view-op-graph 2>&1 | sed -n '/^digraph G {/,/^}$/p' > fir_kernel_original.dot
// RUN: dot -Tpng fir_kernel_original.dot -o fir_kernel_original.png
Expand Down
6 changes: 2 additions & 4 deletions test/e2e/fir/fir_kernel_vec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@
// ASM-NEXT: } (idx_per_ii=3)
// ASM: PE(1,2):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [SOUTH, RED] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: VADD, [NORTH, RED], [SOUTH, RED] -> [SOUTH, RED], [$0] (t=6, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=1)
// ASM-NEXT: {
Expand All @@ -151,12 +148,12 @@
// ASM-NEXT: VECTOR.REDUCE.ADD, [$0] -> [$0] (t=8, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [SOUTH, RED] (t=4, inv_iters=0)
// ASM-NEXT: RETURN_VALUE, [$0] (t=9, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=4)
// ASM: PE(2,2):
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [EAST, RED] (t=5, inv_iters=1)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [WEST, RED] (t=5, inv_iters=1)
// ASM-NEXT: } (idx_per_ii=0)
// ASM-NEXT: {
// ASM-NEXT: GEP, [EAST, RED] -> [$0] (t=2, inv_iters=0)
Expand All @@ -167,4 +164,5 @@
// ASM-NEXT: } (idx_per_ii=3)
// ASM-NEXT: {
// ASM-NEXT: NOT, [EAST, RED] -> [$1], [WEST, RED] (t=4, inv_iters=0)
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [WEST, RED] (t=4, inv_iters=0)
// ASM-NEXT: } (idx_per_ii=4)
Loading