From 30a63c63608abd9900e5dae3316b44644eacf028 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Fri, 10 Nov 2023 10:48:17 -0700 Subject: [PATCH] a try on improving --- pytket/phir/sharding/sharder.py | 111 ++++++++++++++++++++++++-------- tests/test_sharder.py | 14 ++-- 2 files changed, 90 insertions(+), 35 deletions(-) diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index a208c84..f8aeb06 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -1,8 +1,9 @@ import logging +from collections import OrderedDict from typing import cast from pytket.circuit import Circuit, Command, Conditional, Op, OpType -from pytket.unit_id import Bit, UnitID +from pytket.unit_id import Bit, Qubit, UnitID from .shard import Shard @@ -21,13 +22,6 @@ logger = logging.getLogger(__name__) -def _is_command_global_phase(command: Command) -> bool: - return command.op.type == OpType.Phase or ( - command.op.type == OpType.Conditional - and cast(Conditional, command.op).op.type == OpType.Phase - ) - - class Sharder: """The Sharder class. @@ -45,7 +39,7 @@ def __init__(self, circuit: Circuit) -> None: """ self._circuit = circuit self._pending_commands: dict[UnitID, list[Command]] = {} - self._shards: list[Shard] = [] + self._shards: dict[int, Shard] = OrderedDict() logger.debug("Sharder created for circuit %s", self._circuit) def shard(self) -> list[Shard]: @@ -70,7 +64,7 @@ def shard(self) -> list[Shard]: logger.debug("Shard output:") for shard in self._shards: logger.debug(shard) - return self._shards + return list(self._shards.values()) def _process_command(self, command: Command) -> None: """Handles a command per the type and the extant context within the Sharder. @@ -88,7 +82,7 @@ def _process_command(self, command: Command) -> None: msg = f"OpType {command.op.type} not supported!" raise NotImplementedError(msg) - if _is_command_global_phase(command): + if self._is_command_global_phase(command): logger.debug("Ignoring global Phase gate") return @@ -128,14 +122,56 @@ def _build_shard(self, command: Command) -> None: set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] ) - # Handle dependency calculations + depends_upon = self._resolve_shard_dependencies( + qubits_used, bits_written, bits_read + ) + + shard = Shard( + command, + sub_commands, + qubits_used, + bits_written, + bits_read, + depends_upon, + ) + self._shards[shard.ID] = shard + logger.debug("Appended shard: %s", shard) + + def _resolve_shard_dependencies( + self, qubits: set[Qubit], bits_written: set[Bit], bits_read: set[Bit] + ) -> set[int]: + """Finds the dependent shards for a given shard. + + This involves checking for qubit interaction and classical hazards of + various types. + + Args: + shard: Shard to run dependency calculation on + qubits: Set of all qubits interacted with in the command/sub-commands + bits_written: Classical bits the command/sub-commands write to + bits_read: Classical bits the command/sub-commands read from + """ + logger.debug( + "Resolving shard dependencies with qubits=%s bits_written=%s bits_read=%s", + qubits, + bits_written, + bits_read, + ) + depends_upon: set[int] = set() - for shard in self._shards: + shards_to_check: list[Shard] = list(reversed(self._shards.values())) + shards_to_skip: set[int] = set() + + for shard in shards_to_check: + if shard.ID in shards_to_skip: + continue + add_dependency = False + # Check qubit dependencies (R/W implicitly) since all commands # on a given qubit need to be ordered as the circuit dictated - if not shard.qubits_used.isdisjoint(command.qubits): + if not shard.qubits_used.isdisjoint(qubits): logger.debug("...adding shard dep %s -> qubit overlap", shard.ID) - depends_upon.add(shard.ID) + add_dependency = True # Check classical dependencies, which depend on writing and reading # hazards: RAW, WAW, WAR # NOTE: bits_read will include bits_written in the current impl @@ -144,28 +180,35 @@ def _build_shard(self, command: Command) -> None: # by looking at overlap of bits_written elif not shard.bits_written.isdisjoint(bits_written): logger.debug("...adding shard dep %s -> WAW", shard.ID) - depends_upon.add(shard.ID) + add_dependency = True # Check for read-after-write (value seen would change if reordered) elif not shard.bits_written.isdisjoint(bits_read): logger.debug("...adding shard dep %s -> RAW", shard.ID) - depends_upon.add(shard.ID) + add_dependency = True # Check for write-after-read (no reordering or read is changed) - elif not shard.bits_written.isdisjoint(bits_read): + elif not shard.bits_read.isdisjoint(bits_written): logger.debug("...adding shard dep %s -> WAR", shard.ID) + add_dependency = True + + if add_dependency: depends_upon.add(shard.ID) + # Avoid recalculating for anything previous + shards_to_skip.update(self._get_dependencies(shard)) - shard = Shard( - command, - sub_commands, - qubits_used, - bits_written, - bits_read, - depends_upon, - ) - self._shards.append(shard) - logger.debug("Appended shard: %s", shard) + return depends_upon + + def _get_dependencies(self, shard: Shard) -> set[int]: + dependencies: set[int] = set() + self._get_dependencies_recursive(dependencies, shard) + return dependencies + + def _get_dependencies_recursive(self, accumulator: set[int], shard: Shard) -> None: + accumulator.update(shard.depends_upon) + for shard_id in shard.depends_upon: + dependency = self._shards[shard_id] + self._get_dependencies_recursive(accumulator, dependency) def _cleanup_remaining_commands(self) -> None: remaining_qubits = [k for k, v in self._pending_commands.items() if v] @@ -212,3 +255,15 @@ def should_op_create_shard(op: Op) -> bool: ) or (op.is_gate() and op.n_qubits > 1) ) + + @staticmethod + def _is_command_global_phase(command: Command) -> bool: + """Check if an operation related to global phase. + + Args: + command: Command to evaluate + """ + return command.op.type == OpType.Phase or ( + command.op.type == OpType.Conditional + and cast(Conditional, command.op).op.type == OpType.Phase + ) diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 7c7c55d..b65e1cc 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -85,11 +85,11 @@ def test_rollup_behavior(self) -> None: assert shards[3].primary_command.op.type == OpType.Barrier assert len(shards[3].sub_commands) == 1 - assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + assert shards[3].depends_upon == {shards[1].ID} assert shards[4].primary_command.op.type == OpType.Barrier assert len(shards[4].sub_commands) == 1 - assert shards[4].depends_upon == {shards[0].ID, shards[2].ID} + assert shards[4].depends_upon == {shards[2].ID} def test_simple_conditional(self) -> None: circuit = get_qasm_as_circuit(QasmFile.simple_cond) @@ -130,7 +130,7 @@ def test_simple_conditional(self) -> None: assert shards[3].qubits_used == {circuit.qubits[0]} assert shards[3].bits_written == {circuit.bits[0]} assert shards[3].bits_read == {circuit.bits[0]} - assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + assert shards[3].depends_upon == {shards[1].ID, shards[2].ID} assert len(shards[3].sub_commands.items()) == 1 s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) assert s2_qubit == circuit.qubits[0] @@ -183,7 +183,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915 assert shards[3].qubits_used == {circuit.qubits[0]} assert shards[3].bits_written == {circuit.bits[0]} assert shards[3].bits_read == {circuit.bits[0]} - assert shards[3].depends_upon == {shards[2].ID, shards[1].ID} + assert shards[3].depends_upon == {shards[2].ID} # shard 4: [H q[3];, H q[3];, X q[3];]] barrier q[2], q[3]; assert shards[4].primary_command.op.type == OpType.Barrier @@ -221,7 +221,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915 assert shards[6].qubits_used == {circuit.qubits[2]} assert shards[6].bits_written == {circuit.bits[2]} assert shards[6].bits_read == {circuit.bits[2]} - assert shards[6].depends_upon == {shards[5].ID, shards[4].ID} + assert shards[6].depends_upon == {shards[5].ID} def test_classical_hazards(self) -> None: circuit = get_qasm_as_circuit(QasmFile.classical_hazards) @@ -261,7 +261,7 @@ def test_classical_hazards(self) -> None: assert shards[3].qubits_used == set() assert shards[3].bits_written == {circuit.bits[0]} assert shards[3].bits_read == {circuit.bits[0]} - assert shards[3].depends_upon == {shards[0].ID} + assert shards[3].depends_upon == {shards[2].ID} # shard 4: [] if(c[2]==1) c[0]=1; assert shards[4].primary_command.op.type == OpType.Conditional @@ -269,7 +269,7 @@ def test_classical_hazards(self) -> None: assert shards[4].qubits_used == set() assert shards[4].bits_written == {circuit.bits[0]} assert shards[4].bits_read == {circuit.bits[0], circuit.bits[2]} - assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID} + assert shards[4].depends_upon == {shards[1].ID, shards[3].ID} def test_with_big_gate(self) -> None: circuit = get_qasm_as_circuit(QasmFile.big_gate)