generated from CQCL/pytemplate
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improving shard handling and testing
- Loading branch information
1 parent
4e89fc5
commit e5e09f1
Showing
9 changed files
with
282 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,74 @@ | ||
from dataclasses import dataclass | ||
import io | ||
from dataclasses import dataclass, field | ||
from itertools import count | ||
|
||
from pytket.circuit import Command | ||
from pytket.unit_id import UnitID | ||
from pytket.unit_id import Bit, Qubit, UnitID | ||
|
||
|
||
@dataclass | ||
@dataclass(unsafe_hash=False) | ||
class Shard: | ||
""" | ||
A shard is a logical grouping of operations that represents the unit by which | ||
we actually do placement of qubits | ||
""" | ||
|
||
# The schedulable command of the shard | ||
# The "schedulable" command of the shard | ||
primary_command: Command | ||
|
||
# The other commands related to the primary schedulable command, stored | ||
# as a map of bit-handle (unitID) -> list[Command] | ||
sub_commands: dict[UnitID, list[Command]] | ||
|
||
# A set of the other shards this particular shard depends upon, and thus | ||
# must be scheduled after | ||
depends_upon: set["Shard"] | ||
# A set of the identifiers of other shards this particular shard depends upon | ||
depends_upon: set[int] | ||
|
||
# All qubits used by the primary and sub commands | ||
qubits_used: set[Qubit] = field(init=False) | ||
|
||
# Set of all classical bits written to by the primary and sub commands | ||
bits_written: set[Bit] = field(init=False) | ||
|
||
# Set of all classical bits read by the primary and sub commands | ||
bits_read: set[Bit] = field(init=False) | ||
|
||
# The unique identifier of the shard | ||
ID: int = field(default_factory=count().__next__, init=False) | ||
|
||
def __post_init__(self) -> None: | ||
self.qubits_used = set(self.primary_command.qubits) | ||
self.bits_written = set(self.primary_command.bits) | ||
self.bits_read = set() | ||
|
||
all_sub_commands: list[Command] = [] | ||
for sub_commands in self.sub_commands.values(): | ||
all_sub_commands.extend(sub_commands) | ||
|
||
for sub_command in all_sub_commands: | ||
self.bits_written.update(sub_command.bits) | ||
self.bits_read.update( | ||
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore # noqa: PGH003 | ||
) | ||
|
||
def pretty_print(self) -> str: | ||
output = io.StringIO() | ||
output.write(f"Shard {self.ID}:") | ||
output.write(f"\n Command: {self.primary_command}") | ||
output.write( | ||
f'\n Qubits used: [{", ".join(repr(x) for x in self.qubits_used)}]', | ||
) | ||
output.write("\n Sub commands: ") | ||
if not self.sub_commands: | ||
output.write("none") | ||
for sub in self.sub_commands: | ||
output.write(f"\n {sub}: {self.sub_commands[sub]}") | ||
output.write(f"\n Qubits used: {self.qubits_used}") | ||
output.write(f"\n Bits written: {self.bits_written}") | ||
output.write(f"\n Bits read: {self.bits_read}") | ||
output.write("\n Depends upon shards: ") | ||
if not self.depends_upon: | ||
output.write("none") | ||
output.write(", ".join(map(repr, self.depends_upon))) | ||
content = output.getvalue() | ||
output.close() | ||
return content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
target-version = "py310" | ||
|
||
line-length = 88 | ||
line-length = 120 | ||
|
||
select = [ | ||
"E", # pycodestyle Errors | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
OPENQASM 2.0; | ||
include "hqslib1.inc"; | ||
|
||
qreg q[2]; | ||
creg c[2]; | ||
|
||
h q[0]; | ||
h q[1]; | ||
CX q[0], q[1]; | ||
|
||
measure q->c; | ||
|
||
h q[0]; | ||
|
||
h q[1]; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
OPENQASM 2.0; | ||
include "hqslib1.inc"; | ||
|
||
qreg q[1]; | ||
creg c[1]; | ||
|
||
h q; | ||
measure q->c; | ||
reset q; | ||
if (c==1) h q; | ||
measure q->c; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pytket.circuit import Circuit | ||
from pytket.phir.sharding.shard import Shard | ||
|
||
EMPTY_INT_SET: set[int] = set() | ||
|
||
|
||
class TestShard: | ||
def test_shard_ctor(self) -> None: | ||
circ = Circuit(4) # qubits are numbered 0-3 | ||
circ.X(0) # first apply an X gate to qubit 0 | ||
circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3 | ||
circ.Z(3) # then apply a Z gate to qubit 3 | ||
commands = circ.get_commands() | ||
|
||
shard = Shard( | ||
commands[1], | ||
{commands[0].qubits[0]: [commands[0]]}, | ||
EMPTY_INT_SET, | ||
) | ||
|
||
assert shard.primary_command == commands[1] | ||
assert shard.depends_upon == EMPTY_INT_SET | ||
sub_command_key, sub_command_value = next(iter(shard.sub_commands.items())) | ||
assert sub_command_key == commands[0].qubits[0] | ||
assert sub_command_value[0] == commands[0] | ||
|
||
def test_shard_ctor_conditional(self) -> None: | ||
circuit = Circuit(4, 4) | ||
circuit.H(0) | ||
circuit.Measure(0, 0) | ||
circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore # noqa: PGH003 | ||
circuit.Measure(1, 1) # The command we'll build the shard from | ||
commands = circuit.get_commands() | ||
|
||
shard = Shard( | ||
commands[3], | ||
{ | ||
circuit.qubits[0]: [commands[2]], | ||
}, | ||
EMPTY_INT_SET, | ||
) | ||
|
||
assert len(shard.sub_commands.items()) | ||
assert shard.qubits_used == {circuit.qubits[1]} | ||
assert shard.bits_read == {circuit.bits[0]} | ||
assert shard.bits_written == {circuit.bits[1]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,106 @@ | ||
from typing import cast | ||
|
||
from pytket.circuit import Conditional, Op, OpType | ||
from pytket.phir.sharding.sharder import Sharder | ||
|
||
from .sample_data import QasmFiles, get_qasm_as_circuit | ||
|
||
|
||
class TestSharder: | ||
def test_ctor(self) -> None: | ||
sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby)) | ||
assert sharder is not None | ||
def test_should_op_create_shard(self) -> None: | ||
expected_true: list[Op] = [ | ||
Op.create(OpType.Measure), # type: ignore # noqa: PGH003 | ||
Op.create(OpType.Reset), # type: ignore # noqa: PGH003 | ||
Op.create(OpType.CX), # type: ignore # noqa: PGH003 | ||
Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 | ||
] | ||
expected_false: list[Op] = [ | ||
Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 | ||
Op.create(OpType.H), # type: ignore # noqa: PGH003 | ||
Op.create(OpType.Z), # type: ignore # noqa: PGH003 | ||
] | ||
|
||
for op in expected_true: | ||
assert Sharder.should_op_create_shard(op) | ||
for op in expected_false: | ||
assert not Sharder.should_op_create_shard(op) | ||
|
||
def test_with_baby_circuit(self) -> None: | ||
circuit = get_qasm_as_circuit(QasmFiles.baby) | ||
sharder = Sharder(circuit) | ||
shards = sharder.shard() | ||
|
||
assert len(shards) == 3 | ||
|
||
assert shards[0].primary_command.op.type == OpType.CX | ||
assert len(shards[0].primary_command.qubits) == 2 | ||
assert not shards[0].primary_command.bits | ||
assert len(shards[0].sub_commands) == 2 | ||
sub_commands = list(shards[0].sub_commands.items()) | ||
print(sub_commands) | ||
assert sub_commands[0][1][0].op.type == OpType.H | ||
assert len(shards[0].depends_upon) == 0 | ||
|
||
assert shards[1].primary_command.op.type == OpType.Measure | ||
assert len(shards[1].sub_commands) == 0 | ||
assert shards[1].depends_upon == {shards[0].ID} | ||
|
||
assert shards[2].primary_command.op.type == OpType.Measure | ||
assert len(shards[2].sub_commands) == 0 | ||
assert shards[2].depends_upon == {shards[0].ID} | ||
|
||
def test_rollup_behavior(self) -> None: | ||
circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup) | ||
sharder = Sharder(circuit) | ||
shards = sharder.shard() | ||
|
||
assert len(shards) == 5 | ||
|
||
assert shards[0].primary_command.op.type == OpType.CX | ||
assert len(shards[0].primary_command.qubits) == 2 | ||
assert not shards[0].primary_command.bits | ||
assert len(shards[0].sub_commands) == 2 | ||
sub_commands = list(shards[0].sub_commands.items()) | ||
print(sub_commands) | ||
assert sub_commands[0][1][0].op.type == OpType.H | ||
assert len(shards[0].depends_upon) == 0 | ||
|
||
assert shards[1].primary_command.op.type == OpType.Measure | ||
assert len(shards[1].sub_commands) == 0 | ||
assert shards[1].depends_upon == {shards[0].ID} | ||
|
||
assert shards[2].primary_command.op.type == OpType.Measure | ||
assert len(shards[2].sub_commands) == 0 | ||
assert shards[2].depends_upon == {shards[0].ID} | ||
|
||
assert shards[3].primary_command.op.type == OpType.Barrier | ||
assert len(shards[3].sub_commands) == 1 | ||
assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} | ||
|
||
assert shards[4].primary_command.op.type == OpType.Barrier | ||
assert len(shards[4].sub_commands) == 1 | ||
assert shards[4].depends_upon == {shards[0].ID, shards[2].ID} | ||
|
||
def test_simple_conditional(self) -> None: | ||
circuit = get_qasm_as_circuit(QasmFiles.simple_cond) | ||
sharder = Sharder(circuit) | ||
shards = sharder.shard() | ||
|
||
assert len(shards) == 3 | ||
|
||
assert shards[0].primary_command.op.type == OpType.Measure | ||
assert len(shards[0].sub_commands.items()) == 1 | ||
s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) | ||
assert s0_qubit == circuit.qubits[0] | ||
assert s0_sub_cmds[0].op.type == OpType.H | ||
|
||
output = sharder.shard() | ||
assert shards[1].primary_command.op.type == OpType.Reset | ||
assert len(shards[1].sub_commands.items()) == 0 | ||
|
||
assert len(output) == 3 | ||
assert shards[2].primary_command.op.type == OpType.Measure | ||
assert len(shards[2].sub_commands.items()) == 1 | ||
s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items())) | ||
assert s2_qubit == circuit.qubits[0] | ||
assert s2_sub_cmds[0].op.type == OpType.Conditional | ||
assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H | ||
assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] |