diff --git a/lib/NeuraDialect/Transforms/GenerateCodePass.cpp b/lib/NeuraDialect/Transforms/GenerateCodePass.cpp index 341b2cff..a768b2f2 100644 --- a/lib/NeuraDialect/Transforms/GenerateCodePass.cpp +++ b/lib/NeuraDialect/Transforms/GenerateCodePass.cpp @@ -382,7 +382,7 @@ struct GenerateCodePass if (isReserve(op)) return; // materialize all other ops placed on tiles (compute/phi/const/etc.). - auto placement = operation_placements[op]; + TileLocation placement = operation_placements[op]; if (!placement.has_tile) return; std::string opcode = getOpcode(op); @@ -505,7 +505,7 @@ struct GenerateCodePass // Try register-based rewiring. If cross-tile, emit deposits [incoming_dir]->[$reg] at earliest reg ts. // Returns true if rewiring to $reg was applied to consumers. template - bool handleRegisterRewiring(Operation *consOp, Value atVal, const SmallVector ®s, + bool handleRegisterRewiring(Operation *consumer_operation, Value value_at_consumer, const SmallVector ®s, const SmallVector &links, const Topology &topo) { if (regs.empty()) return false; @@ -515,27 +515,46 @@ struct GenerateCodePass if (!links.empty()) { // Cross-tile: deposit on destination tile at earliest register ts. int dst_tile = topo.dstTileOfLink(links.back().link_id); - StringRef incoming_dir = topo.dirFromLink(links.back().link_id); + // Computes incoming direction from destination tile's perspective. + StringRef incoming_dir = topo.invertDir(topo.dirFromLink(links.back().link_id)); placeDstDeposit(topo, dst_tile, timestep_0, incoming_dir, register_id, /*asCtrlMov=*/IsCtrl); - auto cp = operation_placements.lookup(consOp); - if (cp.has_tile && cp.time_step > timestep_0) { - setConsumerSourceExact(consOp, atVal, "$" + std::to_string(register_id)); + TileLocation consumer_placement = operation_placements.lookup(consumer_operation); + if (consumer_placement.has_tile && consumer_placement.time_step > timestep_0) { + setConsumerSourceExact(consumer_operation, value_at_consumer, "$" + std::to_string(register_id)); return true; } } else { // Same-tile: must go via register. - setConsumerSourceExact(consOp, atVal, "$" + std::to_string(register_id)); + setConsumerSourceExact(consumer_operation, value_at_consumer, "$" + std::to_string(register_id)); return true; } return false; } template - void handleDirectionRewiring(Operation *consOp, Value atVal, StringRef consumer_direction, - const SmallVector &links, Operation *forwarder) { + void handleDirectionRewiring(Operation *consumer_operation, Value value_at_consumer, StringRef consumer_direction, + const SmallVector &links, const Topology &topo, + Operation *forwarder) { if (!links.empty()) { - setConsumerSourceExact(consOp, atVal, consumer_direction.str()); + // Computes the direction from the link destination tile to the consumer tile. + TileLocation consumer_placement = operation_placements.lookup(consumer_operation); + if (consumer_placement.has_tile) { + int dst_tile_id = topo.dstTileOfLink(links.back().link_id); + int consumer_tile_id = topo.tileIdAt(consumer_placement.col_idx, consumer_placement.row_idx); + + // If consumer is on the link destination tile, use the incoming direction. + if (consumer_tile_id == dst_tile_id) { + setConsumerSourceExact(consumer_operation, value_at_consumer, consumer_direction.str()); + } else { + // Computes direction from link destination tile to consumer tile. + StringRef actual_dir = topo.invertDir(topo.getDirBetween(dst_tile_id, consumer_tile_id)); + setConsumerSourceExact(consumer_operation, value_at_consumer, actual_dir.str()); + } + } else { + // Falls back to consumer_direction if consumer placement is unknown. + setConsumerSourceExact(consumer_operation, value_at_consumer, consumer_direction.str()); + } } else { forwarder->emitError(IsCtrl ? "same-tile ctrl_mov without register mapping is illegal. Provide a register in mapping_locs." @@ -552,9 +571,11 @@ struct GenerateCodePass // Basic info from forwarders. Value source = forwarder->getOperand(0); Operation *producer = source.getDefiningOp(); - auto links = getLinkChain(forwarder); - auto regs = getRegisterSteps(forwarder); - auto [producer_direction, consumer_direction] = computeDirections(links, topo); + SmallVector links = getLinkChain(forwarder); + SmallVector regs = getRegisterSteps(forwarder); + std::pair directions = computeDirections(links, topo); + StringRef producer_direction = directions.first; + StringRef consumer_direction = directions.second; // Producer endpoints & intermediate hops. setProducerDestination(producer, producer_direction, regs); @@ -570,9 +591,11 @@ struct GenerateCodePass } // Wires each consumer: prefer register rewiring; fallback to direction rewiring. - for (auto &[consOp, atVal] : consumers) { - if (!handleRegisterRewiring(consOp, atVal, regs, links, topo)) - handleDirectionRewiring(consOp, atVal, consumer_direction, links, forwarder); + for (std::pair &consumer_pair : consumers) { + Operation *consumer_operation = consumer_pair.first; + Value value_at_consumer = consumer_pair.second; + if (!handleRegisterRewiring(consumer_operation, value_at_consumer, regs, links, topo)) + handleDirectionRewiring(consumer_operation, value_at_consumer, consumer_direction, links, topo, forwarder); } } @@ -750,15 +773,15 @@ struct GenerateCodePass // Endpoint deposits: on destination tiles at earliest reg ts, move [incoming_dir] -> [$reg]. // CTRL_MOV paths emit CTRL_MOV deposits; DATA_MOV paths emit DATA_MOV deposits. - void placeDstDeposit(const Topology &topo, int dstTileId, int ts, - StringRef incomingDir, int regId, bool asCtrlMov = false) { - uint64_t signature = (uint64_t)dstTileId << 32 ^ (uint64_t)ts << 16 ^ (uint64_t)regId; + void placeDstDeposit(const Topology &topo, int dst_tile_id, int ts, + StringRef incoming_dir, int reg_id, bool asCtrlMov = false) { + uint64_t signature = (uint64_t)dst_tile_id << 32 ^ (uint64_t)ts << 16 ^ (uint64_t)reg_id; if (!deposit_signatures.insert(signature).second) return; // already placed. - auto [tile_x, tile_y] = topo.tile_location.lookup(dstTileId); + auto [tile_x, tile_y] = topo.tile_location.lookup(dst_tile_id); Instruction inst(asCtrlMov ? "CTRL_MOV" : "DATA_MOV"); inst.time_step = ts; - inst.src_operands.emplace_back(incomingDir.str(), "RED"); - inst.dst_operands.emplace_back("$" + std::to_string(regId), "RED"); + inst.src_operands.emplace_back(incoming_dir.str(), "RED"); + inst.dst_operands.emplace_back("$" + std::to_string(reg_id), "RED"); tile_time_instructions[{tile_x, tile_y}][ts].push_back(std::move(inst)); } diff --git a/test/e2e/bicg/bicg_kernel.mlir b/test/e2e/bicg/bicg_kernel.mlir index 921458ab..4cdba7d6 100644 --- a/test/e2e/bicg/bicg_kernel.mlir +++ b/test/e2e/bicg/bicg_kernel.mlir @@ -186,6 +186,7 @@ // YAML-NEXT: - operand: "EAST" // YAML-NEXT: color: "RED" +// ASM: # Compiled II: 12 // ASM: PE(0,0): // ASM-NEXT: { // ASM-NEXT: CONSTANT, [#0] -> [EAST, RED] @@ -194,10 +195,10 @@ // ASM-NEXT: GRANT_ONCE, [] -> [EAST, RED], [NORTH, RED] // ASM-NEXT: } (t=2) // ASM-NEXT: { -// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0] +// ASM-NEXT: DATA_MOV, [EAST, RED] -> [$0] // ASM-NEXT: } (t=4) // ASM-NEXT: { -// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$1] +// ASM-NEXT: DATA_MOV, [EAST, RED] -> [$1] // ASM-NEXT: } (t=5) // ASM-NEXT: { // ASM-NEXT: GRANT_PREDICATE, [$0], [$1] -> [$0] diff --git a/test/e2e/fir/fir_kernel.mlir b/test/e2e/fir/fir_kernel.mlir index 490d4175..c2b79d3c 100644 --- a/test/e2e/fir/fir_kernel.mlir +++ b/test/e2e/fir/fir_kernel.mlir @@ -83,7 +83,7 @@ // ASM-NEXT: } (t=3) // ASM: PE(2,2): // ASM-NEXT: { -// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0] +// ASM-NEXT: DATA_MOV, [EAST, RED] -> [$0] // ASM-NEXT: } (t=2) // ASM-NEXT: { // ASM-NEXT: ICMP_EQ, [EAST, RED], [#32] -> [$0], [WEST, RED] diff --git a/test/e2e/relu/relu_kernel.mlir b/test/e2e/relu/relu_kernel.mlir index 049ffbbe..698ae7c3 100644 --- a/test/e2e/relu/relu_kernel.mlir +++ b/test/e2e/relu/relu_kernel.mlir @@ -99,5 +99,5 @@ // ASM: PE(2,1): // ASM-NEXT: { -// ASM-NEXT: DATA_MOV, [SOUTH, RED] -> [$1] +// ASM-NEXT: DATA_MOV, [NORTH, RED] -> [$1] // ASM-NEXT: } (t=5) \ No newline at end of file