From 6cbbcc62226cf455493058d9379455ac2bff13e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Mon, 12 Jan 2026 11:37:33 +0100 Subject: [PATCH 01/23] Add basic instrumentation subpackage --- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/instrumentation/__init__.py | 8 ++++++++ src/gt4py/next/{ => instrumentation}/metrics.py | 0 src/gt4py/next/otf/compiled_program.py | 3 ++- .../runners/dace/workflow/decoration.py | 3 ++- .../runners/dace/workflow/translation.py | 3 ++- src/gt4py/next/program_processors/runners/gtfn.py | 3 ++- .../feature_tests/ffront_tests/test_decorator.py | 3 ++- tests/next_tests/unit_tests/test_metrics.py | 2 +- 9 files changed, 20 insertions(+), 7 deletions(-) create mode 100644 src/gt4py/next/instrumentation/__init__.py rename src/gt4py/next/{ => instrumentation}/metrics.py (100%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c23fff9a9a..b04cda7049 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,7 +32,6 @@ config, embedded as next_embedded, errors, - metrics, utils, ) from gt4py.next.embedded import operators as embedded_operators @@ -46,6 +45,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, stages, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation diff --git a/src/gt4py/next/instrumentation/__init__.py b/src/gt4py/next/instrumentation/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/src/gt4py/next/instrumentation/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/src/gt4py/next/metrics.py b/src/gt4py/next/instrumentation/metrics.py similarity index 100% rename from src/gt4py/next/metrics.py rename to src/gt4py/next/instrumentation/metrics.py diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index d873e7da17..e80a217a8a 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -17,8 +17,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping, utils as eve_utils -from gt4py.next import backend as gtx_backend, common, config, errors, metrics, utils as gtx_utils +from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront +from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 996ba7a095..de3d25aaa5 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -14,7 +14,8 @@ import numpy as np from gt4py._core import definitions as core_defs -from gt4py.next import common as gtx_common, config, metrics, utils as gtx_utils +from gt4py.next import common as gtx_common, config, utils as gtx_utils +from gt4py.next.instrumentation import metrics from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable from gt4py.next.program_processors.runners.dace.workflow import ( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index b98e0504fd..6e6b6b904b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -15,7 +15,8 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config, metrics +from gt4py.next import common, config +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7bd796785d..27363ea2a8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,8 +15,9 @@ import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils, metrics +from gt4py.next import backend, common, config, field_utils from gt4py.next.embedded import nd_array_field +from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index f32867be7d..ba620bce35 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -13,8 +13,9 @@ import pytest from gt4py import next as gtx +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir -from gt4py.next import common as gtx_common, metrics +from gt4py.next import common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( diff --git a/tests/next_tests/unit_tests/test_metrics.py b/tests/next_tests/unit_tests/test_metrics.py index 384435a091..c75c97b00f 100644 --- a/tests/next_tests/unit_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/test_metrics.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from gt4py.next import metrics +from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments From df46627f9b371023362d8533d90911400405a2b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Mon, 12 Jan 2026 11:37:56 +0100 Subject: [PATCH 02/23] Add machinery and tests and several refactors --- src/gt4py/next/ffront/decorator.py | 69 +++-- .../next/instrumentation/_hook_machinery.py | 184 ++++++++++++ src/gt4py/next/instrumentation/hooks.py | 15 + src/gt4py/next/instrumentation/metrics.py | 153 +++++----- src/gt4py/next/otf/compiled_program.py | 43 +-- .../ffront_tests/test_decorator.py | 2 +- .../instrumentation_tests/__init__.py | 8 + .../test_hook_machinery.py | 267 ++++++++++++++++++ .../test_metrics.py | 0 9 files changed, 624 insertions(+), 117 deletions(-) create mode 100644 src/gt4py/next/instrumentation/_hook_machinery.py create mode 100644 src/gt4py/next/instrumentation/hooks.py create mode 100644 tests/next_tests/unit_tests/instrumentation_tests/__init__.py create mode 100644 tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py rename tests/next_tests/unit_tests/{ => instrumentation_tests}/test_metrics.py (100%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index b04cda7049..e43b5eeabe 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -12,6 +12,7 @@ from __future__ import annotations +import contextlib import dataclasses import functools import time @@ -45,7 +46,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.instrumentation import metrics +from gt4py.next.instrumentation import metrics, _hook_machinery from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, stages, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -54,6 +55,29 @@ DEFAULT_BACKEND: next_backend.Backend | None = None +@_hook_machinery.ContextHook +def program_call_hook( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of a program call.""" + ... + + +@_hook_machinery.ContextHook +def embedded_program_call_hook( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of an embedded program call.""" + ... + + # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -275,27 +299,30 @@ def __call__( kwarg_types={k: type_translation.from_value(v) for k, v in kwargs.items()}, ) - if self.backend is not None: - self._compiled_programs( - *args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit - ) - else: - # Embedded execution. - # Metrics source key needs to be setup here, since embedded programs - # don't have variants and thus there's no other place we could do this. - if config.COLLECT_METRICS_LEVEL: - assert metrics_source is not None - metrics_source.key = ( - f"{self.__name__}<{getattr(self.backend, 'name', '')}>" + with program_call_hook(self, args, offset_provider, enable_jit, kwargs): + if self.backend is not None: + self._compiled_programs( + *args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit ) - warnings.warn( - UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." - ), - stacklevel=2, - ) - with next_embedded.context.update(offset_provider=offset_provider): - self.definition_stage.definition(*args, **kwargs) + else: + # Embedded execution. + warnings.warn( + UserWarning( + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." + ), + stacklevel=2, + ) + + # Metrics source key needs to be setup here, since embedded programs + # don't have variants so there's no other place to do it. + if config.COLLECT_METRICS_LEVEL: + assert metrics_source is not None + metrics.set_current_source_key( + f"{self.__name__}<{getattr(self.backend, 'name', '')}>" + ) + + with next_embedded.context.update(offset_provider=offset_provider): + self.definition_stage.definition(*args, **kwargs) if collect_info_metrics: assert metrics_source is not None diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py new file mode 100644 index 0000000000..b897d6aae2 --- /dev/null +++ b/src/gt4py/next/instrumentation/_hook_machinery.py @@ -0,0 +1,184 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import ast +import collections.abc +import contextlib +import dataclasses +import inspect +import textwrap +import types +import typing +import warnings +from collections.abc import Callable +from typing import Generic, ParamSpec, TypeVar + + +P = ParamSpec("P") +T = TypeVar("T") + + +def _get_unique_name(func: Callable) -> str: + """Generate a unique name for a callable object.""" + return ( + f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}" + ) + + +def _is_empty_function(func: Callable) -> bool: + """Check if a callable object is empty (i.e., contains no statements).""" + try: + callable_src = ( + inspect.getsource(func) + if isinstance(func, types.FunctionType) + else inspect.getsource(func.__call__) + ) + callable_ast = ast.parse(textwrap.dedent(callable_src)) + return all( + isinstance(st, ast.Pass) + or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant)) + for st in callable_ast.body[0].body + ) + except Exception: + return False + + +@dataclasses.dataclass(slots=True) +class _BaseHook(Generic[T, P]): + """Base class to define callback registration functionality for all hook types.""" + + definition: Callable[P, T] + registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) + callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) + + def __post_init__(self) -> None: + # As an optimization to avoid an empty function call if no callbacks are + # registered, we only add the original definitions to the list of callables + # if it contains a non-empty definition. + if not _is_empty_function(self.definition): + self.callbacks = (self.definition,) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + raise NotImplementedError("This method should be implemented by subclasses.") + + def register( + self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None + ) -> None: + """ + Register a callback to the hook. + + Args: + callback: The callable to register. + name: An optional name for the callback. If not provided, a unique name will be generated. + index: An optional index at which to insert the callback (not counting the original + definition). If not provided, the callback will be appended to the end of the list. + """ + + callable_signature = inspect.signature(callback) + hook_signature = inspect.signature(self.definition) + + signature_mismatch = len(callable_signature.parameters) != len( + hook_signature.parameters + ) or any( + # Remove the annotation before comparison to avoid false mismatches + actual_param.replace(annotation="") != expected_param.replace(annotation="") + for actual_param, expected_param in zip( + callable_signature.parameters.values(), hook_signature.parameters.values() + ) + ) + if signature_mismatch: + raise ValueError( + f"Callback signature {callable_signature} does not match hook signature {hook_signature}" + ) + try: + callable_typing = typing.get_type_hints(callback) + hook_typing = typing.get_type_hints(self.definition) + if not all( + callable_typing[arg_key] == arg_typing + for arg_key, arg_typing in hook_typing.items() + ): + warnings.warn( + f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}", + stacklevel=2, + ) + except Exception: + pass + name = name or _get_unique_name(callback) + + if index is None: + self.callbacks += (callback,) + else: + if self.callbacks and self.callbacks[0] is self.definition: + index += 1 # The original definition should always go first + self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:]) + + self.registry[name] = callback + + def remove(self, callback: str | Callable[P, T]) -> None: + """ + Remove a registered callback from the hook. + + Args: + callback: The callable object to remove or its registered name. + """ + if isinstance(callback, str): + name = callback + if name not in self.registry: + raise KeyError(f"No callback registered under the name '{name}'") + else: + name = _get_unique_name(callback) + if name not in self.registry: + raise KeyError(f"Callback object {callback} not found in registry") + + callback = self.registry.pop(name) + assert callback in self.callbacks + self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback) + + +@dataclasses.dataclass(slots=True) +class EventHook(_BaseHook[None, P]): + """Event hook specification.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + for func in self.callbacks: + func(*args, **kwargs) + + +@dataclasses.dataclass(slots=True) +class ContextHook( + contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P] +): + """ + Context hook specification. + + This hook type is used to define context managers that can be stacked together. + """ + + ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field( + default=(), init=False + ) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager: + self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks] + return self + + def __enter__(self) -> None: + for ctx_manager in self.ctx_managers: + ctx_manager.__enter__() + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + for ctx_manager in reversed(self.ctx_managers): + ctx_manager.__exit__(type_, value, traceback) + self.ctx_managers = () diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py new file mode 100644 index 0000000000..32e027615a --- /dev/null +++ b/src/gt4py/next/instrumentation/hooks.py @@ -0,0 +1,15 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from gt4py.next.ffront.decorator import ( + program_call_hook as program_call_hook, + embedded_program_call_hook as embedded_program_call_hook, +) +from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index b6a4beaeef..a5c871f360 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -116,94 +116,91 @@ class Source: metadata: dict[str, Any] = dataclasses.field(default_factory=dict) metrics: MetricsCollection = dataclasses.field(default_factory=MetricsCollection) + assigned_key: str | None = dataclasses.field(default=None, init=False) + #: Global store for all measurements. sources: collections.defaultdict[str, Source] = collections.defaultdict(Source) +# Context variable storing the active collection context. +_source_cvar: contextvars.ContextVar[Source | None] = contextvars.ContextVar("source", default=None) -class SourceHandler: - """ - A handler to manage addition of metrics sources to the global store. - This object is used to collect metrics for a specific source (e.g., a program) - before a final key is assigned to it. The key is typically set when the program - is first executed or compiled, and it uniquely identifies the source in the - global metrics store. - """ +def get_current_source() -> Source: + """Retrieve the active metrics collection source.""" + metrics_source = _source_cvar.get() + assert metrics_source is not None + return metrics_source - def __init__(self, source: Source | None = None) -> None: - if source is not None: - self.source = source - @property - def key(self) -> str | None: - return self.__dict__.get("_key", None) - - @key.setter - def key(self, value: str) -> None: - # The key can only be set once, and if it matches an existing source - # in the global store, it must be the same object. - if self.key is not None and self.key != value: - raise RuntimeError("Metrics source key is already set.") - - if value not in sources: - sources[value] = self.source - else: - source_in_store = sources[value] - if self.__dict__.setdefault("source", source_in_store) is not source_in_store: - raise RuntimeError("Conflicting metrics source data found in the global store.") +def is_current_source_set() -> bool: + """Check if there is an on-going metrics collection.""" + return _source_cvar.get() is not None - self._key = value - # The following attributes are implemented as `cached_properties` - # for efficiency and to be able to initialize them lazily when needed, - # even if the key is not set. - @functools.cached_property - def source(self) -> Source: - return Source() +def set_current_source_key(key: str) -> Source: + if not is_current_source_set(): + raise RuntimeError("No active metrics collection to assign source to.") - @functools.cached_property - def metrics(self) -> MetricsCollection: - return self.source.metrics + metrics_source = get_current_source() + if key in sources and metrics_source is not sources[key]: + # The key can only be set once, and if it matches an existing entry + # in the global store, then it must be exactly the same source object. + raise RuntimeError("Conflicting metrics source data found in the global store.") - @functools.cached_property - def metadata(self) -> dict[str, Any]: - return self.source.metadata + sources[key] = metrics_source + metrics_source.assigned_key = key + return metrics_source -# Context variable storing the active collection context. -_source_cvar: contextvars.ContextVar[SourceHandler | None] = contextvars.ContextVar( - "source", default=None -) +# class SourceHandler: +# """ +# A handler to manage addition of metrics sources to the global store. +# This object is used to collect metrics for a specific source (e.g., a program) +# before a final key is assigned to it. The key is typically set when the program +# is first executed or compiled, and it uniquely identifies the source in the +# global metrics store. +# """ -def in_collection_mode() -> bool: - """Check if there is an on-going metrics collection.""" - return _source_cvar.get() is not None +# def __init__(self, source: Source | None = None) -> None: +# if source is not None: +# self.source = source +# @property +# def key(self) -> str | None: +# return self.__dict__.get("_key", None) -def get_current_source() -> SourceHandler: - """Retrieve the active metrics collection source.""" - source_handler = _source_cvar.get() - assert source_handler is not None - return source_handler +# @key.setter +# def key(self, value: str) -> None: +# # The key can only be set once, and if it matches an existing source +# # in the global store, it must be the same object. +# if self.key is not None and self.key != value: +# raise RuntimeError("Metrics source key is already set.") +# if value not in sources: +# sources[value] = self.source +# else: +# source_in_store = sources[value] +# if self.__dict__.setdefault("source", source_in_store) is not source_in_store: +# raise RuntimeError("Conflicting metrics source data found in the global store.") -def get_source(key: str, *, assign_current: bool = True) -> Source: - """ - Retrieve a metrics source by its key, optionally associating it to the current context. - """ - if in_collection_mode() and assign_current: - metrics_source_handler = get_current_source() - # Set the key if not already set, which will also add the - # source to the global store. Note that if the key is already set, - # this will only succeed if the same object. - metrics_source_handler.key = key - metrics_source = metrics_source_handler.source - else: - metrics_source = sources[key] +# self._key = value - return metrics_source +# # The following attributes are implemented as `cached_properties` +# # for efficiency and to be able to initialize them lazily when needed, +# # even if the key is not set. +# @functools.cached_property +# def source(self) -> Source: +# return Source() + +# @functools.cached_property +# def metrics(self) -> MetricsCollection: +# return self.source.metrics + +# @functools.cached_property +# def metadata(self) -> dict[str, Any]: +# return self.source.metadata class CollectorContextManager(contextlib.AbstractContextManager): @@ -221,19 +218,19 @@ class CollectorContextManager(contextlib.AbstractContextManager): of renewing the generator inside `contextlib.contextmanager`. """ - __slots__ = ("previous_collector_token", "source_handler") + __slots__ = ("previous_collector_token", "source") - source_handler: SourceHandler | None + source: Source | None previous_collector_token: contextvars.Token | None - def __enter__(self) -> SourceHandler | None: + def __enter__(self) -> Source | None: if config.COLLECT_METRICS_LEVEL > 0: assert _source_cvar.get() is None - self.source_handler = SourceHandler() - self.previous_collector_token = _source_cvar.set(self.source_handler) - return self.source_handler + self.source = new_source = Source() + self.previous_collector_token = _source_cvar.set(new_source) + return new_source else: - self.source_handler = self.previous_collector_token = None + self.source = self.previous_collector_token = None return None def __exit__( @@ -244,13 +241,11 @@ def __exit__( ) -> None: if self.previous_collector_token is not None: _source_cvar.reset(self.previous_collector_token) - if type_ is None: - if self.source_handler is not None and self.source_handler.key is None: - raise RuntimeError("Metrics source key was not set during collection.") + if type_ is None and self.source is not None and self.source.assigned_key is None: + raise RuntimeError("Metrics source key was not set during collection.") -def collect() -> CollectorContextManager: - return CollectorContextManager() +collect = CollectorContextManager def dumps(metric_sources: Mapping[str, Source] | None = None) -> str: diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index e80a217a8a..6b72fa4435 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -19,7 +19,7 @@ from gt4py.eve import extended_typing as xtyping, utils as eve_utils from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront -from gt4py.next.instrumentation import metrics +from gt4py.next.instrumentation import metrics, _hook_machinery from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -27,6 +27,19 @@ T = TypeVar("T") +ScalarOrTupleOfScalars: TypeAlias = xtyping.MaybeNestedInTuple[core_defs.Scalar] +CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int] +ArgumentDescriptors: TypeAlias = dict[ + type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] +] +ArgumentDescriptorContext: TypeAlias = dict[ + str, xtyping.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None] +] +ArgumentDescriptorContexts: TypeAlias = dict[ + type[arguments.ArgStaticDescriptor], + ArgumentDescriptorContext, +] + # TODO(havogt): We would like this to be a ProcessPoolExecutor, which requires (to decide what) to pickle. _async_compilation_pool: concurrent.futures.Executor | None = None @@ -41,18 +54,16 @@ def _init_async_compilation_pool() -> None: _init_async_compilation_pool() -ScalarOrTupleOfScalars: TypeAlias = xtyping.MaybeNestedInTuple[core_defs.Scalar] -CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int] -ArgumentDescriptors: TypeAlias = dict[ - type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] -] -ArgumentDescriptorContext: TypeAlias = dict[ - str, xtyping.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None] -] -ArgumentDescriptorContexts: TypeAlias = dict[ - type[arguments.ArgStaticDescriptor], - ArgumentDescriptorContext, -] + +@_hook_machinery.EventHook +def compile_variant_hook( + key: CompiledProgramsKey, + backend: gtx_backend.Backend, + program_definition: ffront_stages.ProgramDefinition, + compile_time_args: arguments.CompileTimeArgs, +) -> None: + """Callback hook invoked before compiling a program variant.""" + ... def wait_for_compilation() -> None: @@ -264,8 +275,7 @@ def __call__( try: program = self.compiled_programs[key] if config.COLLECT_METRICS_LEVEL: - metrics_source = metrics.get_current_source() - metrics_source.key = self._metrics_key_from_pool_key(key) + metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) program(*args, **kwargs, offset_provider=offset_provider) # type: ignore[operator] # the Future case is handled below @@ -421,7 +431,7 @@ def _compile_variant( # If we are collecting metrics, create a new metrics entity for this compiled program if config.COLLECT_METRICS_LEVEL: - metrics_source = metrics.get_source(self._metrics_key_from_pool_key(key)) + metrics_source = metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) metrics_source.metadata |= dict( name=self.definition_stage.definition.__name__, backend=self.backend.name, @@ -443,6 +453,7 @@ def _compile_variant( compile_call = functools.partial( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args ) + compile_variant_hook(key, self.backend, self.definition_stage, compile_time_args) if _async_compilation_pool is None: # synchronous compilation self.compiled_programs[key] = compile_call() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index ba620bce35..b763647874 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -93,7 +93,7 @@ def testee(a: cases.IField, out: cases.IField): with ( mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics_level), - mock.patch("gt4py.next.metrics.sources", collections.defaultdict(metrics.Source)), + mock.patch("gt4py.next.instrumentation.metrics.sources", collections.defaultdict(metrics.Source)), ): testee = testee.with_backend(cartesian_case.backend).with_grid_type( cartesian_case.grid_type diff --git a/tests/next_tests/unit_tests/instrumentation_tests/__init__.py b/tests/next_tests/unit_tests/instrumentation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/unit_tests/instrumentation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py new file mode 100644 index 0000000000..5834fe266f --- /dev/null +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -0,0 +1,267 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib + +import pytest + +from gt4py.next.instrumentation._hook_machinery import ( + EventHook, + ContextHook, + _get_unique_name, + _is_empty_function, +) + + +def test_get_unique_name(): + def func1(): + pass + + def func2(): + pass + + assert _get_unique_name(func1) != _get_unique_name(func2) + + class A: + def __call__(self): ... + + assert _get_unique_name(A) == _get_unique_name(A) + + a1, a2 = A(), A() + + assert (a1_name := _get_unique_name(a1)) != (a2_name := _get_unique_name(a2)) + assert _get_unique_name(a1) == a1_name + assert _get_unique_name(a2) == a2_name + + +def test_empty_function(): + def empty(): + pass + + assert _is_empty_function(empty) is True + + def non_empty(): + x = 1 + + assert _is_empty_function(non_empty) is False + + def with_docstring(): + """This is a docstring.""" + + assert _is_empty_function(with_docstring) is True + + def with_ellipsis(): ... + + assert _is_empty_function(with_ellipsis) is True + + class A: + def __call__(self): ... + + assert _is_empty_function(A()) is True + + +class TestEventHook: + def test_event_hook_call_with_no_callbacks(self): + def hook_def(x: int) -> None: + pass + + hook = EventHook(definition=hook_def) + hook(42) # Should not raise + + def test_event_hook_call_with_callbacks(self): + results = [] + + def hook_def(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(x) + + def callback2(x: int) -> None: + results.append(x * 2) + + hook = EventHook(definition=hook_def) + hook.register(callback1) + hook.register(callback2) + hook(5) + + assert results == [5, 10] + + def test_event_hook_register_with_signature_mismatch(self): + def hook_def(x: int) -> None: + pass + + def bad_callback(x: int, y: int) -> None: + pass + + hook = EventHook(definition=hook_def) + with pytest.raises(ValueError, match="Callback signature"): + hook.register(bad_callback) + + def test_event_hook_register_with_annotation_mismatch(self): + def hook_def(x: int) -> None: + pass + + def weird_callback(x: str) -> None: + pass + + hook = EventHook(definition=hook_def) + + with pytest.warns(UserWarning, match="Callback annotations"): + hook.register(weird_callback) + + def test_event_hook_register_with_name(self): + def hook_def(x: int) -> None: + pass + + def callback(x: int) -> None: + pass + + hook = EventHook(definition=hook_def) + hook.register(callback, name="my_callback") + + assert "my_callback" in hook.registry + + def test_event_hook_register_with_index(self): + results = [] + + def hook_def(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(1) + + def callback2(x: int) -> None: + results.append(2) + + hook = EventHook(definition=hook_def) + hook.register(callback1) + hook.register(callback2, index=0) + hook(0) + + assert results == [2, 1] + + def test_event_hook_remove_by_name(self): + results = [] + + def hook_def(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook = EventHook(definition=hook_def) + hook.register(callback, name="test_cb") + hook(42) + assert results == [42] + + hook.remove("test_cb") + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_by_callback(self): + results = [] + + def hook_def(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook = EventHook(definition=hook_def) + hook.register(callback) + hook(42) + assert results == [42] + + hook.remove(callback) + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_nonexistent_raises(self): + def hook_def(x: int) -> None: + pass + + hook = EventHook(definition=hook_def) + with pytest.raises(KeyError): + hook.remove("nonexistent") + + +class TestContextHook: + def test_context_hook_basic(self): + enter_called = [] + exit_called = [] + + def hook_def() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(): + enter_called.append(True) + yield + exit_called.append(True) + + hook = ContextHook(definition=hook_def) + hook.register(callback) + + with hook(): + assert len(enter_called) == 1 + + assert len(exit_called) == 1 + + def test_context_hook_multiple_callbacks(self): + order = [] + + def hook_def() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback1(): + order.append("enter1") + yield + order.append("exit1") + + @contextlib.contextmanager + def callback2(): + order.append("enter2") + yield + order.append("exit2") + + hook = ContextHook(definition=hook_def) + hook.register(callback1) + hook.register(callback2) + + with hook(): + pass + + # Entry in order, but exit in reverse + assert order == ["enter1", "enter2", "exit2", "exit1"] + + def test_context_hook_with_arguments(self): + results = [] + + def hook_def(x: int) -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(x: int): + results.append(x) + yield + + hook = ContextHook(definition=hook_def) + hook.register(callback) + + with hook(42): + pass + + assert results == [42] diff --git a/tests/next_tests/unit_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py similarity index 100% rename from tests/next_tests/unit_tests/test_metrics.py rename to tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py From becb7ba5b1d730f6115cca9f3c3e85c191314639 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 14 Jan 2026 18:14:15 +0100 Subject: [PATCH 03/23] Run precommit --- src/gt4py/next/ffront/decorator.py | 6 +-- .../next/instrumentation/_hook_machinery.py | 7 +-- src/gt4py/next/instrumentation/hooks.py | 2 +- src/gt4py/next/instrumentation/metrics.py | 51 ------------------- src/gt4py/next/otf/compiled_program.py | 2 +- .../ffront_tests/test_decorator.py | 4 +- 6 files changed, 12 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e43b5eeabe..80d06217da 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -46,7 +46,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.instrumentation import metrics, _hook_machinery +from gt4py.next.instrumentation import _hook_machinery, metrics from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, stages, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -56,7 +56,7 @@ @_hook_machinery.ContextHook -def program_call_hook( +def program_call_hook( # type: ignore[empty-body] program: Program, args: tuple[Any, ...], offset_provider: common.OffsetProvider, @@ -68,7 +68,7 @@ def program_call_hook( @_hook_machinery.ContextHook -def embedded_program_call_hook( +def embedded_program_call_hook( # type: ignore[empty-body] program: Program, args: tuple[Any, ...], offset_provider: common.OffsetProvider, diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py index b897d6aae2..292e778232 100644 --- a/src/gt4py/next/instrumentation/_hook_machinery.py +++ b/src/gt4py/next/instrumentation/_hook_machinery.py @@ -35,16 +35,17 @@ def _get_unique_name(func: Callable) -> str: def _is_empty_function(func: Callable) -> bool: """Check if a callable object is empty (i.e., contains no statements).""" try: + assert callable(func) callable_src = ( inspect.getsource(func) if isinstance(func, types.FunctionType) - else inspect.getsource(func.__call__) + else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above ) callable_ast = ast.parse(textwrap.dedent(callable_src)) return all( isinstance(st, ast.Pass) or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant)) - for st in callable_ast.body[0].body + for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body ) except Exception: return False @@ -157,7 +158,7 @@ class ContextHook( ): """ Context hook specification. - + This hook type is used to define context managers that can be stacked together. """ diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py index 32e027615a..62fa26a7bd 100644 --- a/src/gt4py/next/instrumentation/hooks.py +++ b/src/gt4py/next/instrumentation/hooks.py @@ -9,7 +9,7 @@ from __future__ import annotations from gt4py.next.ffront.decorator import ( - program_call_hook as program_call_hook, embedded_program_call_hook as embedded_program_call_hook, + program_call_hook as program_call_hook, ) from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index a5c871f360..7043d44703 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -12,7 +12,6 @@ import contextlib import contextvars import dataclasses -import functools import itertools import json import numbers @@ -153,56 +152,6 @@ def set_current_source_key(key: str) -> Source: return metrics_source -# class SourceHandler: -# """ -# A handler to manage addition of metrics sources to the global store. - -# This object is used to collect metrics for a specific source (e.g., a program) -# before a final key is assigned to it. The key is typically set when the program -# is first executed or compiled, and it uniquely identifies the source in the -# global metrics store. -# """ - -# def __init__(self, source: Source | None = None) -> None: -# if source is not None: -# self.source = source - -# @property -# def key(self) -> str | None: -# return self.__dict__.get("_key", None) - -# @key.setter -# def key(self, value: str) -> None: -# # The key can only be set once, and if it matches an existing source -# # in the global store, it must be the same object. -# if self.key is not None and self.key != value: -# raise RuntimeError("Metrics source key is already set.") - -# if value not in sources: -# sources[value] = self.source -# else: -# source_in_store = sources[value] -# if self.__dict__.setdefault("source", source_in_store) is not source_in_store: -# raise RuntimeError("Conflicting metrics source data found in the global store.") - -# self._key = value - -# # The following attributes are implemented as `cached_properties` -# # for efficiency and to be able to initialize them lazily when needed, -# # even if the key is not set. -# @functools.cached_property -# def source(self) -> Source: -# return Source() - -# @functools.cached_property -# def metrics(self) -> MetricsCollection: -# return self.source.metrics - -# @functools.cached_property -# def metadata(self) -> dict[str, Any]: -# return self.source.metadata - - class CollectorContextManager(contextlib.AbstractContextManager): """ A context manager to handle metrics collection. diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 6b72fa4435..bbc0f17fb1 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -19,7 +19,7 @@ from gt4py.eve import extended_typing as xtyping, utils as eve_utils from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront -from gt4py.next.instrumentation import metrics, _hook_machinery +from gt4py.next.instrumentation import _hook_machinery, metrics from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index b763647874..a6b6d69a5e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -93,7 +93,9 @@ def testee(a: cases.IField, out: cases.IField): with ( mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics_level), - mock.patch("gt4py.next.instrumentation.metrics.sources", collections.defaultdict(metrics.Source)), + mock.patch( + "gt4py.next.instrumentation.metrics.sources", collections.defaultdict(metrics.Source) + ), ): testee = testee.with_backend(cartesian_case.backend).with_grid_type( cartesian_case.grid_type From 68d22831da64e261fae74ee15b35d95a18fca06e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 15 Jan 2026 13:50:21 +0100 Subject: [PATCH 04/23] Remove boilerplate from unit tests --- .../test_hook_machinery.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py index 5834fe266f..ec323e54f3 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -69,16 +69,17 @@ def __call__(self): ... class TestEventHook: def test_event_hook_call_with_no_callbacks(self): - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass - hook = EventHook(definition=hook_def) hook(42) # Should not raise def test_event_hook_call_with_callbacks(self): results = [] - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def callback1(x: int) -> None: @@ -87,7 +88,6 @@ def callback1(x: int) -> None: def callback2(x: int) -> None: results.append(x * 2) - hook = EventHook(definition=hook_def) hook.register(callback1) hook.register(callback2) hook(5) @@ -95,36 +95,35 @@ def callback2(x: int) -> None: assert results == [5, 10] def test_event_hook_register_with_signature_mismatch(self): - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def bad_callback(x: int, y: int) -> None: pass - hook = EventHook(definition=hook_def) with pytest.raises(ValueError, match="Callback signature"): hook.register(bad_callback) def test_event_hook_register_with_annotation_mismatch(self): - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def weird_callback(x: str) -> None: pass - hook = EventHook(definition=hook_def) - with pytest.warns(UserWarning, match="Callback annotations"): hook.register(weird_callback) def test_event_hook_register_with_name(self): - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def callback(x: int) -> None: pass - hook = EventHook(definition=hook_def) hook.register(callback, name="my_callback") assert "my_callback" in hook.registry @@ -132,7 +131,8 @@ def callback(x: int) -> None: def test_event_hook_register_with_index(self): results = [] - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def callback1(x: int) -> None: @@ -141,7 +141,6 @@ def callback1(x: int) -> None: def callback2(x: int) -> None: results.append(2) - hook = EventHook(definition=hook_def) hook.register(callback1) hook.register(callback2, index=0) hook(0) @@ -151,13 +150,13 @@ def callback2(x: int) -> None: def test_event_hook_remove_by_name(self): results = [] - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def callback(x: int) -> None: results.append(x) - hook = EventHook(definition=hook_def) hook.register(callback, name="test_cb") hook(42) assert results == [42] @@ -171,13 +170,13 @@ def callback(x: int) -> None: def test_event_hook_remove_by_callback(self): results = [] - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass def callback(x: int) -> None: results.append(x) - hook = EventHook(definition=hook_def) hook.register(callback) hook(42) assert results == [42] @@ -189,10 +188,10 @@ def callback(x: int) -> None: assert results == [] def test_event_hook_remove_nonexistent_raises(self): - def hook_def(x: int) -> None: + @EventHook + def hook(x: int) -> None: pass - hook = EventHook(definition=hook_def) with pytest.raises(KeyError): hook.remove("nonexistent") @@ -202,7 +201,8 @@ def test_context_hook_basic(self): enter_called = [] exit_called = [] - def hook_def() -> contextlib.AbstractContextManager: + @ContextHook + def hook() -> contextlib.AbstractContextManager: pass @contextlib.contextmanager @@ -211,7 +211,6 @@ def callback(): yield exit_called.append(True) - hook = ContextHook(definition=hook_def) hook.register(callback) with hook(): @@ -222,7 +221,8 @@ def callback(): def test_context_hook_multiple_callbacks(self): order = [] - def hook_def() -> contextlib.AbstractContextManager: + @ContextHook + def hook() -> contextlib.AbstractContextManager: pass @contextlib.contextmanager @@ -237,7 +237,6 @@ def callback2(): yield order.append("exit2") - hook = ContextHook(definition=hook_def) hook.register(callback1) hook.register(callback2) @@ -250,7 +249,8 @@ def callback2(): def test_context_hook_with_arguments(self): results = [] - def hook_def(x: int) -> contextlib.AbstractContextManager: + @ContextHook + def hook(x: int) -> contextlib.AbstractContextManager: pass @contextlib.contextmanager @@ -258,7 +258,6 @@ def callback(x: int): results.append(x) yield - hook = ContextHook(definition=hook_def) hook.register(callback) with hook(42): From a1963e9bb40beeea2dfeee54602dc203e33983fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 15 Jan 2026 17:37:17 +0100 Subject: [PATCH 05/23] Add integration tests --- src/gt4py/next/ffront/decorator.py | 3 +- .../instrumentation_tests/__init__.py | 8 + .../instrumentation_tests/test_hooks.py | 166 ++++++++++++++++++ 3 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py create mode 100644 tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 80d06217da..a3fc33ddcc 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -322,7 +322,8 @@ def __call__( ) with next_embedded.context.update(offset_provider=offset_provider): - self.definition_stage.definition(*args, **kwargs) + with embedded_program_call_hook(self, args, offset_provider, kwargs): + self.definition_stage.definition(*args, **kwargs) if collect_info_metrics: assert metrics_source is not None diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py new file mode 100644 index 0000000000..f2771cd844 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -0,0 +1,166 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib +from collections.abc import Hashable +from typing import Any, Final + +import pytest + +from gt4py.next import Dims, gtfn_cpu, broadcast, typing as gtx_typing +import gt4py.next as gtx +from gt4py.next.instrumentation import hooks +from gt4py.next import common, backend as gtx_backend +from gt4py.next.typing import Program + +try: + from gt4py.next.program_processors.runners import dace as dace_backends + + BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu_cached] +except ImportError: + BACKENDS = [None, gtfn_cpu] + + +callback_results = [] +embedded_callback_results = [] + + +@contextlib.contextmanager +def custom_program_callback( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + callback_results.append(("enter", None)) + + yield + + callback_results.append(( + "custom_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "enable_jit": enable_jit, + "kwargs": kwargs.keys(), + }, + )) + + +@contextlib.contextmanager +def custom_embedded_program_callback( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + embedded_callback_results.append(("enter", None)) + + yield + + embedded_callback_results.append(( + "custom_embedded_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "kwargs": kwargs.keys(), + }, + )) + + +# @hooks.compile_variant_hook.register +# def compile_variant_hook( +# key: tuple[tuple[Hashable, ...], int], +# backend: gtx_backend.Backend, +# program_definition: "ffront_stages.ProgramDefinition", +# compile_time_args: "arguments.CompileTimeArgs", +# ) -> None: +# """Callback hook invoked before compiling a program variant.""" +# ... + + +Cell = gtx.Dimension("Cell") +IDim = gtx.Dimension("IDim") + + +@gtx.field_operator +def identity_fop( + in_field: gtx.Field[Dims[IDim], gtx.float64], +) -> gtx.Field[Dims[IDim], gtx.float64]: + return in_field + + +@gtx.program +def copy_program( + in_field: gtx.Field[Dims[IDim], gtx.float64], out: gtx.Field[Dims[IDim], gtx.float64] +): + identity_fop(in_field, out=out) + + +@pytest.mark.parametrize("backend", BACKENDS, ids=lambda b: getattr(b, "name", str(b))) +def test_program_call_hooks(backend: gtx_typing.Backend): + size = 10 + in_field = gtx.full([(IDim, size)], 1, dtype=gtx.float64) + out_field = gtx.empty([(IDim, size)], dtype=gtx.float64) + + test_program = copy_program.with_backend(backend) + + # Run the program without hooks + callback_results.clear() + embedded_callback_results.clear() + test_program(in_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() + + # Add hooks and run the program again + hooks.program_call_hook.register(custom_program_callback) + hooks.embedded_program_call_hook.register(custom_embedded_program_callback) + test_program(in_field, out=out_field) + + # Check that the callbacks were called + assert len(callback_results) == 2 + assert callback_results[0] == ("enter", None) + + hook_name, hook_call_info = callback_results[1] + assert hook_name == "custom_program_callback" + assert hook_call_info["program"] == test_program.__name__ + + # The embedded program call hook should have also been called + # with the embedded backend + if backend is None: + assert len(embedded_callback_results) == 2 + assert embedded_callback_results[0] == ("enter", None) + + hook_name, hook_call_info = embedded_callback_results[1] + assert hook_name == "custom_embedded_program_callback" + assert hook_call_info["program"] == copy_program.__name__ + else: + assert len(embedded_callback_results) == 0 + + callback_results.clear() + embedded_callback_results.clear() + + # Remove hooks and call the program again + hooks.program_call_hook.remove(custom_program_callback) + hooks.embedded_program_call_hook.remove(custom_embedded_program_callback) + test_program(in_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() From 9c5729a5143a5b72732ab85ba1d9b00988585972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 15 Jan 2026 17:39:53 +0100 Subject: [PATCH 06/23] Run pre-commit --- .../instrumentation_tests/test_hooks.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py index f2771cd844..ad371469b0 100644 --- a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -44,16 +44,18 @@ def custom_program_callback( yield - callback_results.append(( - "custom_program_callback", - { - "program": program.__name__, - "args": args, - "offset_provider": offset_provider.keys(), - "enable_jit": enable_jit, - "kwargs": kwargs.keys(), - }, - )) + callback_results.append( + ( + "custom_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "enable_jit": enable_jit, + "kwargs": kwargs.keys(), + }, + ) + ) @contextlib.contextmanager @@ -67,15 +69,17 @@ def custom_embedded_program_callback( yield - embedded_callback_results.append(( - "custom_embedded_program_callback", - { - "program": program.__name__, - "args": args, - "offset_provider": offset_provider.keys(), - "kwargs": kwargs.keys(), - }, - )) + embedded_callback_results.append( + ( + "custom_embedded_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "kwargs": kwargs.keys(), + }, + ) + ) # @hooks.compile_variant_hook.register From 86f4e01cfc051beeb12132438fdde4d23186bad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 15 Jan 2026 20:46:11 +0100 Subject: [PATCH 07/23] Address copilot review comments (mostly cosmetic) --- src/gt4py/next/instrumentation/_hook_machinery.py | 3 +++ .../feature_tests/instrumentation_tests/test_hooks.py | 11 ++++------- .../instrumentation_tests/test_hook_machinery.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py index 292e778232..e1eb394a26 100644 --- a/src/gt4py/next/instrumentation/_hook_machinery.py +++ b/src/gt4py/next/instrumentation/_hook_machinery.py @@ -110,7 +110,10 @@ def register( stacklevel=2, ) except Exception: + # Ignore issues while checking type hints (e.g., forward references + # or missing imports); failure here should not prevent hook registration. pass + name = name or _get_unique_name(callback) if index is None: diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py index ad371469b0..e5a25b3cd3 100644 --- a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -9,16 +9,13 @@ from __future__ import annotations import contextlib -from collections.abc import Hashable -from typing import Any, Final +from typing import Any import pytest -from gt4py.next import Dims, gtfn_cpu, broadcast, typing as gtx_typing import gt4py.next as gtx +from gt4py.next import common, Dims, gtfn_cpu, typing as gtx_typing from gt4py.next.instrumentation import hooks -from gt4py.next import common, backend as gtx_backend -from gt4py.next.typing import Program try: from gt4py.next.program_processors.runners import dace as dace_backends @@ -34,7 +31,7 @@ @contextlib.contextmanager def custom_program_callback( - program: Program, + program: gtx_typing.Program, args: tuple[Any, ...], offset_provider: common.OffsetProvider, enable_jit: bool, @@ -60,7 +57,7 @@ def custom_program_callback( @contextlib.contextmanager def custom_embedded_program_callback( - program: Program, + program: gtx_typing.Program, args: tuple[Any, ...], offset_provider: common.OffsetProvider, kwargs: dict[str, Any], diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py index ec323e54f3..e8711e3913 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -48,7 +48,7 @@ def empty(): assert _is_empty_function(empty) is True def non_empty(): - x = 1 + return 1 assert _is_empty_function(non_empty) is False From c38939e5fd291fae00e14064c1a0584c22ea6e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Mon, 19 Jan 2026 16:14:02 +0100 Subject: [PATCH 08/23] WIP metrics refactoring --- src/gt4py/next/ffront/decorator.py | 70 +++++---- .../next/instrumentation/_hook_machinery.py | 6 + src/gt4py/next/instrumentation/hooks.py | 2 +- src/gt4py/next/instrumentation/metrics.py | 148 +++++++++++++----- .../ffront_tests/test_decorator.py | 3 +- .../instrumentation_tests/test_hooks.py | 4 +- 6 files changed, 153 insertions(+), 80 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a3fc33ddcc..dd28cb3057 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -20,7 +20,7 @@ import typing import warnings from collections.abc import Callable, Sequence -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Final, Generic, Optional, TypeVar from gt4py import eve from gt4py._core import definitions as core_defs @@ -55,8 +55,18 @@ DEFAULT_BACKEND: next_backend.Backend | None = None +class ProgramCallMetricsCollector(metrics.AbstractCollectorContextManager[float]): + metric_name: Final[str] = metrics.TOTAL_METRIC + + def enter_collection_mode(self) -> float: + return time.perf_counter() + + def exit_collection_mode(self, enter_state: float) -> float: + return time.perf_counter() - enter_state + + @_hook_machinery.ContextHook -def program_call_hook( # type: ignore[empty-body] +def program_call_context( program: Program, args: tuple[Any, ...], offset_provider: common.OffsetProvider, @@ -64,7 +74,7 @@ def program_call_hook( # type: ignore[empty-body] kwargs: dict[str, Any], ) -> contextlib.AbstractContextManager: """Hook called at the beginning and end of a program call.""" - ... + return ProgramCallMetricsCollector() @_hook_machinery.ContextHook @@ -287,10 +297,7 @@ def __call__( self.enable_jit if self.enable_jit is not None else config.ENABLE_JIT_DEFAULT ) - with metrics.collect() as metrics_source: - if collect_info_metrics := (config.COLLECT_METRICS_LEVEL >= metrics.INFO): - start = time.perf_counter() - + with program_call_context(self, args, offset_provider, enable_jit, kwargs): if __debug__: # TODO: remove or make dependency on self.past_stage optional past_process_args._validate_args( @@ -299,35 +306,30 @@ def __call__( kwarg_types={k: type_translation.from_value(v) for k, v in kwargs.items()}, ) - with program_call_hook(self, args, offset_provider, enable_jit, kwargs): - if self.backend is not None: - self._compiled_programs( - *args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit - ) - else: - # Embedded execution. - warnings.warn( - UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." - ), - stacklevel=2, + if self.backend is not None: + self._compiled_programs( + *args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit + ) + else: + # Embedded execution. + warnings.warn( + UserWarning( + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." + ), + stacklevel=2, + ) + + # Metrics source key needs to be setup here, since embedded programs + # don't have variants so there's no other place to do it. + if config.COLLECT_METRICS_LEVEL: + #assert metrics_source is not None + metrics.set_current_source_key( + f"{self.__name__}<{getattr(self.backend, 'name', '')}>" ) - # Metrics source key needs to be setup here, since embedded programs - # don't have variants so there's no other place to do it. - if config.COLLECT_METRICS_LEVEL: - assert metrics_source is not None - metrics.set_current_source_key( - f"{self.__name__}<{getattr(self.backend, 'name', '')}>" - ) - - with next_embedded.context.update(offset_provider=offset_provider): - with embedded_program_call_hook(self, args, offset_provider, kwargs): - self.definition_stage.definition(*args, **kwargs) - - if collect_info_metrics: - assert metrics_source is not None - metrics_source.metrics[metrics.TOTAL_METRIC].add_sample(time.perf_counter() - start) + with next_embedded.context.update(offset_provider=offset_provider): + with embedded_program_call_hook(self, args, offset_provider, kwargs): + self.definition_stage.definition(*args, **kwargs) def compile( self, diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py index e1eb394a26..0c23145212 100644 --- a/src/gt4py/next/instrumentation/_hook_machinery.py +++ b/src/gt4py/next/instrumentation/_hook_machinery.py @@ -59,6 +59,12 @@ class _BaseHook(Generic[T, P]): registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) + if not typing.TYPE_CHECKING: + + @property + def __doc__(self) -> str | None: + return self.definition.__doc__ + def __post_init__(self) -> None: # As an optimization to avoid an empty function call if no callbacks are # registered, we only add the original definitions to the list of callables diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py index 62fa26a7bd..b688e49ffd 100644 --- a/src/gt4py/next/instrumentation/hooks.py +++ b/src/gt4py/next/instrumentation/hooks.py @@ -10,6 +10,6 @@ from gt4py.next.ffront.decorator import ( embedded_program_call_hook as embedded_program_call_hook, - program_call_hook as program_call_hook, + program_call_context as program_call_context, ) from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 7043d44703..43a05ca440 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -8,10 +8,12 @@ from __future__ import annotations +import abc import collections import contextlib import contextvars import dataclasses +import functools import itertools import json import numbers @@ -19,7 +21,8 @@ import sys import types import typing -from collections.abc import Mapping +from collections.abc import Callable, Mapping +from typing import Generic, TypeVar, Protocol import numpy as np @@ -44,6 +47,16 @@ ALL: Final[int] = 100 +def is_level_enabled(level: int = MINIMAL) -> bool: + """Check if a given metrics collection level is enabled.""" + return config.COLLECT_METRICS_LEVEL >= level + + +def get_current_level() -> int: + """Retrieve the current metrics collection level from the configuration.""" + return config.COLLECT_METRICS_LEVEL + + @dataclasses.dataclass(frozen=True) class Metric: """ @@ -108,51 +121,90 @@ def value_factory(self, key: str) -> Metric: return Metric(name=key) -@dataclasses.dataclass +@dataclasses.dataclass(slots=True) class Source: """A source of metrics, typically associated with a program.""" metadata: dict[str, Any] = dataclasses.field(default_factory=dict) metrics: MetricsCollection = dataclasses.field(default_factory=MetricsCollection) - assigned_key: str | None = dataclasses.field(default=None, init=False) #: Global store for all measurements. sources: collections.defaultdict[str, Source] = collections.defaultdict(Source) -# Context variable storing the active collection context. -_source_cvar: contextvars.ContextVar[Source | None] = contextvars.ContextVar("source", default=None) +# Context variables storing the active source keys. +_source_key_cvar: contextvars.ContextVar[str] = contextvars.ContextVar("source_key") + + +def is_current_source_key_set() -> bool: + """Check if there is an on-going metrics collection.""" + return _source_key_cvar.get(None) is not None + + +def get_current_source_key() -> str: + """Retrieve the current source key for metrics collection.""" + return _source_key_cvar.get() + + +def set_current_source_key(key: str) -> Source: + """Set the current source key for metrics collection.""" + assert _source_key_cvar.get(None) is None, "A source key is already set." + _source_key_cvar.set(key) + return sources[key] def get_current_source() -> Source: """Retrieve the active metrics collection source.""" - metrics_source = _source_cvar.get() - assert metrics_source is not None - return metrics_source + return sources[_source_key_cvar.get()] -def is_current_source_set() -> bool: - """Check if there is an on-going metrics collection.""" - return _source_cvar.get() is not None +def add_sample_to_current_source(metric_name: str, sample: float) -> None: + """Add a sample to a metric in the current source.""" + return get_current_source().metrics[metric_name].add_sample(sample) -def set_current_source_key(key: str) -> Source: - if not is_current_source_set(): - raise RuntimeError("No active metrics collection to assign source to.") +@dataclasses.dataclass(slots=True) +class SourceKeyContextManager(contextlib.AbstractContextManager): + """ + A context manager to handle metrics collection sources. - metrics_source = get_current_source() - if key in sources and metrics_source is not sources[key]: - # The key can only be set once, and if it matches an existing entry - # in the global store, then it must be exactly the same source object. - raise RuntimeError("Conflicting metrics source data found in the global store.") + When entering this context manager, it sets up a new source key for collection + of metrics in a module contextvar. Upon exiting the context, it resets the + contextvar to its previous state. - sources[key] = metrics_source - metrics_source.assigned_key = key - return metrics_source + Note: + This is implemented as a context manager class instead of a generator + function with `@contextlib.contextmanager` to avoid the extra overhead + of renewing the generator inside `contextlib.contextmanager`. + """ + key: str | None = None + previous_cvar_token: contextvars.Token | None = dataclasses.field(default=None, init=False) -class CollectorContextManager(contextlib.AbstractContextManager): + def __enter__(self) -> None: + if is_level_enabled() and self.key is not None: + self.previous_cvar_token = _source_key_cvar.set(self.key) + else: + self.previous_cvar_token = None + + def __exit__( + self, + exc_type_: type[BaseException] | None, + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + if self.previous_cvar_token is not None: + _source_key_cvar.reset(self.previous_cvar_token) + + +# collection_source = SourceContextManager + + +StateT = TypeVar("StateT") + + +class AbstractCollectorContextManager(contextlib.AbstractContextManager, Generic[StateT]): """ A context manager to handle metrics collection. @@ -167,34 +219,46 @@ class CollectorContextManager(contextlib.AbstractContextManager): of renewing the generator inside `contextlib.contextmanager`. """ - __slots__ = ("previous_collector_token", "source") + __slots__ = ("key", "previous_cvar_token", "enter_state") + + key: str | None + previous_cvar_token: contextvars.Token | None + enter_state: StateT | None + + @property + @abc.abstractmethod + def level(self) -> int: ... + + @property + @abc.abstractmethod + def metric_name(self) -> str: ... + + @abc.abstractmethod + def enter_collection_mode(self) -> StateT: ... - source: Source | None - previous_collector_token: contextvars.Token | None + @abc.abstractmethod + def exit_collection_mode(self, enter_state: StateT) -> float: ... - def __enter__(self) -> Source | None: - if config.COLLECT_METRICS_LEVEL > 0: - assert _source_cvar.get() is None - self.source = new_source = Source() - self.previous_collector_token = _source_cvar.set(new_source) - return new_source + def __enter__(self) -> None: + if is_level_enabled(self.level) and self.key is not None: + self.previous_cvar_token = _source_key_cvar.set(self.key) + self.enter_state = self.enter_collection_mode() else: - self.source = self.previous_collector_token = None - return None + self.previous_cvar_token = None def __exit__( self, - type_: type[BaseException] | None, + exc_type_: type[BaseException] | None, value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - if self.previous_collector_token is not None: - _source_cvar.reset(self.previous_collector_token) - if type_ is None and self.source is not None and self.source.assigned_key is None: - raise RuntimeError("Metrics source key was not set during collection.") - - -collect = CollectorContextManager + if self.previous_cvar_token is not None: + assert is_current_source_key_set() is True + assert self.enter_state is not None + get_current_source().metrics[self.metric_name].add_sample( + self.exit_collection_mode(self.enter_state) + ) + _source_key_cvar.reset(self.previous_cvar_token) def dumps(metric_sources: Mapping[str, Source] | None = None) -> str: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index a6b6d69a5e..72736191b4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -74,7 +74,8 @@ def testee(a: cases.IField, out: cases.IField): "metrics_level,expected_names", [ (metrics.DISABLED, ()), - (metrics.PERFORMANCE, ("compute",)), + (metrics.MINIMAL, ("total",)), + (metrics.PERFORMANCE, ("total", "compute")), (metrics.ALL, ("compute", "total")), ], ) diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py index e5a25b3cd3..5c793d18c4 100644 --- a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -128,7 +128,7 @@ def test_program_call_hooks(backend: gtx_typing.Backend): embedded_callback_results.clear() # Add hooks and run the program again - hooks.program_call_hook.register(custom_program_callback) + hooks.program_call_context.register(custom_program_callback) hooks.embedded_program_call_hook.register(custom_embedded_program_callback) test_program(in_field, out=out_field) @@ -156,7 +156,7 @@ def test_program_call_hooks(backend: gtx_typing.Backend): embedded_callback_results.clear() # Remove hooks and call the program again - hooks.program_call_hook.remove(custom_program_callback) + hooks.program_call_context.remove(custom_program_callback) hooks.embedded_program_call_hook.remove(custom_embedded_program_callback) test_program(in_field, out=out_field) From 2e41eea84f3f6c3903710f63bf36bbac397a3d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Mon, 19 Jan 2026 16:19:53 +0100 Subject: [PATCH 09/23] Remove hooks changes --- .../next/instrumentation/_hook_machinery.py | 194 ------------- src/gt4py/next/instrumentation/hooks.py | 15 - .../instrumentation_tests/__init__.py | 8 - .../instrumentation_tests/test_hooks.py | 167 ----------- .../test_hook_machinery.py | 266 ------------------ 5 files changed, 650 deletions(-) delete mode 100644 src/gt4py/next/instrumentation/_hook_machinery.py delete mode 100644 src/gt4py/next/instrumentation/hooks.py delete mode 100644 tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py delete mode 100644 tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py delete mode 100644 tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py deleted file mode 100644 index 0c23145212..0000000000 --- a/src/gt4py/next/instrumentation/_hook_machinery.py +++ /dev/null @@ -1,194 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import ast -import collections.abc -import contextlib -import dataclasses -import inspect -import textwrap -import types -import typing -import warnings -from collections.abc import Callable -from typing import Generic, ParamSpec, TypeVar - - -P = ParamSpec("P") -T = TypeVar("T") - - -def _get_unique_name(func: Callable) -> str: - """Generate a unique name for a callable object.""" - return ( - f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}" - ) - - -def _is_empty_function(func: Callable) -> bool: - """Check if a callable object is empty (i.e., contains no statements).""" - try: - assert callable(func) - callable_src = ( - inspect.getsource(func) - if isinstance(func, types.FunctionType) - else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above - ) - callable_ast = ast.parse(textwrap.dedent(callable_src)) - return all( - isinstance(st, ast.Pass) - or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant)) - for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body - ) - except Exception: - return False - - -@dataclasses.dataclass(slots=True) -class _BaseHook(Generic[T, P]): - """Base class to define callback registration functionality for all hook types.""" - - definition: Callable[P, T] - registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) - callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) - - if not typing.TYPE_CHECKING: - - @property - def __doc__(self) -> str | None: - return self.definition.__doc__ - - def __post_init__(self) -> None: - # As an optimization to avoid an empty function call if no callbacks are - # registered, we only add the original definitions to the list of callables - # if it contains a non-empty definition. - if not _is_empty_function(self.definition): - self.callbacks = (self.definition,) - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - raise NotImplementedError("This method should be implemented by subclasses.") - - def register( - self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None - ) -> None: - """ - Register a callback to the hook. - - Args: - callback: The callable to register. - name: An optional name for the callback. If not provided, a unique name will be generated. - index: An optional index at which to insert the callback (not counting the original - definition). If not provided, the callback will be appended to the end of the list. - """ - - callable_signature = inspect.signature(callback) - hook_signature = inspect.signature(self.definition) - - signature_mismatch = len(callable_signature.parameters) != len( - hook_signature.parameters - ) or any( - # Remove the annotation before comparison to avoid false mismatches - actual_param.replace(annotation="") != expected_param.replace(annotation="") - for actual_param, expected_param in zip( - callable_signature.parameters.values(), hook_signature.parameters.values() - ) - ) - if signature_mismatch: - raise ValueError( - f"Callback signature {callable_signature} does not match hook signature {hook_signature}" - ) - try: - callable_typing = typing.get_type_hints(callback) - hook_typing = typing.get_type_hints(self.definition) - if not all( - callable_typing[arg_key] == arg_typing - for arg_key, arg_typing in hook_typing.items() - ): - warnings.warn( - f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}", - stacklevel=2, - ) - except Exception: - # Ignore issues while checking type hints (e.g., forward references - # or missing imports); failure here should not prevent hook registration. - pass - - name = name or _get_unique_name(callback) - - if index is None: - self.callbacks += (callback,) - else: - if self.callbacks and self.callbacks[0] is self.definition: - index += 1 # The original definition should always go first - self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:]) - - self.registry[name] = callback - - def remove(self, callback: str | Callable[P, T]) -> None: - """ - Remove a registered callback from the hook. - - Args: - callback: The callable object to remove or its registered name. - """ - if isinstance(callback, str): - name = callback - if name not in self.registry: - raise KeyError(f"No callback registered under the name '{name}'") - else: - name = _get_unique_name(callback) - if name not in self.registry: - raise KeyError(f"Callback object {callback} not found in registry") - - callback = self.registry.pop(name) - assert callback in self.callbacks - self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback) - - -@dataclasses.dataclass(slots=True) -class EventHook(_BaseHook[None, P]): - """Event hook specification.""" - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: - for func in self.callbacks: - func(*args, **kwargs) - - -@dataclasses.dataclass(slots=True) -class ContextHook( - contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P] -): - """ - Context hook specification. - - This hook type is used to define context managers that can be stacked together. - """ - - ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field( - default=(), init=False - ) - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager: - self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks] - return self - - def __enter__(self) -> None: - for ctx_manager in self.ctx_managers: - ctx_manager.__enter__() - - def __exit__( - self, - type_: type[BaseException] | None, - value: BaseException | None, - traceback: types.TracebackType | None, - ) -> None: - for ctx_manager in reversed(self.ctx_managers): - ctx_manager.__exit__(type_, value, traceback) - self.ctx_managers = () diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py deleted file mode 100644 index b688e49ffd..0000000000 --- a/src/gt4py/next/instrumentation/hooks.py +++ /dev/null @@ -1,15 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from gt4py.next.ffront.decorator import ( - embedded_program_call_hook as embedded_program_call_hook, - program_call_context as program_call_context, -) -from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py deleted file mode 100644 index abf4c3e24c..0000000000 --- a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py deleted file mode 100644 index 5c793d18c4..0000000000 --- a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py +++ /dev/null @@ -1,167 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import contextlib -from typing import Any - -import pytest - -import gt4py.next as gtx -from gt4py.next import common, Dims, gtfn_cpu, typing as gtx_typing -from gt4py.next.instrumentation import hooks - -try: - from gt4py.next.program_processors.runners import dace as dace_backends - - BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu_cached] -except ImportError: - BACKENDS = [None, gtfn_cpu] - - -callback_results = [] -embedded_callback_results = [] - - -@contextlib.contextmanager -def custom_program_callback( - program: gtx_typing.Program, - args: tuple[Any, ...], - offset_provider: common.OffsetProvider, - enable_jit: bool, - kwargs: dict[str, Any], -) -> contextlib.AbstractContextManager: - callback_results.append(("enter", None)) - - yield - - callback_results.append( - ( - "custom_program_callback", - { - "program": program.__name__, - "args": args, - "offset_provider": offset_provider.keys(), - "enable_jit": enable_jit, - "kwargs": kwargs.keys(), - }, - ) - ) - - -@contextlib.contextmanager -def custom_embedded_program_callback( - program: gtx_typing.Program, - args: tuple[Any, ...], - offset_provider: common.OffsetProvider, - kwargs: dict[str, Any], -) -> contextlib.AbstractContextManager: - embedded_callback_results.append(("enter", None)) - - yield - - embedded_callback_results.append( - ( - "custom_embedded_program_callback", - { - "program": program.__name__, - "args": args, - "offset_provider": offset_provider.keys(), - "kwargs": kwargs.keys(), - }, - ) - ) - - -# @hooks.compile_variant_hook.register -# def compile_variant_hook( -# key: tuple[tuple[Hashable, ...], int], -# backend: gtx_backend.Backend, -# program_definition: "ffront_stages.ProgramDefinition", -# compile_time_args: "arguments.CompileTimeArgs", -# ) -> None: -# """Callback hook invoked before compiling a program variant.""" -# ... - - -Cell = gtx.Dimension("Cell") -IDim = gtx.Dimension("IDim") - - -@gtx.field_operator -def identity_fop( - in_field: gtx.Field[Dims[IDim], gtx.float64], -) -> gtx.Field[Dims[IDim], gtx.float64]: - return in_field - - -@gtx.program -def copy_program( - in_field: gtx.Field[Dims[IDim], gtx.float64], out: gtx.Field[Dims[IDim], gtx.float64] -): - identity_fop(in_field, out=out) - - -@pytest.mark.parametrize("backend", BACKENDS, ids=lambda b: getattr(b, "name", str(b))) -def test_program_call_hooks(backend: gtx_typing.Backend): - size = 10 - in_field = gtx.full([(IDim, size)], 1, dtype=gtx.float64) - out_field = gtx.empty([(IDim, size)], dtype=gtx.float64) - - test_program = copy_program.with_backend(backend) - - # Run the program without hooks - callback_results.clear() - embedded_callback_results.clear() - test_program(in_field, out=out_field) - - # Callbacks should not have been called - assert callback_results == [] - callback_results.clear() - assert embedded_callback_results == [] - embedded_callback_results.clear() - - # Add hooks and run the program again - hooks.program_call_context.register(custom_program_callback) - hooks.embedded_program_call_hook.register(custom_embedded_program_callback) - test_program(in_field, out=out_field) - - # Check that the callbacks were called - assert len(callback_results) == 2 - assert callback_results[0] == ("enter", None) - - hook_name, hook_call_info = callback_results[1] - assert hook_name == "custom_program_callback" - assert hook_call_info["program"] == test_program.__name__ - - # The embedded program call hook should have also been called - # with the embedded backend - if backend is None: - assert len(embedded_callback_results) == 2 - assert embedded_callback_results[0] == ("enter", None) - - hook_name, hook_call_info = embedded_callback_results[1] - assert hook_name == "custom_embedded_program_callback" - assert hook_call_info["program"] == copy_program.__name__ - else: - assert len(embedded_callback_results) == 0 - - callback_results.clear() - embedded_callback_results.clear() - - # Remove hooks and call the program again - hooks.program_call_context.remove(custom_program_callback) - hooks.embedded_program_call_hook.remove(custom_embedded_program_callback) - test_program(in_field, out=out_field) - - # Callbacks should not have been called - assert callback_results == [] - callback_results.clear() - assert embedded_callback_results == [] - embedded_callback_results.clear() diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py deleted file mode 100644 index e8711e3913..0000000000 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py +++ /dev/null @@ -1,266 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import contextlib - -import pytest - -from gt4py.next.instrumentation._hook_machinery import ( - EventHook, - ContextHook, - _get_unique_name, - _is_empty_function, -) - - -def test_get_unique_name(): - def func1(): - pass - - def func2(): - pass - - assert _get_unique_name(func1) != _get_unique_name(func2) - - class A: - def __call__(self): ... - - assert _get_unique_name(A) == _get_unique_name(A) - - a1, a2 = A(), A() - - assert (a1_name := _get_unique_name(a1)) != (a2_name := _get_unique_name(a2)) - assert _get_unique_name(a1) == a1_name - assert _get_unique_name(a2) == a2_name - - -def test_empty_function(): - def empty(): - pass - - assert _is_empty_function(empty) is True - - def non_empty(): - return 1 - - assert _is_empty_function(non_empty) is False - - def with_docstring(): - """This is a docstring.""" - - assert _is_empty_function(with_docstring) is True - - def with_ellipsis(): ... - - assert _is_empty_function(with_ellipsis) is True - - class A: - def __call__(self): ... - - assert _is_empty_function(A()) is True - - -class TestEventHook: - def test_event_hook_call_with_no_callbacks(self): - @EventHook - def hook(x: int) -> None: - pass - - hook(42) # Should not raise - - def test_event_hook_call_with_callbacks(self): - results = [] - - @EventHook - def hook(x: int) -> None: - pass - - def callback1(x: int) -> None: - results.append(x) - - def callback2(x: int) -> None: - results.append(x * 2) - - hook.register(callback1) - hook.register(callback2) - hook(5) - - assert results == [5, 10] - - def test_event_hook_register_with_signature_mismatch(self): - @EventHook - def hook(x: int) -> None: - pass - - def bad_callback(x: int, y: int) -> None: - pass - - with pytest.raises(ValueError, match="Callback signature"): - hook.register(bad_callback) - - def test_event_hook_register_with_annotation_mismatch(self): - @EventHook - def hook(x: int) -> None: - pass - - def weird_callback(x: str) -> None: - pass - - with pytest.warns(UserWarning, match="Callback annotations"): - hook.register(weird_callback) - - def test_event_hook_register_with_name(self): - @EventHook - def hook(x: int) -> None: - pass - - def callback(x: int) -> None: - pass - - hook.register(callback, name="my_callback") - - assert "my_callback" in hook.registry - - def test_event_hook_register_with_index(self): - results = [] - - @EventHook - def hook(x: int) -> None: - pass - - def callback1(x: int) -> None: - results.append(1) - - def callback2(x: int) -> None: - results.append(2) - - hook.register(callback1) - hook.register(callback2, index=0) - hook(0) - - assert results == [2, 1] - - def test_event_hook_remove_by_name(self): - results = [] - - @EventHook - def hook(x: int) -> None: - pass - - def callback(x: int) -> None: - results.append(x) - - hook.register(callback, name="test_cb") - hook(42) - assert results == [42] - - hook.remove("test_cb") - results = [] - hook(42) - - assert results == [] - - def test_event_hook_remove_by_callback(self): - results = [] - - @EventHook - def hook(x: int) -> None: - pass - - def callback(x: int) -> None: - results.append(x) - - hook.register(callback) - hook(42) - assert results == [42] - - hook.remove(callback) - results = [] - hook(42) - - assert results == [] - - def test_event_hook_remove_nonexistent_raises(self): - @EventHook - def hook(x: int) -> None: - pass - - with pytest.raises(KeyError): - hook.remove("nonexistent") - - -class TestContextHook: - def test_context_hook_basic(self): - enter_called = [] - exit_called = [] - - @ContextHook - def hook() -> contextlib.AbstractContextManager: - pass - - @contextlib.contextmanager - def callback(): - enter_called.append(True) - yield - exit_called.append(True) - - hook.register(callback) - - with hook(): - assert len(enter_called) == 1 - - assert len(exit_called) == 1 - - def test_context_hook_multiple_callbacks(self): - order = [] - - @ContextHook - def hook() -> contextlib.AbstractContextManager: - pass - - @contextlib.contextmanager - def callback1(): - order.append("enter1") - yield - order.append("exit1") - - @contextlib.contextmanager - def callback2(): - order.append("enter2") - yield - order.append("exit2") - - hook.register(callback1) - hook.register(callback2) - - with hook(): - pass - - # Entry in order, but exit in reverse - assert order == ["enter1", "enter2", "exit2", "exit1"] - - def test_context_hook_with_arguments(self): - results = [] - - @ContextHook - def hook(x: int) -> contextlib.AbstractContextManager: - pass - - @contextlib.contextmanager - def callback(x: int): - results.append(x) - yield - - hook.register(callback) - - with hook(42): - pass - - assert results == [42] From ffefed335b58135557d7fb4cfd195ea1778e42e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 20 Jan 2026 08:53:51 +0100 Subject: [PATCH 10/23] More metrics refactorings and cleanups --- src/gt4py/next/ffront/decorator.py | 45 ++----------- src/gt4py/next/instrumentation/metrics.py | 67 ++++++++++--------- src/gt4py/next/otf/compiled_program.py | 18 +---- .../runners/dace/workflow/decoration.py | 2 +- .../runners/dace/workflow/translation.py | 2 +- .../next/program_processors/runners/gtfn.py | 2 +- 6 files changed, 49 insertions(+), 87 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index dd28cb3057..8db995d860 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -12,10 +12,8 @@ from __future__ import annotations -import contextlib import dataclasses import functools -import time import types import typing import warnings @@ -46,7 +44,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.instrumentation import _hook_machinery, metrics +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, stages, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -55,38 +53,10 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -class ProgramCallMetricsCollector(metrics.AbstractCollectorContextManager[float]): +class ProgramCallMetricsCollector(metrics.AbstractCollectorContextManager): + level: Final[int] = metrics.MINIMAL metric_name: Final[str] = metrics.TOTAL_METRIC - def enter_collection_mode(self) -> float: - return time.perf_counter() - - def exit_collection_mode(self, enter_state: float) -> float: - return time.perf_counter() - enter_state - - -@_hook_machinery.ContextHook -def program_call_context( - program: Program, - args: tuple[Any, ...], - offset_provider: common.OffsetProvider, - enable_jit: bool, - kwargs: dict[str, Any], -) -> contextlib.AbstractContextManager: - """Hook called at the beginning and end of a program call.""" - return ProgramCallMetricsCollector() - - -@_hook_machinery.ContextHook -def embedded_program_call_hook( # type: ignore[empty-body] - program: Program, - args: tuple[Any, ...], - offset_provider: common.OffsetProvider, - kwargs: dict[str, Any], -) -> contextlib.AbstractContextManager: - """Hook called at the beginning and end of an embedded program call.""" - ... - # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @@ -297,7 +267,7 @@ def __call__( self.enable_jit if self.enable_jit is not None else config.ENABLE_JIT_DEFAULT ) - with program_call_context(self, args, offset_provider, enable_jit, kwargs): + with ProgramCallMetricsCollector(): if __debug__: # TODO: remove or make dependency on self.past_stage optional past_process_args._validate_args( @@ -321,15 +291,14 @@ def __call__( # Metrics source key needs to be setup here, since embedded programs # don't have variants so there's no other place to do it. - if config.COLLECT_METRICS_LEVEL: - #assert metrics_source is not None + if metrics.is_level_enabled(metrics.MINIMAL): + # assert metrics_source is not None metrics.set_current_source_key( f"{self.__name__}<{getattr(self.backend, 'name', '')}>" ) with next_embedded.context.update(offset_provider=offset_provider): - with embedded_program_call_hook(self, args, offset_provider, kwargs): - self.definition_stage.definition(*args, **kwargs) + self.definition_stage.definition(*args, **kwargs) def compile( self, diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 43a05ca440..b4424a59ec 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -13,16 +13,16 @@ import contextlib import contextvars import dataclasses -import functools import itertools import json import numbers import pathlib import sys +import time import types import typing from collections.abc import Callable, Mapping -from typing import Generic, TypeVar, Protocol +from typing import ClassVar import numpy as np @@ -32,6 +32,9 @@ from gt4py.next.otf import arguments +_NO_KEY_SET_MARKER_: Final[str] = sys.intern("@_NO_KEY_SET_MARKER_@") + + # Common metric names COMPUTE_METRIC: Final[str] = sys.intern("compute") TOTAL_METRIC: Final[str] = sys.intern("total") @@ -47,7 +50,12 @@ ALL: Final[int] = 100 -def is_level_enabled(level: int = MINIMAL) -> bool: +def is_enabled() -> bool: + """Check if a given metrics collection level is enabled.""" + return config.COLLECT_METRICS_LEVEL > DISABLED + + +def is_level_enabled(level: int) -> bool: """Check if a given metrics collection level is enabled.""" return config.COLLECT_METRICS_LEVEL >= level @@ -139,7 +147,7 @@ class Source: def is_current_source_key_set() -> bool: """Check if there is an on-going metrics collection.""" - return _source_key_cvar.get(None) is not None + return _source_key_cvar.get(_NO_KEY_SET_MARKER_) is not _NO_KEY_SET_MARKER_ def get_current_source_key() -> str: @@ -149,7 +157,9 @@ def get_current_source_key() -> str: def set_current_source_key(key: str) -> Source: """Set the current source key for metrics collection.""" - assert _source_key_cvar.get(None) is None, "A source key is already set." + assert _source_key_cvar.get(_NO_KEY_SET_MARKER_) is _NO_KEY_SET_MARKER_, ( + "A source key is already set." + ) _source_key_cvar.set(key) return sources[key] @@ -180,11 +190,11 @@ class SourceKeyContextManager(contextlib.AbstractContextManager): """ key: str | None = None - previous_cvar_token: contextvars.Token | None = dataclasses.field(default=None, init=False) + previous_cvar_token: contextvars.Token | None = dataclasses.field(init=False) def __enter__(self) -> None: - if is_level_enabled() and self.key is not None: - self.previous_cvar_token = _source_key_cvar.set(self.key) + if is_enabled(): + self.previous_cvar_token = _source_key_cvar.set(self.key or _NO_KEY_SET_MARKER_) else: self.previous_cvar_token = None @@ -198,13 +208,11 @@ def __exit__( _source_key_cvar.reset(self.previous_cvar_token) -# collection_source = SourceContextManager - +metrics_context = SourceKeyContextManager -StateT = TypeVar("StateT") - -class AbstractCollectorContextManager(contextlib.AbstractContextManager, Generic[StateT]): +@dataclasses.dataclass(slots=True) +class AbstractCollectorContextManager(contextlib.AbstractContextManager): """ A context manager to handle metrics collection. @@ -219,12 +227,6 @@ class AbstractCollectorContextManager(contextlib.AbstractContextManager, Generic of renewing the generator inside `contextlib.contextmanager`. """ - __slots__ = ("key", "previous_cvar_token", "enter_state") - - key: str | None - previous_cvar_token: contextvars.Token | None - enter_state: StateT | None - @property @abc.abstractmethod def level(self) -> int: ... @@ -233,18 +235,22 @@ def level(self) -> int: ... @abc.abstractmethod def metric_name(self) -> str: ... - @abc.abstractmethod - def enter_collection_mode(self) -> StateT: ... + enter_collection_callback: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) - @abc.abstractmethod - def exit_collection_mode(self, enter_state: StateT) -> float: ... + exit_collection_callback: ClassVar[Callable[[float], float]] = staticmethod( + lambda enter_state: time.perf_counter() - enter_state + ) + + key: str | None = None + previous_cvar_token: contextvars.Token = dataclasses.field(init=False) + enter_state: float | None = dataclasses.field(init=False) def __enter__(self) -> None: - if is_level_enabled(self.level) and self.key is not None: - self.previous_cvar_token = _source_key_cvar.set(self.key) - self.enter_state = self.enter_collection_mode() + if is_level_enabled(self.level): + self.enter_state = self.enter_collection_callback() + self.previous_cvar_token = _source_key_cvar.set(self.key or _NO_KEY_SET_MARKER_) else: - self.previous_cvar_token = None + self.enter_state = None def __exit__( self, @@ -252,11 +258,10 @@ def __exit__( value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - if self.previous_cvar_token is not None: + if self.enter_state is not None: assert is_current_source_key_set() is True - assert self.enter_state is not None - get_current_source().metrics[self.metric_name].add_sample( - self.exit_collection_mode(self.enter_state) + sources[_source_key_cvar.get()].metrics[self.metric_name].add_sample( + self.exit_collection_callback(self.enter_state) ) _source_key_cvar.reset(self.previous_cvar_token) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index bbc0f17fb1..4af678c607 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -19,7 +19,7 @@ from gt4py.eve import extended_typing as xtyping, utils as eve_utils from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront -from gt4py.next.instrumentation import _hook_machinery, metrics +from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -55,17 +55,6 @@ def _init_async_compilation_pool() -> None: _init_async_compilation_pool() -@_hook_machinery.EventHook -def compile_variant_hook( - key: CompiledProgramsKey, - backend: gtx_backend.Backend, - program_definition: ffront_stages.ProgramDefinition, - compile_time_args: arguments.CompileTimeArgs, -) -> None: - """Callback hook invoked before compiling a program variant.""" - ... - - def wait_for_compilation() -> None: """ Waits for all ongoing compilations to finish. @@ -274,7 +263,7 @@ def __call__( try: program = self.compiled_programs[key] - if config.COLLECT_METRICS_LEVEL: + if metrics.is_level_enabled(metrics.MINIMAL): metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) program(*args, **kwargs, offset_provider=offset_provider) # type: ignore[operator] # the Future case is handled below @@ -430,7 +419,7 @@ def _compile_variant( raise ValueError(f"Program with key {key} already exists.") # If we are collecting metrics, create a new metrics entity for this compiled program - if config.COLLECT_METRICS_LEVEL: + if metrics.is_level_enabled(metrics.MINIMAL): metrics_source = metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) metrics_source.metadata |= dict( name=self.definition_stage.definition.__name__, @@ -453,7 +442,6 @@ def _compile_variant( compile_call = functools.partial( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args ) - compile_variant_hook(key, self.backend, self.definition_stage, compile_time_args) if _async_compilation_pool is None: # synchronous compilation self.compiled_programs[key] = compile_call() diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index de3d25aaa5..a9bdc9eeab 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -29,7 +29,7 @@ def convert_args( device: core_defs.DeviceType = core_defs.DeviceType.CPU, ) -> stages.CompiledProgram: # Retieve metrics level from GT4Py environment variable. - collect_time = config.COLLECT_METRICS_LEVEL >= metrics.PERFORMANCE + collect_time = metrics.is_level_enabled(metrics.PERFORMANCE) collect_time_arg = np.array([1], dtype=np.float64) # We use the callback function provided by the compiled program to update the SDFG arglist. update_sdfg_call_args = functools.partial( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 6e6b6b904b..696a30921f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -268,7 +268,7 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: dace.Memlet(f"{output}[0]"), ) - if (config.COLLECT_METRICS_LEVEL == metrics.GPU_TX_MARKERS) and gpu: + if metrics.is_level_enabled(metrics.GPU_TX_MARKERS) and gpu: sdfg.instrument = dace.dtypes.InstrumentationType.GPU_TX_MARKERS for node, _ in sdfg.all_nodes_recursive(): if isinstance( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 27363ea2a8..558ba38c16 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -58,7 +58,7 @@ def decorated_program( conn_args = extract_connectivity_args(offset_provider, device) opt_kwargs: dict[str, Any] = {} - if collect_metrics := (config.COLLECT_METRICS_LEVEL >= metrics.PERFORMANCE): + if collect_metrics := metrics.is_level_enabled(metrics.PERFORMANCE): # If we are collecting metrics, we need to add the `exec_info` argument # to the `inp` call, which will be used to collect performance metrics. exec_info: dict[str, float] = {} From edbe49b75785c1e600f9efebb7b34fa9144fed60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 20 Jan 2026 09:24:11 +0100 Subject: [PATCH 11/23] Fixes --- src/gt4py/next/instrumentation/metrics.py | 11 ++++++----- .../runners/dace/workflow/translation.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index c4e41ff292..74627bb03a 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -156,8 +156,8 @@ def get_current_source_key() -> str: def set_current_source_key(key: str) -> Source: """Set the current source key for metrics collection.""" - assert _source_key_cvar.get(_NO_KEY_SET_MARKER_) is _NO_KEY_SET_MARKER_, ( - "A source key is already set." + assert _source_key_cvar.get(_NO_KEY_SET_MARKER_) in {key, _NO_KEY_SET_MARKER_}, ( + "A different source key has been already set." ) _source_key_cvar.set(key) return sources[key] @@ -241,7 +241,7 @@ def metric_name(self) -> str: ... ) key: str | None = None - previous_cvar_token: contextvars.Token = dataclasses.field(init=False) + previous_cvar_token: contextvars.Token | None = dataclasses.field(init=False) enter_state: float | None = dataclasses.field(init=False) def __enter__(self) -> None: @@ -249,7 +249,7 @@ def __enter__(self) -> None: self.enter_state = self.enter_collection_callback() self.previous_cvar_token = _source_key_cvar.set(self.key or _NO_KEY_SET_MARKER_) else: - self.enter_state = None + self.previous_cvar_token = None def __exit__( self, @@ -257,8 +257,9 @@ def __exit__( value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - if self.enter_state is not None: + if self.previous_cvar_token is not None: assert is_current_source_key_set() is True + assert self.enter_state is not None sources[_source_key_cvar.get()].metrics[self.metric_name].add_sample( self.exit_collection_callback(self.enter_state) ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index da7f832fa3..9320ef3bdd 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -15,7 +15,8 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config, metrics +from gt4py.next import common, config +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface From 332ba9b70a2482469e3eeecb00fb5ca6ea5c8278 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 10:54:01 +0100 Subject: [PATCH 12/23] Typing fixes and renames --- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/instrumentation/metrics.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8db995d860..ed0a8b6c60 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -53,7 +53,7 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -class ProgramCallMetricsCollector(metrics.AbstractCollectorContextManager): +class ProgramCallMetricsCollector(metrics.AbstractMetricsCollector): level: Final[int] = metrics.MINIMAL metric_name: Final[str] = metrics.TOTAL_METRIC diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 74627bb03a..196f20c2fe 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -16,6 +16,7 @@ import itertools import json import numbers +import operator import pathlib import sys import time @@ -211,7 +212,7 @@ def __exit__( @dataclasses.dataclass(slots=True) -class AbstractCollectorContextManager(contextlib.AbstractContextManager): +class AbstractMetricsCollector(contextlib.AbstractContextManager): """ A context manager to handle metrics collection. @@ -234,7 +235,9 @@ def level(self) -> int: ... @abc.abstractmethod def metric_name(self) -> str: ... - enter_collection_callback: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) + collect_enter_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) + collect_exit_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) + compute_metric: ClassVar[Callable[[float, float], float]] = staticmethod(operator.sub) exit_collection_callback: ClassVar[Callable[[float], float]] = staticmethod( lambda enter_state: time.perf_counter() - enter_state @@ -242,11 +245,11 @@ def metric_name(self) -> str: ... key: str | None = None previous_cvar_token: contextvars.Token | None = dataclasses.field(init=False) - enter_state: float | None = dataclasses.field(init=False) + enter_counter: float | None = dataclasses.field(init=False) def __enter__(self) -> None: if is_level_enabled(self.level): - self.enter_state = self.enter_collection_callback() + self.enter_counter = self.collect_enter_counter() # type: ignore[misc] # mypy doesn't understand that this is a staticmethod self.previous_cvar_token = _source_key_cvar.set(self.key or _NO_KEY_SET_MARKER_) else: self.previous_cvar_token = None @@ -259,9 +262,9 @@ def __exit__( ) -> None: if self.previous_cvar_token is not None: assert is_current_source_key_set() is True - assert self.enter_state is not None + assert hasattr(self, "enter_state") and self.enter_counter is not None sources[_source_key_cvar.get()].metrics[self.metric_name].add_sample( - self.exit_collection_callback(self.enter_state) + self.compute_metric(self.collect_exit_counter(), self.enter_counter) # type: ignore[call-arg,misc] # mypy doesn't understand that this is a staticmethod ) _source_key_cvar.reset(self.previous_cvar_token) From ca7d62a1c5a464a54556a665e5482c091f993add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 11:27:03 +0100 Subject: [PATCH 13/23] Adding metrics collector maker --- src/gt4py/next/ffront/decorator.py | 10 ++-- src/gt4py/next/instrumentation/metrics.py | 62 +++++++++++++++++------ 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index ed0a8b6c60..359e88f343 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -18,7 +18,7 @@ import typing import warnings from collections.abc import Callable, Sequence -from typing import Any, Final, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar from gt4py import eve from gt4py._core import definitions as core_defs @@ -53,9 +53,9 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -class ProgramCallMetricsCollector(metrics.AbstractMetricsCollector): - level: Final[int] = metrics.MINIMAL - metric_name: Final[str] = metrics.TOTAL_METRIC +program_call_metrics_collector = metrics.make_collector( + level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC +) # TODO(tehrengruber): Decide if and how programs can call other programs. As a @@ -267,7 +267,7 @@ def __call__( self.enable_jit if self.enable_jit is not None else config.ENABLE_JIT_DEFAULT ) - with ProgramCallMetricsCollector(): + with program_call_metrics_collector(): if __debug__: # TODO: remove or make dependency on self.past_stage optional past_process_args._validate_args( diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 196f20c2fe..0a0ca75d6d 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -8,7 +8,6 @@ from __future__ import annotations -import abc import collections import contextlib import contextvars @@ -212,7 +211,7 @@ def __exit__( @dataclasses.dataclass(slots=True) -class AbstractMetricsCollector(contextlib.AbstractContextManager): +class BaseMetricsCollector(contextlib.AbstractContextManager): """ A context manager to handle metrics collection. @@ -227,30 +226,41 @@ class AbstractMetricsCollector(contextlib.AbstractContextManager): of renewing the generator inside `contextlib.contextmanager`. """ - @property - @abc.abstractmethod - def level(self) -> int: ... - - @property - @abc.abstractmethod - def metric_name(self) -> str: ... + def __init_subclass__( + cls, + *, + level: int, + metric_name: str, + collect_enter_counter: Callable[[], float] | None = None, + collect_exit_counter: Callable[[], float] | None = None, + compute_metric: Callable[[float, float], float] | None = None, + **kwargs, + ) -> types.NoneType: + super(BaseMetricsCollector, cls).__init_subclass__(**kwargs) + cls.level = level + cls.metric_name = sys.intern(metric_name) + if collect_enter_counter is not None: + cls.collect_enter_counter = staticmethod(collect_enter_counter) + if collect_exit_counter is not None: + cls.collect_exit_counter = staticmethod(collect_exit_counter) + if compute_metric is not None: + cls.compute_metric = staticmethod(compute_metric) + + level: ClassVar[int] + metric_name: ClassVar[str] collect_enter_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) collect_exit_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) compute_metric: ClassVar[Callable[[float, float], float]] = staticmethod(operator.sub) - exit_collection_callback: ClassVar[Callable[[float], float]] = staticmethod( - lambda enter_state: time.perf_counter() - enter_state - ) - key: str | None = None previous_cvar_token: contextvars.Token | None = dataclasses.field(init=False) enter_counter: float | None = dataclasses.field(init=False) def __enter__(self) -> None: if is_level_enabled(self.level): - self.enter_counter = self.collect_enter_counter() # type: ignore[misc] # mypy doesn't understand that this is a staticmethod self.previous_cvar_token = _source_key_cvar.set(self.key or _NO_KEY_SET_MARKER_) + self.enter_counter = self.collect_enter_counter() # type: ignore[misc] # mypy doesn't understand that this is a staticmethod else: self.previous_cvar_token = None @@ -262,13 +272,35 @@ def __exit__( ) -> None: if self.previous_cvar_token is not None: assert is_current_source_key_set() is True - assert hasattr(self, "enter_state") and self.enter_counter is not None + assert hasattr(self, "enter_counter") and self.enter_counter is not None sources[_source_key_cvar.get()].metrics[self.metric_name].add_sample( self.compute_metric(self.collect_exit_counter(), self.enter_counter) # type: ignore[call-arg,misc] # mypy doesn't understand that this is a staticmethod ) _source_key_cvar.reset(self.previous_cvar_token) +def make_collector( + level: int, + metric_name: str, + *, + collect_enter_counter: Callable[[], float] | None = None, + collect_exit_counter: Callable[[], float] | None = None, + compute_metric: Callable[[float, float], float] | None = None, +) -> type[BaseMetricsCollector]: + collector_kwds = dict(level=level, metric_name=metric_name) | { + key: value + for key in ["collect_enter_counter", "collect_exit_counter", "compute_metric"] + if (value := locals().get(key)) is not None + } + + return types.new_class( + f"AutoMetricsCollectorFor_{metric_name}", + bases=(BaseMetricsCollector,), + kwds=collector_kwds, + exec_body=None, + ) + + def dumps(metric_sources: Mapping[str, Source] | None = None) -> str: """ Format the metrics in the collection store as a string table. From f6fc4272dec2de6303d14d4b125835b129315bd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 11:31:07 +0100 Subject: [PATCH 14/23] Fixes --- src/gt4py/next/instrumentation/metrics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 0a0ca75d6d..bc14f96b93 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -134,7 +134,6 @@ class Source: metadata: dict[str, Any] = dataclasses.field(default_factory=dict) metrics: MetricsCollection = dataclasses.field(default_factory=MetricsCollection) - assigned_key: str | None = dataclasses.field(default=None, init=False) #: Global store for all measurements. @@ -227,15 +226,15 @@ class BaseMetricsCollector(contextlib.AbstractContextManager): """ def __init_subclass__( - cls, + cls: type[BaseMetricsCollector], *, level: int, metric_name: str, collect_enter_counter: Callable[[], float] | None = None, collect_exit_counter: Callable[[], float] | None = None, compute_metric: Callable[[float, float], float] | None = None, - **kwargs, - ) -> types.NoneType: + **kwargs: Any, + ) -> None: super(BaseMetricsCollector, cls).__init_subclass__(**kwargs) cls.level = level cls.metric_name = sys.intern(metric_name) From 7d5cb5c800572a3b5e433064e230d4848ea159e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 11:47:02 +0100 Subject: [PATCH 15/23] Fix typing due to mypy bug and minor cleanups --- src/gt4py/eve/utils.py | 2 ++ src/gt4py/next/instrumentation/metrics.py | 10 ++++++---- .../runners/dace/workflow/decoration.py | 4 +--- src/gt4py/next/program_processors/runners/gtfn.py | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 56bb520350..a0e48ae557 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -281,6 +281,8 @@ class CustomDefaultDictBase(collections.defaultdict[_K, _V]): """ + __slots__ = () + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index bc14f96b93..d0cb362f1c 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -123,13 +123,15 @@ class MetricsCollection(utils.CustomDefaultDictBase[str, Metric]): [0.1, 0.2] """ + __slots__ = () + def value_factory(self, key: str) -> Metric: assert isinstance(key, str) return Metric(name=key) @dataclasses.dataclass(slots=True) -class Source: +class Source: # type: ignore[misc] # Mypy bug fixed by: https://github.com/python/mypy/pull/20573 """A source of metrics, typically associated with a program.""" metadata: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -169,11 +171,11 @@ def get_current_source() -> Source: def add_sample_to_current_source(metric_name: str, sample: float) -> None: """Add a sample to a metric in the current source.""" - return get_current_source().metrics[metric_name].add_sample(sample) + return sources[_source_key_cvar.get()].metrics[metric_name].add_sample(sample) @dataclasses.dataclass(slots=True) -class SourceKeyContextManager(contextlib.AbstractContextManager): +class SourceKeyContextManager(contextlib.AbstractContextManager): # type: ignore[misc] # Mypy bug fixed by: https://github.com/python/mypy/pull/20573 """ A context manager to handle metrics collection sources. @@ -210,7 +212,7 @@ def __exit__( @dataclasses.dataclass(slots=True) -class BaseMetricsCollector(contextlib.AbstractContextManager): +class BaseMetricsCollector(contextlib.AbstractContextManager): # type: ignore[misc] # Mypy bug fixed by: https://github.com/python/mypy/pull/20573 """ A context manager to handle metrics collection. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index a9bdc9eeab..fc2c4d78b6 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -73,8 +73,6 @@ def decorated_program( fun.fast_call() if collect_time: - metric_source = metrics.get_current_source() - assert metric_source is not None - metric_source.metrics[metrics.COMPUTE_METRIC].add_sample(collect_time_arg[0].item()) + metrics.add_sample_to_current_source(metrics.COMPUTE_METRIC, collect_time_arg[0].item()) return decorated_program diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 558ba38c16..038f2959b6 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -72,8 +72,8 @@ def decorated_program( ) if collect_metrics: - metrics.get_current_source().metrics[metrics.COMPUTE_METRIC].add_sample( - exec_info["run_cpp_duration"] + metrics.add_sample_to_current_source( + metrics.COMPUTE_METRIC, exec_info["run_cpp_duration"] ) return decorated_program From 7b4b3ce9f8a0803509552f9894aafa789ea0ec72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 15:30:11 +0100 Subject: [PATCH 16/23] Clean up docs --- src/gt4py/next/instrumentation/metrics.py | 62 ++++++++++++++--------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index d0cb362f1c..fa03d9529d 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -49,8 +49,8 @@ ALL: Final[int] = 100 -def is_enabled() -> bool: - """Check if a given metrics collection level is enabled.""" +def is_any_level_enabled() -> bool: + """Check if any metrics collection level is enabled.""" return config.COLLECT_METRICS_LEVEL > DISABLED @@ -60,7 +60,7 @@ def is_level_enabled(level: int) -> bool: def get_current_level() -> int: - """Retrieve the current metrics collection level from the configuration.""" + """Retrieve the current metrics collection level (from the configuration module).""" return config.COLLECT_METRICS_LEVEL @@ -141,17 +141,17 @@ class Source: # type: ignore[misc] # Mypy bug fixed by: https://github.com/pyt #: Global store for all measurements. sources: collections.defaultdict[str, Source] = collections.defaultdict(Source) -# Context variables storing the active source keys. +# Context variable storing the active source key. _source_key_cvar: contextvars.ContextVar[str] = contextvars.ContextVar("source_key") def is_current_source_key_set() -> bool: - """Check if there is an on-going metrics collection.""" + """Check if there is a source key set for metrics collection.""" return _source_key_cvar.get(_NO_KEY_SET_MARKER_) is not _NO_KEY_SET_MARKER_ def get_current_source_key() -> str: - """Retrieve the current source key for metrics collection.""" + """Retrieve the current source key for metrics collection (it must be set).""" return _source_key_cvar.get() @@ -165,35 +165,33 @@ def set_current_source_key(key: str) -> Source: def get_current_source() -> Source: - """Retrieve the active metrics collection source.""" + """Retrieve the active metrics collection source (a valid source key must be set).""" return sources[_source_key_cvar.get()] def add_sample_to_current_source(metric_name: str, sample: float) -> None: - """Add a sample to a metric in the current source.""" + """Add a sample to a metric in the current source (a valid source key must be set).""" return sources[_source_key_cvar.get()].metrics[metric_name].add_sample(sample) @dataclasses.dataclass(slots=True) class SourceKeyContextManager(contextlib.AbstractContextManager): # type: ignore[misc] # Mypy bug fixed by: https://github.com/python/mypy/pull/20573 """ - A context manager to handle metrics collection sources. + A context manager to handle metrics collection source keys. When entering this context manager, it sets up a new source key for collection of metrics in a module contextvar. Upon exiting the context, it resets the contextvar to its previous state. - - Note: - This is implemented as a context manager class instead of a generator - function with `@contextlib.contextmanager` to avoid the extra overhead - of renewing the generator inside `contextlib.contextmanager`. """ key: str | None = None previous_cvar_token: contextvars.Token | None = dataclasses.field(init=False) + # This class is implemented as a context manager class instead of a generator + # function with `@contextlib.contextmanager` to avoid the extra overhead + # of renewing the generator inside `contextlib.contextmanager`. def __enter__(self) -> None: - if is_enabled(): + if is_any_level_enabled(): self.previous_cvar_token = _source_key_cvar.set(self.key or _NO_KEY_SET_MARKER_) else: self.previous_cvar_token = None @@ -216,15 +214,14 @@ class BaseMetricsCollector(contextlib.AbstractContextManager): # type: ignore[m """ A context manager to handle metrics collection. - This context manager sets up a new `SourceHandler` for collecting metrics - in a module contextvar when entering the context. Upon exiting the context, - it resets the contextvar to its previous state and checks that the collected - metrics have a proper form. + This is a base class for creating metrics collectors that measure + specific metrics during the execution of a code block. It provides + a convenient interface for managing the lifecycle of metrics collection. - Note: - This is implemented as a context manager class instead of a generator - function with `@contextlib.contextmanager` to avoid the extra overhead - of renewing the generator inside `contextlib.contextmanager`. + Subclasses need to define the `level` and `metric_name` attributes, and, + optionally override the methods for collecting counters and computing + the metric. This class offers a simple way to customize this class variables + accepting them as keyword arguments when creating the subclass. """ def __init_subclass__( @@ -247,13 +244,16 @@ def __init_subclass__( if compute_metric is not None: cls.compute_metric = staticmethod(compute_metric) + # Subclass must define these class variables level: ClassVar[int] metric_name: ClassVar[str] + # Default implementations for these methods can be overridden by subclasses collect_enter_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) collect_exit_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) compute_metric: ClassVar[Callable[[float, float], float]] = staticmethod(operator.sub) + # Instance state key: str | None = None previous_cvar_token: contextvars.Token | None = dataclasses.field(init=False) enter_counter: float | None = dataclasses.field(init=False) @@ -288,6 +288,22 @@ def make_collector( collect_exit_counter: Callable[[], float] | None = None, compute_metric: Callable[[float, float], float] | None = None, ) -> type[BaseMetricsCollector]: + """ + Create a custom metrics collector class. + + This function generates a new subclass of `BaseMetricsCollector` with + the specified configuration for metrics collection. + + Args: + level: The metrics collection level. + metric_name: The name of the metric to be collected. + collect_enter_counter: Optional function to collect the enter counter. + collect_exit_counter: Optional function to collect the exit counter. + compute_metric: Optional function to compute the metric from the counters. + + Returns: + A new subclass of `BaseMetricsCollector` configured with the provided parameters. + """ collector_kwds = dict(level=level, metric_name=metric_name) | { key: value for key in ["collect_enter_counter", "collect_exit_counter", "compute_metric"] From 54e8910b7e45def1bcf9b6da93db0692f329678d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 15:36:23 +0100 Subject: [PATCH 17/23] More docs cleanups --- src/gt4py/next/instrumentation/metrics.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index fa03d9529d..2d00467646 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -156,7 +156,17 @@ def get_current_source_key() -> str: def set_current_source_key(key: str) -> Source: - """Set the current source key for metrics collection.""" + """ + Set the current source key for metrics collection. + + It must be called only when no source key is set (or the same key is already set). + + Args: + key: The source key to set. + + Returns: + The `Source` object associated with the given key. + """ assert _source_key_cvar.get(_NO_KEY_SET_MARKER_) in {key, _NO_KEY_SET_MARKER_}, ( "A different source key has been already set." ) From 7e1a290ed697e0eb703e59296fe789037c2aaf2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 17:00:28 +0100 Subject: [PATCH 18/23] Final cleanups and more tests --- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/instrumentation/metrics.py | 13 +- .../instrumentation_tests/test_metrics.py | 241 +++++++++++++++++- 3 files changed, 247 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 359e88f343..be23252f89 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -289,7 +289,7 @@ def __call__( stacklevel=2, ) - # Metrics source key needs to be setup here, since embedded programs + # Metrics source key needs to be set here. Embedded programs # don't have variants so there's no other place to do it. if metrics.is_level_enabled(metrics.MINIMAL): # assert metrics_source is not None diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 2d00467646..65201bd3e9 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -314,16 +314,17 @@ def make_collector( Returns: A new subclass of `BaseMetricsCollector` configured with the provided parameters. """ - collector_kwds = dict(level=level, metric_name=metric_name) | { - key: value - for key in ["collect_enter_counter", "collect_exit_counter", "compute_metric"] - if (value := locals().get(key)) is not None - } return types.new_class( f"AutoMetricsCollectorFor_{metric_name}", bases=(BaseMetricsCollector,), - kwds=collector_kwds, + kwds=dict( + level=level, + metric_name=metric_name, + collect_enter_counter=collect_enter_counter, + collect_exit_counter=collect_exit_counter, + compute_metric=compute_metric, + ), exec_body=None, ) diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index c75c97b00f..b6736d4562 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -8,18 +8,255 @@ import json import pathlib +import unittest.mock from collections.abc import Mapping from typing import Any import numpy as np import pytest +from gt4py.next import config from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments -# TODO(egparedes): add tests for the logic around creating sources lazily -# (SourceHandler class and the context manager). +class TestSetCurrentSourceKey: + def test_set_current_source_key_basic(self): + """Test setting a source key when none is currently set.""" + # Reset context variable before test + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + + key = "test_source" + source = metrics.set_current_source_key(key) + + assert metrics.get_current_source_key() == key + assert metrics.sources[key] == source + assert isinstance(source, metrics.Source) + + def test_set_current_source_key_same_key_twice(self): + """Test setting the same source key twice.""" + # Reset context variable before test + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + + key = "test_source_same" + source1 = metrics.set_current_source_key(key) + source2 = metrics.set_current_source_key(key) + + assert source1 is source2 + assert metrics.get_current_source_key() == key + + def test_set_current_source_key_different_key_raises(self): + """Test that setting a different source key raises AssertionError.""" + # Reset context variable before test + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + + key1 = "test_source_1" + key2 = "test_source_2" + + metrics.set_current_source_key(key1) + + with pytest.raises(AssertionError, match="A different source key has been already set"): + metrics.set_current_source_key(key2) + + +class TestSourceKeyContextManager: + def test_context_manager_sets_and_resets_key(self): + with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + metrics._source_key_cvar.set( + metrics._NO_KEY_SET_MARKER_ + ) # Reset context variable before test + assert metrics.is_current_source_key_set() is False + + key = "context_test_key" + with metrics.metrics_context(key): + assert metrics.is_current_source_key_set() is True + assert metrics._source_key_cvar.get() == key + assert metrics.get_current_source_key() == key + + # After exit, should be reset to marker + assert metrics.is_current_source_key_set() is False + assert ( + metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) + == metrics._NO_KEY_SET_MARKER_ + ) + + def test_context_manager_with_no_key(self): + with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + metrics._source_key_cvar.set("__BEFORE__MARKER__") # Reset context variable before test + + with metrics.SourceKeyContextManager(): + # Should set to marker if no key is provided + assert ( + metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) + == metrics._NO_KEY_SET_MARKER_ + ) + + # After exit, should be the previous value + assert metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) == "__BEFORE__MARKER__" + + def test_context_manager_nested(self): + with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + key1 = "outer_key" + key2 = "inner_key" + + with metrics.SourceKeyContextManager(key=key1): + assert metrics.get_current_source_key() == key1 + with metrics.SourceKeyContextManager(key=key2): + assert metrics.get_current_source_key() == key2 + + # After inner exit, should restore to outer key + assert metrics.get_current_source_key() == key1 + + # After outer exit, should be marker + assert ( + metrics._source_key_cvar.get(metrics._NO_KEY_SET_MARKER_) + == metrics._NO_KEY_SET_MARKER_ + ) + + +class TestBaseMetricsCollector: + def test_collector_basic_timers(self): + """Test basic metrics collection with timing.""" + + class TestCollector( + metrics.BaseMetricsCollector, level=metrics.MINIMAL, metric_name="test_metric" + ): ... + + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.MINIMAL): + outer_key = "outer_key" + metrics.set_current_source_key("outer_key") + assert metrics.get_current_source_key() == outer_key + + key = "test_collector" + with TestCollector(key=key): + assert metrics.get_current_source_key() == key + + assert metrics.get_current_source_key() == outer_key + + assert key in metrics.sources + source = metrics.sources[key] + assert "test_metric" in source.metrics + assert len(source.metrics["test_metric"].samples) == 1 + assert source.metrics["test_metric"].samples[0] >= 0 + + key = "test_disabled" + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.DISABLED): + metrics.set_current_source_key(key) + + with TestCollector(key=key): + pass + + assert key not in metrics.sources or "test_metric" not in metrics.sources[key].metrics + + def test_collector_with_custom_counters(self): + """Test collector with custom counter functions.""" + + class CustomCollector( + metrics.BaseMetricsCollector, + level=metrics.PERFORMANCE, + metric_name="custom_metric", + collect_enter_counter=(lambda: 10.0), + collect_exit_counter=(lambda: 15.0), + ): ... + + key = "test_custom" + metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) + with unittest.mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics.PERFORMANCE): + with CustomCollector(key=key): + pass + + assert key in metrics.sources + source = metrics.sources[key] + assert "custom_metric" in source.metrics + assert len(source.metrics["custom_metric"].samples) == 1 + print(source.metrics["custom_metric"].samples[0]) + assert source.metrics["custom_metric"].samples[0] == 5.0 + + +class TestMakeCollector: + def test_make_collector_creates_subclass(self): + """Test that make_collector creates a proper subclass.""" + CollectorClass = metrics.make_collector(level=metrics.INFO, metric_name="test_metric") + + assert issubclass(CollectorClass, metrics.BaseMetricsCollector) + assert CollectorClass.level == metrics.INFO + assert CollectorClass.metric_name == "test_metric" + + def test_make_collector_with_custom_compute(self): + """Test make_collector with custom compute function.""" + custom_compute = lambda exit, enter: exit - enter + 10 + CollectorClass = metrics.make_collector( + level=metrics.MINIMAL, metric_name="custom_compute", compute_metric=custom_compute + ) + + assert CollectorClass.compute_metric(20, 5) == 25 + + +class TestMetric: + def test_metric_mean(self): + """Test metric mean calculation.""" + metric = metrics.Metric(name="test") + metric.add_sample(1.0) + metric.add_sample(2.0) + metric.add_sample(3.0) + + assert float(metric.mean) == 2.0 + + def test_metric_std(self): + """Test metric standard deviation calculation.""" + metric = metrics.Metric(name="test") + metric.add_sample(1.0) + metric.add_sample(2.0) + metric.add_sample(3.0) + + assert np.isclose(float(metric.std), 1.0) + + def test_metric_mean_empty_raises(self): + """Test that mean of empty metric raises ValueError.""" + metric = metrics.Metric(name="test") + + with pytest.raises(ValueError, match="Cannot compute mean"): + _ = metric.mean + + def test_metric_std_empty_raises(self): + """Test that std of empty metric raises ValueError.""" + metric = metrics.Metric(name="test") + + with pytest.raises(ValueError, match="Cannot compute std"): + _ = metric.std + + def test_metric_str_representation(self): + """Test metric string representation.""" + metric = metrics.Metric(name="test") + metric.add_sample(1.0) + metric.add_sample(3.0) + + str_repr = str(metric) + assert "e+" in str_repr or "e-" in str_repr + assert "+/-" in str_repr + + +class TestMetricsCollection: + def test_metrics_collection_auto_creates_metric(self): + """Test that MetricsCollection auto-creates Metric instances.""" + collection = metrics.MetricsCollection() + + metric = collection["new_metric"] + + assert isinstance(metric, metrics.Metric) + assert metric.name == "new_metric" + + def test_metrics_collection_persists_values(self): + """Test that MetricsCollection persists added values.""" + collection = metrics.MetricsCollection() + + collection["metric1"].add_sample(1.0) + collection["metric1"].add_sample(2.0) + + assert collection["metric1"].samples == [1.0, 2.0] @pytest.fixture From 1dbb1ee2814a7b1e3c2424709029d05d85d1ab8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 21 Jan 2026 17:21:55 +0100 Subject: [PATCH 19/23] Fix review comments --- src/gt4py/next/ffront/decorator.py | 1 - .../next_tests/unit_tests/instrumentation_tests/test_metrics.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index be23252f89..21a6c5b1fd 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -292,7 +292,6 @@ def __call__( # Metrics source key needs to be set here. Embedded programs # don't have variants so there's no other place to do it. if metrics.is_level_enabled(metrics.MINIMAL): - # assert metrics_source is not None metrics.set_current_source_key( f"{self.__name__}<{getattr(self.backend, 'name', '')}>" ) diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index b6736d4562..fdb610ba38 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -15,7 +15,6 @@ import numpy as np import pytest -from gt4py.next import config from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments @@ -172,7 +171,6 @@ class CustomCollector( source = metrics.sources[key] assert "custom_metric" in source.metrics assert len(source.metrics["custom_metric"].samples) == 1 - print(source.metrics["custom_metric"].samples[0]) assert source.metrics["custom_metric"].samples[0] == 5.0 From 0916e39a3d4a66c73d7e0e42372092fba04439c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Thu, 22 Jan 2026 17:02:26 +0100 Subject: [PATCH 20/23] Improve docs --- src/gt4py/next/instrumentation/metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 65201bd3e9..36b0f5567b 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -189,9 +189,10 @@ class SourceKeyContextManager(contextlib.AbstractContextManager): # type: ignor """ A context manager to handle metrics collection source keys. - When entering this context manager, it sets up a new source key for collection - of metrics in a module contextvar. Upon exiting the context, it resets the - contextvar to its previous state. + When entering this context manager it saves the current source key + for metrics collection and sets the new source key if provided, or + a default marker indicating no key is set. Upon exiting the context, + it resets the source key to its previous state. """ key: str | None = None @@ -262,6 +263,7 @@ def __init_subclass__( collect_enter_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) collect_exit_counter: ClassVar[Callable[[], float]] = staticmethod(time.perf_counter) compute_metric: ClassVar[Callable[[float, float], float]] = staticmethod(operator.sub) + #: compute_metric(exit_counter, enter_counter) -> metric_value # Instance state key: str | None = None From 26ac37ac4bd635d7d66005818a9af2d92bd0e233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Mon, 26 Jan 2026 12:38:56 +0100 Subject: [PATCH 21/23] Restore instrumentation hook files --- .../next/instrumentation/_hook_machinery.py | 194 +++++++++++++ src/gt4py/next/instrumentation/hooks.py | 15 + .../instrumentation_tests/__init__.py | 8 + .../instrumentation_tests/test_hooks.py | 167 +++++++++++ .../test_hook_machinery.py | 266 ++++++++++++++++++ 5 files changed, 650 insertions(+) create mode 100644 src/gt4py/next/instrumentation/_hook_machinery.py create mode 100644 src/gt4py/next/instrumentation/hooks.py create mode 100644 tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py create mode 100644 tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py create mode 100644 tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py new file mode 100644 index 0000000000..0c23145212 --- /dev/null +++ b/src/gt4py/next/instrumentation/_hook_machinery.py @@ -0,0 +1,194 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import ast +import collections.abc +import contextlib +import dataclasses +import inspect +import textwrap +import types +import typing +import warnings +from collections.abc import Callable +from typing import Generic, ParamSpec, TypeVar + + +P = ParamSpec("P") +T = TypeVar("T") + + +def _get_unique_name(func: Callable) -> str: + """Generate a unique name for a callable object.""" + return ( + f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}" + ) + + +def _is_empty_function(func: Callable) -> bool: + """Check if a callable object is empty (i.e., contains no statements).""" + try: + assert callable(func) + callable_src = ( + inspect.getsource(func) + if isinstance(func, types.FunctionType) + else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above + ) + callable_ast = ast.parse(textwrap.dedent(callable_src)) + return all( + isinstance(st, ast.Pass) + or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant)) + for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body + ) + except Exception: + return False + + +@dataclasses.dataclass(slots=True) +class _BaseHook(Generic[T, P]): + """Base class to define callback registration functionality for all hook types.""" + + definition: Callable[P, T] + registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) + callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) + + if not typing.TYPE_CHECKING: + + @property + def __doc__(self) -> str | None: + return self.definition.__doc__ + + def __post_init__(self) -> None: + # As an optimization to avoid an empty function call if no callbacks are + # registered, we only add the original definitions to the list of callables + # if it contains a non-empty definition. + if not _is_empty_function(self.definition): + self.callbacks = (self.definition,) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + raise NotImplementedError("This method should be implemented by subclasses.") + + def register( + self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None + ) -> None: + """ + Register a callback to the hook. + + Args: + callback: The callable to register. + name: An optional name for the callback. If not provided, a unique name will be generated. + index: An optional index at which to insert the callback (not counting the original + definition). If not provided, the callback will be appended to the end of the list. + """ + + callable_signature = inspect.signature(callback) + hook_signature = inspect.signature(self.definition) + + signature_mismatch = len(callable_signature.parameters) != len( + hook_signature.parameters + ) or any( + # Remove the annotation before comparison to avoid false mismatches + actual_param.replace(annotation="") != expected_param.replace(annotation="") + for actual_param, expected_param in zip( + callable_signature.parameters.values(), hook_signature.parameters.values() + ) + ) + if signature_mismatch: + raise ValueError( + f"Callback signature {callable_signature} does not match hook signature {hook_signature}" + ) + try: + callable_typing = typing.get_type_hints(callback) + hook_typing = typing.get_type_hints(self.definition) + if not all( + callable_typing[arg_key] == arg_typing + for arg_key, arg_typing in hook_typing.items() + ): + warnings.warn( + f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}", + stacklevel=2, + ) + except Exception: + # Ignore issues while checking type hints (e.g., forward references + # or missing imports); failure here should not prevent hook registration. + pass + + name = name or _get_unique_name(callback) + + if index is None: + self.callbacks += (callback,) + else: + if self.callbacks and self.callbacks[0] is self.definition: + index += 1 # The original definition should always go first + self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:]) + + self.registry[name] = callback + + def remove(self, callback: str | Callable[P, T]) -> None: + """ + Remove a registered callback from the hook. + + Args: + callback: The callable object to remove or its registered name. + """ + if isinstance(callback, str): + name = callback + if name not in self.registry: + raise KeyError(f"No callback registered under the name '{name}'") + else: + name = _get_unique_name(callback) + if name not in self.registry: + raise KeyError(f"Callback object {callback} not found in registry") + + callback = self.registry.pop(name) + assert callback in self.callbacks + self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback) + + +@dataclasses.dataclass(slots=True) +class EventHook(_BaseHook[None, P]): + """Event hook specification.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + for func in self.callbacks: + func(*args, **kwargs) + + +@dataclasses.dataclass(slots=True) +class ContextHook( + contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P] +): + """ + Context hook specification. + + This hook type is used to define context managers that can be stacked together. + """ + + ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field( + default=(), init=False + ) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager: + self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks] + return self + + def __enter__(self) -> None: + for ctx_manager in self.ctx_managers: + ctx_manager.__enter__() + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + for ctx_manager in reversed(self.ctx_managers): + ctx_manager.__exit__(type_, value, traceback) + self.ctx_managers = () diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py new file mode 100644 index 0000000000..b688e49ffd --- /dev/null +++ b/src/gt4py/next/instrumentation/hooks.py @@ -0,0 +1,15 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from gt4py.next.ffront.decorator import ( + embedded_program_call_hook as embedded_program_call_hook, + program_call_context as program_call_context, +) +from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py new file mode 100644 index 0000000000..5c793d18c4 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -0,0 +1,167 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib +from typing import Any + +import pytest + +import gt4py.next as gtx +from gt4py.next import common, Dims, gtfn_cpu, typing as gtx_typing +from gt4py.next.instrumentation import hooks + +try: + from gt4py.next.program_processors.runners import dace as dace_backends + + BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu_cached] +except ImportError: + BACKENDS = [None, gtfn_cpu] + + +callback_results = [] +embedded_callback_results = [] + + +@contextlib.contextmanager +def custom_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + callback_results.append(("enter", None)) + + yield + + callback_results.append( + ( + "custom_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "enable_jit": enable_jit, + "kwargs": kwargs.keys(), + }, + ) + ) + + +@contextlib.contextmanager +def custom_embedded_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + embedded_callback_results.append(("enter", None)) + + yield + + embedded_callback_results.append( + ( + "custom_embedded_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "kwargs": kwargs.keys(), + }, + ) + ) + + +# @hooks.compile_variant_hook.register +# def compile_variant_hook( +# key: tuple[tuple[Hashable, ...], int], +# backend: gtx_backend.Backend, +# program_definition: "ffront_stages.ProgramDefinition", +# compile_time_args: "arguments.CompileTimeArgs", +# ) -> None: +# """Callback hook invoked before compiling a program variant.""" +# ... + + +Cell = gtx.Dimension("Cell") +IDim = gtx.Dimension("IDim") + + +@gtx.field_operator +def identity_fop( + in_field: gtx.Field[Dims[IDim], gtx.float64], +) -> gtx.Field[Dims[IDim], gtx.float64]: + return in_field + + +@gtx.program +def copy_program( + in_field: gtx.Field[Dims[IDim], gtx.float64], out: gtx.Field[Dims[IDim], gtx.float64] +): + identity_fop(in_field, out=out) + + +@pytest.mark.parametrize("backend", BACKENDS, ids=lambda b: getattr(b, "name", str(b))) +def test_program_call_hooks(backend: gtx_typing.Backend): + size = 10 + in_field = gtx.full([(IDim, size)], 1, dtype=gtx.float64) + out_field = gtx.empty([(IDim, size)], dtype=gtx.float64) + + test_program = copy_program.with_backend(backend) + + # Run the program without hooks + callback_results.clear() + embedded_callback_results.clear() + test_program(in_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() + + # Add hooks and run the program again + hooks.program_call_context.register(custom_program_callback) + hooks.embedded_program_call_hook.register(custom_embedded_program_callback) + test_program(in_field, out=out_field) + + # Check that the callbacks were called + assert len(callback_results) == 2 + assert callback_results[0] == ("enter", None) + + hook_name, hook_call_info = callback_results[1] + assert hook_name == "custom_program_callback" + assert hook_call_info["program"] == test_program.__name__ + + # The embedded program call hook should have also been called + # with the embedded backend + if backend is None: + assert len(embedded_callback_results) == 2 + assert embedded_callback_results[0] == ("enter", None) + + hook_name, hook_call_info = embedded_callback_results[1] + assert hook_name == "custom_embedded_program_callback" + assert hook_call_info["program"] == copy_program.__name__ + else: + assert len(embedded_callback_results) == 0 + + callback_results.clear() + embedded_callback_results.clear() + + # Remove hooks and call the program again + hooks.program_call_context.remove(custom_program_callback) + hooks.embedded_program_call_hook.remove(custom_embedded_program_callback) + test_program(in_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py new file mode 100644 index 0000000000..e8711e3913 --- /dev/null +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -0,0 +1,266 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib + +import pytest + +from gt4py.next.instrumentation._hook_machinery import ( + EventHook, + ContextHook, + _get_unique_name, + _is_empty_function, +) + + +def test_get_unique_name(): + def func1(): + pass + + def func2(): + pass + + assert _get_unique_name(func1) != _get_unique_name(func2) + + class A: + def __call__(self): ... + + assert _get_unique_name(A) == _get_unique_name(A) + + a1, a2 = A(), A() + + assert (a1_name := _get_unique_name(a1)) != (a2_name := _get_unique_name(a2)) + assert _get_unique_name(a1) == a1_name + assert _get_unique_name(a2) == a2_name + + +def test_empty_function(): + def empty(): + pass + + assert _is_empty_function(empty) is True + + def non_empty(): + return 1 + + assert _is_empty_function(non_empty) is False + + def with_docstring(): + """This is a docstring.""" + + assert _is_empty_function(with_docstring) is True + + def with_ellipsis(): ... + + assert _is_empty_function(with_ellipsis) is True + + class A: + def __call__(self): ... + + assert _is_empty_function(A()) is True + + +class TestEventHook: + def test_event_hook_call_with_no_callbacks(self): + @EventHook + def hook(x: int) -> None: + pass + + hook(42) # Should not raise + + def test_event_hook_call_with_callbacks(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(x) + + def callback2(x: int) -> None: + results.append(x * 2) + + hook.register(callback1) + hook.register(callback2) + hook(5) + + assert results == [5, 10] + + def test_event_hook_register_with_signature_mismatch(self): + @EventHook + def hook(x: int) -> None: + pass + + def bad_callback(x: int, y: int) -> None: + pass + + with pytest.raises(ValueError, match="Callback signature"): + hook.register(bad_callback) + + def test_event_hook_register_with_annotation_mismatch(self): + @EventHook + def hook(x: int) -> None: + pass + + def weird_callback(x: str) -> None: + pass + + with pytest.warns(UserWarning, match="Callback annotations"): + hook.register(weird_callback) + + def test_event_hook_register_with_name(self): + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + pass + + hook.register(callback, name="my_callback") + + assert "my_callback" in hook.registry + + def test_event_hook_register_with_index(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(1) + + def callback2(x: int) -> None: + results.append(2) + + hook.register(callback1) + hook.register(callback2, index=0) + hook(0) + + assert results == [2, 1] + + def test_event_hook_remove_by_name(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook.register(callback, name="test_cb") + hook(42) + assert results == [42] + + hook.remove("test_cb") + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_by_callback(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook.register(callback) + hook(42) + assert results == [42] + + hook.remove(callback) + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_nonexistent_raises(self): + @EventHook + def hook(x: int) -> None: + pass + + with pytest.raises(KeyError): + hook.remove("nonexistent") + + +class TestContextHook: + def test_context_hook_basic(self): + enter_called = [] + exit_called = [] + + @ContextHook + def hook() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(): + enter_called.append(True) + yield + exit_called.append(True) + + hook.register(callback) + + with hook(): + assert len(enter_called) == 1 + + assert len(exit_called) == 1 + + def test_context_hook_multiple_callbacks(self): + order = [] + + @ContextHook + def hook() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback1(): + order.append("enter1") + yield + order.append("exit1") + + @contextlib.contextmanager + def callback2(): + order.append("enter2") + yield + order.append("exit2") + + hook.register(callback1) + hook.register(callback2) + + with hook(): + pass + + # Entry in order, but exit in reverse + assert order == ["enter1", "enter2", "exit2", "exit1"] + + def test_context_hook_with_arguments(self): + results = [] + + @ContextHook + def hook(x: int) -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(x: int): + results.append(x) + yield + + hook.register(callback) + + with hook(42): + pass + + assert results == [42] From 01a8d38757f6310e072a791309f2fe37d8dc53f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 30 Jan 2026 17:13:38 +0100 Subject: [PATCH 22/23] Use hooks for basic metrics --- src/gt4py/next/ffront/decorator.py | 47 ++++++-- .../{_hook_machinery.py => hook_machinery.py} | 8 +- src/gt4py/next/instrumentation/hooks.py | 5 +- src/gt4py/next/instrumentation/metrics.py | 4 +- src/gt4py/next/otf/compiled_program.py | 114 ++++++++++++------ .../test_hook_machinery.py | 2 +- 6 files changed, 125 insertions(+), 55 deletions(-) rename src/gt4py/next/instrumentation/{_hook_machinery.py => hook_machinery.py} (98%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 21a6c5b1fd..02346d01b8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -12,6 +12,7 @@ from __future__ import annotations +import contextlib import dataclasses import functools import types @@ -44,7 +45,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.instrumentation import metrics +from gt4py.next.instrumentation import hook_machinery, metrics from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, stages, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -53,11 +54,37 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -program_call_metrics_collector = metrics.make_collector( +ProgramCallMetricsCollector = metrics.make_collector( level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC ) +@hook_machinery.ContextHook +def program_call_context( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of a program call.""" + return ProgramCallMetricsCollector() + + +@hook_machinery.EventHook +def embedded_program_call_hook( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> None: + """Hook called at the beginning and end of an embedded program call.""" + # Metrics source key needs to be set here. Embedded programs + # don't have variants so there's no other place to do it. + if metrics.is_level_enabled(metrics.MINIMAL): + metrics.set_current_source_key(f"{program.__name__}<'')>") + + # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -267,7 +294,13 @@ def __call__( self.enable_jit if self.enable_jit is not None else config.ENABLE_JIT_DEFAULT ) - with program_call_metrics_collector(): + with program_call_context( + program=self, + args=args, + offset_provider=offset_provider, + enable_jit=enable_jit, + kwargs=kwargs, + ): if __debug__: # TODO: remove or make dependency on self.past_stage optional past_process_args._validate_args( @@ -289,14 +322,8 @@ def __call__( stacklevel=2, ) - # Metrics source key needs to be set here. Embedded programs - # don't have variants so there's no other place to do it. - if metrics.is_level_enabled(metrics.MINIMAL): - metrics.set_current_source_key( - f"{self.__name__}<{getattr(self.backend, 'name', '')}>" - ) - with next_embedded.context.update(offset_provider=offset_provider): + embedded_program_call_hook(self, args, offset_provider, kwargs) self.definition_stage.definition(*args, **kwargs) def compile( diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/hook_machinery.py similarity index 98% rename from src/gt4py/next/instrumentation/_hook_machinery.py rename to src/gt4py/next/instrumentation/hook_machinery.py index 0c23145212..f9d3be0a5a 100644 --- a/src/gt4py/next/instrumentation/_hook_machinery.py +++ b/src/gt4py/next/instrumentation/hook_machinery.py @@ -59,11 +59,9 @@ class _BaseHook(Generic[T, P]): registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) - if not typing.TYPE_CHECKING: - - @property - def __doc__(self) -> str | None: - return self.definition.__doc__ + @property + def __doc__(self) -> str | None: # type: ignore[override] + return self.definition.__doc__ def __post_init__(self) -> None: # As an optimization to avoid an empty function call if no callbacks are diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py index b688e49ffd..cba47021db 100644 --- a/src/gt4py/next/instrumentation/hooks.py +++ b/src/gt4py/next/instrumentation/hooks.py @@ -12,4 +12,7 @@ embedded_program_call_hook as embedded_program_call_hook, program_call_context as program_call_context, ) -from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook +from gt4py.next.otf.compiled_program import ( + compiled_program_call_hook as compiled_program_call_hook, + compile_variant_hook as compile_variant_hook, +) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 36b0f5567b..9d6dc01e1e 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -210,7 +210,7 @@ def __enter__(self) -> None: def __exit__( self, exc_type_: type[BaseException] | None, - value: BaseException | None, + exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: if self.previous_cvar_token is not None: @@ -280,7 +280,7 @@ def __enter__(self) -> None: def __exit__( self, exc_type_: type[BaseException] | None, - value: BaseException | None, + exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: if self.previous_cvar_token is not None: diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 4af678c607..3f4b65a3ac 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -19,7 +19,7 @@ from gt4py.eve import extended_typing as xtyping, utils as eve_utils from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront -from gt4py.next.instrumentation import metrics +from gt4py.next.instrumentation import hook_machinery, metrics from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -40,6 +40,59 @@ ArgumentDescriptorContext, ] + +def _make_pool_root( + program_definition: ffront_stages.ProgramDefinition, backend: gtx_backend.Backend +) -> tuple[str, str]: + return (program_definition.definition.__name__, backend.name) + + +@functools.cache +def _metrics_prefix_from_pool_root(root: tuple[str, str]) -> str: + """Generate a metrics prefix from a compiled programs pool root.""" + return f"{root[0]}<{root[1]}>" + + +@hook_machinery.EventHook +def compiled_program_call_hook( + compiled_program: stages.CompiledProgram, + args: tuple[Any, ...], + kwargs: dict[str, Any], + offset_provider: common.OffsetProvider, + root: tuple[str, str], + key: CompiledProgramsKey, +) -> None: + """Callback hook invoked before compiling a program variant.""" + if metrics.is_level_enabled(metrics.MINIMAL): + metrics.set_current_source_key(f"{_metrics_prefix_from_pool_root(root)}[{hash(key)}]") + + +@hook_machinery.EventHook +def compile_variant_hook( + program_definition: ffront_stages.ProgramDefinition, + backend: gtx_backend.Backend, + offset_provider: common.OffsetProviderType | common.OffsetProvider, + argument_descriptors: ArgumentDescriptors, + key: CompiledProgramsKey, +) -> None: + """Callback hook invoked before compiling a program variant.""" + + if metrics.is_level_enabled(metrics.MINIMAL): + # Create a new metrics entity for this compiled program + metrics_source = metrics.set_current_source_key( + f"{_metrics_prefix_from_pool_root(_make_pool_root(program_definition, backend))}[{hash(key)}]" + ) + metrics_source.metadata |= dict( + name=program_definition.definition.__name__, + backend=backend.name, + compiled_program_pool_key=hash(key), + **{ + f"{eve_utils.CaseStyleConverter.convert(key.__name__, 'pascal', 'snake')}s": value + for key, value in argument_descriptors.items() + }, + ) + + # TODO(havogt): We would like this to be a ProcessPoolExecutor, which requires (to decide what) to pickle. _async_compilation_pool: concurrent.futures.Executor | None = None @@ -230,8 +283,8 @@ class CompiledProgramsPool: ] = dataclasses.field(default_factory=dict, init=False) @functools.cached_property - def _primitive_values_extractor(self) -> Callable | None: - return arguments.make_primitive_value_args_extractor(self.program_type.definition) + def root(self) -> tuple[str, str]: + return _make_pool_root(self.definition_stage, self.backend) def __post_init__(self) -> None: # TODO(havogt): We currently don't support pos_only or kw_only args at the program level. @@ -262,20 +315,16 @@ def __call__( key = (static_args_values, common.hash_offset_provider_items_by_id(offset_provider)) try: - program = self.compiled_programs[key] - if metrics.is_level_enabled(metrics.MINIMAL): - metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) - - program(*args, **kwargs, offset_provider=offset_provider) # type: ignore[operator] # the Future case is handled below - - except TypeError as e: - if "program" in locals() and isinstance(program, concurrent.futures.Future): - # 'Future' objects are not callable so they will generate a TypeError. - # Here we resolve the future and call it again. - program = self._resolve_future(key) - program(*args, **kwargs, offset_provider=offset_provider) - else: - raise e + compiled_program = self.compiled_programs[key] + if not callable(compiled_program): + # 'Future' objects are not callable so we know they need to be resolved. + assert isinstance(compiled_program, concurrent.futures.Future) + compiled_program = self._resolve_future(key) + + compiled_program_call_hook( + compiled_program, args, kwargs, offset_provider, self.root, key + ) + compiled_program(*args, **kwargs, offset_provider=offset_provider) # type: ignore[operator] # the Future case is handled below except KeyError as e: if enable_jit: @@ -296,14 +345,12 @@ def __call__( raise RuntimeError("No program compiled for this set of static arguments.") from e @functools.cached_property - def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]: - return gtx_utils.make_args_canonicalizer_for_function(self.definition_stage.definition) + def _primitive_values_extractor(self) -> Callable | None: + return arguments.make_primitive_value_args_extractor(self.program_type.definition) @functools.cached_property - def _metrics_key_from_pool_key(self) -> Callable[[CompiledProgramsKey], str]: - prefix = f"{self.definition_stage.definition.__name__}<{self.backend.name}>" - - return lambda key: f"{prefix}[{hash(key)}]" + def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]: + return gtx_utils.make_args_canonicalizer_for_function(self.definition_stage.definition) @functools.cached_property def _argument_descriptor_cache_key_from_args( @@ -418,19 +465,6 @@ def _compile_variant( if key in self.compiled_programs: raise ValueError(f"Program with key {key} already exists.") - # If we are collecting metrics, create a new metrics entity for this compiled program - if metrics.is_level_enabled(metrics.MINIMAL): - metrics_source = metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) - metrics_source.metadata |= dict( - name=self.definition_stage.definition.__name__, - backend=self.backend.name, - compiled_program_pool_key=hash(key), - **{ - f"{eve_utils.CaseStyleConverter.convert(key.__name__, 'pascal', 'snake')}s": value - for key, value in argument_descriptors.items() - }, - ) - compile_time_args = arguments.CompileTimeArgs( offset_provider=offset_provider, column_axis=None, # TODO(havogt): column_axis seems to a unused, even for programs with scans @@ -442,6 +476,14 @@ def _compile_variant( compile_call = functools.partial( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args ) + compile_variant_hook( + self.definition_stage, + self.backend, + offset_provider=offset_provider, + argument_descriptors=argument_descriptors, + key=key, + ) + if _async_compilation_pool is None: # synchronous compilation self.compiled_programs[key] = compile_call() diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py index e8711e3913..4d0c605f20 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -12,7 +12,7 @@ import pytest -from gt4py.next.instrumentation._hook_machinery import ( +from gt4py.next.instrumentation.hook_machinery import ( EventHook, ContextHook, _get_unique_name, From 8e3777739524ca3c5b131f674083980b72729301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 30 Jan 2026 17:42:53 +0100 Subject: [PATCH 23/23] pre-commit --- src/gt4py/next/instrumentation/hook_machinery.py | 2 +- src/gt4py/next/instrumentation/hooks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/instrumentation/hook_machinery.py b/src/gt4py/next/instrumentation/hook_machinery.py index f9d3be0a5a..0e41da3545 100644 --- a/src/gt4py/next/instrumentation/hook_machinery.py +++ b/src/gt4py/next/instrumentation/hook_machinery.py @@ -60,7 +60,7 @@ class _BaseHook(Generic[T, P]): callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) @property - def __doc__(self) -> str | None: # type: ignore[override] + def __doc__(self) -> str | None: # type: ignore[override] return self.definition.__doc__ def __post_init__(self) -> None: diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py index cba47021db..ac5a70ebc8 100644 --- a/src/gt4py/next/instrumentation/hooks.py +++ b/src/gt4py/next/instrumentation/hooks.py @@ -13,6 +13,6 @@ program_call_context as program_call_context, ) from gt4py.next.otf.compiled_program import ( - compiled_program_call_hook as compiled_program_call_hook, compile_variant_hook as compile_variant_hook, + compiled_program_call_hook as compiled_program_call_hook, )