Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving shard handling and testing #3

Merged
merged 15 commits into from
Oct 11, 2023
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Repository = "https://github.com/CQCL/pytket-phir.git"
where = ["."]

[tool.pytest.ini_options]
addopts = "-s -vv"
pythonpath = [
"."
]
log_cli = true
filterwarnings = ["ignore:::lark.s*"]
42 changes: 36 additions & 6 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass
import io
from dataclasses import dataclass, field
from itertools import count

from pytket.circuit import Command
from pytket.unit_id import UnitID
from pytket.unit_id import Bit, Qubit, UnitID


@dataclass
Expand All @@ -11,13 +13,41 @@ class Shard:
we actually do placement of qubits
"""

# The schedulable command of the shard
# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)

# The "schedulable" command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the other shards this particular shard depends upon, and thus
# must be scheduled after
depends_upon: set["Shard"]
# All qubits used by the primary and sub commands
qubits_used: set[Qubit] # = field(init=False)

# Set of all classical bits written to by the primary and sub commands
bits_written: set[Bit] # = field(init=False)

# Set of all classical bits read by the primary and sub commands
bits_read: set[Bit] # = field(init=False)

# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

def pretty_print(self) -> str:
output = io.StringIO()
output.write(f"Shard {self.ID}:")
output.write(f"\n Command: {self.primary_command}")
output.write("\n Sub commands: ")
if not self.sub_commands:
output.write("[]")
for sub in self.sub_commands:
output.write(f"\n {sub}: {self.sub_commands[sub]}")
output.write(f"\n Qubits used: {self.qubits_used}")
output.write(f"\n Bits written: {self.bits_written}")
output.write(f"\n Bits read: {self.bits_read}")
output.write(f"\n Depends upon: {self.depends_upon}")
content = output.getvalue()
output.close()
return content
130 changes: 113 additions & 17 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from pytket.circuit import Circuit, Command, Op, OpType
from pytket.unit_id import UnitID
from typing import cast

from pytket.circuit import Circuit, Command, Conditional, Op, OpType
from pytket.unit_id import Bit, UnitID

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]

SHARD_TRIGGER_OP_TYPES = [
OpType.Measure,
OpType.Reset,
OpType.Barrier,
OpType.SetBits,
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
]


class Sharder:
"""
Expand All @@ -15,9 +26,9 @@ class Sharder:

