Skip to content

Commit

Permalink
feat(phirgen): add support for ZERO/ONE nullary ops (#180)
Browse files Browse the repository at this point in the history
Closes: #178
  • Loading branch information
qartik authored May 28, 2024
1 parent 9d93553 commit fc136a1
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
- black==23.10.1

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.4.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
7 changes: 6 additions & 1 deletion pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,14 @@ def assign_cop(
}


def classical_op(exp: LogicExp, *, bitwise: bool = False) -> JsonDict:
def classical_op(exp: LogicExp, *, bitwise: bool = False) -> JsonDict | int:
"""PHIR for classical register operations."""
match exp.op:
# Nullary
case BitWiseOp.ZERO:
return 0
case BitWiseOp.ONE:
return 1
# Bitwise
case RegWiseOp.AND | BitWiseOp.AND:
cop = "&"
Expand Down
14 changes: 7 additions & 7 deletions pytket/phir/phirgen_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def process_sub_commands(
rz_group_number = -3 # set to 0 when first RZ gate is assigned (-3 + 3 = 0)
r1xy_group_number = -2 # set to 1 when first R1XY gate is assigned (-2 + 3 = 1)
other_group_number = -1 # set to 2 when first other gate is assigned (-1 + 3 = 2)
num_scs_per_qubit: dict["UnitID", int] = {}
num_scs_per_qubit: dict[UnitID, int] = {}
group_exec_order: list[int] = []

for qubit, cmds in sub_commands.items():
Expand Down Expand Up @@ -131,14 +131,14 @@ def process_sub_commands(
def groups2qops(groups: dict[int, list[tk.Command]], ops: list["JsonDict"]) -> None: # noqa: PLR0912
"""Convert the groups of parallel ops to properly formatted PHIR."""
for group in groups.values():
angles2qops: dict[tuple[sympy.Expr | float, ...], "JsonDict"] = {}
angles2qops: dict[tuple[sympy.Expr | float, ...], JsonDict] = {}
for qop in group:
if not qop.op.is_gate():
append_cmd(qop, ops)
else:
angles = qop.op.params
if tuple(angles) not in angles2qops:
fmt_qop: "JsonDict" = {
fmt_qop: JsonDict = {
"qop": tket_gate_to_phir[qop.op.type],
"angles": [angles, "pi"],
}
Expand All @@ -159,7 +159,7 @@ def groups2qops(groups: dict[int, list[tk.Command]], ops: list["JsonDict"]) -> N
# this branch is skipped because non-gate sub-commands
# are always the only member of their group (see process_sub_commands)
if len(angles2qops) > 1:
pll_block: "JsonDict" = {"block": "qparallel", "ops": []}
pll_block: JsonDict = {"block": "qparallel", "ops": []}
for phir_qop in angles2qops.values():
pll_block["ops"].append(phir_qop)
comment = {"//": f"Parallel {tket_gate_to_phir[qop.op.type]}"}
Expand Down Expand Up @@ -256,7 +256,7 @@ def format_and_add_primary_commands(
append_cmd(shard.primary_command, ops)
# for measure, format and include "returns"
elif gate_type == "Measure":
fmt_measure: "JsonDict" = {
fmt_measure: JsonDict = {
"qop": "Measure",
"args": [],
"returns": [],
Expand All @@ -268,7 +268,7 @@ def format_and_add_primary_commands(
ops.append(fmt_measure)
# all other gates, treat as standard qops
else:
fmt_qop: "JsonDict" = {"qop": gate_type, "args": []}
fmt_qop: JsonDict = {"qop": gate_type, "args": []}
for shard in group:
pc = shard.primary_command
fmt_qop["args"].append(arg_to_bit(pc.args[0]))
Expand Down Expand Up @@ -325,7 +325,7 @@ def genphir_parallel(

phir = PHIR_HEADER
phir["metadata"]["strict_parallelism"] = True
ops: list["JsonDict"] = []
ops: list[JsonDict] = []

qbits = set()
cbits = set()
Expand Down
2 changes: 1 addition & 1 deletion pytket/phir/sharding/shards2ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parse_shards_naive(
scheduled: set[int] = set()
num_shards: int = len(shards)
qid_count: int = 0
qubits2ids: dict["UnitID", int] = {}
qubits2ids: dict[UnitID, int] = {}

while len(scheduled) < num_shards:
layer: Layer = []
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ networkx==2.8.8
phir==0.3.3
pre-commit==3.7.1
pydata_sphinx_theme==0.15.2
pytest==8.2.0
pytest==8.2.1
pytest-order==1.2.1
pytket==1.27.0
ruff==0.4.4
pytket==1.28.0
ruff==0.4.5
setuptools_scm==8.1.0
sphinx==7.3.7
wasmtime==20.0.0
wasmtime==21.0.0
wheel==0.43.0
31 changes: 31 additions & 0 deletions tests/test_phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import json

from pytket.circuit import Bit, Circuit
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.phir.api import pytket_to_phir
from pytket.qasm.qasm import circuit_from_qasm_str
from pytket.unit_id import BitRegister
Expand Down Expand Up @@ -431,3 +432,33 @@ def test_irregular_multibit_ops() -> None:
},
],
}


def test_nullary_ops() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/178 ."""
c = Circuit(1, 1)
exp1 = create_bit_logic_exp(BitWiseOp.ONE, [])
c.H(0, condition=exp1)
exp0 = create_bit_logic_exp(BitWiseOp.ZERO, [])
c.H(0, condition=exp0)
c.measure_all()
phir = json.loads(pytket_to_phir(c))

assert phir["ops"][4] == {
"cop": "=",
"returns": [["tk_SCRATCH_BIT", 0]],
"args": [1],
}
assert phir["ops"][6] == {
"cop": "=",
"returns": [["tk_SCRATCH_BIT", 1]],
"args": [0],
}
assert phir["ops"][8]["condition"] == {
"cop": "==",
"args": [["tk_SCRATCH_BIT", 0], 1],
}
assert phir["ops"][10]["condition"] == {
"cop": "==",
"args": [["tk_SCRATCH_BIT", 1], 1], # evals to False
}

0 comments on commit fc136a1

Please sign in to comment.