Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A try on improving sharding dependency resolution efficiency #32

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 83 additions & 28 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from collections import OrderedDict
from typing import cast

from pytket.circuit import Circuit, Command, Conditional, Op, OpType
from pytket.unit_id import Bit, UnitID
from pytket.unit_id import Bit, Qubit, UnitID

from .shard import Shard

Expand All @@ -21,13 +22,6 @@
logger = logging.getLogger(__name__)


def _is_command_global_phase(command: Command) -> bool:
return command.op.type == OpType.Phase or (
command.op.type == OpType.Conditional
and cast(Conditional, command.op).op.type == OpType.Phase
)


class Sharder:
"""The Sharder class.

Expand All @@ -45,7 +39,7 @@ def __init__(self, circuit: Circuit) -> None:
"""
self._circuit = circuit
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
self._shards: dict[int, Shard] = OrderedDict()
logger.debug("Sharder created for circuit %s", self._circuit)

def shard(self) -> list[Shard]:
Expand All @@ -70,7 +64,7 @@ def shard(self) -> list[Shard]:
logger.debug("Shard output:")
for shard in self._shards:
logger.debug(shard)
return self._shards
return list(self._shards.values())

def _process_command(self, command: Command) -> None:
"""Handles a command per the type and the extant context within the Sharder.
Expand All @@ -88,7 +82,7 @@ def _process_command(self, command: Command) -> None:
msg = f"OpType {command.op.type} not supported!"
raise NotImplementedError(msg)

if _is_command_global_phase(command):
if self._is_command_global_phase(command):
logger.debug("Ignoring global Phase gate")
return

Expand Down Expand Up @@ -128,14 +122,56 @@ def _build_shard(self, command: Command) -> None:
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type]
)

# Handle dependency calculations
depends_upon = self._resolve_shard_dependencies(
qubits_used, bits_written, bits_read
)

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards[shard.ID] = shard
logger.debug("Appended shard: %s", shard)

def _resolve_shard_dependencies(
self, qubits: set[Qubit], bits_written: set[Bit], bits_read: set[Bit]
) -> set[int]:
"""Finds the dependent shards for a given shard.

This involves checking for qubit interaction and classical hazards of
various types.

Args:
shard: Shard to run dependency calculation on
qubits: Set of all qubits interacted with in the command/sub-commands
bits_written: Classical bits the command/sub-commands write to
bits_read: Classical bits the command/sub-commands read from
"""
logger.debug(
"Resolving shard dependencies with qubits=%s bits_written=%s bits_read=%s",
qubits,
bits_written,
bits_read,
)

depends_upon: set[int] = set()
for shard in self._shards:
shards_to_check: list[Shard] = list(reversed(self._shards.values()))
shards_to_skip: set[int] = set()

for shard in shards_to_check:
if shard.ID in shards_to_skip:
continue
add_dependency = False

# Check qubit dependencies (R/W implicitly) since all commands
# on a given qubit need to be ordered as the circuit dictated
if not shard.qubits_used.isdisjoint(command.qubits):
if not shard.qubits_used.isdisjoint(qubits):
logger.debug("...adding shard dep %s -> qubit overlap", shard.ID)
depends_upon.add(shard.ID)
add_dependency = True
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# NOTE: bits_read will include bits_written in the current impl
Expand All @@ -144,28 +180,35 @@ def _build_shard(self, command: Command) -> None:
# by looking at overlap of bits_written
elif not shard.bits_written.isdisjoint(bits_written):
logger.debug("...adding shard dep %s -> WAW", shard.ID)
depends_upon.add(shard.ID)
add_dependency = True

# Check for read-after-write (value seen would change if reordered)
elif not shard.bits_written.isdisjoint(bits_read):
logger.debug("...adding shard dep %s -> RAW", shard.ID)
depends_upon.add(shard.ID)
add_dependency = True

# Check for write-after-read (no reordering or read is changed)
elif not shard.bits_written.isdisjoint(bits_read):
elif not shard.bits_read.isdisjoint(bits_written):
logger.debug("...adding shard dep %s -> WAR", shard.ID)
add_dependency = True

if add_dependency:
depends_upon.add(shard.ID)
# Avoid recalculating for anything previous
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this avoiding recalculation is the main thing to help performance?

shards_to_skip.update(self._get_dependencies(shard))

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards.append(shard)
logger.debug("Appended shard: %s", shard)
return depends_upon

def _get_dependencies(self, shard: Shard) -> set[int]:
dependencies: set[int] = set()
self._get_dependencies_recursive(dependencies, shard)
return dependencies

def _get_dependencies_recursive(self, accumulator: set[int], shard: Shard) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm overthinking this, but if we're trying to improve performance--then is there any way to replace this with a data-structure like a stack/queue and some sort of while-loop? Or is the overhead of the recursion here minimal?

Copy link
Member

@qartik qartik Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my comment above may have been confusing, currently the changes in this PR lead to significant overhead and hence it is not ready to be merged. make tests is taking way longer than usual (perhaps because of the issue with recursion that you just identified).

Compare the build times with another recent PR:
Screenshot 2023-11-10 at 4 08 03 PM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure but since this is clearly not a slam dunk I'm going to close it and try a more significant refactor that should reduce to linear time.

accumulator.update(shard.depends_upon)
for shard_id in shard.depends_upon:
dependency = self._shards[shard_id]
self._get_dependencies_recursive(accumulator, dependency)

def _cleanup_remaining_commands(self) -> None:
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
Expand Down Expand Up @@ -212,3 +255,15 @@ def should_op_create_shard(op: Op) -> bool:
)
or (op.is_gate() and op.n_qubits > 1)
)

@staticmethod
def _is_command_global_phase(command: Command) -> bool:
"""Check if an operation related to global phase.

