Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c62537f
refactor: reorganize imports and refine type variable bounds in `alge…
inakleinbottle Mar 31, 2026
cbecbf0
refactor: enforce exception for undefined operations in `get_operatio…
inakleinbottle Mar 31, 2026
715a067
Defer type in __all_operations dict
inakleinbottle Mar 31, 2026
8e32c35
Defer type operation return
inakleinbottle Mar 31, 2026
bccf692
remove generic typevar for operations
inakleinbottle Mar 31, 2026
3bd0e0f
fix: correct unpacking of `bases` in `result_basis` call
inakleinbottle Mar 31, 2026
5272de6
fix global declaration of cache
inakleinbottle Mar 31, 2026
948d5a8
noqa on initialisation of unknown type
inakleinbottle Mar 31, 2026
1883906
fix: specify argument type as `jax.Array` in `convert_args_dtypes` me…
inakleinbottle Mar 31, 2026
bec6ca6
fix: explicitly cast default min degree arguments to `np.int32` in de…
inakleinbottle Mar 31, 2026
cf5a4a1
refactor: replace repeated `np.int32(0)` declarations with constant `…
inakleinbottle Mar 31, 2026
ec9fc54
annotate: add `# type: ignore` and `# noqa` comments to suppress lint…
inakleinbottle Mar 31, 2026
4d91301
fix: remove unused `TensorBasis` import in `strategies.py`
inakleinbottle Mar 31, 2026
97724f6
refactor: consolidate type annotations for tensor operations, replaci…
inakleinbottle Mar 31, 2026
5faaa58
fix: correct tensor construction by unpacking single-element `out_dat…
inakleinbottle Mar 31, 2026
b15b0e0
fix: add type ignores to suppress linter warnings and adjust function…
inakleinbottle Mar 31, 2026
bfe525a
refactor: simplify `from_jax_cotangent` signature and replace overloa…
inakleinbottle Mar 31, 2026
1fed042
fix: correct function calls to use `lie_to_tensor` and adjust keyword…
inakleinbottle Mar 31, 2026
0ca876e
fix: allow `width` and `depth` to accept both `np.int32` and `int` types
inakleinbottle Mar 31, 2026
7562798
refactor: refine type annotations in `dense_algebra.py` by replacing …
inakleinbottle Mar 31, 2026
0ada84e
refactor: update type hints in `ops.py` for improved clarity and cons…
inakleinbottle Mar 31, 2026
e684f5a
refactor: replace `np.int32(0)` with constant `INT32_ZERO` and standa…
inakleinbottle Mar 31, 2026
af6990e
refactor: update type hints in `lie_increment_stream.py` and remove u…
inakleinbottle Mar 31, 2026
2e60cc8
refactor: update type annotations in `piecewise_abelian_stream.py` an…
inakleinbottle Mar 31, 2026
7a24114
refactor: remove unnecessary type ignore in `strategies.py` to improv…
inakleinbottle Mar 31, 2026
9a6428f
refactor: remove `Hashable` from `Basis` and explicitly add `__hash__…
inakleinbottle Mar 31, 2026
93d1ee9
refactor: update type hints and imports in tests for improved clarity…
inakleinbottle Mar 31, 2026
6a15065
refactor: specify `LieBasis` and `TensorBasis` types for `_lie_basis`…
inakleinbottle Mar 31, 2026
0f52633
refactor: add type ignores for `_lie_basis` and `_group_basis` to sup…
inakleinbottle Mar 31, 2026
f47b0c6
refactor: update type annotation for `from_stream` method in `lie_inc…
inakleinbottle Mar 31, 2026
23ce1f7
refactor: reorder arguments in `lie_increment_stream.py` constructor …
inakleinbottle Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
578 changes: 305 additions & 273 deletions roughpy_jax/algebra.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions roughpy_jax/bases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from collections.abc import Hashable, Iterable
from collections.abc import Iterable
from typing import Literal, TypeVar

import jax
Expand All @@ -11,19 +11,20 @@


@typing.runtime_checkable
class Basis(typing.Protocol, Hashable):
class Basis(typing.Protocol):
"""
Structural protocol shared by basis objects used in ``roughpy_jax``.

Any object implementing this protocol provides the width, truncation depth,
and degree offsets needed to construct compatible tensor or Lie bases.
"""

width: np.int32
depth: np.int32
width: np.int32 | int
depth: np.int32 | int
degree_begin: DegreeBeginArray

def size(self) -> int: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool: ...


Expand Down
17 changes: 8 additions & 9 deletions roughpy_jax/dense_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from roughpy_jax.bases import BasisT, TensorBasis

AlgebraT = TypeVar("AlgebraT")
_T = TypeVar("_T")
AlgebraT = TypeVar("AlgebraT", bound="DenseAlgebra")


def get_batch_shape(operand) -> tuple[int, ...]:
Expand Down Expand Up @@ -116,7 +115,7 @@ def _pad_final_dim(data: jax.Array, size: int) -> jax.Array:


def _algebra_add(
a: AlgebraT, b: AlgebraT, *, impl: Callable[[jax.Array, ...], jax.Array]
a: AlgebraT, b: AlgebraT, *, impl: Callable[[jax.Array, jax.Array], jax.Array]
) -> AlgebraT:
"""
Apply a pointwise binary operation to two compatible dense algebra objects.
Expand Down Expand Up @@ -307,11 +306,11 @@ def tree_unflatten(cls, aux_data, children):

@classmethod
def zero(
cls: type[_T],
cls: type[AlgebraT],
basis: BasisT,
dtype: jax.typing.DTypeLike = jnp.dtype("float32"),
batch_dims: tuple[int, ...] = tuple(),
) -> _T:
) -> AlgebraT:
"""
Construct the additive identity in the given basis.

Expand Down Expand Up @@ -344,11 +343,11 @@ class DenseTensor(DenseAlgebra[TensorBasis]):

@classmethod
def identity(
cls: Type[_T],
cls: Type[AlgebraT],
basis: TensorBasis,
dtype: jax.typing.DTypeLike = jnp.dtype("float32"),
batch_dims: tuple[int, ...] = tuple(),
) -> _T:
) -> AlgebraT:
"""
Construct the multiplicative identity in a tensor basis.

Expand All @@ -372,7 +371,7 @@ def identity(
DenseTensor.DualVector = DenseTensor


def zero_like(algebra: _T, dtype: jax.typing.DTypeLike | None = None) -> _T:
def zero_like(algebra: AlgebraT, dtype: jax.typing.DTypeLike | None = None) -> AlgebraT:
"""
Construct a zero element with the same type, shape, and basis as ``algebra``.

Expand All @@ -389,7 +388,7 @@ class of the input while replacing all coefficients with zeros. An
return type(algebra)(data, algebra.basis)


def identity_like(tensor: _T, dtype: jax.typing.DTypeLike | None = None) -> _T:
def identity_like(tensor: AlgebraT, dtype: jax.typing.DTypeLike | None = None) -> AlgebraT:
"""
Creates an identity-like object based on the structure and type of the provided tensor. The resulting object retains
the basis of the input tensor but modifies its data to follow an identity pattern. The primary element indicating
Expand Down
Loading