diff --git a/src/galax/potential/_potential/builtin/multipole.py b/src/galax/potential/_potential/builtin/multipole.py index 8efd9d8a..65372336 100644 --- a/src/galax/potential/_potential/builtin/multipole.py +++ b/src/galax/potential/_potential/builtin/multipole.py @@ -60,11 +60,13 @@ class MultipoleInnerPotential(AbstractMultipolePotential): def __check_init__(self) -> None: shape = (self.l_max + 1, self.l_max + 1) - # TODO: don't use .value - if self.Slm.value.shape != shape or self.Tlm.value.shape != shape: + t = Quantity(0.0, "Gyr") + s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape + # TODO: check shape across time. + if s_shape != shape or t_shape != shape: msg = ( "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." - f"Slm shape: {self.Slm.value.shape}, Tlm shape: {self.Tlm.value.shape}" + f"Slm shape: {s_shape}, Tlm shape: {t_shape}" ) raise ValueError(msg) @@ -117,9 +119,14 @@ class MultipoleOuterPotential(AbstractMultipolePotential): def __check_init__(self) -> None: shape = (self.l_max + 1, self.l_max + 1) - # TODO: don't use .value - if self.Slm.value.shape != shape or self.Tlm.value.shape != shape: - msg = "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." + t = Quantity(0.0, "Gyr") + s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape + # TODO: check shape across time. + if s_shape != shape or t_shape != shape: + msg = ( + "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." + f"Slm shape: {s_shape}, Tlm shape: {t_shape}" + ) raise ValueError(msg) @partial(jax.jit, inline=True) @@ -178,12 +185,15 @@ class MultipolePotential(AbstractMultipolePotential): def __check_init__(self) -> None: shape = (self.l_max + 1, self.l_max + 1) - # TODO: don't use .value + t = Quantity(0.0, "Gyr") + is_shape, it_shape = self.ISlm(t).shape, self.ITlm(t).shape + os_shape, ot_shape = self.OSlm(t).shape, self.OTlm(t).shape + # TODO: check shape across time. if ( - self.ISlm.value.shape != shape - or self.ITlm.value.shape != shape - or self.OSlm.value.shape != shape - or self.OTlm.value.shape != shape + is_shape != shape + or it_shape != shape + or os_shape != shape + or ot_shape != shape ): msg = "I/OSlm and I/OTlm must have the shape (l_max + 1, l_max + 1)." raise ValueError(msg) diff --git a/tests/unit/potential/builtin/multipole/test_innermultipole.py b/tests/unit/potential/builtin/multipole/test_innermultipole.py index 2a4b7c7a..8d7701e5 100644 --- a/tests/unit/potential/builtin/multipole/test_innermultipole.py +++ b/tests/unit/potential/builtin/multipole/test_innermultipole.py @@ -57,6 +57,16 @@ def fields_( # ========================================================================== + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + def test_potential(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None: expect = Quantity(32.96969177, unit="kpc2 / Myr2") assert qnp.isclose( diff --git a/tests/unit/potential/builtin/multipole/test_multipole.py b/tests/unit/potential/builtin/multipole/test_multipole.py index 9f4a0b0c..b9a3087b 100644 --- a/tests/unit/potential/builtin/multipole/test_multipole.py +++ b/tests/unit/potential/builtin/multipole/test_multipole.py @@ -1,5 +1,6 @@ """Test the `MultipolePotential` class.""" +import re from typing import Any import astropy.units as u @@ -252,6 +253,17 @@ def fields_( # ========================================================================== + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["ISlm"] = fields_["ISlm"][::2] # make it the wrong shape + match = re.escape("I/OSlm and I/OTlm must have the shape") + with pytest.raises(ValueError, match=match): + pot_cls(**fields_) + + # ========================================================================== + def test_potential(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None: expect = Quantity(33.59908611, unit="kpc2 / Myr2") assert jnp.isclose( diff --git a/tests/unit/potential/builtin/multipole/test_outermultipole.py b/tests/unit/potential/builtin/multipole/test_outermultipole.py index 1e097afd..3ab3b344 100644 --- a/tests/unit/potential/builtin/multipole/test_outermultipole.py +++ b/tests/unit/potential/builtin/multipole/test_outermultipole.py @@ -57,6 +57,16 @@ def fields_( # ========================================================================== + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + def test_potential(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None: expect = Quantity(0.62939434, unit="kpc2 / Myr2") assert qnp.isclose(