Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6cbbcc6
Add basic instrumentation subpackage
egparedes Jan 12, 2026
df46627
Add machinery and tests and several refactors
egparedes Jan 12, 2026
becb7ba
Run precommit
egparedes Jan 14, 2026
68d2283
Remove boilerplate from unit tests
egparedes Jan 15, 2026
a1963e9
Add integration tests
egparedes Jan 15, 2026
9c5729a
Run pre-commit
egparedes Jan 15, 2026
86f4e01
Address copilot review comments (mostly cosmetic)
egparedes Jan 15, 2026
2944cb7
Merge branch 'main' into instrumentation
egparedes Jan 16, 2026
c38939e
WIP metrics refactoring
egparedes Jan 19, 2026
2e41eea
Remove hooks changes
egparedes Jan 19, 2026
ffefed3
More metrics refactorings and cleanups
egparedes Jan 20, 2026
14a568d
Merge branch 'main' into refactor-metrics
egparedes Jan 20, 2026
edbe49b
Fixes
egparedes Jan 20, 2026
332ba9b
Typing fixes and renames
egparedes Jan 21, 2026
ca7d62a
Adding metrics collector maker
egparedes Jan 21, 2026
f6fc427
Fixes
egparedes Jan 21, 2026
7d5cb5c
Fix typing due to mypy bug and minor cleanups
egparedes Jan 21, 2026
7b4b3ce
Clean up docs
egparedes Jan 21, 2026
54e8910
More docs cleanups
egparedes Jan 21, 2026
fe6db62
Merge branch 'main' into refactor-metrics
egparedes Jan 21, 2026
7e1a290
Final cleanups and more tests
egparedes Jan 21, 2026
1dbb1ee
Fix review comments
egparedes Jan 21, 2026
0916e39
Improve docs
egparedes Jan 22, 2026
801cb6b
Merge branch 'main' into refactor-metrics
egparedes Jan 26, 2026
26ac37a
Restore instrumentation hook files
egparedes Jan 26, 2026
01a8d38
Use hooks for basic metrics
egparedes Jan 30, 2026
7155fef
Merge branch 'main' into instrumentation-merge
egparedes Jan 30, 2026
8e37777
pre-commit
egparedes Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import annotations

import abc
import contextlib
import dataclasses
import functools
import types
Expand Down Expand Up @@ -44,7 +45,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.gtcallable import GTCallable
from gt4py.next.instrumentation import metrics
from gt4py.next.instrumentation import hook_machinery, metrics
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import arguments, compiled_program, options, toolchain
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
Expand Down Expand Up @@ -161,6 +162,37 @@ def compile(
)


ProgramCallMetricsCollector = metrics.make_collector(
level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC
)


@hook_machinery.ContextHook
def program_call_context(
program: Program,
args: tuple[Any, ...],
offset_provider: common.OffsetProvider,
enable_jit: bool,
kwargs: dict[str, Any],
) -> contextlib.AbstractContextManager:
"""Hook called at the beginning and end of a program call."""
return ProgramCallMetricsCollector()


@hook_machinery.EventHook
def embedded_program_call_hook(
program: Program,
args: tuple[Any, ...],
offset_provider: common.OffsetProvider,
kwargs: dict[str, Any],
) -> None:
"""Hook called at the beginning and end of an embedded program call."""
# Metrics source key needs to be set here. Embedded programs
# don't have variants so there's no other place to do it.
if metrics.is_level_enabled(metrics.MINIMAL):
metrics.set_current_source_key(f"{program.__name__}<'<embedded>')>")


# TODO(tehrengruber): Decide if and how programs can call other programs. As a
# result Program could become a GTCallable.
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -327,7 +359,13 @@ def __call__(
offset_provider = {}
enable_jit = self.compilation_options.enable_jit if enable_jit is None else enable_jit

with program_call_metrics_collector():
with program_call_context(
program=self,
args=args,
offset_provider=offset_provider,
enable_jit=enable_jit,
kwargs=kwargs,
):
if __debug__:
# TODO: remove or make dependency on self.past_stage optional
past_process_args._validate_args(
Expand All @@ -349,14 +387,8 @@ def __call__(
stacklevel=2,
)

# Metrics source key needs to be set here. Embedded programs
# don't have variants so there's no other place to do it.
if metrics.is_level_enabled(metrics.MINIMAL):
metrics.set_current_source_key(
f"{self.__name__}<{getattr(self.backend, 'name', '<embedded>')}>"
)

with next_embedded.context.update(offset_provider=offset_provider):
embedded_program_call_hook(self, args, offset_provider, kwargs)
self.definition_stage.definition(*args, **kwargs)


Expand Down
192 changes: 192 additions & 0 deletions src/gt4py/next/instrumentation/hook_machinery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import ast
import collections.abc
import contextlib
import dataclasses
import inspect
import textwrap
import types
import typing
import warnings
from collections.abc import Callable
from typing import Generic, ParamSpec, TypeVar


P = ParamSpec("P")
T = TypeVar("T")


def _get_unique_name(func: Callable) -> str:
"""Generate a unique name for a callable object."""
return (
f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}"
)


def _is_empty_function(func: Callable) -> bool:
"""Check if a callable object is empty (i.e., contains no statements)."""
try:
assert callable(func)
callable_src = (
inspect.getsource(func)
if isinstance(func, types.FunctionType)
else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above
)
callable_ast = ast.parse(textwrap.dedent(callable_src))
return all(
isinstance(st, ast.Pass)
or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant))
for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body
)
except Exception:
return False


