Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ markers = [
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'uses_concat_where: tests that use the concat_where builtin',
'uses_program_metrics: tests that require backend support for program metrics',
'uses_program_with_sliced_out_arguments: tests that use a sliced argument which is not supported for non-mutable arrays, e.g. JAX',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
74 changes: 67 additions & 7 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Optional, cast
from typing import Optional, Protocol, TypeGuard, cast

import gt4py.eve as eve
import gt4py.eve.extended_typing as xtyping
Expand All @@ -20,13 +20,33 @@
from gt4py._core import definitions as core_defs, ndarray_utils


# TODO take from https://github.com/data-apis/array-api-typing/pull/56/
# only added here for documentation purposes
class NDArrayNamespace(Protocol): ...


def is_array_namespace(obj: object) -> TypeGuard[NDArrayNamespace]:
"""Check whether `obj` is an array namespace.

An array namespace is any module that follows the Array API standard
(https://data-apis.org/array-api/latest/).
"""
return (
hasattr(obj, "empty")
and hasattr(obj, "zeros")
and hasattr(obj, "ones")
and hasattr(obj, "full")
and hasattr(obj, "asarray")
)


@eve.utils.with_fluid_partial
def empty(
domain: common.DomainLike,
dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument]
*,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None,
device: Optional[core_defs.Device] = None,
) -> nd_array_field.NdArrayField:
"""Create a `Field` of uninitialized (undefined) values using the given (or device-default) allocator.
Expand Down Expand Up @@ -77,6 +97,12 @@ def empty(
(3, 3)
"""
dtype = core_defs.dtype(dtype)
if allocator is not None and is_array_namespace(allocator):
domain = common.domain(domain)
shape = domain.shape
return common._field(
allocator.empty(shape, dtype=dtype.scalar_type), domain=domain
) # TODO device
if allocator is None and device is None:
device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0)
buffer = next_allocators.allocate(
Expand All @@ -94,7 +120,7 @@ def zeros(
dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument]
*,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None,
device: Optional[core_defs.Device] = None,
) -> nd_array_field.NdArrayField:
"""Create a Field containing all zeros using the given (or device-default) allocator.
Expand All @@ -108,6 +134,13 @@ def zeros(
>>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray
array([0., 0., 0., 0., 0., 0., 0.])
"""
if allocator is not None and is_array_namespace(allocator):
dtype = core_defs.dtype(dtype) if dtype is not None else None
domain = common.domain(domain)
shape = domain.shape
return common._field(
allocator.zeros(shape, dtype=dtype.scalar_type), domain=domain
) # TODO device
field = empty(
domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
Expand All @@ -121,7 +154,7 @@ def ones(
dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument]
*,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None,
device: Optional[core_defs.Device] = None,
) -> nd_array_field.NdArrayField:
"""Create a Field containing all ones using the given (or device-default) allocator.
Expand All @@ -135,6 +168,13 @@ def ones(
>>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray
array([1., 1., 1., 1., 1., 1., 1.])
"""
if allocator is not None and is_array_namespace(allocator):
dtype = core_defs.dtype(dtype) if dtype is not None else None
domain = common.domain(domain)
shape = domain.shape
return common._field(
allocator.ones(shape, dtype=dtype.scalar_type), domain=domain
) # TODO device
field = empty(
domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
Expand All @@ -149,7 +189,7 @@ def full(
dtype: Optional[core_defs.DTypeLike] = None,
*,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None,
device: Optional[core_defs.Device] = None,
) -> nd_array_field.NdArrayField:
"""Create a Field where all values are set to `fill_value` using the given (or device-default) allocator.
Expand All @@ -168,6 +208,13 @@ def full(
>>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray
array([5, 5, 5])
"""
if allocator is not None and is_array_namespace(allocator):
dtype = core_defs.dtype(dtype) if dtype is not None else None
domain = common.domain(domain)
shape = domain.shape
return common._field(
allocator.full(shape, fill_value, dtype=dtype.scalar_type), domain=domain
) # TODO device
field = empty(
domain=domain,
dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)),
Expand All @@ -187,7 +234,7 @@ def as_field(
*,
origin: Optional[Mapping[common.Dimension, int]] = None,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None,
device: Optional[core_defs.Device] = None,
# TODO: copy=False
) -> nd_array_field.NdArrayField:
Expand Down Expand Up @@ -267,6 +314,11 @@ def as_field(
dtype = core_defs.dtype(dtype)
assert dtype.tensor_shape == () # TODO

if allocator is not None and is_array_namespace(allocator):
return common._field(
allocator.asarray(data, dtype=dtype.scalar_type), domain=actual_domain
) # TODO device

if (allocator is None) and (device is None) and xtyping.supports_dlpack(data):
device = core_defs.Device(*data.__dlpack_device__())

Expand All @@ -290,7 +342,7 @@ def as_connectivity(
data: core_defs.NDArrayObject,
dtype: Optional[core_defs.DType] = None,
*,
allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None,
allocator: Optional[next_allocators.FieldBufferAllocationUtil | NDArrayNamespace] = None,
device: Optional[core_defs.Device] = None,
skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING,
# TODO: copy=False
Expand Down Expand Up @@ -349,6 +401,14 @@ def as_connectivity(
dtype = core_defs.dtype(dtype)
assert dtype.tensor_shape == () # TODO

if allocator is not None and is_array_namespace(allocator):
# return common._field(
# allocator.asarray(data, dtype=dtype.scalar_type), domain=actual_domain
# ) # TODO device
return common._connectivity(
data, codomain=codomain, domain=actual_domain, skip_value=skip_value
)

if (allocator is None) and (device is None) and xtyping.supports_dlpack(data):
device = core_defs.Device(*data.__dlpack_device__())
buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device)
Expand Down
17 changes: 15 additions & 2 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
cp: Optional[ModuleType] = None # type: ignore[no-redef]

try:
import jax
from jax import numpy as jnp
except ImportError:
jax: Optional[ModuleType] = None # type: ignore[no-redef]
jnp: Optional[ModuleType] = None # type: ignore[no-redef]

try:
Expand Down Expand Up @@ -1087,7 +1089,10 @@ class CuPyArrayConnectivityField(NdArrayConnectivityField):

# JAX
if jnp:
assert jax is not None

_nd_array_implementations.append(jnp)
jax.config.update("jax_enable_x64", True)

@dataclasses.dataclass(frozen=True, eq=False)
class JaxArrayField(NdArrayField):
Expand All @@ -1102,8 +1107,16 @@ def __setitem__(
index: common.AnyIndexSpec,
value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT,
) -> None:
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.")
target_domain, target_slice = self._slice(index)

if isinstance(value, common.Field):
if not value.domain == target_domain:
raise ValueError(
f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'."
)
value = value.ndarray

object.__setattr__(self, "_ndarray", self._ndarray.at[target_slice].set(value))

common._field.register(jnp.ndarray, JaxArrayField.from_array)

Expand Down
9 changes: 8 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest

from gt4py.next import allocators as next_allocators
from gt4py.next.embedded import nd_array_field


# Skip definitions
Expand Down Expand Up @@ -65,11 +66,13 @@ class EmbeddedDummyBackend:
cupy_execution = EmbeddedDummyBackend(
"EmbeddedCuPy", next_allocators.StandardGPUFieldBufferAllocator()
)
jax_numpy_execution = EmbeddedDummyBackend("EmbeddedJaxNumPy", nd_array_field.jnp)


class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum):
NUMPY_EXECUTION = "next_tests.definitions.numpy_execution"
CUPY_EXECUTION = "next_tests.definitions.cupy_execution"
JAX_NUMPY_EXECUTION = "next_tests.definitions.jax_numpy_execution"


class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
Expand Down Expand Up @@ -128,6 +131,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_PROGRAM_METRICS = "uses_program_metrics"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
USES_CONCAT_WHERE = "uses_concat_where"
USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS = "uses_program_with_sliced_out_arguments"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"

# Skip messages (available format keys: 'marker', 'backend')
Expand All @@ -138,7 +142,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
)
# Common list of feature markers to skip
COMMON_SKIP_TEST_LIST = [
(REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
Expand Down Expand Up @@ -172,6 +175,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
), # we can't extract the field type from scan args
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
JAX_EMBEDDED_SKIP_LIST = EMBEDDED_SKIP_LIST + [
(USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS, XFAIL, UNSUPPORTED_MESSAGE),
]
ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [
(USES_PROGRAM_METRICS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
Expand Down Expand Up @@ -199,6 +205,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
BACKEND_SKIP_TEST_MATRIX = {
EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST,
EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST,
EmbeddedIds.JAX_NUMPY_EXECUTION: JAX_EMBEDDED_SKIP_LIST,
OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __gt_allocator__(
pytest.param(
next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu
),
pytest.param(
next_tests.definitions.EmbeddedIds.JAX_NUMPY_EXECUTION,
marks=pytest.mark.requires_jax,
),
pytest.param(
next_tests.definitions.OptionalProgramBackendId.DACE_CPU,
marks=pytest.mark.requires_dace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def testee(inp: IField) -> IField:
out = cases.allocate(
cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42)
)()
ref = inp.array_ns.zeros(size)
ref = np.zeros(size)
ref[0] = ref[-1] = 42
ref[1:-1] = inp.ndarray[1:-1]
ref[1:-1] = inp.asnumpy()[1:-1]

cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref)

Expand All @@ -323,8 +323,8 @@ def testee(inp: IField) -> tuple[IField, IField]:
out = cases.allocate(
cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42)
)()
ref = inp.array_ns.zeros(size)
ref = np.zeros(size)
ref[0] = ref[-1] = 42
ref[1:-1] = inp.ndarray[1:-1]
ref[1:-1] = inp.asnumpy()[1:-1]

cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=(ref, ref))
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def testee(
)


@pytest.mark.uses_program_with_sliced_out_arguments
def test_single_value_field(cartesian_case):
@gtx.field_operator
def testee_fo(a: cases.IKField) -> cases.IKField:
Expand Down Expand Up @@ -1199,6 +1200,7 @@ def program_domain(
)


@pytest.mark.uses_program_with_sliced_out_arguments
def test_domain_tuple(cartesian_case):
@gtx.field_operator
def fieldop_domain_tuple(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_identity_fo_execution(cartesian_case, identity_def):
)


@pytest.mark.uses_program_with_sliced_out_arguments
@pytest.mark.uses_cartesian_shift
def test_shift_by_one_execution(cartesian_case):
@gtx.field_operator
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_double_copy_execution(cartesian_case, double_copy_program_def):
)


@pytest.mark.uses_program_with_sliced_out_arguments
def test_copy_restricted_execution(cartesian_case, copy_restrict_program_def):
copy_restrict_program = gtx.program(copy_restrict_program_def, backend=cartesian_case.backend)

Expand Down Expand Up @@ -154,6 +156,7 @@ def prog(
assert np.allclose((a.asnumpy(), b.asnumpy()), (out_a.asnumpy(), out_b.asnumpy()))


@pytest.mark.uses_program_with_sliced_out_arguments
def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case):
@gtx.field_operator
def pack_tuple(
Expand Down Expand Up @@ -279,12 +282,12 @@ def identity(a: cases.IField) -> cases.IField:
def copy_program(a: cases.IField, out: cases.IField):
identity(a, out=out, domain={IDim: (1, 9)})

inp = constructors.empty(
inp = constructors.full(
common.domain({IDim: (1, 9)}),
42,
dtype=np.int32,
allocator=cartesian_case.allocator,
)
inp.ndarray[...] = 42
out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})()
ref = out.asnumpy().copy() # ensure we are not writing to `out` outside the domain
ref[1:9] = inp.asnumpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@

pytest.importorskip("atlas4py")

import gt4py._core.definitions as core_defs
from gt4py import next as gtx
from gt4py.next import allocators, neighbor_sum
from gt4py.next.iterator import atlas_utils
from gt4py.next import neighbor_sum

from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
Expand Down Expand Up @@ -65,10 +63,6 @@ def pnabla(

@pytest.mark.requires_atlas
def test_ffront_compute_zavgS(exec_alloc_descriptor):
# TODO(havogt): fix nabla setup to work with GPU
if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU:
pytest.skip("This test is only supported on CPU devices yet")

setup = nabla_setup(allocator=exec_alloc_descriptor.allocator)

zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=exec_alloc_descriptor.allocator)
Expand All @@ -88,10 +82,6 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor):

@pytest.mark.requires_atlas
def test_ffront_nabla(exec_alloc_descriptor):
# TODO(havogt): fix nabla setup to work with GPU
if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU:
pytest.skip("This test is only supported on CPU devices yet")

setup = nabla_setup(allocator=exec_alloc_descriptor.allocator)

pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator)
Expand Down
Loading
Loading