Skip to content
67 changes: 45 additions & 22 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ struct GenerateCodePass
if (isReserve(op)) return;

// materialize all other ops placed on tiles (compute/phi/const/etc.).
auto placement = operation_placements[op];
TileLocation placement = operation_placements[op];
if (!placement.has_tile) return;

std::string opcode = getOpcode(op);
Expand Down Expand Up @@ -505,7 +505,7 @@ struct GenerateCodePass
// Try register-based rewiring. If cross-tile, emit deposits [incoming_dir]->[$reg] at earliest reg ts.
// Returns true if rewiring to $reg was applied to consumers.
template<bool IsCtrl>
bool handleRegisterRewiring(Operation *consOp, Value atVal, const SmallVector<RegStep, 4> &regs,
bool handleRegisterRewiring(Operation *consumer_operation, Value value_at_consumer, const SmallVector<RegStep, 4> &regs,
const SmallVector<LinkStep, 8> &links, const Topology &topo) {
if (regs.empty()) return false;

Expand All @@ -515,27 +515,46 @@ struct GenerateCodePass
if (!links.empty()) {
// Cross-tile: deposit on destination tile at earliest register ts.
int dst_tile = topo.dstTileOfLink(links.back().link_id);
StringRef incoming_dir = topo.dirFromLink(links.back().link_id);
// Computes incoming direction from destination tile's perspective.
StringRef incoming_dir = topo.invertDir(topo.dirFromLink(links.back().link_id));
placeDstDeposit(topo, dst_tile, timestep_0, incoming_dir, register_id, /*asCtrlMov=*/IsCtrl);

auto cp = operation_placements.lookup(consOp);
if (cp.has_tile && cp.time_step > timestep_0) {
setConsumerSourceExact(consOp, atVal, "$" + std::to_string(register_id));
TileLocation consumer_placement = operation_placements.lookup(consumer_operation);
if (consumer_placement.has_tile && consumer_placement.time_step > timestep_0) {
setConsumerSourceExact(consumer_operation, value_at_consumer, "$" + std::to_string(register_id));
return true;
}
} else {
// Same-tile: must go via register.
setConsumerSourceExact(consOp, atVal, "$" + std::to_string(register_id));
setConsumerSourceExact(consumer_operation, value_at_consumer, "$" + std::to_string(register_id));
return true;
}
return false;
}

template<bool IsCtrl>
void handleDirectionRewiring(Operation *consOp, Value atVal, StringRef consumer_direction,
const SmallVector<LinkStep, 8> &links, Operation *forwarder) {
void handleDirectionRewiring(Operation *consumer_operation, Value value_at_consumer, StringRef consumer_direction,
const SmallVector<LinkStep, 8> &links, const Topology &topo,
Operation *forwarder) {
if (!links.empty()) {
setConsumerSourceExact(consOp, atVal, consumer_direction.str());
// Computes the direction from the link destination tile to the consumer tile.
TileLocation consumer_placement = operation_placements.lookup(consumer_operation);
if (consumer_placement.has_tile) {
int dst_tile_id = topo.dstTileOfLink(links.back().link_id);
int consumer_tile_id = topo.tileIdAt(consumer_placement.col_idx, consumer_placement.row_idx);

// If consumer is on the link destination tile, use the incoming direction.
if (consumer_tile_id == dst_tile_id) {
setConsumerSourceExact(consumer_operation, value_at_consumer, consumer_direction.str());
} else {
// Computes direction from link destination tile to consumer tile.
StringRef actual_dir = topo.invertDir(topo.getDirBetween(dst_tile_id, consumer_tile_id));
setConsumerSourceExact(consumer_operation, value_at_consumer, actual_dir.str());
}
} else {
// Falls back to consumer_direction if consumer placement is unknown.
setConsumerSourceExact(consumer_operation, value_at_consumer, consumer_direction.str());
}
} else {
forwarder->emitError(IsCtrl
? "same-tile ctrl_mov without register mapping is illegal. Provide a register in mapping_locs."
Expand All @@ -552,9 +571,11 @@ struct GenerateCodePass
// Basic info from forwarders.
Value source = forwarder->getOperand(0);
Operation *producer = source.getDefiningOp();
auto links = getLinkChain(forwarder);
auto regs = getRegisterSteps(forwarder);
auto [producer_direction, consumer_direction] = computeDirections(links, topo);
SmallVector<LinkStep, 8> links = getLinkChain(forwarder);
SmallVector<RegStep, 4> regs = getRegisterSteps(forwarder);
std::pair<StringRef, StringRef> directions = computeDirections(links, topo);
StringRef producer_direction = directions.first;
StringRef consumer_direction = directions.second;

// Producer endpoints & intermediate hops.
setProducerDestination(producer, producer_direction, regs);
Expand All @@ -570,9 +591,11 @@ struct GenerateCodePass
}

// Wires each consumer: prefer register rewiring; fallback to direction rewiring.
for (auto &[consOp, atVal] : consumers) {
if (!handleRegisterRewiring<IsCtrl>(consOp, atVal, regs, links, topo))
handleDirectionRewiring<IsCtrl>(consOp, atVal, consumer_direction, links, forwarder);
for (std::pair<Operation*, Value> &consumer_pair : consumers) {
Operation *consumer_operation = consumer_pair.first;
Value value_at_consumer = consumer_pair.second;
if (!handleRegisterRewiring<IsCtrl>(consumer_operation, value_at_consumer, regs, links, topo))
handleDirectionRewiring<IsCtrl>(consumer_operation, value_at_consumer, consumer_direction, links, topo, forwarder);
}
}

Expand Down Expand Up @@ -750,15 +773,15 @@ struct GenerateCodePass

// Endpoint deposits: on destination tiles at earliest reg ts, move [incoming_dir] -> [$reg].
// CTRL_MOV paths emit CTRL_MOV deposits; DATA_MOV paths emit DATA_MOV deposits.
void placeDstDeposit(const Topology &topo, int dstTileId, int ts,
StringRef incomingDir, int regId, bool asCtrlMov = false) {
uint64_t signature = (uint64_t)dstTileId << 32 ^ (uint64_t)ts << 16 ^ (uint64_t)regId;
void placeDstDeposit(const Topology &topo, int dst_tile_id, int ts,
StringRef incoming_dir, int reg_id, bool asCtrlMov = false) {
uint64_t signature = (uint64_t)dst_tile_id << 32 ^ (uint64_t)ts << 16 ^ (uint64_t)reg_id;
if (!deposit_signatures.insert(signature).second) return; // already placed.
auto [tile_x, tile_y] = topo.tile_location.lookup(dstTileId);
auto [tile_x, tile_y] = topo.tile_location.lookup(dst_tile_id);
Instruction inst(asCtrlMov ? "CTRL_MOV" : "DATA_MOV");
inst.time_step = ts;
inst.src_operands.emplace_back(incomingDir.str(), "RED");
inst.dst_operands.emplace_back("$" + std::to_string(regId), "RED");
inst.src_operands.emplace_back(incoming_dir.str(), "RED");
inst.dst_operands.emplace_back("$" + std::to_string(reg_id), "RED");
tile_time_instructions[{tile_x, tile_y}][ts].push_back(std::move(inst));
}

Expand Down
79 changes: 43 additions & 36 deletions test/e2e/bicg/bicg_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -130,40 +130,40 @@
// AFTER_CANONICALIZE-NEXT: ^bb4(%28: i32, %29: i32, %30: i64): // pred: ^bb1
// AFTER_CANONICALIZE-NEXT: %31 = neura.zext %28 : i32 -> i64
// AFTER_CANONICALIZE-NEXT: %32 = neura.zext %29 : i32 -> i64
// AFTER_CANONICALIZE-NEXT: neura.br %30, %30, %31 : i64, i64, i64 to ^bb5
// AFTER_CANONICALIZE-NEXT: ^bb5(%33: i64, %34: i64, %35: i64): // 2 preds: ^bb4, ^bb7
// AFTER_CANONICALIZE-NEXT: %36 = "neura.gep"(%33) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg4"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: "neura.store"(%36) {lhs_value = 0.000000e+00 : f64} : (!llvm.ptr) -> ()
// AFTER_CANONICALIZE-NEXT: %37 = "neura.gep"(%33) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg6"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: neura.br %34, %33, %35, %34 : i64, i64, i64, i64 to ^bb6
// AFTER_CANONICALIZE-NEXT: ^bb6(%38: i64, %39: i64, %40: i64, %41: i64): // 2 preds: ^bb5, ^bb6
// AFTER_CANONICALIZE-NEXT: %42 = "neura.gep"(%38) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg3"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: %43 = "neura.load"(%42) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %44 = "neura.load"(%37) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %45 = "neura.gep"(%39, %38) <{operandSegmentSizes = array<i32: 0, 2>}> {lhs_value = "%arg2"} : (i64, i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: %46 = "neura.load"(%45) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %47 = "neura.fmul_fadd"(%44, %46, %43) : (f64, f64, f64) -> f64
// AFTER_CANONICALIZE-NEXT: "neura.store"(%47, %42) : (f64, !llvm.ptr) -> ()
// AFTER_CANONICALIZE-NEXT: %48 = "neura.load"(%36) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %49 = "neura.load"(%45) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %50 = "neura.gep"(%38) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg5"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: %51 = "neura.load"(%50) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %52 = "neura.fmul_fadd"(%49, %51, %48) : (f64, f64, f64) -> f64
// AFTER_CANONICALIZE-NEXT: "neura.store"(%52, %36) : (f64, !llvm.ptr) -> ()
// AFTER_CANONICALIZE-NEXT: %53 = "neura.add"(%38) {rhs_value = 1 : i64} : (i64) -> i64
// AFTER_CANONICALIZE-NEXT: %54 = "neura.icmp"(%53, %32) <{cmpType = "eq"}> : (i64, i64) -> i1
// AFTER_CANONICALIZE-NEXT: neura.cond_br %54 : i1 then %39, %40, %41 : i64, i64, i64 to ^bb7 else %53, %39, %40, %41 : i64, i64, i64, i64 to ^bb6
// AFTER_CANONICALIZE-NEXT: ^bb7(%55: i64, %56: i64, %57: i64): // pred: ^bb6
// AFTER_CANONICALIZE-NEXT: %58 = "neura.add"(%55) {rhs_value = 1 : i64} : (i64) -> i64
// AFTER_CANONICALIZE-NEXT: %59 = "neura.icmp"(%58, %56) <{cmpType = "eq"}> : (i64, i64) -> i1
// AFTER_CANONICALIZE-NEXT: neura.cond_br %59 : i1 then to ^bb8 else %58, %57, %56 : i64, i64, i64 to ^bb5
// AFTER_CANONICALIZE-NEXT: neura.br %30, %30, %32, %31 : i64, i64, i64, i64 to ^bb5
// AFTER_CANONICALIZE-NEXT: ^bb5(%33: i64, %34: i64, %35: i64, %36: i64): // 2 preds: ^bb4, ^bb7
// AFTER_CANONICALIZE-NEXT: %37 = "neura.gep"(%33) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg4"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: "neura.store"(%37) {lhs_value = 0.000000e+00 : f64} : (!llvm.ptr) -> ()
// AFTER_CANONICALIZE-NEXT: %38 = "neura.gep"(%33) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg6"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: neura.br %34, %38, %33, %37, %35, %36, %34 : i64, !llvm.ptr, i64, !llvm.ptr, i64, i64, i64 to ^bb6
// AFTER_CANONICALIZE-NEXT: ^bb6(%39: i64, %40: !llvm.ptr, %41: i64, %42: !llvm.ptr, %43: i64, %44: i64, %45: i64): // 2 preds: ^bb5, ^bb6
// AFTER_CANONICALIZE-NEXT: %46 = "neura.gep"(%39) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg3"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: %47 = "neura.load"(%46) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %48 = "neura.load"(%40) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %49 = "neura.gep"(%41, %39) <{operandSegmentSizes = array<i32: 0, 2>}> {lhs_value = "%arg2"} : (i64, i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: %50 = "neura.load"(%49) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %51 = "neura.fmul_fadd"(%48, %50, %47) : (f64, f64, f64) -> f64
// AFTER_CANONICALIZE-NEXT: "neura.store"(%51, %46) : (f64, !llvm.ptr) -> ()
// AFTER_CANONICALIZE-NEXT: %52 = "neura.load"(%42) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %53 = "neura.load"(%49) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %54 = "neura.gep"(%39) <{operandSegmentSizes = array<i32: 0, 1>}> {lhs_value = "%arg5"} : (i64) -> !llvm.ptr
// AFTER_CANONICALIZE-NEXT: %55 = "neura.load"(%54) : (!llvm.ptr) -> f64
// AFTER_CANONICALIZE-NEXT: %56 = "neura.fmul_fadd"(%53, %55, %52) : (f64, f64, f64) -> f64
// AFTER_CANONICALIZE-NEXT: "neura.store"(%56, %42) : (f64, !llvm.ptr) -> ()
// AFTER_CANONICALIZE-NEXT: %57 = "neura.add"(%39) {rhs_value = 1 : i64} : (i64) -> i64
// AFTER_CANONICALIZE-NEXT: %58 = "neura.icmp"(%57, %43) <{cmpType = "eq"}> : (i64, i64) -> i1
// AFTER_CANONICALIZE-NEXT: neura.cond_br %58 : i1 then %41, %44, %45, %43 : i64, i64, i64, i64 to ^bb7 else %57, %40, %41, %42, %43, %44, %45 : i64, !llvm.ptr, i64, !llvm.ptr, i64, i64, i64 to ^bb6
// AFTER_CANONICALIZE-NEXT: ^bb7(%59: i64, %60: i64, %61: i64, %62: i64): // pred: ^bb6
// AFTER_CANONICALIZE-NEXT: %63 = "neura.add"(%59) {rhs_value = 1 : i64} : (i64) -> i64
// AFTER_CANONICALIZE-NEXT: %64 = "neura.icmp"(%63, %60) <{cmpType = "eq"}> : (i64, i64) -> i1
// AFTER_CANONICALIZE-NEXT: neura.cond_br %64 : i1 then to ^bb8 else %63, %61, %62, %60 : i64, i64, i64, i64 to ^bb5
// AFTER_CANONICALIZE-NEXT: ^bb8: // 4 preds: ^bb1, ^bb2, ^bb3, ^bb7
// AFTER_CANONICALIZE-NEXT: "neura.return"() : () -> ()
// AFTER_CANONICALIZE-NEXT: }

//MAPPING: func.func @kernel
//MAPPING-SAME: accelerator = "neura", dataflow_mode = "predicate"
//MAPPING-SAME: mapping_info = {compiled_ii = 12 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 9 : i32, res_mii = 5 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}
//MAPPING-SAME: mapping_info = {compiled_ii = 12 : i32, mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 9 : i32, res_mii = 6 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32}

// YAML: array_config:
// YAML-NEXT: columns: 4
Expand Down Expand Up @@ -194,17 +194,24 @@
// ASM-NEXT: GRANT_ONCE, [] -> [EAST, RED], [NORTH, RED]
// ASM-NEXT: } (t=2)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0]
// ASM-NEXT: } (t=4)
// ASM-NEXT: GRANT_ONCE, [#0] -> [NORTH, RED]
// ASM-NEXT: } (t=3)
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$1]
// ASM-NEXT: } (t=5)
// ASM-NEXT: ICMP_SGT, [EAST, RED], [#0] -> [NORTH, RED], [EAST, RED], [$0]
// ASM-NEXT: } (t=4)
// ASM-NEXT: {
// ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [$0]
// ASM-NEXT: } (t=7)
// ASM-NEXT: GRANT_PREDICATE, [NORTH, RED], [$0] -> [NORTH, RED], [$0]
// ASM-NEXT: } (t=6)
// ASM-NEXT: {
// ASM-NEXT: ZEXT, [$0] -> [NORTH, RED]
// ASM-NEXT: PHI, [EAST, RED], [$0] -> [$1], [NORTH, RED], [$0]
// ASM-NEXT: } (t=8)
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [] -> [NORTH, RED]
// ASM-NEXT: PHI, [$2], [$0] -> [NORTH, RED], [$2], [EAST, RED]
// ASM-NEXT: } (t=9)
// ASM-NEXT: {
// ASM-NEXT: GEP, [$1] -> [NORTH, RED], [EAST, RED]
// ASM-NEXT: } (t=10)
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [] -> [EAST, RED]
// ASM-NEXT: } (t=11)

2 changes: 1 addition & 1 deletion test/e2e/fir/fir_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
// ASM-NEXT: } (t=3)
// ASM: PE(2,2):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0]
// ASM-NEXT: DATA_MOV, [EAST, RED] -> [$0]
// ASM-NEXT: } (t=2)
// ASM-NEXT: {
// ASM-NEXT: ICMP_EQ, [EAST, RED], [#32] -> [$0], [WEST, RED]
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/relu/relu_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@

// ASM: PE(2,1):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [SOUTH, RED] -> [$1]
// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [$1]
// ASM-NEXT: } (t=5)
Loading