diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99f7ad3..4210b23 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,9 +33,9 @@ repos: hooks: - id: mypy pass_filenames: false - args: [.] + args: [--package=pytket.phir, --package=tests] additional_dependencies: [ pytest, + pytket==1.20.1, types-setuptools, - pytket, ] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..98b123e --- /dev/null +++ b/Makefile @@ -0,0 +1,7 @@ +.PHONY: tests, lint + +tests: + pytest -s -x -vv tests/test*.py + +lint: + pre-commit run --all-files diff --git a/pyproject.toml b/pyproject.toml index aafb143..c60e2c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,3 +16,4 @@ where = ["."] pythonpath = [ "." ] +filterwarnings = ["ignore:::lark.s*"] diff --git a/pytket/phir/main.py b/pytket/phir/main.py index bca21dd..a8d37c1 100644 --- a/pytket/phir/main.py +++ b/pytket/phir/main.py @@ -1,10 +1,29 @@ -"""This is the main module of the pytket-phir package.""" +""" +NOTE: Just a placeholder to allow convenient testing of the flows +""" +from pytket.circuit import Circuit +from pytket.qasm.qasm import circuit_from_qasm -def hello_world() -> str: - """Print 'Hello, world!' to the console.""" - hw = "Hello, World!" - return hw +from .sharding.sharder import Sharder +# Load a qasm circuit and parse +# ,=""=, +# c , _,{ +# /\ @ ) __ +# / ^~~^\ <=.,__/ '}= +# (_/ ,, ,,) \_ _>_/~ +# ~\_(/-\)'-,_,_,_,-'(_)-(_) +circuit: Circuit = circuit_from_qasm("tests/data/qasm/baby.qasm") -print(hello_world()) # noqa: T201 +# https://cqcl.github.io/tket/pytket/api/circuit_class.html + +# Just a little debugging fun +print("Input circuit:") +print(circuit) +print() + +sharding_output = Sharder(circuit).shard() + +print("Sharding output:") +print(sharding_output) diff --git a/pytket/__init__.py b/pytket/phir/sharding/__init__.py similarity index 100% rename from pytket/__init__.py rename to pytket/phir/sharding/__init__.py diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py new file mode 100644 index 0000000..0ff4cb8 --- /dev/null +++ b/pytket/phir/sharding/shard.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +from pytket.circuit import Command +from pytket.unit_id import UnitID + + +@dataclass +class Shard: + """ + A shard is a logical grouping of operations that represents the unit by which + we actually do placement of qubits + """ + + # The schedulable command of the shard + primary_command: Command + + # The other commands related to the primary schedulable command, stored + # as a map of bit-handle (unitID) -> list[Command] + sub_commands: dict[UnitID, list[Command]] + + # A set of the other shards this particular shard depends upon, and thus + # must be scheduled after + depends_upon: set["Shard"] diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py new file mode 100644 index 0000000..b3e9cf1 --- /dev/null +++ b/pytket/phir/sharding/sharder.py @@ -0,0 +1,81 @@ +from pytket.circuit import Circuit, Command, Op, OpType +from pytket.unit_id import UnitID + +from .shard import Shard + +NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM] + + +class Sharder: + """ + The sharder class is responsible for taking in a circuit in TKET representation + and converting it into shards that can be subsequently handled in the + compilation pipeline. + """ + + def __init__(self, circuit: Circuit) -> None: + self._circuit = circuit + print(f"Sharder created for circuit {self._circuit}") + self._pending_commands: dict[UnitID, list[Command]] = {} + self._shards: list[Shard] = [] + + def shard(self) -> list[Shard]: + """ + Performs the sharding algorithm on the circuit the Sharder was initialized + with, returning the list of Shards needed to schedule + """ + print("Sharding begins....") + # https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command + for command in self._circuit.get_commands(): + self._process_command(command) + return self._shards + + def _process_command(self, command: Command) -> None: + """ + Handles a given TKET command (operation, bits, etc) according to the type + and the extant context within the Sharder + """ + print("Processing command: ", command.op, command.op.type, command.args) + if command.op.type in NOT_IMPLEMENTED_OP_TYPES: + msg = f"OpType {command.op.type} not supported!" + raise NotImplementedError(msg) + + if self.should_op_create_shard(command.op): + print(f"Building shard for command: {command}") + self._build_shard(command) + else: + self._add_pending_command(command) + + def _build_shard(self, command: Command) -> None: + """ + Creates a Shard object given the extant sharding context and the schedulable + Command object passed in, and appends it to the Shard list + """ + shard = Shard(command, self._pending_commands, set()) + # TODO: Dependencies! + self._pending_commands = {} + self._shards.append(shard) + print("Appended shard:", shard) + + def _add_pending_command(self, command: Command) -> None: + """ + Adds a pending sub command to the buffer to be flushed when a schedulable + operation creates a Shard. + """ + # TODO: Need to make sure 'args[0]' is the right key to use. + if command.args[0] not in self._pending_commands: + self._pending_commands[command.args[0]] = [] + self._pending_commands[command.args[0]].append(command) + + @staticmethod + def should_op_create_shard(op: Op) -> bool: + """ + Returns `True` if the operation is one that should result in shard creation. + This includes non-gate operations like measure/reset as well as 2-qubit gates. + """ + # TODO: This is almost certainly inadequate right now + return ( + op.type == OpType.Measure + or op.type == OpType.Reset + or (op.is_gate() and op.n_qubits > 1) + ) diff --git a/pytket/phir/utils.py b/pytket/phir/utils.py deleted file mode 100644 index 4114ecb..0000000 --- a/pytket/phir/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Utility functions for the pytket-phir package.""" - - -def add_numbers(a: int, b: int) -> int: - """Add two numbers and returns the result.""" - return a + b diff --git a/requirements.txt b/requirements.txt index 214e0e8..c602e38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ pytest==7.4.2 ruff==0.0.291 sphinx==7.2.6 wheel==0.41.2 +pytket==1.20.1 diff --git a/ruff.toml b/ruff.toml index 7dc820d..c66ae97 100644 --- a/ruff.toml +++ b/ruff.toml @@ -42,11 +42,19 @@ select = [ "YTT", # flake8-2020 ] -ignore = [] +ignore = [ + "T201", # no prints flake-8 + "FIX002", # Allow todos + "TD002", # Allow no author todos + "TD003", # allow todos with no issues +] [per-file-ignores] "__init__.py" = ["F401"] # module imported but unused -"tests/*" = ["S101"] # Use of `assert` detected +"tests/*" = [ + "S101", # Use of `assert` detected + "PLR2004" # Magic constants + ] [pydocstyle] convention = "google" diff --git a/tests/data/qasm/baby.qasm b/tests/data/qasm/baby.qasm new file mode 100644 index 0000000..72a521b --- /dev/null +++ b/tests/data/qasm/baby.qasm @@ -0,0 +1,11 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[2]; +creg c[2]; + +h q[0]; +h q[1]; +CX q[0], q[1]; + +measure q->c; diff --git a/tests/data/qasm/bv_n10.qasm b/tests/data/qasm/bv_n10.qasm new file mode 100644 index 0000000..63775cc --- /dev/null +++ b/tests/data/qasm/bv_n10.qasm @@ -0,0 +1,47 @@ +//@author Raymond Harry Rudy rudyhar@jp.ibm.com +//Bernstein-Vazirani with 10 qubits. +//Hidden string is 111111111 +OPENQASM 2.0; +include "qelib1.inc"; +qreg qr[10]; +creg cr[9]; +h qr[0]; +h qr[1]; +h qr[2]; +h qr[3]; +h qr[4]; +h qr[5]; +h qr[6]; +h qr[7]; +h qr[8]; +x qr[9]; +h qr[9]; +barrier qr[0],qr[1],qr[2],qr[3],qr[4],qr[5],qr[6],qr[7],qr[8],qr[9]; +cx qr[0],qr[9]; +cx qr[1],qr[9]; +cx qr[2],qr[9]; +cx qr[3],qr[9]; +cx qr[4],qr[9]; +cx qr[5],qr[9]; +cx qr[6],qr[9]; +cx qr[7],qr[9]; +cx qr[8],qr[9]; +barrier qr[0],qr[1],qr[2],qr[3],qr[4],qr[5],qr[6],qr[7],qr[8],qr[9]; +h qr[0]; +h qr[1]; +h qr[2]; +h qr[3]; +h qr[4]; +h qr[5]; +h qr[6]; +h qr[7]; +h qr[8]; +measure qr[0] -> cr[0]; +measure qr[1] -> cr[1]; +measure qr[2] -> cr[2]; +measure qr[3] -> cr[3]; +measure qr[4] -> cr[4]; +measure qr[5] -> cr[5]; +measure qr[6] -> cr[6]; +measure qr[7] -> cr[7]; +measure qr[8] -> cr[8]; diff --git a/tests/data/qasm/cond_1.qasm b/tests/data/qasm/cond_1.qasm new file mode 100644 index 0000000..6dfe4a5 --- /dev/null +++ b/tests/data/qasm/cond_1.qasm @@ -0,0 +1,16 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[1]; +creg c[2]; + +h q; +measure q[0]->c[0]; +reset q; +if (c==1) h q; +if (c<1) h q; +if (c>1) h q; +if (c<=1) h q; +if (c>=1) h q; +if (c!=1) h q; +measure q[0]->c[0]; diff --git a/tests/data/qasm/simple.qasm b/tests/data/qasm/simple.qasm new file mode 100644 index 0000000..9ff7b72 --- /dev/null +++ b/tests/data/qasm/simple.qasm @@ -0,0 +1,11 @@ +OPENQASM 2.0; +include "qelib1.inc"; +qreg q[1]; +creg c[1]; + +h q[0]; +h q[0]; +h q[0]; +h q[0]; +h q[0]; +measure q->c; diff --git a/tests/sample_data.py b/tests/sample_data.py new file mode 100644 index 0000000..73487d1 --- /dev/null +++ b/tests/sample_data.py @@ -0,0 +1,15 @@ +from enum import Enum + +from pytket.circuit import Circuit +from pytket.qasm.qasm import circuit_from_qasm + + +class QasmFiles(Enum): + simple = 1 + cond_1 = 2 + bv_n10 = 3 + baby = 4 + + +def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: + return circuit_from_qasm(f"tests/data/qasm/{qasm_file.name}.qasm") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 9d673cb..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Tests for pytket-phir.main module.""" - -from pytket.phir.main import hello_world - - -def test_hello_world(): - """Test the hello_world function.""" - assert hello_world() == "Hello, World!" diff --git a/tests/test_sharder.py b/tests/test_sharder.py new file mode 100644 index 0000000..7a6c533 --- /dev/null +++ b/tests/test_sharder.py @@ -0,0 +1,13 @@ +from pytket.phir.sharding.sharder import Sharder + +from .sample_data import QasmFiles, get_qasm_as_circuit + + +class TestSharder: + def test_ctor(self) -> None: + sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby)) + assert sharder is not None + + output = sharder.shard() + + assert len(output) == 3 diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index d181a80..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Tests for pytket-phir.utils.""" - -from pytket.phir.utils import add_numbers - - -def test_add(): - """Test the add function.""" - assert add_numbers(2, 3) == 5 # noqa: PLR2004 - assert add_numbers(0, 0) == 0 - assert add_numbers(-1, 1) == 0