Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving shard handling and testing #3

Merged
merged 15 commits into from
Oct 11, 2023
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git"
where = ["."]

[tool.pytest.ini_options]
addopts = "-s -vv"
pythonpath = [
"."
]
log_cli = true
filterwarnings = ["ignore:::lark.s*"]
49 changes: 42 additions & 7 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,59 @@
from dataclasses import dataclass
import io
from dataclasses import dataclass, field
from itertools import count

from pytket.circuit import Command
from pytket.unit_id import UnitID
from pytket.unit_id import Bit, Qubit, UnitID


@dataclass
@dataclass(frozen=True)
class Shard:
"""The Shard class.

A shard is a logical grouping of operations that represents the unit by which
we actually do placement of qubits.
"""

# The schedulable command of the shard
# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)

# The "schedulable" command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the other shards this particular shard depends upon, and thus
# must be scheduled after
depends_upon: set["Shard"]
# All qubits used by the primary and sub commands
qubits_used: set[Qubit]

# Set of all classical bits written to by the primary and sub commands
bits_written: set[Bit]

# Set of all classical bits read by the primary and sub commands
bits_read: set[Bit]

# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

def __hash__(self) -> int:
"""Hashing for shards is done only by its autogen unique int ID."""
return self.ID

def pretty_print(self) -> str:
"""Returns the shard in a human-friendly format."""
output = io.StringIO()
output.write(f"Shard {self.ID}:")
output.write(f"\n Command: {self.primary_command}")
output.write("\n Sub commands: ")
if not self.sub_commands:
output.write("[]")
for sub in self.sub_commands:
output.write(f"\n {sub}: {self.sub_commands[sub]}")
output.write(f"\n Qubits used: {self.qubits_used}")
output.write(f"\n Bits written: {self.bits_written}")
output.write(f"\n Bits read: {self.bits_read}")
output.write(f"\n Depends upon: {self.depends_upon}")
content = output.getvalue()
output.close()
return content
121 changes: 104 additions & 17 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from pytket.circuit import Circuit, Command, Op, OpType
from pytket.unit_id import UnitID
from typing import cast

from pytket.circuit import Circuit, Command, Conditional, Op, OpType
from pytket.unit_id import Bit, UnitID

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]

SHARD_TRIGGER_OP_TYPES = [
OpType.Measure,
OpType.Reset,
OpType.Barrier,
OpType.SetBits,
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
]


class Sharder:
"""The Sharder class.
Expand All @@ -22,9 +33,9 @@ def __init__(self, circuit: Circuit) -> None:
circuit: tket Circuit
"""
self._circuit = circuit
print(f"Sharder created for circuit {self._circuit}")
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
print(f"Sharder created for circuit {self._circuit}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want prints vs. some other form of logging? I'm ok with it for first cut merging PR in, but I imagine we won't want to keep these prints as is.


def shard(self) -> list[Shard]:
"""Performs sharding algorithm on the circuit the Sharder was initialized with.
Expand All @@ -34,9 +45,14 @@ def shard(self) -> list[Shard]:
list of Shards needed to schedule
"""
print("Sharding begins....")
# https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command
for command in self._circuit.get_commands():
self._process_command(command)
self._cleanup_remaining_commands()

print("--------------------------------------------")
print("Shard output:")
for shard in self._shards:
print(shard.pretty_print())
return self._shards

def _process_command(self, command: Command) -> None:
Expand All @@ -51,10 +67,12 @@ def _process_command(self, command: Command) -> None:
raise NotImplementedError(msg)

if self.should_op_create_shard(command.op):
print(f"Building shard for command: {command}")
print(
f"Building shard for command: {command}",
)
self._build_shard(command)
else:
self._add_pending_command(command)
self._add_pending_sub_command(command)

def _build_shard(self, command: Command) -> None:
"""Builds a shard.
Expand All @@ -65,13 +83,77 @@ def _build_shard(self, command: Command) -> None:
Args:
command: tket command (operation, bits, etc)
"""
shard = Shard(command, self._pending_commands, set())
# TODO: Dependencies!
self._pending_commands = {}
# Rollup any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
key for key in list(self._pending_commands) if key in command.qubits
):
sub_commands[key] = self._pending_commands.pop(key)

all_commands = [command]
for sub_command_list in sub_commands.values():
all_commands.extend(sub_command_list)

qubits_used = set(command.qubits)
bits_written = set(command.bits)
bits_read: set[Bit] = set()

for sub_command in all_commands:
bits_written.update(sub_command.bits)
bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type]
)

# Handle dependency calculations
depends_upon: set[int] = set()
for shard in self._shards:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if dependency creation should be a self contained method. The overall _build_shard method isn't too long but this could be a natural part to break down.

# Check qubit dependencies (R/W implicitly) since all commands
# on a given qubit need to be ordered as the circuit dictated
if not shard.qubits_used.isdisjoint(command.qubits):
print(f"...adding shard dep {shard.ID} -> qubit overlap")
depends_upon.add(shard.ID)
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# NOTE: bits_read will include bits_written in the current impl

