From 73325707c044d1dd59d70d0afbe1d02ab7f1a1f5 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 17 Oct 2023 14:37:10 -0600 Subject: [PATCH] routing integration --- pytket/phir/api.py | 16 +++++++-- pytket/phir/{machine_class.py => machine.py} | 0 pytket/phir/place_and_route.py | 13 ++------ pytket/phir/qtm_machine.py | 15 +++++++++ tests/e2e_test.py | 6 ++-- tests/sample_data.py | 34 ++++++++++---------- tests/test_api.py | 14 ++++---- tests/test_placement.py | 2 +- tests/test_rebaser.py | 4 +-- tests/test_sharder.py | 16 ++++----- 10 files changed, 70 insertions(+), 50 deletions(-) rename pytket/phir/{machine_class.py => machine.py} (100%) diff --git a/pytket/phir/api.py b/pytket/phir/api.py index f4bbb6e..c60ca14 100644 --- a/pytket/phir/api.py +++ b/pytket/phir/api.py @@ -1,7 +1,9 @@ import logging from pytket.circuit import Circuit -from pytket.phir.qtm_machine import QtmMachine +from pytket.phir.machine import Machine +from pytket.phir.place_and_route import place_and_route +from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine from pytket.phir.sharding.sharder import Sharder @@ -24,14 +26,24 @@ def pytket_to_phir( PHIR JSON as a str """ logger.info(f"Starting phir conversion process for circuit {circuit}") + machine: Machine | None = None if qtm_machine: logger.info(f"Rebasing to machine {qtm_machine}") circuit = rebase_to_qtm_machine(circuit, qtm_machine.value) + machine = QTM_MACHINES_MAP.get(qtm_machine) + else: + msg = "Machine parameter is currently required" + raise NotImplementedError(msg) + logger.debug("Sharding input circuit...") sharder = Sharder(circuit) shards = sharder.shard() - phir_output = str(shards) # Just returning fake string for now + logger.debug("Performing placement and routing...") + placed = place_and_route(machine, shards) # type: ignore [misc] + + phir_output = str(placed) # type: ignore [misc] + # TODO: Pass shards[] into placement, routing, etc # TODO: Convert to PHIR JSON spec and return logger.info("Output: %s", phir_output) diff --git a/pytket/phir/machine_class.py b/pytket/phir/machine.py similarity index 100% rename from pytket/phir/machine_class.py rename to pytket/phir/machine.py diff --git a/pytket/phir/place_and_route.py b/pytket/phir/place_and_route.py index 434e4e7..2642bdd 100644 --- a/pytket/phir/place_and_route.py +++ b/pytket/phir/place_and_route.py @@ -1,22 +1,15 @@ import typing -from pytket.phir.machine_class import Machine +from pytket.phir.machine import Machine from pytket.phir.placement import optimized_place from pytket.phir.routing import transport_cost -from pytket.phir.sharding.sharder import Sharder +from pytket.phir.sharding.shard import Shard from pytket.phir.sharding.shards2ops import parse_shards_naive -from tests.sample_data import ( # type: ignore [attr-defined] - Circuit, - get_qasm_as_circuit, -) @typing.no_type_check -def place_and_route(machine: Machine, qasm: Circuit): +def place_and_route(machine: Machine, shards: list[Shard]): """Get all the routing info needed for PHIR generation.""" - circuit = get_qasm_as_circuit(qasm) - sharder = Sharder(circuit) - shards = sharder.shard() shard_set = set(shards) circuit_rep, shard_layers = parse_shards_naive(shard_set) initial_order = list(range(machine.size)) diff --git a/pytket/phir/qtm_machine.py b/pytket/phir/qtm_machine.py index e8ebd74..e6cf83e 100644 --- a/pytket/phir/qtm_machine.py +++ b/pytket/phir/qtm_machine.py @@ -1,5 +1,7 @@ from enum import Enum +from pytket.phir.machine import Machine + class QtmMachine(Enum): """Available machine architectures.""" @@ -7,3 +9,16 @@ class QtmMachine(Enum): H1_1 = "H1-1" H1_2 = "H1-2" H2_1 = "H2-1" + + +QTM_MACHINES_MAP = { + QtmMachine.H1_1: Machine( + size=20, + tq_options={0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + # need to get better timing values for below + # but will have to look them up in hqcompiler + tq_time=3.0, + sq_time=1.0, + qb_swap_time=2.0, + ), +} diff --git a/tests/e2e_test.py b/tests/e2e_test.py index b90b36a..ef53833 100644 --- a/tests/e2e_test.py +++ b/tests/e2e_test.py @@ -1,7 +1,7 @@ -from pytket.phir.machine_class import Machine +from pytket.phir.machine import Machine from pytket.phir.place_and_route import place_and_route from pytket.phir.placement import placement_check -from tests.sample_data import QasmFiles +from tests.sample_data import QasmFile if __name__ == "__main__": machine = Machine( @@ -16,7 +16,7 @@ machine.sq_options = {0, 1, 2} # The type: ignores in this file are because mypy doesn't like the return type of place and route, # noqa: E501 # list[triple(list[int], list[Shard], float)] - output = place_and_route(machine, QasmFiles.eztest) # type: ignore [misc] + output = place_and_route(machine, QasmFile.eztest) # type: ignore [misc] # print(output) ez_ops_0 = [[0, 2], [1]] ez_ops_1 = [[0], [2]] diff --git a/tests/sample_data.py b/tests/sample_data.py index 2b803b3..2909e0f 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -1,28 +1,28 @@ import os -from enum import Enum +from enum import Enum, auto from pytket.circuit import Circuit from pytket.qasm.qasm import circuit_from_qasm -class QasmFiles(Enum): - simple = 1 - cond_1 = 2 - bv_n10 = 3 - baby = 4 - baby_with_rollup = 5 - simple_cond = 6 - cond_classical = 7 - barrier_complex = 8 - classical_hazards = 9 - big_gate = 10 - n10_test = 11 - qv20_0 = 13 - oned_brickwork_circuit_n20 = 14 - eztest = 15 +class QasmFile(Enum): + simple = auto() + cond_1 = auto() + bv_n10 = auto() + baby = auto() + baby_with_rollup = auto() + simple_cond = auto() + cond_classical = auto() + barrier_complex = auto() + classical_hazards = auto() + big_gate = auto() + n10_test = auto() + qv20_0 = auto() + oned_brickwork_circuit_n20 = auto() + eztest = auto() -def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: +def get_qasm_as_circuit(qasm_file: QasmFile) -> Circuit: """Utility function to convert a QASM file to Circuit. Args: diff --git a/tests/test_api.py b/tests/test_api.py index 460b09d..f39f552 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,20 +1,20 @@ +import pytest + from pytket.phir.api import pytket_to_phir from pytket.phir.qtm_machine import QtmMachine -from .sample_data import QasmFiles, get_qasm_as_circuit +from .sample_data import QasmFile, get_qasm_as_circuit class TestApi: def test_pytket_to_phir_no_machine(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.baby) - - phir = pytket_to_phir(circuit) + circuit = get_qasm_as_circuit(QasmFile.baby) - # TODO: Make this test more valuable once PHIR is actually returned - assert len(phir) > 0 + with pytest.raises(NotImplementedError): + pytket_to_phir(circuit) def test_pytket_to_phir_h1_1(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.baby) + circuit = get_qasm_as_circuit(QasmFile.baby) phir = pytket_to_phir(circuit, QtmMachine.H1_1) diff --git a/tests/test_placement.py b/tests/test_placement.py index 329ea3c..f671975 100644 --- a/tests/test_placement.py +++ b/tests/test_placement.py @@ -3,7 +3,7 @@ import pytest -from pytket.phir.machine_class import Machine +from pytket.phir.machine import Machine from pytket.phir.placement import ( GateOpportunitiesError, InvalidParallelOpsError, diff --git a/tests/test_rebaser.py b/tests/test_rebaser.py index 0d4f01b..ceeca49 100644 --- a/tests/test_rebaser.py +++ b/tests/test_rebaser.py @@ -3,7 +3,7 @@ from pytket.circuit import Circuit, OpType from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine -from .sample_data import QasmFiles, get_qasm_as_circuit +from .sample_data import QasmFile, get_qasm_as_circuit EXPECTED_GATES = [ OpType.Measure, @@ -18,7 +18,7 @@ class TestRebaser: def test_rebaser_happy_path_arc1a(self) -> None: - circ = get_qasm_as_circuit(QasmFiles.baby) + circ = get_qasm_as_circuit(QasmFile.baby) rebased: Circuit = rebase_to_qtm_machine(circ, "H1-1") logger.info(rebased) diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 2fd6715..b96e6e7 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -3,12 +3,12 @@ from pytket.circuit import Conditional, Op, OpType from pytket.phir.sharding.sharder import Sharder -from .sample_data import QasmFiles, get_qasm_as_circuit +from .sample_data import QasmFile, get_qasm_as_circuit class TestSharder: def test_shard_hashing(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.baby) + circuit = get_qasm_as_circuit(QasmFile.baby) sharder = Sharder(circuit) shards = sharder.shard() @@ -38,7 +38,7 @@ def test_should_op_create_shard(self) -> None: assert not Sharder.should_op_create_shard(op) def test_with_baby_circuit(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.baby) + circuit = get_qasm_as_circuit(QasmFile.baby) sharder = Sharder(circuit) shards = sharder.shard() @@ -61,7 +61,7 @@ def test_with_baby_circuit(self) -> None: assert shards[2].depends_upon == {shards[0].ID} def test_rollup_behavior(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup) + circuit = get_qasm_as_circuit(QasmFile.baby_with_rollup) sharder = Sharder(circuit) shards = sharder.shard() @@ -92,7 +92,7 @@ def test_rollup_behavior(self) -> None: assert shards[4].depends_upon == {shards[0].ID, shards[2].ID} def test_simple_conditional(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.simple_cond) + circuit = get_qasm_as_circuit(QasmFile.simple_cond) sharder = Sharder(circuit) shards = sharder.shard() @@ -139,7 +139,7 @@ def test_simple_conditional(self) -> None: assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] def test_complex_barriers(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.barrier_complex) + circuit = get_qasm_as_circuit(QasmFile.barrier_complex) sharder = Sharder(circuit) shards = sharder.shard() @@ -224,7 +224,7 @@ def test_complex_barriers(self) -> None: assert shards[6].depends_upon == {shards[5].ID, shards[4].ID} def test_classical_hazards(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.classical_hazards) + circuit = get_qasm_as_circuit(QasmFile.classical_hazards) sharder = Sharder(circuit) shards = sharder.shard() @@ -272,7 +272,7 @@ def test_classical_hazards(self) -> None: assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID} def test_with_big_gate(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.big_gate) + circuit = get_qasm_as_circuit(QasmFile.big_gate) sharder = Sharder(circuit) shards = sharder.shard()