From f90884aec4d4ab96aaa65146e9c626c439e458fa Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Thu, 28 Mar 2024 10:47:19 -0500 Subject: [PATCH 1/6] fix(phirgen): handle both WASM and non-WASM conditional comments I had unknowingly removed the Conditional case in #152, it was needed for WASM conditional handling --- pytket/phir/phirgen.py | 11 +- requirements.txt | 2 +- tests/test_phirgen.py | 1 + tests/test_wasm.py | 222 +++++++++++++++++++++-------------------- 4 files changed, 125 insertions(+), 111 deletions(-) diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index de76eaa..992485c 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -375,9 +375,18 @@ def make_comment_text(cmd: tk.Command, op: tk.Op) -> str: """Converts a command + op to the PHIR comment spec.""" comment = str(cmd) match op: + case tk.Conditional(): + conditional_text = str(cmd) + cleaned = ( + conditional_text[: conditional_text.find("THEN") + 5] + if isinstance(op.op, tk.WASMOp) + else "" + ) + comment = f"{cleaned}{make_comment_text(cmd, op.op)}" + case tk.WASMOp(): args, returns = extract_wasm_args_and_returns(cmd, op) - comment = f"WASM function={op.func_name} args={args} returns={returns}" + comment = f"WASM_function={op.func_name} args={args} returns={returns};" case tk.BarrierOp(): comment = op.data + " " + str(cmd.args[0]) + ";" if op.data else str(cmd) diff --git a/requirements.txt b/requirements.txt index dbb78e3..a77ba00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -build==1.1.1 +build==1.2.1 mypy==1.9.0 networkx==2.8.8 phir==0.3.2 diff --git a/tests/test_phirgen.py b/tests/test_phirgen.py index 6002a07..fd362d3 100644 --- a/tests/test_phirgen.py +++ b/tests/test_phirgen.py @@ -183,6 +183,7 @@ def test_conditional_measure() -> None: c.Measure(0, 0) c.Measure(1, 1, condition_bits=[0], condition_value=1) phir = json.loads(pytket_to_phir(c)) + assert phir["ops"][-2] == {"//": "IF ([c[0]] == 1) THEN Measure q[1] --> c[1];"} assert phir["ops"][-1] == { "block": "if", "condition": {"cop": "==", "args": [["c", 0], 1]}, diff --git a/tests/test_wasm.py b/tests/test_wasm.py index b826e14..b6d5685 100644 --- a/tests/test_wasm.py +++ b/tests/test_wasm.py @@ -27,112 +27,116 @@ logger = logging.getLogger(__name__) -class TestWASM: - def test_qasm_to_phir_with_wasm(self) -> None: - """Test the qasm string entrypoint works with WASM.""" - qasm = """ - OPENQASM 2.0; - include "qelib1.inc"; - - qreg q[2]; - h q; - ZZ q[1],q[0]; - creg cr[3]; - creg cs[3]; - creg co[3]; - measure q[0]->cr[0]; - measure q[1]->cr[1]; - - cs = cr; - co = add(cr, cs); - """ - - wasm_bytes = get_wat_as_wasm_bytes(WatFile.add) - - wasm_uid = hashlib.sha256(base64.b64encode(wasm_bytes)).hexdigest() - - phir_str = qasm_to_phir(qasm, QtmMachine.H1, wasm_bytes=wasm_bytes) - phir = json.loads(phir_str) - - expected_metadata = {"ff_object": (f"WASM module uid: {wasm_uid}")} - - assert phir["ops"][21] == { - "metadata": expected_metadata, - "cop": "ffcall", - "function": "add", - "args": ["cr", "cs"], - "returns": ["co"], - } - - @pytest.mark.order("first") - def test_pytket_with_wasm(self) -> None: - wasm_bytes = get_wat_as_wasm_bytes(WatFile.testfile) - phir_str: str - try: - wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False) - wasm_file.write(wasm_bytes) - wasm_file.flush() - wasm_file.close() - - w = WasmFileHandler(wasm_file.name) - - c = Circuit(6, 6) - c0 = c.add_c_register("c0", 3) - c1 = c.add_c_register("c1", 4) - c2 = c.add_c_register("c2", 5) - - c.add_wasm_to_reg("multi", w, [c0, c1], [c2]) - c.add_wasm_to_reg("add_one", w, [c2], [c2]) - c.add_wasm_to_reg("no_return", w, [c2], []) - c.add_wasm_to_reg("no_parameters", w, [], [c2]) - - c.add_wasm_to_reg("add_one", w, [c0], [c0], condition=c1[0]) - - phir_str = pytket_to_phir(c, QtmMachine.H1) - finally: - Path.unlink(Path(wasm_file.name)) - - phir = json.loads(phir_str) - - expected_metadata = {"ff_object": (f"WASM module uid: {w!s}")} - - assert phir["ops"][4] == { - "metadata": expected_metadata, - "cop": "ffcall", - "function": "multi", - "args": ["c0", "c1"], - "returns": ["c2"], - } - assert phir["ops"][7] == { - "metadata": expected_metadata, - "cop": "ffcall", - "function": "add_one", - "args": ["c2"], - "returns": ["c2"], - } - assert phir["ops"][9] == { - "block": "if", - "condition": {"cop": "==", "args": [["c1", 0], 1]}, - "true_branch": [ - { - "metadata": expected_metadata, - "cop": "ffcall", - "returns": ["c0"], - "function": "add_one", - "args": ["c1", "c0"], - } - ], - } - assert phir["ops"][12] == { - "metadata": expected_metadata, - "cop": "ffcall", - "function": "no_return", - "args": ["c2"], - } - assert phir["ops"][15] == { - "metadata": expected_metadata, - "cop": "ffcall", - "function": "no_parameters", - "args": [], - "returns": ["c2"], - } +def test_qasm_to_phir_with_wasm() -> None: + """Test the qasm string entrypoint works with WASM.""" + qasm = """ + OPENQASM 2.0; + include "qelib1.inc"; + + qreg q[2]; + h q; + ZZ q[1],q[0]; + creg cr[3]; + creg cs[3]; + creg co[3]; + measure q[0]->cr[0]; + measure q[1]->cr[1]; + + cs = cr; + co = add(cr, cs); + """ + + wasm_bytes = get_wat_as_wasm_bytes(WatFile.add) + + wasm_uid = hashlib.sha256(base64.b64encode(wasm_bytes)).hexdigest() + + phir_str = qasm_to_phir(qasm, QtmMachine.H1, wasm_bytes=wasm_bytes) + phir = json.loads(phir_str) + + expected_metadata = {"ff_object": (f"WASM module uid: {wasm_uid}")} + + assert phir["ops"][21] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "add", + "args": ["cr", "cs"], + "returns": ["co"], + } + + +@pytest.mark.order("first") +def test_pytket_with_wasm() -> None: + """Test whether pytket works with WASM.""" + wasm_bytes = get_wat_as_wasm_bytes(WatFile.testfile) + phir_str: str + try: + wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False) + wasm_file.write(wasm_bytes) + wasm_file.flush() + wasm_file.close() + + w = WasmFileHandler(wasm_file.name) + + c = Circuit(6, 6) + c0 = c.add_c_register("c0", 3) + c1 = c.add_c_register("c1", 4) + c2 = c.add_c_register("c2", 5) + + c.add_wasm_to_reg("multi", w, [c0, c1], [c2]) + c.add_wasm_to_reg("add_one", w, [c2], [c2]) + c.add_wasm_to_reg("no_return", w, [c2], []) + c.add_wasm_to_reg("no_parameters", w, [], [c2]) + + c.add_wasm_to_reg("add_one", w, [c0], [c0], condition=c1[0]) + + phir_str = pytket_to_phir(c, QtmMachine.H1) + finally: + Path.unlink(Path(wasm_file.name)) + + phir = json.loads(phir_str) + + expected_metadata = {"ff_object": (f"WASM module uid: {w!s}")} + + assert phir["ops"][4] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "multi", + "args": ["c0", "c1"], + "returns": ["c2"], + } + assert phir["ops"][7] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "add_one", + "args": ["c2"], + "returns": ["c2"], + } + assert phir["ops"][8] == { + "//": "IF ([c1[0]] == 1) THEN WASM_function=add_one args=['c1', 'c0'] returns=['c0'];" # noqa: E501 + } + assert phir["ops"][9] == { + "block": "if", + "condition": {"cop": "==", "args": [["c1", 0], 1]}, + "true_branch": [ + { + "metadata": expected_metadata, + "cop": "ffcall", + "returns": ["c0"], + "function": "add_one", + "args": ["c1", "c0"], + } + ], + } + assert phir["ops"][12] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "no_return", + "args": ["c2"], + } + assert phir["ops"][15] == { + "metadata": expected_metadata, + "cop": "ffcall", + "function": "no_parameters", + "args": [], + "returns": ["c2"], + } From be5ce7eeaee9bef0b13ad5360df667df8835f373 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Thu, 28 Mar 2024 11:12:59 -0500 Subject: [PATCH 2/6] test: fixup wasm conditional tests --- pytket/phir/phirgen.py | 2 +- tests/test_wasm.py | 43 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index 992485c..f42737b 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -386,7 +386,7 @@ def make_comment_text(cmd: tk.Command, op: tk.Op) -> str: case tk.WASMOp(): args, returns = extract_wasm_args_and_returns(cmd, op) - comment = f"WASM_function={op.func_name} args={args} returns={returns};" + comment = f"WASM_function='{op.func_name}' args={args} returns={returns};" case tk.BarrierOp(): comment = op.data + " " + str(cmd.args[0]) + ";" if op.data else str(cmd) diff --git a/tests/test_wasm.py b/tests/test_wasm.py index b6d5685..87cf0a8 100644 --- a/tests/test_wasm.py +++ b/tests/test_wasm.py @@ -17,7 +17,7 @@ import pytest -from pytket.circuit import Circuit +from pytket.circuit import Circuit, Qubit from pytket.phir.api import pytket_to_phir, qasm_to_phir from pytket.phir.qtm_machine import QtmMachine from pytket.wasm.wasm import WasmFileHandler @@ -112,7 +112,7 @@ def test_pytket_with_wasm() -> None: "returns": ["c2"], } assert phir["ops"][8] == { - "//": "IF ([c1[0]] == 1) THEN WASM_function=add_one args=['c1', 'c0'] returns=['c0'];" # noqa: E501 + "//": "IF ([c1[0]] == 1) THEN WASM_function='add_one' args=['c0'] returns=['c0'];" # noqa: E501 } assert phir["ops"][9] == { "block": "if", @@ -123,7 +123,7 @@ def test_pytket_with_wasm() -> None: "cop": "ffcall", "returns": ["c0"], "function": "add_one", - "args": ["c1", "c0"], + "args": ["c0"], } ], } @@ -140,3 +140,40 @@ def test_pytket_with_wasm() -> None: "args": [], "returns": ["c2"], } + + +def test_conditional_wasm() -> None: + """From https://github.com/CQCL/pytket-phir/issues/156 .""" + wasm_bytes = get_wat_as_wasm_bytes(WatFile.testfile) + try: + wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False) + wasm_file.write(wasm_bytes) + wasm_file.flush() + wasm_file.close() + + w = WasmFileHandler(wasm_file.name) + + c = Circuit(1) + areg = c.add_c_register("a", 2) + breg = c.add_c_register("b", 1) + c.H(0) + c.Measure(Qubit(0), breg[0]) + c.add_wasm( + funcname="add_one", + filehandler=w, + list_i=[1], + list_o=[1], + args=[areg[0], areg[1]], + args_wasm=[0], + condition_bits=[breg[0]], + condition_value=1, + ) + finally: + Path.unlink(Path(wasm_file.name)) + + phir = json.loads(pytket_to_phir(c)) + + assert phir["ops"][-2] == { + "//": "IF ([b[0]] == 1) THEN WASM_function='add_one' args=['a'] returns=['a'];" + } + assert phir["ops"][-1]["true_branch"][0]["args"] == ["a"] From 0b333110ecf7f6e2d0544585e1d3258f46c12b70 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Thu, 28 Mar 2024 11:41:06 -0500 Subject: [PATCH 3/6] fix(wasm): Eliminate conditional bits from the arg list --- pytket/phir/phirgen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index f42737b..0b5ab0f 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -357,11 +357,13 @@ def extract_wasm_args_and_returns( ) -> tuple[list[str], list[str]]: """Extract the wasm args and return values as whole register names.""" # This slice removes the extra `_w` cregs (wires) that are not part of the - # circuit, and the output args which are appended after the input args + # circuit and the output args, which are appended after the input args slice_index = op.num_w + sum(op.output_widths) only_args = command.args[:-slice_index] + # Eliminate conditional bits from the front of the args + input_args = only_args[len(only_args) - op.n_inputs :] return ( - dedupe_bits_to_registers(only_args), + dedupe_bits_to_registers(input_args), dedupe_bits_to_registers(command.bits), ) @@ -404,8 +406,6 @@ def make_comment_text(cmd: tk.Command, op: tk.Op) -> str: def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]: """Format the qvar and cvar define PHIR elements.""" - # TODO(kartik): this may not always be accurate - # https://github.com/CQCL/pytket-phir/issues/24 qvar_dim: dict[str, int] = {} for qbit in qbits: qvar_dim.setdefault(qbit.reg_name, 0) From d33d74994d5d94940c69a3f6db08cdcca87f8cb2 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Mon, 1 Apr 2024 12:56:17 -0500 Subject: [PATCH 4/6] chore(deps): update ruff version --- .pre-commit-config.yaml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86b1dfc..4ceab40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - black==23.10.1 - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.4 + rev: v0.3.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/requirements.txt b/requirements.txt index a77ba00..462022e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ pydata_sphinx_theme==0.15.2 pytest==8.1.1 pytest-order==1.2.0 pytket==1.26.0 -ruff==0.3.4 +ruff==0.3.5 setuptools_scm==8.0.4 sphinx==7.2.6 wasmtime==19.0.0 From b9371fc18ccb7f8492ca85eb8a4e236b248d1d02 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Tue, 2 Apr 2024 06:57:54 -0500 Subject: [PATCH 5/6] chore(deps): update typos in pre-commit --- .pre-commit-config.yaml | 2 +- pyproject.toml | 3 +++ pytket/phir/phirgen.py | 4 ++-- tests/test_parallelization.py | 24 ++++++++++++------------ 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4ceab40..d218eff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: debug-statements - repo: https://github.com/crate-ci/typos - rev: v1.19.0 + rev: v1.20.1 hooks: - id: typos diff --git a/pyproject.toml b/pyproject.toml index 7a08cbb..3ad5b96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,3 +65,6 @@ version_scheme = "python-simplified-semver" [tool.refurb] python_version = "3.10" + +[tool.typos] +default.extend-words = { lst = "lst" } diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index 0b5ab0f..7dfa958 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -230,11 +230,11 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: # See https://github.com/CQCL/tket/blob/0ec603986821d994caa3a0fb9c4640e5bc6c0a24/pytket/pytket/qasm/qasm.py#L419-L459 match op.data[0:5]: case "sleep": - dur = op.data.removeprefix("sleep(").removesuffix(")") + duration = op.data.removeprefix("sleep(").removesuffix(")") out = { "mop": "Idle", "args": [arg_to_bit(qbit) for qbit in cmd.qubits], - "duration": (float(dur), "s"), + "duration": (float(duration), "s"), } case "order" | "group": raise NotImplementedError(op.data) diff --git a/tests/test_parallelization.py b/tests/test_parallelization.py index 3fe1143..13deb44 100644 --- a/tests/test_parallelization.py +++ b/tests/test_parallelization.py @@ -64,18 +64,18 @@ def test_parallel_subcommand_relative_ordering() -> None: phir = get_phir_json(QasmFile.rxrz, rebase=True) # make sure it is ordered like the qasm file ops = phir["ops"] - frst_sc = ops[3] - scnd_sc = ops[5] - thrd_sc = ops[7] - frth_sc = ops[9] - assert frst_sc["qop"] == "RZ" - assert frst_sc["angles"] == [[0.5], "pi"] - assert scnd_sc["qop"] == "R1XY" - assert scnd_sc["angles"] == [[3.5, 0.0], "pi"] - assert thrd_sc["qop"] == "R1XY" - assert thrd_sc["angles"] == [[0.5, 0.0], "pi"] - assert frth_sc["qop"] == "RZ" - assert frth_sc["angles"] == [[3.5], "pi"] + sc1 = ops[3] + sc2 = ops[5] + sc3 = ops[7] + sc4 = ops[9] + assert sc1["qop"] == "RZ" + assert sc1["angles"] == [[0.5], "pi"] + assert sc2["qop"] == "R1XY" + assert sc2["angles"] == [[3.5, 0.0], "pi"] + assert sc3["qop"] == "R1XY" + assert sc3["angles"] == [[0.5, 0.0], "pi"] + assert sc4["qop"] == "RZ" + assert sc4["angles"] == [[3.5], "pi"] def test_single_qubit_circuit_with_parallel() -> None: From ab17133bf922cb29448863cdac9f18dedff72ef5 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Tue, 2 Apr 2024 07:56:51 -0500 Subject: [PATCH 6/6] test: add conditional classical test --- tests/test_phirgen.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_phirgen.py b/tests/test_phirgen.py index fd362d3..904c134 100644 --- a/tests/test_phirgen.py +++ b/tests/test_phirgen.py @@ -92,6 +92,18 @@ def test_conditional_barrier() -> None: } +def test_simple_cond_classical() -> None: + """Ensure conditional classical operation are correctly generated.""" + circ = get_qasm_as_circuit(QasmFile.simple_cond) + phir = json.loads(pytket_to_phir(circ)) + assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"} + assert phir["ops"][-5] == { + "block": "if", + "condition": {"cop": "==", "args": [["c", 0], 1]}, + "true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}], + } + + def test_nested_bitwise_op() -> None: """From https://github.com/CQCL/pytket-phir/issues/133 .""" circ = Circuit(4)