Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/gt4py/cartesian/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import functools
import os
import platform
import warnings
from dataclasses import dataclass
from typing import Literal, Tuple, Union

Expand Down Expand Up @@ -42,6 +43,28 @@
"""Default literal precision used for unspecific `float` types and casts."""


def _check_boolean_env_var(name: str, default: bool) -> bool:
envvar = os.environ.get(name, default=default)
if type(envvar) is bool:
return envvar

if type(envvar) is str:
if envvar.lower() in ["true", "1"]:
return True
if envvar in ["false", "0"]:
return False

warnings.warn(
f"Could not match `{name}={envvar}` into a boolean value. Falling back to the default `{default}`.",
stacklevel=2,
)
return default


FORCE_ANNOTATED_TEMPORARIES = _check_boolean_env_var("GT4PY_FORCE_ANNOTATED_TEMPORARIES", False)
"""If True, forces all temporaries in stencils to have type annotations."""


@enum.unique
class AccessKind(enum.IntFlag):
NONE = 0
Expand Down Expand Up @@ -123,6 +146,8 @@ class BuildOptions(AttributeClassLike):
"Literal precision for `int` types and casts. Defaults to architecture precision unless overwritten by the environment variable `GT4PY_LITERAL_INT_PRECISION`."
literal_float_precision = attribute(of=int, default=LITERAL_FLOAT_PRECISION)
"Literal precision for `float` types and casts. Defaults to architecture precision unless overwritten by the environment variable `GT4PY_LITERAL_FLOAT_PRECISION`."
force_annotated_temporaries = attribute(of=bool, default=FORCE_ANNOTATED_TEMPORARIES)
"If True, enforce all temporaries to have type annotations. Defaults to False unless overwritten by the environment variable `GT4PY_FORCE_ANNOTATED_TEMPORARIES`."

@property
def qualified_name(self):
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def __init__(
self.domain = domain or nodes.Domain.LatLonGrid()
self.literal_int_precision = options.literal_int_precision
self.literal_float_precision = options.literal_float_precision
self.force_annotated_temporaries = options.force_annotated_temporaries
self.temp_decls = temp_decls or {}
self.parsing_context = None
self.iteration_order = None
Expand Down Expand Up @@ -1705,6 +1706,11 @@ def _resolve_assign(
message="Temporaries with data dimensions need to be declared explicitly.",
loc=nodes.Location.from_ast_node(t, scope=self.stencil_name),
)
if self.force_annotated_temporaries and target_annotation is None:
raise GTScriptSyntaxError(
message=f"Missing type hint for '{name}' in stencil '{self.stencil_name}'.",
loc=nodes.Location.from_ast_node(t, scope=self.stencil_name),
)
dtype = nodes.DataType.AUTO
axes = nodes.Domain.LatLonGrid().axes_names
if target_annotation is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def parse_definition(
dtypes: Dict[Type, Type] = None,
literal_int_precision: int | None = None,
literal_float_precision: int | None = None,
force_annotated_temporaries: bool | None = None,
rebuild=False,
**kwargs,
) -> nodes.StencilDefinition:
Expand All @@ -73,6 +74,8 @@ def parse_definition(
build_args["literal_int_precision"] = literal_int_precision
if literal_float_precision is not None:
build_args["literal_float_precision"] = literal_float_precision
if force_annotated_temporaries is not None:
build_args["force_annotated_temporaries"] = force_annotated_temporaries

build_options = gt_definitions.BuildOptions(**build_args)

Expand Down Expand Up @@ -2222,3 +2225,35 @@ def stencil(in_field: gtscript.Field[float], out_field: gtscript.Field[float]):
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)


class TestForceAnnotatedTemporaries:
def test_missing_annotation(self):
def good_case(in_field: gtscript.Field[float], out_field: gtscript.Field[float]):
with computation(PARALLEL), interval(...):
tmp: float = 2 * in_field
out_field = tmp

parsed = parse_definition(
good_case, name=inspect.stack()[0][3], module=self.__class__.__name__
)

declaration = parsed.computations[0].body.stmts[0]
assert isinstance(declaration, nodes.FieldDecl)
assert declaration.data_type == nodes.DataType.FLOAT64

def bad_case(in_field: gtscript.Field[float], out_field: gtscript.Field[float]):
with computation(PARALLEL), interval(...):
tmp = 2 * in_field
out_field = tmp

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Missing type hint for 'tmp' in stencil 'bad_case'.",
):
parse_definition(
bad_case,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
force_annotated_temporaries=True,
)