Skip to content

Commit

Permalink
Merge pull request #16 from CQCL/no-strict-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik authored Oct 23, 2023
2 parents 4e963f1 + 739e91b commit 33adde6
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: black

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
rev: v0.1.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [{name = "Quantinuum"}]

dependencies = ["pytket"]
dependencies = ["phir>=0.1.5", "pytket"]

[project.optional-dependencies]
tests = ["pytest"]
Expand Down
2 changes: 1 addition & 1 deletion pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def pytket_to_phir(
phir_json = genphir(placed)

if logger.getEffectiveLevel() <= logging.INFO:
print(PHIRModel.model_validate_json(phir_json, strict=True)) # type: ignore[misc]
print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]
return phir_json
22 changes: 11 additions & 11 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from phir.model import PHIRModel
from pytket.circuit import Command
from pytket.phir.sharding.shard import Shard
from pytket.phir.sharding.shard import Cost, Layer, Ordering


def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
Expand All @@ -14,13 +14,13 @@ def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
ops: the list of ops to append to
"""
gate = cmd.op.get_name().split("(", 1)[0]
metadata, angles = (
({"angle_multiplier": "π"}, cmd.op.params)
if gate != "Measure" and cmd.op.params
else (None, None)
angles = (
(cmd.op.params, "pi")
if gate not in ("Measure", "Barrier", "SetBits") and cmd.op.params
else None
)

qop: dict[str, Any] = {
"metadata": metadata,
"angles": angles,
"qop": gate,
"args": [],
Expand All @@ -36,7 +36,7 @@ def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
ops.extend(({"//": str(cmd)}, qop))


def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
def genphir(inp: list[tuple[Ordering, Layer, Cost]]) -> str:
"""Convert a list of shards to the equivalent PHIR.
Args:
Expand All @@ -51,8 +51,8 @@ def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:

qbits = set()
cbits = set()
for _orders, shard_layers, layer_costs in inp:
for shard in shard_layers:
for _orders, shard_layer, layer_cost in inp:
for shard in shard_layer:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
for sub_commands in shard.sub_commands.values():
Expand All @@ -62,7 +62,7 @@ def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
ops.append(
{
"mop": "Transport",
"metadata": {"duration": layer_costs / 1000000}, # microseconds to secs
"duration": (layer_cost, "ms"),
},
)

Expand Down Expand Up @@ -97,5 +97,5 @@ def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
]

phir["ops"] = decls + ops
PHIRModel.model_validate(phir, strict=True)
PHIRModel.model_validate(phir)
return json.dumps(phir)
8 changes: 4 additions & 4 deletions pytket/phir/place_and_route.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from pytket.phir.machine import Machine
from pytket.phir.placement import optimized_place
from pytket.phir.routing import transport_cost
from pytket.phir.sharding.shard import Shard
from pytket.phir.sharding.shard import Cost, Layer, Ordering, Shard
from pytket.phir.sharding.shards2ops import parse_shards_naive


def place_and_route(
machine: Machine,
shards: list[Shard],
) -> list[tuple[list[int], list[Shard], float]]:
) -> list[tuple[Ordering, Layer, Cost]]:
"""Get all the routing info needed for PHIR generation."""
shard_set = set(shards)
circuit_rep, shard_layers = parse_shards_naive(shard_set)
initial_order = list(range(machine.size))
layer_num = 0
orders: list[list[int]] = []
layer_costs: list[float] = []
orders: list[Ordering] = []
layer_costs: list[Cost] = []
net_cost: float = 0.0
for layer in circuit_rep:
order = optimized_place(
Expand Down
6 changes: 6 additions & 0 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
from dataclasses import dataclass, field
from itertools import count
from typing import TypeAlias

from pytket.circuit import Command
from pytket.unit_id import Bit, Qubit, UnitID
Expand Down Expand Up @@ -57,3 +58,8 @@ def pretty_print(self) -> str:
content = output.getvalue()
output.close()
return content


Cost: TypeAlias = float
Layer: TypeAlias = list[Shard]
Ordering: TypeAlias = list[int]
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
black==23.10.0
build==1.0.3
mypy==1.6.1
phir==0.1.3
phir==0.1.5
pre-commit==3.5.0
pytest==7.4.2
pytket-quantinuum==0.25.0
pytket==1.21.0
ruff==0.1.0
ruff==0.1.1
wheel==0.41.2
2 changes: 1 addition & 1 deletion tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@

phir_json = genphir(output)

print(PHIRModel.model_validate_json(phir_json, strict=True)) # type: ignore[misc]
print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]

0 comments on commit 33adde6

Please sign in to comment.