Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Dec 8, 2023
1 parent 5647434 commit e422b6a
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 259 deletions.
27 changes: 14 additions & 13 deletions src/galdynamix/potential/_potential/param/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import astropy.units as u
import equinox as eqx
from jaxtyping import Array, Float

from galdynamix.typing import ArrayAnyShape, FloatScalar
from galdynamix.typing import ArrayAnyShape, FloatLike, VecShape
from galdynamix.utils import partial_jit


Expand All @@ -30,19 +31,19 @@ class AbstractParameter(eqx.Module): # type: ignore[misc]
unit: u.Unit = eqx.field(static=True) # TODO: move this to an annotation?

@abc.abstractmethod
def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:
def __call__(self, t: FloatLike, **kwargs: Any) -> ArrayAnyShape:
"""Compute the parameter value at the given time(s).
Parameters
----------
t : Array
t : float | Array[float, ()]
The time(s) at which to compute the parameter value.
**kwargs
**kwargs : Any
Additional parameters to pass to the parameter function.
Returns
-------
Array
Array[float, "*shape"]
The parameter value at times ``t``.
"""
...
Expand All @@ -53,15 +54,15 @@ class ConstantParameter(AbstractParameter):

# TODO: unit handling
# TODO: link this shape to the return shape from __call__
value: ArrayAnyShape
value: VecShape

@partial_jit()
def __call__(self, t: FloatScalar = 0, **kwargs: Any) -> ArrayAnyShape:
def __call__(self, t: FloatLike = 0, **kwargs: Any) -> Float[Array, "{self.value}"]:
"""Return the constant parameter value.
Parameters
----------
t : Array, optional
t : float | Array[float, ()], optional
This is ignored and is thus optional.
Note that for most :class:`~galdynamix.potential.AbstractParameter`
the time is required.
Expand All @@ -70,7 +71,7 @@ def __call__(self, t: FloatScalar = 0, **kwargs: Any) -> ArrayAnyShape:
Returns
-------
Array
Array[float, "*shape"]
The constant parameter value.
"""
return self.value
Expand All @@ -84,19 +85,19 @@ def __call__(self, t: FloatScalar = 0, **kwargs: Any) -> ArrayAnyShape:
class ParameterCallable(Protocol):
"""Protocol for a Parameter callable."""

def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:
def __call__(self, t: FloatLike, **kwargs: Any) -> ArrayAnyShape:
"""Compute the parameter value at the given time(s).
Parameters
----------
t : Array
t : float | Array[float, ()]
Time(s) at which to compute the parameter value.
**kwargs : Any
Additional parameters to pass to the parameter function.
Returns
-------
Array
Array[float, "*shape"]
Parameter value(s) at the given time(s).
"""
...
Expand All @@ -109,5 +110,5 @@ class UserParameter(AbstractParameter):
func: ParameterCallable

@partial_jit()
def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:
def __call__(self, t: FloatLike, **kwargs: Any) -> ArrayAnyShape:
return self.func(t, **kwargs)
43 changes: 22 additions & 21 deletions src/galdynamix/potential/_potential/scf/bfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import astropy.units as u
import equinox as eqx
import jax.numpy as xp
import jax.typing as jt
import jaxtyping as jtx
from jaxtyping import Array, Float
from typing_extensions import override

from galdynamix.potential._potential.core import AbstractPotential
from galdynamix.potential._potential.param import AbstractParameter, ParameterField
from galdynamix.utils import partial_jit
from galdynamix.typing import ArrayAnyShape, FloatLike, FloatScalar, Vec3
from galdynamix.utils import partial_jit, vectorize_method

from .bfe_helper import phi_nl as calculate_phi_nl
from .bfe_helper import rho_nl as calculate_rho_nl
Expand Down Expand Up @@ -81,9 +81,9 @@ def __post_init__(self) -> None:

@partial_jit()
@eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
def _potential_energy(
self, q: jtx.Float[jtx.Array, "3 N"], /, t: jtx.Float[jtx.Array, "1"]
) -> jtx.Float[jtx.Array, "N"]: # type: ignore[name-defined]
def _potential_energy_helper(
self, q: Float[Array, "3 N"], /, t: Float[Array, "1"]
) -> Float[Array, "N"]: # type: ignore[name-defined]
r, theta, phi = cartesian_to_spherical(q)
r_s = self.r_s(t)
s = xp.atleast_1d(r / r_s)[:, None, None, None]
Expand All @@ -110,14 +110,15 @@ def _potential_energy(
return out[0] if len(q.shape) == 1 else out

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
@vectorize_method(signature="(3),()->()")
def _potential_energy(self, q: Vec3, /, t: FloatScalar) -> FloatScalar:
"""Compute the potential energy at the given position(s)."""
out = self._potential_energy(expand_dim1(q), t)
out = self._potential_energy_helper(expand_dim1(q), t)
return out[0] if len(q.shape) == 1 else out

# @partial_jit()
# @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
# def _gradient(self, q: jtx.Float[jtx.Array, "3"], /, t: jt.Array) -> jt.Array:
# def _gradient(self, q: Float[Array, "3"], /, t: jt.Array) -> jt.Array:
# """Compute the gradient."""
# r, theta, phi = cartesian_to_spherical(q)
# r_s = self.r_s(t)
Expand Down Expand Up @@ -171,8 +172,8 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
@partial_jit()
@eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
def density(
self, q: jtx.Float[jtx.Array, "3 N"], /, t: jtx.Float[jtx.Array, "1"]
) -> jtx.Float[jtx.Array, "N"]: # type: ignore[name-defined]
self, q: Float[Array, "3 N"], /, t: Float[Array, "1"]
) -> Float[Array, "N"]: # type: ignore[name-defined]
"""Compute the density at the given position(s)."""
r, theta, phi = cartesian_to_spherical(q)
r_s = self.r_s(t)
Expand Down Expand Up @@ -208,41 +209,41 @@ class STnlmSnapshotParameter(AbstractParameter):
"""Parameter for the STnlm coefficients."""

