Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions roughpy_jax/streams/piecewise_abelian_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
to_log_signature,
to_signature,
)
from roughpy_jax.bases import LieBasis, TensorBasis, Basis
from roughpy_jax.bases import Basis, LieBasis, TensorBasis
from roughpy_jax.intervals import Interval, Partition, RealInterval, intersection

from .concepts import Stream


@partial(jax.tree_util.register_dataclass, data_fields=["_data", "_partition"],meta_fields=["_lie_basis", "_group_basis"],)
@partial(
jax.tree_util.register_dataclass,
data_fields=["_data", "_partition"],
meta_fields=["_lie_basis", "_group_basis"],
)
@dataclass(frozen=True)
class PiecewiseAbelianStream(Stream[DenseLie, DenseFreeTensor]):
"""A stream representing a piecewise abelian path."""
Expand Down Expand Up @@ -85,7 +89,9 @@ def log_signature(self, interval: Interval) -> DenseLie:
"or single-element endpoint arrays"
)
if inf.shape or sup.shape:
interval = RealInterval(inf.reshape(()), sup.reshape(()), interval.interval_type)
interval = RealInterval(
inf.reshape(()), sup.reshape(()), interval.interval_type
)

initial = FreeTensor.identity(
self._group_basis,
Expand Down Expand Up @@ -116,20 +122,16 @@ def get_piece(x_and_interval):
)

intervals = self._partition.to_intervals()
all_tensors = [initial] + [
get_piece((x, p)) for x, p in zip(self._data, intervals, strict=True)
]

# Stack all tensors along a leading axis into a single batched FreeTensor.
batched = jax.tree.map(lambda *arrs: jnp.stack(arrs), *all_tensors)
pieces = [get_piece((x, p)) for x, p in zip(self._data, intervals, strict=True)]
batched = jax.tree.map(lambda *arrs: jnp.stack(arrs), *pieces)

result_batched = lax.associative_scan(
lambda a, b: ft_fmexp(a, b, self._group_basis),
batched,
)
def combine(carry, piece):
updated = ft_fmexp(carry, piece, self._group_basis)
return updated, None

# Take the last prefix (the full product over all selected pieces).
result = jax.tree.map(lambda x: x[-1], result_batched)
result, _ = lax.scan(combine, initial, batched)
return to_log_signature(result)

@jax.jit
Expand Down
33 changes: 33 additions & 0 deletions roughpy_jax/tests/streams/test_piecewise_abelian_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax.numpy as jnp
import pytest
import roughpy_jax as rpj
from roughpy_jax.algebra import FreeTensor, ft_fmexp, lie_to_tensor, to_log_signature
from roughpy_jax.intervals import IntervalType, Partition, RealInterval
from roughpy_jax.streams import PiecewiseAbelianStream

Expand Down Expand Up @@ -119,6 +120,38 @@ def test_log_signature_cbh(self, pas_data):

assert jnp.allclose(log_sig.data, expected_log_sig.data, atol=1e-6)

def test_log_signature_multi_piece_stream(self):
lie_basis = rpj.LieBasis(26, 3)
tensor_basis = rpj.to_tensor_basis(lie_basis)
indices = (0, 1, 4, 19)

def make_lie(index):
data = jnp.zeros((lie_basis.size(),), dtype=jnp.float32)
data = data.at[index].set(1.0)
return rpj.Lie(data, lie_basis)

lies = tuple(make_lie(index) for index in indices)
stream = PiecewiseAbelianStream(
_data=lies,
_partition=Partition([0.0, 1.0, 2.0, 3.0, 4.0], IntervalType.ClOpen),
_lie_basis=lie_basis,
_group_basis=tensor_basis,
)
query_interval = RealInterval(0.0, 4.0, IntervalType.ClOpen)

actual_log_sig = stream.log_signature(query_interval)

expected_signature = FreeTensor.identity(tensor_basis, dtype=jnp.float32)
for lie in lies:
expected_signature = ft_fmexp(
expected_signature,
lie_to_tensor(lie),
tensor_basis,
)
expected_log_sig = to_log_signature(expected_signature)

assert jnp.allclose(actual_log_sig.data, expected_log_sig.data, atol=1e-6)

def test_stream_metadata(self, pas_data):
"""Test that the stream exposes dtype and batch metadata."""
assert pas_data.stream.dtype == pas_data.dtype
Expand Down