diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index ea393e2ad0..116772cc9b 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -11,6 +11,7 @@ from gt4py import eve from gt4py._core import definitions as core_defs +from gt4py.eve import extended_typing as xtyping from gt4py.next import common, errors, field_utils, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context from gt4py.next.field_utils import get_array_ns @@ -108,7 +109,9 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> domain = kwargs.pop("domain", None) - out_domain = common.domain(domain) if domain is not None else _get_out_domain(out) + out_domain = ( + utils.tree_map(common.domain)(domain) if domain is not None else _get_out_domain(out) + ) new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) @@ -128,6 +131,7 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> return op(*args, **kwargs) +@utils.tree_map def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] assert len(vertical_dim_filtered) <= 1 @@ -137,17 +141,19 @@ def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.Nothin def _tuple_assign_field( target: tuple[common.MutableField | tuple, ...] | common.MutableField, source: tuple[common.Field | tuple, ...] | common.Field, - domain: common.Domain, + domain: xtyping.MaybeNestedInTuple[common.Domain], ) -> None: @utils.tree_map - def impl(target: common.MutableField, source: common.Field) -> None: + def impl(target: common.MutableField, source: common.Field, domain: common.Domain) -> None: if isinstance(source, common.Field): target[domain] = source[domain] else: assert core_defs.is_scalar_type(source) target[domain] = source - impl(target, source) + if not isinstance(domain, tuple): + domain = utils.tree_map(lambda _: domain)(target) + impl(target, source, domain) def _intersect_scan_args( diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index bcfd1efbee..308097692c 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -714,8 +714,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - domain = common.domain(kwargs.pop("domain")) - out = utils.tree_map(lambda f: f[domain])(out) + domain = utils.tree_map(common.domain)(kwargs.pop("domain")) + if not isinstance(domain, tuple): + domain = utils.tree_map(lambda _: domain)(out) + out = utils.tree_map(lambda f, dom: f[dom])(out, domain) args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9355273588..42330cbaaf 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -37,6 +37,53 @@ def _is_integral_scalar(expr: past.Expr) -> bool: return isinstance(expr.type, ts.ScalarType) and type_info.is_integral(expr.type) +def _validate_domain_out( + dom: past.Dict | past.TupleExpr, + out: ts.TypeSpec, + is_nested: bool = False, +) -> None: + if isinstance(dom, past.Dict): + # Only reject tuple outputs if nested + if is_nested and isinstance(out, ts.TupleType): + raise ValueError("Domain dict cannot map to tuple outputs.") + assert not (is_nested and isinstance(out, past.TupleExpr)) + + if len(dom.keys_) == 0: + raise ValueError("Empty domain not allowed.") + + for dim in dom.keys_: + if not isinstance(dim.type, ts.DimensionType): + raise ValueError( + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." + ) + + for domain_values in dom.values_: + if len(domain_values.elts) != 2: + raise ValueError( + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." + ) + if any(not _is_integral_scalar(el) for el in domain_values.elts): + raise ValueError( + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." + ) + + elif isinstance(dom, past.TupleExpr): + if isinstance(out, ts.TupleType): + out_elts = out.types + else: + raise ValueError(f"Tuple domain requires tuple output, got {type(out)}.") + + if len(dom.elts) != len(out_elts): + raise ValueError("Mismatched tuple lengths between domain and output.") + + for d, o in zip(dom.elts, out_elts, strict=True): + assert isinstance(d, (past.Dict, past.TupleExpr)) + _validate_domain_out(d, o, is_nested=True) + + else: + raise ValueError(f"'domain' must be Dict or TupleExpr, got {type(dom)}.") + + def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: """ Perform checks for domain and output field types. @@ -53,32 +100,11 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: if "out" not in new_kwargs: raise ValueError("Missing required keyword argument 'out'.") - if "domain" in new_kwargs: + if (domain := new_kwargs.get("domain")) is not None: _ensure_no_sliced_field(new_kwargs["out"]) - - domain_kwarg = new_kwargs["domain"] - if not isinstance(domain_kwarg, past.Dict): - raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.") - - if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: - raise ValueError("Empty domain not allowed.") - - for dim in domain_kwarg.keys_: - if not isinstance(dim.type, ts.DimensionType): - raise ValueError( - f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." - ) - for domain_values in domain_kwarg.values_: - if len(domain_values.elts) != 2: - raise ValueError( - f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." - ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] - ): - raise ValueError( - f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." - ) + out = new_kwargs["out"] + assert isinstance(out, past.Expr) and out.type is not None + _validate_domain_out(domain, out.type) class ProgramTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -131,11 +157,22 @@ def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute type=getattr(new_value.type, node.attr), ) + def visit_Dict(self, node: past.Dict, **kwargs: Any) -> past.Dict: + # the only supported dict for now is in domain specification + keys = self.visit(node.keys_, **kwargs) + assert all(isinstance(key.type, ts.DimensionType) for key in keys) + return past.Dict( + keys_=keys, + values_=self.visit(node.values_, **kwargs), + location=node.location, + type=ts.DomainType(dims=[key.type.dim for key in keys]), + ) + def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: elts = self.visit(node.elts, **kwargs) - return past.TupleExpr( - elts=elts, type=ts.TupleType(types=[el.type for el in elts]), location=node.location - ) + ttype = ts.TupleType(types=[elt.type for elt in elts]) + + return past.TupleExpr(elts=elts, type=ttype, location=node.location) def _deduce_binop_type( self, node: past.BinOp, *, left: past.Expr, right: past.Expr, **kwargs: Any diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 628efb001c..271d79f153 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -10,11 +10,11 @@ import dataclasses import functools -from typing import Any, Optional, cast +from typing import Any, Optional, Sequence, cast import devtools -from gt4py.eve import NodeTranslator, concepts, traits, utils as eve_utils +from gt4py.eve import NodeTranslator, traits from gt4py.next import common, config, errors, utils as gtx_utils from gt4py.next.ffront import ( fbuiltins, @@ -164,19 +164,51 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension] return iter(scanops_per_axis.keys()).__next__() -def _range_arg_from_field(field_name: str, dim: int) -> str: - return f"__{field_name}_{dim}_range" +def _compute_field_slice(node: past.Subscript) -> list[past.Slice]: + out_field_name: past.Name = node.value + out_field_slice_: list[past.Slice] + if isinstance(node.slice_, past.TupleExpr) and all( + isinstance(el, past.Slice) for el in node.slice_.elts + ): + out_field_slice_ = cast(list[past.Slice], node.slice_.elts) # type ensured by if + elif isinstance(node.slice_, past.Slice): + out_field_slice_ = [node.slice_] + else: + raise AssertionError( + "Unexpected 'out' argument, must be tuple of slices or slice expression." + ) + node_dims = cast(ts.FieldType, node.type).dims + assert isinstance(node_dims, list) + if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims): + raise errors.DSLError( + node.location, + f"Too many indices for field '{out_field_name}': field is {len(node_dims)}" + f"-dimensional, but {len(out_field_slice_)} were indexed.", + ) + return out_field_slice_ + + +def _get_element_from_tuple_expr(node: past.Expr, path: tuple[int, ...]) -> past.Expr: + """Get element from a (nested) TupleExpr by following the given path. + + Pre-condition: `node` is a `past.TupleExpr` (if `path ! = ()`) + and `path` is a valid path through the nested tuple structure. + """ + return functools.reduce(lambda e, i: e.elts[i], path, node) # type: ignore[attr-defined] # see pre-condition + +def _unwrap_tuple_expr(expr: past.Expr, path: tuple[int, ...]) -> tuple[past.Expr, Sequence[int]]: + """Unwrap (nested) TupleExpr by following the given path as long as possible. -def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]: - if isinstance(node, (past.Name, past.Subscript)): - return [node] - elif isinstance(node, past.TupleExpr): - result = [] - for e in node.elts: - result.extend(_flatten_tuple_expr(e)) - return result - raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") + If a non-tuple expression is encountered, the current expression and the remaining path are + returned. + """ + path_remainder: Sequence[int] = path + while isinstance(expr, past.TupleExpr): + idx, *path_remainder = path_remainder + expr = expr.elts[idx] + + return expr, path_remainder @dataclasses.dataclass @@ -320,38 +352,18 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: def _construct_itir_domain_arg( self, - out_field: past.Name, + out_expr: itir.Expr, + out_type: ts.FieldType, node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: - assert isinstance(out_field.type, ts.TypeSpec) - out_field_types = type_info.primitive_constituents(out_field.type).to_list() - out_dims = cast(ts.FieldType, out_field_types[0]).dims - if any( - not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims - for out_field_type in out_field_types - ): - raise AssertionError( - f"Expected constituents of '{out_field.id}' argument to be" - " fields defined on the same dimensions. This error should be " - " caught in type deduction already." - ) - # if the out_field is a (potentially nested) tuple we get the domain from its first - # element - first_out_el_path = eve_utils.first( - type_info.primitive_constituents(out_field.type, with_path_arg=True) - )[1] - first_out_el = functools.reduce( - lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id - ) - domain_args = [] - domain_args_kind = [] - for dim_i, dim in enumerate(out_dims): + for dim_i, dim in enumerate(out_type.dims): # an expression for the range of a dimension dim_range = im.call("get_domain_range")( - first_out_el, itir.AxisLiteral(value=dim.value, kind=dim.kind) + out_expr, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) + dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds lower: itir.Expr @@ -381,7 +393,6 @@ def _construct_itir_domain_arg( args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], ) ) - domain_args_kind.append(dim.kind) if self.grid_type == common.GridType.CARTESIAN: domain_builtin = "cartesian_domain" @@ -393,7 +404,7 @@ def _construct_itir_domain_arg( return itir.FunCall( fun=itir.SymRef(id=domain_builtin), args=domain_args, - location=(node_domain or out_field).location, + location=(node_domain or out_expr).location, ) def _construct_itir_initialized_domain_arg( @@ -409,80 +420,52 @@ def _construct_itir_initialized_domain_arg( return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] - @staticmethod - def _compute_field_slice(node: past.Subscript) -> list[past.Slice]: - out_field_name: past.Name = node.value - out_field_slice_: list[past.Slice] - if isinstance(node.slice_, past.TupleExpr) and all( - isinstance(el, past.Slice) for el in node.slice_.elts - ): - out_field_slice_ = cast(list[past.Slice], node.slice_.elts) # type ensured by if - elif isinstance(node.slice_, past.Slice): - out_field_slice_ = [node.slice_] + def _split_field_and_slice( + self, field: past.Name | past.Subscript + ) -> tuple[past.Name, list[past.Slice] | None]: + if isinstance(field, past.Subscript): + return field.value, _compute_field_slice(field) else: - raise AssertionError( - "Unexpected 'out' argument, must be tuple of slices or slice expression." - ) - node_dims = cast(ts.FieldType, node.type).dims - assert isinstance(node_dims, list) - if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims): - raise errors.DSLError( - node.location, - f"Too many indices for field '{out_field_name}': field is {len(node_dims)}" - f"-dimensional, but {len(out_field_slice_)} were indexed.", - ) - return out_field_slice_ + assert isinstance(field, past.Name) + return field, None def _visit_stencil_call_out_arg( self, out_arg: past.Expr, domain_arg: Optional[past.Expr], **kwargs: Any ) -> tuple[itir.Expr, itir.FunCall]: - if isinstance(out_arg, past.Subscript): - # as the ITIR does not support slicing a field we have to do a deeper - # inspection of the PAST to emulate the behaviour - out_field_name: past.Name = out_arg.value - return ( - self._construct_itir_out_arg(out_field_name), - self._construct_itir_domain_arg( - out_field_name, domain_arg, self._compute_field_slice(out_arg) - ), - ) - elif isinstance(out_arg, past.Name): - return ( - self._construct_itir_out_arg(out_arg), - self._construct_itir_domain_arg(out_arg, domain_arg), + assert isinstance(out_arg, (past.Subscript, past.Name, past.TupleExpr)), ( + "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." + ) + + @gtx_utils.tree_map( + collection_type=ts.TupleType, + with_path_arg=True, + unpack=True, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + ) + def impl(out_type: ts.FieldType, path: tuple[int, ...]) -> tuple[itir.Expr, itir.Expr]: + out_field, path_remainder = _unwrap_tuple_expr(out_arg, path) + + assert isinstance(out_field, (past.Name, past.Subscript)) + out_field, slice_info = self._split_field_and_slice(out_field) + + domain_element = ( + _get_element_from_tuple_expr(domain_arg, path) + if isinstance(domain_arg, past.TupleExpr) + else domain_arg ) - elif isinstance(out_arg, past.TupleExpr): - flattened = _flatten_tuple_expr(out_arg) - - first_field = flattened[0] - assert all( - self.visit(field.type).dims == self.visit(first_field.type).dims - for field in flattened - ), "Incompatible fields in tuple: all fields must have the same dimensions." - - field_slice = None - if isinstance(first_field, past.Subscript): - assert all(isinstance(field, past.Subscript) for field in flattened), ( - "Incompatible field in tuple: either all fields or no field must be sliced." - ) - assert all( - concepts.eq_nonlocated( - first_field.slice_, - field.slice_, # type: ignore[union-attr] # mypy cannot deduce type - ) - for field in flattened - ), "Incompatible field in tuple: all fields must be sliced in the same way." - field_slice = self._compute_field_slice(first_field) - first_field = first_field.value - - return ( - self._construct_itir_out_arg(out_arg), - self._construct_itir_domain_arg(first_field, domain_arg, field_slice), + + lowered_out_field = functools.reduce( + lambda expr, i: im.tuple_get(i, expr), path_remainder, self.visit(out_field) ) - else: - raise AssertionError( - "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." + lowered_domain = self._construct_itir_domain_arg( + lowered_out_field, + out_type, + domain_element, + slice_info, ) + return lowered_out_field, lowered_domain + + return impl(out_arg.type) def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: if isinstance(node.type, ts.ScalarType) and node.type.shape is None: @@ -497,7 +480,7 @@ def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: raise NotImplementedError("Only scalar literals supported currently.") def visit_Name(self, node: past.Name, **kwargs: Any) -> itir.SymRef: - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) def visit_Symbol(self, node: past.Symbol, **kwargs: Any) -> itir.Sym: return itir.Sym(id=node.id, type=node.type) diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index ea579aa211..9e0eb30939 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -9,7 +9,7 @@ from typing import Any, Generic, Literal, Optional, TypeVar, Union import gt4py.eve as eve -from gt4py.eve import Coerced, Node, SourceLocation, SymbolName, SymbolRef +from gt4py.eve import Coerced, Node, SourceLocation, SymbolName, SymbolRef, datamodels from gt4py.eve.traits import SymbolTableTrait from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront from gt4py.next.type_system import type_specifications as ts @@ -85,6 +85,12 @@ class Dict(Expr): keys_: list[Union[Name | Attribute]] values_: list[TupleExpr] + @datamodels.root_validator + @classmethod + def keys_values_length_validation(cls: type["Dict"], instance: "Dict") -> None: + if len(instance.keys_) != len(instance.values_): + raise ValueError("`Dict` must have same number of keys as values.") + class Slice(Expr): lower: Optional[Constant] diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 266bdf3000..efbad21c2b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -45,7 +45,7 @@ overload, runtime_checkable, ) -from gt4py.next import common, field_utils +from gt4py.next import common, field_utils, utils from gt4py.next.embedded import ( context as embedded_context, exceptions as embedded_exceptions, @@ -1629,8 +1629,13 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider @runtime.set_at.register(EMBEDDED) -def set_at(expr: common.Field, domain: common.DomainLike, target: common.MutableField) -> None: - operators._tuple_assign_field(target, expr, common.domain(domain)) +def set_at( + expr: common.Field, + domain_like: xtyping.MaybeNestedInTuple[common.DomainLike], + target: common.MutableField, +) -> None: + domain = utils.tree_map(common.domain)(domain_like) + operators._tuple_assign_field(target, expr, domain) @runtime.get_domain_range.register(EMBEDDED) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index ce5c0c085d..08f4746cfb 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -8,7 +8,7 @@ import dataclasses from collections import ChainMap -from typing import Callable, Iterable, TypeVar +from typing import Callable, Iterable, TypeVar, cast from gt4py import eve from gt4py._core import definitions as core_defs @@ -229,9 +229,20 @@ def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: return common.GridType.UNSTRUCTURED +def _flatten_tuple_expr(expr: itir.Expr) -> tuple[itir.Expr]: + if cpm.is_call_to(expr, "make_tuple"): + return sum( + (_flatten_tuple_expr(arg) for arg in expr.args), start=cast(tuple[itir.Expr], ()) + ) + else: + return (expr,) + + def grid_type_from_program(program: itir.Program) -> common.GridType: - domains = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() - grid_types = {grid_type_from_domain(d) for d in domains} + domain_exprs = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + domains = sum((_flatten_tuple_expr(domain_expr) for domain_expr in domain_exprs), start=()) + assert all(isinstance(d, itir.FunCall) for d in domains) + grid_types = {grid_type_from_domain(d) for d in domains} # type: ignore[arg-type] # checked above if len(grid_types) != 1: raise ValueError( f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a31ef082d2..d4a6543aa3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -9,7 +9,6 @@ from __future__ import annotations import functools -from collections.abc import Sequence from typing import Callable, Literal, Optional, cast from gt4py.eve import utils as eve_utils @@ -30,7 +29,7 @@ def select_elems_by_domain( select_domain: SymbolicDomain, target: itir.Expr, - args: Sequence[itir.Expr], + source: itir.Expr, domains: tuple[SymbolicDomain, ...], ): """ @@ -40,12 +39,12 @@ def select_elems_by_domain( """ new_targets = [] new_els = [] - for i, (el, el_domain) in enumerate(zip(args, domains)): + for i, el_domain in enumerate(domains): current_target = im.tuple_get(i, target) + current_source = im.tuple_get(i, source) if isinstance(el_domain, tuple): - assert cpm.is_call_to(el, "make_tuple") more_targets, more_els = select_elems_by_domain( - select_domain, current_target, el.args, el_domain + select_domain, current_target, current_source, el_domain ) new_els.extend(more_els) new_targets.extend(more_targets) @@ -53,16 +52,15 @@ def select_elems_by_domain( assert isinstance(el_domain, SymbolicDomain) if el_domain == select_domain: new_targets.append(current_target) - new_els.append(el) + new_els.append(current_source) return new_targets, new_els def _set_at_for_domain(stmt: itir.SetAt, domain: SymbolicDomain) -> itir.SetAt: """Extract all elements with given domain into a new `SetAt` statement.""" tuple_expr = stmt.expr - assert cpm.is_call_to(tuple_expr, "make_tuple") targets, expr_els = select_elems_by_domain( - domain, stmt.target, tuple_expr.args, stmt.expr.annex.domain + domain, stmt.target, tuple_expr, stmt.expr.annex.domain ) new_expr = im.make_tuple(*expr_els) new_expr.annex.domain = domain diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 22889ea2de..d77ef9f096 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -483,7 +483,7 @@ def infer_expr( domain, fill_value=DomainAccessDescriptor.NEVER, # el_types already has the right structure, we only want to change domain - bidirectional=False, + bidirectional=False if not isinstance(expr.type, ts.DeferredType) else True, ) if cpm.is_applied_as_fieldop(expr) and cpm.is_call_to(expr.fun.args[0], "scan"): @@ -519,6 +519,13 @@ def infer_expr( return expr, accessed_domains +def _make_symbolic_domain_tuple(domains: itir.Node) -> DomainAccess: + if cpm.is_call_to(domains, "make_tuple"): + return tuple(_make_symbolic_domain_tuple(arg) for arg in domains.args) + else: + return SymbolicDomain.from_expr(domains) + + def _infer_stmt( stmt: itir.Stmt, **kwargs: Unpack[InferenceOptions], @@ -528,9 +535,9 @@ def _infer_stmt( # between the domain stored in IR and in the annex domain = constant_folding.ConstantFolding.apply(stmt.domain) - transformed_call, _ = infer_expr( - stmt.expr, domain_utils.SymbolicDomain.from_expr(domain), **kwargs - ) + symbolic_domain = _make_symbolic_domain_tuple(domain) + + transformed_call, _ = infer_expr(stmt.expr, symbolic_domain, **kwargs) return itir.SetAt( expr=transformed_call, diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py index 9641bd26a5..303bb5d936 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py @@ -14,8 +14,11 @@ import dace from dace import subsets as dace_subsets +from gt4py import eve +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common -from gt4py.next.iterator.ir_utils import domain_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils @@ -57,6 +60,30 @@ def get_field_domain(domain: domain_utils.SymbolicDomain) -> FieldopDomain: ] +TargetDomain: TypeAlias = MaybeNestedInTuple[domain_utils.SymbolicDomain] +"""Symbolic domain which defines the range to write in the target field. + +For tuple output in fieldview, `TargetDomain` is a tree-like tuple of symbolic domains. +""" + + +def extract_target_domain(node: gtir.Expr) -> TargetDomain: + """ + Visit a GTIR domain expression and construct a `TargetDomain` symbolic domain. + + We use a visitor class to extract the tree-like structure for (nested) tuple of domains. + """ + + class TargetDomainParser(eve.visitors.NodeTranslator): + def visit_FunCall(self, node: gtir.FunCall) -> TargetDomain: + if cpm.is_call_to(node, "make_tuple"): + return tuple(self.visit(arg) for arg in node.args) + else: + return domain_utils.SymbolicDomain.from_expr(node) + + return TargetDomainParser().visit(node) + + def get_domain_indices( dims: Sequence[gtx_common.Dimension], origin: Optional[Sequence[dace.symbolic.SymExpr]] ) -> dace_subsets.Indices: diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index b6e31a97ef..88165dc2ee 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -693,8 +693,7 @@ def visit_SetAt( """ # Visit the domain expression. - assert isinstance(stmt.domain.type, ts.DomainType) - domain = domain_utils.SymbolicDomain.from_expr(stmt.domain) + domain = gtir_domain.extract_target_domain(stmt.domain) # Visit the field operator expression. source_tree = self._visit_expression(stmt.expr, sdfg, state) @@ -754,10 +753,10 @@ def _visit_target( ) gtx_utils.tree_map( - lambda source, target, domain_=domain, target_state_=target_state: _visit_target( - source, target, domain_, target_state_ + lambda source, target, target_domain: _visit_target( + source, target, target_domain, target_state ) - )(source_tree, target_tree) + )(source_tree, target_tree, domain) if target_state.is_empty(): sdfg.remove_node(target_state) @@ -850,20 +849,23 @@ def visit_Lambda( lambda_arg_nodes: dict[str, gtir_to_sdfg_types.FieldopData] = {} for gt_symbol, arg in args.items(): gt_symbol_name = str(gt_symbol.id) - if isinstance(arg, gtir_to_sdfg_types.SymbolicData): + if arg is None: + pass # domain inference has detetcted that this argument is not used + elif isinstance(arg, gtir_to_sdfg_types.SymbolicData): symbolic_args[gt_symbol_name] = arg else: data_args[gt_symbol_name] = arg lambda_arg_nodes |= { str(nested_param.id): nested_arg for nested_param, nested_arg in gtir_to_sdfg_types.flatten_tuple(gt_symbol, arg) + if nested_arg is not None # we filter out arguments with empty domain } # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) - } | {str(param.id): param.type for param, arg in args.items()} + } | {str(param.id): param.type for param, arg in args.items() if arg is not None} assert all(isinstance(_type, ts.DataType) for _type in lambda_symbols.values()) # lower let-statement lambda node as a nested SDFG @@ -888,6 +890,7 @@ def visit_Lambda( } input_memlets = {} + unused_data = set() for nsdfg_dataname, nsdfg_datadesc in lambda_ctx.sdfg.arrays.items(): if nsdfg_datadesc.transient: pass # nothing to do here @@ -895,13 +898,20 @@ def visit_Lambda( arg_node = lambda_arg_nodes[nsdfg_dataname] source_data = arg_node.dc_node.data input_memlets[nsdfg_dataname] = ctx.sdfg.make_array_memlet(source_data) - else: - assert nsdfg_dataname in ctx.sdfg.arrays + elif nsdfg_dataname in ctx.sdfg.arrays: source_data = nsdfg_dataname # ensure that connectivity tables are non-transient arrays in parent SDFG if source_data in connectivity_arrays: ctx.sdfg.arrays[source_data].transient = False input_memlets[nsdfg_dataname] = ctx.sdfg.make_array_memlet(source_data) + else: + # This argument has empty domain, which means that it is not used + # by the lambda expression, and does not need to be connected on + # the nested SDFG. + unused_data.add(nsdfg_dataname) + + for data in sorted(unused_data): # NOTE: remove the data in deterministic order + lambda_ctx.sdfg.remove_data(data, validate=__debug__) # Process lambda outputs # @@ -999,7 +1009,10 @@ def construct_output_for_nested_sdfg( # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - outer_data = lambda_arg_nodes[inner_dataname] + outer_arg = lambda_arg_nodes[inner_dataname] + if outer_arg is None: + raise ValueError(f"Unexpected argument with empty domain {inner_data}.") + outer_data = outer_arg else: # This must be a symbol captured from the lambda parent scope. outer_node = ctx.state.add_access(inner_dataname) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index f23817fecd..a3072187f7 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -426,12 +426,8 @@ def translate_if( false_br_result = sdfg_builder.visit(false_expr, ctx=fbranch_ctx) node_output = gtx_utils.tree_map( - lambda domain, - true_br, - false_br, - _ctx=ctx, - sdfg_builder=sdfg_builder: _construct_if_branch_output( - ctx=_ctx, + lambda domain, true_br, false_br: _construct_if_branch_output( + ctx=ctx, sdfg_builder=sdfg_builder, field_domain=gtir_domain.get_field_domain(domain), true_br=true_br, @@ -442,10 +438,10 @@ def translate_if( true_br_result, false_br_result, ) - gtx_utils.tree_map(lambda src, dst, _ctx=tbranch_ctx: _write_if_branch_output(_ctx, src, dst))( + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(tbranch_ctx, src, dst))( true_br_result, node_output ) - gtx_utils.tree_map(lambda src, dst, _ctx=fbranch_ctx: _write_if_branch_output(_ctx, src, dst))( + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(fbranch_ctx, src, dst))( false_br_result, node_output ) @@ -600,9 +596,21 @@ def translate_tuple_get( data_nodes = sdfg_builder.visit(node.args[1], ctx=ctx) if isinstance(data_nodes, gtir_to_sdfg_types.FieldopData): raise ValueError(f"Invalid tuple expression {node}") - unused_arg_nodes: Iterable[gtir_to_sdfg_types.FieldopData] = gtx_utils.flatten_nested_tuple( + # Now we remove the tuple fields that are not used, to avoid an SDFG validation + # error because of isolated access nodes. + unused_arg_nodes = gtx_utils.flatten_nested_tuple( tuple(arg for i, arg in enumerate(data_nodes) if i != index) ) + # However, for temporary fields inside the tuple (non-globals and non-scalar + # values, supposed to contain the result of some field operator) the gt4py + # domain inference should have already set an empty domain, so the corresponding + # `arg` is expected to be None and can be ignored. + assert all( + not arg.dc_node.desc(ctx.sdfg).transient or isinstance(arg.gt_type, ts.ScalarType) + for arg in unused_arg_nodes + if arg is not None + ) + unused_arg_nodes = tuple(arg for arg in unused_arg_nodes if arg is not None) ctx.state.remove_nodes_from( [arg.dc_node for arg in unused_arg_nodes if ctx.state.degree(arg.dc_node) == 0] ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index 2693ffd8ec..575bd90d56 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -36,6 +36,7 @@ domain_utils, ir_makers as im, ) +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.program_processors.runners.dace import ( gtir_dataflow, gtir_domain, @@ -82,17 +83,17 @@ def _parse_fieldop_arg_impl( return _parse_fieldop_arg_impl(arg) else: # handle tuples of fields - return gtx_utils.tree_map(lambda x: _parse_fieldop_arg_impl(x))(arg) + return gtx_utils.tree_map(_parse_fieldop_arg_impl)(arg) def _create_scan_field_operator_impl( ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - field_domain: gtir_domain.FieldopDomain, - output_edge: gtir_dataflow.DataflowOutputEdge, + output_edge: gtir_dataflow.DataflowOutputEdge | None, + output_domain: infer_domain.NonTupleDomainAccess, output_type: ts.FieldType, map_exit: dace.nodes.MapExit, -) -> gtir_to_sdfg_types.FieldopData: +) -> gtir_to_sdfg_types.FieldopData | None: """ Helper method to allocate a temporary array that stores one field computed by the scan field operator. @@ -105,9 +106,25 @@ def _create_scan_field_operator_impl( Therefore, the memlet subset will write a slice into the result array, that corresponds to the full vertical shape for each horizontal grid point. + Another difference is that this function is called on all fields inside a tuple, + in case of tuple return. Note that a regular field operator only computes a + single field, never a tuple of fields. For tuples, it can happen that one of + the nested fields is not used, outside the scan field operator, and therefore + does not need to be computed. Then, the domain inferred by gt4py on this field + is empty and the corresponding `output_edge` argument to this function is None. + In this case, the function does not allocate an array node for the output field + and returns None. + Refer to `gtir_to_sdfg_primitives._create_field_operator_impl()` for - the description of function arguments and return values. + the description of function arguments. """ + if output_edge is None: + # According to domain inference, this tuple field does not need to be computed. + assert output_domain == infer_domain.DomainAccessDescriptor.NEVER + return None + assert isinstance(output_domain, domain_utils.SymbolicDomain) + field_domain = gtir_domain.get_field_domain(output_domain) + dataflow_output_desc = output_edge.result.dc_node.desc(ctx.sdfg) assert isinstance(dataflow_output_desc, dace.data.Array) @@ -201,7 +218,8 @@ def _create_scan_field_operator( node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_to_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output: MaybeNestedInTuple[gtir_dataflow.DataflowOutputEdge], + output: MaybeNestedInTuple[gtir_dataflow.DataflowOutputEdge | None], + output_domain: infer_domain.DomainAccess, ) -> gtir_to_sdfg_types.FieldopResult: """ Helper method to build the output of a field operator, which can consist of @@ -213,7 +231,12 @@ def _create_scan_field_operator( by a loop region in a mapped nested SDFG. Refer to `gtir_to_sdfg_primitives._create_field_operator()` for the - description of function arguments and return values. + description of function arguments. Note that the return value is different, + because the scan field operator can return a tuple of fields, while a regular + field operator return a single field. The domain of the nested fields, in + a tuple, can be empty, in case the nested field is not used outside the scan. + In this case, the corresponding `output` edge will be None and this function + will also return None for the corresponding field inside the tree-like result. """ dims, _, _ = gtir_domain.get_field_layout(field_domain) @@ -253,17 +276,17 @@ def _create_scan_field_operator( ) return gtx_utils.tree_map( - lambda edge, sym: ( + lambda edge, domain, sym: ( _create_scan_field_operator_impl( ctx, sdfg_builder, - field_domain, edge, + domain, sym.type, map_exit, ) ) - )(output, dummy_output_symbol) + )(output, output_domain, dummy_output_symbol) def _scan_input_name(input_name: str) -> str: @@ -372,6 +395,7 @@ def get_scan_output_shape( if isinstance(init_data, tuple): lambda_result_shape = gtx_utils.tree_map(get_scan_output_shape)(init_data) else: + assert init_data is not None lambda_result_shape = get_scan_output_shape(init_data) # Create the body of the initialization state @@ -549,6 +573,29 @@ def _connect_nested_sdfg_output_to_temporaries( return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr) +def _handle_dataflow_result_of_nested_sdfg( + nsdfg_node: dace.nodes.NestedSDFG, + inner_ctx: gtir_to_sdfg.SubgraphContext, + outer_ctx: gtir_to_sdfg.SubgraphContext, + inner_data: gtir_to_sdfg_types.FieldopData, + field_domain: infer_domain.NonTupleDomainAccess, +) -> gtir_dataflow.DataflowOutputEdge | None: + if isinstance(field_domain, domain_utils.SymbolicDomain): + # The field is used outside the nested SDFG, therefore it needs to be copied + # to a temporary array in the parent SDFG (outer context). + return _connect_nested_sdfg_output_to_temporaries( + inner_ctx, outer_ctx, nsdfg_node, inner_data + ) + else: + # The field is not used outside the nested SDFG. It is likely just storage + # for some internal state, accessed during column scan, and can be turned + # into a transient array inside the nested SDFG. + assert field_domain == infer_domain.DomainAccessDescriptor.NEVER + inner_data.dc_node.desc(inner_ctx.sdfg).transient = True + nsdfg_node.out_connectors.pop(inner_data.dc_node.data) + return None + + def translate_scan( node: gtir.Node, ctx: gtir_to_sdfg.SubgraphContext, @@ -628,13 +675,18 @@ def translate_scan( lambda_args = [sdfg_builder.visit(arg, ctx=ctx) for arg in node.args] lambda_args_mapping = [ (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), - ] + [(param, arg) for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True)] + ] + [ + (gt_symbol, arg) + for gt_symbol, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + if arg is not None + ] lambda_arg_nodes: dict[str, gtir_to_sdfg_types.FieldopData] = {} for gt_symbol, arg in lambda_args_mapping: lambda_arg_nodes |= { str(nested_gt_symbol.id): nested_arg for nested_gt_symbol, nested_arg in gtir_to_sdfg_types.flatten_tuple(gt_symbol, arg) + if nested_arg is not None } # parse the dataflow output symbols @@ -686,12 +738,16 @@ def translate_scan( # for output connections, we create temporary arrays that contain the computation # results of a column slice for each point in the horizontal domain output_tree = gtx_utils.tree_map( - lambda output_data: _connect_nested_sdfg_output_to_temporaries( - lambda_ctx, ctx, nsdfg_node, output_data + lambda output_data, output_domain: _handle_dataflow_result_of_nested_sdfg( + nsdfg_node=nsdfg_node, + inner_ctx=lambda_ctx, + outer_ctx=ctx, + inner_data=output_data, + field_domain=output_domain, ) - )(lambda_output) + )(lambda_output, node.annex.domain) # we call a helper method to create a map scope that will compute the entire field return _create_scan_field_operator( - ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree + ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree, node.annex.domain ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py index dba4243a5c..c2f5fbd081 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py @@ -180,8 +180,14 @@ def get_symbol_mapping( return symbol_mapping -FieldopResult: TypeAlias = MaybeNestedInTuple[FieldopData] -"""Result of a field operator, can be either a field or a tuple fields.""" +FieldopResult: TypeAlias = MaybeNestedInTuple[FieldopData | None] +"""Result of a field operator, can be either a field or a tuple fields. + +For tuple of fields, any of the nested fields can be None, in case it is not used +and therefore does not need to be computed. The information whether a field needs +to be computed or not is the result of GTIR domain inference, and it is stored in +the GTIR node annex domain. +""" @dataclasses.dataclass(frozen=True) @@ -194,7 +200,7 @@ class SymbolicData: """Data type used for field indexing.""" -def flatten_tuple(sym: gtir.Sym, arg: FieldopResult) -> list[tuple[gtir.Sym, FieldopData]]: +def flatten_tuple(sym: gtir.Sym, arg: FieldopResult) -> list[tuple[gtir.Sym, FieldopData | None]]: """ Visit a `FieldopResult`, potentially containing nested tuples, and construct a list of pairs `(gtir.Sym, FieldopData)` containing the symbol of each tuple diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 6361767f48..8f00b70ee0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -20,6 +20,7 @@ Optional, ParamSpec, Protocol, + Sequence, TypeGuard, TypeVar, cast, @@ -90,6 +91,8 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[Callable] = None, + unpack: bool = False, + with_path_arg: bool = False, ) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... @@ -98,6 +101,8 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[Callable] = None, + unpack: bool = False, + with_path_arg: bool = False, ) -> Callable[ [Callable[_P, _R]], Callable[..., Any] ]: ... # TODO(havogt): typing of `result_collection_constructor` is too weak here @@ -108,6 +113,8 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[Callable] = None, + unpack: bool = False, + with_path_arg: bool = False, ) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: """ Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s). @@ -116,7 +123,9 @@ def tree_map( fun: Function to apply to each entry of the collection. collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types. result_collection_constructor: Type of the collection to be returned. If `None` the same type as `collection_type` is used. - + unpack: Replicate tuple structure returned from `fun` to the mapped result, i.e. return + tuple of result collections instead of result collections of tuples. + with_path_arg: Pass the path to access the current element to `fun`. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) ((2, 3), 4) @@ -140,10 +149,32 @@ def tree_map( ... return x + 1 >>> impl(((1, 2), 3)) ((2, 3), 4) + + >>> @tree_map(with_path_arg=True) + ... def impl(x, path: tuple[int, ...]): + ... path_str = "".join(f"[{i}]" for i in path) + ... return f"t{path_str} = {x}" + >>> t = impl(((1, 2), 3)) + >>> t[0][0] + 't[0][0] = 1' + >>> t[0][1] + 't[0][1] = 2' + >>> t[1] + 't[1] = 3' + + >>> @tree_map(unpack=True) + ... def impl(x): + ... return (x, x**2) + >>> identity, squared = impl(((2, 3), 4)) + >>> identity + ((2, 3), 4) + >>> squared + ((4, 9), 16) """ if result_collection_constructor is None: if isinstance(collection_type, tuple): + # Note: that doesn't mean `collection_type=tuple`, but e.g. `collection_type=(list, tuple)` raise TypeError( "tree_map() requires `result_collection_constructor` when `collection_type` is a tuple of types." ) @@ -154,22 +185,41 @@ def tree_map( @functools.wraps(fun) def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: if isinstance(args[0], collection_type): + non_path_args: Sequence[Any] + if with_path_arg: + *non_path_args, path = args + args = (*non_path_args, tuple((*path, i) for i in range(len(args[0])))) + else: + non_path_args = args + assert all( - isinstance(arg, collection_type) and len(args[0]) == len(arg) for arg in args + isinstance(arg, collection_type) and len(args[0]) == len(arg) + for arg in non_path_args ) assert result_collection_constructor is not None - return result_collection_constructor(args[0], (impl(*arg) for arg in zip(*args))) + ctor = functools.partial(result_collection_constructor, args[0]) + + mapped = [impl(*arg) for arg in zip(*args)] + if unpack: + return tuple(map(ctor, zip(*mapped))) + else: + return ctor(mapped) return fun( # type: ignore[call-arg] - *cast(_P.args, args) # type: ignore[valid-type] + *cast(_P.args, args), # type: ignore[valid-type] ) # mypy doesn't understand that `args` at this point is of type `_P.args` - return impl + if with_path_arg: + return lambda *args: impl(*args, ()) + else: + return impl else: return functools.partial( tree_map, collection_type=collection_type, result_collection_constructor=result_collection_constructor, + unpack=unpack, + with_path_arg=with_path_arg, ) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 967cf0ab11..19e8e3ea6d 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -50,6 +50,7 @@ JDim, Joff, KDim, + KHalfDim, Koff, V2EDim, Vertex, @@ -701,6 +702,7 @@ def from_cartesian_grid_descriptor( IDim: grid_descriptor.sizes[0], JDim: grid_descriptor.sizes[1], KDim: grid_descriptor.sizes[2], + KHalfDim: grid_descriptor.sizes[3], }, grid_type=common.GridType.CARTESIAN, allocator=allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 2ab97814b9..0d6c44977e 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -42,21 +42,6 @@ def exec_alloc_descriptor(request): yield request.param -@pytest.fixture -def cartesian(request, gtir_dace_backend): - yield cases.Case( - backend=gtir_dace_backend, - offset_provider={ - "Ioff": IDim, - "Joff": JDim, - "Koff": KDim, - }, - default_sizes={IDim: 10, JDim: 10, KDim: 10}, - grid_type=common.GridType.CARTESIAN, - allocator=gtir_dace_backend.allocator, - ) - - @pytest.fixture def unstructured(request, exec_alloc_descriptor, mesh_descriptor): # noqa: F811 yield cases.Case( @@ -66,7 +51,6 @@ def unstructured(request, exec_alloc_descriptor, mesh_descriptor): # noqa: F811 Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index b2cb8b0a2c..7640553e6a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -135,6 +135,7 @@ def debug_itir(tree): IDim = gtx.Dimension("IDim") JDim = gtx.Dimension("JDim") KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) +KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) Joff = gtx.FieldOffset("Joff", source=JDim, target=(JDim,)) Koff = gtx.FieldOffset("Koff", source=KDim, target=(KDim,)) @@ -170,15 +171,18 @@ def offset_provider(self) -> common.OffsetProvider: ... def offset_provider_type(self) -> common.OffsetProviderType: ... -def simple_cartesian_grid(sizes: int | tuple[int, int, int] = 10) -> CartesianGridDescriptor: +def simple_cartesian_grid( + sizes: int | tuple[int, int, int, int] = (5, 7, 9, 11), +) -> CartesianGridDescriptor: if isinstance(sizes, int): - sizes = (sizes,) * 3 - assert len(sizes) == 3, "sizes must be a tuple of three integers" + sizes = (sizes,) * 4 + assert len(sizes) == 4, "sizes must be a tuple of four integers" offset_provider = { "Ioff": IDim, "Joff": JDim, "Koff": KDim, + "KHalfoff": KHalfDim, } return types.SimpleNamespace( @@ -207,9 +211,6 @@ def num_cells(self) -> int: ... @property def num_edges(self) -> int: ... - @property - def num_levels(self) -> int: ... - @property def offset_provider(self) -> common.OffsetProvider: ... diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index b8037b1082..647a94c8a2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -412,8 +412,9 @@ def testee(qc: cases.IKFloatField, scalar: float): qc = cases.allocate(cartesian_case, testee, "qc").zeros()() scalar = 1.0 + isize = cartesian_case.default_sizes[IDim] ksize = cartesian_case.default_sizes[KDim] - expected = np.full((ksize, ksize), np.arange(start=1, stop=11, step=1).astype(float64)) + expected = np.full((isize, ksize), np.arange(start=1, stop=ksize + 1, step=1).astype(float64)) cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) @@ -435,8 +436,9 @@ def testee_op( qc = cases.allocate(cartesian_case, testee_op, "qc").zeros()() tuple_scalar = (1.0, (1.0, 0.0)) + isize = cartesian_case.default_sizes[IDim] ksize = cartesian_case.default_sizes[KDim] - expected = np.full((ksize, ksize), np.arange(start=1.0, stop=11.0), dtype=float) + expected = np.full((isize, ksize), np.arange(start=1.0, stop=ksize + 1), dtype=float) cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) @@ -795,7 +797,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I @pytest.mark.parametrize("forward", [True, False]) def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 - expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1) + expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[KDim], 1) out = cartesian_case.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) if not forward: @@ -1090,22 +1092,22 @@ def fieldop_domain(a: cases.IField) -> cases.IField: return a + a @gtx.program - def program_domain(a: cases.IField, out: cases.IField): - fieldop_domain(a, out=out, domain={IDim: (minimum(1, 2), 9)}) + def program_domain(a: cases.IField, size: int32, out: cases.IField): + fieldop_domain(a, out=out, domain={IDim: (minimum(1, 2), size)}) a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - + size = cartesian_case.default_sizes[IDim] ref = out.asnumpy().copy() # ensure we are not writing to out outside the domain - ref[1:9] = a.asnumpy()[1:9] * 2 + ref[1:size] = a.asnumpy()[1:size] * 2 - cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) + cases.verify(cartesian_case, program_domain, a, size, out, inout=out, ref=ref) @pytest.mark.uses_floordiv def test_domain_input_bounds(cartesian_case): lower_i = 1 - upper_i = 10 + upper_i = cartesian_case.default_sizes[IDim] + 1 @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: @@ -1128,9 +1130,9 @@ def program_domain( def test_domain_input_bounds_1(cartesian_case): lower_i = 1 - upper_i = 9 - lower_j = 4 - upper_j = 6 + upper_i = cartesian_case.default_sizes[IDim] + lower_j = cartesian_case.default_sizes[JDim] - 3 + upper_j = cartesian_case.default_sizes[JDim] - 1 @gtx.field_operator def fieldop_domain(a: cases.IJField) -> cases.IJField: @@ -1180,19 +1182,30 @@ def fieldop_domain_tuple( @gtx.program def program_domain_tuple( - inp0: cases.IJField, inp1: cases.IJField, out0: cases.IJField, out1: cases.IJField + inp0: cases.IJField, + inp1: cases.IJField, + out0: cases.IJField, + out1: cases.IJField, + isize: int32, + jsize: int32, ): - fieldop_domain_tuple(inp0, inp1, out=(out0, out1), domain={IDim: (1, 9), JDim: (4, 6)}) + fieldop_domain_tuple( + inp0, inp1, out=(out0, out1), domain={IDim: (1, isize), JDim: (jsize - 2, jsize)} + ) inp0 = cases.allocate(cartesian_case, program_domain_tuple, "inp0")() inp1 = cases.allocate(cartesian_case, program_domain_tuple, "inp1")() out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() + isize = cartesian_case.default_sizes[IDim] + jsize = cartesian_case.default_sizes[JDim] - 1 ref0 = out0.asnumpy().copy() - ref0[1:9, 4:6] = inp0.asnumpy()[1:9, 4:6] + inp1.asnumpy()[1:9, 4:6] + ref0[1:isize, jsize - 2 : jsize] = ( + inp0.asnumpy()[1:isize, jsize - 2 : jsize] + inp1.asnumpy()[1:isize, jsize - 2 : jsize] + ) ref1 = out1.asnumpy().copy() - ref1[1:9, 4:6] = inp1.asnumpy()[1:9, 4:6] + ref1[1:isize, jsize - 2 : jsize] = inp1.asnumpy()[1:isize, jsize - 2 : jsize] cases.verify( cartesian_case, @@ -1201,6 +1214,8 @@ def program_domain_tuple( inp1, out0, out1, + isize, + jsize, inout=(out0, out1), ref=(ref0, ref1), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py index 8438a735dc..952dcb31bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py @@ -10,11 +10,11 @@ import numpy as np import gt4py.next as gtx -from gt4py.next import broadcast, astype +from gt4py.next import broadcast, astype, int32 from next_tests import integration_tests from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.cases import cartesian_case, IDim, KDim from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -29,27 +29,29 @@ def mod_op(f: cases.IField) -> cases.IKField: return f_i_k @gtx.program - def mod_prog(f: cases.IField, out: cases.IKField): + def mod_prog(f: cases.IField, isize: int32, ksize: int32, out: cases.IKField): mod_op( f, out=out, domain={ integration_tests.cases.IDim: ( 0, - 8, + isize, ), # Nested import done on purpose, do not change - cases.KDim: (0, 3), + cases.KDim: (0, ksize), }, ) f = cases.allocate(cartesian_case, mod_prog, "f")() out = cases.allocate(cartesian_case, mod_prog, "out")() expected = np.zeros_like(out.asnumpy()) - expected[0:8, 0:3] = np.reshape(np.repeat(f.asnumpy(), out.shape[1], axis=0), out.shape)[ - 0:8, 0:3 - ] + isize = cartesian_case.default_sizes[IDim] - 1 + ksize = cartesian_case.default_sizes[KDim] - 2 + expected[0:isize, 0:ksize] = np.reshape( + np.repeat(f.asnumpy(), out.shape[1], axis=0), out.shape + )[0:isize, 0:ksize] - cases.verify(cartesian_case, mod_prog, f, out=out, ref=expected) + cases.verify(cartesian_case, mod_prog, f, isize, ksize, out=out, ref=expected) # TODO: these set of features should be allowed as module imports in a later PR diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 1707adada8..bf6dd34cca 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -75,7 +75,9 @@ def test_mod(cartesian_case): def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 - inp1 = cartesian_case.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) + inp1 = cartesian_case.as_field( + [IDim], np.asarray(range(cartesian_case.default_sizes[IDim]), dtype=int32) - 5 + ) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py index 7d634cec90..cd10c10437 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py @@ -11,6 +11,7 @@ import pytest from next_tests.integration_tests.cases import IDim, JDim, KDim, Koff, cartesian_case from gt4py import next as gtx +from gt4py.next import int32 from gt4py.next.ffront.fbuiltins import where, broadcast from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -27,8 +28,14 @@ def fieldop_where_k_offset( return where(k_index > 0, inp(Koff[-1]), 2) @gtx.program - def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): - fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) + def prog( + inp: cases.IKField, + k_index: gtx.Field[[KDim], gtx.IndexType], + isize: int32, + ksize: int32, + out: cases.IKField, + ): + fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, isize), KDim: (1, ksize)}) inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() k_index = cases.allocate( @@ -37,8 +44,9 @@ def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cas out = cases.allocate(cartesian_case, fieldop_where_k_offset, cases.RETURN)() ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), out.asnumpy()) - - cases.verify(cartesian_case, prog, inp, k_index, out=out, ref=ref) + isize = cartesian_case.default_sizes[IDim] + ksize = cartesian_case.default_sizes[KDim] + cases.verify(cartesian_case, prog, inp, k_index, isize, ksize, out=out, ref=ref) def test_same_size_fields(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 5727f29a2a..e5c5182b81 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -37,7 +37,7 @@ def test_allocate_default_unique(cartesian_case): a = cases.allocate(cartesian_case, mixed_args, "a")() assert np.min(a.asnumpy()) == 1 - assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) + assert np.max(a.asnumpy()) == np.prod(tuple(list(cartesian_case.default_sizes.values())[:3])) b = cases.allocate(cartesian_case, mixed_args, "b")() @@ -46,7 +46,7 @@ def test_allocate_default_unique(cartesian_case): c = cases.allocate(cartesian_case, mixed_args, "c")() assert np.min(c.asnumpy()) == b + 1 - assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + 1 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())[:3]) * 2 + 1 def test_allocate_return_default_zeros(cartesian_case): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py new file mode 100644 index 0000000000..dd30caa726 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -0,0 +1,682 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import copy + +import numpy as np +import pytest + +import gt4py.next as gtx +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import ( + IDim, + JDim, + KDim, + C2E, + E2V, + V2E, + Edge, + EField, + CField, + VField, + Cell, + cartesian_case, + unstructured_case, + Case, + IField, + JField, + KField, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, + mesh_descriptor, +) + +KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) +pytestmark = pytest.mark.uses_cartesian_shift + + +@gtx.field_operator +def testee_no_tuple(a: IField, b: JField) -> IField: + return a + + +@gtx.program +def prog_no_tuple( + a: IField, + b: JField, + out_a: IField, + i_size: gtx.int32, +): + testee_no_tuple(a, b, out=out_a, domain={IDim: (0, i_size)}) + + +def test_program_no_tuple(cartesian_case): + a = cases.allocate(cartesian_case, prog_no_tuple, "a")() + b = cases.allocate(cartesian_case, prog_no_tuple, "b")() + out_a = cases.allocate(cartesian_case, prog_no_tuple, "out_a")() + + cases.verify( + cartesian_case, + prog_no_tuple, + a, + b, + out_a, + cartesian_case.default_sizes[IDim], + inout=out_a, + ref=a, + ) + + +@gtx.field_operator +def fop_original(a: IField, b: IField) -> tuple[IField, IField]: + return b, a + + +@gtx.program +def prog_orig( + a: IField, + b: IField, + out_a: IField, + out_b: IField, + i_size: gtx.int32, +): + fop_original(a, b, out=(out_b, out_a), domain={IDim: (0, i_size)}) + + +def test_program_orig(cartesian_case): + a = cases.allocate(cartesian_case, prog_orig, "a")() + b = cases.allocate(cartesian_case, prog_orig, "b")() + out_a = cases.allocate(cartesian_case, prog_orig, "out_a")() + out_b = cases.allocate(cartesian_case, prog_orig, "out_b")() + + cases.verify( + cartesian_case, + prog_orig, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[IDim], + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.program +def prog_no_domain( + a: IField, + b: IField, + out_a: IField, + out_b: IField, +): + fop_original(a, b, out=(out_b, out_a)) + + +def test_program_no_domain(cartesian_case): + a = cases.allocate(cartesian_case, prog_no_domain, "a")() + b = cases.allocate(cartesian_case, prog_no_domain, "b")() + out_a = cases.allocate(cartesian_case, prog_no_domain, "out_a")() + out_b = cases.allocate(cartesian_case, prog_no_domain, "out_b")() + + cases.verify( + cartesian_case, + prog_no_domain, + a, + b, + out_a, + out_b, + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.field_operator +def fop_different_fields(a: IField, b: JField) -> tuple[JField, IField]: + return b, a + + +@gtx.program +def prog_no_domain_different_fields( + a: IField, + b: JField, + out_a: IField, + out_b: JField, +): + fop_different_fields(a, b, out=(out_b, out_a)) + + +def test_program_no_domain_different_fields( + cartesian_case, +): + a = cases.allocate(cartesian_case, prog_no_domain_different_fields, "a")() + b = cases.allocate(cartesian_case, prog_no_domain_different_fields, "b")() + out_a = cases.allocate(cartesian_case, prog_no_domain_different_fields, "out_a")() + out_b = cases.allocate(cartesian_case, prog_no_domain_different_fields, "out_b")() + + cases.verify( + cartesian_case, + prog_no_domain_different_fields, + a, + b, + out_a, + out_b, + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.program +def prog( + a: IField, + b: JField, + out_a: IField, + out_b: JField, + i_size: gtx.int32, + j_size: gtx.int32, +): + fop_different_fields( + a, b, out=(out_b, out_a), domain=({JDim: (0, j_size)}, {IDim: (0, i_size)}) + ) + + +def test_program(cartesian_case): + a = cases.allocate(cartesian_case, prog, "a")() + b = cases.allocate(cartesian_case, prog, "b")() + out_a = cases.allocate(cartesian_case, prog, "out_a")() + out_b = cases.allocate(cartesian_case, prog, "out_b")() + + cases.verify( + cartesian_case, + prog, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.program +def prog_slicing( + a: IField, + b: JField, + out_a: IField, + out_b: JField, +): + fop_different_fields( + a, + b, + out=(out_b[2:-2], out_a[1:-1]), + ) + + +def test_program_slicing(cartesian_case): + a = cases.allocate(cartesian_case, prog, "a")() + b = cases.allocate(cartesian_case, prog, "b")() + out_a = cases.allocate(cartesian_case, prog, "out_a")() + out_b = cases.allocate(cartesian_case, prog, "out_b")() + out_a_ = copy.deepcopy(out_a) + out_b_ = copy.deepcopy(out_b) + cases.verify( + cartesian_case, + prog_slicing, + a, + b, + out_a, + out_b, + inout=(out_b, out_a), + ref=( + np.concatenate([out_b_.ndarray[0:2], b.ndarray[2:-2], out_b_.ndarray[-2:]]), + np.concatenate([out_a_.ndarray[0:1], a.ndarray[1:-1], out_a_.ndarray[-1:]]), + ), + ) + + +@gtx.program +def prog_out_as_tuple( + a: IField, + b: JField, + out: tuple[JField, IField], + i_size: gtx.int32, + j_size: gtx.int32, +): + fop_different_fields(a, b, out=out, domain=({JDim: (0, j_size)}, {IDim: (0, i_size)})) + + +def test_program_out_as_tuple( + cartesian_case, +): + a = cases.allocate(cartesian_case, prog_out_as_tuple, "a")() + b = cases.allocate(cartesian_case, prog_out_as_tuple, "b")() + out = cases.allocate(cartesian_case, prog_out_as_tuple, "out")() + + cases.verify( + cartesian_case, + prog_out_as_tuple, + a, + b, + out, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + inout=(out), + ref=(b, a), + ) + + +@gtx.program +def prog_out_as_tuple_different_sizes( + a: IField, + b: JField, + out: tuple[JField, IField], + i_size: gtx.int32, + j_size: gtx.int32, + restrict_i_0: gtx.int32, + restrict_i_1: gtx.int32, + restrict_j_0: gtx.int32, + restrict_j_1: gtx.int32, +): + fop_different_fields( + a, + b, + out=out, + domain=( + {JDim: (restrict_j_0, j_size + restrict_j_1)}, + {IDim: (restrict_i_0, i_size + restrict_i_1)}, + ), + ) + + +def test_program_out_as_tuple_different_sizes( + cartesian_case, +): + restrict_i = (1, -3) + restrict_j = (2, -4) + i_size = cartesian_case.default_sizes[IDim] + j_size = cartesian_case.default_sizes[JDim] + a = cases.allocate(cartesian_case, prog_out_as_tuple_different_sizes, "a")() + b = cases.allocate(cartesian_case, prog_out_as_tuple_different_sizes, "b")() + out = cases.allocate( + cartesian_case, + prog_out_as_tuple_different_sizes, + "out", + extend={IDim: (-restrict_i[0], restrict_i[1]), JDim: (-restrict_j[0], restrict_j[1])}, + )() + + cases.verify( + cartesian_case, + prog_out_as_tuple_different_sizes, + a, + b, + out, + i_size, + j_size, + restrict_i[0], + restrict_i[1], + restrict_j[0], + restrict_j[1], + inout=(out), + ref=( + b.ndarray[restrict_j[0] : j_size + restrict_j[1]], + a.ndarray[restrict_i[0] : i_size + restrict_i[1]], + ), + ) + + +@gtx.field_operator +def fop_nested_tuples( + a: IField, + b: JField, + c: KField, +) -> tuple[ + tuple[IField, JField], + KField, +]: + return (a, b), c + + +@gtx.program +def prog_nested_tuples( + a: IField, + b: JField, + c: KField, + out_a: IField, + out_b: JField, + out_c: KField, + i_size: gtx.int32, + j_size: gtx.int32, + k_size: gtx.int32, +): + fop_nested_tuples( + a, + b, + c, + out=((out_a, out_b), out_c), + domain=(({IDim: (0, i_size)}, {JDim: (0, j_size)}), {KDim: (0, k_size)}), + ) + + +def test_program_nested_tuples( + cartesian_case, +): + a = cases.allocate(cartesian_case, prog_nested_tuples, "a")() + b = cases.allocate(cartesian_case, prog_nested_tuples, "b")() + c = cases.allocate(cartesian_case, prog_nested_tuples, "c")() + out_a = cases.allocate(cartesian_case, prog_nested_tuples, "out_a")() + out_b = cases.allocate(cartesian_case, prog_nested_tuples, "out_b")() + out_c = cases.allocate(cartesian_case, prog_nested_tuples, "out_c")() + + cases.verify( + cartesian_case, + prog_nested_tuples, + a, + b, + c, + out_a, + out_b, + out_c, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + cartesian_case.default_sizes[KDim], + inout=((out_a, out_b), out_c), + ref=((a, b), c), + ) + + +@gtx.field_operator +def fop_double_nested_tuples( + a: IField, + b: JField, + c: KField, +) -> tuple[ + tuple[ + IField, + tuple[JField, KField], + ], + KField, +]: + return (a, (b, c)), c + + +@gtx.program +def prog_double_nested_tuples( + a: IField, + b: JField, + c: KField, + out_a: IField, + out_b: JField, + out_c0: KField, + out_c1: KField, + i_size: gtx.int32, + j_size: gtx.int32, + k_size: gtx.int32, +): + fop_double_nested_tuples( + a, + b, + c, + out=((out_a, (out_b, out_c0)), out_c1), + domain=( + ({IDim: (0, i_size)}, ({JDim: (0, j_size)}, {KDim: (0, k_size)})), + {KDim: (0, k_size)}, + ), + ) + + +def test_program_double_nested_tuples( + cartesian_case, +): + a = cases.allocate(cartesian_case, prog_double_nested_tuples, "a")() + b = cases.allocate(cartesian_case, prog_double_nested_tuples, "b")() + c = cases.allocate(cartesian_case, prog_double_nested_tuples, "c")() + out_a = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_a")() + out_b = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_b")() + out_c0 = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_c0")() + out_c1 = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_c1")() + + cases.verify( + cartesian_case, + prog_double_nested_tuples, + a, + b, + c, + out_a, + out_b, + out_c0, + out_c1, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + cartesian_case.default_sizes[KDim], + inout=((out_a, (out_b, out_c0)), out_c1), + ref=((a, (b, c)), c), + ) + + +@gtx.field_operator +def fop_two_vertical_dims( + a: KField, b: gtx.Field[[KHalfDim], gtx.float32] +) -> tuple[gtx.Field[[KHalfDim], gtx.float32], KField]: + return b, a + + +@gtx.program +def prog_two_vertical_dims( + a: KField, + b: gtx.Field[[KHalfDim], gtx.float32], + out_a: KField, + out_b: gtx.Field[[KHalfDim], gtx.float32], + k_size: gtx.int32, + k_half_size: gtx.int32, +): + fop_two_vertical_dims( + a, b, out=(out_b, out_a), domain=({KHalfDim: (0, k_half_size)}, {KDim: (0, k_size)}) + ) + + +def test_program_two_vertical_dims(cartesian_case): + a = cases.allocate(cartesian_case, prog_two_vertical_dims, "a")() + b = cases.allocate(cartesian_case, prog_two_vertical_dims, "b")() + out_a = cases.allocate(cartesian_case, prog_two_vertical_dims, "out_a")() + out_b = cases.allocate(cartesian_case, prog_two_vertical_dims, "out_b")() + + cases.verify( + cartesian_case, + prog_two_vertical_dims, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[KDim], + cartesian_case.default_sizes[KHalfDim], + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.field_operator +def fop_shift_e2c(a: EField) -> tuple[CField, EField]: + return a(C2E[1]), a + + +@gtx.program +def prog_unstructured( + a: EField, + out_a: EField, + out_a_shifted: CField, + c_size: gtx.int32, + e_size: gtx.int32, +): + fop_shift_e2c(a, out=(out_a_shifted, out_a), domain=({Cell: (0, c_size)}, {Edge: (0, e_size)})) + + +def test_program_unstructured(unstructured_case): + a = cases.allocate(unstructured_case, prog_unstructured, "a")() + out_a = cases.allocate(unstructured_case, prog_unstructured, "out_a")() + out_a_shifted = cases.allocate(unstructured_case, prog_unstructured, "out_a_shifted")() + + cases.verify( + unstructured_case, + prog_unstructured, + a, + out_a, + out_a_shifted, + unstructured_case.default_sizes[Cell], + unstructured_case.default_sizes[Edge], + inout=(out_a_shifted, out_a), + ref=((a.ndarray)[unstructured_case.offset_provider["C2E"].asnumpy()[:, 1]], a), + ) + + +@gtx.field_operator +def fop_temporary(a: VField): + edge = a(E2V[1]) + cell = edge(C2E[1]) + return edge, cell + + +@gtx.program +def prog_temporary( + a: VField, + out_edge: EField, + out_cell: CField, + c_size: gtx.int32, + e_size: gtx.int32, + restrict_edge_0: gtx.int32, + restrict_edge_1: gtx.int32, + restrict_cell_0: gtx.int32, + restrict_cell_1: gtx.int32, +): + fop_temporary( + a, + out=(out_edge, out_cell), + domain=( + {Edge: (restrict_edge_0, e_size + restrict_edge_1)}, + {Cell: (restrict_cell_0, c_size + restrict_cell_1)}, + ), + ) + + +def test_program_temporary(unstructured_case): + restrict_edge = (4, -2) + restrict_cell = (3, -1) + cell_size = unstructured_case.default_sizes[Cell] + edge_size = unstructured_case.default_sizes[Edge] + a = cases.allocate(unstructured_case, prog_temporary, "a")() + out_edge = cases.allocate( + unstructured_case, + prog_temporary, + "out_edge", + extend={Edge: (-restrict_edge[0], restrict_edge[1])}, + )() + out_cell = cases.allocate( + unstructured_case, + prog_temporary, + "out_cell", + extend={Cell: (-restrict_cell[0], restrict_cell[1])}, + )() + + e2v = (a.ndarray)[unstructured_case.offset_provider["E2V"].asnumpy()[:, 1]] + cases.verify( + unstructured_case, + prog_temporary, + a, + out_edge, + out_cell, + cell_size, + edge_size, + restrict_edge[0], + restrict_edge[1], + restrict_cell[0], + restrict_cell[1], + inout=(out_edge, out_cell), + ref=( + e2v[restrict_edge[0] : edge_size + restrict_edge[1]], + e2v[unstructured_case.offset_provider["C2E"].asnumpy()[:, 1]][ + restrict_cell[0] : cell_size + restrict_cell[1] + ], + ), + ) + + +def test_direct_fo_orig(cartesian_case): + a = cases.allocate(cartesian_case, fop_original, "a")() + b = cases.allocate(cartesian_case, fop_original, "b")() + out = cases.allocate(cartesian_case, fop_original, cases.RETURN)() + + cases.verify( + cartesian_case, + fop_original, + a, + b, + out=out, + ref=(b, a), + domain={IDim: (0, cartesian_case.default_sizes[IDim])}, + ) + + +def test_direct_fo_nested(cartesian_case): + a = cases.allocate(cartesian_case, fop_nested_tuples, "a")() + b = cases.allocate(cartesian_case, fop_nested_tuples, "b")() + c = cases.allocate(cartesian_case, fop_nested_tuples, "c")() + out = cases.allocate(cartesian_case, fop_nested_tuples, cases.RETURN)() + + cases.verify( + cartesian_case, + fop_nested_tuples, + a, + b, + c, + out=out, + ref=((a, b), c), + domain=( + ( + {IDim: (0, cartesian_case.default_sizes[IDim])}, + {JDim: (0, cartesian_case.default_sizes[JDim])}, + ), + {KDim: (0, cartesian_case.default_sizes[KDim])}, + ), + ) + + +def test_direct_fo(cartesian_case): + a = cases.allocate(cartesian_case, fop_different_fields, "a")() + b = cases.allocate(cartesian_case, fop_different_fields, "b")() + out = cases.allocate(cartesian_case, fop_different_fields, cases.RETURN)() + + cases.verify( + cartesian_case, + fop_different_fields, + a, + b, + out=out, + ref=(b, a), + domain=( + {JDim: (0, cartesian_case.default_sizes[JDim])}, + {IDim: (0, cartesian_case.default_sizes[IDim])}, + ), + ) + + +def test_direct_fo_nested_no_domain(cartesian_case): + a = cases.allocate(cartesian_case, fop_nested_tuples, "a")() + b = cases.allocate(cartesian_case, fop_nested_tuples, "b")() + c = cases.allocate(cartesian_case, fop_nested_tuples, "c")() + out = cases.allocate(cartesian_case, fop_nested_tuples, cases.RETURN)() + + cases.verify( + cartesian_case, + fop_nested_tuples, + a, + b, + c, + out=out, + ref=((a, b), c), + ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 411fc68f54..24cd3426e8 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -256,6 +256,7 @@ def ksum_even_odd_fencil(i_size, k_size, inp, out): @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_ksum_even_odd_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] @@ -304,6 +305,7 @@ def ksum_even_odd_nested_fencil(i_size, k_size, inp, out): @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_ksum_even_odd_nested_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index ad985e7ee8..2234895b5b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -154,7 +154,7 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): assert exc_info.match("Invalid call to 'domain_format_1'") assert ( - re.search("Only Dictionaries allowed in 'domain'", exc_info.value.__cause__.args[0]) + re.search("Tuple domain requires tuple output", exc_info.value.__cause__.args[0]) is not None ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 81ed0aee62..3fe845cf08 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -247,7 +247,10 @@ def test_gtir_tuple_swap(): body=[ gtir.SetAt( expr=im.make_tuple("y", "x"), - domain=im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + domain=im.make_tuple( + im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + im.get_field_domain(gtx_common.GridType.CARTESIAN, "y", [IDim]), + ), # TODO(havogt): add a frontend check for this pattern target=im.make_tuple("x", "y"), ) @@ -455,10 +458,22 @@ def test_gtir_tuple_return(): body=[ gtir.SetAt( expr=im.make_tuple(im.make_tuple(im.op_as_fieldop("plus")("x", "y"), "x"), "y"), - domain=im.get_field_domain( - gtx_common.GridType.CARTESIAN, - im.tuple_get(0, im.tuple_get(0, "z")), - [IDim], + domain=im.make_tuple( + im.make_tuple( + im.get_field_domain( + gtx_common.GridType.CARTESIAN, + im.tuple_get(0, im.tuple_get(0, "z")), + [IDim], + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, + im.tuple_get(1, im.tuple_get(0, "z")), + [IDim], + ), + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(1, "z"), [IDim] + ), ), target=gtir.SymRef(id="z"), ) @@ -502,7 +517,10 @@ def test_gtir_tuple_target(): body=[ gtir.SetAt( expr=im.make_tuple(im.op_as_fieldop("plus")("x", 1.0), gtir.SymRef(id="x")), - domain=im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + domain=im.make_tuple( + im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + im.get_field_domain(gtx_common.GridType.CARTESIAN, "y", [IDim]), + ), target=im.make_tuple("x", "y"), ) ], @@ -1851,8 +1869,13 @@ def test_gtir_let_lambda_with_tuple1(): im.make_tuple(im.op_as_fieldop("plus", inner_domain)("x", "y"), "x"), "y" ), )(im.make_tuple(im.tuple_get(1, im.tuple_get(0, "t")), im.tuple_get(1, "t"))), - domain=im.get_field_domain( - gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + domain=im.make_tuple( + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(1, "z"), [IDim] + ), ), target=gtir.SymRef(id="z"), ) @@ -1905,8 +1928,16 @@ def test_gtir_let_lambda_with_tuple2(): ) ) ), - domain=im.get_field_domain( - gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + domain=im.make_tuple( + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(1, "z"), [IDim] + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(2, "z"), [IDim] + ), ), target=gtir.SymRef(id="z"), ) @@ -2248,7 +2279,7 @@ def test_gtir_scan(id, use_symbolic_column_size): im.make_tuple(0.0, True), ) )("x"), - domain=domain, + domain=im.make_tuple(domain, domain), target=im.make_tuple(gtir.SymRef(id="y"), gtir.SymRef(id="z")), ) ],