From c362c37005659185a76d93be73a94c32a730f91d Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Aug 2025 09:39:39 +0200 Subject: [PATCH 1/4] WIP 1 --- .github/workflows/daily-ci.yml | 2 +- ci/cscs-ci.yml | 4 +- .../exercises/1_simple_addition.ipynb | 2 +- .../exercises/3_gradient_exercise.ipynb | 2 +- .../workshop/exercises/4_curl_exercise.ipynb | 2 +- .../workshop/exercises/6_where_domain.ipynb | 12 +-- .../workshop/exercises/7_scan_operator.ipynb | 12 +-- pyproject.toml | 75 +++++++++++-------- scripts/update.py | 2 +- src/gt4py/cartesian/backend/numpy_backend.py | 2 +- src/gt4py/cartesian/caching.py | 2 +- src/gt4py/cartesian/frontend/defir_to_gtir.py | 2 +- .../cartesian/gtc/dace/treeir_to_stree.py | 6 +- src/gt4py/cartesian/gtc/oir.py | 2 +- src/gt4py/eve/codegen.py | 2 +- src/gt4py/next/constructors.py | 4 +- src/gt4py/next/embedded/common.py | 2 +- src/gt4py/next/embedded/nd_array_field.py | 4 +- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 5 +- src/gt4py/next/ffront/field_operator_ast.py | 2 +- .../ffront/foast_passes/type_deduction.py | 4 +- src/gt4py/next/ffront/type_info.py | 2 +- src/gt4py/next/iterator/embedded.py | 10 ++- src/gt4py/next/iterator/ir_utils/ir_makers.py | 4 +- .../next/iterator/type_system/inference.py | 2 +- .../compilation/build_systems/compiledb.py | 2 +- src/gt4py/next/otf/workflow.py | 2 +- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 4 +- .../runners/dace/gtir_dataflow.py | 3 +- .../runners/dace/transformations/gpu_utils.py | 6 +- .../dace/transformations/loop_blocking.py | 2 +- .../transformations/map_fusion_extended.py | 4 +- .../redundant_array_removers.py | 2 +- .../dace/transformations/splitting_tools.py | 2 +- .../runners/dace/workflow/translation.py | 6 +- src/gt4py/next/type_system/type_info.py | 2 +- src/gt4py/storage/allocators.py | 2 +- tests/next_tests/integration_tests/cases.py | 2 +- .../ffront_tests/test_import_from_mod.py | 2 +- .../test_math_builtin_execution.py | 2 +- .../test_func_to_foast_error_line_number.py | 2 +- .../iterator_tests/test_type_inference.py | 16 ++-- .../transforms_tests/test_symbol_ref_utils.py | 2 +- .../unit_tests/test_constructors.py | 8 +- 45 files changed, 128 insertions(+), 113 deletions(-) diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index c4693a6798..94370d1e70 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -34,7 +34,7 @@ jobs: # [arg: --resolution, env: UV_RESOLUTION=] dependencies-strategy: ["lowest-direct", "highest"] gt4py-module: ["cartesian", "eve", "next", "storage"] - # TODO: switch to macos-latest once macos-15 is default, see https://github.com/actions/runner-images/issues/12520 + # TODO(): switch to macos-latest once macos-15 is default, see https://github.com/actions/runner-images/issues/12520 # On macos-14 we see the error: call to 'isfinite' is ambiguous. os: ["ubuntu-latest", "macos-15"] python-version: ${{ fromJSON(needs.get-python-versions.outputs.python-versions) }} diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 2485b86c23..489905ecbe 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -59,7 +59,7 @@ stages: BASE_IMAGE: jfrog.svc.cscs.ch/dockerhub/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} EXTRA_UV_SYNC_ARGS: "--extra cuda12" -# TODO: rocm steps are in draft state for now to show how to add support in the future +# TODO(): rocm steps are in draft state for now to show how to add support in the future # .build_extra_rocm: # variables: # # jfrog.svc.cscs.ch/dockerhub/rocm is the cached version of docker.io/rocm @@ -107,7 +107,7 @@ build_cscs_gh200: rules: - if: $SUBPACKAGE == 'next' && $VARIANT == 'dace' && $DETAIL == 'nomesh' variables: - # TODO: investigate why the dace tests seem to hang with multiple jobs + # TODO(): investigate why the dace tests seem to hang with multiple jobs GT4PY_BUILD_JOBS: 1 SLURM_TIMELIMIT: "00:15:00" - when: on_success diff --git a/docs/user/next/workshop/exercises/1_simple_addition.ipynb b/docs/user/next/workshop/exercises/1_simple_addition.ipynb index 7f42f2b9d8..ce30bd4c7d 100644 --- a/docs/user/next/workshop/exercises/1_simple_addition.ipynb +++ b/docs/user/next/workshop/exercises/1_simple_addition.ipynb @@ -60,7 +60,7 @@ "metadata": {}, "outputs": [], "source": [ - "def addition(): ... # TODO fix this cell" + "def addition(): ... # EXERCISE: fix this cell" ] }, { diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb index 2b422b1823..f3e575c72e 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb @@ -72,7 +72,7 @@ " A: gtx.Field[Dims[C], float],\n", " edge_orientation: gtx.Field[Dims[C, C2EDim], float],\n", ") -> gtx.tuple[gtx.Field[Dims[C], float], gtx.Field[Dims[C], float]]:\n", - " # TODO: fix components of gradient\n", + " # TODO(): fix components of gradient\n", " f_x = A\n", " f_y = A\n", " return f_x, f_y" diff --git a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb index dc321f1bdd..215f13b2e0 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb @@ -94,7 +94,7 @@ " dualA: gtx.Field[Dims[V], float],\n", " edge_orientation: gtx.Field[Dims[V, V2EDim], float],\n", ") -> gtx.Field[Dims[V], float]:\n", - " # TODO: fix curl\n", + " # TODO(): fix curl\n", " uv_curl = dualA\n", "\n", " return uv_curl" diff --git a/docs/user/next/workshop/exercises/6_where_domain.ipynb b/docs/user/next/workshop/exercises/6_where_domain.ipynb index 3c50da2245..ab8e65949c 100644 --- a/docs/user/next/workshop/exercises/6_where_domain.ipynb +++ b/docs/user/next/workshop/exercises/6_where_domain.ipynb @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "80cbe9d5", "metadata": {}, "outputs": [ @@ -134,7 +134,7 @@ } ], "source": [ - "# TODO implement the field_operator\n", + "# EXERCISE: implement the field_operator\n", "\n", "\n", "@gtx.program(backend=backend)\n", @@ -201,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "ebe2cd9d", "metadata": {}, "outputs": [ @@ -238,7 +238,7 @@ "@gtx.program(backend=backend)\n", "def program_domain(\n", " a: gtx.Field[Dims[K], float], b: gtx.Field[Dims[K], float]\n", - "): ... # TODO write the call to fieldop_domain" + "): ... # EXERCISE: write the call to fieldop_domain" ] }, { @@ -338,7 +338,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "cac80fbb", "metadata": {}, "outputs": [ @@ -367,7 +367,7 @@ "source": [ "@gtx.field_operator\n", "def fieldop_domain_where(a: gtx.Field[Dims[K], float]) -> gtx.Field[Dims[K], float]:\n", - " return # TODO\n", + " return # ...\n", "\n", "\n", "@gtx.program(backend=backend)\n", diff --git a/docs/user/next/workshop/exercises/7_scan_operator.ipynb b/docs/user/next/workshop/exercises/7_scan_operator.ipynb index 626fd4ecd9..23957e5865 100644 --- a/docs/user/next/workshop/exercises/7_scan_operator.ipynb +++ b/docs/user/next/workshop/exercises/7_scan_operator.ipynb @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "69bf6022", "metadata": {}, "outputs": [ @@ -182,18 +182,18 @@ " sedimentaion_constant = 0.05\n", "\n", " # unpack state of previous iteration\n", - " # TODO\n", + " # ...\n", "\n", " # Autoconversion: Cloud Drops -> Rain Drops\n", - " # TODO\n", + " # ...\n", "\n", " ## Add sedimentation flux from level above\n", - " # TODO\n", + " # ...\n", "\n", " # Remove mass due to sedimentation flux\n", - " # TODO\n", + " # ...\n", "\n", - " return # TODO" + " return # ..." ] }, { diff --git a/pyproject.toml b/pyproject.toml index 57d2879a57..bbdab9988b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -202,7 +202,7 @@ module = 'gt4py.*' [[tool.mypy.overrides]] # The following ignore_errors are only temporary. -# TODO: Fix errors and enable these settings. +# TODO(): Fix errors and enable these settings. disallow_incomplete_defs = false disallow_untyped_defs = false follow_imports = 'silent' @@ -233,7 +233,7 @@ module = 'gt4py.eve.extended_typing' warn_unused_ignores = false [[tool.mypy.overrides]] -# TODO: Make this false and fix errors +# TODO(): Make this false and fix errors disallow_untyped_defs = false follow_imports = 'silent' module = 'gt4py.storage.*' @@ -315,44 +315,53 @@ target-version = 'py310' docstring-code-format = true [tool.ruff.lint] -# -- Rules set to be considered -- -# A: flake8-builtins -# B: flake8-bugbear -# C4: flake8-comprehensions -# CPY: flake8-copyright -# D: pydocstyle -# DOC: pydoclint -# E: pycodestyle -# ERA: eradicate -# F: Pyflakes -# FA100: future-rewritable-type-annotation -# FBT: flake8-boolean-trap -# FLY: flynt -# I: isort -# ICN: flake8-import-conventions -# ISC: flake8-implicit-str-concat -# N: pep8-naming -# NPY: NumPy-specific rules -# PERF: Perflint -# PGH: pygrep-hooks -# PTH: flake8-use-pathlib -# Q: flake8-quotes -# RUF: Ruff-specific rules -# SIM: flake8-simplify -# T10: flake8-debugger -# TD: flake8-todos -# UP: pyupgrade -# YTT: flake8-2020 exclude = ['docs/**', "examples/**", "tests/**"] explicit-preview-rules = true -extend-select = ["F822"] # TODO(egparedes): remove when not longer in preview ignore = [ + 'PLR0913', 'E501', # [line-too-long] 'B905', # [zip-without-explicit-strict] # TODO(egparedes): remove when possible - 'TD003' # [missing-todo-link] + 'ISC001', # [single-line-implicit-string-concatenation] + 'TD003', # [missing-todo-link] + 'UP038' # [non-pep604-isinstance] ] preview = true # use only with explicit-preview-rules=true -select = ['A', 'B', 'CPY', 'E', 'ERA', 'F', 'FA100', 'I', 'ISC', 'NPY', 'Q', 'RUF', 'T10', 'YTT'] +select = [ + 'A', # flake8-builtins + 'B', # flake8-bugbear + # 'C4', # flake8-comprehensions + 'CPY', # flake8-copyright + # 'D', # pydocstyle + 'DOC', # pydoclint + 'E', # pycodestyle + 'ERA', # eradicate + 'F', # Pyflakes + 'F822', # [undefined-export] (preview) # TODO(egparedes): remove when not longer in preview + 'FA100', # future-rewritable-type-annotation + # 'FBT', # flake8-boolean-trap + 'FLY', # flynt + 'I', # isort + # 'ICN', # flake8-import-conventions + 'ISC', # flake8-implicit-str-concat + # 'N', # pep8-naming + 'NPY', # NumPy-specific rules + 'PERF', # Perflint + #'PGH', # pygrep-hooks + 'PLC', # Pylint-Convention + 'PLE', # Pylint-Error + #'PLR', # Pylint-Refactor + 'PLW', # Pylint-Warning + 'PTH', # flake8-use-pathlib + 'Q', # flake8-quotes + 'RUF', # Ruff-specific rules + #'SIM', # flake8-simplify + 'T10', # flake8-debugger + 'TD', # flake8-todos + #'UP', # pyupgrade + 'YTT' # flake8-2020 +] + + typing-modules = ['gt4py.eve.extended_typing'] unfixable = [] diff --git a/scripts/update.py b/scripts/update.py index 08d28c9928..8780e2ab20 100755 --- a/scripts/update.py +++ b/scripts/update.py @@ -42,7 +42,7 @@ def dependencies() -> None: def precommit() -> None: """Update versions of pre-commit hooks.""" subprocess.run( - f"uv run --quiet --locked --project {common.REPO_ROOT} pre-commit autoupdate", shell=True + f"uv run --quiet --locked --project {common.REPO_ROOT} pre-commit autoupdate", shell=True, check=False ) try: diff --git a/src/gt4py/cartesian/backend/numpy_backend.py b/src/gt4py/cartesian/backend/numpy_backend.py index 80da38d4f7..fa48a41643 100644 --- a/src/gt4py/cartesian/backend/numpy_backend.py +++ b/src/gt4py/cartesian/backend/numpy_backend.py @@ -31,7 +31,7 @@ class NumpyBackend(backend.BaseBackend): name = "numpy" options: ClassVar[dict[str, Any]] = { "oir_pipeline": {"versioning": True, "type": passes.OirPipeline}, - # TODO: Implement this option in source code + # TODO(): Implement this option in source code "ignore_np_errstate": {"versioning": True, "type": bool}, } storage_info = layout.NaiveCPULayout diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index ff31719aae..79fd8edb5d 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -369,7 +369,7 @@ class NoCachingStrategy(CachingStrategy): name = "nocaching" - def __init__(self, builder: StencilBuilder, *, output_path: pathlib.Path = pathlib.Path(".")): + def __init__(self, builder: StencilBuilder, *, output_path: pathlib.Path = pathlib.Path()): super().__init__(builder) self._output_path = output_path diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index 8a3b9061a4..78616a68cf 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -59,7 +59,7 @@ def _convert_dtype(data_type) -> common.DataType: dtype = common.DataType(int(data_type)) if dtype == common.DataType.DEFAULT: - # TODO: this will be a frontend choice later + # TODO(): this will be a frontend choice later # in non-GTC parts, this is set in the backend dtype = cast( common.DataType, common.DataType.FLOAT64 diff --git a/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py b/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py index bbf6106f0d..f6e25b3a8d 100644 --- a/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py +++ b/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py @@ -152,7 +152,8 @@ def _for_scope_header(node: tir.VerticalLoop) -> dcf.ForScope: Only setup the required data, default or mock the rest. - TODO: In DaCe 2.x this will be replaced by an SDFG concept which should + Todo: + In DaCe 2.x this will be replaced by an SDFG concept which should be closer and required less mockup. """ if not dace_version.startswith("1."): @@ -201,7 +202,8 @@ def _while_scope_header(node: tir.While) -> dcf.WhileScope: Only setup the required data, default or mock the rest. - TODO: In DaCe 2.x this will be replaced by an SDFG concept which should + Todo: + In DaCe 2.x this will be replaced by an SDFG concept which should be closer and required less mockup. """ if not dace_version.startswith("1."): diff --git a/src/gt4py/cartesian/gtc/oir.py b/src/gt4py/cartesian/gtc/oir.py index 1ba36b5077..0c4225cd26 100644 --- a/src/gt4py/cartesian/gtc/oir.py +++ b/src/gt4py/cartesian/gtc/oir.py @@ -296,7 +296,7 @@ def valid_section_intervals(cls: Type[VerticalLoop], instance: VerticalLoop) -> class Stencil(LocNode, eve.ValidatedSymbolTableTrait): name: str - # TODO: fix to be List[Union[ScalarDecl, FieldDecl]] + # TODO(): fix to be List[Union[ScalarDecl, FieldDecl]] params: List[Decl] vertical_loops: List[VerticalLoop] declarations: List[Temporary] diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 3869ff313b..d3e8af56d3 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -125,7 +125,7 @@ def _get_clang_format() -> Optional[str]: executable = os.getenv("CLANG_FORMAT_EXECUTABLE", "clang-format") try: assert isinstance(executable, str) - if subprocess.run([executable, "--version"], capture_output=True).returncode != 0: + if subprocess.run([executable, "--version"], capture_output=True, check=False).returncode != 0: return None except Exception: return None diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index b790922784..f597d338e9 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -189,7 +189,7 @@ def as_field( aligned_index: Optional[Sequence[common.NamedIndex]] = None, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - # TODO: copy=False + # TODO(): copy=False ) -> nd_array_field.NdArrayField: """Create a Field from an array-like object using the given (or device-default) allocator. @@ -293,7 +293,7 @@ def as_connectivity( allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, - # TODO: copy=False + # TODO(): copy=False ) -> common.Connectivity: """ Construct a `Connectivity` from the given domain, codomain, and data. diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 2510aee8b4..5d4ed0d8b6 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -188,7 +188,7 @@ def _find_index_of_dim( def canonicalize_any_index_sequence(index: common.AnyIndexSpec) -> common.AnyIndexSpec: - # TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice` + # TODO(): instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice` new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index): new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 25ce060c7c..096d054d26 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -163,7 +163,7 @@ def from_array( cls, data: ( npt.ArrayLike | core_defs.NDArrayObject - ), # TODO: NDArrayObject should be part of ArrayLike + ), # TODO(): NDArrayObject should be part of ArrayLike /, *, domain: common.DomainLike, @@ -814,7 +814,7 @@ def _compute_mask_slices( mask: core_defs.NDArrayObject, ) -> list[tuple[bool, slice]]: """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" - # TODO: does it make sense to upgrade this naive algorithm to numpy? + # TODO(): does it make sense to upgrade this naive algorithm to numpy? assert mask.ndim == 1 cur = bool(mask[0].item()) ind = 0 diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index fe3e2410fc..74b5d37520 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -319,7 +319,7 @@ def __call__( start = time.time() if __debug__: - # TODO: remove or make dependency on self.past_stage optional + # TODO(): remove or make dependency on self.past_stage optional past_process_args._validate_args( self.past_stage.past_node, arg_types=[type_translation.from_value(arg) for arg in args], diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 82832dd0f6..8769bd7a57 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -105,13 +105,12 @@ def dispatch(self, *args: Any) -> Callable[_P, _R]: arg_types = tuple(type(arg) for arg in args) for atype in arg_types: # current strategy is to select the implementation of the first arg that supports the operation - # TODO: define a strategy that converts or prevents conversion + # TODO(): define a strategy that converts or prevents conversion if (dispatcher := getattr(atype, "__gt_builtin_func__", None)) is not None and ( op_func := dispatcher(self) ) is not NotImplemented: return op_func - else: - return self.function + return self.function def __gt_type__(self) -> ts.FunctionType: signature = inspect.signature(self.function) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 8e127d33d2..24a3015ba2 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -77,7 +77,7 @@ class Name(Expr): class Constant(Expr): - value: Any # TODO: be more specific + value: Any # TODO(): be more specific class Subscript(Expr): diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b49f883a6..19954a1778 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -450,7 +450,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt: f"Inconsistent types between two branches for variable '{sym}': " f"got types '{true_type}' and '{false_type}.", ) - # TODO: properly patch symtable (new node?) + # TODO(): properly patch symtable (new node?) symtable[sym].type = new_node.annex.propagated_symbols[sym].type = ( new_true_branch.annex.symtable[sym].type ) @@ -889,7 +889,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> foast.Call: def _visit_reduction(self, node: foast.Call, **kwargs: Any) -> foast.Call: field_type = cast(ts.FieldType, node.args[0].type) reduction_dim = cast(ts.DimensionType, node.kwargs["axis"].type).dim - # TODO: This code does not handle ellipses for dimensions. Fix it. + # TODO(): This code does not handle ellipses for dimensions. Fix it. assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 2b162382b7..3717f1a7db 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -189,7 +189,7 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: # The structure of the scan passes argument and the requested # argument type differ. As such we can not extract the dimensions # and just return a generic field shown in the error later on. - # TODO: we want some generic field type here, but our type system does not support it yet. + # TODO(): we want some generic field type here, but our type system does not support it yet. return ts.FieldType(dims=[common.Dimension("...")], dtype=dtype) res = type_info.apply_to_primitive_constituents(_as_field, param, with_path_arg=True) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3888ccf2de..7b6a4de386 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1136,7 +1136,8 @@ class IndexField(common.Field): """ Minimal index field implementation. - TODO: Improve implementation (e.g. support slicing) and move out of this module. + Todo: + Improve implementation (e.g. support slicing) and move out of this module. """ _dimension: common.Dimension @@ -1201,7 +1202,7 @@ def restrict(self, item: common.AnyIndexSpec) -> Self: assert isinstance(r, core_defs.INTEGRAL_TYPES) # TODO(tehrengruber): Use a regular zero dimensional field instead. return self.__class__(self._dimension, r) - # TODO: set a domain... + # TODO(): set a domain... raise NotImplementedError() __call__ = premap @@ -1274,7 +1275,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]): """ Minimal constant field implementation. - TODO: Improve implementation (e.g. support slicing) and move out of this module. + Todo: + Improve implementation (e.g. support slicing) and move out of this module. """ _value: core_defs.ScalarT @@ -1494,7 +1496,7 @@ def make_const_list(value): @builtins.reduce.register(EMBEDDED) def reduce(fun, init): def sten(*lists): - # TODO: assert check_that_all_lists_are_compatible(*lists) + # TODO(): assert check_that_all_lists_are_compatible(*lists) lst = None for cur in lists: if isinstance(cur, _List): diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index c0ef1a47f1..aeae9f0e6c 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -440,7 +440,7 @@ def promote_to_lifted_stencil(op: str | itir.SymRef | Callable) -> Callable[..., def _impl(*its: itir.Expr) -> itir.FunCall: args = [ f"__arg{i}" for i in range(len(its)) - ] # TODO: `op` must not contain `SymRef(id="__argX")` + ] # TODO(): `op` must not contain `SymRef(id="__argX")` return lift(lambda_(*args)(op(*[deref(arg) for arg in args])))(*its) return _impl @@ -533,7 +533,7 @@ def op_as_fieldop( def _impl(*its: itir.Expr) -> itir.FunCall: args = [ f"__arg{i}" for i in range(len(its)) - ] # TODO: `op` must not contain `SymRef(id="__argX")` + ] # TODO(): `op` must not contain `SymRef(id="__argX")` return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its) return _impl diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index c33e3a71d0..9d04961638 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -494,7 +494,7 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) - # TODO: revisit what we want to do with OffsetLiterals as we already have an Offset type in + # TODO(): revisit what we want to do with OffsetLiterals as we already have an Offset type in # the frontend. def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: if _is_representable_as_int(node.value): diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 9017ae1ff1..73aa578453 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -168,7 +168,7 @@ def ignore_function(folder: str, children: list[str]) -> list[str]: build_data.write_data( data=build_data.BuildData( status=build_data.BuildStatus.INITIALIZED, - module=pathlib.Path(""), + module=pathlib.Path(), entry_point_name=self.program_name, ), path=self.root_path, diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ef3a4083b9..1b1749d5c1 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -198,7 +198,7 @@ class StepSequence(ChainableWorkflowMixin[StartT, EndT]): class __Steps: inner: tuple[Workflow[Any, Any], ...] - # todo(ricoh): replace with normal tuple with TypeVarTuple hints + # TODO(ricoh): replace with normal tuple with TypeVarTuple hints # to enable automatic deduction StartT and EndT fom constructor # calls. TypeVarTuple is available in typing_extensions in # Python <= 3.11. Revise after mypy constraint is > 1.0.1, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index b2aea05641..4868306f41 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -24,7 +24,7 @@ ) -# TODO: start of code clone from unroll_reduce.py. This is necessary since whilet the IR nodes are compatible between itir and gtfn_ir, +# TODO(): start of code clone from unroll_reduce.py. This is necessary since whilet the IR nodes are compatible between itir and gtfn_ir, # the structure of the ir is slightly different, hence functions like _is_shifted and _get_partial_offset_tag are slightly changed # in this version of the code clone. To be removed asap def _is_shifted(arg: gtfn_ir_common.Expr) -> TypeGuard[gtfn_ir.FunCall]: @@ -83,7 +83,7 @@ def _is_reduce(node: gtfn_ir.FunCall) -> TypeGuard[gtfn_ir.FunCall]: ) -# TODO: end of code clone +# TODO(): end of code clone class PlugInCurrentIdx(NodeTranslator): diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 42766666e6..889cb0f800 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -918,8 +918,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp if nstate.degree(data_node) == 0: assert not data_node.desc(nsdfg).transient nsdfg.remove_node(data_node) - else: - result = outer_value + result = outer_value outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 3b5ce24512..fcb235b94c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -703,7 +703,8 @@ def gt_remove_trivial_gpu_maps( validate: Perform validation at the end of the function. validate_all: Perform validation also on intermediate steps. - Todo: Improve this function. + Todo: + Improve this function. """ # First we try to promote and fuse them with other non-trivial maps. @@ -785,7 +786,8 @@ class TrivialGPUMapElimination(dace_transformation.SingleStateTransformation): is run within the context of `gt_gpu_transformation()`. - This transformation must be run after the GPU Transformation. - Todo: Figuring out if this transformation is still needed. + Todo: + Figuring out if this transformation is still needed. """ only_gpu_maps = dace_properties.Property( diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 39450004f1..c9fb595390 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -426,7 +426,7 @@ def _classify_node( # A Tasklet must write to an AccessNode, because otherwise there would # be nothing that could be used to cache anything. Furthermore, this # AccessNode must be outside of the inner loop, i.e. be independent. - # TODO: Make this check stronger to ensure that there is always an + # TODO(): Make this check stronger to ensure that there is always an # AccessNode that is independent. if not all( isinstance(out_edge.dst, dace_nodes.AccessNode) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py index ad50bac233..863456cdab 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py @@ -88,7 +88,7 @@ def gt_vertical_map_split_fusion( find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) - # TODO: Restrict MapFusion such that it only applies to the Maps that have + # TODO(): Restrict MapFusion such that it only applies to the Maps that have # been split and not some other random Maps. transformations = [ VerticalSplitMapRange( @@ -477,7 +477,7 @@ def can_be_applied( ): return False - # TODO: Ensure that the fusion can be performed. + # TODO(): Ensure that the fusion can be performed. return True diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py index 83fbbd08c8..4db1f293dd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -361,7 +361,7 @@ def can_be_applied( # do not have to adjust maps. # NOTE: In previous versions there was an ad hoc rule, to bypass the "full # read rule". However, it caused problems, so it was removed. - # TODO: We have to improve this test, because sometimes the expressions are + # TODO(): We have to improve this test, because sometimes the expressions are # so complex that without information about relations, such as # `vertical_start <= vertical_end` it is not possible to prove this check. a1_range = dace_sbs.Range.from_array(a1_desc) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py b/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py index bcbd60f6d9..75bd79b88c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py @@ -844,7 +844,7 @@ def _perform_node_split_with_bypass_impl( ) -> list[dace_graph.MultiConnectorEdge]: """Performs the splitting but the edge might go directly to the consumer. - # TODO: Remove the producer edge, run reconfiguration, split operation. + # TODO(): Remove the producer edge, run reconfiguration, split operation. # TODO ADDING PRODUCER TO THE SET OF PROCESSED NODES """ diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index af1d312665..5e98f6b331 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -69,7 +69,8 @@ def make_sdfg_call_async(sdfg: dace.SDFG, gpu: bool) -> None: for work that runs on the GPU. Furthermore, all work is scheduled on the default stream. - Todo: Revisit this function once DaCe changes its behaviour in this regard. + Todo: + Revisit this function once DaCe changes its behaviour in this regard. """ # This is only a problem on GPU. @@ -106,7 +107,8 @@ def make_sdfg_call_sync(sdfg: dace.SDFG, gpu: bool) -> None: have _finished_ and the results are available. This function only has an effect for work that runs on the GPU. Furthermore, all work is scheduled on the default stream. - Todo: Revisit this function once DaCe changes its behaviour in this regard. + Todo: + Revisit this function once DaCe changes its behaviour in this regard. """ # This is only a problem on GPU. diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 438d98fc32..30e8a5da14 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -639,7 +639,7 @@ def return_type_field( source_dim = with_args[0].source target_dims = with_args[0].target new_dims = [] - # TODO: This code does not handle ellipses for dimensions. Fix it. + # TODO(): This code does not handle ellipses for dimensions. Fix it. assert field_type.dims is not ... for d in field_type.dims: if d != source_dim: diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index e2311e3e60..1499acbc8b 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -45,7 +45,7 @@ _NDBuffer: TypeAlias = Union[ - # TODO: add `xtyping.Buffer` once we update typing_extensions + # TODO(): add `xtyping.Buffer` once we update typing_extensions xtyping.ArrayInterface, xtyping.CUDAArrayInterface, xtyping.DLPackBuffer ] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 967cf0ab11..efd335d797 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -367,7 +367,7 @@ def allocate( domain = extend_domain( domain, extend - ) # TODO: this should take into account the Domain of the allocated field + ) # TODO(): this should take into account the Domain of the allocated field arg_type = get_param_types(fieldview_prog)[name] if strategy is None: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py index 87bf0e5bd7..21dbbfb364 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py @@ -54,7 +54,7 @@ def mod_prog(f: cases.IField, out: cases.IKField): cases.verify(cartesian_case, mod_prog, f, out=out, ref=expected) -# TODO: these set of features should be allowed as module imports in a later PR +# TODO(): these set of features should be allowed as module imports in a later PR def test_import_module_errors_future_allowed(cartesian_case): with pytest.raises(gtx.errors.DSLError): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 172b936665..90c8f8fa3a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -31,7 +31,7 @@ from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data -# TODO: reduce duplication with `test_math_unary_builtins` +# TODO(): reduce duplication with `test_math_unary_builtins` # TODO(tehrengruber): add tests for scalar arguments to builtin. To avoid code # bloat this is postponed until programatically creating field operators diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 62a8f1b755..c760a8641d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -86,4 +86,4 @@ def field_operator_with_undeclared_symbol(): assert exc_info.value.location.end_column == 33 -# TODO: test program type deduction? +# TODO(): test program type deduction? diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 43ab1fcaf1..06c907b815 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -7,12 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -# TODO: test failure when something is not typed after inference is run -# TODO: test lift with no args -# TODO: lambda function that is not called -# TODO: partially applied function in a let -# TODO: function calling itself should fail -# TODO: lambda function called with different argument types +# TODO(): test failure when something is not typed after inference is run +# TODO(): test lift with no args +# TODO(): lambda function that is not called +# TODO(): partially applied function in a let +# TODO(): function calling itself should fail +# TODO(): lambda function called with different argument types import pytest @@ -132,8 +132,8 @@ def expression_test_cases(): ), # cast (im.cast_(1, int_type), int_type), - # TODO: lift - # TODO: scan + # TODO(): lift + # TODO(): scan # map ( im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py index c162860c7c..3444836915 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py @@ -33,4 +33,4 @@ def test_get_user_defined_symbols(): def test_collect_symbol_refs(): ... - # TODO: Test collect_symbol_refs + # TODO(): Test collect_symbol_refs diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 0998ab8eab..e33a42213d 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -21,7 +21,7 @@ sizes = {I: 10, J: 10, K: 10} -# TODO: parametrize with gpu backend and compare with cupy array +# TODO(): parametrize with gpu backend and compare with cupy array @pytest.mark.parametrize( "allocator, device", [ @@ -40,7 +40,7 @@ def test_empty(allocator, device): assert a.shape == ref.shape -# TODO: parametrize with gpu backend and compare with cupy array +# TODO(): parametrize with gpu backend and compare with cupy array @pytest.mark.parametrize( "allocator, device", [ @@ -62,7 +62,7 @@ def test_zeros(allocator, device): assert np.array_equal(a.ndarray, ref) -# TODO: parametrize with gpu backend and compare with cupy array +# TODO(): parametrize with gpu backend and compare with cupy array @pytest.mark.parametrize( "allocator, device", [ @@ -82,7 +82,7 @@ def test_ones(allocator, device): assert np.array_equal(a.ndarray, ref) -# TODO: parametrize with gpu backend and compare with cupy array +# TODO(): parametrize with gpu backend and compare with cupy array @pytest.mark.parametrize( "allocator, device", [ From d83b55c027b083d7ac3881492f8a5d70fb16be23 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Aug 2025 09:39:45 +0200 Subject: [PATCH 2/4] WIP2 --- .github/workflows/_disabled/gt4py-sphinx.yml | 2 +- src/gt4py/_core/definitions.py | 2 +- src/gt4py/cartesian/backend/cuda_backend.py | 2 +- src/gt4py/cartesian/backend/dace_backend.py | 2 +- src/gt4py/cartesian/backend/gtcpp_backend.py | 2 +- src/gt4py/cartesian/frontend/defir_to_gtir.py | 2 +- src/gt4py/cartesian/frontend/nodes.py | 2 +- src/gt4py/cartesian/gtc/common.py | 6 +++--- src/gt4py/cartesian/gtc/dace/oir_to_treeir.py | 2 +- src/gt4py/cartesian/gtc/gtir.py | 6 +++--- src/gt4py/next/common.py | 2 +- src/gt4py/next/constructors.py | 6 ++---- src/gt4py/next/embedded/nd_array_field.py | 6 +++--- src/gt4py/next/ffront/fbuiltins.py | 6 +++--- src/gt4py/next/ffront/past_process_args.py | 3 +-- src/gt4py/next/iterator/builtins.py | 2 +- src/gt4py/next/iterator/dispatcher.py | 2 +- src/gt4py/next/iterator/embedded.py | 12 ++++++------ src/gt4py/next/iterator/library.py | 2 +- .../next/program_processors/codegens/gtfn/codegen.py | 2 +- .../program_processors/codegens/gtfn/gtfn_module.py | 2 +- .../next/program_processors/runners/dace/program.py | 2 +- .../runners/dace/transformations/splitting_tools.py | 2 +- src/gt4py/next/type_system/type_translation.py | 3 +-- .../feature_tests/ffront_tests/test_concat_where.py | 2 +- .../ffront_tests/test_math_builtin_execution.py | 2 +- .../ffront_tests/test_math_unary_builtins.py | 2 +- .../feature_tests/iterator_tests/test_tuple.py | 2 +- .../ffront_tests/test_ffront_fvm_nabla.py | 2 +- .../iterator_tests/test_column_stencil.py | 2 +- .../iterator_tests/test_fvm_nabla.py | 2 +- .../iterator_tests/test_with_toy_connectivity.py | 4 ++-- 32 files changed, 47 insertions(+), 51 deletions(-) diff --git a/.github/workflows/_disabled/gt4py-sphinx.yml b/.github/workflows/_disabled/gt4py-sphinx.yml index 6606950eb5..711dcb2c7e 100644 --- a/.github/workflows/_disabled/gt4py-sphinx.yml +++ b/.github/workflows/_disabled/gt4py-sphinx.yml @@ -29,6 +29,6 @@ jobs: run: | python -m pip install .[dace] - name: Build documentation - # TODO re-enable SPHINXOPTS=-W + # TODO(): re-enable SPHINXOPTS=-W run: | cd docs && make -e html diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 41a592c3d4..3d7e77629e 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -140,7 +140,7 @@ def is_positive_integral_type(integral_type: type) -> TypeGuard[Type[PositiveInt TensorShape: TypeAlias = Sequence[ int -] # TODO(egparedes) figure out if PositiveIntegral can be made to work +] # TODO(egparedes): figure out if PositiveIntegral can be made to work def is_valid_tensor_shape(value: Sequence[IntegralScalar]) -> TypeGuard[TensorShape]: diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index 8097b5c474..abf53a1869 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -144,7 +144,7 @@ def generate_extension(self) -> None: def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) - # TODO(havogt) add bypass if computation has no effect + # TODO(havogt): add bypass if computation has no effect self.generate_extension() # Generate and return the Python wrapper class diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5ad028d183..f74ec21679 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -814,7 +814,7 @@ class BaseDaceBackend(BaseGTBackend): def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) - # TODO(havogt) add bypass if computation has no effect + # TODO(havogt): add bypass if computation has no effect self.generate_extension() # Generate and return the Python wrapper class diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index 35419fb2d9..adfca17c7e 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -132,7 +132,7 @@ def _generate_extension(self, uses_cuda: bool) -> None: def generate(self) -> type[StencilObject]: self.check_options(self.builder.options) - # TODO(havogt) add bypass if computation has no effect + # TODO(havogt): add bypass if computation has no effect self.generate_extension() # Generate and return the Python wrapper class diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index 78616a68cf..2075d06a61 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -544,7 +544,7 @@ def visit_AxisInterval(self, node: AxisInterval) -> Tuple[gtir.AxisBound, gtir.A return self.visit(node.start), self.visit(node.end) def visit_AxisBound(self, node: AxisBound) -> gtir.AxisBound: - # TODO(havogt) add support VarRef + # TODO(havogt): add support VarRef return gtir.AxisBound( level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level], offset=node.offset ) diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index 511dc9f760..b71da1b8e3 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -755,7 +755,7 @@ def get_offset(bound: AxisBound) -> int: ) -# TODO Find a better place for this in the file. +# TODO(): Find a better place for this in the file. @attribclass class HorizontalIf(Statement): intervals = attribute(of=DictOf[str, AxisInterval]) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index c6545b6f6c..0ecca2b8bc 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -306,7 +306,7 @@ def compute_kind(*values: Expr) -> ExprKind: class Literal(eve.Node): - # TODO(havogt) reconsider if `str` is a good representation for value, + # TODO(havogt): reconsider if `str` is a good representation for value, # maybe it should be Union[float,int,str] etc? value: Union[BuiltInLiteral, str] dtype: DataType @@ -656,8 +656,8 @@ def _allowed_flags(self, loop_order: LoopOrder) -> List[Tuple[bool, bool, bool]] return allowed_flags -# TODO(ricoh) consider making gtir.Decl & oir.Decl common and / or adding a VerticalLoop baseclass -# TODO(ricoh) in common instead of passing type arguments +# TODO(ricoh): consider making gtir.Decl & oir.Decl common and / or adding a VerticalLoop baseclass +# TODO(ricoh): in common instead of passing type arguments def validate_lvalue_dims( vertical_loop_type: Type[eve.Node], decl_type: Type[eve.Node] ) -> datamodels.RootValidator: diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py b/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py index e1cbbef5ce..a30678b3c9 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py @@ -347,7 +347,7 @@ def visit_Stencil(self, node: oir.Stencil) -> tir.TreeRoot: tir.Axis.J: -field_extent[1][0], tir.Axis.K: max(k_bound[0], 0), } - # TODO / Dev Note: Persistent memory is an overkill here - we should scope + # TODO(): / Dev Note: Persistent memory is an overkill here - we should scope # the temporary as close to the tasklets as we can, but any lifetime lower # than persistent will yield issues with memory leaks. containers[field.name] = data.Array( diff --git a/src/gt4py/cartesian/gtc/gtir.py b/src/gt4py/cartesian/gtc/gtir.py index 0ee4f7ebe1..3ca91025dd 100644 --- a/src/gt4py/cartesian/gtc/gtir.py +++ b/src/gt4py/cartesian/gtc/gtir.py @@ -119,7 +119,7 @@ def verify_scalar_condition(self, attribute: datamodels.Attribute, value: Expr) if value.kind != common.ExprKind.FIELD: raise ValueError("Condition is not a field expression") - # TODO(havogt) add validator for the restriction (it's a pass over the subtrees...) + # TODO(havogt): add validator for the restriction (it's a pass over the subtrees...) class ScalarIfStmt(common.IfStmt[BlockStmt, Expr], Stmt): @@ -171,7 +171,7 @@ class NativeFuncCall(common.NativeFuncCall[Expr], Expr): _dtype_propagation = common.native_func_call_dtype_propagation(strict=False) -class Decl(LocNode): # TODO probably Stmt +class Decl(LocNode): # TODO(): probably Stmt name: eve.Coerced[eve.SymbolName] dtype: common.DataType @@ -195,7 +195,7 @@ class Interval(LocNode): end: AxisBound -# TODO(havogt) should vertical loop open a scope? +# TODO(havogt): should vertical loop open a scope? class VerticalLoop(LocNode): interval: Interval loop_order: common.LoopOrder diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index dc6f24e9dd..2d44d84867 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -660,7 +660,7 @@ class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, P @property def __gt_domain__(self) -> Domain: - # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`) + # TODO(): probably should be changed to `DomainLike` (with a new concept `DimensionLike`) # to allow implementations without having to import gtx.Domain. ... diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index f597d338e9..feb0b6e7a8 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -265,8 +265,7 @@ def as_field( if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) - assert dtype.tensor_shape == () # TODO - + assert dtype.tensor_shape == () # TODO(): fix if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) @@ -347,8 +346,7 @@ def as_connectivity( if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) - assert dtype.tensor_shape == () # TODO - + assert dtype.tensor_shape == () # TODO(): fix if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 096d054d26..55b4c23406 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -117,7 +117,7 @@ class NdArrayField( _domain: common.Domain _ndarray: core_defs.NDArrayObject - array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol + array_ns: ClassVar[ModuleType] # TODO(havogt): introduce a NDArrayNamespace protocol @property def domain(self) -> common.Domain: @@ -784,7 +784,7 @@ def _hyperslice( fbuiltins.power, # type: ignore[attr-defined] NdArrayField.__pow__, ) -# TODO gamma +# TODO(): gamma for name in ( fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES @@ -1080,7 +1080,7 @@ def __setitem__( index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` + # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index):.set(value):` raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") common._field.register(jnp.ndarray, JaxArrayField.from_array) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8769bd7a57..e2f0d4d197 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -147,13 +147,13 @@ def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( - # TODO(havogt) find a strategy to unify parsing and embedded error messages + # TODO(havogt): find a strategy to unify parsing and embedded error messages f"Either both or none can be tuple in '{true_field=}' and '{false_field=}'." ) if len(true_field) != len(false_field): raise ValueError( "Tuple of different size not allowed." - ) # TODO(havogt) find a strategy to unify parsing and embedded error messages + ) # TODO(havogt): find a strategy to unify parsing and embedded error messages return tuple(self(cond, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(cond, true_field, false_field) @@ -180,7 +180,7 @@ def broadcast( assert core_defs.is_scalar_type( field ) # default implementation for scalars, Fields are handled via dispatch - # TODO(havogt) implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions + # TODO(havogt): implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions return field # type: ignore[return-value] # see comment above diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index f0360e05ba..af2e9807ef 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -122,8 +122,7 @@ def _field_constituents_range_and_dims( yield from _field_constituents_range_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) - if isinstance(arg, ts.TypeSpec): # TODO - yield (tuple(), dims) + if isinstance(arg, ts.TypeSpec): # TODO(): fix yield (tuple(), dims) elif dims: assert ( hasattr(arg, "domain") diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e3f45f6c74..5454902b26 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -443,7 +443,7 @@ def concat_where(*args): "multiplies", "divides", "mod", - "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 + "floordiv", # TODO(): see https://github.com/GridTools/gt4py/issues/1136 "minimum", "maximum", "fmod", diff --git a/src/gt4py/next/iterator/dispatcher.py b/src/gt4py/next/iterator/dispatcher.py index fae92dc971..d362c99c22 100644 --- a/src/gt4py/next/iterator/dispatcher.py +++ b/src/gt4py/next/iterator/dispatcher.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List -# TODO test +# TODO(): test class _fun_dispatcher: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 7b6a4de386..bfaa7f60ce 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# TODO(havogt) move public definitions and make this module private +# TODO(havogt): move public definitions and make this module private from __future__ import annotations @@ -433,7 +433,7 @@ def __init__( self.offsets = offsets or [] self.elem = elem - # TODO needs to be supported by all iterators that represent tuples + # TODO(): needs to be supported by all iterators that represent tuples def __getitem__(self, index): return _WrappedIterator(self.stencil, self.args, offsets=self.offsets, elem=index) @@ -487,7 +487,7 @@ def unstructured_domain(*args: NamedRange) -> runtime.UnstructuredDomain: @builtins.named_range.register(EMBEDDED) def named_range(tag: Tag | common.Dimension, start: int, end: int) -> NamedRange: - # TODO revisit this pattern after the discussion of 0d-field vs scalar + # TODO(): revisit this pattern after the discussion of 0d-field vs scalar if isinstance(start, ConstantField): start = start.value if isinstance(end, ConstantField): @@ -1190,7 +1190,7 @@ def premap( index_field: common.Connectivity | fbuiltins.FieldOffset, *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: - # TODO can be implemented by constructing and ndarray (but do we know of which kind?) + # TODO(): can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> Self: @@ -1317,11 +1317,11 @@ def premap( index_field: common.Connectivity | fbuiltins.FieldOffset, *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: - # TODO can be implemented by constructing and ndarray (but do we know of which kind?) + # TODO(): can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> Self: - # TODO set a domain... + # TODO(): set a domain... return self def as_scalar(self) -> core_defs.ScalarT: diff --git a/src/gt4py/next/iterator/library.py b/src/gt4py/next/iterator/library.py index 9d2db8afc5..871b8bdf0f 100644 --- a/src/gt4py/next/iterator/library.py +++ b/src/gt4py/next/iterator/library.py @@ -13,7 +13,7 @@ def sum_(fun=None): if fun is None: return reduce(lambda a, b: a + b, 0.0) else: - return reduce(lambda first, a, b: first + fun(a, b), 0.0) # TODO tracing for *args + return reduce(lambda first, a, b: first + fun(a, b), 0.0) # TODO(): tracing for *args def dot(a, b): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 37b7620bdc..88a39cfa00 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -160,7 +160,7 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> str: Lambda = as_mako( "[=](${','.join('auto ' + p for p in params)}){return ${expr};}" - ) # TODO capture + ) # TODO(): capture Backend = as_fmt("make_backend(backend, {domain})") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 90441ec61a..3aaa0d5f95 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -48,7 +48,7 @@ class GTFNTranslationStep( ], ): language_settings: Optional[languages.LanguageWithHeaderFilesSettings] = None - # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 + # TODO(): replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 enable_itir_transforms: bool = True use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index acf12880e5..1109278444 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -131,7 +131,7 @@ def single_horizontal_dim_per_field( ) sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) - # TODO (ricoh): bring back sdfg.offset_providers_per_input_field. + # TODO(ricoh): bring back sdfg.offset_providers_per_input_field. # A starting point would be to use the "trace_shifts" pass on GTIR # and associate the extracted shifts with each input field. # Analogous to the version in `runners.dace_iterator.__init__`, which diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py b/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py index 75bd79b88c..9c7273a2b8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py @@ -845,7 +845,7 @@ def _perform_node_split_with_bypass_impl( """Performs the splitting but the edge might go directly to the consumer. # TODO(): Remove the producer edge, run reconfiguration, split operation. - # TODO ADDING PRODUCER TO THE SET OF PROCESSED NODES + # TODO(): ADDING PRODUCER TO THE SET OF PROCESSED NODES """ producer_edge_desc = next( diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 8e0007e315..dea06c30d6 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -151,8 +151,7 @@ def from_type_hint( return ts.FunctionType( pos_only_args=new_args, pos_or_kw_args=kwargs, - kw_only_args={}, # TODO - returns=returns, + kw_only_args={}, # TODO(): fix returns=returns, ) raise ValueError(f"'{type_hint}' type is not supported.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 8ce734ef22..9ad435b1cb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -274,7 +274,7 @@ def test_lap_like(cartesian_case): def testee( inp: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] ) -> cases.IJField: - # TODO add support for multi-dimensional concat_where masks + # TODO(): add support for multi-dimensional concat_where masks return concat_where( (IDim == 0) | (IDim == shape[0] - 1), boundary, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 90c8f8fa3a..827581efb0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -119,7 +119,7 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[next_backen @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inputs): if cartesian_case.backend is None: - # TODO(havogt) find a way that works for embedded + # TODO(havogt): find a way that works for embedded pytest.xfail("Test does not have a field view program.") if builtin_name == "gamma": # numpy has no gamma function diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 1707adada8..827154da38 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -175,7 +175,7 @@ def test_unary_not(cartesian_case): pytest.xfail( "We accidentally supported 'not' on fields. This is wrong, we should raise an error." ) - with pytest.raises: # TODO 'not' on a field should be illegal + with pytest.raises: # TODO(): 'not' on a field should be illegal @gtx.field_operator def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index ea89bb23ba..cdf0d40bc1 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -230,7 +230,7 @@ def test_tuple_field_input(program_processor): [IDim, JDim, KDim], rng.normal( size=(shape[0], shape[1], shape[2] + 1) - ), # TODO(havogt) currently we allow different sizes, needed for icon4py compatibility + ), # TODO(havogt): currently we allow different sizes, needed for icon4py compatibility ) out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index da354be7ea..66bf9652bd 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -109,7 +109,7 @@ def test_ffront_nabla(exec_alloc_descriptor): }, ) - # TODO this check is not sensitive enough, need to implement a proper numpy reference! + # TODO(): this check is not sensitive enough, need to implement a proper numpy reference! assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 22b9f5f9e8..962e61cf6a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -478,4 +478,4 @@ def test_different_vertical_sizes_with_origin(program_processor): assert np.allclose(ref, out.asnumpy()) -# TODO(havogt) test tuple_get builtin on a Column +# TODO(havogt): test tuple_get builtin on a Column diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 22b4d8b3c5..30649b42e2 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -270,7 +270,7 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): - # TODO replace by single stencil which returns tuple + # TODO(): replace by single stencil which returns tuple domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) set_at( as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ff87de7348..3119c6a36c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -49,11 +49,11 @@ from next_tests.unit_tests.conftest import program_processor, run_processor -def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings +def edge_index_field(): # TODO(): replace by gtx.index_field once supported in bindings return gtx.as_field([Edge], np.arange(e2v_arr.shape[0], dtype=np.int32)) -def vertex_index_field(): # TODO replace by gtx.index_field once supported in bindings +def vertex_index_field(): # TODO(): replace by gtx.index_field once supported in bindings return gtx.as_field([Vertex], np.arange(v2e_arr.shape[0], dtype=np.int32)) From b0aaf64a51a7e6e0f0cdb40a6b3f05c8cfba053b Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Aug 2025 09:47:42 +0200 Subject: [PATCH 3/4] WIP 3 --- pyproject.toml | 12 +- scripts/nox_sessions.py | 19 ++- scripts/update.py | 6 +- src/gt4py/_core/definitions.py | 12 +- src/gt4py/cartesian/backend/base.py | 3 +- src/gt4py/cartesian/backend/cuda_backend.py | 8 +- .../cartesian/backend/dace_lazy_stencil.py | 17 +-- .../cartesian/backend/dace_stencil_object.py | 31 ++--- src/gt4py/cartesian/backend/gtc_common.py | 12 +- src/gt4py/cartesian/backend/gtcpp_backend.py | 4 +- .../cartesian/backend/module_generator.py | 6 +- src/gt4py/cartesian/backend/pyext_builder.py | 46 +++---- src/gt4py/cartesian/caching.py | 28 ++-- src/gt4py/cartesian/config.py | 12 +- src/gt4py/cartesian/definitions.py | 24 ++-- src/gt4py/cartesian/frontend/defir_to_gtir.py | 42 +++--- src/gt4py/cartesian/frontend/exceptions.py | 28 ++-- .../cartesian/frontend/gtscript_frontend.py | 98 ++++++-------- src/gt4py/cartesian/frontend/node_util.py | 5 +- src/gt4py/cartesian/frontend/nodes.py | 5 +- src/gt4py/cartesian/gt_cache_manager.py | 5 +- src/gt4py/cartesian/gtc/common.py | 124 ++++++++---------- src/gt4py/cartesian/gtc/cuir/cuir.py | 56 ++++---- src/gt4py/cartesian/gtc/cuir/cuir_codegen.py | 17 +-- .../cartesian/gtc/cuir/extent_analysis.py | 8 +- src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py | 20 +-- src/gt4py/cartesian/gtc/dace/oir_to_treeir.py | 6 +- src/gt4py/cartesian/gtc/dace/treeir.py | 3 +- src/gt4py/cartesian/gtc/dace/utils.py | 4 +- .../cartesian/gtc/debug/debug_codegen.py | 3 +- src/gt4py/cartesian/gtc/definitions.py | 36 +++-- src/gt4py/cartesian/gtc/gtcpp/gtcpp.py | 46 +++---- .../cartesian/gtc/gtcpp/gtcpp_codegen.py | 11 +- src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py | 31 ++--- src/gt4py/cartesian/gtc/gtir.py | 34 ++--- src/gt4py/cartesian/gtc/gtir_to_oir.py | 20 +-- src/gt4py/cartesian/gtc/numpy/npir.py | 40 +++--- src/gt4py/cartesian/gtc/numpy/npir_codegen.py | 51 ++++--- src/gt4py/cartesian/gtc/numpy/oir_to_npir.py | 24 ++-- .../cartesian/gtc/numpy/scalars_to_temps.py | 9 +- src/gt4py/cartesian/gtc/oir.py | 42 +++--- .../gtir_definitive_assignment_analysis.py | 13 +- .../gtc/passes/gtir_dtype_resolver.py | 8 +- .../cartesian/gtc/passes/gtir_k_boundary.py | 12 +- .../cartesian/gtc/passes/gtir_pipeline.py | 5 +- .../cartesian/gtc/passes/gtir_upcaster.py | 7 +- .../cartesian/gtc/passes/horizontal_masks.py | 14 +- .../cartesian/gtc/passes/oir_access_kinds.py | 10 +- .../gtc/passes/oir_optimizations/caches.py | 67 +++++----- .../horizontal_execution_merging.py | 29 ++-- .../gtc/passes/oir_optimizations/inlining.py | 22 ++-- .../oir_optimizations/mask_stmt_merging.py | 3 +- .../gtc/passes/oir_optimizations/pruning.py | 4 +- .../passes/oir_optimizations/temporaries.py | 13 +- .../gtc/passes/oir_optimizations/utils.py | 55 ++++---- .../cartesian/gtc/passes/oir_pipeline.py | 5 +- src/gt4py/cartesian/gtc/utils.py | 9 +- src/gt4py/cartesian/gtscript.py | 5 +- src/gt4py/cartesian/gtscript_imports.py | 13 +- src/gt4py/cartesian/lazy_stencil.py | 4 +- src/gt4py/cartesian/loader.py | 2 +- src/gt4py/cartesian/stencil_object.py | 7 +- .../cartesian/testing/input_strategies.py | 11 +- src/gt4py/cartesian/testing/suites.py | 16 +-- src/gt4py/cartesian/utils/attrib.py | 34 ++--- src/gt4py/cartesian/utils/meta.py | 19 +-- src/gt4py/eve/codegen.py | 29 ++-- src/gt4py/eve/concepts.py | 4 +- src/gt4py/eve/datamodels/core.py | 25 ++-- src/gt4py/eve/extended_typing.py | 4 +- src/gt4py/eve/traits.py | 4 +- src/gt4py/eve/utils.py | 72 +++++----- src/gt4py/next/embedded/operators.py | 3 +- src/gt4py/next/ffront/dialect_parser.py | 2 +- src/gt4py/next/ffront/experimental.py | 7 +- src/gt4py/next/ffront/fbuiltins.py | 19 +-- src/gt4py/next/ffront/field_operator_ast.py | 14 +- .../ffront/foast_passes/type_deduction.py | 4 +- src/gt4py/next/ffront/foast_pretty_printer.py | 4 +- src/gt4py/next/ffront/foast_to_gtir.py | 3 +- src/gt4py/next/ffront/func_to_foast.py | 5 +- src/gt4py/next/ffront/lowering_utils.py | 4 +- src/gt4py/next/ffront/past_process_args.py | 6 +- src/gt4py/next/ffront/program_ast.py | 8 +- src/gt4py/next/ffront/signature.py | 3 +- src/gt4py/next/ffront/transform_utils.py | 3 +- src/gt4py/next/ffront/type_info.py | 3 +- src/gt4py/next/iterator/dispatcher.py | 7 +- src/gt4py/next/iterator/embedded.py | 9 +- src/gt4py/next/iterator/ir.py | 20 +-- .../ir_utils/common_pattern_matcher.py | 6 +- .../next/iterator/ir_utils/domain_utils.py | 3 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 13 +- src/gt4py/next/iterator/ir_utils/misc.py | 3 +- src/gt4py/next/iterator/pretty_parser.py | 5 +- src/gt4py/next/iterator/runtime.py | 5 +- src/gt4py/next/iterator/tracing.py | 13 +- src/gt4py/next/iterator/transforms/cse.py | 5 +- .../transforms/fixed_point_transformation.py | 4 +- .../next/iterator/transforms/global_tmps.py | 4 +- .../iterator/transforms/inline_fundefs.py | 4 +- .../iterator/transforms/inline_into_scan.py | 3 +- .../iterator/transforms/inline_lambdas.py | 3 +- .../next/iterator/transforms/inline_lifts.py | 3 +- .../next/iterator/transforms/remap_symbols.py | 12 +- .../next/iterator/transforms/trace_shifts.py | 4 +- .../next/iterator/type_system/inference.py | 6 +- .../iterator/type_system/type_synthesizer.py | 2 +- src/gt4py/next/metrics.py | 4 +- src/gt4py/next/otf/arguments.py | 3 +- src/gt4py/next/otf/binding/cpp_interface.py | 3 +- src/gt4py/next/otf/binding/nanobind.py | 12 +- .../compilation/build_systems/cmake_lists.py | 2 +- .../compilation/build_systems/compiledb.py | 2 +- src/gt4py/next/otf/workflow.py | 4 +- .../codegens/gtfn/codegen.py | 5 +- .../codegens/gtfn/gtfn_im_ir.py | 8 +- .../codegens/gtfn/gtfn_ir.py | 17 +-- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 11 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 15 ++- .../program_processors/program_formatter.py | 3 +- .../runners/dace/gtir_dataflow.py | 37 ++---- .../runners/dace/gtir_domain.py | 3 +- .../runners/dace/gtir_python_codegen.py | 3 +- .../runners/dace/gtir_to_sdfg.py | 39 ++---- .../runners/dace/gtir_to_sdfg_primitives.py | 3 +- .../runners/dace/gtir_to_sdfg_scan.py | 3 +- .../runners/dace/gtir_to_sdfg_utils.py | 8 +- .../runners/dace/program.py | 3 +- .../dace/transformations/auto_optimize.py | 3 +- .../dead_dataflow_elimination.py | 3 +- .../runners/dace/transformations/gpu_utils.py | 29 ++-- .../dace/transformations/loop_blocking.py | 8 +- .../dace/transformations/map_fusion.py | 7 +- .../transformations/map_fusion_extended.py | 12 +- .../dace/transformations/map_orderer.py | 10 +- .../dace/transformations/map_promoter.py | 15 +-- ...ulti_state_global_self_copy_elimination.py | 8 +- .../redundant_array_removers.py | 3 +- .../runners/dace/transformations/simplify.py | 5 +- ...ngle_state_global_self_copy_elimination.py | 11 +- .../transformations/split_access_nodes.py | 5 +- .../dace/transformations/splitting_tools.py | 13 +- .../runners/dace/transformations/utils.py | 13 +- .../program_processors/runners/dace/utils.py | 7 +- .../runners/dace/workflow/bindings.py | 8 +- .../runners/dace/workflow/common.py | 3 +- .../runners/dace/workflow/compilation.py | 3 +- .../runners/dace/workflow/decoration.py | 3 +- src/gt4py/next/type_system/type_info.py | 17 +-- .../next/type_system/type_specifications.py | 5 +- src/gt4py/next/utils.py | 3 +- src/gt4py/storage/allocators.py | 6 +- src/gt4py/storage/cartesian/interface.py | 13 +- src/gt4py/storage/cartesian/layout.py | 39 ++---- src/gt4py/storage/cartesian/utils.py | 25 ++-- 156 files changed, 1060 insertions(+), 1160 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bbdab9988b..a6e77becb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -346,22 +346,20 @@ select = [ # 'N', # pep8-naming 'NPY', # NumPy-specific rules 'PERF', # Perflint - #'PGH', # pygrep-hooks - 'PLC', # Pylint-Convention + # 'PGH', # pygrep-hooks + # 'PLC', # Pylint-Convention 'PLE', # Pylint-Error - #'PLR', # Pylint-Refactor + # 'PLR', # Pylint-Refactor 'PLW', # Pylint-Warning 'PTH', # flake8-use-pathlib 'Q', # flake8-quotes 'RUF', # Ruff-specific rules - #'SIM', # flake8-simplify + # 'SIM', # flake8-simplify 'T10', # flake8-debugger 'TD', # flake8-todos - #'UP', # pyupgrade + 'UP', # pyupgrade 'YTT' # flake8-2020 ] - - typing-modules = ['gt4py.eve.extended_typing'] unfixable = [] diff --git a/scripts/nox_sessions.py b/scripts/nox_sessions.py index 95ee6e2da3..e206edc5fe 100644 --- a/scripts/nox_sessions.py +++ b/scripts/nox_sessions.py @@ -48,17 +48,14 @@ class ExitCode(enum.IntEnum): {"name": str, "paths": NotRequired[list[str]], "ignore-paths": NotRequired[list[str]]}, ) -NoxSessionDefinition = TypedDict( - "NoxSessionDefinition", - { - "session": str, - "name": str, - "description": str, - "python": str, - "tags": list[str], - "call_spec": dict[str, str], - }, -) + +class NoxSessionDefinition(TypedDict): + session: str + name: str + description: str + python: str + tags: list[str] + call_spec: dict[str, str] cli = typer.Typer(no_args_is_help=True, name="nox-sessions", help=__doc__) diff --git a/scripts/update.py b/scripts/update.py index 8780e2ab20..da93b1d8c3 100755 --- a/scripts/update.py +++ b/scripts/update.py @@ -42,7 +42,9 @@ def dependencies() -> None: def precommit() -> None: """Update versions of pre-commit hooks.""" subprocess.run( - f"uv run --quiet --locked --project {common.REPO_ROOT} pre-commit autoupdate", shell=True, check=False + f"uv run --quiet --locked --project {common.REPO_ROOT} pre-commit autoupdate", + shell=True, + check=False, ) try: @@ -57,7 +59,7 @@ def precommit() -> None: try: pre_commit_path = common.REPO_ROOT / ".pre-commit-config.yaml" - with open(pre_commit_path, "r", encoding="utf-8") as f: + with open(pre_commit_path, encoding="utf-8") as f: content = f.read() new_content = re.sub( diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 3d7e77629e..ae3cd40d7d 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -71,7 +71,7 @@ float32 = np.float32 float64 = np.float64 -BoolScalar: TypeAlias = Union[bool_, bool] +BoolScalar: TypeAlias = bool_ | bool BoolT = TypeVar("BoolT", bound=BoolScalar) BOOL_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], @@ -79,7 +79,7 @@ ) -IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] +IntScalar: TypeAlias = int8 | int16 | int32 | int64 | int IntT = TypeVar("IntT", bound=IntScalar) INT_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], @@ -87,7 +87,7 @@ ) -UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] +UnsignedIntScalar: TypeAlias = uint8 | uint16 | uint32 | uint64 UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) UINT_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], @@ -95,12 +95,12 @@ ) -IntegralScalar: TypeAlias = Union[IntScalar, UnsignedIntScalar] +IntegralScalar: TypeAlias = IntScalar | UnsignedIntScalar IntegralT = TypeVar("IntegralT", bound=IntegralScalar) INTEGRAL_TYPES: Final[Tuple[type, ...]] = (*INT_TYPES, *UINT_TYPES) -FloatingScalar: TypeAlias = Union[float32, float64, float] +FloatingScalar: TypeAlias = float32 | float64 | float FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast( Tuple[type, ...], @@ -109,7 +109,7 @@ #: Type alias for all scalar types supported by GT4Py -Scalar: TypeAlias = Union[BoolScalar, IntegralScalar, FloatingScalar] +Scalar: TypeAlias = BoolScalar | IntegralScalar | FloatingScalar ScalarT = TypeVar("ScalarT", bound=Scalar) SCALAR_TYPES: Final[tuple[type, ...]] = (*BOOL_TYPES, *INTEGRAL_TYPES, *FLOAT_TYPES) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index d37352d09e..297ac78bf9 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -15,7 +15,8 @@ import pathlib import time import warnings -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar, Protocol from typing_extensions import deprecated diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index abf53a1869..ecbb982dbe 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -82,8 +82,8 @@ def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs): data_ndim = len(node.data_dims) sid_ndim = domain_ndim + data_ndim if kwargs["external_arg"]: - return "py::object {name}, std::array {name}_origin".format( - name=node.name, sid_ndim=sid_ndim + return ( + f"py::object {node.name}, std::array {node.name}_origin" ) else: return pybuffer_to_sid( @@ -98,9 +98,9 @@ def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs): def visit_ScalarDecl(self, node: cuir.ScalarDecl, **kwargs): if "external_arg" in kwargs: if kwargs["external_arg"]: - return "{dtype} {name}".format(name=node.name, dtype=self.visit(node.dtype)) + return f"{self.visit(node.dtype)} {node.name}" else: - return "gridtools::stencil::global_parameter({name})".format(name=node.name) + return f"gridtools::stencil::global_parameter({node.name})" def visit_Program(self, node: cuir.Program, **kwargs): assert "module_name" in kwargs diff --git a/src/gt4py/cartesian/backend/dace_lazy_stencil.py b/src/gt4py/cartesian/backend/dace_lazy_stencil.py index 38a41ac466..bf82b984b6 100644 --- a/src/gt4py/cartesian/backend/dace_lazy_stencil.py +++ b/src/gt4py/cartesian/backend/dace_lazy_stencil.py @@ -8,7 +8,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional import dace from dace.frontend.python.common import SDFGConvertible @@ -30,7 +31,7 @@ def __init__(self, builder: StencilBuilder): super().__init__(builder=builder) @property - def field_info(self) -> Dict[str, Any]: + def field_info(self) -> dict[str, Any]: """ Return same value as compiled stencil object's `field_info` attribute. @@ -40,10 +41,10 @@ def field_info(self) -> Dict[str, Any]: def closure_resolver( self, - constant_args: Dict[str, Any], - given_args: Set[str], - parent_closure: Optional["dace.frontend.python.common.SDFGClosure"] = None, - ) -> "dace.frontend.python.common.SDFGClosure": + constant_args: dict[str, Any], + given_args: set[str], + parent_closure: Optional[dace.frontend.python.common.SDFGClosure] = None, + ) -> dace.frontend.python.common.SDFGClosure: return dace.frontend.python.common.SDFGClosure() def __sdfg__(self, *args, **kwargs) -> dace.SDFG: @@ -67,9 +68,9 @@ def __sdfg__(self, *args, **kwargs) -> dace.SDFG: **norm_kwargs, ) - def __sdfg_closure__(self, reevaluate: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: return {} - def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: + def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: args = [arg.name for arg in self.builder.gtir.api_signature] return (args, []) diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index 600295fc24..3dbb38eaf4 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -11,8 +11,9 @@ import copy import inspect import os +from collections.abc import Iterable, Sequence from dataclasses import dataclass -from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple +from typing import Any, Optional import dace import dace.data @@ -26,7 +27,7 @@ from gt4py.cartesian.utils import shash -def _extract_array_infos(field_args, device) -> Dict[str, Optional[ArgsInfo]]: +def _extract_array_infos(field_args, device) -> dict[str, Optional[ArgsInfo]]: return { name: ArgsInfo( array=arg, @@ -40,8 +41,8 @@ def _extract_array_infos(field_args, device) -> Dict[str, Optional[ArgsInfo]]: def add_optional_fields( sdfg: dace.SDFG, - field_info: Dict[str, Any], - parameter_info: Dict[str, Any], + field_info: dict[str, Any], + parameter_info: dict[str, Any], **kwargs: Any, ) -> dace.SDFG: sdfg = copy.deepcopy(sdfg) @@ -65,8 +66,8 @@ def add_optional_fields( @dataclass(frozen=True) class DaCeFrozenStencil(FrozenStencil, SDFGConvertible): stencil_object: DaCeStencilObject - origin: Dict[str, Tuple[int, ...]] - domain: Tuple[int, ...] + origin: dict[str, tuple[int, ...]] + domain: tuple[int, ...] sdfg: dace.SDFG def __sdfg__(self, **kwargs): @@ -106,8 +107,8 @@ def _get_domain_origin_key(domain, origin): def freeze( self: DaCeStencilObject, *, - origin: Dict[str, Tuple[int, ...]], - domain: Tuple[int, ...], + origin: dict[str, tuple[int, ...]], + domain: tuple[int, ...], ) -> DaCeFrozenStencil: key = DaCeStencilObject._get_domain_origin_key(domain, origin) @@ -142,8 +143,8 @@ def sdfg(cls) -> dace.SDFG: def closure_resolver( self, - constant_args: Dict[str, Any], - given_args: Set[str], + constant_args: dict[str, Any], + given_args: set[str], parent_closure: Optional[dace.frontend.python.common.SDFGClosure] = None, ) -> dace.frontend.python.common.SDFGClosure: return dace.frontend.python.common.SDFGClosure() @@ -161,10 +162,10 @@ def __sdfg__(self, *args, **kwargs) -> dace.SDFG: frozen_stencil = self.freeze(origin=norm_kwargs["origin"], domain=norm_kwargs["domain"]) return frozen_stencil.__sdfg__(**norm_kwargs) - def __sdfg_closure__(self, reevaluate: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: return {} - def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: + def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: special_args = {"self", "domain", "origin", "validate_args", "exec_info"} args = [] for arg in ( @@ -182,9 +183,9 @@ def normalize_arg_fields( backend: str, arg_names: Iterable[str], domain_info: DomainInfo, - field_info: Dict[str, FieldInfo], - domain: Optional[Tuple[int, ...]] = None, - origin: Optional[Dict[str, Tuple[int, ...]]] = None, + field_info: dict[str, FieldInfo], + domain: Optional[tuple[int, ...]] = None, + origin: Optional[dict[str, tuple[int, ...]]] = None, **kwargs, ): """Normalize Fields in argument list to the proper domain/origin""" diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index f5e694be28..7af70a4d88 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -41,13 +41,9 @@ def pybuffer_to_sid( as_sid = "as_cuda_sid" if backend.storage_info["device"] == "gpu" else "as_sid" - sid_def = """gt::{as_sid}<{ctype}, {sid_ndim}, - gt::integral_constant>({name})""".format( - name=name, ctype=ctype, unique_index=stride_kind_index, sid_ndim=sid_ndim, as_sid=as_sid - ) - sid_def = "gt::sid::shift_sid_origin({sid_def}, {name}_origin)".format( - sid_def=sid_def, name=name - ) + sid_def = f"""gt::{as_sid}<{ctype}, {sid_ndim}, + gt::integral_constant>({name})""" + sid_def = f"gt::sid::shift_sid_origin({sid_def}, {name}_origin)" if domain_ndim != 3: gt_dims = [ f"gt::stencil::dim::{dim}" @@ -149,7 +145,7 @@ def generate_implementation(self) -> str: for decl in ir.params: args.append(decl.name) if isinstance(decl, gtir.FieldDecl): - args.append("list(_origin_['{}'])".format(decl.name)) + args.append(f"list(_origin_['{decl.name}'])") # only generate implementation if any multi_stages are present. e.g. if no statement in the # stencil has any effect on the API fields, this may not be the case since they could be diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index adfca17c7e..3d41423c13 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -101,9 +101,9 @@ def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs): def visit_GlobalParamDecl(self, node: gtcpp.GlobalParamDecl, **kwargs): if "external_arg" in kwargs: if kwargs["external_arg"]: - return "{dtype} {name}".format(name=node.name, dtype=self.visit(node.dtype)) + return f"{self.visit(node.dtype)} {node.name}" else: - return "gridtools::stencil::global_parameter({name})".format(name=node.name) + return f"gridtools::stencil::global_parameter({node.name})" def visit_Program(self, node: gtcpp.Program, **kwargs): assert "module_name" in kwargs diff --git a/src/gt4py/cartesian/backend/module_generator.py b/src/gt4py/cartesian/backend/module_generator.py index 6a37fe0558..c1f51a35ba 100644 --- a/src/gt4py/cartesian/backend/module_generator.py +++ b/src/gt4py/cartesian/backend/module_generator.py @@ -266,14 +266,12 @@ def generate_signature(self) -> str: for arg in self.builder.gtir.api_signature: if arg.is_keyword: if arg.default: - keyword_args.append( - "{name}={default}".format(name=arg.name, default=arg.default) - ) + keyword_args.append(f"{arg.name}={arg.default}") else: keyword_args.append(arg.name) else: if arg.default: - args.append("{name}={default}".format(name=arg.name, default=arg.default)) + args.append(f"{arg.name}={arg.default}") else: args.append(arg.name) diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index 579763a6a2..974b3da541 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -12,7 +12,7 @@ import os import shutil import threading -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict import pybind11 import setuptools @@ -58,7 +58,7 @@ def get_gt_pyext_build_opts( add_profile_info: bool = False, uses_openmp: bool = True, uses_cuda: bool = False, -) -> Dict[str, Union[str, List[str], Dict[str, Any]]]: +) -> dict[str, str | list[str] | dict[str, Any]]: include_dirs: list[str] = [] extra_compile_args_from_config = gt_config.build_settings["extra_compile_args"] is_rocm_gpu = core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM @@ -88,7 +88,7 @@ def get_gt_pyext_build_opts( # A compiler is allowed to choose if `char` is signed or unsigned. We force the signed behavior # because `char` is used to represent the `int8` type in GT4Py programs. "-fsigned-char", - "-isystem{}".format(gt_include_path), + f"-isystem{gt_include_path}", *extra_compile_args_from_config["cxx"], ] ) @@ -99,14 +99,14 @@ def get_gt_pyext_build_opts( ] if is_rocm_gpu: extra_compile_args["cuda"] += [ - "-isystem{}".format(gt_include_path), + f"-isystem{gt_include_path}", "-fvisibility=hidden", "-fPIC", ] else: extra_compile_args["cuda"] += [ - "-isystem={}".format(gt_include_path), - "-arch=sm_{}".format(cuda_arch), + f"-isystem={gt_include_path}", + f"-arch=sm_{cuda_arch}", "--expt-relaxed-constexpr", "--compiler-options", "-fvisibility=hidden", @@ -199,15 +199,15 @@ def build_pybind_ext( build_path: str, target_path: str, *, - include_dirs: Optional[List[str]] = None, - library_dirs: Optional[List[str]] = None, - libraries: Optional[List[str]] = None, - extra_compile_args: Optional[Union[List[str], Dict[str, List[str]]]] = None, - extra_link_args: Optional[List[str]] = None, - build_ext_class: Optional[Type] = None, + include_dirs: Optional[list[str]] = None, + library_dirs: Optional[list[str]] = None, + libraries: Optional[list[str]] = None, + extra_compile_args: Optional[list[str] | dict[str, list[str]]] = None, + extra_link_args: Optional[list[str]] = None, + build_ext_class: Optional[type] = None, verbose: bool = False, clean: bool = False, -) -> Tuple[str, str]: +) -> tuple[str, str]: # Hack to remove warning about "-Wstrict-prototypes" not having effect in C++ replaced_flags_backup = copy.deepcopy(distutils.sysconfig._config_vars) _clean_build_flags(distutils.sysconfig._config_vars) @@ -239,8 +239,8 @@ def build_pybind_ext( ext_modules=[py_extension], script_args=[ "build_ext", - "--build-temp={}".format(build_path), - "--build-lib={}".format(build_path), + f"--build-temp={build_path}", + f"--build-lib={build_path}", "--force", ], ) @@ -280,14 +280,14 @@ def build_pybind_cuda_ext( build_path: str, target_path: str, *, - include_dirs: Optional[List[str]] = None, - library_dirs: Optional[List[str]] = None, - libraries: Optional[List[str]] = None, - extra_compile_args: Optional[Union[List[str], Dict[str, List[str]]]] = None, - extra_link_args: Optional[List[str]] = None, + include_dirs: Optional[list[str]] = None, + library_dirs: Optional[list[str]] = None, + libraries: Optional[list[str]] = None, + extra_compile_args: Optional[list[str] | dict[str, list[str]]] = None, + extra_link_args: Optional[list[str]] = None, verbose: bool = False, clean: bool = False, -) -> Tuple[str, str]: +) -> tuple[str, str]: include_dirs = include_dirs or [] include_dirs = [*include_dirs, gt_config.build_settings["cuda_include_path"]] library_dirs = library_dirs or [] @@ -315,7 +315,7 @@ def build_pybind_cuda_ext( ) -def _clean_build_flags(config_vars: Dict[str, str]) -> None: +def _clean_build_flags(config_vars: dict[str, str]) -> None: for key, value in config_vars.items(): if isinstance(value, str): value = " " + value + " " @@ -329,7 +329,7 @@ def _clean_build_flags(config_vars: Dict[str, str]) -> None: config_vars[key] = " ".join(value.split()) -class CUDABuildExtension(build_ext, object): +class CUDABuildExtension(build_ext): # Refs: # - https://github.com/pytorch/pytorch/torch/utils/cpp_extension.py # - https://github.com/rmcgibbo/npcuda-example/blob/master/cython/setup.py diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index 79fd8edb5d..98755ad0d8 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -16,7 +16,7 @@ import pickle import sys import types -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional from cached_property import cached_property @@ -57,7 +57,7 @@ def cache_info_path(self) -> Optional[pathlib.Path]: raise NotImplementedError @abc.abstractmethod - def generate_cache_info(self) -> Dict[str, Any]: + def generate_cache_info(self) -> dict[str, Any]: """ Generate the cache info dict. @@ -102,7 +102,7 @@ def stencil_id(self) -> StencilID: @property @abc.abstractmethod - def cache_info(self) -> Dict[str, Any]: + def cache_info(self) -> dict[str, Any]: """ Read currently stored cache info from file into a dictionary. @@ -151,7 +151,7 @@ def class_name(self) -> str: """Calculate the name for the stencil class, default is to read from build options.""" return self.builder.options.name - def capture_externals(self) -> Dict[str, Any]: + def capture_externals(self) -> dict[str, Any]: """Extract externals from the annotated stencil definition for fingerprinting. Freezes the references.""" return {} @@ -197,9 +197,7 @@ def root_path(self) -> pathlib.Path: @property def backend_root_path(self) -> pathlib.Path: - cpython_id = "py{version.major}{version.minor}_{api_version}".format( - version=sys.version_info, api_version=sys.api_version - ) + cpython_id = f"py{sys.version_info.major}{sys.version_info.minor}_{sys.api_version}" backend_root = self.root_path / cpython_id / gt_utils.slugify(self.builder.backend.name) if not backend_root.exists(): if not backend_root.parent.exists(): @@ -212,7 +210,7 @@ def cache_info_path(self) -> Optional[pathlib.Path]: """Get the cache info file path from the stencil module path.""" return self.builder.module_path.parent / f"{self.builder.module_path.stem}.cacheinfo" - def generate_cache_info(self) -> Dict[str, Any]: + def generate_cache_info(self) -> dict[str, Any]: return { "backend": self.builder.backend.name, "stencil_name": self.builder.stencil_id.qualified_name, @@ -265,7 +263,7 @@ def is_cache_info_available_and_consistent( return result @property - def cache_info(self) -> Dict[str, Any]: + def cache_info(self) -> dict[str, Any]: if not self.cache_info_path: return {} if not self.cache_info_path.exists(): @@ -273,7 +271,7 @@ def cache_info(self) -> Dict[str, Any]: return self._unpickle_cache_info_file(self.cache_info_path) @staticmethod - def _unpickle_cache_info_file(cache_info_path: pathlib.Path) -> Dict[str, Any]: + def _unpickle_cache_info_file(cache_info_path: pathlib.Path) -> dict[str, Any]: with cache_info_path.open("rb") as cache_info_file: return pickle.load(cache_info_file) @@ -283,19 +281,19 @@ def options_id(self) -> str: return self.builder.backend.filter_options_for_id(self.builder.options).shashed_id return self.builder.options.shashed_id - def capture_externals(self) -> Dict[str, Any]: + def capture_externals(self) -> dict[str, Any]: """Extract externals from the annotated stencil definition for fingerprinting.""" return self._externals @cached_property - def _externals(self) -> Dict[str, Any]: + def _externals(self) -> dict[str, Any]: """Extract externals from the annotated stencil definition for fingerprinting.""" return { name: value._gtscript_["canonical_ast"] if hasattr(value, "_gtscript_") else value for name, value in self.builder.definition._gtscript_["externals"].items() } - def _extract_api_annotations(self) -> List[str]: + def _extract_api_annotations(self) -> list[str]: """Extract API annotations from the annotated stencil definition for fingerprinting.""" return [str(item) for item in self.builder.definition._gtscript_["api_annotations"]] @@ -389,7 +387,7 @@ def backend_root_path(self) -> pathlib.Path: def cache_info_path(self) -> Optional[pathlib.Path]: return None - def generate_cache_info(self) -> Dict[str, Any]: + def generate_cache_info(self) -> dict[str, Any]: return {} def update_cache_info(self) -> None: @@ -399,7 +397,7 @@ def is_cache_info_available_and_consistent(self, *, validate_hash: bool) -> bool return False @property - def cache_info(self) -> Dict[str, Any]: + def cache_info(self) -> dict[str, Any]: return {} @property diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index a9ded21ec7..4d3af7f322 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -8,7 +8,7 @@ import multiprocessing import os -from typing import Any, Dict, List, Optional +from typing import Any, Optional import gridtools_cpp @@ -32,13 +32,13 @@ # Settings dict GT4PY_EXTRA_COMPILE_ARGS: str = os.environ.get("GT4PY_EXTRA_COMPILE_ARGS", "") -extra_compile_args: List[str] = ( +extra_compile_args: list[str] = ( list(GT4PY_EXTRA_COMPILE_ARGS.split(" ")) if GT4PY_EXTRA_COMPILE_ARGS else [] ) GT4PY_EXTRA_LINK_ARGS: str = os.environ.get("GT4PY_EXTRA_LINK_ARGS", "") -extra_link_args: List[str] = list(GT4PY_EXTRA_LINK_ARGS.split(" ")) if GT4PY_EXTRA_LINK_ARGS else [] +extra_link_args: list[str] = list(GT4PY_EXTRA_LINK_ARGS.split(" ")) if GT4PY_EXTRA_LINK_ARGS else [] -build_settings: Dict[str, Any] = { +build_settings: dict[str, Any] = { "cuda_bin_path": os.path.join(CUDA_ROOT, "bin"), "cuda_include_path": os.path.join(CUDA_ROOT, "include"), "cuda_arch": os.environ.get("CUDA_ARCH", None), @@ -58,14 +58,14 @@ if CUDA_HOST_CXX is not None: build_settings["extra_compile_args"]["cuda"].append(f"-ccbin={CUDA_HOST_CXX}") -cache_settings: Dict[str, Any] = { +cache_settings: dict[str, Any] = { "dir_name": os.environ.get("GT_CACHE_DIR_NAME", ".gt_cache"), "root_path": os.environ.get("GT_CACHE_ROOT", os.path.abspath(".")), "load_retries": int(os.environ.get("GT_CACHE_LOAD_RETRIES", 3)), "load_retry_delay": int(os.environ.get("GT_CACHE_LOAD_RETRY_DELAY", 100)), # unit milliseconds } -code_settings: Dict[str, Any] = {"root_package_name": "_GT_"} +code_settings: dict[str, Any] = {"root_package_name": "_GT_"} os.environ.setdefault("DACE_CONFIG", os.path.join(os.path.abspath("."), ".dace.conf")) diff --git a/src/gt4py/cartesian/definitions.py b/src/gt4py/cartesian/definitions.py index 2d7ac99656..30c5db97be 100644 --- a/src/gt4py/cartesian/definitions.py +++ b/src/gt4py/cartesian/definitions.py @@ -11,7 +11,7 @@ import os import platform from dataclasses import dataclass -from typing import Literal, Tuple, Union +from typing import Literal import numpy @@ -55,7 +55,7 @@ def __str__(self): @dataclass(frozen=True) class DomainInfo: - parallel_axes: Tuple[str, ...] + parallel_axes: tuple[str, ...] sequential_axis: str min_sequential_axis_size: int ndim: int @@ -65,18 +65,12 @@ class DomainInfo: class FieldInfo: access: AccessKind boundary: Boundary - axes: Tuple[str, ...] - data_dims: Tuple[int, ...] + axes: tuple[str, ...] + data_dims: tuple[int, ...] dtype: numpy.dtype def __repr__(self): - return "FieldInfo(access=AccessKind.{access}, boundary={boundary}, axes={axes}, data_dims={data_dims}, dtype={dtype})".format( - access=self.access.name, - boundary=repr(self.boundary), - axes=repr(self.axes), - data_dims=repr(self.data_dims), - dtype=repr(self.dtype), - ) + return f"FieldInfo(access=AccessKind.{self.access.name}, boundary={self.boundary!r}, axes={self.axes!r}, data_dims={self.data_dims!r}, dtype={self.dtype!r})" @functools.cached_property def domain_mask(self): @@ -97,13 +91,11 @@ def ndim(self): @dataclass(frozen=True) class ParameterInfo: - access: Union[Literal[AccessKind.NONE], Literal[AccessKind.READ]] + access: Literal[AccessKind.NONE] | Literal[AccessKind.READ] dtype: numpy.dtype def __repr__(self): - return "ParameterInfo(access=AccessKind.{access}, dtype={dtype})".format( - access=self.access.name, dtype=repr(self.dtype) - ) + return f"ParameterInfo(access=AccessKind.{self.access.name}, dtype={self.dtype!r})" @attribkwclass @@ -126,7 +118,7 @@ class BuildOptions(AttributeClassLike): @property def qualified_name(self): - return ".".join([self.module, self.name]) + return f"{self.module}.{self.name}" @property def shashed_id(self): diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index 2075d06a61..96e94d3088 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -10,7 +10,7 @@ import functools import itertools import numbers -from typing import Any, Dict, Final, List, Optional, Tuple, Union, cast +from typing import Any, Final, Optional, cast import numpy as np @@ -68,7 +68,7 @@ def _convert_dtype(data_type) -> common.DataType: def _make_literal(v: numbers.Number) -> gtir.Literal: - value: Union[BuiltinLiteral, str] + value: BuiltinLiteral | str if isinstance(v, (bool, np.bool_)): dtype = common.DataType.BOOL value = common.BuiltInLiteral.TRUE if v else common.BuiltInLiteral.FALSE @@ -101,7 +101,7 @@ class UnrollVectorAssignments(IRNodeMapper): def apply(cls, root, **kwargs): return cls().visit(root, **kwargs) - def _is_vector_assignment(self, stmt: Node, fields_decls: Dict[str, FieldDecl]) -> bool: + def _is_vector_assignment(self, stmt: Node, fields_decls: dict[str, FieldDecl]) -> bool: if not isinstance(stmt, Assign): return False @@ -109,7 +109,7 @@ def _is_vector_assignment(self, stmt: Node, fields_decls: Dict[str, FieldDecl]) return fields_decls[stmt.target.name].data_dims and not stmt.target.data_index def visit_StencilDefinition( - self, node: StencilDefinition, *, fields_decls: Dict[str, FieldDecl], **kwargs + self, node: StencilDefinition, *, fields_decls: dict[str, FieldDecl], **kwargs ) -> StencilDefinition: node = copy.deepcopy(node) @@ -128,14 +128,14 @@ def visit_StencilDefinition( return node # computes dimensions of nested lists - def _nested_list_dim(self, a: List) -> List[int]: + def _nested_list_dim(self, a: list) -> list[int]: if not isinstance(a, list): return [] return [len(a), *self._nested_list_dim(a[0])] def visit_Assign( - self, node: Assign, *, fields_decls: Dict[str, FieldDecl], **kwargs - ) -> Union[gtir.ParAssignStmt, List[gtir.ParAssignStmt]]: + self, node: Assign, *, fields_decls: dict[str, FieldDecl], **kwargs + ) -> gtir.ParAssignStmt | list[gtir.ParAssignStmt]: if self._is_vector_assignment(node, fields_decls): assert isinstance(node.target, FieldRef) or isinstance(node.target, VarRef) target_dims = fields_decls[node.target.name].data_dims @@ -170,7 +170,7 @@ def visit_Assign( class UnrollVectorExpressions(IRNodeMapper): @classmethod - def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: Dict[str, FieldDecl]): + def apply(cls, root, *, expected_dim: tuple[int, ...], fields_decls: dict[str, FieldDecl]): result = cls().visit(root, fields_decls=fields_decls) # if the expression is just a scalar broadcast to the expected dimensions if not isinstance(result, list): @@ -179,10 +179,10 @@ def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: Dict[str, F ) return result - def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], **kwargs): + def visit_FieldRef(self, node: FieldRef, *, fields_decls: dict[str, FieldDecl], **kwargs): name = node.name if fields_decls[name].data_dims: - field_list: List[Union[FieldRef, List[FieldRef]]] = [] + field_list: list[FieldRef | list[FieldRef]] = [] # vector if len(fields_decls[name].data_dims) == 1: dims = fields_decls[name].data_dims[0] @@ -197,7 +197,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], elif len(fields_decls[name].data_dims) == 2: rows, cols = fields_decls[name].data_dims for row in range(rows): - row_list: List[FieldRef] = [] + row_list: list[FieldRef] = [] for col in range(cols): data_type = DataType.INT32 data_index = [ @@ -221,7 +221,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], return node - def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: Dict[str, FieldDecl], **kwargs): + def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs): if node.op == UnaryOperator.TRANSPOSED: node = self.visit(node.arg, fields_decls=fields_decls, **kwargs) assert isinstance(node, list) and all( @@ -233,10 +233,10 @@ def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: Dict[str, FieldD return self.generic_visit(node, **kwargs) - def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: Dict[str, FieldDecl], **kwargs): + def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs): lhs = self.visit(node.lhs, fields_decls=fields_decls, **kwargs) rhs = self.visit(node.rhs, fields_decls=fields_decls, **kwargs) - result: Union[List[BinOpExpr], BinOpExpr] = [] + result: list[BinOpExpr] | BinOpExpr = [] if node.op == BinaryOperator.MATMULT: for j in range(len(lhs)): @@ -388,7 +388,7 @@ def visit_StencilDefinition(self, node: StencilDefinition) -> gtir.Stencil: loc=location_to_source_location(node.loc), ) - def visit_ArgumentInfo(self, node: ArgumentInfo, all_params: Dict[str, gtir.Decl]) -> gtir.Decl: + def visit_ArgumentInfo(self, node: ArgumentInfo, all_params: dict[str, gtir.Decl]) -> gtir.Decl: return all_params[node.name] def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop: @@ -412,7 +412,7 @@ def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop: loc=location_to_source_location(node.loc), ) - def visit_BlockStmt(self, node: BlockStmt) -> List[gtir.Stmt]: + def visit_BlockStmt(self, node: BlockStmt) -> list[gtir.Stmt]: return [self.visit(s) for s in node.stmts] def visit_Assign(self, node: Assign) -> gtir.ParAssignStmt: @@ -433,7 +433,7 @@ def visit_UnaryOpExpr(self, node: UnaryOpExpr) -> gtir.UnaryOp: loc=location_to_source_location(node.loc), ) - def visit_BinOpExpr(self, node: BinOpExpr) -> Union[gtir.BinaryOp, gtir.NativeFuncCall]: + def visit_BinOpExpr(self, node: BinOpExpr) -> gtir.BinaryOp | gtir.NativeFuncCall: if node.op in (BinaryOperator.POW, BinaryOperator.MOD): return gtir.NativeFuncCall( func=common.NativeFunction[node.op.name], @@ -485,7 +485,7 @@ def visit_FieldRef(self, node: FieldRef) -> gtir.FieldAccess: loc=location_to_source_location(node.loc), ) - def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]: + def visit_If(self, node: If) -> gtir.FieldIfStmt | gtir.ScalarIfStmt: cond = self.visit(node.condition) if cond.kind == ExprKind.FIELD: return gtir.FieldIfStmt( @@ -540,7 +540,7 @@ def visit_While(self, node: While) -> gtir.While: def visit_VarRef(self, node: VarRef, **kwargs) -> gtir.ScalarAccess: return gtir.ScalarAccess(name=node.name, loc=location_to_source_location(node.loc)) - def visit_AxisInterval(self, node: AxisInterval) -> Tuple[gtir.AxisBound, gtir.AxisBound]: + def visit_AxisInterval(self, node: AxisInterval) -> tuple[gtir.AxisBound, gtir.AxisBound]: return self.visit(node.start), self.visit(node.end) def visit_AxisBound(self, node: AxisBound) -> gtir.AxisBound: @@ -570,8 +570,8 @@ def visit_VarDecl(self, node: VarDecl) -> gtir.ScalarDecl: ) def transform_offset( - self, offset: Dict[str, Union[int, Expr]], **kwargs: Any - ) -> Union[common.CartesianOffset, gtir.VariableKOffset]: + self, offset: dict[str, int | Expr], **kwargs: Any + ) -> common.CartesianOffset | gtir.VariableKOffset: k_val = offset.get("K", 0) if isinstance(k_val, numbers.Integral): return common.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val) diff --git a/src/gt4py/cartesian/frontend/exceptions.py b/src/gt4py/cartesian/frontend/exceptions.py index e9e38ab53c..3c9801a6d2 100644 --- a/src/gt4py/cartesian/frontend/exceptions.py +++ b/src/gt4py/cartesian/frontend/exceptions.py @@ -20,13 +20,9 @@ class GTScriptSymbolError(GTScriptSyntaxError): def __init__(self, name, message=None, *, loc=None): if message is None: if loc is None: - message = "Unknown symbol '{name}' symbol".format(name=name) + message = f"Unknown symbol '{name}' symbol" else: - message = ( - "Unknown symbol '{name}' symbol in '{scope}' (line: {line}, col: {col})".format( - name=name, scope=loc.scope, line=loc.line, col=loc.column - ) - ) + message = f"Unknown symbol '{name}' symbol in '{loc.scope}' (line: {loc.line}, col: {loc.column})" super().__init__(message, loc=loc) self.name = name @@ -35,11 +31,9 @@ class GTScriptDefinitionError(GTScriptSyntaxError): def __init__(self, name, value, message=None, *, loc=None): if message is None: if loc is None: - message = "Invalid definition for '{name}' symbol".format(name=name) + message = f"Invalid definition for '{name}' symbol" else: - message = "Invalid definition for '{name}' symbol in '{scope}' (line: {line}, col: {col})".format( - name=name, scope=loc.scope, line=loc.line, col=loc.column - ) + message = f"Invalid definition for '{name}' symbol in '{loc.scope}' (line: {loc.line}, col: {loc.column})" super().__init__(message, loc=loc) self.name = name self.value = value @@ -49,13 +43,9 @@ class GTScriptValueError(GTScriptDefinitionError): def __init__(self, name, value, message=None, *, loc=None): if message is None: if loc is None: - message = "Invalid value for '{name}' symbol ".format(name=name) + message = f"Invalid value for '{name}' symbol " else: - message = ( - "Invalid value for '{name}' in '{scope}' (line: {line}, col: {col})".format( - name=name, scope=loc.scope, line=loc.line, col=loc.column - ) - ) + message = f"Invalid value for '{name}' in '{loc.scope}' (line: {loc.line}, col: {loc.column})" super().__init__(name, value, message, loc=loc) @@ -63,11 +53,9 @@ class GTScriptDataTypeError(GTScriptSyntaxError): def __init__(self, name, data_type, message=None, *, loc=None): if message is None: if loc is None: - message = "Invalid data type for '{name}' numeric symbol ".format(name=name) + message = f"Invalid data type for '{name}' numeric symbol " else: - message = "Invalid data type for '{name}' numeric symbol in '{scope}' (line: {line}, col: {col})".format( - name=name, scope=loc.scope, line=loc.line, col=loc.column - ) + message = f"Invalid data type for '{name}' numeric symbol in '{loc.scope}' (line: {loc.line}, col: {loc.column})" super().__init__(message, loc=loc) self.name = name self.data_type = data_type diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 0f2ce8aba3..907ecd904d 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -15,20 +15,8 @@ import textwrap import time import types -from typing import ( - Any, - Callable, - Dict, - Final, - List, - Literal, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) +from collections.abc import Callable, Sequence +from typing import Any, Final, Literal, Optional import numpy as np @@ -54,11 +42,11 @@ class AssertionChecker(ast.NodeTransformer): """Check assertions and remove from the AST for further parsing.""" @classmethod - def apply(cls, func_node: ast.FunctionDef, context: Dict[str, Any], source: str): + def apply(cls, func_node: ast.FunctionDef, context: dict[str, Any], source: str): checker = cls(context, source) checker.visit(func_node) - def __init__(self, context: Dict[str, Any], source: str): + def __init__(self, context: dict[str, Any], source: str): self.context = context self.source = source @@ -104,7 +92,7 @@ class AxisIntervalParser(gt_meta.ASTPass): @classmethod def apply( cls, - node: Union[ast.Ellipsis, ast.Slice, ast.Subscript, ast.Constant], + node: ast.Ellipsis | ast.Slice | ast.Subscript | ast.Constant, axis_name: str, loc: Optional[nodes.Location] = None, ) -> nodes.AxisInterval: @@ -165,7 +153,7 @@ def slice_from_value(node: ast.Expr) -> ast.Slice: def _make_axis_bound( self, - value: Union[int, None, gtscript.AxisIndex, nodes.AxisBound, nodes.VarRef], + value: int | None | gtscript.AxisIndex | nodes.AxisBound | nodes.VarRef, endpt: nodes.LevelMarker, ) -> nodes.AxisBound: if isinstance(value, nodes.AxisBound): @@ -196,7 +184,7 @@ def _make_axis_bound( def visit_Name(self, node: ast.Name) -> nodes.VarRef: return nodes.VarRef(name=node.id, loc=nodes.Location.from_ast_node(node)) - def visit_Constant(self, node: ast.Constant) -> Union[int, gtscript.AxisIndex, None]: + def visit_Constant(self, node: ast.Constant) -> int | gtscript.AxisIndex | None: if isinstance(node.value, gtscript.AxisIndex): return node.value elif isinstance(node.value, numbers.Number): @@ -209,7 +197,7 @@ def visit_Constant(self, node: ast.Constant) -> Union[int, gtscript.AxisIndex, N loc=self.loc, ) - def visit_BinOp(self, node: ast.BinOp) -> Union[gtscript.AxisIndex, nodes.AxisBound, int]: + def visit_BinOp(self, node: ast.BinOp) -> gtscript.AxisIndex | nodes.AxisBound | int: left = self.visit(node.left) right = self.visit(node.right) @@ -355,13 +343,13 @@ class CallInliner(ast.NodeTransformer): @classmethod def apply( - cls, func_node: ast.FunctionDef, context: dict, *, call_stack: Optional[Set[str]] = None + cls, func_node: ast.FunctionDef, context: dict, *, call_stack: Optional[set[str]] = None ): inliner = cls(context, call_stack=call_stack or set()) inliner(func_node) return inliner.all_skip_names - def __init__(self, context: dict, call_stack: Optional[Set[str]] = None): + def __init__(self, context: dict, call_stack: Optional[set[str]] = None): self.context = context self.current_block = None self.call_stack = call_stack @@ -590,10 +578,10 @@ def visit_Expr(self, node: ast.Expr): class CompiledIfInliner(ast.NodeTransformer): @classmethod - def apply(cls, ast_object: ast.AST, context: Dict[str, Any]): + def apply(cls, ast_object: ast.AST, context: dict[str, Any]): cls(context).visit(ast_object) - def __init__(self, context: Dict[str, Any]): + def __init__(self, context: dict[str, Any]): self.context = context def visit_If(self, node: ast.If): @@ -618,8 +606,8 @@ def visit_If(self, node: ast.If): def _make_temp_decls( - descriptors: Dict[str, gtscript._FieldDescriptor], -) -> Dict[str, nodes.FieldDecl]: + descriptors: dict[str, gtscript._FieldDescriptor], +) -> dict[str, nodes.FieldDecl]: return { name: nodes.FieldDecl( name=name, @@ -633,12 +621,12 @@ def _make_temp_decls( def _make_init_computations( - temp_decls: Dict[str, nodes.FieldDecl], init_values: Dict[str, Any], func_node: ast.AST -) -> List[nodes.ComputationBlock]: + temp_decls: dict[str, nodes.FieldDecl], init_values: dict[str, Any], func_node: ast.AST +) -> list[nodes.ComputationBlock]: if not temp_decls: return [] - stmts: List[nodes.Assign] = [] + stmts: list[nodes.Assign] = [] for name in init_values: decl = temp_decls[name] stmts.append(decl) @@ -673,8 +661,8 @@ def _make_init_computations( ] -def _find_accesses_with_offsets(node: nodes.Node) -> Set[str]: - names: Set[str] = set() +def _find_accesses_with_offsets(node: nodes.Node) -> set[str]: + names: set[str] = set() class FindRefs(node_util.IRNodeVisitor): def visit_FieldRef(self, node: nodes.FieldRef) -> None: @@ -720,8 +708,8 @@ def __init__( *, domain: nodes.Domain, options: gt_definitions.BuildOptions, - temp_decls: Optional[Dict[str, nodes.FieldDecl]] = None, - dtypes: Optional[Dict[Type, Type]] = None, + temp_decls: Optional[dict[str, nodes.FieldDecl]] = None, + dtypes: Optional[dict[type, type]] = None, ): fields = fields or {} parameters = parameters or {} @@ -743,7 +731,7 @@ def __init__( self.iteration_order = None self.decls_stack = [] self.parsing_horizontal_region = False - self.written_vars: Set[str] = set() + self.written_vars: set[str] = set() self.dtypes = dtypes self.python_symbol_to_ir_op = { "abs": nodes.NativeFunction.ABS, @@ -831,7 +819,7 @@ def _is_local_symbol(self, name: str): def _is_known(self, name: str): return self._is_field(name) or self._is_parameter(name) or self._is_local_symbol(name) - def _are_blocks_sorted(self, compute_blocks: List[nodes.ComputationBlock]): + def _are_blocks_sorted(self, compute_blocks: list[nodes.ComputationBlock]): def sort_blocks_key(comp_block): start = comp_block.interval.start assert isinstance(start.level, nodes.LevelMarker) @@ -861,7 +849,7 @@ def sort_blocks_key(comp_block): # if sorting didn't change anything it was already sorted return compute_blocks == compute_blocks_sorted - def _parse_region_intervals(self, node: ast.Tuple) -> Dict[str, nodes.AxisInterval]: + def _parse_region_intervals(self, node: ast.Tuple) -> dict[str, nodes.AxisInterval]: # Since Python 3.9: directly returns a Tuple for region[0, 1] list_of_exprs = [axis_node for axis_node in node.elts] axes_names = [axis.name for axis in self.domain.parallel_axes] @@ -872,7 +860,7 @@ def _parse_region_intervals(self, node: ast.Tuple) -> Dict[str, nodes.AxisInterv def _visit_with_horizontal( self, node: ast.withitem, loc: nodes.Location - ) -> List[Dict[str, nodes.AxisInterval]]: + ) -> list[dict[str, nodes.AxisInterval]]: syntax_error = GTScriptSyntaxError( f"Invalid 'with' statement at line {loc.line} (column {loc.column})", loc=loc ) @@ -885,7 +873,7 @@ def _visit_with_horizontal( return [self._parse_region_intervals(arg.slice) for arg in call_args] - def _are_intervals_nonoverlapping(self, compute_blocks: List[nodes.ComputationBlock]): + def _are_intervals_nonoverlapping(self, compute_blocks: list[nodes.ComputationBlock]): for i, block in enumerate(compute_blocks[1:]): other = compute_blocks[i] if not block.interval.disjoint_from(other.interval): @@ -1021,7 +1009,7 @@ def visit_Raise(self): # -- Literal nodes -- def visit_Constant( self, node: ast.Constant - ) -> Union[nodes.ScalarLiteral, nodes.BuiltinLiteral, nodes.Cast]: + ) -> nodes.ScalarLiteral | nodes.BuiltinLiteral | nodes.Cast: value = node.value if value is None: return nodes.BuiltinLiteral(value=nodes.Builtin.from_value(value)) @@ -1094,8 +1082,8 @@ def visit_Index(self, node: ast.Index): return index def _eval_new_spatial_index( - self, index_nodes: Sequence[nodes.Expr], field_axes: Optional[Set[Literal["I", "J", "K"]]] - ) -> List[int]: + self, index_nodes: Sequence[nodes.Expr], field_axes: Optional[set[Literal["I", "J", "K"]]] + ) -> list[int]: index_dict = {} all_spatial_axes = ("I", "J", "K") last_index = -1 @@ -1138,8 +1126,8 @@ def _eval_new_spatial_index( return [index_dict.get(axis, 0) for axis in ("I", "J", "K") if axis in field_axes] def _eval_index( - self, node: ast.Subscript, field_axes: Optional[Set[Literal["I", "J", "K"]]] = None - ) -> Optional[List[int]]: + self, node: ast.Subscript, field_axes: Optional[set[Literal["I", "J", "K"]]] = None + ) -> Optional[list[int]]: tuple_or_expr = node.slice.value if isinstance(node.slice, ast.Index) else node.slice index_nodes = gt_utils.listify( tuple_or_expr.elts if isinstance(tuple_or_expr, ast.Tuple) else tuple_or_expr @@ -1240,7 +1228,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp): op = self.visit(node.op) arg = self.visit(node.operand) if isinstance(arg, numbers.Number): - result = eval("{op}{arg}".format(op=op.python_symbol, arg=arg)) + result = eval(f"{op.python_symbol}{arg}") else: result = nodes.UnaryOpExpr(op=op, arg=arg, loc=nodes.Location.from_ast_node(node)) @@ -1433,8 +1421,8 @@ def visit_Call(self, node: ast.Call): # -- Statement nodes -- def _parse_assign_target( - self, target_node: Union[ast.Subscript, ast.Name] - ) -> Tuple[str, Optional[List[int]], Optional[List[int]]]: + self, target_node: ast.Subscript | ast.Name + ) -> tuple[str, Optional[list[int]], Optional[list[int]]]: invalid_target = GTScriptSyntaxError( message="Invalid target in assignment.", loc=target_node ) @@ -1470,8 +1458,8 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> list: def _resolve_assign( self, - node: Union[ast.AnnAssign, ast.Assign], - targets: List[Any], + node: ast.AnnAssign | ast.Assign, + targets: list[Any], target_annotation: Optional[Any] = None, ) -> list: result = [] @@ -1659,7 +1647,7 @@ def visit_With(self, node: ast.With): # Mixing nested `with` blocks with stmts not allowed raise syntax_error - def visit_FunctionDef(self, node: ast.FunctionDef) -> List[nodes.ComputationBlock]: + def visit_FunctionDef(self, node: ast.FunctionDef) -> list[nodes.ComputationBlock]: blocks = [] for stmt in filter(lambda s: not isinstance(s, ast.AnnAssign), node.body): blocks.extend(self.visit(stmt)) @@ -1736,7 +1724,7 @@ def __init__(self, definition, *, options, externals=None, dtypes=None): def __str__(self) -> str: result = " {\n" - result += "\n".join("\t{}: {}".format(name, getattr(self, name)) for name in vars(self)) + result += "\n".join(f"\t{name}: {getattr(self, name)}" for name in vars(self)) result += "\n}" return result @@ -1766,7 +1754,7 @@ def annotate_definition( api_signature = [] api_annotations = [] - qualified_name = "{}.{}".format(definition.__module__, definition.__name__) + qualified_name = f"{definition.__module__}.{definition.__name__}" sig = inspect.signature(definition) for param in sig.parameters.values(): if param.kind == inspect.Parameter.VAR_POSITIONAL: @@ -1827,8 +1815,8 @@ def annotate_definition( ) # Gather temporary - temp_annotations: Dict[str, gtscript._FieldDescriptor] = {} - temp_init_values: Dict[str, numbers.Number] = {} + temp_annotations: dict[str, gtscript._FieldDescriptor] = {} + temp_init_values: dict[str, numbers.Number] = {} frontend_types_to_native_types = nodes.frontend_type_to_native_type( options.literal_int_precision @@ -1916,7 +1904,7 @@ def collect_external_symbols(definition): wrong_imports.append(key) if wrong_imports: - raise GTScriptSyntaxError("Invalid 'import' statements ({})".format(wrong_imports)) + raise GTScriptSyntaxError(f"Invalid 'import' statements ({wrong_imports})") context, unbound = gt_meta.get_closure( definition, included_nonlocals=True, include_builtins=False @@ -1971,7 +1959,7 @@ def eval_external(name: str, context: dict, loc=None): raise GTScriptDefinitionError( name=name, value="", - message="Missing or invalid value for external symbol {name}".format(name=name), + message=f"Missing or invalid value for external symbol {name}", loc=loc, ) from e return value diff --git a/src/gt4py/cartesian/frontend/node_util.py b/src/gt4py/cartesian/frontend/node_util.py index 52496b12a4..15e6b48bde 100644 --- a/src/gt4py/cartesian/frontend/node_util.py +++ b/src/gt4py/cartesian/frontend/node_util.py @@ -8,7 +8,8 @@ import collections import operator -from typing import Generator, Optional, Type +from collections.abc import Generator +from typing import Optional import boltons.typeutils @@ -111,7 +112,7 @@ def generic_visit(self, node: Node, **kwargs): return node -def iter_nodes_of_type(root_node: Node, node_type: Type) -> Generator[Node, None, None]: +def iter_nodes_of_type(root_node: Node, node_type: type) -> Generator[Node, None, None]: """Yield an iterator over the nodes of node_type inside root_node in DFS order.""" def recurse(node: Node) -> Generator[Node, None, None]: diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index b71da1b8e3..420b80f0b8 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -138,7 +138,8 @@ import enum import operator import sys -from typing import List, Optional, Sequence +from collections.abc import Sequence +from typing import Optional import numpy as np @@ -364,7 +365,7 @@ class FieldRef(Ref): @classmethod def at_center( - cls, name: str, axes: Sequence[str], data_index: Optional[List[int]] = None, loc=None + cls, name: str, axes: Sequence[str], data_index: Optional[list[int]] = None, loc=None ): return cls( name=name, offset={axis: 0 for axis in axes}, data_index=data_index or [], loc=loc diff --git a/src/gt4py/cartesian/gt_cache_manager.py b/src/gt4py/cartesian/gt_cache_manager.py index 5f3c87bb50..0ec8eab244 100644 --- a/src/gt4py/cartesian/gt_cache_manager.py +++ b/src/gt4py/cartesian/gt_cache_manager.py @@ -12,7 +12,8 @@ import os import pathlib import shutil -from typing import List, Optional, Sequence +from collections.abc import Sequence +from typing import Optional from gt4py.cartesian import config as gt_config @@ -29,7 +30,7 @@ def _get_cache_name() -> str: return result -def find_caches(root: Optional[str] = None, cache_name: Optional[str] = None) -> List[pathlib.Path]: +def find_caches(root: Optional[str] = None, cache_name: Optional[str] = None) -> list[pathlib.Path]: root_path = pathlib.Path(root or _get_root()) cache_name = cache_name or _get_cache_name() diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 0ecca2b8bc..be19a79cce 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -11,20 +11,8 @@ import enum import functools import typing -from typing import ( - Any, - ClassVar, - Dict, - Final, - Generic, - List, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from collections.abc import Mapping +from typing import Any, ClassVar, Final, Generic, Optional, TypeVar import numpy as np @@ -186,7 +174,7 @@ class NativeFunction(eve.StrEnum): FLOAT32 = "float32" FLOAT64 = "float64" - IR_OP_TO_NUM_ARGS: ClassVar[Dict[NativeFunction, int]] + IR_OP_TO_NUM_ARGS: ClassVar[dict[NativeFunction, int]] @property def arity(self) -> int: @@ -273,15 +261,15 @@ class Stmt(LocNode): def verify_condition_is_boolean(parent_node_cls: datamodels.DataModel, cond: Expr) -> None: if cond.dtype and cond.dtype is not DataType.BOOL: - raise ValueError("Condition in `{}` must be boolean.".format(type(parent_node_cls))) + raise ValueError(f"Condition in `{type(parent_node_cls)}` must be boolean.") def verify_and_get_common_dtype( - node_cls: Type[datamodels.DataModel], exprs: List[Expr], *, strict: bool = True + node_cls: type[datamodels.DataModel], exprs: list[Expr], *, strict: bool = True ) -> Optional[DataType]: assert len(exprs) > 0 if all(e.dtype is not DataType.AUTO for e in exprs): - dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None + dtypes: list[DataType] = [e.dtype for e in exprs] # guaranteed to be not None dtype = dtypes[0] if strict: if all(dt == dtype for dt in dtypes): @@ -308,7 +296,7 @@ def compute_kind(*values: Expr) -> ExprKind: class Literal(eve.Node): # TODO(havogt): reconsider if `str` is a good representation for value, # maybe it should be Union[float,int,str] etc? - value: Union[BuiltInLiteral, str] + value: BuiltInLiteral | str dtype: DataType kind: ExprKind = ExprKind.SCALAR @@ -328,14 +316,14 @@ class CartesianOffset(eve.Node): def zero(cls) -> CartesianOffset: return cls(i=0, j=0, k=0) - def to_dict(self) -> Dict[str, int]: + def to_dict(self) -> dict[str, int]: return {"i": self.i, "j": self.j, "k": self.k} class VariableKOffset(eve.GenericNode, Generic[ExprT]): k: ExprT - def to_dict(self) -> Dict[str, Optional[int]]: + def to_dict(self) -> dict[str, Optional[int]]: return {"i": 0, "j": 0, "k": None} @datamodels.validator("k") @@ -352,8 +340,8 @@ class ScalarAccess(LocNode): class FieldAccess(eve.GenericNode, Generic[ExprT, VariableKOffsetT]): name: eve.Coerced[eve.SymbolRef] - offset: Union[CartesianOffset, VariableKOffsetT] - data_index: List[ExprT] = eve.field(default_factory=list) + offset: CartesianOffset | VariableKOffsetT + data_index: list[ExprT] = eve.field(default_factory=list) kind: ExprKind = ExprKind.FIELD @classmethod @@ -362,7 +350,7 @@ def centered(cls, *, name: str, loc: Optional[eve.SourceLocation] = None) -> Fie @datamodels.validator("data_index") def data_index_exprs_are_int(self, attribute: datamodels.Attribute, value: Any) -> None: - value = typing.cast(List[Expr], value) + value = typing.cast(list[Expr], value) if value and any( index.dtype is not DataType.AUTO and not index.dtype.isinteger() for index in value ): @@ -370,7 +358,7 @@ def data_index_exprs_are_int(self, attribute: datamodels.Attribute, value: Any) class BlockStmt(eve.GenericNode, eve.SymbolTableTrait, Generic[StmtT]): - body: List[StmtT] + body: list[StmtT] class IfStmt(eve.GenericNode, Generic[StmtT, ExprT]): @@ -397,7 +385,7 @@ class While(eve.GenericNode, Generic[StmtT, ExprT]): """ cond: ExprT - body: List[StmtT] + body: list[StmtT] @datamodels.validator("cond") def condition_is_boolean(self, attribute: datamodels.Attribute, value: Expr) -> None: @@ -414,7 +402,7 @@ def _make_root_validator(impl: datamodels.RootValidator) -> datamodels.RootValid def assign_stmt_dtype_validation(*, strict: bool) -> datamodels.RootValidator: - def _impl(cls: Type[datamodels.DataModel], instance: datamodels.DataModel) -> None: + def _impl(cls: type[datamodels.DataModel], instance: datamodels.DataModel) -> None: assert isinstance(instance, AssignStmt) verify_and_get_common_dtype(cls, [instance.left, instance.right], strict=strict) @@ -433,17 +421,17 @@ class UnaryOp(eve.GenericNode, Generic[ExprT]): @datamodels.root_validator @classmethod - def dtype_propagation(cls: Type[UnaryOp], instance: UnaryOp) -> None: + def dtype_propagation(cls: type[UnaryOp], instance: UnaryOp) -> None: instance.dtype = instance.expr.dtype # type: ignore[attr-defined] @datamodels.root_validator @classmethod - def kind_propagation(cls: Type[UnaryOp], instance: UnaryOp) -> None: + def kind_propagation(cls: type[UnaryOp], instance: UnaryOp) -> None: instance.kind = instance.expr.kind # type: ignore[attr-defined] @datamodels.root_validator @classmethod - def op_to_dtype_check(cls: Type[UnaryOp], instance: UnaryOp) -> None: + def op_to_dtype_check(cls: type[UnaryOp], instance: UnaryOp) -> None: if instance.expr.dtype: if instance.op == UnaryOperator.NOT: if not instance.expr.dtype == DataType.BOOL: @@ -464,18 +452,18 @@ class BinaryOp(eve.GenericNode, Generic[ExprT]): """ # consider parametrizing on op - op: Union[ArithmeticOperator, ComparisonOperator, LogicalOperator] + op: ArithmeticOperator | ComparisonOperator | LogicalOperator left: ExprT right: ExprT @datamodels.root_validator @classmethod - def kind_propagation(cls: Type[BinaryOp], instance: BinaryOp) -> None: + def kind_propagation(cls: type[BinaryOp], instance: BinaryOp) -> None: instance.kind = compute_kind(instance.left, instance.right) # type: ignore[attr-defined] def binary_op_dtype_propagation(*, strict: bool) -> datamodels.RootValidator: - def _impl(cls: Type[BinaryOp], instance: BinaryOp) -> None: + def _impl(cls: type[BinaryOp], instance: BinaryOp) -> None: common_dtype = verify_and_get_common_dtype( cls, [instance.left, instance.right], strict=strict ) @@ -518,12 +506,12 @@ def condition_is_boolean(self, attribute: datamodels.Attribute, value: Expr) -> @datamodels.root_validator @classmethod - def kind_propagation(cls: Type[TernaryOp], instance: TernaryOp) -> None: + def kind_propagation(cls: type[TernaryOp], instance: TernaryOp) -> None: instance.kind = compute_kind(instance.true_expr, instance.false_expr) # type: ignore[attr-defined] def ternary_op_dtype_propagation(*, strict: bool) -> datamodels.RootValidator: - def _impl(cls: Type[TernaryOp], instance: TernaryOp) -> None: + def _impl(cls: type[TernaryOp], instance: TernaryOp) -> None: common_dtype = verify_and_get_common_dtype( cls, [instance.true_expr, instance.false_expr], strict=strict ) @@ -539,17 +527,17 @@ class Cast(eve.GenericNode, Generic[ExprT]): @datamodels.root_validator @classmethod - def kind_propagation(cls: Type[Cast], instance: Cast) -> None: + def kind_propagation(cls: type[Cast], instance: Cast) -> None: instance.kind = compute_kind(instance.expr) # type: ignore[attr-defined] class NativeFuncCall(eve.GenericNode, Generic[ExprT]): func: NativeFunction - args: List[ExprT] + args: list[ExprT] @datamodels.root_validator @classmethod - def arity_check(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: + def arity_check(cls: type[NativeFuncCall], instance: NativeFuncCall) -> None: if instance.func.arity != len(instance.args): raise ValueError( f"{instance.func} accepts {instance.func.arity} arguments, {len(instance.args)} where passed." @@ -557,7 +545,7 @@ def arity_check(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: @datamodels.root_validator @classmethod - def kind_propagation(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: + def kind_propagation(cls: type[NativeFuncCall], instance: NativeFuncCall) -> None: instance.kind = compute_kind(*instance.args) # type: ignore[attr-defined] @@ -573,7 +561,7 @@ def _precision_to_datatype(func: NativeFunction) -> DataType: return DataType.FLOAT64 raise NotImplementedError(f"Found unknown precision specification {func}") - def _impl(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: + def _impl(cls: type[NativeFuncCall], instance: NativeFuncCall) -> None: if instance.func in (NativeFunction.ISFINITE, NativeFunction.ISINF, NativeFunction.ISNAN): instance.dtype = DataType.BOOL # type: ignore[attr-defined] elif instance.func in ( @@ -593,8 +581,8 @@ def _impl(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: def validate_dtype_is_set() -> datamodels.RootValidator: - def _impl(cls: Type[ExprT], instance: ExprT) -> None: - dtype_nodes: List[ExprT] = [] + def _impl(cls: type[ExprT], instance: ExprT) -> None: + dtype_nodes: list[ExprT] = [] for v in utils.flatten(datamodels.astuple(instance)): if isinstance(v, eve.Node): dtype_nodes.extend(v.walk_values().if_hasattr("dtype")) @@ -605,18 +593,18 @@ def _impl(cls: Type[ExprT], instance: ExprT) -> None: nodes_without_dtype.append(node) if len(nodes_without_dtype) > 0: - raise ValueError("Nodes without dtype detected {}".format(nodes_without_dtype)) + raise ValueError(f"Nodes without dtype detected {nodes_without_dtype}") return _make_root_validator(_impl) class _LvalueDimsValidator(eve.VisitorWithSymbolTableTrait): - def __init__(self, vertical_loop_type: Type[eve.Node], decl_type: Type[eve.Node]) -> None: + def __init__(self, vertical_loop_type: type[eve.Node], decl_type: type[eve.Node]) -> None: if vertical_loop_type.__annotations__.get("loop_order") is not LoopOrder: raise ValueError( f"Vertical loop type {vertical_loop_type} has no `loop_order` attribute" ) - if not decl_type.__annotations__.get("dimensions") == Tuple[bool, bool, bool]: + if not decl_type.__annotations__.get("dimensions") == tuple[bool, bool, bool]: raise ValueError( f"Field decl type {decl_type} must have a `dimensions` " "attribute of type `Tuple[bool, bool, bool]`." @@ -632,11 +620,11 @@ def visit_Node( self.generic_visit(node, loop_order=loop_order, **kwargs) def visit_AssignStmt( - self, node: AssignStmt, *, loop_order: LoopOrder, symtable: Dict[str, Any], **kwargs: Any + self, node: AssignStmt, *, loop_order: LoopOrder, symtable: dict[str, Any], **kwargs: Any ) -> None: decl = symtable.get(node.left.name, None) if decl is None: - raise ValueError("Symbol {} not found.".format(node.left.name)) + raise ValueError(f"Symbol {node.left.name} not found.") if not isinstance(decl, self.decl_type): return None @@ -649,7 +637,7 @@ def visit_AssignStmt( ) return None - def _allowed_flags(self, loop_order: LoopOrder) -> List[Tuple[bool, bool, bool]]: + def _allowed_flags(self, loop_order: LoopOrder) -> list[tuple[bool, bool, bool]]: allowed_flags = [(True, True, True)] # ijk always allowed if loop_order is not LoopOrder.PARALLEL: allowed_flags.append((True, True, False)) # ij only allowed in FORWARD and BACKWARD @@ -659,7 +647,7 @@ def _allowed_flags(self, loop_order: LoopOrder) -> List[Tuple[bool, bool, bool]] # TODO(ricoh): consider making gtir.Decl & oir.Decl common and / or adding a VerticalLoop baseclass # TODO(ricoh): in common instead of passing type arguments def validate_lvalue_dims( - vertical_loop_type: Type[eve.Node], decl_type: Type[eve.Node] + vertical_loop_type: type[eve.Node], decl_type: type[eve.Node] ) -> datamodels.RootValidator: """ Validate lvalue dimensions using the root node symbol table. @@ -687,7 +675,7 @@ def validate_lvalue_dims( `Tuple[bool, bool, bool]` in an attribute named `dimensions`. """ - def _impl(cls: Type[datamodels.DataModel], instance: datamodels.DataModel) -> None: + def _impl(cls: type[datamodels.DataModel], instance: datamodels.DataModel) -> None: _LvalueDimsValidator(vertical_loop_type, decl_type).visit(instance) return _make_root_validator(_impl) @@ -772,7 +760,7 @@ def at_endpt( @datamodels.root_validator @classmethod - def check_start_before_end(cls: Type[HorizontalInterval], instance: HorizontalInterval) -> None: + def check_start_before_end(cls: type[HorizontalInterval], instance: HorizontalInterval) -> None: if instance.start and instance.end and not (instance.start <= instance.end): raise ValueError( f"End ({instance.end}) is not after or equal to start ({instance.start})" @@ -816,7 +804,7 @@ class HorizontalMask(LocNode): j: HorizontalInterval @property - def intervals(self) -> Tuple[HorizontalInterval, HorizontalInterval]: + def intervals(self) -> tuple[HorizontalInterval, HorizontalInterval]: return (self.i, self.j) @@ -824,7 +812,7 @@ class HorizontalRestriction(eve.GenericNode, Generic[StmtT]): """A specialization of the horizontal space.""" mask: HorizontalMask - body: List[StmtT] + body: list[StmtT] def data_type_to_typestr(dtype: DataType) -> str: @@ -849,21 +837,17 @@ def data_type_to_typestr(dtype: DataType) -> str: # different operators use the same key: UnaryOperator.POS == BinaryOperator.ADD OP_TO_UFUNC_NAME: Final[ Mapping[ - Union[ - Type[UnaryOperator], - Type[ArithmeticOperator], - Type[ComparisonOperator], - Type[LogicalOperator], - Type[NativeFunction], - ], + type[UnaryOperator] + | type[ArithmeticOperator] + | type[ComparisonOperator] + | type[LogicalOperator] + | type[NativeFunction], Mapping[ - Union[ - UnaryOperator, - ArithmeticOperator, - ComparisonOperator, - LogicalOperator, - NativeFunction, - ], + UnaryOperator + | ArithmeticOperator + | ComparisonOperator + | LogicalOperator + | NativeFunction, str, ], ] @@ -929,9 +913,7 @@ def data_type_to_typestr(dtype: DataType) -> str: def op_to_ufunc( - op: Union[ - UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction - ], + op: UnaryOperator | ArithmeticOperator | ComparisonOperator | LogicalOperator | NativeFunction, ) -> np.ufunc: if not isinstance( op, (UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction) @@ -942,7 +924,7 @@ def op_to_ufunc( return getattr(ufuncs, OP_TO_UFUNC_NAME[type(op)][op]) -@functools.lru_cache(maxsize=None) +@functools.cache def typestr_to_data_type(typestr: str) -> DataType: if not isinstance(typestr, str) or len(typestr) < 3 or not typestr[2:].isnumeric(): return DataType.INVALID diff --git a/src/gt4py/cartesian/gtc/cuir/cuir.py b/src/gt4py/cartesian/gtc/cuir/cuir.py index fb6d28d071..c8b2824697 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union from gt4py import eve from gt4py.cartesian.gtc import common @@ -57,7 +57,7 @@ def zero_k_offset(self, attribute: datamodels.Attribute, value: CartesianOffset) raise ValueError("No k-offset allowed") @datamodels.validator("data_index") - def no_additional_dimensions(self, attribute: datamodels.Attribute, value: List[int]) -> None: + def no_additional_dimensions(self, attribute: datamodels.Attribute, value: list[int]) -> None: if len(value) > 0: raise ValueError("IJ-cached higher-dimensional fields are not supported") @@ -67,7 +67,7 @@ class KCacheAccess(common.FieldAccess[Expr, VariableKOffset], Expr): @datamodels.validator("offset") def has_no_ij_offset( - self, attribute: datamodels.Attribute, value: Union[CartesianOffset, VariableKOffset] + self, attribute: datamodels.Attribute, value: CartesianOffset | VariableKOffset ) -> None: offsets = value.to_dict() if not offsets["i"] == offsets["j"] == 0: @@ -75,13 +75,13 @@ def has_no_ij_offset( @datamodels.validator("offset") def not_variable_offset( - self, attribute: datamodels.Attribute, value: Union[CartesianOffset, VariableKOffset] + self, attribute: datamodels.Attribute, value: CartesianOffset | VariableKOffset ) -> None: if isinstance(value, VariableKOffset): raise ValueError("Cannot k-cache a variable k offset") @datamodels.validator("data_index") - def no_additional_dimensions(self, attribute: datamodels.Attribute, value: List[int]) -> None: + def no_additional_dimensions(self, attribute: datamodels.Attribute, value: list[int]) -> None: if len(value) > 0: raise ValueError("K-cached higher-dimensional fields are not supported") @@ -94,7 +94,7 @@ class AssignStmt( class MaskStmt(Stmt): mask: Expr - body: List[Stmt] + body: list[Stmt] class While(common.While[Stmt, Expr], Stmt): @@ -132,8 +132,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class FieldDecl(Decl): - dimensions: Tuple[bool, bool, bool] - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + dimensions: tuple[bool, bool, bool] + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) class ScalarDecl(Decl): @@ -145,7 +145,7 @@ class LocalScalar(Decl): class Temporary(Decl): - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) class Positional(Decl): @@ -154,15 +154,15 @@ class Positional(Decl): class IJExtent(LocNode): - i: Tuple[int, int] - j: Tuple[int, int] + i: tuple[int, int] + j: tuple[int, int] @classmethod def zero(cls) -> IJExtent: return cls(i=(0, 0), j=(0, 0)) @classmethod - def from_offset(cls, offset: Union[CartesianOffset, VariableKOffset]) -> IJExtent: + def from_offset(cls, offset: CartesianOffset | VariableKOffset) -> IJExtent: if isinstance(offset, VariableKOffset): return cls(i=(0, 0), j=(0, 0)) return cls(i=(offset.i, offset.i), j=(offset.j, offset.j)) @@ -181,14 +181,14 @@ def __add__(self, other: IJExtent) -> IJExtent: class KExtent(LocNode): - k: Tuple[int, int] + k: tuple[int, int] @classmethod def zero(cls) -> KExtent: return cls(k=(0, 0)) @classmethod - def from_offset(cls, offset: Union[CartesianOffset, VariableKOffset]) -> KExtent: + def from_offset(cls, offset: CartesianOffset | VariableKOffset) -> KExtent: MAX_OFFSET = 1000 if isinstance(offset, VariableKOffset): return cls(k=(-MAX_OFFSET, MAX_OFFSET)) @@ -207,29 +207,29 @@ class KCacheDecl(Decl): class HorizontalExecution(LocNode, eve.SymbolTableTrait): - body: List[Stmt] - declarations: List[LocalScalar] + body: list[Stmt] + declarations: list[LocalScalar] extent: Optional[IJExtent] = None class VerticalLoopSection(LocNode): start: AxisBound end: AxisBound - horizontal_executions: List[HorizontalExecution] + horizontal_executions: list[HorizontalExecution] class VerticalLoop(LocNode): loop_order: LoopOrder - sections: List[VerticalLoopSection] - ij_caches: List[IJCacheDecl] - k_caches: List[KCacheDecl] + sections: list[VerticalLoopSection] + ij_caches: list[IJCacheDecl] + k_caches: list[KCacheDecl] class Kernel(LocNode): - vertical_loops: List[VerticalLoop] + vertical_loops: list[VerticalLoop] @datamodels.validator("vertical_loops") - def check_loops(self, attribute: datamodels.Attribute, value: List[VerticalLoop]) -> None: + def check_loops(self, attribute: datamodels.Attribute, value: list[VerticalLoop]) -> None: if len(value) < 1: raise ValueError("At least one loop required") parallel = [loop.loop_order == LoopOrder.PARALLEL for loop in value] @@ -237,7 +237,7 @@ def check_loops(self, attribute: datamodels.Attribute, value: List[VerticalLoop] raise ValueError("Mixed k-parallelism in kernel") -def axis_size_decls() -> List[ScalarDecl]: +def axis_size_decls() -> list[ScalarDecl]: return [ ScalarDecl(name="i_size", dtype=common.DataType.INT32), ScalarDecl(name="j_size", dtype=common.DataType.INT32), @@ -247,8 +247,8 @@ def axis_size_decls() -> List[ScalarDecl]: class Program(LocNode, eve.ValidatedSymbolTableTrait): name: str - params: List[Decl] - positionals: List[Positional] - temporaries: List[Temporary] - kernels: List[Kernel] - axis_sizes: List[ScalarDecl] = eve.field(default_factory=axis_size_decls) + params: list[Decl] + positionals: list[Positional] + temporaries: list[Temporary] + kernels: list[Kernel] + axis_sizes: list[ScalarDecl] = eve.field(default_factory=axis_size_decls) diff --git a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py index 47059984df..0867e2bd0d 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Collection, Dict, Final, List, Set, Union +from collections.abc import Collection +from typing import Any, Final import numpy as np @@ -55,7 +56,7 @@ def visit_FieldAccess(self, node: cuir.FieldAccess, **kwargs: Any): if isinstance(node, cuir.KCacheAccess): return self.generic_visit(node, **kwargs) - symtable: Dict[str, cuir.Decl] = kwargs["symtable"] + symtable: dict[str, cuir.Decl] = kwargs["symtable"] def maybe_const(s): try: @@ -79,7 +80,7 @@ def maybe_const(s): return f"{name}({offset}{data_index_str})" def visit_IJCacheAccess( - self, node: cuir.IJCacheAccess, symtable: Dict[str, Any], **kwargs: Any + self, node: cuir.IJCacheAccess, symtable: dict[str, Any], **kwargs: Any ) -> str: decl = symtable[node.name] assert isinstance(decl, cuir.IJCacheDecl) @@ -291,7 +292,7 @@ def k_cache_var(name: str, offset: int) -> str: return name + (f"p{offset}" if offset >= 0 else f"m{-offset}") @classmethod - def k_cache_vars(cls, k_cache: cuir.KCacheDecl) -> List[str]: + def k_cache_vars(cls, k_cache: cuir.KCacheDecl) -> list[str]: assert k_cache.extent return [ cls.k_cache_var(k_cache.name, offset) @@ -299,8 +300,8 @@ def k_cache_vars(cls, k_cache: cuir.KCacheDecl) -> List[str]: ] def visit_VerticalLoop( - self, node: cuir.VerticalLoop, *, symtable: Dict[str, Any], **kwargs: Any - ) -> Union[str, Collection[str]]: + self, node: cuir.VerticalLoop, *, symtable: dict[str, Any], **kwargs: Any + ) -> str | Collection[str]: fields = { name: data_dims for name, data_dims in node.walk_values() @@ -420,7 +421,7 @@ def visit_VerticalLoop( """ ) - def visit_Program(self, node: cuir.Program, **kwargs: Any) -> Union[str, Collection[str]]: + def visit_Program(self, node: cuir.Program, **kwargs: Any) -> str | Collection[str]: def loop_start(vertical_loop: cuir.VerticalLoop) -> str: if vertical_loop.loop_order == cuir.LoopOrder.FORWARD: return self.visit(vertical_loop.sections[0].start, **kwargs) @@ -428,7 +429,7 @@ def loop_start(vertical_loop: cuir.VerticalLoop) -> str: return self.visit(vertical_loop.sections[0].end, **kwargs) + " - 1" return "0" - def loop_fields(vertical_loop: cuir.VerticalLoop) -> Set[str]: + def loop_fields(vertical_loop: cuir.VerticalLoop) -> set[str]: return ( vertical_loop.walk_values().if_isinstance(cuir.FieldAccess).getattr("name").to_set() ) diff --git a/src/gt4py/cartesian/gtc/cuir/extent_analysis.py b/src/gt4py/cartesian/gtc/cuir/extent_analysis.py index 3fd365fb76..8456a65780 100644 --- a/src/gt4py/cartesian/gtc/cuir/extent_analysis.py +++ b/src/gt4py/cartesian/gtc/cuir/extent_analysis.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections import defaultdict -from typing import Any, Dict +from typing import Any from gt4py import eve from gt4py.cartesian.gtc.cuir import cuir @@ -15,17 +15,17 @@ class CacheExtents(eve.NodeTranslator): def visit_IJCacheDecl( - self, node: cuir.IJCacheDecl, *, ij_extents: Dict[str, cuir.KExtent], **kwargs: Any + self, node: cuir.IJCacheDecl, *, ij_extents: dict[str, cuir.KExtent], **kwargs: Any ) -> cuir.IJCacheDecl: return cuir.IJCacheDecl(name=node.name, dtype=node.dtype, extent=ij_extents[node.name]) def visit_KCacheDecl( - self, node: cuir.KCacheDecl, *, k_extents: Dict[str, cuir.KExtent], **kwargs: Any + self, node: cuir.KCacheDecl, *, k_extents: dict[str, cuir.KExtent], **kwargs: Any ) -> cuir.KCacheDecl: return cuir.KCacheDecl(name=node.name, dtype=node.dtype, extent=k_extents[node.name]) def visit_VerticalLoop(self, node: cuir.VerticalLoop) -> cuir.VerticalLoop: - ij_extents: Dict[str, cuir.IJExtent] = defaultdict(cuir.IJExtent.zero) + ij_extents: dict[str, cuir.IJExtent] = defaultdict(cuir.IJExtent.zero) for horizontal_execution in node.walk_values().if_isinstance(cuir.HorizontalExecution): ij_access_extents = ( horizontal_execution.walk_values() diff --git a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py index fa95ec8cba..62ae177677 100644 --- a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py @@ -10,7 +10,7 @@ import functools from dataclasses import dataclass, field -from typing import Any, Dict, List, Set, Union, cast +from typing import Any, cast from typing_extensions import Protocol @@ -46,8 +46,8 @@ class OIRToCUIR(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @dataclass class Context: new_symbol_name: SymbolNameCreator - accessed_fields: Set[str] = field(default_factory=set) - positionals: Dict[int, cuir.Positional] = field(default_factory=dict) + accessed_fields: set[str] = field(default_factory=set) + positionals: dict[int, cuir.Positional] = field(default_factory=dict) def make_positional(self, axis: int) -> cuir.FieldAccess: axis_name = ["i", "j", "k"][axis] @@ -96,7 +96,7 @@ def visit_VariableKOffset( return cuir.VariableKOffset(k=self.visit(node.k, **kwargs)) def _mask_to_expr(self, mask: common.HorizontalMask, ctx: Context) -> cuir.Expr: - mask_expr: List[cuir.Expr] = [] + mask_expr: list[cuir.Expr] = [] for axis_index, interval in enumerate(mask.intervals): if interval.is_single_index(): assert interval.start is not None @@ -140,11 +140,11 @@ def visit_FieldAccess( self, node: oir.FieldAccess, *, - ij_caches: Dict[str, cuir.IJCacheDecl], - k_caches: Dict[str, cuir.KCacheDecl], + ij_caches: dict[str, cuir.IJCacheDecl], + k_caches: dict[str, cuir.KCacheDecl], ctx: Context, **kwargs: Any, - ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]: + ) -> cuir.FieldAccess | cuir.IJCacheAccess | cuir.KCacheAccess: data_index = self.visit( node.data_index, ij_caches=ij_caches, k_caches=k_caches, ctx=ctx, **kwargs ) @@ -169,8 +169,8 @@ def visit_FieldAccess( ) def visit_ScalarAccess( - self, node: oir.ScalarAccess, *, symtable: Dict[str, Any], **kwargs: Any - ) -> Union[cuir.ScalarAccess, cuir.FieldAccess]: + self, node: oir.ScalarAccess, *, symtable: dict[str, Any], **kwargs: Any + ) -> cuir.ScalarAccess | cuir.FieldAccess: if isinstance(symtable.get(node.name, None), oir.ScalarDecl): return cuir.FieldAccess( name=node.name, offset=common.CartesianOffset.zero(), dtype=node.dtype @@ -228,7 +228,7 @@ def visit_VerticalLoopSection( ) def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, symtable: Dict[str, Any], ctx: Context, **kwargs: Any + self, node: oir.VerticalLoop, *, symtable: dict[str, Any], ctx: Context, **kwargs: Any ) -> cuir.Kernel: assert not any(c.fill or c.flush for c in node.caches if isinstance(c, oir.KCache)) ij_caches = { diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py b/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py index a30678b3c9..4c979d26b7 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, List, TypeAlias +from typing import Any, TypeAlias from dace import data, dtypes, symbolic @@ -82,8 +82,8 @@ def _group_statements(self, node: ControlFlow) -> list[oir.CodeBlock | ControlFl This function only groups statements. The job of visiting the groups statements is left to the caller. """ - statements: List[ControlFlow | oir.CodeBlock | common.Stmt] = [] - groups: List[ControlFlow | oir.CodeBlock] = [] + statements: list[ControlFlow | oir.CodeBlock | common.Stmt] = [] + groups: list[ControlFlow | oir.CodeBlock] = [] for statement in node.body: if isinstance(statement, ControlFlow): diff --git a/src/gt4py/cartesian/gtc/dace/treeir.py b/src/gt4py/cartesian/gtc/dace/treeir.py index 390ef328f1..30f0cd5003 100644 --- a/src/gt4py/cartesian/gtc/dace/treeir.py +++ b/src/gt4py/cartesian/gtc/dace/treeir.py @@ -8,9 +8,10 @@ from __future__ import annotations +from collections.abc import Generator from dataclasses import dataclass from types import TracebackType -from typing import Generator, TypeAlias +from typing import TypeAlias from dace import Memlet, data, dtypes, nodes diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 6bc0de399f..d20a5e7f77 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import re -from functools import lru_cache +from functools import cache import numpy as np from dace import data, dtypes, symbolic @@ -65,7 +65,7 @@ def data_type_to_dace_typeclass(data_type: common.DataType) -> dtypes.typeclass: return dtypes.typeclass(dtype.type) -@lru_cache(maxsize=None) +@cache def get_dace_symbol( name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32 ) -> symbolic.symbol: diff --git a/src/gt4py/cartesian/gtc/debug/debug_codegen.py b/src/gt4py/cartesian/gtc/debug/debug_codegen.py index 6b4edd666c..122fe859ff 100644 --- a/src/gt4py/cartesian/gtc/debug/debug_codegen.py +++ b/src/gt4py/cartesian/gtc/debug/debug_codegen.py @@ -7,9 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Generator +from collections.abc import Generator, Mapping from contextlib import contextmanager -from typing import Mapping from gt4py import eve from gt4py.cartesian.gtc import common as gtc_common, definitions as gtc_definitions, oir diff --git a/src/gt4py/cartesian/gtc/definitions.py b/src/gt4py/cartesian/gtc/definitions.py index 4c9efad84c..5f09bf5869 100644 --- a/src/gt4py/cartesian/gtc/definitions.py +++ b/src/gt4py/cartesian/gtc/definitions.py @@ -47,7 +47,7 @@ def is_valid(cls, value, *, ndims=(1, None)): if isinstance(ndims, numbers.Integral): ndims = tuple([ndims] * 2) elif not isinstance(ndims, tuple) or len(ndims) != 2: - raise ValueError("Invalid 'ndims' definition ({})".format(ndims)) + raise ValueError(f"Invalid 'ndims' definition ({ndims})") try: cls._check_value(value, ndims) @@ -90,12 +90,12 @@ def __new__(cls, sizes, *args, ndims=None): elif isinstance(ndims, int): ndims = tuple([ndims] * 2) elif not isinstance(ndims, tuple) or len(ndims) != 2: - raise ValueError("Invalid 'ndims' definition ({})".format(ndims)) + raise ValueError(f"Invalid 'ndims' definition ({ndims})") try: cls._check_value(sizes, ndims=ndims) except Exception as e: - raise TypeError("Invalid {} definition".format(cls.__name__)) from e + raise TypeError(f"Invalid {cls.__name__} definition") from e else: return super().__new__(cls, sizes) @@ -104,7 +104,7 @@ def __getattr__(self, name): value = self[CartesianSpace.Axis.symbols.index(name)] except (IndexError, ValueError) as e: raise AttributeError( - "'{}' object has no attribute '{}'".format(self.__class__.__name__, name) + f"'{self.__class__.__name__}' object has no attribute '{name}'" ) from e else: return value @@ -179,9 +179,7 @@ def __ge__(self, other): ) def __repr__(self): - return "{cls_name}({value})".format( - cls_name=type(self).__name__, value=tuple.__repr__(self) - ) + return f"{type(self).__name__}({tuple.__repr__(self)})" def __hash__(self): return tuple.__hash__(self) @@ -210,7 +208,7 @@ def intersection(self, other): def _apply(self, other, func): if not isinstance(other, type(self)) or len(self) != len(other): - raise ValueError("Incompatible instance '{obj}'".format(obj=other)) + raise ValueError(f"Incompatible instance '{other}'") return type(self)([func(a, b) for a, b in zip(self, other)]) @@ -224,7 +222,7 @@ def _broadcast(self, value): def _compare(self, other, op, reduction_op): if len(self) != len(other): # or not isinstance(other, type(self)) - raise ValueError("Incompatible instance '{obj}'".format(obj=other)) + raise ValueError(f"Incompatible instance '{other}'") return reduction_op(op(a, b) for a, b in zip(self, other)) @@ -277,7 +275,7 @@ def is_valid(cls, value, *, ndims=(1, None)): if isinstance(ndims, int): ndims = tuple([ndims] * 2) elif not isinstance(ndims, tuple) or len(ndims) != 2: - raise ValueError("Invalid 'ndims' definition ({})".format(ndims)) + raise ValueError(f"Invalid 'ndims' definition ({ndims})") try: cls._check_value(value, ndims) @@ -321,7 +319,7 @@ def __getattr__(self, name): value = self[CartesianSpace.Axis.symbols.index(name)] except (IndexError, ValueError) as e: raise AttributeError( - "'{}' object has no attribute '{}'".format(self.__class__.__name__, name) + f"'{self.__class__.__name__}' object has no attribute '{name}'" ) from e else: return value @@ -372,9 +370,7 @@ def __ge__(self, other): return self._compare(self._broadcast(other), operator.ge) def __repr__(self): - return "{cls_name}({value})".format( - cls_name=self.__class__.__name__, value=tuple.__repr__(self) - ) + return f"{self.__class__.__name__}({tuple.__repr__(self)})" def __hash__(self): return tuple.__hash__(self) @@ -428,7 +424,7 @@ def intersection(self, other): def _apply(self, other, left_func, right_func=None): if not isinstance(other, FrameTuple) or len(self) != len(other): - raise ValueError("Incompatible instance '{obj}'".format(obj=other)) + raise ValueError(f"Incompatible instance '{other}'") right_func = right_func or left_func return type(self)( @@ -440,7 +436,7 @@ def _reduce(self, reduce_func, out_type=tuple): def _compare(self, other, left_op, right_op=None): if len(self) != len(other): # or not isinstance(other, Frame) - raise ValueError("Incompatible instance '{obj}'".format(obj=other)) + raise ValueError(f"Incompatible instance '{other}'") right_op = right_op or left_op return all(left_op(a[0], b[0]) and right_op(a[1], b[1]) for a, b in zip(self, other)) @@ -479,7 +475,7 @@ def _check_value(cls, value, ndims): @classmethod def from_offset(cls, offset): if not Index.is_valid(offset): - raise ValueError("Invalid offset value ({})".format(offset)) + raise ValueError(f"Invalid offset value ({offset})") return cls([(-1 * min(0, i), max(0, i)) for i in offset]) @property @@ -524,7 +520,7 @@ def empty(cls, ndims=CartesianSpace.ndim): @classmethod def from_offset(cls, offset): if not Index.is_valid(offset): - raise ValueError("Invalid offset value ({})".format(offset)) + raise ValueError(f"Invalid offset value ({offset})") return cls([(i, i) for i in offset]) def __and__(self, other): @@ -571,7 +567,7 @@ def to_boundary(self): def _apply(self, other, left_func, right_func=None): if not isinstance(other, FrameTuple) or len(self) != len(other): - raise ValueError("Incompatible instance '{obj}'".format(obj=other)) + raise ValueError(f"Incompatible instance '{other}'") right_func = right_func or left_func result = [None] * len(self) @@ -622,7 +618,7 @@ def empty(cls, ndims=CartesianSpace.ndim): @classmethod def from_offset(cls, offset): if not Index.is_valid(offset): - raise ValueError("Invalid offset value ({})".format(offset)) + raise ValueError(f"Invalid offset value ({offset})") return cls([(min(i, 0), max(i, 0)) for i in offset]) def to_boundary(self): diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py index 5ca766c272..7a8e221a39 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py @@ -9,7 +9,7 @@ from __future__ import annotations import enum -from typing import Any, List, Tuple, Union +from typing import Any, Union from gt4py import eve from gt4py.cartesian.gtc import common @@ -54,7 +54,7 @@ class BlockStmt(common.BlockStmt[Stmt], Stmt): class AssignStmt(common.AssignStmt[Union[LocalAccess, AccessorRef], Expr], Stmt): @datamodels.validator("left") def no_horizontal_offset_in_assignment( - self, attribute: datamodels.Attribute, value: Union[LocalAccess, AccessorRef] + self, attribute: datamodels.Attribute, value: LocalAccess | AccessorRef ) -> None: if isinstance(value, AccessorRef): offsets = value.offset.to_dict() @@ -95,7 +95,7 @@ class Cast(common.Cast[Expr], Expr): class Temporary(LocNode): name: eve.Coerced[eve.SymbolName] dtype: common.DataType - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) class GTLevel(LocNode): @@ -120,8 +120,8 @@ class LocalVarDecl(LocNode): class GTApplyMethod(LocNode): interval: GTInterval - body: List[Stmt] - local_variables: List[LocalVarDecl] + body: list[Stmt] + local_variables: list[LocalVarDecl] @enum.unique @@ -131,15 +131,15 @@ class Intent(eve.StrEnum): class GTExtent(LocNode): - i: Tuple[int, int] - j: Tuple[int, int] - k: Tuple[int, int] + i: tuple[int, int] + j: tuple[int, int] + k: tuple[int, int] @classmethod def zero(cls) -> GTExtent: return cls(i=(0, 0), j=(0, 0), k=(0, 0)) - def __add__(self, offset: Union[common.CartesianOffset, VariableKOffset]) -> GTExtent: + def __add__(self, offset: common.CartesianOffset | VariableKOffset) -> GTExtent: if isinstance(offset, common.CartesianOffset): return GTExtent( i=(min(self.i[0], offset.i), max(self.i[1], offset.i)), @@ -162,12 +162,12 @@ class GTAccessor(LocNode): class GTParamList(LocNode): - accessors: List[GTAccessor] + accessors: list[GTAccessor] class GTFunctor(LocNode, eve.SymbolTableTrait): name: eve.Coerced[eve.SymbolName] - applies: List[GTApplyMethod] + applies: list[GTApplyMethod] param_list: GTParamList @@ -194,8 +194,8 @@ def __init__(self, *args: Any, **kwargs: Any): class FieldDecl(ApiParamDecl): - dimensions: Tuple[bool, bool, bool] - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + dimensions: tuple[bool, bool, bool] + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) class GlobalParamDecl(ApiParamDecl): @@ -220,10 +220,10 @@ class GTStage(LocNode): functor: eve.Coerced[eve.SymbolRef] # `args` are SymbolRefs to GTComputation `arguments` (interpreted as parameters) # or `temporaries` - args: List[Arg] + args: list[Arg] @datamodels.validator("args") - def has_args(self, attribute: datamodels.Attribute, value: List[Arg]) -> None: + def has_args(self, attribute: datamodels.Attribute, value: list[Arg]) -> None: if not value: raise ValueError("At least one argument required") @@ -243,8 +243,8 @@ class KCache(Cache): class GTMultiStage(LocNode): loop_order: common.LoopOrder - stages: List[GTStage] - caches: List[Cache] + stages: list[GTStage] + caches: list[Cache] class GTComputationCall(LocNode, eve.SymbolTableTrait): @@ -252,18 +252,18 @@ class GTComputationCall(LocNode, eve.SymbolTableTrait): # and the parameters of the function object. # We could represent this closer to the C++ code by splitting call and definition of the # function object. - arguments: List[Arg] - extra_decls: List[ComputationDecl] - temporaries: List[Temporary] - multi_stages: List[GTMultiStage] + arguments: list[Arg] + extra_decls: list[ComputationDecl] + temporaries: list[Temporary] + multi_stages: list[GTMultiStage] class Program(LocNode, eve.ValidatedSymbolTableTrait): name: str - parameters: List[ + parameters: list[ ApiParamDecl ] # in the current implementation these symbols can be accessed by the functor body - functors: List[GTFunctor] + functors: list[GTFunctor] gt_computation: GTComputationCall # here could be the CtrlFlow region _validate_dtype_is_set = common.validate_dtype_is_set() diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py index 4fbc3645fd..6acf224c17 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Collection, Dict, Final, Optional, Union +from collections.abc import Collection +from typing import Any, Final, Optional import numpy as np @@ -77,8 +78,8 @@ def visit_AccessorRef( self, accessor_ref: gtcpp.AccessorRef, *, - symtable: Dict[str, gtcpp.GTAccessor], - temp_decls: Optional[Dict[str, gtcpp.Temporary]] = None, + symtable: dict[str, gtcpp.GTAccessor], + temp_decls: Optional[dict[str, gtcpp.Temporary]] = None, **kwargs: Any, ): temp_decls = temp_decls or {} @@ -258,7 +259,7 @@ def visit_Temporary(self, node: gtcpp.Temporary, **kwargs: Any) -> str: def visit_GTComputationCall( self, node: gtcpp.GTComputationCall, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: computation_name = type(node).__name__ + str(id(node)) return self.generic_visit(node, computation_name=computation_name, **kwargs) @@ -282,7 +283,7 @@ def visit_GTComputationCall( """ ) - def visit_Program(self, node: gtcpp.Program, **kwargs: Any) -> Union[str, Collection[str]]: + def visit_Program(self, node: gtcpp.Program, **kwargs: Any) -> str | Collection[str]: temp_decls = {temp.name: temp for temp in node.gt_computation.temporaries} return self.generic_visit(node, temp_decls=temp_decls, **kwargs) diff --git a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py index 0d5b1517c5..5381335c87 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py @@ -10,8 +10,9 @@ import functools import itertools +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Set, Union, cast +from typing import Any, cast from devtools import debug # noqa: F401 [unused-import] from typing_extensions import Protocol @@ -30,7 +31,7 @@ # - Each VerticalLoop is MultiStage -def _extract_accessors(node: eve.Node, temp_names: Set[str]) -> List[gtcpp.GTAccessor]: +def _extract_accessors(node: eve.Node, temp_names: set[str]) -> list[gtcpp.GTAccessor]: extents = ( node.walk_values() .if_isinstance(gtcpp.AccessorRef) @@ -42,7 +43,7 @@ def _extract_accessors(node: eve.Node, temp_names: Set[str]) -> List[gtcpp.GTAcc ) ) - inout_fields: Set[str] = ( + inout_fields: set[str] = ( node.walk_values() .if_isinstance(gtcpp.AssignStmt) .getattr("left") @@ -96,31 +97,31 @@ def __call__(self, name: str) -> str: ... class OIRToGTCpp(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @dataclass class ProgramContext: - functors: List[gtcpp.GTFunctor] = field(default_factory=list) + functors: list[gtcpp.GTFunctor] = field(default_factory=list) - def add_functor(self, functor: gtcpp.GTFunctor) -> "OIRToGTCpp.ProgramContext": + def add_functor(self, functor: gtcpp.GTFunctor) -> OIRToGTCpp.ProgramContext: self.functors.append(functor) return self @dataclass class GTComputationContext: create_symbol_name: SymbolNameCreator - temporaries: List[gtcpp.Temporary] = field(default_factory=list) - positionals: Dict[int, gtcpp.Positional] = field(default_factory=dict) - axis_lengths: Dict[int, gtcpp.AxisLength] = field(default_factory=dict) - _arguments: Set[str] = field(default_factory=set) + temporaries: list[gtcpp.Temporary] = field(default_factory=list) + positionals: dict[int, gtcpp.Positional] = field(default_factory=dict) + axis_lengths: dict[int, gtcpp.AxisLength] = field(default_factory=dict) + _arguments: set[str] = field(default_factory=set) def add_temporaries( - self, temporaries: List[gtcpp.Temporary] + self, temporaries: list[gtcpp.Temporary] ) -> OIRToGTCpp.GTComputationContext: self.temporaries.extend(temporaries) return self @property - def arguments(self) -> List[gtcpp.Arg]: + def arguments(self) -> list[gtcpp.Arg]: return [gtcpp.Arg(name=name) for name in self._arguments] - def add_arguments(self, arguments: Set[str]) -> OIRToGTCpp.GTComputationContext: + def add_arguments(self, arguments: set[str]) -> OIRToGTCpp.GTComputationContext: self._arguments.update(arguments) return self @@ -147,7 +148,7 @@ def make_length(self, axis: int) -> gtcpp.AccessorRef: return self._make_scalar_accessor(length.name) @property - def extra_decls(self) -> List[gtcpp.ComputationDecl]: + def extra_decls(self) -> list[gtcpp.ComputationDecl]: return list(self.positionals.values()) + list(self.axis_lengths.values()) def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> gtcpp.Literal: @@ -192,7 +193,7 @@ def visit_FieldAccess(self, node: oir.FieldAccess, **kwargs: Any) -> gtcpp.Acces def visit_ScalarAccess( self, node: oir.ScalarAccess, **kwargs: Any - ) -> Union[gtcpp.AccessorRef, gtcpp.LocalAccess]: + ) -> gtcpp.AccessorRef | gtcpp.LocalAccess: assert "symtable" in kwargs if node.name in kwargs["symtable"]: symbol = kwargs["symtable"][node.name] @@ -225,7 +226,7 @@ def visit_Interval(self, node: oir.Interval, **kwargs: Any) -> gtcpp.GTInterval: def _mask_to_expr( self, mask: common.HorizontalMask, comp_ctx: GTComputationContext ) -> gtcpp.Expr: - mask_expr: List[gtcpp.Expr] = [] + mask_expr: list[gtcpp.Expr] = [] for axis_index, interval in enumerate(mask.intervals): if interval.is_single_index(): assert interval.start is not None diff --git a/src/gt4py/cartesian/gtc/gtir.py b/src/gt4py/cartesian/gtc/gtir.py index 3ca91025dd..c897329f0e 100644 --- a/src/gt4py/cartesian/gtc/gtir.py +++ b/src/gt4py/cartesian/gtc/gtir.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Set, Tuple, Type +from typing import Any from gt4py import eve from gt4py.cartesian.gtc import common @@ -79,7 +79,7 @@ def no_horizontal_offset_in_assignment( @datamodels.root_validator @classmethod def no_write_and_read_with_offset_of_same_field( - cls: Type[ParAssignStmt], instance: ParAssignStmt + cls: type[ParAssignStmt], instance: ParAssignStmt ) -> None: if isinstance(instance.left, FieldAccess): offset_reads = ( @@ -144,7 +144,7 @@ class While(common.While[Stmt, Expr], Stmt): @datamodels.validator("body") def _no_write_and_read_with_horizontal_offset_all( - self, attribute: datamodels.Attribute, value: List[Stmt] + self, attribute: datamodels.Attribute, value: list[Stmt] ) -> None: """In a while loop all variables must not be written and read with a horizontal offset.""" if names := _written_and_read_with_offset(value): @@ -182,8 +182,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class FieldDecl(Decl): - dimensions: Tuple[bool, bool, bool] - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + dimensions: tuple[bool, bool, bool] + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) class ScalarDecl(Decl): @@ -199,13 +199,13 @@ class Interval(LocNode): class VerticalLoop(LocNode): interval: Interval loop_order: common.LoopOrder - temporaries: List[FieldDecl] - body: List[Stmt] + temporaries: list[FieldDecl] + body: list[Stmt] @datamodels.root_validator @classmethod def _no_write_and_read_with_horizontal_offset( - cls: Type[VerticalLoop], instance: VerticalLoop + cls: type[VerticalLoop], instance: VerticalLoop ) -> None: """ In the same VerticalLoop a field must not be written and read with a horizontal offset. @@ -231,15 +231,15 @@ class Argument(eve.Node): class Stencil(LocNode, eve.ValidatedSymbolTableTrait): name: str - api_signature: List[Argument] - params: List[Decl] - vertical_loops: List[VerticalLoop] - externals: Dict[str, Literal] - sources: Dict[str, str] + api_signature: list[Argument] + params: list[Decl] + vertical_loops: list[VerticalLoop] + externals: dict[str, Literal] + sources: dict[str, str] docstring: str @property - def param_names(self) -> List[str]: + def param_names(self) -> list[str]: return [p.name for p in self.params] _validate_lvalue_dims = common.validate_lvalue_dims(VerticalLoop, FieldDecl) @@ -254,16 +254,16 @@ def _variablek_fieldaccess(node) -> bool: # TODO(havogt): either move to eve or will be removed in the attr-based eve if a List[Node] is represented as a CollectionNode -def _written_and_read_with_offset(stmts: List[Stmt]) -> Set[str]: +def _written_and_read_with_offset(stmts: list[Stmt]) -> set[str]: """Return a list of names that are written to and read with offset.""" - def _writes(stmts: List[Stmt]) -> Set[str]: + def _writes(stmts: list[Stmt]) -> set[str]: result = set() for left in eve.walk_values(stmts).if_isinstance(ParAssignStmt).getattr("left"): result |= eve.walk_values(left).if_isinstance(FieldAccess).getattr("name").to_set() return result - def _reads_with_offset(stmts: List[Stmt]) -> Set[str]: + def _reads_with_offset(stmts: list[Stmt]) -> set[str]: return ( eve.walk_values(stmts) .filter(_cartesian_fieldaccess) diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 3207a8bf5b..4546a6b4cc 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass, field -from typing import Any, List, Set, Union +from typing import Any from gt4py import eve from gt4py.cartesian import utils @@ -25,7 +25,7 @@ def validate_stencil_memory_accesses(node: oir.Stencil) -> oir.Stencil: indirect read-with-offset through temporaries. """ - def _writes(node: oir.Stencil) -> Set[str]: + def _writes(node: oir.Stencil) -> set[str]: result = set() for left in node.walk_values().if_isinstance(oir.AssignStmt).getattr("left"): result |= left.walk_values().if_isinstance(oir.FieldAccess).getattr("name").to_set() @@ -36,7 +36,7 @@ def _writes(node: oir.Stencil) -> Set[str]: field_extents = compute_fields_extents(node) - names: Set[str] = set() + names: set[str] = set() for name in write_fields: if not field_extents[name].is_zero: names.add(name) @@ -50,8 +50,8 @@ def _writes(node: oir.Stencil) -> Set[str]: class GTIRToOIR(eve.NodeTranslator): @dataclass class Context: - local_scalars: List[oir.ScalarDecl] = field(default_factory=list) - temp_fields: List[oir.FieldDecl] = field(default_factory=list) + local_scalars: list[oir.ScalarDecl] = field(default_factory=list) + temp_fields: list[oir.FieldDecl] = field(default_factory=list) def reset_local_scalars(self): self.local_scalars = [] @@ -132,7 +132,7 @@ def visit_HorizontalRestriction( return oir.HorizontalRestriction(mask=node.mask, body=body) def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While: - body: List[oir.Stmt] = [] + body: list[oir.Stmt] = [] for statement in node.body: oir_statement = self.visit(statement, **kwargs) body.extend(utils.flatten(utils.listify(oir_statement))) @@ -146,12 +146,12 @@ def visit_FieldIfStmt( *, ctx: Context, **kwargs: Any, - ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: + ) -> list[oir.AssignStmt | oir.MaskStmt]: mask_field_decl = oir.Temporary( name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True) ) ctx.temp_fields.append(mask_field_decl) - statements: List[Union[oir.AssignStmt, oir.MaskStmt]] = [ + statements: list[oir.AssignStmt | oir.MaskStmt] = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, @@ -192,7 +192,7 @@ def visit_ScalarIfStmt( *, ctx: Context, **kwargs: Any, - ) -> List[oir.MaskStmt]: + ) -> list[oir.MaskStmt]: condition = self.visit(node.cond) body = utils.flatten( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] @@ -214,7 +214,7 @@ def visit_Interval(self, node: gtir.Interval) -> oir.Interval: # --- Control flow --- def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: - horizontal_executions: List[oir.HorizontalExecution] = [] + horizontal_executions: list[oir.HorizontalExecution] = [] for statement in node.body: ctx.reset_local_scalars() body = utils.flatten(utils.listify(self.visit(statement, ctx=ctx))) diff --git a/src/gt4py/cartesian/gtc/numpy/npir.py b/src/gt4py/cartesian/gtc/numpy/npir.py index 1a7b0816d9..c195bdff98 100644 --- a/src/gt4py/cartesian/gtc/numpy/npir.py +++ b/src/gt4py/cartesian/gtc/numpy/npir.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import List, Optional, Tuple, Union +from typing import Optional from gt4py import eve from gt4py.cartesian.gtc import common @@ -25,8 +25,8 @@ class AxisName(eve.StrEnum): # - They are expressed relative to the iteration domain of the statement # - Each axis is a tuple of two common.AxisBound instead of common.HorizontalInterval class HorizontalMask(eve.Node): - i: Tuple[common.AxisBound, common.AxisBound] - j: Tuple[common.AxisBound, common.AxisBound] + i: tuple[common.AxisBound, common.AxisBound] + j: tuple[common.AxisBound, common.AxisBound] # --- Decls --- @@ -57,8 +57,8 @@ class LocalScalarDecl(Decl): class FieldDecl(Decl): """General field shared across HorizontalBlocks.""" - dimensions: Tuple[bool, bool, bool] - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + dimensions: tuple[bool, bool, bool] + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) extent: Extent @@ -72,9 +72,9 @@ class TemporaryDecl(Decl): padding: Buffer added to compute domain as field size. """ - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) - offset: Tuple[int, int] - padding: Tuple[int, int] + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) + offset: tuple[int, int] + padding: tuple[int, int] # --- Expressions --- @@ -119,13 +119,13 @@ class FieldSlice(VectorLValue): name: eve.Coerced[eve.SymbolRef] i_offset: int j_offset: int - k_offset: Union[int, VarKOffset] - data_index: List[Expr] = eve.field(default_factory=list) + k_offset: int | VarKOffset + data_index: list[Expr] = eve.field(default_factory=list) kind: common.ExprKind = common.ExprKind.FIELD @datamodels.validator("data_index") def data_indices_are_scalar( - self, attribute: datamodels.Attribute, data_index: List[Expr] + self, attribute: datamodels.Attribute, data_index: list[Expr] ) -> None: for index in data_index: if index.kind != common.ExprKind.SCALAR: @@ -143,7 +143,7 @@ class LocalScalarAccess(VectorLValue): class VectorArithmetic(common.BinaryOp[Expr], Expr): - op: Union[common.ArithmeticOperator, common.ComparisonOperator] + op: common.ArithmeticOperator | common.ComparisonOperator _dtype_propagation = common.binary_op_dtype_propagation(strict=True) @@ -188,21 +188,21 @@ class While(common.While[Stmt, Expr], Stmt): # --- Control Flow --- class HorizontalBlock(common.LocNode, eve.SymbolTableTrait): - body: List[Stmt] + body: list[Stmt] extent: Extent - declarations: List[LocalScalarDecl] + declarations: list[LocalScalarDecl] class VerticalPass(common.LocNode): - body: List[HorizontalBlock] + body: list[HorizontalBlock] lower: common.AxisBound upper: common.AxisBound direction: common.LoopOrder class Computation(common.LocNode, eve.SymbolTableTrait): - arguments: List[str] - api_field_decls: List[FieldDecl] - param_decls: List[ScalarDecl] - temp_decls: List[TemporaryDecl] - vertical_passes: List[VerticalPass] + arguments: list[str] + api_field_decls: list[FieldDecl] + param_decls: list[ScalarDecl] + temp_decls: list[TemporaryDecl] + vertical_passes: list[VerticalPass] diff --git a/src/gt4py/cartesian/gtc/numpy/npir_codegen.py b/src/gt4py/cartesian/gtc/numpy/npir_codegen.py index a6e7a81fda..96b8c6dd42 100644 --- a/src/gt4py/cartesian/gtc/numpy/npir_codegen.py +++ b/src/gt4py/cartesian/gtc/numpy/npir_codegen.py @@ -10,8 +10,9 @@ import numbers import textwrap +from collections.abc import Collection from dataclasses import dataclass, field -from typing import Any, Collection, List, Optional, Set, Tuple, Union, cast +from typing import Any, Optional, cast from gt4py import eve from gt4py.cartesian.gtc import common @@ -32,7 +33,7 @@ def _offset_to_str(offset: int) -> str: return "" -def _slice_string(ch: str, offset: int, interval: Tuple[common.AxisBound, common.AxisBound]) -> str: +def _slice_string(ch: str, offset: int, interval: tuple[common.AxisBound, common.AxisBound]) -> str: start_ch = ch if interval[0].level == common.LevelMarker.START else ch.upper() end_ch = ch if interval[1].level == common.LevelMarker.START else ch.upper() @@ -44,11 +45,11 @@ def _slice_string(ch: str, offset: int, interval: Tuple[common.AxisBound, common def _make_slice_access( - offset: Tuple[Optional[int], Optional[int], Union[str, Optional[int]]], + offset: tuple[Optional[int], Optional[int], str | Optional[int]], is_serial: bool, interval: Optional[npir.HorizontalMask] = None, -) -> List[str]: - axes: List[str] = [] +) -> list[str]: + axes: list[str] = [] if interval is None: interval = npir.HorizontalMask( @@ -78,16 +79,14 @@ def _make_slice_access( class NpirCodegen(codegen.TemplatedGenerator, eve.VisitorWithSymbolTableTrait): @dataclass class BlockContext: - locals_declared: Set[str] = field(default_factory=set) + locals_declared: set[str] = field(default_factory=set) def add_declared(self, *args): self.locals_declared |= set(args) FieldDecl = as_fmt("{name} = Field({name}, _origin_['{name}'], ({', '.join(dimensions)}))") - def visit_TemporaryDecl( - self, node: npir.TemporaryDecl, **kwargs - ) -> Union[str, Collection[str]]: + def visit_TemporaryDecl(self, node: npir.TemporaryDecl, **kwargs) -> str | Collection[str]: shape = [f"_dI_ + {node.padding[0]}", f"_dJ_ + {node.padding[1]}", "_dK_"] + [ str(dim) for dim in node.data_dims ] @@ -101,13 +100,13 @@ def visit_TemporaryDecl( VarKOffset = as_fmt("lk + {k}") - def visit_FieldSlice(self, node: npir.FieldSlice, **kwargs: Any) -> Union[str, Collection[str]]: + def visit_FieldSlice(self, node: npir.FieldSlice, **kwargs: Any) -> str | Collection[str]: k_offset = ( self.visit(node.k_offset, **kwargs) if isinstance(node.k_offset, npir.VarKOffset) else node.k_offset ) - offsets: Tuple[Optional[int], Optional[int], Union[str, int, None]] = ( + offsets: tuple[Optional[int], Optional[int], str | int | None] = ( node.i_offset, node.j_offset, k_offset, @@ -118,7 +117,7 @@ def visit_FieldSlice(self, node: npir.FieldSlice, **kwargs: Any) -> Union[str, C decl = kwargs["symtable"][node.name] dimensions = decl.dimensions if isinstance(decl, npir.FieldDecl) else [True] * 3 offsets = cast( - Tuple[Optional[int], Optional[int], Union[str, int, None]], + tuple[Optional[int], Optional[int], str | int | None], tuple(off if has_dim else None for has_dim, off in zip(dimensions, offsets)), ) @@ -136,7 +135,7 @@ def visit_LocalScalarAccess( is_serial: bool, horizontal_mask: Optional[npir.HorizontalMask] = None, **kwargs: Any, - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: args = _make_slice_access((0, 0, 0), is_serial, horizontal_mask) if is_serial: args[2] = ":" @@ -144,16 +143,14 @@ def visit_LocalScalarAccess( ParamAccess = as_fmt("{name}") - def visit_DataType(self, node: common.DataType, **kwargs: Any) -> Union[str, Collection[str]]: + def visit_DataType(self, node: common.DataType, **kwargs: Any) -> str | Collection[str]: # `np.bool` is a deprecated alias for the builtin `bool` or `np.bool_`. if node not in {common.DataType.BOOL}: return f"np.{node.name.lower()}" else: return node.name.lower() - def visit_BuiltInLiteral( - self, node: common.BuiltInLiteral, **kwargs - ) -> Union[str, Collection[str]]: + def visit_BuiltInLiteral(self, node: common.BuiltInLiteral, **kwargs) -> str | Collection[str]: if node is common.BuiltInLiteral.TRUE: return "True" elif node is common.BuiltInLiteral.FALSE: @@ -163,7 +160,7 @@ def visit_BuiltInLiteral( def visit_ScalarLiteral( self, node: npir.ScalarLiteral, *, inside_slice: bool = False, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: # This could be trivial, but it's convenient for reading if the dtype is omitted in slices. dtype = self.visit(node.dtype, inside_slice=inside_slice, **kwargs) value = self.visit(node.value, inside_slice=inside_slice, **kwargs) @@ -175,7 +172,7 @@ def visit_ScalarLiteral( def visit_NativeFunction( self, node: common.NativeFunction, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: return f"ufuncs.{common.OP_TO_UFUNC_NAME[common.NativeFunction][node]}" def visit_NativeFuncCall( @@ -188,7 +185,7 @@ def visit_NativeFuncCall( def visit_VectorAssign( self, node: npir.VectorAssign, *, ctx: BlockContext, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: left = self.visit(node.left, horizontal_mask=node.horizontal_mask, **kwargs) right = self.visit(node.right, horizontal_mask=node.horizontal_mask, **kwargs) return f"{left} = {right}" @@ -199,7 +196,7 @@ def visit_VectorAssign( def visit_UnaryOperator( self, node: common.UnaryOperator, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: if node is common.UnaryOperator.NOT: return "np.bitwise_not" return self.generic_visit(node, **kwargs) @@ -208,12 +205,10 @@ def visit_UnaryOperator( VectorTernaryOp = as_fmt("np.where({cond}, {true_expr}, {false_expr})") - def visit_LevelMarker( - self, node: common.LevelMarker, **kwargs: Any - ) -> Union[str, Collection[str]]: + def visit_LevelMarker(self, node: common.LevelMarker, **kwargs: Any) -> str | Collection[str]: return "K" if node == common.LevelMarker.END else "k" - def visit_AxisBound(self, node: common.AxisBound, **kwargs: Any) -> Union[str, Collection[str]]: + def visit_AxisBound(self, node: common.AxisBound, **kwargs: Any) -> str | Collection[str]: if node.offset > 0: voffset = f" + {node.offset}" elif node.offset == 0: @@ -224,7 +219,7 @@ def visit_AxisBound(self, node: common.AxisBound, **kwargs: Any) -> Union[str, C AxisBound = as_fmt("_d{level}_{voffset}") - def visit_LoopOrder(self, node: common.LoopOrder, **kwargs) -> Union[str, Collection[str]]: + def visit_LoopOrder(self, node: common.LoopOrder, **kwargs) -> str | Collection[str]: if node is common.LoopOrder.FORWARD: return "for k_ in range(k, K):" elif node is common.LoopOrder.BACKWARD: @@ -281,7 +276,7 @@ def visit_VerticalPass(self, node: npir.VerticalPass, **kwargs): def visit_HorizontalBlock( self, node: npir.HorizontalBlock, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: lower = (-node.extent[0][0], -node.extent[1][0]) upper = (node.extent[0][1], node.extent[1][1]) return self.generic_visit(node, lower=lower, upper=upper, ctx=self.BlockContext(), **kwargs) @@ -303,7 +298,7 @@ def visit_HorizontalBlock( def visit_Computation( self, node: npir.Computation, *, ignore_np_errstate: bool = True, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: signature = ["*", *node.arguments, "_domain_", "_origin_"] return self.generic_visit( node, diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index 8606852d87..e56298a775 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional from gt4py import eve from gt4py.cartesian import utils @@ -22,7 +22,7 @@ class OirToNpir(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): # --- Decls --- def visit_FieldDecl( - self, node: oir.FieldDecl, *, field_extents: Dict[str, Extent], **kwargs: Any + self, node: oir.FieldDecl, *, field_extents: dict[str, Extent], **kwargs: Any ) -> npir.FieldDecl: extent = field_extents.get(node.name, Extent.zeros(ndims=2)) return npir.FieldDecl( @@ -40,7 +40,7 @@ def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> npir.LocalS return npir.LocalScalarDecl(name=node.name, dtype=node.dtype) def visit_Temporary( - self, node: oir.Temporary, *, field_extents: Dict[str, Extent], **kwargs: Any + self, node: oir.Temporary, *, field_extents: dict[str, Extent], **kwargs: Any ) -> npir.TemporaryDecl: temp_extent = field_extents[node.name] offset = tuple(-ext[0] for ext in temp_extent) @@ -60,8 +60,8 @@ def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> npir.ScalarLiteral: return npir.ScalarLiteral(value=node.value, dtype=node.dtype, kind=node.kind) def visit_ScalarAccess( - self, node: oir.ScalarAccess, *, symtable: Dict[str, oir.Decl], **kwargs: Any - ) -> Union[npir.ParamAccess, npir.LocalScalarAccess]: + self, node: oir.ScalarAccess, *, symtable: dict[str, oir.Decl], **kwargs: Any + ) -> npir.ParamAccess | npir.LocalScalarAccess: assert node.kind == common.ExprKind.SCALAR if isinstance(symtable[node.name], oir.LocalScalar): return npir.LocalScalarAccess(name=node.name, dtype=symtable[node.name].dtype) @@ -70,12 +70,12 @@ def visit_ScalarAccess( def visit_CartesianOffset( self, node: common.CartesianOffset, **kwargs: Any - ) -> Tuple[int, int, int]: + ) -> tuple[int, int, int]: return node.i, node.j, node.k def visit_VariableKOffset( self, node: oir.VariableKOffset, **kwargs: Any - ) -> Tuple[int, int, eve.Node]: + ) -> tuple[int, int, eve.Node]: return 0, 0, npir.VarKOffset(k=self.visit(node.k, **kwargs)) def visit_FieldAccess(self, node: oir.FieldAccess, **kwargs: Any) -> npir.FieldSlice: @@ -95,7 +95,7 @@ def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> npir.VectorUnaryOp: def visit_BinaryOp( self, node: oir.BinaryOp, **kwargs: Any - ) -> Union[npir.VectorArithmetic, npir.VectorLogic]: + ) -> npir.VectorArithmetic | npir.VectorLogic: args = dict( op=node.op, left=self.visit(node.left, **kwargs), right=self.visit(node.right, **kwargs) ) @@ -111,7 +111,7 @@ def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> npir.VectorTern false_expr=self.visit(node.false_expr, **kwargs), ) - def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> Union[npir.VectorCast, npir.ScalarCast]: + def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> npir.VectorCast | npir.ScalarCast: expr = self.visit(node.expr, **kwargs) args = {"dtype": node.dtype, "expr": expr} return ( @@ -128,7 +128,7 @@ def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> npir. # --- Statements --- def visit_MaskStmt( self, node: oir.MaskStmt, *, mask: Optional[npir.Expr] = None, **kwargs: Any - ) -> List[npir.Stmt]: + ) -> list[npir.Stmt]: mask_expr = self.visit(node.mask, **kwargs) if mask: mask_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=mask_expr) @@ -181,7 +181,7 @@ def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, - block_extents: Optional[Dict[int, Extent]] = None, + block_extents: Optional[dict[int, Extent]] = None, **kwargs: Any, ) -> npir.HorizontalBlock: if block_extents: @@ -204,7 +204,7 @@ def visit_VerticalLoopSection( direction=loop_order, ) - def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> List[npir.VerticalPass]: + def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> list[npir.VerticalPass]: return self.visit(node.sections, loop_order=node.loop_order, **kwargs) def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> npir.Computation: diff --git a/src/gt4py/cartesian/gtc/numpy/scalars_to_temps.py b/src/gt4py/cartesian/gtc/numpy/scalars_to_temps.py index 89afa40fa2..95dd0ebc41 100644 --- a/src/gt4py/cartesian/gtc/numpy/scalars_to_temps.py +++ b/src/gt4py/cartesian/gtc/numpy/scalars_to_temps.py @@ -9,7 +9,6 @@ """An optimization to convert npir.LocalScalarDecl to npir.TemporaryDecl.""" from dataclasses import dataclass -from typing import Dict from gt4py import eve from gt4py.cartesian import utils @@ -30,7 +29,7 @@ def _all_local_scalars_are_unique_type(stencil: npir.Computation) -> bool: stencil.walk_values().if_isinstance(npir.HorizontalBlock).getattr("declarations").to_list() ) - name_to_dtype: Dict[str, common.DataType] = {} + name_to_dtype: dict[str, common.DataType] = {} for decl in all_declarations: if decl.name in name_to_dtype: if decl.dtype != name_to_dtype[decl.name]: @@ -43,7 +42,7 @@ def _all_local_scalars_are_unique_type(stencil: npir.Computation) -> bool: class ScalarsToTemporaries(eve.NodeTranslator): def visit_LocalScalarAccess( - self, node: npir.LocalScalarAccess, *, temps_from_scalars: Dict[str, Temporary] + self, node: npir.LocalScalarAccess, *, temps_from_scalars: dict[str, Temporary] ) -> npir.FieldSlice: return npir.FieldSlice( name=node.name, @@ -54,7 +53,7 @@ def visit_LocalScalarAccess( ) def visit_HorizontalBlock( - self, node: npir.HorizontalBlock, *, temps_from_scalars: Dict[str, Temporary] + self, node: npir.HorizontalBlock, *, temps_from_scalars: dict[str, Temporary] ) -> npir.HorizontalBlock: for decl in node.declarations: if decl.name not in temps_from_scalars: @@ -76,7 +75,7 @@ def visit_Computation(self, node: npir.Computation) -> npir.Computation: "The numpy backend currently assumes this is not the case." ) - temps_from_scalars: Dict[str, Temporary] = {} + temps_from_scalars: dict[str, Temporary] = {} vertical_passes = self.visit(node.vertical_passes, temps_from_scalars=temps_from_scalars) diff --git a/src/gt4py/cartesian/gtc/oir.py b/src/gt4py/cartesian/gtc/oir.py index 0c4225cd26..fc1ff6035c 100644 --- a/src/gt4py/cartesian/gtc/oir.py +++ b/src/gt4py/cartesian/gtc/oir.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from gt4py import eve from gt4py.cartesian.gtc import common @@ -56,7 +56,7 @@ class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): class AssignStmt(common.AssignStmt[Union[ScalarAccess, FieldAccess], Expr], Stmt): @datamodels.validator("left") def no_horizontal_offset_in_assignment( - self, attribute: datamodels.Attribute, value: Union[ScalarAccess, FieldAccess] + self, attribute: datamodels.Attribute, value: ScalarAccess | FieldAccess ) -> None: if isinstance(value, FieldAccess): offsets = value.offset.to_dict() @@ -68,7 +68,7 @@ def no_horizontal_offset_in_assignment( class MaskStmt(Stmt): mask: Expr - body: List[Stmt] + body: list[Stmt] @datamodels.validator("mask") def mask_is_boolean_field_expr(self, attribute: datamodels.Attribute, v: Expr) -> None: @@ -115,8 +115,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class FieldDecl(Decl): - dimensions: Tuple[bool, bool, bool] - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) + dimensions: tuple[bool, bool, bool] + data_dims: tuple[int, ...] = eve.field(default_factory=tuple) class ScalarDecl(Decl): @@ -131,7 +131,7 @@ class Temporary(FieldDecl): pass -def _check_interval(instance: Union[Interval, UnboundedInterval]) -> None: +def _check_interval(instance: Interval | UnboundedInterval) -> None: start, end = instance.start, instance.end if ( start is not None @@ -158,7 +158,7 @@ class Interval(LocNode): @datamodels.root_validator @classmethod - def check(cls: Type[Interval], instance: Interval) -> None: + def check(cls: type[Interval], instance: Interval) -> None: _check_interval(instance) def covers(self, other: Interval) -> bool: @@ -169,7 +169,7 @@ def covers(self, other: Interval) -> bool: def intersects(self, other: Interval) -> bool: return not (other.start >= self.end or self.start >= other.end) - def shifted(self, offset: Optional[int]) -> Union[Interval, UnboundedInterval]: + def shifted(self, offset: Optional[int]) -> Interval | UnboundedInterval: if offset is None: return UnboundedInterval() start = AxisBound(level=self.start.level, offset=self.start.offset + offset) @@ -187,10 +187,10 @@ class UnboundedInterval: @datamodels.root_validator @classmethod - def check(cls: Type[UnboundedInterval], instance: UnboundedInterval) -> None: + def check(cls: type[UnboundedInterval], instance: UnboundedInterval) -> None: _check_interval(instance) - def covers(self, other: Union[Interval, UnboundedInterval]) -> bool: + def covers(self, other: Interval | UnboundedInterval) -> bool: if self.start is None and self.end is None: return True if ( @@ -215,7 +215,7 @@ def covers(self, other: Union[Interval, UnboundedInterval]) -> bool: assert isinstance(other, Interval) return Interval(start=self.start, end=self.end).covers(other) - def intersects(self, other: Union[Interval, UnboundedInterval]) -> bool: + def intersects(self, other: Interval | UnboundedInterval) -> bool: no_overlap_high = ( self.end is not None and other.start is not None and other.start >= self.end ) @@ -246,8 +246,8 @@ def full(cls): class HorizontalExecution(LocNode, eve.SymbolTableTrait): - body: List[Stmt] - declarations: List[LocalScalar] + body: list[Stmt] + declarations: list[LocalScalar] class CacheDesc(LocNode): @@ -265,22 +265,22 @@ class KCache(CacheDesc): class VerticalLoopSection(LocNode): interval: Interval - horizontal_executions: List[HorizontalExecution] + horizontal_executions: list[HorizontalExecution] class VerticalLoop(LocNode): loop_order: common.LoopOrder - sections: List[VerticalLoopSection] - caches: List[CacheDesc] = eve.field(default_factory=list) + sections: list[VerticalLoopSection] + caches: list[CacheDesc] = eve.field(default_factory=list) @datamodels.validator("sections") - def nonempty_loop(self, attribute: datamodels.Attribute, v: List[VerticalLoopSection]) -> None: + def nonempty_loop(self, attribute: datamodels.Attribute, v: list[VerticalLoopSection]) -> None: if not v: raise ValueError("Empty vertical loop is not allowed") @datamodels.root_validator @classmethod - def valid_section_intervals(cls: Type[VerticalLoop], instance: VerticalLoop) -> None: + def valid_section_intervals(cls: type[VerticalLoop], instance: VerticalLoop) -> None: starts, ends = zip(*((s.interval.start, s.interval.end) for s in instance.sections)) if instance.loop_order == common.LoopOrder.BACKWARD: starts, ends = starts[:-1], ends[1:] @@ -297,9 +297,9 @@ def valid_section_intervals(cls: Type[VerticalLoop], instance: VerticalLoop) -> class Stencil(LocNode, eve.ValidatedSymbolTableTrait): name: str # TODO(): fix to be List[Union[ScalarDecl, FieldDecl]] - params: List[Decl] - vertical_loops: List[VerticalLoop] - declarations: List[Temporary] + params: list[Decl] + vertical_loops: list[VerticalLoop] + declarations: list[Temporary] _validate_dtype_is_set = common.validate_dtype_is_set() _validate_lvalue_dims = common.validate_lvalue_dims(VerticalLoop, FieldDecl) diff --git a/src/gt4py/cartesian/gtc/passes/gtir_definitive_assignment_analysis.py b/src/gt4py/cartesian/gtc/passes/gtir_definitive_assignment_analysis.py index d11ad9f441..7c1c5cd6d5 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_definitive_assignment_analysis.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_definitive_assignment_analysis.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import List, Set from gt4py import eve from gt4py.cartesian.gtc import gtir @@ -25,7 +24,7 @@ class DefinitiveAssignmentAnalysis(eve.NodeVisitor): result of the condition. """ - def visit_IfStmt(self, node: gtir.FieldIfStmt, *, alive_vars: Set[str], **kwargs) -> None: + def visit_IfStmt(self, node: gtir.FieldIfStmt, *, alive_vars: set[str], **kwargs) -> None: true_branch_vars = {*alive_vars} false_branch_vars = {*alive_vars} self.visit(node.true_branch, alive_vars=true_branch_vars, **kwargs) @@ -33,7 +32,7 @@ def visit_IfStmt(self, node: gtir.FieldIfStmt, *, alive_vars: Set[str], **kwargs alive_vars.update(true_branch_vars & false_branch_vars) def visit_ParAssignStmt( - self, node: gtir.ParAssignStmt, *, alive_vars: Set[str], **kwargs + self, node: gtir.ParAssignStmt, *, alive_vars: set[str], **kwargs ) -> None: self.visit(node.right, alive_vars=alive_vars, **kwargs) alive_vars.add(node.left.name) @@ -42,17 +41,17 @@ def visit_FieldAccess( self, node: gtir.FieldAccess, *, - alive_vars: Set[str], - invalid_accesses: List[gtir.FieldAccess], + alive_vars: set[str], + invalid_accesses: list[gtir.FieldAccess], **kwargs, ) -> None: if node.name not in alive_vars: invalid_accesses.append(node) @classmethod - def apply(cls, gtir_stencil_expr: gtir.Stencil) -> List[gtir.FieldAccess]: + def apply(cls, gtir_stencil_expr: gtir.Stencil) -> list[gtir.FieldAccess]: """Execute analysis and return all accesses to undefined symbols.""" - invalid_accesses: List[gtir.FieldAccess] = [] + invalid_accesses: list[gtir.FieldAccess] = [] DefinitiveAssignmentAnalysis().visit( gtir_stencil_expr, alive_vars=set(gtir_stencil_expr.param_names), diff --git a/src/gt4py/cartesian/gtc/passes/gtir_dtype_resolver.py b/src/gt4py/cartesian/gtc/passes/gtir_dtype_resolver.py index ade4dc3c1a..57d1ef98d0 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_dtype_resolver.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_dtype_resolver.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict +from typing import Any from gt4py import eve from gt4py.cartesian.gtc import gtir @@ -27,7 +27,7 @@ class _GTIRUpdateAutoDecl(eve.NodeTranslator): """Updates FieldDecls with resolved types.""" def visit_FieldDecl( - self, node: gtir.FieldDecl, new_symbols: Dict[str, Any], **kwargs: Any + self, node: gtir.FieldDecl, new_symbols: dict[str, Any], **kwargs: Any ) -> gtir.FieldDecl: if node.dtype == DataType.AUTO: dtype = new_symbols[node.name].dtype @@ -38,7 +38,7 @@ def visit_FieldDecl( return node def visit_FieldAccess( - self, node: gtir.FieldAccess, *, symtable: Dict[str, Any], **kwargs: Any + self, node: gtir.FieldAccess, *, symtable: dict[str, Any], **kwargs: Any ) -> gtir.FieldAccess: if symtable[node.name].dtype == DataType.AUTO: assert "new_dtype" in kwargs @@ -89,7 +89,7 @@ def visit_FieldAccess(self, node: gtir.FieldAccess, **kwargs: Any) -> gtir.Field ) def visit_ScalarAccess( - self, node: gtir.ScalarAccess, *, symtable: Dict[str, Any], **kwargs: Any + self, node: gtir.ScalarAccess, *, symtable: dict[str, Any], **kwargs: Any ) -> gtir.ScalarAccess: return gtir.ScalarAccess(name=node.name, dtype=symtable[node.name].dtype, loc=node.loc) diff --git a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py index 40c31dca53..2381c22941 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py @@ -8,7 +8,7 @@ import math import typing -from typing import Any, Dict, Tuple, Union +from typing import Any from gt4py import eve from gt4py.cartesian.gtc import gtir @@ -16,7 +16,7 @@ def _iter_field_names( - node: Union[gtir.Stencil, gtir.ParAssignStmt], + node: gtir.Stencil | gtir.ParAssignStmt, ) -> eve.utils.XIterable[gtir.FieldAccess]: return node.walk_values().if_isinstance(gtir.FieldDecl).getattr("name").unique() @@ -24,7 +24,7 @@ def _iter_field_names( class KBoundaryVisitor(eve.NodeVisitor): """For every field compute the boundary in k, e.g. (2, -1) if [k_origin-2, k_origin+k_domain-1] is accessed.""" - def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> Dict[str, Tuple[int, int]]: + def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> dict[str, tuple[int, int]]: field_boundaries = {name: (-math.inf, -math.inf) for name in _iter_field_names(node)} for vloop in node.vertical_loops: self.generic_visit(vloop.body, vloop=vloop, field_boundaries=field_boundaries, **kwargs) @@ -34,13 +34,13 @@ def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> Dict[str, Tuple[in b[0] if b[0] != -math.inf else 0, b[1] if b[1] != -math.inf else 0, ) - return typing.cast(Dict[str, Tuple[int, int]], field_boundaries) + return typing.cast(dict[str, tuple[int, int]], field_boundaries) def visit_FieldAccess( self, node: gtir.FieldAccess, vloop: gtir.VerticalLoop, - field_boundaries: Dict[str, Tuple[Union[float, int], Union[float, int]]], + field_boundaries: dict[str, tuple[float | int, float | int]], **_: Any, ): boundary = field_boundaries[node.name] @@ -64,7 +64,7 @@ def visit_FieldAccess( field_boundaries[node.name] = boundary -def compute_k_boundary(node: gtir.Stencil) -> Dict[str, Tuple[int, int]]: +def compute_k_boundary(node: gtir.Stencil) -> dict[str, tuple[int, int]]: # loop from START to END is not considered as it might be empty. additional check possible in the future return KBoundaryVisitor().visit(node) diff --git a/src/gt4py/cartesian/gtc/passes/gtir_pipeline.py b/src/gt4py/cartesian/gtc/passes/gtir_pipeline.py index ee7974d895..4033e77728 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_pipeline.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_pipeline.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, ClassVar, Dict, Optional, Sequence, Tuple +from collections.abc import Callable, Sequence +from typing import ClassVar, Optional from gt4py.cartesian.definitions import StencilID from gt4py.cartesian.gtc import gtir @@ -29,7 +30,7 @@ class GtirPipeline: """ # Cache pipelines across all instances - _cache: ClassVar[Dict[Tuple[StencilID, Tuple[PASS_T, ...]], gtir.Stencil]] = {} + _cache: ClassVar[dict[tuple[StencilID, tuple[PASS_T, ...]], gtir.Stencil]] = {} def __init__(self, node: gtir.Stencil, stencil_id: StencilID): self.gtir = node diff --git a/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py b/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py index e97ef26396..fa77571bc3 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Any, Callable, Dict, Iterator, List, TypeVar +from collections.abc import Callable, Iterator +from typing import Any, TypeVar import numpy as np @@ -24,7 +25,7 @@ def _upcast_node(target_dtype: DataType, node: Expr) -> Expr: def _upcast_nodes(*exprs: Expr, upcasting_rule: Callable) -> Iterator[Expr]: assert all(e.dtype for e in exprs) - dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None + dtypes: list[DataType] = [e.dtype for e in exprs] # guaranteed to be not None target_dtypes = upcasting_rule(*dtypes) return iter(_upcast_node(target_dtype, arg) for target_dtype, arg in zip(target_dtypes, exprs)) @@ -32,7 +33,7 @@ def _upcast_nodes(*exprs: Expr, upcasting_rule: Callable) -> Iterator[Expr]: _T = TypeVar("_T", bound=eve.Node) -def _update_node(node: _T, updated_children: Dict[str, eve.RootNode]) -> _T: +def _update_node(node: _T, updated_children: dict[str, eve.RootNode]) -> _T: # create new node only if children changed old_children = datamodels.asdict(node) if any([old_children[k] != updated_children[k] for k in updated_children.keys()]): diff --git a/src/gt4py/cartesian/gtc/passes/horizontal_masks.py b/src/gt4py/cartesian/gtc/passes/horizontal_masks.py index d1fc083be8..e20f030b86 100644 --- a/src/gt4py/cartesian/gtc/passes/horizontal_masks.py +++ b/src/gt4py/cartesian/gtc/passes/horizontal_masks.py @@ -6,15 +6,15 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, Tuple +from typing import Optional from gt4py.cartesian.gtc import common from gt4py.cartesian.gtc.definitions import Extent def _overlap_along_axis( - extent: Tuple[int, int], interval: common.HorizontalInterval -) -> Optional[Tuple[int, int]]: + extent: tuple[int, int], interval: common.HorizontalInterval +) -> Optional[tuple[int, int]]: """Return a tuple of the distances to the edge of the compute domain, if overlapping.""" start_diff: Optional[int] end_diff: Optional[int] @@ -59,10 +59,10 @@ def mask_overlap_with_extent( def _compute_relative_interval( - extent: Tuple[int, int], interval: common.HorizontalInterval -) -> Optional[Tuple[common.AxisBound, common.AxisBound]]: + extent: tuple[int, int], interval: common.HorizontalInterval +) -> Optional[tuple[common.AxisBound, common.AxisBound]]: def _offset( - extent: Tuple[int, int], bound: Optional[common.AxisBound], start: bool = True + extent: tuple[int, int], bound: Optional[common.AxisBound], start: bool = True ) -> int: if bound: if start: @@ -98,7 +98,7 @@ def _offset( def compute_relative_mask( extent: Extent, mask: common.HorizontalMask ) -> Optional[ - Tuple[Tuple[common.AxisBound, common.AxisBound], Tuple[common.AxisBound, common.AxisBound]] + tuple[tuple[common.AxisBound, common.AxisBound], tuple[common.AxisBound, common.AxisBound]] ]: """ Output a HorizontalMask that is relative to and always inside the extent instead of the compute domain. diff --git a/src/gt4py/cartesian/gtc/passes/oir_access_kinds.py b/src/gt4py/cartesian/gtc/passes/oir_access_kinds.py index d5774106e6..0290e89d39 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_access_kinds.py +++ b/src/gt4py/cartesian/gtc/passes/oir_access_kinds.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import collections -from typing import Any, Dict +from typing import Any from gt4py import eve from gt4py.cartesian.definitions import AccessKind @@ -19,7 +19,7 @@ class AccessKindComputer(eve.NodeVisitor): def _visit_Access( - self, name, *, access: Dict[str, AccessKind], kind: AccessKind, **kwargs: Any + self, name, *, access: dict[str, AccessKind], kind: AccessKind, **kwargs: Any ) -> None: if kind == AccessKind.WRITE and access.get(name, None) == AccessKind.READ: access[name] = AccessKind.READ_WRITE @@ -56,12 +56,12 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> None: def visit_HorizontalExecution(self, node: oir.HorizontalExecution, **kwargs: Any) -> None: self.generic_visit(node, horizontal_extent=kwargs["block_extents"][id(node)], **kwargs) - def visit_Stencil(self, node: oir.Stencil) -> Dict[str, AccessKind]: - access: Dict[str, AccessKind] = collections.defaultdict(lambda: AccessKind.NONE) + def visit_Stencil(self, node: oir.Stencil) -> dict[str, AccessKind]: + access: dict[str, AccessKind] = collections.defaultdict(lambda: AccessKind.NONE) block_extents = compute_horizontal_block_extents(node) self.generic_visit(node, access=access, block_extents=block_extents) return access -def compute_access_kinds(stencil: oir.Stencil) -> Dict[str, AccessKind]: +def compute_access_kinds(stencil: oir.Stencil) -> dict[str, AccessKind]: return AccessKindComputer().visit(stencil) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py index ecc424b09b..220b0cf2b1 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py @@ -7,8 +7,9 @@ # SPDX-License-Identifier: BSD-3-Clause import collections +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Set, Tuple +from typing import Any from gt4py import eve from gt4py.cartesian.gtc import common, oir @@ -43,7 +44,7 @@ class IJCacheDetection(eve.NodeTranslator): def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, local_tmps: Set[str], **kwargs: Any + self, node: oir.VerticalLoop, *, local_tmps: set[str], **kwargs: Any ) -> oir.VerticalLoop: if node.loop_order != common.LoopOrder.PARALLEL or not local_tmps: return node @@ -51,7 +52,7 @@ def visit_VerticalLoop( def already_cached(field: str) -> bool: return any(c.name == field for c in node.caches) - def has_vertical_offset(offsets: Set[Tuple[int, int, int]]) -> bool: + def has_vertical_offset(offsets: set[tuple[int, int, int]]) -> bool: return any(offset[2] != 0 for offset in offsets) accesses = AccessCollector.apply(node).cartesian_accesses().offsets() @@ -108,17 +109,17 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> oir.Verti if any(off[2] is None for off in offsets) } - def accessed_more_than_once(offsets: Set[Any]) -> bool: + def accessed_more_than_once(offsets: set[Any]) -> bool: return len(offsets) > 1 def already_cached(field: str) -> bool: return field in {c.name for c in node.caches} # TODO(fthaler): k-caches with non-zero ij offsets? - def has_horizontal_offset(offsets: Set[Tuple[int, int, int]]) -> bool: + def has_horizontal_offset(offsets: set[tuple[int, int, int]]) -> bool: return any(offset[:2] != (0, 0) for offset in offsets) - def offsets_within_limits(offsets: Set[Tuple[int, int, int]]) -> bool: + def offsets_within_limits(offsets: set[tuple[int, int, int]]) -> bool: return all(abs(offset[2]) <= self.max_cacheable_offset for offset in offsets) def has_variable_offset_reads(field: str) -> bool: @@ -151,7 +152,7 @@ class PruneKCacheFills(eve.NodeTranslator): If none of the conditions holds for any loop section, the fill is considered as unneeded. """ - def visit_KCache(self, node: oir.KCache, *, pruneable: Set[str], **kwargs: Any) -> oir.KCache: + def visit_KCache(self, node: oir.KCache, *, pruneable: set[str], **kwargs: Any) -> oir.KCache: if node.name in pruneable: return oir.KCache(name=node.name, fill=False, flush=node.flush) return self.generic_visit(node, **kwargs) @@ -162,7 +163,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> oir.Verti return self.generic_visit(node, **kwargs) assert node.loop_order != common.LoopOrder.PARALLEL - def pruneable_fields(section: oir.VerticalLoopSection) -> Set[str]: + def pruneable_fields(section: oir.VerticalLoopSection) -> set[str]: accesses = AccessCollector.apply(section).cartesian_accesses() offsets = accesses.offsets() center_accesses = [a for a in accesses.ordered_accesses() if a.offset == (0, 0, 0)] @@ -222,7 +223,7 @@ class PruneKCacheFlushes(eve.NodeTranslator): * There are no read accesses to the field in a following loop. """ - def visit_KCache(self, node: oir.KCache, *, pruneable: Set[str], **kwargs: Any) -> oir.KCache: + def visit_KCache(self, node: oir.KCache, *, pruneable: set[str], **kwargs: Any) -> oir.KCache: if node.name in pruneable: return oir.KCache(name=node.name, fill=node.fill, flush=False, loc=node.loc) return self.generic_visit(node, **kwargs) @@ -237,7 +238,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: read_only_fields = flushing_fields & ( accesses[i].read_fields() - accesses[i].write_fields() ) - future_reads: Set[str] = set() + future_reads: set[str] = set() future_reads = future_reads.union(*(acc.read_fields() for acc in accesses[i + 1 :])) tmps_without_reuse = ( flushing_fields & {str(d.name) for d in node.declarations} @@ -265,7 +266,7 @@ class FillFlushToLocalKCaches(eve.NodeTranslator, eve.VisitorWithSymbolTableTrai """ def visit_FieldAccess( - self, node: oir.FieldAccess, *, name_map: Dict[str, str], **kwargs: Any + self, node: oir.FieldAccess, *, name_map: dict[str, str], **kwargs: Any ) -> oir.FieldAccess: if node.name in name_map: return oir.FieldAccess( @@ -288,9 +289,9 @@ def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, - name_map: Dict[str, str], - fills: List[oir.Stmt], - flushes: List[oir.Stmt], + name_map: dict[str, str], + fills: list[oir.Stmt], + flushes: list[oir.Stmt], **kwargs: Any, ) -> oir.HorizontalExecution: return oir.HorizontalExecution( @@ -302,7 +303,7 @@ def visit_HorizontalExecution( @staticmethod def _fill_limits( loop_order: common.LoopOrder, section: oir.VerticalLoopSection - ) -> Dict[str, Tuple[int, int]]: + ) -> dict[str, tuple[int, int]]: """Direction-normalized min and max read accesses for each accessed field. Args: @@ -313,7 +314,7 @@ def _fill_limits( A dict, mapping field names to min and max read offsets relative to loop order (i.e., positive means in the direction of the loop order). """ - def directional_k_offset(offset: Tuple[int, int, int]) -> int: + def directional_k_offset(offset: tuple[int, int, int]) -> int: """Positive k-offset for forward loops, negative for backward.""" return offset[2] if loop_order == common.LoopOrder.FORWARD else -offset[2] @@ -338,7 +339,7 @@ def _split_entry_level( loop_order: common.LoopOrder, section: oir.VerticalLoopSection, new_symbol_name: Callable[[str], str], - ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]: + ) -> tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]: """Split the entry level of a loop section. Args: @@ -392,9 +393,9 @@ def _split_section_with_multiple_fills( loop_order: common.LoopOrder, section: oir.VerticalLoopSection, filling_fields: Iterable[str], - first_unfilled: Dict[str, int], + first_unfilled: dict[str, int], new_symbol_name: Callable[[str], str], - ) -> Tuple[Tuple[oir.VerticalLoopSection, ...], Dict[str, int]]: + ) -> tuple[tuple[oir.VerticalLoopSection, ...], dict[str, int]]: """Split loop sections that require multiple fills. Args: @@ -422,10 +423,10 @@ def _fill_stmts( cls, loop_order: common.LoopOrder, section: oir.VerticalLoopSection, - filling_fields: Dict[str, str], - first_unfilled: Dict[str, int], - symtable: Dict[str, Any], - ) -> Tuple[List[oir.AssignStmt], Dict[str, int]]: + filling_fields: dict[str, str], + first_unfilled: dict[str, int], + symtable: dict[str, Any], + ) -> tuple[list[oir.AssignStmt], dict[str, int]]: """Generate fill statements for the given loop section. Args: @@ -464,9 +465,9 @@ def _flush_stmts( cls, loop_order: common.LoopOrder, section: oir.VerticalLoopSection, - flushing_fields: Dict[str, str], - symtable: Dict[str, Any], - ) -> List[oir.AssignStmt]: + flushing_fields: dict[str, str], + symtable: dict[str, Any], + ) -> list[oir.AssignStmt]: """Generate flush statements for the given loop section. Args: @@ -501,17 +502,17 @@ def visit_VerticalLoop( self, node: oir.VerticalLoop, *, - new_tmps: List[oir.Temporary], - symtable: Dict[str, Any], + new_tmps: list[oir.Temporary], + symtable: dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoop: - filling_fields: Dict[str, str] = { + filling_fields: dict[str, str] = { c.name: new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.fill } - flushing_fields: Dict[str, str] = { + flushing_fields: dict[str, str] = { c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.flush @@ -534,8 +535,8 @@ def visit_VerticalLoop( if filling_fields: # split sections where more than one fill operations are required at the entry level - first_unfilled: Dict[str, int] = dict() - split_sections: List[oir.VerticalLoopSection] = [] + first_unfilled: dict[str, int] = dict() + split_sections: list[oir.VerticalLoopSection] = [] for section in node.sections: split_section, _previous_fills = self._split_section_with_multiple_fills( node.loop_order, section, filling_fields, first_unfilled, new_symbol_name @@ -573,7 +574,7 @@ def visit_VerticalLoop( ) def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: - new_tmps: List[oir.Temporary] = [] + new_tmps: list[oir.Temporary] = [] return oir.Stencil( name=node.name, params=node.params, diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py index 6eb87a5c56..8c00940d45 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py @@ -6,8 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional from gt4py import eve from gt4py.cartesian.gtc import common, oir @@ -34,7 +35,7 @@ def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, - block_extents: Dict[int, Extent], + block_extents: dict[int, Extent], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoopSection: @@ -42,8 +43,8 @@ def visit_VerticalLoopSection( class UncheckedHorizontalExecution: # local replacement without type checking for type-checked oir node # required to reach reasonable run times for large node counts - body: List[oir.Stmt] - declarations: List[oir.LocalScalar] + body: list[oir.Stmt] + declarations: list[oir.LocalScalar] loc: Optional[eve.SourceLocation] assert set(oir.HorizontalExecution.__datamodel_fields__.keys()) == { @@ -122,7 +123,7 @@ def to_oir(self) -> oir.HorizontalExecution: ) def visit_ScalarAccess( - self, node: oir.ScalarAccess, *, scalar_map: Dict[str, str], **kwargs: Any + self, node: oir.ScalarAccess, *, scalar_map: dict[str, str], **kwargs: Any ) -> oir.ScalarAccess: return oir.ScalarAccess( name=scalar_map[node.name] if node.name in scalar_map else node.name, @@ -147,7 +148,7 @@ def visit_CartesianOffset( self, node: common.CartesianOffset, *, - shift: Optional[Tuple[int, int, int]] = None, + shift: Optional[tuple[int, int, int]] = None, **kwargs: Any, ) -> common.CartesianOffset: if shift: @@ -159,9 +160,9 @@ def visit_FieldAccess( self, node: oir.FieldAccess, *, - offset_symbol_map: Optional[Dict[Tuple[str, Tuple[int, int, int]], str]] = None, + offset_symbol_map: Optional[dict[tuple[str, tuple[int, int, int]], str]] = None, **kwargs: Any, - ) -> Union[oir.FieldAccess, oir.ScalarAccess]: + ) -> oir.FieldAccess | oir.ScalarAccess: if offset_symbol_map: offset = self.visit(node.offset, **kwargs) key = node.name, (offset.i, offset.j, offset.k) @@ -170,7 +171,7 @@ def visit_FieldAccess( return self.generic_visit(node, **kwargs) def visit_ScalarAccess( - self, node: oir.ScalarAccess, *, scalar_map: Dict[str, str], **kwargs: Any + self, node: oir.ScalarAccess, *, scalar_map: dict[str, str], **kwargs: Any ) -> oir.ScalarAccess: return oir.ScalarAccess( name=scalar_map[node.name] if node.name in scalar_map else node.name, @@ -180,11 +181,11 @@ def visit_ScalarAccess( def _merge( self, - horizontal_executions: List[oir.HorizontalExecution], - symtable: Dict[str, Any], + horizontal_executions: list[oir.HorizontalExecution], + symtable: dict[str, Any], new_symbol_name: Callable[[str], str], - protected_fields: Set[str], - ) -> List[oir.HorizontalExecution]: + protected_fields: set[str], + ) -> list[oir.HorizontalExecution]: """Recursively merge horizontal executions. Uses the following algorithm: @@ -255,7 +256,7 @@ def first_has_horizontal_restriction() -> bool: writes = first_accesses.write_fields() others_otf = [] for horizontal_execution in others: - read_offsets: Set[Tuple[int, int, int]] = set() + read_offsets: set[tuple[int, int, int]] = set() read_offsets = read_offsets.union( *( offsets diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py index 0233e80af1..7715fdcb7a 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy as cp -from typing import Any, Dict, Optional, Set, Union, cast +from typing import Any, Optional, cast from gt4py import eve from gt4py.cartesian.gtc import oir @@ -17,18 +17,18 @@ class MaskCollector(eve.NodeVisitor): """Collects the boolean expressions defining mask statements that are boolean fields.""" def visit_AssignStmt( - self, node: oir.AssignStmt, *, masks_to_inline: Dict[str, oir.Expr] + self, node: oir.AssignStmt, *, masks_to_inline: dict[str, oir.Expr] ) -> None: if node.left.name in masks_to_inline: assert masks_to_inline[node.left.name] is None masks_to_inline[node.left.name] = node.right def visit_MaskStmt( - self, node: oir.MaskStmt, *, masks_to_inline: Dict[str, oir.Expr], **kwargs: Any + self, node: oir.MaskStmt, *, masks_to_inline: dict[str, oir.Expr], **kwargs: Any ) -> None: if isinstance(node.mask, oir.FieldAccess) and node.mask.name in masks_to_inline: # Find all reads in condition - condition_reads: Set[str] = ( + condition_reads: set[str] = ( masks_to_inline[node.mask.name] .walk_values() .if_isinstance(oir.FieldAccess, oir.ScalarAccess) @@ -36,15 +36,15 @@ def visit_MaskStmt( .to_set() ) # Find all writes in body - body_writes: Set[str] = { + body_writes: set[str] = { child.left.name for child in node.body if isinstance(child, oir.AssignStmt) } # Do not inline the mask if there is an intersection if condition_reads.intersection(body_writes): masks_to_inline.pop(node.mask.name) - def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> Dict[str, oir.Expr]: - masks_to_inline: Dict[str, Optional[oir.Expr]] = { + def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> dict[str, oir.Expr]: + masks_to_inline: dict[str, Optional[oir.Expr]] = { mask_stmt.mask.name: None for mask_stmt in node.walk_values() .if_isinstance(oir.MaskStmt) @@ -52,7 +52,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> Dict[str, oir.Expr] } self.visit(node.vertical_loops, masks_to_inline=masks_to_inline, **kwargs) assert all(value is not None for value in masks_to_inline.values()) - return cast(Dict[str, oir.Expr], masks_to_inline) + return cast(dict[str, oir.Expr], masks_to_inline) class MaskInlining(eve.NodeTranslator): @@ -63,15 +63,15 @@ class MaskInlining(eve.NodeTranslator): """ def visit_FieldAccess( - self, node: oir.FieldAccess, *, masks_to_inline: Dict[str, oir.Expr], **kwargs: Any + self, node: oir.FieldAccess, *, masks_to_inline: dict[str, oir.Expr], **kwargs: Any ) -> oir.Expr: if node.name in masks_to_inline: return cp.copy(masks_to_inline[node.name]) return self.generic_visit(node, masks_to_inline=masks_to_inline, **kwargs) def visit_AssignStmt( - self, node: oir.AssignStmt, *, masks_to_inline: Dict[str, oir.Expr], **kwargs: Any - ) -> Union[oir.AssignStmt, eve.NothingType]: + self, node: oir.AssignStmt, *, masks_to_inline: dict[str, oir.Expr], **kwargs: Any + ) -> oir.AssignStmt | eve.NothingType: if node.left.name in masks_to_inline: return eve.NOTHING return self.generic_visit(node, masks_to_inline=masks_to_inline, **kwargs) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py index 465a63a2a8..297e059d5c 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import List from gt4py import eve from gt4py.cartesian.gtc import oir @@ -14,7 +13,7 @@ class MaskStmtMerging(eve.NodeTranslator): - def _merge(self, stmts: List[oir.Stmt]) -> List[oir.Stmt]: + def _merge(self, stmts: list[oir.Stmt]) -> list[oir.Stmt]: merged = [self.visit(stmts[0])] for stmt in stmts[1:]: stmt = self.visit(stmt) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py index 3b93ce4669..f7526a5607 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict +from typing import Any from gt4py import eve from gt4py.cartesian.gtc import oir @@ -67,7 +67,7 @@ def visit_Stencil(self, node: oir.Stencil) -> oir.Stencil: return self.generic_visit(node, block_extents=block_extents) def visit_HorizontalExecution( - self, node: oir.HorizontalExecution, *, block_extents: Dict[int, Extent] + self, node: oir.HorizontalExecution, *, block_extents: dict[int, Extent] ) -> oir.HorizontalExecution: return self.generic_visit(node, block_extent=block_extents[id(node)]) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py index d4dfd6a118..970b9a5665 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import collections -from typing import Any, Callable, Dict, Set, Union +from collections.abc import Callable +from typing import Any from gt4py import eve from gt4py.cartesian.gtc import oir @@ -20,8 +21,8 @@ class TemporariesToScalarsBase(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): def visit_FieldAccess( - self, node: oir.FieldAccess, *, tmps_name_map: Dict[str, str], **kwargs: Any - ) -> Union[oir.FieldAccess, oir.ScalarAccess]: + self, node: oir.FieldAccess, *, tmps_name_map: dict[str, str], **kwargs: Any + ) -> oir.FieldAccess | oir.ScalarAccess: offsets = node.offset.to_dict() if node.name in tmps_name_map: assert offsets["i"] == offsets["j"] == offsets["k"] == 0, ( @@ -33,8 +34,8 @@ def visit_FieldAccess( def visit_HorizontalExecution( self, node: oir.HorizontalExecution, - tmps_to_replace: Set[str], - symtable: Dict[str, Any], + tmps_to_replace: set[str], + symtable: dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.HorizontalExecution: @@ -60,7 +61,7 @@ def visit_HorizontalExecution( ) def visit_VerticalLoop( - self, node: oir.VerticalLoop, tmps_to_replace: Set[str], **kwargs: Any + self, node: oir.VerticalLoop, tmps_to_replace: set[str], **kwargs: Any ) -> oir.VerticalLoop: return oir.VerticalLoop( loop_order=node.loop_order, diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py index 0c78c00c17..73da3dc0c7 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py @@ -10,8 +10,9 @@ import dataclasses import re +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, TypeVar, cast +from typing import Any, Generic, Optional, TypeVar, cast from gt4py import eve from gt4py.cartesian.gtc import common, oir @@ -21,7 +22,7 @@ OffsetT = TypeVar("OffsetT") -GeneralOffsetTuple = Tuple[int, int, Optional[int]] +GeneralOffsetTuple = tuple[int, int, Optional[int]] _digits_at_end_pattern = re.compile(r"[0-9]+$") _generated_name_pattern = re.compile(r".+_gen_[0-9]+") @@ -32,7 +33,7 @@ class GenericAccess(Generic[OffsetT]): field: str offset: OffsetT is_write: bool - data_index: List[oir.Expr] = dataclasses.field(default_factory=list) + data_index: list[oir.Expr] = dataclasses.field(default_factory=list) horizontal_mask: Optional[common.HorizontalMask] = None @property @@ -53,10 +54,10 @@ def to_extent( """ if centered: offset_as_extent = CenteredExtent.from_offset( - cast(Tuple[int, int, int], self.offset)[:2] + cast(tuple[int, int, int], self.offset)[:2] ) else: - offset_as_extent = Extent.from_offset(cast(Tuple[int, int, int], self.offset)[:2]) + offset_as_extent = Extent.from_offset(cast(tuple[int, int, int], self.offset)[:2]) zeros = Extent.zeros(ndims=2) if self.horizontal_mask and not ignore_horizontal_mask: if dist_from_edge := mask_overlap_with_extent(self.horizontal_mask, horizontal_extent): @@ -66,7 +67,7 @@ def to_extent( return horizontal_extent + offset_as_extent -class CartesianAccess(GenericAccess[Tuple[int, int, int]]): +class CartesianAccess(GenericAccess[tuple[int, int, int]]): pass @@ -84,7 +85,7 @@ def visit_FieldAccess( self, node: oir.FieldAccess, *, - accesses: List[GeneralAccess], + accesses: list[GeneralAccess], is_write: bool, horizontal_mask: Optional[common.HorizontalMask] = None, **kwargs: Any, @@ -118,55 +119,55 @@ def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs: @dataclass class GenericAccessCollection(Generic[AccessT, OffsetT]): - _ordered_accesses: List[AccessT] + _ordered_accesses: list[AccessT] @staticmethod - def _offset_dict(accesses: eve.utils.XIterable) -> Dict[str, Set[OffsetT]]: + def _offset_dict(accesses: eve.utils.XIterable) -> dict[str, set[OffsetT]]: return accesses.reduceby( lambda acc, x: acc | {x.offset}, "field", init=set(), as_dict=True ) - def offsets(self) -> Dict[str, Set[OffsetT]]: + def offsets(self) -> dict[str, set[OffsetT]]: """Get a dictionary, mapping all accessed fields' names to sets of offset tuples.""" return self._offset_dict(eve.utils.XIterable(self._ordered_accesses)) - def read_offsets(self) -> Dict[str, Set[OffsetT]]: + def read_offsets(self) -> dict[str, set[OffsetT]]: """Get a dictionary, mapping read fields' names to sets of offset tuples.""" return self._offset_dict( eve.utils.XIterable(self._ordered_accesses).filter(lambda x: x.is_read) ) - def read_accesses(self) -> List[AccessT]: + def read_accesses(self) -> list[AccessT]: """Get the sub-list of read accesses.""" return list(eve.utils.XIterable(self._ordered_accesses).filter(lambda x: x.is_read)) - def write_offsets(self) -> Dict[str, Set[OffsetT]]: + def write_offsets(self) -> dict[str, set[OffsetT]]: """Get a dictionary, mapping written fields' names to sets of offset tuples.""" return self._offset_dict( eve.utils.XIterable(self._ordered_accesses).filter(lambda x: x.is_write) ) - def write_accesses(self) -> List[AccessT]: + def write_accesses(self) -> list[AccessT]: """Get the sub-list of write accesses.""" return list(eve.utils.XIterable(self._ordered_accesses).filter(lambda x: x.is_write)) - def fields(self) -> Set[str]: + def fields(self) -> set[str]: """Get a set of all accessed fields' names.""" return {acc.field for acc in self._ordered_accesses} - def read_fields(self) -> Set[str]: + def read_fields(self) -> set[str]: """Get a set of all read fields' names.""" return {acc.field for acc in self._ordered_accesses if acc.is_read} - def write_fields(self) -> Set[str]: + def write_fields(self) -> set[str]: """Get a set of all written fields' names.""" return {acc.field for acc in self._ordered_accesses if acc.is_write} - def ordered_accesses(self) -> List[AccessT]: + def ordered_accesses(self) -> list[AccessT]: """Get a list of ordered accesses.""" return self._ordered_accesses - class CartesianAccessCollection(GenericAccessCollection[CartesianAccess, Tuple[int, int, int]]): + class CartesianAccessCollection(GenericAccessCollection[CartesianAccess, tuple[int, int, int]]): pass class GeneralAccessCollection(GenericAccessCollection[GeneralAccess, GeneralOffsetTuple]): @@ -175,7 +176,7 @@ def cartesian_accesses(self) -> AccessCollector.CartesianAccessCollection: [ CartesianAccess( field=acc.field, - offset=cast(Tuple[int, int, int], acc.offset), + offset=cast(tuple[int, int, int], acc.offset), data_index=acc.data_index, is_write=acc.is_write, ) @@ -194,7 +195,7 @@ def apply(cls, node: eve.RootNode, **kwargs: Any) -> AccessCollector.GeneralAcce return result -def symbol_name_creator(used_names: Set[str]) -> Callable[[str], str]: +def symbol_name_creator(used_names: set[str]) -> Callable[[str], str]: """Create a function that generates symbol names that are not already in use. Args: @@ -219,7 +220,7 @@ def new_symbol_name(name: str) -> str: return new_symbol_name -def collect_symbol_names(node: eve.RootNode) -> Set[str]: +def collect_symbol_names(node: eve.RootNode) -> set[str]: return ( eve.walk_values(node) .if_isinstance(eve.SymbolTableTrait) @@ -241,8 +242,8 @@ class StencilExtentComputer(eve.NodeVisitor): @dataclass class Context: - fields: Dict[str, Extent] = dataclasses.field(default_factory=dict) - blocks: Dict[int, Extent] = dataclasses.field(default_factory=dict) + fields: dict[str, Extent] = dataclasses.field(default_factory=dict) + blocks: dict[int, Extent] = dataclasses.field(default_factory=dict) def __init__( self, @@ -295,18 +296,18 @@ def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, ctx: Conte ctx.fields[access.field] = extent -def compute_horizontal_block_extents(node: oir.Stencil, **kwargs: Any) -> Dict[int, Extent]: +def compute_horizontal_block_extents(node: oir.Stencil, **kwargs: Any) -> dict[int, Extent]: ctx = StencilExtentComputer(**kwargs).visit(node) return ctx.blocks -def compute_fields_extents(node: oir.Stencil, **kwargs: Any) -> Dict[str, Extent]: +def compute_fields_extents(node: oir.Stencil, **kwargs: Any) -> dict[str, Extent]: ctx = StencilExtentComputer(**kwargs).visit(node) return ctx.fields def compute_extents( node: oir.Stencil, **kwargs: Any -) -> Tuple[Dict[str, Extent], Dict[int, Extent]]: +) -> tuple[dict[str, Extent], dict[int, Extent]]: ctx = StencilExtentComputer(**kwargs).visit(node) return ctx.fields, ctx.blocks diff --git a/src/gt4py/cartesian/gtc/passes/oir_pipeline.py b/src/gt4py/cartesian/gtc/passes/oir_pipeline.py index de5488881c..4a06148098 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_pipeline.py +++ b/src/gt4py/cartesian/gtc/passes/oir_pipeline.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import abstractmethod -from typing import Callable, Optional, Protocol, Sequence, Type, Union +from collections.abc import Callable, Sequence +from typing import Optional, Protocol, Union from gt4py import eve from gt4py.cartesian.gtc import oir @@ -34,7 +35,7 @@ from gt4py.cartesian.gtc.passes.oir_optimizations.vertical_loop_merging import AdjacentLoopMerging -PassT = Union[Callable[[oir.Stencil], oir.Stencil], Type[eve.NodeVisitor]] +PassT = Union[Callable[[oir.Stencil], oir.Stencil], type[eve.NodeVisitor]] class OirPipeline(Protocol): diff --git a/src/gt4py/cartesian/gtc/utils.py b/src/gt4py/cartesian/gtc/utils.py index 2b77afdb39..78c7fc2bfd 100644 --- a/src/gt4py/cartesian/gtc/utils.py +++ b/src/gt4py/cartesian/gtc/utils.py @@ -6,16 +6,17 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Sequence, Tuple +from collections.abc import Sequence +from typing import Any -def dimension_flags_to_names(mask: Tuple[bool, bool, bool]) -> str: +def dimension_flags_to_names(mask: tuple[bool, bool, bool]) -> str: labels = ["i", "j", "k"] selection = [i for i, flag in enumerate(mask) if flag] return "".join(labels[i] for i in selection) -def interpolate_mask(seq: Sequence[Any], mask: Sequence[bool], default) -> Tuple[Any, ...]: +def interpolate_mask(seq: Sequence[Any], mask: Sequence[bool], default) -> tuple[Any, ...]: """ Replace True values by those from the seq in the mask, else default. @@ -30,7 +31,7 @@ def interpolate_mask(seq: Sequence[Any], mask: Sequence[bool], default) -> Tuple return tuple(next(it) if m else default for m in mask) -def filter_mask(seq: Sequence[Any], mask: Sequence[bool]) -> Tuple[Any, ...]: +def filter_mask(seq: Sequence[Any], mask: Sequence[bool]) -> tuple[Any, ...]: """ Return a reduced-size tuple, with indices where mask[i]=False removed. diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 27241e8641..a5dc801a84 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -16,7 +16,8 @@ import inspect import numbers import types -from typing import Callable, Dict, Type, Union +from collections.abc import Callable +from typing import Union import numpy as np @@ -126,7 +127,7 @@ ) -def _set_arg_dtypes(definition: Callable[..., None], dtypes: Dict[Type, Type]): +def _set_arg_dtypes(definition: Callable[..., None], dtypes: dict[type, type]): def _parse_annotation(arg, annotation): # This function evaluates the type hint 'annotation' for the stencil argument 'arg'. # Note that 'typing.get_type_hints()' cannot be used here since field diff --git a/src/gt4py/cartesian/gtscript_imports.py b/src/gt4py/cartesian/gtscript_imports.py index 6fe49f18dd..bb5b66dbfb 100644 --- a/src/gt4py/cartesian/gtscript_imports.py +++ b/src/gt4py/cartesian/gtscript_imports.py @@ -36,9 +36,10 @@ import pathlib import sys import tempfile +from collections.abc import Generator, Iterator from contextlib import contextmanager from types import ModuleType -from typing import Any, Generator, Iterator, List, Optional, Union +from typing import Any, Optional GTS_EXTENSIONS = [".gt.py"] @@ -76,8 +77,8 @@ class GtsFinder(importlib.abc.MetaPathFinder): def __init__( self, - search_path: Optional[List[Union[str, pathlib.Path]]] = None, - generate_path: Optional[Union[str, pathlib.Path]] = None, + search_path: Optional[list[str | pathlib.Path]] = None, + generate_path: Optional[str | pathlib.Path] = None, in_source: bool = False, ): if in_source: @@ -94,7 +95,7 @@ def get_generate_path(self, src_file_path: pathlib.Path) -> pathlib.Path: return self.generate_path or src_file_path.parent def iter_search_candidates( - self, fullname: str, path: Optional[List[pathlib.Path]] + self, fullname: str, path: Optional[list[pathlib.Path]] ) -> Generator[pathlib.Path, None, None]: """Iterate possible source file paths.""" search_paths = [p for p in self.search_path or sys.path] @@ -196,8 +197,8 @@ def create_module(self, spec: importlib.machinery.ModuleSpec) -> ModuleType: def enable( *, - search_path: Optional[List[Union[str, pathlib.Path]]] = None, - generate_path: Optional[Union[str, pathlib.Path]] = None, + search_path: Optional[list[str | pathlib.Path]] = None, + generate_path: Optional[str | pathlib.Path] = None, in_source: bool = False, ) -> GtsFinder: """ diff --git a/src/gt4py/cartesian/lazy_stencil.py b/src/gt4py/cartesian/lazy_stencil.py index a2209e0b7f..c605aa0789 100644 --- a/src/gt4py/cartesian/lazy_stencil.py +++ b/src/gt4py/cartesian/lazy_stencil.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from cached_property import cached_property @@ -55,7 +55,7 @@ def backend(self) -> Backend: return self.builder.backend @property - def field_info(self) -> Dict[str, Any]: + def field_info(self) -> dict[str, Any]: """ Access the compiled stencil object's `field_info` attribute. diff --git a/src/gt4py/cartesian/loader.py b/src/gt4py/cartesian/loader.py index 8b597759c5..cfe5923ac8 100644 --- a/src/gt4py/cartesian/loader.py +++ b/src/gt4py/cartesian/loader.py @@ -58,7 +58,7 @@ def gtscript_loader( dtypes: dict[type, type], ) -> StencilObject: if not isinstance(definition_func, types.FunctionType): - raise ValueError("Invalid stencil definition object ({obj})".format(obj=definition_func)) + raise ValueError(f"Invalid stencil definition object ({definition_func})") if not build_options.name: build_options.name = f"{definition_func.__name__}" diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 3e988149bc..f1a87b7e52 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -12,10 +12,11 @@ import collections.abc import sys import time +from collections.abc import Callable from dataclasses import dataclass from numbers import Number from pickle import dumps -from typing import Any, Callable, ClassVar, Literal, Union, cast +from typing import Any, ClassVar, Literal, Union, cast import numpy as np @@ -293,7 +294,7 @@ def _make_origin_dict( except Exception: pass - raise ValueError("Invalid 'origin' value ({})".format(origin)) + raise ValueError(f"Invalid 'origin' value ({origin})") @staticmethod def _get_max_domain( @@ -370,7 +371,7 @@ def _validate_args( # Function is too complex try: domain = Shape(domain) except Exception as ex: - raise ValueError("Invalid 'domain' value ({})".format(domain)) from ex + raise ValueError(f"Invalid 'domain' value ({domain})") from ex if not domain > Shape.zeros(domain_ndim): raise ValueError(f"Compute domain contains zero sizes '{domain}')") diff --git a/src/gt4py/cartesian/testing/input_strategies.py b/src/gt4py/cartesian/testing/input_strategies.py index 486f0935d9..abf844b217 100644 --- a/src/gt4py/cartesian/testing/input_strategies.py +++ b/src/gt4py/cartesian/testing/input_strategies.py @@ -11,7 +11,8 @@ import enum import itertools import numbers -from typing import Any, Callable, Optional, Sequence, Tuple +from collections.abc import Callable, Sequence +from typing import Any, Optional import hypothesis.strategies as hyp_st import numpy as np @@ -31,17 +32,17 @@ class SymbolKind(enum.Enum): @dataclasses.dataclass(frozen=True) class _SymbolStrategy: kind: SymbolKind - boundary: Optional[Sequence[Tuple[int, int]]] + boundary: Optional[Sequence[tuple[int, int]]] axes: Optional[str] - data_dims: Optional[Tuple[int, ...]] + data_dims: Optional[tuple[int, ...]] value_st_factory: Callable[..., hyp_st.SearchStrategy] @dataclasses.dataclass(frozen=True) class _SymbolValueTuple: kind: str - boundary: Sequence[Tuple[int, int]] - values: Tuple[Any] + boundary: Sequence[tuple[int, int]] + values: tuple[Any] def global_name(*, singleton=None, symbol=None, one_of=None, in_range=None): diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index 61c97b166d..2ef1a73415 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -208,9 +208,7 @@ def hyp_wrapper(test_hyp, hypothesis_data): if test["suite"] == cls_name: name = test["backend"] name += "".join(f"_{key}_{value}" for key, value in test["constants"].items()) - name += "".join( - "_{}_{}".format(key, value.name) for key, value in test["dtypes"].items() - ) + name += "".join(f"_{key}_{value.name}" for key, value in test["dtypes"].items()) marks = test["marks"].copy() if gt_backend.from_name(test["backend"]).storage_info["device"] == "gpu": @@ -242,9 +240,7 @@ def hyp_wrapper(test_hyp, hypothesis_data): if test["suite"] == cls_name: name = test["backend"] name += "".join(f"_{key}_{value}" for key, value in test["constants"].items()) - name += "".join( - "_{}_{}".format(key, value.name) for key, value in test["dtypes"].items() - ) + name += "".join(f"_{key}_{value.name}" for key, value in test["dtypes"].items()) marks = test["marks"].copy() if gt_backend.from_name(test["backend"]).storage_info["device"] == "gpu": @@ -273,9 +269,7 @@ def _validate_new_args(cls, cls_name, cls_dict): missing_members = cls.required_members - cls_dict.keys() if len(missing_members) > 0: raise TypeError( - "Missing {missing} required members in '{name}' definition".format( - missing=missing_members, name=cls_name - ) + f"Missing {missing_members} required members in '{cls_name}' definition" ) # Check class dict domain_range = cls_dict["domain_range"] @@ -309,7 +303,7 @@ def _validate_new_args(cls, cls_name, cls_dict): backends = [pytest.param(b) if isinstance(b, str) else b for b in backends] for b in backends: if b.values[0] not in gt_backend.REGISTRY.names: - raise ValueError("backend '{backend}' not supported".format(backend=b)) + raise ValueError(f"backend '{b}' not supported") # Check definition and validation functions if not isinstance(cls_dict["definition"], types.FunctionType): @@ -602,7 +596,7 @@ def _run_test_implementation(cls, parameters_dict, implementation): # too compl rtol=RTOL, atol=ATOL, equal_nan=EQUAL_NAN, - err_msg="Wrong data in output field '{name}'".format(name=name), + err_msg=f"Wrong data in output field '{name}'", ) @classmethod diff --git a/src/gt4py/cartesian/utils/attrib.py b/src/gt4py/cartesian/utils/attrib.py index cbcec19de6..26ff80a15d 100644 --- a/src/gt4py/cartesian/utils/attrib.py +++ b/src/gt4py/cartesian/utils/attrib.py @@ -28,7 +28,7 @@ def __repr__(self): elif isinstance(a, _TypeDescriptor): arg_names.append(repr(a)) args = "[{}]".format(", ".join(arg_names)) if len(arg_names) > 0 else "" - return "{}{}".format(self.name, args) + return f"{self.name}{args}" @property def validator(self): @@ -92,9 +92,7 @@ def _is_sequence_of_validator(instance, attribute, value): assert isinstance([item_validator(instance, attribute, v) for v in value], list) except Exception as ex: raise ValueError( - "Expr ({value}) does not match the '{name}' specification".format( - value=value, name=attribute.name - ) + f"Expr ({value}) does not match the '{attribute.name}' specification" ) from ex return _is_sequence_of_validator @@ -121,9 +119,7 @@ def _is_dict_of_validator(instance, attribute, value): ) except Exception as ex: raise ValueError( - "Expr ({value}) does not match the '{name}' specification".format( - value=value, name=attribute.name - ) + f"Expr ({value}) does not match the '{attribute.name}' specification" ) from ex return _is_dict_of_validator @@ -146,9 +142,7 @@ def _is_tuple_of_validator(instance, attribute, value): ) except Exception as ex: raise ValueError( - "Expr ({value}) does not match the '{name}' specification".format( - value=value, name=attribute.name - ) + f"Expr ({value}) does not match the '{attribute.name}' specification" ) from ex return _is_tuple_of_validator @@ -171,11 +165,7 @@ def _is_union_of_validator(instance, attribute, value): passed = False if not passed: - raise ValueError( - "Expr ({value}) does not match the '{name}' specification".format( - value=value, name=attribute.name - ) - ) + raise ValueError(f"Expr ({value}) does not match the '{attribute.name}' specification") return _is_union_of_validator @@ -197,10 +187,10 @@ def _is_nothing_validator(instance, attribute, value): Any = _TypeDescriptor("Any", None, _make_any_validator, typing.Any) Sequence = _GenericTypeDescriptor("Sequence", 1, _make_sequence_validator, typing.Sequence) -List = _GenericTypeDescriptor("List", 1, _make_list_validator, typing.List) -Dict = _GenericTypeDescriptor("Dict", 2, _make_dict_validator, typing.Dict) -Set = _GenericTypeDescriptor("Set", 1, _make_set_validator, typing.Set) -Tuple = _GenericTypeDescriptor("Tuple", (1, None), _make_tuple_validator, typing.Tuple) +List = _GenericTypeDescriptor("List", 1, _make_list_validator, list) +Dict = _GenericTypeDescriptor("Dict", 2, _make_dict_validator, dict) +Set = _GenericTypeDescriptor("Set", 1, _make_set_validator, set) +Tuple = _GenericTypeDescriptor("Tuple", (1, None), _make_tuple_validator, tuple) Union = _GenericTypeDescriptor("Union", (2, None), _make_union_validator, typing.Union) Optional = _GenericTypeDescriptor( @@ -222,7 +212,7 @@ def attribute(of, optional=False, **kwargs): attr_type_hint = of else: - raise ValueError("Invalid attribute type '{}'".format(of)) + raise ValueError(f"Invalid attribute type '{of}'") if optional: attr_validator = attr.validators.optional(attr_validator) @@ -263,9 +253,7 @@ def _make_attrs_class_wrapper(cls): for name, member in extra_members.items(): if name in cls.__dict__.keys(): raise ValueError( - "Name clashing with a existing '{name}' member of the decorated class ".format( - name=name - ) + f"Name clashing with a existing '{name}' member of the decorated class " ) setattr(cls, name, member) diff --git a/src/gt4py/cartesian/utils/meta.py b/src/gt4py/cartesian/utils/meta.py index a6ffb0832f..9a683749cd 100644 --- a/src/gt4py/cartesian/utils/meta.py +++ b/src/gt4py/cartesian/utils/meta.py @@ -13,7 +13,8 @@ import inspect import operator import textwrap -from typing import Callable, Dict, Final, List, Tuple, Type +from collections.abc import Callable +from typing import Final from gt4py.cartesian.utils.base import shashed_id @@ -52,7 +53,7 @@ def get_source(func): return source -def get_ast(func_or_source_or_ast, *, feature_version: Tuple[int, int]): +def get_ast(func_or_source_or_ast, *, feature_version: tuple[int, int]): if callable(func_or_source_or_ast): func_or_source_or_ast = get_source(func_or_source_or_ast) if isinstance(func_or_source_or_ast, str): @@ -62,7 +63,7 @@ def get_ast(func_or_source_or_ast, *, feature_version: Tuple[int, int]): if isinstance(func_or_source_or_ast, (ast.AST, list)): ast_root = func_or_source_or_ast else: - raise ValueError("Invalid function definition ({})".format(func_or_source_or_ast)) + raise ValueError(f"Invalid function definition ({func_or_source_or_ast})") return ast_root @@ -71,7 +72,7 @@ def ast_dump( *, skip_annotations: bool = True, skip_decorators: bool = True, - feature_version: Tuple[int, int], + feature_version: tuple[int, int], ) -> str: def _dump(node: ast.AST, excluded_names): if isinstance(node, ast.AST): @@ -91,7 +92,7 @@ def _dump(node: ast.AST, excluded_names): [ node.__class__.__name__, "({content})".format( - content=", ".join("{}={}".format(name, value) for name, value in fields) + content=", ".join(f"{name}={value}" for name, value in fields) ), ] ) @@ -243,7 +244,7 @@ def generic_visit(self, node, **kwargs): class ASTEvaluator(ASTPass): - AST_OP_TO_OP: Final[Dict[Type, Callable]] = { + AST_OP_TO_OP: Final[dict[type, Callable]] = { # Arithmetic operations ast.UAdd: operator.pos, ast.USub: operator.neg, @@ -328,7 +329,7 @@ def visit_Compare(self, node: ast.Compare): return all(comparisons) def generic_visit(self, node): - raise ValueError("Invalid AST node for evaluation: {}".format(repr(node))) + raise ValueError(f"Invalid AST node for evaluation: {node!r}") ast_eval = ASTEvaluator.apply @@ -414,7 +415,7 @@ def visit_Name(self, node: ast.Name): self.name_nodes[node.id].append(node) def _get_name_components(self, node: ast.AST): - components: List + components: list if isinstance(node, ast.Name): components = [node.id] valid = self.prefixes is None or node.id in self.prefixes @@ -466,7 +467,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom): module = node.module for alias in node.names: - name = ".".join([module, alias.name]) + name = f"{module}.{alias.name}" as_name = alias.asname if alias.asname else name imports_dict[name] = as_name diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index d3e8af56d3..72cc003a2a 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -125,7 +125,10 @@ def _get_clang_format() -> Optional[str]: executable = os.getenv("CLANG_FORMAT_EXECUTABLE", "clang-format") try: assert isinstance(executable, str) - if subprocess.run([executable, "--version"], capture_output=True, check=False).returncode != 0: + if ( + subprocess.run([executable, "--version"], capture_output=True, check=False).returncode + != 0 + ): return None except Exception: return None @@ -327,7 +330,7 @@ def indent_str(self) -> str: """Indentation string for new lines (in the current state).""" return self.indent_char * (self.indent_level * self.indent_size) - def __iadd__(self, source_line: Union[str, AnyTextSequence]) -> TextBlock: + def __iadd__(self, source_line: str | AnyTextSequence) -> TextBlock: if isinstance(source_line, str): return self.append(source_line) else: @@ -453,7 +456,7 @@ class StringTemplate(BaseTemplate): definition: string.Template - def __init__(self, definition: Union[str, string.Template], **kwargs: Any) -> None: + def __init__(self, definition: str | string.Template, **kwargs: Any) -> None: super().__init__() if isinstance(definition, str): definition = string.Template(definition) @@ -485,7 +488,7 @@ class JinjaTemplate(BaseTemplate): __jinja_env__ = jinja2.Environment(undefined=jinja2.StrictUndefined) - def __init__(self, definition: Union[str, jinja2.Template], **kwargs: Any) -> None: + def __init__(self, definition: str | jinja2.Template, **kwargs: Any) -> None: super().__init__() try: if isinstance(definition, str): @@ -649,7 +652,7 @@ def apply( # redefinition of symbol @classmethod def apply( # redefinition of symbol cls, root: RootNode, **kwargs: Any - ) -> Union[str, Collection[str]]: + ) -> str | Collection[str]: """Public method to build a class instance and visit an IR node. Args: @@ -678,18 +681,16 @@ def generic_visit(self, node: Node, **kwargs: Any) -> str: ... @overload def generic_visit( self, - node: Union[ - list, - tuple, - collections.abc.Set, - collections.abc.Sequence, - dict, - collections.abc.Mapping, - ], + node: list + | tuple + | collections.abc.Set + | collections.abc.Sequence + | dict + | collections.abc.Mapping, **kwargs: Any, ) -> Collection[str]: ... - def generic_visit(self, node: RootNode, **kwargs: Any) -> Union[str, Collection[str]]: + def generic_visit(self, node: RootNode, **kwargs: Any) -> str | Collection[str]: if isinstance(node, Node): template, key = self.get_template(node) if template: diff --git a/src/gt4py/eve/concepts.py b/src/gt4py/eve/concepts.py index 73ea58aad0..2cc05b2bc4 100644 --- a/src/gt4py/eve/concepts.py +++ b/src/gt4py/eve/concepts.py @@ -93,10 +93,10 @@ class SourceLocationGroup: """A group of merged source code locations (with optional info).""" locations: Tuple[SourceLocation, ...] = datamodels.field(validator=_validators.non_empty()) - context: Optional[Union[str, Tuple[str, ...]]] + context: Optional[str | Tuple[str, ...]] def __init__( - self, *locations: SourceLocation, context: Optional[Union[str, Tuple[str, ...]]] = None + self, *locations: SourceLocation, context: Optional[str | Tuple[str, ...]] = None ) -> None: self.__auto_init__(locations=locations, context=context) # type: ignore[attr-defined] # __auto_init__ added dynamically diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 09eaa871b3..1ceed6deba 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -51,7 +51,6 @@ TypeAlias, TypeAnnotation, TypeVar, - Union, cast, overload, ) @@ -104,13 +103,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: ... class GenericDataModelTP(DataModelTP, Protocol): - __args__: ClassVar[Tuple[Union[Type, TypeVar], ...]] = () + __args__: ClassVar[Tuple[Type | TypeVar, ...]] = () __parameters__: ClassVar[Tuple[TypeVar, ...]] = () @classmethod def __class_getitem__( - cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: ... + cls: Type[GenericDataModelTP], args: Type | Tuple[Type, ...] + ) -> DataModelTP | GenericDataModelTP: ... _DM = TypeVar("_DM", bound="DataModel") @@ -165,7 +164,7 @@ class ForwardRefValidator: factory: type_val.TypeValidatorFactory """Type factory used to create the actual field validator.""" - validator: Union[type_val.FixedTypeValidator, None, NothingType] = NOTHING + validator: type_val.FixedTypeValidator | None | NothingType = NOTHING """Actual type validator created after resolving the forward references.""" def __call__(self, instance: DataModel, attribute: Attribute, value: Any) -> None: @@ -302,7 +301,7 @@ def datamodel( # redefinition of unused symbol coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: +) -> Type[_T] | Callable[[Type[_T]], Type[_T]]: """Add generated special methods to classes according to the specified attributes (class decorator). It converts the class to an `attrs `_ with some extra features. @@ -392,7 +391,7 @@ def __call__( type_validation_factory: Optional[ FieldTypeValidatorFactory ] = DefaultFieldTypeValidatorFactory, - ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: ... + ) -> Type[_T] | Callable[[Type[_T]], Type[_T]]: ... frozenmodel: _DataModelDecoratorTP = functools.partial(datamodel, frozen=True) @@ -624,7 +623,7 @@ def is_generic_datamodel_class(cls: Type) -> bool: return is_datamodel(cls) and xtyping.has_type_parameters(cls) -def get_fields(model: Union[DataModel, Type[DataModel]]) -> utils.FrozenNamespace: +def get_fields(model: DataModel | Type[DataModel]) -> utils.FrozenNamespace: """Return the field meta-information of a Data Model. Arguments: @@ -852,8 +851,8 @@ def _get_attribute_from_bases( def _substitute_typevars( - type_hint: Type, type_params_map: Mapping[TypeVar, Union[Type, TypeVar]] -) -> Tuple[Union[Type, TypeVar], bool]: + type_hint: Type, type_params_map: Mapping[TypeVar, Type | TypeVar] +) -> Tuple[Type | TypeVar, bool]: if isinstance(type_hint, typing.TypeVar): assert type_hint in type_params_map return type_params_map[type_hint], True @@ -938,7 +937,7 @@ def __pretty__( def _make_data_model_class_getitem() -> classmethod: def __class_getitem__( - cls: Type[GenericDataModelT], args: Union[Type, Tuple[Type]] + cls: Type[GenericDataModelT], args: Type | Tuple[Type] ) -> Type[DataModelT] | Type[GenericDataModelT]: """Return an instance compatible with aliases created by :class:`typing.Generic` classes. @@ -1349,8 +1348,8 @@ class FrozenModel(DataModel, frozen=True): class GenericDataModel(GenericDataModelTP): @classmethod def __class_getitem__( - cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: ... + cls: Type[GenericDataModelTP], args: Type | Tuple[Type, ...] + ) -> DataModelTP | GenericDataModelTP: ... else: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 4baacb8cb1..f12345e958 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -245,7 +245,7 @@ class ArrayInterfaceTypedDict(TypedDict): descr: NotRequired[List[Tuple]] data: NotRequired[Tuple[int, bool]] strides: NotRequired[Optional[Tuple[int, ...]]] - mask: NotRequired[Optional["StrictArrayInterface"]] + mask: NotRequired[Optional[StrictArrayInterface]] offset: NotRequired[int] version: int @@ -271,7 +271,7 @@ class CUDAArrayInterfaceTypedDict(TypedDict): version: int strides: NotRequired[Optional[Tuple[int, ...]]] descr: NotRequired[List[Tuple]] - mask: NotRequired[Optional["StrictCUDAArrayInterface"]] + mask: NotRequired[Optional[StrictCUDAArrayInterface]] stream: NotRequired[Optional[int]] diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index 8c84b75bb3..81bfe8e78f 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -102,9 +102,7 @@ def _validate_symbol_refs(cls: Type[SymbolRefsValidatorTrait], instance: concept validator.visit(child_node, symtable=symtable) if validator.missing_symbols: - raise exceptions.EveValueError( - "Symbols {} not found.".format(validator.missing_symbols) - ) + raise exceptions.EveValueError(f"Symbols {validator.missing_symbols} not found.") class SymbolRefsValidator(visitors.NodeVisitor): def __init__(self) -> None: diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 01c388a6c5..24f07cffe0 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -80,7 +80,7 @@ T = TypeVar("T") -def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) -> T: +def first(iterable: Iterable[T], *, default: T | NothingType = NOTHING) -> T: try: return next(iter(iterable)) except StopIteration as error: @@ -89,7 +89,7 @@ def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) -> raise error -def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], bool]: +def isinstancechecker(type_info: Type | Iterable[Type]) -> Callable[[Any], bool]: """Return a callable object that checks if operand is an instance of `type_info`. Examples: @@ -413,7 +413,7 @@ def optional_lru_cache( def optional_lru_cache( func: Optional[Callable[_P, _T]] = None, *, maxsize: Optional[int] = 128, typed: bool = False -) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: +) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: """Wrap :func:`functools.lru_cache` to fall back to the original function if arguments are not hashable. Examples: @@ -465,7 +465,7 @@ def lru_cache( key: Optional[Callable[_P, int]] = None, maxsize: Optional[int] = 128, typed: bool = False, -) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: +) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: """ Wrap :func:`functools.lru_cache` but allow customizing the cache key. @@ -531,7 +531,7 @@ def with_fluid_partial(func: Callable[_P, _T], *args: Any, **kwargs: Any) -> Cal def with_fluid_partial( func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any -) -> Union[Callable[..., Any], Callable[[Callable[..., Any]], Callable[..., Any]]]: +) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: """Add a `partial` attribute to the decorated function. The `partial` attribute is a function that behaves like `functools.partial`, @@ -687,7 +687,7 @@ class CASE_STYLE(enum.Enum): KEBAB = "kebab" @classmethod - def split(cls, name: str, case_style: Union[CASE_STYLE, str]) -> List[str]: + def split(cls, name: str, case_style: CASE_STYLE | str) -> List[str]: if isinstance(case_style, str): case_style = cls.CASE_STYLE(case_style) assert isinstance(case_style, cls.CASE_STYLE) @@ -698,7 +698,7 @@ def split(cls, name: str, case_style: Union[CASE_STYLE, str]) -> List[str]: return splitter(name) @classmethod - def join(cls, words: AnyWordsIterable, case_style: Union[CASE_STYLE, str]) -> str: + def join(cls, words: AnyWordsIterable, case_style: CASE_STYLE | str) -> str: if isinstance(case_style, str): case_style = cls.CASE_STYLE(case_style) assert isinstance(case_style, cls.CASE_STYLE) @@ -712,7 +712,7 @@ def join(cls, words: AnyWordsIterable, case_style: Union[CASE_STYLE, str]) -> st @classmethod def convert( - cls, name: str, source_style: Union[CASE_STYLE, str], target_style: Union[CASE_STYLE, str] + cls, name: str, source_style: CASE_STYLE | str, target_style: CASE_STYLE | str ) -> str: return cls.join(cls.split(name, source_style), target_style) @@ -940,7 +940,7 @@ class XIterable(Iterable[T]): iterator: Iterator[T] - def __init__(self, it: Union[Iterable[T], Iterator[T]]) -> None: + def __init__(self, it: Iterable[T] | Iterator[T]) -> None: object.__setattr__(self, "iterator", iter(it)) def __getattr__(self, name: str) -> Any: @@ -1184,7 +1184,7 @@ def getattr( # A003: shadowing a python builtin """ return XIterable(map(attrgetter_(*names, default=default), self.iterator)) - def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterable[Any]: + def getitem(self, *indices: int | str, default: Any = NOTHING) -> XIterable[Any]: """Get provided indices data from each item in a sequence. Equivalent to ``toolz.itertoolz.pluck(indices, self)``. @@ -1221,7 +1221,7 @@ def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterabl else: return XIterable(toolz.itertoolz.pluck(ind, self.iterator, default)) - def chain(self, *others: Iterable) -> XIterable[Union[T, S]]: + def chain(self, *others: Iterable) -> XIterable[T | S]: """Chain iterators. Equivalent to ``itertools.chain(self, *others)``. @@ -1242,7 +1242,7 @@ def chain(self, *others: Iterable) -> XIterable[Union[T, S]]: return XIterable(itertools.chain(self.iterator, *iterators)) def diff( - self, *others: Iterable, default: Any = NOTHING, key: Union[NOTHING, Callable] = NOTHING + self, *others: Iterable, default: Any = NOTHING, key: NOTHING | Callable = NOTHING ) -> XIterable[Tuple[T, S]]: """Diff iterators. @@ -1285,9 +1285,7 @@ def diff( iterators = [it.iterator if isinstance(it, XIterable) else it for it in others] return XIterable(toolz.itertoolz.diff(self.iterator, *iterators, **kwargs)) - def product( - self, other: Union[Iterable[S], int] - ) -> Union[XIterable[Tuple[T, S]], XIterable[Tuple[T, T]]]: + def product(self, other: Iterable[S] | int) -> XIterable[Tuple[T, S]] | XIterable[Tuple[T, T]]: """Product of iterators. Equivalent to ``itertools.product(it_a, it_b)``. @@ -1435,7 +1433,7 @@ def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: .. def islice( self, __start_or_stop: int, - __stop_or_nothing: Union[int, NothingType] = NOTHING, + __stop_or_nothing: int | NothingType = NOTHING, step: int = 1, ) -> XIterable[T]: """Select elements from an iterable. @@ -1484,7 +1482,7 @@ def select(self, selectors: Iterable[bool]) -> XIterable[T]: raise TypeError(f"Non-iterable 'selectors' value: '{selectors}'.") return XIterable(itertools.compress(self.iterator, selectors)) - def unique(self, *, key: Union[NOTHING, Callable] = NOTHING) -> XIterable[T]: + def unique(self, *, key: NOTHING | Callable = NOTHING) -> XIterable[T]: """Return only unique elements of a sequence. Equivalent to ``toolz.itertoolz.unique(self)``. @@ -1525,8 +1523,8 @@ def groupby( ) -> XIterable[Tuple[Any, List[T]]]: ... def groupby( - self, key: Union[str, List[Any], Callable[[T], Any]], *attr_keys: str, as_dict: bool = False - ) -> Union[XIterable[Tuple[Any, List[T]]], Dict]: + self, key: str | List[Any] | Callable[[T], Any], *attr_keys: str, as_dict: bool = False + ) -> XIterable[Tuple[Any, List[T]]] | Dict: """Group a sequence by a given key. More or less equivalent to ``toolz.itertoolz.groupby(key, self)`` with some caveats. @@ -1650,7 +1648,7 @@ def reduceby( key: str, *, as_dict: Literal[False], - init: Union[S, NothingType], + init: S | NothingType, ) -> XIterable[Tuple[str, S]]: ... @typing.overload @@ -1661,7 +1659,7 @@ def reduceby( __attr_keys1: str, *attr_keys: str, as_dict: Literal[False], - init: Union[S, NothingType], + init: S | NothingType, ) -> XIterable[Tuple[Tuple[str, ...], S]]: ... @typing.overload @@ -1671,7 +1669,7 @@ def reduceby( key: str, *, as_dict: Literal[True], - init: Union[S, NothingType], + init: S | NothingType, ) -> Dict[str, S]: ... @typing.overload @@ -1682,7 +1680,7 @@ def reduceby( __attr_keys1: str, *attr_keys: str, as_dict: Literal[True], - init: Union[S, NothingType], + init: S | NothingType, ) -> Dict[Tuple[str, ...], S]: ... @typing.overload @@ -1692,7 +1690,7 @@ def reduceby( key: List[K], *, as_dict: Literal[False], - init: Union[S, NothingType], + init: S | NothingType, ) -> XIterable[Tuple[K, S]]: ... @typing.overload @@ -1702,7 +1700,7 @@ def reduceby( key: List[K], *, as_dict: Literal[True], - init: Union[S, NothingType], + init: S | NothingType, ) -> Dict[K, S]: ... @typing.overload @@ -1712,7 +1710,7 @@ def reduceby( key: Callable[[T], K], *, as_dict: Literal[False], - init: Union[S, NothingType], + init: S | NothingType, ) -> XIterable[Tuple[K, S]]: ... @typing.overload @@ -1722,24 +1720,24 @@ def reduceby( key: Callable[[T], K], *, as_dict: Literal[True], - init: Union[S, NothingType], + init: S | NothingType, ) -> Dict[K, S]: ... def reduceby( self, bin_op_func: Callable[[S, T], S], - key: Union[str, List[K], Callable[[T], K]], + key: str | List[K] | Callable[[T], K], *attr_keys: str, as_dict: bool = False, - init: Union[S, NothingType] = NOTHING, - ) -> Union[ - XIterable[Tuple[str, S]], - Dict[str, S], - XIterable[Tuple[Tuple[str, ...], S]], - Dict[Tuple[str, ...], S], - XIterable[Tuple[K, S]], - Dict[K, S], - ]: + init: S | NothingType = NOTHING, + ) -> ( + XIterable[Tuple[str, S]] + | Dict[str, S] + | XIterable[Tuple[Tuple[str, ...], S]] + | Dict[Tuple[str, ...], S] + | XIterable[Tuple[K, S]] + | Dict[K, S] + ): """Group a sequence by a given key and simultaneously perform a reduction inside the groups. More or less equivalent to ``toolz.itertoolz.reduceby(key, bin_op_func, self, init)`` diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index ea393e2ad0..8177817502 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Any, Callable, Generic, Optional, ParamSpec, Sequence, TypeVar +from collections.abc import Callable, Sequence +from typing import Any, Generic, Optional, ParamSpec, TypeVar from gt4py import eve from gt4py._core import definitions as core_defs diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 79d188cdf2..57c365f6c9 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -9,8 +9,8 @@ import ast import textwrap import typing +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable from gt4py.eve.concepts import SourceLocation from gt4py.eve.extended_typing import Any, Generic, TypeVar diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index b30b25b309..3378d3e498 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple from gt4py._core import definitions as core_defs from gt4py.next import common @@ -21,10 +20,10 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi @WhereBuiltinFunction def concat_where( cond: common.Domain, - true_field: common.Field | core_defs.ScalarT | Tuple, - false_field: common.Field | core_defs.ScalarT | Tuple, + true_field: common.Field | core_defs.ScalarT | tuple, + false_field: common.Field | core_defs.ScalarT | tuple, /, -) -> common.Field | Tuple: +) -> common.Field | tuple: """ Concatenates two field fields based on a 1D mask. diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index e2f0d4d197..73715e1f2e 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -12,7 +12,8 @@ import math import operator from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in -from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast +from collections.abc import Callable +from typing import Any, Final, Generic, ParamSpec, TypeAlias, TypeVar, Union, cast import numpy as np from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64 @@ -72,7 +73,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ( ts.FunctionType ) # our type of type is currently represented by the type constructor function - elif t is Tuple or (hasattr(t, "__origin__") and t.__origin__ is tuple): + elif t is tuple or (hasattr(t, "__origin__") and t.__origin__ is tuple): return ts.TupleType elif hasattr(t, "__origin__") and t.__origin__ is Union: types = [_type_conversion_helper(e) for e in t.__args__] # type: ignore[attr-defined] @@ -136,8 +137,8 @@ def __gt_type__(self) -> ts.FunctionType: ) -CondT = TypeVar("CondT", bound=Union[common.Field, common.Domain]) -FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) +CondT = TypeVar("CondT", bound=common.Field | common.Domain) +FieldT = TypeVar("FieldT", bound=common.Field | core_defs.Scalar | tuple) class WhereBuiltinFunction( @@ -187,17 +188,17 @@ def broadcast( @WhereBuiltinFunction def where( mask: common.Field, - true_field: common.Field | core_defs.ScalarT | Tuple, - false_field: common.Field | core_defs.ScalarT | Tuple, + true_field: common.Field | core_defs.ScalarT | tuple, + false_field: common.Field | core_defs.ScalarT | tuple, /, -) -> common.Field | Tuple: +) -> common.Field | tuple: raise NotImplementedError() @BuiltInFunction def astype( - value: common.Field | core_defs.ScalarT | Tuple, type_: type, / -) -> common.Field | core_defs.ScalarT | Tuple: + value: common.Field | core_defs.ScalarT | tuple, type_: type, / +) -> common.Field | core_defs.ScalarT | tuple: if isinstance(value, tuple): return tuple(astype(v, type_) for v in value) # default implementation for scalars, Fields are handled via dispatch diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 24a3015ba2..b6ae9b3a1d 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar from gt4py.eve import Coerced, Node, SourceLocation, SymbolName, SymbolRef, datamodels from gt4py.eve.traits import SymbolTableTrait @@ -46,7 +46,7 @@ def __str__(self) -> str: # class Symbol(LocatedNode, Generic[SymbolT]): id: Coerced[SymbolName] - type: Union[SymbolT, ts.DeferredType] # A003 + type: SymbolT | ts.DeferredType # A003 namespace: dialect_ast_enums.Namespace = dialect_ast_enums.Namespace( dialect_ast_enums.Namespace.LOCAL ) @@ -151,11 +151,11 @@ class Stmt(LocatedNode): ... class Starred(Expr): - id: Union[FieldSymbol, TupleSymbol, ScalarSymbol] + id: FieldSymbol | TupleSymbol | ScalarSymbol class Assign(Stmt): - target: Union[FieldSymbol, TupleSymbol, ScalarSymbol] + target: FieldSymbol | TupleSymbol | ScalarSymbol value: Expr @@ -196,13 +196,13 @@ class FunctionDefinition(LocatedNode, SymbolTableTrait): params: list[DataSymbol] body: BlockStmt closure_vars: list[Symbol] - type: Union[ts.FunctionType, ts.DeferredType] = ts.DeferredType(constraint=ts.FunctionType) + type: ts.FunctionType | ts.DeferredType = ts.DeferredType(constraint=ts.FunctionType) class FieldOperator(LocatedNode, SymbolTableTrait): id: Coerced[SymbolName] definition: FunctionDefinition - type: Union[ts_ffront.FieldOperatorType, ts.DeferredType] = ts.DeferredType( + type: ts_ffront.FieldOperatorType | ts.DeferredType = ts.DeferredType( constraint=ts_ffront.FieldOperatorType ) @@ -213,6 +213,6 @@ class ScanOperator(LocatedNode, SymbolTableTrait): forward: Constant init: Constant definition: FunctionDefinition # scan pass - type: Union[ts_ffront.ScanOperatorType, ts.DeferredType] = ts.DeferredType( + type: ts_ffront.ScanOperatorType | ts.DeferredType = ts.DeferredType( constraint=ts_ffront.ScanOperatorType ) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 19954a1778..e14502938f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -709,7 +709,7 @@ def _deduce_binop_type( raise errors.DSLError( node.location, f"{err_msg} Operator " - f"must be one of {', '.join((str(op) for op in logical_ops))}.", + f"must be one of {', '.join(str(op) for op in logical_ops)}.", ) return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) else: @@ -1075,7 +1075,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] - if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): + if not set(arg_dims := type_info.extract_dims(arg_type)).issubset(set(broadcast_dims)): raise errors.DSLError( node.location, f"Incompatible broadcast dimensions in '{node.func!s}': expected " diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 5e145f32bf..e4563b5f3a 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -10,7 +10,7 @@ import enum import textwrap -from typing import Any, Final, TypeAlias, Union +from typing import Any, Final, TypeAlias import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako, TemplatedGenerator @@ -18,7 +18,7 @@ from gt4py.next.type_system import type_specifications as ts -PropertyIdentifier: TypeAlias = Union[type[foast.LocatedNode], tuple[type[foast.LocatedNode], str]] +PropertyIdentifier: TypeAlias = type[foast.LocatedNode] | tuple[type[foast.LocatedNode], str] INDENTATION_PREFIX: Final[str] = " " diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index c1e34e2e57..a05acb082b 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -8,7 +8,8 @@ import dataclasses -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from gt4py import eve from gt4py.eve import utils as eve_utils diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 997e819a1a..00de95dabd 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -12,7 +12,8 @@ import builtins import textwrap import typing -from typing import Any, Callable, Iterable, Mapping, Type +from collections.abc import Callable, Iterable, Mapping +from typing import Any import gt4py.eve as eve from gt4py.next import errors @@ -295,7 +296,7 @@ def visit_Assign( if not isinstance(target, ast.Name): raise errors.DSLError(self.get_location(node), "Can only assign to names.") new_value = self.visit(node.value) - constraint_type: Type[ts.DataType] = ts.DataType + constraint_type: type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): constraint_type = ts.TupleType elif ( diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 7049f70021..e01cb19fb1 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -6,8 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Iterable -from typing import Any, Callable, Optional, TypeVar +from collections.abc import Callable, Iterable +from typing import Any, Optional, TypeVar from gt4py.eve import utils as eve_utils from gt4py.next.ffront import type_info as ti_ffront diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index af2e9807ef..79eacebdbb 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Iterator, Sequence, TypeAlias +from collections.abc import Iterator, Sequence +from typing import Any, TypeAlias from gt4py.next import common, errors from gt4py.next.ffront import ( @@ -122,7 +123,8 @@ def _field_constituents_range_and_dims( yield from _field_constituents_range_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) - if isinstance(arg, ts.TypeSpec): # TODO(): fix yield (tuple(), dims) + if isinstance(arg, ts.TypeSpec): # TODO(): fix + yield (tuple(), dims) elif dims: assert ( hasattr(arg, "domain") diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index ea579aa211..fe2e9966b1 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Generic, Literal, Optional, TypeVar, Union +from typing import Any, Generic, Literal, Optional, TypeVar import gt4py.eve as eve from gt4py.eve import Coerced, Node, SourceLocation, SymbolName, SymbolRef @@ -24,7 +24,7 @@ class LocatedNode(Node): class Symbol(eve.GenericNode, LocatedNode, Generic[SymbolT]): id: Coerced[SymbolName] - type: Union[SymbolT, ts.DeferredType] # A003 + type: SymbolT | ts.DeferredType # A003 namespace: dialect_ast_enums.Namespace = dialect_ast_enums.Namespace( dialect_ast_enums.Namespace.LOCAL ) @@ -82,7 +82,7 @@ class Constant(Expr): class Dict(Expr): - keys_: list[Union[Name | Attribute]] + keys_: list[Name | Attribute] values_: list[TupleExpr] @@ -97,7 +97,7 @@ class Stmt(LocatedNode): ... class Program(LocatedNode, SymbolTableTrait): id: Coerced[SymbolName] - type: Union[ts_ffront.ProgramType, ts.DeferredType] # A003 + type: ts_ffront.ProgramType | ts.DeferredType # A003 params: list[DataSymbol] body: list[Call] closure_vars: list[Symbol] diff --git a/src/gt4py/next/ffront/signature.py b/src/gt4py/next/ffront/signature.py index 4a58d56f57..0759f9c25e 100644 --- a/src/gt4py/next/ffront/signature.py +++ b/src/gt4py/next/ffront/signature.py @@ -18,7 +18,8 @@ import functools import inspect import types -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from gt4py.next.ffront import ( field_operator_ast as foast, diff --git a/src/gt4py/next/ffront/transform_utils.py b/src/gt4py/next/ffront/transform_utils.py index e6299ce302..443ca6f491 100644 --- a/src/gt4py/next/ffront/transform_utils.py +++ b/src/gt4py/next/ffront/transform_utils.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import collections -from typing import Any, Iterable, Optional +from collections.abc import Iterable +from typing import Any, Optional from gt4py.next import common from gt4py.next.ffront import fbuiltins diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 3717f1a7db..8324fa72c9 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -6,8 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Iterator, Sequence from functools import reduce -from typing import Iterator, Sequence, cast +from typing import cast import gt4py.next.ffront.type_specifications as ts_ffront import gt4py.next.type_system.type_specifications as ts diff --git a/src/gt4py/next/iterator/dispatcher.py b/src/gt4py/next/iterator/dispatcher.py index d362c99c22..b88fcc7d8b 100644 --- a/src/gt4py/next/iterator/dispatcher.py +++ b/src/gt4py/next/iterator/dispatcher.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Any, Callable, Dict, List +from collections.abc import Callable +from typing import Any # TODO(): test @@ -38,8 +39,8 @@ def _impl(fun): class Dispatcher: def __init__(self) -> None: - self._funs: Dict[str, Dict[str, Callable]] = {} - self.key_stack: List[str] = [] + self._funs: dict[str, dict[str, Callable]] = {} + self.key_stack: list[str] = [] @property def key(self): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index bfaa7f60ce..f3432aa5d7 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -41,7 +41,6 @@ TypeAlias, TypeGuard, TypeVar, - Union, cast, overload, runtime_checkable, @@ -83,7 +82,7 @@ FieldAxis: TypeAlias = common.Dimension TupleAxis: TypeAlias = type[None] -Axis: TypeAlias = Union[FieldAxis, TupleAxis] +Axis: TypeAlias = FieldAxis | TupleAxis Scalar: TypeAlias = ( SupportsInt | SupportsFloat @@ -227,7 +226,7 @@ def skip_value( ConcretePosition: TypeAlias = dict[Tag, PositionEntry] IncompletePosition: TypeAlias = dict[Tag, IncompletePositionEntry] -Position: TypeAlias = Union[ConcretePosition, IncompletePosition] +Position: TypeAlias = ConcretePosition | IncompletePosition #: A ``None`` position flags invalid not-a-neighbor results in neighbor-table lookups MaybePosition: TypeAlias = Optional[Position] @@ -1397,7 +1396,7 @@ def constant_field(value: Any, dtype_like: Optional[core_defs.DTypeLike] = None) @builtins.shift.register(EMBEDDED) -def shift(*offsets: Union[runtime.Offset, int]) -> Callable[[ItIterator], ItIterator]: +def shift(*offsets: runtime.Offset | int) -> Callable[[ItIterator], ItIterator]: def impl(it: ItIterator) -> ItIterator: return it.shift(*list(o.value if isinstance(o, runtime.Offset) else o for o in offsets)) @@ -1468,7 +1467,7 @@ def list_get(i, lst: _List[Optional[DT]]) -> Optional[DT]: def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]: - offsets = set((lst.offset for lst in lists if hasattr(lst, "offset"))) + offsets = set(lst.offset for lst in lists if hasattr(lst, "offset")) if len(offsets) == 0: return None if len(offsets) == 1: diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e25eaeee1e..56f06f6cde 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -8,7 +8,7 @@ from __future__ import annotations import typing -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, ClassVar, Optional import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef @@ -82,7 +82,7 @@ def __str__(self): class OffsetLiteral(Expr): - value: Union[int, str] + value: int | str class AxisLiteral(Expr): @@ -97,18 +97,18 @@ class SymRef(Expr): class Lambda(Expr, SymbolTableTrait): - params: List[Sym] + params: list[Sym] expr: Expr class FunCall(Expr): fun: Expr # VType[Callable] - args: List[Expr] + args: list[Expr] class FunctionDefinition(Node, SymbolTableTrait): id: Coerced[SymbolName] - params: List[Sym] + params: list[Sym] expr: Expr @@ -135,13 +135,13 @@ class Temporary(Node): class Program(Node, ValidatedSymbolTableTrait): id: Coerced[SymbolName] - function_definitions: List[FunctionDefinition] - params: List[Sym] - declarations: List[Temporary] - body: List[Stmt] + function_definitions: list[FunctionDefinition] + params: list[Sym] + declarations: list[Temporary] + body: list[Stmt] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + _NODE_SYMBOLS_: ClassVar[list[Sym]] = [ Sym(id=name) for name in sorted(BUILTINS) ] # sorted for serialization stability diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index da13d20bb6..7c1ee9ae7e 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Generic, List, TypeAlias, TypeGuard, TypeVar +from typing import Any, Generic, TypeAlias, TypeGuard, TypeVar from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -18,7 +18,7 @@ class _FunCallTo(itir.FunCall, Generic[_Fun]): fun: _Fun - args: List[itir.Expr] + args: list[itir.Expr] _FunCallToSymRef: TypeAlias = _FunCallTo[itir.SymRef] @@ -49,7 +49,7 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[_FunCallToSymRe and node.fun.id == fun ) else: - return any((is_call_to(node, f) for f in fun)) + return any(is_call_to(node, f) for f in fun) _FunCallToFunCallToRef: TypeAlias = _FunCallTo[_FunCallToSymRef] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 52853899c4..81e389bd16 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,7 +10,8 @@ import dataclasses import functools -from typing import Any, Callable, Iterable, Literal, Mapping, Optional +from collections.abc import Callable, Iterable, Mapping +from typing import Any, Literal, Optional from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index aeae9f0e6c..6bbd52a229 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import typing -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional from gt4py._core import definitions as core_defs from gt4py.next import common @@ -15,7 +16,7 @@ from gt4py.next.type_system import type_specifications as ts, type_translation -def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = None) -> itir.Sym: +def sym(sym_or_name: str | itir.Sym, type_: str | ts.TypeSpec | None = None) -> itir.Sym: """ Convert to Sym if necessary. @@ -38,7 +39,7 @@ def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = Non def ref( - ref_or_name: Union[str, itir.SymRef], + ref_or_name: str | itir.SymRef, type_: str | ts.TypeSpec | None = None, annex: dict[str, Any] | None = None, ) -> itir.SymRef: @@ -68,7 +69,7 @@ def ref( return ref -def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> itir.Expr: +def ensure_expr(literal_or_expr: str | core_defs.Scalar | itir.Expr) -> itir.Expr: """ Convert literals into a SymRef or Literal and let expressions pass unchanged. @@ -93,7 +94,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti return literal_or_expr -def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.OffsetLiteral: +def ensure_offset(str_or_offset: str | int | itir.OffsetLiteral) -> itir.OffsetLiteral: """ Convert Python literals into an OffsetLiteral and let OffsetLiterals pass unchanged. @@ -447,7 +448,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( - grid_type: Union[common.GridType, str], + grid_type: common.GridType | str, ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]], ) -> itir.FunCall: """ diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 00ff9abbd9..c8632c4482 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -8,7 +8,8 @@ import dataclasses from collections import ChainMap -from typing import Callable, Iterable, TypeVar +from collections.abc import Callable, Iterable +from typing import TypeVar from gt4py import eve from gt4py.eve import utils as eve_utils diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index a077b39911..264dc36bb2 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Union from lark import lark, lexer as lark_lexer, tree as lark_tree, visitors as lark_visitors @@ -101,7 +100,7 @@ class ToIrTransformer(lark_visitors.Transformer): def SYM(self, value: lark_lexer.Token) -> ir.Sym: return ir.Sym(id=value.value) - def SYM_REF(self, value: lark_lexer.Token) -> Union[ir.SymRef, ir.Literal]: + def SYM_REF(self, value: lark_lexer.Token) -> ir.SymRef | ir.Literal: if value.value in ("True", "False"): return im.literal(value.value, "bool") return ir.SymRef(id=value.value) @@ -118,7 +117,7 @@ def TYPE_LITERAL(self, value: lark_lexer.Token) -> ts.TypeSpec: raise NotImplementedError(f"Type {value} not supported.") def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: - v: Union[int, str] = value.value[:-1] + v: int | str = value.value[:-1] try: v = int(v) except ValueError: diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 699b7f2b4c..5ea523eeb1 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -11,8 +11,9 @@ import dataclasses import functools import types +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional import devtools @@ -33,7 +34,7 @@ @dataclass(frozen=True) class Offset: - value: Union[int, str] + value: int | str def offset(value): diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index bc1d244f88..9534304062 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -9,7 +9,8 @@ import dataclasses import inspect import typing -from typing import Callable, ClassVar, List +from collections.abc import Callable +from typing import ClassVar from gt4py._core import definitions as core_defs from gt4py.eve import Node, utils as eve_utils @@ -202,9 +203,9 @@ def __bool__(self): class TracerContext: - fundefs: ClassVar[List[FunctionDefinition]] = [] - body: ClassVar[List[itir.Stmt]] = [] - declarations: ClassVar[List[itir.Temporary]] = [] + fundefs: ClassVar[list[FunctionDefinition]] = [] + body: ClassVar[list[itir.Stmt]] = [] + declarations: ClassVar[list[itir.Temporary]] = [] @classmethod def add_fundef(cls, fun): @@ -238,8 +239,8 @@ def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None: def if_stmt( cond: itir.Expr, true_branch_f: typing.Callable, false_branch_f: typing.Callable ) -> None: - true_branch: List[itir.Stmt] = [] - false_branch: List[itir.Stmt] = [] + true_branch: list[itir.Stmt] = [] + false_branch: list[itir.Stmt] = [] old_body = TracerContext.body TracerContext.body = true_branch diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 2fcbd5df0d..c90fafe67a 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -11,7 +11,8 @@ import collections import dataclasses import math -from typing import Callable, Iterable, TypeVar, Union, cast +from collections.abc import Callable, Iterable +from typing import TypeVar, cast import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( @@ -268,7 +269,7 @@ def extract_subexpression( uid_generator: UIDGenerator, once_only: bool = False, deepest_expr_first: bool = False, -) -> tuple[itir.Expr, Union[dict[itir.Sym, itir.Expr], None], bool]: +) -> tuple[itir.Expr, dict[itir.Sym, itir.Expr] | None, bool]: """ Given an expression extract all subexprs and return a new expr with the subexprs replaced. diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index 3818f3864a..d737ac5b1b 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -8,7 +8,7 @@ import dataclasses import enum -from typing import ClassVar, Optional, Type +from typing import ClassVar, Optional from gt4py import eve from gt4py.next.iterator import ir @@ -65,7 +65,7 @@ class CombinedFixedPointTransform(FixedPointTransformation): #: Enum of all transformation (names). The transformations need to be defined as methods #: named `transform_`. - Transformation: ClassVar[Type[enum.Flag]] + Transformation: ClassVar[type[enum.Flag]] #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. #: Usually the default value is chosen to be all transformations. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index b3c81ca2d0..20a11f0808 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -9,8 +9,8 @@ from __future__ import annotations import functools -from collections.abc import Sequence -from typing import Callable, Literal, Optional, cast +from collections.abc import Callable, Sequence +from typing import Literal, Optional, cast from gt4py.eve import utils as eve_utils from gt4py.next import common, utils as next_utils diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 2b8767e4a2..ee75af1fa2 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict +from typing import Any from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir as itir @@ -14,7 +14,7 @@ class InlineFundefs(PreserveLocationVisitor, NodeTranslator): - def visit_SymRef(self, node: itir.SymRef, *, symtable: Dict[str, Any]): + def visit_SymRef(self, node: itir.SymRef, *, symtable: dict[str, Any]): if node.id in symtable and isinstance( (symbol := symtable[node.id]), itir.FunctionDefinition ): diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index 33e36bfa4b..bb93818be5 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause # FIXME[#1582](tehrengruber): This transformation is not used anymore. Decide on its fate. -from typing import Sequence, TypeGuard +from collections.abc import Sequence +from typing import TypeGuard from gt4py import eve from gt4py.eve import NodeTranslator, traits diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index a41f74ebc1..f220868639 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Mapping, Optional, TypeVar +from collections.abc import Mapping +from typing import Optional, TypeVar from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8811f8ea25..118b4bccba 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -8,7 +8,8 @@ import dataclasses import enum -from typing import Callable, ClassVar, Optional +from collections.abc import Callable +from typing import ClassVar, Optional import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 5495f63ae1..866685c704 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Optional, Set +from typing import Any, Optional from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir @@ -17,10 +17,10 @@ class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") - def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): + def visit_SymRef(self, node: ir.SymRef, *, symbol_map: dict[str, ir.Node]): return symbol_map.get(str(node.id), node) - def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): + def visit_Lambda(self, node: ir.Lambda, *, symbol_map: dict[str, ir.Node]): params = {str(p.id) for p in node.params} new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params} return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map)) @@ -37,14 +37,14 @@ class RenameSymbols(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_Sym( - self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + self, node: ir.Sym, *, name_map: dict[str, str], active: Optional[set[str]] = None ): if active and node.id in active: return ir.Sym(id=name_map.get(node.id, node.id)) return node def visit_SymRef( - self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + self, node: ir.SymRef, *, name_map: dict[str, str], active: Optional[set[str]] = None ): if active and node.id in active: new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) @@ -53,7 +53,7 @@ def visit_SymRef( return node def generic_visit( # type: ignore[override] - self, node: ir.Node, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + self, node: ir.Node, *, name_map: dict[str, str], active: Optional[set[str]] = None ): if isinstance(node, SymbolTableTrait): if active is None: diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 8173ceebbb..675713439d 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -8,8 +8,8 @@ import dataclasses import sys -from collections.abc import Callable -from typing import Any, Final, Iterable, Literal, Optional +from collections.abc import Callable, Iterable +from typing import Any, Final, Literal, Optional from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 9d04961638..cbcfb4c720 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -15,7 +15,7 @@ from gt4py import eve from gt4py.eve import concepts -from gt4py.eve.extended_typing import Any, Callable, Optional, TypeVar, Union +from gt4py.eve.extended_typing import Any, Callable, Optional, TypeVar from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_as_fieldop, is_call_to @@ -62,7 +62,7 @@ def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> N _set_node_type(to, from_.type) -def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: +def on_inferred(callback: Callable, *args: ts.TypeSpec | ObservableTypeSynthesizer) -> None: """ Execute `callback` as soon as all `args` have a type. """ @@ -184,7 +184,7 @@ def __call__( *args: type_synthesizer.TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType, **kwargs, - ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: + ) -> ts.TypeSpec | ObservableTypeSynthesizer: assert all(isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args), ( "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" ) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index ce99532645..e23d918c3e 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -87,7 +87,7 @@ def type_synthesizer(*, cache: bool = False) -> Callable[[F], TypeSynthesizer]: def type_synthesizer( fun: Optional[F] = None, cache: bool = False -) -> Union[TypeSynthesizer, Callable[[F], TypeSynthesizer]]: +) -> TypeSynthesizer | Callable[[F], TypeSynthesizer]: if fun is None: return functools.partial(TypeSynthesizer, cache=cache) return TypeSynthesizer(fun, cache=cache) diff --git a/src/gt4py/next/metrics.py b/src/gt4py/next/metrics.py index 48a399791d..e54470dd57 100644 --- a/src/gt4py/next/metrics.py +++ b/src/gt4py/next/metrics.py @@ -131,9 +131,7 @@ def metric_names(self) -> Sequence[str]: """Returns a list of all metric names across all collections in the store.""" return list( - dict.fromkeys( - (name for collection in self.values() for name in collection.keys()) - ).keys() + dict.fromkeys(name for collection in self.values() for name in collection.keys()).keys() ) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 39540baebb..4fdeeeda91 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -10,7 +10,8 @@ import dataclasses import typing -from typing import Any, Generic, Iterable, Iterator, Optional +from collections.abc import Iterable, Iterator +from typing import Any, Generic, Optional from typing_extensions import Self diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index b9058350a3..548f6b4858 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Final, Sequence +from collections.abc import Sequence +from typing import Final from gt4py.next.otf import cpp_utils, languages from gt4py.next.otf.binding import interface diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 326863f395..678f7f12e3 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -10,8 +10,8 @@ from __future__ import annotations -from collections.abc import Collection -from typing import Any, Optional, Sequence, TypeVar, Union +from collections.abc import Collection, Sequence +from typing import Any, Optional, TypeVar import gt4py.eve as eve from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator @@ -40,7 +40,7 @@ class BufferSID(Expr): class Tuple(Expr): - elems: list[Union[Expr, str]] + elems: list[Expr | str] class FunctionCall(Expr): @@ -64,7 +64,7 @@ class FunctionParameter(eve.Node): class WrapperFunction(eve.Node): name: str parameters: Sequence[FunctionParameter] - body: Union[ExprStmt, ReturnStmt] + body: ExprStmt | ReturnStmt on_device: bool = False @@ -121,9 +121,7 @@ class BindingCodeGenerator(TemplatedGenerator): """ ) - def visit_WrapperFunction( - self, node: WrapperFunction, **kwargs: Any - ) -> Union[str, Collection[str]]: + def visit_WrapperFunction(self, node: WrapperFunction, **kwargs: Any) -> str | Collection[str]: return_stmt = "return _gt4py_return;" if isinstance(node.body, ReturnStmt) else "" return self.generic_visit(node, return_stmt=return_stmt) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index b05cf1af89..8b5887d05f 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Sequence +from collections.abc import Sequence import gt4py.eve as eve from gt4py.eve.codegen import JinjaTemplate as as_jinja diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 73aa578453..1b916aa674 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -247,7 +247,7 @@ def _cc_prototype_program_name( base_name = "compile_commands_cache" deps_str = "_".join(f"{dep.name}_{dep.version}" for dep in deps) flags_str = "_".join(re.sub(r"\W+", "", f) for f in flags) - return "_".join([base_name, deps_str, build_type, flags_str]).replace(".", "_") + return f"{base_name}_{deps_str}_{build_type}_{flags_str}".replace(".", "_") def _cc_prototype_program_source( diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 1b1749d5c1..79a7cbe89b 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -12,8 +12,8 @@ import dataclasses import functools import typing -from collections.abc import MutableMapping -from typing import Any, Callable, Generic, Protocol, TypeVar +from collections.abc import Callable, MutableMapping +from typing import Any, Generic, Protocol, TypeVar from typing_extensions import Self diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 88a39cfa00..df8c3a2f30 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Collection, Final, Union +from collections.abc import Collection +from typing import Any, Final from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako @@ -250,7 +251,7 @@ def visit_TemporaryAllocation(self, node: gtfn_ir.TemporaryAllocation, **kwargs: "auto {id} = gridtools::sid::shift_sid_origin(gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes}), {shifts});" ) - def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Collection[str]]: + def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> str | Collection[str]: self.is_cartesian = node.grid_type == common.GridType.CARTESIAN self.user_defined_function_ids = list( str(fundef.id) for fundef in node.function_definitions diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py index b6a9b565ee..9805cc908c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing import List, Union - from gt4py.eve import Coerced, Node, SymbolName from gt4py.eve.traits import SymbolTableTrait from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Sym, SymRef @@ -20,7 +18,7 @@ class Stmt(Node): ... class AssignStmt(Stmt): op: str = "=" - lhs: Union[Sym, SymRef] + lhs: Sym | SymRef rhs: Expr @@ -45,5 +43,5 @@ class ReturnStmt(Stmt): class ImperativeFunctionDefinition(Node, SymbolTableTrait): id: Coerced[SymbolName] - params: List[Sym] - fun: List[Stmt] + params: list[Sym] + fun: list[Stmt] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index f7445461c0..268274cd7c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,7 +8,8 @@ from __future__ import annotations -from typing import Callable, ClassVar, Optional, Union +from collections.abc import Callable +from typing import ClassVar, Optional from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait @@ -50,7 +51,7 @@ class IntegralConstant(Expr): class OffsetLiteral(Expr): - value: Union[int, str] + value: int | str class Lambda(Expr, SymbolTableTrait): @@ -94,7 +95,7 @@ class UnstructuredDomain(Node): class Backend(Node): - domain: Union[SymRef, CartesianDomain, UnstructuredDomain] + domain: SymRef | CartesianDomain | UnstructuredDomain def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: @@ -175,8 +176,8 @@ class Stmt(Node): class StencilExecution(Stmt): backend: Backend stencil: SymRef - output: Union[SymRef, SidComposite] - inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] + output: SymRef | SidComposite + inputs: list[SymRef | SidComposite | SidFromScalar | FunCall] @datamodels.validator("inputs") def _arg_validator( @@ -220,7 +221,7 @@ class IfStmt(Stmt): class TemporaryAllocation(Node): id: SymbolName dtype: str - domain: Union[SymRef, CartesianDomain, UnstructuredDomain] + domain: SymRef | CartesianDomain | UnstructuredDomain GTFN_BUILTINS = [ @@ -243,14 +244,14 @@ class TemporaryAllocation(Node): class TagDefinition(Node): name: Sym - alias: Optional[Union[str, SymRef]] = None + alias: Optional[str | SymRef] = None class Program(Node, ValidatedSymbolTableTrait): id: SymbolName params: list[Sym] function_definitions: list[ - Union[FunctionDefinition, ScanPassDefinition, ImperativeFunctionDefinition] + FunctionDefinition | ScanPassDefinition | ImperativeFunctionDefinition ] executions: list[Stmt] offset_definitions: list[TagDefinition] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index 4868306f41..9a1cfddfd2 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Any, Dict, Iterable, Iterator, List, Optional, TypeGuard, Union +from collections.abc import Iterable, Iterator +from typing import Any, Optional, TypeGuard import gt4py.eve as eve from gt4py.eve import NodeTranslator, concepts @@ -127,7 +128,7 @@ def commit_args( def _expand_lambda( self, node: gtfn_ir.FunCall, - new_args: List[gtfn_ir.FunCall], + new_args: list[gtfn_ir.FunCall], red_idx: str, max_neighbors: int, **kwargs: Any, @@ -157,7 +158,7 @@ def visit_Expr(self, node: gtfn_ir_common.Expr) -> gtfn_ir_common.Expr: def _expand_symref( self, node: gtfn_ir.FunCall, - new_args: List[gtfn_ir.FunCall], + new_args: list[gtfn_ir.FunCall], red_idx: str, max_neighbors: int, **kwargs: Any, @@ -244,8 +245,8 @@ def visit_TernaryExpr(self, node: gtfn_ir.TernaryExpr, **kwargs: Any) -> gtfn_ir def visit_FunctionDefinition( self, node: gtfn_ir.FunctionDefinition, **kwargs: Any ) -> ImperativeFunctionDefinition: - self.imp_list_ir: List[Union[Stmt, Conditional]] = [] - self.sym_table: Dict[gtfn_ir_common.Sym, gtfn_ir_common.SymRef] = node.annex.symtable + self.imp_list_ir: list[Stmt | Conditional] = [] + self.sym_table: dict[gtfn_ir_common.Sym, gtfn_ir_common.SymRef] = node.annex.symtable ret = self.visit(node.expr, localized_symbols={}, **kwargs) return ImperativeFunctionDefinition( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index a445390583..6afb6321a6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -8,7 +8,8 @@ import dataclasses import functools -from typing import Any, Callable, ClassVar, Iterable, Optional, Type, TypeGuard, Union +from collections.abc import Callable, Iterable +from typing import Any, ClassVar, Optional, TypeGuard import gt4py.eve as eve from gt4py.eve import utils as eve_utils @@ -360,7 +361,7 @@ def visit_Lambda( force_function_extraction: bool = False, extracted_functions: Optional[list] = None, **kwargs: Any, - ) -> Union[SymRef, Lambda]: + ) -> SymRef | Lambda: if force_function_extraction: assert extracted_functions is not None fun_id = self.uids.sequential_id(prefix="_fun") @@ -405,7 +406,7 @@ def _make_domain(self, node: itir.FunCall) -> tuple[TaggedValues, TaggedValues]: @staticmethod def _collect_offset_or_axis_node( - node_type: Type, tree: eve.Node | Iterable[eve.Node] + node_type: type, tree: eve.Node | Iterable[eve.Node] ) -> set[str]: if not isinstance(tree, Iterable): tree = [tree] @@ -531,8 +532,8 @@ def check_el_type(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> Expr: @staticmethod def _merge_scans( - executions: list[Union[StencilExecution, ScanExecution]], - ) -> list[Union[StencilExecution, ScanExecution]]: + executions: list[StencilExecution | ScanExecution], + ) -> list[StencilExecution | ScanExecution]: def merge(a: ScanExecution, b: ScanExecution) -> ScanExecution: assert a.backend == b.backend assert a.axis == b.axis @@ -586,7 +587,7 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any - ) -> Union[StencilExecution, ScanExecution]: + ) -> StencilExecution | ScanExecution: if _is_tuple_of_ref_or_literal(node.expr): node.expr = im.as_fieldop("deref", node.domain)(node.expr) @@ -661,7 +662,7 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E ) def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: - extracted_functions: list[Union[FunctionDefinition, ScanPassDefinition]] = [] + extracted_functions: list[FunctionDefinition | ScanPassDefinition] = [] executions = self.visit(node.body, extracted_functions=extracted_functions) executions = self._merge_scans(executions) function_definitions = self.visit(node.function_definitions) + extracted_functions diff --git a/src/gt4py/next/program_processors/program_formatter.py b/src/gt4py/next/program_processors/program_formatter.py index 321c09668c..67936e609c 100644 --- a/src/gt4py/next/program_processors/program_formatter.py +++ b/src/gt4py/next/program_processors/program_formatter.py @@ -23,7 +23,8 @@ import abc import dataclasses -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import gt4py.next.iterator.ir as itir diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 889cb0f800..e233a284cd 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -10,20 +10,8 @@ import abc import dataclasses -from typing import ( - Any, - Dict, - Final, - Iterable, - List, - Optional, - Protocol, - Sequence, - Set, - Tuple, - TypeAlias, - Union, -) +from collections.abc import Iterable, Sequence +from typing import Any, Final, Optional, Protocol, TypeAlias import dace from dace import subsets as dace_subsets @@ -397,12 +385,9 @@ def _add_edge( def _add_map( self, name: str, - ndrange: Union[ - Dict[str, Union[str, dace.subsets.Subset]], - List[Tuple[str, Union[str, dace.subsets.Subset]]], - ], + ndrange: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], **kwargs: Any, - ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + ) -> tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: """ Helper method to add a map in current state. @@ -414,8 +399,8 @@ def _add_map( def _add_tasklet( self, name: str, - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + inputs: set[str] | dict[str, dace.dtypes.typeclass], + outputs: set[str] | dict[str, dace.dtypes.typeclass], code: str, **kwargs: Any, ) -> dace.nodes.Tasklet: @@ -439,11 +424,11 @@ def _add_tasklet( def _add_mapped_tasklet( self, name: str, - map_ranges: Dict[str, str | dace.subsets.Subset] - | List[Tuple[str, str | dace.subsets.Subset]], - inputs: Dict[str, dace.Memlet], + map_ranges: dict[str, str | dace.subsets.Subset] + | list[tuple[str, str | dace.subsets.Subset]], + inputs: dict[str, dace.Memlet], code: str, - outputs: Dict[str, dace.Memlet], + outputs: dict[str, dace.Memlet], **kwargs: Any, ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: """ @@ -1674,7 +1659,7 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: builtin_name, set(node_connections.keys()), {out_connector}, - "{} = {}".format(out_connector, code), + f"{out_connector} = {code}", ) for connector, arg_expr in node_connections.items(): diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py index 4582c98bcb..a5d8bb9812 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py @@ -9,7 +9,8 @@ from __future__ import annotations import dataclasses -from typing import Optional, Sequence, TypeAlias +from collections.abc import Sequence +from typing import Optional, TypeAlias import dace import sympy diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 1219262f51..5219b4f3f2 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -8,7 +8,8 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import numpy as np import sympy diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index b399b5d8aa..c7fda10a30 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -17,19 +17,8 @@ import abc import dataclasses import itertools -from typing import ( - Any, - Dict, - Iterable, - List, - Mapping, - Optional, - Protocol, - Sequence, - Set, - Tuple, - Union, -) +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Optional, Protocol import dace @@ -90,12 +79,9 @@ def add_map( self, name: str, state: dace.SDFGState, - ndrange: Union[ - Dict[str, Union[str, dace.subsets.Subset]], - List[Tuple[str, Union[str, dace.subsets.Subset]]], - ], + ndrange: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], **kwargs: Any, - ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + ) -> tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" unique_name = self.unique_map_name(name) return state.add_map(unique_name, ndrange, **kwargs) @@ -104,8 +90,8 @@ def add_tasklet( self, name: str, state: dace.SDFGState, - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + inputs: set[str] | dict[str, dace.dtypes.typeclass], + outputs: set[str] | dict[str, dace.dtypes.typeclass], code: str, **kwargs: Any, ) -> dace.nodes.Tasklet: @@ -117,11 +103,11 @@ def add_mapped_tasklet( self, name: str, state: dace.SDFGState, - map_ranges: Dict[str, str | dace.subsets.Subset] - | List[Tuple[str, str | dace.subsets.Subset]], - inputs: Dict[str, dace.Memlet], + map_ranges: dict[str, str | dace.subsets.Subset] + | list[tuple[str, str | dace.subsets.Subset]], + inputs: dict[str, dace.Memlet], code: str, - outputs: Dict[str, dace.Memlet], + outputs: dict[str, dace.Memlet], **kwargs: Any, ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns unique name.""" @@ -417,10 +403,7 @@ def _make_array_shape_and_strides( # expression of domain range 'stop - start' shape.append( dace.symbolic.pystr_to_symbolic( - "{} - {}".format( - gtx_dace_utils.range_stop_symbol(name, i), - gtx_dace_utils.range_start_symbol(name, i), - ) + f"{gtx_dace_utils.range_stop_symbol(name, i)} - {gtx_dace_utils.range_start_symbol(name, i)}" ) ) strides = [ diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index 1a1d056f1b..d7666c4509 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -9,7 +9,8 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, Protocol import dace from dace import subsets as dace_subsets diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index 25e70f4844..7fe2794e27 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -23,7 +23,8 @@ from __future__ import annotations import itertools -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any import dace from dace import subsets as dace_subsets diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py index 8bd708b298..7b8cc71b9e 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Dict, Optional, TypeVar +from typing import Optional, TypeVar import dace @@ -129,14 +129,14 @@ class ReplaceSymbols(eve.PreserveLocationVisitor, eve.NodeTranslator): T = TypeVar("T", gtir.Sym, gtir.SymRef) - def _replace_sym(self, node: T, symtable: Dict[str, str]) -> T: + def _replace_sym(self, node: T, symtable: dict[str, str]) -> T: sym = str(node.id) return type(node)(id=symtable.get(sym, sym), type=node.type) - def visit_Sym(self, node: gtir.Sym, *, symtable: Dict[str, str]) -> gtir.Sym: + def visit_Sym(self, node: gtir.Sym, *, symtable: dict[str, str]) -> gtir.Sym: return self._replace_sym(node, symtable) - def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.SymRef: + def visit_SymRef(self, node: gtir.SymRef, *, symtable: dict[str, str]) -> gtir.SymRef: return self._replace_sym(node, symtable) # program arguments are checked separetely, because they cannot be replaced diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 1109278444..b339a30321 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -10,7 +10,8 @@ import dataclasses import itertools import typing -from typing import Any, ClassVar, Optional, Sequence +from collections.abc import Sequence +from typing import Any, ClassVar, Optional import dace import numpy as np diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index f0f27a091c..9095155810 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -8,7 +8,8 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional import dace from dace import data as dace_data diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/dead_dataflow_elimination.py b/src/gt4py/next/program_processors/runners/dace/transformations/dead_dataflow_elimination.py index 80c6895892..fbabdda6bd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/dead_dataflow_elimination.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/dead_dataflow_elimination.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional import dace from dace import properties as dace_properties, transformation as dace_transformation diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index fcb235b94c..fd1e51418e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -11,7 +11,8 @@ from __future__ import annotations import copy -from typing import Any, Callable, Final, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any, Final, Optional import dace from dace import ( @@ -174,7 +175,7 @@ def restrict_fusion_to_newly_created_maps_vertical( self: gtx_transformations.MapFusionVertical, map_exit_1: dace_nodes.MapExit, map_entry_2: dace_nodes.MapEntry, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, ) -> bool: return (map_entry_2 in new_maps) or (graph.entry_node(map_exit_1) in new_maps) @@ -183,7 +184,7 @@ def restrict_fusion_to_newly_created_maps_horizontal( self: gtx_transformations.MapFusionHorizontal, map_entry_1: dace_nodes.MapEntry, map_entry_2: dace_nodes.MapEntry, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, ) -> bool: return (map_entry_1 in new_maps) or (map_entry_2 in new_maps) @@ -392,7 +393,7 @@ def gt_set_gpu_blocksize( configured_maps = 0 for state in sdfg.states(): - scope_dict: Union[None, dict[Any, Any]] = None + scope_dict: None | dict[Any, Any] = None cfg_id = state.parent_graph.cfg_id state_id = state.block_id for node in state.nodes(): @@ -425,7 +426,7 @@ def gt_set_gpu_blocksize( def _make_gpu_block_parser_for( dim: int, -) -> Callable[["GPUSetBlockSize", Any], None]: +) -> Callable[[GPUSetBlockSize, Any], None]: """Generates a parser for GPU blocks for dimension `dim`. The returned function can be used as parser for the `GPUSetBlockSize.block_size_*d` @@ -473,11 +474,11 @@ def _gpu_block_parser( def _make_gpu_block_getter_for( dim: int, -) -> Callable[["GPUSetBlockSize"], tuple[int, int, int]]: +) -> Callable[[GPUSetBlockSize], tuple[int, int, int]]: """Makes the getter for the block size of dimension `dim`.""" def _gpu_block_getter( - self: "GPUSetBlockSize", + self: GPUSetBlockSize, ) -> tuple[int, int, int]: """Used as getter in the `GPUSetBlockSize.block_size` property.""" return getattr(self, f"_block_size_{dim}d") @@ -619,7 +620,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -642,7 +643,7 @@ def can_be_applied( def apply( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, ) -> None: """Modify the map as requested.""" @@ -721,10 +722,10 @@ def gt_remove_trivial_gpu_maps( # Now we try to fuse them together, however, we restrict the fusion to trivial # GPU map. def restrict_to_trivial_gpu_maps( - self: Union[gtx_transformations.MapFusionVertical, gtx_transformations.MapFusionHorizontal], - map_node_1: Union[dace_nodes.MapEntry, dace_nodes.MapExit], + self: gtx_transformations.MapFusionVertical | gtx_transformations.MapFusionHorizontal, + map_node_1: dace_nodes.MapEntry | dace_nodes.MapExit, map_entry_2: dace_nodes.MapEntry, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, ) -> bool: map_entry_1 = ( @@ -829,7 +830,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -907,7 +908,7 @@ def can_be_applied( return True - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + def apply(self, graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG) -> None: """Performs the Map Promoting. The function will first perform the promotion of the trivial map and then diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index c9fb595390..ed0d54afc6 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -from typing import Any, Optional, Union +from typing import Any, Optional import dace from dace import ( @@ -79,7 +79,7 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): def __init__( self, blocking_size: Optional[int] = None, - blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, + blocking_parameter: Optional[gtx_common.Dimension | str] = None, require_independent_nodes: Optional[bool] = None, ) -> None: super().__init__() @@ -100,7 +100,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -150,7 +150,7 @@ def can_be_applied( def apply( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, ) -> None: """Creates a blocking map. diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py index 3e8746fbba..cdded21120 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py @@ -9,7 +9,8 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Any, Callable, Optional, TypeAlias, Union +from collections.abc import Callable +from typing import Any, Optional, TypeAlias import dace from dace import nodes as dace_nodes, properties as dace_properties @@ -88,7 +89,7 @@ def __init__( def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -127,7 +128,7 @@ def __init__( def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py index 863456cdab..dfbb7d767d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_extended.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Union +from typing import Any, Optional import dace from dace import ( @@ -161,7 +161,7 @@ def __init__( def split_maps( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, first_map_entry: dace_nodes.MapEntry, first_map_exit: dace_nodes.MapExit, @@ -295,7 +295,7 @@ def can_be_applied( return False # Test if the map is in the right scope. - map_scope: Union[dace_nodes.Node, None] = scope_dict[self.first_map_entry] + map_scope: dace_nodes.Node | None = scope_dict[self.first_map_entry] if self.only_toplevel_maps and (map_scope is not None): return False @@ -370,7 +370,7 @@ def apply( # would need to obtain the scope dict after every iteration again, which would require # a rescan. But since the operations do not alter the scopes, at least not in a way that # would affect us, we can be more efficient and get the thing at the beginning. - scope_dict: Dict = graph.scope_dict() + scope_dict: dict = graph.scope_dict() for first_map_fragment_entry, second_map_fragment_entry in matched_map_fragments: first_map_fragment_exit = graph.exit_node(first_map_fragment_entry) @@ -446,7 +446,7 @@ def can_be_applied( second_map = self.second_map_entry.map # Test if the map is in the right scope. - map_scope: Union[dace_nodes.Node, None] = graph.scope_dict()[first_map_entry] + map_scope: dace_nodes.Node | None = graph.scope_dict()[first_map_entry] if self.only_toplevel_maps and (map_scope is not None): return False @@ -481,7 +481,7 @@ def can_be_applied( return True - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + def apply(self, graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG) -> None: """Split the map range in order to obtain an overlapping range between the first and second map.""" first_map_entry: dace_nodes.MapEntry = graph.entry_node(self.first_map_exit) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py index c018542c46..6625131a3a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, Union +from typing import Any, Optional import dace from dace import properties as dace_properties, transformation as dace_transformation @@ -19,7 +19,7 @@ def gt_set_iteration_order( sdfg: dace.SDFG, unit_strides_dim: Optional[ - Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + str | gtx_common.Dimension | list[str | gtx_common.Dimension] ] = None, unit_strides_kind: Optional[gtx_common.DimensionKind] = None, validate: bool = True, @@ -115,7 +115,7 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): def __init__( self, unit_strides_dims: Optional[ - Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + str | gtx_common.Dimension | list[str | gtx_common.Dimension] ] = None, unit_strides_kind: Optional[gtx_common.DimensionKind] = None, *args: Any, @@ -144,7 +144,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -160,7 +160,7 @@ def can_be_applied( def apply( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG, ) -> None: """Performs the actual parameter reordering. diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py index 0a1934a46c..a4a47d0fba 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py @@ -8,7 +8,8 @@ import copy import warnings -from typing import Any, Callable, Mapping, Optional, TypeAlias, Union +from collections.abc import Callable, Mapping +from typing import Any, Optional, TypeAlias import dace from dace import ( @@ -218,7 +219,7 @@ def __init__( def can_be_applied( self, - graph: Union[dace.SDFGState, dace.SDFG], + graph: dace.SDFGState | dace.SDFG, expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -227,7 +228,7 @@ def can_be_applied( second_map_entry: dace_nodes.MapEntry = self.entry_second_map if self.only_inner_maps or self.only_toplevel_maps: - scope_dict: Mapping[dace_nodes.Node, Union[dace_nodes.Node, None]] = graph.scope_dict() + scope_dict: Mapping[dace_nodes.Node, dace_nodes.Node | None] = graph.scope_dict() if self.only_inner_maps and (scope_dict[second_map_entry] is None): return False if self.only_toplevel_maps and (scope_dict[second_map_entry] is not None): @@ -278,10 +279,8 @@ def can_be_applied( # According to [issue#2095](https://github.com/spcl/dace/issues/2095) DaCe is quite # liberal concerning the positivity assumption, but in GT4Py this is not possible. second_map_iterations = second_map_iterations.subs( - ( - (sym, dace.symbol(sym.name, nonnegative=False)) - for sym in list(second_map_iterations.free_symbols) - ) + (sym, dace.symbol(sym.name, nonnegative=False)) + for sym in list(second_map_iterations.free_symbols) ) if (second_map_iterations > 0) != True: # noqa: E712 [true-false-comparison] # SymPy fuzzy bools. return False @@ -313,7 +312,7 @@ def _promote_first_map( first_map_exit.map.params = copy.deepcopy(second_map_entry.map.params) first_map_exit.map.range = copy.deepcopy(second_map_entry.map.range) - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + def apply(self, graph: dace.SDFGState | dace.SDFG, sdfg: dace.SDFG) -> None: first_map_exit: dace_nodes.MapExit = self.exit_first_map access_node: dace_nodes.AccessNode = self.access_node second_map_entry: dace_nodes.MapEntry = self.entry_second_map diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/multi_state_global_self_copy_elimination.py b/src/gt4py/next/program_processors/runners/dace/transformations/multi_state_global_self_copy_elimination.py index 7854299780..5cc4f6394a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/multi_state_global_self_copy_elimination.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/multi_state_global_self_copy_elimination.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, TypeAlias, Union +from typing import Any, Optional, TypeAlias import dace from dace import ( @@ -623,7 +623,7 @@ def _filter_candidate( transient_data: str, write_locations: list[AccessLocation], read_locations: list[AccessLocation], - ) -> Union[None, str]: + ) -> None | str: """Test if the transient can be eliminated. The function tests if transient data can be eliminated and be replaced by a @@ -657,7 +657,7 @@ def _filter_candidate( # TODO(phimuell): To better handle `concat_where` also allow multiple producers. # TODO(phimuell): In `concat_where` we are using `dynamic` Memlets, they should # also be checked. - global_data: Union[None, str] = None + global_data: None | str = None for state, transient_access_node in write_locations: for iedge in state.in_edges(transient_access_node): src_node = iedge.src @@ -700,7 +700,7 @@ def _find_exclusive_read_and_write_locations_of( self, sdfg: dace.SDFG, data_name: str, - ) -> Union[None, tuple[list[AccessLocation], list[AccessLocation]]]: + ) -> None | tuple[list[AccessLocation], list[AccessLocation]]: """The function finds all locations were `data_name` is written and read. The function will scan the SDFG and returns all places where `data_name` is diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py index 4db1f293dd..f7b8f150dd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional import dace from dace import ( diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index 69332b3a72..fa89cfcdbd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -11,7 +11,8 @@ import collections import copy import uuid -from typing import Any, Iterable, Optional, TypeAlias +from collections.abc import Iterable +from typing import Any, Optional, TypeAlias import dace from dace import ( @@ -775,7 +776,7 @@ def _check_read_write_dependency_impl( if dnode.data != global_data_name: continue dnode_degree = sum( - (1 for oedge in state_to_inspect.out_edges(dnode) if not oedge.data.is_empty()) + 1 for oedge in state_to_inspect.out_edges(dnode) if not oedge.data.is_empty() ) if dnode_degree > 1: return True diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/single_state_global_self_copy_elimination.py b/src/gt4py/next/program_processors/runners/dace/transformations/single_state_global_self_copy_elimination.py index d5ab08946c..74b64722ed 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/single_state_global_self_copy_elimination.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/single_state_global_self_copy_elimination.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Final, Literal, Optional, Sequence, Union, overload +from collections.abc import Sequence +from typing import Any, Final, Literal, Optional, overload import dace from dace import ( @@ -562,7 +563,7 @@ def _compute_tmp_to_g_mapping( node_tmp: dace_nodes.AccessNode, node_g2: dace_nodes.AccessNode, check_only: Literal[True], - ) -> Union[None, dict[dace_sbs.Subset, dace_sbs.Subset]]: ... + ) -> None | dict[dace_sbs.Subset, dace_sbs.Subset]: ... @overload def _compute_tmp_to_g_mapping( @@ -581,7 +582,7 @@ def _compute_tmp_to_g_mapping( node_tmp: dace_nodes.AccessNode, node_g2: dace_nodes.AccessNode, check_only: bool, - ) -> Union[None, dict[dace_sbs.Subset, dace_sbs.Subset]]: + ) -> None | dict[dace_sbs.Subset, dace_sbs.Subset]: """Computes a mapping that describes how `tmp` maps into `g`. The function returns a `dict`, that maps subsets of the `tmp_node` @@ -737,7 +738,7 @@ def _check_merging_strategy( node_g1: dace_nodes.AccessNode, node_tmp: dace_nodes.AccessNode, node_g2: dace_nodes.AccessNode, - ) -> Union[None, int]: + ) -> None | int: """Tests which merging strategy should be used. By default the transformation tries to merge the three nodes together, @@ -804,7 +805,7 @@ def _compute_offset( tdesc: gtx_dace_split.EdgeConnectionSpec, tmp_to_g_mapping: dict[dace_sbs.Subset, dace_sbs.Subset], check_only: bool, - ) -> Union[list[dace_sym.SymbolicType], bool]: + ) -> list[dace_sym.SymbolicType] | bool: """Computes the offset to turn a subset described in terms of `tmp` into a `g`. `tdesc` describes an edge that interacts with `tmp`, the function returns diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/split_access_nodes.py b/src/gt4py/next/program_processors/runners/dace/transformations/split_access_nodes.py index 1fb4cda322..ebee830e32 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/split_access_nodes.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/split_access_nodes.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Any, Iterable, Optional +from collections.abc import Iterable +from typing import Any, Optional import dace from dace import properties as dace_properties, transformation as dace_transformation @@ -251,7 +252,7 @@ def _find_edge_reassignment( # AccessNode has to be kept alive. warnings.warn( "'SplitAccessNode': found producers " - + ", ".join((str(p) for p in unused_producers)) + + ", ".join(str(p) for p in unused_producers) + " that generates data but that is never read.", stacklevel=0, ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py b/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py index 9c7273a2b8..9d4cbe429c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py @@ -8,7 +8,8 @@ import copy import dataclasses -from typing import Any, Iterable, Optional, Sequence, Union +from collections.abc import Iterable, Sequence +from typing import Any, Optional import dace from dace import data as dace_data, subsets as dace_sbs, symbolic as dace_sym @@ -307,7 +308,7 @@ def split_edge( sdfg: dace.SDFG, edge_to_split: dace_graph.MultiConnectorEdge, split_description: Sequence[dace_sbs.Subset], -) -> dict[Union[dace_sbs.Range, None], set[dace_graph.MultiConnectorEdge]]: +) -> dict[dace_sbs.Range | None, set[dace_graph.MultiConnectorEdge]]: """Tries to split `edge_to_split` into multiple edges. How the edge is split is described by `split_description`, which is a @@ -366,7 +367,7 @@ def split_edge( new_fully_splitted_subsets.append(consumer) fully_splitted_subsets = new_fully_splitted_subsets - new_edges: dict[Union[dace_sbs.Range, None], dace_graph.MultiConnectorEdge] = { + new_edges: dict[dace_sbs.Range | None, dace_graph.MultiConnectorEdge] = { split: set() for split in split_description } new_edges[None] = set() @@ -413,7 +414,7 @@ def split_edge( def decompose_subset( producer: dace_sbs.Subset, consumer: dace_sbs.Subset, -) -> Union[list[dace_sbs.Subset], None]: +) -> list[dace_sbs.Subset] | None: """ Decompose `consumer` into pieces either covered by `producer` or have no intersection. @@ -544,7 +545,7 @@ def decompose_subset( def subset_merger( - subsets: Union[Sequence[EdgeConnectionSpec], Sequence[dace_sbs.Subset]], + subsets: Sequence[EdgeConnectionSpec] | Sequence[dace_sbs.Subset], ) -> list[dace_sbs.Subset]: """Merges subsets together. @@ -615,7 +616,7 @@ def _subset_merger_impl( def _try_to_merge_subsets( subset1: dace_sbs.Subset, subset2: dace_sbs.Subset, -) -> Union[None, dace_sbs.Subset]: +) -> None | dace_sbs.Subset: """Tries to merge the subsets together, it it is impossible return `None`. Two subset can only be merged if they have the same bounds in all but one diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 0561bde568..d8923359e4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -8,7 +8,8 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Any, Container, Optional, Sequence, TypeVar, Union +from collections.abc import Container, Sequence +from typing import Any, Optional, TypeVar import dace from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym @@ -237,8 +238,8 @@ def is_accessed_downstream( def is_reachable( - start: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], - target: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], + start: dace_nodes.Node | Sequence[dace_nodes.Node], + target: dace_nodes.Node | Sequence[dace_nodes.Node], state: dace.SDFGState, ) -> bool: """Explores the graph from `start` and checks if `target` is reachable. @@ -268,7 +269,7 @@ def is_reachable( def is_source_node_of( sink: dace_nodes.Node, - possible_sources: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], + possible_sources: dace_nodes.Node | Sequence[dace_nodes.Node], state: dace.SDFGState, ) -> bool: """Explores the graph and checks if `possible_sources` produce data for `sink`. @@ -303,7 +304,7 @@ def is_source_node_of( def is_view( - node: Union[dace_nodes.AccessNode, dace_data.Data], + node: dace_nodes.AccessNode | dace_data.Data, sdfg: Optional[dace.SDFG] = None, ) -> bool: """Tests if `node` points to a view or not.""" @@ -643,7 +644,7 @@ def find_successor_state(state: dace.SDFGState) -> list[dace.SDFGState]: """ def _impl( - state: Union[dace.SDFGState, dace.sdfg.state.AbstractControlFlowRegion], + state: dace.SDFGState | dace.sdfg.state.AbstractControlFlowRegion, graph: dace.sdfg.state.AbstractControlFlowRegion, ) -> list[dace.sdfg.state.AbstractControlFlowRegion]: assert state is not graph diff --git a/src/gt4py/next/program_processors/runners/dace/utils.py b/src/gt4py/next/program_processors/runners/dace/utils.py index 01792e4fab..41767e601c 100644 --- a/src/gt4py/next/program_processors/runners/dace/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/utils.py @@ -9,7 +9,8 @@ from __future__ import annotations import re -from typing import Final, Literal, Mapping, Optional, Union +from collections.abc import Mapping +from typing import Final, Literal, Optional import dace @@ -131,9 +132,7 @@ def filter_connectivity_types( def safe_replace_symbolic( val: dace.symbolic.SymbolicType, - symbol_mapping: Mapping[ - Union[dace.symbolic.SymbolicType, str], Union[dace.symbolic.SymbolicType, str] - ], + symbol_mapping: Mapping[dace.symbolic.SymbolicType | str, dace.symbolic.SymbolicType | str], ) -> dace.symbolic.SymbolicType: """ Replace free symbols in a dace symbolic expression, using `safe_replace()` diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py index bf5449bcd6..1880504ed2 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py @@ -253,13 +253,7 @@ def {_cb_get_stride}(ndarray, dim_index): """) code.empty_line() code.append( - "def {funname}({arg0}, {arg1}, {arg2}, {arg3}):".format( - funname=bind_func_name, - arg0=_cb_device, - arg1=_cb_sdfg_argtypes, - arg2=_cb_args, - arg3=_cb_last_call_args, - ) + f"def {bind_func_name}({_cb_device}, {_cb_sdfg_argtypes}, {_cb_args}, {_cb_last_call_args}):" ) code.indent() for i, param in enumerate(program_source.entry_point.parameters): diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/common.py b/src/gt4py/next/program_processors/runners/dace/workflow/common.py index e9915f8633..0cff55fd7d 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/common.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/common.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import contextlib -from typing import Any, Generator, Optional +from collections.abc import Generator +from typing import Any, Optional import dace diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index ed36ef1947..1a080ca21b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -12,7 +12,8 @@ import importlib import os import pathlib -from typing import Any, Callable, Sequence +from collections.abc import Callable, Sequence +from typing import Any import dace import factory diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 4eb8b12ed9..9082ff465c 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -9,7 +9,8 @@ from __future__ import annotations import functools -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import dace diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 30e8a5da14..5413191ee2 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -8,19 +8,8 @@ import functools import types -from collections.abc import Callable, Iterator -from typing import ( - Any, - Generic, - Literal, - Protocol, - Sequence, - Type, - TypeGuard, - TypeVar, - cast, - overload, -) +from collections.abc import Callable, Iterator, Sequence +from typing import Any, Generic, Literal, Protocol, TypeGuard, TypeVar, cast, overload import numpy as np @@ -64,7 +53,7 @@ def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: return False -def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: +def type_class(symbol_type: ts.TypeSpec) -> type[ts.TypeSpec]: """ Determine which class should be used to create a compatible concrete type. diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 23737d8ba1..5f20efb36f 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Optional, Sequence, Union +from collections.abc import Iterator, Sequence +from typing import Optional from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -145,7 +146,7 @@ class FunctionType(TypeSpec, CallableType): pos_only_args: Sequence[TypeSpec] pos_or_kw_args: dict[str, TypeSpec] kw_only_args: dict[str, TypeSpec] - returns: Union[TypeSpec] + returns: TypeSpec def __str__(self) -> str: arg_strs = [str(arg) for arg in self.pos_only_args] diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 16f3629c92..c09b05f036 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -8,7 +8,8 @@ import functools import itertools -from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload +from collections.abc import Callable +from typing import Any, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload class RecursionGuard: diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 1499acbc8b..085c04b360 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -34,7 +34,6 @@ Type, TypeAlias, TypeGuard, - Union, ) @@ -44,10 +43,7 @@ cp = None -_NDBuffer: TypeAlias = Union[ - # TODO(): add `xtyping.Buffer` once we update typing_extensions - xtyping.ArrayInterface, xtyping.CUDAArrayInterface, xtyping.DLPackBuffer -] +_NDBuffer: TypeAlias = xtyping.ArrayInterface | xtyping.CUDAArrayInterface | xtyping.DLPackBuffer #: Tuple of positive integers encoding a permutation of the dimensions, such that #: layout_map[i] = j means that the i-th dimension of the tensor corresponds diff --git a/src/gt4py/storage/cartesian/interface.py b/src/gt4py/storage/cartesian/interface.py index 8b38bcdd42..f865584a83 100644 --- a/src/gt4py/storage/cartesian/interface.py +++ b/src/gt4py/storage/cartesian/interface.py @@ -9,7 +9,8 @@ from __future__ import annotations import numbers -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional import numpy as np @@ -44,7 +45,7 @@ def empty( backend: str, aligned_index: Optional[Sequence[int]] = None, dimensions: Optional[Sequence[str]] = None, -) -> Union[np.ndarray, "cp.ndarray"]: +) -> np.ndarray | cp.ndarray: """Allocate an array of uninitialized (undefined) values with performance-optimal strides and alignment. Parameters @@ -109,7 +110,7 @@ def ones( backend: str, aligned_index: Optional[Sequence[int]] = None, dimensions: Optional[Sequence[str]] = None, -) -> Union[np.ndarray, "cp.ndarray"]: +) -> np.ndarray | cp.ndarray: """Allocate an array with values initialized to 1.0 with performance-optimal strides and alignment. Parameters @@ -162,7 +163,7 @@ def full( backend: str, aligned_index: Optional[Sequence[int]] = None, dimensions: Optional[Sequence[str]] = None, -) -> Union[np.ndarray, "cp.ndarray"]: +) -> np.ndarray | cp.ndarray: """Allocate an array with values initialized to `fill_value` with performance-optimal strides and alignment. Parameters @@ -216,7 +217,7 @@ def zeros( backend: str, aligned_index: Optional[Sequence[int]] = None, dimensions: Optional[Sequence[str]] = None, -) -> Union[np.ndarray, "cp.ndarray"]: +) -> np.ndarray | cp.ndarray: """Allocate an array with values initialized to 0.0 with performance-optimal strides and alignment. Parameters @@ -268,7 +269,7 @@ def from_array( backend: str, aligned_index: Optional[Sequence[int]] = None, dimensions: Optional[Sequence[str]] = None, -) -> Union[np.ndarray, "cp.ndarray"]: +) -> np.ndarray | cp.ndarray: """Allocate an array with values initialized to those of `data` with performance-optimal strides and alignment. This copies the values from `data` to the resulting buffer. diff --git a/src/gt4py/storage/cartesian/layout.py b/src/gt4py/storage/cartesian/layout.py index af5dc0a2ba..35040f2458 100644 --- a/src/gt4py/storage/cartesian/layout.py +++ b/src/gt4py/storage/cartesian/layout.py @@ -6,19 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Final, - Literal, - Optional, - Sequence, - Tuple, - TypedDict, - Union, -) +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Final, Literal, Optional, TypedDict, Union import numpy as np @@ -33,11 +22,11 @@ class LayoutInfo(TypedDict): alignment: int # measured in bytes device: Literal["cpu", "gpu"] - layout_map: Callable[[Tuple[str, ...]], Tuple[int, ...]] - is_optimal_layout: Callable[[Any, Tuple[str, ...]], bool] + layout_map: Callable[[tuple[str, ...]], tuple[int, ...]] + is_optimal_layout: Callable[[Any, tuple[str, ...]], bool] -REGISTRY: Dict[str, LayoutInfo] = {} +REGISTRY: dict[str, LayoutInfo] = {} def from_name(name: str) -> Optional[LayoutInfo]: @@ -66,9 +55,9 @@ def check_layout(layout_map, strides): def layout_maker_factory( - base_layout: Tuple[int, ...], -) -> Callable[[Tuple[str, ...]], Tuple[int, ...]]: - def layout_maker(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: + base_layout: tuple[int, ...], +) -> Callable[[tuple[str, ...]], tuple[int, ...]]: + def layout_maker(dimensions: tuple[str, ...]) -> tuple[int, ...]: mask = [dim in dimensions for dim in "IJK"] mask += [True] * (len(dimensions) - sum(mask)) ranks = [] @@ -90,7 +79,7 @@ def layout_maker(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: def layout_checker_factory(layout_maker): - def layout_checker(field: Union[np.ndarray, "cp.ndarray"], dimensions: Tuple[str, ...]) -> bool: + def layout_checker(field: Union[np.ndarray, "cp.ndarray"], dimensions: tuple[str, ...]) -> bool: layout_map = layout_maker(dimensions) return check_layout(layout_map, field.strides) @@ -98,8 +87,8 @@ def layout_checker(field: Union[np.ndarray, "cp.ndarray"], dimensions: Tuple[str def _permute_layout_to_dimensions( - layout: Sequence[int], dimensions: Tuple[str, ...] -) -> Tuple[int, ...]: + layout: Sequence[int], dimensions: tuple[str, ...] +) -> tuple[int, ...]: data_dims = [int(d) for d in dimensions if d.isdigit()] canonical_dimensions = [d for d in "IJK" if d in dimensions] + [ str(d) for d in sorted(data_dims) @@ -110,14 +99,14 @@ def _permute_layout_to_dimensions( return tuple(res_layout) -def make_gtcpu_kfirst_layout_map(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: +def make_gtcpu_kfirst_layout_map(dimensions: tuple[str, ...]) -> tuple[int, ...]: layout = [i for i in range(len(dimensions))] naxes = sum(dim in dimensions for dim in "IJK") layout = [*layout[-naxes:], *layout[:-naxes]] return _permute_layout_to_dimensions([lt for lt in layout if lt is not None], dimensions) -def make_gtcpu_ifirst_layout_map(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: +def make_gtcpu_ifirst_layout_map(dimensions: tuple[str, ...]) -> tuple[int, ...]: ctr = reversed(range(len(dimensions))) layout = [next(ctr) for dim in "IJK" if dim in dimensions] + list(ctr) if "K" in dimensions and "J" in dimensions: @@ -128,7 +117,7 @@ def make_gtcpu_ifirst_layout_map(dimensions: Tuple[str, ...]) -> Tuple[int, ...] return _permute_layout_to_dimensions(layout, dimensions) -def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: +def make_cuda_layout_map(dimensions: tuple[str, ...]) -> tuple[int, ...]: layout = tuple(reversed(range(len(dimensions)))) return _permute_layout_to_dimensions(layout, dimensions) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 2275c1cd57..f964f46803 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -12,7 +12,8 @@ import functools import math import numbers -from typing import Literal, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Literal, Optional, Union, cast import numpy as np import numpy.typing as npt @@ -75,7 +76,7 @@ def _strides_from_padded_shape(padded_size, order_idx, itemsize): return list(strides) -def dimensions_to_mask(dimensions: Tuple[str, ...]) -> Tuple[bool, ...]: +def dimensions_to_mask(dimensions: tuple[str, ...]) -> tuple[bool, ...]: ndata_dims = sum(d.isdigit() for d in dimensions) mask = [(d in dimensions) for d in "IJK"] + [True for _ in range(ndata_dims)] return tuple(mask) @@ -86,7 +87,7 @@ def normalize_storage_spec( shape: Sequence[int], dtype: DTypeLike, dimensions: Optional[Sequence[str]], -) -> Tuple[Sequence[int], Sequence[int], np.dtype, Tuple[str, ...]]: +) -> tuple[Sequence[int], Sequence[int], np.dtype, tuple[str, ...]]: """Normalize the fields of the storage spec in a homogeneous representation. Returns @@ -149,14 +150,14 @@ def normalize_storage_spec( aligned_index = tuple(aligned_index) if any(i < 0 for i in aligned_index): - raise ValueError("aligned_index ({}) contains negative value.".format(aligned_index)) + raise ValueError(f"aligned_index ({aligned_index}) contains negative value.") else: raise TypeError("aligned_index must be an iterable of ints.") dtype = np.dtype(dtype) if dtype.shape: # Subarray dtype - sub_dtype, sub_shape = cast(Tuple[np.dtype, Tuple[int, ...]], dtype.subdtype) + sub_dtype, sub_shape = cast(tuple[np.dtype, tuple[int, ...]], dtype.subdtype) aligned_index = (*aligned_index, *((0,) * dtype.ndim)) shape = (*shape, *sub_shape) dimensions = (*dimensions, *(str(d) for d in range(dtype.ndim))) @@ -165,7 +166,7 @@ def normalize_storage_spec( return aligned_index, shape, dtype, dimensions -def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray: +def cpu_copy(array: np.ndarray | cp.ndarray) -> np.ndarray: if cp is not None: # it's not clear from the documentation if cp.asnumpy guarantees a copy. # worst case, this copies twice. @@ -216,14 +217,14 @@ def asarray( raise TypeError(f"Cannot convert {type(array)} to ndarray") -def get_dims(obj: Union[core_defs.GTDimsInterface, npt.NDArray]) -> Optional[Tuple[str, ...]]: +def get_dims(obj: core_defs.GTDimsInterface | npt.NDArray) -> Optional[tuple[str, ...]]: dims = getattr(obj, "__gt_dims__", None) if dims is None: return dims return tuple(str(d) for d in dims) -def get_origin(obj: Union[core_defs.GTDimsInterface, npt.NDArray]) -> Optional[Tuple[int, ...]]: +def get_origin(obj: core_defs.GTDimsInterface | npt.NDArray) -> Optional[tuple[int, ...]]: origin = getattr(obj, "__gt_origin__", None) if origin is None: return origin @@ -236,7 +237,7 @@ def allocate_cpu( dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], -) -> Tuple[allocators._NDBuffer, np.ndarray]: +) -> tuple[allocators._NDBuffer, np.ndarray]: device = core_defs.Device(core_defs.DeviceType.CPU, 0) buffer = _CPUBufferAllocator.allocate( shape, @@ -255,7 +256,7 @@ def _allocate_gpu( dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], -) -> Tuple["cp.ndarray", "cp.ndarray"]: +) -> tuple[cp.ndarray, cp.ndarray]: assert cp is not None assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" if core_defs.CUPY_DEVICE_TYPE is None: @@ -282,7 +283,7 @@ def _allocate_gpu( if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: class CUDAArrayInterfaceNDArray(cp.ndarray): - def __new__(cls, input_array: "cp.ndarray") -> CUDAArrayInterfaceNDArray: + def __new__(cls, input_array: cp.ndarray) -> CUDAArrayInterfaceNDArray: return ( input_array if isinstance(input_array, CUDAArrayInterfaceNDArray) @@ -310,7 +311,7 @@ def _allocate_gpu_rocm( dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], - ) -> Tuple["cp.ndarray", "cp.ndarray"]: + ) -> tuple[cp.ndarray, cp.ndarray]: buffer, ndarray = _allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) return buffer, CUDAArrayInterfaceNDArray(ndarray) From b7c2a25d97ef14c762ee5cdc5858f02d459b513e Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Aug 2025 15:53:02 +0200 Subject: [PATCH 4/4] Fix bug introduced earlier --- src/gt4py/next/type_system/type_translation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index dea06c30d6..5f19e573db 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -151,7 +151,8 @@ def from_type_hint( return ts.FunctionType( pos_only_args=new_args, pos_or_kw_args=kwargs, - kw_only_args={}, # TODO(): fix returns=returns, + kw_only_args={}, # TODO(): fix + returns=returns, ) raise ValueError(f"'{type_hint}' type is not supported.")