Skip to content

Commit

Permalink
Support for ExplicitPredicate, ExplicitModifier, MultiBitOp` (#162)
Browse files Browse the repository at this point in the history
* explicit predicates for common logical ops and multi-bit ops supported + unit tests
* feat: add support for ExplicitPredicate/Modifier & MultiBitOp
* fix(phirgen): handle ops inside a conditional correctly
* feat(phirgen): add support for irregular multibit ops

---------

Co-authored-by: Kartik Singhal <[email protected]>
Co-authored-by: Alec Edgington <[email protected]>
  • Loading branch information
3 people authored Apr 10, 2024
1 parent 4f332a4 commit bac76d6
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 112 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: debug-statements

- repo: https://github.com/crate-ci/typos
rev: v1.20.4
rev: v1.20.7
hooks:
- id: typos

Expand Down
223 changes: 166 additions & 57 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import json
import logging
from copy import deepcopy
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, TypeAlias

Expand Down Expand Up @@ -218,13 +219,155 @@ def convert_gate(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
return qop


def cop_from_op_name(op_name: str) -> str:
"""Get PHIR classical op name from pytket op name."""
match op_name:
case "AND":
cop = "&"
case "OR":
cop = "|"
case "XOR":
cop = "^"
case "NOT":
cop = "~"
case name:
raise NotImplementedError(name)
return cop


def convert_classicalevalop(op: tk.ClassicalEvalOp, cmd: tk.Command) -> JsonDict | None:
"""Return PHIR dict for a pytket ClassicalEvalOp."""
# Exclude conditional bits from args
args = cmd.args[cmd.op.width :] if isinstance(cmd.op, tk.Conditional) else cmd.args
out: JsonDict | None = None
match op:
case tk.CopyBitsOp():
if len(cmd.bits) != len(args) // 2:
msg = "LHS and RHS lengths mismatch for CopyBits"
raise TypeError(msg)
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits],
[arg_to_bit(args[i]) for i in range(len(args) // 2)],
)
case tk.SetBitsOp():
if len(cmd.bits) != len(op.values):
logger.error("LHS and RHS lengths mismatch for classical assignment")
raise ValueError
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits], list(map(int, op.values))
)
case tk.RangePredicateOp(): # where the condition is a range
cond: JsonDict
match op.lower, op.upper:
case l, u if l == u:
cond = {
"cop": "==",
"args": [args[0].reg_name, u],
}
case l, u if u == UINTMAX:
cond = {
"cop": ">=",
"args": [args[0].reg_name, l],
}
case 0, u:
cond = {
"cop": "<=",
"args": [args[0].reg_name, u],
}
out = {
"block": "if",
"condition": cond,
"true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])],
}
case tk.MultiBitOp():
if len(args) % len(cmd.bits) != 0:
msg = "Input bit- and output bit lengths mismatch."
raise TypeError(msg)

cop = cop_from_op_name(op.basic_op.get_name())
is_explicit = op.basic_op.type == tk.OpType.ExplicitPredicate

# determine number of register operands involved in the operation
operand_count = len(args) // len(cmd.bits) - is_explicit

iters = [iter(args)] * (operand_count + is_explicit)
iter2 = deepcopy(iters)

# Columns of expressions, e.g.,
# AND (*2) a[0], b[0], c[0]
# , a[1], b[1], c[1]
# would be [(a[0], a[1]), (b[0], b[1]), (c[0], c[1])]
# and AND (*2) a[0], a[1], b[0]
# , b[1], c[0], c[1]
# would be [(a[0], b[1]), (a[1], c[0]), (b[0], c[1])]
cols = zip(*zip(*iters, strict=True), strict=True)

if all(
all(col[0].reg_name == bit.reg_name for bit in col) for col in cols
): # expression can be applied register-wise
out = assign_cop(
[cmd.bits[0].reg_name],
[
{
"cop": cop,
"args": [arg.reg_name for arg in args[:operand_count]],
}
],
)
else: # apply a sequence of bit-wise ops
exps = zip(*iter2, strict=True)
out = {
"block": "sequence",
"ops": [
assign_cop(
[arg_to_bit(bit)],
[
{
"cop": cop,
"args": [
arg_to_bit(arg) for arg in exp[:operand_count]
],
}
],
)
for bit, exp in zip(cmd.bits, exps, strict=True)
],
}
case _:
raise NotImplementedError(op)

return out


def multi_bit_condition(args: "list[UnitID]", value: int) -> JsonDict:
"""Construct bitwise condition."""
return {
"cop": "&",
"args": [
{"cop": "==", "args": [arg_to_bit(arg), bval]}
for (arg, bval) in zip(
args[::-1], map(int, f"{value:0{len(args)}b}"), strict=True
)
],
}


def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
"""Return PHIR dict given a tket op and its arguments."""
if op.is_gate():
return convert_gate(op, cmd)

out: JsonDict | None = None
match op: # non-quantum op
case tk.Conditional():
out = {
"block": "if",
"condition": {"cop": "==", "args": [arg_to_bit(cmd.args[0]), op.value]}
if op.width == 1
else multi_bit_condition(cmd.args[: op.width], op.value),
"true_branch": [convert_subcmd(op.op, cmd)],
}

case tk.BarrierOp():
if op.data:
# See https://github.com/CQCL/tket/blob/0ec603986821d994caa3a0fb9c4640e5bc6c0a24/pytket/pytket/qasm/qasm.py#L419-L459
Expand All @@ -246,45 +389,6 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
"args": [arg_to_bit(qbit) for qbit in cmd.qubits],
}

case tk.Conditional(): # where the condition is equality check
out = {
"block": "if",
"condition": {
"cop": "==",
"args": [
arg_to_bit(cmd.args[0])
if op.width == 1
else cmd.args[0].reg_name,
op.value,
],
},
"true_branch": [convert_subcmd(op.op, cmd)],
}

case tk.RangePredicateOp(): # where the condition is a range
cond: JsonDict
match op.lower, op.upper:
case l, u if l == u:
cond = {
"cop": "==",
"args": [cmd.args[0].reg_name, u],
}
case l, u if u == UINTMAX:
cond = {
"cop": ">=",
"args": [cmd.args[0].reg_name, l],
}
case 0, u:
cond = {
"cop": "<=",
"args": [cmd.args[0].reg_name, u],
}
out = {
"block": "if",
"condition": cond,
"true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])],
}

case tk.ClassicalExpBox():
exp = op.get_exp()
match exp:
Expand All @@ -295,29 +399,34 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
rhs = [classical_op(exp)]
out = assign_cop([cmd.bits[0].reg_name], rhs)

case tk.SetBitsOp():
if len(cmd.bits) != len(op.values):
logger.error("LHS and RHS lengths mismatch for classical assignment")
raise ValueError
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits], list(map(int, op.values))
)

case tk.CopyBitsOp():
if len(cmd.bits) != len(cmd.args) // 2:
logger.warning("LHS and RHS lengths mismatch for CopyBits")
out = assign_cop(
[arg_to_bit(bit) for bit in cmd.bits],
[arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)],
)
case tk.ClassicalEvalOp():
return convert_classicalevalop(op, cmd)

case tk.WASMOp():
return create_wasm_op(cmd, op)

case _:
# TODO(kartik): NYI
# https://github.com/CQCL/pytket-phir/issues/25
raise NotImplementedError
# Exclude conditional bits from args
args = (
cmd.args[cmd.op.width :]
if isinstance(cmd.op, tk.Conditional)
else cmd.args
)
match op.type:
case tk.OpType.ExplicitPredicate | tk.OpType.ExplicitModifier:
# exclude output bit when not modifying in place
args = args[:-1] if op.type == tk.OpType.ExplicitPredicate else args
out = assign_cop(
[arg_to_bit(cmd.bits[0])],
[
{
"cop": cop_from_op_name(op.get_name()),
"args": [arg_to_bit(arg) for arg in args],
}
],
)
case _:
raise NotImplementedError(op.type)

return out

Expand Down
2 changes: 2 additions & 0 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
OpType.ExplicitPredicate,
OpType.ExplicitModifier,
OpType.MultiBit,
OpType.CopyBits,
OpType.WASM,
]
Expand Down
23 changes: 0 additions & 23 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

# mypy: disable-error-code="misc"

import json
import logging

import pytest

from pytket.circuit import Bit, Circuit
from pytket.phir.api import pytket_to_phir, qasm_to_phir
from pytket.phir.qtm_machine import QtmMachine

Expand Down Expand Up @@ -50,27 +48,6 @@ def test_pytket_to_phir_h1_1_all(self, test_file: QasmFile) -> None:

assert pytket_to_phir(circuit, QtmMachine.H1)

def test_pytket_classical_only(self) -> None:
c = Circuit(1)
a = c.add_c_register("a", 2)
b = c.add_c_register("b", 3)

c.add_c_copyreg(a, b)
c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)])

phir = json.loads(pytket_to_phir(c))

assert phir["ops"][3] == {
"cop": "=",
"returns": [["b", 0], ["b", 1]],
"args": [["a", 0], ["a", 1]],
}
assert phir["ops"][5] == {
"cop": "=",
"returns": [["a", 0], ["b", 0]],
"args": [["b", 2], ["a", 1]],
}

def test_qasm_to_phir(self) -> None:
"""Test the qasm string entrypoint works."""
qasm = """
Expand Down
Loading

0 comments on commit bac76d6

Please sign in to comment.