Skip to content

Commit

Permalink
Skeleton of Sharding Approach (#1)
Browse files Browse the repository at this point in the history
First pass at structure.
---------

Co-authored-by: Neal Erickson <[email protected]>
Co-authored-by: Kartik Singhal <[email protected]>
  • Loading branch information
3 people authored Sep 28, 2023
1 parent 0637f35 commit 72363eb
Show file tree
Hide file tree
Showing 18 changed files with 263 additions and 34 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ repos:
hooks:
- id: mypy
pass_filenames: false
args: [.]
args: [--package=pytket.phir, --package=tests]
additional_dependencies: [
pytest,
pytket==1.20.1,
types-setuptools,
pytket,
]
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.PHONY: tests, lint

tests:
pytest -s -x -vv tests/test*.py

lint:
pre-commit run --all-files
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ where = ["."]
pythonpath = [
"."
]
filterwarnings = ["ignore:::lark.s*"]
31 changes: 25 additions & 6 deletions pytket/phir/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
"""This is the main module of the pytket-phir package."""
"""
NOTE: Just a placeholder to allow convenient testing of the flows
"""

from pytket.circuit import Circuit
from pytket.qasm.qasm import circuit_from_qasm

def hello_world() -> str:
"""Print 'Hello, world!' to the console."""
hw = "Hello, World!"
return hw
from .sharding.sharder import Sharder

# Load a qasm circuit and parse
# ,=""=,
# c , _,{
# /\ @ ) __
# / ^~~^\ <=.,__/ '}=
# (_/ ,, ,,) \_ _>_/~
# ~\_(/-\)'-,_,_,_,-'(_)-(_)
circuit: Circuit = circuit_from_qasm("tests/data/qasm/baby.qasm")

print(hello_world()) # noqa: T201
# https://cqcl.github.io/tket/pytket/api/circuit_class.html

# Just a little debugging fun
print("Input circuit:")
print(circuit)
print()

sharding_output = Sharder(circuit).shard()

print("Sharding output:")
print(sharding_output)
File renamed without changes.
23 changes: 23 additions & 0 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass

from pytket.circuit import Command
from pytket.unit_id import UnitID


@dataclass
class Shard:
"""
A shard is a logical grouping of operations that represents the unit by which
we actually do placement of qubits
"""

# The schedulable command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the other shards this particular shard depends upon, and thus
# must be scheduled after
depends_upon: set["Shard"]
81 changes: 81 additions & 0 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from pytket.circuit import Circuit, Command, Op, OpType
from pytket.unit_id import UnitID

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]


class Sharder:
"""
The sharder class is responsible for taking in a circuit in TKET representation
and converting it into shards that can be subsequently handled in the
compilation pipeline.
"""

def __init__(self, circuit: Circuit) -> None:
self._circuit = circuit
print(f"Sharder created for circuit {self._circuit}")
self._pending_commands: dict[UnitID, list[Command]] = {}
self._shards: list[Shard] = []

def shard(self) -> list[Shard]:
"""
Performs the sharding algorithm on the circuit the Sharder was initialized
with, returning the list of Shards needed to schedule
"""
print("Sharding begins....")
# https://cqcl.github.io/tket/pytket/api/circuit.html#pytket.circuit.Command
for command in self._circuit.get_commands():
self._process_command(command)
return self._shards

def _process_command(self, command: Command) -> None:
"""
Handles a given TKET command (operation, bits, etc) according to the type
and the extant context within the Sharder
"""
print("Processing command: ", command.op, command.op.type, command.args)
if command.op.type in NOT_IMPLEMENTED_OP_TYPES:
msg = f"OpType {command.op.type} not supported!"
raise NotImplementedError(msg)

if self.should_op_create_shard(command.op):
print(f"Building shard for command: {command}")
self._build_shard(command)
else:
self._add_pending_command(command)

def _build_shard(self, command: Command) -> None:
"""
Creates a Shard object given the extant sharding context and the schedulable
Command object passed in, and appends it to the Shard list
"""
shard = Shard(command, self._pending_commands, set())
# TODO: Dependencies!
self._pending_commands = {}
self._shards.append(shard)
print("Appended shard:", shard)

