Skip to content

Commit

Permalink
PHIR generation (#12)
Browse files Browse the repository at this point in the history
First implementation to generate PHIR from sample QASM circuits.
--------------------------------------------------------------------------

* Update packages

* routing integration

* Fix linting and typing issues

* Update todos

* Add first pass to generate PHIR

* Bit more cleanup

---------

Co-authored-by: Neal Erickson <[email protected]>
  • Loading branch information
qartik and nealerickson-qtm authored Oct 18, 2023
1 parent c0605d8 commit ac3c955
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 104 deletions.
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -13,30 +13,31 @@ repos:
- id: debug-statements

- repo: https://github.com/crate-ci/typos
rev: v1.16.17
rev: v1.16.20
hooks:
- id: typos

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.0
hooks:
- id: black

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
rev: v0.1.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.5.1'
rev: 'v1.6.0'
hooks:
- id: mypy
pass_filenames: false
args: [--package=pytket.phir, --package=tests]
additional_dependencies: [
pytest,
pytket==1.20.1,
pytket-quantinuum==0.23.0,
pytket==1.21.0,
pytket-quantinuum==0.25.0,
types-setuptools,
"git+https://github.com/CQCL/phir#egg=phir",
]
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sys

sys.path.insert(0, pathlib.Path("../../pytket").resolve().as_posix())
# print(sys.path)

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "pytket-phir"
version = "0.0.1"
description = "A python library"
description = "A circuit analyzer and translator from pytket to PHIR"
readme = "README.md"
requires-python = ">=3.10"
license = {file = "LICENSE"}
Expand All @@ -31,7 +31,7 @@ pythonpath = [
"."
]
log_cli = true
log_cli_level = "DEBUG"
log_cli_level = "INFO"
filterwarnings = ["ignore:::lark.s*"]
log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"
16 changes: 11 additions & 5 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from typing import TYPE_CHECKING

from pytket.circuit import Circuit
from pytket.phir.machine import Machine
from pytket.phir.phirgen import genphir
from pytket.phir.place_and_route import place_and_route
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine
from pytket.phir.sharding.sharder import Sharder

if TYPE_CHECKING:
from pytket.phir.machine import Machine

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -40,11 +44,13 @@ def pytket_to_phir(
shards = sharder.shard()

logger.debug("Performing placement and routing...")
placed = place_and_route(machine, shards) # type: ignore [misc]
if machine:
placed = place_and_route(machine, shards)
else:
msg = "no machine found"
raise ValueError(msg)

phir_output = str(placed) # type: ignore [misc]
phir_output = genphir(placed)

# TODO: Pass shards[] into placement, routing, etc
# TODO: Convert to PHIR JSON spec and return
logger.info("Output: %s", phir_output)
return phir_output
2 changes: 1 addition & 1 deletion pytket/phir/machine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
class Machine:
"""A machine info class for testing."""

def __init__( # noqa: PLR0913
def __init__(
self,
size: int,
tq_options: set[int],
Expand Down
106 changes: 106 additions & 0 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# mypy: disable-error-code="misc,no-any-unimported"

import json

from phir.model import ( # type: ignore [import-untyped]
Cmd,
DataMgmt,
OpType,
PHIRModel,
QOp,
)
from pytket.circuit import Command
from pytket.phir.sharding.shard import Shard


def write_cmd(cmd: Command, ops: list[Cmd]) -> None:
"""Write a pytket command to PHIR qop.
Args:
cmd: pytket command obtained from pytket-phir
ops: the list of ops to append to
"""
gate = cmd.op.get_name().split("(", 1)[0]
metadata, angles = (
({"angle_multiplier": "π"}, cmd.op.params)
if gate != "Measure"
else (None, None)
)
qop: QOp = {
"metadata": metadata,
"angles": angles,
"qop": gate,
"args": [],
}
for qbit in cmd.args:
qop["args"].append([qbit.reg_name, qbit.index[0]])
if cmd.bits:
qop["returns"] = []
for cbit in cmd.bits:
qop["returns"].append([cbit.reg_name, cbit.index[0]])
ops.extend(({"//": str(cmd)}, qop))


def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
"""Convert a list of shards to the equivalent PHIR.
Args:
inp: list of shards
"""
phir = {
"format": "PHIR/JSON",
"version": "0.1.0",
"metadata": {"source": "pytket-phir"},
}
ops: OpType = []

qbits = set()
cbits = set()
for _orders, shard_layers, layer_costs in inp:
for shard in shard_layers:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
for sub_commands in shard.sub_commands.values():
for sc in sub_commands:
write_cmd(sc, ops)
write_cmd(shard.primary_command, ops)
ops.append(
{
"mop": "Transport",
"metadata": {"duration": layer_costs / 1000000}, # microseconds to secs
},
)

# TODO(kartik): this may not always be accurate
qvar_dim: dict[str, int] = {}
for qbit in qbits:
qvar_dim.setdefault(qbit.reg_name, 0)
qvar_dim[qbit.reg_name] += 1

cvar_dim: dict[str, int] = {}
for cbit in cbits:
cvar_dim.setdefault(cbit.reg_name, 0)
cvar_dim[cbit.reg_name] += 1

decls: list[DataMgmt] = [
{
"data": "qvar_define",
"data_type": "qubits",
"variable": q,
"size": d,
}
for q, d in qvar_dim.items()
]

decls += [
{
"data": "cvar_define",
"variable": c,
"size": d,
}
for c, d in cvar_dim.items()
]

phir["ops"] = decls + ops
PHIRModel.model_validate(phir)
return json.dumps(phir)
14 changes: 6 additions & 8 deletions pytket/phir/place_and_route.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import typing

from pytket.phir.machine import Machine
from pytket.phir.placement import optimized_place
from pytket.phir.routing import transport_cost
from pytket.phir.sharding.shard import Shard
from pytket.phir.sharding.shards2ops import parse_shards_naive


@typing.no_type_check
def place_and_route(machine: Machine, shards: list[Shard]):
def place_and_route(
machine: Machine,
shards: list[Shard],
) -> list[tuple[list[int], list[Shard], float]]:
"""Get all the routing info needed for PHIR generation."""
shard_set = set(shards)
circuit_rep, shard_layers = parse_shards_naive(shard_set)
initial_order = list(range(machine.size))
layer_num = 0
orders: list[list[int]] = []
layer_costs: list[int] = []
layer_costs: list[float] = []
net_cost: float = 0.0
for layer in circuit_rep:
order = optimized_place(
Expand All @@ -33,6 +33,4 @@ def place_and_route(machine: Machine, shards: list[Shard]):
net_cost += cost

# don't need a custom error for this, "strict" parameter will throw error if needed
info = list(zip(orders, shard_layers, layer_costs, strict=True))

return info
return list(zip(orders, shard_layers, layer_costs, strict=True))
27 changes: 11 additions & 16 deletions pytket/phir/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,20 @@ class GateOpportunitiesError(Exception):
"""Exception raised when gating zones cannot accommodate all operations."""

def __init__(self) -> None:
"""Exception""" # noqa: D415
super().__init__("Not enough gating opportunities for all ops in this layer")


class InvalidParallelOpsError(Exception):
"""Exception raised when a layer attempts to gate the same qubit more than once in parallel.""" # noqa: E501
"""Raised when a layer tries to gate the same qubit more than once in parallel."""

def __init__(self, q: int) -> None:
"""Exception
Args: q: a qubit
""" # noqa: D205, D415
super().__init__(f"Cannot gate qubit {q} more than once in the same layer")


class PlacementCheckError(Exception):
"""Exception raised when placement check fails."""

def __init__(self) -> None:
"""Exception""" # noqa: D415
super().__init__("Placement Check Failed")


Expand Down Expand Up @@ -69,9 +64,9 @@ def nearest(zone: int, options: set[int]) -> int:
elif ind == len(lst):
nearest_zone = lst[-1]
else:
l = lst[ind - 1] # noqa: E741
r = lst[ind]
nearest_zone = l if r - zone > zone - l else r
lft = lst[ind - 1]
rgt = lst[ind]
nearest_zone = lft if rgt - zone > zone - lft else rgt

return nearest_zone

Expand Down Expand Up @@ -131,7 +126,7 @@ def place( # noqa: PLR0912
sq_ops.append(op)

# sort the tq_ops by distance apart [[furthest] -> [closest]]
tq_ops_sorted = sorted(tq_ops, key=lambda x: abs(x[0] - x[1]), reverse=True) # type: ignore [misc] # noqa: E501, RUF100
tq_ops_sorted = sorted(tq_ops, key=lambda x: abs(x[0] - x[1]), reverse=True) # type: ignore [misc]

# check to make sure that there are zones available for all ops
if len(tq_ops) > len(tq_zones):
Expand Down Expand Up @@ -168,11 +163,11 @@ def place( # noqa: PLR0912

if placement_check(ops, tq_options, sq_options, order):
return order
else:
raise PlacementCheckError

raise PlacementCheckError


def optimized_place( # noqa: PLR0912
def optimized_place(
ops: list[list[int]],
tq_options: set[int],
sq_options: set[int],
Expand All @@ -198,7 +193,7 @@ def optimized_place( # noqa: PLR0912
sq_ops.append(op)

# sort the tq_ops by distance apart [[furthest] -> [closest]]
tq_ops_sorted = sorted(tq_ops, key=lambda x: abs(x[0] - x[1]), reverse=True) # type: ignore [misc] # noqa: E501, RUF100
tq_ops_sorted = sorted(tq_ops, key=lambda x: abs(x[0] - x[1]), reverse=True) # type: ignore [misc]

# check to make sure that there are zones available for all ops
if len(tq_ops) > len(tq_zones):
Expand Down Expand Up @@ -244,5 +239,5 @@ def optimized_place( # noqa: PLR0912

if placement_check(ops, tq_options, sq_options, order):
return order
else:
raise PlacementCheckError

raise PlacementCheckError
12 changes: 8 additions & 4 deletions pytket/phir/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
class TransportError(Exception):
"""Error raised by inverse() util function."""

def __init__(self, a: list[int], b: list[int]): # noqa: D107
def __init__(self, a: list[int], b: list[int]):
super().__init__(f"Traps different sizes: {len(a)} vs. {len(b)}")


class PermutationError(Exception):
"""Error raised by inverse() util function."""

def __init__(self, lst: list[int]): # noqa: D107
def __init__(self, lst: list[int]):
super().__init__(f"List {lst} is not a permutation of range({len(lst)})")


def inverse(lst: list[int]) -> list[int]:
"""Inverse of a permutation list. If a[i] = x, then inverse(a)[x] = i.""" # noqa: D402
"""Inverse of a permutation list.
If a[i] = x, then inverse(a)[x] = i.
"""
inv = [-1] * len(lst)

for i, elem in enumerate(lst):
Expand All @@ -29,9 +32,10 @@ def inverse(lst: list[int]) -> list[int]:

def transport_cost(init: list[int], goal: list[int], swap_cost: float) -> float:
"""Cost of transport from init to goal.
This is based on the number of parallel swaps performed by Odd-Even
Transposition Sort, which is the maximum distance that any qubit travels.
""" # noqa: D205
"""
if len(init) != len(goal):
raise TransportError(init, goal)

Expand Down
Loading

0 comments on commit ac3c955

Please sign in to comment.