Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Hugr generation for modifiers."""

from hugr import Wire, ops
from hugr import tys as ht

from guppylang_internals.ast_util import get_type
from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back
from guppylang_internals.compiler.cfg_compiler import compile_cfg
from guppylang_internals.compiler.core import CompilerContext, DFContainer
from guppylang_internals.compiler.expr_compiler import ExprCompiler
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
from guppylang_internals.std._internal.compiler.array import (
array_new,
array_to_std_array,
standard_array_type,
std_array_to_array,
unpack_array,
)
from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION
from guppylang_internals.tys.builtin import int_type, is_array_type
from guppylang_internals.tys.ty import InputFlags


def compile_modified_block(
modified_block: CheckedModifiedBlock,
dfg: DFContainer,
ctx: CompilerContext,
expr_compiler: ExprCompiler,
) -> Wire:
DAGGER_OP_NAME = "DaggerModifier"
CONTROL_OP_NAME = "ControlModifier"
POWER_OP_NAME = "PowerModifier"

dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME)
control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME)
power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME)

body_ty = modified_block.ty
# TODO: Shouldn't this be `to_hugr_poly` since it can contain
# a variable with a generic type?
hugr_ty = body_ty.to_hugr(ctx)
Comment on lines +39 to +41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will most likely break, but afaik nested functions that reference type variables from the parent function also aren't handled correctly.

I guess the best solution for now would be to detect this case and emit an error telling the user that this isn't supported yet.

To properly handle this, we'll need to turn the modifier into a generic function and then instantiate it with the variables from the parent function...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay to merge this PR without such a proper error handling?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's fine. We just have to wait for the tket release

in_out_ht = [
fn_inp.ty.to_hugr(ctx)
for fn_inp in body_ty.inputs
if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags
]
other_in_ht = [
fn_inp.ty.to_hugr(ctx)
for fn_inp in body_ty.inputs
if InputFlags.Inout not in fn_inp.flags
and InputFlags.Comptime not in fn_inp.flags
]
in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht])
other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht])

func_builder = dfg.builder.module_root_builder().define_function(
str(modified_block), hugr_ty.input, hugr_ty.output
)

# compile body
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
func_builder.set_outputs(*cfg)

# LoadFunc
call = dfg.builder.load_function(func_builder, hugr_ty)

# Function inputs
captured = [v for v, _ in modified_block.captured.values()]
captured = non_copyable_front_others_back(captured)
args = [dfg[v] for v in captured]

# Apply modifiers
if modified_block.has_dagger():
dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty])
call = dfg.builder.add_op(
ops.ExtOp(
dagger_op_def,
dagger_ty,
[in_out_arg, other_in_arg],
),
call,
)
if modified_block.has_power():
power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty])
for power in modified_block.power:
num = expr_compiler.compile(power.iter, dfg)
call = dfg.builder.add_op(
ops.ExtOp(
power_op_def,
power_ty,
[in_out_arg, other_in_arg],
),
call,
num,
)
qubit_num_args = []
if modified_block.has_control():
for control in modified_block.control:
assert control.qubit_num is not None
qubit_num: ht.TypeArg
if isinstance(control.qubit_num, int):
qubit_num = ht.BoundedNatArg(control.qubit_num)
else:
qubit_num = control.qubit_num.to_arg().to_hugr(ctx)
qubit_num_args.append(qubit_num)
std_array = standard_array_type(ht.Qubit, qubit_num)

# control operator
input_fn_ty = hugr_ty
output_fn_ty = ht.FunctionType(
[std_array, *hugr_ty.input], [std_array, *hugr_ty.output]
)
op = ops.ExtOp(
control_op_def,
ht.FunctionType([input_fn_ty], [output_fn_ty]),
[qubit_num, in_out_arg, other_in_arg],
)
call = dfg.builder.add_op(op, call)
# update types
in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems])
hugr_ty = output_fn_ty

# Prepare control arguments
ctrl_args: list[Wire] = []
for i, control in enumerate(modified_block.control):
if is_array_type(get_type(control.ctrl[0])):
control_array = expr_compiler.compile(control.ctrl[0], dfg)
control_array = dfg.builder.add_op(
array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array
)
ctrl_args.append(control_array)
else:
cs = [expr_compiler.compile(c, dfg) for c in control.ctrl]
control_array = dfg.builder.add_op(
array_new(ht.Qubit, len(control.ctrl)), *cs
)
control_array = dfg.builder.add_op(
array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array
)
ctrl_args.append(control_array)

# Call
call = dfg.builder.add_op(
ops.CallIndirect(),
call,
*ctrl_args,
*args,
)
outports = iter(call)

# Unpack controls
for i, control in enumerate(modified_block.control):
outport = next(outports)
if is_array_type(get_type(control.ctrl[0])):
control_array = dfg.builder.add_op(
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
)
c = control.ctrl[0]
assert isinstance(c, PlaceNode)
dfg[c.place] = control_array
else:
control_array = dfg.builder.add_op(
std_array_to_array(ht.Qubit, qubit_num_args[i]), outport
)
unpacked = unpack_array(dfg.builder, control_array)
for c, new_c in zip(control.ctrl, unpacked, strict=False):
assert isinstance(c, PlaceNode)
dfg[c.place] = new_c

for arg in captured:
if InputFlags.Inout in arg.flags:
dfg[arg] = next(outports)

return call
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from guppylang_internals.error import InternalGuppyError
from guppylang_internals.nodes import (
ArrayUnpack,
CheckedModifiedBlock,
CheckedNestedFunctionDef,
IterableUnpack,
PlaceNode,
Expand Down Expand Up @@ -220,3 +221,10 @@ def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None
var = Variable(node.name, node.ty, node)
loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
self.dfg[var] = loaded_func

def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None:
from guppylang_internals.compiler.modifier_compiler import (
compile_modified_block,
)

compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler)
4 changes: 2 additions & 2 deletions guppylang-internals/src/guppylang_internals/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,11 @@ def __str__(self) -> str:
# generate a function name from the def_id
return f"__WithBlock__({self.def_id})"

def is_dagger(self) -> bool:
def has_dagger(self) -> bool:
return len(self.dagger) % 2 == 1

def has_control(self) -> bool:
return any(len(c.ctrl) > 0 for c in self.control)

def is_power(self) -> bool:
def has_power(self) -> bool:
return len(self.power) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from tket_exts import (
debug,
futures,
global_phase,
guppy,
modifier,
qsystem,
qsystem_random,
qsystem_utils,
Expand All @@ -19,13 +21,16 @@
DEBUG_EXTENSION = debug()
FUTURES_EXTENSION = futures()
GUPPY_EXTENSION = guppy()
MODIFIER_EXTENSION = modifier()
QSYSTEM_EXTENSION = qsystem()
QSYSTEM_RANDOM_EXTENSION = qsystem_random()
QSYSTEM_UTILS_EXTENSION = qsystem_utils()
QUANTUM_EXTENSION = quantum()
RESULT_EXTENSION = result()
ROTATION_EXTENSION = rotation()
WASM_EXTENSION = wasm()
MODIFIER_EXTENSION = modifier()
GLOBAL_PHASE_EXTENSION = global_phase()

TKET_EXTENSIONS = [
BOOL_EXTENSION,
Expand All @@ -39,6 +44,8 @@
RESULT_EXTENSION,
ROTATION_EXTENSION,
WASM_EXTENSION,
MODIFIER_EXTENSION,
GLOBAL_PHASE_EXTENSION,
]


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ exclude_also = [
addopts = "--benchmark-skip" # benchmarks run explicitly with `just bench`
filterwarnings = [
"ignore::DeprecationWarning", # TODO remove after removing guppy.compile()
"ignore::SyntaxWarning", # Python 3.14 complains about guppy tuple callables
"ignore::SyntaxWarning", # Python 3.14 complains about guppy tuple callables
]
135 changes: 135 additions & 0 deletions tests/integration/test_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from guppylang.decorator import guppy
from guppylang.std.quantum import qubit
from guppylang.std.num import nat
from guppylang.std.builtins import owned
from guppylang.std.array import array

# Dummy variables to suppress Undefined name
# TODO: `ruff` fails when without these, which need to be fixed
dagger = object()
control = object()
power = object()
Comment on lines +7 to +11
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add these to guppylang.std.lang and then reexport in guppylang.std.builtins.

But I'm also happy if you prefer to do this in a separate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to mention, but I made an issue for this (#1290)
I would rather think it should be left as a future feature to be handled in a better manner.



def test_dagger_simple(validate):
@guppy
def bar() -> None:
with dagger:
pass

validate(bar.compile_function())


def test_dagger_call_simple(validate):
@guppy
def bar() -> None:
with dagger():
pass

validate(bar.compile_function())


def test_control_simple(validate):
@guppy
def bar(q: qubit) -> None:
with control(q):
pass

validate(bar.compile_function())


def test_control_multiple(validate):
@guppy
def bar(q1: qubit, q2: qubit) -> None:
with control(q1, q2):
pass

validate(bar.compile_function())


def test_control_array(validate):
@guppy
def bar(q: array[qubit, 3]) -> None:
with control(q):
pass

validate(bar.compile_function())


def test_power_simple(validate):
@guppy
def bar(n: nat) -> None:
with power(n):
pass

validate(bar.compile_function())


def test_call_in_modifier(validate):
@guppy
def foo() -> None:
pass

@guppy
def bar() -> None:
with dagger:
foo()

validate(bar.compile_function())


def test_combined_modifiers(validate):
@guppy
def bar(q: qubit) -> None:
with control(q), power(2), dagger:
pass

validate(bar.compile_function())


def test_nested_modifiers(validate):
@guppy
def bar(q: qubit) -> None:
with control(q):
with power(2):
with dagger:
pass

validate(bar.compile_function())


def test_free_linear_variable_in_modifier(validate):
T = guppy.type_var("T", copyable=False, droppable=False)

@guppy.declare
def use(a: T) -> None: ...

@guppy.declare
def discard(a: T @ owned) -> None: ...

@guppy
def bar(q: qubit) -> None:
a = array(qubit())
with control(q):
use(a)
discard(a)

validate(bar.compile_function())


def test_free_copyable_variable_in_modifier(validate):
T = guppy.type_var("T", copyable=True, droppable=True)

@guppy.declare
def use(a: T) -> None: ...

@guppy.declare
def discard(a: T @ owned) -> None: ...

@guppy
def bar(q: array[qubit, 3]) -> None:
a = 3
with control(q):
use(a)

validate(bar.compile_function())
Loading