From e5e09f15a8da0a34056e93b2135d6e7471f36d41 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Sun, 1 Oct 2023 19:45:18 -0600 Subject: [PATCH] Improving shard handling and testing --- pyproject.toml | 3 + pytket/phir/sharding/shard.py | 65 ++++++++++++++-- pytket/phir/sharding/sharder.py | 59 ++++++++++++--- ruff.toml | 2 +- tests/data/qasm/baby_with_rollup.qasm | 15 ++++ tests/data/qasm/simple_cond.qasm | 11 +++ tests/sample_data.py | 2 + tests/test_shard.py | 46 ++++++++++++ tests/test_sharder.py | 103 ++++++++++++++++++++++++-- 9 files changed, 282 insertions(+), 24 deletions(-) create mode 100644 tests/data/qasm/baby_with_rollup.qasm create mode 100644 tests/data/qasm/simple_cond.qasm create mode 100644 tests/test_shard.py diff --git a/pyproject.toml b/pyproject.toml index ca5665b..fa40d33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git" where = ["."] [tool.pytest.ini_options] +addopts = "-s -vv" pythonpath = [ "." ] +log_cli = true +filterwarnings = ["ignore:::lark.s*"] diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index 0ff4cb8..facf98d 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -1,23 +1,74 @@ -from dataclasses import dataclass +import io +from dataclasses import dataclass, field +from itertools import count from pytket.circuit import Command -from pytket.unit_id import UnitID +from pytket.unit_id import Bit, Qubit, UnitID -@dataclass +@dataclass(unsafe_hash=False) class Shard: """ A shard is a logical grouping of operations that represents the unit by which we actually do placement of qubits """ - # The schedulable command of the shard + # The "schedulable" command of the shard primary_command: Command # The other commands related to the primary schedulable command, stored # as a map of bit-handle (unitID) -> list[Command] sub_commands: dict[UnitID, list[Command]] - # A set of the other shards this particular shard depends upon, and thus - # must be scheduled after - depends_upon: set["Shard"] + # A set of the identifiers of other shards this particular shard depends upon + depends_upon: set[int] + + # All qubits used by the primary and sub commands + qubits_used: set[Qubit] = field(init=False) + + # Set of all classical bits written to by the primary and sub commands + bits_written: set[Bit] = field(init=False) + + # Set of all classical bits read by the primary and sub commands + bits_read: set[Bit] = field(init=False) + + # The unique identifier of the shard + ID: int = field(default_factory=count().__next__, init=False) + + def __post_init__(self) -> None: + self.qubits_used = set(self.primary_command.qubits) + self.bits_written = set(self.primary_command.bits) + self.bits_read = set() + + all_sub_commands: list[Command] = [] + for sub_commands in self.sub_commands.values(): + all_sub_commands.extend(sub_commands) + + for sub_command in all_sub_commands: + self.bits_written.update(sub_command.bits) + self.bits_read.update( + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore # noqa: PGH003 + ) + + def pretty_print(self) -> str: + output = io.StringIO() + output.write(f"Shard {self.ID}:") + output.write(f"\n Command: {self.primary_command}") + output.write( + f'\n Qubits used: [{", ".join(repr(x) for x in self.qubits_used)}]', + ) + output.write("\n Sub commands: ") + if not self.sub_commands: + output.write("none") + for sub in self.sub_commands: + output.write(f"\n {sub}: {self.sub_commands[sub]}") + output.write(f"\n Qubits used: {self.qubits_used}") + output.write(f"\n Bits written: {self.bits_written}") + output.write(f"\n Bits read: {self.bits_read}") + output.write("\n Depends upon shards: ") + if not self.depends_upon: + output.write("none") + output.write(", ".join(map(repr, self.depends_upon))) + content = output.getvalue() + output.close() + return content diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index b3e9cf1..e40d3b3 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -15,9 +15,9 @@ class Sharder: def __init__(self, circuit: Circuit) -> None: self._circuit = circuit - print(f"Sharder created for circuit {self._circuit}") self._pending_commands: dict[UnitID, list[Command]] = {} self._shards: list[Shard] = [] + print(f"Sharder created for circuit {self._circuit}") def shard(self) -> list[Shard]: """ @@ -28,6 +28,11 @@ def shard(self) -> list[Shard]: # https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command for command in self._circuit.get_commands(): self._process_command(command) + self._cleanup_remaining_commands() + + print("Shard output:") + for shard in self._shards: + print(shard.pretty_print()) return self._shards def _process_command(self, command: Command) -> None: @@ -44,38 +49,70 @@ def _process_command(self, command: Command) -> None: print(f"Building shard for command: {command}") self._build_shard(command) else: - self._add_pending_command(command) + self._add_pending_sub_command(command) def _build_shard(self, command: Command) -> None: """ Creates a Shard object given the extant sharding context and the schedulable Command object passed in, and appends it to the Shard list """ - shard = Shard(command, self._pending_commands, set()) - # TODO: Dependencies! - self._pending_commands = {} + # Resolve any sub commands (SQ gates) that interact with the same qubits + sub_commands: dict[UnitID, list[Command]] = {} + for key in ( + key for key in list(self._pending_commands) if key in command.qubits + ): + sub_commands[key] = self._pending_commands.pop(key) + + # Handle dependency calculations + depends_upon: set[int] = set() + for shard in self._shards: + # Check qubit dependencies (R/W implicitly) since all commands on a given qubit + # need to be ordered as the circuit dictated + if not shard.qubits_used.isdisjoint(command.qubits): + depends_upon.add(shard.ID) + # Check classical dependencies, which depend on writing and reading + # hazards: RAW, WAW, WAR + # TODO: Do it! + + shard = Shard(command, sub_commands, depends_upon) self._shards.append(shard) print("Appended shard:", shard) - def _add_pending_command(self, command: Command) -> None: + def _cleanup_remaining_commands(self) -> None: + """ + Checks for any remaining "unsharded" commands, and if found, adds them + to Barrier op shards for each qubit + """ + remaining_qubits = [k for k, v in self._pending_commands.items() if len(v) > 0] + for qubit in remaining_qubits: + self._circuit.add_barrier([qubit]) + # Easiest way to get to a command, since there's no constructor. Could + # create an entire orphan circuit with the matching qubits and the barrier + # instead if this has unintended consequences + barrier_command = self._circuit.get_commands()[-1] + self._build_shard(barrier_command) + + def _add_pending_sub_command(self, command: Command) -> None: """ Adds a pending sub command to the buffer to be flushed when a schedulable operation creates a Shard. """ - # TODO: Need to make sure 'args[0]' is the right key to use. - if command.args[0] not in self._pending_commands: - self._pending_commands[command.args[0]] = [] - self._pending_commands[command.args[0]].append(command) + key = command.qubits[0] + if key not in self._pending_commands: + self._pending_commands[key] = [] + self._pending_commands[key].append(command) + print(f"Adding pending command {command}") @staticmethod def should_op_create_shard(op: Op) -> bool: """ Returns `True` if the operation is one that should result in shard creation. This includes non-gate operations like measure/reset as well as 2-qubit gates. + TODO: This is almost certainly inadequate right now """ - # TODO: This is almost certainly inadequate right now return ( op.type == OpType.Measure or op.type == OpType.Reset + or op.type == OpType.Barrier or (op.is_gate() and op.n_qubits > 1) ) diff --git a/ruff.toml b/ruff.toml index c66ae97..27cb084 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,6 +1,6 @@ target-version = "py310" -line-length = 88 +line-length = 120 select = [ "E", # pycodestyle Errors diff --git a/tests/data/qasm/baby_with_rollup.qasm b/tests/data/qasm/baby_with_rollup.qasm new file mode 100644 index 0000000..b2a81fc --- /dev/null +++ b/tests/data/qasm/baby_with_rollup.qasm @@ -0,0 +1,15 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[2]; +creg c[2]; + +h q[0]; +h q[1]; +CX q[0], q[1]; + +measure q->c; + +h q[0]; + +h q[1]; diff --git a/tests/data/qasm/simple_cond.qasm b/tests/data/qasm/simple_cond.qasm new file mode 100644 index 0000000..d590bb6 --- /dev/null +++ b/tests/data/qasm/simple_cond.qasm @@ -0,0 +1,11 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[1]; +creg c[1]; + +h q; +measure q->c; +reset q; +if (c==1) h q; +measure q->c; diff --git a/tests/sample_data.py b/tests/sample_data.py index 73487d1..4e7d3c5 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -9,6 +9,8 @@ class QasmFiles(Enum): cond_1 = 2 bv_n10 = 3 baby = 4 + baby_with_rollup = 5 + simple_cond = 6 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_shard.py b/tests/test_shard.py new file mode 100644 index 0000000..11008b2 --- /dev/null +++ b/tests/test_shard.py @@ -0,0 +1,46 @@ +from pytket.circuit import Circuit +from pytket.phir.sharding.shard import Shard + +EMPTY_INT_SET: set[int] = set() + + +class TestShard: + def test_shard_ctor(self) -> None: + circ = Circuit(4) # qubits are numbered 0-3 + circ.X(0) # first apply an X gate to qubit 0 + circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3 + circ.Z(3) # then apply a Z gate to qubit 3 + commands = circ.get_commands() + + shard = Shard( + commands[1], + {commands[0].qubits[0]: [commands[0]]}, + EMPTY_INT_SET, + ) + + assert shard.primary_command == commands[1] + assert shard.depends_upon == EMPTY_INT_SET + sub_command_key, sub_command_value = next(iter(shard.sub_commands.items())) + assert sub_command_key == commands[0].qubits[0] + assert sub_command_value[0] == commands[0] + + def test_shard_ctor_conditional(self) -> None: + circuit = Circuit(4, 4) + circuit.H(0) + circuit.Measure(0, 0) + circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore # noqa: PGH003 + circuit.Measure(1, 1) # The command we'll build the shard from + commands = circuit.get_commands() + + shard = Shard( + commands[3], + { + circuit.qubits[0]: [commands[2]], + }, + EMPTY_INT_SET, + ) + + assert len(shard.sub_commands.items()) + assert shard.qubits_used == {circuit.qubits[1]} + assert shard.bits_read == {circuit.bits[0]} + assert shard.bits_written == {circuit.bits[1]} diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 7a6c533..82011d0 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -1,13 +1,106 @@ +from typing import cast + +from pytket.circuit import Conditional, Op, OpType from pytket.phir.sharding.sharder import Sharder from .sample_data import QasmFiles, get_qasm_as_circuit class TestSharder: - def test_ctor(self) -> None: - sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby)) - assert sharder is not None + def test_should_op_create_shard(self) -> None: + expected_true: list[Op] = [ + Op.create(OpType.Measure), # type: ignore # noqa: PGH003 + Op.create(OpType.Reset), # type: ignore # noqa: PGH003 + Op.create(OpType.CX), # type: ignore # noqa: PGH003 + Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 + ] + expected_false: list[Op] = [ + Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 + Op.create(OpType.H), # type: ignore # noqa: PGH003 + Op.create(OpType.Z), # type: ignore # noqa: PGH003 + ] + + for op in expected_true: + assert Sharder.should_op_create_shard(op) + for op in expected_false: + assert not Sharder.should_op_create_shard(op) + + def test_with_baby_circuit(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 3 + + assert shards[0].primary_command.op.type == OpType.CX + assert len(shards[0].primary_command.qubits) == 2 + assert not shards[0].primary_command.bits + assert len(shards[0].sub_commands) == 2 + sub_commands = list(shards[0].sub_commands.items()) + print(sub_commands) + assert sub_commands[0][1][0].op.type == OpType.H + assert len(shards[0].depends_upon) == 0 + + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].depends_upon == {shards[0].ID} + + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + def test_rollup_behavior(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 5 + + assert shards[0].primary_command.op.type == OpType.CX + assert len(shards[0].primary_command.qubits) == 2 + assert not shards[0].primary_command.bits + assert len(shards[0].sub_commands) == 2 + sub_commands = list(shards[0].sub_commands.items()) + print(sub_commands) + assert sub_commands[0][1][0].op.type == OpType.H + assert len(shards[0].depends_upon) == 0 + + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].depends_upon == {shards[0].ID} + + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + assert shards[3].primary_command.op.type == OpType.Barrier + assert len(shards[3].sub_commands) == 1 + assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + + assert shards[4].primary_command.op.type == OpType.Barrier + assert len(shards[4].sub_commands) == 1 + assert shards[4].depends_upon == {shards[0].ID, shards[2].ID} + + def test_simple_conditional(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.simple_cond) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 3 + + assert shards[0].primary_command.op.type == OpType.Measure + assert len(shards[0].sub_commands.items()) == 1 + s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) + assert s0_qubit == circuit.qubits[0] + assert s0_sub_cmds[0].op.type == OpType.H - output = sharder.shard() + assert shards[1].primary_command.op.type == OpType.Reset + assert len(shards[1].sub_commands.items()) == 0 - assert len(output) == 3 + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands.items()) == 1 + s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items())) + assert s2_qubit == circuit.qubits[0] + assert s2_sub_cmds[0].op.type == OpType.Conditional + assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H + assert s2_sub_cmds[0].qubits == [circuit.qubits[0]]