Skip to content

Commit

Permalink
a try on improving
Browse files Browse the repository at this point in the history
  • Loading branch information
nealerickson-qtm committed Nov 10, 2023
1 parent 177d1bb commit 30a63c6
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 35 deletions.
111 changes: 83 additions & 28 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from collections import OrderedDict
from typing import cast

from pytket.circuit import Circuit, Command, Conditional, Op, OpType
from pytket.unit_id import Bit, UnitID
from pytket.unit_id import Bit, Qubit, UnitID

from .shard import Shard

Expand All @@ -21,13 +22,6 @@
logger = logging.getLogger(__name__)


def _is_command_global_phase(command: Command) -> bool:
return command.op.type == OpType.Phase or (
command.op.type == OpType.Conditional
and cast(Conditional, command.op).op.type == OpType.Phase
)


class Sharder:
"""The Sharder class.
Expand All @@ -45,7 +39,7 @@ def __init__(self, circuit: Circuit) -> None:
"""
self._circuit = circuit
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
self._shards: dict[int, Shard] = OrderedDict()
logger.debug("Sharder created for circuit %s", self._circuit)

def shard(self) -> list[Shard]:
Expand All @@ -70,7 +64,7 @@ def shard(self) -> list[Shard]:
logger.debug("Shard output:")
for shard in self._shards:
logger.debug(shard)
return self._shards
return list(self._shards.values())

