Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving shard handling and testing #3

Merged
merged 15 commits into from
Oct 11, 2023
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git"
where = ["."]

[tool.pytest.ini_options]
addopts = "-s -vv"
pythonpath = [
"."
]
log_cli = true
filterwarnings = ["ignore:::lark.s*"]
65 changes: 58 additions & 7 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,74 @@
from dataclasses import dataclass
import io
from dataclasses import dataclass, field
from itertools import count

from pytket.circuit import Command
from pytket.unit_id import UnitID
from pytket.unit_id import Bit, Qubit, UnitID


@dataclass
@dataclass(unsafe_hash=False)
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
class Shard:
"""
A shard is a logical grouping of operations that represents the unit by which
we actually do placement of qubits
"""

# The schedulable command of the shard
# The "schedulable" command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the other shards this particular shard depends upon, and thus
# must be scheduled after
depends_upon: set["Shard"]
# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

# All qubits used by the primary and sub commands
qubits_used: set[Qubit] = field(init=False)

# Set of all classical bits written to by the primary and sub commands
bits_written: set[Bit] = field(init=False)

# Set of all classical bits read by the primary and sub commands
bits_read: set[Bit] = field(init=False)

# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)

def __post_init__(self) -> None:
self.qubits_used = set(self.primary_command.qubits)
self.bits_written = set(self.primary_command.bits)
self.bits_read = set()

all_sub_commands: list[Command] = []
for sub_commands in self.sub_commands.values():
all_sub_commands.extend(sub_commands)

for sub_command in all_sub_commands:
self.bits_written.update(sub_command.bits)
self.bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore # noqa: PGH003
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
)

def pretty_print(self) -> str:
output = io.StringIO()
output.write(f"Shard {self.ID}:")
output.write(f"\n Command: {self.primary_command}")
output.write(
f'\n Qubits used: [{", ".join(repr(x) for x in self.qubits_used)}]',
)
output.write("\n Sub commands: ")
if not self.sub_commands:
output.write("none")
for sub in self.sub_commands:
output.write(f"\n {sub}: {self.sub_commands[sub]}")
output.write(f"\n Qubits used: {self.qubits_used}")
output.write(f"\n Bits written: {self.bits_written}")
output.write(f"\n Bits read: {self.bits_read}")
output.write("\n Depends upon shards: ")
if not self.depends_upon:
output.write("none")
output.write(", ".join(map(repr, self.depends_upon)))
content = output.getvalue()
output.close()
return content
59 changes: 48 additions & 11 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class Sharder:

