Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
nealerickson-qtm committed Oct 3, 2023
1 parent 3faaa27 commit 3d85f40
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 32 deletions.
40 changes: 20 additions & 20 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,42 @@ class Shard:
we actually do placement of qubits
"""

# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)

# The "schedulable" command of the shard
primary_command: Command

# The other commands related to the primary schedulable command, stored
# as a map of bit-handle (unitID) -> list[Command]
sub_commands: dict[UnitID, list[Command]]

# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

# All qubits used by the primary and sub commands
qubits_used: set[Qubit] = field(init=False)
qubits_used: set[Qubit] # = field(init=False)

# Set of all classical bits written to by the primary and sub commands
bits_written: set[Bit] = field(init=False)
bits_written: set[Bit] # = field(init=False)

# Set of all classical bits read by the primary and sub commands
bits_read: set[Bit] = field(init=False)
bits_read: set[Bit] # = field(init=False)

# The unique identifier of the shard
ID: int = field(default_factory=count().__next__, init=False)
# A set of the identifiers of other shards this particular shard depends upon
depends_upon: set[int]

def __post_init__(self) -> None:
self.qubits_used = set(self.primary_command.qubits)
self.bits_written = set(self.primary_command.bits)
self.bits_read = set()
# def __post_init__(self) -> None:
# self.qubits_used = set(self.primary_command.qubits)
# self.bits_written = set(self.primary_command.bits)
# self.bits_read = set()

all_sub_commands: list[Command] = []
for sub_commands in self.sub_commands.values():
all_sub_commands.extend(sub_commands)
# all_sub_commands: list[Command] = []
# for sub_commands in self.sub_commands.values():
# all_sub_commands.extend(sub_commands)

for sub_command in all_sub_commands:
self.bits_written.update(sub_command.bits)
self.bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501
)
# for sub_command in all_sub_commands:
# self.bits_written.update(sub_command.bits)
# self.bits_read.update(
# set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501
# )

def pretty_print(self) -> str:
output = io.StringIO()
Expand Down
53 changes: 45 additions & 8 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from pytket.circuit import Circuit, Command, Op, OpType
from pytket.unit_id import UnitID
from pytket.unit_id import Bit, UnitID

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]

SHARD_TRIGGER_OP_TYPES = [
OpType.Measure,
OpType.Reset,
OpType.Barrier,
OpType.SetBits,
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
]


class Sharder:
"""
Expand Down Expand Up @@ -46,7 +55,9 @@ def _process_command(self, command: Command) -> None:
raise NotImplementedError(msg)

if self.should_op_create_shard(command.op):
print(f"Building shard for command: {command}")
print(
f"Building shard for command: {command} args:{command.args} bits:{command.bits}",
)
self._build_shard(command)
else:
self._add_pending_sub_command(command)
Expand All @@ -63,6 +74,22 @@ def _build_shard(self, command: Command) -> None:
):
sub_commands[key] = self._pending_commands.pop(key)

all_commands = [command]
for sub_command in sub_commands.values():
all_commands.extend(sub_command)

qubits_used = set(command.qubits)

bits_written = set(command.bits)

bits_read = set()

for sub_command in all_commands:
bits_written.update(sub_command.bits)
bits_read.update(
set(filter(lambda x: isinstance(x, Bit), sub_command.args)), # type: ignore [misc,arg-type] # noqa: E501
)

# Handle dependency calculations
depends_upon: set[int] = set()
for shard in self._shards:
Expand All @@ -72,9 +99,14 @@ def _build_shard(self, command: Command) -> None:
depends_upon.add(shard.ID)
# Check classical dependencies, which depend on writing and reading
# hazards: RAW, WAW, WAR
# TODO: Do it!
elif not shard.bits_written.isdisjoint(bits_written):
depends_upon.add(shard.ID)
elif not shard.bits_read.isdisjoint(bits_written):
depends_upon.add(shard.ID)

shard = Shard(command, sub_commands, depends_upon)
shard = Shard(
command, sub_commands, qubits_used, bits_written, bits_read, depends_upon,
)
self._shards.append(shard)
print("Appended shard:", shard)

