diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c23fff9a9a..a3fc33ddcc 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -12,6 +12,7 @@ from __future__ import annotations +import contextlib import dataclasses import functools import time @@ -32,7 +33,6 @@ config, embedded as next_embedded, errors, - metrics, utils, ) from gt4py.next.embedded import operators as embedded_operators @@ -46,6 +46,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable +from gt4py.next.instrumentation import _hook_machinery, metrics from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, stages, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -54,6 +55,29 @@ DEFAULT_BACKEND: next_backend.Backend | None = None +@_hook_machinery.ContextHook +def program_call_hook( # type: ignore[empty-body] + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of a program call.""" + ... + + +@_hook_machinery.ContextHook +def embedded_program_call_hook( # type: ignore[empty-body] + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of an embedded program call.""" + ... + + # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -275,27 +299,31 @@ def __call__( kwarg_types={k: type_translation.from_value(v) for k, v in kwargs.items()}, ) - if self.backend is not None: - self._compiled_programs( - *args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit - ) - else: - # Embedded execution. - # Metrics source key needs to be setup here, since embedded programs - # don't have variants and thus there's no other place we could do this. - if config.COLLECT_METRICS_LEVEL: - assert metrics_source is not None - metrics_source.key = ( - f"{self.__name__}<{getattr(self.backend, 'name', '')}>" + with program_call_hook(self, args, offset_provider, enable_jit, kwargs): + if self.backend is not None: + self._compiled_programs( + *args, **kwargs, offset_provider=offset_provider, enable_jit=enable_jit ) - warnings.warn( - UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." - ), - stacklevel=2, - ) - with next_embedded.context.update(offset_provider=offset_provider): - self.definition_stage.definition(*args, **kwargs) + else: + # Embedded execution. + warnings.warn( + UserWarning( + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." + ), + stacklevel=2, + ) + + # Metrics source key needs to be setup here, since embedded programs + # don't have variants so there's no other place to do it. + if config.COLLECT_METRICS_LEVEL: + assert metrics_source is not None + metrics.set_current_source_key( + f"{self.__name__}<{getattr(self.backend, 'name', '')}>" + ) + + with next_embedded.context.update(offset_provider=offset_provider): + with embedded_program_call_hook(self, args, offset_provider, kwargs): + self.definition_stage.definition(*args, **kwargs) if collect_info_metrics: assert metrics_source is not None diff --git a/src/gt4py/next/instrumentation/__init__.py b/src/gt4py/next/instrumentation/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/src/gt4py/next/instrumentation/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/src/gt4py/next/instrumentation/_hook_machinery.py b/src/gt4py/next/instrumentation/_hook_machinery.py new file mode 100644 index 0000000000..e1eb394a26 --- /dev/null +++ b/src/gt4py/next/instrumentation/_hook_machinery.py @@ -0,0 +1,188 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import ast +import collections.abc +import contextlib +import dataclasses +import inspect +import textwrap +import types +import typing +import warnings +from collections.abc import Callable +from typing import Generic, ParamSpec, TypeVar + + +P = ParamSpec("P") +T = TypeVar("T") + + +def _get_unique_name(func: Callable) -> str: + """Generate a unique name for a callable object.""" + return ( + f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}" + ) + + +def _is_empty_function(func: Callable) -> bool: + """Check if a callable object is empty (i.e., contains no statements).""" + try: + assert callable(func) + callable_src = ( + inspect.getsource(func) + if isinstance(func, types.FunctionType) + else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above + ) + callable_ast = ast.parse(textwrap.dedent(callable_src)) + return all( + isinstance(st, ast.Pass) + or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant)) + for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body + ) + except Exception: + return False + + +@dataclasses.dataclass(slots=True) +class _BaseHook(Generic[T, P]): + """Base class to define callback registration functionality for all hook types.""" + + definition: Callable[P, T] + registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) + callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) + + def __post_init__(self) -> None: + # As an optimization to avoid an empty function call if no callbacks are + # registered, we only add the original definitions to the list of callables + # if it contains a non-empty definition. + if not _is_empty_function(self.definition): + self.callbacks = (self.definition,) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + raise NotImplementedError("This method should be implemented by subclasses.") + + def register( + self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None + ) -> None: + """ + Register a callback to the hook. + + Args: + callback: The callable to register. + name: An optional name for the callback. If not provided, a unique name will be generated. + index: An optional index at which to insert the callback (not counting the original + definition). If not provided, the callback will be appended to the end of the list. + """ + + callable_signature = inspect.signature(callback) + hook_signature = inspect.signature(self.definition) + + signature_mismatch = len(callable_signature.parameters) != len( + hook_signature.parameters + ) or any( + # Remove the annotation before comparison to avoid false mismatches + actual_param.replace(annotation="") != expected_param.replace(annotation="") + for actual_param, expected_param in zip( + callable_signature.parameters.values(), hook_signature.parameters.values() + ) + ) + if signature_mismatch: + raise ValueError( + f"Callback signature {callable_signature} does not match hook signature {hook_signature}" + ) + try: + callable_typing = typing.get_type_hints(callback) + hook_typing = typing.get_type_hints(self.definition) + if not all( + callable_typing[arg_key] == arg_typing + for arg_key, arg_typing in hook_typing.items() + ): + warnings.warn( + f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}", + stacklevel=2, + ) + except Exception: + # Ignore issues while checking type hints (e.g., forward references + # or missing imports); failure here should not prevent hook registration. + pass + + name = name or _get_unique_name(callback) + + if index is None: + self.callbacks += (callback,) + else: + if self.callbacks and self.callbacks[0] is self.definition: + index += 1 # The original definition should always go first + self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:]) + + self.registry[name] = callback + + def remove(self, callback: str | Callable[P, T]) -> None: + """ + Remove a registered callback from the hook. + + Args: + callback: The callable object to remove or its registered name. + """ + if isinstance(callback, str): + name = callback + if name not in self.registry: + raise KeyError(f"No callback registered under the name '{name}'") + else: + name = _get_unique_name(callback) + if name not in self.registry: + raise KeyError(f"Callback object {callback} not found in registry") + + callback = self.registry.pop(name) + assert callback in self.callbacks + self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback) + + +@dataclasses.dataclass(slots=True) +class EventHook(_BaseHook[None, P]): + """Event hook specification.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + for func in self.callbacks: + func(*args, **kwargs) + + +@dataclasses.dataclass(slots=True) +class ContextHook( + contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P] +): + """ + Context hook specification. + + This hook type is used to define context managers that can be stacked together. + """ + + ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field( + default=(), init=False + ) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager: + self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks] + return self + + def __enter__(self) -> None: + for ctx_manager in self.ctx_managers: + ctx_manager.__enter__() + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + for ctx_manager in reversed(self.ctx_managers): + ctx_manager.__exit__(type_, value, traceback) + self.ctx_managers = () diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py new file mode 100644 index 0000000000..62fa26a7bd --- /dev/null +++ b/src/gt4py/next/instrumentation/hooks.py @@ -0,0 +1,15 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from gt4py.next.ffront.decorator import ( + embedded_program_call_hook as embedded_program_call_hook, + program_call_hook as program_call_hook, +) +from gt4py.next.otf.compiled_program import compile_variant_hook as compile_variant_hook diff --git a/src/gt4py/next/metrics.py b/src/gt4py/next/instrumentation/metrics.py similarity index 71% rename from src/gt4py/next/metrics.py rename to src/gt4py/next/instrumentation/metrics.py index b6a4beaeef..7043d44703 100644 --- a/src/gt4py/next/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -12,7 +12,6 @@ import contextlib import contextvars import dataclasses -import functools import itertools import json import numbers @@ -116,93 +115,40 @@ class Source: metadata: dict[str, Any] = dataclasses.field(default_factory=dict) metrics: MetricsCollection = dataclasses.field(default_factory=MetricsCollection) + assigned_key: str | None = dataclasses.field(default=None, init=False) + #: Global store for all measurements. sources: collections.defaultdict[str, Source] = collections.defaultdict(Source) - -class SourceHandler: - """ - A handler to manage addition of metrics sources to the global store. - - This object is used to collect metrics for a specific source (e.g., a program) - before a final key is assigned to it. The key is typically set when the program - is first executed or compiled, and it uniquely identifies the source in the - global metrics store. - """ - - def __init__(self, source: Source | None = None) -> None: - if source is not None: - self.source = source - - @property - def key(self) -> str | None: - return self.__dict__.get("_key", None) - - @key.setter - def key(self, value: str) -> None: - # The key can only be set once, and if it matches an existing source - # in the global store, it must be the same object. - if self.key is not None and self.key != value: - raise RuntimeError("Metrics source key is already set.") - - if value not in sources: - sources[value] = self.source - else: - source_in_store = sources[value] - if self.__dict__.setdefault("source", source_in_store) is not source_in_store: - raise RuntimeError("Conflicting metrics source data found in the global store.") - - self._key = value - - # The following attributes are implemented as `cached_properties` - # for efficiency and to be able to initialize them lazily when needed, - # even if the key is not set. - @functools.cached_property - def source(self) -> Source: - return Source() - - @functools.cached_property - def metrics(self) -> MetricsCollection: - return self.source.metrics - - @functools.cached_property - def metadata(self) -> dict[str, Any]: - return self.source.metadata +# Context variable storing the active collection context. +_source_cvar: contextvars.ContextVar[Source | None] = contextvars.ContextVar("source", default=None) -# Context variable storing the active collection context. -_source_cvar: contextvars.ContextVar[SourceHandler | None] = contextvars.ContextVar( - "source", default=None -) +def get_current_source() -> Source: + """Retrieve the active metrics collection source.""" + metrics_source = _source_cvar.get() + assert metrics_source is not None + return metrics_source -def in_collection_mode() -> bool: +def is_current_source_set() -> bool: """Check if there is an on-going metrics collection.""" return _source_cvar.get() is not None -def get_current_source() -> SourceHandler: - """Retrieve the active metrics collection source.""" - source_handler = _source_cvar.get() - assert source_handler is not None - return source_handler - +def set_current_source_key(key: str) -> Source: + if not is_current_source_set(): + raise RuntimeError("No active metrics collection to assign source to.") -def get_source(key: str, *, assign_current: bool = True) -> Source: - """ - Retrieve a metrics source by its key, optionally associating it to the current context. - """ - if in_collection_mode() and assign_current: - metrics_source_handler = get_current_source() - # Set the key if not already set, which will also add the - # source to the global store. Note that if the key is already set, - # this will only succeed if the same object. - metrics_source_handler.key = key - metrics_source = metrics_source_handler.source - else: - metrics_source = sources[key] + metrics_source = get_current_source() + if key in sources and metrics_source is not sources[key]: + # The key can only be set once, and if it matches an existing entry + # in the global store, then it must be exactly the same source object. + raise RuntimeError("Conflicting metrics source data found in the global store.") + sources[key] = metrics_source + metrics_source.assigned_key = key return metrics_source @@ -221,19 +167,19 @@ class CollectorContextManager(contextlib.AbstractContextManager): of renewing the generator inside `contextlib.contextmanager`. """ - __slots__ = ("previous_collector_token", "source_handler") + __slots__ = ("previous_collector_token", "source") - source_handler: SourceHandler | None + source: Source | None previous_collector_token: contextvars.Token | None - def __enter__(self) -> SourceHandler | None: + def __enter__(self) -> Source | None: if config.COLLECT_METRICS_LEVEL > 0: assert _source_cvar.get() is None - self.source_handler = SourceHandler() - self.previous_collector_token = _source_cvar.set(self.source_handler) - return self.source_handler + self.source = new_source = Source() + self.previous_collector_token = _source_cvar.set(new_source) + return new_source else: - self.source_handler = self.previous_collector_token = None + self.source = self.previous_collector_token = None return None def __exit__( @@ -244,13 +190,11 @@ def __exit__( ) -> None: if self.previous_collector_token is not None: _source_cvar.reset(self.previous_collector_token) - if type_ is None: - if self.source_handler is not None and self.source_handler.key is None: - raise RuntimeError("Metrics source key was not set during collection.") + if type_ is None and self.source is not None and self.source.assigned_key is None: + raise RuntimeError("Metrics source key was not set during collection.") -def collect() -> CollectorContextManager: - return CollectorContextManager() +collect = CollectorContextManager def dumps(metric_sources: Mapping[str, Source] | None = None) -> str: diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index d873e7da17..bbc0f17fb1 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -17,8 +17,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping, utils as eve_utils -from gt4py.next import backend as gtx_backend, common, config, errors, metrics, utils as gtx_utils +from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront +from gt4py.next.instrumentation import _hook_machinery, metrics from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -26,6 +27,19 @@ T = TypeVar("T") +ScalarOrTupleOfScalars: TypeAlias = xtyping.MaybeNestedInTuple[core_defs.Scalar] +CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int] +ArgumentDescriptors: TypeAlias = dict[ + type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] +] +ArgumentDescriptorContext: TypeAlias = dict[ + str, xtyping.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None] +] +ArgumentDescriptorContexts: TypeAlias = dict[ + type[arguments.ArgStaticDescriptor], + ArgumentDescriptorContext, +] + # TODO(havogt): We would like this to be a ProcessPoolExecutor, which requires (to decide what) to pickle. _async_compilation_pool: concurrent.futures.Executor | None = None @@ -40,18 +54,16 @@ def _init_async_compilation_pool() -> None: _init_async_compilation_pool() -ScalarOrTupleOfScalars: TypeAlias = xtyping.MaybeNestedInTuple[core_defs.Scalar] -CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int] -ArgumentDescriptors: TypeAlias = dict[ - type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] -] -ArgumentDescriptorContext: TypeAlias = dict[ - str, xtyping.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None] -] -ArgumentDescriptorContexts: TypeAlias = dict[ - type[arguments.ArgStaticDescriptor], - ArgumentDescriptorContext, -] + +@_hook_machinery.EventHook +def compile_variant_hook( + key: CompiledProgramsKey, + backend: gtx_backend.Backend, + program_definition: ffront_stages.ProgramDefinition, + compile_time_args: arguments.CompileTimeArgs, +) -> None: + """Callback hook invoked before compiling a program variant.""" + ... def wait_for_compilation() -> None: @@ -263,8 +275,7 @@ def __call__( try: program = self.compiled_programs[key] if config.COLLECT_METRICS_LEVEL: - metrics_source = metrics.get_current_source() - metrics_source.key = self._metrics_key_from_pool_key(key) + metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) program(*args, **kwargs, offset_provider=offset_provider) # type: ignore[operator] # the Future case is handled below @@ -420,7 +431,7 @@ def _compile_variant( # If we are collecting metrics, create a new metrics entity for this compiled program if config.COLLECT_METRICS_LEVEL: - metrics_source = metrics.get_source(self._metrics_key_from_pool_key(key)) + metrics_source = metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) metrics_source.metadata |= dict( name=self.definition_stage.definition.__name__, backend=self.backend.name, @@ -442,6 +453,7 @@ def _compile_variant( compile_call = functools.partial( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args ) + compile_variant_hook(key, self.backend, self.definition_stage, compile_time_args) if _async_compilation_pool is None: # synchronous compilation self.compiled_programs[key] = compile_call() diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 996ba7a095..de3d25aaa5 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -14,7 +14,8 @@ import numpy as np from gt4py._core import definitions as core_defs -from gt4py.next import common as gtx_common, config, metrics, utils as gtx_utils +from gt4py.next import common as gtx_common, config, utils as gtx_utils +from gt4py.next.instrumentation import metrics from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable from gt4py.next.program_processors.runners.dace.workflow import ( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index b98e0504fd..6e6b6b904b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -15,7 +15,8 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config, metrics +from gt4py.next import common, config +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7bd796785d..27363ea2a8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,8 +15,9 @@ import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils, metrics +from gt4py.next import backend, common, config, field_utils from gt4py.next.embedded import nd_array_field +from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index f32867be7d..a6b6d69a5e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -13,8 +13,9 @@ import pytest from gt4py import next as gtx +from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir -from gt4py.next import common as gtx_common, metrics +from gt4py.next import common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -92,7 +93,9 @@ def testee(a: cases.IField, out: cases.IField): with ( mock.patch("gt4py.next.config.COLLECT_METRICS_LEVEL", metrics_level), - mock.patch("gt4py.next.metrics.sources", collections.defaultdict(metrics.Source)), + mock.patch( + "gt4py.next.instrumentation.metrics.sources", collections.defaultdict(metrics.Source) + ), ): testee = testee.with_backend(cartesian_case.backend).with_grid_type( cartesian_case.grid_type diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py new file mode 100644 index 0000000000..e5a25b3cd3 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -0,0 +1,167 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib +from typing import Any + +import pytest + +import gt4py.next as gtx +from gt4py.next import common, Dims, gtfn_cpu, typing as gtx_typing +from gt4py.next.instrumentation import hooks + +try: + from gt4py.next.program_processors.runners import dace as dace_backends + + BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu_cached] +except ImportError: + BACKENDS = [None, gtfn_cpu] + + +callback_results = [] +embedded_callback_results = [] + + +@contextlib.contextmanager +def custom_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + callback_results.append(("enter", None)) + + yield + + callback_results.append( + ( + "custom_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "enable_jit": enable_jit, + "kwargs": kwargs.keys(), + }, + ) + ) + + +@contextlib.contextmanager +def custom_embedded_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + embedded_callback_results.append(("enter", None)) + + yield + + embedded_callback_results.append( + ( + "custom_embedded_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "kwargs": kwargs.keys(), + }, + ) + ) + + +# @hooks.compile_variant_hook.register +# def compile_variant_hook( +# key: tuple[tuple[Hashable, ...], int], +# backend: gtx_backend.Backend, +# program_definition: "ffront_stages.ProgramDefinition", +# compile_time_args: "arguments.CompileTimeArgs", +# ) -> None: +# """Callback hook invoked before compiling a program variant.""" +# ... + + +Cell = gtx.Dimension("Cell") +IDim = gtx.Dimension("IDim") + + +@gtx.field_operator +def identity_fop( + in_field: gtx.Field[Dims[IDim], gtx.float64], +) -> gtx.Field[Dims[IDim], gtx.float64]: + return in_field + + +@gtx.program +def copy_program( + in_field: gtx.Field[Dims[IDim], gtx.float64], out: gtx.Field[Dims[IDim], gtx.float64] +): + identity_fop(in_field, out=out) + + +@pytest.mark.parametrize("backend", BACKENDS, ids=lambda b: getattr(b, "name", str(b))) +def test_program_call_hooks(backend: gtx_typing.Backend): + size = 10 + in_field = gtx.full([(IDim, size)], 1, dtype=gtx.float64) + out_field = gtx.empty([(IDim, size)], dtype=gtx.float64) + + test_program = copy_program.with_backend(backend) + + # Run the program without hooks + callback_results.clear() + embedded_callback_results.clear() + test_program(in_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() + + # Add hooks and run the program again + hooks.program_call_hook.register(custom_program_callback) + hooks.embedded_program_call_hook.register(custom_embedded_program_callback) + test_program(in_field, out=out_field) + + # Check that the callbacks were called + assert len(callback_results) == 2 + assert callback_results[0] == ("enter", None) + + hook_name, hook_call_info = callback_results[1] + assert hook_name == "custom_program_callback" + assert hook_call_info["program"] == test_program.__name__ + + # The embedded program call hook should have also been called + # with the embedded backend + if backend is None: + assert len(embedded_callback_results) == 2 + assert embedded_callback_results[0] == ("enter", None) + + hook_name, hook_call_info = embedded_callback_results[1] + assert hook_name == "custom_embedded_program_callback" + assert hook_call_info["program"] == copy_program.__name__ + else: + assert len(embedded_callback_results) == 0 + + callback_results.clear() + embedded_callback_results.clear() + + # Remove hooks and call the program again + hooks.program_call_hook.remove(custom_program_callback) + hooks.embedded_program_call_hook.remove(custom_embedded_program_callback) + test_program(in_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() diff --git a/tests/next_tests/unit_tests/instrumentation_tests/__init__.py b/tests/next_tests/unit_tests/instrumentation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/unit_tests/instrumentation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py new file mode 100644 index 0000000000..e8711e3913 --- /dev/null +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -0,0 +1,266 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib + +import pytest + +from gt4py.next.instrumentation._hook_machinery import ( + EventHook, + ContextHook, + _get_unique_name, + _is_empty_function, +) + + +def test_get_unique_name(): + def func1(): + pass + + def func2(): + pass + + assert _get_unique_name(func1) != _get_unique_name(func2) + + class A: + def __call__(self): ... + + assert _get_unique_name(A) == _get_unique_name(A) + + a1, a2 = A(), A() + + assert (a1_name := _get_unique_name(a1)) != (a2_name := _get_unique_name(a2)) + assert _get_unique_name(a1) == a1_name + assert _get_unique_name(a2) == a2_name + + +def test_empty_function(): + def empty(): + pass + + assert _is_empty_function(empty) is True + + def non_empty(): + return 1 + + assert _is_empty_function(non_empty) is False + + def with_docstring(): + """This is a docstring.""" + + assert _is_empty_function(with_docstring) is True + + def with_ellipsis(): ... + + assert _is_empty_function(with_ellipsis) is True + + class A: + def __call__(self): ... + + assert _is_empty_function(A()) is True + + +class TestEventHook: + def test_event_hook_call_with_no_callbacks(self): + @EventHook + def hook(x: int) -> None: + pass + + hook(42) # Should not raise + + def test_event_hook_call_with_callbacks(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(x) + + def callback2(x: int) -> None: + results.append(x * 2) + + hook.register(callback1) + hook.register(callback2) + hook(5) + + assert results == [5, 10] + + def test_event_hook_register_with_signature_mismatch(self): + @EventHook + def hook(x: int) -> None: + pass + + def bad_callback(x: int, y: int) -> None: + pass + + with pytest.raises(ValueError, match="Callback signature"): + hook.register(bad_callback) + + def test_event_hook_register_with_annotation_mismatch(self): + @EventHook + def hook(x: int) -> None: + pass + + def weird_callback(x: str) -> None: + pass + + with pytest.warns(UserWarning, match="Callback annotations"): + hook.register(weird_callback) + + def test_event_hook_register_with_name(self): + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + pass + + hook.register(callback, name="my_callback") + + assert "my_callback" in hook.registry + + def test_event_hook_register_with_index(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(1) + + def callback2(x: int) -> None: + results.append(2) + + hook.register(callback1) + hook.register(callback2, index=0) + hook(0) + + assert results == [2, 1] + + def test_event_hook_remove_by_name(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook.register(callback, name="test_cb") + hook(42) + assert results == [42] + + hook.remove("test_cb") + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_by_callback(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook.register(callback) + hook(42) + assert results == [42] + + hook.remove(callback) + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_nonexistent_raises(self): + @EventHook + def hook(x: int) -> None: + pass + + with pytest.raises(KeyError): + hook.remove("nonexistent") + + +class TestContextHook: + def test_context_hook_basic(self): + enter_called = [] + exit_called = [] + + @ContextHook + def hook() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(): + enter_called.append(True) + yield + exit_called.append(True) + + hook.register(callback) + + with hook(): + assert len(enter_called) == 1 + + assert len(exit_called) == 1 + + def test_context_hook_multiple_callbacks(self): + order = [] + + @ContextHook + def hook() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback1(): + order.append("enter1") + yield + order.append("exit1") + + @contextlib.contextmanager + def callback2(): + order.append("enter2") + yield + order.append("exit2") + + hook.register(callback1) + hook.register(callback2) + + with hook(): + pass + + # Entry in order, but exit in reverse + assert order == ["enter1", "enter2", "exit2", "exit1"] + + def test_context_hook_with_arguments(self): + results = [] + + @ContextHook + def hook(x: int) -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(x: int): + results.append(x) + yield + + hook.register(callback) + + with hook(42): + pass + + assert results == [42] diff --git a/tests/next_tests/unit_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py similarity index 99% rename from tests/next_tests/unit_tests/test_metrics.py rename to tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index 384435a091..c75c97b00f 100644 --- a/tests/next_tests/unit_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from gt4py.next import metrics +from gt4py.next.instrumentation import metrics from gt4py.next.otf import arguments