@dataclasses.dataclass(slots=True)
class _BaseHook(Generic[T, P]):
"""Base class to define callback registration functionality for all hook types."""

definition: Callable[P, T]
registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True)
callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False)

@property
def __doc__(self) -> str | None: # type: ignore[override]
return self.definition.__doc__

def __post_init__(self) -> None:
# As an optimization to avoid an empty function call if no callbacks are
# registered, we only add the original definitions to the list of callables
# if it contains a non-empty definition.
if not _is_empty_function(self.definition):
self.callbacks = (self.definition,)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
raise NotImplementedError("This method should be implemented by subclasses.")

def register(
self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None
) -> None:
"""
Register a callback to the hook.

Args:
callback: The callable to register.
name: An optional name for the callback. If not provided, a unique name will be generated.
index: An optional index at which to insert the callback (not counting the original
definition). If not provided, the callback will be appended to the end of the list.
"""

callable_signature = inspect.signature(callback)
hook_signature = inspect.signature(self.definition)

signature_mismatch = len(callable_signature.parameters) != len(
hook_signature.parameters
) or any(
# Remove the annotation before comparison to avoid false mismatches
actual_param.replace(annotation="") != expected_param.replace(annotation="")
for actual_param, expected_param in zip(
callable_signature.parameters.values(), hook_signature.parameters.values()
)
)
if signature_mismatch:
raise ValueError(
f"Callback signature {callable_signature} does not match hook signature {hook_signature}"
)
try:
callable_typing = typing.get_type_hints(callback)
hook_typing = typing.get_type_hints(self.definition)
if not all(
callable_typing[arg_key] == arg_typing
for arg_key, arg_typing in hook_typing.items()
):
warnings.warn(
f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}",
stacklevel=2,
)
except Exception:
# Ignore issues while checking type hints (e.g., forward references
# or missing imports); failure here should not prevent hook registration.
pass

name = name or _get_unique_name(callback)

if index is None:
self.callbacks += (callback,)
else:
if self.callbacks and self.callbacks[0] is self.definition:
index += 1 # The original definition should always go first
self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:])

self.registry[name] = callback

def remove(self, callback: str | Callable[P, T]) -> None:
"""
Remove a registered callback from the hook.

Args:
callback: The callable object to remove or its registered name.
"""
if isinstance(callback, str):
name = callback
if name not in self.registry:
raise KeyError(f"No callback registered under the name '{name}'")
else:
name = _get_unique_name(callback)
if name not in self.registry:
raise KeyError(f"Callback object {callback} not found in registry")

callback = self.registry.pop(name)
assert callback in self.callbacks
self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback)


@dataclasses.dataclass(slots=True)
class EventHook(_BaseHook[None, P]):
"""Event hook specification."""

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
for func in self.callbacks:
func(*args, **kwargs)


@dataclasses.dataclass(slots=True)
class ContextHook(
contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P]
):
"""
Context hook specification.

This hook type is used to define context managers that can be stacked together.
"""

ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field(
default=(), init=False
)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager:
self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks]
return self

def __enter__(self) -> None:
for ctx_manager in self.ctx_managers:
ctx_manager.__enter__()

def __exit__(
self,
type_: type[BaseException] | None,
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
for ctx_manager in reversed(self.ctx_managers):
ctx_manager.__exit__(type_, value, traceback)
self.ctx_managers = ()
18 changes: 18 additions & 0 deletions src/gt4py/next/instrumentation/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from gt4py.next.ffront.decorator import (
embedded_program_call_hook as embedded_program_call_hook,
program_call_context as program_call_context,
)
from gt4py.next.otf.compiled_program import (
compile_variant_hook as compile_variant_hook,
compiled_program_call_hook as compiled_program_call_hook,
)
4 changes: 2 additions & 2 deletions src/gt4py/next/instrumentation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __enter__(self) -> None:
def __exit__(
self,
exc_type_: type[BaseException] | None,
value: BaseException | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
if self.previous_cvar_token is not None:
Expand Down Expand Up @@ -280,7 +280,7 @@ def __enter__(self) -> None:
def __exit__(
self,
exc_type_: type[BaseException] | None,
value: BaseException | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
if self.previous_cvar_token is not None:
Expand Down
Loading
Loading