diff --git a/src/galax/potential/_potential/builtin/multipole.py b/src/galax/potential/_potential/builtin/multipole.py index 996e541b..7c71d84c 100644 --- a/src/galax/potential/_potential/builtin/multipole.py +++ b/src/galax/potential/_potential/builtin/multipole.py @@ -59,11 +59,13 @@ class MultipoleInnerPotential(AbstractMultipolePotential): def __check_init__(self) -> None: shape = (self.l_max + 1, self.l_max + 1) - # TODO: don't use .value - if self.Slm.value.shape != shape or self.Tlm.value.shape != shape: + t = Quantity(0.0, "Gyr") + s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape + # TODO: check shape across time. + if s_shape != shape or t_shape != shape: msg = ( "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." - f"Slm shape: {self.Slm.value.shape}, Tlm shape: {self.Tlm.value.shape}" + f"Slm shape: {s_shape}, Tlm shape: {t_shape}" ) raise ValueError(msg) @@ -115,10 +117,14 @@ class MultipoleOuterPotential(AbstractMultipolePotential): r"""Spherical harmonic coefficients for the $\sin(m \phi)$ terms.""" def __check_init__(self) -> None: - shape = (self.l_max + 1, self.l_max + 1) - # TODO: don't use .value - if self.Slm.value.shape != shape or self.Tlm.value.shape != shape: - msg = "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." + t = Quantity(0.0, "Gyr") + s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape + # TODO: check shape across time. + if s_shape != shape or t_shape != shape: + msg = ( + "Slm and Tlm must have the shape (l_max + 1, l_max + 1)." + f"Slm shape: {s_shape}, Tlm shape: {t_shape}" + ) raise ValueError(msg) @partial(jax.jit, inline=True) @@ -177,12 +183,15 @@ class MultipolePotential(AbstractMultipolePotential): def __check_init__(self) -> None: shape = (self.l_max + 1, self.l_max + 1) - # TODO: don't use .value + t = Quantity(0.0, "Gyr") + is_shape, it_shape = self.ISlm(t).shape, self.ITlm(t).shape + os_shape, ot_shape = self.OSlm(t).shape, self.OTlm(t).shape + # TODO: check shape across time. if ( - self.ISlm.value.shape != shape - or self.ITlm.value.shape != shape - or self.OSlm.value.shape != shape - or self.OTlm.value.shape != shape + is_shape != shape + or it_shape != shape + or os_shape != shape + or ot_shape != shape ): msg = "I/OSlm and I/OTlm must have the shape (l_max + 1, l_max + 1)." raise ValueError(msg) diff --git a/tests/unit/potential/builtin/multipole/test_innermultipole.py b/tests/unit/potential/builtin/multipole/test_innermultipole.py index 2a4b7c7a..8d7701e5 100644 --- a/tests/unit/potential/builtin/multipole/test_innermultipole.py +++ b/tests/unit/potential/builtin/multipole/test_innermultipole.py @@ -57,6 +57,16 @@ def fields_( # ========================================================================== + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + def test_potential(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None: expect = Quantity(32.96969177, unit="kpc2 / Myr2") assert qnp.isclose( diff --git a/tests/unit/potential/builtin/multipole/test_multipole.py b/tests/unit/potential/builtin/multipole/test_multipole.py index 9f4a0b0c..2fac103e 100644 --- a/tests/unit/potential/builtin/multipole/test_multipole.py +++ b/tests/unit/potential/builtin/multipole/test_multipole.py @@ -252,6 +252,16 @@ def fields_( # ========================================================================== + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["ISlm"] = fields_["ISlm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + def test_potential(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None: expect = Quantity(33.59908611, unit="kpc2 / Myr2") assert jnp.isclose( diff --git a/tests/unit/potential/builtin/multipole/test_outermultipole.py b/tests/unit/potential/builtin/multipole/test_outermultipole.py index 1e097afd..3ab3b344 100644 --- a/tests/unit/potential/builtin/multipole/test_outermultipole.py +++ b/tests/unit/potential/builtin/multipole/test_outermultipole.py @@ -57,6 +57,16 @@ def fields_( # ========================================================================== + def test_check_init( + self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any] + ) -> None: + """Test the `MultipoleInnerPotential.__check_init__` method.""" + fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape + with pytest.raises(ValueError, match="Slm and Tlm must have the shape"): + pot_cls(**fields_) + + # ========================================================================== + def test_potential(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None: expect = Quantity(0.62939434, unit="kpc2 / Myr2") assert qnp.isclose(