Skip to content

Commit

Permalink
SCF
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Dec 7, 2023
1 parent b43ac7a commit 8a31515
Show file tree
Hide file tree
Showing 13 changed files with 1,093 additions and 37 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [

[project.optional-dependencies]
test = [
"gala",
"hypothesis[numpy]",
"pytest >=6",
"pytest-cov >=3",
Expand Down Expand Up @@ -148,6 +149,7 @@ ignore = [
"F821", # undefined name <- jaxtyping
"FIX002", # Line contains TODO, consider resolving the issue
"N80", # Naming conventions.
"N816", # Variable in global scope should not be mixedCase
"PD", # pandas-vet
"PLR", # Design related pylint codes
"TCH00", # Move into a type-checking block
Expand Down
3 changes: 2 additions & 1 deletion src/galdynamix/potential/_potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from .composite import *
from .core import *
from .param import *
from .scf import SCFPotential

__all__: list[str] = []
__all__ += base.__all__
__all__ += core.__all__
__all__ += composite.__all__
__all__ += param.__all__
__all__ += builtin.__all__
__all__ += ["scf"]
__all__ += ["scf", "SCFPotential"]
14 changes: 9 additions & 5 deletions src/galdynamix/potential/_potential/scf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from . import gegenbauer
from .gegenbauer import *
from . import bfe, bfe_helper, coeffs, coeffs_helper
from .bfe import *
from .bfe_helper import *
from .coeffs import *
from .coeffs_helper import *

__all__: list[str] = []
__all__ += gegenbauer.__all__
__all__ += bfe.__all__
__all__ += bfe_helper.__all__
__all__ += coeffs.__all__
__all__ += coeffs_helper.__all__
258 changes: 258 additions & 0 deletions src/galdynamix/potential/_potential/scf/bfe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""Self-Consistent Field Potential."""

__all__ = ["SCFPotential", "STnlmSnapshotParameter"]

from collections.abc import Callable
from dataclasses import KW_ONLY
from typing import Any

import astropy.units as u
import equinox as eqx
import jax.numpy as xp
import jax.typing as jt
import jaxtyping as jtx
from typing_extensions import override

from galdynamix.potential._potential.core import AbstractPotential
from galdynamix.potential._potential.param import AbstractParameter, ParameterField
from galdynamix.utils import partial_jit

from .bfe_helper import phi_nl as calculate_phi_nl
from .bfe_helper import rho_nl as calculate_rho_nl
from .coeffs import compute_coeffs_discrete
from .gegenbauer import GegenbauerCalculator
from .utils import cartesian_to_spherical, expand_dim1, real_Ylm

##############################################################################


class SCFPotential(AbstractPotential):
r"""Self-Consistent Field (SCF) potential.
A gravitational potential represented as a basis function expansion. This
uses the self-consistent field (SCF) method of Hernquist & Ostriker (1992)
and Lowing et al. (2011), and represents all coefficients as real
quantities.
Parameters
----------
m : numeric
Scale mass.
r_s : numeric
Scale length.
Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable
Array of coefficients for the cos() terms of the expansion. This should
be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the
number of radial expansion terms and `lmax` is the number of spherical
harmonic `l` terms. If a callable is provided, it should accept a
single argument `t` and return the array of coefficients for that time.
Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable
Array of coefficients for the sin() terms of the expansion. This should
be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the
number of radial expansion terms and `lmax` is the number of spherical
harmonic `l` terms. If a callable is provided, it should accept a
single argument `t` and return the array of coefficients for that time.
units : iterable
Unique list of non-reducable units that specify (at minimum) the length,
mass, time, and angle units.
"""

m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
Snlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment]
Tnlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment]

nmax: int = eqx.field(init=False, static=True, repr=False)
lmax: int = eqx.field(init=False, static=True, repr=False)
_ultra_sph: GegenbauerCalculator = eqx.field(init=False, static=True, repr=False)

def __post_init__(self) -> None:
super().__post_init__()

# shape parameters
shape = self.Snlm(0).shape
object.__setattr__(self, "nmax", shape[0] - 1)
object.__setattr__(self, "lmax", shape[1] - 1)

# gegenbauer calculator
object.__setattr__(self, "_ultra_sph", GegenbauerCalculator(self.nmax))

# ==========================================================================

@partial_jit()
@eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
def _potential_energy(
self, q: jtx.Float[jtx.Array, "3 N"], /, t: jtx.Float[jtx.Array, "1"]
) -> jtx.Float[jtx.Array, "N"]: # type: ignore[name-defined]
r, theta, phi = cartesian_to_spherical(q)
r_s = self.r_s(t)
s = xp.atleast_1d(r / r_s)[:, None, None, None]
theta = xp.atleast_1d(theta)[:, None, None, None]
phi = xp.atleast_1d(phi)[:, None, None, None]

ns = xp.arange(self.nmax + 1)[None, :, None, None] # ([N], n, [l], [m])
ls = xp.arange(self.lmax + 1)[None, None, :, None] # ([N], [n], l, [m])
phi_nl = calculate_phi_nl(s, ns, ls, gegenbauer=self._ultra_sph)

li, mi = xp.tril_indices(self.lmax + 1) # (l*(l+1)//2,)
shape = (1, 1, self.lmax + 1, self.lmax + 1)
midx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi)
Ylm = xp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1))
Ylm = Ylm.at[:, :, li, mi].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0]))

Snlm = self.Snlm(t, r_s=r_s)[None]
Tnlm = self.Tnlm(t, r_s=r_s)[None]

out = (self._G * self.m(t) / r_s) * xp.sum(
Ylm * phi_nl * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)),
axis=(1, 2, 3),
)
return out[0] if len(q.shape) == 1 else out

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
"""Compute the potential energy at the given position(s)."""
out = self._potential_energy(expand_dim1(q), t)
return out[0] if len(q.shape) == 1 else out

# @partial_jit()
# @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
# def _gradient(self, q: jtx.Float[jtx.Array, "3"], /, t: jt.Array) -> jt.Array:
# """Compute the gradient."""
# r, theta, phi = cartesian_to_spherical(q)
# r_s = self.r_s(t)
# s = xp.atleast_1d(r / r_s)[:, None, None, None]
# theta = xp.atleast_1d(theta)[:, None, None, None]
# phi = xp.atleast_1d(phi)[:, None, None, None]

# ns = xp.arange(self.nmax + 1)[None, :, None, None] # ([N], n, [l], [m])
# ls = xp.arange(self.lmax + 1)[None, None, :, None] # ([N], [n], l, [m])
# phi_nl = calculate_phi_nl(s, ns, ls, gegenbauer=self._ultra_sph)
# dphi_nl_dr = phi_nl_grad(s, ns, ls, self._ultra_sph)

# li, mi = xp.tril_indices(self.lmax + 1) # (l*(l+1)//2,)
# shape = (1, 1, self.lmax + 1, self.lmax + 1)
# lidx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(li)
# midx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi)
# mvalid = xp.zeros(shape).at[:, :, li, mi].set(1) # m <= l
# Ylm = real_Ylm(lidx, midx, theta)
# dYlm_dtheta = calculate_dYlm_dtheta(lidx, midx, theta)

# Snlm = self.Snlm(t, r_s=r_s)[None]
# Tnlm = self.Tnlm(t, r_s=r_s)[None]

# grad_r = xp.sum(
# (mvalid * Ylm)
# * dphi_nl_dr
# * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)),
# axis=(1, 2, 3),
# )
# grad_theta = (1 / s[:, 0, 0, 0]) * xp.sum(
# (mvalid * dYlm_dtheta)
# * phi_nl
# * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)),
# axis=(1, 2, 3),
# )
# grad_phi = (1 / s[:, 0, 0, 0]) * xp.sum(
# (mvalid * Ylm / xp.sin(theta))
# * phi_nl
# * (Tnlm * xp.cos(midx * phi) - Snlm * xp.sin(midx * phi)),
# axis=(1, 2, 3),
# )
# return (self._G * self.m(t) / r_s) * xp.stack([grad_r, grad_theta, grad_phi],
# axis=-1)

# @partial_jit()
# def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
# """Compute the potential energy at the given position(s)."""
# out = self._gradient(expand_dim1(q), t)
# return out[0, 0] if len(q.shape) == 1 else out[:, 0] # TODO: fix this

@partial_jit()
@eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
def density(
self, q: jtx.Float[jtx.Array, "3 N"], /, t: jtx.Float[jtx.Array, "1"]
) -> jtx.Float[jtx.Array, "N"]: # type: ignore[name-defined]
"""Compute the density at the given position(s)."""
r, theta, phi = cartesian_to_spherical(q)
r_s = self.r_s(t)
s = xp.atleast_1d(r / r_s)[:, None, None, None]
theta = xp.atleast_1d(theta)[:, None, None, None]
phi = xp.atleast_1d(phi)[:, None, None, None]

ns = xp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m])
ls = xp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m])

phi_nl = calculate_rho_nl(s, ns[None], ls[None], gegenbauer=self._ultra_sph)

li, mi = xp.tril_indices(self.lmax + 1) # (l*(l+1)//2,)
shape = (1, 1, self.lmax + 1, self.lmax + 1)
midx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi)
Ylm = xp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1))
Ylm = Ylm.at[:, :, li, mi].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0]))

Snlm = self.Snlm(t, r_s=r_s)[None]
Tnlm = self.Tnlm(t, r_s=r_s)[None]

out = (self._G * self.m(t) / r_s) * xp.sum(
Ylm * phi_nl * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)),
axis=(1, 2, 3),
)
return out[0] if len(q.shape) == 1 else out


# =============================================================================


class STnlmSnapshotParameter(AbstractParameter):
"""Parameter for the STnlm coefficients."""

snapshot: Callable[ # type: ignore[name-defined]
[jtx.Float[jtx.Array, "N"]],
tuple[jtx.Float[jtx.Array, "3 N"], jtx.Float[jtx.Array, "N"]],
]
"""Cartesian coordinates of the snapshot.
This should be a callable that accepts a single argument `t` and returns
the cartesian coordinates and the masses of the snapshot at that time.
"""

nmax: int = eqx.field(static=True)
"""Radial expansion term."""

lmax: int = eqx.field(static=True)
"""Spherical harmonic term."""

_: KW_ONLY
unit: u.Unit = eqx.field(default=u.dimensionless_unscaled, static=True)

def __post_init__(self) -> None:
super().__post_init__()
if self.unit != u.dimensionless_unscaled:
msg = "unit must be dimensionless"
raise ValueError(msg)

@override
def __call__( # type: ignore[override]
self, t: float, *, r_s: float, **kwargs: Any
) -> tuple[jt.Array, jt.Array]:
"""Return the coefficients at the given time(s).
Parameters
----------
t : float
Time at which to evaluate the coefficients.
r_s : float
Scale length of the potential at the given time(s.
**kwargs : Any
Additional keyword arguments are ignored.
Returns
-------
Snlm : Array[(nmax+1, lmax+1, lmax+1), float]
The value of the cosine expansion coefficient.
Tnlm : Array[(nmax+1, lmax+1, lmax+1), float]
The value of the sine expansion coefficient.
"""
xyz, m = self.snapshot(t)
return compute_coeffs_discrete(xyz, m, nmax=self.nmax, lmax=self.lmax, r_s=r_s)
Loading

0 comments on commit 8a31515

Please sign in to comment.