Skip to content

Commit

Permalink
Merge pull request #71 from CQCL/parallel-ops-testing
Browse files Browse the repository at this point in the history
More specific testing for parallel ops formatting
  • Loading branch information
Asa-Kosto-QTM authored Dec 15, 2023
2 parents d1b7f65 + 88432ee commit 63830bc
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ include "hqslib1.inc";
creg c[4];
qreg q[4];

Rxxyyzz(0.5, 0.5, 0.5) q[0],q[1];
Rxxyyzz(0.5, 0.5, 0.5) q[2],q[3];

Rxxyyzz(1.0, 1.0, 1.0) q[0],q[1];
Rxxyyzz(0.5, 0.5, 0.5) q[2],q[3];

Expand Down
10 changes: 10 additions & 0 deletions tests/data/qasm/tk2_same_angle.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
OPENQASM 2.0;
include "hqslib1.inc";

creg c[4];
qreg q[4];

Rxxyyzz(0.5, 0.5, 0.5) q[0],q[1];
Rxxyyzz(0.5, 0.5, 0.5) q[2],q[3];

measure q->c;
2 changes: 1 addition & 1 deletion tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytket.phir.placement import placement_check
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.sharding.sharder import Sharder
from tests.sample_data import QasmFile, get_qasm_as_circuit
from tests.test_utils import QasmFile, get_qasm_as_circuit

