Skip to content

Commit

Permalink
Improving shard handling and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
nealerickson-qtm committed Oct 2, 2023
1 parent 4e89fc5 commit e5e09f1
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 24 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git"
where = ["."]

[tool.pytest.ini_options]
addopts = "-s -vv"
pythonpath = [
"."
]
log_cli = true
filterwarnings = ["ignore:::lark.s*"]
65 changes: 58 additions & 7 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,74 @@
from dataclasses import dataclass
import io
from dataclasses import dataclass, field
from itertools import count

from pytket.circuit import Command
from pytket.unit_id import UnitID
from pytket.unit_id import Bit, Qubit, UnitID


@dataclass
@dataclass(unsafe_hash=False)
class Shard:
"""
A shard is a logical grouping of operations that represents the unit by which
we actually do placement of qubits
"""

# The schedulable command of the shard
# The "schedulable" command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the other shards this particular shard depends upon, and thus
# must be scheduled after
depends_upon: set["Shard"]
# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

# All qubits used by the primary and sub commands
qubits_used: set[Qubit] = field(init=False)

# Set of all classical bits written to by the primary and sub commands
bits_written: set[Bit] = field(init=False)

# Set of all classical bits read by the primary and sub commands
bits_read: set[Bit] = field(init=False)

# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)

def __post_init__(self) -> None:
self.qubits_used = set(self.primary_command.qubits)
self.bits_written = set(self.primary_command.bits)
self.bits_read = set()

all_sub_commands: list[Command] = []
for sub_commands in self.sub_commands.values():
all_sub_commands.extend(sub_commands)

for sub_command in all_sub_commands:
self.bits_written.update(sub_command.bits)
self.bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore # noqa: PGH003
)

def pretty_print(self) -> str:
output = io.StringIO()
output.write(f"Shard {self.ID}:")
output.write(f"\n Command: {self.primary_command}")
output.write(
f'\n Qubits used: [{", ".join(repr(x) for x in self.qubits_used)}]',
)
output.write("\n Sub commands: ")
if not self.sub_commands:
output.write("none")
for sub in self.sub_commands:
output.write(f"\n {sub}: {self.sub_commands[sub]}")
output.write(f"\n Qubits used: {self.qubits_used}")
output.write(f"\n Bits written: {self.bits_written}")
output.write(f"\n Bits read: {self.bits_read}")
output.write("\n Depends upon shards: ")
if not self.depends_upon:
output.write("none")
output.write(", ".join(map(repr, self.depends_upon)))
content = output.getvalue()
output.close()
return content
59 changes: 48 additions & 11 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class Sharder:

def __init__(self, circuit: Circuit) -> None:
self._circuit = circuit
print(f"Sharder created for circuit {self._circuit}")
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
print(f"Sharder created for circuit {self._circuit}")

