diff --git a/pyproject.toml b/pyproject.toml index ca5665b..fa40d33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git" where = ["."] [tool.pytest.ini_options] +addopts = "-s -vv" pythonpath = [ "." ] +log_cli = true +filterwarnings = ["ignore:::lark.s*"] diff --git a/pytket/phir/sharding/shard.py b/pytket/phir/sharding/shard.py index 29ac0de..8c1ee10 100644 --- a/pytket/phir/sharding/shard.py +++ b/pytket/phir/sharding/shard.py @@ -1,10 +1,12 @@ -from dataclasses import dataclass +import io +from dataclasses import dataclass, field +from itertools import count from pytket.circuit import Command -from pytket.unit_id import UnitID +from pytket.unit_id import Bit, Qubit, UnitID -@dataclass +@dataclass(frozen=True) class Shard: """The Shard class. @@ -12,13 +14,46 @@ class Shard: we actually do placement of qubits. """ - # The schedulable command of the shard + # The unique identifier of the shard + ID: int = field(default_factory=count().__next__, init=False) + + # The "schedulable" command of the shard primary_command: Command # The other commands related to the primary schedulable command, stored # as a map of bit-handle (unitID) -> list[Command] sub_commands: dict[UnitID, list[Command]] - # A set of the other shards this particular shard depends upon, and thus - # must be scheduled after - depends_upon: set["Shard"] + # All qubits used by the primary and sub commands + qubits_used: set[Qubit] + + # Set of all classical bits written to by the primary and sub commands + bits_written: set[Bit] + + # Set of all classical bits read by the primary and sub commands + bits_read: set[Bit] + + # A set of the identifiers of other shards this particular shard depends upon + depends_upon: set[int] + + def __hash__(self) -> int: + """Hashing for shards is done only by its autogen unique int ID.""" + return self.ID + + def pretty_print(self) -> str: + """Returns the shard in a human-friendly format.""" + output = io.StringIO() + output.write(f"Shard {self.ID}:") + output.write(f"\n Command: {self.primary_command}") + output.write("\n Sub commands: ") + if not self.sub_commands: + output.write("[]") + for sub in self.sub_commands: + output.write(f"\n {sub}: {self.sub_commands[sub]}") + output.write(f"\n Qubits used: {self.qubits_used}") + output.write(f"\n Bits written: {self.bits_written}") + output.write(f"\n Bits read: {self.bits_read}") + output.write(f"\n Depends upon: {self.depends_upon}") + content = output.getvalue() + output.close() + return content diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 3862563..8ec5a92 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -1,10 +1,21 @@ -from pytket.circuit import Circuit, Command, Op, OpType -from pytket.unit_id import UnitID +from typing import cast + +from pytket.circuit import Circuit, Command, Conditional, Op, OpType +from pytket.unit_id import Bit, UnitID from .shard import Shard NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM] +SHARD_TRIGGER_OP_TYPES = [ + OpType.Measure, + OpType.Reset, + OpType.Barrier, + OpType.SetBits, + OpType.ClassicalExpBox, # some classical operations are rolled up into a box + OpType.RangePredicate, +] + class Sharder: """The Sharder class. @@ -22,9 +33,9 @@ def __init__(self, circuit: Circuit) -> None: circuit: tket Circuit """ self._circuit = circuit - print(f"Sharder created for circuit {self._circuit}") self._pending_commands: dict[UnitID, list[Command]] = {} self._shards: list[Shard] = [] + print(f"Sharder created for circuit {self._circuit}") def shard(self) -> list[Shard]: """Performs sharding algorithm on the circuit the Sharder was initialized with. @@ -34,9 +45,14 @@ def shard(self) -> list[Shard]: list of Shards needed to schedule """ print("Sharding begins....") - # https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command for command in self._circuit.get_commands(): self._process_command(command) + self._cleanup_remaining_commands() + + print("--------------------------------------------") + print("Shard output:") + for shard in self._shards: + print(shard.pretty_print()) return self._shards def _process_command(self, command: Command) -> None: @@ -51,10 +67,12 @@ def _process_command(self, command: Command) -> None: raise NotImplementedError(msg) if self.should_op_create_shard(command.op): - print(f"Building shard for command: {command}") + print( + f"Building shard for command: {command}", + ) self._build_shard(command) else: - self._add_pending_command(command) + self._add_pending_sub_command(command) def _build_shard(self, command: Command) -> None: """Builds a shard. @@ -65,13 +83,77 @@ def _build_shard(self, command: Command) -> None: Args: command: tket command (operation, bits, etc) """ - shard = Shard(command, self._pending_commands, set()) - # TODO: Dependencies! - self._pending_commands = {} + # Rollup any sub commands (SQ gates) that interact with the same qubits + sub_commands: dict[UnitID, list[Command]] = {} + for key in ( + key for key in list(self._pending_commands) if key in command.qubits + ): + sub_commands[key] = self._pending_commands.pop(key) + + all_commands = [command] + for sub_command_list in sub_commands.values(): + all_commands.extend(sub_command_list) + + qubits_used = set(command.qubits) + bits_written = set(command.bits) + bits_read: set[Bit] = set() + + for sub_command in all_commands: + bits_written.update(sub_command.bits) + bits_read.update( + set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] + ) + + # Handle dependency calculations + depends_upon: set[int] = set() + for shard in self._shards: + # Check qubit dependencies (R/W implicitly) since all commands + # on a given qubit need to be ordered as the circuit dictated + if not shard.qubits_used.isdisjoint(command.qubits): + print(f"...adding shard dep {shard.ID} -> qubit overlap") + depends_upon.add(shard.ID) + # Check classical dependencies, which depend on writing and reading + # hazards: RAW, WAW, WAR + # NOTE: bits_read will include bits_written in the current impl + + # Check for write-after-write (changing order would change final value) + # by looking at overlap of bits_written + elif not shard.bits_written.isdisjoint(bits_written): + print(f"...adding shard dep {shard.ID} -> WAW") + depends_upon.add(shard.ID) + + # Check for read-after-write (value seen would change if reordered) + elif not shard.bits_written.isdisjoint(bits_read): + print(f"...adding shard dep {shard.ID} -> RAW") + depends_upon.add(shard.ID) + + # Check for write-after-read (no reordering or read is changed) + elif not shard.bits_written.isdisjoint(bits_read): + print(f"...adding shard dep {shard.ID} -> WAR") + depends_upon.add(shard.ID) + + shard = Shard( + command, + sub_commands, + qubits_used, + bits_written, + bits_read, + depends_upon, + ) self._shards.append(shard) print("Appended shard:", shard) - def _add_pending_command(self, command: Command) -> None: + def _cleanup_remaining_commands(self) -> None: + remaining_qubits = [k for k, v in self._pending_commands.items() if v] + for qubit in remaining_qubits: + self._circuit.add_barrier([qubit]) + # Easiest way to get to a command, since there's no constructor. Could + # create an entire orphan circuit with the matching qubits and the barrier + # instead if this has unintended consequences + barrier_command = self._circuit.get_commands()[-1] + self._build_shard(barrier_command) + + def _add_pending_sub_command(self, command: Command) -> None: """Adds a pending command. Adds a pending sub command to the buffer to be flushed when a schedulable @@ -80,10 +162,13 @@ def _add_pending_command(self, command: Command) -> None: Args: command: tket command (operation, bits, etc) """ - # TODO: Need to make sure 'args[0]' is the right key to use. - if command.args[0] not in self._pending_commands: - self._pending_commands[command.args[0]] = [] - self._pending_commands[command.args[0]].append(command) + key = command.qubits[0] + if key not in self._pending_commands: + self._pending_commands[key] = [] + self._pending_commands[key].append(command) + print( + f"Adding pending command {command}", + ) @staticmethod def should_op_create_shard(op: Op) -> bool: @@ -97,9 +182,11 @@ def should_op_create_shard(op: Op) -> bool: Returns: `True` if the operation is one that should result in shard creation """ - # TODO: This is almost certainly inadequate right now return ( - op.type == OpType.Measure - or op.type == OpType.Reset + op.type in (SHARD_TRIGGER_OP_TYPES) + or ( + op.type == OpType.Conditional + and cast(Conditional, op).op.type in (SHARD_TRIGGER_OP_TYPES) + ) or (op.is_gate() and op.n_qubits > 1) ) diff --git a/ruff.toml b/ruff.toml index cf9ac21..2350fd0 100644 --- a/ruff.toml +++ b/ruff.toml @@ -61,8 +61,10 @@ ignore = [ "tests/*" = [ "INP001", "D101", # Missing docstring in public class + "D102", # Missing docstring in public method "S101", # Use of `assert` detected - "PLR2004" # Magic constants + "PLR2004", # Magic constants + "PLR0915", # Too many statements in function ] [pydocstyle] diff --git a/tests/data/qasm/baby_with_rollup.qasm b/tests/data/qasm/baby_with_rollup.qasm new file mode 100644 index 0000000..b2a81fc --- /dev/null +++ b/tests/data/qasm/baby_with_rollup.qasm @@ -0,0 +1,15 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[2]; +creg c[2]; + +h q[0]; +h q[1]; +CX q[0], q[1]; + +measure q->c; + +h q[0]; + +h q[1]; diff --git a/tests/data/qasm/barrier_complex.qasm b/tests/data/qasm/barrier_complex.qasm new file mode 100644 index 0000000..bee4d0b --- /dev/null +++ b/tests/data/qasm/barrier_complex.qasm @@ -0,0 +1,27 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[4]; +creg c[4]; + +h q[0]; +h q[1]; +h q[2]; +h q[3]; +c[3] = 1; + +barrier q[0], q[1], c[3]; + +CX q[0], q[1]; + +measure q[0]->c[0]; + +h q[2]; +h q[3]; +x q[3]; + +barrier q[2], q[3]; + +CX q[2], q[3]; + +measure q[2]->c[2]; diff --git a/tests/data/qasm/big_gate.qasm b/tests/data/qasm/big_gate.qasm new file mode 100644 index 0000000..745122c --- /dev/null +++ b/tests/data/qasm/big_gate.qasm @@ -0,0 +1,13 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +creg c[1]; +qreg q[4]; + +h q[0]; +h q[1]; +h q[2]; + +c4x q[0],q[1],q[2],q[3]; + +measure q[3] -> c[0]; diff --git a/tests/data/qasm/classical_hazards.qasm b/tests/data/qasm/classical_hazards.qasm new file mode 100644 index 0000000..b38fd29 --- /dev/null +++ b/tests/data/qasm/classical_hazards.qasm @@ -0,0 +1,19 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[2]; +creg c[3]; + +h q[0]; + +measure q[0]->c[0]; + +if(c[0]==1) c[1]=1; // RAW + +c[0]=0; // WAR + +h q[1]; + +measure q[1]->c[2]; + +if(c[2]==1) c[0]=1; // WAW diff --git a/tests/data/qasm/cond_classical.qasm b/tests/data/qasm/cond_classical.qasm new file mode 100644 index 0000000..6d7ca18 --- /dev/null +++ b/tests/data/qasm/cond_classical.qasm @@ -0,0 +1,25 @@ +OPENQASM 2.0; +include "hqslib1_dev.inc"; +qreg q[1]; +creg a[10]; +creg b[10]; +creg c[4]; + +// classical assignment of registers +a[0] = 1; +a = 3; +// classical bitwise functions +a = 1; +b = 3; +c = a ^ b; // XOR +// evaluating a beyond creg == int +a = 1; +b = 2; +if(a[0]==1) x q[0]; +if(a!=1) x q[0]; +if(a>1) x q[0]; +if(a<1) x q[0]; +if(a>=1) x q[0]; +if(a<=1) x q[0]; +if (a==10) b=1; +measure q[0] -> c[0]; diff --git a/tests/data/qasm/simple_cond.qasm b/tests/data/qasm/simple_cond.qasm new file mode 100644 index 0000000..5334a94 --- /dev/null +++ b/tests/data/qasm/simple_cond.qasm @@ -0,0 +1,13 @@ +OPENQASM 2.0; +include "hqslib1.inc"; + +qreg q[1]; +creg c[1]; +creg z[1]; + +h q; +measure q->c; +reset q; +if (c==1) h q; +if (c==1) z=1; +measure q->c; diff --git a/tests/sample_data.py b/tests/sample_data.py index 79c4490..a9f26b6 100644 --- a/tests/sample_data.py +++ b/tests/sample_data.py @@ -9,6 +9,12 @@ class QasmFiles(Enum): cond_1 = 2 bv_n10 = 3 baby = 4 + baby_with_rollup = 5 + simple_cond = 6 + cond_classical = 7 + barrier_complex = 8 + classical_hazards = 9 + big_gate = 10 def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit: diff --git a/tests/test_sharder.py b/tests/test_sharder.py index d72980a..d65b51e 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -1,14 +1,295 @@ +from typing import cast + +from pytket.circuit import Conditional, Op, OpType from pytket.phir.sharding.sharder import Sharder from .sample_data import QasmFiles, get_qasm_as_circuit class TestSharder: - def test_ctor(self) -> None: - """Simple test of Sharder construction.""" - sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby)) - assert sharder is not None + def test_shard_hashing(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby) + sharder = Sharder(circuit) + shards = sharder.shard() + + shard_set = set(shards) + + assert len(shard_set) > 0 + + def test_should_op_create_shard(self) -> None: + expected_true: list[Op] = [ + Op.create(OpType.Measure), # type: ignore [misc] + Op.create(OpType.Reset), # type: ignore [misc] + Op.create(OpType.CX), # type: ignore [misc] + Op.create(OpType.Barrier), # type: ignore [misc] + ] + expected_false: list[Op] = [ + Op.create(OpType.U1, 0.32), # type: ignore [misc] + Op.create(OpType.H), # type: ignore [misc] + Op.create(OpType.Z), # type: ignore [misc] + ] + + for op in expected_true: + assert Sharder.should_op_create_shard(op) + for op in expected_false: + assert not Sharder.should_op_create_shard(op) + + def test_with_baby_circuit(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 3 + + assert shards[0].primary_command.op.type == OpType.CX + assert len(shards[0].primary_command.qubits) == 2 + assert not shards[0].primary_command.bits + assert len(shards[0].sub_commands) == 2 + sub_commands = list(shards[0].sub_commands.items()) + print(sub_commands) + assert sub_commands[0][1][0].op.type == OpType.H + assert len(shards[0].depends_upon) == 0 + + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].depends_upon == {shards[0].ID} + + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + def test_rollup_behavior(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.baby_with_rollup) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 5 + + assert shards[0].primary_command.op.type == OpType.CX + assert len(shards[0].primary_command.qubits) == 2 + assert not shards[0].primary_command.bits + assert len(shards[0].sub_commands) == 2 + sub_commands = list(shards[0].sub_commands.items()) + print(sub_commands) + assert sub_commands[0][1][0].op.type == OpType.H + assert len(shards[0].depends_upon) == 0 + + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].depends_upon == {shards[0].ID} + + assert shards[2].primary_command.op.type == OpType.Measure + assert len(shards[2].sub_commands) == 0 + assert shards[2].depends_upon == {shards[0].ID} + + assert shards[3].primary_command.op.type == OpType.Barrier + assert len(shards[3].sub_commands) == 1 + assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + + assert shards[4].primary_command.op.type == OpType.Barrier + assert len(shards[4].sub_commands) == 1 + assert shards[4].depends_upon == {shards[0].ID, shards[2].ID} + + def test_simple_conditional(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.simple_cond) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 4 + + # shard 0: [h q;] measure q->c; + assert shards[0].primary_command.op.type == OpType.Measure + assert shards[0].qubits_used == {circuit.qubits[0]} + assert shards[0].bits_written == {circuit.bits[0]} + assert shards[0].depends_upon == set() + assert len(shards[0].sub_commands.items()) == 1 + s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items())) + assert s0_qubit == circuit.qubits[0] + assert s0_sub_cmds[0].op.type == OpType.H + + # shard 1: reset q; + assert shards[1].primary_command.op.type == OpType.Reset + assert len(shards[1].sub_commands.items()) == 0 + assert shards[1].qubits_used == {circuit.qubits[0]} + assert shards[1].depends_upon == {shards[0].ID} + assert shards[1].bits_written == set() + assert shards[1].bits_read == set() + + # shard 2: if (c==1) z=1; + assert shards[2].primary_command.op.type == OpType.Conditional + assert cast(Conditional, shards[2].primary_command.op).op.type == OpType.SetBits + assert len(shards[2].sub_commands.keys()) == 0 + assert shards[2].qubits_used == set() + assert shards[2].bits_written == {circuit.bits[1]} + assert shards[2].bits_read == {circuit.bits[0], circuit.bits[1]} + assert shards[2].depends_upon == {shards[0].ID} + + # shard 3: [if (c==1) h q;] measure q->c; + assert shards[3].primary_command.op.type == OpType.Measure + assert shards[3].qubits_used == {circuit.qubits[0]} + assert shards[3].bits_written == {circuit.bits[0]} + assert shards[3].bits_read == {circuit.bits[0]} + assert shards[3].depends_upon == {shards[0].ID, shards[1].ID} + assert len(shards[3].sub_commands.items()) == 1 + s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) + assert s2_qubit == circuit.qubits[0] + assert s2_sub_cmds[0].op.type == OpType.Conditional + assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H + assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] + + def test_complex_barriers(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.barrier_complex) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 7 + + # shard 0: [], c[3] = 1 + assert shards[0].primary_command.op.type == OpType.SetBits + assert len(shards[0].sub_commands.items()) == 0 + assert shards[0].qubits_used == set() + assert shards[0].bits_written == {circuit.bits[3]} + assert shards[0].bits_read == {circuit.bits[3]} # bits written are always read + assert shards[0].depends_upon == set() + + # shard 1: [h q[0]; h q[1];] barrier q[0], q[1], c[3]; + assert shards[1].primary_command.op.type == OpType.Barrier + assert len(shards[1].sub_commands.items()) == 2 + shard_1_q0_cmds = shards[1].sub_commands[circuit.qubits[0]] + assert len(shard_1_q0_cmds) == 1 + assert shard_1_q0_cmds[0].op.type == OpType.H + assert shard_1_q0_cmds[0].qubits == [circuit.qubits[0]] + shard_1_q1_cmds = shards[1].sub_commands[circuit.qubits[1]] + assert len(shard_1_q1_cmds) == 1 + assert shard_1_q1_cmds[0].op.type == OpType.H + assert shard_1_q1_cmds[0].qubits == [circuit.qubits[1]] + assert shards[1].qubits_used == {circuit.qubits[0], circuit.qubits[1]} + assert shards[1].bits_written == {circuit.bits[3]} + assert shards[1].bits_read == {circuit.bits[3]} + assert shards[1].depends_upon == {shards[0].ID} + + # shard 2: [] CX q[0], q[1]; + assert shards[2].primary_command.op.type == OpType.CX + assert len(shards[2].sub_commands.items()) == 0 + assert shards[2].qubits_used == {circuit.qubits[0], circuit.qubits[1]} + assert shards[2].bits_written == set() + assert shards[2].bits_read == set() + assert shards[2].depends_upon == {shards[1].ID} + + # shard 3: measure q[0]->c[0]; + assert shards[3].primary_command.op.type == OpType.Measure + assert len(shards[3].sub_commands.items()) == 0 + assert shards[3].qubits_used == {circuit.qubits[0]} + assert shards[3].bits_written == {circuit.bits[0]} + assert shards[3].bits_read == {circuit.bits[0]} + assert shards[3].depends_upon == {shards[2].ID, shards[1].ID} + + # shard 4: [H q[3];, H q[3];, X q[3];]] barrier q[2], q[3]; + assert shards[4].primary_command.op.type == OpType.Barrier + assert len(shards[4].sub_commands.items()) == 2 + shard_4_q2_cmds = shards[4].sub_commands[circuit.qubits[2]] + assert len(shard_4_q2_cmds) == 2 + assert shard_4_q2_cmds[0].op.type == OpType.H + assert shard_4_q2_cmds[0].qubits == [circuit.qubits[2]] + assert shard_4_q2_cmds[1].op.type == OpType.H + assert shard_4_q2_cmds[1].qubits == [circuit.qubits[2]] + shard_4_q3_cmds = shards[4].sub_commands[circuit.qubits[3]] + assert len(shard_4_q3_cmds) == 3 + assert shard_4_q3_cmds[0].op.type == OpType.H + assert shard_4_q3_cmds[0].qubits == [circuit.qubits[3]] + assert shard_4_q3_cmds[1].op.type == OpType.H + assert shard_4_q3_cmds[1].qubits == [circuit.qubits[3]] + assert shard_4_q3_cmds[2].op.type == OpType.X + assert shard_4_q3_cmds[2].qubits == [circuit.qubits[3]] + assert shards[4].qubits_used == {circuit.qubits[2], circuit.qubits[3]} + assert shards[4].bits_written == set() + assert shards[4].bits_read == set() + assert shards[4].depends_upon == set() + + # shard 5: [] CX q[2], q[3]; + assert shards[5].primary_command.op.type == OpType.CX + assert len(shards[5].sub_commands.items()) == 0 + assert shards[5].qubits_used == {circuit.qubits[2], circuit.qubits[3]} + assert shards[5].bits_written == set() + assert shards[5].bits_read == set() + assert shards[5].depends_upon == {shards[4].ID} + + # shard 6: measure q[2]->c[2]; + assert shards[6].primary_command.op.type == OpType.Measure + assert len(shards[6].sub_commands.items()) == 0 + assert shards[6].qubits_used == {circuit.qubits[2]} + assert shards[6].bits_written == {circuit.bits[2]} + assert shards[6].bits_read == {circuit.bits[2]} + assert shards[6].depends_upon == {shards[5].ID, shards[4].ID} + + def test_classical_hazards(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.classical_hazards) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 5 + + # shard 0: [h q[0];] measure q[0]->c[0]; + assert shards[0].primary_command.op.type == OpType.Measure + assert len(shards[0].sub_commands.items()) == 1 + assert shards[0].qubits_used == {circuit.qubits[0]} + assert shards[0].bits_written == {circuit.bits[0]} + assert shards[0].bits_read == {circuit.bits[0]} + assert shards[0].depends_upon == set() + + # shard 1: [H q[1];] measure q[1]->c[2]; + # NOTE: pytket reorganizes circuits to be efficiently ordered + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands.items()) == 1 + assert shards[1].qubits_used == {circuit.qubits[1]} + assert shards[1].bits_written == {circuit.bits[2]} + assert shards[1].bits_read == {circuit.bits[2]} + assert shards[1].depends_upon == set() + + # shard 2: [] if(c[0]==1) c[1]=1; + assert shards[2].primary_command.op.type == OpType.Conditional + assert len(shards[2].sub_commands) == 0 + assert shards[2].qubits_used == set() + assert shards[2].bits_written == {circuit.bits[1]} + assert shards[2].bits_read == {circuit.bits[1], circuit.bits[0]} + assert shards[2].depends_upon == {shards[0].ID} + + # shard 3: [] c[0]=0; + assert shards[3].primary_command.op.type == OpType.SetBits + assert len(shards[2].sub_commands) == 0 + assert shards[3].qubits_used == set() + assert shards[3].bits_written == {circuit.bits[0]} + assert shards[3].bits_read == {circuit.bits[0]} + assert shards[3].depends_upon == {shards[0].ID} + + # shard 4: [] if(c[2]==1) c[0]=1; + assert shards[4].primary_command.op.type == OpType.Conditional + assert len(shards[4].sub_commands) == 0 + assert shards[4].qubits_used == set() + assert shards[4].bits_written == {circuit.bits[0]} + assert shards[4].bits_read == {circuit.bits[0], circuit.bits[2]} + assert shards[4].depends_upon == {shards[1].ID, shards[0].ID, shards[3].ID} + + def test_with_big_gate(self) -> None: + circuit = get_qasm_as_circuit(QasmFiles.big_gate) + sharder = Sharder(circuit) + shards = sharder.shard() + + assert len(shards) == 2 - output = sharder.shard() + # shard 0: [h q[0]; h q[1]; h q[2];] c4x q[0],q[1],q[2],q[3]; + assert shards[0].primary_command.op.type == OpType.CnX + assert len(shards[0].sub_commands) == 3 + assert shards[0].qubits_used == { + circuit.qubits[0], + circuit.qubits[1], + circuit.qubits[2], + circuit.qubits[3], + } + assert shards[0].bits_written == set() - assert len(output) == 3 + # shard 1: [] measure q[3]->[c0] + assert shards[1].primary_command.op.type == OpType.Measure + assert len(shards[1].sub_commands) == 0 + assert shards[1].qubits_used == {circuit.qubits[3]} + assert shards[1].bits_written == {circuit.bits[0]}