# Check for write-after-write (changing order would change final value)
# by looking at overlap of bits_written
elif not shard.bits_written.isdisjoint(bits_written):
print(f"...adding shard dep {shard.ID} -> WAW")
depends_upon.add(shard.ID)

# Check for read-after-write (value seen would change if reordered)
elif not shard.bits_written.isdisjoint(bits_read):
print(f"...adding shard dep {shard.ID} -> RAW")
depends_upon.add(shard.ID)

# Check for write-after-read (no reordering or read is changed)
elif not shard.bits_written.isdisjoint(bits_read):
print(f"...adding shard dep {shard.ID} -> WAR")
depends_upon.add(shard.ID)

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards.append(shard)
print("Appended shard:", shard)

def _add_pending_command(self, command: Command) -> None:
def _cleanup_remaining_commands(self) -> None:
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
for qubit in remaining_qubits:
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
# instead if this has unintended consequences
barrier_command = self._circuit.get_commands()[-1]
self._build_shard(barrier_command)

def _add_pending_sub_command(self, command: Command) -> None:
"""Adds a pending command.

Adds a pending sub command to the buffer to be flushed when a schedulable
Expand All @@ -80,10 +162,13 @@ def _add_pending_command(self, command: Command) -> None:
Args:
command: tket command (operation, bits, etc)
"""
# TODO: Need to make sure 'args[0]' is the right key to use.
if command.args[0] not in self._pending_commands:
self._pending_commands[command.args[0]] = []
self._pending_commands[command.args[0]].append(command)
key = command.qubits[0]
if key not in self._pending_commands:
self._pending_commands[key] = []
self._pending_commands[key].append(command)
print(
f"Adding pending command {command}",
)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
Expand All @@ -97,9 +182,11 @@ def should_op_create_shard(op: Op) -> bool:
Returns:
`True` if the operation is one that should result in shard creation
"""
# TODO: This is almost certainly inadequate right now
return (
op.type == OpType.Measure
or op.type == OpType.Reset
op.type in (SHARD_TRIGGER_OP_TYPES)
or (
op.type == OpType.Conditional
and cast(Conditional, op).op.type in (SHARD_TRIGGER_OP_TYPES)
)
or (op.is_gate() and op.n_qubits > 1)
)
4 changes: 3 additions & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ ignore = [
"tests/*" = [
"INP001",
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"S101", # Use of `assert` detected
"PLR2004" # Magic constants
"PLR2004", # Magic constants
"PLR0915", # Too many statements in function
]

[pydocstyle]
Expand Down
15 changes: 15 additions & 0 deletions tests/data/qasm/baby_with_rollup.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[2];

h q[0];
h q[1];
CX q[0], q[1];

measure q->c;

h q[0];

h q[1];
27 changes: 27 additions & 0 deletions tests/data/qasm/barrier_complex.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[4];
creg c[4];

h q[0];
h q[1];
h q[2];
h q[3];
c[3] = 1;

barrier q[0], q[1], c[3];

CX q[0], q[1];

measure q[0]->c[0];

h q[2];
h q[3];
x q[3];

barrier q[2], q[3];

CX q[2], q[3];

measure q[2]->c[2];
13 changes: 13 additions & 0 deletions tests/data/qasm/big_gate.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
OPENQASM 2.0;
include "hqslib1.inc";

creg c[1];
qreg q[4];

h q[0];
h q[1];
h q[2];

c4x q[0],q[1],q[2],q[3];

measure q[3] -> c[0];
19 changes: 19 additions & 0 deletions tests/data/qasm/classical_hazards.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[3];

h q[0];

measure q[0]->c[0];

if(c[0]==1) c[1]=1; // RAW

c[0]=0; // WAR

h q[1];

measure q[1]->c[2];

if(c[2]==1) c[0]=1; // WAW
25 changes: 25 additions & 0 deletions tests/data/qasm/cond_classical.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
OPENQASM 2.0;
include "hqslib1_dev.inc";
qreg q[1];
creg a[10];
creg b[10];
creg c[4];

// classical assignment of registers
a[0] = 1;
a = 3;
// classical bitwise functions
a = 1;
b = 3;
c = a ^ b; // XOR
// evaluating a beyond creg == int
a = 1;
b = 2;
if(a[0]==1) x q[0];
if(a!=1) x q[0];
if(a>1) x q[0];
if(a<1) x q[0];
if(a>=1) x q[0];
if(a<=1) x q[0];
if (a==10) b=1;
measure q[0] -> c[0];
13 changes: 13 additions & 0 deletions tests/data/qasm/simple_cond.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[1];
creg z[1];

h q;
measure q->c;
reset q;
if (c==1) h q;
if (c==1) z=1;
measure q->c;
Loading