Skip to content

Commit

Permalink
Improving shard handling and testing (#3)
Browse files Browse the repository at this point in the history
* Improving shard handling and testing

* Apply suggestions from code review

Co-authored-by: Kartik Singhal <[email protected]>

* in progress

* test progress

* fixing ruff

* removing commented

* linting feedback

* mypy

* Adding test for partial barriers

* updating test

* cleanup comments

* another classical test

* big gate test

* hash method for shards to allow sets

---------

Co-authored-by: Neal Erickson <[email protected]>
Co-authored-by: Kartik Singhal <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2023
1 parent a77f979 commit 90565f1
Show file tree
Hide file tree
Showing 12 changed files with 557 additions and 31 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git"
where = ["."]

[tool.pytest.ini_options]
addopts = "-s -vv"
pythonpath = [
"."
]
log_cli = true
filterwarnings = ["ignore:::lark.s*"]
49 changes: 42 additions & 7 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,59 @@
from dataclasses import dataclass
import io
from dataclasses import dataclass, field
from itertools import count

from pytket.circuit import Command
from pytket.unit_id import UnitID
from pytket.unit_id import Bit, Qubit, UnitID


@dataclass
@dataclass(frozen=True)
class Shard:
"""The Shard class.
A shard is a logical grouping of operations that represents the unit by which
we actually do placement of qubits.
"""

# The schedulable command of the shard
# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)

# The "schedulable" command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the other shards this particular shard depends upon, and thus
# must be scheduled after
depends_upon: set["Shard"]
# All qubits used by the primary and sub commands
qubits_used: set[Qubit]

# Set of all classical bits written to by the primary and sub commands
bits_written: set[Bit]

# Set of all classical bits read by the primary and sub commands
bits_read: set[Bit]

# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

def __hash__(self) -> int:
"""Hashing for shards is done only by its autogen unique int ID."""
return self.ID

def pretty_print(self) -> str:
"""Returns the shard in a human-friendly format."""
output = io.StringIO()
output.write(f"Shard {self.ID}:")
output.write(f"\n Command: {self.primary_command}")
output.write("\n Sub commands: ")
if not self.sub_commands:
output.write("[]")
for sub in self.sub_commands:
output.write(f"\n {sub}: {self.sub_commands[sub]}")
output.write(f"\n Qubits used: {self.qubits_used}")
output.write(f"\n Bits written: {self.bits_written}")
output.write(f"\n Bits read: {self.bits_read}")
output.write(f"\n Depends upon: {self.depends_upon}")
content = output.getvalue()
output.close()
return content
121 changes: 104 additions & 17 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from pytket.circuit import Circuit, Command, Op, OpType
from pytket.unit_id import UnitID
from typing import cast

from pytket.circuit import Circuit, Command, Conditional, Op, OpType
from pytket.unit_id import Bit, UnitID

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]

SHARD_TRIGGER_OP_TYPES = [
OpType.Measure,
OpType.Reset,
OpType.Barrier,
OpType.SetBits,
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
]


class Sharder:
"""The Sharder class.
Expand All @@ -22,9 +33,9 @@ def __init__(self, circuit: Circuit) -> None:
circuit: tket Circuit
"""
self._circuit = circuit
print(f"Sharder created for circuit {self._circuit}")
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
print(f"Sharder created for circuit {self._circuit}")

def shard(self) -> list[Shard]:
"""Performs sharding algorithm on the circuit the Sharder was initialized with.
Expand All @@ -34,9 +45,14 @@ def shard(self) -> list[Shard]:
list of Shards needed to schedule
"""
print("Sharding begins....")
# https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command
for command in self._circuit.get_commands():
self._process_command(command)
self._cleanup_remaining_commands()

print("--------------------------------------------")
print("Shard output:")
for shard in self._shards:
print(shard.pretty_print())
return self._shards

def _process_command(self, command: Command) -> None:
Expand All @@ -51,10 +67,12 @@ def _process_command(self, command: Command) -> None:
raise NotImplementedError(msg)

if self.should_op_create_shard(command.op):
print(f"Building shard for command: {command}")
print(
f"Building shard for command: {command}",
)
self._build_shard(command)
else:
self._add_pending_command(command)
self._add_pending_sub_command(command)

def _build_shard(self, command: Command) -> None:
"""Builds a shard.
Expand All @@ -65,13 +83,77 @@ def _build_shard(self, command: Command) -> None:
Args:
command: tket command (operation, bits, etc)
"""
shard = Shard(command, self._pending_commands, set())
# TODO: Dependencies!
self._pending_commands = {}
# Rollup any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
key for key in list(self._pending_commands) if key in command.qubits
):
sub_commands[key] = self._pending_commands.pop(key)

all_commands = [command]
for sub_command_list in sub_commands.values():
all_commands.extend(sub_command_list)

qubits_used = set(command.qubits)
bits_written = set(command.bits)
bits_read: set[Bit] = set()

for sub_command in all_commands:
bits_written.update(sub_command.bits)
bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type]
)

