From e5e09f15a8da0a34056e93b2135d6e7471f36d41 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Sun, 1 Oct 2023 19:45:18 -0600 Subject: [PATCH 01/14] Improving shard handling and testing --- pyproject.toml | 3 + pytket/phir/sharding/shard.py | 65 ++++++++++++++-- pytket/phir/sharding/sharder.py | 59 ++++++++++++--- ruff.toml | 2 +- tests/data/qasm/baby_with_rollup.qasm | 15 ++++ tests/data/qasm/simple_cond.qasm | 11 +++ tests/sample_data.py | 2 + tests/test_shard.py | 46 ++++++++++++ tests/test_sharder.py | 103 ++++++++++++++++++++++++-- 9 files changed, 282 insertions(+), 24 deletions(-) create mode 100644 tests/data/qasm/baby_with_rollup.qasm create mode 100644 tests/data/qasm/simple_cond.qasm create mode 100644 tests/test_shard.py diff --git a/pyproject.toml b/pyproject.toml index ca5665b..fa40d33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git" where = ["."] [tool.pytest.ini_options] +addopts = "-s -vv" pythonpath = [ "." ] +log_cli = true +filterwarnings = ["ignore:::lark.s*"] diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index 0ff4cb8..facf98d 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -1,23 +1,74 @@ -from dataclasses import dataclass +import io +from dataclasses import dataclass, field +from itertools import count from pytket.circuit import Command -from pytket.unit_id import UnitID +from pytket.unit_id import Bit, Qubit, UnitID -@dataclass +@dataclass(unsafe_hash=False) class Shard: """ A shard is a logical grouping of operations that represents the unit by which we actually do placement of qubits """ - # The schedulable command of the shard + # The "schedulable" command of the shard primary_command: Command # The other commands related to the primary schedulable command, stored # as a map of bit-handle (unitID) -> list[Command] sub_commands: dict[UnitID, list[Command]] - # A set of the other shards this particular shard depends upon, and thus - # must be scheduled after - depends_upon: set["Shard"] + # A set of the identifiers of other shards this particular shard depends upon + depends_upon: set[int] + + # All qubits used by the primary and sub commands + qubits_used: set[Qubit] = field(init=False) + + # Set of all classical bits written to by the primary and sub commands + bits_written: set[Bit] = field(init=False) + + # Set of all classical bits read by the primary and sub commands + bits_read: set[Bit] = field(init=False) + + # The unique identifier of the shard + ID: int = field(default_factory=count().__next__, init=False) + + def __post_init__(self) -> None: + self.qubits_used = set(self.primary_command.qubits) + self.bits_written = set(self.primary_command.bits) + self.bits_read = set() + + all_sub_commands: list[Command] = [] + for sub_commands in self.sub_commands.values(): + all_sub_commands.extend(sub_commands) + + for sub_command in all_sub_commands: + self.bits_written.update(sub_command.bits) + self.bits_read.update( + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore # noqa: PGH003 + ) + + def pretty_print(self) -> str: + output = io.StringIO() + output.write(f"Shard {self.ID}:") + output.write(f"\n Command: {self.primary_command}") + output.write( + f'\n Qubits used: [{", ".join(repr(x) for x in self.qubits_used)}]', + ) + output.write("\n Sub commands: ") + if not self.sub_commands: + output.write("none") + for sub in self.sub_commands: + output.write(f"\n {sub}: {self.sub_commands[sub]}") + output.write(f"\n Qubits used: {self.qubits_used}") + output.write(f"\n Bits written: {self.bits_written}") + output.write(f"\n Bits read: {self.bits_read}") + output.write("\n Depends upon shards: ") + if not self.depends_upon: + output.write("none") + output.write(", ".join(map(repr, self.depends_upon))) + content = output.getvalue() + output.close() + return content diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index b3e9cf1..e40d3b3 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -15,9 +15,9 @@ class Sharder: def __init__(self, circuit: Circuit) -> None: self._circuit = circuit - print(f"Sharder created for circuit {self._circuit}") self._pending_commands: dict[UnitID, list[Command]] = {} self._shards: list[Shard] = [] + print(f"Sharder created for circuit {self._circuit}") def shard(self) -> list[Shard]: """ @@ -28,6 +28,11 @@ def shard(self) -> list[Shard]: # https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command for command in self._circuit.get_commands(): self._process_command(command) + self._cleanup_remaining_commands() + + print("Shard output:") + for shard in self._shards: + print(shard.pretty_print()) return self._shards def _process_command(self, command: Command) -> None: @@ -44,38 +49,70 @@ def _process_command(self, command: Command) -> None: print(f"Building shard for command: {command}") self._build_shard(command) else: - self._add_pending_command(command) + self._add_pending_sub_command(command) def _build_shard(self, command: Command) -> None: """ Creates a Shard object given the extant sharding context and the schedulable Command object passed in, and appends it to the Shard list """ - shard = Shard(command, self._pending_commands, set()) - # TODO: Dependencies! - self._pending_commands = {} + # Resolve any sub commands (SQ gates) that interact with the same qubits + sub_commands: dict[UnitID, list[Command]] = {} + for key in ( + key for key in list(self._pending_commands) if key in command.qubits + ): + sub_commands[key] = self._pending_commands.pop(key) + + # Handle dependency calculations + depends_upon: set[int] = set() + for shard in self._shards: + # Check qubit dependencies (R/W implicitly) since all commands on a given qubit + # need to be ordered as the circuit dictated + if not shard.qubits_used.isdisjoint(command.qubits): + depends_upon.add(shard.ID) + # Check classical dependencies, which depend on writing and reading + # hazards: RAW, WAW, WAR + # TODO: Do it! + + shard = Shard(command, sub_commands, depends_upon) self._shards.append(shard) print("Appended shard:", shard) - def _add_pending_command(self, command: Command) -> None: + def _cleanup_remaining_commands(self) -> None: + """ + Checks for any remaining "unsharded" commands, and if found, adds them + to Barrier op shards for each qubit + """ + remaining_qubits = [k for k, v in self._pending_commands.items() if len(v) > 0] + for qubit in remaining_qubits: + self._circuit.add_barrier([qubit]) + # Easiest way to get to a command, since there's no constructor. Could + # create an entire orphan circuit with the matching qubits and the barrier + # instead if this has unintended consequences + barrier_command = self._circuit.get_commands()[-1] + self._build_shard(barrier_command) + + def _add_pending_sub_command(self, command: Command) -> None: """ Adds a pending sub command to the buffer to be flushed when a schedulable operation creates a Shard. """ - # TODO: Need to make sure 'args[0]' is the right key to use. - if command.args[0] not in self._pending_commands: - self._pending_commands[command.args[0]] = [] - self._pending_commands[command.args[0]].append(command) + key = command.qubits[0] + if key not in self._pending_commands: + self._pending_commands[key] = [] + self._pending_commands[key].append(command) + print(f"Adding pending command {command}") @staticmethod def should_op_create_shard(op: Op) -> bool: """ Returns `True` if the operation is one that should result in shard creation. This includes non-gate operations like measure/reset as well as 2-qubit gates. + TODO: This is almost certainly inadequate right now """ - # TODO: This is almost certainly inadequate right now return ( op.type == OpType.Measure or op.type == OpType.Reset + or op.type == OpType.Barrier or (op.is_gate() and op.n_qubits > 1) ) diff --git a/ruff.toml b/ruff.toml index c66ae97..27cb084 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,6 +1,6 @@ target-version = "py310" -line-length = 88 +line-length = 120 select = [ "E", # pycodestyle Errors diff --git a/tests/data/qasm/baby_with_rollup.qasm b/tests/data/qasm/baby_with_rollup.qasm new file mode 100644 index 0000000..b2a81fc --- /dev/null +++ b/tests/data/qasm/baby_with_rollup.qasm @@ -0,0 +1,15 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[2]; +creg c[2]; + +h q[0]; +h q[1]; +CX q[0], q[1]; + +measure q->c; + +h q[0]; + +h q[1]; diff --git a/tests/data/qasm/simple_cond.qasm b/tests/data/qasm/simple_cond.qasm new file mode 100644 index 0000000..d590bb6 --- /dev/null +++ b/tests/data/qasm/simple_cond.qasm @@ -0,0 +1,11 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[1]; +creg c[1]; + +h q; +measure q->c; +reset q; +if (c==1) h q; +measure q->c; diff --git a/tests/sample_data.py b/tests/sample_data.py index 73487d1..4e7d3c5 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -9,6 +9,8 @@ class QasmFiles(Enum): cond_1 = 2 bv_n10 = 3 baby = 4 + baby_with_rollup = 5 + simple_cond = 6 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_shard.py b/tests/test_shard.py new file mode 100644 index 0000000..11008b2 --- /dev/null +++ b/tests/test_shard.py @@ -0,0 +1,46 @@ +from pytket.circuit import Circuit +from pytket.phir.sharding.shard import Shard + +EMPTY_INT_SET: set[int] = set() + + +class TestShard: + def test_shard_ctor(self) -> None: + circ = Circuit(4) # qubits are numbered 0-3 + circ.X(0) # first apply an X gate to qubit 0 + circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3 + circ.Z(3) # then apply a Z gate to qubit 3 + commands = circ.get_commands() + + shard = Shard( + commands[1], + {commands[0].qubits[0]: [commands[0]]}, + EMPTY_INT_SET, + ) + + assert shard.primary_command == commands[1] + assert shard.depends_upon == EMPTY_INT_SET + sub_command_key, sub_command_value = next(iter(shard.sub_commands.items())) + assert sub_command_key == commands[0].qubits[0] + assert sub_command_value[0] == commands[0] + + def test_shard_ctor_conditional(self) -> None: + circuit = Circuit(4, 4) + circuit.H(0) + circuit.Measure(0, 0) + circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore # noqa: PGH003 + circuit.Measure(1, 1) # The command we'll build the shard from + commands = circuit.get_commands() + + shard = Shard( + commands[3], + { + circuit.qubits[0]: [commands[2]], + }, + EMPTY_INT_SET, + ) + + assert len(shard.sub_commands.items()) + assert shard.qubits_used == {circuit.qubits[1]} + assert shard.bits_read == {circuit.bits[0]} + assert shard.bits_written == {circuit.bits[1]} diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 7a6c533..82011d0 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -1,13 +1,106 @@ +from typing import cast + +from pytket.circuit import Conditional, Op, OpType from pytket.phir.sharding.sharder import Sharder from .sample_data import QasmFiles, get_qasm_as_circuit class TestSharder: - def test_ctor(self) -> None: - sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby)) - assert sharder is not None + def test_should_op_create_shard(self) -> None: + expected_true: list[Op] = [ + Op.create(OpType.Measure), # type: ignore # noqa: PGH003 + Op.create(OpType.Reset), # type: ignore # noqa: PGH003 + Op.create(OpType.CX), # type: ignore # noqa: PGH003 + Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 + ] + expected_false: list[Op] = [ + Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 + Op.create(OpType.H), # type: ignore # noqa: PGH003 + Op.create(OpType.Z), # type: ignore # noqa: PGH003 + ] + + for op in expected_true: + assert Sharder.should_op_create_shard(op) + for op in expected_false: + assert not Sharder.should_op_create_shard(op) + + def test_with_baby_circuit(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 3 + + assert shards[0].primary_command.op.type == OpType.CX + assert len(shards[0].primary_command.qubits) == 2 + assert not shards[0].primary_command.bits + assert len(shards[0].sub_commands) == 2 + sub_commands = list(shards[0].sub_commands.items()) + print(sub_commands) + assert sub_commands[0][1][0].op.type == OpType.H + assert len(shards[0].depends_upon) == 0 + + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].depends_upon == {shards[0].ID} + + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + def test_rollup_behavior(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 5 + + assert shards[0].primary_command.op.type == OpType.CX + assert len(shards[0].primary_command.qubits) == 2 + assert not shards[0].primary_command.bits + assert len(shards[0].sub_commands) == 2 + sub_commands = list(shards[0].sub_commands.items()) + print(sub_commands) + assert sub_commands[0][1][0].op.type == OpType.H + assert len(shards[0].depends_upon) == 0 + + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].depends_upon == {shards[0].ID} + + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + assert shards[3].primary_command.op.type == OpType.Barrier + assert len(shards[3].sub_commands) == 1 + assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + + assert shards[4].primary_command.op.type == OpType.Barrier + assert len(shards[4].sub_commands) == 1 + assert shards[4].depends_upon == {shards[0].ID, shards[2].ID} + + def test_simple_conditional(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.simple_cond) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 3 + + assert shards[0].primary_command.op.type == OpType.Measure + assert len(shards[0].sub_commands.items()) == 1 + s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) + assert s0_qubit == circuit.qubits[0] + assert s0_sub_cmds[0].op.type == OpType.H - output = sharder.shard() + assert shards[1].primary_command.op.type == OpType.Reset + assert len(shards[1].sub_commands.items()) == 0 - assert len(output) == 3 + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands.items()) == 1 + s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items())) + assert s2_qubit == circuit.qubits[0] + assert s2_sub_cmds[0].op.type == OpType.Conditional + assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H + assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] From 3faaa2738a0ebdd7c51af989f577179e8b88e4b2 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Mon, 2 Oct 2023 11:18:31 -0600 Subject: [PATCH 02/14] Apply suggestions from code review Co-authored-by: Kartik Singhal <130700862+qartik@users.noreply.github.com> --- pytket/phir/sharding/shard.py | 4 ++-- pytket/phir/sharding/sharder.py | 13 +++++-------- tests/test_shard.py | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index facf98d..f73531e 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -6,7 +6,7 @@ from pytket.unit_id import Bit, Qubit, UnitID -@dataclass(unsafe_hash=False) +@dataclass class Shard: """ A shard is a logical grouping of operations that represents the unit by which @@ -47,7 +47,7 @@ def __post_init__(self) -> None: for sub_command in all_sub_commands: self.bits_written.update(sub_command.bits) self.bits_read.update( - set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore # noqa: PGH003 + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 ) def pretty_print(self) -> str: diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index e40d3b3..1455495 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -66,8 +66,8 @@ def _build_shard(self, command: Command) -> None: # Handle dependency calculations depends_upon: set[int] = set() for shard in self._shards: - # Check qubit dependencies (R/W implicitly) since all commands on a given qubit - # need to be ordered as the circuit dictated + # Check qubit dependencies (R/W implicitly) since all commands + # on a given qubit need to be ordered as the circuit dictated if not shard.qubits_used.isdisjoint(command.qubits): depends_upon.add(shard.ID) # Check classical dependencies, which depend on writing and reading @@ -83,7 +83,7 @@ def _cleanup_remaining_commands(self) -> None: Checks for any remaining "unsharded" commands, and if found, adds them to Barrier op shards for each qubit """ - remaining_qubits = [k for k, v in self._pending_commands.items() if len(v) > 0] + remaining_qubits = [k for k, v in self._pending_commands.items() if v] for qubit in remaining_qubits: self._circuit.add_barrier([qubit]) # Easiest way to get to a command, since there's no constructor. Could @@ -110,9 +110,6 @@ def should_op_create_shard(op: Op) -> bool: This includes non-gate operations like measure/reset as well as 2-qubit gates. TODO: This is almost certainly inadequate right now """ - return ( - op.type == OpType.Measure - or op.type == OpType.Reset - or op.type == OpType.Barrier - or (op.is_gate() and op.n_qubits > 1) + return op.type in (OpType.Measure, OpType.Reset, OpType.Barrier) or ( + op.is_gate() and op.n_qubits > 1 ) diff --git a/tests/test_shard.py b/tests/test_shard.py index 11008b2..32968f0 100644 --- a/tests/test_shard.py +++ b/tests/test_shard.py @@ -28,7 +28,7 @@ def test_shard_ctor_conditional(self) -> None: circuit = Circuit(4, 4) circuit.H(0) circuit.Measure(0, 0) - circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore # noqa: PGH003 + circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore [misc] circuit.Measure(1, 1) # The command we'll build the shard from commands = circuit.get_commands() From 3d85f40b42cfaf2ed451d7c8ecd836236bdb19ce Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 3 Oct 2023 11:37:58 -0600 Subject: [PATCH 03/14] in progress --- pytket/phir/sharding/shard.py | 40 ++++++++--------- pytket/phir/sharding/sharder.py | 53 ++++++++++++++++++---- tests/data/qasm/cond_classical.qasm | 24 ++++++++++ tests/data/qasm/simple_cond.qasm | 2 + tests/sample_data.py | 1 + tests/test_sharder.py | 70 +++++++++++++++++++++++++++-- 6 files changed, 158 insertions(+), 32 deletions(-) create mode 100644 tests/data/qasm/cond_classical.qasm diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index f73531e..99799ef 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -13,6 +13,9 @@ class Shard: we actually do placement of qubits """ + # The unique identifier of the shard + ID: int = field(default_factory=count().__next__, init=False) + # The "schedulable" command of the shard primary_command: Command @@ -20,35 +23,32 @@ class Shard: # as a map of bit-handle (unitID) -> list[Command] sub_commands: dict[UnitID, list[Command]] - # A set of the identifiers of other shards this particular shard depends upon - depends_upon: set[int] - # All qubits used by the primary and sub commands - qubits_used: set[Qubit] = field(init=False) + qubits_used: set[Qubit] # = field(init=False) # Set of all classical bits written to by the primary and sub commands - bits_written: set[Bit] = field(init=False) + bits_written: set[Bit] # = field(init=False) # Set of all classical bits read by the primary and sub commands - bits_read: set[Bit] = field(init=False) + bits_read: set[Bit] # = field(init=False) - # The unique identifier of the shard - ID: int = field(default_factory=count().__next__, init=False) + # A set of the identifiers of other shards this particular shard depends upon + depends_upon: set[int] - def __post_init__(self) -> None: - self.qubits_used = set(self.primary_command.qubits) - self.bits_written = set(self.primary_command.bits) - self.bits_read = set() + # def __post_init__(self) -> None: + # self.qubits_used = set(self.primary_command.qubits) + # self.bits_written = set(self.primary_command.bits) + # self.bits_read = set() - all_sub_commands: list[Command] = [] - for sub_commands in self.sub_commands.values(): - all_sub_commands.extend(sub_commands) + # all_sub_commands: list[Command] = [] + # for sub_commands in self.sub_commands.values(): + # all_sub_commands.extend(sub_commands) - for sub_command in all_sub_commands: - self.bits_written.update(sub_command.bits) - self.bits_read.update( - set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 - ) + # for sub_command in all_sub_commands: + # self.bits_written.update(sub_command.bits) + # self.bits_read.update( + # set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 + # ) def pretty_print(self) -> str: output = io.StringIO() diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 1455495..a7b7317 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -1,10 +1,19 @@ from pytket.circuit import Circuit, Command, Op, OpType -from pytket.unit_id import UnitID +from pytket.unit_id import Bit, UnitID from .shard import Shard NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM] +SHARD_TRIGGER_OP_TYPES = [ + OpType.Measure, + OpType.Reset, + OpType.Barrier, + OpType.SetBits, + OpType.ClassicalExpBox, # some classical operations are rolled up into a box + OpType.RangePredicate, +] + class Sharder: """ @@ -46,7 +55,9 @@ def _process_command(self, command: Command) -> None: raise NotImplementedError(msg) if self.should_op_create_shard(command.op): - print(f"Building shard for command: {command}") + print( + f"Building shard for command: {command} args:{command.args} bits:{command.bits}", + ) self._build_shard(command) else: self._add_pending_sub_command(command) @@ -63,6 +74,22 @@ def _build_shard(self, command: Command) -> None: ): sub_commands[key] = self._pending_commands.pop(key) + all_commands = [command] + for sub_command in sub_commands.values(): + all_commands.extend(sub_command) + + qubits_used = set(command.qubits) + + bits_written = set(command.bits) + + bits_read = set() + + for sub_command in all_commands: + bits_written.update(sub_command.bits) + bits_read.update( + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 + ) + # Handle dependency calculations depends_upon: set[int] = set() for shard in self._shards: @@ -72,9 +99,14 @@ def _build_shard(self, command: Command) -> None: depends_upon.add(shard.ID) # Check classical dependencies, which depend on writing and reading # hazards: RAW, WAW, WAR - # TODO: Do it! + elif not shard.bits_written.isdisjoint(bits_written): + depends_upon.add(shard.ID) + elif not shard.bits_read.isdisjoint(bits_written): + depends_upon.add(shard.ID) - shard = Shard(command, sub_commands, depends_upon) + shard = Shard( + command, sub_commands, qubits_used, bits_written, bits_read, depends_upon, + ) self._shards.append(shard) print("Appended shard:", shard) @@ -101,15 +133,20 @@ def _add_pending_sub_command(self, command: Command) -> None: if key not in self._pending_commands: self._pending_commands[key] = [] self._pending_commands[key].append(command) - print(f"Adding pending command {command}") + print( + f"Adding pending command {command} args: {command.args} bits: {command.bits}", + ) @staticmethod def should_op_create_shard(op: Op) -> bool: """ Returns `True` if the operation is one that should result in shard creation. This includes non-gate operations like measure/reset as well as 2-qubit gates. - TODO: This is almost certainly inadequate right now """ - return op.type in (OpType.Measure, OpType.Reset, OpType.Barrier) or ( - op.is_gate() and op.n_qubits > 1 + return ( + op.type in (SHARD_TRIGGER_OP_TYPES) + or ( + op.type == OpType.Conditional and op.op.type in (SHARD_TRIGGER_OP_TYPES) + ) + or (op.is_gate() and op.n_qubits > 1) ) diff --git a/tests/data/qasm/cond_classical.qasm b/tests/data/qasm/cond_classical.qasm new file mode 100644 index 0000000..2b3cb60 --- /dev/null +++ b/tests/data/qasm/cond_classical.qasm @@ -0,0 +1,24 @@ +OPENQASM 2.0; +include "hqslib1_dev.inc"; +qreg q[1]; +creg a[10]; +creg b[10]; +creg c[4]; +// classical assignment of registers +a[0] = 1; +a = 3; +// classical bitwise functions +a = 1; +b = 3; +c = a ^ b; // XOR +// evaluating a beyond creg == int +a = 1; +b = 2; +if(a[0]==1) x q[0]; +if(a!=1) x q[0]; +if(a>1) x q[0]; +if(a<1) x q[0]; +if(a>=1) x q[0]; +if(a<=1) x q[0]; +if (a==10) b=1; +measure q[0] -> c[0]; diff --git a/tests/data/qasm/simple_cond.qasm b/tests/data/qasm/simple_cond.qasm index d590bb6..30a1b95 100644 --- a/tests/data/qasm/simple_cond.qasm +++ b/tests/data/qasm/simple_cond.qasm @@ -3,9 +3,11 @@ include "hqslib1.inc"; qreg q[1]; creg c[1]; +creg z[1]; h q; measure q->c; reset q; if (c==1) h q; +if (c==1) z=3; measure q->c; diff --git a/tests/sample_data.py b/tests/sample_data.py index 4e7d3c5..a6b15db 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -11,6 +11,7 @@ class QasmFiles(Enum): baby = 4 baby_with_rollup = 5 simple_cond = 6 + cond_classical = 7 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 82011d0..57ee256 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -13,6 +13,7 @@ def test_should_op_create_shard(self) -> None: Op.create(OpType.Reset), # type: ignore # noqa: PGH003 Op.create(OpType.CX), # type: ignore # noqa: PGH003 Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 + # Op.create(OpType.SetBits, [3, 1]), ] expected_false: list[Op] = [ Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 @@ -86,21 +87,82 @@ def test_simple_conditional(self) -> None: sharder = Sharder(circuit) shards = sharder.shard() - assert len(shards) == 3 + assert len(shards) == 4 + # shard 0: h q; measure q->c; assert shards[0].primary_command.op.type == OpType.Measure assert len(shards[0].sub_commands.items()) == 1 s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) assert s0_qubit == circuit.qubits[0] assert s0_sub_cmds[0].op.type == OpType.H + assert shards[0].depends_upon == set() + # shard 1: reset q; assert shards[1].primary_command.op.type == OpType.Reset assert len(shards[1].sub_commands.items()) == 0 + assert shards[1].depends_upon == {shards[0].ID} - assert shards[2].primary_command.op.type == OpType.Measure - assert len(shards[2].sub_commands.items()) == 1 - s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items())) + # shard 2: if (c==1) z=3; + assert shards[2].primary_command.op.type == OpType.Conditional + assert cast(Conditional, shards[2].primary_command).op.op.type == OpType.SetBits + assert len(shards[2].sub_commands.keys()) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + # shard 3: if (c==1) h q; measure q->c; + assert shards[3].primary_command.op.type == OpType.Measure + assert len(shards[3].sub_commands.items()) == 1 + s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) assert s2_qubit == circuit.qubits[0] assert s2_sub_cmds[0].op.type == OpType.Conditional assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] + + def test_classical_with_conditionals(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.cond_classical) + sharder = Sharder(circuit) + shards = sharder.shard() + + circuit.get_commands() + + # assert len(shards) == 10 # TODO: fix with correct value + + # shard 0: a[0] = 1; + assert shards[0].primary_command.op.type == OpType.SetBits + assert len(shards[0].sub_commands.keys()) == 0 + assert shards[0].depends_upon == set() + assert shards[0].qubits_used == set() + assert shards[0].bits_read == set() + assert shards[0].bits_written == {circuit.bits[0]} + + # shard 1: a = 3; + assert shards[1].primary_command.op.type == OpType.SetBits + assert len(shards[1].sub_commands.keys()) == 0 + assert shards[1].depends_upon == {shards[0].ID} # WAW for shard 0 + assert shards[1].qubits_used == set() + assert shards[1].bits_read == set() + assert len(shards[1].bits_written) == 10 # TODO: Check for a[0-9] + + # shard 2: a = 1; + assert shards[2].primary_command.op.type == OpType.SetBits + assert len(shards[2].sub_commands.keys()) == 0 + assert shards[2].depends_upon == { + shards[0].ID, + shards[1].ID, + } # WAW for shard 0, 1 + assert shards[2].qubits_used == set() + assert shards[2].bits_read == set() + assert len(shards[2].bits_written) == 10 # TODO: Check for a[0-9] + + # shard 3: b = 3; + assert shards[3].primary_command.op.type == OpType.SetBits + assert len(shards[3].sub_commands.keys()) == 0 + assert shards[3].depends_upon == set() + assert shards[3].qubits_used == set() + assert shards[3].bits_read == set() + assert len(shards[3].bits_written) == 10 # TODO: Check for b[0-9] + + # shard 4: c = a ^ b; // XOR + assert shards[4].primary_command.op.type == OpType.ClassicalExpBox + assert len(shards[3].sub_commands.keys()) == 0 + assert shards[3].depends_upon == set() + assert shards[4].qubits_used == set() From e9add68cfa1b8787908c025e813eb6e09cd0d3a9 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Wed, 4 Oct 2023 10:23:26 -0600 Subject: [PATCH 04/14] test progress --- pytket/phir/sharding/shard.py | 10 +--- pytket/phir/sharding/sharder.py | 46 ++++++++++++---- tests/data/qasm/cond_classical.qasm | 1 + tests/data/qasm/simple_cond.qasm | 2 +- tests/test_shard.py | 82 ++++++++++++++--------------- tests/test_sharder.py | 72 ++++++------------------- 6 files changed, 96 insertions(+), 117 deletions(-) diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index 99799ef..00af2ad 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -54,21 +54,15 @@ def pretty_print(self) -> str: output = io.StringIO() output.write(f"Shard {self.ID}:") output.write(f"\n Command: {self.primary_command}") - output.write( - f'\n Qubits used: [{", ".join(repr(x) for x in self.qubits_used)}]', - ) output.write("\n Sub commands: ") if not self.sub_commands: - output.write("none") + output.write("[]") for sub in self.sub_commands: output.write(f"\n {sub}: {self.sub_commands[sub]}") output.write(f"\n Qubits used: {self.qubits_used}") output.write(f"\n Bits written: {self.bits_written}") output.write(f"\n Bits read: {self.bits_read}") - output.write("\n Depends upon shards: ") - if not self.depends_upon: - output.write("none") - output.write(", ".join(map(repr, self.depends_upon))) + output.write(f"\n Depends upon: {self.depends_upon}") content = output.getvalue() output.close() return content diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index a7b7317..64477d6 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -1,4 +1,6 @@ -from pytket.circuit import Circuit, Command, Op, OpType +from typing import cast + +from pytket.circuit import Circuit, Command, Conditional, Op, OpType from pytket.unit_id import Bit, UnitID from .shard import Shard @@ -39,6 +41,7 @@ def shard(self) -> list[Shard]: self._process_command(command) self._cleanup_remaining_commands() + print("--------------------------------------------") print("Shard output:") for shard in self._shards: print(shard.pretty_print()) @@ -67,7 +70,7 @@ def _build_shard(self, command: Command) -> None: Creates a Shard object given the extant sharding context and the schedulable Command object passed in, and appends it to the Shard list """ - # Resolve any sub commands (SQ gates) that interact with the same qubits + # Rollup any sub commands (SQ gates) that interact with the same qubits sub_commands: dict[UnitID, list[Command]] = {} for key in ( key for key in list(self._pending_commands) if key in command.qubits @@ -75,19 +78,17 @@ def _build_shard(self, command: Command) -> None: sub_commands[key] = self._pending_commands.pop(key) all_commands = [command] - for sub_command in sub_commands.values(): - all_commands.extend(sub_command) + for sub_command_list in sub_commands.values(): + all_commands.extend(sub_command_list) qubits_used = set(command.qubits) - bits_written = set(command.bits) - - bits_read = set() + bits_read: set[Bit] = set() for sub_command in all_commands: bits_written.update(sub_command.bits) bits_read.update( - set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] ) # Handle dependency calculations @@ -96,16 +97,38 @@ def _build_shard(self, command: Command) -> None: # Check qubit dependencies (R/W implicitly) since all commands # on a given qubit need to be ordered as the circuit dictated if not shard.qubits_used.isdisjoint(command.qubits): + print(f"...adding shard dep {shard.ID} -> qubit overlap") depends_upon.add(shard.ID) # Check classical dependencies, which depend on writing and reading # hazards: RAW, WAW, WAR + # NOTE: bits_read will include bits_written in the current impl + + # Check for write-after-write (changing order would change final value) + # by looking at overlap of bits_written elif not shard.bits_written.isdisjoint(bits_written): + print(f"...adding shard dep {shard.ID} -> WAW") depends_upon.add(shard.ID) - elif not shard.bits_read.isdisjoint(bits_written): + + # Check for read-after-write (value seen would change if reordered) + # elif not shard.bits_read.isdisjoint(bits_written): + # print(f'...adding shard dep {shard.ID} -> ') + # depends_upon.add(shard.ID) + elif not shard.bits_written.isdisjoint(bits_read): + print(f"...adding shard dep {shard.ID} -> RAW") + depends_upon.add(shard.ID) + + # Check for write-after-read (no reordering or read is changed) + elif not shard.bits_written.isdisjoint(bits_read): + print(f"...adding shard dep {shard.ID} -> WAR") depends_upon.add(shard.ID) shard = Shard( - command, sub_commands, qubits_used, bits_written, bits_read, depends_upon, + command, + sub_commands, + qubits_used, + bits_written, + bits_read, + depends_upon, ) self._shards.append(shard) print("Appended shard:", shard) @@ -146,7 +169,8 @@ def should_op_create_shard(op: Op) -> bool: return ( op.type in (SHARD_TRIGGER_OP_TYPES) or ( - op.type == OpType.Conditional and op.op.type in (SHARD_TRIGGER_OP_TYPES) + op.type == OpType.Conditional + and cast(Conditional, op).op.type in (SHARD_TRIGGER_OP_TYPES) ) or (op.is_gate() and op.n_qubits > 1) ) diff --git a/tests/data/qasm/cond_classical.qasm b/tests/data/qasm/cond_classical.qasm index 2b3cb60..6d7ca18 100644 --- a/tests/data/qasm/cond_classical.qasm +++ b/tests/data/qasm/cond_classical.qasm @@ -4,6 +4,7 @@ qreg q[1]; creg a[10]; creg b[10]; creg c[4]; + // classical assignment of registers a[0] = 1; a = 3; diff --git a/tests/data/qasm/simple_cond.qasm b/tests/data/qasm/simple_cond.qasm index 30a1b95..5334a94 100644 --- a/tests/data/qasm/simple_cond.qasm +++ b/tests/data/qasm/simple_cond.qasm @@ -9,5 +9,5 @@ h q; measure q->c; reset q; if (c==1) h q; -if (c==1) z=3; +if (c==1) z=1; measure q->c; diff --git a/tests/test_shard.py b/tests/test_shard.py index 32968f0..e2f03e7 100644 --- a/tests/test_shard.py +++ b/tests/test_shard.py @@ -1,46 +1,44 @@ -from pytket.circuit import Circuit -from pytket.phir.sharding.shard import Shard - EMPTY_INT_SET: set[int] = set() class TestShard: - def test_shard_ctor(self) -> None: - circ = Circuit(4) # qubits are numbered 0-3 - circ.X(0) # first apply an X gate to qubit 0 - circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3 - circ.Z(3) # then apply a Z gate to qubit 3 - commands = circ.get_commands() - - shard = Shard( - commands[1], - {commands[0].qubits[0]: [commands[0]]}, - EMPTY_INT_SET, - ) - - assert shard.primary_command == commands[1] - assert shard.depends_upon == EMPTY_INT_SET - sub_command_key, sub_command_value = next(iter(shard.sub_commands.items())) - assert sub_command_key == commands[0].qubits[0] - assert sub_command_value[0] == commands[0] - - def test_shard_ctor_conditional(self) -> None: - circuit = Circuit(4, 4) - circuit.H(0) - circuit.Measure(0, 0) - circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore [misc] - circuit.Measure(1, 1) # The command we'll build the shard from - commands = circuit.get_commands() - - shard = Shard( - commands[3], - { - circuit.qubits[0]: [commands[2]], - }, - EMPTY_INT_SET, - ) - - assert len(shard.sub_commands.items()) - assert shard.qubits_used == {circuit.qubits[1]} - assert shard.bits_read == {circuit.bits[0]} - assert shard.bits_written == {circuit.bits[1]} + pass + # def test_shard_ctor(self) -> None: + # circ = Circuit(4) # qubits are numbered 0-3 + # circ.X(0) # first apply an X gate to qubit 0 + # circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3 + # circ.Z(3) # then apply a Z gate to qubit 3 + # commands = circ.get_commands() + + # shard = Shard( + # commands[1], + # {commands[0].qubits[0]: [commands[0]]}, + # EMPTY_INT_SET, + # ) + + # assert shard.primary_command == commands[1] + # assert shard.depends_upon == EMPTY_INT_SET + # sub_command_key, sub_command_value = next(iter(shard.sub_commands.items())) + # assert sub_command_key == commands[0].qubits[0] + # assert sub_command_value[0] == commands[0] + + # def test_shard_ctor_conditional(self) -> None: + # circuit = Circuit(4, 4) + # circuit.H(0) + # circuit.Measure(0, 0) + # circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore [misc] + # circuit.Measure(1, 1) # The command we'll build the shard from + # commands = circuit.get_commands() + + # shard = Shard( + # commands[3], + # { + # circuit.qubits[0]: [commands[2]], + # }, + # EMPTY_INT_SET, + # ) + + # assert len(shard.sub_commands.items()) + # assert shard.qubits_used == {circuit.qubits[1]} + # assert shard.bits_read == {circuit.bits[0]} + # assert shard.bits_written == {circuit.bits[1]} diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 57ee256..34ad51d 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -89,80 +89,42 @@ def test_simple_conditional(self) -> None: assert len(shards) == 4 - # shard 0: h q; measure q->c; + # shard 0: [h q;] measure q->c; assert shards[0].primary_command.op.type == OpType.Measure + assert shards[0].qubits_used == {circuit.qubits[0]} + assert shards[0].bits_written == {circuit.bits[0]} + assert shards[0].depends_upon == set() assert len(shards[0].sub_commands.items()) == 1 s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) assert s0_qubit == circuit.qubits[0] assert s0_sub_cmds[0].op.type == OpType.H - assert shards[0].depends_upon == set() # shard 1: reset q; assert shards[1].primary_command.op.type == OpType.Reset assert len(shards[1].sub_commands.items()) == 0 + assert shards[1].qubits_used == {circuit.qubits[0]} assert shards[1].depends_upon == {shards[0].ID} + assert shards[1].bits_written == set() + assert shards[1].bits_read == set() - # shard 2: if (c==1) z=3; + # shard 2: if (c==1) z=1; assert shards[2].primary_command.op.type == OpType.Conditional - assert cast(Conditional, shards[2].primary_command).op.op.type == OpType.SetBits + assert cast(Conditional, shards[2].primary_command.op).op.type == OpType.SetBits assert len(shards[2].sub_commands.keys()) == 0 + assert shards[2].qubits_used == set() + assert shards[2].bits_written == {circuit.bits[1]} + # assert shards[2].bits_read == {circuit.bits[0]} assert shards[2].depends_upon == {shards[0].ID} - # shard 3: if (c==1) h q; measure q->c; + # shard 3: [if (c==1) h q;] measure q->c; assert shards[3].primary_command.op.type == OpType.Measure + assert shards[3].qubits_used == {circuit.qubits[0]} + assert shards[3].bits_written == {circuit.bits[0]} + assert shards[3].bits_read == {circuit.bits[0]} + assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} assert len(shards[3].sub_commands.items()) == 1 s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) assert s2_qubit == circuit.qubits[0] assert s2_sub_cmds[0].op.type == OpType.Conditional assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] - - def test_classical_with_conditionals(self) -> None: - circuit = get_qasm_as_circuit(QasmFiles.cond_classical) - sharder = Sharder(circuit) - shards = sharder.shard() - - circuit.get_commands() - - # assert len(shards) == 10 # TODO: fix with correct value - - # shard 0: a[0] = 1; - assert shards[0].primary_command.op.type == OpType.SetBits - assert len(shards[0].sub_commands.keys()) == 0 - assert shards[0].depends_upon == set() - assert shards[0].qubits_used == set() - assert shards[0].bits_read == set() - assert shards[0].bits_written == {circuit.bits[0]} - - # shard 1: a = 3; - assert shards[1].primary_command.op.type == OpType.SetBits - assert len(shards[1].sub_commands.keys()) == 0 - assert shards[1].depends_upon == {shards[0].ID} # WAW for shard 0 - assert shards[1].qubits_used == set() - assert shards[1].bits_read == set() - assert len(shards[1].bits_written) == 10 # TODO: Check for a[0-9] - - # shard 2: a = 1; - assert shards[2].primary_command.op.type == OpType.SetBits - assert len(shards[2].sub_commands.keys()) == 0 - assert shards[2].depends_upon == { - shards[0].ID, - shards[1].ID, - } # WAW for shard 0, 1 - assert shards[2].qubits_used == set() - assert shards[2].bits_read == set() - assert len(shards[2].bits_written) == 10 # TODO: Check for a[0-9] - - # shard 3: b = 3; - assert shards[3].primary_command.op.type == OpType.SetBits - assert len(shards[3].sub_commands.keys()) == 0 - assert shards[3].depends_upon == set() - assert shards[3].qubits_used == set() - assert shards[3].bits_read == set() - assert len(shards[3].bits_written) == 10 # TODO: Check for b[0-9] - - # shard 4: c = a ^ b; // XOR - assert shards[4].primary_command.op.type == OpType.ClassicalExpBox - assert len(shards[3].sub_commands.keys()) == 0 - assert shards[3].depends_upon == set() - assert shards[4].qubits_used == set() From 4af6e369d33b8c4b83cfab8cf18bc241664d1c7f Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Wed, 4 Oct 2023 10:24:13 -0600 Subject: [PATCH 05/14] fixing ruff --- ruff.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index 27cb084..c66ae97 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,6 +1,6 @@ target-version = "py310" -line-length = 120 +line-length = 88 select = [ "E", # pycodestyle Errors From 4cfd4233fb4ff7e26a6afacceac46d636d19e013 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Wed, 4 Oct 2023 10:27:24 -0600 Subject: [PATCH 06/14] removing commented --- pytket/phir/sharding/shard.py | 15 ------------ tests/test_shard.py | 44 ----------------------------------- tests/test_sharder.py | 3 +-- 3 files changed, 1 insertion(+), 61 deletions(-) delete mode 100644 tests/test_shard.py diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index 00af2ad..41b41ee 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -35,21 +35,6 @@ class Shard: # A set of the identifiers of other shards this particular shard depends upon depends_upon: set[int] - # def __post_init__(self) -> None: - # self.qubits_used = set(self.primary_command.qubits) - # self.bits_written = set(self.primary_command.bits) - # self.bits_read = set() - - # all_sub_commands: list[Command] = [] - # for sub_commands in self.sub_commands.values(): - # all_sub_commands.extend(sub_commands) - - # for sub_command in all_sub_commands: - # self.bits_written.update(sub_command.bits) - # self.bits_read.update( - # set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501 - # ) - def pretty_print(self) -> str: output = io.StringIO() output.write(f"Shard {self.ID}:") diff --git a/tests/test_shard.py b/tests/test_shard.py deleted file mode 100644 index e2f03e7..0000000 --- a/tests/test_shard.py +++ /dev/null @@ -1,44 +0,0 @@ -EMPTY_INT_SET: set[int] = set() - - -class TestShard: - pass - # def test_shard_ctor(self) -> None: - # circ = Circuit(4) # qubits are numbered 0-3 - # circ.X(0) # first apply an X gate to qubit 0 - # circ.CX(1, 3) # and apply a CX gate with control qubit 1 and target qubit 3 - # circ.Z(3) # then apply a Z gate to qubit 3 - # commands = circ.get_commands() - - # shard = Shard( - # commands[1], - # {commands[0].qubits[0]: [commands[0]]}, - # EMPTY_INT_SET, - # ) - - # assert shard.primary_command == commands[1] - # assert shard.depends_upon == EMPTY_INT_SET - # sub_command_key, sub_command_value = next(iter(shard.sub_commands.items())) - # assert sub_command_key == commands[0].qubits[0] - # assert sub_command_value[0] == commands[0] - - # def test_shard_ctor_conditional(self) -> None: - # circuit = Circuit(4, 4) - # circuit.H(0) - # circuit.Measure(0, 0) - # circuit.X(1, condition_bits=[0], condition_value=1) # type: ignore [misc] - # circuit.Measure(1, 1) # The command we'll build the shard from - # commands = circuit.get_commands() - - # shard = Shard( - # commands[3], - # { - # circuit.qubits[0]: [commands[2]], - # }, - # EMPTY_INT_SET, - # ) - - # assert len(shard.sub_commands.items()) - # assert shard.qubits_used == {circuit.qubits[1]} - # assert shard.bits_read == {circuit.bits[0]} - # assert shard.bits_written == {circuit.bits[1]} diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 34ad51d..aa1d962 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -13,7 +13,6 @@ def test_should_op_create_shard(self) -> None: Op.create(OpType.Reset), # type: ignore # noqa: PGH003 Op.create(OpType.CX), # type: ignore # noqa: PGH003 Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 - # Op.create(OpType.SetBits, [3, 1]), ] expected_false: list[Op] = [ Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 @@ -113,7 +112,7 @@ def test_simple_conditional(self) -> None: assert len(shards[2].sub_commands.keys()) == 0 assert shards[2].qubits_used == set() assert shards[2].bits_written == {circuit.bits[1]} - # assert shards[2].bits_read == {circuit.bits[0]} + assert shards[2].bits_read == {circuit.bits[0], circuit.bits[1]} assert shards[2].depends_upon == {shards[0].ID} # shard 3: [if (c==1) h q;] measure q->c; From 0b7a5dd208092f9e961cc2902010511457189f61 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Wed, 4 Oct 2023 10:53:16 -0600 Subject: [PATCH 07/14] linting feedback --- pytket/phir/sharding/sharder.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 64477d6..7ed763a 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -59,7 +59,7 @@ def _process_command(self, command: Command) -> None: if self.should_op_create_shard(command.op): print( - f"Building shard for command: {command} args:{command.args} bits:{command.bits}", + f"Building shard for command: {command}", ) self._build_shard(command) else: @@ -67,7 +67,7 @@ def _process_command(self, command: Command) -> None: def _build_shard(self, command: Command) -> None: """ - Creates a Shard object given the extant sharding context and the schedulable + Creates a Shard object given the extant sharding context and the primary Command object passed in, and appends it to the Shard list """ # Rollup any sub commands (SQ gates) that interact with the same qubits @@ -85,10 +85,11 @@ def _build_shard(self, command: Command) -> None: bits_written = set(command.bits) bits_read: set[Bit] = set() + # def filter_to_bits: Callable[[UnitId], bool] = lambda x: isinstance(x, Bit) for sub_command in all_commands: bits_written.update(sub_command.bits) bits_read.update( - set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] # noqa: E501 ) # Handle dependency calculations @@ -157,7 +158,7 @@ def _add_pending_sub_command(self, command: Command) -> None: self._pending_commands[key] = [] self._pending_commands[key].append(command) print( - f"Adding pending command {command} args: {command.args} bits: {command.bits}", + f"Adding pending command {command}", ) @staticmethod From 73366346af5d8b6604486e5b37af559bd82c67c1 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Wed, 4 Oct 2023 12:16:30 -0600 Subject: [PATCH 08/14] mypy --- tests/test_sharder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_sharder.py b/tests/test_sharder.py index aa1d962..caada75 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -9,15 +9,15 @@ class TestSharder: def test_should_op_create_shard(self) -> None: expected_true: list[Op] = [ - Op.create(OpType.Measure), # type: ignore # noqa: PGH003 - Op.create(OpType.Reset), # type: ignore # noqa: PGH003 - Op.create(OpType.CX), # type: ignore # noqa: PGH003 - Op.create(OpType.Barrier), # type: ignore # noqa: PGH003 + Op.create(OpType.Measure), # type: ignore [misc] + Op.create(OpType.Reset), # type: ignore [misc] + Op.create(OpType.CX), # type: ignore [misc] + Op.create(OpType.Barrier), # type: ignore [misc] ] expected_false: list[Op] = [ - Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003 - Op.create(OpType.H), # type: ignore # noqa: PGH003 - Op.create(OpType.Z), # type: ignore # noqa: PGH003 + Op.create(OpType.U1, 0.32), # type: ignore [misc] + Op.create(OpType.H), # type: ignore [misc] + Op.create(OpType.Z), # type: ignore [misc] ] for op in expected_true: From 7fd37fd1c7bfe87b621c4c1a72876f969a2d18ae Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 10 Oct 2023 10:37:53 -0600 Subject: [PATCH 09/14] Adding test for partial barriers --- tests/data/qasm/barrier_complex.qasm | 27 ++++++++++++ tests/sample_data.py | 1 + tests/test_sharder.py | 64 ++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 tests/data/qasm/barrier_complex.qasm diff --git a/tests/data/qasm/barrier_complex.qasm b/tests/data/qasm/barrier_complex.qasm new file mode 100644 index 0000000..bee4d0b --- /dev/null +++ b/tests/data/qasm/barrier_complex.qasm @@ -0,0 +1,27 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[4]; +creg c[4]; + +h q[0]; +h q[1]; +h q[2]; +h q[3]; +c[3] = 1; + +barrier q[0], q[1], c[3]; + +CX q[0], q[1]; + +measure q[0]->c[0]; + +h q[2]; +h q[3]; +x q[3]; + +barrier q[2], q[3]; + +CX q[2], q[3]; + +measure q[2]->c[2]; diff --git a/tests/sample_data.py b/tests/sample_data.py index 56300f1..983e8be 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -12,6 +12,7 @@ class QasmFiles(Enum): baby_with_rollup = 5 simple_cond = 6 cond_classical = 7 + barrier_complex = 8 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_sharder.py b/tests/test_sharder.py index caada75..de8bdf8 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -127,3 +127,67 @@ def test_simple_conditional(self) -> None: assert s2_sub_cmds[0].op.type == OpType.Conditional assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] + + def test_complex_barriers(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.barrier_complex) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 7 + + # shard 0: [], c[3] = 1 + assert shards[0].primary_command.op.type == OpType.SetBits + assert len(shards[0].sub_commands.items()) == 0 + assert shards[0].qubits_used == set() + assert shards[0].bits_written == {circuit.bits[3]} + assert shards[0].bits_read == {circuit.bits[3]} # bits written are always read + assert shards[0].depends_upon == set() + + # shard 1: [h q[0]; h q[1];] barrier q[0], q[1], c[3]; + assert shards[1].primary_command.op.type == OpType.Barrier + assert len(shards[1].sub_commands.items()) == 2 + # TODO: sub commands + assert shards[1].qubits_used == {circuit.qubits[0], circuit.qubits[1]} + assert shards[1].bits_written == {circuit.bits[3]} + assert shards[1].bits_read == {circuit.bits[3]} + assert shards[1].depends_upon == {shards[0].ID} + + # shard 2: [] CX q[0], q[1]; + assert shards[2].primary_command.op.type == OpType.CX + assert len(shards[2].sub_commands.items()) == 0 + assert shards[2].qubits_used == {circuit.qubits[0], circuit.qubits[1]} + assert shards[2].bits_written == set() + assert shards[2].bits_read == set() + assert shards[2].depends_upon == {shards[1].ID} + + # shard 3: measure q[0]->c[0]; + assert shards[3].primary_command.op.type == OpType.Measure + assert len(shards[3].sub_commands.items()) == 0 + assert shards[3].qubits_used == {circuit.qubits[0]} + assert shards[3].bits_written == {circuit.bits[0]} + assert shards[3].bits_read == {circuit.bits[0]} + assert shards[3].depends_upon == {shards[2].ID, shards[1].ID} + + # shard 4: [] barrier q[2], q[3]; + assert shards[4].primary_command.op.type == OpType.Barrier + assert len(shards[4].sub_commands.items()) == 2 + assert shards[4].qubits_used == {circuit.qubits[2], circuit.qubits[3]} + assert shards[4].bits_written == set() + assert shards[4].bits_read == set() + assert shards[4].depends_upon == set() + + # shard 5: [] CX q[2], q[3]; + assert shards[5].primary_command.op.type == OpType.CX + assert len(shards[5].sub_commands.items()) == 0 + assert shards[5].qubits_used == {circuit.qubits[2], circuit.qubits[3]} + assert shards[5].bits_written == set() + assert shards[5].bits_read == set() + assert shards[5].depends_upon == {shards[4].ID} + + # shard 6: measure q[2]->c[2]; + assert shards[6].primary_command.op.type == OpType.Measure + assert len(shards[6].sub_commands.items()) == 0 + assert shards[6].qubits_used == {circuit.qubits[2]} + assert shards[6].bits_written == {circuit.bits[2]} + assert shards[6].bits_read == {circuit.bits[2]} + assert shards[6].depends_upon == {shards[5].ID, shards[4].ID} From c8c634fc61f50d73b6255b3460170832043b4172 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 10 Oct 2023 11:16:02 -0600 Subject: [PATCH 10/14] updating test --- ruff.toml | 3 ++- tests/test_sharder.py | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/ruff.toml b/ruff.toml index 0552189..2350fd0 100644 --- a/ruff.toml +++ b/ruff.toml @@ -63,7 +63,8 @@ ignore = [ "D101", # Missing docstring in public class "D102", # Missing docstring in public method "S101", # Use of `assert` detected - "PLR2004" # Magic constants + "PLR2004", # Magic constants + "PLR0915", # Too many statements in function ] [pydocstyle] diff --git a/tests/test_sharder.py b/tests/test_sharder.py index de8bdf8..7d421f0 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -146,7 +146,14 @@ def test_complex_barriers(self) -> None: # shard 1: [h q[0]; h q[1];] barrier q[0], q[1], c[3]; assert shards[1].primary_command.op.type == OpType.Barrier assert len(shards[1].sub_commands.items()) == 2 - # TODO: sub commands + shard_1_q0_cmds = shards[1].sub_commands[circuit.qubits[0]] + assert len(shard_1_q0_cmds) == 1 + assert shard_1_q0_cmds[0].op.type == OpType.H + assert shard_1_q0_cmds[0].qubits == [circuit.qubits[0]] + shard_1_q1_cmds = shards[1].sub_commands[circuit.qubits[1]] + assert len(shard_1_q1_cmds) == 1 + assert shard_1_q1_cmds[0].op.type == OpType.H + assert shard_1_q1_cmds[0].qubits == [circuit.qubits[1]] assert shards[1].qubits_used == {circuit.qubits[0], circuit.qubits[1]} assert shards[1].bits_written == {circuit.bits[3]} assert shards[1].bits_read == {circuit.bits[3]} @@ -168,9 +175,23 @@ def test_complex_barriers(self) -> None: assert shards[3].bits_read == {circuit.bits[0]} assert shards[3].depends_upon == {shards[2].ID, shards[1].ID} - # shard 4: [] barrier q[2], q[3]; + # shard 4: [H q[3];, H q[3];, X q[3];]] barrier q[2], q[3]; assert shards[4].primary_command.op.type == OpType.Barrier assert len(shards[4].sub_commands.items()) == 2 + shard_4_q2_cmds = shards[4].sub_commands[circuit.qubits[2]] + assert len(shard_4_q2_cmds) == 2 + assert shard_4_q2_cmds[0].op.type == OpType.H + assert shard_4_q2_cmds[0].qubits == [circuit.qubits[2]] + assert shard_4_q2_cmds[1].op.type == OpType.H + assert shard_4_q2_cmds[1].qubits == [circuit.qubits[2]] + shard_4_q3_cmds = shards[4].sub_commands[circuit.qubits[3]] + assert len(shard_4_q3_cmds) == 3 + assert shard_4_q3_cmds[0].op.type == OpType.H + assert shard_4_q3_cmds[0].qubits == [circuit.qubits[3]] + assert shard_4_q3_cmds[1].op.type == OpType.H + assert shard_4_q3_cmds[1].qubits == [circuit.qubits[3]] + assert shard_4_q3_cmds[2].op.type == OpType.X + assert shard_4_q3_cmds[2].qubits == [circuit.qubits[3]] assert shards[4].qubits_used == {circuit.qubits[2], circuit.qubits[3]} assert shards[4].bits_written == set() assert shards[4].bits_read == set() From e04dc692f82841cac5fff296f5a5f9c475d6253b Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 10 Oct 2023 11:24:14 -0600 Subject: [PATCH 11/14] cleanup comments --- pytket/phir/sharding/sharder.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 0639713..8ec5a92 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -45,7 +45,6 @@ def shard(self) -> list[Shard]: list of Shards needed to schedule """ print("Sharding begins....") - # https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command for command in self._circuit.get_commands(): self._process_command(command) self._cleanup_remaining_commands() @@ -99,7 +98,6 @@ def _build_shard(self, command: Command) -> None: bits_written = set(command.bits) bits_read: set[Bit] = set() - # def filter_to_bits: Callable[[UnitId], bool] = lambda x: isinstance(x, Bit) for sub_command in all_commands: bits_written.update(sub_command.bits) bits_read.update( @@ -125,9 +123,6 @@ def _build_shard(self, command: Command) -> None: depends_upon.add(shard.ID) # Check for read-after-write (value seen would change if reordered) - # elif not shard.bits_read.isdisjoint(bits_written): - # print(f'...adding shard dep {shard.ID} -> ') - # depends_upon.add(shard.ID) elif not shard.bits_written.isdisjoint(bits_read): print(f"...adding shard dep {shard.ID} -> RAW") depends_upon.add(shard.ID) From fca911006c5f5be2f1bc8c69c1844beb3e5fd8b6 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 10 Oct 2023 13:19:21 -0600 Subject: [PATCH 12/14] another classical test --- tests/data/qasm/classical_hazards.qasm | 19 ++++++++++ tests/sample_data.py | 1 + tests/test_sharder.py | 48 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 tests/data/qasm/classical_hazards.qasm diff --git a/tests/data/qasm/classical_hazards.qasm b/tests/data/qasm/classical_hazards.qasm new file mode 100644 index 0000000..b38fd29 --- /dev/null +++ b/tests/data/qasm/classical_hazards.qasm @@ -0,0 +1,19 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[2]; +creg c[3]; + +h q[0]; + +measure q[0]->c[0]; + +if(c[0]==1) c[1]=1; // RAW + +c[0]=0; // WAR + +h q[1]; + +measure q[1]->c[2]; + +if(c[2]==1) c[0]=1; // WAW diff --git a/tests/sample_data.py b/tests/sample_data.py index 983e8be..63e632f 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -13,6 +13,7 @@ class QasmFiles(Enum): simple_cond = 6 cond_classical = 7 barrier_complex = 8 + classical_hazards = 9 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 7d421f0..8e36b76 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -212,3 +212,51 @@ def test_complex_barriers(self) -> None: assert shards[6].bits_written == {circuit.bits[2]} assert shards[6].bits_read == {circuit.bits[2]} assert shards[6].depends_upon == {shards[5].ID, shards[4].ID} + + def test_classical_hazards(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.classical_hazards) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 5 + + # shard 0: [h q[0];] measure q[0]->c[0]; + assert shards[0].primary_command.op.type == OpType.Measure + assert len(shards[0].sub_commands.items()) == 1 + assert shards[0].qubits_used == {circuit.qubits[0]} + assert shards[0].bits_written == {circuit.bits[0]} + assert shards[0].bits_read == {circuit.bits[0]} + assert shards[0].depends_upon == set() + + # shard 1: [H q[1];] measure q[1]->c[2]; + # NOTE: pytket reorganizes circuits to be efficiently ordered + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands.items()) == 1 + assert shards[1].qubits_used == {circuit.qubits[1]} + assert shards[1].bits_written == {circuit.bits[2]} + assert shards[1].bits_read == {circuit.bits[2]} + assert shards[1].depends_upon == set() + + # shard 2: [] if(c[0]==1) c[1]=1; + assert shards[2].primary_command.op.type == OpType.Conditional + assert len(shards[2].sub_commands) == 0 + assert shards[2].qubits_used == set() + assert shards[2].bits_written == {circuit.bits[1]} + assert shards[2].bits_read == {circuit.bits[1], circuit.bits[0]} + assert shards[2].depends_upon == {shards[0].ID} + + # shard 3: [] c[0]=0; + assert shards[3].primary_command.op.type == OpType.SetBits + assert len(shards[2].sub_commands) == 0 + assert shards[3].qubits_used == set() + assert shards[3].bits_written == {circuit.bits[0]} + assert shards[3].bits_read == {circuit.bits[0]} + assert shards[3].depends_upon == {shards[0].ID} + + # shard 4: [] if(c[2]==1) c[0]=1; + assert shards[4].primary_command.op.type == OpType.Conditional + assert len(shards[4].sub_commands) == 0 + assert shards[4].qubits_used == set() + assert shards[4].bits_written == {circuit.bits[0]} + assert shards[4].bits_read == {circuit.bits[0], circuit.bits[2]} + assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID} From 029bbd3fe0bcf3d0cbc6928e44a6bbdac086123e Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Tue, 10 Oct 2023 15:28:48 -0600 Subject: [PATCH 13/14] big gate test --- tests/data/qasm/big_gate.qasm | 13 +++++++++++++ tests/sample_data.py | 1 + tests/test_sharder.py | 24 ++++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 tests/data/qasm/big_gate.qasm diff --git a/tests/data/qasm/big_gate.qasm b/tests/data/qasm/big_gate.qasm new file mode 100644 index 0000000..745122c --- /dev/null +++ b/tests/data/qasm/big_gate.qasm @@ -0,0 +1,13 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +creg c[1]; +qreg q[4]; + +h q[0]; +h q[1]; +h q[2]; + +c4x q[0],q[1],q[2],q[3]; + +measure q[3] -> c[0]; diff --git a/tests/sample_data.py b/tests/sample_data.py index 63e632f..a9f26b6 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -14,6 +14,7 @@ class QasmFiles(Enum): cond_classical = 7 barrier_complex = 8 classical_hazards = 9 + big_gate = 10 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 8e36b76..8c80efc 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -260,3 +260,27 @@ def test_classical_hazards(self) -> None: assert shards[4].bits_written == {circuit.bits[0]} assert shards[4].bits_read == {circuit.bits[0], circuit.bits[2]} assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID} + + def test_with_big_gate(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.big_gate) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 2 + + # shard 0: [h q[0]; h q[1]; h q[2];] c4x q[0],q[1],q[2],q[3]; + assert shards[0].primary_command.op.type == OpType.CnX + assert len(shards[0].sub_commands) == 3 + assert shards[0].qubits_used == { + circuit.qubits[0], + circuit.qubits[1], + circuit.qubits[2], + circuit.qubits[3], + } + assert shards[0].bits_written == set() + + # shard 1: [] measure q[3]->[c0] + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].qubits_used == {circuit.qubits[3]} + assert shards[1].bits_written == {circuit.bits[0]} From 1face99906e19fbf4e68c8bc07e1ecf79e7d2ad5 Mon Sep 17 00:00:00 2001 From: Neal Erickson Date: Wed, 11 Oct 2023 12:01:10 -0600 Subject: [PATCH 14/14] hash method for shards to allow sets --- pytket/phir/sharding/shard.py | 12 ++++++++---- tests/test_sharder.py | 9 +++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index c0e8291..8c1ee10 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -6,7 +6,7 @@ from pytket.unit_id import Bit, Qubit, UnitID -@dataclass +@dataclass(frozen=True) class Shard: """The Shard class. @@ -25,17 +25,21 @@ class Shard: sub_commands: dict[UnitID, list[Command]] # All qubits used by the primary and sub commands - qubits_used: set[Qubit] # = field(init=False) + qubits_used: set[Qubit] # Set of all classical bits written to by the primary and sub commands - bits_written: set[Bit] # = field(init=False) + bits_written: set[Bit] # Set of all classical bits read by the primary and sub commands - bits_read: set[Bit] # = field(init=False) + bits_read: set[Bit] # A set of the identifiers of other shards this particular shard depends upon depends_upon: set[int] + def __hash__(self) -> int: + """Hashing for shards is done only by its autogen unique int ID.""" + return self.ID + def pretty_print(self) -> str: """Returns the shard in a human-friendly format.""" output = io.StringIO() diff --git a/tests/test_sharder.py b/tests/test_sharder.py index 8c80efc..d65b51e 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -7,6 +7,15 @@ class TestSharder: + def test_shard_hashing(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby) + sharder = Sharder(circuit) + shards = sharder.shard() + + shard_set = set(shards) + + assert len(shard_set) > 0 + def test_should_op_create_shard(self) -> None: expected_true: list[Op] = [ Op.create(OpType.Measure), # type: ignore [misc]