Skip to content

Commit

Permalink
show transport time in PHIR (#147)
Browse files Browse the repository at this point in the history
* show transport time in PHIR

* consolidate machine objects, add warning for gate types not assigned a transport time
  • Loading branch information
Asa-Kosto-QTM authored Mar 15, 2024
1 parent c940257 commit 9cd29b4
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 38 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Install additional dependencies needed for the CLI using `pip install pytket-phi

```sh
❯ phirc -h
usage: phirc [-h] [-w WASM_FILE] [-m {H1-1,H1-2}] [-v] [--version] qasm_files [qasm_files ...]
usage: phirc [-h] [-w WASM_FILE] [-m {H1}] [-v] [--version] qasm_files [qasm_files ...]

Emulates QASM program execution via PECOS

Expand All @@ -37,8 +37,8 @@ options:
-h, --help show this help message and exit
-w WASM_FILE, --wasm-file WASM_FILE
Optional WASM file for use by the QASM programs
-m {H1-1,H1-2}, --machine {H1-1,H1-2}
Machine name, H1-1 by default
-m {H1}, --machine {H1}
Machine name, H1 by default
-v, --verbose
--version show program's version number and exit
```
Expand Down
14 changes: 7 additions & 7 deletions pytket/phir/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def main() -> None:
parser.add_argument(
"-m",
"--machine",
choices=["H1-1", "H1-2"],
default="H1-1",
help="Machine name, H1-1 by default",
choices=["H1"],
default="H1",
help="Machine name, H1 by default",
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
Expand All @@ -66,10 +66,10 @@ def main() -> None:
circuit = circuit_from_qasm(file)

match args.machine:
case "H1-1":
machine = QtmMachine.H1_1
case "H1-2":
machine = QtmMachine.H1_2
case "H1":
machine = QtmMachine.H1
case _:
raise NotImplementedError

if args.verbose:
logging.basicConfig(level=logging.INFO)
Expand Down
3 changes: 3 additions & 0 deletions pytket/phir/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class MachineTimings:
tq_time: time for a two qubit gate
sq_time: time for a single qubit gate
qb_swap_time: time it takes to swap to qubits
meas_prep_time: time to arrange qubits for measurement
"""

tq_time: float
sq_time: float
qb_swap_time: float
meas_prep_time: float


class Machine:
Expand Down Expand Up @@ -52,6 +54,7 @@ def __init__(
self.tq_time = timings.tq_time
self.sq_time = timings.sq_time
self.qb_swap_time = timings.qb_swap_time
self.meas_prep_time = timings.meas_prep_time

for i in self.tq_options:
self.sq_options.add(i)
Expand Down
38 changes: 36 additions & 2 deletions pytket/phir/phirgen_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,39 @@ def format_and_add_primary_commands(
groups2qops(fmt_g2q, ops)


def adjust_phir_transport_time(ops: list["JsonDict"], machine: "Machine") -> None:
"""Analyze the generated phir and adjust the transport time."""
adjustment = 0.0
for op in ops:
if "qop" in op:
match op["qop"]:
case "RZ" | "R1XY":
adjustment += machine.sq_time
case "RZZ":
adjustment += machine.tq_time
case "Measure":
adjustment += machine.meas_prep_time
case _:
logger.warning(
"Gate type %s not assigned a transport duration", op["qop"]
)
if "block" in op and op["block"] == "qparallel":
first_op = op["ops"][0]["qop"]
match first_op:
case "RZ" | "R1XY":
adjustment += machine.sq_time
case "RZZ":
adjustment += machine.tq_time
case _:
logger.warning(
"Gate type %s not assigned a transport duration", first_op
)
if "mop" in op and op["mop"] == "Transport":
cost, units = op["duration"]
op["duration"] = cost + adjustment, units
adjustment = 0.0


def genphir_parallel(
inp: list[tuple["Ordering", "ShardLayer", "Cost"]], machine: "Machine"
) -> str:
Expand All @@ -289,8 +322,8 @@ def genphir_parallel(
inp: list of shards
machine: a QTM machine on which to simulate the circuit
"""
max_parallel_tq_gates = len(machine.tq_options)
max_parallel_sq_gates = len(machine.sq_options)
max_parallel_tq_gates = len(machine.tq_options) // 2
max_parallel_sq_gates = len(machine.sq_options) // 2

phir = PHIR_HEADER
phir["metadata"]["strict_parallelism"] = True
Expand Down Expand Up @@ -322,6 +355,7 @@ def genphir_parallel(
"duration": (layer_cost, "ms"),
},
)
adjust_phir_transport_time(ops, machine)

decls = get_decls(qbits, cbits)

Expand Down
24 changes: 9 additions & 15 deletions pytket/phir/qtm_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,21 @@
class QtmMachine(Enum):
"""Available machine architectures."""

H1_1 = "H1-1"
H1_2 = "H1-2"
H1 = "H1"


QTM_DEFAULT_GATESET = {OpType.Rz, OpType.PhasedX, OpType.ZZPhase}

QTM_MACHINES_MAP = {
QtmMachine.H1_1: Machine(
QtmMachine.H1: Machine(
size=20,
gateset=QTM_DEFAULT_GATESET,
tq_options={0, 2, 4, 6, 8, 10, 12, 14, 16, 18},
# need to get better timing values for below
# but will have to look them up in hqcompiler
timings=MachineTimings(tq_time=3.0, sq_time=1.0, qb_swap_time=2.0),
),
QtmMachine.H1_2: Machine(
size=12,
gateset=QTM_DEFAULT_GATESET,
tq_options={0, 2, 4, 6, 8, 10},
# need to get better timing values for below
# but will have to look them up in hqcompiler
timings=MachineTimings(tq_time=3.0, sq_time=1.0, qb_swap_time=2.0),
),
timings=MachineTimings(
tq_time=0.04,
sq_time=0.03,
qb_swap_time=0.9,
meas_prep_time=0.05,
),
)
}
4 changes: 2 additions & 2 deletions tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from tests.test_utils import QasmFile, get_qasm_as_circuit

if __name__ == "__main__":
machine = Machine(3, QTM_DEFAULT_GATESET, {1}, MachineTimings(3.0, 1.0, 2.0))
machine = Machine(3, QTM_DEFAULT_GATESET, {1}, MachineTimings(3.0, 1.0, 2.0, 2.0))
# force machine options for this test
# machines normally don't like odd numbers of qubits
machine.sq_options = {0, 1, 2}

h11 = QTM_MACHINES_MAP[QtmMachine.H1_1]
h11 = QTM_MACHINES_MAP[QtmMachine.H1]

circuit = get_qasm_as_circuit(QasmFile.classical_hazards)
sharder = Sharder(circuit)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_pytket_to_phir_h1_1_all(self, test_file: QasmFile) -> None:
"""Standard case."""
circuit = get_qasm_as_circuit(test_file)

assert pytket_to_phir(circuit, QtmMachine.H1_1)
assert pytket_to_phir(circuit, QtmMachine.H1)

def test_pytket_classical_only(self) -> None:
c = Circuit(1)
Expand Down Expand Up @@ -85,4 +85,4 @@ def test_qasm_to_phir(self) -> None:
measure q[1]->cr[0];
"""

assert qasm_to_phir(qasm, QtmMachine.H1_1)
assert qasm_to_phir(qasm, QtmMachine.H1)
6 changes: 3 additions & 3 deletions tests/test_placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
)
from pytket.phir.qtm_machine import QTM_DEFAULT_GATESET

m = Machine(4, QTM_DEFAULT_GATESET, {1}, MachineTimings(10, 2, 2))
m2 = Machine(6, QTM_DEFAULT_GATESET, {1, 3}, MachineTimings(10, 2, 2))
m3 = Machine(8, QTM_DEFAULT_GATESET, {0, 6}, MachineTimings(10, 2, 2))
m = Machine(4, QTM_DEFAULT_GATESET, {1}, MachineTimings(10, 2, 2, 1))
m2 = Machine(6, QTM_DEFAULT_GATESET, {1, 3}, MachineTimings(10, 2, 2, 1))
m3 = Machine(8, QTM_DEFAULT_GATESET, {0, 6}, MachineTimings(10, 2, 2, 1))


def test_placement_check() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rebaser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class TestRebaser:
def test_rebaser_happy_path_arc1a(self) -> None:
circ = get_qasm_as_circuit(QasmFile.baby)
rebased: Circuit = rebase_to_qtm_machine(circ, QtmMachine.H1_1)
rebased: Circuit = rebase_to_qtm_machine(circ, QtmMachine.H1)

logger.info(rebased)
for command in rebased.get_commands():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_qasm_as_circuit(qasm_file: QasmFile) -> "Circuit":

def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> "JsonDict":
"""Get the QASM file for the specified circuit."""
qtm_machine = QtmMachine.H1_1
qtm_machine = QtmMachine.H1
circuit = get_qasm_as_circuit(qasmfile)
if rebase:
circuit = rebase_to_qtm_machine(circuit, qtm_machine)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_qasm_to_phir_with_wasm(self) -> None:

wasm_uid = hashlib.sha256(base64.b64encode(wasm_bytes)).hexdigest()

phir_str = qasm_to_phir(qasm, QtmMachine.H1_1, wasm_bytes=wasm_bytes)
phir_str = qasm_to_phir(qasm, QtmMachine.H1, wasm_bytes=wasm_bytes)
phir = json.loads(phir_str)

expected_metadata = {"ff_object": (f"WASM module uid: {wasm_uid}")}
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_pytket_with_wasm(self) -> None:

c.add_wasm_to_reg("add_one", w, [c0], [c0], condition=c1[0])

phir_str = pytket_to_phir(c, QtmMachine.H1_1)
phir_str = pytket_to_phir(c, QtmMachine.H1)
finally:
Path.unlink(Path(wasm_file.name))

Expand Down

0 comments on commit 9cd29b4

Please sign in to comment.