def __init__(self, circuit: Circuit) -> None:
self._circuit = circuit
print(f"Sharder created for circuit {self._circuit}")
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
print(f"Sharder created for circuit {self._circuit}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want prints vs. some other form of logging? I'm ok with it for first cut merging PR in, but I imagine we won't want to keep these prints as is.


def shard(self) -> list[Shard]:
"""
Expand All @@ -28,6 +28,11 @@ def shard(self) -> list[Shard]:
# https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command
for command in self._circuit.get_commands():
self._process_command(command)
self._cleanup_remaining_commands()

print("Shard output:")
for shard in self._shards:
print(shard.pretty_print())
return self._shards

def _process_command(self, command: Command) -> None:
Expand All @@ -44,38 +49,70 @@ def _process_command(self, command: Command) -> None:
print(f"Building shard for command: {command}")
self._build_shard(command)
else:
self._add_pending_command(command)
self._add_pending_sub_command(command)

def _build_shard(self, command: Command) -> None:
"""
Creates a Shard object given the extant sharding context and the schedulable
Command object passed in, and appends it to the Shard list
"""
shard = Shard(command, self._pending_commands, set())
# TODO: Dependencies!
self._pending_commands = {}
# Resolve any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
key for key in list(self._pending_commands) if key in command.qubits
):
sub_commands[key] = self._pending_commands.pop(key)

# Handle dependency calculations
depends_upon: set[int] = set()
for shard in self._shards:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if dependency creation should be a self contained method. The overall _build_shard method isn't too long but this could be a natural part to break down.

# Check qubit dependencies (R/W implicitly) since all commands on a given qubit
# need to be ordered as the circuit dictated
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
if not shard.qubits_used.isdisjoint(command.qubits):
depends_upon.add(shard.ID)
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# TODO: Do it!

shard = Shard(command, sub_commands, depends_upon)
self._shards.append(shard)
print("Appended shard:", shard)

def _add_pending_command(self, command: Command) -> None:
def _cleanup_remaining_commands(self) -> None:
"""
Checks for any remaining "unsharded" commands, and if found, adds them
to Barrier op shards for each qubit
"""
remaining_qubits = [k for k, v in self._pending_commands.items() if len(v) > 0]
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
for qubit in remaining_qubits:
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
# instead if this has unintended consequences
barrier_command = self._circuit.get_commands()[-1]
self._build_shard(barrier_command)

def _add_pending_sub_command(self, command: Command) -> None:
"""
Adds a pending sub command to the buffer to be flushed when a schedulable
operation creates a Shard.
"""
# TODO: Need to make sure 'args[0]' is the right key to use.
if command.args[0] not in self._pending_commands:
self._pending_commands[command.args[0]] = []
self._pending_commands[command.args[0]].append(command)
key = command.qubits[0]
if key not in self._pending_commands:
self._pending_commands[key] = []
self._pending_commands[key].append(command)
print(f"Adding pending command {command}")

@staticmethod
def should_op_create_shard(op: Op) -> bool:
"""
Returns `True` if the operation is one that should result in shard creation.
This includes non-gate operations like measure/reset as well as 2-qubit gates.
TODO: This is almost certainly inadequate right now
"""
# TODO: This is almost certainly inadequate right now
return (
op.type == OpType.Measure
or op.type == OpType.Reset
or op.type == OpType.Barrier
or (op.is_gate() and op.n_qubits > 1)
)
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
target-version = "py310"

line-length = 88
line-length = 120
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved

select = [
"E", # pycodestyle Errors
Expand Down
15 changes: 15 additions & 0 deletions tests/data/qasm/baby_with_rollup.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[2];

h q[0];
h q[1];
CX q[0], q[1];

measure q->c;

h q[0];

h q[1];
11 changes: 11 additions & 0 deletions tests/data/qasm/simple_cond.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[1];

h q;
measure q->c;
reset q;
if (c==1) h q;
measure q->c;
2 changes: 2 additions & 0 deletions tests/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class QasmFiles(Enum):
cond_1 = 2
bv_n10 = 3
baby = 4
baby_with_rollup = 5
simple_cond = 6


def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pytket.circuit import Circuit
from pytket.phir.sharding.shard import Shard

EMPTY_INT_SET: set[int] = set()


class TestShard:
def test_shard_ctor(self) -> None:
circ = Circuit(4) # qubits are numbered 0-3
circ.X(0) # first apply an X gate to qubit 0
circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3
circ.Z(3) # then apply a Z gate to qubit 3
commands = circ.get_commands()

shard = Shard(
commands[1],
{commands[0].qubits[0]: [commands[0]]},
EMPTY_INT_SET,
)

assert shard.primary_command == commands[1]
assert shard.depends_upon == EMPTY_INT_SET
sub_command_key, sub_command_value = next(iter(shard.sub_commands.items()))
assert sub_command_key == commands[0].qubits[0]
assert sub_command_value[0] == commands[0]

def test_shard_ctor_conditional(self) -> None:
circuit = Circuit(4, 4)
circuit.H(0)
circuit.Measure(0, 0)
circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore # noqa: PGH003
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
circuit.Measure(1, 1) # The command we'll build the shard from
commands = circuit.get_commands()

shard = Shard(
commands[3],
{
circuit.qubits[0]: [commands[2]],
},
EMPTY_INT_SET,
)

assert len(shard.sub_commands.items())
assert shard.qubits_used == {circuit.qubits[1]}
assert shard.bits_read == {circuit.bits[0]}
assert shard.bits_written == {circuit.bits[1]}
103 changes: 98 additions & 5 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,106 @@
from typing import cast

from pytket.circuit import Conditional, Op, OpType
from pytket.phir.sharding.sharder import Sharder

from .sample_data import QasmFiles, get_qasm_as_circuit


class TestSharder:
def test_ctor(self) -> None:
sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby))
assert sharder is not None
def test_should_op_create_shard(self) -> None:
expected_true: list[Op] = [
Op.create(OpType.Measure), # type: ignore # noqa: PGH003
Op.create(OpType.Reset), # type: ignore # noqa: PGH003
Op.create(OpType.CX), # type: ignore # noqa: PGH003
Op.create(OpType.Barrier), # type: ignore # noqa: PGH003
]
expected_false: list[Op] = [
Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003
Op.create(OpType.H), # type: ignore # noqa: PGH003
Op.create(OpType.Z), # type: ignore # noqa: PGH003
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
]

for op in expected_true:
assert Sharder.should_op_create_shard(op)
for op in expected_false:
assert not Sharder.should_op_create_shard(op)

def test_with_baby_circuit(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.baby)
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 3

assert shards[0].primary_command.op.type == OpType.CX
assert len(shards[0].primary_command.qubits) == 2
assert not shards[0].primary_command.bits
assert len(shards[0].sub_commands) == 2
sub_commands = list(shards[0].sub_commands.items())
print(sub_commands)
assert sub_commands[0][1][0].op.type == OpType.H
assert len(shards[0].depends_upon) == 0

assert shards[1].primary_command.op.type == OpType.Measure
assert len(shards[1].sub_commands) == 0
assert shards[1].depends_upon == {shards[0].ID}

assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands) == 0
assert shards[2].depends_upon == {shards[0].ID}

def test_rollup_behavior(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup)
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 5

assert shards[0].primary_command.op.type == OpType.CX
assert len(shards[0].primary_command.qubits) == 2
assert not shards[0].primary_command.bits
assert len(shards[0].sub_commands) == 2
sub_commands = list(shards[0].sub_commands.items())
print(sub_commands)
assert sub_commands[0][1][0].op.type == OpType.H
assert len(shards[0].depends_upon) == 0

assert shards[1].primary_command.op.type == OpType.Measure
assert len(shards[1].sub_commands) == 0
assert shards[1].depends_upon == {shards[0].ID}

assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands) == 0
assert shards[2].depends_upon == {shards[0].ID}

assert shards[3].primary_command.op.type == OpType.Barrier
assert len(shards[3].sub_commands) == 1
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID}

assert shards[4].primary_command.op.type == OpType.Barrier
assert len(shards[4].sub_commands) == 1
assert shards[4].depends_upon == {shards[0].ID, shards[2].ID}

def test_simple_conditional(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.simple_cond)
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 3

assert shards[0].primary_command.op.type == OpType.Measure
assert len(shards[0].sub_commands.items()) == 1
s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items()))
assert s0_qubit == circuit.qubits[0]
assert s0_sub_cmds[0].op.type == OpType.H

output = sharder.shard()
assert shards[1].primary_command.op.type == OpType.Reset
assert len(shards[1].sub_commands.items()) == 0

assert len(output) == 3
assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands.items()) == 1
s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items()))
assert s2_qubit == circuit.qubits[0]
assert s2_sub_cmds[0].op.type == OpType.Conditional
assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H
assert s2_sub_cmds[0].qubits == [circuit.qubits[0]]