Expand All @@ -101,15 +133,20 @@ def _add_pending_sub_command(self, command: Command) -> None:
if key not in self._pending_commands:
self._pending_commands[key] = []
self._pending_commands[key].append(command)
print(f"Adding pending command {command}")
print(
f"Adding pending command {command} args: {command.args} bits: {command.bits}",
)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
"""
Returns `True` if the operation is one that should result in shard creation.
This includes non-gate operations like measure/reset as well as 2-qubit gates.
TODO: This is almost certainly inadequate right now
"""
return op.type in (OpType.Measure, OpType.Reset, OpType.Barrier) or (
op.is_gate() and op.n_qubits > 1
return (
op.type in (SHARD_TRIGGER_OP_TYPES)
or (
op.type == OpType.Conditional and op.op.type in (SHARD_TRIGGER_OP_TYPES)
)
or (op.is_gate() and op.n_qubits > 1)
)
24 changes: 24 additions & 0 deletions tests/data/qasm/cond_classical.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
OPENQASM 2.0;
include "hqslib1_dev.inc";
qreg q[1];
creg a[10];
creg b[10];
creg c[4];
// classical assignment of registers
a[0] = 1;
a = 3;
// classical bitwise functions
a = 1;
b = 3;
c = a ^ b; // XOR
// evaluating a beyond creg == int
a = 1;
b = 2;
if(a[0]==1) x q[0];
if(a!=1) x q[0];
if(a>1) x q[0];
if(a<1) x q[0];
if(a>=1) x q[0];
if(a<=1) x q[0];
if (a==10) b=1;
measure q[0] -> c[0];
2 changes: 2 additions & 0 deletions tests/data/qasm/simple_cond.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ include "hqslib1.inc";

qreg q[1];
creg c[1];
creg z[1];

h q;
measure q->c;
reset q;
if (c==1) h q;
if (c==1) z=3;
measure q->c;
1 change: 1 addition & 0 deletions tests/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class QasmFiles(Enum):
baby = 4
baby_with_rollup = 5
simple_cond = 6
cond_classical = 7


def get_qasm_as_circuit(qasm_file: QasmFiles) -> Circuit:
Expand Down
70 changes: 66 additions & 4 deletions tests/test_sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_should_op_create_shard(self) -> None:
Op.create(OpType.Reset), # type: ignore # noqa: PGH003
Op.create(OpType.CX), # type: ignore # noqa: PGH003
Op.create(OpType.Barrier), # type: ignore # noqa: PGH003
# Op.create(OpType.SetBits, [3, 1]),
]
expected_false: list[Op] = [
Op.create(OpType.U1, 0.32), # type: ignore # noqa: PGH003
Expand Down Expand Up @@ -86,21 +87,82 @@ def test_simple_conditional(self) -> None:
sharder = Sharder(circuit)
shards = sharder.shard()

assert len(shards) == 3
assert len(shards) == 4

# shard 0: h q; measure q->c;
assert shards[0].primary_command.op.type == OpType.Measure
assert len(shards[0].sub_commands.items()) == 1
s0_qubit, s0_sub_cmds = next(iter(shards[0].sub_commands.items()))
assert s0_qubit == circuit.qubits[0]
assert s0_sub_cmds[0].op.type == OpType.H
assert shards[0].depends_upon == set()

# shard 1: reset q;
assert shards[1].primary_command.op.type == OpType.Reset
assert len(shards[1].sub_commands.items()) == 0
assert shards[1].depends_upon == {shards[0].ID}

assert shards[2].primary_command.op.type == OpType.Measure
assert len(shards[2].sub_commands.items()) == 1
s2_qubit, s2_sub_cmds = next(iter(shards[2].sub_commands.items()))
# shard 2: if (c==1) z=3;
assert shards[2].primary_command.op.type == OpType.Conditional
assert cast(Conditional, shards[2].primary_command).op.op.type == OpType.SetBits
assert len(shards[2].sub_commands.keys()) == 0
assert shards[2].depends_upon == {shards[0].ID}

# shard 3: if (c==1) h q; measure q->c;
assert shards[3].primary_command.op.type == OpType.Measure
assert len(shards[3].sub_commands.items()) == 1
s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items()))
assert s2_qubit == circuit.qubits[0]
assert s2_sub_cmds[0].op.type == OpType.Conditional
assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H
assert s2_sub_cmds[0].qubits == [circuit.qubits[0]]

def test_classical_with_conditionals(self) -> None:
circuit = get_qasm_as_circuit(QasmFiles.cond_classical)
sharder = Sharder(circuit)
shards = sharder.shard()

circuit.get_commands()

# assert len(shards) == 10 # TODO: fix with correct value

# shard 0: a[0] = 1;
assert shards[0].primary_command.op.type == OpType.SetBits
assert len(shards[0].sub_commands.keys()) == 0
assert shards[0].depends_upon == set()
assert shards[0].qubits_used == set()
assert shards[0].bits_read == set()
assert shards[0].bits_written == {circuit.bits[0]}

# shard 1: a = 3;
assert shards[1].primary_command.op.type == OpType.SetBits
assert len(shards[1].sub_commands.keys()) == 0
assert shards[1].depends_upon == {shards[0].ID} # WAW for shard 0
assert shards[1].qubits_used == set()
assert shards[1].bits_read == set()
assert len(shards[1].bits_written) == 10 # TODO: Check for a[0-9]

# shard 2: a = 1;
assert shards[2].primary_command.op.type == OpType.SetBits
assert len(shards[2].sub_commands.keys()) == 0
assert shards[2].depends_upon == {
shards[0].ID,
shards[1].ID,
} # WAW for shard 0, 1
assert shards[2].qubits_used == set()
assert shards[2].bits_read == set()
assert len(shards[2].bits_written) == 10 # TODO: Check for a[0-9]

# shard 3: b = 3;
assert shards[3].primary_command.op.type == OpType.SetBits
assert len(shards[3].sub_commands.keys()) == 0
assert shards[3].depends_upon == set()
assert shards[3].qubits_used == set()
assert shards[3].bits_read == set()
assert len(shards[3].bits_written) == 10 # TODO: Check for b[0-9]

# shard 4: c = a ^ b; // XOR
assert shards[4].primary_command.op.type == OpType.ClassicalExpBox
assert len(shards[3].sub_commands.keys()) == 0
assert shards[3].depends_upon == set()
assert shards[4].qubits_used == set()

0 comments on commit 3d85f40

Please sign in to comment.