diff --git a/src/gt4py/cartesian/definitions.py b/src/gt4py/cartesian/definitions.py index 4781faca22..d535bc29d3 100644 --- a/src/gt4py/cartesian/definitions.py +++ b/src/gt4py/cartesian/definitions.py @@ -10,6 +10,7 @@ import functools import os import platform +import warnings from dataclasses import dataclass from typing import Literal, Tuple, Union @@ -42,6 +43,28 @@ """Default literal precision used for unspecific `float` types and casts.""" +def _check_boolean_env_var(name: str, default: bool) -> bool: + envvar = os.environ.get(name, default=default) + if type(envvar) is bool: + return envvar + + if type(envvar) is str: + if envvar.lower() in ["true", "1"]: + return True + if envvar in ["false", "0"]: + return False + + warnings.warn( + f"Could not match `{name}={envvar}` into a boolean value. Falling back to the default `{default}`.", + stacklevel=2, + ) + return default + + +FORCE_ANNOTATED_TEMPORARIES = _check_boolean_env_var("GT4PY_FORCE_ANNOTATED_TEMPORARIES", False) +"""If True, forces all temporaries in stencils to have type annotations.""" + + @enum.unique class AccessKind(enum.IntFlag): NONE = 0 @@ -123,6 +146,8 @@ class BuildOptions(AttributeClassLike): "Literal precision for `int` types and casts. Defaults to architecture precision unless overwritten by the environment variable `GT4PY_LITERAL_INT_PRECISION`." literal_float_precision = attribute(of=int, default=LITERAL_FLOAT_PRECISION) "Literal precision for `float` types and casts. Defaults to architecture precision unless overwritten by the environment variable `GT4PY_LITERAL_FLOAT_PRECISION`." + force_annotated_temporaries = attribute(of=bool, default=FORCE_ANNOTATED_TEMPORARIES) + "If True, enforce all temporaries to have type annotations. Defaults to False unless overwritten by the environment variable `GT4PY_FORCE_ANNOTATED_TEMPORARIES`." @property def qualified_name(self): diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index b9482c91bc..8c3befccb9 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -799,6 +799,7 @@ def __init__( self.domain = domain or nodes.Domain.LatLonGrid() self.literal_int_precision = options.literal_int_precision self.literal_float_precision = options.literal_float_precision + self.force_annotated_temporaries = options.force_annotated_temporaries self.temp_decls = temp_decls or {} self.parsing_context = None self.iteration_order = None @@ -1705,6 +1706,11 @@ def _resolve_assign( message="Temporaries with data dimensions need to be declared explicitly.", loc=nodes.Location.from_ast_node(t, scope=self.stencil_name), ) + if self.force_annotated_temporaries and target_annotation is None: + raise GTScriptSyntaxError( + message=f"Missing type hint for '{name}' in stencil '{self.stencil_name}'.", + loc=nodes.Location.from_ast_node(t, scope=self.stencil_name), + ) dtype = nodes.DataType.AUTO axes = nodes.Domain.LatLonGrid().axes_names if target_annotation is not None: diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 093bd60fde..738fb75f3d 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -57,6 +57,7 @@ def parse_definition( dtypes: Dict[Type, Type] = None, literal_int_precision: int | None = None, literal_float_precision: int | None = None, + force_annotated_temporaries: bool | None = None, rebuild=False, **kwargs, ) -> nodes.StencilDefinition: @@ -73,6 +74,8 @@ def parse_definition( build_args["literal_int_precision"] = literal_int_precision if literal_float_precision is not None: build_args["literal_float_precision"] = literal_float_precision + if force_annotated_temporaries is not None: + build_args["force_annotated_temporaries"] = force_annotated_temporaries build_options = gt_definitions.BuildOptions(**build_args) @@ -2222,3 +2225,35 @@ def stencil(in_field: gtscript.Field[float], out_field: gtscript.Field[float]): name=inspect.stack()[0][3], module=self.__class__.__name__, ) + + +class TestForceAnnotatedTemporaries: + def test_missing_annotation(self): + def good_case(in_field: gtscript.Field[float], out_field: gtscript.Field[float]): + with computation(PARALLEL), interval(...): + tmp: float = 2 * in_field + out_field = tmp + + parsed = parse_definition( + good_case, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + + declaration = parsed.computations[0].body.stmts[0] + assert isinstance(declaration, nodes.FieldDecl) + assert declaration.data_type == nodes.DataType.FLOAT64 + + def bad_case(in_field: gtscript.Field[float], out_field: gtscript.Field[float]): + with computation(PARALLEL), interval(...): + tmp = 2 * in_field + out_field = tmp + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match="Missing type hint for 'tmp' in stencil 'bad_case'.", + ): + parse_definition( + bad_case, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + force_annotated_temporaries=True, + )