diff --git a/pyproject.toml b/pyproject.toml index 70936ed2..9d3e7a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ filterwarnings = [ "ignore:ast\\.Str is deprecated:DeprecationWarning", "ignore:numpy\\.ndarray size changed:RuntimeWarning", "ignore:Passing arguments 'a':DeprecationWarning", # TODO: from diffrax + "ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning", ] log_cli_level = "INFO" markers = [ diff --git a/src/galax/_interop/galax_interop_gala/potential.py b/src/galax/_interop/galax_interop_gala/potential.py index 6e20fba4..253feada 100644 --- a/src/galax/_interop/galax_interop_gala/potential.py +++ b/src/galax/_interop/galax_interop_gala/potential.py @@ -6,6 +6,7 @@ import equinox as eqx import gala.potential as gp +import jax.numpy as jnp from astropy.units import Quantity as APYQuantity from gala.units import ( DimensionlessUnitSystem as GalaDimensionlessUnitSystem, @@ -1170,10 +1171,144 @@ def galax_to_gala(pot: gpx.LMJ09LogarithmicPotential, /) -> gp.LogarithmicPotent # ----------------------------------------------------------------------------- -# NFW potentials +# Multipole potentials @dispatch # type: ignore[misc] +def gala_to_galax( + gala: gp.MultipolePotential, / +) -> gpx.MultipoleInnerPotential | gpx.MultipoleOuterPotential | gpx.PotentialFrame: + params = gala.parameters + cls = ( + gpx.MultipoleInnerPotential + if params["inner"] == 1 + else gpx.MultipoleOuterPotential + ) + + l_max = gala._lmax # noqa: SLF001 + Slm = jnp.zeros((l_max + 1, l_max + 1), dtype=float) + Tlm = jnp.zeros_like(Slm) + + for l, m in zip(*jnp.tril_indices(l_max + 1), strict=True): + skey = f"S{l}{m}" + if skey in params: + Slm = Slm.at[l, m].set(params[skey]) + + tkey = f"T{l}{m}" + if tkey in params: + Tlm = Tlm.at[l, m].set(params[tkey]) + + pot = cls( + m_tot=params["m"], + r_s=params["r_s"], + l_max=l_max, + Slm=Slm, + Tlm=Tlm, + units=gala.units, + ) + return _apply_frame(_get_frame(gala), pot) + + +@dispatch.multi((gpx.MultipoleInnerPotential,), (gpx.MultipoleOuterPotential,)) # type: ignore[misc] +def galax_to_gala( + pot: gpx.MultipoleInnerPotential | gpx.MultipoleOuterPotential, / +) -> gp.MultipolePotential: + """Convert a Galax Multipole to a Gala potential.""" + _error_if_not_all_constant_parameters(pot, "m_tot", "r_s", "Slm", "Tlm") + + Slm, Tlm = pot.Slm(0).value, pot.Tlm(0).value + ls, ms = jnp.tril_indices(pot.l_max + 1) + + return gp.MultipolePotential( + m=convert(pot.m_tot(0), APYQuantity), + r_s=convert(pot.r_s(0), APYQuantity), + lmax=pot.l_max, + **{ + f"S{l}{m}": Slm[l, m] for l, m in zip(ls, ms, strict=True) if Slm[l, m] != 0 + }, + **{ + f"T{l}{m}": Tlm[l, m] for l, m in zip(ls, ms, strict=True) if Tlm[l, m] != 0 + }, + inner=isinstance(pot, gpx.MultipoleInnerPotential), + units=_galax_to_gala_units(pot.units), + ) + + +# ----------------------------------------------------------------------------- +# NFW potentials + + +@dispatch +def gala_to_galax(gala: gp.NFWPotential, /) -> gpx.NFWPotential | gpx.PotentialFrame: + """Convert a Gala NFWPotential to a Galax potential. + + Examples + -------- + >>> import gala.potential as gp + >>> import gala.units as gu + >>> import galax.potential as gpx + + >>> gpot = gp.NFWPotential(m=1e12, r_s=20, units=gu.galactic) + >>> gpx.io.convert_potential(gpx.io.GalaxLibrary, gpot) + NFWPotential( + units=UnitSystem(kpc, Myr, solMass, rad), + constants=ImmutableMap({'G': ...}), + m=ConstantParameter( ... ), + r_s=ConstantParameter( ... ) + ) + + """ + params = gala.parameters + pot = gpx.NFWPotential(m=params["m"], r_s=params["r_s"], units=gala.units) + return _apply_frame(_get_frame(gala), pot) + + +@dispatch +def gala_to_galax( + pot: gp.LeeSutoTriaxialNFWPotential, / +) -> gpx.LeeSutoTriaxialNFWPotential: + """Convert a :class:`gala.potential.LeeSutoTriaxialNFWPotential` to a :class:`galax.potential.LeeSutoTriaxialNFWPotential`. + + Examples + -------- + >>> import gala.potential as gp + >>> import gala.units as gu + >>> import galax.potential as gpx + + >>> gpot = gp.LeeSutoTriaxialNFWPotential( + ... v_c=220, r_s=20, a=1, b=0.9, c=0.8, units=gu.galactic ) + >>> gpx.io.convert_potential(gpx.io.GalaxLibrary, gpot) + LeeSutoTriaxialNFWPotential( + units=UnitSystem(kpc, Myr, solMass, rad), + constants=ImmutableMap({'G': ...}), + m=ConstantParameter( ... ), + r_s=ConstantParameter( ... ), + a1=ConstantParameter( ... ), + a2=ConstantParameter( ... ), + a3=ConstantParameter( ... ) + ) + + """ # noqa: E501 + units = pot.units + params = pot.parameters + G = Quantity(pot.G, units["length"] ** 3 / units["time"] ** 2 / units["mass"]) + + return gpx.LeeSutoTriaxialNFWPotential( + m=params["v_c"] ** 2 * params["r_s"] / G, + r_s=params["r_s"], + a1=params["a"], + a2=params["b"], + a3=params["c"], + units=units, + constants={"G": G}, + ) + + +# ----------------------------------------------------------------------------- +# NFW potentials + + +@dispatch def gala_to_galax(gala: gp.NFWPotential, /) -> gpx.NFWPotential | gpx.PotentialFrame: """Convert a Gala NFWPotential to a Galax potential. diff --git a/src/galax/potential/__init__.py b/src/galax/potential/__init__.py index 98523acb..f9098327 100644 --- a/src/galax/potential/__init__.py +++ b/src/galax/potential/__init__.py @@ -31,6 +31,11 @@ # logarithmic "LogarithmicPotential", "LMJ09LogarithmicPotential", + # multipole + "AbstractMultipolePotential", + "MultipoleInnerPotential", + "MultipoleOuterPotential", + "MultipolePotential", # nfw "NFWPotential", "LeeSutoTriaxialNFWPotential", @@ -79,6 +84,12 @@ LMJ09LogarithmicPotential, LogarithmicPotential, ) + from ._potential.builtin.multipole import ( + AbstractMultipolePotential, + MultipoleInnerPotential, + MultipoleOuterPotential, + MultipolePotential, + ) from ._potential.builtin.nfw import ( LeeSutoTriaxialNFWPotential, NFWPotential, diff --git a/src/galax/potential/_potential/builtin/__init__.py b/src/galax/potential/_potential/builtin/__init__.py index 72a8c438..086a149c 100644 --- a/src/galax/potential/_potential/builtin/__init__.py +++ b/src/galax/potential/_potential/builtin/__init__.py @@ -1,10 +1,11 @@ """``galax`` Potentials.""" # ruff:noqa: F401 -from . import bars, builtin, logarithmic, nfw, special +from . import bars, builtin, logarithmic, multipole, nfw, special from .bars import * from .builtin import * from .logarithmic import * +from .multipole import * from .nfw import * from .special import * @@ -12,5 +13,6 @@ __all__ += builtin.__all__ __all__ += bars.__all__ __all__ += logarithmic.__all__ +__all__ += multipole.__all__ __all__ += nfw.__all__ __all__ += special.__all__ diff --git a/src/galax/potential/_potential/builtin/multipole.py b/src/galax/potential/_potential/builtin/multipole.py new file mode 100644 index 00000000..65372336 --- /dev/null +++ b/src/galax/potential/_potential/builtin/multipole.py @@ -0,0 +1,279 @@ +"""Multipole potential.""" + +__all__ = [ + "AbstractMultipolePotential", + "MultipoleInnerPotential", + "MultipoleOuterPotential", + "MultipolePotential", +] + +from dataclasses import KW_ONLY +from functools import partial +from typing import final + +import jax +from equinox import field +from jax.scipy.special import sph_harm +from jaxtyping import Array, Float + +import quaxed.array_api as xp +import quaxed.numpy as jnp +from unxt import Quantity + +import galax.typing as gt +from galax.potential._potential.core import AbstractPotential +from galax.potential._potential.params.core import AbstractParameter +from galax.potential._potential.params.field import ParameterField + + +class AbstractMultipolePotential(AbstractPotential): + """Abstract Multipole Potential.""" + + m_tot: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment] + """Total mass of the multipole potential.""" + + r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment] + """Scale radius.""" + + _: KW_ONLY + l_max: int = field(static=True) + + +@final +class MultipoleInnerPotential(AbstractMultipolePotential): + r"""Multipole inner expansion potential. + + .. math:: + + \Phi^l_\mathrm{max}(r,\theta,\phi) = + \sum_{l=0}^{l=l_\mathrm{max}}\sum_{m=0}^{m=l} + r^l \, (S_{lm} \, \cos{m\,\phi} + T_{lm} \, \sin{m\,\phi}) + \, P_l^m(\cos\theta) + + """ + + Slm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Spherical harmonic coefficients for the $\cos(m \phi)$ terms.""" + + Tlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Spherical harmonic coefficients for the $\sin(m \phi)$ terms.""" + + def __check_init__(self) -> None: + shape = (self.l_max + 1, self.l_max + 1) + t = Quantity(0.0, "Gyr") + s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape + # TODO: check shape across time. + if s_shape != shape or t_shape != shape: + msg = ( + "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." + f"Slm shape: {s_shape}, Tlm shape: {t_shape}" + ) + raise ValueError(msg) + + @partial(jax.jit, inline=True) + def _potential( + self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, / + ) -> gt.BatchFloatQScalar: + # Compute the parameters + m_tot, r_s = self.m_tot(t), self.r_s(t) + Slm, Tlm = self.Slm(t).value, self.Tlm(t).value + + # spherical coordinates + is_scalar = q.ndim == 1 + s, theta, phi = cartesian_to_normalized_spherical(jnp.atleast_2d(q), r_s) + + # Compute the summation over l and m + l_max = self.l_max + ls, ms = jnp.tril_indices(l_max + 1) + + # TODO: vectorize compute_Ylm over l, m, then don't need a vmap? + def summand(l: int, m: int) -> Float[Array, "*batch"]: + cPlm, sPlm = compute_Ylm(l, m, theta, phi, l_max=l_max) + return xp.pow(s, l) * (Slm[l, m] * cPlm + Tlm[l, m] * sPlm) + + summation = xp.sum(jax.vmap(summand, in_axes=(0, 0))(ls, ms), axis=0) + if is_scalar: + summation = summation[0] + + return self.constants["G"] * m_tot / r_s * summation + + +@final +class MultipoleOuterPotential(AbstractMultipolePotential): + r"""Multipole outer expansion potential. + + .. math:: + + \Phi^l_\mathrm{max}(r,\theta,\phi) = + \sum_{l=0}^{l=l_\mathrm{max}}\sum_{m=0}^{m=l} + r^{-(l+1)} \, (S_{lm} \, \cos{m\,\phi} + T_{lm} \, \sin{m\,\phi}) + \, P_l^m(\cos\theta) + + """ + + Slm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Spherical harmonic coefficients for the $\cos(m \phi)$ terms.""" + + Tlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Spherical harmonic coefficients for the $\sin(m \phi)$ terms.""" + + def __check_init__(self) -> None: + shape = (self.l_max + 1, self.l_max + 1) + t = Quantity(0.0, "Gyr") + s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape + # TODO: check shape across time. + if s_shape != shape or t_shape != shape: + msg = ( + "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." + f"Slm shape: {s_shape}, Tlm shape: {t_shape}" + ) + raise ValueError(msg) + + @partial(jax.jit, inline=True) + def _potential( + self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, / + ) -> gt.BatchFloatQScalar: + # Compute the parameters + m_tot, r_s = self.m_tot(t), self.r_s(t) + Slm, Tlm = self.Slm(t).value, self.Tlm(t).value + + # spherical coordinates + is_scalar = q.ndim == 1 + s, theta, phi = cartesian_to_normalized_spherical(jnp.atleast_2d(q), r_s) + + # Compute the summation over l and m + l_max = self.l_max + ls, ms = jnp.tril_indices(l_max + 1) + + # TODO: vectorize compute_Ylm over l, m, then don't need a vmap? + def summand(l: int, m: int) -> Float[Array, "*batch"]: + cPlm, sPlm = compute_Ylm(l, m, theta, phi, l_max=l_max) + return xp.pow(s, -(l + 1)) * (Slm[l, m] * cPlm + Tlm[l, m] * sPlm) + + summation = xp.sum(jax.vmap(summand, in_axes=(0, 0))(ls, ms), axis=0) + if is_scalar: + summation = summation[0] + + return self.constants["G"] * m_tot / r_s * summation + + +@final +class MultipolePotential(AbstractMultipolePotential): + r"""Multipole inner and outer expansion potential. + + .. math:: + + \Phi^l_\mathrm{max}(r,\theta,\phi) = + \sum_{l=0}^{l=l_\mathrm{max}}\sum_{m=0}^{m=l} + [ (r^l IS_{lm} + r^{-(l+1)} OS_{lm}) \, \cos{m\,\phi} + + (r^l IT_{lm} + r^{-(l+1)} OT_{lm}) \, \sin{m\,\phi}] + \, P_l^m(\cos\theta) + + """ + + ISlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Inner spherical harmonic coefficients for the $\cos(m \phi)$ terms.""" + + ITlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Inner Spherical harmonic coefficients for the $\sin(m \phi)$ terms.""" + + OSlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Outer spherical harmonic coefficients for the $\cos(m \phi)$ terms.""" + + OTlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + r"""Outer Spherical harmonic coefficients for the $\sin(m \phi)$ terms.""" + + def __check_init__(self) -> None: + shape = (self.l_max + 1, self.l_max + 1) + t = Quantity(0.0, "Gyr") + is_shape, it_shape = self.ISlm(t).shape, self.ITlm(t).shape + os_shape, ot_shape = self.OSlm(t).shape, self.OTlm(t).shape + # TODO: check shape across time. + if ( + is_shape != shape + or it_shape != shape + or os_shape != shape + or ot_shape != shape + ): + msg = "I/OSlm and I/OTlm must have the shape (l_max + 1, l_max + 1)." + raise ValueError(msg) + + @partial(jax.jit, inline=True) + def _potential( + self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, / + ) -> gt.BatchFloatQScalar: + # Compute the parameters + m_tot, r_s = self.m_tot(t), self.r_s(t) + ISlm, ITlm = self.ISlm(t).value, self.ITlm(t).value + OSlm, OTlm = self.OSlm(t).value, self.OTlm(t).value + + # spherical coordinates + is_scalar = q.ndim == 1 + s, theta, phi = cartesian_to_normalized_spherical(jnp.atleast_2d(q), r_s) + + # Compute the summation over l and m + l_max = self.l_max + ls, ms = jnp.tril_indices(l_max + 1) + + # TODO: vectorize compute_Ylm over l, m, then don't need a vmap? + def summand(l: int, m: int) -> Float[Array, "*batch"]: + cPlm, sPlm = compute_Ylm(l, m, theta, phi, l_max=l_max) + inner = xp.pow(s, l) * (ISlm[l, m] * cPlm + ITlm[l, m] * sPlm) + outer = xp.pow(s, -l - 1) * (OSlm[l, m] * cPlm + OTlm[l, m] * sPlm) + return inner + outer + + summation = xp.sum(jax.vmap(summand, in_axes=(0, 0))(ls, ms), axis=0) + if is_scalar: + summation = summation[0] + + return self.constants["G"] * m_tot / r_s * summation + + +# ===== Helper functions ===== + + +def cartesian_to_normalized_spherical( + q: gt.BatchQVec3, r_s: Quantity, / +) -> tuple[gt.BatchFloatScalar, gt.BatchFloatScalar, Quantity]: + r"""Convert Cartesian coordinates to normalized spherical coordinates. + + .. math:: + + r = \sqrt{x^2 + y^2 + z^2} + X = \cos(\theta) = z / r + \phi = \tan^{-1}\left(\frac{y}{x}\right) + + Parameters + ---------- + q : Array[float, (*batch, 3), "length"] + Cartesian coordinates. + + Returns + ------- + s : Array[float, (*batch,)] + Normalized radius. + theta : Array[float, (*batch,)] + theta angle. + phi : Quantity[float, (*batch,), "angle"] + phi angle. + """ + r = xp.linalg.vector_norm(q, axis=-1) + s = r / r_s + theta = xp.acos(q[..., 2] / r).to_value("rad") # theta + phi = xp.atan2(q[..., 1], q[..., 0]).to_value("rad") # atan(y/x) + + # Return, converting Quantity["dimensionless"] -> Array + return s.value, theta, phi + + +# TODO: vectorize such that it's signature="(l),(l),(N),(N)->(l, N)": +def compute_Ylm( + l: int, + m: int, + theta: Float[Array, "*batch"], + phi: Float[Array, "*batch"], + *, + l_max: int, +) -> tuple[Float[Array, "*batch"], Float[Array, "*batch"]]: + Ylm = sph_harm(jnp.atleast_1d(m), jnp.atleast_1d(l), phi, theta, n_max=l_max) + return Ylm.real, Ylm.imag diff --git a/src/galax/potential/_potential/params/core.py b/src/galax/potential/_potential/params/core.py index a771c0d6..8f55ba41 100644 --- a/src/galax/potential/_potential/params/core.py +++ b/src/galax/potential/_potential/params/core.py @@ -17,10 +17,10 @@ import equinox as eqx import jax -import quaxed.array_api as xp from unxt import Quantity from galax.typing import BatchableRealQScalar, FloatQAnyShape, Unit +from galax.utils._shape import expand_batch_dims if TYPE_CHECKING: from typing import Self @@ -97,6 +97,8 @@ class ConstantParameter(AbstractParameter): value: FloatQAnyShape = eqx.field( converter=lambda x: Quantity.constructor(x, dtype=float) ) + """The time-independent value of the parameter.""" + _: KW_ONLY unit: Unit = eqx.field(static=True, converter=u.Unit) @@ -126,7 +128,7 @@ def __call__(self, t: BatchableRealQScalar = t0, **_: Any) -> FloatQAnyShape: Array[float, "*shape"] The constant parameter value. """ - return xp.broadcast_to(self.value, t.shape) + return expand_batch_dims(self.value, t.ndim) # ------------------------------------------- diff --git a/tests/unit/potential/builtin/multipole/__init__.py b/tests/unit/potential/builtin/multipole/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/potential/builtin/multipole/test_abstractmultipole.py b/tests/unit/potential/builtin/multipole/test_abstractmultipole.py new file mode 100644 index 00000000..7c444852 --- /dev/null +++ b/tests/unit/potential/builtin/multipole/test_abstractmultipole.py @@ -0,0 +1,113 @@ +"""Test AbstractMultipolePotential.""" + +import jax.numpy as jnp +import pytest +from jaxtyping import Array, Shaped + +import quaxed.numpy as qnp +from unxt import Quantity + +import galax.potential as gp +from ...param.test_field import ParameterFieldMixin + + +class ParameterAngularCoefficientsMixin(ParameterFieldMixin): + """Test the angular coefficients.""" + + @pytest.fixture(scope="class") + def field_l_max(self) -> int: + """l_max static field.""" + return 3 + + +class ParameterSlmMixin(ParameterAngularCoefficientsMixin): + """Test the Slm parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_Slm(self, field_l_max) -> Shaped[Array, "3 3"]: + """Slm parameter.""" + Slm = jnp.zeros((field_l_max + 1, field_l_max + 1)) + Slm = Slm.at[1, 0].set(5.0) + return Slm # noqa: RET504 + + # ===================================================== + + def test_Slm_units(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + Slm = jnp.zeros((l_max + 1, l_max + 1)) + Slm = Slm.at[1, :].set(5.0) + + fields["Slm"] = Quantity(Slm, "") + pot = pot_cls(**fields) + assert isinstance(pot.Slm, gp.params.ConstantParameter) + assert qnp.allclose(pot.Slm.value, Quantity(Slm, "")) + + def test_Slm_constant(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + Slm = jnp.zeros((l_max + 1, l_max + 1)) + Slm = Slm.at[1, 0].set(5.0) + + fields["Slm"] = Slm + pot = pot_cls(**fields) + assert qnp.allclose(pot.Slm(t=Quantity(0, "Myr")), Slm) + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_Slm_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + Slm = jnp.zeros((l_max + 1, l_max + 1)) + Slm = Slm.at[1, 0].set(5.0) + + fields["Slm"] = lambda t: Slm * qnp.exp(-qnp.abs(t)) + pot = pot_cls(**fields) + assert qnp.allclose(pot.Slm(t=Quantity(0, "Myr")), Slm) + + +class ParameterTlmMixin(ParameterAngularCoefficientsMixin): + """Test the Tlm parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_Tlm(self, field_l_max) -> Shaped[Array, "3 3"]: + """Tlm parameter.""" + return jnp.zeros((field_l_max + 1, field_l_max + 1)) + + # ===================================================== + + def test_Tlm_units(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + Tlm = jnp.zeros((l_max + 1, l_max + 1)) + Tlm = Tlm.at[1, :].set(5.0) + + fields["Tlm"] = Quantity(Tlm, "") + fields["l_max"] = l_max + pot = pot_cls(**fields) + assert isinstance(pot.Tlm, gp.params.ConstantParameter) + assert qnp.allclose(pot.Tlm.value, Quantity(Tlm, "")) + + def test_Tlm_constant(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + Tlm = jnp.zeros((l_max + 1, l_max + 1)) + Tlm = Tlm.at[1, 0].set(5.0) + + fields["Tlm"] = Tlm + pot = pot_cls(**fields) + assert qnp.allclose(pot.Tlm(t=Quantity(0, "Myr")), Tlm) + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_Tlm_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + Tlm = jnp.zeros((l_max + 1, l_max + 1)) + Tlm = Tlm.at[1, :].set(5.0) + + fields["Tlm"] = lambda t: Tlm * qnp.exp(-qnp.abs(t)) + pot = pot_cls(**fields) + assert qnp.allclose(pot.Tlm(t=Quantity(0, "Myr")), Tlm) diff --git a/tests/unit/potential/builtin/multipole/test_innermultipole.py b/tests/unit/potential/builtin/multipole/test_innermultipole.py new file mode 100644 index 00000000..eebf3cc8 --- /dev/null +++ b/tests/unit/potential/builtin/multipole/test_innermultipole.py @@ -0,0 +1,144 @@ +"""Test the `MultipoleInnerPotential` class.""" + +from typing import Any + +import astropy.units as u +import pytest +from jaxtyping import Array, Shaped +from plum import convert +from typing_extensions import override + +import quaxed.numpy as qnp +from unxt import Quantity +from unxt.unitsystems import AbstractUnitSystem + +import galax.potential as gp +import galax.typing as gt +from ...test_core import AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterScaleRadiusMixin +from .test_abstractmultipole import ParameterSlmMixin, ParameterTlmMixin +from galax.utils._optional_deps import GSL_ENABLED, HAS_GALA + +############################################################################### + + +class TestMultipoleInnerPotential( + AbstractPotential_Test, + # Parameters + ParameterMTotMixin, + ParameterScaleRadiusMixin, + ParameterSlmMixin, + ParameterTlmMixin, +): + @pytest.fixture(scope="class") + @override + def pot_cls(self) -> type[gp.MultipoleInnerPotential]: + return gp.MultipoleInnerPotential + + @pytest.fixture(scope="class") + @override + def fields_( + self, + field_m_tot: u.Quantity, + field_r_s: u.Quantity, + field_l_max: int, + field_Slm: Shaped[Array, "3 3"], + field_Tlm: Shaped[Array, "3 3"], + field_units: AbstractUnitSystem, + ) -> dict[str, Any]: + return { + "m_tot": field_m_tot, + "r_s": field_r_s, + "l_max": field_l_max, + "Slm": field_Slm, + "Tlm": field_Tlm, + "units": field_units, + } + + # ========================================================================== + + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + + def test_potential(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None: + expect = Quantity(32.96969177, unit="kpc2 / Myr2") + assert qnp.isclose( + pot.potential(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None: + expect = Quantity( + [4.74751335e-16, 9.49502670e-16, 10.9898973], pot.units["acceleration"] + ) + got = convert(pot.gradient(x, t=0), Quantity) + assert qnp.allclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + def test_density(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None: + expect = Quantity(2.89194575e-05, unit="solMass / kpc3") + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None: + expect = Quantity( + [ + [3.81496608e-16, -1.86509453e-16, 7.62993217e-17], + [-1.86509453e-16, 1.01732429e-16, 1.52598643e-16], + [-3.78931294e-16, -7.57862587e-16, 1.15158342e-15], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotentialBase, x: gt.QVec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [-1.63440876e-16, -1.86509453e-16, 7.62993217e-17], + [-1.86509453e-16, -4.43205056e-16, 1.52598643e-16], + [-3.78931294e-16, -7.57862587e-16, 6.06645933e-16], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # ========================================================================== + # Interoperability + + @pytest.mark.skipif(not HAS_GALA or not GSL_ENABLED, reason="requires gala + GSL") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + ("density", "density", 3e-5), # TODO: get gala and galax to agree + ("hessian", "hessian", 1e-8), # TODO: get gala and galax to agree + ], + ) + def test_method_gala( + self, + pot: gp.AbstractPotentialBase, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + """Test the equivalence of methods between gala and galax. + + This test only runs if the potential can be mapped to gala. + """ + super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/builtin/multipole/test_multipole.py b/tests/unit/potential/builtin/multipole/test_multipole.py new file mode 100644 index 00000000..4e1e4648 --- /dev/null +++ b/tests/unit/potential/builtin/multipole/test_multipole.py @@ -0,0 +1,335 @@ +"""Test the `MultipolePotential` class.""" + +import re +from typing import Any + +import astropy.units as u +import pytest +from jaxtyping import Array, Shaped +from plum import convert +from typing_extensions import override + +import quaxed.numpy as jnp +from unxt import Quantity +from unxt.unitsystems import AbstractUnitSystem + +import galax.potential as gp +import galax.typing as gt +from ...io.test_gala import parametrize_test_method_gala +from ...test_core import AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterScaleRadiusMixin +from .test_abstractmultipole import ParameterAngularCoefficientsMixin + +############################################################################### + + +class ParameterISlmMixin(ParameterAngularCoefficientsMixin): + """Test the ISlm parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_ISlm(self, field_l_max) -> Shaped[Array, "3 3"]: + """ISlm parameter.""" + ISlm = jnp.zeros((field_l_max + 1, field_l_max + 1)) + ISlm = ISlm.at[1, 0].set(5.0) + return ISlm # noqa: RET504 + + # ===================================================== + + def test_ISlm_units(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + ISlm = jnp.zeros((l_max + 1, l_max + 1)) + ISlm = ISlm.at[1, :].set(5.0) + + fields["ISlm"] = Quantity(ISlm, "") + pot = pot_cls(**fields) + assert isinstance(pot.ISlm, gp.params.ConstantParameter) + assert jnp.allclose(pot.ISlm.value, Quantity(ISlm, "")) + + def test_ISlm_constant(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + ISlm = jnp.zeros((l_max + 1, l_max + 1)) + ISlm = ISlm.at[1, 0].set(5.0) + + fields["ISlm"] = ISlm + pot = pot_cls(**fields) + assert jnp.allclose(pot.ISlm(t=Quantity(0, "Myr")), ISlm) + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_ISlm_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + ISlm = jnp.zeros((l_max + 1, l_max + 1)) + ISlm = ISlm.at[1, 0].set(5.0) + + fields["ISlm"] = lambda t: ISlm * jnp.exp(-jnp.abs(t)) + pot = pot_cls(**fields) + assert jnp.allclose(pot.ISlm(t=Quantity(0, "Myr")), ISlm) + + +class ParameterITlmMixin(ParameterAngularCoefficientsMixin): + """Test the ITlm parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_ITlm(self, field_l_max) -> Shaped[Array, "3 3"]: + """ITlm parameter.""" + return jnp.zeros((field_l_max + 1, field_l_max + 1)) + + # ===================================================== + + def test_ITlm_units(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + ITlm = jnp.zeros((l_max + 1, l_max + 1)) + ITlm = ITlm.at[1, :].set(5.0) + + fields["ITlm"] = Quantity(ITlm, "") + fields["l_max"] = l_max + pot = pot_cls(**fields) + assert isinstance(pot.ITlm, gp.params.ConstantParameter) + assert jnp.allclose(pot.ITlm.value, Quantity(ITlm, "")) + + def test_ITlm_constant(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + ITlm = jnp.zeros((l_max + 1, l_max + 1)) + ITlm = ITlm.at[1, 0].set(5.0) + + fields["ITlm"] = ITlm + pot = pot_cls(**fields) + assert jnp.allclose(pot.ITlm(t=Quantity(0, "Myr")), ITlm) + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_ITlm_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + ITlm = jnp.zeros((l_max + 1, l_max + 1)) + ITlm = ITlm.at[1, :].set(5.0) + + fields["ITlm"] = lambda t: ITlm * jnp.exp(-jnp.abs(t)) + pot = pot_cls(**fields) + assert jnp.allclose(pot.ITlm(t=Quantity(0, "Myr")), ITlm) + + +class ParameterOSlmMixin(ParameterAngularCoefficientsMixin): + """Test the OSlm parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_OSlm(self, field_l_max) -> Shaped[Array, "3 3"]: + """OSlm parameter.""" + OSlm = jnp.zeros((field_l_max + 1, field_l_max + 1)) + OSlm = OSlm.at[1, 0].set(5.0) + return OSlm # noqa: RET504 + + # ===================================================== + + def test_OSlm_units(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + OSlm = jnp.zeros((l_max + 1, l_max + 1)) + OSlm = OSlm.at[1, :].set(5.0) + + fields["OSlm"] = Quantity(OSlm, "") + pot = pot_cls(**fields) + assert isinstance(pot.OSlm, gp.params.ConstantParameter) + assert jnp.allclose(pot.OSlm.value, Quantity(OSlm, "")) + + def test_OSlm_constant(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + OSlm = jnp.zeros((l_max + 1, l_max + 1)) + OSlm = OSlm.at[1, 0].set(5.0) + + fields["OSlm"] = OSlm + pot = pot_cls(**fields) + assert jnp.allclose(pot.OSlm(t=Quantity(0, "Myr")), OSlm) + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_OSlm_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + OSlm = jnp.zeros((l_max + 1, l_max + 1)) + OSlm = OSlm.at[1, 0].set(5.0) + + fields["OSlm"] = lambda t: OSlm * jnp.exp(-jnp.abs(t)) + pot = pot_cls(**fields) + assert jnp.allclose(pot.OSlm(t=Quantity(0, "Myr")), OSlm) + + +class ParameterOTlmMixin(ParameterAngularCoefficientsMixin): + """Test the OTlm parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_OTlm(self, field_l_max) -> Shaped[Array, "3 3"]: + """OTlm parameter.""" + return jnp.zeros((field_l_max + 1, field_l_max + 1)) + + # ===================================================== + + def test_OTlm_units(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + OTlm = jnp.zeros((l_max + 1, l_max + 1)) + OTlm = OTlm.at[1, :].set(5.0) + + fields["OTlm"] = Quantity(OTlm, "") + fields["l_max"] = l_max + pot = pot_cls(**fields) + assert isinstance(pot.OTlm, gp.params.ConstantParameter) + assert jnp.allclose(pot.OTlm.value, Quantity(OTlm, "")) + + def test_OTlm_constant(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + OTlm = jnp.zeros((l_max + 1, l_max + 1)) + OTlm = OTlm.at[1, 0].set(5.0) + + fields["OTlm"] = OTlm + pot = pot_cls(**fields) + assert jnp.allclose(pot.OTlm(t=Quantity(0, "Myr")), OTlm) + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_OTlm_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + l_max = fields["l_max"] + OTlm = jnp.zeros((l_max + 1, l_max + 1)) + OTlm = OTlm.at[1, :].set(5.0) + + fields["OTlm"] = lambda t: OTlm * jnp.exp(-jnp.abs(t)) + pot = pot_cls(**fields) + assert jnp.allclose(pot.OTlm(t=Quantity(0, "Myr")), OTlm) + + +############################################################################### + + +class TestMultipolePotential( + AbstractPotential_Test, + # Parameters + ParameterMTotMixin, + ParameterScaleRadiusMixin, + ParameterISlmMixin, + ParameterITlmMixin, + ParameterOSlmMixin, + ParameterOTlmMixin, +): + @pytest.fixture(scope="class") + @override + def pot_cls(self) -> type[gp.MultipolePotential]: + return gp.MultipolePotential + + @pytest.fixture(scope="class") + @override + def fields_( + self, + field_m_tot: u.Quantity, + field_r_s: u.Quantity, + field_l_max: int, + field_ISlm: Shaped[Array, "3 3"], + field_ITlm: Shaped[Array, "3 3"], + field_OSlm: Shaped[Array, "3 3"], + field_OTlm: Shaped[Array, "3 3"], + field_units: AbstractUnitSystem, + ) -> dict[str, Any]: + return { + "m_tot": field_m_tot, + "r_s": field_r_s, + "l_max": field_l_max, + "ISlm": field_ISlm, + "ITlm": field_ITlm, + "OSlm": field_OSlm, + "OTlm": field_OTlm, + "units": field_units, + } + + # ========================================================================== + + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["ISlm"] = fields_["ISlm"][::2] # make it the wrong shape + match = re.escape("I/OSlm and I/OTlm must have the shape") + with pytest.raises(ValueError, match=match): + pot_cls(**fields_) + + # ========================================================================== + + def test_potential(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None: + expect = Quantity(33.59908611, unit="kpc2 / Myr2") + assert jnp.isclose( + pot.potential(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None: + expect = Quantity( + [-0.13487022, -0.26974043, 10.79508472], pot.units["acceleration"] + ) + got = convert(pot.gradient(x, t=0), Quantity) + assert jnp.allclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + def test_density(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None: + expect = Quantity(4.73805126e-05, pot.units["mass density"]) + assert jnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None: + expect = Quantity( + [ + [-0.08670228, 0.09633587, 0.09954706], + [0.09633587, 0.05780152, 0.19909413], + [0.09954706, 0.19909413, 0.02890076], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotentialBase, x: gt.QVec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [-0.08670228, 0.09633587, 0.09954706], + [0.09633587, 0.05780152, 0.19909413], + [0.09954706, 0.19909413, 0.02890076], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # ========================================================================== + # Interoperability + + @pytest.mark.xfail() + def test_galax_to_gala_to_galax_roundtrip( + self, pot: gp.AbstractPotentialBase, x: gt.QVec3 + ) -> None: + super().test_galax_to_gala_to_galax_roundtrip(pot, x) + + @pytest.mark.xfail() + @parametrize_test_method_gala + def test_method_gala( + self, + pot: gp.MultipolePotential, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/builtin/multipole/test_outermultipole.py b/tests/unit/potential/builtin/multipole/test_outermultipole.py new file mode 100644 index 00000000..345543bc --- /dev/null +++ b/tests/unit/potential/builtin/multipole/test_outermultipole.py @@ -0,0 +1,144 @@ +"""Test the `MultipoleOuterPotential` class.""" + +from typing import Any + +import astropy.units as u +import pytest +from jaxtyping import Array, Shaped +from plum import convert +from typing_extensions import override + +import quaxed.numpy as qnp +from unxt import Quantity +from unxt.unitsystems import AbstractUnitSystem + +import galax.potential as gp +import galax.typing as gt +from ...test_core import AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterScaleRadiusMixin +from .test_abstractmultipole import ParameterSlmMixin, ParameterTlmMixin +from galax.utils._optional_deps import GSL_ENABLED, HAS_GALA + +############################################################################### + + +class TestMultipoleOuterPotential( + AbstractPotential_Test, + # Parameters + ParameterMTotMixin, + ParameterScaleRadiusMixin, + ParameterSlmMixin, + ParameterTlmMixin, +): + @pytest.fixture(scope="class") + @override + def pot_cls(self) -> type[gp.MultipoleOuterPotential]: + return gp.MultipoleOuterPotential + + @pytest.fixture(scope="class") + @override + def fields_( + self, + field_m_tot: u.Quantity, + field_r_s: u.Quantity, + field_l_max: int, + field_Slm: Shaped[Array, "3 3"], + field_Tlm: Shaped[Array, "3 3"], + field_units: AbstractUnitSystem, + ) -> dict[str, Any]: + return { + "m_tot": field_m_tot, + "r_s": field_r_s, + "l_max": field_l_max, + "Slm": field_Slm, + "Tlm": field_Tlm, + "units": field_units, + } + + # ========================================================================== + + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + + def test_potential(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None: + expect = Quantity(0.62939434, unit="kpc2 / Myr2") + assert qnp.isclose( + pot.potential(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None: + expect = Quantity( + [-0.13487022, -0.26974043, -0.19481253], pot.units["acceleration"] + ) + got = convert(pot.gradient(x, t=0), Quantity) + assert qnp.allclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + def test_density(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None: + expect = Quantity(4.90989768e-07, unit="solMass / kpc3") + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None: + expect = Quantity( + [ + [-0.08670228, 0.09633587, 0.09954706], + [0.09633587, 0.05780152, 0.19909413], + [0.09954706, 0.19909413, 0.02890076], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotentialBase, x: gt.QVec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [-0.08670228, 0.09633587, 0.09954706], + [0.09633587, 0.05780152, 0.19909413], + [0.09954706, 0.19909413, 0.02890076], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # ========================================================================== + # Interoperability + + @pytest.mark.skipif(not HAS_GALA or not GSL_ENABLED, reason="requires gala + GSL") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + ("density", "density", 6e-7), # TODO: get gala and galax to agree + ("hessian", "hessian", 1e0), # TODO: THIS IS BAD!! + ], + ) + def test_method_gala( + self, + pot: gp.AbstractPotentialBase, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + """Test the equivalence of methods between gala and galax. + + This test only runs if the potential can be mapped to gala. + """ + super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/io/test_gala.py b/tests/unit/potential/io/test_gala.py index 24042fc6..8257c4e8 100644 --- a/tests/unit/potential/io/test_gala.py +++ b/tests/unit/potential/io/test_gala.py @@ -42,9 +42,8 @@ def test_galax_to_gala_to_galax_roundtrip( if not self.HAS_GALA_COUNTERPART: pytest.skip("potential does not have a gala counterpart") - rpot = gp.io.convert_potential( - gp.io.GalaxLibrary, gp.io.convert_potential(gp.io.GalaLibrary, pot) - ) + gala_pot = gp.io.convert_potential(gp.io.GalaLibrary, pot) + rpot = gp.io.convert_potential(gp.io.GalaxLibrary, gala_pot) # quick test that the potential energies are the same got = rpot(x, 0) diff --git a/tests/unit/potential/param/test_core.py b/tests/unit/potential/param/test_core.py index fd1a9a02..c230c9c1 100644 --- a/tests/unit/potential/param/test_core.py +++ b/tests/unit/potential/param/test_core.py @@ -81,9 +81,9 @@ def test_call(self, param: T, field_value: float) -> None: """Test `galax.potential.ConstantParameter` call method.""" assert param(t=1.0) == field_value assert param(t=1.0 * u.s) == field_value - assert qnp.array_equal( - param(t=xp.asarray([1.0, 2.0])), [field_value, field_value] - ) + # Calling the parameter doesn't broadcast the shape, only the number of + # dimensions + assert qnp.array_equal(param(t=xp.asarray([1.0, 2.0])), [field_value]) def test_mul(self, param: T, field_value: float) -> None: """Test `galax.potential.ConstantParameter` multiplication."""