diff --git a/lib/solvers/SameNetTraceMergeSolver/SameNetTraceMergeSolver.ts b/lib/solvers/SameNetTraceMergeSolver/SameNetTraceMergeSolver.ts new file mode 100644 index 00000000..dac21bf0 --- /dev/null +++ b/lib/solvers/SameNetTraceMergeSolver/SameNetTraceMergeSolver.ts @@ -0,0 +1,250 @@ +/** + * SameNetTraceMergeSolver + * + * A pipeline phase that merges same-net trace segments that are close together + * along the same axis (horizontal segments at nearly the same Y, or vertical + * segments at nearly the same X). This eliminates redundant parallel wires that + * are electrically identical, producing a cleaner schematic. + * + * Implements: https://github.com/tscircuit/schematic-trace-solver/issues/34 + */ + +import { BaseSolver } from "lib/solvers/BaseSolver/BaseSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import type { GraphicsObject } from "graphics-debug" + +/** Segments closer than this (in schematic units) on the same axis get merged */ +const GAP_THRESHOLD = 0.19 + +/** Floating-point tolerance for axis-alignment checks */ +const AXIS_TOL = 1e-9 + +type Segment = { + traceIdx: number + segIdx: number // index of first point of segment in tracePath + x1: number + y1: number + x2: number + y2: number +} + +export interface SameNetTraceMergeSolverInput { + allTraces: SolvedTracePath[] +} + +/** + * Simplify a path by removing collinear intermediate points and zero-length + * segments while preserving the first and last points. + */ +function simplifyPath( + path: Array<{ x: number; y: number }>, +): Array<{ x: number; y: number }> { + if (path.length < 3) return path + + const result: Array<{ x: number; y: number }> = [path[0]!] + + for (let i = 1; i < path.length - 1; i++) { + const prev = result[result.length - 1]! + const curr = path[i]! + const next = path[i + 1]! + + // Skip zero-length segments + if ( + Math.abs(prev.x - curr.x) < AXIS_TOL && + Math.abs(prev.y - curr.y) < AXIS_TOL + ) { + continue + } + + // Skip collinear points (both horizontal or both vertical) + const prevCurrHoriz = Math.abs(prev.y - curr.y) < AXIS_TOL + const currNextHoriz = Math.abs(curr.y - next.y) < AXIS_TOL + const prevCurrVert = Math.abs(prev.x - curr.x) < AXIS_TOL + const currNextVert = Math.abs(curr.x - next.x) < AXIS_TOL + + if ((prevCurrHoriz && currNextHoriz) || (prevCurrVert && currNextVert)) { + continue + } + + result.push(curr) + } + + const last = path[path.length - 1]! + const secondLast = result[result.length - 1]! + + // Only push the last point if it's not identical to the current last + if ( + Math.abs(secondLast.x - last.x) > AXIS_TOL || + Math.abs(secondLast.y - last.y) > AXIS_TOL + ) { + result.push(last) + } + + return result +} + +export class SameNetTraceMergeSolver extends BaseSolver { + private inputTraces: SolvedTracePath[] + outputTraces: SolvedTracePath[] + + constructor(input: SameNetTraceMergeSolverInput) { + super() + this.inputTraces = input.allTraces + // Deep-clone paths so we don't mutate the upstream solver's data + this.outputTraces = input.allTraces.map((t) => ({ + ...t, + tracePath: t.tracePath.map((p) => ({ ...p })), + })) + } + + override _step() { + // Group trace indices by globalConnNetId + const netGroups = new Map() + for (let i = 0; i < this.outputTraces.length; i++) { + const netId = this.outputTraces[i]!.globalConnNetId + if (!netGroups.has(netId)) netGroups.set(netId, []) + netGroups.get(netId)!.push(i) + } + + for (const traceIndices of netGroups.values()) { + if (traceIndices.length < 2) continue + this._mergeTracesInGroup(traceIndices) + } + + // Simplify all paths after merging + for (const trace of this.outputTraces) { + trace.tracePath = simplifyPath(trace.tracePath) + } + + this.solved = true + } + + private _mergeTracesInGroup(traceIndices: number[]) { + // Collect all segments across all traces in this net group + const allSegments: Segment[] = [] + for (const traceIdx of traceIndices) { + const path = this.outputTraces[traceIdx]!.tracePath + for (let si = 0; si < path.length - 1; si++) { + allSegments.push({ + traceIdx, + segIdx: si, + x1: path[si]!.x, + y1: path[si]!.y, + x2: path[si + 1]!.x, + y2: path[si + 1]!.y, + }) + } + } + + // Check every pair of segments from different traces + for (let i = 0; i < allSegments.length; i++) { + for (let j = i + 1; j < allSegments.length; j++) { + const a = allSegments[i]! + const b = allSegments[j]! + + // Only consider cross-trace pairs + if (a.traceIdx === b.traceIdx) continue + + const aHoriz = + Math.abs(a.y1 - a.y2) < AXIS_TOL && Math.abs(a.x1 - a.x2) > AXIS_TOL + const bHoriz = + Math.abs(b.y1 - b.y2) < AXIS_TOL && Math.abs(b.x1 - b.x2) > AXIS_TOL + const aVert = + Math.abs(a.x1 - a.x2) < AXIS_TOL && Math.abs(a.y1 - a.y2) > AXIS_TOL + const bVert = + Math.abs(b.x1 - b.x2) < AXIS_TOL && Math.abs(b.y1 - b.y2) > AXIS_TOL + + if (aHoriz && bHoriz) { + this._tryMergeHorizontal(a, b) + } else if (aVert && bVert) { + this._tryMergeVertical(a, b) + } + } + } + } + + /** + * If two horizontal segments on the same net are close in Y and overlap in X, + * snap segment b's Y to segment a's Y. + */ + private _tryMergeHorizontal(a: Segment, b: Segment) { + const yDiff = Math.abs(a.y1 - b.y1) + if (yDiff < AXIS_TOL || yDiff > GAP_THRESHOLD) return + + // Check that their X ranges overlap + const aXMin = Math.min(a.x1, a.x2) + const aXMax = Math.max(a.x1, a.x2) + const bXMin = Math.min(b.x1, b.x2) + const bXMax = Math.max(b.x1, b.x2) + + const overlapStart = Math.max(aXMin, bXMin) + const overlapEnd = Math.min(aXMax, bXMax) + if (overlapEnd <= overlapStart) return + + // Snap b's segment to a's Y coordinate + const targetY = a.y1 + const path = this.outputTraces[b.traceIdx]!.tracePath + path[b.segIdx]!.y = targetY + path[b.segIdx + 1]!.y = targetY + + // Update our in-loop segment metadata so later iterations use the new coords + b.y1 = targetY + b.y2 = targetY + } + + /** + * If two vertical segments on the same net are close in X and overlap in Y, + * snap segment b's X to segment a's X. + */ + private _tryMergeVertical(a: Segment, b: Segment) { + const xDiff = Math.abs(a.x1 - b.x1) + if (xDiff < AXIS_TOL || xDiff > GAP_THRESHOLD) return + + // Check that their Y ranges overlap + const aYMin = Math.min(a.y1, a.y2) + const aYMax = Math.max(a.y1, a.y2) + const bYMin = Math.min(b.y1, b.y2) + const bYMax = Math.max(b.y1, b.y2) + + const overlapStart = Math.max(aYMin, bYMin) + const overlapEnd = Math.min(aYMax, bYMax) + if (overlapEnd <= overlapStart) return + + // Snap b's segment to a's X coordinate + const targetX = a.x1 + const path = this.outputTraces[b.traceIdx]!.tracePath + path[b.segIdx]!.x = targetX + path[b.segIdx + 1]!.x = targetX + + b.x1 = targetX + b.x2 = targetX + } + + getOutput() { + return { + traces: this.outputTraces, + } + } + + override visualize(): GraphicsObject { + return { + lines: this.outputTraces.flatMap((trace) => { + const segs = [] + for (let i = 0; i < trace.tracePath.length - 1; i++) { + segs.push({ + x1: trace.tracePath[i]!.x, + y1: trace.tracePath[i]!.y, + x2: trace.tracePath[i + 1]!.x, + y2: trace.tracePath[i + 1]!.y, + strokeColor: "blue", + points: [], + }) + } + return segs + }), + points: [], + rects: [], + circles: [], + } + } +} diff --git a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts index c9d5a995..70d2007b 100644 --- a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts +++ b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts @@ -20,6 +20,7 @@ import { expandChipsToFitPins } from "./expandChipsToFitPins" import { LongDistancePairSolver } from "../LongDistancePairSolver/LongDistancePairSolver" import { MergedNetLabelObstacleSolver } from "../TraceLabelOverlapAvoidanceSolver/sub-solvers/LabelMergingSolver/LabelMergingSolver" import { TraceCleanupSolver } from "../TraceCleanupSolver/TraceCleanupSolver" +import { SameNetTraceMergeSolver } from "../SameNetTraceMergeSolver/SameNetTraceMergeSolver" type PipelineStep BaseSolver> = { solverName: string @@ -69,6 +70,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { labelMergingSolver?: MergedNetLabelObstacleSolver traceLabelOverlapAvoidanceSolver?: TraceLabelOverlapAvoidanceSolver traceCleanupSolver?: TraceCleanupSolver + sameNetTraceMergeSolver?: SameNetTraceMergeSolver startTimeOfPhase: Record endTimeOfPhase: Record @@ -206,11 +208,22 @@ export class SchematicTracePipelineSolver extends BaseSolver { }, ] }), + definePipelineStep( + "sameNetTraceMergeSolver", + SameNetTraceMergeSolver, + (instance) => { + const traces = + instance.traceCleanupSolver?.getOutput().traces ?? + instance.traceLabelOverlapAvoidanceSolver!.getOutput().traces + return [{ allTraces: traces }] + }, + ), definePipelineStep( "netLabelPlacementSolver", NetLabelPlacementSolver, (instance) => { const traces = + instance.sameNetTraceMergeSolver?.getOutput().traces ?? instance.traceCleanupSolver?.getOutput().traces ?? instance.traceLabelOverlapAvoidanceSolver!.getOutput().traces diff --git a/tests/examples/__snapshots__/example31.snap.svg b/tests/examples/__snapshots__/example31.snap.svg new file mode 100644 index 00000000..4e5c7102 --- /dev/null +++ b/tests/examples/__snapshots__/example31.snap.svg @@ -0,0 +1,121 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/examples/example31.test.ts b/tests/examples/example31.test.ts new file mode 100644 index 00000000..59a1d86e --- /dev/null +++ b/tests/examples/example31.test.ts @@ -0,0 +1,274 @@ +/** + * Tests for issue #34: same-net trace segments that run parallel and close + * together on the same axis should be merged into a single line by the + * SameNetTraceMergeSolver pipeline phase. + */ +import { test, expect } from "bun:test" +import type { InputProblem } from "lib/index" +import { SchematicTracePipelineSolver } from "lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver" +import { SameNetTraceMergeSolver } from "lib/solvers/SameNetTraceMergeSolver/SameNetTraceMergeSolver" +import "tests/fixtures/matcher" + +// --------------------------------------------------------------------------- +// Unit tests for SameNetTraceMergeSolver directly +// --------------------------------------------------------------------------- + +test("SameNetTraceMergeSolver merges close horizontal segments on the same net", () => { + const traceA = { + mspPairId: "pair-a" as any, + dcConnNetId: "net0", + globalConnNetId: "net0", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + tracePath: [ + { x: 0, y: 0 }, + { x: 0, y: 1.0 }, + { x: 4, y: 1.0 }, + { x: 4, y: 0 }, + ], + } + + const traceB = { + mspPairId: "pair-b" as any, + dcConnNetId: "net0", + globalConnNetId: "net0", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + // Horizontal segment at y=1.1 — within GAP_THRESHOLD of traceA's y=1.0 + tracePath: [ + { x: 1, y: 0 }, + { x: 1, y: 1.1 }, + { x: 3, y: 1.1 }, + { x: 3, y: 0 }, + ], + } + + const solver = new SameNetTraceMergeSolver({ allTraces: [traceA, traceB] }) + solver.solve() + expect(solver.solved).toBe(true) + + const outTraces = solver.getOutput().traces + const outB = outTraces.find((t) => t.mspPairId === "pair-b")! + + // Locate the horizontal segment in outB + let mergedY: number | undefined + for (let i = 0; i < outB.tracePath.length - 1; i++) { + const p1 = outB.tracePath[i]! + const p2 = outB.tracePath[i + 1]! + if (Math.abs(p1.y - p2.y) < 1e-9 && Math.abs(p1.x - p2.x) > 0.1) { + mergedY = p1.y + break + } + } + + expect(mergedY).toBeDefined() + // Should be snapped to traceA's y=1.0 + expect(Math.abs(mergedY! - 1.0)).toBeLessThan(1e-9) +}) + +test("SameNetTraceMergeSolver merges close vertical segments on the same net", () => { + const traceA = { + mspPairId: "pair-a" as any, + dcConnNetId: "net1", + globalConnNetId: "net1", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + tracePath: [ + { x: 0, y: 0 }, + { x: 2.0, y: 0 }, + { x: 2.0, y: 4 }, + { x: 0, y: 4 }, + ], + } + + const traceB = { + mspPairId: "pair-b" as any, + dcConnNetId: "net1", + globalConnNetId: "net1", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + // Vertical segment at x=2.1 — within GAP_THRESHOLD of traceA's x=2.0 + tracePath: [ + { x: 0, y: 1 }, + { x: 2.1, y: 1 }, + { x: 2.1, y: 3 }, + { x: 0, y: 3 }, + ], + } + + const solver = new SameNetTraceMergeSolver({ allTraces: [traceA, traceB] }) + solver.solve() + expect(solver.solved).toBe(true) + + const outTraces = solver.getOutput().traces + const outB = outTraces.find((t) => t.mspPairId === "pair-b")! + + // Locate the vertical segment in outB + let mergedX: number | undefined + for (let i = 0; i < outB.tracePath.length - 1; i++) { + const p1 = outB.tracePath[i]! + const p2 = outB.tracePath[i + 1]! + if (Math.abs(p1.x - p2.x) < 1e-9 && Math.abs(p1.y - p2.y) > 0.1) { + mergedX = p1.x + break + } + } + + expect(mergedX).toBeDefined() + // Should be snapped to traceA's x=2.0 + expect(Math.abs(mergedX! - 2.0)).toBeLessThan(1e-9) +}) + +test("SameNetTraceMergeSolver does NOT merge traces on different nets", () => { + const traceA = { + mspPairId: "pair-a" as any, + dcConnNetId: "net0", + globalConnNetId: "net0", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + tracePath: [ + { x: 0, y: 0 }, + { x: 0, y: 1.0 }, + { x: 4, y: 1.0 }, + { x: 4, y: 0 }, + ], + } + + const traceB = { + mspPairId: "pair-b" as any, + dcConnNetId: "net999", + globalConnNetId: "net999", // different net! + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + tracePath: [ + { x: 1, y: 0 }, + { x: 1, y: 1.05 }, + { x: 3, y: 1.05 }, + { x: 3, y: 0 }, + ], + } + + const solver = new SameNetTraceMergeSolver({ allTraces: [traceA, traceB] }) + solver.solve() + + const outTraces = solver.getOutput().traces + const outB = outTraces.find((t) => t.mspPairId === "pair-b")! + + // traceB should remain at y=1.05, not snapped to y=1.0 + const horizSeg = outB.tracePath.find( + (_, i) => + i < outB.tracePath.length - 1 && + Math.abs(outB.tracePath[i]!.y - outB.tracePath[i + 1]!.y) < 1e-9 && + Math.abs(outB.tracePath[i]!.x - outB.tracePath[i + 1]!.x) > 0.1, + )! + expect(Math.abs(horizSeg.y - 1.05)).toBeLessThan(1e-9) +}) + +test("SameNetTraceMergeSolver does NOT merge segments farther than GAP_THRESHOLD", () => { + const traceA = { + mspPairId: "pair-a" as any, + dcConnNetId: "net0", + globalConnNetId: "net0", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + tracePath: [ + { x: 0, y: 1.0 }, + { x: 4, y: 1.0 }, + ], + } + + const traceB = { + mspPairId: "pair-b" as any, + dcConnNetId: "net0", + globalConnNetId: "net0", + pins: [] as any, + mspConnectionPairIds: [] as any, + pinIds: [] as any, + // y=1.5 is more than GAP_THRESHOLD=0.19 away — should NOT be merged + tracePath: [ + { x: 1, y: 1.5 }, + { x: 3, y: 1.5 }, + ], + } + + const solver = new SameNetTraceMergeSolver({ allTraces: [traceA, traceB] }) + solver.solve() + + const outB = solver.getOutput().traces.find((t) => t.mspPairId === "pair-b")! + // y should remain 1.5 + expect(Math.abs(outB.tracePath[0]!.y - 1.5)).toBeLessThan(1e-9) +}) + +// --------------------------------------------------------------------------- +// Integration test: full pipeline includes the merge phase +// --------------------------------------------------------------------------- + +test("example31 - pipeline snapshot includes SameNetTraceMergeSolver", () => { + const inputProblem: InputProblem = { + chips: [ + { + chipId: "U1", + center: { x: 0, y: 0 }, + width: 1.0, + height: 1.2, + pins: [ + { pinId: "U1.1", x: -0.8, y: 0.3 }, + { pinId: "U1.2", x: -0.8, y: -0.3 }, + ], + }, + { + chipId: "R1", + center: { x: -3.5, y: 0.3 }, + width: 0.6, + height: 0.4, + pins: [{ pinId: "R1.1", x: -2.9, y: 0.3 }], + }, + { + chipId: "R2", + center: { x: -3.5, y: -0.28 }, + width: 0.6, + height: 0.4, + pins: [{ pinId: "R2.1", x: -2.9, y: -0.28 }], + }, + ], + netConnections: [ + { + netId: "VDD", + pinIds: ["U1.1", "R1.1", "U1.2", "R2.1"], + netLabelWidth: 0.3, + }, + ], + directConnections: [], + availableNetLabelOrientations: { + VDD: ["y+"], + }, + } + + const solver = new SchematicTracePipelineSolver(inputProblem) + solver.solve() + + // The pipeline should have run the merge phase + expect(solver.sameNetTraceMergeSolver).toBeDefined() + expect(solver.sameNetTraceMergeSolver!.solved).toBe(true) + + // All output segments must remain orthogonal + const mergedTraces = solver.sameNetTraceMergeSolver!.getOutput().traces + for (const trace of mergedTraces) { + for (let i = 0; i < trace.tracePath.length - 1; i++) { + const p1 = trace.tracePath[i]! + const p2 = trace.tracePath[i + 1]! + const isHoriz = Math.abs(p1.y - p2.y) < 1e-9 + const isVert = Math.abs(p1.x - p2.x) < 1e-9 + expect(isHoriz || isVert).toBe(true) + } + } + + expect(solver).toMatchSolverSnapshot(import.meta.path) +})