Skip to content

Commit

Permalink
feat: quaxify (#140)
Browse files Browse the repository at this point in the history
* feat: quaxify
* feat: position mul
* docs: add examples
* docs: add

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Aug 7, 2024
1 parent bcf9d0a commit fd1f51f
Show file tree
Hide file tree
Showing 11 changed files with 544 additions and 129 deletions.
47 changes: 34 additions & 13 deletions src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from dataclasses import fields
from enum import Enum
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from typing_extensions import Never
from typing import TYPE_CHECKING, Any, Literal, NoReturn, TypeVar

import astropy.units as u
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax import Device
from plum import dispatch
from quax import ArrayValue

import quaxed.array_api as xp
from dataclassish import field_items, field_values, replace
Expand All @@ -45,7 +46,7 @@ class ToUnitsOptions(Enum):
# ===================================================================


class AbstractVector(eqx.Module): # type: ignore[misc]
class AbstractVector(ArrayValue): # type: ignore[misc]
"""Base class for all vector types.
A vector is a collection of components that can be represented in different
Expand Down Expand Up @@ -143,6 +144,18 @@ def constructor(cls: "type[AbstractVector]", obj: Quantity, /) -> "AbstractVecto
comps = {f.name: obj[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)

# ===============================================================
# Quax

def materialise(self) -> None:
msg = "Refusing to materialise `Quantity`."
raise RuntimeError(msg)

@abstractmethod
def aval(self) -> jax.core.ShapedArray:
"""Return the vector as a JAX array."""
raise NotImplementedError # pragma: no cover

# ===============================================================
# Array API

Expand Down Expand Up @@ -357,7 +370,7 @@ def __neg__(self) -> "Self":
def __rmul__(self: "AbstractVector", other: Any) -> Any:
return NotImplemented

def __setitem__(self, k: Any, v: Any) -> Never:
def __setitem__(self, k: Any, v: Any) -> NoReturn:
msg = f"{type(self).__name__} is immutable."
raise TypeError(msg)

Expand Down Expand Up @@ -518,17 +531,13 @@ def components(cls) -> tuple[str, ...]:
Examples
--------
We assume the following imports:
>>> from coordinax import CartesianPosition2D, SphericalPosition, RadialVelocity
We can get the components of a vector:
>>> import coordinax as cx
>>> CartesianPosition2D.components
>>> cx.CartesianPosition2D.components
('x', 'y')
>>> SphericalPosition.components
>>> cx.SphericalPosition.components
('r', 'theta', 'phi')
>>> RadialVelocity.components
>>> cx.RadialVelocity.components
('d_r',)
"""
Expand Down Expand Up @@ -556,7 +565,19 @@ def shapes(self) -> MappingProxyType[str, tuple[int, ...]]:

@property
def sizes(self) -> MappingProxyType[str, int]:
"""Get the sizes of the vector's components."""
"""Get the sizes of the vector's components.
Examples
--------
>>> import coordinax as cx
>>> cx.CartesianPosition2D.constructor(Quantity([1, 2], "m")).sizes
mappingproxy({'x': 1, 'y': 1})
>>> cx.CartesianPosition2D.constructor(Quantity([[1, 2], [1, 2]], "m")).sizes
mappingproxy({'x': 2, 'y': 2})
"""
return MappingProxyType({k: v.size for k, v in field_items(self)})

# ===============================================================
Expand Down
7 changes: 7 additions & 0 deletions src/coordinax/_base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def integral_cls(cls) -> type[AbstractVelocity]:
"""
raise NotImplementedError

# ===============================================================
# Quax

def aval(self) -> jax.core.ShapedArray:
"""Return the vector as a JAX array."""
raise NotImplementedError

# ===============================================================
# Unary operations

Expand Down
Loading

0 comments on commit fd1f51f

Please sign in to comment.