From 3e3564d0613c5710a19aaaf74ce36b35a98282dd Mon Sep 17 00:00:00 2001 From: Sam Morley <41870650+inakleinbottle@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:47:09 +0000 Subject: [PATCH] Fix infinite loop if support not specified (#187) * Add tests for signature equivariance of tree-like paths Introduces three new tests to validate the equivariance property of signatures for tree-like paths. Ensures consistency between original paths and their pruned representations using `esig_stream2sig`. These additions strengthen coverage for stream behavior in edge cases. * Adjust effective support handling in LieIncrementStream Replaced right_unbounded interval creation with direct initialization of `effective_support`. Added fallback to `effective_support` when `md.support` is not provided, ensuring proper restriction behavior. --- roughpy/src/streams/lie_increment_stream.cpp | 11 +++- tests/streams/test_lie_increment_path.py | 58 ++++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/roughpy/src/streams/lie_increment_stream.cpp b/roughpy/src/streams/lie_increment_stream.cpp index d3dfb2e1d..817e76d45 100644 --- a/roughpy/src/streams/lie_increment_stream.cpp +++ b/roughpy/src/streams/lie_increment_stream.cpp @@ -114,8 +114,7 @@ static py::object lie_increment_stream_from_increments(py::object data, py::kwar dimn_t num_increments = ks_stream.row_count(); - auto effective_support - = intervals::RealInterval::right_unbounded(0.0, md.interval_type); + intervals::RealInterval effective_support (0.0, 1.0, md.interval_type); if (kwargs.contains("indices")) { auto indices_arg = kwargs_pop(kwargs, "indices"); @@ -232,7 +231,13 @@ static py::object lie_increment_stream_from_increments(py::object data, py::kwar *md.resolution}, md.schema )); - if (md.support) { result.restrict_to(*md.support); } + + if (md.support) { + result.restrict_to(*md.support); + } + else { + result.restrict_to(effective_support); + } return py::reinterpret_steal( python::RPyStream_FromStream(std::move(result)) diff --git a/tests/streams/test_lie_increment_path.py b/tests/streams/test_lie_increment_path.py index 58ea1ae0a..efd7fbf28 100644 --- a/tests/streams/test_lie_increment_path.py +++ b/tests/streams/test_lie_increment_path.py @@ -411,3 +411,61 @@ def test_construct_sequence_of_lies(): lsig = stream.log_signature(rp.RealInterval(0, 1)) assert lsig == seq[0], f"{lsig} != {seq[0]}" + + +def esig_stream2sig(array, depth): + no_pts, width = array.shape + ctx = rp.get_context(width=width, depth=depth, coeffs=rp.DPReal) + + increments = np.diff(array, axis=0) + + stream = rp.LieIncrementStream.from_increments(increments, ctx=ctx) + + return stream.signature() + + +def test_equivariance_esig_treelike1(): + + tree_like = np.array([ + [0.0, 0], + [1, 0], + [1, 1], + [1, 0], + [2, 0], + [3, 1], + [2, 2], + [1, 1], + [2, 2], + [1, 3], + [2, 2], + [3, 3], + [2, 2], + [3, 1], + [4, 1], + [3, 1], + [2, 0], + [2, -1], + [2, 0], + [1, 0], + [1, -1], + [1, 0], + [0, 0], + ], dtype=np.float64) + + pruned = np.array([[0.0, 0.0]], dtype=np.float64) + + assert_array_almost_equal(esig_stream2sig(tree_like, 2), esig_stream2sig(pruned, 2)) + + +def test_equivariance_esig_treelike2(): + + tree_like = np.array([[0.0, 0], [1, 3], [0, 0], [1, 5], [2, 5], [1, 5], [0, 6], [1, 5], [0, 0]], dtype=np.float64) + pruned = np.array([[0., 0.]], dtype=np.float64) + + assert_array_almost_equal(esig_stream2sig(tree_like, 2), esig_stream2sig(pruned, 2)) + +def test_equivariance_esig_treelike3(): + tree_like = np.array([[0.0, 0], [1, 1], [3, 1], [2, 1], [1, 1], [2, 0]], dtype=np.float64) + pruned = np.array([[0., 0.], [1. ,1.], [2., 0]], dtype=np.float64) + + assert_array_almost_equal(esig_stream2sig(tree_like, 2), esig_stream2sig(pruned, 2)) \ No newline at end of file