def _add_pending_command(self, command: Command) -> None:
"""
Adds a pending sub command to the buffer to be flushed when a schedulable
operation creates a Shard.
"""
# TODO: Need to make sure 'args[0]' is the right key to use.
if command.args[0] not in self._pending_commands:
self._pending_commands[command.args[0]] = []
self._pending_commands[command.args[0]].append(command)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
"""
Returns `True` if the operation is one that should result in shard creation.
This includes non-gate operations like measure/reset as well as 2-qubit gates.
"""
# TODO: This is almost certainly inadequate right now
return (
op.type == OpType.Measure
or op.type == OpType.Reset
or (op.is_gate() and op.n_qubits > 1)
)
6 changes: 0 additions & 6 deletions pytket/phir/utils.py

This file was deleted.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pytest==7.4.2
ruff==0.0.291
sphinx==7.2.6
wheel==0.41.2
pytket==1.20.1
12 changes: 10 additions & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@ select = [
"YTT", # flake8-2020
]

ignore = []
ignore = [
"T201", # no prints flake-8
"FIX002", # Allow todos
"TD002", # Allow no author todos
"TD003", # allow todos with no issues
]

[per-file-ignores]
"__init__.py" = ["F401"] # module imported but unused
"tests/*" = ["S101"] # Use of `assert` detected
"tests/*" = [
"S101", # Use of `assert` detected
"PLR2004" # Magic constants
]

[pydocstyle]
convention = "google"
11 changes: 11 additions & 0 deletions tests/data/qasm/baby.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[2];
creg c[2];

h q[0];
h q[1];
CX q[0], q[1];

measure q->c;
47 changes: 47 additions & 0 deletions tests/data/qasm/bv_n10.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//@author Raymond Harry Rudy [email protected]
//Bernstein-Vazirani with 10 qubits.
//Hidden string is 111111111
OPENQASM 2.0;
include "qelib1.inc";
qreg qr[10];
creg cr[9];
h qr[0];
h qr[1];
h qr[2];
h qr[3];
h qr[4];
h qr[5];
h qr[6];
h qr[7];
h qr[8];
x qr[9];
h qr[9];
barrier qr[0],qr[1],qr[2],qr[3],qr[4],qr[5],qr[6],qr[7],qr[8],qr[9];
cx qr[0],qr[9];
cx qr[1],qr[9];
cx qr[2],qr[9];
cx qr[3],qr[9];
cx qr[4],qr[9];
cx qr[5],qr[9];
cx qr[6],qr[9];
cx qr[7],qr[9];
cx qr[8],qr[9];
barrier qr[0],qr[1],qr[2],qr[3],qr[4],qr[5],qr[6],qr[7],qr[8],qr[9];
h qr[0];
h qr[1];
h qr[2];
h qr[3];
h qr[4];
h qr[5];
h qr[6];
h qr[7];
h qr[8];
measure qr[0] -> cr[0];
measure qr[1] -> cr[1];
measure qr[2] -> cr[2];
measure qr[3] -> cr[3];
measure qr[4] -> cr[4];
measure qr[5] -> cr[5];
measure qr[6] -> cr[6];
measure qr[7] -> cr[7];
measure qr[8] -> cr[8];
16 changes: 16 additions & 0 deletions tests/data/qasm/cond_1.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[2];

h q;
measure q[0]->c[0];
reset q;
if (c==1) h q;
if (c<1) h q;
if (c>1) h q;
if (c<=1) h q;
if (c>=1) h q;
if (c!=1) h q;
measure q[0]->c[0];
11 changes: 11 additions & 0 deletions tests/data/qasm/simple.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
OPENQASM 2.0;
include "qelib1.inc";
qreg q[1];
creg c[1];

h q[0];
h q[0];
h q[0];
h q[0];
h q[0];
measure q->c;
15 changes: 15 additions & 0 deletions tests/sample_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import Enum

from pytket.circuit import Circuit
from pytket.qasm.qasm import circuit_from_qasm


class QasmFiles(Enum):
simple = 1
cond_1 = 2
bv_n10 = 3
baby = 4


def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit:
return circuit_from_qasm(f"tests/data/qasm/{qasm_file.name}.qasm")
8 changes: 0 additions & 8 deletions tests/test_main.py

This file was deleted.

13 changes: 13 additions & 0 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pytket.phir.sharding.sharder import Sharder

from .sample_data import QasmFiles, get_qasm_as_circuit


class TestSharder:
def test_ctor(self) -> None:
sharder = Sharder(get_qasm_as_circuit(QasmFiles.baby))
assert sharder is not None

output = sharder.shard()

assert len(output) == 3
10 changes: 0 additions & 10 deletions tests/test_utils.py

This file was deleted.

0 comments on commit 72363eb

Please sign in to comment.