Skip to content

Commit

Permalink
feat: start jnp.linalg (#71)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Aug 5, 2024
1 parent 13b23f7 commit fd56ae3
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion src/quaxed/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
"""Quaxed :mod:`jax.numpy.linalg`."""

__all__: list[str] = []
__all__ = [ # noqa: F822
"det",
]

import sys
from collections.abc import Callable
from typing import Any

import jax.numpy as jnp
from quax import quaxify


def __dir__() -> list[str]:
return sorted(__all__)


def __getattr__(name: str) -> Callable[..., Any]: # TODO: better type hint
"""Get the object from the `jax.numpy` module."""
if name not in __all__:
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)

# Get the object
jnp_obj = getattr(jnp.linalg, name)

# Quaxify?
out = quaxify(jnp_obj)

# Cache the function in this module
setattr(sys.modules[__name__], name, out)

return out

0 comments on commit fd56ae3

Please sign in to comment.