diff --git a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py new file mode 100644 index 000000000..33ff67ff5 --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py @@ -0,0 +1,174 @@ +"""Hugr generation for modifiers.""" + +from hugr import Wire, ops +from hugr import tys as ht + +from guppylang_internals.ast_util import get_type +from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back +from guppylang_internals.compiler.cfg_compiler import compile_cfg +from guppylang_internals.compiler.core import CompilerContext, DFContainer +from guppylang_internals.compiler.expr_compiler import ExprCompiler +from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode +from guppylang_internals.std._internal.compiler.array import ( + array_new, + array_to_std_array, + standard_array_type, + std_array_to_array, + unpack_array, +) +from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION +from guppylang_internals.tys.builtin import int_type, is_array_type +from guppylang_internals.tys.ty import InputFlags + + +def compile_modified_block( + modified_block: CheckedModifiedBlock, + dfg: DFContainer, + ctx: CompilerContext, + expr_compiler: ExprCompiler, +) -> Wire: + DAGGER_OP_NAME = "DaggerModifier" + CONTROL_OP_NAME = "ControlModifier" + POWER_OP_NAME = "PowerModifier" + + dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME) + control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME) + power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME) + + body_ty = modified_block.ty + # TODO: Shouldn't this be `to_hugr_poly` since it can contain + # a variable with a generic type? + hugr_ty = body_ty.to_hugr(ctx) + in_out_ht = [ + fn_inp.ty.to_hugr(ctx) + for fn_inp in body_ty.inputs + if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags + ] + other_in_ht = [ + fn_inp.ty.to_hugr(ctx) + for fn_inp in body_ty.inputs + if InputFlags.Inout not in fn_inp.flags + and InputFlags.Comptime not in fn_inp.flags + ] + in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht]) + other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht]) + + func_builder = dfg.builder.module_root_builder().define_function( + str(modified_block), hugr_ty.input, hugr_ty.output + ) + + # compile body + cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx) + func_builder.set_outputs(*cfg) + + # LoadFunc + call = dfg.builder.load_function(func_builder, hugr_ty) + + # Function inputs + captured = [v for v, _ in modified_block.captured.values()] + captured = non_copyable_front_others_back(captured) + args = [dfg[v] for v in captured] + + # Apply modifiers + if modified_block.has_dagger(): + dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty]) + call = dfg.builder.add_op( + ops.ExtOp( + dagger_op_def, + dagger_ty, + [in_out_arg, other_in_arg], + ), + call, + ) + if modified_block.has_power(): + power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty]) + for power in modified_block.power: + num = expr_compiler.compile(power.iter, dfg) + call = dfg.builder.add_op( + ops.ExtOp( + power_op_def, + power_ty, + [in_out_arg, other_in_arg], + ), + call, + num, + ) + qubit_num_args = [] + if modified_block.has_control(): + for control in modified_block.control: + assert control.qubit_num is not None + qubit_num: ht.TypeArg + if isinstance(control.qubit_num, int): + qubit_num = ht.BoundedNatArg(control.qubit_num) + else: + qubit_num = control.qubit_num.to_arg().to_hugr(ctx) + qubit_num_args.append(qubit_num) + std_array = standard_array_type(ht.Qubit, qubit_num) + + # control operator + input_fn_ty = hugr_ty + output_fn_ty = ht.FunctionType( + [std_array, *hugr_ty.input], [std_array, *hugr_ty.output] + ) + op = ops.ExtOp( + control_op_def, + ht.FunctionType([input_fn_ty], [output_fn_ty]), + [qubit_num, in_out_arg, other_in_arg], + ) + call = dfg.builder.add_op(op, call) + # update types + in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems]) + hugr_ty = output_fn_ty + + # Prepare control arguments + ctrl_args: list[Wire] = [] + for i, control in enumerate(modified_block.control): + if is_array_type(get_type(control.ctrl[0])): + control_array = expr_compiler.compile(control.ctrl[0], dfg) + control_array = dfg.builder.add_op( + array_to_std_array(ht.Qubit, qubit_num_args[i]), control_array + ) + ctrl_args.append(control_array) + else: + cs = [expr_compiler.compile(c, dfg) for c in control.ctrl] + control_array = dfg.builder.add_op( + array_new(ht.Qubit, len(control.ctrl)), *cs + ) + control_array = dfg.builder.add_op( + array_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array + ) + ctrl_args.append(control_array) + + # Call + call = dfg.builder.add_op( + ops.CallIndirect(), + call, + *ctrl_args, + *args, + ) + outports = iter(call) + + # Unpack controls + for i, control in enumerate(modified_block.control): + outport = next(outports) + if is_array_type(get_type(control.ctrl[0])): + control_array = dfg.builder.add_op( + std_array_to_array(ht.Qubit, qubit_num_args[i]), outport + ) + c = control.ctrl[0] + assert isinstance(c, PlaceNode) + dfg[c.place] = control_array + else: + control_array = dfg.builder.add_op( + std_array_to_array(ht.Qubit, qubit_num_args[i]), outport + ) + unpacked = unpack_array(dfg.builder, control_array) + for c, new_c in zip(control.ctrl, unpacked, strict=False): + assert isinstance(c, PlaceNode) + dfg[c.place] = new_c + + for arg in captured: + if InputFlags.Inout in arg.flags: + dfg[arg] = next(outports) + + return call diff --git a/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py index f733dfad1..dcc31c7b4 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py @@ -17,6 +17,7 @@ from guppylang_internals.error import InternalGuppyError from guppylang_internals.nodes import ( ArrayUnpack, + CheckedModifiedBlock, CheckedNestedFunctionDef, IterableUnpack, PlaceNode, @@ -220,3 +221,10 @@ def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None var = Variable(node.name, node.ty, node) loaded_func = compile_local_func_def(node, self.dfg, self.ctx) self.dfg[var] = loaded_func + + def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None: + from guppylang_internals.compiler.modifier_compiler import ( + compile_modified_block, + ) + + compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler) diff --git a/guppylang-internals/src/guppylang_internals/nodes.py b/guppylang-internals/src/guppylang_internals/nodes.py index e80e4a1fd..e5f1dc2fc 100644 --- a/guppylang-internals/src/guppylang_internals/nodes.py +++ b/guppylang-internals/src/guppylang_internals/nodes.py @@ -538,11 +538,11 @@ def __str__(self) -> str: # generate a function name from the def_id return f"__WithBlock__({self.def_id})" - def is_dagger(self) -> bool: + def has_dagger(self) -> bool: return len(self.dagger) % 2 == 1 def has_control(self) -> bool: return any(len(c.ctrl) > 0 for c in self.control) - def is_power(self) -> bool: + def has_power(self) -> bool: return len(self.power) > 0 diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py index fe17ea0eb..fb475860b 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py @@ -5,7 +5,9 @@ from tket_exts import ( debug, futures, + global_phase, guppy, + modifier, qsystem, qsystem_random, qsystem_utils, @@ -19,6 +21,7 @@ DEBUG_EXTENSION = debug() FUTURES_EXTENSION = futures() GUPPY_EXTENSION = guppy() +MODIFIER_EXTENSION = modifier() QSYSTEM_EXTENSION = qsystem() QSYSTEM_RANDOM_EXTENSION = qsystem_random() QSYSTEM_UTILS_EXTENSION = qsystem_utils() @@ -26,6 +29,8 @@ RESULT_EXTENSION = result() ROTATION_EXTENSION = rotation() WASM_EXTENSION = wasm() +MODIFIER_EXTENSION = modifier() +GLOBAL_PHASE_EXTENSION = global_phase() TKET_EXTENSIONS = [ BOOL_EXTENSION, @@ -39,6 +44,8 @@ RESULT_EXTENSION, ROTATION_EXTENSION, WASM_EXTENSION, + MODIFIER_EXTENSION, + GLOBAL_PHASE_EXTENSION, ] diff --git a/pyproject.toml b/pyproject.toml index e6381c5ab..d1f5f4b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,5 +65,5 @@ exclude_also = [ addopts = "--benchmark-skip" # benchmarks run explicitly with `just bench` filterwarnings = [ "ignore::DeprecationWarning", # TODO remove after removing guppy.compile() - "ignore::SyntaxWarning", # Python 3.14 complains about guppy tuple callables + "ignore::SyntaxWarning", # Python 3.14 complains about guppy tuple callables ] diff --git a/tests/integration/test_modifier.py b/tests/integration/test_modifier.py new file mode 100644 index 000000000..2108d5a85 --- /dev/null +++ b/tests/integration/test_modifier.py @@ -0,0 +1,135 @@ +from guppylang.decorator import guppy +from guppylang.std.quantum import qubit +from guppylang.std.num import nat +from guppylang.std.builtins import owned +from guppylang.std.array import array + +# Dummy variables to suppress Undefined name +# TODO: `ruff` fails when without these, which need to be fixed +dagger = object() +control = object() +power = object() + + +def test_dagger_simple(validate): + @guppy + def bar() -> None: + with dagger: + pass + + validate(bar.compile_function()) + + +def test_dagger_call_simple(validate): + @guppy + def bar() -> None: + with dagger(): + pass + + validate(bar.compile_function()) + + +def test_control_simple(validate): + @guppy + def bar(q: qubit) -> None: + with control(q): + pass + + validate(bar.compile_function()) + + +def test_control_multiple(validate): + @guppy + def bar(q1: qubit, q2: qubit) -> None: + with control(q1, q2): + pass + + validate(bar.compile_function()) + + +def test_control_array(validate): + @guppy + def bar(q: array[qubit, 3]) -> None: + with control(q): + pass + + validate(bar.compile_function()) + + +def test_power_simple(validate): + @guppy + def bar(n: nat) -> None: + with power(n): + pass + + validate(bar.compile_function()) + + +def test_call_in_modifier(validate): + @guppy + def foo() -> None: + pass + + @guppy + def bar() -> None: + with dagger: + foo() + + validate(bar.compile_function()) + + +def test_combined_modifiers(validate): + @guppy + def bar(q: qubit) -> None: + with control(q), power(2), dagger: + pass + + validate(bar.compile_function()) + + +def test_nested_modifiers(validate): + @guppy + def bar(q: qubit) -> None: + with control(q): + with power(2): + with dagger: + pass + + validate(bar.compile_function()) + + +def test_free_linear_variable_in_modifier(validate): + T = guppy.type_var("T", copyable=False, droppable=False) + + @guppy.declare + def use(a: T) -> None: ... + + @guppy.declare + def discard(a: T @ owned) -> None: ... + + @guppy + def bar(q: qubit) -> None: + a = array(qubit()) + with control(q): + use(a) + discard(a) + + validate(bar.compile_function()) + + +def test_free_copyable_variable_in_modifier(validate): + T = guppy.type_var("T", copyable=True, droppable=True) + + @guppy.declare + def use(a: T) -> None: ... + + @guppy.declare + def discard(a: T @ owned) -> None: ... + + @guppy + def bar(q: array[qubit, 3]) -> None: + a = 3 + with control(q): + use(a) + + validate(bar.compile_function())