Skip to content

Commit

Permalink
🎉 cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Jan 24, 2024
1 parent 3c6d9d9 commit 6916dbb
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 130 deletions.
3 changes: 3 additions & 0 deletions tests/test_algos_python_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
(MinMaxDownsampler(), MinMax_py()),
(M4Downsampler(), M4_py()),
(LTTBDownsampler(), LTTB_py()),
# Include NaN downsamplers
(NanMinMaxDownsampler(), NaNMinMax_py()),
(NaNM4Downsampler(), NaNM4_py()),
],
)
@pytest.mark.parametrize("n", [10_000, 10_032, 20_321, 23_489])
Expand Down
30 changes: 29 additions & 1 deletion tests/test_tsdownsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_error_invalid_args():
@pytest.mark.parametrize("downsampler", generate_rust_downsamplers())
def test_non_contiguous_array(downsampler: AbstractDownsampler):
"""Test non contiguous array."""
arr = np.random.randint(0, 100, size=10_000)
arr = np.random.randint(0, 100, size=10_000).astype(np.float32)
arr = arr[::2]
assert not arr.flags["C_CONTIGUOUS"]
with pytest.raises(ValueError) as e_msg:
Expand All @@ -361,3 +361,31 @@ def test_everynth_non_contiguous_array():
s_downsampled = downsampler.downsample(arr, n_out=100)
assert s_downsampled[0] == 0
assert s_downsampled[-1] == 4950


def test_nan_minmax_downsampler():
"""Test NaN downsamplers."""
arr = np.random.randn(50_000)
arr[::5] = np.nan
s_downsampled = NanMinMaxDownsampler().downsample(arr, n_out=100)
arr_downsampled = arr[s_downsampled]
assert np.all(np.isnan(arr_downsampled))


def test_nan_m4_downsampler():
"""Test NaN downsamplers."""
arr = np.random.randn(50_000)
arr[::5] = np.nan
s_downsampled = NaNM4Downsampler().downsample(arr, n_out=100)
arr_downsampled = arr[s_downsampled]
assert np.all(np.isnan(arr_downsampled[1::4])) # min is NaN
assert np.all(np.isnan(arr_downsampled[2::4])) # max is NaN


def test_nan_minmaxlttb_downsampler():
"""Test NaN downsamplers."""
arr = np.random.randn(50_000)
arr[::5] = np.nan
s_downsampled = NaNMinMaxLTTBDownsampler().downsample(arr, n_out=100)
arr_downsampled = arr[s_downsampled]
assert np.all(np.isnan(arr_downsampled[1:-1])) # first and last are not NaN
47 changes: 45 additions & 2 deletions tsdownsample/downsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@


class MinMaxDownsampler(AbstractRustDownsampler):
"""Downsampler that uses the MinMax algorithm. If the y data contains NaNs, these
ignored (i.e. the NaNs are not taken into account when selecting data points).
For each bin, the indices of the minimum and maximum values are selected.
"""

@property
def rust_mod(self):
return _tsdownsample_rs.minmax
Expand All @@ -26,6 +32,12 @@ def _check_valid_n_out(n_out: int):


class NanMinMaxDownsampler(AbstractRustNaNDownsampler):
"""Downsampler that uses the MinMax algorithm. If the y data contains NaNs, the
indices of these NaNs are returned.
For each bin, the indices of the minimum and maximum values are selected.
"""

@property
def rust_mod(self):
return _tsdownsample_rs.minmax
Expand All @@ -38,6 +50,13 @@ def _check_valid_n_out(n_out: int):


class M4Downsampler(AbstractRustDownsampler):
"""Downsampler that uses the M4 algorithm. If the y data contains NaNs, these are
ignored (i.e. the NaNs are not taken into account when selecting data points).
For each bin, the indices of the first, last, minimum and maximum values are
selected.
"""

@property
def rust_mod(self):
return _tsdownsample_rs.m4
Expand All @@ -50,6 +69,13 @@ def _check_valid_n_out(n_out: int):


class NaNM4Downsampler(AbstractRustNaNDownsampler):
"""Downsampler that uses the M4 algorithm. If the y data contains NaNs, the indices
of these NaNs are returned.
For each bin, the indices of the first, last, minimum and maximum values are
selected.
"""

@property
def rust_mod(self):
return _tsdownsample_rs.m4
Expand All @@ -62,12 +88,21 @@ def _check_valid_n_out(n_out: int):


class LTTBDownsampler(AbstractRustDownsampler):
"""Downsampler that uses the LTTB algorithm."""

@property
def rust_mod(self):
return _tsdownsample_rs.lttb


class MinMaxLTTBDownsampler(AbstractRustDownsampler):
"""Downsampler that uses the MinMaxLTTB algorithm. If the y data contains NaNs,
these are ignored (i.e. the NaNs are not taken into account when selecting data
points).
MinMaxLTTB paper: https://arxiv.org/abs/2305.00332
"""

@property
def rust_mod(self):
return _tsdownsample_rs.minmaxlttb
Expand All @@ -82,23 +117,31 @@ def downsample(


class NaNMinMaxLTTBDownsampler(AbstractRustNaNDownsampler):
"""Downsampler that uses the MinMaxLTTB algorithm. If the y data contains NaNs, the
indices of these NaNs are returned.
MinMaxLTTB paper: https://arxiv.org/abs/2305.00332
"""

@property
def rust_mod(self):
return _tsdownsample_rs.minmaxlttb

def downsample(
self, *args, n_out: int, minmax_ratio: int = 30, n_threads: int = 1, **_
self, *args, n_out: int, minmax_ratio: int = 4, parallel: bool = False, **_
):
assert minmax_ratio > 0, "minmax_ratio must be greater than 0"
return super().downsample(
*args, n_out=n_out, n_threads=n_threads, ratio=minmax_ratio
*args, n_out=n_out, parallel=parallel, ratio=minmax_ratio
)


# ------------------ EveryNth Downsampler ------------------


class EveryNthDownsampler(AbstractDownsampler):
"""Downsampler that selects every nth data point"""

def __init__(self, **kwargs):
super().__init__(check_contiguous=False, **kwargs)

Expand Down
Loading

0 comments on commit 6916dbb

Please sign in to comment.