Skip to content

Commit

Permalink
Replace poses fixture in test_filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Jan 20, 2025
1 parent 1d6a729 commit 4055d26
Showing 1 changed file with 26 additions and 79 deletions.
105 changes: 26 additions & 79 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# Dataset fixtures
list_valid_datasets_without_nans = [
"valid_poses_dataset",
"valid_poses_dataset_uniform_linear_motion",
"valid_bboxes_dataset",
]
list_valid_datasets_with_nans = [
Expand All @@ -28,7 +28,8 @@
list_valid_datasets_with_nans,
)
@pytest.mark.parametrize(
"max_gap, expected_n_nans_in_position", [(None, 0), (0, 3), (1, 2), (2, 0)]
"max_gap, expected_n_nans_in_position",
[(None, [20, 0]), (0, [26, 6]), (1, [24, 4]), (2, [20, 0])],
)
def test_interpolate_over_time_on_position(
valid_dataset_with_nan,
Expand All @@ -42,7 +43,6 @@ def test_interpolate_over_time_on_position(
for different values of ``max_gap``.
"""
valid_dataset_in_frames = request.getfixturevalue(valid_dataset_with_nan)

# Get position array with time unit in frames & seconds
# assuming 10 fps = 0.1 s per frame
valid_dataset_in_seconds = valid_dataset_in_frames.copy()
Expand All @@ -53,9 +53,7 @@ def test_interpolate_over_time_on_position(
"frames": valid_dataset_in_frames.position,
"seconds": valid_dataset_in_seconds.position,
}

# Count number of NaNs before and after interpolating position
n_nans_before = helpers.count_nans(position["frames"])
# Count number of NaNs
n_nans_after_per_time_unit = {}
for time_unit in ["frames", "seconds"]:
# interpolate
Expand All @@ -66,39 +64,24 @@ def test_interpolate_over_time_on_position(
n_nans_after_per_time_unit[time_unit] = helpers.count_nans(
position_interp
)

# The number of NaNs should be the same for both datasets
# as max_gap is based on number of missing observations (NaNs)
assert (
n_nans_after_per_time_unit["frames"]
== n_nans_after_per_time_unit["seconds"]
)

# The number of NaNs should decrease after interpolation
n_nans_after = n_nans_after_per_time_unit["frames"]
if max_gap == 0:
assert n_nans_after == n_nans_before
else:
assert n_nans_after < n_nans_before

# The number of NaNs after interpolating should be as expected
assert n_nans_after == (
valid_dataset_in_frames.sizes["space"]
* valid_dataset_in_frames.sizes.get("keypoints", 1)
# in bboxes dataset there is no keypoints dimension
* expected_n_nans_in_position
)
n_nans_after = n_nans_after_per_time_unit["frames"]
dataset_index = list_valid_datasets_with_nans.index(valid_dataset_with_nan)
assert n_nans_after == expected_n_nans_in_position[dataset_index]


@pytest.mark.parametrize(
"valid_dataset_no_nans, n_low_confidence_kpts",
[
("valid_poses_dataset", 20),
("valid_bboxes_dataset", 5),
],
"valid_dataset_no_nans",
list_valid_datasets_without_nans,
)
def test_filter_by_confidence_on_position(
valid_dataset_no_nans, n_low_confidence_kpts, helpers, request
valid_dataset_no_nans, helpers, request
):
"""Test that points below the default 0.6 confidence threshold
are converted to NaN.
Expand All @@ -110,15 +93,14 @@ def test_filter_by_confidence_on_position(
confidence=valid_input_dataset.confidence,
threshold=0.6,
)

# Count number of NaNs in the full array
n_nans = helpers.count_nans(position_filtered)

# expected number of nans for poses:
# 5 timepoints * 2 individuals * 2 keypoints
# Note: we count the number of nans in the array, so we multiply
# the number of low confidence keypoints by the number of
# space dimensions
n_low_confidence_kpts = 5
assert isinstance(position_filtered, xr.DataArray)
assert n_nans == valid_input_dataset.sizes["space"] * n_low_confidence_kpts

Expand Down Expand Up @@ -147,12 +129,9 @@ def test_filter_on_position(
position_filtered = filter_func(
valid_input_dataset.position, **filter_kwargs
)

del position_filtered.attrs["log"]

# filtered array is an xr.DataArray
assert isinstance(position_filtered, xr.DataArray)

# filtered data should not be equal to the original data
assert not position_filtered.equals(valid_input_dataset.position)

Expand All @@ -163,12 +142,12 @@ def test_filter_on_position(
("valid_dataset, expected_nans_in_filtered_position_per_indiv"),
[
(
"valid_poses_dataset",
{0: 0, 1: 0},
), # filtering should not introduce nans if input has no nans
("valid_bboxes_dataset", {0: 0, 1: 0}),
("valid_poses_dataset_with_nan", {0: 7, 1: 0}),
("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}),
"valid_poses_dataset_uniform_linear_motion",
[0, 0], # no nans in the input data
),
("valid_bboxes_dataset", [0, 0]), # no nans in the input data
("valid_poses_dataset_uniform_linear_motion_with_nan", [38, 0]),
("valid_bboxes_dataset_with_nan", [14, 0]),
],
)
@pytest.mark.parametrize(
Expand All @@ -189,49 +168,22 @@ def test_filter_with_nans_on_position(
"""Test NaN behaviour of the selected filter. The median and SG filters
should set all values to NaN if one element of the sliding window is NaN.
"""

def _assert_n_nans_in_position_per_individual(
valid_input_dataset,
position_filtered,
expected_nans_in_filt_position_per_indiv,
):
# compute n nans in position after filtering per individual
n_nans_after_filtering_per_indiv = {
i: helpers.count_nans(position_filtered.isel(individuals=i))
for i in range(valid_input_dataset.sizes["individuals"])
}

# check number of nans per indiv is as expected
for i in range(valid_input_dataset.sizes["individuals"]):
assert n_nans_after_filtering_per_indiv[i] == (
expected_nans_in_filt_position_per_indiv[i]
* valid_input_dataset.sizes["space"]
* valid_input_dataset.sizes.get("keypoints", 1)
)

# Filter position
valid_input_dataset = request.getfixturevalue(valid_dataset)
position_filtered = filter_func(
valid_input_dataset.position, **filter_kwargs
)

# check number of nans per indiv is as expected
_assert_n_nans_in_position_per_individual(
valid_input_dataset,
position_filtered,
expected_nans_in_filtered_position_per_indiv,
# Compute n nans in position after filtering per individual
n_nans_after_filtering_per_indiv = [
helpers.count_nans(position_filtered.isel(individuals=i))
for i in range(valid_input_dataset.sizes["individuals"])
]
# Check number of nans per indiv is as expected
assert (
n_nans_after_filtering_per_indiv
== expected_nans_in_filtered_position_per_indiv
)

# if input had nans,
# individual 1's position at exact timepoints 0, 1 and 5 is not nan
n_nans_input = helpers.count_nans(valid_input_dataset.position)
if n_nans_input != 0:
assert not (
position_filtered.isel(individuals=0, time=[0, 1, 5])
.isnull()
.any()
)


@pytest.mark.parametrize(
"valid_dataset_with_nan",
Expand All @@ -256,24 +208,19 @@ def test_filter_with_nans_on_position_varying_window(
kwargs = {"window": window}
if filter_func == savgol_filter:
kwargs["polyorder"] = 2

# Filter position
valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan)
position_filtered = filter_func(
valid_input_dataset.position,
**kwargs,
)

# Count number of NaNs in the input and filtered position data
n_total_nans_initial = helpers.count_nans(valid_input_dataset.position)
n_consecutive_nans_initial = helpers.count_consecutive_nans(
valid_input_dataset.position
)

n_total_nans_filtered = helpers.count_nans(position_filtered)

max_nans_increase = (window - 1) * n_consecutive_nans_initial

# Check that filtering does not reduce number of nans
assert n_total_nans_filtered >= n_total_nans_initial
# Check that the increase in nans is below the expected threshold
Expand Down

0 comments on commit 4055d26

Please sign in to comment.