Skip to content

Commit

Permalink
Make machine argument optional (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
Asa-Kosto-QTM authored Oct 26, 2023
1 parent 62da8f1 commit 97c539b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
13 changes: 6 additions & 7 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,18 @@ def pytket_to_phir(
circuit = rebase_to_qtm_machine(circuit, qtm_machine.value)
machine = QTM_MACHINES_MAP.get(qtm_machine)
else:
msg = "Machine parameter is currently required"
raise NotImplementedError(msg)
machine = None

logger.debug("Sharding input circuit...")
sharder = Sharder(circuit)
shards = sharder.shard()

logger.debug("Performing placement and routing...")
if machine:
placed = place_and_route(machine, shards)
else:
msg = "no machine found"
raise ValueError(msg)
# Only print message if a machine object is passed
# Otherwise, placment and routing are functionally skipped
# The function is called, but the output is just filled with 0s
logger.debug("Performing placement and routing...")
placed = place_and_route(shards, machine)

phir_json = genphir(placed)

Expand Down
44 changes: 28 additions & 16 deletions pytket/phir/place_and_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,43 @@


def place_and_route(
machine: Machine,
shards: list[Shard],
machine: Machine | None = None,
) -> list[tuple[Ordering, Layer, Cost]]:
"""Get all the routing info needed for PHIR generation."""
shard_set = set(shards)
circuit_rep, shard_layers = parse_shards_naive(shard_set)
initial_order = list(range(machine.size))
if machine:
initial_order = list(range(machine.size))
layer_num = 0
orders: list[Ordering] = []
layer_costs: list[Cost] = []
net_cost: float = 0.0
for layer in circuit_rep:
order = optimized_place(
layer,
machine.tq_options,
machine.sq_options,
machine.size,
initial_order,
)
orders.append(order)
cost = transport_cost(initial_order, order, machine.qb_swap_time)
layer_num += 1
initial_order = order
layer_costs.append(cost)
net_cost += cost
if machine:
for layer in circuit_rep:
order = optimized_place(
layer,
machine.tq_options,
machine.sq_options,
machine.size,
initial_order,
)
orders.append(order)
cost = transport_cost(initial_order, order, machine.qb_swap_time)
layer_num += 1
initial_order = order
layer_costs.append(cost)
net_cost += cost
else:
# If no machine object specified,
# generic lists of qubits with no placement and no routing costs,
# only the shards

# If needed later, write a helper to find the number
# of qubits needed in the circuit
n = len(circuit_rep)
orders = [[]] * n
layer_costs = [0] * n

# don't need a custom error for this, "strict" parameter will throw error if needed
return list(zip(orders, shard_layers, layer_costs, strict=True))
9 changes: 7 additions & 2 deletions tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytket.phir.phirgen import genphir
from pytket.phir.place_and_route import place_and_route
from pytket.phir.placement import placement_check
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.sharding.sharder import Sharder
from tests.sample_data import QasmFile, get_qasm_as_circuit

Expand All @@ -19,10 +20,14 @@
# force machine options for this test
# machines normally don't like odd numbers of qubits
machine.sq_options = {0, 1, 2}
circuit = get_qasm_as_circuit(QasmFile.eztest)

h11 = QTM_MACHINES_MAP[QtmMachine.H1_1]

circuit = get_qasm_as_circuit(QasmFile.classical_hazards)
sharder = Sharder(circuit)
shards = sharder.shard()
output = place_and_route(machine, shards)

output = place_and_route(shards, h11)
ez_ops_0 = [[0, 2], [1]]
ez_ops_1 = [[0], [2]]
state_0 = output[0][0]
Expand Down
7 changes: 3 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from pytket.phir.api import pytket_to_phir
from pytket.phir.qtm_machine import QtmMachine

Expand All @@ -8,12 +6,13 @@

class TestApi:
def test_pytket_to_phir_no_machine(self) -> None:
"""Test case when no machine is present."""
circuit = get_qasm_as_circuit(QasmFile.baby)

with pytest.raises(NotImplementedError):
pytket_to_phir(circuit)
assert pytket_to_phir(circuit)

def test_pytket_to_phir_h1_1(self) -> None:
"""Standard case."""
circuit = get_qasm_as_circuit(QasmFile.baby)

# TODO(neal): Make this test more valuable once PHIR is actually returned
Expand Down

0 comments on commit 97c539b

Please sign in to comment.