def _process_command(self, command: Command) -> None:
"""Handles a command per the type and the extant context within the Sharder.
Expand All @@ -88,7 +82,7 @@ def _process_command(self, command: Command) -> None:
msg = f"OpType {command.op.type} not supported!"
raise NotImplementedError(msg)

if _is_command_global_phase(command):
if self._is_command_global_phase(command):
logger.debug("Ignoring global Phase gate")
return

Expand Down Expand Up @@ -128,14 +122,56 @@ def _build_shard(self, command: Command) -> None:
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type]
)

# Handle dependency calculations
depends_upon = self._resolve_shard_dependencies(
qubits_used, bits_written, bits_read
)

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards[shard.ID] = shard
logger.debug("Appended shard: %s", shard)

def _resolve_shard_dependencies(
self, qubits: set[Qubit], bits_written: set[Bit], bits_read: set[Bit]
) -> set[int]:
"""Finds the dependent shards for a given shard.
This involves checking for qubit interaction and classical hazards of
various types.
Args:
shard: Shard to run dependency calculation on
qubits: Set of all qubits interacted with in the command/sub-commands
bits_written: Classical bits the command/sub-commands write to
bits_read: Classical bits the command/sub-commands read from
"""
logger.debug(
"Resolving shard dependencies with qubits=%s bits_written=%s bits_read=%s",
qubits,
bits_written,
bits_read,
)

depends_upon: set[int] = set()
for shard in self._shards:
shards_to_check: list[Shard] = list(reversed(self._shards.values()))
shards_to_skip: set[int] = set()

for shard in shards_to_check:
if shard.ID in shards_to_skip:
continue
add_dependency = False

# Check qubit dependencies (R/W implicitly) since all commands
# on a given qubit need to be ordered as the circuit dictated
if not shard.qubits_used.isdisjoint(command.qubits):
if not shard.qubits_used.isdisjoint(qubits):
logger.debug("...adding shard dep %s -> qubit overlap", shard.ID)
depends_upon.add(shard.ID)
add_dependency = True
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# NOTE: bits_read will include bits_written in the current impl
Expand All @@ -144,28 +180,35 @@ def _build_shard(self, command: Command) -> None:
# by looking at overlap of bits_written
elif not shard.bits_written.isdisjoint(bits_written):
logger.debug("...adding shard dep %s -> WAW", shard.ID)
depends_upon.add(shard.ID)
add_dependency = True

# Check for read-after-write (value seen would change if reordered)
elif not shard.bits_written.isdisjoint(bits_read):
logger.debug("...adding shard dep %s -> RAW", shard.ID)
depends_upon.add(shard.ID)
add_dependency = True

# Check for write-after-read (no reordering or read is changed)
elif not shard.bits_written.isdisjoint(bits_read):
elif not shard.bits_read.isdisjoint(bits_written):
logger.debug("...adding shard dep %s -> WAR", shard.ID)
add_dependency = True

if add_dependency:
depends_upon.add(shard.ID)
# Avoid recalculating for anything previous
shards_to_skip.update(self._get_dependencies(shard))

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards.append(shard)
logger.debug("Appended shard: %s", shard)
return depends_upon

def _get_dependencies(self, shard: Shard) -> set[int]:
dependencies: set[int] = set()
self._get_dependencies_recursive(dependencies, shard)
return dependencies

def _get_dependencies_recursive(self, accumulator: set[int], shard: Shard) -> None:
accumulator.update(shard.depends_upon)
for shard_id in shard.depends_upon:
dependency = self._shards[shard_id]
self._get_dependencies_recursive(accumulator, dependency)

def _cleanup_remaining_commands(self) -> None:
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
Expand Down Expand Up @@ -212,3 +255,15 @@ def should_op_create_shard(op: Op) -> bool:
)
or (op.is_gate() and op.n_qubits > 1)
)

@staticmethod
def _is_command_global_phase(command: Command) -> bool:
"""Check if an operation related to global phase.
Args:
command: Command to evaluate
"""
return command.op.type == OpType.Phase or (
command.op.type == OpType.Conditional
and cast(Conditional, command.op).op.type == OpType.Phase
)
14 changes: 7 additions & 7 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def test_rollup_behavior(self) -> None:

assert shards[3].primary_command.op.type == OpType.Barrier
assert len(shards[3].sub_commands) == 1
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID}
assert shards[3].depends_upon == {shards[1].ID}

assert shards[4].primary_command.op.type == OpType.Barrier
assert len(shards[4].sub_commands) == 1
assert shards[4].depends_upon == {shards[0].ID, shards[2].ID}
assert shards[4].depends_upon == {shards[2].ID}

def test_simple_conditional(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.simple_cond)
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_simple_conditional(self) -> None:
assert shards[3].qubits_used == {circuit.qubits[0]}
assert shards[3].bits_written == {circuit.bits[0]}
assert shards[3].bits_read == {circuit.bits[0]}
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID}
assert shards[3].depends_upon == {shards[1].ID, shards[2].ID}
assert len(shards[3].sub_commands.items()) == 1
s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items()))
assert s2_qubit == circuit.qubits[0]
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915
assert shards[3].qubits_used == {circuit.qubits[0]}
assert shards[3].bits_written == {circuit.bits[0]}
assert shards[3].bits_read == {circuit.bits[0]}
assert shards[3].depends_upon == {shards[2].ID, shards[1].ID}
assert shards[3].depends_upon == {shards[2].ID}

# shard 4: [H q[3];, H q[3];, X q[3];]] barrier q[2], q[3];
assert shards[4].primary_command.op.type == OpType.Barrier
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915
assert shards[6].qubits_used == {circuit.qubits[2]}
assert shards[6].bits_written == {circuit.bits[2]}
assert shards[6].bits_read == {circuit.bits[2]}
assert shards[6].depends_upon == {shards[5].ID, shards[4].ID}
assert shards[6].depends_upon == {shards[5].ID}

def test_classical_hazards(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.classical_hazards)
Expand Down Expand Up @@ -261,15 +261,15 @@ def test_classical_hazards(self) -> None:
assert shards[3].qubits_used == set()
assert shards[3].bits_written == {circuit.bits[0]}
assert shards[3].bits_read == {circuit.bits[0]}
assert shards[3].depends_upon == {shards[0].ID}
assert shards[3].depends_upon == {shards[2].ID}

# shard 4: [] if(c[2]==1) c[0]=1;
assert shards[4].primary_command.op.type == OpType.Conditional
assert len(shards[4].sub_commands) == 0
assert shards[4].qubits_used == set()
assert shards[4].bits_written == {circuit.bits[0]}
assert shards[4].bits_read == {circuit.bits[0], circuit.bits[2]}
assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID}
assert shards[4].depends_upon == {shards[1].ID, shards[3].ID}

def test_with_big_gate(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.big_gate)
Expand Down

0 comments on commit 30a63c6

Please sign in to comment.