Skip to content
53 changes: 36 additions & 17 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ struct GenerateCodePass
// Try register-based rewiring. If cross-tile, emit deposits [incoming_dir]->[$reg] at earliest reg ts.
// Returns true if rewiring to $reg was applied to consumers.
template<bool IsCtrl>
bool handleRegisterRewiring(Operation *consOp, Value atVal, const SmallVector<RegStep, 4> &regs,
bool handleRegisterRewiring(Operation *cons_op, Value at_val, const SmallVector<RegStep, 4> &regs,
const SmallVector<LinkStep, 8> &links, const Topology &topo) {
if (regs.empty()) return false;

Expand All @@ -515,27 +515,46 @@ struct GenerateCodePass
if (!links.empty()) {
// Cross-tile: deposit on destination tile at earliest register ts.
int dst_tile = topo.dstTileOfLink(links.back().link_id);
StringRef incoming_dir = topo.dirFromLink(links.back().link_id);
// Computes incoming direction from destination tile's perspective.
StringRef incoming_dir = topo.invertDir(topo.dirFromLink(links.back().link_id));
placeDstDeposit(topo, dst_tile, timestep_0, incoming_dir, register_id, /*asCtrlMov=*/IsCtrl);

auto cp = operation_placements.lookup(consOp);
auto cp = operation_placements.lookup(cons_op);
if (cp.has_tile && cp.time_step > timestep_0) {
setConsumerSourceExact(consOp, atVal, "$" + std::to_string(register_id));
setConsumerSourceExact(cons_op, at_val, "$" + std::to_string(register_id));
return true;
}
} else {
// Same-tile: must go via register.
setConsumerSourceExact(consOp, atVal, "$" + std::to_string(register_id));
setConsumerSourceExact(cons_op, at_val, "$" + std::to_string(register_id));
return true;
}
return false;
}

template<bool IsCtrl>
void handleDirectionRewiring(Operation *consOp, Value atVal, StringRef consumer_direction,
const SmallVector<LinkStep, 8> &links, Operation *forwarder) {
void handleDirectionRewiring(Operation *cons_op, Value at_val, StringRef consumer_direction,
const SmallVector<LinkStep, 8> &links, const Topology &topo,
Operation *forwarder) {
if (!links.empty()) {
setConsumerSourceExact(consOp, atVal, consumer_direction.str());
// Computes the direction from the link destination tile to the consumer tile.
auto cp = operation_placements.lookup(cons_op);
if (cp.has_tile) {
int dst_tile_id = topo.dstTileOfLink(links.back().link_id);
int consumer_tile_id = topo.tileIdAt(cp.col_idx, cp.row_idx);

// If consumer is on the link destination tile, use the incoming direction.
if (consumer_tile_id == dst_tile_id) {
setConsumerSourceExact(cons_op, at_val, consumer_direction.str());
} else {
// Computes direction from link destination tile to consumer tile.
StringRef actual_dir = topo.invertDir(topo.getDirBetween(dst_tile_id, consumer_tile_id));
setConsumerSourceExact(cons_op, at_val, actual_dir.str());
}
} else {
// Falls back to consumer_direction if consumer placement is unknown.
setConsumerSourceExact(cons_op, at_val, consumer_direction.str());
}
} else {
forwarder->emitError(IsCtrl
? "same-tile ctrl_mov without register mapping is illegal. Provide a register in mapping_locs."
Expand Down Expand Up @@ -570,9 +589,9 @@ struct GenerateCodePass
}

// Wires each consumer: prefer register rewiring; fallback to direction rewiring.
for (auto &[consOp, atVal] : consumers) {
if (!handleRegisterRewiring<IsCtrl>(consOp, atVal, regs, links, topo))
handleDirectionRewiring<IsCtrl>(consOp, atVal, consumer_direction, links, forwarder);
for (auto &[cons_op, at_val] : consumers) {
if (!handleRegisterRewiring<IsCtrl>(cons_op, at_val, regs, links, topo))
handleDirectionRewiring<IsCtrl>(cons_op, at_val, consumer_direction, links, topo, forwarder);
}
}

Expand Down Expand Up @@ -750,15 +769,15 @@ struct GenerateCodePass

// Endpoint deposits: on destination tiles at earliest reg ts, move [incoming_dir] -> [$reg].
// CTRL_MOV paths emit CTRL_MOV deposits; DATA_MOV paths emit DATA_MOV deposits.
void placeDstDeposit(const Topology &topo, int dstTileId, int ts,
StringRef incomingDir, int regId, bool asCtrlMov = false) {
uint64_t signature = (uint64_t)dstTileId << 32 ^ (uint64_t)ts << 16 ^ (uint64_t)regId;
void placeDstDeposit(const Topology &topo, int dst_tile_id, int ts,
StringRef incoming_dir, int reg_id, bool asCtrlMov = false) {
uint64_t signature = (uint64_t)dst_tile_id << 32 ^ (uint64_t)ts << 16 ^ (uint64_t)reg_id;
if (!deposit_signatures.insert(signature).second) return; // already placed.
auto [tile_x, tile_y] = topo.tile_location.lookup(dstTileId);
auto [tile_x, tile_y] = topo.tile_location.lookup(dst_tile_id);
Instruction inst(asCtrlMov ? "CTRL_MOV" : "DATA_MOV");
inst.time_step = ts;
inst.src_operands.emplace_back(incomingDir.str(), "RED");
inst.dst_operands.emplace_back("$" + std::to_string(regId), "RED");
inst.src_operands.emplace_back(incoming_dir.str(), "RED");
inst.dst_operands.emplace_back("$" + std::to_string(reg_id), "RED");
tile_time_instructions[{tile_x, tile_y}][ts].push_back(std::move(inst));
}

Expand Down
Loading