if __name__ == "__main__":
machine = Machine(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytket.phir.api import pytket_to_phir, qasm_to_phir
from pytket.phir.qtm_machine import QtmMachine

from .sample_data import QasmFile, get_qasm_as_circuit
from .test_utils import QasmFile, get_qasm_as_circuit

logger = logging.getLogger(__name__)

Expand Down
144 changes: 54 additions & 90 deletions tests/test_parallel_tk2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,103 +8,67 @@

# mypy: disable-error-code="misc"

import json
import logging
from typing import Any

from pytket.phir.phirgen_parallel import genphir_parallel
from pytket.phir.place_and_route import place_and_route
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.sharding.sharder import Sharder

from .sample_data import QasmFile, get_qasm_as_circuit
from .test_utils import QasmFile, get_phir_json

logger = logging.getLogger(__name__)


def get_phir_json_no_rebase(qasmfile: QasmFile) -> dict[str, Any]:
"""Get the QASM file for the specified circuit."""
qtm_machine = QtmMachine.H1_1
circuit = get_qasm_as_circuit(qasmfile)
machine = QTM_MACHINES_MAP.get(qtm_machine)
assert machine
sharder = Sharder(circuit)
shards = sharder.shard()
placed = place_and_route(shards, machine)
return json.loads(genphir_parallel(placed, machine)) # type: ignore[no-any-return]
def test_pll_tk2_same_angle() -> None:
"""Make sure the parallelization is correct for the tk2_same_angle circuit."""
phir = get_phir_json(QasmFile.tk2_same_angle, rebase=False)

# Check that the op is properly formatted
op = phir["ops"][3]
measure = phir["ops"][5]
assert op["qop"] == "R2XXYYZZ"

# Check that the args are properly formatted
assert len(op["args"]) == 2
assert len(op["args"][0]) == len(op["args"][1]) == 2
q01_fst = (["q", 0] in op["args"][0]) and (["q", 1] in op["args"][0])
q01_snd = (["q", 0] in op["args"][1]) and (["q", 1] in op["args"][1])
q23_fst = (["q", 2] in op["args"][0]) and (["q", 3] in op["args"][0])
q23_snd = (["q", 2] in op["args"][1]) and (["q", 3] in op["args"][1])
assert (q01_fst and q23_snd) != (q23_fst and q01_snd)

# Check that the measure op is properly formatted
measure_args = measure["args"]
measure_returns = measure["returns"]
assert len(measure_args) == len(measure_returns) == 4
assert measure_args.index(["q", 0]) == measure_returns.index(["c", 0])
assert measure_args.index(["q", 1]) == measure_returns.index(["c", 1])
assert measure_args.index(["q", 2]) == measure_returns.index(["c", 2])
assert measure_args.index(["q", 3]) == measure_returns.index(["c", 3])


def test_pll_tk2_diff_angles() -> None:
"""Make sure the parallelization is correct for the tk2_diff_angles circuit."""
phir = get_phir_json(QasmFile.tk2_diff_angles, rebase=False)

def test_pll_tk2() -> None:
"""Make sure the parallelization is happening properly for the tk2 circuit."""
# the first pair of gates have the same angle arguments
# to make sure that the qubit arguments get added to the
# same list and the comment is generated with the angle
# the second pair of gates have differing angle arguments
# to make sure the qops get added to a parallel block
actual = get_phir_json_no_rebase(QasmFile.tk2)
# DO NOT modify the expected json
# it is the correct output for the tk2.qasm file
# if you change the tk2.qasm file, you just re-generate the correct
# phir json and replace the expected or the test will fail
expected: dict[str, Any] = {
"ops": [
{"data": "qvar_define", "data_type": "qubits", "variable": "q", "size": 4},
{"data": "cvar_define", "data_type": "u32", "variable": "c", "size": 4},
{"//": "Parallel TK2(0.159155, 0.159155, 0.159155)"},
{
"qop": "R2XXYYZZ",
"angles": [
[0.15915494309189535, 0.15915494309189535, 0.15915494309189535],
"pi",
],
"args": [[["q", 0], ["q", 1]], [["q", 2], ["q", 3]]],
},
{"mop": "Transport", "duration": [0.0, "ms"]},
{"//": "Parallel R2XXYYZZ"},
{
"block": "qparallel",
"ops": [
{
"qop": "R2XXYYZZ",
"angles": [
[
0.3183098861837907,
0.3183098861837907,
0.3183098861837907,
],
"pi",
],
"args": [[["q", 0], ["q", 1]]],
},
{
"qop": "R2XXYYZZ",
"angles": [
[
0.15915494309189535,
0.15915494309189535,
0.15915494309189535,
],
"pi",
],
"args": [[["q", 2], ["q", 3]]],
},
],
},
{"mop": "Transport", "duration": [0.0, "ms"]},
{
"qop": "Measure",
"args": [["q", 0], ["q", 1], ["q", 2], ["q", 3]],
"returns": [["c", 0], ["c", 1], ["c", 2], ["c", 3]],
},
{"mop": "Transport", "duration": [0.0, "ms"]},
],
}
# Check that the qparallel block is properly formatted
block = phir["ops"][3]
measure = phir["ops"][5]
assert block["block"] == "qparallel"
assert len(block["ops"]) == 2

assert actual["ops"][6]["block"] == "qparallel"
for op in expected["ops"][6]["ops"]:
assert op in actual["ops"][6]["ops"]
# Check that the individual ops are properly formatted
qop0, qop1 = block["ops"]
assert qop0["qop"] == qop1["qop"] == "R2XXYYZZ"
assert len(qop0["args"][0]) == len(qop1["args"][0]) == 2
# Ensure the args for each op are invalid combinations, irrespecive of order
q01_fst = (["q", 0] in qop0["args"][0]) and (["q", 1] in qop0["args"][0])
q01_snd = (["q", 0] in qop1["args"][0]) and (["q", 1] in qop1["args"][0])
q23_fst = (["q", 2] in qop0["args"][0]) and (["q", 3] in qop0["args"][0])
q23_snd = (["q", 2] in qop1["args"][0]) and (["q", 3] in qop1["args"][0])
assert (q01_fst and q23_snd) != (q23_fst and q01_snd)

act_meas_op = actual["ops"][8]
assert act_meas_op["qop"] == "Measure"
assert sorted(act_meas_op["args"]) == expected["ops"][8]["args"]
assert sorted(act_meas_op["returns"]) == expected["ops"][8]["returns"]
# Check that the measure op is properly foramtted
measure_args = measure["args"]
measure_returns = measure["returns"]
assert len(measure_args) == len(measure_returns) == 4
assert measure_args.index(["q", 0]) == measure_returns.index(["c", 0])
assert measure_args.index(["q", 1]) == measure_returns.index(["c", 1])
assert measure_args.index(["q", 2]) == measure_returns.index(["c", 2])
assert measure_args.index(["q", 3]) == measure_returns.index(["c", 3])
114 changes: 40 additions & 74 deletions tests/test_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,84 +8,50 @@

# mypy: disable-error-code="misc"

import json
import logging
from typing import Any

from pytket.phir.phirgen_parallel import genphir_parallel
from pytket.phir.place_and_route import place_and_route
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine
from pytket.phir.sharding.sharder import Sharder

from .sample_data import QasmFile, get_qasm_as_circuit
from .test_utils import QasmFile, get_phir_json

logger = logging.getLogger(__name__)


def get_phir_json(qasmfile: QasmFile) -> dict[str, Any]:
"""Get the QASM file for the specified circuit."""
qtm_machine = QtmMachine.H1_1
circuit = get_qasm_as_circuit(qasmfile)
circuit = rebase_to_qtm_machine(circuit, qtm_machine.value, 0)
machine = QTM_MACHINES_MAP.get(qtm_machine)
assert machine
sharder = Sharder(circuit)
shards = sharder.shard()
placed = place_and_route(shards, machine)
return json.loads(genphir_parallel(placed, machine)) # type: ignore[no-any-return]


def test_bv_n10() -> None:
def test_parallelization() -> None:
"""Make sure the parallelization is happening properly for the test circuit."""
actual = get_phir_json(QasmFile.parallelization_test)
expected: dict[str, Any] = {
"ops": [
{"data": "qvar_define", "data_type": "qubits", "variable": "q", "size": 4},
{"data": "cvar_define", "data_type": "u32", "variable": "c", "size": 4},
{"//": "Parallel Rz(1)"},
{
"qop": "RZ",
"angles": [[1.0], "pi"],
"args": [["q", 0], ["q", 1], ["q", 2], ["q", 3]],
},
{"//": "Parallel PhasedX(0.5, 0.5)"},
{
"qop": "R1XY",
"angles": [[0.5, 0.5], "pi"],
"args": [["q", 0], ["q", 1], ["q", 2], ["q", 3]],
},
{"//": "Parallel RZZ"},
{
"block": "qparallel",
"ops": [
{
"qop": "RZZ",
"angles": [[0.125], "pi"],
"args": [[["q", 0], ["q", 1]]],
},
{
"qop": "RZZ",
"angles": [[1.0], "pi"],
"args": [[["q", 2], ["q", 3]]],
},
],
},
{"mop": "Transport", "duration": [0.0, "ms"]},
{
"qop": "Measure",
"args": [["q", 0], ["q", 1], ["q", 2], ["q", 3]],
"returns": [["c", 0], ["c", 1], ["c", 2], ["c", 3]],
},
{"mop": "Transport", "duration": [0.0, "ms"]},
],
}

assert actual["ops"][7]["block"] == "qparallel"
for op in expected["ops"][7]["ops"]:
assert op in actual["ops"][7]["ops"]

act_meas_op = actual["ops"][9]
assert act_meas_op["qop"] == "Measure"
assert sorted(act_meas_op["args"]) == expected["ops"][9]["args"]
assert sorted(act_meas_op["returns"]) == expected["ops"][9]["returns"]
phir = get_phir_json(QasmFile.parallelization_test, rebase=True)

# Make sure The parallel RZ and R1XY gates have the correct arguments
parallel_rz1 = phir["ops"][3]
assert parallel_rz1["qop"] == "RZ"
qubits = [["q", 0], ["q", 1], ["q", 2], ["q", 3]]
for qubit in qubits:
assert qubit in parallel_rz1["args"]
parallel_phasedx = phir["ops"][5]
assert parallel_phasedx["qop"] == "R1XY"
for qubit in qubits:
assert qubit in parallel_phasedx["args"]

# Make sure the parallel block is properly formatted
block = phir["ops"][7]
assert block["block"] == "qparallel"
assert len(block["ops"]) == 2
qop0 = block["ops"][0]
qop1 = block["ops"][1]
assert qop0["qop"] == qop1["qop"] == "RZZ"

# Make sure the ops within the parallel block have the correct arguments
assert len(qop0["args"][0]) == len(qop1["args"][0]) == 2
q01_fst = (["q", 0] in qop0["args"][0]) and (["q", 1] in qop0["args"][0])
q01_snd = (["q", 0] in qop1["args"][0]) and (["q", 1] in qop1["args"][0])
q23_fst = (["q", 2] in qop0["args"][0]) and (["q", 3] in qop0["args"][0])
q23_snd = (["q", 2] in qop1["args"][0]) and (["q", 3] in qop1["args"][0])
assert (q01_fst and q23_snd) != (q23_fst and q01_snd)

# Make sure the measure op is properly formatted
measure = phir["ops"][9]
measure_args = measure["args"]
measure_returns = measure["returns"]
assert len(measure_args) == len(measure_returns) == 4
assert measure_args.index(["q", 0]) == measure_returns.index(["c", 0])
assert measure_args.index(["q", 1]) == measure_returns.index(["c", 1])
assert measure_args.index(["q", 2]) == measure_returns.index(["c", 2])
assert measure_args.index(["q", 3]) == measure_returns.index(["c", 3])
2 changes: 1 addition & 1 deletion tests/test_rebaser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytket.circuit import Circuit, OpType
from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine

from .sample_data import QasmFile, get_qasm_as_circuit
from .test_utils import QasmFile, get_qasm_as_circuit

EXPECTED_GATES = [
OpType.Measure,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytket.circuit import Conditional, Op, OpType
from pytket.phir.sharding.sharder import Sharder

from .sample_data import QasmFile, get_qasm_as_circuit
from .test_utils import QasmFile, get_qasm_as_circuit


class TestSharder:
Expand Down
25 changes: 23 additions & 2 deletions tests/sample_data.py → tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
#
##############################################################################

import json
from enum import Enum, auto
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from pytket.phir.phirgen_parallel import genphir_parallel
from pytket.phir.place_and_route import place_and_route
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine
from pytket.phir.sharding.sharder import Sharder
from pytket.qasm.qasm import circuit_from_qasm

if TYPE_CHECKING:
Expand All @@ -32,7 +38,8 @@ class QasmFile(Enum):
oned_brickwork_circuit_n20 = auto()
qv20_0 = auto()
parallelization_test = auto()
tk2 = auto()
tk2_same_angle = auto()
tk2_diff_angles = auto()


def get_qasm_as_circuit(qasm_file: QasmFile) -> "Circuit":
Expand All @@ -46,3 +53,17 @@ def get_qasm_as_circuit(qasm_file: QasmFile) -> "Circuit":
"""
this_dir = Path(Path(__file__).resolve()).parent
return circuit_from_qasm(f"{this_dir}/data/qasm/{qasm_file.name}.qasm")


def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> dict[str, Any]: # type: ignore[misc]
"""Get the QASM file for the specified circuit."""
qtm_machine = QtmMachine.H1_1
circuit = get_qasm_as_circuit(qasmfile)
if rebase:
circuit = rebase_to_qtm_machine(circuit, qtm_machine.value, 0)
machine = QTM_MACHINES_MAP.get(qtm_machine)
assert machine
sharder = Sharder(circuit)
shards = sharder.shard()
placed = place_and_route(shards, machine)
return json.loads(genphir_parallel(placed, machine)) # type: ignore[misc, no-any-return]

0 comments on commit 63830bc

Please sign in to comment.