def shard(self) -> list[Shard]:
"""
Expand All @@ -28,6 +28,11 @@ def shard(self) -> list[Shard]:
# https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command
for command in self._circuit.get_commands():
self._process_command(command)
self._cleanup_remaining_commands()

print("Shard output:")
for shard in self._shards:
print(shard.pretty_print())
return self._shards

def _process_command(self, command: Command) -> None:
Expand All @@ -44,38 +49,70 @@ def _process_command(self, command: Command) -> None:
print(f"Building shard for command: {command}")
self._build_shard(command)
else:
self._add_pending_command(command)
self._add_pending_sub_command(command)

def _build_shard(self, command: Command) -> None:
"""
Creates a Shard object given the extant sharding context and the schedulable
Command object passed in, and appends it to the Shard list
"""
shard = Shard(command, self._pending_commands, set())
# TODO: Dependencies!
self._pending_commands = {}
# Resolve any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
key for key in list(self._pending_commands) if key in command.qubits
):
sub_commands[key] = self._pending_commands.pop(key)

# Handle dependency calculations
depends_upon: set[int] = set()
for shard in self._shards:
# Check qubit dependencies (R/W implicitly) since all commands on a given qubit
# need to be ordered as the circuit dictated
if not shard.qubits_used.isdisjoint(command.qubits):
depends_upon.add(shard.ID)
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# TODO: Do it!

shard = Shard(command, sub_commands, depends_upon)
self._shards.append(shard)
print("Appended shard:", shard)

def _add_pending_command(self, command: Command) -> None:
def _cleanup_remaining_commands(self) -> None:
"""
Checks for any remaining "unsharded" commands, and if found, adds them
to Barrier op shards for each qubit
"""
remaining_qubits = [k for k, v in self._pending_commands.items() if len(v) > 0]
for qubit in remaining_qubits:
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
# instead if this has unintended consequences
barrier_command = self._circuit.get_commands()[-1]
self._build_shard(barrier_command)

def _add_pending_sub_command(self, command: Command) -> None:
"""
Adds a pending sub command to the buffer to be flushed when a schedulable
operation creates a Shard.
"""
# TODO: Need to make sure 'args[0]' is the right key to use.
if command.args[0] not in self._pending_commands:
self._pending_commands[command.args[0]] = []
self._pending_commands[command.args[0]].append(command)
key = command.qubits[0]
if key not in self._pending_commands:
self._pending_commands[key] = []
self._pending_commands[key].append(command)
print(f"Adding pending command {command}")

@staticmethod
def should_op_create_shard(op: Op) -> bool:
"""
Returns `True` if the operation is one that should result in shard creation.
This includes non-gate operations like measure/reset as well as 2-qubit gates.
TODO: This is almost certainly inadequate right now
"""
# TODO: This is almost certainly inadequate right now
return (
op.type == OpType.Measure
or op.type == OpType.Reset
or op.type == OpType.Barrier
or (op.is_gate() and op.n_qubits > 1)
)
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
target-version = "py310"

line-length = 88
line-length = 120

select = [
"E", # pycodestyle Errors
Expand Down
15 changes: 15 additions & 0 deletions tests/data/qasm/baby_with_rollup.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[2];

h q[0];
h q[1];
CX q[0], q[1];

measure q->c;

h q[0];

h q[1];
11 changes: 11 additions & 0 deletions tests/data/qasm/simple_cond.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[1];

h q;
measure q->c;
reset q;
if (c==1) h q;
measure q->c;
2 changes: 2 additions & 0 deletions tests/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class QasmFiles(Enum):
cond_1 = 2
bv_n10 = 3
baby = 4
baby_with_rollup = 5
simple_cond = 6


def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pytket.circuit import Circuit
from pytket.phir.sharding.shard import Shard

EMPTY_INT_SET: set[int] = set()


class TestShard:
def test_shard_ctor(self) -> None:
circ = Circuit(4) # qubits are numbered 0-3
circ.X(0) # first apply an X gate to qubit 0
circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3
circ.Z(3) # then apply a Z gate to qubit 3
commands = circ.get_commands()

shard = Shard(
commands[1],
{commands[0].qubits[0]: [commands[0]]},
EMPTY_INT_SET,
)

assert shard.primary_command == commands[1]
assert shard.depends_upon == EMPTY_INT_SET
sub_command_key, sub_command_value = next(iter(shard.sub_commands.items()))
assert sub_command_key == commands[0].qubits[0]
assert sub_command_value[0] == commands[0]

def test_shard_ctor_conditional(self) -> None:
circuit = Circuit(4, 4)
circuit.H(0)
circuit.Measure(0, 0)
circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore # noqa: PGH003
circuit.Measure(1, 1) # The command we'll build the shard from
commands = circuit.get_commands()

shard = Shard(
commands[3],
{
circuit.qubits[0]: [commands[2]],
},
EMPTY_INT_SET,
)

assert len(shard.sub_commands.items())
assert shard.qubits_used == {circuit.qubits[1]}
assert shard.bits_read == {circuit.bits[0]}
assert shard.bits_written == {circuit.bits[1]}
103 changes: 98 additions & 5 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,106 @@
from typing import cast

from pytket.circuit import Conditional, Op, OpType
from pytket.phir.sharding.sharder import Sharder

from .sample_data import QasmFiles, get_qasm_as_circuit


class TestSharder:
def test_ctor(self) -> None:
sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby))
assert sharder is not None
def test_should_op_create_shard(self) -> None:
expected_true: list[Op] = [
Op.create(OpType.Measure), # type: ignore # noqa: PGH003
Op.create(OpType.Reset), # type: ignore # noqa: PGH003
Op.create(OpType.CX), # type: ignore # noqa: PGH003
Op.create(OpType.Barrier), # type: ignore # noqa: PGH003
]
expected_false: list[Op] = [
Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003
Op.create(OpType.H), # type: ignore # noqa: PGH003
Op.create(OpType.Z), # type: ignore # noqa: PGH003
]

for op in expected_true:
assert Sharder.should_op_create_shard(op)
for op in expected_false:
assert not Sharder.should_op_create_shard(op)

def test_with_baby_circuit(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.baby)
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 3

assert shards[0].primary_command.op.type == OpType.CX
assert len(shards[0].primary_command.qubits) == 2
assert not shards[0].primary_command.bits
assert len(shards[0].sub_commands) == 2
sub_commands = list(shards[0].sub_commands.items())
print(sub_commands)
assert sub_commands[0][1][0].op.type == OpType.H
assert len(shards[0].depends_upon) == 0

assert shards[1].primary_command.op.type == OpType.Measure
assert len(shards[1].sub_commands) == 0
assert shards[1].depends_upon == {shards[0].ID}

assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands) == 0
assert shards[2].depends_upon == {shards[0].ID}

def test_rollup_behavior(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup)
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 5

assert shards[0].primary_command.op.type == OpType.CX
assert len(shards[0].primary_command.qubits) == 2
assert not shards[0].primary_command.bits
assert len(shards[0].sub_commands) == 2
sub_commands = list(shards[0].sub_commands.items())
print(sub_commands)
assert sub_commands[0][1][0].op.type == OpType.H
assert len(shards[0].depends_upon) == 0

assert shards[1].primary_command.op.type == OpType.Measure
assert len(shards[1].sub_commands) == 0
assert shards[1].depends_upon == {shards[0].ID}

assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands) == 0
assert shards[2].depends_upon == {shards[0].ID}

assert shards[3].primary_command.op.type == OpType.Barrier
assert len(shards[3].sub_commands) == 1
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID}

assert shards[4].primary_command.op.type == OpType.Barrier
assert len(shards[4].sub_commands) == 1
assert shards[4].depends_upon == {shards[0].ID, shards[2].ID}

def test_simple_conditional(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.simple_cond)
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 3

assert shards[0].primary_command.op.type == OpType.Measure
assert len(shards[0].sub_commands.items()) == 1
s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items()))
assert s0_qubit == circuit.qubits[0]
assert s0_sub_cmds[0].op.type == OpType.H

output = sharder.shard()
assert shards[1].primary_command.op.type == OpType.Reset
assert len(shards[1].sub_commands.items()) == 0

assert len(output) == 3
assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands.items()) == 1
s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items()))
assert s2_qubit == circuit.qubits[0]
assert s2_sub_cmds[0].op.type == OpType.Conditional
assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H
assert s2_sub_cmds[0].qubits == [circuit.qubits[0]]

0 comments on commit e5e09f1

Please sign in to comment.