def __init__(self, circuit: Circuit) -> None:
self._circuit = circuit
print(f"Sharder created for circuit {self._circuit}")
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []
print(f"Sharder created for circuit {self._circuit}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want prints vs. some other form of logging? I'm ok with it for first cut merging PR in, but I imagine we won't want to keep these prints as is.


def shard(self) -> list[Shard]:
"""
Expand All @@ -28,6 +39,12 @@ def shard(self) -> list[Shard]:
# https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command
for command in self._circuit.get_commands():
self._process_command(command)
self._cleanup_remaining_commands()

print("--------------------------------------------")
print("Shard output:")
for shard in self._shards:
print(shard.pretty_print())
return self._shards

def _process_command(self, command: Command) -> None:
Expand All @@ -41,41 +58,120 @@ def _process_command(self, command: Command) -> None:
raise NotImplementedError(msg)

if self.should_op_create_shard(command.op):
print(f"Building shard for command: {command}")
print(
f"Building shard for command: {command}",
)
self._build_shard(command)
else:
self._add_pending_command(command)
self._add_pending_sub_command(command)

def _build_shard(self, command: Command) -> None:
"""
Creates a Shard object given the extant sharding context and the schedulable
Creates a Shard object given the extant sharding context and the primary
Command object passed in, and appends it to the Shard list
"""
shard = Shard(command, self._pending_commands, set())
# TODO: Dependencies!
self._pending_commands = {}
# Rollup any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
key for key in list(self._pending_commands) if key in command.qubits
):
sub_commands[key] = self._pending_commands.pop(key)

all_commands = [command]
for sub_command_list in sub_commands.values():
all_commands.extend(sub_command_list)

qubits_used = set(command.qubits)
bits_written = set(command.bits)
bits_read: set[Bit] = set()

# def filter_to_bits: Callable[[UnitId], bool] = lambda x: isinstance(x, Bit)
for sub_command in all_commands:
bits_written.update(sub_command.bits)
bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc, arg-type] # noqa: E501
)

# Handle dependency calculations
depends_upon: set[int] = set()
for shard in self._shards:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if dependency creation should be a self contained method. The overall _build_shard method isn't too long but this could be a natural part to break down.

# Check qubit dependencies (R/W implicitly) since all commands
# on a given qubit need to be ordered as the circuit dictated
if not shard.qubits_used.isdisjoint(command.qubits):
print(f"...adding shard dep {shard.ID} -> qubit overlap")
depends_upon.add(shard.ID)
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# NOTE: bits_read will include bits_written in the current impl

# Check for write-after-write (changing order would change final value)
# by looking at overlap of bits_written
elif not shard.bits_written.isdisjoint(bits_written):
print(f"...adding shard dep {shard.ID} -> WAW")
depends_upon.add(shard.ID)

# Check for read-after-write (value seen would change if reordered)
# elif not shard.bits_read.isdisjoint(bits_written):
# print(f'...adding shard dep {shard.ID} -> ')
# depends_upon.add(shard.ID)
elif not shard.bits_written.isdisjoint(bits_read):
print(f"...adding shard dep {shard.ID} -> RAW")
depends_upon.add(shard.ID)

# Check for write-after-read (no reordering or read is changed)
elif not shard.bits_written.isdisjoint(bits_read):
print(f"...adding shard dep {shard.ID} -> WAR")
depends_upon.add(shard.ID)

shard = Shard(
command,
sub_commands,
qubits_used,
bits_written,
bits_read,
depends_upon,
)
self._shards.append(shard)
print("Appended shard:", shard)

def _add_pending_command(self, command: Command) -> None:
def _cleanup_remaining_commands(self) -> None:
"""
Checks for any remaining "unsharded" commands, and if found, adds them
to Barrier op shards for each qubit
"""
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
for qubit in remaining_qubits:
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
# instead if this has unintended consequences
barrier_command = self._circuit.get_commands()[-1]
self._build_shard(barrier_command)

def _add_pending_sub_command(self, command: Command) -> None:
"""
Adds a pending sub command to the buffer to be flushed when a schedulable
operation creates a Shard.
"""
# TODO: Need to make sure 'args[0]' is the right key to use.
if command.args[0] not in self._pending_commands:
self._pending_commands[command.args[0]] = []
self._pending_commands[command.args[0]].append(command)
key = command.qubits[0]
if key not in self._pending_commands:
self._pending_commands[key] = []
self._pending_commands[key].append(command)
print(
f"Adding pending command {command}",
)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
"""
Returns `True` if the operation is one that should result in shard creation.
This includes non-gate operations like measure/reset as well as 2-qubit gates.
"""
# TODO: This is almost certainly inadequate right now
return (
op.type == OpType.Measure
or op.type == OpType.Reset
op.type in (SHARD_TRIGGER_OP_TYPES)
or (
op.type == OpType.Conditional
and cast(Conditional, op).op.type in (SHARD_TRIGGER_OP_TYPES)
)
or (op.is_gate() and op.n_qubits > 1)
)
15 changes: 15 additions & 0 deletions tests/data/qasm/baby_with_rollup.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[2];

h q[0];
h q[1];
CX q[0], q[1];

measure q->c;

h q[0];

h q[1];
25 changes: 25 additions & 0 deletions tests/data/qasm/cond_classical.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
OPENQASM 2.0;
include "hqslib1_dev.inc";
qreg q[1];
creg a[10];
creg b[10];
creg c[4];

// classical assignment of registers
a[0] = 1;
a = 3;
// classical bitwise functions
a = 1;
b = 3;
c = a ^ b; // XOR
// evaluating a beyond creg == int
a = 1;
b = 2;
if(a[0]==1) x q[0];
if(a!=1) x q[0];
if(a>1) x q[0];
if(a<1) x q[0];
if(a>=1) x q[0];
if(a<=1) x q[0];
if (a==10) b=1;
measure q[0] -> c[0];
13 changes: 13 additions & 0 deletions tests/data/qasm/simple_cond.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[1];
creg z[1];

h q;
measure q->c;
reset q;
if (c==1) h q;
if (c==1) z=1;
measure q->c;
3 changes: 3 additions & 0 deletions tests/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class QasmFiles(Enum):
cond_1 = 2
bv_n10 = 3
baby = 4
baby_with_rollup = 5
simple_cond = 6
cond_classical = 7


def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit:
Expand Down
Loading