diff --git a/pyproject.toml b/pyproject.toml index 0461d13a78..dd00401726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -304,6 +304,7 @@ markers = [ 'uses_mesh_with_skip_values: tests that use a mesh with skip values', 'uses_concat_where: tests that use the concat_where builtin', 'uses_program_metrics: tests that require backend support for program metrics', + 'uses_program_with_sliced_out_arguments: tests that use a sliced argument which is not supported for non-mutable arrays, e.g. JAX', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 14adb85d0a..6435f1a8b4 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -9,7 +9,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Optional, cast +from typing import Optional, Protocol, TypeGuard, cast import gt4py.eve as eve import gt4py.eve.extended_typing as xtyping @@ -20,13 +20,33 @@ from gt4py._core import definitions as core_defs, ndarray_utils +# TODO take from https://github.com/data-apis/array-api-typing/pull/56/ +# only added here for documentation purposes +class NDArrayNamespace(Protocol): ... + + +def is_array_namespace(obj: object) -> TypeGuard[NDArrayNamespace]: + """Check whether `obj` is an array namespace. + + An array namespace is any module that follows the Array API standard + (https://data-apis.org/array-api/latest/). + """ + return ( + hasattr(obj, "empty") + and hasattr(obj, "zeros") + and hasattr(obj, "ones") + and hasattr(obj, "full") + and hasattr(obj, "asarray") + ) + + @eve.utils.with_fluid_partial def empty( domain: common.DomainLike, dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument] *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None, device: Optional[core_defs.Device] = None, ) -> nd_array_field.NdArrayField: """Create a `Field` of uninitialized (undefined) values using the given (or device-default) allocator. @@ -77,6 +97,12 @@ def empty( (3, 3) """ dtype = core_defs.dtype(dtype) + if allocator is not None and is_array_namespace(allocator): + domain = common.domain(domain) + shape = domain.shape + return common._field( + allocator.empty(shape, dtype=dtype.scalar_type), domain=domain + ) # TODO device if allocator is None and device is None: device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) buffer = next_allocators.allocate( @@ -94,7 +120,7 @@ def zeros( dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument] *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None, device: Optional[core_defs.Device] = None, ) -> nd_array_field.NdArrayField: """Create a Field containing all zeros using the given (or device-default) allocator. @@ -108,6 +134,13 @@ def zeros( >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ + if allocator is not None and is_array_namespace(allocator): + dtype = core_defs.dtype(dtype) if dtype is not None else None + domain = common.domain(domain) + shape = domain.shape + return common._field( + allocator.zeros(shape, dtype=dtype.scalar_type), domain=domain + ) # TODO device field = empty( domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device ) @@ -121,7 +154,7 @@ def ones( dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument] *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None, device: Optional[core_defs.Device] = None, ) -> nd_array_field.NdArrayField: """Create a Field containing all ones using the given (or device-default) allocator. @@ -135,6 +168,13 @@ def ones( >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ + if allocator is not None and is_array_namespace(allocator): + dtype = core_defs.dtype(dtype) if dtype is not None else None + domain = common.domain(domain) + shape = domain.shape + return common._field( + allocator.ones(shape, dtype=dtype.scalar_type), domain=domain + ) # TODO device field = empty( domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device ) @@ -149,7 +189,7 @@ def full( dtype: Optional[core_defs.DTypeLike] = None, *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None, device: Optional[core_defs.Device] = None, ) -> nd_array_field.NdArrayField: """Create a Field where all values are set to `fill_value` using the given (or device-default) allocator. @@ -168,6 +208,13 @@ def full( >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ + if allocator is not None and is_array_namespace(allocator): + dtype = core_defs.dtype(dtype) if dtype is not None else None + domain = common.domain(domain) + shape = domain.shape + return common._field( + allocator.full(shape, fill_value, dtype=dtype.scalar_type), domain=domain + ) # TODO device field = empty( domain=domain, dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)), @@ -187,7 +234,7 @@ def as_field( *, origin: Optional[Mapping[common.Dimension, int]] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None, device: Optional[core_defs.Device] = None, # TODO: copy=False ) -> nd_array_field.NdArrayField: @@ -267,6 +314,11 @@ def as_field( dtype = core_defs.dtype(dtype) assert dtype.tensor_shape == () # TODO + if allocator is not None and is_array_namespace(allocator): + return common._field( + allocator.asarray(data, dtype=dtype.scalar_type), domain=actual_domain + ) # TODO device + if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) @@ -290,7 +342,7 @@ def as_connectivity( data: core_defs.NDArrayObject, dtype: Optional[core_defs.DType] = None, *, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None, device: Optional[core_defs.Device] = None, skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, # TODO: copy=False @@ -349,6 +401,14 @@ def as_connectivity( dtype = core_defs.dtype(dtype) assert dtype.tensor_shape == () # TODO + if allocator is not None and is_array_namespace(allocator): + # return common._field( + # allocator.asarray(data, dtype=dtype.scalar_type), domain=actual_domain + # ) # TODO device + return common._connectivity( + data, codomain=codomain, domain=actual_domain, skip_value=skip_value + ) + if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index f4aee67332..234487b930 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -43,8 +43,10 @@ cp: Optional[ModuleType] = None # type: ignore[no-redef] try: + import jax from jax import numpy as jnp except ImportError: + jax: Optional[ModuleType] = None # type: ignore[no-redef] jnp: Optional[ModuleType] = None # type: ignore[no-redef] try: @@ -1087,7 +1089,10 @@ class CuPyArrayConnectivityField(NdArrayConnectivityField): # JAX if jnp: + assert jax is not None + _nd_array_implementations.append(jnp) + jax.config.update("jax_enable_x64", True) @dataclasses.dataclass(frozen=True, eq=False) class JaxArrayField(NdArrayField): @@ -1102,8 +1107,16 @@ def __setitem__( index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` - raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") + target_domain, target_slice = self._slice(index) + + if isinstance(value, common.Field): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'." + ) + value = value.ndarray + + object.__setattr__(self, "_ndarray", self._ndarray.at[target_slice].set(value)) common._field.register(jnp.ndarray, JaxArrayField.from_array) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 0bff0b0aa7..94dac86da8 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -16,6 +16,7 @@ import pytest from gt4py.next import allocators as next_allocators +from gt4py.next.embedded import nd_array_field # Skip definitions @@ -65,11 +66,13 @@ class EmbeddedDummyBackend: cupy_execution = EmbeddedDummyBackend( "EmbeddedCuPy", next_allocators.StandardGPUFieldBufferAllocator() ) +jax_numpy_execution = EmbeddedDummyBackend("EmbeddedJaxNumPy", nd_array_field.jnp) class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): NUMPY_EXECUTION = "next_tests.definitions.numpy_execution" CUPY_EXECUTION = "next_tests.definitions.cupy_execution" + JAX_NUMPY_EXECUTION = "next_tests.definitions.jax_numpy_execution" class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): @@ -128,6 +131,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_PROGRAM_METRICS = "uses_program_metrics" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" USES_CONCAT_WHERE = "uses_concat_where" +USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS = "uses_program_with_sliced_out_arguments" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') @@ -138,7 +142,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ) # Common list of feature markers to skip COMMON_SKIP_TEST_LIST = [ - (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), @@ -172,6 +175,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ), # we can't extract the field type from scan args (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] +JAX_EMBEDDED_SKIP_LIST = EMBEDDED_SKIP_LIST + [ + (USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS, XFAIL, UNSUPPORTED_MESSAGE), +] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_PROGRAM_METRICS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), @@ -199,6 +205,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, + EmbeddedIds.JAX_NUMPY_EXECUTION: JAX_EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 7640553e6a..d9c8bb1e8f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -82,6 +82,10 @@ def __gt_allocator__( pytest.param( next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu ), + pytest.param( + next_tests.definitions.EmbeddedIds.JAX_NUMPY_EXECUTION, + marks=pytest.mark.requires_jax, + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index e88a72b3a6..c840ddd219 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -305,9 +305,9 @@ def testee(inp: IField) -> IField: out = cases.allocate( cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42) )() - ref = inp.array_ns.zeros(size) + ref = np.zeros(size) ref[0] = ref[-1] = 42 - ref[1:-1] = inp.ndarray[1:-1] + ref[1:-1] = inp.asnumpy()[1:-1] cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref) @@ -323,8 +323,8 @@ def testee(inp: IField) -> tuple[IField, IField]: out = cases.allocate( cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42) )() - ref = inp.array_ns.zeros(size) + ref = np.zeros(size) ref[0] = ref[-1] = 42 - ref[1:-1] = inp.ndarray[1:-1] + ref[1:-1] = inp.asnumpy()[1:-1] cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=(ref, ref)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..6a4db8e52a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -496,6 +496,7 @@ def testee( ) +@pytest.mark.uses_program_with_sliced_out_arguments def test_single_value_field(cartesian_case): @gtx.field_operator def testee_fo(a: cases.IKField) -> cases.IKField: @@ -1199,6 +1200,7 @@ def program_domain( ) +@pytest.mark.uses_program_with_sliced_out_arguments def test_domain_tuple(cartesian_case): @gtx.field_operator def fieldop_domain_tuple( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 1abaa47d03..3339858c8e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -52,6 +52,7 @@ def test_identity_fo_execution(cartesian_case, identity_def): ) +@pytest.mark.uses_program_with_sliced_out_arguments @pytest.mark.uses_cartesian_shift def test_shift_by_one_execution(cartesian_case): @gtx.field_operator @@ -95,6 +96,7 @@ def test_double_copy_execution(cartesian_case, double_copy_program_def): ) +@pytest.mark.uses_program_with_sliced_out_arguments def test_copy_restricted_execution(cartesian_case, copy_restrict_program_def): copy_restrict_program = gtx.program(copy_restrict_program_def, backend=cartesian_case.backend) @@ -154,6 +156,7 @@ def prog( assert np.allclose((a.asnumpy(), b.asnumpy()), (out_a.asnumpy(), out_b.asnumpy())) +@pytest.mark.uses_program_with_sliced_out_arguments def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): @gtx.field_operator def pack_tuple( @@ -279,12 +282,12 @@ def identity(a: cases.IField) -> cases.IField: def copy_program(a: cases.IField, out: cases.IField): identity(a, out=out, domain={IDim: (1, 9)}) - inp = constructors.empty( + inp = constructors.full( common.domain({IDim: (1, 9)}), + 42, dtype=np.int32, allocator=cartesian_case.allocator, ) - inp.ndarray[...] = 42 out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})() ref = out.asnumpy().copy() # ensure we are not writing to `out` outside the domain ref[1:9] = inp.asnumpy() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index da354be7ea..6c72bff19d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -13,10 +13,8 @@ pytest.importorskip("atlas4py") -import gt4py._core.definitions as core_defs from gt4py import next as gtx -from gt4py.next import allocators, neighbor_sum -from gt4py.next.iterator import atlas_utils +from gt4py.next import neighbor_sum from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -65,10 +63,6 @@ def pnabla( @pytest.mark.requires_atlas def test_ffront_compute_zavgS(exec_alloc_descriptor): - # TODO(havogt): fix nabla setup to work with GPU - if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU: - pytest.skip("This test is only supported on CPU devices yet") - setup = nabla_setup(allocator=exec_alloc_descriptor.allocator) zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=exec_alloc_descriptor.allocator) @@ -88,10 +82,6 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): @pytest.mark.requires_atlas def test_ffront_nabla(exec_alloc_descriptor): - # TODO(havogt): fix nabla setup to work with GPU - if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU: - pytest.skip("This test is only supported on CPU devices yet") - setup = nabla_setup(allocator=exec_alloc_descriptor.allocator) pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 6d7fd9df2b..8cb14ba384 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -214,6 +214,7 @@ class setup: @pytest.mark.uses_tuple_returns +@pytest.mark.uses_program_with_sliced_out_arguments def test_solve_nonhydro_stencil_52_like_z_q(test_setup): cases.verify( test_setup.case, @@ -232,6 +233,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): @pytest.mark.uses_tuple_returns +@pytest.mark.uses_program_with_sliced_out_arguments def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") @@ -251,6 +253,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns +@pytest.mark.uses_program_with_sliced_out_arguments def test_solve_nonhydro_stencil_52_like(test_setup): cases.run( test_setup.case, @@ -267,6 +270,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup): @pytest.mark.uses_tuple_returns +@pytest.mark.uses_program_with_sliced_out_arguments def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 5c6bd5a54a..f6a184ae27 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -84,6 +84,7 @@ def skewedlap_ref(inp): return -4.0 * inp[1:-1, 1:-1] + inp[2:, 2:] + inp[2:, :-2] + inp[:-2, 2:] + inp[:-2, :-2] +@pytest.mark.uses_program_with_sliced_out_arguments def test_ffront_lap(cartesian_case): in_field = cases.allocate(cartesian_case, lap_program, "in_field")() in_field = square(in_field) @@ -99,6 +100,7 @@ def test_ffront_lap(cartesian_case): ) +@pytest.mark.uses_program_with_sliced_out_arguments def test_ffront_skewedlap(cartesian_case): in_field = cases.allocate(cartesian_case, skewedlap_program, "in_field")() in_field = square(in_field) @@ -114,6 +116,7 @@ def test_ffront_skewedlap(cartesian_case): ) +@pytest.mark.uses_program_with_sliced_out_arguments def test_ffront_laplap(cartesian_case): in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() in_field = square(in_field) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index dd30caa726..7b96da2201 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -219,6 +219,7 @@ def prog_slicing( ) +@pytest.mark.uses_program_with_sliced_out_arguments def test_program_slicing(cartesian_case): a = cases.allocate(cartesian_case, prog, "a")() b = cases.allocate(cartesian_case, prog, "b")() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index cd853bd69c..760575c9ca 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -294,6 +294,7 @@ def test_nabla_sign(program_processor): pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + vertex_index = gtx.as_field([Vertex], np.arange(setup.nodes_size, dtype=np.int32)) run_processor( nabla_sign, @@ -305,7 +306,7 @@ def test_nabla_sign(program_processor): S_MXX, S_MYY, setup.vol_field, - embedded.index_field(Vertex), + vertex_index, # TODO(havogt): should be an index function field setup.is_pole_edge_field, offset_provider={ "E2V": setup.edges2node_connectivity,