Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .test-conda-env-py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pyopencl
- python=3
- gmsh
- jax

# test scripts use ompi-specific arguments
- openmpi
Expand Down
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
intersphinx_mapping = {
"arraycontext": ("https://documen.tician.de/arraycontext/", None),
"loopy": ("https://documen.tician.de/loopy/", None),
"jax": ("https://docs.jax.dev/en/latest/", None),
"meshmode": ("https://documen.tician.de/meshmode/", None),
"modepy": ("https://documen.tician.de/modepy/", None),
"mpi4py": ("https://mpi4py.readthedocs.io/en/stable", None),
Expand All @@ -33,6 +34,7 @@
os.environ["PYOPENCL_TEST"] = "port:cpu"

nitpick_ignore_regex = [
["py:mod", r"jax"], # FIXME: not sure why this does not work
["py:class", r"np\.ndarray"],
["py:data|py:class", r"arraycontext.*ContainerTc"],
]
Expand Down
35 changes: 34 additions & 1 deletion grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. autoclass:: MPIPyOpenCLArrayContext
.. autoclass:: MPINumpyArrayContext
.. class:: MPIPytatoArrayContext
.. autoclass:: MPIEagerJAXArrayContext
.. autofunction:: get_reasonable_array_context_class
"""

Expand Down Expand Up @@ -75,9 +76,10 @@
_HAVE_FUSION_ACTX = False


from arraycontext import ArrayContext, NumpyArrayContext
from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
from arraycontext.pytest import (
_PytestEagerJaxArrayContextFactory,
_PytestNumpyArrayContextFactory,
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
Expand Down Expand Up @@ -443,6 +445,26 @@
# }}}


# {{{ distributed + eager jax

class MPIEagerJAXArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):

Check failure on line 450 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Base classes for class "MPIEagerJAXArrayContext" define method "einsum" in incompatible way   Return type mismatch: base method returns type "Array", override returns type "Array"     "Array" is incompatible with protocol "Array"       "arraycontext.typing.Array" is not assignable to "jax._src.basearray.Array"       "arraycontext.typing.Array" is not assignable to "jax._src.basearray.Array"       "arraycontext.typing.Array" is not assignable to "jax._src.basearray.Array"       "__getitem__" is an incompatible type         Type "(key: Unknown) -> Array" is not assignable to type "(index: Any) -> Array"           Parameter name mismatch: "index" versus "key" ... (reportIncompatibleMethodOverride)

Check failure on line 450 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Multiple inheritance is not allowed because the following base classes contain `__init__` or `__new__` methods that may not get called: ArrayContext (reportUnsafeMultipleInheritance)
"""An array context for using distributed computation with :mod:`jax`
eager evaluation.

.. autofunction:: __init__
"""

def __init__(self, mpi_communicator) -> None:

Check warning on line 457 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "mpi_communicator" (reportMissingParameterType)

Check warning on line 457 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "mpi_communicator" is unknown (reportUnknownParameterType)
super().__init__()

self.mpi_communicator = mpi_communicator

Check warning on line 460 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation for attribute `mpi_communicator` is required because this class is not decorated with `@final` (reportUnannotatedClassAttribute)

def clone(self) -> Self:

Check warning on line 462 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Method "clone" is not marked as override but is overriding a method in class "EagerJAXArrayContext" (reportImplicitOverride)
return type(self)(self.mpi_communicator)

# }}}


# {{{ distributed + pytato array context subclasses

class MPIBasePytatoPyOpenCLArrayContext(
Expand Down Expand Up @@ -542,12 +564,23 @@
return self.actx_class()


class PytestEagerJAXArrayContextFactory(_PytestEagerJaxArrayContextFactory):
actx_class = EagerJAXArrayContext

Check warning on line 568 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation for attribute `actx_class` is required because this class is not decorated with `@final` (reportUnannotatedClassAttribute)

def __call__(self):

Check warning on line 570 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Method "__call__" is not marked as override but is overriding a method in class "_PytestEagerJaxArrayContextFactory" (reportImplicitOverride)
import jax
jax.config.update("jax_enable_x64", True)

Check warning on line 572 in grudge/array_context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "update" is partially unknown   Type of "update" is "(name: Unknown, val: Unknown) -> None" (reportUnknownMemberType)
return self.actx_class()


register_pytest_array_context_factory("grudge.pyopencl",
PytestPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.pytato-pyopencl",
PytestPytatoPyOpenCLArrayContextFactory)
register_pytest_array_context_factory("grudge.numpy",
PytestNumpyArrayContextFactory)
register_pytest_array_context_factory("grudge.eager-jax",
PytestEagerJAXArrayContextFactory)

# }}}

Expand Down
22 changes: 12 additions & 10 deletions grudge/geometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,22 +597,24 @@
dd_base.untrace(), dd_base
)
assert isinstance(all_faces_conn, DirectDiscretizationConnection)
signed_ones = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE)).zeros(
actx, dtype=dcoll.real_dtype
) + 1

signed_face_ones_numpy = actx.to_numpy(signed_ones)
discr = dcoll.discr_from_dd(dd.with_discr_tag(DISCR_TAG_BASE))

new_group_arrays = []

for dgrp, grp in zip(discr.groups, all_faces_conn.groups, strict=True):
sign = np.ones((dgrp.nelements, dgrp.nunit_dofs),
dtype=discr.real_dtype)

for igrp, grp in enumerate(all_faces_conn.groups):
for batch in grp.batches:
assert batch.to_element_face is not None
i = actx.to_numpy(actx.thaw(batch.to_element_indices))
grp_field = signed_face_ones_numpy[igrp].reshape(-1)
grp_field[i] = ( # pyright: ignore[reportIndexIssue]
(2.0 * (batch.to_element_face % 2) - 1.0) * grp_field[i]
)
sign[i, :] = 2.0 * (batch.to_element_face % 2) - 1.0

new_group_arrays.append(sign)

Check warning on line 614 in grudge/geometry/metrics.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "append" is partially unknown   Type of "append" is "(object: Unknown, /) -> None" (reportUnknownMemberType)

return actx.from_numpy(signed_face_ones_numpy)
from meshmode.dof_array import DOFArray
return actx.from_numpy(DOFArray(actx, tuple(new_group_arrays)))

Check warning on line 617 in grudge/geometry/metrics.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__new__"   Argument type is "list[Unknown]" (reportUnknownArgumentType)

Check warning on line 617 in grudge/geometry/metrics.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is partially unknown   Argument corresponds to parameter "data" in function "__init__"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)


def parametrization_derivative(
Expand Down
4 changes: 3 additions & 1 deletion test/test_dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from arraycontext import ArrayContextFactory, pytest_generate_tests_for_array_contexts

from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
Expand All @@ -40,7 +41,8 @@
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])

import logging

Expand Down
10 changes: 8 additions & 2 deletions test/test_euler_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,18 @@
)

from grudge import op
from grudge.array_context import PytestPyOpenCLArrayContextFactory
from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
)


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
[PytestPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])


@pytest.mark.parametrize("order", [1, 2, 3])
Expand Down
10 changes: 8 additions & 2 deletions test/test_grudge.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@
from meshmode.mesh import TensorProductElementGroup

from grudge import dof_desc, geometry, op
from grudge.array_context import PytestPyOpenCLArrayContextFactory
from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
[PytestPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])


# {{{ mass operator trig integration
Expand Down
4 changes: 3 additions & 1 deletion test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from meshmode.dof_array import flat_norm

from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
Expand All @@ -47,7 +48,8 @@
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory,
PytestPytatoPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory])
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])


# {{{ inverse metric
Expand Down
10 changes: 8 additions & 2 deletions test/test_modal_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@

from arraycontext import ArrayContextFactory, pytest_generate_tests_for_array_contexts

from grudge.array_context import PytestPyOpenCLArrayContextFactory
from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection


pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
[PytestPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])

import pytest

Expand Down
14 changes: 12 additions & 2 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
from meshmode.dof_array import flat_norm

from grudge import dof_desc, op
from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext
from grudge.array_context import (
MPIEagerJAXArrayContext,
MPINumpyArrayContext,
MPIPyOpenCLArrayContext,
MPIPytatoArrayContext,
)
from grudge.discretization import make_discretization_collection
from grudge.shortcuts import compiled_lsrk45_step

Expand All @@ -52,7 +57,8 @@ class SimpleTag:

# {{{ mpi test infrastructure

DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext]
DISTRIBUTED_ACTXS = [MPIPyOpenCLArrayContext, MPIPytatoArrayContext,
MPIEagerJAXArrayContext, MPINumpyArrayContext]


def run_test_with_mpi(num_ranks, f, *args):
Expand Down Expand Up @@ -90,6 +96,10 @@ def run_test_with_mpi_inner():
actx = actx_class(comm, queue, mpi_base_tag=15000)
elif actx_class is MPIPyOpenCLArrayContext:
actx = actx_class(comm, queue)
elif actx_class is MPIEagerJAXArrayContext:
actx = actx_class(comm)
elif actx_class is MPINumpyArrayContext:
actx = actx_class(comm)
else:
raise ValueError("unknown actx_class")

Expand Down
10 changes: 8 additions & 2 deletions test/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from meshmode.mesh import BTAG_ALL

from grudge import geometry, op
from grudge.array_context import PytestPyOpenCLArrayContextFactory
from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection
from grudge.dof_desc import (
DISCR_TAG_BASE,
Expand All @@ -55,7 +59,9 @@

logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
[PytestPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])


# {{{ gradient
Expand Down
10 changes: 8 additions & 2 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@
from meshmode.dof_array import DOFArray

from grudge import op
from grudge.array_context import PytestPyOpenCLArrayContextFactory
from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
[PytestPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])


@pytest.mark.parametrize(("mesh_size", "with_initial"), [
Expand Down Expand Up @@ -169,7 +175,7 @@

# {{{ Array container tests

@with_container_arithmetic(bcasts_across_obj_array=False,

Check warning on line 178 in test/test_reductions.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3

Broadcasting array context array types across <class 'test_reductions.MyContainer'> has been implicitly enabled. As of 2026, this will no longer work. Use arraycontext.Bcast* object wrappers for roughly equivalent functionality. See the discussion in https://github.com/inducer/arraycontext/pull/190. To opt out now (and avoid this warning), pass _bcast_actx_array_type=False.

Check warning on line 178 in test/test_reductions.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3

Broadcasting array context array types across <class 'test_reductions.MyContainer'> has been implicitly enabled. As of 2026, this will no longer work. Use arraycontext.Bcast* object wrappers for roughly equivalent functionality. See the discussion in https://github.com/inducer/arraycontext/pull/190. To opt out now (and avoid this warning), pass _bcast_actx_array_type=False.
eq_comparison=False,
rel_comparison=False,
_cls_has_array_context_attr=True,
Expand Down
4 changes: 0 additions & 4 deletions test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,12 @@
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])

import logging

import pytest

import pytools.obj_array as obj_array


logger = logging.getLogger(__name__)


# {{{ map_subarrays and rec_map_subarrays

@dataclass(frozen=True, eq=True)
Expand Down
10 changes: 8 additions & 2 deletions test/test_trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@
from arraycontext import ArrayContextFactory, pytest_generate_tests_for_array_contexts
from meshmode.dof_array import DOFArray

from grudge.array_context import PytestPyOpenCLArrayContextFactory
from grudge.array_context import (
PytestEagerJAXArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
)
from grudge.discretization import make_discretization_collection
from grudge.trace_pair import TracePair


logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])
[PytestPyOpenCLArrayContextFactory,
PytestNumpyArrayContextFactory,
PytestEagerJAXArrayContextFactory])


def test_trace_pair(actx_factory: ArrayContextFactory):
Expand Down
Loading