snapshot: Callable[ # type: ignore[name-defined]
[jtx.Float[jtx.Array, "N"]],
tuple[jtx.Float[jtx.Array, "3 N"], jtx.Float[jtx.Array, "N"]],
[Float[Array, "N"]],
tuple[Float[Array, "3 N"], Float[Array, "N"]],
]
"""Cartesian coordinates of the snapshot.
This should be a callable that accepts a single argument `t` and returns
the cartesian coordinates and the masses of the snapshot at that time.
"""

nmax: int = eqx.field(static=True)
nmax: int = eqx.field(static=True, converter=int)
"""Radial expansion term."""

lmax: int = eqx.field(static=True)
lmax: int = eqx.field(static=True, converter=int)
"""Spherical harmonic term."""

_: KW_ONLY
unit: u.Unit = eqx.field(default=u.dimensionless_unscaled, static=True)
unit: u.Unit = eqx.field(default=u.one, static=True, converter=u.Unit)

def __post_init__(self) -> None:
super().__post_init__()
if self.unit != u.dimensionless_unscaled:
if self.unit != u.one:
msg = "unit must be dimensionless"
raise ValueError(msg)

@override
def __call__( # type: ignore[override]
self, t: float, *, r_s: float, **kwargs: Any
) -> tuple[jt.Array, jt.Array]:
self, t: FloatLike, *, r_s: float, **kwargs: Any
) -> tuple[ArrayAnyShape, ArrayAnyShape]:
"""Return the coefficients at the given time(s).
Parameters
----------
t : float
t : float | Array[float, ()]
Time at which to evaluate the coefficients.
r_s : float
r_s : float | Array[float, ()]
Scale length of the potential at the given time(s.
**kwargs : Any
Additional keyword arguments are ignored.
Expand Down
53 changes: 6 additions & 47 deletions src/galdynamix/potential/_potential/scf/bfe_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,24 @@

import jax
import jax.numpy as xp
from jax.scipy.special import gamma
from jaxtyping import Array, Float

from galdynamix.potential._potential.scf.gegenbauer import GegenbauerCalculator
from galdynamix.utils import partial_jit

from .utils import factorial, psi_of_r


def normalization_Knl(n: int, l: int) -> float:
"""SCF normalization factor.
Parameters
----------
n : int
Radial expansion term.
l : int
Spherical harmonic term.
Returns
-------
float
"""
return 0.5 * n * (n + 4 * l + 3.0) + (l + 1) * (2 * l + 1)


@partial_jit()
def expansion_coeffs_Anl_discrete(n: int, l: int) -> Float[Array, "1"]:
"""Return normalization factor for the coefficients.
Equation 16 of Lowing et al. (2011).
Parameters
----------
n : int
Radial expansion term.
l : int
spherical harmonic term.
Returns
-------
float
"""
Knl = normalization_Knl(n=n, l=l)
prefac = -(2 ** (8.0 * l + 6)) / (4 * xp.pi * Knl)
numerator = factorial(n) * (n + 2 * l + 1.5) * gamma(2 * l + 1.5) ** 2
denominator = gamma(n + 4.0 * l + 3.0)
return prefac * (numerator / denominator)
from .coeffs_helper import normalization_Knl
from .utils import psi_of_r


@partial_jit(static_argnames=("gegenbauer",))
def rho_nl(
s: Float[Array, "samples"], n: int, l: int, *, gegenbauer: GegenbauerCalculator
) -> Float[Array, "samples"]:
s: Float[Array, "N"], n: int, l: int, *, gegenbauer: GegenbauerCalculator
) -> Float[Array, "N"]:
r"""Radial density expansion terms.
Parameters
----------
s : Array[(n_samples,), float]
s : Array[(n,), float]
Scaled radius :math:`r/r_s`.
n : int
Radial expansion term.
Expand All @@ -74,7 +33,7 @@ def rho_nl(
Returns
-------
Array[(n_samples,), float]
Array[(n,), float]
"""
return (
xp.sqrt(xp.pi * 4)
Expand Down
Loading

0 comments on commit e422b6a

Please sign in to comment.