-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A try on improving sharding dependency resolution efficiency #32
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import logging | ||
from collections import OrderedDict | ||
from typing import cast | ||
|
||
from pytket.circuit import Circuit, Command, Conditional, Op, OpType | ||
from pytket.unit_id import Bit, UnitID | ||
from pytket.unit_id import Bit, Qubit, UnitID | ||
|
||
from .shard import Shard | ||
|
||
|
@@ -21,13 +22,6 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _is_command_global_phase(command: Command) -> bool: | ||
return command.op.type == OpType.Phase or ( | ||
command.op.type == OpType.Conditional | ||
and cast(Conditional, command.op).op.type == OpType.Phase | ||
) | ||
|
||
|
||
class Sharder: | ||
"""The Sharder class. | ||
|
||
|
@@ -45,7 +39,7 @@ def __init__(self, circuit: Circuit) -> None: | |
""" | ||
self._circuit = circuit | ||
self._pending_commands: dict[UnitID, list[Command]] = {} | ||
self._shards: list[Shard] = [] | ||
self._shards: dict[int, Shard] = OrderedDict() | ||
logger.debug("Sharder created for circuit %s", self._circuit) | ||
|
||
def shard(self) -> list[Shard]: | ||
|
@@ -70,7 +64,7 @@ def shard(self) -> list[Shard]: | |
logger.debug("Shard output:") | ||
for shard in self._shards: | ||
logger.debug(shard) | ||
return self._shards | ||
return list(self._shards.values()) | ||
|
||
def _process_command(self, command: Command) -> None: | ||
"""Handles a command per the type and the extant context within the Sharder. | ||
|
@@ -88,7 +82,7 @@ def _process_command(self, command: Command) -> None: | |
msg = f"OpType {command.op.type} not supported!" | ||
raise NotImplementedError(msg) | ||
|
||
if _is_command_global_phase(command): | ||
if self._is_command_global_phase(command): | ||
logger.debug("Ignoring global Phase gate") | ||
return | ||
|
||
|
@@ -128,14 +122,56 @@ def _build_shard(self, command: Command) -> None: | |
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] | ||
) | ||
|
||
# Handle dependency calculations | ||
depends_upon = self._resolve_shard_dependencies( | ||
qubits_used, bits_written, bits_read | ||
) | ||
|
||
shard = Shard( | ||
command, | ||
sub_commands, | ||
qubits_used, | ||
bits_written, | ||
bits_read, | ||
depends_upon, | ||
) | ||
self._shards[shard.ID] = shard | ||
logger.debug("Appended shard: %s", shard) | ||
|
||
def _resolve_shard_dependencies( | ||
self, qubits: set[Qubit], bits_written: set[Bit], bits_read: set[Bit] | ||
) -> set[int]: | ||
"""Finds the dependent shards for a given shard. | ||
|
||
This involves checking for qubit interaction and classical hazards of | ||
various types. | ||
|
||
Args: | ||
shard: Shard to run dependency calculation on | ||
qubits: Set of all qubits interacted with in the command/sub-commands | ||
bits_written: Classical bits the command/sub-commands write to | ||
bits_read: Classical bits the command/sub-commands read from | ||
""" | ||
logger.debug( | ||
"Resolving shard dependencies with qubits=%s bits_written=%s bits_read=%s", | ||
qubits, | ||
bits_written, | ||
bits_read, | ||
) | ||
|
||
depends_upon: set[int] = set() | ||
for shard in self._shards: | ||
shards_to_check: list[Shard] = list(reversed(self._shards.values())) | ||
shards_to_skip: set[int] = set() | ||
|
||
for shard in shards_to_check: | ||
if shard.ID in shards_to_skip: | ||
continue | ||
add_dependency = False | ||
|
||
# Check qubit dependencies (R/W implicitly) since all commands | ||
# on a given qubit need to be ordered as the circuit dictated | ||
if not shard.qubits_used.isdisjoint(command.qubits): | ||
if not shard.qubits_used.isdisjoint(qubits): | ||
logger.debug("...adding shard dep %s -> qubit overlap", shard.ID) | ||
depends_upon.add(shard.ID) | ||
add_dependency = True | ||
# Check classical dependencies, which depend on writing and reading | ||
# hazards: RAW, WAW, WAR | ||
# NOTE: bits_read will include bits_written in the current impl | ||
|
@@ -144,28 +180,35 @@ def _build_shard(self, command: Command) -> None: | |
# by looking at overlap of bits_written | ||
elif not shard.bits_written.isdisjoint(bits_written): | ||
logger.debug("...adding shard dep %s -> WAW", shard.ID) | ||
depends_upon.add(shard.ID) | ||
add_dependency = True | ||
|
||
# Check for read-after-write (value seen would change if reordered) | ||
elif not shard.bits_written.isdisjoint(bits_read): | ||
logger.debug("...adding shard dep %s -> RAW", shard.ID) | ||
depends_upon.add(shard.ID) | ||
add_dependency = True | ||
|
||
# Check for write-after-read (no reordering or read is changed) | ||
elif not shard.bits_written.isdisjoint(bits_read): | ||
elif not shard.bits_read.isdisjoint(bits_written): | ||
logger.debug("...adding shard dep %s -> WAR", shard.ID) | ||
add_dependency = True | ||
|
||
if add_dependency: | ||
depends_upon.add(shard.ID) | ||
# Avoid recalculating for anything previous | ||
shards_to_skip.update(self._get_dependencies(shard)) | ||
|
||
shard = Shard( | ||
command, | ||
sub_commands, | ||
qubits_used, | ||
bits_written, | ||
bits_read, | ||
depends_upon, | ||
) | ||
self._shards.append(shard) | ||
logger.debug("Appended shard: %s", shard) | ||
return depends_upon | ||
|
||
def _get_dependencies(self, shard: Shard) -> set[int]: | ||
dependencies: set[int] = set() | ||
self._get_dependencies_recursive(dependencies, shard) | ||
return dependencies | ||
|
||
def _get_dependencies_recursive(self, accumulator: set[int], shard: Shard) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I'm overthinking this, but if we're trying to improve performance--then is there any way to replace this with a data-structure like a stack/queue and some sort of while-loop? Or is the overhead of the recursion here minimal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, my comment above may have been confusing, currently the changes in this PR lead to significant overhead and hence it is not ready to be merged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not 100% sure but since this is clearly not a slam dunk I'm going to close it and try a more significant refactor that should reduce to linear time. |
||
accumulator.update(shard.depends_upon) | ||
for shard_id in shard.depends_upon: | ||
dependency = self._shards[shard_id] | ||
self._get_dependencies_recursive(accumulator, dependency) | ||
|
||
def _cleanup_remaining_commands(self) -> None: | ||
remaining_qubits = [k for k, v in self._pending_commands.items() if v] | ||
|
@@ -212,3 +255,15 @@ def should_op_create_shard(op: Op) -> bool: | |
) | ||
or (op.is_gate() and op.n_qubits > 1) | ||
) | ||
|
||
@staticmethod | ||
def _is_command_global_phase(command: Command) -> bool: | ||
"""Check if an operation related to global phase. | ||
|
||
Args: | ||
command: Command to evaluate | ||
""" | ||
return command.op.type == OpType.Phase or ( | ||
command.op.type == OpType.Conditional | ||
and cast(Conditional, command.op).op.type == OpType.Phase | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this avoiding recalculation is the main thing to help performance?