diff --git a/guppylang-internals/src/guppylang_internals/decorator.py b/guppylang-internals/src/guppylang_internals/decorator.py index 2e76b107e..9081a5506 100644 --- a/guppylang-internals/src/guppylang_internals/decorator.py +++ b/guppylang-internals/src/guppylang_internals/decorator.py @@ -3,9 +3,11 @@ import inspect from typing import TYPE_CHECKING, ParamSpec, TypeVar +from hugr import ext as he from hugr import ops from hugr import tys as ht +from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition from guppylang_internals.compiler.core import ( CompilerContext, GlobalConstId, @@ -20,6 +22,8 @@ OpCompiler, RawCustomFunctionDef, ) +from guppylang_internals.definition.function import RawFunctionDef +from guppylang_internals.definition.lowerable import RawLowerableFunctionDef from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef from guppylang_internals.definition.wasm import RawWasmFunctionDef from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running @@ -46,7 +50,6 @@ from collections.abc import Callable, Sequence from types import FrameType - from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition from guppylang_internals.tys.arg import Argument from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.subst import Inst @@ -121,6 +124,73 @@ def hugr_op( return custom_function(OpCompiler(op), checker, higher_order_value, name, signature) +def lowerable_op( + hugr_ext: he.Extension, + checker: CustomCallChecker | None = None, + higher_order_value: bool = True, +) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: + """Decorator to automatically generate a hugr OpDef and add to the user-provided + hugr extension. + + Args: + hugr_ext: Hugr extension for the hugr OpDef to be added + """ + + def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: + defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f) + DEF_STORE.register_def(defn, get_calling_frame()) + + defn = GuppyDefinition(defn) + + compiled_defn = defn.compile() + + try: + func_op = next( + data.op + for _, data in compiled_defn.modules[0].nodes() + if isinstance(data.op, ops.FuncDefn) and data.op.f_name == f.__name__ + ) + except StopIteration as e: + raise NameError( + f"Function definition ({f.__name__}) not found in hugr." + ) from e + + op_def = he.OpDef( + name=f.__name__, + description=f.__doc__ or "", + signature=he.OpDefSig(poly_func=func_op.signature), + lower_funcs=[ + he.FixedHugr( + ht.ExtensionSet([ext.name for ext in compiled_defn.extensions]), + module, + ) + for module in compiled_defn.modules + ], + ) + + hugr_ext.add_op_def(op_def) + + def op(ty: ht.FunctionType, inst: Inst, ctx: CompilerContext) -> ops.DataflowOp: + return ops.ExtOp(op_def, ty, [arg.to_hugr(ctx) for arg in inst]) + + call_checker = checker or DefaultCallChecker() + + func = RawLowerableFunctionDef( + DefId.fresh(), + f.__name__, + None, + f, + call_checker, + OpCompiler(op), + higher_order_value, + None, + ) + DEF_STORE.register_def(func, get_calling_frame()) + return GuppyFunctionDefinition(func) + + return dec + + def extend_type(defn: TypeDef) -> Callable[[type], type]: """Decorator to add new instance functions to a type.""" from guppylang.defs import GuppyDefinition diff --git a/guppylang-internals/src/guppylang_internals/definition/lowerable.py b/guppylang-internals/src/guppylang_internals/definition/lowerable.py new file mode 100644 index 000000000..94cf68379 --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/definition/lowerable.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from guppylang_internals.compiler.core import ( + GlobalConstId, +) +from guppylang_internals.definition.custom import ( + CustomFunctionDef, + RawCustomFunctionDef, +) +from guppylang_internals.span import SourceMap +from guppylang_internals.tys.ty import ( + FunctionType, + NoneType, +) + +if TYPE_CHECKING: + from guppylang_internals.checker.core import Globals + + +@dataclass(frozen=True) +class RawLowerableFunctionDef(RawCustomFunctionDef): + """A raw custom function definition provided by the user. + + Custom functions provide their own checking and compilation logic using a + `CustomCallChecker` and a `CustomCallCompiler`. + + The raw definition stores exactly what the user has written (i.e. the AST together + with the provided checker and compiler), without inspecting the signature. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. + call_checker: The custom call checker. + call_compiler: The custom call compiler. + higher_order_value: Whether the function may be used as a higher-order value. + signature: Optional User-provided signature. + """ + + def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef": + """Parses and checks the signature of the lowerable function. + + The signature is optional if custom type checking logic is provided by the user. + However, a signature *must* be provided to use the function as a higher-order + value (either by annotation or as an argument). If a signature is provided as an + argument, this will override any annotation. + + If no signature is provided, we fill in the dummy signature `() -> ()`. This + type will never be inspected, since we rely on the provided custom checking + code. The only information we need to access is that it's a function type and + that there are no unsolved existential vars. + """ + from guppylang_internals.definition.function import parse_py_func + + func_ast, _ = parse_py_func(self.python_func, sources) + sig = self.signature or self._get_signature(func_ast, globals) + ty = sig or FunctionType([], NoneType()) + return CustomFunctionDef( + self.id, + self.name, + func_ast, + ty, + self.call_checker, + self.call_compiler, + self.higher_order_value, + GlobalConstId.fresh(self.name), + sig is not None, + ) diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index 98a76fab0..7963f978f 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -24,9 +24,7 @@ ) from guppylang_internals.definition.declaration import RawFunctionDecl from guppylang_internals.definition.extern import RawExternDef -from guppylang_internals.definition.function import ( - RawFunctionDef, -) +from guppylang_internals.definition.function import RawFunctionDef from guppylang_internals.definition.overloaded import OverloadedFunctionDef from guppylang_internals.definition.parameter import ( ConstVarDef, diff --git a/tests/integration/test_lower_op.py b/tests/integration/test_lower_op.py new file mode 100644 index 000000000..508c741e1 --- /dev/null +++ b/tests/integration/test_lower_op.py @@ -0,0 +1,47 @@ +from guppylang import guppy +from guppylang_internals.decorator import lowerable_op +from guppylang.std.quantum import qubit, h, cx, measure + +import hugr.ext as he + +from pydantic_extra_types.semantic_version import SemanticVersion + + +def test_auto_hugr_lowering(validate): + test_hugr_ext = he.Extension("test_hugr_ext", SemanticVersion(0, 1, 0)) + + @lowerable_op(test_hugr_ext) + def entangle(q0: qubit, q1: qubit) -> None: + h(q0) + cx(q0, q1) + + @guppy + def main() -> None: + q0 = qubit() + q1 = qubit() + q2 = qubit() + entangle(q0, q1) + entangle(q1, q2) + measure(q0) + measure(q1) + measure(q2) + + hugr = main.compile() + + hugr.extensions.append(test_hugr_ext) + + validate(hugr) + + +def test_lower_funcs_hugr(validate): + test_hugr_ext = he.Extension("test_hugr_ext", SemanticVersion(0, 1, 0)) + + @lowerable_op(test_hugr_ext) + def entangle(q0: qubit, q1: qubit) -> None: + h(q0) + cx(q0, q1) + + op = test_hugr_ext.get_op("entangle") + + for funcs in op.lower_funcs: + validate(funcs.hugr)