Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion guppylang-internals/src/guppylang_internals/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import inspect
from typing import TYPE_CHECKING, ParamSpec, TypeVar

from hugr import ext as he
from hugr import ops
from hugr import tys as ht

from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
from guppylang_internals.compiler.core import (
CompilerContext,
GlobalConstId,
Expand All @@ -20,6 +22,8 @@
OpCompiler,
RawCustomFunctionDef,
)
from guppylang_internals.definition.function import RawFunctionDef
from guppylang_internals.definition.lowerable import RawLowerableFunctionDef
from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
from guppylang_internals.definition.wasm import RawWasmFunctionDef
from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
Expand All @@ -46,7 +50,6 @@
from collections.abc import Callable, Sequence
from types import FrameType

from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
from guppylang_internals.tys.arg import Argument
from guppylang_internals.tys.param import Parameter
from guppylang_internals.tys.subst import Inst
Expand Down Expand Up @@ -121,6 +124,73 @@ def hugr_op(
return custom_function(OpCompiler(op), checker, higher_order_value, name, signature)


def lowerable_op(
hugr_ext: he.Extension,
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Decorator to automatically generate a hugr OpDef and add to the user-provided
hugr extension.

Args:
hugr_ext: Hugr extension for the hugr OpDef to be added
"""

def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())

defn = GuppyDefinition(defn)

compiled_defn = defn.compile()

try:
func_op = next(
data.op
for _, data in compiled_defn.modules[0].nodes()
if isinstance(data.op, ops.FuncDefn) and data.op.f_name == f.__name__
)
except StopIteration as e:
raise NameError(
f"Function definition ({f.__name__}) not found in hugr."
) from e

op_def = he.OpDef(
name=f.__name__,
description=f.__doc__ or "",
signature=he.OpDefSig(poly_func=func_op.signature),
lower_funcs=[
he.FixedHugr(
ht.ExtensionSet([ext.name for ext in compiled_defn.extensions]),
module,
)
for module in compiled_defn.modules
],
)

hugr_ext.add_op_def(op_def)

def op(ty: ht.FunctionType, inst: Inst, ctx: CompilerContext) -> ops.DataflowOp:
return ops.ExtOp(op_def, ty, [arg.to_hugr(ctx) for arg in inst])

call_checker = checker or DefaultCallChecker()

func = RawLowerableFunctionDef(
DefId.fresh(),
f.__name__,
None,
f,
call_checker,
OpCompiler(op),
higher_order_value,
None,
)
DEF_STORE.register_def(func, get_calling_frame())
return GuppyFunctionDefinition(func)

return dec


def extend_type(defn: TypeDef) -> Callable[[type], type]:
"""Decorator to add new instance functions to a type."""
from guppylang.defs import GuppyDefinition
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from guppylang_internals.compiler.core import (
GlobalConstId,
)
from guppylang_internals.definition.custom import (
CustomFunctionDef,
RawCustomFunctionDef,
)
from guppylang_internals.span import SourceMap
from guppylang_internals.tys.ty import (
FunctionType,
NoneType,
)

if TYPE_CHECKING:
from guppylang_internals.checker.core import Globals


@dataclass(frozen=True)
class RawLowerableFunctionDef(RawCustomFunctionDef):
"""A raw custom function definition provided by the user.

Custom functions provide their own checking and compilation logic using a
`CustomCallChecker` and a `CustomCallCompiler`.

The raw definition stores exactly what the user has written (i.e. the AST together
with the provided checker and compiler), without inspecting the signature.

Args:
id: The unique definition identifier.
name: The name of the definition.
defined_at: The AST node where the definition was defined.
call_checker: The custom call checker.
call_compiler: The custom call compiler.
higher_order_value: Whether the function may be used as a higher-order value.
signature: Optional User-provided signature.
"""

def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
"""Parses and checks the signature of the lowerable function.

The signature is optional if custom type checking logic is provided by the user.
However, a signature *must* be provided to use the function as a higher-order
value (either by annotation or as an argument). If a signature is provided as an
argument, this will override any annotation.

If no signature is provided, we fill in the dummy signature `() -> ()`. This
type will never be inspected, since we rely on the provided custom checking
code. The only information we need to access is that it's a function type and
that there are no unsolved existential vars.
"""
from guppylang_internals.definition.function import parse_py_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm, ok. So the differences against the RawCustomFunctionDef.parse that we're overriding here are:

  • we suppress the error if there is a body
  • we don't look at self.signature before we call self._get_signature(func_ast, globals)....is that right?
  • we don't assign the local variable docstring that AFAICS in the original/overridden method is defined but not used(?!)

There is definitely an opportunity here for both commoning-up, and possibly cleaning up....do you feel like taking it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are the difference between the two methods. As suppressing the error if there is no body is required for the lowering case, it was suggested that these implementations were kept separate. This does come with a small side affect that self.signature is not used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed that you need suppress the warning, and since you are now using self.signature, looks good to me. You could possibly add a private helper that does the sig = .....; ty = .....; return CustomFunctionDef(self...., self....) etc. (taking just self and func_ast as parameters) but fine as it stands.


func_ast, _ = parse_py_func(self.python_func, sources)
sig = self.signature or self._get_signature(func_ast, globals)
ty = sig or FunctionType([], NoneType())
return CustomFunctionDef(
self.id,
self.name,
func_ast,
ty,
self.call_checker,
self.call_compiler,
self.higher_order_value,
GlobalConstId.fresh(self.name),
sig is not None,
)
4 changes: 1 addition & 3 deletions guppylang/src/guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
)
from guppylang_internals.definition.declaration import RawFunctionDecl
from guppylang_internals.definition.extern import RawExternDef
from guppylang_internals.definition.function import (
RawFunctionDef,
)
from guppylang_internals.definition.function import RawFunctionDef
from guppylang_internals.definition.overloaded import OverloadedFunctionDef
from guppylang_internals.definition.parameter import (
ConstVarDef,
Expand Down
47 changes: 47 additions & 0 deletions tests/integration/test_lower_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from guppylang import guppy
from guppylang_internals.decorator import lowerable_op
from guppylang.std.quantum import qubit, h, cx, measure

import hugr.ext as he

from pydantic_extra_types.semantic_version import SemanticVersion


def test_auto_hugr_lowering(validate):
test_hugr_ext = he.Extension("test_hugr_ext", SemanticVersion(0, 1, 0))

@lowerable_op(test_hugr_ext)
def entangle(q0: qubit, q1: qubit) -> None:
h(q0)
cx(q0, q1)

@guppy
def main() -> None:
q0 = qubit()
q1 = qubit()
q2 = qubit()
entangle(q0, q1)
entangle(q1, q2)
measure(q0)
measure(q1)
measure(q2)

hugr = main.compile()

hugr.extensions.append(test_hugr_ext)

validate(hugr)


def test_lower_funcs_hugr(validate):
test_hugr_ext = he.Extension("test_hugr_ext", SemanticVersion(0, 1, 0))

@lowerable_op(test_hugr_ext)
def entangle(q0: qubit, q1: qubit) -> None:
h(q0)
cx(q0, q1)

op = test_hugr_ext.get_op("entangle")

for funcs in op.lower_funcs:
validate(funcs.hugr)
Loading