Args:
command: Command to evaluate
"""
return command.op.type == OpType.Phase or (
command.op.type == OpType.Conditional
and cast(Conditional, command.op).op.type == OpType.Phase
)
14 changes: 7 additions & 7 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def test_rollup_behavior(self) -> None:

assert shards[3].primary_command.op.type == OpType.Barrier
assert len(shards[3].sub_commands) == 1
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID}
assert shards[3].depends_upon == {shards[1].ID}

assert shards[4].primary_command.op.type == OpType.Barrier
assert len(shards[4].sub_commands) == 1
assert shards[4].depends_upon == {shards[0].ID, shards[2].ID}
assert shards[4].depends_upon == {shards[2].ID}

def test_simple_conditional(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.simple_cond)
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_simple_conditional(self) -> None:
assert shards[3].qubits_used == {circuit.qubits[0]}
assert shards[3].bits_written == {circuit.bits[0]}
assert shards[3].bits_read == {circuit.bits[0]}
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID}
assert shards[3].depends_upon == {shards[1].ID, shards[2].ID}
assert len(shards[3].sub_commands.items()) == 1
s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items()))
assert s2_qubit == circuit.qubits[0]
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915
assert shards[3].qubits_used == {circuit.qubits[0]}
assert shards[3].bits_written == {circuit.bits[0]}
assert shards[3].bits_read == {circuit.bits[0]}
assert shards[3].depends_upon == {shards[2].ID, shards[1].ID}
assert shards[3].depends_upon == {shards[2].ID}

# shard 4: [H q[3];, H q[3];, X q[3];]] barrier q[2], q[3];
assert shards[4].primary_command.op.type == OpType.Barrier
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_complex_barriers(self) -> None: # noqa: PLR0915
assert shards[6].qubits_used == {circuit.qubits[2]}
assert shards[6].bits_written == {circuit.bits[2]}
assert shards[6].bits_read == {circuit.bits[2]}
assert shards[6].depends_upon == {shards[5].ID, shards[4].ID}
assert shards[6].depends_upon == {shards[5].ID}

def test_classical_hazards(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.classical_hazards)
Expand Down Expand Up @@ -261,15 +261,15 @@ def test_classical_hazards(self) -> None:
assert shards[3].qubits_used == set()
assert shards[3].bits_written == {circuit.bits[0]}
assert shards[3].bits_read == {circuit.bits[0]}
assert shards[3].depends_upon == {shards[0].ID}
assert shards[3].depends_upon == {shards[2].ID}

# shard 4: [] if(c[2]==1) c[0]=1;
assert shards[4].primary_command.op.type == OpType.Conditional
assert len(shards[4].sub_commands) == 0
assert shards[4].qubits_used == set()
assert shards[4].bits_written == {circuit.bits[0]}
assert shards[4].bits_read == {circuit.bits[0], circuit.bits[2]}
assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID}
assert shards[4].depends_upon == {shards[1].ID, shards[3].ID}

def test_with_big_gate(self) -> None:
circuit = get_qasm_as_circuit(QasmFile.big_gate)
Expand Down