diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md
deleted file mode 100644
index d730eed37e..0000000000
--- a/docs/user/next/advanced/ToolchainWalkthrough.md
+++ /dev/null
@@ -1,400 +0,0 @@
-```python
-import dataclasses
-import inspect
-import pprint
-
-import gt4py.next as gtx
-from gt4py.next import backend
-
-import devtools
-```
-
-
-
-
-```python
-I = gtx.Dimension("I")
-Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,))
-OFFSET_PROVIDER = {"Ioff": I}
-```
-
-# Toolchain Overview
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-```
-
-# Walkthrough from Field Operator
-
-## Starting Out
-
-```python
-@gtx.field_operator
-def example_fo(a: gtx.Field[[I], gtx.float64]) -> gtx.Field[[I], gtx.float64]:
- return a + 1.0
-```
-
-```python
-start = example_fo.definition_stage
-```
-
-```python
-gtx.ffront.stages.FieldOperatorDefinition?
-```
-
-## DSL -> FOAST
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style fdef fill:red
-style foast fill:red
-linkStyle 0 stroke:red,stroke-width:4px,color:pink
-```
-
-```python
-foast = backend.DEFAULT_TRANSFORMS.func_to_foast(
- gtx.otf.toolchain.CompilableProgram(start, gtx.otf.arguments.CompileTimeArgs.empty())
-)
-```
-
-```python
-foast.data.__class__?
-```
-
-## FOAST -> ITIR
-
-This also happens inside the `decorator.FieldOperator.__gt_itir__` method during the lowering from calling Programs to ITIR
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style foast fill:red
-style itir_expr fill:red
-linkStyle 1 stroke:red,stroke-width:4px,color:pink
-```
-
-```python
-fitir = backend.DEFAULT_TRANSFORMS.foast_to_itir(foast)
-```
-
-```python
-fitir.__class__
-```
-
-## FOAST with args -> PAST with args
-
-This auto-generates a program for us, directly in PAST representation and forwards the call arguments to it
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style foast fill:red
-style past fill:red
-linkStyle 2 stroke:red,stroke-width:4px,color:pink
-```
-
-So far we have gotten away with empty compile time arguments, now we need to supply actual types. The easiest way to do that is from concrete arguments.
-
-```python
-jit_args = gtx.otf.arguments.JITArgs.from_signature(
- gtx.ones(domain={I: 10}, dtype=gtx.float64),
- out=gtx.zeros(domain={I: 10}, dtype=gtx.float64),
- offset_provider=OFFSET_PROVIDER,
-)
-
-aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete(*jit_args.args, **jit_args.kwargs)
-```
-
-```python
-pclos = backend.DEFAULT_TRANSFORMS.field_view_op_to_prog(
- gtx.otf.toolchain.CompilableProgram(data=foast.data, args=aot_args)
-)
-```
-
-```python
-pclos.data.__class__?
-```
-
-## Lint ProgramAST
-
-This checks the generated (or manually passed) PAST node.
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style past fill:red
-%%style tapast fill:red
-linkStyle 3 stroke:red,stroke-width:4px,color:pink
-```
-
-```python
-linted = backend.DEFAULT_TRANSFORMS.past_lint(pclos)
-```
-
-## Transform PAST closure arguments
-
-This turns data arguments (or rather, their compile-time standins) passed as keyword args (allowed in DSL programs) into positional args (the only way supported by all compiled programs). Included in this is the 'out' argument which is automatically added when generating a fieldview program from a fieldview operator.
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style past fill:red
-style tapast fill:red
-linkStyle 4 stroke:red,stroke-width:4px,color:pink
-```
-
-```python
-pclost = backend.DEFAULT_TRANSFORMS.field_view_prog_args_transform(pclos)
-```
-
-```python
-pprint.pprint(pclos.args)
-```
-
-```python
-pprint.pprint(pclost.args)
-```
-
-## Lower PAST -> ITIR
-
-still forwarding the call arguments
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style tapast fill:red
-style pcall fill:red
-linkStyle 5 stroke:red,stroke-width:4px,color:pink
-```
-
-```python
-pitir = backend.DEFAULT_TRANSFORMS.past_to_itir(pclost)
-```
-
-```python
-pitir.__class__?
-```
-
-## Executing The Result
-
-```python
-pprint.pprint(jit_args)
-```
-
-```python
-gtx.program_processors.runners.roundtrip.Roundtrip()(pitir)(*jit_args.args, **jit_args.kwargs)
-```
-
-```python
-pprint.pprint(jit_args)
-```
-
-## Full Field Operator Toolchain
-
-using the default step order
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style fdef fill:red
-style foast fill:red
-style past fill:red
-style tapast fill:red
-style pcall fill:red
-linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink
-```
-
-### Starting from DSL
-
-```python
-pitir2 = backend.DEFAULT_TRANSFORMS(gtx.otf.toolchain.CompilableProgram(data=start, args=aot_args))
-assert pitir2 == pitir
-```
-
-#### Pass The result to the compile workflow and execute
-
-```python
-example_compiled = gtx.program_processors.runners.roundtrip.Roundtrip()(pitir2)
-```
-
-```python
-example_compiled(*jit_args.args, **jit_args.kwargs)
-```
-
-We can re-run with the output from the previous run as in- and output.
-
-```python
-example_compiled(jit_args.kwargs["out"], *jit_args.args[1:], **jit_args.kwargs)
-```
-
-```python
-pprint.pprint(jit_args)
-```
-
-### Starting from FOAST
-
-Note that it is the exact same call but with a different input stage
-
-```python
-pitir3 = backend.DEFAULT_TRANSFORMS(
- gtx.otf.toolchain.CompilableProgram(data=foast.data, args=aot_args)
-)
-assert pitir3 == pitir
-```
-
-# Walkthrough starting from Program
-
-## Starting Out
-
-```python
-@gtx.program
-def example_prog(a: gtx.Field[[I], gtx.float64], out: gtx.Field[[I], gtx.float64]) -> None:
- example_fo(a, out=out)
-```
-
-```python
-p_start = example_prog.definition_stage
-```
-
-```python
-p_start.__class__?
-```
-
-## DSL -> PAST
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style pdef fill:red
-style past fill:red
-linkStyle 6 stroke:red,stroke-width:4px,color:pink
-```
-
-```python
-p_past = backend.DEFAULT_TRANSFORMS.func_to_past(
- gtx.otf.toolchain.CompilableProgram(
- data=p_start, args=gtx.otf.arguments.CompileTimeArgs.empty()
- )
-)
-```
-
-## Full Program Toolchain
-
-```mermaid
-graph LR
-
-fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]")
-foast -->|foast_to_itir| itir_expr(itir.Expr)
-foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]")
-past -->|past_lint| past
-past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]")
-tapast -->|past_to_itir| pcall(AOTProgram)
-
-pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]")
-
-style pdef fill:red
-style past fill:red
-style tapast fill:red
-style pcall fill:red
-linkStyle 3,4,5,6 stroke:red,stroke-width:4px,color:pink
-```
-
-### Starting from DSL
-
-```python
-p_itir1 = backend.DEFAULT_TRANSFORMS(
- gtx.otf.toolchain.CompilableProgram(data=p_start, args=jit_args)
-)
-```
-
-```python
-p_itir2 = backend.DEFAULT_TRANSFORMS(
- gtx.otf.toolchain.CompilableProgram(data=p_past.data, args=aot_args)
-)
-```
-
-```python
-assert p_itir1 == p_itir2
-```
diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py
index 1af766f539..33fe77c4a3 100644
--- a/src/gt4py/next/backend.py
+++ b/src/gt4py/next/backend.py
@@ -10,7 +10,7 @@
import dataclasses
import typing
-from typing import Any, Generic
+from typing import Generic
from gt4py._core import definitions as core_defs
from gt4py.next import allocators as next_allocators
@@ -21,7 +21,6 @@
func_to_past,
past_process_args,
past_to_itir,
- signature,
)
from gt4py.next.ffront.past_passes import linters as past_linters
from gt4py.next.ffront.stages import (
@@ -144,28 +143,6 @@ class Backend(Generic[core_defs.DeviceTypeT]):
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
transforms: workflow.Workflow[CompilableDefinition, stages.CompilableProgram]
- def __call__(
- self,
- program: IRDefinitionForm,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- if not isinstance(program, itir.Program):
- args, kwargs = signature.convert_to_positional(program, *args, **kwargs)
- # TODO(egparedes): this extraction is not strictly correct, as we should only
- # extract values from the correct container types, not from ANY container,
- # but that would require a larger refactoring and anyway this Backend class
- # should be removed in the future.
- extracted_args = tuple(arguments.extract(a) for a in args)
- extracted_kwargs = {k: arguments.extract(v) for k, v in kwargs.items()}
- self.jit(program, *args, **kwargs)(*extracted_args, **extracted_kwargs)
-
- def jit(self, program: IRDefinitionForm, *args: Any, **kwargs: Any) -> stages.CompiledProgram:
- if not isinstance(program, itir.Program):
- args, kwargs = signature.convert_to_positional(program, *args, **kwargs)
- aot_args = arguments.CompileTimeArgs.from_concrete(*args, **kwargs)
- return self.compile(program, aot_args)
-
def compile(
self, program: IRDefinitionForm, compile_time_args: arguments.CompileTimeArgs
) -> stages.CompiledProgram:
diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py
index 21a6c5b1fd..48c5b92ed5 100644
--- a/src/gt4py/next/ffront/decorator.py
+++ b/src/gt4py/next/ffront/decorator.py
@@ -12,23 +12,23 @@
from __future__ import annotations
+import abc
import dataclasses
import functools
import types
import typing
import warnings
-from collections.abc import Callable, Sequence
+from collections.abc import Callable
from typing import Any, Generic, Optional, TypeVar
from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
-from gt4py.eve.extended_typing import Self, override
+from gt4py.eve.extended_typing import Self, Unpack, override
from gt4py.next import (
allocators as next_allocators,
backend as next_backend,
common,
- config,
embedded as next_embedded,
errors,
utils,
@@ -38,20 +38,123 @@
field_operator_ast as foast,
foast_to_gtir,
past_process_args,
- signature,
stages as ffront_stages,
transform_utils,
+ type_info as ffront_type_info,
type_specifications as ts_ffront,
)
from gt4py.next.ffront.gtcallable import GTCallable
from gt4py.next.instrumentation import metrics
from gt4py.next.iterator import ir as itir
-from gt4py.next.otf import arguments, compiled_program, stages, toolchain
+from gt4py.next.otf import arguments, compiled_program, options, toolchain
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
DEFAULT_BACKEND: next_backend.Backend | None = None
+ProgramLikeDefinitionT = TypeVar(
+ "ProgramLikeDefinitionT", ffront_stages.ProgramDefinition, ffront_stages.FieldOperatorDefinition
+)
+
+
+@dataclasses.dataclass(frozen=True)
+class _ProgramLikeMixin(Generic[ProgramLikeDefinitionT]):
+ """
+ Mixing used by program and program-like objects.
+
+ Contains functionality and configuration options common to all kinds of program-likes.
+ """
+
+ definition_stage: ProgramLikeDefinitionT
+ backend: Optional[next_backend.Backend]
+ compilation_options: options.CompilationOptions
+
+ @abc.abstractmethod
+ def __gt_type__(self) -> ts.CallableType: ...
+
+ def with_backend(self, backend: next_backend.Backend) -> Self:
+ return dataclasses.replace(self, backend=backend)
+
+ def with_connectivities(
+ self,
+ connectivities: common.OffsetProvider, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
+ ) -> Self:
+ return dataclasses.replace(
+ self,
+ compilation_options=dataclasses.replace(
+ self.compilation_options, connectivities=connectivities
+ ),
+ )
+
+ @functools.cached_property
+ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool:
+ if self.backend is None or self.backend == eve.NOTHING:
+ raise RuntimeError("Cannot compile a program without backend.")
+
+ if self.compilation_options.static_params is None:
+ object.__setattr__(self.compilation_options, "static_params", ())
+
+ argument_descriptor_mapping = {
+ arguments.StaticArg: self.compilation_options.static_params,
+ }
+
+ program_type = ffront_type_info.type_in_program_context(self.__gt_type__())
+ assert isinstance(program_type, ts_ffront.ProgramType)
+
+ return compiled_program.CompiledProgramsPool(
+ backend=self.backend,
+ definition_stage=self.definition_stage,
+ program_type=program_type,
+ argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible
+ )
+
+ def compile(
+ self,
+ offset_provider: common.OffsetProviderType
+ | common.OffsetProvider
+ | list[common.OffsetProviderType | common.OffsetProvider]
+ | None = None,
+ enable_jit: bool | None = None,
+ **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
+ ) -> Self:
+ """
+ Compiles the program or operator for the given combination of static arguments and offset
+ provider type.
+
+ Note: Unlike `with_...` methods, this method does not return a new instance of the program,
+ but adds the compiled variants to the current program instance.
+ """
+ # TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
+ # rename to `with_static_args` or similar) once we have a better understanding of the
+ # use-cases.
+
+ if enable_jit is not None:
+ object.__setattr__(self.compilation_options, "enable_jit", enable_jit)
+ if self.compilation_options.static_params is None:
+ object.__setattr__(self.compilation_options, "static_params", tuple(static_args.keys()))
+ if self.compilation_options.connectivities is None and offset_provider is None:
+ raise ValueError(
+ "Cannot compile a program without connectivities / OffsetProviderType."
+ )
+ if not all(isinstance(v, list) for v in static_args.values()):
+ raise TypeError(
+ "Please provide the static arguments as lists."
+ ) # To avoid confusion with tuple args
+
+ offset_provider = (
+ self.compilation_options.connectivities if offset_provider is None else offset_provider
+ )
+ if not isinstance(offset_provider, list):
+ offset_provider = [offset_provider] # type: ignore[list-item] # cleanup offset_provider vs offset_provider_type
+
+ assert all(
+ common.is_offset_provider(op) or common.is_offset_provider_type(op)
+ for op in offset_provider
+ )
+
+ self._compiled_programs.compile(offset_providers=offset_provider, **static_args)
+ return self
+
program_call_metrics_collector = metrics.make_collector(
level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC
@@ -61,7 +164,7 @@
# TODO(tehrengruber): Decide if and how programs can call other programs. As a
# result Program could become a GTCallable.
@dataclasses.dataclass(frozen=True)
-class Program:
+class Program(_ProgramLikeMixin[ffront_stages.ProgramDefinition]):
"""
Construct a program object from a PAST node.
@@ -81,37 +184,25 @@ class Program:
i.e. DaCe programs that call GT4Py Programs -SDFGConvertible interface-.
"""
- definition_stage: ffront_stages.ProgramDefinition
- backend: Optional[next_backend.Backend]
- connectivities: Optional[
- common.OffsetProvider
- ] # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
- enable_jit: bool | None
- static_params: (
- Sequence[str] | None
- ) # if the user requests static params, they will be used later to initialize CompiledPrograms
-
@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: next_backend.Backend | None,
grid_type: common.GridType | None = None,
- enable_jit: bool | None = None,
- static_params: Sequence[str] | None = None,
- connectivities: Optional[
- common.OffsetProvider
- ] = None, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
+ **compilation_options: Unpack[options.CompilationOptionsArgs],
) -> Program:
program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type)
return cls(
definition_stage=program_def,
backend=backend,
- connectivities=connectivities,
- enable_jit=enable_jit,
- static_params=static_params,
+ compilation_options=options.CompilationOptions(**compilation_options),
)
+ def __gt_type__(self) -> ts_ffront.ProgramType:
+ assert isinstance(self.past_stage.past_node.type, ts_ffront.ProgramType)
+ return self.past_stage.past_node.type
+
# TODO(ricoh): linting should become optional, up to the backend.
def __post_init__(self) -> None:
no_args_past = toolchain.CompilableProgram(
@@ -169,36 +260,6 @@ def gtir(self) -> itir.Program:
)
return self._frontend_transforms.past_to_itir(no_args_past).data
- @functools.cached_property
- def _compiled_programs(self) -> compiled_program.CompiledProgramsPool:
- if self.backend is None or self.backend == eve.NOTHING:
- raise RuntimeError("Cannot compile a program without backend.")
-
- if self.static_params is None:
- object.__setattr__(self, "static_params", ())
-
- argument_descriptor_mapping = {
- arguments.StaticArg: self.static_params,
- }
-
- program_type = self.past_stage.past_node.type
- assert isinstance(program_type, ts_ffront.ProgramType)
- return compiled_program.CompiledProgramsPool(
- backend=self.backend,
- definition_stage=self.definition_stage,
- program_type=program_type,
- argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible
- )
-
- def with_backend(self, backend: next_backend.Backend) -> Program:
- return dataclasses.replace(self, backend=backend)
-
- def with_connectivities(
- self,
- connectivities: common.OffsetProvider, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime informatio
- ) -> Program:
- return dataclasses.replace(self, connectivities=connectivities)
-
def with_grid_type(self, grid_type: common.GridType) -> Program:
return dataclasses.replace(
self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type)
@@ -212,7 +273,9 @@ def with_static_params(self, *static_params: str | None) -> Program:
_static_params = typing.cast(tuple[str], static_params)
return dataclasses.replace(
self,
- static_params=_static_params,
+ compilation_options=dataclasses.replace(
+ self.compilation_options, static_params=_static_params
+ ),
)
def with_bound_args(self, **kwargs: Any) -> ProgramWithBoundArgs:
@@ -262,10 +325,7 @@ def __call__(
) -> None:
if offset_provider is None:
offset_provider = {}
- if enable_jit is None:
- enable_jit = (
- self.enable_jit if self.enable_jit is not None else config.ENABLE_JIT_DEFAULT
- )
+ enable_jit = self.compilation_options.enable_jit if enable_jit is None else enable_jit
with program_call_metrics_collector():
if __debug__:
@@ -299,105 +359,6 @@ def __call__(
with next_embedded.context.update(offset_provider=offset_provider):
self.definition_stage.definition(*args, **kwargs)
- def compile(
- self,
- offset_provider: common.OffsetProviderType
- | common.OffsetProvider
- | list[common.OffsetProviderType | common.OffsetProvider]
- | None = None,
- enable_jit: bool | None = None,
- **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
- ) -> Self:
- """
- Compiles the program for the given combination of static arguments and offset provider type.
-
- Note: Unlike `with_...` methods, this method does not return a new instance of the program,
- but adds the compiled variants to the current program instance.
- """
- # TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
- # rename to `with_static_args` or similar) once we have a better understanding of the
- # use-cases.
-
- if enable_jit is not None:
- object.__setattr__(self, "enable_jit", enable_jit)
- if self.static_params is None:
- object.__setattr__(self, "static_params", tuple(static_args.keys()))
- if self.connectivities is None and offset_provider is None:
- raise ValueError(
- "Cannot compile a program without connectivities / OffsetProviderType."
- )
- if not all(isinstance(v, list) for v in static_args.values()):
- raise TypeError(
- "Please provide the static arguments as lists."
- ) # To avoid confusion with tuple args
-
- offset_provider = self.connectivities if offset_provider is None else offset_provider
- if not isinstance(offset_provider, list):
- offset_provider = [offset_provider] # type: ignore[list-item] # cleanup offset_provider vs offset_provider_type
-
- assert all(
- common.is_offset_provider(op) or common.is_offset_provider_type(op)
- for op in offset_provider
- )
-
- self._compiled_programs.compile(offset_providers=offset_provider, **static_args)
- return self
-
- def freeze(self) -> FrozenProgram:
- if self.backend is None:
- raise ValueError("Can not freeze a program without backend (embedded execution).")
- return FrozenProgram(
- self.definition_stage if self.definition_stage else self.past_stage,
- backend=self.backend,
- )
-
-
-@dataclasses.dataclass(frozen=True)
-class FrozenProgram:
- """
- Simplified program instance, which skips the whole toolchain after the first execution.
-
- Does not work in embedded execution.
- """
-
- program: ffront_stages.DSL_PRG | ffront_stages.PAST_PRG
- backend: next_backend.Backend
- _compiled_program: Optional[stages.CompiledProgram] = dataclasses.field(
- init=False, default=None
- )
-
- def __post_init__(self) -> None:
- if self.backend is None:
- raise ValueError("Can not JIT-compile programs without backend (embedded execution).")
-
- @property
- def definition(self) -> types.FunctionType:
- # `PastProgramDefinition` doesn't have `definition`
- assert isinstance(self.program, ffront_stages.ProgramDefinition)
- return self.program.definition
-
- def with_backend(self, backend: next_backend.Backend) -> FrozenProgram:
- return self.__class__(program=self.program, backend=backend)
-
- def with_grid_type(self, grid_type: common.GridType) -> FrozenProgram:
- return self.__class__(
- program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend
- )
-
- def jit(
- self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any
- ) -> stages.CompiledProgram:
- return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs)
-
- def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None:
- args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs)
-
- if not self._compiled_program:
- super().__setattr__(
- "_compiled_program", self.jit(*args, offset_provider=offset_provider, **kwargs)
- )
- self._compiled_program(*args, offset_provider=offset_provider, **kwargs) # type: ignore[misc] # _compiled_program is not None
-
try:
from gt4py.next.program_processors.runners.dace.program import ( # type: ignore[assignment]
@@ -407,42 +368,6 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs:
pass
-# TODO(tehrengruber): This class does not follow the Liskov-Substitution principle as it doesn't
-# have a program definition. Revisit.
-@dataclasses.dataclass(frozen=True)
-class ProgramFromPast(Program):
- """
- This version of program has no DSL definition associated with it.
-
- PAST nodes can be built programmatically from field operators or from scratch.
- This wrapper provides the appropriate toolchain entry points.
- """
-
- past_stage: ffront_stages.PastProgramDefinition
-
- @override
- def __call__(
- self, *args: Any, offset_provider: Optional[common.OffsetProvider] = None, **kwargs: Any
- ) -> None:
- if self.backend is None:
- raise NotImplementedError(
- "Programs created from a PAST node (without a function definition) can not be executed in embedded mode"
- )
-
- if offset_provider is None:
- offset_provider = {}
- # TODO(ricoh): add test that does the equivalent of IDim + 1 in a ProgramFromPast
- self.backend(
- self.past_stage,
- *args,
- **(kwargs | {"offset_provider": {**offset_provider}}),
- )
-
- # TODO(ricoh): linting should become optional, up to the backend.
- def __post_init__(self) -> None:
- self._frontend_transforms.past_lint(self.past_stage) # type: ignore[arg-type] # ignored because the class has more TODO than code
-
-
@dataclasses.dataclass(frozen=True)
class ProgramWithBoundArgs(Program):
bound_args: dict[str, float | int | bool] = dataclasses.field(default_factory=dict)
@@ -525,9 +450,7 @@ def program(
*,
backend: next_backend.Backend | eve.NothingType | None,
grid_type: common.GridType | None,
- enable_jit: bool | None,
- static_params: Sequence[str] | None,
- frozen: bool,
+ **compilation_options: Unpack[options.CompilationOptionsArgs],
) -> Callable[[types.FunctionType], Program]: ...
@@ -537,10 +460,8 @@ def program(
# `NOTHING` -> default backend, `None` -> no backend (embedded execution)
backend: next_backend.Backend | eve.NothingType | None = eve.NOTHING,
grid_type: common.GridType | None = None,
- enable_jit: bool | None = None, # only relevant if static_params are set
- static_params: Sequence[str] | None = None,
- frozen: bool = False,
-) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]:
+ **compilation_options: Unpack[options.CompilationOptionsArgs],
+) -> Program | Callable[[types.FunctionType], Program]:
"""
Generate an implementation of a program from a Python function object.
@@ -566,11 +487,8 @@ def program_inner(definition: types.FunctionType) -> Program:
next_backend.Backend | None, DEFAULT_BACKEND if backend is eve.NOTHING else backend
),
grid_type=grid_type,
- enable_jit=enable_jit,
- static_params=static_params,
+ **compilation_options,
)
- if frozen:
- return program.freeze() # type: ignore[return-value] # TODO(havogt): Should `FrozenProgram` be a `Program`?
return program
return program_inner if definition is None else program_inner(definition)
@@ -580,7 +498,9 @@ def program_inner(definition: types.FunctionType) -> Program:
@dataclasses.dataclass(frozen=True)
-class FieldOperator(GTCallable, Generic[OperatorNodeT]):
+class FieldOperator(
+ _ProgramLikeMixin[ffront_stages.FieldOperatorDefinition], GTCallable, Generic[OperatorNodeT]
+):
"""
Construct a field operator object from a FOAST node.
@@ -601,12 +521,6 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
it will be deduced from actually occurring dimensions.
"""
- definition_stage: ffront_stages.FieldOperatorDefinition
- backend: Optional[next_backend.Backend]
- _program_cache: dict = dataclasses.field(
- init=False, default_factory=dict
- ) # init=False ensure the cache is not copied in calls to replace
-
@classmethod
def from_function(
cls,
@@ -616,6 +530,7 @@ def from_function(
*,
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, # type: ignore[assignment] # TODO(ricoh): understand why mypy complains
operator_attributes: Optional[dict[str, Any]] = None,
+ **compilation_options: Unpack[options.CompilationOptionsArgs],
) -> FieldOperator[OperatorNodeT]:
return cls(
definition_stage=ffront_stages.FieldOperatorDefinition(
@@ -625,6 +540,7 @@ def from_function(
attributes=operator_attributes or {},
),
backend=backend,
+ compilation_options=options.CompilationOptions(**compilation_options),
)
# TODO(ricoh): linting should become optional, up to the backend.
@@ -662,9 +578,6 @@ def __gt_type__(self) -> ts.CallableType:
assert isinstance(type_, ts.CallableType)
return type_
- def with_backend(self, backend: next_backend.Backend) -> FieldOperator:
- return dataclasses.replace(self, backend=backend)
-
def with_grid_type(self, grid_type: common.GridType) -> FieldOperator:
return dataclasses.replace(
self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type)
@@ -683,27 +596,7 @@ def __gt_gtir__(self) -> itir.FunctionDefinition:
def __gt_closure_vars__(self) -> dict[str, Any]:
return self.foast_stage.closure_vars
- def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program:
- foast_with_types = (
- toolchain.CompilableProgram(
- data=self.foast_stage,
- args=compiletime_args,
- ),
- )
-
- past_stage = self._frontend_transforms.field_view_op_to_prog.foast_to_past( # type: ignore[attr-defined] # TODO(havogt): needs more work
- foast_with_types
- ).data
- return ProgramFromPast(
- definition_stage=None, # type: ignore[arg-type] # ProgramFromPast needs to be fixed
- past_stage=past_stage,
- backend=self.backend,
- connectivities=None,
- enable_jit=False, # TODO(havogt): revisit ProgramFromPast
- static_params=None, # TODO(havogt): revisit ProgramFromPast
- )
-
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ def __call__(self, *args: Any, enable_jit: bool | None = None, **kwargs: Any) -> Any:
if not next_embedded.context.within_valid_context() and self.backend is not None:
# non embedded execution
offset_provider = {**kwargs.pop("offset_provider", {})}
@@ -716,15 +609,14 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
domain = utils.tree_map(lambda _: domain)(out)
out = utils.tree_map(lambda f, dom: f[dom])(out, domain)
- args, kwargs = type_info.canonicalize_arguments(
- self.foast_stage.foast_node.type, args, kwargs
- )
- return self.backend(
- self.definition_stage,
+ return self._compiled_programs(
*args,
+ **kwargs,
out=out,
offset_provider=offset_provider,
- **kwargs,
+ enable_jit=self.compilation_options.enable_jit
+ if enable_jit is None
+ else enable_jit,
)
else:
if not next_embedded.context.within_valid_context():
@@ -752,6 +644,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return embedded_operators.field_operator_call(op, args, kwargs)
+# TODO(tehrengruber): This class does not follow the Liskov-Substitution principle as it doesn't
+# have a field operator definition. Currently implementation is merely a hack to keep the only
+# test relying on this working. Revisit.
@dataclasses.dataclass(frozen=True)
class FieldOperatorFromFoast(FieldOperator):
"""
@@ -767,7 +662,10 @@ class FieldOperatorFromFoast(FieldOperator):
@override
def __call__(self, *args: Any, **kwargs: Any) -> Any:
assert self.backend is not None
- return self.backend(self.foast_stage, *args, **kwargs)
+ compiled_fo = self.backend.compile(
+ self.foast_stage, arguments.CompileTimeArgs.from_concrete(*args, **kwargs)
+ )
+ return compiled_fo(*args, **kwargs)
@typing.overload
@@ -790,6 +688,7 @@ def field_operator(
*,
backend: next_backend.Backend | eve.NothingType | None = eve.NOTHING,
grid_type: common.GridType | None = None,
+ **compilation_options: Unpack[options.CompilationOptionsArgs],
) -> (
FieldOperator[foast.FieldOperator]
| Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]
@@ -817,6 +716,7 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.
next_backend.Backend | None, DEFAULT_BACKEND if backend is eve.NOTHING else backend
),
grid_type,
+ **compilation_options,
)
return field_operator_inner if definition is None else field_operator_inner(definition)
@@ -918,9 +818,3 @@ def add_foast_fieldop_to_fingerprint(
def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None:
ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher)
ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
-
-
-@ffront_stages.add_content_to_fingerprint.register
-def add_past_program_to_fingerprint(obj: ProgramFromPast, hasher: xtyping.HashlibAlgorithm) -> None:
- ffront_stages.add_content_to_fingerprint(obj.past_stage, hasher)
- ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py
index f16ad104bb..5e03e37b8d 100644
--- a/src/gt4py/next/ffront/foast_to_past.py
+++ b/src/gt4py/next/ffront/foast_to_past.py
@@ -14,6 +14,7 @@
foast_to_gtir,
program_ast as past,
stages as ffront_stages,
+ type_info as ffront_type_info,
type_specifications as ts_ffront,
)
from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction
@@ -70,9 +71,12 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]):
>>> op_to_prog = OperatorToProgram(foast_to_gtir.adapted_foast_to_gtir_factory())
>>> compile_time_args = arguments.CompileTimeArgs(
- ... args=tuple(param.type for param in copy.foast_stage.foast_node.definition.params),
+ ... args=(
+ ... *(param.type for param in copy.foast_stage.foast_node.definition.params),
+ ... copy.foast_stage.foast_node.definition.type.returns,
+ ... ),
... kwargs={},
- ... offset_provider={"I", IDim},
+ ... offset_provider={"I": IDim},
... column_axis=None,
... argument_descriptor_contexts={},
... )
@@ -94,20 +98,23 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG:
# of arg and kwarg types
# TODO(tehrengruber): check foast operator has no out argument that clashes
# with the out argument of the program we generate here.
-
arg_types, kwarg_types = inp.args.args, inp.args.kwargs
+ assert not kwarg_types
+ type_ = inp.data.foast_node.type
loc = inp.data.foast_node.location
- definition = inp.data.foast_node.type.definition
+ partial_program_type = ffront_type_info.type_in_program_context(inp.data.foast_node.type)
+ assert isinstance(partial_program_type, ts_ffront.ProgramType)
args_names = [
- *definition.pos_only_args,
- *definition.pos_or_kw_args.keys(),
- *definition.kw_only_args.keys(),
+ *partial_program_type.definition.pos_only_args,
+ *partial_program_type.definition.pos_or_kw_args.keys(),
+ *partial_program_type.definition.kw_only_args.keys(),
]
- if isinstance(inp.data.foast_node.type, ts_ffront.ScanOperatorType):
- args_names = args_names[1:] # carry argument is not in parameter list
+ assert arg_types[-1] == type_info.return_type(
+ type_, with_args=list(arg_types), with_kwargs=kwarg_types
+ )
+ assert args_names[-1] == "out"
- type_ = inp.data.foast_node.type
params_decl: list[past.Symbol] = [
past.DataSymbol(
id=name,
@@ -121,13 +128,7 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG:
strict=True,
)
]
- params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl]
- out_sym: past.Symbol = past.DataSymbol(
- id="out",
- type=type_info.return_type(type_, with_args=list(arg_types), with_kwargs=kwarg_types),
- namespace=dialect_ast_enums.Namespace.LOCAL,
- location=loc,
- )
+ params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl[:-1]]
out_ref = past.Name(id="out", location=loc)
if inp.data.foast_node.id in inp.data.closure_vars:
@@ -147,7 +148,7 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG:
untyped_past_node = past.Program(
id=f"__field_operator_{inp.data.foast_node.id}",
type=ts.DeferredType(constraint=ts_ffront.ProgramType),
- params=[*params_decl, out_sym],
+ params=params_decl,
body=[
past.Call(
func=past.Name(id=inp.data.foast_node.id, location=loc),
diff --git a/src/gt4py/next/ffront/signature.py b/src/gt4py/next/ffront/signature.py
deleted file mode 100644
index 4a58d56f57..0000000000
--- a/src/gt4py/next/ffront/signature.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# GT4Py - GridTools Framework
-#
-# Copyright (c) 2014-2024, ETH Zurich
-# All rights reserved.
-#
-# Please, refer to the LICENSE file in the root directory.
-# SPDX-License-Identifier: BSD-3-Clause
-
-# TODO(ricoh): This overlaps with `canonicalize_arguments`, solutions:
-# - merge the two
-# - extract the signature gathering functionality from canonicalize_arguments
-# and use it to pass the signature through the toolchain so that the
-# decorate step can take care of it. Then get rid of all pre-toolchain
-# arguments rearranging (including this module)
-
-from __future__ import annotations
-
-import functools
-import inspect
-import types
-from typing import Any, Callable
-
-from gt4py.next.ffront import (
- field_operator_ast as foast,
- program_ast as past,
- stages as ffront_stages,
-)
-from gt4py.next.type_system import type_specifications as ts
-
-
-def should_be_positional(param: inspect.Parameter) -> bool:
- return (param.kind is inspect.Parameter.POSITIONAL_ONLY) or (
- param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD
- )
-
-
-@functools.singledispatch
-def make_signature(func: Any) -> inspect.Signature:
- """Make a signature for a Python or DSL callable, which suffices for use in 'convert_to_positional'."""
- if isinstance(func, types.FunctionType):
- return inspect.signature(func)
- raise NotImplementedError(f"'make_signature' not implemented for {type(func)}.")
-
-
-@make_signature.register(foast.ScanOperator)
-@make_signature.register(past.Program)
-@make_signature.register(foast.FieldOperator)
-def signature_from_fieldop(func: foast.FieldOperator) -> inspect.Signature:
- if isinstance(func.type, ts.DeferredType):
- raise NotImplementedError(
- f"'make_signature' not implemented for pre type deduction {type(func)}."
- )
- fieldview_signature = func.type.definition
- return inspect.Signature(
- parameters=[
- inspect.Parameter(name=str(i), kind=inspect.Parameter.POSITIONAL_ONLY)
- for i, param in enumerate(fieldview_signature.pos_only_args)
- ]
- + [
- inspect.Parameter(name=k, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
- for k in fieldview_signature.pos_or_kw_args
- ],
- )
-
-
-@make_signature.register(ffront_stages.FieldOperatorDefinition)
-def signature_from_fieldop_def(func: ffront_stages.FieldOperatorDefinition) -> inspect.Signature:
- signature = make_signature(func.definition)
- if func.node_class == foast.ScanOperator:
- return inspect.Signature(list(signature.parameters.values())[1:])
- return signature
-
-
-@make_signature.register(ffront_stages.ProgramDefinition)
-def signature_from_program_def(func: ffront_stages.ProgramDefinition) -> inspect.Signature:
- return make_signature(func.definition)
-
-
-@make_signature.register(ffront_stages.FoastOperatorDefinition)
-def signature_from_foast_stage(func: ffront_stages.FoastOperatorDefinition) -> inspect.Signature:
- return make_signature(func.foast_node)
-
-
-@make_signature.register
-def signature_from_past_stage(func: ffront_stages.PastProgramDefinition) -> inspect.Signature:
- return make_signature(func.past_node)
-
-
-def convert_to_positional(
- func: Callable
- | foast.FieldOperator
- | foast.ScanOperator
- | ffront_stages.FieldOperatorDefinition
- | ffront_stages.FoastOperatorDefinition
- | ffront_stages.ProgramDefinition
- | ffront_stages.PastProgramDefinition,
- *args: Any,
- **kwargs: Any,
-) -> tuple[tuple[Any, ...], dict[str, Any]]:
- """
- Convert arguments given as keyword args to positional ones where possible.
-
- Raises en error if and only if there are clearly missing positional arguments,
- Without awareness of the peculiarities of DSL function signatures. A more
- thorough check on whether the signature is fulfilled is expected to happen
- later in the toolchain.
-
- Note that positional-or-keyword arguments with defaults will have their defaults
- inserted even if not strictly necessary. This is to reduce complexity and should
- be changed if the current behavior is found harmful in some way.
-
- Examples:
- >>> def example(posonly, /, pos_or_key, pk_with_default=42, *, key_only=43):
- ... pass
-
- >>> convert_to_positional(example, 1, pos_or_key=2, key_only=3)
- ((1, 2, 42), {'key_only': 3})
- >>> # inserting the default value '42' here could be avoided
- >>> # but this is not the current behavior.
- """
- signature = make_signature(func)
- new_args = list(args)
- modified_kwargs = kwargs.copy()
- missing = []
- interesting_params = [p for p in signature.parameters.values() if should_be_positional(p)]
-
- for param in interesting_params[len(args) :]:
- if param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD and param.name in modified_kwargs:
- # if keyword allowed, check if was given as kwarg
- new_args.append(modified_kwargs.pop(param.name))
- else:
- # add default and report as missing if no default
- # note: this treats POSITIONAL_ONLY params correctly, as they can not have a default.
- new_args.append(param.default)
- if param.default is inspect._empty:
- missing.append(param.name)
- if missing:
- raise TypeError(f"Missing positional argument(s): {', '.join(missing)}.")
- return tuple(new_args), modified_kwargs
diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py
index af9af29eae..1f77224885 100644
--- a/src/gt4py/next/ffront/type_info.py
+++ b/src/gt4py/next/ffront/type_info.py
@@ -5,16 +5,17 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
-
import functools
+import inspect
from collections.abc import Callable
-from typing import Iterator, Sequence, cast
+from typing import Any, Iterator, Sequence, cast
import gt4py.next.ffront.type_specifications as ts_ffront
import gt4py.next.type_system.type_specifications as ts
from gt4py.eve import datamodels
from gt4py.eve.extended_typing import NestedTuple
from gt4py.next import common, utils
+from gt4py.next.ffront.type_specifications import ProgramType
from gt4py.next.type_system import type_info
@@ -349,3 +350,111 @@ def return_type_scanop(
ts.TypeSpec,
tree_map_type(lambda arg: ts.FieldType(dims=promoted_dims, dtype=arg))(carry_dtype),
)
+
+
+def type_in_program_context(callable_type: ts.CallableType) -> ProgramType | ts.FunctionType:
+ """
+ Return the type of a callable when encountered in context of a program.
+
+ A callable can be a field-, scan-operator or a simple function (though the latter is not
+ implemented in the frontent). The program context is either inside of a program or even
+ outside the GT4Py where all callables behave as if they were called from inside a program.
+
+ For example a simple field operator like
+
+ ```
+ @field_operator
+ def identity(a: IField) -> IField: ...
+ ```
+
+ has the signature of the following program in the context of a program.
+
+ ```
+ @program
+ def identity(a: IField, *, out: IField) -> None: ...
+ ```
+ """
+ if isinstance(callable_type, ts_ffront.FieldOperatorType):
+ definition = callable_type.definition
+ return ProgramType(
+ definition=ts.FunctionType(
+ pos_only_args=definition.pos_only_args,
+ pos_or_kw_args=definition.pos_or_kw_args | {"out": definition.returns},
+ kw_only_args=definition.kw_only_args,
+ returns=ts.VoidType(),
+ )
+ )
+ elif isinstance(callable_type, ts_ffront.ScanOperatorType):
+ as_deferred_type_with_same_structure = tree_map_type(
+ lambda _: ts.DeferredType(constraint=None)
+ )
+ scan_pass_type = callable_type.definition
+ _, *non_carry_args = scan_pass_type.pos_or_kw_args.items()
+ pos_or_kw_args = dict(non_carry_args) | {"out": scan_pass_type.returns}
+ assert not scan_pass_type.pos_only_args
+ return ProgramType(
+ ts.FunctionType(
+ pos_only_args=[],
+ # TODO(tehrengruber): What we actually want is a generic type here, but we don't
+ # have that concept yet.
+ pos_or_kw_args={
+ k: as_deferred_type_with_same_structure(t) for k, t in pos_or_kw_args.items()
+ },
+ kw_only_args={
+ k: as_deferred_type_with_same_structure(t)
+ for k, t in scan_pass_type.kw_only_args.items()
+ },
+ returns=ts.VoidType(),
+ )
+ )
+ assert isinstance(callable_type, (ts.FunctionType, ts_ffront.ProgramType))
+ return callable_type
+
+
+def _signature_from_callable_in_program_context(
+ callable_type: ts.CallableType,
+) -> inspect.Signature:
+ if isinstance(callable_type, ts_ffront.ProgramType):
+ return _signature_from_callable_in_program_context(callable_type.definition)
+ elif isinstance(callable_type, ts_ffront.FieldOperatorType | ts_ffront.ScanOperatorType):
+ operator_signature = _signature_from_callable_in_program_context(callable_type.definition)
+ params = list(operator_signature.parameters.values())
+ if isinstance(callable_type, ts_ffront.ScanOperatorType):
+ params = params[1:] # Remove the carry state arg
+ return inspect.Signature(
+ parameters=[*params, inspect.Parameter("out", inspect.Parameter.KEYWORD_ONLY)],
+ return_annotation=inspect.Signature.empty,
+ )
+ assert isinstance(callable_type, ts.FunctionType)
+ return inspect.Signature(
+ parameters=(
+ [
+ *(
+ inspect.Parameter(name=str(i), kind=inspect.Parameter.POSITIONAL_ONLY)
+ for i, type_ in enumerate(callable_type.pos_only_args)
+ ),
+ *(
+ inspect.Parameter(name=name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ for name, type_ in callable_type.pos_or_kw_args.items()
+ ),
+ *(
+ inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY)
+ for name, type_ in callable_type.kw_only_args.items()
+ ),
+ ]
+ ),
+ return_annotation=callable_type.returns,
+ )
+
+
+def make_args_canonicalizer(
+ callable_type: ts.CallableType, **kwargs: Any
+) -> Callable[..., tuple[tuple, dict[str, Any]]]:
+ """
+ Create a call arguments canonicalizer function from a given signature.
+
+ See :ref:`utils.make_args_canonicalizer`.
+ """
+ return utils.make_args_canonicalizer(
+ _signature_from_callable_in_program_context(callable_type), **kwargs
+ )
diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py
index a995385b0f..18c4f9f897 100644
--- a/src/gt4py/next/iterator/runtime.py
+++ b/src/gt4py/next/iterator/runtime.py
@@ -20,6 +20,7 @@
from gt4py.next import common, config
from gt4py.next.iterator import builtins, dispatcher
from gt4py.next.iterator.builtins import BackendNotSelectedError, builtin_dispatch
+from gt4py.next.otf import arguments
from gt4py.next.program_processors import program_formatter
from gt4py.next.type_system import type_specifications as ts, type_translation as tt
@@ -88,8 +89,11 @@ def __call__(
if isinstance(backend, next_backend.Backend):
assert isinstance(backend, next_backend.Backend)
- compiled_program = backend.jit(
- itir_node, *args, offset_provider=offset_provider, column_axis=column_axis
+ compiled_program = backend.compile(
+ itir_node,
+ arguments.CompileTimeArgs.from_concrete(
+ *args, offset_provider=offset_provider, column_axis=column_axis
+ ),
)
compiled_program(*args, offset_provider=offset_provider)
elif isinstance(backend, program_formatter.ProgramFormatter):
diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py
index f4e8e0564b..301e06292f 100644
--- a/src/gt4py/next/iterator/tracing.py
+++ b/src/gt4py/next/iterator/tracing.py
@@ -14,7 +14,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import Node, utils as eve_utils
from gt4py.next import common, iterator
-from gt4py.next.iterator import builtins, ir as itir
+from gt4py.next.iterator import builtins, ir as itir, runtime as iterator_runtime
from gt4py.next.iterator.ir import (
AxisLiteral,
Expr,
@@ -100,7 +100,7 @@ def _s(id_):
def trace_function_argument(arg):
- if isinstance(arg, iterator.runtime.FundefDispatcher):
+ if isinstance(arg, iterator_runtime.FundefDispatcher):
make_function_definition(arg.fun)
return _s(arg.fun.__name__)
return arg
@@ -148,7 +148,7 @@ def make_node(o):
return lambdadef(o)
if hasattr(o, "__code__") and o.__code__.co_flags & inspect.CO_NESTED:
return lambdadef(o)
- if isinstance(o, iterator.runtime.Offset):
+ if isinstance(o, iterator_runtime.Offset):
return OffsetLiteral(value=o.value)
if isinstance(o, core_defs.Scalar):
return im.literal_from_value(o)
@@ -187,7 +187,7 @@ def make_function_definition(fun):
class FundefTracer:
- def __call__(self, fundef_dispatcher: iterator.runtime.FundefDispatcher):
+ def __call__(self, fundef_dispatcher: iterator_runtime.FundefDispatcher):
def fun(*args):
res = make_function_definition(fundef_dispatcher.fun)
return res(*args)
@@ -198,7 +198,7 @@ def __bool__(self):
return iterator.builtins.builtin_dispatch.key == TRACING
-iterator.runtime.FundefDispatcher.register_hook(FundefTracer())
+iterator_runtime.FundefDispatcher.register_hook(FundefTracer())
class TracerContext:
@@ -229,12 +229,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
iterator.builtins.builtin_dispatch.pop_key()
-@iterator.runtime.set_at.register(TRACING)
+@iterator_runtime.set_at.register(TRACING)
def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None:
TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target))
-@iterator.runtime.if_stmt.register(TRACING)
+@iterator_runtime.if_stmt.register(TRACING)
def if_stmt(
cond: itir.Expr, true_branch_f: typing.Callable, false_branch_f: typing.Callable
) -> None:
@@ -255,7 +255,7 @@ def if_stmt(
)
-@iterator.runtime.temporary.register(TRACING)
+@iterator_runtime.temporary.register(TRACING)
def temporary(
domain: itir.Expr,
dtype: Callable, # the gt4py type builtin
@@ -265,7 +265,7 @@ def temporary(
itir.Temporary(
id=id_,
domain=domain,
- dtype=iterator.runtime._dtypebuiltin_to_ts(dtype),
+ dtype=iterator_runtime._dtypebuiltin_to_ts(dtype),
)
)
return itir.SymRef(id=id_)
diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py
index 437dcbf93a..84c83ad34e 100644
--- a/src/gt4py/next/otf/arguments.py
+++ b/src/gt4py/next/otf/arguments.py
@@ -237,7 +237,11 @@ def make_primitive_value_args_extractor(
The returned function has the signature `(*args, **kwargs) -> (args, kwargs)`,
where `args` is a tuple of positional arguments and `kwargs` is a dictionary of
keyword arguments containing the extracted primitive values where needed.
+
+ This function only uses structural information of the type, primitive values may be
+ passed as :ref:`ts.DeferredType`.
"""
+
args_param = "args"
kwargs_param = "kwargs"
num_args_to_extract = 0
diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py
index 4af678c607..5e603b52ed 100644
--- a/src/gt4py/next/otf/compiled_program.py
+++ b/src/gt4py/next/otf/compiled_program.py
@@ -12,13 +12,19 @@
import dataclasses
import functools
import itertools
+import warnings
from collections.abc import Callable, Hashable, Sequence
from typing import Any, TypeAlias, TypeVar
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping, utils as eve_utils
from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils
-from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront
+from gt4py.next.ffront import (
+ stages as ffront_stages,
+ type_info as ffront_type_info,
+ type_specifications as ts_ffront,
+ type_translation,
+)
from gt4py.next.instrumentation import metrics
from gt4py.next.otf import arguments, stages
from gt4py.next.type_system import type_info, type_specifications as ts
@@ -28,7 +34,7 @@
T = TypeVar("T")
ScalarOrTupleOfScalars: TypeAlias = xtyping.MaybeNestedInTuple[core_defs.Scalar]
-CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int]
+CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int, None | str]
ArgumentDescriptors: TypeAlias = dict[
type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor]
]
@@ -216,7 +222,9 @@ class CompiledProgramsPool:
"""
backend: gtx_backend.Backend
- definition_stage: ffront_stages.ProgramDefinition
+ definition_stage: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition
+ # Note: This type can be incomplete, i.e. contain DeferredType, whenever the operator is a
+ # scan operator. In the future it could also be the type of a generic program.
program_type: ts_ffront.ProgramType
#: mapping from an argument descriptor type to a list of parameters or expression thereof
#: e.g. `{arguments.StaticArg: ["static_int_param"]}`
@@ -259,7 +267,31 @@ def __call__(
else:
args, kwargs = canonical_args, canonical_kwargs
static_args_values = self._argument_descriptor_cache_key_from_args(*args, **kwargs)
- key = (static_args_values, common.hash_offset_provider_items_by_id(offset_provider))
+
+ if self._is_generic:
+ # In case the program or operator is generic, i.e. callable for arguments of varying
+ # type, add the argument types to the cache key as the argument types are used during
+ # compilation. In case the program is not generic we can avoid the potentially
+ # expensive type deduction for all arguments and not include it in the key.
+ warnings.warn(
+ "Calling generic programs / direct calls to scan operators are not optimized. "
+ "Consider calling a specialized version instead.",
+ stacklevel=2,
+ )
+ arg_specialization_key = eve_utils.content_hash(
+ (
+ tuple(type_translation.from_value(arg) for arg in canonical_args),
+ {k: type_translation.from_value(v) for k, v in canonical_kwargs.items()},
+ )
+ )
+ else:
+ arg_specialization_key = None
+
+ key = (
+ static_args_values,
+ common.hash_offset_provider_items_by_id(offset_provider),
+ arg_specialization_key,
+ )
try:
program = self.compiled_programs[key]
@@ -284,6 +316,12 @@ def __call__(
argument_descriptors=_make_argument_descriptors(
self.program_type, self.argument_descriptor_mapping, args, kwargs
),
+ # note: it is important to use the args before named collections are extracted
+ # as otherwise the implicit program generation from an operator fails
+ arg_specialization_info=(
+ tuple(type_translation.from_value(arg) for arg in canonical_args),
+ {k: type_translation.from_value(v) for k, v in canonical_kwargs.items()},
+ ),
offset_provider=offset_provider,
call_key=key,
)
@@ -295,9 +333,31 @@ def __call__(
) # passing `enable_jit=False` because a cache miss should be a hard-error in this call`
raise RuntimeError("No program compiled for this set of static arguments.") from e
+ @functools.cached_property
+ def _is_generic(self) -> bool:
+ """
+ Is the operator or program generic in the sense that it can be called for different
+ argument types.
+
+ Right now this is only the case for scan operators.
+ """
+ # TODO(tehrengruber): This concept does not exist elsewhere and is not properly reflected
+ # in the type system. For now we just use `DeferredType` to communicate between
+ # here and `type_info.type_in_program_context`.
+ return any(
+ isinstance(t, ts.DeferredType)
+ for t in itertools.chain(
+ self.program_type.definition.pos_only_args,
+ self.program_type.definition.pos_or_kw_args.values(),
+ self.program_type.definition.kw_only_args.values(),
+ )
+ )
+
@functools.cached_property
def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]:
- return gtx_utils.make_args_canonicalizer_for_function(self.definition_stage.definition)
+ return ffront_type_info.make_args_canonicalizer(
+ self.program_type, name=self.definition_stage.definition.__name__
+ )
@functools.cached_property
def _metrics_key_from_pool_key(self) -> Callable[[CompiledProgramsKey], str]:
@@ -393,6 +453,11 @@ def _compile_variant(
self,
argument_descriptors: ArgumentDescriptors,
offset_provider: common.OffsetProviderType | common.OffsetProvider,
+ #: tuple consisting of the types of the positional and keyword arguments.
+ arg_specialization_info: tuple[tuple[ts.TypeSpec, ...], dict[str, ts.TypeSpec]]
+ | None = None,
+ # argument used only to validate key computed in a call / dispatch agrees with the
+ # key computed here
call_key: CompiledProgramsKey | None = None,
) -> None:
if not common.is_offset_provider(offset_provider):
@@ -412,6 +477,7 @@ def _compile_variant(
key = (
self._argument_descriptor_cache_key_from_descriptors(argument_descriptor_contexts),
common.hash_offset_provider_items_by_id(offset_provider),
+ eve_utils.content_hash(arg_specialization_info) if self._is_generic else None,
)
assert call_key is None or call_key == key
@@ -431,12 +497,24 @@ def _compile_variant(
},
)
+ if arg_specialization_info:
+ arg_types, kwarg_types = arg_specialization_info
+ else:
+ if self._is_generic:
+ raise ValueError(
+ "Can not precompile generic program or scan operator without argument types."
+ )
+ arg_types = (
+ *self.program_type.definition.pos_only_args,
+ *self.program_type.definition.pos_or_kw_args.values(),
+ )
+ kwarg_types = self.program_type.definition.kw_only_args
+
compile_time_args = arguments.CompileTimeArgs(
offset_provider=offset_provider,
column_axis=None, # TODO(havogt): column_axis seems to a unused, even for programs with scans
- args=tuple(self.program_type.definition.pos_only_args)
- + tuple(self.program_type.definition.pos_or_kw_args.values()),
- kwargs=self.program_type.definition.kw_only_args,
+ args=arg_types,
+ kwargs=kwarg_types,
argument_descriptor_contexts=argument_descriptor_contexts,
)
compile_call = functools.partial(
@@ -449,7 +527,7 @@ def _compile_variant(
self.compiled_programs[key] = _async_compilation_pool.submit(compile_call)
# TODO(tehrengruber): Rework the interface to allow precompilation with compile time
- # domains.
+ # domains and of scans.
def compile(
self,
offset_providers: list[common.OffsetProvider | common.OffsetProviderType],
diff --git a/src/gt4py/next/otf/options.py b/src/gt4py/next/otf/options.py
new file mode 100644
index 0000000000..303996e458
--- /dev/null
+++ b/src/gt4py/next/otf/options.py
@@ -0,0 +1,41 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2024, ETH Zurich
+# All rights reserved.
+#
+# Please, refer to the LICENSE file in the root directory.
+# SPDX-License-Identifier: BSD-3-Clause
+
+import dataclasses
+from typing import Sequence, TypedDict
+
+from gt4py.next import common, config
+
+
+class CompilationOptionsArgs(TypedDict, total=False):
+ enable_jit: bool
+ static_params: Sequence[str]
+ connectivities: common.OffsetProvider
+
+
+@dataclasses.dataclass(frozen=True)
+class CompilationOptions:
+ #: Enable Just-in-Time compilation, otherwise a program has to be compiled manually by a call
+ #: to `compile` before calling.
+ # Uses a factory to make changes to the config after module import time take effect. This is
+ # mostly important for testing. Users should not rely on it.
+ enable_jit: bool = dataclasses.field(default_factory=lambda: config.ENABLE_JIT_DEFAULT)
+
+ #: if the user requests static params, they will be used later to initialize CompiledPrograms
+ static_params: Sequence[str] | None = (
+ None # TODO: describe that this value will eventually be a sequence of strings
+ )
+
+ # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
+ #: A dictionary holding static/compile-time information about the offset providers.
+ #: For now, it is used for ahead of time compilation in DaCe orchestrated programs,
+ #: i.e. DaCe programs that call GT4Py Programs -SDFGConvertible interface-.
+ connectivities: common.OffsetProvider | None = None
+
+
+assert CompilationOptionsArgs.__annotations__.keys() == CompilationOptions.__annotations__.keys()
diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py
index 8b5496989f..bcb11953cf 100644
--- a/src/gt4py/next/program_processors/runners/dace/program.py
+++ b/src/gt4py/next/program_processors/runners/dace/program.py
@@ -39,7 +39,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG:
if (self.backend is None) or "dace" not in self.backend.name.lower():
raise ValueError("The SDFG can be generated only for the DaCe backend.")
- offset_provider: gtx_common.OffsetProvider = self.connectivities or {}
+ offset_provider: gtx_common.OffsetProvider = self.compilation_options.connectivities or {}
column_axis = kwargs.get("column_axis", None)
# TODO(ricoh): connectivity tables required here for now.
@@ -150,12 +150,12 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[
the offset providers are not part of GT4Py Program's arguments.
Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method.
"""
- if not self.connectivities:
+ if not self.compilation_options.connectivities:
return {}
used_connectivities: dict[str, gtx_common.NeighborConnectivity] = {
conn_id: conn
- for offset, conn in self.connectivities.items()
+ for offset, conn in self.compilation_options.connectivities.items()
if gtx_common.is_neighbor_table(conn)
and (conn_id := gtx_dace_args.connectivity_identifier(offset))
in self.sdfg_closure_cache["arrays"]
@@ -171,7 +171,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[
# Build the closure dictionary
closure_dict: dict[str, dace.data.Array] = {}
- offset_provider_type = gtx_common.offset_provider_to_type(self.connectivities)
+ offset_provider_type = gtx_common.offset_provider_to_type(
+ self.compilation_options.connectivities
+ )
for conn_id, conn in used_connectivities.items():
if conn_id not in self.connectivity_tables_data_descriptors:
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array(
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py
index 8225974e0a..8437e71367 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py
@@ -16,12 +16,13 @@
from gt4py import next as gtx
from gt4py._core import definitions as core_defs
from gt4py.next import errors, config
-from gt4py.next.otf import compiled_program
+from gt4py.next.otf import compiled_program, options
from gt4py.next.ffront.decorator import Program
from gt4py.next.ffront.fbuiltins import int32, neighbor_sum
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
+ KDim,
V2E,
cartesian_case,
mesh_descriptor,
@@ -39,9 +40,14 @@
_raise_on_compile.compile.side_effect = AssertionError("This function should never be called.")
-@pytest.fixture
-def compile_testee(cartesian_case):
- @gtx.field_operator
+@pytest.fixture(
+ params=[
+ pytest.param(True, id="program"),
+ pytest.param(False, id="field-operator"),
+ ]
+)
+def compile_testee(request, cartesian_case):
+ @gtx.field_operator(backend=cartesian_case.backend)
def testee_op(a: cases.IField, b: cases.IField) -> cases.IField:
return a + b
@@ -49,7 +55,11 @@ def testee_op(a: cases.IField, b: cases.IField) -> cases.IField:
def testee(a: cases.IField, b: cases.IField, out: cases.IField):
testee_op(a, b, out=out)
- return testee
+ wrap_in_program = request.param
+ if wrap_in_program:
+ return testee
+ else:
+ return testee_op
@pytest.fixture
@@ -65,9 +75,14 @@ def testee(a: cases.IField, b: cases.IField, out: cases.IField, isize: gtx.int32
return testee
-@pytest.fixture
-def compile_testee_scan(cartesian_case):
- @gtx.scan_operator(axis=cases.KDim, forward=True, init=0)
+@pytest.fixture(
+ params=[
+ pytest.param(True, id="program"),
+ pytest.param(False, id="scan-operator"),
+ ]
+)
+def compile_testee_scan(request, cartesian_case):
+ @gtx.scan_operator(axis=cases.KDim, forward=True, init=0, backend=cartesian_case.backend)
def testee_op(carry: gtx.int32, inp: gtx.int32) -> gtx.int32:
return carry + inp
@@ -75,7 +90,11 @@ def testee_op(carry: gtx.int32, inp: gtx.int32) -> gtx.int32:
def testee(a: cases.KField, out: cases.KField):
testee_op(a, out=out)
- return testee
+ wrap_in_program = request.param
+ if wrap_in_program:
+ return testee
+ else:
+ return testee_op
def test_compile(cartesian_case, compile_testee):
@@ -126,15 +145,20 @@ def test_compile_scan(cartesian_case, compile_testee_scan):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")
+ if isinstance(compile_testee_scan, gtx.ffront.decorator.FieldOperator):
+ pytest.xfail(reason="Scan operators can not be precompiled yet.")
+
compile_testee_scan.compile(offset_provider=cartesian_case.offset_provider)
- args, kwargs = cases.get_default_data(cartesian_case, compile_testee_scan)
+ k_size = cartesian_case.default_sizes[KDim]
+ inp = cartesian_case.as_field([KDim], np.arange(k_size, dtype=np.int32))
+ out = cartesian_case.as_field([KDim], np.zeros(k_size, dtype=np.int32))
# make sure the backend is never called
object.__setattr__(compile_testee_scan, "backend", _raise_on_compile)
- compile_testee_scan(*args, offset_provider=cartesian_case.offset_provider, **kwargs)
- assert np.allclose(kwargs["out"].ndarray, np.cumsum(args[0].ndarray))
+ compile_testee_scan(inp, out=out, offset_provider=cartesian_case.offset_provider)
+ assert np.allclose(out.ndarray, np.cumsum(inp.ndarray))
def test_compile_domain(cartesian_case, compile_testee_domain):
@@ -376,7 +400,11 @@ def test_compile_variants(cartesian_case, compile_variants_testee):
# make sure the backend is never called
object.__setattr__(compile_variants_testee, "backend", _raise_on_compile)
- assert compile_variants_testee.static_params == ("scalar_int", "scalar_float", "scalar_bool")
+ assert compile_variants_testee.compilation_options.static_params == (
+ "scalar_int",
+ "scalar_float",
+ "scalar_bool",
+ )
field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
@@ -430,7 +458,7 @@ def test_compile_variants_args_and_kwargs(cartesian_case, compile_variants_teste
def test_compile_variants_not_compiled(cartesian_case, compile_variants_testee):
- object.__setattr__(compile_variants_testee, "enable_jit", False)
+ object.__setattr__(compile_variants_testee.compilation_options, "enable_jit", False)
field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
@@ -452,7 +480,7 @@ def test_compile_variants_not_compiled_but_jit_enabled_on_call(
cartesian_case, compile_variants_testee
):
# disable jit on the program
- object.__setattr__(compile_variants_testee, "enable_jit", False)
+ object.__setattr__(compile_variants_testee.compilation_options, "enable_jit", False)
field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
@@ -484,36 +512,55 @@ def test_compile_variants_not_compiled_but_jit_enabled_on_call(
assert np.allclose(out[1].ndarray, field_b.ndarray - 4.0)
-def test_compile_variants_config_default_disable_jit(cartesian_case, compile_variants_testee):
+def test_compile_variants_config_default_disable_jit(cartesian_case):
"""
Checks that changing the config default will be picked up at call time.
"""
- field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
- field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
- out = cases.allocate(cartesian_case, compile_variants_testee, "out")()
+ # One of the 2 cases will be the non-default. The program has to be defined after the config
+ # has been altered in order for the value to take effect.
+ if cartesian_case.backend is None:
+ pytest.skip("Embedded compiled program doesn't make sense.")
- # One of the 2 cases will be the non-default.
with mock.patch.object(config, "ENABLE_JIT_DEFAULT", True):
- compile_variants_testee(
- field_a,
- int32(3), # variant does not exist
- 4.0,
- False,
- field_b,
+
+ @gtx.field_operator
+ def identity(a: cases.IField):
+ return a
+
+ @gtx.program(backend=cartesian_case.backend)
+ def testee(inp: cases.IField, out: cases.IField):
+ identity(inp, out=out)
+
+ assert testee.compilation_options.enable_jit == True
+
+ inp = cases.allocate(cartesian_case, testee, "inp")()
+ out = cases.allocate(cartesian_case, testee, "out")()
+
+ testee(
+ inp,
out=out,
offset_provider=cartesian_case.offset_provider,
)
- assert np.allclose(out[0].ndarray, field_a.ndarray - 3)
- assert np.allclose(out[1].ndarray, field_b.ndarray - 4.0)
+ assert np.allclose(out.ndarray, inp.ndarray)
with mock.patch.object(config, "ENABLE_JIT_DEFAULT", False):
with pytest.raises(RuntimeError):
- compile_variants_testee(
- field_a,
- int32(-42), # other value than before
- 4.0,
- False,
- field_b,
+
+ @gtx.field_operator
+ def identity(a: cases.IField):
+ return a
+
+ @gtx.program(backend=cartesian_case.backend)
+ def testee(inp: cases.IField, out: cases.IField):
+ identity(inp, out=out)
+
+ assert testee.compilation_options.enable_jit == False
+
+ inp = cases.allocate(cartesian_case, testee, "inp")()
+ out = cases.allocate(cartesian_case, testee, "out")()
+
+ testee(
+ inp,
out=out,
offset_provider=cartesian_case.offset_provider,
)
@@ -526,13 +573,13 @@ def test_compile_variants_not_compiled_then_reset_static_params(
This test ensures that after calling ".with_static_params(None)" the previously compiled programs are gone
and we can compile for the generic version.
"""
- object.__setattr__(compile_variants_testee, "enable_jit", True)
+ object.__setattr__(compile_variants_testee.compilation_options, "enable_jit", True)
field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
# the compile_variants_testee has static_params set and is compiled (in a previous test)
- assert len(compile_variants_testee.static_params) > 0
+ assert len(compile_variants_testee.compilation_options.static_params) > 0
assert compile_variants_testee._compiled_programs is not None
# but now we reset the compiled programs
@@ -577,13 +624,13 @@ def test_compile_variants_not_compiled_then_set_new_static_params(
This test ensures that after calling `with_static_params("scalar_float", "scalar_bool")`
the previously compiled programs are gone and we can compile for the new `static_params`.
"""
- object.__setattr__(compile_variants_testee, "enable_jit", False)
+ object.__setattr__(compile_variants_testee.compilation_options, "enable_jit", False)
field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
# the compile_variants_testee has static_params set and is compiled (in a previous test)
- assert len(compile_variants_testee.static_params) > 0
+ assert len(compile_variants_testee.compilation_options.static_params) > 0
assert compile_variants_testee._compiled_programs is not None
# but now we reset the compiled programs and fix to other static params
@@ -623,7 +670,7 @@ def test_compile_variants_not_compiled_then_set_new_static_params(
def test_compile_variants_jit(cartesian_case, compile_variants_testee):
- object.__setattr__(compile_variants_testee, "enable_jit", True)
+ object.__setattr__(compile_variants_testee.compilation_options, "enable_jit", True)
field_a = cases.allocate(cartesian_case, compile_variants_testee, "field_a")()
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()
@@ -660,7 +707,7 @@ def test_compile_variants_jit(cartesian_case, compile_variants_testee):
def test_compile_variants_with_static_params_jit(
cartesian_case, compile_variants_testee_not_compiled
):
- object.__setattr__(compile_variants_testee_not_compiled, "enable_jit", True)
+ object.__setattr__(compile_variants_testee_not_compiled.compilation_options, "enable_jit", True)
testee_with_static_params = compile_variants_testee_not_compiled.with_static_params(
"scalar_int", "scalar_float", "scalar_bool"
)
@@ -839,7 +886,11 @@ def test_synchronous_compilation(cartesian_case, compile_testee):
a = cases.allocate(cartesian_case, compile_testee, "a")()
b = cases.allocate(cartesian_case, compile_testee, "b")()
- out = cases.allocate(cartesian_case, compile_testee, "out")()
+ if isinstance(compile_testee, gtx.ffront.decorator.FieldOperator):
+ out = cases.allocate(cartesian_case, compile_testee, cases.RETURN)()
+ else:
+ out = cases.allocate(cartesian_case, compile_testee, "out")()
+
compile_testee(
a,
b,
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py
index 72736191b4..4578576f02 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py
@@ -37,39 +37,6 @@ def testee(a: cases.IField, out: cases.IField):
assert isinstance(testee.with_backend(cartesian_case.backend).gtir, itir.Program)
-def test_frozen(cartesian_case):
- if cartesian_case.backend is None:
- pytest.skip("Frozen Program with embedded execution is not possible.")
-
- @gtx.field_operator
- def testee_op(a: cases.IField) -> cases.IField:
- return a
-
- @gtx.program(backend=cartesian_case.backend, frozen=True)
- def testee(a: cases.IField, out: cases.IField):
- testee_op(a, out=out)
-
- assert isinstance(testee, gtx.ffront.decorator.FrozenProgram)
-
- # first run should JIT compile
- args_1, kwargs_1 = cases.get_default_data(cartesian_case, testee)
- testee(*args_1, offset_provider=cartesian_case.offset_provider, **kwargs_1)
-
- # _compiled_program should be set after JIT compiling
- args_2, kwargs_2 = cases.get_default_data(cartesian_case, testee)
- testee._compiled_program(*args_2, offset_provider=cartesian_case.offset_provider, **kwargs_2)
-
- # and give expected results
- xp = args_1[0].array_ns
- assert xp.allclose(kwargs_2["out"].ndarray, args_2[0].ndarray)
-
- # with_backend returns a new instance, which is frozen but not compiled yet
- assert testee.with_backend(cartesian_case.backend)._compiled_program is None
-
- # with_grid_type returns a new instance, which is frozen but not compiled yet
- assert testee.with_grid_type(cartesian_case.grid_type)._compiled_program is None
-
-
@pytest.mark.parametrize(
"metrics_level,expected_names",
[
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
index 172b936665..46b40085d5 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
@@ -21,6 +21,7 @@
)
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
from gt4py.next import backend as next_backend
+from gt4py.next.otf import options
from gt4py.next.type_system import type_translation
from next_tests.integration_tests import cases
@@ -112,6 +113,7 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[next_backen
closure_vars=closure_vars,
grid_type=None,
),
+ compilation_options=options.CompilationOptions(),
backend=backend,
)
@@ -133,6 +135,6 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp
builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.backend)
- builtin_field_op(*inps, out=out, offset_provider={})
+ builtin_field_op(*inps, out, offset_provider={})
assert np.allclose(out.asnumpy(), expected)
diff --git a/tests/next_tests/unit_tests/otf_tests/test_arguments.py b/tests/next_tests/unit_tests/otf_tests/test_arguments.py
index 54d30be5fc..8f256062af 100644
--- a/tests/next_tests/unit_tests/otf_tests/test_arguments.py
+++ b/tests/next_tests/unit_tests/otf_tests/test_arguments.py
@@ -195,3 +195,6 @@ def test_make_primitive_value_args_extractor_mixed_args(
args, kwargs = extractor(container1, 3.14, c=container2)
assert args == ((1.0, 2.0), 3.14)
assert kwargs == {"c": (3.0, 4.0)}
+
+
+# TODO: write test for scan