Skip to content

Commit

Permalink
feat: distance modulus (#131)
Browse files Browse the repository at this point in the history
* feat: distance modulus

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Jul 8, 2024
1 parent 2977ff0 commit b4bb989
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
71 changes: 69 additions & 2 deletions src/unxt/_quantity/distance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=import-error, no-member, unsubscriptable-object
# b/c it doesn't understand dataclass fields

__all__ = ["AbstractDistance", "Distance", "Parallax"]
__all__ = ["AbstractDistance", "Distance", "Parallax", "DistanceModulus"]

from abc import abstractmethod
from dataclasses import KW_ONLY
Expand Down Expand Up @@ -143,7 +143,61 @@ def parallax(self) -> "Parallax":
@property
def distance_modulus(self) -> Quantity:
"""The distance modulus."""
return self.distance.distance_modulus # TODO: shortcut
return self.distance.distance_modulus # TODO: specific shortcut


##############################################################################


class DistanceModulus(AbstractDistance):
"""Distance modulus quantity."""

def __check_init__(self) -> None:
"""Check the initialization."""
if self.unit != u.mag:
msg = "Distance modulus must have units of magnitude."
raise ValueError(msg)

@property
def distance(self) -> Distance:
"""The distance.
The distance is calculated as :math:`10^{(m / 5 + 1)}`.
Examples
--------
>>> from unxt import DistanceModulus
>>> DistanceModulus(10, "mag").distance
Distance(Array(1000., dtype=float32, ...), unit='pc')
"""
return Distance(10 ** (self.value / 5 + 1), "pc")

@property
def parallax(self) -> Parallax:
"""The parallax.
Examples
--------
>>> from unxt import DistanceModulus
>>> DistanceModulus(10, "mag").parallax.to("mas")
Parallax(Array(0.99999994, dtype=float32), unit='mas')
"""
return self.distance.parallax # TODO: specific shortcut

@property
def distance_modulus(self) -> "DistanceModulus":
"""The distance modulus.
Examples
--------
>>> from unxt import DistanceModulus
>>> DistanceModulus(10, "mag").distance_modulus
DistanceModulus(Array(10, dtype=int32, ...), unit='mag')
"""
return self


# ============================================================================
Expand All @@ -159,6 +213,19 @@ def constructor(
return cls(xp.asarray(d.value, dtype=dtype), d.unit)


@Distance.constructor._f.register # type: ignore[no-redef] # noqa: SLF001
def constructor(
cls: type[Distance],
value: DistanceModulus | Quantity["mag"],
/,
*,
dtype: Any = None,
) -> Distance:
"""Construct a `Distance` from a mag through the dist mod."""
d = 10 ** (value.to_units_value("mag") / 5 + 1)
return cls(xp.asarray(d, dtype=dtype), "pc")


@Parallax.constructor._f.register # type: ignore[no-redef] # noqa: SLF001
def constructor(
cls: type[Parallax], value: Distance | Quantity["length"], /, *, dtype: Any = None
Expand Down
33 changes: 33 additions & 0 deletions src/unxt/_quantity/register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ def _abs_p(x: AbstractQuantity) -> AbstractQuantity:
>>> abs(q)
UncheckedQuantity(Array(1, dtype=int32, ...), unit='m')
>>> from unxt import Distance
>>> d = Distance(-1, "m")
>>> xp.abs(d)
Distance(Array(1, dtype=int32, ...), unit='m')
>>> from unxt import Parallax
>>> p = Parallax(-1, "mas", check_negative=False)
>>> xp.abs(p)
Parallax(Array(1, dtype=int32, ...), unit='mas')
>>> from unxt import DistanceModulus
>>> dm = DistanceModulus(-1, "mag")
>>> xp.abs(dm)
DistanceModulus(Array(1, dtype=int32, weak_type=True), unit='mag')
"""
return replace(x, value=lax.abs(x.value))

Expand Down Expand Up @@ -159,6 +174,24 @@ def _add_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity:
>>> q1 + q2
Quantity['length'](Array(1.5, dtype=float32, ...), unit='km')
>>> from unxt import Distance
>>> d1 = Distance(1.0, "km")
>>> d2 = Distance(500.0, "m")
>>> xp.add(d1, d2)
Distance(Array(1.5, dtype=float32, ...), unit='km')
>>> from unxt import Parallax
>>> p1 = Parallax(1.0, "mas")
>>> p2 = Parallax(500.0, "uas")
>>> xp.add(p1, p2)
Parallax(Array(1.5, dtype=float32, ...), unit='mas')
>>> from unxt import DistanceModulus
>>> dm1 = DistanceModulus(1.0, "mag")
>>> dm2 = DistanceModulus(500.0, "mag")
>>> xp.add(dm1, dm2)
DistanceModulus(Array(501., dtype=float32, ...), unit='mag')
"""
return replace(x, value=lax.add(x.value, y.to_units_value(x.unit)))

Expand Down

0 comments on commit b4bb989

Please sign in to comment.