Skip to content

Commit

Permalink
feat(matplotlib): add plot method
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jan 28, 2024
1 parent 52b55ac commit b7d3d5e
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/galax/extensions/matplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Matplotlib extension for :mod:`galax`."""

__all__ = ["plot_contours"]

from galax.utils._optional_deps import HAS_MATPLOTLIB

if HAS_MATPLOTLIB:
from ._mpl_mixin import plot_contours
else:
from ._mpl_mixin_noop import plot_contours
168 changes: 168 additions & 0 deletions src/galax/extensions/matplotlib/_mpl_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
__all__ = ["plot_contours"]


from typing import Any

import numpy as np
from astropy.utils import isiterable

import matplotlib.pyplot as plt
from galax.potential._potential.base import AbstractPotentialBase
from matplotlib.axes import Axes
from matplotlib.cm import Blues
from matplotlib.figure import Figure

# ============================================================================
# Plot contours


def _plot_countours_1d(
pot: AbstractPotentialBase,
t: float,
*,
ax: Axes,
grids: list[Any],
slices: list[Any],
labels: tuple[str, ...] | None,
kwargs: dict[str, Any],
) -> None:
x1 = grids[0][1]
r = np.zeros((len(grids) + len(slices), len(x1)))
r[grids[0][0]] = x1

for ii, slc in slices:
r[ii] = slc

Z = pot.potential_energy(r * pot.units["length"], t=t)
ax.plot(x1, Z, **kwargs)

if labels is not None:
ax.set_xlabel(labels[0])
ax.set_ylabel("potential")


def _plot_countours_2d(
pot: AbstractPotentialBase,
t: float,
*,
ax: Axes,
grids: list[Any],
slices: list[Any],
labels: tuple[str, ...] | None,
filled: bool,
kwargs: dict[str, Any],
) -> None:
x1, x2 = np.meshgrid(grids[0][1], grids[1][1])
shp = x1.shape
x1, x2 = x1.ravel(), x2.ravel()

r = np.zeros((len(grids) + len(slices), len(x1)))
r[grids[0][0]] = x1
r[grids[1][0]] = x2

for ii, slc in slices:
r[ii] = slc

Z = pot.potential_energy(r * pot.units["length"], t=t)

# make default colormap not suck
cmap = kwargs.setdefault("cmap", Blues)
if filled:
ax.contourf(x1.reshape(shp), x2.reshape(shp), Z.reshape(shp), **kwargs)
else:
ax.contour(x1.reshape(shp), x2.reshape(shp), Z.reshape(shp), **kwargs)

if labels is not None:
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])


def plot_contours(
pot: AbstractPotentialBase,
/,
grid: tuple[Any, ...],
t: float = 0.0,
*,
filled: bool = True,
ax: Any | None = None,
labels: tuple[str, ...] | None = None,
subplots_kw: dict[str, Any] | None = None,
**kwargs: Any,
) -> Figure:
"""Plot contours of the potential.
Parameters
----------
pot : :class:`~galax.potential.base.AbstractPotentialBase`
Potential object to plot.
grid : tuple[Any, ...]
Coordinate grids or slice value for each dimension. Should be a
tuple of 1D arrays or numbers.
t : quantity-like, optional
The time to evaluate at.
filled : bool, optional keyword-only
Use :func:`~matplotlib.pyplot.contourf` instead of
:func:`~matplotlib.pyplot.contour`. Default is ``True``.
ax : `~matplotlib.Axes`, optional keyword-only
Axes object to plot on. If not specified, a new figure and axes will
be created.
labels : tuple[str, ...] or None, optional keyword-only
List of axis labels.
subplots_kw : dict[str, Any], optional keyword-only
kwargs passed to matplotlib's subplots() function if an axes object
is not specified.
**kwargs : Any, optional keyword-only
kwargs passed to either :func:`~matplotlib.pyplot.contourf` or
:func:`~matplotlib.pyplot.plot`.
Returns
-------
:class:`~matplotlib.Figure`
"""
# Make figure
if ax is None:
fig, ax = plt.subplots(1, 1, **(subplots_kw or {}))
else:
fig = ax.figure

# Figure out which elements are iterable, which are numeric
_grids = []
_slices = []
for i, g in enumerate(grid):
if isiterable(g):
_grids.append((i, g))
else:
_slices.append((i, g))

# Figure out the dimensionality
ndim = len(_grids)

if ndim == 0 or ndim > 2:
msg = (
f"ndim = {ndim}: you can only make contours on a 1D or 2D grid."
" For other dimensions, you have to specify values to slice."
)
raise ValueError(msg)

if ndim == 1:
_plot_countours_1d(
pot, t, ax=ax, grids=_grids, slices=_slices, labels=labels, kwargs=kwargs
)

else:
_plot_countours_2d(
pot,
t,
ax=ax,
grids=_grids,
slices=_slices,
labels=labels,
filled=filled,
kwargs=kwargs,
)

return fig


# ============================================================================
56 changes: 56 additions & 0 deletions src/galax/extensions/matplotlib/_mpl_mixin_noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
__all__ = ["plot_contours"]


from typing import TYPE_CHECKING, Any, Never

from galax.potential._potential.base import AbstractPotentialBase

if TYPE_CHECKING:
from matplotlib.axes import Axes


def plot_contours(
pot: AbstractPotentialBase,
/,
grid: tuple[Any, ...],
t: float = 0.0,
*,
filled: bool = True,
ax: "Axes | None" = None,
labels: tuple[str, ...] | None = None,
subplots_kw: dict[str, Any] | None = None,
**kwargs: Any,
) -> Never:
"""Plot contours of the potential.
Parameters
----------
pot : :class:`~galax.potential.base.AbstractPotentialBase`
Potential object to plot.
grid : tuple[Any, ...]
Coordinate grids or slice value for each dimension. Should be a
tuple of 1D arrays or numbers.
t : quantity-like, optional
The time to evaluate at.
filled : bool, optional keyword-only
Use :func:`~matplotlib.pyplot.contourf` instead of
:func:`~matplotlib.pyplot.contour`. Default is ``True``.
ax : `~matplotlib.Axes`, optional keyword-only
Axes object to plot on. If not specified, a new figure and axes will
be created.
labels : iterable, optional keyword-only
List of axis labels.
subplots_kw : dict[str, Any], optional keyword-only
kwargs passed to matplotlib's subplots() function if an axes object
is not specified.
**kwargs : Any, optional keyword-only
kwargs passed to either :func:`~matplotlib.pyplot.contourf` or
:func:`~matplotlib.pyplot.plot`.
Returns
-------
:class:`~matplotlib.Figure`
"""
msg = r"No module named 'matplotlib'"
raise ModuleNotFoundError(msg)
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _init_units(self) -> None:

def potential_energy(
self,
q: BatchVec3 | Quantity | BaseRepresentation,
q: BaseRepresentation | BatchVec3 | Quantity,
/,
t: BatchableFloatOrIntScalarLike | Quantity,
) -> BatchFloatScalar:
Expand Down

0 comments on commit b7d3d5e

Please sign in to comment.