From bac76d692afe921169786b709dd677fa41af3c8f Mon Sep 17 00:00:00 2001 From: Asa-Kosto-QTM <108833721+Asa-Kosto-QTM@users.noreply.github.com> Date: Wed, 10 Apr 2024 09:41:46 -0600 Subject: [PATCH] Support for `ExplicitPredicate`, `ExplicitModifier, `MultiBitOp` (#162) * explicit predicates for common logical ops and multi-bit ops supported + unit tests * feat: add support for ExplicitPredicate/Modifier & MultiBitOp * fix(phirgen): handle ops inside a conditional correctly * feat(phirgen): add support for irregular multibit ops --------- Co-authored-by: Kartik Singhal Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- pytket/phir/phirgen.py | 223 +++++++++++++++++------- pytket/phir/sharding/sharder.py | 2 + tests/test_api.py | 23 --- tests/test_phirgen.py | 292 ++++++++++++++++++++++++++++---- 5 files changed, 430 insertions(+), 112 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 60c8cd98..edbfe907 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: debug-statements - repo: https://github.com/crate-ci/typos - rev: v1.20.4 + rev: v1.20.7 hooks: - id: typos diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index 7dfa9585..0a9941a3 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -10,6 +10,7 @@ import json import logging +from copy import deepcopy from importlib.metadata import version from typing import TYPE_CHECKING, Any, TypeAlias @@ -218,6 +219,139 @@ def convert_gate(op: tk.Op, cmd: tk.Command) -> JsonDict | None: return qop +def cop_from_op_name(op_name: str) -> str: + """Get PHIR classical op name from pytket op name.""" + match op_name: + case "AND": + cop = "&" + case "OR": + cop = "|" + case "XOR": + cop = "^" + case "NOT": + cop = "~" + case name: + raise NotImplementedError(name) + return cop + + +def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict | None: + """Return PHIR dict for a pytket ClassicalEvalOp.""" + # Exclude conditional bits from args + args = cmd.args[cmd.op.width :] if isinstance(cmd.op, tk.Conditional) else cmd.args + out: JsonDict | None = None + match op: + case tk.CopyBitsOp(): + if len(cmd.bits) != len(args) // 2: + msg = "LHS and RHS lengths mismatch for CopyBits" + raise TypeError(msg) + out = assign_cop( + [arg_to_bit(bit) for bit in cmd.bits], + [arg_to_bit(args[i]) for i in range(len(args) // 2)], + ) + case tk.SetBitsOp(): + if len(cmd.bits) != len(op.values): + logger.error("LHS and RHS lengths mismatch for classical assignment") + raise ValueError + out = assign_cop( + [arg_to_bit(bit) for bit in cmd.bits], list(map(int, op.values)) + ) + case tk.RangePredicateOp(): # where the condition is a range + cond: JsonDict + match op.lower, op.upper: + case l, u if l == u: + cond = { + "cop": "==", + "args": [args[0].reg_name, u], + } + case l, u if u == UINTMAX: + cond = { + "cop": ">=", + "args": [args[0].reg_name, l], + } + case 0, u: + cond = { + "cop": "<=", + "args": [args[0].reg_name, u], + } + out = { + "block": "if", + "condition": cond, + "true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])], + } + case tk.MultiBitOp(): + if len(args) % len(cmd.bits) != 0: + msg = "Input bit- and output bit lengths mismatch." + raise TypeError(msg) + + cop = cop_from_op_name(op.basic_op.get_name()) + is_explicit = op.basic_op.type == tk.OpType.ExplicitPredicate + + # determine number of register operands involved in the operation + operand_count = len(args) // len(cmd.bits) - is_explicit + + iters = [iter(args)] * (operand_count + is_explicit) + iter2 = deepcopy(iters) + + # Columns of expressions, e.g., + # AND (*2) a[0], b[0], c[0] + # , a[1], b[1], c[1] + # would be [(a[0], a[1]), (b[0], b[1]), (c[0], c[1])] + # and AND (*2) a[0], a[1], b[0] + # , b[1], c[0], c[1] + # would be [(a[0], b[1]), (a[1], c[0]), (b[0], c[1])] + cols = zip(*zip(*iters, strict=True), strict=True) + + if all( + all(col[0].reg_name == bit.reg_name for bit in col) for col in cols + ): # expression can be applied register-wise + out = assign_cop( + [cmd.bits[0].reg_name], + [ + { + "cop": cop, + "args": [arg.reg_name for arg in args[:operand_count]], + } + ], + ) + else: # apply a sequence of bit-wise ops + exps = zip(*iter2, strict=True) + out = { + "block": "sequence", + "ops": [ + assign_cop( + [arg_to_bit(bit)], + [ + { + "cop": cop, + "args": [ + arg_to_bit(arg) for arg in exp[:operand_count] + ], + } + ], + ) + for bit, exp in zip(cmd.bits, exps, strict=True) + ], + } + case _: + raise NotImplementedError(op) + + return out + + +def multi_bit_condition(args: "list[UnitID]", value: int) -> JsonDict: + """Construct bitwise condition.""" + return { + "cop": "&", + "args": [ + {"cop": "==", "args": [arg_to_bit(arg), bval]} + for (arg, bval) in zip( + args[::-1], map(int, f"{value:0{len(args)}b}"), strict=True + ) + ], + } + + def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: """Return PHIR dict given a tket op and its arguments.""" if op.is_gate(): @@ -225,6 +359,15 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: out: JsonDict | None = None match op: # non-quantum op + case tk.Conditional(): + out = { + "block": "if", + "condition": {"cop": "==", "args": [arg_to_bit(cmd.args[0]), op.value]} + if op.width == 1 + else multi_bit_condition(cmd.args[: op.width], op.value), + "true_branch": [convert_subcmd(op.op, cmd)], + } + case tk.BarrierOp(): if op.data: # See https://github.com/CQCL/tket/blob/0ec603986821d994caa3a0fb9c4640e5bc6c0a24/pytket/pytket/qasm/qasm.py#L419-L459 @@ -246,45 +389,6 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: "args": [arg_to_bit(qbit) for qbit in cmd.qubits], } - case tk.Conditional(): # where the condition is equality check - out = { - "block": "if", - "condition": { - "cop": "==", - "args": [ - arg_to_bit(cmd.args[0]) - if op.width == 1 - else cmd.args[0].reg_name, - op.value, - ], - }, - "true_branch": [convert_subcmd(op.op, cmd)], - } - - case tk.RangePredicateOp(): # where the condition is a range - cond: JsonDict - match op.lower, op.upper: - case l, u if l == u: - cond = { - "cop": "==", - "args": [cmd.args[0].reg_name, u], - } - case l, u if u == UINTMAX: - cond = { - "cop": ">=", - "args": [cmd.args[0].reg_name, l], - } - case 0, u: - cond = { - "cop": "<=", - "args": [cmd.args[0].reg_name, u], - } - out = { - "block": "if", - "condition": cond, - "true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])], - } - case tk.ClassicalExpBox(): exp = op.get_exp() match exp: @@ -295,29 +399,34 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: rhs = [classical_op(exp)] out = assign_cop([cmd.bits[0].reg_name], rhs) - case tk.SetBitsOp(): - if len(cmd.bits) != len(op.values): - logger.error("LHS and RHS lengths mismatch for classical assignment") - raise ValueError - out = assign_cop( - [arg_to_bit(bit) for bit in cmd.bits], list(map(int, op.values)) - ) - - case tk.CopyBitsOp(): - if len(cmd.bits) != len(cmd.args) // 2: - logger.warning("LHS and RHS lengths mismatch for CopyBits") - out = assign_cop( - [arg_to_bit(bit) for bit in cmd.bits], - [arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)], - ) + case tk.ClassicalEvalOp(): + return convert_classicalevalop(op, cmd) case tk.WASMOp(): return create_wasm_op(cmd, op) case _: - # TODO(kartik): NYI - # https://github.com/CQCL/pytket-phir/issues/25 - raise NotImplementedError + # Exclude conditional bits from args + args = ( + cmd.args[cmd.op.width :] + if isinstance(cmd.op, tk.Conditional) + else cmd.args + ) + match op.type: + case tk.OpType.ExplicitPredicate | tk.OpType.ExplicitModifier: + # exclude output bit when not modifying in place + args = args[:-1] if op.type == tk.OpType.ExplicitPredicate else args + out = assign_cop( + [arg_to_bit(cmd.bits[0])], + [ + { + "cop": cop_from_op_name(op.get_name()), + "args": [arg_to_bit(arg) for arg in args], + } + ], + ) + case _: + raise NotImplementedError(op.type) return out diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 6e31004c..5f872f61 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -23,6 +23,8 @@ OpType.ClassicalExpBox, # some classical operations are rolled up into a box OpType.RangePredicate, OpType.ExplicitPredicate, + OpType.ExplicitModifier, + OpType.MultiBit, OpType.CopyBits, OpType.WASM, ] diff --git a/tests/test_api.py b/tests/test_api.py index 84e65af6..f0372ce2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,12 +8,10 @@ # mypy: disable-error-code="misc" -import json import logging import pytest -from pytket.circuit import Bit, Circuit from pytket.phir.api import pytket_to_phir, qasm_to_phir from pytket.phir.qtm_machine import QtmMachine @@ -50,27 +48,6 @@ def test_pytket_to_phir_h1_1_all(self, test_file: QasmFile) -> None: assert pytket_to_phir(circuit, QtmMachine.H1) - def test_pytket_classical_only(self) -> None: - c = Circuit(1) - a = c.add_c_register("a", 2) - b = c.add_c_register("b", 3) - - c.add_c_copyreg(a, b) - c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)]) - - phir = json.loads(pytket_to_phir(c)) - - assert phir["ops"][3] == { - "cop": "=", - "returns": [["b", 0], ["b", 1]], - "args": [["a", 0], ["a", 1]], - } - assert phir["ops"][5] == { - "cop": "=", - "returns": [["a", 0], ["b", 0]], - "args": [["b", 2], ["a", 1]], - } - def test_qasm_to_phir(self) -> None: """Test the qasm string entrypoint works.""" qasm = """ diff --git a/tests/test_phirgen.py b/tests/test_phirgen.py index 904c1345..9ec06697 100644 --- a/tests/test_phirgen.py +++ b/tests/test_phirgen.py @@ -10,13 +10,98 @@ import json -from pytket.circuit import Circuit +from pytket.circuit import Bit, Circuit from pytket.phir.api import pytket_to_phir from pytket.qasm.qasm import circuit_from_qasm_str +from pytket.unit_id import BitRegister from .test_utils import QasmFile, get_qasm_as_circuit +def test_multiple_sleep() -> None: + """Ensure multiple sleep ops get converted correctly.""" + qasm = """ + OPENQASM 2.0; + include "hqslib1_dev.inc"; + + qreg q[2]; + + sleep(1) q[0]; + sleep(2) q[1]; + """ + circ = circuit_from_qasm_str(qasm) + phir = json.loads(pytket_to_phir(circ)) + assert phir["ops"][2] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]} + assert phir["ops"][4] == {"mop": "Idle", "args": [["q", 1]], "duration": [2.0, "s"]} + + +def test_simple_cond_classical() -> None: + """Ensure conditional classical operation are correctly generated.""" + circ = get_qasm_as_circuit(QasmFile.simple_cond) + phir = json.loads(pytket_to_phir(circ)) + assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"} + assert phir["ops"][-5] == { + "block": "if", + "condition": {"cop": "==", "args": [["c", 0], 1]}, + "true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}], + } + + +def test_pytket_classical_only() -> None: + """From https://github.com/CQCL/pytket-phir/issues/61 .""" + c = Circuit(1) + a = c.add_c_register("a", 2) + b = c.add_c_register("b", 3) + + c.add_c_copyreg(a, b) + c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)]) + c.add_c_copybits( + [Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)], condition=Bit("b", 1) + ) + c.add_c_copybits( + [Bit("a", 0), Bit("a", 1)], # type: ignore[list-item] # overloaded function + [Bit("b", 0), Bit("b", 1)], # type: ignore[list-item] # overloaded function + condition_bits=[Bit("b", 1), Bit("b", 2)], + condition_value=2, + ) + + phir = json.loads(pytket_to_phir(c)) + + assert phir["ops"][3] == { + "cop": "=", + "returns": [["b", 0], ["b", 1]], + "args": [["a", 0], ["a", 1]], + } + assert phir["ops"][5] == { + "cop": "=", + "returns": [["a", 0], ["b", 0]], + "args": [["b", 2], ["a", 1]], + } + assert phir["ops"][7] == { + "block": "if", + "condition": {"cop": "==", "args": [["b", 1], 1]}, + "true_branch": [ + {"cop": "=", "returns": [["a", 0], ["b", 0]], "args": [["b", 2], ["a", 1]]} + ], + } + assert phir["ops"][8] == { + "//": "IF ([b[1], b[2]] == 2) THEN CopyBits a[0], a[1], b[0], b[1];" + } + assert phir["ops"][9] == { + "block": "if", + "condition": { + "cop": "&", + "args": [ + {"cop": "==", "args": [["b", 2], 1]}, + {"cop": "==", "args": [["b", 1], 0]}, + ], + }, + "true_branch": [ + {"cop": "=", "returns": [["b", 0], ["b", 1]], "args": [["a", 0], ["a", 1]]} + ], + } + + def test_classicalexpbox() -> None: """From https://github.com/CQCL/pytket-phir/issues/86 .""" circ = Circuit(1) @@ -87,23 +172,17 @@ def test_conditional_barrier() -> None: assert phir["ops"][4] == {"//": "IF ([m[0], m[1]] == 0) THEN Barrier q[0], q[1];"} assert phir["ops"][5] == { "block": "if", - "condition": {"cop": "==", "args": ["m", 0]}, + "condition": { + "cop": "&", + "args": [ + {"cop": "==", "args": [["m", 1], 0]}, + {"cop": "==", "args": [["m", 0], 0]}, + ], + }, "true_branch": [{"meta": "barrier", "args": [["q", 0], ["q", 1]]}], } -def test_simple_cond_classical() -> None: - """Ensure conditional classical operation are correctly generated.""" - circ = get_qasm_as_circuit(QasmFile.simple_cond) - phir = json.loads(pytket_to_phir(circ)) - assert phir["ops"][-6] == {"//": "IF ([c[0]] == 1) THEN SetBits(1) z[0];"} - assert phir["ops"][-5] == { - "block": "if", - "condition": {"cop": "==", "args": [["c", 0], 1]}, - "true_branch": [{"cop": "=", "returns": [["z", 0]], "args": [1]}], - } - - def test_nested_bitwise_op() -> None: """From https://github.com/CQCL/pytket-phir/issues/133 .""" circ = Circuit(4) @@ -137,23 +216,6 @@ def test_sleep_idle() -> None: assert phir["ops"][7] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]} -def test_multiple_sleep() -> None: - """Ensure multiple sleep ops get converted correctly.""" - qasm = """ - OPENQASM 2.0; - include "hqslib1_dev.inc"; - - qreg q[2]; - - sleep(1) q[0]; - sleep(2) q[1]; - """ - circ = circuit_from_qasm_str(qasm) - phir = json.loads(pytket_to_phir(circ)) - assert phir["ops"][2] == {"mop": "Idle", "args": [["q", 0]], "duration": [1.0, "s"]} - assert phir["ops"][4] == {"mop": "Idle", "args": [["q", 1]], "duration": [2.0, "s"]} - - def test_reordering_classical_conditional() -> None: """From https://github.com/CQCL/pytket-phir/issues/150 .""" circuit = Circuit(1) @@ -201,3 +263,171 @@ def test_conditional_measure() -> None: "condition": {"cop": "==", "args": [["c", 0], 1]}, "true_branch": [{"qop": "Measure", "returns": [["c", 1]], "args": [["q", 1]]}], } + + +def test_conditional_classical_not() -> None: + """From https://github.com/CQCL/pytket-phir/issues/159 .""" + circuit = Circuit() + target_reg = circuit.add_c_register(BitRegister(name="target_reg", size=1)) + control_reg = circuit.add_c_register(BitRegister(name="control_reg", size=1)) + + circuit.add_c_not( + arg_in=target_reg[0], arg_out=target_reg[0], condition=control_reg[0] + ) + + phir = json.loads(pytket_to_phir(circuit)) + assert phir["ops"][-1] == { + "block": "if", + "condition": {"cop": "==", "args": [["control_reg", 0], 1]}, + "true_branch": [ + { + "cop": "=", + "returns": [["target_reg", 0]], + "args": [{"cop": "~", "args": [["target_reg", 0]]}], + } + ], + } + + +def test_explicit_classical_ops() -> None: + """Test explicit predicates and modifiers.""" + # From https://github.com/CQCL/tket/blob/a2f6fab8a57da8787dfae94764b7c3a8e5779024/pytket/tests/classical_test.py#L97-L101 + c = Circuit(0, 4) + # predicates + c.add_c_and(1, 2, 3) + c.add_c_not(0, 1) + c.add_c_xor(1, 2, 3) + # modifiers + c.add_c_and(2, 3, 3) + c.add_c_or(0, 3, 0) + phir = json.loads(pytket_to_phir(c)) + assert phir["ops"][1] == {"//": "AND c[1], c[2], c[3];"} + assert phir["ops"][2] == { + "cop": "=", + "returns": [["c", 3]], + "args": [{"cop": "&", "args": [["c", 1], ["c", 2]]}], + } + assert phir["ops"][3] == {"//": "NOT c[0], c[1];"} + assert phir["ops"][4] == { + "cop": "=", + "returns": [["c", 1]], + "args": [{"cop": "~", "args": [["c", 0]]}], + } + assert phir["ops"][5] == {"//": "XOR c[1], c[2], c[3];"} + assert phir["ops"][6] == { + "cop": "=", + "returns": [["c", 3]], + "args": [{"cop": "^", "args": [["c", 1], ["c", 2]]}], + } + assert phir["ops"][7] == {"//": "AND c[2], c[3];"} + assert phir["ops"][8] == { + "cop": "=", + "returns": [["c", 3]], + "args": [{"cop": "&", "args": [["c", 2], ["c", 3]]}], + } + assert phir["ops"][9] == {"//": "OR c[3], c[0];"} + assert phir["ops"][10] == { + "cop": "=", + "returns": [["c", 0]], + "args": [{"cop": "|", "args": [["c", 3], ["c", 0]]}], + } + + +def test_multi_bit_ops() -> None: + """Test classical ops added to the circuit via tket multi-bit ops.""" + # Test from https://github.com/CQCL/tket/blob/a2f6fab8a57da8787dfae94764b7c3a8e5779024/pytket/tests/classical_test.py#L107-L112 + c = Circuit(0, 4) + c0 = c.add_c_register("c0", 3) + c1 = c.add_c_register("c1", 4) + c2 = c.add_c_register("c2", 5) + # predicates + c.add_c_and_to_registers(c0, c1, c2) + c.add_c_not_to_registers(c1, c2) + c.add_c_or_to_registers(c0, c1, c2) + # modifier + c.add_c_xor_to_registers(c2, c1, c2) + # conditionals + c.add_c_not_to_registers(c1, c2, condition=Bit("c0", 0)) + c.add_c_not_to_registers(c1, c1, condition=Bit("c0", 0)) + phir = json.loads(pytket_to_phir(c)) + assert phir["ops"][3] == { + "//": "AND (*3) c0[0], c1[0], c2[0], c0[1], c1[1], c2[1], c0[2], c1[2], c2[2];" + } + assert phir["ops"][4] == { + "cop": "=", + "returns": ["c2"], + "args": [{"cop": "&", "args": ["c0", "c1"]}], + } + assert phir["ops"][5] == { + "//": "NOT (*4) c1[0], c2[0], c1[1], c2[1], c1[2], c2[2], c1[3], c2[3];" + } + assert phir["ops"][6] == { + "cop": "=", + "returns": ["c2"], + "args": [{"cop": "~", "args": ["c1"]}], + } + assert phir["ops"][7] == { + "//": "OR (*3) c0[0], c1[0], c2[0], c0[1], c1[1], c2[1], c0[2], c1[2], c2[2];" + } + assert phir["ops"][8] == { + "cop": "=", + "returns": ["c2"], + "args": [{"cop": "|", "args": ["c0", "c1"]}], + } + assert phir["ops"][9] == { + "//": "XOR (*4) c1[0], c2[0], c1[1], c2[1], c1[2], c2[2], c1[3], c2[3];" + } + assert phir["ops"][10] == { + "cop": "=", + "returns": ["c2"], + "args": [{"cop": "^", "args": ["c1", "c2"]}], + } + assert phir["ops"][12] == { + "block": "if", + "condition": {"cop": "==", "args": [["c0", 0], 1]}, + "true_branch": [ + {"cop": "=", "returns": ["c2"], "args": [{"cop": "~", "args": ["c1"]}]} + ], + } + assert phir["ops"][14] == { + "block": "if", + "condition": {"cop": "==", "args": [["c0", 0], 1]}, + "true_branch": [ + {"cop": "=", "returns": ["c1"], "args": [{"cop": "~", "args": ["c1"]}]} + ], + } + + +def test_irregular_multibit_ops() -> None: + """From https://github.com/CQCL/pytket-phir/pull/162#discussion_r1555807863 .""" + c = Circuit() + areg = c.add_c_register("a", 2) + breg = c.add_c_register("b", 2) + creg = c.add_c_register("c", 2) + c.add_c_and_to_registers(areg, breg, creg) + mbop = c.get_commands()[0].op + c.add_gate(mbop, [areg[0], areg[1], breg[0], breg[1], creg[0], creg[1]]) + + phir = json.loads(pytket_to_phir(c)) + assert phir["ops"][3] == {"//": "AND (*2) a[0], b[0], c[0], a[1], b[1], c[1];"} + assert phir["ops"][4] == { + "cop": "=", + "returns": ["c"], + "args": [{"cop": "&", "args": ["a", "b"]}], + } + assert phir["ops"][5] == {"//": "AND (*2) a[0], a[1], b[0], b[1], c[0], c[1];"} + assert phir["ops"][6] == { + "block": "sequence", + "ops": [ + { + "cop": "=", + "returns": [["b", 0]], + "args": [{"cop": "&", "args": [["a", 0], ["a", 1]]}], + }, + { + "cop": "=", + "returns": [["c", 1]], + "args": [{"cop": "&", "args": [["b", 1], ["c", 0]]}], + }, + ], + }