# Handle dependency calculations
depends_upon: set[int] = set()
for shard in self._shards:
# Check qubit dependencies (R/W implicitly) since all commands
# on a given qubit need to be ordered as the circuit dictated
if not shard.qubits_used.isdisjoint(command.qubits):
print(f"...adding shard dep {shard.ID} -> qubit overlap")
depends_upon.add(shard.ID)
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# NOTE: bits_read will include bits_written in the current impl

# Check for write-after-write (changing order would change final value)
# by looking at overlap of bits_written
elif not shard.bits_written.isdisjoint(bits_written):
print(f"...adding shard dep {shard.ID} -> WAW")
depends_upon.add(shard.ID)

# Check for read-after-write (value seen would change if reordered)
elif not shard.bits_written.isdisjoint(bits_read):
print(f"...adding shard dep {shard.ID} -> RAW")
depends_upon.add(shard.ID)

# Check for write-after-read (no reordering or read is changed)
elif not shard.bits_written.isdisjoint(bits_read):
print(f"...adding shard dep {shard.ID} -> WAR")
depends_upon.add(shard.ID)

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards.append(shard)
print("Appended shard:", shard)

def _add_pending_command(self, command: Command) -> None:
def _cleanup_remaining_commands(self) -> None:
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
for qubit in remaining_qubits:
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
# instead if this has unintended consequences
barrier_command = self._circuit.get_commands()[-1]
self._build_shard(barrier_command)

def _add_pending_sub_command(self, command: Command) -> None:
"""Adds a pending command.
Adds a pending sub command to the buffer to be flushed when a schedulable
Expand All @@ -80,10 +162,13 @@ def _add_pending_command(self, command: Command) -> None:
Args:
command: tket command (operation, bits, etc)
"""
# TODO: Need to make sure 'args[0]' is the right key to use.
if command.args[0] not in self._pending_commands:
self._pending_commands[command.args[0]] = []
self._pending_commands[command.args[0]].append(command)
key = command.qubits[0]
if key not in self._pending_commands:
self._pending_commands[key] = []
self._pending_commands[key].append(command)
print(
f"Adding pending command {command}",
)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
Expand All @@ -97,9 +182,11 @@ def should_op_create_shard(op: Op) -> bool:
Returns:
`True` if the operation is one that should result in shard creation
"""
# TODO: This is almost certainly inadequate right now
return (
op.type == OpType.Measure
or op.type == OpType.Reset
op.type in (SHARD_TRIGGER_OP_TYPES)
or (
op.type == OpType.Conditional
and cast(Conditional, op).op.type in (SHARD_TRIGGER_OP_TYPES)
)
or (op.is_gate() and op.n_qubits > 1)
)
4 changes: 3 additions & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ ignore = [
"tests/*" = [
"INP001",
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"S101", # Use of `assert` detected
"PLR2004" # Magic constants
"PLR2004", # Magic constants
"PLR0915", # Too many statements in function
]

[pydocstyle]
Expand Down
15 changes: 15 additions & 0 deletions tests/data/qasm/baby_with_rollup.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[2];

h q[0];
h q[1];
CX q[0], q[1];

measure q->c;

h q[0];

h q[1];
27 changes: 27 additions & 0 deletions tests/data/qasm/barrier_complex.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[4];
creg c[4];

h q[0];
h q[1];
h q[2];
h q[3];
c[3] = 1;

barrier q[0], q[1], c[3];

CX q[0], q[1];

measure q[0]->c[0];

h q[2];
h q[3];
x q[3];

barrier q[2], q[3];

CX q[2], q[3];

measure q[2]->c[2];
13 changes: 13 additions & 0 deletions tests/data/qasm/big_gate.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
OPENQASM 2.0;
include "hqslib1.inc";

creg c[1];
qreg q[4];

h q[0];
h q[1];
h q[2];

c4x q[0],q[1],q[2],q[3];

measure q[3] -> c[0];
19 changes: 19 additions & 0 deletions tests/data/qasm/classical_hazards.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[3];

h q[0];

measure q[0]->c[0];

if(c[0]==1) c[1]=1; // RAW

c[0]=0; // WAR

h q[1];

measure q[1]->c[2];

if(c[2]==1) c[0]=1; // WAW
25 changes: 25 additions & 0 deletions tests/data/qasm/cond_classical.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
OPENQASM 2.0;
include "hqslib1_dev.inc";
qreg q[1];
creg a[10];
creg b[10];
creg c[4];

// classical assignment of registers
a[0] = 1;
a = 3;
// classical bitwise functions
a = 1;
b = 3;
c = a ^ b; // XOR
// evaluating a beyond creg == int
a = 1;
b = 2;
if(a[0]==1) x q[0];
if(a!=1) x q[0];
if(a>1) x q[0];
if(a<1) x q[0];
if(a>=1) x q[0];
if(a<=1) x q[0];
if (a==10) b=1;
measure q[0] -> c[0];
13 changes: 13 additions & 0 deletions tests/data/qasm/simple_cond.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[1];
creg z[1];

h q;
measure q->c;
reset q;
if (c==1) h q;
if (c==1) z=1;
measure q->c;
Loading

0 comments on commit 90565f1

Please sign in to comment.