Skip to content

Commit

Permalink
Fix infinite loop if support not specified (#187)
Browse files Browse the repository at this point in the history
* Add tests for signature equivariance of tree-like paths

Introduces three new tests to validate the equivariance property of signatures for tree-like paths. Ensures consistency between original paths and their pruned representations using `esig_stream2sig`. These additions strengthen coverage for stream behavior in edge cases.

* Adjust effective support handling in LieIncrementStream

Replaced right_unbounded interval creation with direct initialization of `effective_support`. Added fallback to `effective_support` when `md.support` is not provided, ensuring proper restriction behavior.
  • Loading branch information
inakleinbottle authored Dec 16, 2024
1 parent 64ac672 commit 3e3564d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
11 changes: 8 additions & 3 deletions roughpy/src/streams/lie_increment_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ static py::object lie_increment_stream_from_increments(py::object data, py::kwar

dimn_t num_increments = ks_stream.row_count();

auto effective_support
= intervals::RealInterval::right_unbounded(0.0, md.interval_type);
intervals::RealInterval effective_support (0.0, 1.0, md.interval_type);

if (kwargs.contains("indices")) {
auto indices_arg = kwargs_pop(kwargs, "indices");
Expand Down Expand Up @@ -232,7 +231,13 @@ static py::object lie_increment_stream_from_increments(py::object data, py::kwar
*md.resolution},
md.schema
));
if (md.support) { result.restrict_to(*md.support); }

if (md.support) {
result.restrict_to(*md.support);
}
else {
result.restrict_to(effective_support);
}

return py::reinterpret_steal<py::object>(
python::RPyStream_FromStream(std::move(result))
Expand Down
58 changes: 58 additions & 0 deletions tests/streams/test_lie_increment_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,61 @@ def test_construct_sequence_of_lies():
lsig = stream.log_signature(rp.RealInterval(0, 1))

assert lsig == seq[0], f"{lsig} != {seq[0]}"


def esig_stream2sig(array, depth):
no_pts, width = array.shape
ctx = rp.get_context(width=width, depth=depth, coeffs=rp.DPReal)

increments = np.diff(array, axis=0)

stream = rp.LieIncrementStream.from_increments(increments, ctx=ctx)

return stream.signature()


def test_equivariance_esig_treelike1():

tree_like = np.array([
[0.0, 0],
[1, 0],
[1, 1],
[1, 0],
[2, 0],
[3, 1],
[2, 2],
[1, 1],
[2, 2],
[1, 3],
[2, 2],
[3, 3],
[2, 2],
[3, 1],
[4, 1],
[3, 1],
[2, 0],
[2, -1],
[2, 0],
[1, 0],
[1, -1],
[1, 0],
[0, 0],
], dtype=np.float64)

pruned = np.array([[0.0, 0.0]], dtype=np.float64)

assert_array_almost_equal(esig_stream2sig(tree_like, 2), esig_stream2sig(pruned, 2))


def test_equivariance_esig_treelike2():

tree_like = np.array([[0.0, 0], [1, 3], [0, 0], [1, 5], [2, 5], [1, 5], [0, 6], [1, 5], [0, 0]], dtype=np.float64)
pruned = np.array([[0., 0.]], dtype=np.float64)

assert_array_almost_equal(esig_stream2sig(tree_like, 2), esig_stream2sig(pruned, 2))

def test_equivariance_esig_treelike3():
tree_like = np.array([[0.0, 0], [1, 1], [3, 1], [2, 1], [1, 1], [2, 0]], dtype=np.float64)
pruned = np.array([[0., 0.], [1. ,1.], [2., 0]], dtype=np.float64)

assert_array_almost_equal(esig_stream2sig(tree_like, 2), esig_stream2sig(pruned, 2))

0 comments on commit 3e3564d

Please sign in to comment.