diff --git a/src/gt4py/cartesian/definitions.py b/src/gt4py/cartesian/definitions.py index 4781faca22..62b87488e3 100644 --- a/src/gt4py/cartesian/definitions.py +++ b/src/gt4py/cartesian/definitions.py @@ -42,6 +42,21 @@ """Default literal precision used for unspecific `float` types and casts.""" +def get_integer_default_type(): + """Return the integer numpy type corresponding to the LITERAL_INT_PRECISION set.""" + # I'd love to return `numpy.signedinteger[LITERAL_INT_PRECISION]` but that won't work + if LITERAL_INT_PRECISION == 8: + return numpy.int8 + if LITERAL_INT_PRECISION == 32: + return numpy.int32 + if LITERAL_INT_PRECISION == 64: + return numpy.int64 + if LITERAL_INT_PRECISION == 128: + return numpy.int128 + + raise NotImplementedError("Unknown integer precision type") + + @enum.unique class AccessKind(enum.IntFlag): NONE = 0 diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f4dd4dcb3b..62c71b9747 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -309,6 +309,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom): return node def visit_Attribute(self, node: ast.Attribute): + # An enum MyEnum.A would come has + # > ast.Attribute("A") + # - value: ast.Name("MyEnum") + # We want to replace the entire thing - so we capture the top level + # attribute. We don't use the `self.context` because of this + # two-step AST structure which doesn't fit the generic `replace_node`. + + if isinstance(node.value, ast.Name) and node.value.id in _ENUM_REGISTER.keys(): + int_value = getattr(_ENUM_REGISTER[node.value.id], node.attr) + return ast.Constant(value=int_value) + + # Common replace for all other nodes in context. return self._replace_node(node) def visit_Name(self, node: ast.Name): @@ -1808,6 +1820,11 @@ def visit_Assign(self, node: ast.Assign): raise invalid_target +_ENUM_REGISTER: dict[str, object] = {} +"""Register of IntEnum that will be available to parsing in stencils. Register +with @gtscript.enum()""" + + class GTScriptParser(ast.NodeVisitor): CONST_VALUE_TYPES = ( *gtscript._VALID_DATA_TYPES, @@ -1899,6 +1916,10 @@ def annotate_definition( and param.annotation in gtscript._VALID_DATA_TYPES ): dtype_annotation = np.dtype(param.annotation) + elif param.annotation in _ENUM_REGISTER.values(): + dtype_annotation = ( + gt_definitions.get_integer_default_type() + ) # We will replace all enums with `int` elif param.annotation is inspect.Signature.empty: dtype_annotation = None else: @@ -2024,6 +2045,10 @@ def collect_external_symbols(definition): local_symbols = CollectLocalSymbolsAstVisitor.apply(gtscript_ast) nonlocal_symbols = {} + # Remove enums from `context`, they will be turned into integers in the ValueReplacer + for enum_ in _ENUM_REGISTER.keys(): + context.pop(enum_, "") + name_nodes = gt_meta.collect_names(gtscript_ast, skip_annotations=False) for collected_name in name_nodes.keys(): if collected_name not in gtscript.builtins: @@ -2147,6 +2172,20 @@ def resolve_external_symbols( return result + @staticmethod + def register_enum(class_: type[enum.IntEnum]): + class_name = class_.__name__ + if class_name in _ENUM_REGISTER: + raise ValueError( + f"Enum names must be unique. @gtscript.enum {class_name} is already taken." + ) + + if not issubclass(class_, enum.IntEnum): + raise ValueError(f"Enum {class_name} needs to derive from `enum.IntEnum`.") + + _ENUM_REGISTER[class_name] = class_ + return class_ + def extract_arg_descriptors(self): api_signature = self.definition._gtscript_["api_signature"] api_annotations = self.definition._gtscript_["api_annotations"] diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 56bf8874a8..6debce00de 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -16,6 +16,7 @@ import inspect import numbers import types +from enum import IntEnum from typing import Callable, Dict, Type, Union import numpy as np @@ -159,6 +160,14 @@ def _parse_annotation(arg, annotation): return original_annotations +def enum(class_: type[IntEnum]): + """Mark an IntEnum derived class as readable for GT4Py.""" + from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend + + gt_frontend.GTScriptParser.register_enum(class_) + return class_ + + def function(func): """Mark a GTScript function.""" from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 3e988149bc..97c46d2ba9 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -20,7 +20,14 @@ import numpy as np from gt4py.cartesian import backend as gt_backend -from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo, ParameterInfo +from gt4py.cartesian.definitions import ( + AccessKind, + DomainInfo, + FieldInfo, + ParameterInfo, + get_integer_default_type, +) +from gt4py.cartesian.frontend import gtscript_frontend from gt4py.cartesian.gtc import utils as gtc_utils from gt4py.cartesian.gtc.definitions import Index, Shape from gt4py.storage.cartesian import utils as storage_utils @@ -558,8 +565,13 @@ def _call_run( exec_info["call_run_start_time"] = time.perf_counter() backend_cls = gt_backend.from_name(self.backend) device = backend_cls.storage_info["device"] - array_infos = _extract_array_infos(field_args, device) + # Normalize `gtscript.enum` to integers + for name, value in parameter_args.items(): + if type(value) in gtscript_frontend._ENUM_REGISTER.values(): + parameter_args[name] = get_integer_default_type()(value.value) + + array_infos = _extract_array_infos(field_args, device) cache_key = _compute_domain_origin_cache_key(array_infos, parameter_args, domain, origin) if cache_key not in self._domain_origin_cache: origin = self._normalize_origins(array_infos, self.field_info, origin) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 9ffcb3f16c..94c2d06ab3 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from enum import IntEnum import numpy as np import pytest @@ -22,6 +23,7 @@ J, K, IJ, + IJK, computation, horizontal, interval, @@ -1296,3 +1298,38 @@ def test_lower_dim_field( out_arr[:, :, :] = 0 test_lower_dim_field(k_arr, out_arr) assert (out_arr[:, :, :] == 42.42).all() + + +@gtscript.enum +class MyEnum(IntEnum): + Zero = 0 + A = 10 + B = 20 + C = 30 + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_enum_runtime(backend): + @gtscript.stencil(backend=backend) + def the_stencil(out_field: Field[IJK, int], order: MyEnum): + with computation(PARALLEL), interval(0, 1): + out_field = 32 + if order < MyEnum.A: + out_field = MyEnum.A + + with computation(PARALLEL), interval(1, 2): + out_field = 23 + out_field = MyEnum.B + + with computation(PARALLEL), interval(2, None): + out_field = 56 + out_field = MyEnum.C + + domain = (5, 5, 5) + out_arr = gt_storage.zeros(backend=backend, shape=domain, dtype=int) + + the_stencil(out_arr, MyEnum.Zero) + + assert out_arr[0, 0, 0] == MyEnum.A.value + assert out_arr[0, 0, 1] == MyEnum.B.value + assert (out_arr[0, 0, 2:] == MyEnum.C.value).all() diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 00338ec790..2adca00867 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from enum import IntEnum import inspect import textwrap import types @@ -2029,3 +2030,33 @@ def test_assign_constant_numpy_typed(self): constant: nodes.ScalarLiteral = def_ir.computations[0].body.stmts[0].value assert isinstance(constant, nodes.ScalarLiteral) assert constant.data_type == nodes.DataType.FLOAT32 + + +@gtscript.enum +class LocalEnum(IntEnum): + A = 42 + B = 1000 + + +class TestEnum: + def setup_method(self): + def enum(field: gtscript.Field[float], order: LocalEnum): # type: ignore + with computation(PARALLEL), interval(0, 1): + if order > LocalEnum.A: + field[0, 0, 0] = LocalEnum.B + + self.stencil = enum + + def test_enum_in_stencil(self): + def_ir = parse_definition( + self.stencil, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + + assert isinstance(def_ir.computations[0].body.stmts[0].condition.rhs, nodes.ScalarLiteral) + assert def_ir.computations[0].body.stmts[0].condition.rhs.value == LocalEnum.A + assert isinstance( + def_ir.computations[0].body.stmts[0].main_body.stmts[0].value, nodes.ScalarLiteral + ) + assert def_ir.computations[0].body.stmts[0].main_body.stmts[0].value.value == LocalEnum.B