Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from __future__ import annotations

import contextlib
import dataclasses
import functools
import time
Expand All @@ -32,7 +33,6 @@
config,
embedded as next_embedded,
errors,
metrics,
utils,
)
from gt4py.next.embedded import operators as embedded_operators
Expand All @@ -46,6 +46,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.gtcallable import GTCallable
from gt4py.next.instrumentation import _hook_machinery, metrics
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import arguments, compiled_program, stages, toolchain
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
Expand All @@ -54,6 +55,29 @@
DEFAULT_BACKEND: next_backend.Backend | None = None


@_hook_machinery.ContextHook
def program_call_hook( # type: ignore[empty-body]
program: Program,
args: tuple[Any, ...],
offset_provider: common.OffsetProvider,
enable_jit: bool,
kwargs: dict[str, Any],
) -> contextlib.AbstractContextManager:
"""Hook called at the beginning and end of a program call."""
...


@_hook_machinery.ContextHook
def embedded_program_call_hook( # type: ignore[empty-body]
program: Program,
args: tuple[Any, ...],
offset_provider: common.OffsetProvider,
kwargs: dict[str, Any],
) -> contextlib.AbstractContextManager:
"""Hook called at the beginning and end of an embedded program call."""
...


# TODO(tehrengruber): Decide if and how programs can call other programs. As a
# result Program could become a GTCallable.
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -275,27 +299,31 @@ def __call__(
kwarg_types={k: type_translation.from_value(v) for k, v in kwargs.items()},
)

if self.backend is not None:
self._compiled_programs(
*args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit
)
else:
# Embedded execution.
# Metrics source key needs to be setup here, since embedded programs
# don't have variants and thus there's no other place we could do this.
if config.COLLECT_METRICS_LEVEL:
assert metrics_source is not None
metrics_source.key = (
f"{self.__name__}<{getattr(self.backend, 'name', '<embedded>')}>"
with program_call_hook(self, args, offset_provider, enable_jit, kwargs):
if self.backend is not None:
self._compiled_programs(
*args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit
)
warnings.warn(
UserWarning(
f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend."
),
stacklevel=2,
)
with next_embedded.context.update(offset_provider=offset_provider):
self.definition_stage.definition(*args, **kwargs)
else:
# Embedded execution.
warnings.warn(
UserWarning(
f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend."
),
stacklevel=2,
)

# Metrics source key needs to be setup here, since embedded programs
# don't have variants so there's no other place to do it.
if config.COLLECT_METRICS_LEVEL:
assert metrics_source is not None
metrics.set_current_source_key(
f"{self.__name__}<{getattr(self.backend, 'name', '<embedded>')}>"
)

with next_embedded.context.update(offset_provider=offset_provider):
with embedded_program_call_hook(self, args, offset_provider, kwargs):
self.definition_stage.definition(*args, **kwargs)

if collect_info_metrics:
assert metrics_source is not None
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/instrumentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

185 changes: 185 additions & 0 deletions src/gt4py/next/instrumentation/_hook_machinery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import ast
import collections.abc
import contextlib
import dataclasses
import inspect
import textwrap
import types
import typing
import warnings
from collections.abc import Callable
from typing import Generic, ParamSpec, TypeVar


P = ParamSpec("P")
T = TypeVar("T")


def _get_unique_name(func: Callable) -> str:
"""Generate a unique name for a callable object."""
return (
f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}"
)


def _is_empty_function(func: Callable) -> bool:
"""Check if a callable object is empty (i.e., contains no statements)."""
try:
assert callable(func)
callable_src = (
inspect.getsource(func)
if isinstance(func, types.FunctionType)
else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above
)
callable_ast = ast.parse(textwrap.dedent(callable_src))
return all(
isinstance(st, ast.Pass)
or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant))
for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body
)
except Exception:
return False


@dataclasses.dataclass(slots=True)
class _BaseHook(Generic[T, P]):
"""Base class to define callback registration functionality for all hook types."""

definition: Callable[P, T]
registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True)
callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False)

def __post_init__(self) -> None:
# As an optimization to avoid an empty function call if no callbacks are
# registered, we only add the original definitions to the list of callables
# if it contains a non-empty definition.
if not _is_empty_function(self.definition):
self.callbacks = (self.definition,)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
raise NotImplementedError("This method should be implemented by subclasses.")

def register(
self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None
) -> None:
"""
Register a callback to the hook.

Args:
callback: The callable to register.
name: An optional name for the callback. If not provided, a unique name will be generated.
index: An optional index at which to insert the callback (not counting the original
definition). If not provided, the callback will be appended to the end of the list.
"""

callable_signature = inspect.signature(callback)
hook_signature = inspect.signature(self.definition)

signature_mismatch = len(callable_signature.parameters) != len(
hook_signature.parameters
) or any(
# Remove the annotation before comparison to avoid false mismatches
actual_param.replace(annotation="") != expected_param.replace(annotation="")
for actual_param, expected_param in zip(
callable_signature.parameters.values(), hook_signature.parameters.values()
)
)
if signature_mismatch:
raise ValueError(
f"Callback signature {callable_signature} does not match hook signature {hook_signature}"
)
try:
callable_typing = typing.get_type_hints(callback)
hook_typing = typing.get_type_hints(self.definition)
if not all(
callable_typing[arg_key] == arg_typing
for arg_key, arg_typing in hook_typing.items()
):
warnings.warn(
f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}",
stacklevel=2,
)
except Exception:
pass
name = name or _get_unique_name(callback)

if index is None:
self.callbacks += (callback,)
else:
if self.callbacks and self.callbacks[0] is self.definition:
index += 1 # The original definition should always go first
self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:])

self.registry[name] = callback

def remove(self, callback: str | Callable[P, T]) -> None:
"""
Remove a registered callback from the hook.

Args:
callback: The callable object to remove or its registered name.
"""
if isinstance(callback, str):
name = callback
if name not in self.registry:
raise KeyError(f"No callback registered under the name '{name}'")
else:
name = _get_unique_name(callback)
if name not in self.registry:
raise KeyError(f"Callback object {callback} not found in registry")

callback = self.registry.pop(name)
assert callback in self.callbacks
self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback)


@dataclasses.dataclass(slots=True)
class EventHook(_BaseHook[None, P]):
"""Event hook specification."""

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
for func in self.callbacks:
func(*args, **kwargs)


@dataclasses.dataclass(slots=True)
class ContextHook(
contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P]
):
"""
Context hook specification.

This hook type is used to define context managers that can be stacked together.
"""

ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field(
default=(), init=False
)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager:
self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks]
return self

def __enter__(self) -> None:
for ctx_manager in self.ctx_managers:
ctx_manager.__enter__()

def __exit__(
self,
type_: type[BaseException] | None,
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
for ctx_manager in reversed(self.ctx_managers):
ctx_manager.__exit__(type_, value, traceback)
self.ctx_managers = ()
15 changes: 15 additions & 0 deletions src/gt4py/next/instrumentation/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from gt4py.next.ffront.decorator import (
embedded_program_call_hook as embedded_program_call_hook,
program_call_hook as program_call_hook,
)
from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook
Loading
Loading