diff --git a/roughpy_jax/streams/piecewise_abelian_stream.py b/roughpy_jax/streams/piecewise_abelian_stream.py index 4ebf5b6..bb0bb49 100644 --- a/roughpy_jax/streams/piecewise_abelian_stream.py +++ b/roughpy_jax/streams/piecewise_abelian_stream.py @@ -14,13 +14,17 @@ to_log_signature, to_signature, ) -from roughpy_jax.bases import LieBasis, TensorBasis, Basis +from roughpy_jax.bases import Basis, LieBasis, TensorBasis from roughpy_jax.intervals import Interval, Partition, RealInterval, intersection from .concepts import Stream -@partial(jax.tree_util.register_dataclass, data_fields=["_data", "_partition"],meta_fields=["_lie_basis", "_group_basis"],) +@partial( + jax.tree_util.register_dataclass, + data_fields=["_data", "_partition"], + meta_fields=["_lie_basis", "_group_basis"], +) @dataclass(frozen=True) class PiecewiseAbelianStream(Stream[DenseLie, DenseFreeTensor]): """A stream representing a piecewise abelian path.""" @@ -85,7 +89,9 @@ def log_signature(self, interval: Interval) -> DenseLie: "or single-element endpoint arrays" ) if inf.shape or sup.shape: - interval = RealInterval(inf.reshape(()), sup.reshape(()), interval.interval_type) + interval = RealInterval( + inf.reshape(()), sup.reshape(()), interval.interval_type + ) initial = FreeTensor.identity( self._group_basis, @@ -116,20 +122,16 @@ def get_piece(x_and_interval): ) intervals = self._partition.to_intervals() - all_tensors = [initial] + [ - get_piece((x, p)) for x, p in zip(self._data, intervals, strict=True) - ] # Stack all tensors along a leading axis into a single batched FreeTensor. - batched = jax.tree.map(lambda *arrs: jnp.stack(arrs), *all_tensors) + pieces = [get_piece((x, p)) for x, p in zip(self._data, intervals, strict=True)] + batched = jax.tree.map(lambda *arrs: jnp.stack(arrs), *pieces) - result_batched = lax.associative_scan( - lambda a, b: ft_fmexp(a, b, self._group_basis), - batched, - ) + def combine(carry, piece): + updated = ft_fmexp(carry, piece, self._group_basis) + return updated, None - # Take the last prefix (the full product over all selected pieces). - result = jax.tree.map(lambda x: x[-1], result_batched) + result, _ = lax.scan(combine, initial, batched) return to_log_signature(result) @jax.jit diff --git a/roughpy_jax/tests/streams/test_piecewise_abelian_stream.py b/roughpy_jax/tests/streams/test_piecewise_abelian_stream.py index dd87f56..15bb101 100644 --- a/roughpy_jax/tests/streams/test_piecewise_abelian_stream.py +++ b/roughpy_jax/tests/streams/test_piecewise_abelian_stream.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import pytest import roughpy_jax as rpj +from roughpy_jax.algebra import FreeTensor, ft_fmexp, lie_to_tensor, to_log_signature from roughpy_jax.intervals import IntervalType, Partition, RealInterval from roughpy_jax.streams import PiecewiseAbelianStream @@ -119,6 +120,38 @@ def test_log_signature_cbh(self, pas_data): assert jnp.allclose(log_sig.data, expected_log_sig.data, atol=1e-6) + def test_log_signature_multi_piece_stream(self): + lie_basis = rpj.LieBasis(26, 3) + tensor_basis = rpj.to_tensor_basis(lie_basis) + indices = (0, 1, 4, 19) + + def make_lie(index): + data = jnp.zeros((lie_basis.size(),), dtype=jnp.float32) + data = data.at[index].set(1.0) + return rpj.Lie(data, lie_basis) + + lies = tuple(make_lie(index) for index in indices) + stream = PiecewiseAbelianStream( + _data=lies, + _partition=Partition([0.0, 1.0, 2.0, 3.0, 4.0], IntervalType.ClOpen), + _lie_basis=lie_basis, + _group_basis=tensor_basis, + ) + query_interval = RealInterval(0.0, 4.0, IntervalType.ClOpen) + + actual_log_sig = stream.log_signature(query_interval) + + expected_signature = FreeTensor.identity(tensor_basis, dtype=jnp.float32) + for lie in lies: + expected_signature = ft_fmexp( + expected_signature, + lie_to_tensor(lie), + tensor_basis, + ) + expected_log_sig = to_log_signature(expected_signature) + + assert jnp.allclose(actual_log_sig.data, expected_log_sig.data, atol=1e-6) + def test_stream_metadata(self, pas_data): """Test that the stream exposes dtype and batch metadata.""" assert pas_data.stream.dtype == pas_data.dtype