diff --git a/pytket/phir/routing.py b/pytket/phir/routing.py index eaabb46..f2d94f1 100644 --- a/pytket/phir/routing.py +++ b/pytket/phir/routing.py @@ -9,17 +9,26 @@ from __future__ import annotations +class TransportError(Exception): + def __init__(self, a: list[int], b: list[int]): + super().__init__(f"Traps different sizes: {len(a)} vs. {len(b)}") + + +class PermutationError(Exception): + def __init__(self, lst: list[int]): + super().__init__(f"List {lst} is not a permutation of range({len(lst)})") + + def inverse(lst: list[int]) -> list[int]: """Inverse of a permutation list. If a[i] = x, then inverse(a)[x] = i.""" inv = [-1] * len(lst) - for (i, elem) in enumerate(lst): - if not 0 <= elem < len(lst): - raise ValueError(f"List contains element not in range: {elem}") - if inv[elem] != -1: - raise ValueError(f"List contains duplicate elements: {lst}") + for i, elem in enumerate(lst): + if not 0 <= elem < len(lst) or inv[elem] != -1: + raise PermutationError(lst) inv[elem] = i return inv + def transport_cost(init: list[int], goal: list[int], swap_cost: float) -> float: """Cost of transport from init to goal. @@ -27,8 +36,8 @@ def transport_cost(init: list[int], goal: list[int], swap_cost: float) -> float: Transposition Sort, which is the maximum distance that any qubit travels. """ if len(init) != len(goal): - raise ValueError( - f"init and goal lists have different lengths: {len(init)} vs. {len(goal)}" - ) - n_swaps = max(abs(g - i) for (i, g) in zip(inverse(init), inverse(goal))) + raise TransportError(init, goal) + n_swaps = max( + abs(g - i) for (i, g) in zip(inverse(init), inverse(goal)) # noqa: B905 + ) return n_swaps * swap_cost