From ce60e9bf85decb220ccd13cfcc0f1d0c13769d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arne=20K=C3=BCderle?= Date: Wed, 27 Mar 2024 16:01:18 +0100 Subject: [PATCH] tooling update (bye black) --- .ruff.toml | 141 +++++++++++ _tasks.py | 22 +- docs/conf.py | 13 +- ...te_coordinate_definition_template_plots.py | 7 +- docs/sphinxext/githublink.py | 17 +- example_data/extract_example_data.py | 19 +- examples/advanced_features/algo_serialize.py | 1 + examples/advanced_features/caching.py | 1 - examples/advanced_features/multi_process.py | 1 + .../cross_validation.py | 5 +- .../datasets_and_pipelines/custom_dataset.py | 10 +- examples/datasets_and_pipelines/gridsearch.py | 4 +- .../datasets_and_pipelines/gridsearch_cv.py | 3 +- .../optimizable_pipelines.py | 4 +- .../event_detection/herzer_event_detection.py | 12 +- .../event_detection/rampp_event_detection.py | 21 +- examples/full_pipelines/mad_gait_pipeline.py | 13 +- .../ullrich_gait_sequence_detection.py | 4 +- .../generic_algorithms/base_dtw_generic.py | 6 +- .../automatic_sensor_alignment_details.py | 4 +- .../preprocessing/manual_sensor_alignment.py | 2 +- .../barth_dtw_custom_template.py | 4 +- .../barth_dtw_stride_segmentation.py | 2 +- .../barth_dtw_stride_segmentation_roi.py | 2 +- ...nstrained_barth_dtw_stride_segmentation.py | 2 +- .../roth_hmm_stride_segmentation.py | 3 +- .../segmentation_hmm_training.py | 1 + .../trajectory_reconstruction_region.py | 2 +- .../zupt_dependency.py | 5 +- .../_event_detection_mixin.py | 4 +- gaitmap/base.py | 4 +- gaitmap/data_transform/__init__.py | 1 + gaitmap/data_transform/_base.py | 9 +- gaitmap/data_transform/_feature_transform.py | 9 +- gaitmap/data_transform/_filter.py | 7 +- gaitmap/data_transform/_scaler.py | 15 +- gaitmap/evaluation_utils/event_detection.py | 39 ++-- gaitmap/evaluation_utils/parameter_errors.py | 4 +- gaitmap/evaluation_utils/scores.py | 33 ++- .../evaluation_utils/stride_segmentation.py | 69 +++--- gaitmap/event_detection/__init__.py | 1 + .../_herzer_event_detection.py | 7 +- gaitmap/gait_detection/__init__.py | 1 + gaitmap/parameters/__init__.py | 1 + gaitmap/parameters/_spatial_parameters.py | 10 +- gaitmap/parameters/_temporal_parameters.py | 5 +- gaitmap/preprocessing/__init__.py | 1 + .../sensor_alignment/_gravity_alignment.py | 1 + .../sensor_alignment/_pca_alignment.py | 6 +- .../_roi_stride_segmentation.py | 11 +- gaitmap/stride_segmentation/_utils.py | 1 + .../_region_level_trajectory.py | 16 +- .../_stride_level_trajectory.py | 9 +- .../_trajectory_wrapper.py | 5 +- .../orientation_methods/_madgwick.py | 3 +- .../_simple_gyro_integration.py | 3 +- .../_forward_backwards_integration.py | 2 +- .../trajectory_methods/_kalman_numba_funcs.py | 2 +- .../trajectory_methods/_rts_kalman.py | 20 +- gaitmap/utils/_algo_helper.py | 5 +- gaitmap/utils/_datatype_validation_helper.py | 36 ++- gaitmap/utils/_gaitmap_mad.py | 14 +- gaitmap/utils/_types.py | 1 + gaitmap/utils/array_handling.py | 24 +- gaitmap/utils/consts.py | 1 + gaitmap/utils/coordinate_conversion.py | 12 +- gaitmap/utils/datatype_helper.py | 80 +++---- gaitmap/utils/exceptions.py | 4 +- gaitmap/utils/fast_quaternion_math.py | 1 + gaitmap/utils/rotations.py | 9 +- gaitmap/utils/signal_processing.py | 6 +- gaitmap/utils/static_moment_detection.py | 7 +- gaitmap/utils/stride_list_conversion.py | 1 + gaitmap/utils/vector_math.py | 7 +- gaitmap/zupt_detection/__init__.py | 1 + .../zupt_detection/_combo_zupt_detector.py | 2 +- .../_moving_window_zupt_detector.py | 7 +- .../_stride_event_zupt_detector.py | 2 +- gaitmap_mad/gaitmap_mad/__init__.py | 1 + .../gaitmap_mad/event_detection/__init__.py | 1 + .../_filtered_rampp_event_detection.py | 3 +- .../event_detection/_rampp_event_detection.py | 5 +- .../_ullrich_gait_sequence_detection.py | 21 +- .../_forward_direction_alignment.py | 4 +- .../stride_segmentation/__init__.py | 1 + .../stride_segmentation/dtw/__init__.py | 1 + .../stride_segmentation/dtw/_barth_dtw.py | 9 +- .../stride_segmentation/dtw/_base_dtw.py | 35 +-- .../dtw/_constrained_barth_dtw.py | 3 +- .../dtw/_dtw_templates/templates.py | 12 +- .../dtw/_vendored_tslearn.py | 3 +- .../stride_segmentation/hmm/__init__.py | 1 + .../hmm/_hmm_feature_transform.py | 7 +- .../hmm/_hmm_stride_segmentation.py | 3 +- .../hmm/_segmentation_model.py | 3 +- .../stride_segmentation/hmm/_simple_model.py | 2 +- .../stride_segmentation/hmm/_utils.py | 20 +- ...piece_wise_linear_dedrifted_integration.py | 10 +- poetry.lock | 112 ++------- pyproject.toml | 156 ++----------- tests/_regression_utils.py | 9 +- tests/_test_gaitmap_mad_split.py | 17 +- tests/conftest.py | 8 +- tests/mixins/test_algorithm_mixin.py | 23 +- tests/mixins/test_caching_mixin.py | 9 +- tests/test_base.py | 27 +-- tests/test_data_transforms/test_base.py | 36 +-- .../test_feature_transformer.py | 26 +-- tests/test_data_transforms/test_filter.py | 8 +- tests/test_data_transforms/test_scalers.py | 28 +-- .../test_event_detection.py | 8 +- .../test_parameter_errors.py | 21 +- tests/test_evaluation_utlis/test_scores.py | 44 ++-- .../test_stride_segmentation.py | 86 +++---- .../test_event_detection_filtered_rampp.py | 8 +- .../test_event_detection_herzer.py | 28 +-- .../test_event_detection_rampp.py | 34 +-- tests/test_examples/test_all_examples.py | 56 ++--- .../test_ullrich_gait_sequence_detection.py | 56 ++--- .../test_spatial_parameters.py | 30 +-- .../test_temporal_parameter.py | 11 +- .../test_forward_direction_alignment.py | 14 +- .../test_preprocessing/test_pca_alignment.py | 20 +- .../test_sensor_alignment.py | 16 +- .../test_barth_dtw.py | 26 +-- .../test_stride_segmentation/test_base_dtw.py | 87 +++---- .../test_constrained_barth_dtw.py | 3 +- .../test_dtw_templates.py | 50 ++-- .../test_roi_stride_segmentation.py | 44 ++-- .../test_stride_segmentation/test_roth_hmm.py | 66 +++--- .../test_orientation_methods/test_madgwick.py | 2 +- .../test_ori_method_mixin.py | 10 +- ...piece_wise_linear_dedrifted_integration.py | 18 +- .../test_pos_method_mixin.py | 14 +- .../test_region_level_trajectory.py | 44 ++-- .../test_stride_level_trajectory.py | 4 +- .../test_rts_kalman.py | 10 +- .../test_trajectory_method_mixin.py | 12 +- .../test_trajectory_wrapper.py | 32 +-- tests/test_utils/test_array_handling.py | 116 ++++----- .../test_utils/test_coordinate_conversion.py | 26 +-- tests/test_utils/test_datatype_helper.py | 220 +++++++++--------- tests/test_utils/test_fast_quaternion_math.py | 10 +- tests/test_utils/test_rotations.py | 108 ++++----- tests/test_utils/test_signal_processing.py | 2 +- .../test_static_moment_detection.py | 34 +-- .../test_utils/test_stride_list_conversion.py | 27 ++- tests/test_utils/test_vector_math.py | 18 +- .../test_combo_zupt_detector.py | 14 +- .../test_moving_window_zupt_detector.py | 55 +++-- .../test_stride_event_zupt_detector.py | 10 +- 151 files changed, 1453 insertions(+), 1495 deletions(-) create mode 100644 .ruff.toml diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 00000000..00c568a7 --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,141 @@ +line-length = 120 +target-version = "py38" + +[lint] +select = [ + # pyflakes + "F", + # pycodestyle + "E", + "W", + # mccabe + "C90", + # isort + "I", + # pydocstyle + "D", + # pyupgrade + "UP", + # pep8-naming + "N", + # flake8-blind-except + "BLE", + # flake8-2020 + "YTT", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-errmsg + "EM", + # flake8-implicit-str-concat + "ISC", + # flake8-pytest-style + "PT", + # flake8-return + "RET", + # flake8-simplify + "SIM", + # flake8-unused-arguments + "ARG", + # pandas-vet + "PD", + # pygrep-hooks + "PGH", + # flake8-bugbear + "B", + # flake8-quotes + "Q", + # pylint + "PL", + # flake8-pie + "PIE", + # flake8-type-checking + "TCH", + # tryceratops + "TRY", + # flake8-use-pathlib + "PTH", + "RUF", + # Numpy rules + "NPY", + # Implicit namespace packages + "INP", + # No relative imports + "TID252", + # f-strings over string concatenation + "FLY", + # Annotations + # No enforced annotations +# "ANN" + + +] + +ignore = [ + # controversial + "B006", + # controversial + "B008", + "B010", + # Magic constants + "PLR2004", + # Strings in error messages + "EM101", + "EM102", + "EM103", + # Exception strings + "TRY003", + # Varaibles before return + "RET504", + # Abstract raise into inner function + "TRY301", + # df as varaible name + "PD901", + # melt over stack + "PD013", + # No Any annotations + "ANN401", + # Self annotation + "ANN101", + # To many arguments + "PLR0913", + # Class attribute shadows builtin + "A003", + # No typing for `cls` + "ANN102", + # Ignore because of formatting + "ISC001", + # Use type-checking block + "TCH001", + "TCH002", + "TCH003", + # No stacklevel + "B028", + # Overwriting loop variable + "PLW2901" + +] + + +exclude = [ + "doc/sphinxext/*.py", + "doc/build/*.py", + "doc/temp/*.py", + ".eggs/*.py", + "example_data", +] + + +[lint.per-file-ignores] +# https://github.com/astral-sh/ruff/issues/8925 +"examples/**/*.py" = ["D400"] + + +[lint.pydocstyle] +convention = "numpy" + +[format] +docstring-code-format = true diff --git a/_tasks.py b/_tasks.py index c2500167..d06cad39 100644 --- a/_tasks.py +++ b/_tasks.py @@ -1,5 +1,5 @@ -import platform import re +import shutil import subprocess import sys from pathlib import Path @@ -9,18 +9,18 @@ HERE = Path(__file__).parent -def task_docs(): +def task_docs(clean=False, builder="html") -> None: """Build the html docs using Sphinx.""" # Delete Autogenerated files from previous run - # shutil.rmtree(str(HERE / "docs/modules/generated"), ignore_errors=True) + if clean: + shutil.rmtree(str(HERE / "docs/modules/generated"), ignore_errors=True) + shutil.rmtree(str(HERE / "docs/_build"), ignore_errors=True) + shutil.rmtree(str(HERE / "docs/auto_examples"), ignore_errors=True) - if platform.system() == "Windows": - subprocess.run([HERE / "docs/make.bat", "html"], shell=False, check=True) - else: - subprocess.run(["make", "-C", HERE / "docs", "html"], shell=False, check=True) + subprocess.run(f"sphinx-build -b {builder} -j auto -d docs/_build docs docs/_build/html", shell=True, check=True) -def update_version_strings(file_path, new_version): +def update_version_strings(file_path, new_version) -> None: # taken from: # https://stackoverflow.com/questions/57108712/replace-updated-version-strings-in-files-via-python version_regex = re.compile(r"(^_*?version_*?\s*=\s*\")(\d+\.\d+\.\d+-?\S*)\"", re.M) @@ -37,7 +37,7 @@ def update_version_strings(file_path, new_version): f.truncate() -def update_version(version): +def update_version(version) -> None: subprocess.run(["poetry", "version", version], shell=False, check=True) new_version = ( subprocess.run(["poetry", "version"], shell=False, check=True, capture_output=True) @@ -51,11 +51,11 @@ def update_version(version): update_version_strings(HERE / "gaitmap_mad/gaitmap_mad/__init__.py", new_version) -def task_update_version(): +def task_update_version() -> None: update_version(sys.argv[1]) -def task_bump_all_dev(): +def task_bump_all_dev() -> None: """Bump all dev dependencies.""" pyproject = toml.load(HERE.joinpath("pyproject.toml")) try: diff --git a/docs/conf.py b/docs/conf.py index 071851fc..d7f5919e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ from datetime import datetime from inspect import getsourcefile from pathlib import Path -from typing import List +from typing import List, Optional import toml from sphinx_gallery.sorting import ExplicitOrder @@ -32,7 +32,7 @@ def replace_gitlab_links(base_url, text): regex = base_url + r"-/(merge_requests|issues|commit)/(\w+)" - def substitute(matchobj): + def substitute(matchobj) -> str: tokens = {"merge_requests": "!", "issues": "#"} if matchobj.group(1) == "commit": return f"[mad-gitlab: {matchobj.group(2)[:5]}]({matchobj.group(0)})" @@ -45,7 +45,7 @@ def substitute(matchobj): def convert_github_links(base_url, text): regex = base_url + r"(pull|issues|commit)/(\w+)" - def substitute(matchobj): + def substitute(matchobj) -> str: if matchobj.group(1) == "commit": return f"[{matchobj.group(2)[:5]}]({matchobj.group(0)})" return f"[#{matchobj.group(2)}]({matchobj.group(0)})" @@ -229,10 +229,11 @@ def get_nested_attr(obj, attr): ) -def skip_properties(app, what, name, obj, skip, options): +def skip_properties(app, what, name, obj, skip, options) -> Optional[bool]: """This removes all properties from the documentation as they are expected to be documented in the docstring.""" if isinstance(obj, property): return True + return None GAITMAP_MAD_TEST = """ @@ -244,7 +245,7 @@ def skip_properties(app, what, name, obj, skip, options): """ -def add_info_about_origin(app, what, name, obj, options, lines: List[str]): +def add_info_about_origin(app, what, name, obj, options, lines: List[str]) -> None: """Add a short info text to all algorithms that are only available via gaitmap_mad.""" if what != "class": return @@ -259,6 +260,6 @@ def add_info_about_origin(app, what, name, obj, options, lines: List[str]): lines.insert(2, l) -def setup(app): +def setup(app) -> None: app.connect("autodoc-skip-member", skip_properties) app.connect("autodoc-process-docstring", add_info_about_origin) diff --git a/docs/image_src/create_coordinate_definition_template_plots.py b/docs/image_src/create_coordinate_definition_template_plots.py index f17f1ba3..179816fd 100644 --- a/docs/image_src/create_coordinate_definition_template_plots.py +++ b/docs/image_src/create_coordinate_definition_template_plots.py @@ -32,8 +32,9 @@ colors = [docs_red, docs_green, docs_blue] + # helper to plot different coordinate frames -def plot_stride(data, column_names, sensor_id, stride_id, export_name): +def plot_stride(data, column_names, sensor_id, stride_id, export_name) -> None: fig, axs = plt.subplots(2, figsize=(7, 7)) start = dtw.stride_list_[sensor_id].iloc[stride_id].start end = dtw.stride_list_[sensor_id].iloc[stride_id].end @@ -61,12 +62,12 @@ def plot_stride(data, column_names, sensor_id, stride_id, export_name): fig.savefig(sensor_id + col[3:] + ".pdf", bbox_inches="tight") -#%% +# %% # Plot "Stride-Template" in Sensor Frame plot_stride(dataset_sf, SF_COLS, "left_sensor", 5, "left_sensor_sensor_frame_template.pdf") plot_stride(dataset_sf, SF_COLS, "right_sensor", 18, "right_sensor_sensor_frame_template.pdf") -#%% +# %% # Plot "Stride-Template" in Body Frame plot_stride(dataset_bf, BF_COLS, "left_sensor", 5, "left_sensor_body_frame_template.pdf") plot_stride(dataset_bf, BF_COLS, "right_sensor", 18, "right_sensor_body_frame_template.pdf") diff --git a/docs/sphinxext/githublink.py b/docs/sphinxext/githublink.py index 24a999eb..f7480159 100644 --- a/docs/sphinxext/githublink.py +++ b/docs/sphinxext/githublink.py @@ -30,12 +30,13 @@ def _linkcode_resolve(domain, info, package, url_fmt, revision): This is called by sphinx.ext.linkcode An example with a long-untouched module that everyone has - >>> _linkcode_resolve('py', {'module': 'tty', - ... 'fullname': 'setraw'}, - ... package='tty', - ... url_fmt='http://hg.python.org/cpython/file/' - ... '{revision}/Lib/{package}/{path}#L{lineno}', - ... revision='xxxx') + >>> _linkcode_resolve( + ... "py", + ... {"module": "tty", "fullname": "setraw"}, + ... package="tty", + ... url_fmt="http://hg.python.org/cpython/file/" "{revision}/Lib/{package}/{path}#L{lineno}", + ... revision="xxxx", + ... ) 'http://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18' """ if revision is None: @@ -85,6 +86,4 @@ def make_linkcode_resolve(package, url_fmt): '{path}#L{lineno}') """ revision = _get_git_revision() - return partial( - _linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt - ) + return partial(_linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt) diff --git a/example_data/extract_example_data.py b/example_data/extract_example_data.py index 44c4501b..a6be8ba5 100644 --- a/example_data/extract_example_data.py +++ b/example_data/extract_example_data.py @@ -73,10 +73,7 @@ def normalize(v: np.ndarray) -> np.ndarray: If a 2D array is provided, each row is considered a vector, which is normalized independently. """ v = np.array(v) - if len(v.shape) == 1: - ax = 0 - else: - ax = 1 + ax = 0 if len(v.shape) == 1 else 1 return (v.T / np.linalg.norm(v, axis=ax)).T @@ -130,8 +127,8 @@ def normalize(v: np.ndarray) -> np.ndarray: rotation_from_angle(np.array([0, 0, 1]), np.deg2rad(-90)) * rotation_from_angle(np.array([1, 0, 0]), np.deg2rad(-90)) ).inv() -rotations = dict(left_sensor=left_rot, right_sensor=right_rot) -test_df = test_df.rename(columns={"l_{}".format(sensor): "left_sensor", "r_{}".format(sensor): "right_sensor"}) +rotations = {"left_sensor": left_rot, "right_sensor": right_rot} +test_df = test_df.rename(columns={f"l_{sensor}": "left_sensor", f"r_{sensor}": "right_sensor"}) test_df.columns = test_df.columns.set_names(("sensor", "axis")) test_df.sort_index(axis=1).to_csv("./imu_sample_not_rotated.csv") @@ -147,7 +144,7 @@ def normalize(v: np.ndarray) -> np.ndarray: test_df.to_csv("./imu_sample.csv") # Example events -test_events = test_borders = pd.read_csv(get_subject_mocap_folder(subject) / "{}_steps.csv".format(test), index_col=0) +test_events = test_borders = pd.read_csv(get_subject_mocap_folder(subject) / f"{test}_steps.csv", index_col=0) test_events = test_events.rename(columns={"hs": "ic", "to": "tc", "ms": "min_vel"}) # convert to 204.8 Hz test_events[["ic", "tc", "min_vel"]] *= 204.8 / 100 @@ -172,8 +169,8 @@ def normalize(v: np.ndarray) -> np.ndarray: # Back to 100 Hz test_events[["start", "end"]] *= 100 / 204.8 -test_orientation = dict() -test_position = dict() +test_orientation = {} +test_position = {} for sensor, short in [("left_sensor", "L"), ("right_sensor", "R")]: normal_vectors = find_plane_from_points( test_mocap[f"{short}_FCC"], test_mocap[f"{short}_TOE"], test_mocap[f"{short}_FM5"] @@ -182,8 +179,8 @@ def normalize(v: np.ndarray) -> np.ndarray: sidewards = np.cross(normal_vectors, forward_vector, axis=1) rot_mat = np.hstack([forward_vector, sidewards, normal_vectors]).reshape((-1, 3, 3)) ori = pd.DataFrame(Rotation.from_matrix(rot_mat).inv().as_quat(), columns=["q_x", "q_y", "q_z", "q_w"]) - ori_per_stride = dict() - pos_per_stride = dict() + ori_per_stride = {} + pos_per_stride = {} for _, s in test_events[test_events["foot"] == sensor.split("_")[0]].iterrows(): ori_per_stride[s["s_id"]] = ori.iloc[int(s["start"]) : int(s["end"])].reset_index(drop=True) pos = test_mocap[short + "_FCC"].iloc[int(s["start"]) : int(s["end"])].reset_index(drop=True) diff --git a/examples/advanced_features/algo_serialize.py b/examples/advanced_features/algo_serialize.py index 04e55cc7..5ef9a2be 100644 --- a/examples/advanced_features/algo_serialize.py +++ b/examples/advanced_features/algo_serialize.py @@ -17,6 +17,7 @@ reproducibility. This means you should save the exact library version together with the json version of the used algorithms. """ + from pprint import pprint # %% diff --git a/examples/advanced_features/caching.py b/examples/advanced_features/caching.py index bfcadbd8..8e81b644 100644 --- a/examples/advanced_features/caching.py +++ b/examples/advanced_features/caching.py @@ -32,7 +32,6 @@ # ---------------- # We will simply copy the stride segmentation example to have some data to work with. from gaitmap.example_data import get_healthy_example_imu_data -from gaitmap.stride_segmentation import BarthOriginalTemplate from gaitmap.utils.coordinate_conversion import convert_to_fbf data = get_healthy_example_imu_data().iloc[:2000] diff --git a/examples/advanced_features/multi_process.py b/examples/advanced_features/multi_process.py index ba83c256..d9de6ddc 100644 --- a/examples/advanced_features/multi_process.py +++ b/examples/advanced_features/multi_process.py @@ -30,6 +30,7 @@ Other Python helpers to spawn multiple processes will of course work as well. """ + from pprint import pprint from typing import Any, Dict diff --git a/examples/datasets_and_pipelines/cross_validation.py b/examples/datasets_and_pipelines/cross_validation.py index d515e459..89384761 100644 --- a/examples/datasets_and_pipelines/cross_validation.py +++ b/examples/datasets_and_pipelines/cross_validation.py @@ -25,6 +25,7 @@ If you want to have more information on how the dataset and pipeline is built, head over to this example. Here we will just copy the code over. """ + import numpy as np import pandas as pd from tpcp import CloneFactory, Dataset, OptimizableParameter, OptimizablePipeline, Parameter @@ -70,7 +71,7 @@ class MyPipeline(OptimizablePipeline): cost_func_: np.ndarray # We need to wrap the template in a `CloneFactory` call here to prevent issues with mutable defaults! - def __init__(self, max_cost: float = 3, template: BaseDtwTemplate = CloneFactory(BarthOriginalTemplate())): + def __init__(self, max_cost: float = 3, template: BaseDtwTemplate = CloneFactory(BarthOriginalTemplate())) -> None: self.max_cost = max_cost self.template = template @@ -227,7 +228,7 @@ def score(pipeline: MyPipeline, datapoint: MyDataset): optimized_pipeline = result_df["optimizer"][0] optimized_pipeline -#%% +# %% optimized_pipeline.optimized_pipeline_.get_params() # %% diff --git a/examples/datasets_and_pipelines/custom_dataset.py b/examples/datasets_and_pipelines/custom_dataset.py index 5e3f2b52..c4dbe426 100644 --- a/examples/datasets_and_pipelines/custom_dataset.py +++ b/examples/datasets_and_pipelines/custom_dataset.py @@ -46,7 +46,7 @@ trials = list(product(("rec_1", "rec_2", "rec_3"), ("trial_1",))) trials.append(("rec_3", "trial_2")) -index = [(p, *t) for p, t in product(("p{}".format(i) for i in range(1, 6)), trials)] +index = [(p, *t) for p, t in product((f"p{i}" for i in range(1, 6)), trials)] index = pd.DataFrame(index, columns=["participant", "recording", "trial"]) index @@ -92,7 +92,7 @@ def create_index(self): # Note, that each row itself is a dataset again, but just with a single entry. for row in final_subset: print(row) - print("This row contains {} data-point".format(len(row)), end="\n\n") + print(f"This row contains {len(row)} data-point", end="\n\n") # %% # However, in many cases, we don't want to iterate over all rows, but rather iterate over groups of the datasets ( @@ -109,7 +109,7 @@ def create_index(self): # Note that the grouped_subset shows the new groupby columns as the index in the representation and the length of the # dataset is reported to be the number of groups. grouped_subset = final_subset.groupby(["participant", "recording"]) -print("The dataset contains {} groups.".format(len(grouped_subset))) +print(f"The dataset contains {len(grouped_subset)} groups.") grouped_subset # %% @@ -118,7 +118,7 @@ def create_index(self): # Grouping also changes the meaning of a "single datapoint". # Each group reports a shape of `(1,)` independent of the number of rows in each group. for group in grouped_subset: - print("This group has the shape {}".format(group.shape)) + print(f"This group has the shape {group.shape}") print(group, end="\n\n") # %% @@ -331,7 +331,7 @@ def __init__( *, groupby_cols: Optional[Union[List[str], str]] = None, subset_index: Optional[pd.DataFrame] = None, - ): + ) -> None: self.data_folder = data_folder self.custom_config_para = custom_config_para super().__init__(groupby_cols=groupby_cols, subset_index=subset_index) diff --git a/examples/datasets_and_pipelines/gridsearch.py b/examples/datasets_and_pipelines/gridsearch.py index 19e91b2e..c933926d 100644 --- a/examples/datasets_and_pipelines/gridsearch.py +++ b/examples/datasets_and_pipelines/gridsearch.py @@ -9,7 +9,7 @@ Hence, it makes sense to cross-check the official examples. """ -import joblib + import pandas as pd # %% @@ -90,7 +90,7 @@ class MyPipeline(Pipeline): segmented_stride_list_: SingleSensorStrideList - def __init__(self, max_cost: float = 3): + def __init__(self, max_cost: float = 3) -> None: self.max_cost = max_cost def run(self, datapoint: MyDataset): diff --git a/examples/datasets_and_pipelines/gridsearch_cv.py b/examples/datasets_and_pipelines/gridsearch_cv.py index 70c9a06d..24bd2a74 100644 --- a/examples/datasets_and_pipelines/gridsearch_cv.py +++ b/examples/datasets_and_pipelines/gridsearch_cv.py @@ -18,6 +18,7 @@ tuning hyperparameters `_. """ + import random from typing import Optional @@ -87,7 +88,7 @@ def __init__( # We need to wrap the template in a `CloneFactory` call here to prevent issues with mutable defaults! template: InterpolatedDtwTemplate = CloneFactory(InterpolatedDtwTemplate(scaling=TrainableAbsMaxScaler())), n_train_strides: Optional[int] = None, - ): + ) -> None: self.max_cost = max_cost self.template = template self.n_train_strides = n_train_strides diff --git a/examples/datasets_and_pipelines/optimizable_pipelines.py b/examples/datasets_and_pipelines/optimizable_pipelines.py index 209c20af..294ad4c7 100644 --- a/examples/datasets_and_pipelines/optimizable_pipelines.py +++ b/examples/datasets_and_pipelines/optimizable_pipelines.py @@ -103,7 +103,9 @@ class MyPipeline(OptimizablePipeline): cost_func_: np.ndarray # We need to wrap the template in a `CloneFactory` call here to prevent issues with mutable defaults! - def __init__(self, max_cost: float = 3, template: BaseDtwTemplate = CloneFactory(InterpolatedDtwTemplate())): + def __init__( + self, max_cost: float = 3, template: BaseDtwTemplate = CloneFactory(InterpolatedDtwTemplate()) + ) -> None: self.max_cost = max_cost self.template = template diff --git a/examples/event_detection/herzer_event_detection.py b/examples/event_detection/herzer_event_detection.py index 2598aced..97445f10 100644 --- a/examples/event_detection/herzer_event_detection.py +++ b/examples/event_detection/herzer_event_detection.py @@ -62,7 +62,7 @@ # stride list. # As we passed a dataset with two sensors, the output will be a dictionary. min_vel_events_left = ed.min_vel_event_list_["left_sensor"] -print("Gait events for {} min_vel strides were detected.".format(len(min_vel_events_left))) +print(f"Gait events for {len(min_vel_events_left)} min_vel strides were detected.") min_vel_events_left.head() # %% @@ -70,7 +70,7 @@ # `min_vel_event_list_`, but the start and the end of each stride are unchanged compared to the input. # This also means that no strides are removed due to the conversion step explained below. segmented_events_left = ed.segmented_event_list_["left_sensor"] -print("Gait events for {} segmented strides were detected.".format(len(segmented_events_left))) +print(f"Gait events for {len(segmented_events_left)} segmented strides were detected.") segmented_events_left.head() # %% @@ -105,7 +105,7 @@ for ax, data in zip(axs, axs_data): ax.plot(data) - for i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): + for _i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): ax.axvline(stride["start"], color="g") ax.axvline(stride["end"], color="r") @@ -141,7 +141,7 @@ pre_ic_idx = ed.min_vel_event_list_["left_sensor"]["pre_ic"].to_numpy().astype(int) for ax, sensor in zip([ax1, ax2], ["gyr_ml", "acc_pa"]): - for i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): + for _i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): ax.axvline(stride["start"], color="g") ax.axvline(stride["end"], color="r") @@ -217,7 +217,7 @@ sensor_axis = "gyr_ml" ax1.plot(bf_data.reset_index(drop=True)["left_sensor"][sensor_axis]) -for i, stride in segmented_stride_list.iterrows(): +for _i, stride in segmented_stride_list.iterrows(): ax1.axvline(stride["start"], color="g") ax1.axvline(stride["end"], color="r") ax1.axvspan(stride["start"], stride["end"], alpha=0.2) @@ -229,7 +229,7 @@ min_vel_idx = ed2.min_vel_event_list_["min_vel"].to_numpy().astype(int) pre_ic_idx = ed2.min_vel_event_list_["pre_ic"].to_numpy().astype(int) -for i, stride in ed2.min_vel_event_list_.iterrows(): +for _i, stride in ed2.min_vel_event_list_.iterrows(): ax2.axvline(stride["start"], color="g") ax2.axvline(stride["end"], color="r") ax2.axvspan(stride["start"], stride["end"], alpha=0.2) diff --git a/examples/event_detection/rampp_event_detection.py b/examples/event_detection/rampp_event_detection.py index 1fcb88de..0cd405ed 100644 --- a/examples/event_detection/rampp_event_detection.py +++ b/examples/event_detection/rampp_event_detection.py @@ -13,6 +13,7 @@ sensor-based stride parameter calculation from gait sequences in geriatric patients. IEEE transactions on biomedical engineering, 62(4), 1089-1097.. https://doi.org/10.1109/TBME.2014.2368211 """ + # %% # Getting some example data # ------------------------- @@ -79,7 +80,7 @@ # stride list. # As we passed a dataset with two sensors, the output will be a dictionary. min_vel_events_left = ed.min_vel_event_list_["left_sensor"] -print("Gait events for {} min_vel strides were detected.".format(len(min_vel_events_left))) +print(f"Gait events for {len(min_vel_events_left)} min_vel strides were detected.") min_vel_events_left.head() # %% @@ -87,7 +88,7 @@ # `min_vel_event_list_`, but the start and the end of each stride are unchanged compared to the input. # This also means that no strides are removed due to the conversion step explained below. segmented_events_left = ed.segmented_event_list_["left_sensor"] -print("Gait events for {} segmented strides were detected.".format(len(segmented_events_left))) +print(f"Gait events for {len(segmented_events_left)} segmented strides were detected.") segmented_events_left.head() # %% @@ -109,7 +110,7 @@ min_vel_idx = ed.min_vel_event_list_["left_sensor"]["min_vel"].to_numpy().astype(int) for ax, sensor in zip([ax1, ax2], ["gyr_ml", "acc_pa"]): - for i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): + for _i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): ax.axvline(stride["start"], color="g") ax.axvline(stride["end"], color="r") @@ -169,7 +170,7 @@ pre_ic_idx = ed.min_vel_event_list_["left_sensor"]["pre_ic"].to_numpy().astype(int) for ax, sensor in zip([ax1, ax2], ["gyr_ml", "acc_pa"]): - for i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): + for _i, stride in ed.min_vel_event_list_["left_sensor"].iterrows(): ax.axvline(stride["start"], color="g") ax.axvline(stride["end"], color="r") @@ -246,7 +247,7 @@ sensor_axis = "gyr_ml" ax1.plot(bf_data.reset_index(drop=True)["left_sensor"][sensor_axis]) -for i, stride in segmented_stride_list.iterrows(): +for _i, stride in segmented_stride_list.iterrows(): ax1.axvline(stride["start"], color="g") ax1.axvline(stride["end"], color="r") ax1.axvspan(stride["start"], stride["end"], alpha=0.2) @@ -258,7 +259,7 @@ min_vel_idx = ed2.min_vel_event_list_["min_vel"].to_numpy().astype(int) pre_ic_idx = ed2.min_vel_event_list_["pre_ic"].to_numpy().astype(int) -for i, stride in ed2.min_vel_event_list_.iterrows(): +for _i, stride in ed2.min_vel_event_list_.iterrows(): ax2.axvline(stride["start"], color="g") ax2.axvline(stride["end"], color="r") ax2.axvspan(stride["start"], stride["end"], alpha=0.2) @@ -331,12 +332,10 @@ edfilt = FilteredRamppEventDetection(ic_lowpass_filter=ButterworthFilter(10, 15)) edfilt = edfilt.detect(data=bf_data, stride_list=stride_list, sampling_rate_hz=sampling_rate_hz) min_vel_events_left = edfilt.min_vel_event_list_["left_sensor"] -print("Gait events for {} min_vel strides using the filtered version were detected.".format(len(min_vel_events_left))) +print(f"Gait events for {len(min_vel_events_left)} min_vel strides using the filtered version were detected.") min_vel_events_left.head() segmented_events_left = edfilt.segmented_event_list_["left_sensor"] -print( - "Gait events for {} segmented strides using the filtered version were detected.".format(len(segmented_events_left)) -) +print(f"Gait events for {len(segmented_events_left)} segmented strides using the filtered version were detected.") segmented_events_left.head() fig, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(10, 5)) ax1.plot(bf_data.reset_index(drop=True)["left_sensor"][["gyr_ml"]]) @@ -347,7 +346,7 @@ min_vel_idx = edfilt.min_vel_event_list_["left_sensor"]["min_vel"].to_numpy().astype(int) for ax, sensor in zip([ax1, ax2], ["gyr_ml", "acc_pa"]): - for i, stride in edfilt.min_vel_event_list_["left_sensor"].iterrows(): + for _i, stride in edfilt.min_vel_event_list_["left_sensor"].iterrows(): ax.axvline(stride["start"], color="g") ax.axvline(stride["end"], color="r") diff --git a/examples/full_pipelines/mad_gait_pipeline.py b/examples/full_pipelines/mad_gait_pipeline.py index d6dc7992..b2fc7afa 100644 --- a/examples/full_pipelines/mad_gait_pipeline.py +++ b/examples/full_pipelines/mad_gait_pipeline.py @@ -15,6 +15,7 @@ - :ref:`Temporal Parameters ` and :ref:`Spatial Parameters ` """ + # %% # Load example data # ----------------- @@ -46,7 +47,7 @@ np.array([0, 0, 1]), np.deg2rad(90) ) -rotations = dict(left_sensor=left_rot, right_sensor=right_rot) +rotations = {"left_sensor": left_rot, "right_sensor": right_rot} dataset_sf = flip_dataset(example_dataset, rotations) # Align to Gravity @@ -120,22 +121,20 @@ import matplotlib.pyplot as plt print( - "The following number of strides were identified and parameterized for each sensor: {}".format( - {k: len(v) for k, v in ed.min_vel_event_list_.items()} - ) + f"The following number of strides were identified and parameterized for each sensor: {({k: len(v) for k, v in ed.min_vel_event_list_.items()})}" ) # %% for k, v in temporal_paras.parameters_pretty_.items(): v.plot() - plt.title("All temporal parameters of sensor {}".format(k)) + plt.title(f"All temporal parameters of sensor {k}") # %% for k, v in spatial_paras.parameters_pretty_.items(): v[["stride length [m]", "gait velocity [m/s]", "arc length [m]"]].plot() - plt.title("All spatial parameters of sensor {}".format(k)) + plt.title(f"All spatial parameters of sensor {k}") # %% for k, v in spatial_paras.parameters_pretty_.items(): v.filter(like="angle").plot() - plt.title("All angle parameters of sensor {}".format(k)) + plt.title(f"All angle parameters of sensor {k}") diff --git a/examples/gait_detection/ullrich_gait_sequence_detection.py b/examples/gait_detection/ullrich_gait_sequence_detection.py index 7163fb9e..7ad678da 100644 --- a/examples/gait_detection/ullrich_gait_sequence_detection.py +++ b/examples/gait_detection/ullrich_gait_sequence_detection.py @@ -100,7 +100,7 @@ # `end` of all detected gait sequences. It furthermore has a column `gs_id` for the gait sequence id which is used in # further processing steps to assign for example single strides to their respective `gs_id`. gait_sequences = gsd.gait_sequences_ -print("{} gait sequences were detected.".format(len(gait_sequences))) +print(f"{len(gait_sequences)} gait sequences were detected.") gait_sequences.head() # %% @@ -115,7 +115,7 @@ start_idx = gait_sequences["start"].to_numpy().astype(int) end_idx = gait_sequences["end"].to_numpy().astype(int) -for i, gs in gait_sequences.iterrows(): +for _i, gs in gait_sequences.iterrows(): start_sample = int(gs["start"]) end_sample = int(gs["end"]) ax1.axvline(start_sample, color="g") diff --git a/examples/generic_algorithms/base_dtw_generic.py b/examples/generic_algorithms/base_dtw_generic.py index 50206797..d87ac24a 100644 --- a/examples/generic_algorithms/base_dtw_generic.py +++ b/examples/generic_algorithms/base_dtw_generic.py @@ -38,8 +38,8 @@ sz1 = len(long_sequence) sz2 = len(short_sequence) -print("Shape long sequence: {}".format(long_sequence.shape)) -print("Shape short sequence: {}".format(short_sequence.shape)) +print(f"Shape long sequence: {long_sequence.shape}") +print(f"Shape short sequence: {short_sequence.shape}") # %% # Plot the sequences @@ -80,7 +80,7 @@ # Afterwards a set of results are available on the dtw object dtw = dtw.segment(long_sequence, sampling_rate_hz=sampling_rate_hz) -print("{} matches were found".format(len(dtw.matches_start_end_))) +print(f"{len(dtw.matches_start_end_)} matches were found") print(dtw.matches_start_end_) # %% diff --git a/examples/preprocessing/automatic_sensor_alignment_details.py b/examples/preprocessing/automatic_sensor_alignment_details.py index 9b614666..b3ea9147 100644 --- a/examples/preprocessing/automatic_sensor_alignment_details.py +++ b/examples/preprocessing/automatic_sensor_alignment_details.py @@ -261,7 +261,7 @@ # We see that the PCA alignment method applies a pure heading correction. # Further the angle value, basically matches the rotation we applied in step 1 perfectly. rot_angles = np.rad2deg(pca_rotation.as_euler("xyz")) -print("X-rot: %.1f deg, Y-rot: %.1f deg, Z-rot: %.1f deg" % (rot_angles[0], rot_angles[1], rot_angles[2])) +print(f"X-rot: {rot_angles[0]:.1f} deg, Y-rot: {rot_angles[1]:.1f} deg, Z-rot: {rot_angles[2]:.1f} deg") # %% @@ -338,7 +338,7 @@ # Lets look at the rotation angles in degree. # The forward direction sign alignment applied the required 180deg flip to the data. rot_angles = np.rad2deg(fdsa_rotation.as_euler("xyz")) -print("X-rot: %.1f deg, Y-rot: %.1f deg, Z-rot: %.1f deg" % (rot_angles[0], rot_angles[1], rot_angles[2])) +print(f"X-rot: {rot_angles[0]:.1f} deg, Y-rot: {rot_angles[1]:.1f} deg, Z-rot: {rot_angles[2]:.1f} deg") # %% # As a final result of the automatic alignment pipeline all misalignment around all axis were subsequently fixed. diff --git a/examples/preprocessing/manual_sensor_alignment.py b/examples/preprocessing/manual_sensor_alignment.py index fa630bb8..7a4ec401 100644 --- a/examples/preprocessing/manual_sensor_alignment.py +++ b/examples/preprocessing/manual_sensor_alignment.py @@ -71,7 +71,7 @@ np.array([0, 0, 1]), np.deg2rad(90) ) -rotations = dict(left_sensor=left_rot, right_sensor=right_rot) +rotations = {"left_sensor": left_rot, "right_sensor": right_rot} # As all rotations are just "axis-flips" we can use flip_dataset to apply the rotations, which is much faster than using # `rotate_dataset`. diff --git a/examples/stride_segmentation/barth_dtw_custom_template.py b/examples/stride_segmentation/barth_dtw_custom_template.py index 18c90a55..78b40789 100644 --- a/examples/stride_segmentation/barth_dtw_custom_template.py +++ b/examples/stride_segmentation/barth_dtw_custom_template.py @@ -45,7 +45,7 @@ ax.set_title(f"{foot} foot") convert_left_foot_to_fbf(template_data[f"{foot}_sensor"])[BF_GYR].plot(ax=ax) # Mark stride borders with vertical lines - for i, val in template_stride_borders[f"{foot}_sensor"].iterrows(): + for _i, val in template_stride_borders[f"{foot}_sensor"].iterrows(): ax.axvline(x=val["end"] / sampling_rate_hz, color="k") ax.axvline(x=val["start"] / sampling_rate_hz, color="k") @@ -182,7 +182,7 @@ ax.set_title(f"{foot} foot") bf_data[f"{foot}_sensor"][BF_GYR].plot(ax=ax) # Mark stride borders with vertical lines - for i, val in dtw.stride_list_[f"{foot}_sensor"].iterrows(): + for _i, val in dtw.stride_list_[f"{foot}_sensor"].iterrows(): ax.axvline(x=val["end"] / sampling_rate_hz, color="k") ax.axvline(x=val["start"] / sampling_rate_hz, color="k") diff --git a/examples/stride_segmentation/barth_dtw_stride_segmentation.py b/examples/stride_segmentation/barth_dtw_stride_segmentation.py index 5e9490d5..0b1ff4a0 100644 --- a/examples/stride_segmentation/barth_dtw_stride_segmentation.py +++ b/examples/stride_segmentation/barth_dtw_stride_segmentation.py @@ -79,7 +79,7 @@ # The main output is the `stride_list_`, which contains the start and the end of all identified strides. # As we passed a dataset with two sensors, the output will be a dictionary. stride_list_left = dtw.stride_list_["left_sensor"] -print("{} strides were detected.".format(len(stride_list_left))) +print(f"{len(stride_list_left)} strides were detected.") stride_list_left.head() # %% diff --git a/examples/stride_segmentation/barth_dtw_stride_segmentation_roi.py b/examples/stride_segmentation/barth_dtw_stride_segmentation_roi.py index e32affeb..81893ac0 100644 --- a/examples/stride_segmentation/barth_dtw_stride_segmentation_roi.py +++ b/examples/stride_segmentation/barth_dtw_stride_segmentation_roi.py @@ -83,7 +83,7 @@ # The additional "roi_id" column indicates in which ROI a stride was identified in. stride_list_left = roi_seg.stride_list_["left_sensor"] -print("{} strides were detected.".format(len(stride_list_left))) +print(f"{len(stride_list_left)} strides were detected.") stride_list_left # %% diff --git a/examples/stride_segmentation/constrained_barth_dtw_stride_segmentation.py b/examples/stride_segmentation/constrained_barth_dtw_stride_segmentation.py index e1fee32e..15ea0cec 100644 --- a/examples/stride_segmentation/constrained_barth_dtw_stride_segmentation.py +++ b/examples/stride_segmentation/constrained_barth_dtw_stride_segmentation.py @@ -77,7 +77,7 @@ # ---------------------- -def plot_dtw(dtw, sensor="left_sensor"): +def plot_dtw(dtw, sensor="left_sensor") -> None: fig, axs = plt.subplots(nrows=3, sharex=True, figsize=(10, 5)) dtw.data[sensor]["gyr_ml"].reset_index(drop=True).plot(ax=axs[0]) axs[0].set_ylabel("gyro [deg/s]") diff --git a/examples/stride_segmentation/roth_hmm_stride_segmentation.py b/examples/stride_segmentation/roth_hmm_stride_segmentation.py index df183ead..5221adea 100644 --- a/examples/stride_segmentation/roth_hmm_stride_segmentation.py +++ b/examples/stride_segmentation/roth_hmm_stride_segmentation.py @@ -13,7 +13,6 @@ Hidden Markov Model based Stride Segmentation on Unsupervised Free-living Gait Data in Parkinson’s Disease Patients. Journal of NeuroEngineering and Rehabilitation, (JNER). """ -import json import matplotlib.pyplot as plt import numpy as np @@ -77,7 +76,7 @@ # The main output is the `stride_list_`, which contains the start and the end of all identified strides. # As we passed a dataset with two sensors, the output will be a dictionary. stride_list_left = hmm_seg.stride_list_["left_sensor"] -print("{} strides were detected.".format(len(stride_list_left))) +print(f"{len(stride_list_left)} strides were detected.") stride_list_left.head() # %% diff --git a/examples/stride_segmentation/segmentation_hmm_training.py b/examples/stride_segmentation/segmentation_hmm_training.py index 2432ef2e..55296aca 100644 --- a/examples/stride_segmentation/segmentation_hmm_training.py +++ b/examples/stride_segmentation/segmentation_hmm_training.py @@ -12,6 +12,7 @@ Hidden Markov Model based Stride Segmentation on Unsupervised Free-living Gait Data in Parkinson’s Disease Patients. Journal of NeuroEngineering and Rehabilitation, (JNER). """ + import numpy as np from matplotlib import pyplot as plt diff --git a/examples/trajectory_reconstruction/trajectory_reconstruction_region.py b/examples/trajectory_reconstruction/trajectory_reconstruction_region.py index 534f8185..51f2349a 100644 --- a/examples/trajectory_reconstruction/trajectory_reconstruction_region.py +++ b/examples/trajectory_reconstruction/trajectory_reconstruction_region.py @@ -4,7 +4,7 @@ This example shows how to calculate a IMU/foot trajectory_full over an entire gait sequence using :class:`~gaitmap.trajectory_reconstruction.RegionLevelTrajectory`. -If you need an introduction to trajectory reconstruction in general, have a look +If you need an introduction to trajectory reconstruction in general, have a look at:ref`this example `. """ diff --git a/examples/trajectory_reconstruction/zupt_dependency.py b/examples/trajectory_reconstruction/zupt_dependency.py index 47d30897..0705057c 100644 --- a/examples/trajectory_reconstruction/zupt_dependency.py +++ b/examples/trajectory_reconstruction/zupt_dependency.py @@ -18,6 +18,7 @@ Navigation.” """ + import pandas as pd # %% @@ -76,13 +77,13 @@ def create_index(self) -> pd.DataFrame: from gaitmap.base import BaseZuptDetector from gaitmap.trajectory_reconstruction import RtsKalman -from gaitmap.zupt_detection import AredZuptDetector, ShoeZuptDetector +from gaitmap.zupt_detection import ShoeZuptDetector class TrajectoryPipeline(Pipeline[HealthyImu]): trajectory_: pd.DataFrame - def __init__(self, zupt_method: BaseZuptDetector = cf(ShoeZuptDetector())): + def __init__(self, zupt_method: BaseZuptDetector = cf(ShoeZuptDetector())) -> None: self.zupt_method = zupt_method def run(self, datapoint: HealthyImu) -> Self: diff --git a/gaitmap/_event_detection_common/_event_detection_mixin.py b/gaitmap/_event_detection_common/_event_detection_mixin.py index 76dbbe9c..c96555a7 100644 --- a/gaitmap/_event_detection_common/_event_detection_mixin.py +++ b/gaitmap/_event_detection_common/_event_detection_mixin.py @@ -44,7 +44,7 @@ def __init__( memory: Optional[Memory] = None, enforce_consistency: bool = True, detect_only: Optional[Tuple[str, ...]] = None, - ): + ) -> None: self.memory = memory self.enforce_consistency = enforce_consistency self.detect_only = detect_only @@ -73,7 +73,7 @@ def detect(self, data: SensorData, stride_list: StrideList, *, sampling_rate_hz: if dataset_type != stride_list_type: raise ValidationError( "An invalid combination of stride list and dataset was provided." - "The dataset is {} sensor and the stride list is {} sensor.".format(dataset_type, stride_list_type) + f"The dataset is {dataset_type} sensor and the stride list is {stride_list_type} sensor." ) self.data = data diff --git a/gaitmap/base.py b/gaitmap/base.py index 7c4c8baf..a16f158c 100644 --- a/gaitmap/base.py +++ b/gaitmap/base.py @@ -44,7 +44,7 @@ class _CustomEncoder(json.JSONEncoder): def encode(self, o: Any) -> str: return super().encode(_hint_tuples(o)) - def default(self, o): # noqa: C901 + def default(self, o): # noqa: C901, PLR0911 if isinstance(o, _BaseSerializable): return o._to_json_dict() if isinstance(o, Rotation): @@ -84,7 +84,7 @@ def default(self, o): # noqa: C901 return super().default(o) -def _custom_deserialize(json_obj): # pylint: disable=too-many-return-statements +def _custom_deserialize(json_obj): # pylint: disable=too-many-return-statements # noqa: PLR0911 if "_gaitmap_obj" in json_obj: return _BaseSerializable._find_subclass(json_obj["_gaitmap_obj"])._from_json_dict(json_obj) if "_obj_type" in json_obj: diff --git a/gaitmap/data_transform/__init__.py b/gaitmap/data_transform/__init__.py index a13ea014..fcb899d2 100644 --- a/gaitmap/data_transform/__init__.py +++ b/gaitmap/data_transform/__init__.py @@ -1,4 +1,5 @@ """Classes representing data transformations as preprocessing for different algorithms.""" + from gaitmap.data_transform._base import ( BaseTransformer, ChainedTransformer, diff --git a/gaitmap/data_transform/_base.py b/gaitmap/data_transform/_base.py index 0105bcf6..aee0b18a 100644 --- a/gaitmap/data_transform/_base.py +++ b/gaitmap/data_transform/_base.py @@ -1,4 +1,5 @@ """Basic transformers for higher level functionality.""" + from copy import copy from functools import reduce from typing import List, Sequence, Set, Tuple, Union @@ -98,7 +99,7 @@ def __init__( self, transformer_mapping: List[Tuple[Union[_Hashable, Tuple[_Hashable, ...]], BaseTransformer]], keep_all_cols: bool = True, - ): + ) -> None: self.transformer_mapping = transformer_mapping self.keep_all_cols = keep_all_cols @@ -195,7 +196,7 @@ def _validate_mapping(self) -> Set[_Hashable]: unique_k.append(i) return set(unique_k) - def _validate(self, data: SingleSensorData, selected_cols: Set[_Hashable]): + def _validate(self, data: SingleSensorData, selected_cols: Set[_Hashable]) -> None: if not set(data.columns).issuperset(selected_cols): raise ValueError("You specified transformations for columns that do not exist. This is not supported!") @@ -238,7 +239,7 @@ class ChainedTransformer(BaseTransformer, TrainableTransformerMixin): chain: OptimizableParameter[List[Tuple[_Hashable, BaseTransformer]]] - def __init__(self, chain: List[Tuple[_Hashable, BaseTransformer]]): + def __init__(self, chain: List[Tuple[_Hashable, BaseTransformer]]) -> None: self.chain = chain def self_optimize(self, data: Sequence[SingleSensorData], **kwargs) -> Self: @@ -333,7 +334,7 @@ class ParallelTransformer(BaseTransformer, TrainableTransformerMixin): transformers: OptimizableParameter[List[Tuple[_Hashable, BaseTransformer]]] - def __init__(self, transformers: List[Tuple[_Hashable, BaseTransformer]]): + def __init__(self, transformers: List[Tuple[_Hashable, BaseTransformer]]) -> None: self.transformers = transformers def self_optimize(self, data: Sequence[SingleSensorData], **kwargs) -> Self: diff --git a/gaitmap/data_transform/_feature_transform.py b/gaitmap/data_transform/_feature_transform.py index 65018d64..7dcca02b 100644 --- a/gaitmap/data_transform/_feature_transform.py +++ b/gaitmap/data_transform/_feature_transform.py @@ -1,6 +1,7 @@ """A set of transformers that can be used to calculate traditional features from a timeseries.""" + from copy import copy -from typing import Optional +from typing import NoReturn, Optional import numpy as np import pandas as pd @@ -55,7 +56,7 @@ class Resample(BaseTransformer): def __init__( self, target_sampling_rate_hz: Optional[float] = None, - ): + ) -> None: self.target_sampling_rate_hz = target_sampling_rate_hz def transform( @@ -129,7 +130,7 @@ class BaseSlidingWindowFeatureTransform(BaseTransformer): sampling_rate_hz: float - def __init__(self, window_size_s: Optional[float] = None): + def __init__(self, window_size_s: Optional[float] = None) -> None: self.window_size_s = window_size_s @property @@ -289,7 +290,7 @@ def _get_centered_window_view(array, window_size_samples, pad_value=0.0): class _CustomSlidingWindowTransform(BaseSlidingWindowFeatureTransform): - def _apply_to_window_view(self, windowed_view: np.ndarray, data: pd.DataFrame): + def _apply_to_window_view(self, windowed_view: np.ndarray, data: pd.DataFrame) -> NoReturn: raise NotImplementedError def _transform(self, data: SingleSensorData, sampling_rate_hz: float, **_) -> SingleSensorData: # noqa: ARG002 diff --git a/gaitmap/data_transform/_filter.py b/gaitmap/data_transform/_filter.py index 2c9cef66..b8761712 100644 --- a/gaitmap/data_transform/_filter.py +++ b/gaitmap/data_transform/_filter.py @@ -1,4 +1,5 @@ """A set of filters that can be applied to data.""" + from typing import Literal, Optional, Tuple, Union import pandas as pd @@ -22,9 +23,7 @@ def filtered_data_(self) -> SingleSensorData: """ return self.transformed_data_ - def filter( # noqa: A003 - self, data: SingleSensorData, *, sampling_rate_hz: Optional[float] = None, **kwargs - ) -> Self: + def filter(self, data: SingleSensorData, *, sampling_rate_hz: Optional[float] = None, **kwargs) -> Self: """Filter the data. This will apply the filter along the **first** axis (axis=0) (aka each column will be filtered). @@ -84,7 +83,7 @@ def __init__( order: int, cutoff_freq_hz: Union[float, Tuple[float, float]], filter_type: Literal["lowpass", "highpass", "bandpass", "bandstop"] = "lowpass", - ): + ) -> None: self.order = order self.cutoff_freq_hz = cutoff_freq_hz self.filter_type = filter_type diff --git a/gaitmap/data_transform/_scaler.py b/gaitmap/data_transform/_scaler.py index 168bcd37..6c5570ec 100644 --- a/gaitmap/data_transform/_scaler.py +++ b/gaitmap/data_transform/_scaler.py @@ -1,4 +1,5 @@ """Transformers that scale data to certain data ranges.""" + from typing import Optional, Sequence, Tuple import numpy as np @@ -41,7 +42,7 @@ class FixedScaler(BaseTransformer): scale: Parameter[float] offset: Parameter[float] - def __init__(self, scale: float = 1, offset: float = 0): + def __init__(self, scale: float = 1, offset: float = 0) -> None: self.scale = scale self.offset = offset @@ -94,7 +95,7 @@ class StandardScaler(BaseTransformer): ddof: Parameter[int] = 1 - def __init__(self, ddof: int = 1): + def __init__(self, ddof: int = 1) -> None: self.ddof = ddof def transform(self, data: SingleSensorData, **_) -> Self: @@ -150,7 +151,7 @@ class TrainableStandardScaler(StandardScaler, TrainableTransformerMixin): mean: OptimizableParameter[Optional[float]] std: OptimizableParameter[Optional[float]] - def __init__(self, mean: Optional[float] = None, std: Optional[float] = None, ddof: int = 1): + def __init__(self, mean: Optional[float] = None, std: Optional[float] = None, ddof: int = 1) -> None: self.mean = mean self.std = std super().__init__(ddof=ddof) @@ -234,7 +235,7 @@ class AbsMaxScaler(BaseTransformer): out_max: Parameter[float] - def __init__(self, out_max: float = 1): + def __init__(self, out_max: float = 1) -> None: self.out_max = out_max def transform(self, data: SingleSensorData, **_) -> Self: @@ -313,7 +314,7 @@ class TrainableAbsMaxScaler(AbsMaxScaler, TrainableTransformerMixin): data_max: OptimizableParameter[Optional[float]] - def __init__(self, out_max: float = 1, data_max: Optional[float] = None): + def __init__(self, out_max: float = 1, data_max: Optional[float] = None) -> None: self.data_max = data_max super().__init__(out_max=out_max) @@ -394,7 +395,7 @@ class MinMaxScaler(BaseTransformer): def __init__( self, out_range: Tuple[float, float] = (0, 1.0), - ): + ) -> None: self.out_range = out_range def transform(self, data: SingleSensorData, **_) -> Self: @@ -488,7 +489,7 @@ def __init__( self, out_range: Tuple[float, float] = (0, 1.0), data_range: Optional[Tuple[float, float]] = None, - ): + ) -> None: self.data_range = data_range super().__init__(out_range=out_range) diff --git a/gaitmap/evaluation_utils/event_detection.py b/gaitmap/evaluation_utils/event_detection.py index 06a387ef..da5cb919 100644 --- a/gaitmap/evaluation_utils/event_detection.py +++ b/gaitmap/evaluation_utils/event_detection.py @@ -76,18 +76,13 @@ def evaluate_stride_event_list( Examples -------- >>> stride_list_ground_truth = DataFrame( - ... [[10,21, 10],[20,34, 30],[31,40, 20]], - ... columns=["start", "end", "ic"] - ... ).rename_axis('s_id') + ... [[10, 21, 10], [20, 34, 30], [31, 40, 20]], columns=["start", "end", "ic"] + ... ).rename_axis("s_id") >>> stride_list_seg = DataFrame( - ... [[10,20, 10],[21,30, 30],[31,40, 22]], - ... columns=["start", "end", "ic"] - ... ).rename_axis('s_id') + ... [[10, 20, 10], [21, 30, 30], [31, 40, 22]], columns=["start", "end", "ic"] + ... ).rename_axis("s_id") >>> matches = evaluate_stride_event_list( - ... ground_truth=stride_list_ground_truth, - ... stride_event_list=stride_list_seg, - ... match_cols="ic", - ... tolerance=3 + ... ground_truth=stride_list_ground_truth, stride_event_list=stride_list_seg, match_cols="ic", tolerance=3 ... ) >>> matches s_id s_id_ground_truth match_type @@ -96,28 +91,22 @@ def evaluate_stride_event_list( 2 2 2 tp >>> stride_list_ground_truth_left = DataFrame( - ... [[10,21,30],[20,34,20],[31,40,10], [10, 30 ,60]], - ... columns=["start", "end", "ic"] - ... ).rename_axis('s_id') + ... [[10, 21, 30], [20, 34, 20], [31, 40, 10], [10, 30, 60]], columns=["start", "end", "ic"] + ... ).rename_axis("s_id") >>> stride_list_ground_truth_right = DataFrame( - ... [[10,21,1],[20,34,2],[31,40,3]], - ... columns=["start", "end", "ic"] - ... ).rename_axis('s_id') - ... + ... [[10, 21, 1], [20, 34, 2], [31, 40, 3]], columns=["start", "end", "ic"] + ... ).rename_axis("s_id") >>> stride_list_seg_left = DataFrame( - ... [[10,20, 30],[21,30,20],[31,40,13]], - ... columns=["start", "end", "ic"] - ... ).rename_axis('s_id') + ... [[10, 20, 30], [21, 30, 20], [31, 40, 13]], columns=["start", "end", "ic"] + ... ).rename_axis("s_id") >>> stride_list_seg_right = DataFrame( - ... [[10,21, 1],[20,34, 2],[31,40, 3]], - ... columns=["start", "end", "ic"] - ... ).rename_axis('s_id') - ... + ... [[10, 21, 1], [20, 34, 2], [31, 40, 3]], columns=["start", "end", "ic"] + ... ).rename_axis("s_id") >>> matches_multi = evaluate_stride_event_list( ... ground_truth={"left_sensor": stride_list_ground_truth_left, "right_sensor": stride_list_ground_truth_right}, ... stride_event_list={"left_sensor": stride_list_seg_left, "right_sensor": stride_list_seg_right}, ... match_cols="ic", - ... tolerance=2 + ... tolerance=2, ... ) >>> matches_multi["left_sensor"] s_id s_id_ground_truth match_type diff --git a/gaitmap/evaluation_utils/parameter_errors.py b/gaitmap/evaluation_utils/parameter_errors.py index 89cce55d..049b987b 100644 --- a/gaitmap/evaluation_utils/parameter_errors.py +++ b/gaitmap/evaluation_utils/parameter_errors.py @@ -1,4 +1,5 @@ """A helper function to evaluate the output of the temporal or spatial parameter calculation against a ground truth.""" + import warnings from typing import Dict, Literal, Tuple, Union @@ -250,7 +251,6 @@ def calculate_aggregated_parameter_errors( >>> pd.set_option("display.max_columns", None) >>> pd.set_option("display.width", 0) - ... >>> predicted_sensor_left = pd.DataFrame(columns=["para"], data=[23, 82, 42]).rename_axis("s_id") >>> reference_sensor_left = pd.DataFrame(columns=["para"], data=[21, 86, 65]).rename_axis("s_id") >>> predicted_sensor_right = pd.DataFrame(columns=["para"], data=[26, -58, -3]).rename_axis("s_id") @@ -326,7 +326,7 @@ def calculate_aggregated_parameter_errors( >>> calculate_aggregated_parameter_errors( ... predicted_parameter={"left_sensor": predicted_sensor_left, "right_sensor": predicted_sensor_right}, ... reference_parameter={"left_sensor": reference_sensor_left, "right_sensor": reference_sensor_right}, - ... calculate_per_sensor=False + ... calculate_per_sensor=False, ... ) # doctest: +NORMALIZE_WHITESPACE para predicted_mean 18.666667 diff --git a/gaitmap/evaluation_utils/scores.py b/gaitmap/evaluation_utils/scores.py index 9b9f8413..7f89a575 100644 --- a/gaitmap/evaluation_utils/scores.py +++ b/gaitmap/evaluation_utils/scores.py @@ -1,4 +1,5 @@ """A set of helper functions to score the output of the evaluation of a stride segmentation against ground truth.""" + import warnings from typing import Dict, Union, overload @@ -8,19 +9,21 @@ from gaitmap.utils._types import _Hashable from gaitmap.utils.datatype_helper import get_multi_sensor_names -_ScoresDict = TypedDict("_ScoresDict", {"precision": float, "recall": float, "f1_score": float}) + +class _ScoresDict(TypedDict): + precision: float + recall: float + f1_score: float @overload def recall_score( matches_df: Dict[_Hashable, pd.DataFrame], *, zero_division: Literal["warn", 0, 1] = "warn" -) -> Dict[_Hashable, float]: - ... +) -> Dict[_Hashable, float]: ... @overload -def recall_score(matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn") -> float: - ... +def recall_score(matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn") -> float: ... def recall_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn"): @@ -76,13 +79,11 @@ def recall_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn"): @overload def precision_score( matches_df: Dict[_Hashable, pd.DataFrame], *, zero_division: Literal["warn", 0, 1] = "warn" -) -> Dict[_Hashable, float]: - ... +) -> Dict[_Hashable, float]: ... @overload -def precision_score(matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn") -> float: - ... +def precision_score(matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn") -> float: ... def precision_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn"): @@ -140,13 +141,11 @@ def precision_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn" @overload def f1_score( matches_df: Dict[_Hashable, pd.DataFrame], *, zero_division: Literal["warn", 0, 1] = "warn" -) -> Dict[_Hashable, float]: - ... +) -> Dict[_Hashable, float]: ... @overload -def f1_score(matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn") -> float: - ... +def f1_score(matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn") -> float: ... def f1_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn"): @@ -205,15 +204,13 @@ def f1_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn"): @overload def precision_recall_f1_score( matches_df: Dict[str, pd.DataFrame], *, zero_division: Literal["warn", 0, 1] = "warn" -) -> Dict[str, _ScoresDict]: - ... +) -> Dict[str, _ScoresDict]: ... @overload def precision_recall_f1_score( matches_df: pd.DataFrame, *, zero_division: Literal["warn", 0, 1] = "warn" -) -> _ScoresDict: - ... +) -> _ScoresDict: ... def precision_recall_f1_score(matches_df, *, zero_division: Literal["warn", 0, 1] = "warn"): @@ -276,7 +273,7 @@ def precision_recall_f1_score(matches_df, *, zero_division: Literal["warn", 0, 1 def _get_match_type_dfs( - match_results: Union[pd.DataFrame, Dict[_Hashable, pd.DataFrame]] + match_results: Union[pd.DataFrame, Dict[_Hashable, pd.DataFrame]], ) -> Union[Dict[_Hashable, Dict[str, pd.DataFrame]], Dict[str, pd.DataFrame]]: is_not_dict = not isinstance(match_results, dict) if is_not_dict: diff --git a/gaitmap/evaluation_utils/stride_segmentation.py b/gaitmap/evaluation_utils/stride_segmentation.py index 4b4d4360..31568003 100644 --- a/gaitmap/evaluation_utils/stride_segmentation.py +++ b/gaitmap/evaluation_utils/stride_segmentation.py @@ -82,12 +82,14 @@ def evaluate_segmented_stride_list( Examples -------- - >>> stride_list_ground_truth = pd.DataFrame([[10,21],[20,34],[31,40]], columns=["start", "end"]).rename_axis('s_id') - >>> stride_list_seg = pd.DataFrame([[10,20],[21,30],[31,40],[50,60]], columns=["start", "end"]).rename_axis('s_id') + >>> stride_list_ground_truth = pd.DataFrame([[10, 21], [20, 34], [31, 40]], columns=["start", "end"]).rename_axis( + ... "s_id" + ... ) + >>> stride_list_seg = pd.DataFrame([[10, 20], [21, 30], [31, 40], [50, 60]], columns=["start", "end"]).rename_axis( + ... "s_id" + ... ) >>> matches = evaluate_segmented_stride_list( - ... ground_truth=stride_list_ground_truth, - ... segmented_stride_list=stride_list_seg, - ... tolerance=2 + ... ground_truth=stride_list_ground_truth, segmented_stride_list=stride_list_seg, tolerance=2 ... ) >>> matches s_id s_id_ground_truth match_type @@ -98,24 +100,21 @@ def evaluate_segmented_stride_list( 4 NaN 1 fn >>> stride_list_ground_truth_left = pd.DataFrame( - ... [[10,21],[20,34],[31,40]], - ... columns=["start", "end"] - ... ).rename_axis('s_id') + ... [[10, 21], [20, 34], [31, 40]], columns=["start", "end"] + ... ).rename_axis("s_id") >>> stride_list_ground_truth_right = pd.DataFrame( - ... [[10,21],[20,34],[31,40]], - ... columns=["start", "end"] - ... ).rename_axis('s_id') - ... + ... [[10, 21], [20, 34], [31, 40]], columns=["start", "end"] + ... ).rename_axis("s_id") >>> stride_list_seg_left = pd.DataFrame( - ... [[10,20],[21,30],[31,40],[50,60]], - ... columns=["start", "end"] - ... ).rename_axis('s_id') - >>> stride_list_seg_right = pd.DataFrame([[10,21],[20,34],[31,40]], columns=["start", "end"]).rename_axis('s_id') - ... + ... [[10, 20], [21, 30], [31, 40], [50, 60]], columns=["start", "end"] + ... ).rename_axis("s_id") + >>> stride_list_seg_right = pd.DataFrame([[10, 21], [20, 34], [31, 40]], columns=["start", "end"]).rename_axis( + ... "s_id" + ... ) >>> matches = evaluate_segmented_stride_list( ... ground_truth={"left_sensor": stride_list_ground_truth_left, "right_sensor": stride_list_ground_truth_right}, ... segmented_stride_list={"left_sensor": stride_list_seg_left, "right_sensor": stride_list_seg_right}, - ... tolerance=2 + ... tolerance=2, ... ) >>> matches["left_sensor"] s_id s_id_ground_truth match_type @@ -251,14 +250,16 @@ def match_stride_lists( -------- Single Sensor: - >>> stride_list_left = pd.DataFrame([[10,20],[21,30],[31,40],[50,60]], columns=["start", "end"]).rename_axis('s_id') - >>> stride_list_right = pd.DataFrame([[10,21],[20,34],[31,40]], columns=["start", "end"]).rename_axis('s_id') + >>> stride_list_left = pd.DataFrame([[10, 20], [21, 30], [31, 40], [50, 60]], columns=["start", "end"]).rename_axis( + ... "s_id" + ... ) + >>> stride_list_right = pd.DataFrame([[10, 21], [20, 34], [31, 40]], columns=["start", "end"]).rename_axis("s_id") >>> match_stride_lists( ... stride_list_a=stride_list_left, ... stride_list_b=stride_list_right, ... tolerance=2, ... postfix_a="_left", - ... postfix_b="_right" + ... postfix_b="_right", ... ) s_id_left s_id_right 0 0 0 @@ -270,21 +271,21 @@ def match_stride_lists( Multi Sensor: >>> stride_list_left_11 = pd.DataFrame( - ... [[10,20],[21,30],[31,40],[50,60]], - ... columns=["start", "end"] - ... ).rename_axis('s_id') - >>> stride_list_right_12 = pd.DataFrame([[10,21],[20,34],[31,40]], columns=["start", "end"]).rename_axis('s_id') - ... + ... [[10, 20], [21, 30], [31, 40], [50, 60]], columns=["start", "end"] + ... ).rename_axis("s_id") + >>> stride_list_right_12 = pd.DataFrame([[10, 21], [20, 34], [31, 40]], columns=["start", "end"]).rename_axis( + ... "s_id" + ... ) >>> stride_list_left_21 = pd.DataFrame( - ... [[10,20],[31,41],[21,31],[50,60]], - ... columns=["start", "end"] - ... ).rename_axis('s_id') - >>> stride_list_right_22 = pd.DataFrame([[10,22],[31, 41],[20, 36]], columns=["start", "end"]).rename_axis('s_id') - ... + ... [[10, 20], [31, 41], [21, 31], [50, 60]], columns=["start", "end"] + ... ).rename_axis("s_id") + >>> stride_list_right_22 = pd.DataFrame([[10, 22], [31, 41], [20, 36]], columns=["start", "end"]).rename_axis( + ... "s_id" + ... ) >>> test_output = match_stride_lists( ... stride_list_a={"left_sensor": stride_list_left_11, "right_sensor": stride_list_right_12}, ... stride_list_b={"left_sensor": stride_list_left_21, "right_sensor": stride_list_right_22}, - ... tolerance=1 + ... tolerance=1, ... ) >>> test_output["left_sensor"] s_id_a s_id_b @@ -377,9 +378,7 @@ def _match_single_stride_lists( ) -> pd.DataFrame: if not (set(match_cols).issubset(stride_list_a.columns) and set(match_cols).issubset(stride_list_b.columns)): raise ValueError( - "One or more selected columns ({}) are missing in at least one of the provided stride lists".format( - match_cols - ) + f"One or more selected columns ({match_cols}) are missing in at least one of the provided stride lists" ) stride_list_a = set_correct_index(stride_list_a, SL_INDEX) stride_list_b = set_correct_index(stride_list_b, SL_INDEX) diff --git a/gaitmap/event_detection/__init__.py b/gaitmap/event_detection/__init__.py index 96eed713..315722c9 100644 --- a/gaitmap/event_detection/__init__.py +++ b/gaitmap/event_detection/__init__.py @@ -2,6 +2,7 @@ Different algorithms for event detection are going to be collected here. """ + from gaitmap.event_detection._herzer_event_detection import HerzerEventDetection from gaitmap.utils._gaitmap_mad import patch_gaitmap_mad_import diff --git a/gaitmap/event_detection/_herzer_event_detection.py b/gaitmap/event_detection/_herzer_event_detection.py index f6365c21..28d06810 100644 --- a/gaitmap/event_detection/_herzer_event_detection.py +++ b/gaitmap/event_detection/_herzer_event_detection.py @@ -1,4 +1,5 @@ """An event detection algorithm optimized for stair ambulation developed by Liv Herzer in her Bachelor Thesis .""" + from typing import Callable, Dict, Optional, Tuple, Union import numpy as np @@ -197,7 +198,7 @@ def __init__( memory: Optional[Memory] = None, enforce_consistency: bool = True, detect_only: Optional[Tuple[str, ...]] = None, - ): + ) -> None: self.min_vel_search_win_size_ms = min_vel_search_win_size_ms self.mid_swing_peak_prominence = mid_swing_peak_prominence self.mid_swing_n_considered_peaks = mid_swing_n_considered_peaks @@ -328,9 +329,7 @@ def _detect_ic( # and the start acc_pa max has to be the max before that which is not necessarily # the global max within the search region refined_search_region_end = int( - search_region[0] - + np.argmax(gyr_ml_grad[slice(*search_region)]) - + 1 + search_region[0] + np.argmax(gyr_ml_grad[slice(*search_region)]) + 1 # +1 because the min max distance is often very small # and in a search range the last value is normally not included but here it should be ) diff --git a/gaitmap/gait_detection/__init__.py b/gaitmap/gait_detection/__init__.py index 173485de..e8bc6d17 100644 --- a/gaitmap/gait_detection/__init__.py +++ b/gaitmap/gait_detection/__init__.py @@ -2,6 +2,7 @@ Different algorithms for gait sequence detection are going to be collected here. """ + from gaitmap.utils._gaitmap_mad import patch_gaitmap_mad_import _gaitmap_mad_modules = { diff --git a/gaitmap/parameters/__init__.py b/gaitmap/parameters/__init__.py index 319b0450..f86c3f4a 100644 --- a/gaitmap/parameters/__init__.py +++ b/gaitmap/parameters/__init__.py @@ -1,4 +1,5 @@ """Calculate biomechanical gait parameters based on all the information calculated in the rest of the pipeline.""" + from gaitmap.parameters._spatial_parameters import SpatialParameterCalculation from gaitmap.parameters._temporal_parameters import TemporalParameterCalculation diff --git a/gaitmap/parameters/_spatial_parameters.py b/gaitmap/parameters/_spatial_parameters.py index 72811f9a..d6752a46 100644 --- a/gaitmap/parameters/_spatial_parameters.py +++ b/gaitmap/parameters/_spatial_parameters.py @@ -1,4 +1,5 @@ """Calculate spatial parameters algorithm by Kanzler et al. 2015 and Rampp et al. 2014.""" + import warnings from typing import Dict, Literal, Optional, Sequence, Union @@ -136,11 +137,8 @@ class SpatialParameterCalculation(BaseSpatialParameterCalculation): >>> orientations = ... # from orientation estimation >>> spatial_paras = SpatialParameterCalculation() >>> spatial_paras = spatial_paras.calculate( - ... stride_event_list=stride_list, - ... positions=positions, - ... orientations=orientations, - ... sampling_rate_hz=204.8 - ... ) + ... stride_event_list=stride_list, positions=positions, orientations=orientations, sampling_rate_hz=204.8 + ... ) >>> spatial_paras.parameters_ >>> spatial_paras.parameters_pretty_ @@ -167,7 +165,7 @@ def __init__( self, calculate_only: Optional[Sequence[ParamterNames]] = None, expected_stride_type: Literal["min_vel", "ic"] = "min_vel", - ): + ) -> None: self.calculate_only = calculate_only self.expected_stride_type = expected_stride_type diff --git a/gaitmap/parameters/_temporal_parameters.py b/gaitmap/parameters/_temporal_parameters.py index d6dca3bf..6dddecf5 100644 --- a/gaitmap/parameters/_temporal_parameters.py +++ b/gaitmap/parameters/_temporal_parameters.py @@ -1,4 +1,5 @@ """Calculate temporal parameters algorithm.""" + from typing import Dict, Literal, Tuple, TypeVar, Union import pandas as pd @@ -70,7 +71,7 @@ class TemporalParameterCalculation(BaseTemporalParameterCalculation): -------- This method requires the output of a event detection method as input. - >>> stride_list = ... # from event detection + >>> stride_list = ... # from event detection >>> temporal_paras = TemporalParameterCalculation() >>> temporal_paras = temporal_paras.calculate(stride_event_list=stride_list, sampling_rate_hz=204.8) >>> temporal_paras.parameters_ @@ -91,7 +92,7 @@ class TemporalParameterCalculation(BaseTemporalParameterCalculation): sampling_rate_hz: float stride_event_list: StrideList - def __init__(self, expected_stride_type: Literal["min_vel", "ic"] = "min_vel"): + def __init__(self, expected_stride_type: Literal["min_vel", "ic"] = "min_vel") -> None: self.expected_stride_type = expected_stride_type @property diff --git a/gaitmap/preprocessing/__init__.py b/gaitmap/preprocessing/__init__.py index 04cc17ea..71bcd46e 100644 --- a/gaitmap/preprocessing/__init__.py +++ b/gaitmap/preprocessing/__init__.py @@ -1,4 +1,5 @@ """A set of functions that help to align the sensor orientation and prepare the dataset for the use with gaitmap.""" + from gaitmap.preprocessing.sensor_alignment import align_dataset_to_gravity, align_heading_of_sensors __all__ = ["align_dataset_to_gravity", "align_heading_of_sensors"] diff --git a/gaitmap/preprocessing/sensor_alignment/_gravity_alignment.py b/gaitmap/preprocessing/sensor_alignment/_gravity_alignment.py index ecbccfc5..4f071fce 100644 --- a/gaitmap/preprocessing/sensor_alignment/_gravity_alignment.py +++ b/gaitmap/preprocessing/sensor_alignment/_gravity_alignment.py @@ -1,4 +1,5 @@ """Helpers to rotate the sensor in the predefined gaitmap sensor frame.""" + from typing import Dict, Union import numpy as np diff --git a/gaitmap/preprocessing/sensor_alignment/_pca_alignment.py b/gaitmap/preprocessing/sensor_alignment/_pca_alignment.py index 5c5460fb..d5dba50c 100644 --- a/gaitmap/preprocessing/sensor_alignment/_pca_alignment.py +++ b/gaitmap/preprocessing/sensor_alignment/_pca_alignment.py @@ -105,9 +105,9 @@ class PcaAlignment(BaseSensorAlignment): Align dataset to medio-lateral plane, by aligning the y-axis with the dominant component in the gyro x-y-plane - >>> pca_alignment = PcaAlignment(target_axis="y", pca_plane_axis=("gyr_x","gyr_y")) + >>> pca_alignment = PcaAlignment(target_axis="y", pca_plane_axis=("gyr_x", "gyr_y")) >>> pca_alignment = pca_alignment.align(data, 204.8) - >>> pca_alignment.aligned_data_['left_sensor'] + >>> pca_alignment.aligned_data_["left_sensor"] ... @@ -137,7 +137,7 @@ class PcaAlignment(BaseSensorAlignment): data: SensorData - def __init__(self, target_axis: str = "y", pca_plane_axis: Sequence[str] = ("gyr_x", "gyr_y")): + def __init__(self, target_axis: str = "y", pca_plane_axis: Sequence[str] = ("gyr_x", "gyr_y")) -> None: self.target_axis = target_axis self.pca_plane_axis = pca_plane_axis diff --git a/gaitmap/stride_segmentation/_roi_stride_segmentation.py b/gaitmap/stride_segmentation/_roi_stride_segmentation.py index b5cac80d..aaf9cd84 100644 --- a/gaitmap/stride_segmentation/_roi_stride_segmentation.py +++ b/gaitmap/stride_segmentation/_roi_stride_segmentation.py @@ -1,4 +1,5 @@ """Wrapper class to apply a stride segmentation to multiple regions of interest in a dataset.""" + from copy import deepcopy from typing import Dict, Generic, Optional, TypeVar, Union @@ -142,7 +143,7 @@ def __init__( segmentation_algorithm: Optional[StrideSegmentationAlgorithm] = None, s_id_naming: Literal["replace", "prefix"] = "replace", action_method: Optional[str] = None, - ): + ) -> None: self.segmentation_algorithm = segmentation_algorithm self.s_id_naming = s_id_naming self.action_method = action_method @@ -273,7 +274,7 @@ def _merge_single_sensor_stride_lists(self, stride_lists, index_name) -> StrideL ) return concat_stride_list.set_index("s_id") - def _validate_parameters(self): + def _validate_parameters(self) -> None: if self.segmentation_algorithm is None: raise ValueError( "`segmentation_algorithm` must be a valid instance of a StrideSegmentation algorithm. Currently `None`" @@ -281,7 +282,7 @@ def _validate_parameters(self): if self.s_id_naming not in ["replace", "prefix"]: raise ValueError("Invalid value for `s_id_naming`") - def _validate_other_parameters(self): + def _validate_other_parameters(self) -> None: self._multi_dataset = is_sensor_data(self.data, check_acc=False, check_gyr=False) == "multi" self._multi_roi = is_regions_of_interest_list(self.regions_of_interest, region_type="any") == "multi" if self._multi_roi and not self._multi_dataset: @@ -297,6 +298,6 @@ def _validate_other_parameters(self): missing_sensors = [key for key in self.regions_of_interest if key not in sensor_names] if len(missing_sensors) > 0: raise KeyError( - "The regions of interest list contains information for a sensor ({}) that is not in the " - "dataset.".format(missing_sensors) + f"The regions of interest list contains information for a sensor ({missing_sensors}) that is not " + "in the dataset." ) diff --git a/gaitmap/stride_segmentation/_utils.py b/gaitmap/stride_segmentation/_utils.py index ec78affe..ea18e784 100644 --- a/gaitmap/stride_segmentation/_utils.py +++ b/gaitmap/stride_segmentation/_utils.py @@ -1,4 +1,5 @@ """Some general utils for the stride segmentation algorithms.""" + from typing import Tuple, Union import numpy as np diff --git a/gaitmap/trajectory_reconstruction/_region_level_trajectory.py b/gaitmap/trajectory_reconstruction/_region_level_trajectory.py index 2a616751..84a24b67 100644 --- a/gaitmap/trajectory_reconstruction/_region_level_trajectory.py +++ b/gaitmap/trajectory_reconstruction/_region_level_trajectory.py @@ -1,4 +1,5 @@ """Wrapper to apply position and orientation estimation to multiple regions in a dataset.""" + from typing import Dict, List, Optional, Tuple, Union import pandas as pd @@ -139,9 +140,7 @@ class RegionLevelTrajectory(_TrajectoryReconstructionWrapperMixin, BaseTrajector >>> sampling_rate_hz = 204.8 >>> roi_list = ... >>> per_region_traj = per_region_traj.estimate( - ... data, - ... regions_of_interest=roi_list, - ... sampling_rate_hz=sampling_rate_hz + ... data, regions_of_interest=roi_list, sampling_rate_hz=sampling_rate_hz ... ) >>> per_region_traj.position_ @@ -164,10 +163,7 @@ class RegionLevelTrajectory(_TrajectoryReconstructionWrapperMixin, BaseTrajector >>> roi_list = ... >>> stride_list = ... >>> per_region_traj = per_region_traj.estimate_intersect( - ... data, - ... regions_of_interest=roi_list, - ... stride_event_list=stride_list, - ... sampling_rate_hz=sampling_rate_hz + ... data, regions_of_interest=roi_list, stride_event_list=stride_list, sampling_rate_hz=sampling_rate_hz ... ) >>> per_region_traj.position_ @@ -196,7 +192,7 @@ def __init__( pos_method: Optional[BasePositionMethod] = CloneFactory(ForwardBackwardIntegration()), trajectory_method: Optional[BaseTrajectoryMethod] = None, align_window_width: int = 8, - ): + ) -> None: # TODO: Make align window with a second value? self.align_window_width = align_window_width super().__init__(ori_method=ori_method, pos_method=pos_method, trajectory_method=trajectory_method) @@ -397,8 +393,8 @@ def intersect( ) from e if data_type != stride_list_type: raise ValidationError( - "You are trying to intersect the results from a {} sensor dataset with a {} " - "sensor stride list".format(data_type, stride_list_type) + f"You are trying to intersect the results from a {data_type} sensor dataset with a " + f"{stride_list_type} sensor stride list" ) if data_type == "single": data = self._intersect(data, self.regions_of_interest, stride_event_list) diff --git a/gaitmap/trajectory_reconstruction/_stride_level_trajectory.py b/gaitmap/trajectory_reconstruction/_stride_level_trajectory.py index 94789940..e8864c66 100644 --- a/gaitmap/trajectory_reconstruction/_stride_level_trajectory.py +++ b/gaitmap/trajectory_reconstruction/_stride_level_trajectory.py @@ -1,4 +1,5 @@ """Wrapper to apply position and orientation estimation to each stride of a dataset.""" + from typing import Optional from scipy.spatial.transform import Rotation @@ -112,9 +113,7 @@ class StrideLevelTrajectory(_TrajectoryReconstructionWrapperMixin, BaseTrajector >>> sampling_rate_hz = 204.8 >>> stride_list = ... >>> per_stride_traj = per_stride_traj.estimate( - ... data, - ... stride_event_list=stride_list, - ... sampling_rate_hz=sampling_rate_hz + ... data, stride_event_list=stride_list, sampling_rate_hz=sampling_rate_hz ... ) >>> per_stride_traj.position_ @@ -141,7 +140,7 @@ def __init__( pos_method: Optional[BasePositionMethod] = CloneFactory(ForwardBackwardIntegration()), trajectory_method: Optional[BaseTrajectoryMethod] = None, align_window_width: int = 8, - ): + ) -> None: # TODO: Make align window with a second value? self.align_window_width = align_window_width super().__init__(ori_method=ori_method, pos_method=pos_method, trajectory_method=trajectory_method) @@ -173,7 +172,7 @@ def estimate(self, data: SensorData, stride_event_list: StrideList, *, sampling_ if dataset_type != stride_list_type: raise ValidationError( "An invalid combination of stride list and dataset was provided. " - "The dataset is {} sensor and the stride list is {} sensor.".format(dataset_type, stride_list_type) + f"The dataset is {dataset_type} sensor and the stride list is {stride_list_type} sensor." ) # For the per stride integration, we create a dummy stride list-list, containing only the single stride that is diff --git a/gaitmap/trajectory_reconstruction/_trajectory_wrapper.py b/gaitmap/trajectory_reconstruction/_trajectory_wrapper.py index 4dce1b27..3ca67a13 100644 --- a/gaitmap/trajectory_reconstruction/_trajectory_wrapper.py +++ b/gaitmap/trajectory_reconstruction/_trajectory_wrapper.py @@ -1,4 +1,5 @@ """A helper class for common utilities TrajectoryReconstructionWrapper classes.""" + import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -44,12 +45,12 @@ def __init__( ori_method: Optional[BaseOrientationMethod] = cf(SimpleGyroIntegration()), pos_method: Optional[BasePositionMethod] = cf(ForwardBackwardIntegration()), trajectory_method: Optional[BaseTrajectoryMethod] = None, - ): + ) -> None: self.ori_method = ori_method self.pos_method = pos_method self.trajectory_method = trajectory_method - def _validate_methods(self): + def _validate_methods(self) -> None: if self.trajectory_method: if self.ori_method or self.pos_method: warnings.warn( diff --git a/gaitmap/trajectory_reconstruction/orientation_methods/_madgwick.py b/gaitmap/trajectory_reconstruction/orientation_methods/_madgwick.py index 592dfe76..9b19c97f 100644 --- a/gaitmap/trajectory_reconstruction/orientation_methods/_madgwick.py +++ b/gaitmap/trajectory_reconstruction/orientation_methods/_madgwick.py @@ -1,4 +1,5 @@ """Implementation of the MadgwickAHRS.""" + from typing import Optional, Union import numpy as np @@ -103,7 +104,7 @@ def __init__( beta: float = 0.2, initial_orientation: Union[np.ndarray, Rotation] = cf(np.array([0, 0, 0, 1.0])), memory: Optional[Memory] = None, - ): + ) -> None: self.initial_orientation = initial_orientation self.beta = beta self.memory = memory diff --git a/gaitmap/trajectory_reconstruction/orientation_methods/_simple_gyro_integration.py b/gaitmap/trajectory_reconstruction/orientation_methods/_simple_gyro_integration.py index 13417585..5d71c77a 100644 --- a/gaitmap/trajectory_reconstruction/orientation_methods/_simple_gyro_integration.py +++ b/gaitmap/trajectory_reconstruction/orientation_methods/_simple_gyro_integration.py @@ -1,4 +1,5 @@ """Naive Integration of Gyroscope to estimate the orientation.""" + from typing import Optional, Union import numpy as np @@ -82,7 +83,7 @@ def __init__( self, initial_orientation: Union[np.ndarray, Rotation] = cf(np.array([0, 0, 0, 1.0])), memory: Optional[Memory] = None, - ): + ) -> None: self.initial_orientation = initial_orientation self.memory = memory diff --git a/gaitmap/trajectory_reconstruction/position_methods/_forward_backwards_integration.py b/gaitmap/trajectory_reconstruction/position_methods/_forward_backwards_integration.py index 9776015b..ef43a23f 100644 --- a/gaitmap/trajectory_reconstruction/position_methods/_forward_backwards_integration.py +++ b/gaitmap/trajectory_reconstruction/position_methods/_forward_backwards_integration.py @@ -113,7 +113,7 @@ def __init__( steepness: float = 0.08, level_assumption: bool = True, gravity: Optional[np.ndarray] = cf(GRAV_VEC), - ): + ) -> None: self.turning_point = turning_point self.steepness = steepness self.level_assumption = level_assumption diff --git a/gaitmap/trajectory_reconstruction/trajectory_methods/_kalman_numba_funcs.py b/gaitmap/trajectory_reconstruction/trajectory_methods/_kalman_numba_funcs.py index 359baeb5..88243f1c 100644 --- a/gaitmap/trajectory_reconstruction/trajectory_methods/_kalman_numba_funcs.py +++ b/gaitmap/trajectory_reconstruction/trajectory_methods/_kalman_numba_funcs.py @@ -83,7 +83,7 @@ def madgwick_motion_update(acc, gyro, orientation, position, velocity, sampling_ @njit() -def default_rts_kalman_forward_pass( # pylint: disable=too-many-statements +def default_rts_kalman_forward_pass( # pylint: disable=too-many-statements # noqa: PLR0915 accel, gyro, initial_orientation, diff --git a/gaitmap/trajectory_reconstruction/trajectory_methods/_rts_kalman.py b/gaitmap/trajectory_reconstruction/trajectory_methods/_rts_kalman.py index 642ac4b4..3e5f3cef 100644 --- a/gaitmap/trajectory_reconstruction/trajectory_methods/_rts_kalman.py +++ b/gaitmap/trajectory_reconstruction/trajectory_methods/_rts_kalman.py @@ -1,4 +1,5 @@ """An error state kalman filter with Rauch-Tung-Striebel smoothing fo estimating trajectories.""" + from typing import Optional, Union import numpy as np @@ -115,14 +116,13 @@ class RtsKalman(BaseTrajectoryMethod): >>> data = pd.DataFrame(..., columns=SF_COLS) >>> sampling_rate_hz = 100 >>> # Create an algorithm instance - >>> kalman = RtsKalman(initial_orientation=np.array([0, 0, 0, 1.0]), - ... zupt_variance=10e-8, - ... velocity_error_variance=10e5, - ... orientation_error_variance=10e-2, - ... zupt_detector=NormZuptDetector(semsor="gyr", - ... window_length_s=0.05 - ... ) - ... ) + >>> kalman = RtsKalman( + ... initial_orientation=np.array([0, 0, 0, 1.0]), + ... zupt_variance=10e-8, + ... velocity_error_variance=10e5, + ... orientation_error_variance=10e-2, + ... zupt_detector=NormZuptDetector(semsor="gyr", window_length_s=0.05), + ... ) >>> # Apply the algorithm >>> kalman = kalman.estimate(data, sampling_rate_hz=sampling_rate_hz) >>> # Inspect the results @@ -177,7 +177,7 @@ def __init__( sensor="gyr", window_length_s=0.05, window_overlap=0.5, metric="maximum", inactive_signal_threshold=34.0 ) ), - ): + ) -> None: self.initial_orientation = initial_orientation self.zupt_variance = zupt_variance self.velocity_error_variance = velocity_error_variance @@ -374,7 +374,7 @@ def __init__( ) ), madgwick_beta: float = 0.2, - ): + ) -> None: self.madgwick_beta = madgwick_beta super().__init__( initial_orientation=initial_orientation, diff --git a/gaitmap/utils/_algo_helper.py b/gaitmap/utils/_algo_helper.py index 52687aa5..cdb5667a 100644 --- a/gaitmap/utils/_algo_helper.py +++ b/gaitmap/utils/_algo_helper.py @@ -1,4 +1,5 @@ """A set of helper functions to make developing algorithms easier.""" + from __future__ import annotations from typing import Any @@ -7,7 +8,7 @@ def invert_result_dictionary( - nested_dict: dict[_Hashable, dict[_HashableVar, Any]] + nested_dict: dict[_Hashable, dict[_HashableVar, Any]], ) -> dict[_HashableVar, dict[_Hashable, Any]]: """Invert result dictionaries that are obtained from multi sensor results. @@ -33,7 +34,7 @@ def invert_result_dictionary( return out -def set_params_from_dict(obj: Any, param_dict: dict[str, Any], result_formatting: bool = False): +def set_params_from_dict(obj: Any, param_dict: dict[str, Any], result_formatting: bool = False) -> None: """Update object attributes from dictionary. The object will be updated inplace. diff --git a/gaitmap/utils/_datatype_validation_helper.py b/gaitmap/utils/_datatype_validation_helper.py index 7bb824f2..04064728 100644 --- a/gaitmap/utils/_datatype_validation_helper.py +++ b/gaitmap/utils/_datatype_validation_helper.py @@ -1,4 +1,5 @@ """Internal helpers for dataset validation.""" + from typing import Dict, Iterable, List, Sequence, Tuple, Union import pandas as pd @@ -33,13 +34,13 @@ def _get_expected_dataset_cols( return expected_cols -def _assert_is_dtype(obj, dtype: Union[type, Tuple[type, ...]]): +def _assert_is_dtype(obj, dtype: Union[type, Tuple[type, ...]]) -> None: """Check if an object has a specific dtype.""" if not isinstance(obj, dtype): raise ValidationError(f"The dataobject is expected to be one of ({dtype},). But it is a {type(obj)}") -def _assert_has_multindex_cols(df: pd.DataFrame, nlevels: int = 2, expected: bool = True): +def _assert_has_multindex_cols(df: pd.DataFrame, nlevels: int = 2, expected: bool = True) -> None: """Check if a pd.DataFrame has a multiindex as columns. Parameters @@ -57,20 +58,20 @@ def _assert_has_multindex_cols(df: pd.DataFrame, nlevels: int = 2, expected: boo if expected is False: raise ValidationError( "The dataframe is expected to have a single level of columns. " - "But it has a MultiIndex with {} levels.".format(df.columns.nlevels) + f"But it has a MultiIndex with {df.columns.nlevels} levels." ) raise ValidationError( - "The dataframe is expected to have a MultiIndex with {} levels as columns. " - "It has just a single normal column level.".format(nlevels) + f"The dataframe is expected to have a MultiIndex with {nlevels} levels as columns. " + "It has just a single normal column level." ) if has_multiindex is True and not df.columns.nlevels == nlevels: raise ValidationError( - "The dataframe is expected to have a MultiIndex with {} levels as columns. " - "It has a MultiIndex with {} levels.".format(nlevels, df.columns.nlevels) + f"The dataframe is expected to have a MultiIndex with {nlevels} levels as columns. " + f"It has a MultiIndex with {df.columns.nlevels} levels." ) -def _assert_has_columns(df: pd.DataFrame, columns_sets: Sequence[Union[List[_Hashable], List[str]]]): +def _assert_has_columns(df: pd.DataFrame, columns_sets: Sequence[Union[List[_Hashable], List[str]]]) -> None: """Check if the dataframe has at least all columns sets. Examples @@ -93,34 +94,27 @@ def _assert_has_columns(df: pd.DataFrame, columns_sets: Sequence[Union[List[_Has else: helper_str = f"one of the following sets of columns: {columns_sets}" raise ValidationError( - "The dataframe is expected to have {}. Instead it has the following columns: {}".format( - helper_str, list(df.columns) - ) + f"The dataframe is expected to have {helper_str}. Instead it has the following columns: {list(df.columns)}" ) -def _assert_has_index_columns(df: pd.DataFrame, index_cols: Iterable[_Hashable]): +def _assert_has_index_columns(df: pd.DataFrame, index_cols: Iterable[_Hashable]) -> None: ex_index_cols = list(index_cols) ac_index_cols = list(df.index.names) if ex_index_cols != ac_index_cols: raise ValidationError( - "The dataframe is expected to have exactly the following index columns ({}), " - "but it has {}".format(index_cols, df.index.name) + f"The dataframe is expected to have exactly the following index columns ({index_cols}), " + f"but it has {df.index.name}" ) # This function exists to avoid cyclic imports in this module def _get_multi_sensor_data_names(dataset: Union[dict, pd.DataFrame]) -> Sequence[str]: - if isinstance(dataset, pd.DataFrame): - keys = dataset.columns.unique(level=0) - else: - # In case it is a dict - keys = dataset.keys() - + keys = dataset.columns.unique(level=0) if isinstance(dataset, pd.DataFrame) else dataset.keys() return keys -def _assert_multisensor_is_not_empty(obj: Union[pd.DataFrame, Dict]): +def _assert_multisensor_is_not_empty(obj: Union[pd.DataFrame, Dict]) -> None: sensors = _get_multi_sensor_data_names(obj) if len(sensors) == 0: raise ValidationError("The provided multi-sensor object does not contain any data/contains no sensors.") diff --git a/gaitmap/utils/_gaitmap_mad.py b/gaitmap/utils/_gaitmap_mad.py index 806bd71d..afa604d9 100644 --- a/gaitmap/utils/_gaitmap_mad.py +++ b/gaitmap/utils/_gaitmap_mad.py @@ -1,4 +1,5 @@ """Helper functions to handle the gaitmap/gaitmap_mad split.""" + from importlib.util import find_spec import gaitmap @@ -9,12 +10,13 @@ def patch_gaitmap_mad_import(_gaitmap_mad_modules, current_module_name): if find_spec("gaitmap_mad"): import gaitmap_mad # pylint: disable=import-outside-toplevel - assert (gm_version := gaitmap_mad.__version__) == (g_version := gaitmap.__version__), ( - "We only support using the exact same version of `gaitmap` and `gaitmap_mad`. " - f"Currently you have the versions `gaitmap`: v{g_version} and `gaitmap_mad`: v{gm_version}. " - "Update the `gaitmap` and `gaitmap_mad` packages to the same version (likely you just forgot to update " - "`gaitmap_mad` when you updated `gaitmap`)." - ) + if (gm_version := gaitmap_mad.__version__) != (g_version := gaitmap.__version__): + raise ImportError( + "We only support using the exact same version of `gaitmap` and `gaitmap_mad`. " + f"Currently you have the versions `gaitmap`: v{g_version} and `gaitmap_mad`: v{gm_version}. " + "Update the `gaitmap` and `gaitmap_mad` packages to the same version (likely you just forgot to update " + "`gaitmap_mad` when you updated `gaitmap`)." + ) return None from gaitmap.utils.exceptions import GaitmapMadImportError # pylint: disable=import-outside-toplevel diff --git a/gaitmap/utils/_types.py b/gaitmap/utils/_types.py index 86dc51a6..199f3877 100644 --- a/gaitmap/utils/_types.py +++ b/gaitmap/utils/_types.py @@ -2,6 +2,7 @@ For user facing type declarations, please see `gaitmap.utils.datatype_helper`. """ + from typing import TYPE_CHECKING, Any, Hashable, TypeVar, Union import pandas as pd diff --git a/gaitmap/utils/array_handling.py b/gaitmap/utils/array_handling.py index 7309a885..ca045b2c 100644 --- a/gaitmap/utils/array_handling.py +++ b/gaitmap/utils/array_handling.py @@ -1,4 +1,5 @@ """A set of util functions that help to manipulate arrays in any imaginable way.""" + from typing import Iterable, Iterator, List, Optional, Tuple, Union import numba.typed @@ -52,8 +53,8 @@ def sliding_window_view(arr: np.ndarray, window_length: int, overlap: int, nan_p Examples -------- - >>> data = np.arange(0,10) - >>> windowed_view = sliding_window_view(arr = data, window_length = 5, overlap = 3, nan_padding = True) + >>> data = np.arange(0, 10) + >>> windowed_view = sliding_window_view(arr=data, window_length=5, overlap=3, nan_padding=True) >>> windowed_view array([[ 0., 1., 2., 3., 4.], [ 2., 3., 4., 5., 6.], @@ -114,12 +115,12 @@ def bool_array_to_start_end_array(bool_array: np.ndarray) -> np.ndarray: Examples -------- - >>> example_array = np.array([0,0,1,1,0,0,1,1,1]) + >>> example_array = np.array([0, 0, 1, 1, 0, 0, 1, 1, 1]) >>> start_end_list = bool_array_to_start_end_array(example_array) >>> start_end_list array([[2, 4], [6, 9]]) - >>> example_array[start_end_list[0, 0]: start_end_list[0, 1]] + >>> example_array[start_end_list[0, 0] : start_end_list[0, 1]] array([1, 1]) """ @@ -138,7 +139,7 @@ def bool_array_to_start_end_array(bool_array: np.ndarray) -> np.ndarray: return np.array([[s.start, s.stop] for s in slices]) -def start_end_array_to_bool_array(start_end_array: np.ndarray, pad_to_length: int = None) -> np.ndarray: +def start_end_array_to_bool_array(start_end_array: np.ndarray, pad_to_length: Optional[int] = None) -> np.ndarray: """Convert a start-end list to a bool array. Parameters @@ -162,11 +163,11 @@ def start_end_array_to_bool_array(start_end_array: np.ndarray, pad_to_length: in Examples -------- >>> import numpy as np - >>> example_array = np.array([[3,5],[7,8]]) + >>> example_array = np.array([[3, 5], [7, 8]]) >>> start_end_array_to_bool_array(example_array, pad_to_length=12) array([False, False, False, True, True, True, False, True, True, False, False, False]) - >>> example_array = np.array([[3,5],[7,8]]) + >>> example_array = np.array([[3, 5], [7, 8]]) >>> start_end_array_to_bool_array(example_array, pad_to_length=None) array([False, False, False, True, True, True, False, True, True]) @@ -273,7 +274,7 @@ def find_extrema_in_radius( if extrema_type not in extrema_funcs: raise ValueError(f"`extrema_type` must be one of {list(extrema_funcs.keys())}, not {extrema_type}") extrema_func = extrema_funcs[extrema_type] - if radius == 0 or radius == (0, 0): + if radius in (0, (0, 0)): # In case the search radius is 0 samples, we can just return the input. return indices if isinstance(radius, int): @@ -353,7 +354,6 @@ def multi_array_interpolation(arrays: List[np.ndarray], n_samples, kind: str = " @numba.njit() def _fast_linear_interpolation(arrays: numba.typed.List, n_samples) -> np.ndarray: - final_array = np.empty((len(arrays), arrays[0].shape[1], n_samples)) for i, s in enumerate(arrays): s_len = len(s) @@ -456,8 +456,8 @@ def iterate_region_data( "The label sequences must be either SingleSensorStrideList or SingleSensorRegionsOfInterestList. " "\n" "The validations failed with the following errors: \n\n" - f"Stride List: \n\n{str(e_stride_list)}\n\n" - f"Regions of Interest List: \n\n{str(e_roi)}" + f"Stride List: \n\n{e_stride_list!s}\n\n" + f"Regions of Interest List: \n\n{e_roi!s}" ) from e_roi if expected_col_order is None: # In the first iteration we pull the column order. @@ -465,5 +465,5 @@ def iterate_region_data( # user to put all the data in RAM first. expected_col_order = df.columns df = df.reindex(columns=expected_col_order) - for (_, s, e) in labels[["start", "end"]].itertuples(): + for _, s, e in labels[["start", "end"]].itertuples(): yield df.iloc[s:e] diff --git a/gaitmap/utils/consts.py b/gaitmap/utils/consts.py index 9b383dd6..92b24406 100644 --- a/gaitmap/utils/consts.py +++ b/gaitmap/utils/consts.py @@ -1,4 +1,5 @@ """Common constants used in the library.""" + import numpy as np #: The default names of the Gyroscope columns in the sensor frame diff --git a/gaitmap/utils/coordinate_conversion.py b/gaitmap/utils/coordinate_conversion.py index 9f73a2d6..fd98e30d 100644 --- a/gaitmap/utils/coordinate_conversion.py +++ b/gaitmap/utils/coordinate_conversion.py @@ -2,6 +2,7 @@ Definitions can be found in the :ref:`coordinate_systems` guide. """ + import warnings from typing import List, Optional @@ -83,8 +84,8 @@ def convert_to_fbf( data: MultiSensorData, left: Optional[List[str]] = None, right: Optional[List[str]] = None, - right_like: str = None, - left_like: str = None, + right_like: Optional[str] = None, + left_like: Optional[str] = None, ): """Convert the axes from the sensor frame to the body frame for one MultiSensorDataset. @@ -120,7 +121,7 @@ def convert_to_fbf( -------- These examples assume that your dataset has two sensors called `left_sensor` and `right_sensor`. - >>> dataset = ... # Sensordata in FSF + >>> dataset = ... # Sensordata in FSF >>> fbf_dataset = convert_to_fbf(dataset, left_like="left_", right_like="right_") Alternatively, you can specify the full sensor names. @@ -164,9 +165,8 @@ def _handle_foot(foot, foot_like, data, rot_func): foot = [sensor for sensor in get_multi_sensor_names(data) if foot_like in sensor] if not foot: warnings.warn( - "The substring {} is not contained in any sensor name. Available sensor names are: {}".format( - foot_like, get_multi_sensor_names(data) - ) + f"The substring {foot_like} is not contained in any sensor name. Available sensor names are: " + f"{get_multi_sensor_names(data)}" ) foot = foot or [] for s in foot: diff --git a/gaitmap/utils/datatype_helper.py b/gaitmap/utils/datatype_helper.py index 542f4084..6ad5edf3 100644 --- a/gaitmap/utils/datatype_helper.py +++ b/gaitmap/utils/datatype_helper.py @@ -1,4 +1,5 @@ """A couple of helper functions that easy the use of the typical gaitmap data formats.""" + from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -135,7 +136,7 @@ def is_single_sensor_data( if raise_exception is True: raise ValidationError( "The passed object does not seem to be SingleSensorData. " - "The validation failed with the following error:\n\n{}".format(str(e)) + f"The validation failed with the following error:\n\n{e!s}" ) from e return False return True @@ -194,7 +195,7 @@ def is_multi_sensor_data( if raise_exception is True: raise ValidationError( "The passed object does not seem to be MultiSensorData. " - "The validation failed with the following error:\n\n{}".format(str(e)) + f"The validation failed with the following error:\n\n{e!s}" ) from e return False @@ -205,9 +206,7 @@ def is_multi_sensor_data( if raise_exception is True: raise ValidationError( "The passed object appears to be MultiSensorData, " - 'but for the sensor with the name "{}", the following validation error was raised:\n\n{}'.format( - k, str(e) - ) + f'but for the sensor with the name "{k}", the following validation error was raised:\n\n{e!s}' ) from e return False return True @@ -269,10 +268,10 @@ def is_sensor_data( "Below you can find the errors raised for both checks:\n\n" "Single-Sensor\n" "=============\n" - f"{str(single_error)}\n\n" + f"{single_error!s}\n\n" "Multi-Sensor\n" "=============\n" - f"{str(multi_error)}" + f"{multi_error!s}" ) @@ -365,8 +364,8 @@ def is_single_sensor_stride_list( and not np.array_equal(stride_list["start"].to_numpy(), stride_list[start_event[stride_type]].to_numpy()) ): raise ValidationError( - "For a {} stride list, the start column is expected to be identical to the {} column, " - "but they are different.".format(stride_type, start_event[stride_type]) + f"For a {stride_type} stride list, the start column is expected to be identical to the " + f"{start_event[stride_type]} column, but they are different." ) # Check that the stride ids are unique if not stride_list.index.nunique() == stride_list.index.size: @@ -376,7 +375,7 @@ def is_single_sensor_stride_list( if raise_exception is True: raise ValidationError( "The passed object does not seem to be a SingleSensorStrideList. " - "The validation failed with the following error:\n\n{}".format(str(e)) + f"The validation failed with the following error:\n\n{e!s}" ) from e return False return True @@ -424,7 +423,7 @@ def is_multi_sensor_stride_list( if raise_exception is True: raise ValidationError( "The passed object does not seem to be a MultiSensorStrideList. " - "The validation failed with the following error:\n\n{}".format(str(e)) + f"The validation failed with the following error:\n\n{e!s}" ) from e return False @@ -440,9 +439,7 @@ def is_multi_sensor_stride_list( if raise_exception is True: raise ValidationError( "The passed object appears to be a MultiSensorStrideList, " - 'but for the sensor with the name "{}", the following validation error was raised:\n\n{}'.format( - k, str(e) - ) + f'but for the sensor with the name "{k}", the following validation error was raised:\n\n{e!s}' ) from e return False return True @@ -508,10 +505,10 @@ def is_stride_list( "Below you can find the errors raised for both checks:\n\n" "Single-Sensor\n" "=============\n" - f"{str(single_error)}\n\n" + f"{single_error!s}\n\n" "Multi-Sensor\n" "=============\n" - f"{str(multi_error)}" + f"{multi_error!s}" ) @@ -522,8 +519,8 @@ def get_single_sensor_regions_of_interest_types(roi_list: SingleSensorRegionsOfI matched_index_col = [col for col in roi_list_columns if col in valid_index_dict.values()] if not matched_index_col: raise ValidationError( - "The region of interest list is expected to have one of {} either as a column or in the " - "index".format(list(valid_index_dict.values())) + f"The region of interest list is expected to have one of {list(valid_index_dict.values())} either as a " + "column or in the index" ) region_type = cast( Literal["roi", "gs"], list(valid_index_dict.keys())[list(valid_index_dict.values()).index(matched_index_col[0])] @@ -574,10 +571,9 @@ def is_single_sensor_regions_of_interest_list( actual_region_type = get_single_sensor_regions_of_interest_types(roi_list) if region_type not in ("any", actual_region_type): raise ValidationError( - "A ROI list of type {} is expected to have a either an index or a column named {}. " - "The provided ROI list appears to be of the type {} instead.".format( - region_type, ROI_ID_COLS[region_type], actual_region_type - ) + f"A ROI list of type {region_type} is expected to have a either an index or a column named " + f"{ROI_ID_COLS[region_type]}." + "The provided ROI list appears to be of the type {actual_region_type} instead." ) roi_list = set_correct_index(roi_list, [ROI_ID_COLS[actual_region_type]]) @@ -590,7 +586,7 @@ def is_single_sensor_regions_of_interest_list( if raise_exception is True: raise ValidationError( "The passed object does not seem to be a SingleSensorRegionsOfInterestList. " - "The validation failed with the following error:\n\n{}".format(str(e)) + f"The validation failed with the following error:\n\n{e!s}" ) from e return False @@ -632,7 +628,7 @@ def is_multi_sensor_regions_of_interest_list( if raise_exception is True: raise ValidationError( "The passed object does not seem to be a MultiSensorRegionsOfInterestList. " - "The validation failed with the following error:\n\n{}".format(str(e)) + f"The validation failed with the following error:\n\n{e!s}" ) from e return False @@ -643,9 +639,7 @@ def is_multi_sensor_regions_of_interest_list( if raise_exception is True: raise ValidationError( "The passed object appears to be a MultiSensorRegionsOfInterestList, " - 'but for the sensor with the name "{}", the following validation error was raised:\n\n{}'.format( - k, str(e) - ) + f'but for the sensor with the name "{k}", the following validation error was raised:\n\n{e!s}' ) from e return False return True @@ -703,10 +697,10 @@ def is_regions_of_interest_list( "Below you can find the errors raised for both checks:\n\n" "Single-Sensor\n" "=============\n" - f"{str(single_error)}\n\n" + f"{single_error!s}\n\n" "Multi-Sensor\n" "=============\n" - f"{str(multi_error)}" + f"{multi_error!s}" ) @@ -725,7 +719,7 @@ def get_multi_sensor_names(dataset: MultiSensorData) -> Sequence[_Hashable]: def get_single_sensor_trajectory_list_types( - traj_list: Union[SingleSensorPositionList, SingleSensorOrientationList, SingleSensorVelocityList] + traj_list: Union[SingleSensorPositionList, SingleSensorOrientationList, SingleSensorVelocityList], ) -> Literal["roi", "gs", "stride"]: """Identify which type of trajectory list is passed by checking the existing columns.""" traj_list_columns = traj_list.reset_index().columns @@ -733,8 +727,8 @@ def get_single_sensor_trajectory_list_types( matched_index_col = [col for col in traj_list_columns if col in valid_index_dict.values()] if not matched_index_col: raise ValidationError( - "The trajectory (orientation, position, velocity) list is expected to have one of {} either as a column or " - "in the index".format(list(valid_index_dict.values())) + "The trajectory (orientation, position, velocity) list is expected to have one of " + f"{list(valid_index_dict.values())} either as a column or in the index" ) list_type = cast( Literal["roi", "gs", "stride"], @@ -771,8 +765,8 @@ def _is_single_sensor_trajectory_list( except ValidationError as e: if raise_exception is True: raise ValidationError( - "The passed object does not seem to be a {}. " - "The validation failed with the following error:\n\n{}".format(input_datatype, str(e)) + f"The passed object does not seem to be a {input_datatype}. " + f"The validation failed with the following error:\n\n{e!s}" ) from e return False return True @@ -791,8 +785,8 @@ def _is_multi_sensor_trajectory_list( except ValidationError as e: if raise_exception is True: raise ValidationError( - "The passed object does not seem to be a {}. " - "The validation failed with the following error:\n\n{}".format(input_datatype, str(e)) + f"The passed object does not seem to be a {input_datatype}. " + f"The validation failed with the following error:\n\n{e!s}" ) from e return False @@ -802,10 +796,8 @@ def _is_multi_sensor_trajectory_list( except ValidationError as e: if raise_exception is True: raise ValidationError( - "The passed object appears to be a {}, " - 'but for the sensor with the name "{}", the following validation error was raised:\n\n{}'.format( - input_datatype, k, str(e) - ) + f"The passed object appears to be a {input_datatype}, " + f'but for the sensor with the name "{k}", the following validation error was raised:\n\n{e!s}' ) from e return False return True @@ -837,10 +829,10 @@ def _is_trajectory_list( "Below you can find the errors raised for both checks:\n\n" "Single-Sensor\n" "=============\n" - f"{str(single_error)}\n\n" + f"{single_error!s}\n\n" "Multi-Sensor\n" "=============\n" - f"{str(multi_error)}" + f"{multi_error!s}" ) @@ -1254,8 +1246,8 @@ def set_correct_index( _assert_has_columns(df_just_right_index, [index_cols]) except ValidationError as e: raise ValidationError( - "The dataframe is expected to have the following columns either in the index or as columns ({}), " - "but it has {}".format(index_cols, df.columns) + "The dataframe is expected to have the following columns either in the index or as columns " + f"({index_cols}), but it has {df.columns}" ) from e return df_just_right_index.set_index(index_cols) diff --git a/gaitmap/utils/exceptions.py b/gaitmap/utils/exceptions.py index d04fa67a..8ba12273 100644 --- a/gaitmap/utils/exceptions.py +++ b/gaitmap/utils/exceptions.py @@ -8,12 +8,12 @@ class ValidationError(Exception): class GaitmapMadImportError(ImportError): """An error indicating that the algorithm is implemented in gaitmap-mad and not gaitmap.""" - def __init__(self, object_name: str, module_name: str): + def __init__(self, object_name: str, module_name: str) -> None: self.object_name = object_name self.module_name = module_name super().__init__() - def __str__(self): + def __str__(self) -> str: """Return a string representation of the error.""" return ( f"You are trying to import {self.object_name} from {self.module_name}." diff --git a/gaitmap/utils/fast_quaternion_math.py b/gaitmap/utils/fast_quaternion_math.py index a7d022fd..60a21660 100644 --- a/gaitmap/utils/fast_quaternion_math.py +++ b/gaitmap/utils/fast_quaternion_math.py @@ -2,6 +2,7 @@ Note that we follow the same order as :class:`~scipy.spatial.transform.Rotation` (x, y, z, w). """ + import numpy as np from numba import njit diff --git a/gaitmap/utils/rotations.py b/gaitmap/utils/rotations.py index 3dd2f563..2b393d7b 100644 --- a/gaitmap/utils/rotations.py +++ b/gaitmap/utils/rotations.py @@ -2,6 +2,7 @@ All util functions use :class:`scipy.spatial.transform.Rotation` to represent rotations. """ + from typing import Callable, Dict, List, Optional, Union import numpy as np @@ -41,7 +42,7 @@ def rotation_from_angle(axis: np.ndarray, angle: Union[float, np.ndarray]) -> Ro >>> rot = rotation_from_angle(np.array([1, 0, 0]), np.deg2rad(180)) >>> rot.as_quat().round(decimals=3) array([1., 0., 0., 0.]) - >>> rot.apply(np.array([[0, 0, 1.], [0, 1, 0.]])).round() + >>> rot.apply(np.array([[0, 0, 1.0], [0, 1, 0.0]])).round() array([[ 0., -0., -1.], [ 0., -1., 0.]]) @@ -53,7 +54,7 @@ def rotation_from_angle(axis: np.ndarray, angle: Union[float, np.ndarray]) -> Ro [1. , 0. , 0. , 0. ]]) >>> # In case of multiple rotations, the first rotation is applied to the first vector >>> # and the second to the second - >>> rot.apply(np.array([[0, 0, 1.], [0, 1, 0.]])).round() + >>> rot.apply(np.array([[0, 0, 1.0], [0, 1, 0.0]])).round() array([[ 0., -1., 0.], [ 0., -1., 0.]]) @@ -340,7 +341,7 @@ def find_rotation_around_axis(rot: Rotation, rotation_axis: Union[np.ndarray, Li Examples -------- >>> # Create composite rotation around y and z axis - >>> rot = Rotation.from_rotvec([0, 0, np.pi / 2]) * Rotation.from_rotvec([0, np.pi / 4, 0 ]) + >>> rot = Rotation.from_rotvec([0, 0, np.pi / 2]) * Rotation.from_rotvec([0, np.pi / 4, 0]) >>> find_rotation_around_axis(rot, [0, 0, 1]).as_rotvec() # Extract part around z array([0. , 0. , 1.57079633]) >>> find_rotation_around_axis(rot, [0, 1, 0]).as_rotvec() # Extract part around y @@ -444,7 +445,7 @@ def find_unsigned_3d_angle(v1: np.ndarray, v2: np.ndarray) -> Union[np.ndarray, two vectors: 2D - >>> find_unsigned_3d_angle(np.array([[-1, 0, 0],[-1, 0, 0]]), np.array([[-1, 0, 0],[-1, 0, 0]])) + >>> find_unsigned_3d_angle(np.array([[-1, 0, 0], [-1, 0, 0]]), np.array([[-1, 0, 0], [-1, 0, 0]])) array([0,0]) """ diff --git a/gaitmap/utils/signal_processing.py b/gaitmap/utils/signal_processing.py index e155bdd3..4ee14fca 100644 --- a/gaitmap/utils/signal_processing.py +++ b/gaitmap/utils/signal_processing.py @@ -25,8 +25,8 @@ def butter_lowpass_filter_1d(data: np.ndarray, sampling_rate_hz: float, cutoff_f Examples -------- - >>> data = np.arange(0,100) - >>> data_filtered = butter_lowpass_filter_1d(data = data, sampling_rate_hz = 10, cutoff_freq_hz = 1, order = 4) + >>> data = np.arange(0, 100) + >>> data_filtered = butter_lowpass_filter_1d(data=data, sampling_rate_hz=10, cutoff_freq_hz=1, order=4) >>> data_filtered array([0.00000000e+00, 4.82434336e-03, 4.03774045e-02, 1.66525148e-01,...]) @@ -53,7 +53,7 @@ def row_wise_autocorrelation(array: np.ndarray, lag_max: int): Examples -------- - >>> t = np.arange(0,1,0.1) + >>> t = np.arange(0, 1, 0.1) >>> sin_wave = np.sin(t) >>> array = np.array([sin_wave, sin_wave]) >>> out = row_wise_autocorrelation(array, 5) diff --git a/gaitmap/utils/static_moment_detection.py b/gaitmap/utils/static_moment_detection.py index 234fe1f9..805fa63e 100644 --- a/gaitmap/utils/static_moment_detection.py +++ b/gaitmap/utils/static_moment_detection.py @@ -1,4 +1,5 @@ """A set of util functions to detect static regions in a IMU signal given certain constrains.""" + from functools import partial from typing import Callable, Optional, Sequence, Tuple, get_args @@ -96,7 +97,7 @@ def find_static_samples( Examples -------- >>> test_data = load_gyro_data(path) - >>> get_static_moments(gyro_data, window_length=128, overlap=64, inactive_signal_th = 5, metric = 'mean') + >>> get_static_moments(gyro_data, window_length=128, overlap=64, inactive_signal_th=5, metric="mean") References ---------- @@ -119,7 +120,7 @@ def find_static_samples( if window_length > len(signal): raise ValueError( "Invalid window length, window must be smaller or equal than given signal length. Given signal length: " - "{} with given window_length: {}.".format(len(signal), window_length) + f"{len(signal)} with given window_length: {window_length}." ) # add default overlap value @@ -221,7 +222,7 @@ def find_static_sequences( window_length: int, inactive_signal_th: float, metric: METRIC_FUNCTION_NAMES = "mean", - overlap: int = None, + overlap: Optional[int] = None, ) -> np.ndarray: """Search for static sequences within given input signal, based on windowed L2-norm thresholding. diff --git a/gaitmap/utils/stride_list_conversion.py b/gaitmap/utils/stride_list_conversion.py index af699f7a..8625425e 100644 --- a/gaitmap/utils/stride_list_conversion.py +++ b/gaitmap/utils/stride_list_conversion.py @@ -1,4 +1,5 @@ """A couple of utils to convert stride lists into different formats.""" + from typing import List, Tuple import numpy as np diff --git a/gaitmap/utils/vector_math.py b/gaitmap/utils/vector_math.py index ad308d19..c185de44 100644 --- a/gaitmap/utils/vector_math.py +++ b/gaitmap/utils/vector_math.py @@ -3,6 +3,7 @@ Wherever possible, these functions are designed to handle multiple vectors at the same time to perform efficient computations. """ + from typing import Union import numpy as np @@ -49,7 +50,7 @@ def is_almost_parallel_or_antiparallel( array of vectors - >>> is_almost_parallel_or_antiparallel(np.array([[0, 0, 1],[0,1,0]]), np.array([[0, 0, 2],[1,0,0]])) + >>> is_almost_parallel_or_antiparallel(np.array([[0, 0, 1], [0, 1, 0]]), np.array([[0, 0, 2], [1, 0, 0]])) array([True,False]) """ @@ -80,7 +81,7 @@ def normalize(v: np.ndarray) -> np.ndarray: 2D array - >>> normalize(np.array([[2, 0, 0],[2, 0, 0]])) + >>> normalize(np.array([[2, 0, 0], [2, 0, 0]])) array([[1., 0., 0.], [1., 0., 0.]]) @@ -143,7 +144,7 @@ def find_orthogonal(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: Examples -------- - >>> find_orthogonal(np.array([1, 0, 0]),np.array([-1, 0, 0])) + >>> find_orthogonal(np.array([1, 0, 0]), np.array([-1, 0, 0])) array([0, 0, -1]) """ diff --git a/gaitmap/zupt_detection/__init__.py b/gaitmap/zupt_detection/__init__.py index fb20e100..8f8ade8c 100644 --- a/gaitmap/zupt_detection/__init__.py +++ b/gaitmap/zupt_detection/__init__.py @@ -1,4 +1,5 @@ """A set of methods to detect static regions/zero-velocity regions (ZUPTS) in a signal.""" + from gaitmap.zupt_detection._base import PerSampleZuptDetectorMixin, RegionZuptDetectorMixin from gaitmap.zupt_detection._combo_zupt_detector import ComboZuptDetector from gaitmap.zupt_detection._moving_window_zupt_detector import AredZuptDetector, NormZuptDetector, ShoeZuptDetector diff --git a/gaitmap/zupt_detection/_combo_zupt_detector.py b/gaitmap/zupt_detection/_combo_zupt_detector.py index 8d68d594..c3cfadda 100644 --- a/gaitmap/zupt_detection/_combo_zupt_detector.py +++ b/gaitmap/zupt_detection/_combo_zupt_detector.py @@ -47,7 +47,7 @@ class ComboZuptDetector(BaseZuptDetector, PerSampleZuptDetectorMixin): def __init__( self, detectors: Optional[List[Tuple[str, BaseZuptDetector]]] = None, operation: Literal["and", "or"] = "or" - ): + ) -> None: self.detectors = detectors self.operation = operation diff --git a/gaitmap/zupt_detection/_moving_window_zupt_detector.py b/gaitmap/zupt_detection/_moving_window_zupt_detector.py index 53cd46e1..dc4b62d4 100644 --- a/gaitmap/zupt_detection/_moving_window_zupt_detector.py +++ b/gaitmap/zupt_detection/_moving_window_zupt_detector.py @@ -1,4 +1,5 @@ """A Basic ZUPT detector based on moving windows on the norm.""" + from typing import Optional, Tuple import numpy as np @@ -184,7 +185,7 @@ def __init__( window_overlap_samples: Optional[int] = None, metric: METRIC_FUNCTION_NAMES = "mean", inactive_signal_threshold: float = 15, - ): + ) -> None: self.sensor = sensor self.window_length_s = window_length_s self.window_overlap = window_overlap @@ -333,7 +334,7 @@ def __init__( window_overlap_samples: Optional[int] = -1, metric: METRIC_FUNCTION_NAMES = "squared_mean", inactive_signal_threshold: float = 180, - ): + ) -> None: super().__init__( sensor=sensor, window_length_s=window_length_s, @@ -459,7 +460,7 @@ def __init__( window_overlap: Optional[float] = 0.5, window_overlap_samples: Optional[int] = None, inactive_signal_threshold: float = 2310129700, - ): + ) -> None: self.acc_noise_variance = acc_noise_variance self.gyr_noise_variance = gyr_noise_variance self.window_length_s = window_length_s diff --git a/gaitmap/zupt_detection/_stride_event_zupt_detector.py b/gaitmap/zupt_detection/_stride_event_zupt_detector.py index 107fafa9..aaeb2727 100644 --- a/gaitmap/zupt_detection/_stride_event_zupt_detector.py +++ b/gaitmap/zupt_detection/_stride_event_zupt_detector.py @@ -57,7 +57,7 @@ class StrideEventZuptDetector(BaseZuptDetector, RegionZuptDetectorMixin): half_region_size_s: float half_region_size_samples_: int - def __init__(self, half_region_size_s: float = 0.05): + def __init__(self, half_region_size_s: float = 0.05) -> None: self.half_region_size_s = half_region_size_s def detect( diff --git a/gaitmap_mad/gaitmap_mad/__init__.py b/gaitmap_mad/gaitmap_mad/__init__.py index 9775b90f..f28203fb 100644 --- a/gaitmap_mad/gaitmap_mad/__init__.py +++ b/gaitmap_mad/gaitmap_mad/__init__.py @@ -7,4 +7,5 @@ library. Note, that we only support using the exact same version of the gaitmap library. """ + __version__ = "2.3.0" diff --git a/gaitmap_mad/gaitmap_mad/event_detection/__init__.py b/gaitmap_mad/gaitmap_mad/event_detection/__init__.py index a4c8c293..48d1ae62 100644 --- a/gaitmap_mad/gaitmap_mad/event_detection/__init__.py +++ b/gaitmap_mad/gaitmap_mad/event_detection/__init__.py @@ -2,6 +2,7 @@ Different algorithms for event detection are going to be collected here. """ + from gaitmap_mad.event_detection._filtered_rampp_event_detection import FilteredRamppEventDetection from gaitmap_mad.event_detection._rampp_event_detection import RamppEventDetection diff --git a/gaitmap_mad/gaitmap_mad/event_detection/_filtered_rampp_event_detection.py b/gaitmap_mad/gaitmap_mad/event_detection/_filtered_rampp_event_detection.py index 46acedd0..f5310a86 100644 --- a/gaitmap_mad/gaitmap_mad/event_detection/_filtered_rampp_event_detection.py +++ b/gaitmap_mad/gaitmap_mad/event_detection/_filtered_rampp_event_detection.py @@ -1,4 +1,5 @@ """The event detection algorithm by Rampp et al. 2014.""" + from typing import Dict, Optional, Tuple from joblib import Memory @@ -96,7 +97,7 @@ def __init__( memory: Optional[Memory] = None, enforce_consistency: bool = True, detect_only: Optional[Tuple[str, ...]] = None, - ): + ) -> None: self.ic_lowpass_filter = ic_lowpass_filter super().__init__( memory=memory, diff --git a/gaitmap_mad/gaitmap_mad/event_detection/_rampp_event_detection.py b/gaitmap_mad/gaitmap_mad/event_detection/_rampp_event_detection.py index 0de53493..63055915 100644 --- a/gaitmap_mad/gaitmap_mad/event_detection/_rampp_event_detection.py +++ b/gaitmap_mad/gaitmap_mad/event_detection/_rampp_event_detection.py @@ -1,4 +1,5 @@ """The event detection algorithm by Rampp et al. 2014.""" + from typing import Callable, Dict, Optional, Tuple, Union, cast import numpy as np @@ -161,7 +162,7 @@ def __init__( memory: Optional[Memory] = None, enforce_consistency: bool = True, detect_only: Optional[Tuple[str, ...]] = None, - ): + ) -> None: self.ic_search_region_ms = ic_search_region_ms self.min_vel_search_win_size_ms = min_vel_search_win_size_ms super().__init__(memory=memory, enforce_consistency=enforce_consistency, detect_only=detect_only) @@ -179,7 +180,7 @@ def _get_detect_kwargs(self) -> Dict[str, Union[Tuple[int, int], int]]: ) if all(v == 0 for v in ic_search_region): raise ValueError( - "The chosen values are smaller than the sample time ({} ms)".format((1 / self.sampling_rate_hz) * 1000) + f"The chosen values are smaller than the sample time ({(1 / self.sampling_rate_hz) * 1000} ms)" ) min_vel_search_win_size = int(self.min_vel_search_win_size_ms / 1000 * self.sampling_rate_hz) return { diff --git a/gaitmap_mad/gaitmap_mad/gait_detection/_ullrich_gait_sequence_detection.py b/gaitmap_mad/gaitmap_mad/gait_detection/_ullrich_gait_sequence_detection.py index f73b17dc..0f42ef24 100644 --- a/gaitmap_mad/gaitmap_mad/gait_detection/_ullrich_gait_sequence_detection.py +++ b/gaitmap_mad/gaitmap_mad/gait_detection/_ullrich_gait_sequence_detection.py @@ -1,7 +1,8 @@ """The gait sequence detection algorithm by Ullrich et al. 2020.""" + import copy import itertools -from typing import Dict, Tuple, TypeVar, Union +from typing import Dict, Optional, Tuple, TypeVar, Union import numpy as np import pandas as pd @@ -146,12 +147,12 @@ def __init__( sensor_channel_config: str = "gyr_ml", peak_prominence: float = 17.0, window_size_s: float = 10, - active_signal_threshold: float = None, + active_signal_threshold: Optional[float] = None, locomotion_band: Tuple[float, float] = (0.5, 3), harmonic_tolerance_hz: float = 0.3, merge_gait_sequences_from_sensors: bool = False, - additional_margin_s: float = None, - ): + additional_margin_s: Optional[float] = None, + ) -> None: self.sensor_channel_config = sensor_channel_config self.peak_prominence = peak_prominence self.window_size_s = window_size_s @@ -400,7 +401,7 @@ def _harmonics_analysis(self, s_1d, dominant_frequency, window_size, fft_factor, return valid_windows - def _assert_input_data(self, data): + def _assert_input_data(self, data) -> None: if self.merge_gait_sequences_from_sensors and is_multi_sensor_data(data) and not isinstance(data, pd.DataFrame): raise ValueError("Merging of data set is only possible for synchronized data sets.") @@ -429,11 +430,11 @@ def _assert_input_data(self, data): # cause edge cases for the flattened fft peak detection later on. if (self.sampling_rate_hz / 2) - (5 * self.locomotion_band[1]) < 5: raise ValueError( - "The upper limit of the locomotion band ({} Hz) is too close to the Nyquist frequency ({} Hz) of the " - "signal, given the sampling rate of {} Hz. The difference between upper limit of locomotion band and " - "Nyquist frequency should be smaller than 5 Hz.".format( - self.locomotion_band[1], self.sampling_rate_hz / 2, self.sampling_rate_hz - ) + f"The upper limit of the locomotion band ({self.locomotion_band[1]} Hz) is too close to the Nyquist " + f"frequency ({self.sampling_rate_hz / 2} Hz) of the signal, given the sampling rate of " + f"{self.sampling_rate_hz} Hz. " + "The difference between upper limit of locomotion band and Nyquist frequency should be smaller than " + "5 Hz." ) def _merge_gait_sequences_multi_sensor_data( diff --git a/gaitmap_mad/gaitmap_mad/preprocessing/sensor_alignment/_forward_direction_alignment.py b/gaitmap_mad/gaitmap_mad/preprocessing/sensor_alignment/_forward_direction_alignment.py index aae3fc4b..54f6a3b0 100644 --- a/gaitmap_mad/gaitmap_mad/preprocessing/sensor_alignment/_forward_direction_alignment.py +++ b/gaitmap_mad/gaitmap_mad/preprocessing/sensor_alignment/_forward_direction_alignment.py @@ -86,7 +86,7 @@ class ForwardDirectionSignAlignment(BaseSensorAlignment): >>> fdsa = ForwardDirectionSignAlignment(forward_direction="x", rotation_axis="z", baseline_velocity_threshold=0.2) >>> fdsa = fdsa.align(data, 204.8) - >>> fdsa.aligned_data_['left_sensor'] + >>> fdsa.aligned_data_["left_sensor"] ... >>> fdsa.is_flipped_ # True when the data was rotated by 180 deg, False afterwise @@ -132,7 +132,7 @@ def __init__( gravity=GRAV_VEC, ) ), - ): + ) -> None: self.forward_direction = forward_direction self.rotation_axis = rotation_axis self.baseline_velocity_threshold = baseline_velocity_threshold diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/__init__.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/__init__.py index dfa925d8..dfdbeebc 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/__init__.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/__init__.py @@ -7,6 +7,7 @@ algorithm, as implemented in :py:mod:`gaitmap.event_detection`, to be able to provide information about biomechanical events. """ + from gaitmap_mad.stride_segmentation.dtw import ( BarthDtw, BarthOriginalTemplate, diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/__init__.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/__init__.py index 8a828a64..ab2d69cd 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/__init__.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/__init__.py @@ -1,4 +1,5 @@ """Dtw based Stride Segmentation.""" + from gaitmap_mad.stride_segmentation.dtw._barth_dtw import BarthDtw from gaitmap_mad.stride_segmentation.dtw._base_dtw import ( BaseDtw, diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_barth_dtw.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_barth_dtw.py index 3c9af2b0..0eddbbf1 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_barth_dtw.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_barth_dtw.py @@ -1,6 +1,7 @@ """The msDTW based stride segmentation algorithm by Barth et al 2013.""" + import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, NoReturn, Optional, Tuple, Union import numpy as np import pandas as pd @@ -183,7 +184,7 @@ def __init__( snap_to_min_axis: Optional[str] = "gyr_ml", conflict_resolution: bool = True, memory: Optional[Memory] = None, - ): + ) -> None: self.snap_to_min_win_ms = snap_to_min_win_ms self.snap_to_min_axis = snap_to_min_axis self.conflict_resolution = conflict_resolution @@ -208,7 +209,7 @@ def stride_list_(self) -> StrideList: return self._format_stride_list(start_ends) @stride_list_.setter - def stride_list_(self, arg: StrideList): # noqa: no-self-use + def stride_list_(self, arg: StrideList) -> NoReturn: # noqa: ARG002 """Fake setter for the stride list. This is required to be type compatible with the base class. @@ -286,7 +287,7 @@ def _postprocess_matches( return matches_start_end, to_keep - def _post_postprocess_check(self, matches_start_end): + def _post_postprocess_check(self, matches_start_end) -> None: super()._post_postprocess_check(matches_start_end) # Check if there are still overlapping strides if np.any(np.diff(matches_start_end.flatten()) < 0): diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_base_dtw.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_base_dtw.py index 053a4c8d..b86927e6 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_base_dtw.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_base_dtw.py @@ -1,6 +1,7 @@ """A implementation of a sDTW that can be used independent of the context of Stride Segmentation.""" + import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, TypeVar, Union import numpy as np import pandas as pd @@ -257,7 +258,10 @@ class BaseDtw(BaseAlgorithm): data: Union[np.ndarray, SensorData] sampling_rate_hz: float - _allowed_methods_map = {"min_under_thres": find_matches_min_under_threshold, "find_peaks": find_matches_find_peaks} + _allowed_methods_map: ClassVar = { + "min_under_thres": find_matches_min_under_threshold, + "find_peaks": find_matches_find_peaks, + } _min_sequence_length: Optional[float] _max_sequence_length: Optional[float] _max_template_stretch: Optional[int] @@ -294,7 +298,7 @@ def __init__( max_template_stretch_ms: Optional[float] = None, max_signal_stretch_ms: Optional[float] = None, memory: Optional[Memory] = None, - ): + ) -> None: self.template = template self.max_cost = max_cost self.min_match_length_s = min_match_length_s @@ -400,9 +404,9 @@ def _segment_single_dataset( and self.resample_template is False ): warnings.warn( - "The data and template sampling rate are different ({} Hz vs. {} Hz), " - "but `resample_template` is False. " - "This might lead to unexpected results".format(template_sampling_rate, self.sampling_rate_hz) + f"The data and template sampling rate are different ({template_sampling_rate} Hz vs. " + f"{self.sampling_rate_hz} Hz), but `resample_template` is False. " + "This might lead to unexpected results" ) # Extract the parts of the data that is relevant for matching and apply potential data transforms defined in @@ -570,7 +574,7 @@ def _postprocess_matches( to_keep[indices[invalid_strides]] = False return matches_start_end, to_keep - def _post_postprocess_check(self, matches_start_end): + def _post_postprocess_check(self, matches_start_end) -> None: """Check that is invoked after all processing is done. Parameters @@ -614,9 +618,8 @@ def _extract_relevant_data_and_template( return template_array, data if template_array.shape[1] > data.shape[1]: raise ValueError( - "The provided data has less columns than the used template. ({} < {})".format( - data.shape[1], template_array.shape[1] - ) + "The provided data has less columns than the used template. " + f"({data.shape[1]} < {template_array.shape[1]})" ) return ( template_array, @@ -635,7 +638,7 @@ def _extract_relevant_data_and_template( "Some columns of the template are not available in the data! This might happen because you " "provided the data in the wrong coordinate frame (Sensor vs. Body)." "Review the general documentation for more information." - "\n\nMissing columns: {}".format(list(set(template_array.columns) - set(data.columns))) + f"\n\nMissing columns: {list(set(template_array.columns) - set(data.columns))}" ) from e return template_array.to_numpy(), data.to_numpy() # TODO: Better error message @@ -650,26 +653,24 @@ def _find_multiple_paths(acc_cost_mat: np.ndarray, start_points: np.ndarray) -> paths.append(path_array) return paths - def _validate_basic_inputs(self): + def _validate_basic_inputs(self) -> None: if self.template is None: raise ValueError("A `template` must be specified.") if self.find_matches_method not in self._allowed_methods_map: raise ValueError( - "Invalid value for `find_matches_method`. Must be one of {}".format( - list(self._allowed_methods_map.keys()) - ) + f"Invalid value for `find_matches_method`. Must be one of {list(self._allowed_methods_map.keys())}" ) if self.max_template_stretch_ms is not None and self.max_template_stretch_ms <= 0: raise ValueError( "Invalid value for `max_template_stretch_ms`." - "The value must be a number larger than 0 and not {}".format(self.max_template_stretch_ms) + f"The value must be a number larger than 0 and not {self.max_template_stretch_ms}" ) if self.max_signal_stretch_ms is not None and self.max_signal_stretch_ms <= 0: raise ValueError( "Invalid value for `max_signal_stretch_ms`." - "The value must be a number larger than 0 and not {}".format(self.max_signal_stretch_ms) + f"The value must be a number larger than 0 and not {self.max_signal_stretch_ms}" ) def _calculate_constrains(self, template: BaseDtwTemplate): diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_constrained_barth_dtw.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_constrained_barth_dtw.py index 28bed00c..9dd1e9fa 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_constrained_barth_dtw.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_constrained_barth_dtw.py @@ -1,4 +1,5 @@ """A version of BarthDTW that used local warping constrains by default.""" + from typing import Dict, Optional, Union from joblib import Memory @@ -148,7 +149,7 @@ def __init__( snap_to_min_axis: Optional[str] = "gyr_ml", conflict_resolution: bool = True, memory: Optional[Memory] = None, - ): + ) -> None: super().__init__( template=template, max_cost=max_cost, diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_dtw_templates/templates.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_dtw_templates/templates.py index 8ba8033b..8bdda9a0 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_dtw_templates/templates.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_dtw_templates/templates.py @@ -1,4 +1,5 @@ """Dtw template base classes and helper.""" + from importlib.resources import open_text from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast @@ -24,7 +25,7 @@ def __init__( *, scaling: Optional[BaseTransformer] = None, use_cols: Optional[Sequence[Union[str, int]]] = None, - ): + ) -> None: self.scaling = scaling self.use_cols = use_cols @@ -104,7 +105,9 @@ class BarthOriginalTemplate(BaseDtwTemplate): template_file_name = "barth_original_template.csv" sampling_rate_hz = 204.8 - def __init__(self, *, scaling=cf(FixedScaler(scale=500.0)), use_cols: Optional[Sequence[Union[str, int]]] = None): + def __init__( + self, *, scaling=cf(FixedScaler(scale=500.0)), use_cols: Optional[Sequence[Union[str, int]]] = None + ) -> None: super().__init__(scaling=scaling, use_cols=use_cols) def get_data(self) -> Union[np.ndarray, pd.DataFrame]: @@ -186,8 +189,7 @@ def __init__( sampling_rate_hz: Optional[float] = None, scaling: Optional[BaseTransformer] = None, use_cols: Optional[Sequence[Union[str, int]]] = None, - ): - + ) -> None: self.data = data self.sampling_rate_hz = sampling_rate_hz super().__init__(scaling=scaling, use_cols=use_cols) @@ -279,7 +281,7 @@ def __init__( interpolation_method: str = "linear", n_samples: Optional[int] = None, use_cols: Optional[Sequence[Union[str, int]]] = None, - ): + ) -> None: self.interpolation_method = interpolation_method self.n_samples = n_samples super().__init__( diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_vendored_tslearn.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_vendored_tslearn.py index ae1465cf..2fd84e5e 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_vendored_tslearn.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/dtw/_vendored_tslearn.py @@ -100,8 +100,7 @@ def subsequence_path(acc_cost_mat, idx_path_end): Examples -------- - >>> acc_cost_mat = numpy.array([[1., 0., 0., 1., 4.], - ... [5., 1., 1., 0., 1.]]) + >>> acc_cost_mat = numpy.array([[1.0, 0.0, 0.0, 1.0, 4.0], [5.0, 1.0, 1.0, 0.0, 1.0]]) >>> # calculate the globally optimal path >>> optimal_end_point = numpy.argmin(acc_cost_mat[-1, :]) >>> path = subsequence_path(acc_cost_mat, optimal_end_point) diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/__init__.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/__init__.py index 150b5c8e..292ab8f0 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/__init__.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/__init__.py @@ -1,4 +1,5 @@ """Roth et al. HMM based stride segmentation model.""" + import multiprocessing import warnings diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_feature_transform.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_feature_transform.py index bebf69b1..40562aa2 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_feature_transform.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_feature_transform.py @@ -1,5 +1,6 @@ """Feature transformation class for HMM.""" -from typing import List, Optional + +from typing import List, NoReturn, Optional import numpy as np import pandas as pd @@ -53,7 +54,7 @@ def transform( roi_list: Optional[SingleSensorRegionsOfInterestList] = None, sampling_rate_hz: Optional[float] = None, **kwargs, - ): + ) -> NoReturn: """Transform the data and the roi/stride list into to the feature space. Transforming the roi/stride list is only required, if the sampling rate of the features space is different from @@ -170,7 +171,7 @@ def __init__( features: List[str] = cf(["raw", "gradient"]), window_size_s: float = 0.2, standardization: bool = True, - ): + ) -> None: self.sampling_rate_feature_space_hz = sampling_rate_feature_space_hz self.low_pass_filter = low_pass_filter self.axes = axes diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_stride_segmentation.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_stride_segmentation.py index 9c09bc10..8e928915 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_stride_segmentation.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_hmm_stride_segmentation.py @@ -1,4 +1,5 @@ """HMM based stride segmentation by Roth et al. 2021.""" + from contextlib import suppress from importlib.resources import open_text from pathlib import Path @@ -135,7 +136,7 @@ def __init__( *, snap_to_min_win_ms: Optional[float] = 100, snap_to_min_axis: str = "gyr_ml", - ): + ) -> None: self.snap_to_min_win_ms = snap_to_min_win_ms self.snap_to_min_axis = snap_to_min_axis self.model = model diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_segmentation_model.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_segmentation_model.py index 1de4fd62..056c5b48 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_segmentation_model.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_segmentation_model.py @@ -1,4 +1,5 @@ """Segmentation _model base classes and helper.""" + import copy from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple @@ -337,7 +338,7 @@ def __init__( name: str = "segmentation_model", model: Optional[pgHMM] = None, data_columns: Optional[Tuple[str, ...]] = None, - ): + ) -> None: self.stride_model = stride_model self.transition_model = transition_model self.feature_transform = feature_transform diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_simple_model.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_simple_model.py index ec13e2f9..666869cd 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_simple_model.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_simple_model.py @@ -272,7 +272,7 @@ def __init__( name: str = "my_model", model: Optional[pgHMM] = None, data_columns: Optional[Tuple[str, ...]] = None, - ): + ) -> None: self.n_states = n_states self.n_gmm_components = n_gmm_components self.algo_train = algo_train diff --git a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_utils.py b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_utils.py index 51aadd4a..dd290363 100644 --- a/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_utils.py +++ b/gaitmap_mad/gaitmap_mad/stride_segmentation/hmm/_utils.py @@ -1,4 +1,5 @@ """Utils and helper functions for HMM classes.""" + import json import warnings from typing import Any, List, Literal, Optional, Set, Tuple @@ -13,7 +14,7 @@ from gaitmap.utils.datatype_helper import SingleSensorData, SingleSensorRegionsOfInterestList, SingleSensorStrideList -def _add_transition(model, a, b, probability, pseudocount, group): +def _add_transition(model, a, b, probability, pseudocount, group) -> None: """Hacky way to add a transition when cloning a model in the "wrong" way.""" pseudocount = pseudocount or probability model.graph.add_edge(a, b, probability=probability, pseudocount=pseudocount, group=group) @@ -91,7 +92,7 @@ def _clone_model(orig_model: pg.HiddenMarkovModel, assert_correct: bool = True) new_state_order = [state.name for state in states] # Add all the edges to the model - for (start, end, data) in list(orig_model.graph.edges(data=True)): + for start, end, data in list(orig_model.graph.edges(data=True)): _add_transition( model, states[new_state_order.index(start.name)], @@ -181,7 +182,7 @@ def create_transition_matrix_left_right( return transition_matrix, start_probs, end_probs -def print_transition_matrix(model: pg.HiddenMarkovModel, precision: int = 3): +def print_transition_matrix(model: pg.HiddenMarkovModel, precision: int = 3) -> None: """Print model transition matrix in user-friendly format.""" np.set_printoptions(suppress=True) np.set_printoptions(precision) @@ -326,7 +327,7 @@ def model_params_are_finite(model: pg.HiddenMarkovModel) -> bool: return True -def check_history_for_training_failure(history: History): +def check_history_for_training_failure(history: History) -> None: """Check if training history contains any NaNs.""" if not np.all(np.isfinite(history.improvements)) or np.any(np.array(history.improvements) < 0): warnings.warn( @@ -348,7 +349,7 @@ def get_state_by_name(model: pg.HiddenMarkovModel, state_name: str) -> str: raise ValueError(f"State {state_name} not found within given _model!") -def add_transition(model: pg.HiddenMarkovModel, transition: Tuple[str, str], transition_probability: float): +def add_transition(model: pg.HiddenMarkovModel, transition: Tuple[str, str], transition_probability: float) -> None: """Add a transition to an existing model by state-names. add_transition(model, transition = ("s0","s1"), transition_probability = 0.5) @@ -447,11 +448,11 @@ def create_equidistant_label_sequence(n_labels: int, n_states: int) -> np.ndarra Example ------- - >>> create_equidistant_label_sequence(n_labels = 10, n_states = 5) + >>> create_equidistant_label_sequence(n_labels=10, n_states=5) array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) - >>> create_equidistant_label_sequence(n_labels = 10, n_states = 4) + >>> create_equidistant_label_sequence(n_labels=10, n_states=4) array([0, 0, 0, 1, 1, 2, 2, 3, 3, 3]) - >>> create_equidistant_label_sequence(n_labels = 10, n_states = 3) + >>> create_equidistant_label_sequence(n_labels=10, n_states=3) array([0, 0, 0, 1, 1, 1, 2, 2, 2, 2]) """ @@ -513,7 +514,6 @@ def get_train_data_sequences_transitions( n_too_short_transitions = 0 for data, stride_list in zip(data_train_sequence, stride_list_sequence): - # for each transition, get data and create some naive labels for initialization for start, end in convert_stride_list_to_transition_list(stride_list, data.shape[0])[ ["start", "end"] @@ -632,7 +632,7 @@ def predict( data = np.ascontiguousarray(data.to_numpy()) try: labels_predicted = np.asarray(model.predict(data.copy(), algorithm=algorithm)) - except Exception as e: # noqa: broad-except + except Exception as e: # noqa: BLE001 if not model_params_are_finite(model): raise ValueError( "Prediction failed! (See error above.). " diff --git a/gaitmap_mad/gaitmap_mad/trajectory_reconstruction/position_methods/_piece_wise_linear_dedrifted_integration.py b/gaitmap_mad/gaitmap_mad/trajectory_reconstruction/position_methods/_piece_wise_linear_dedrifted_integration.py index db465d83..54022bbc 100644 --- a/gaitmap_mad/gaitmap_mad/trajectory_reconstruction/position_methods/_piece_wise_linear_dedrifted_integration.py +++ b/gaitmap_mad/gaitmap_mad/trajectory_reconstruction/position_methods/_piece_wise_linear_dedrifted_integration.py @@ -81,11 +81,9 @@ class PieceWiseLinearDedriftedIntegration(BasePositionMethod): >>> data = pd.DataFrame(..., columns=SF_COLS) >>> sampling_rate_hz = 100 >>> # Create an algorithm instance - >>> pwli = PieceWiseLinearDedriftedIntegration(NormZuptDetector(window_length_s=0.15, - ... inactive_signal_threshold=15. - ... ), - ... gravity=np.array([0, 0, 9.81]) - ... ) + >>> pwli = PieceWiseLinearDedriftedIntegration( + ... NormZuptDetector(window_length_s=0.15, inactive_signal_threshold=15.0), gravity=np.array([0, 0, 9.81]) + ... ) >>> # Apply the algorithm >>> pwli = pwli.estimate(data, sampling_rate_hz=sampling_rate_hz) >>> # Inspect the results @@ -144,7 +142,7 @@ def __init__( ), level_assumption: bool = True, gravity: Optional[np.ndarray] = cf(GRAV_VEC), - ): + ) -> None: self.zupt_detector = zupt_detector self.level_assumption = level_assumption self.gravity = gravity diff --git a/poetry.lock b/poetry.lock index cc9c2d65..195465a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -103,41 +103,6 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] -[[package]] -name = "black" -version = "22.12.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.7" -files = [ - {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, - {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, - {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"}, - {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"}, - {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"}, - {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"}, - {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"}, - {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"}, - {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"}, - {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"}, - {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"}, - {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} -typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "certifi" version = "2024.2.2" @@ -312,20 +277,6 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] -[[package]] -name = "click" -version = "8.1.7" -description = "Composable command line interface toolkit" -optional = false -python-versions = ">=3.7" -files = [ - {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, - {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - [[package]] name = "colorama" version = "0.4.6" @@ -1387,17 +1338,6 @@ files = [ [package.dependencies] psutil = "*" -[[package]] -name = "mypy-extensions" -version = "1.0.0" -description = "Type system extensions for programs checked with the mypy type checker." -optional = false -python-versions = ">=3.5" -files = [ - {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, - {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, -] - [[package]] name = "myst-parser" version = "1.0.0" @@ -1595,8 +1535,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1666,17 +1606,6 @@ files = [ {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, ] -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "patsy" version = "0.5.6" @@ -2336,27 +2265,28 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.0.235" -description = "An extremely fast Python linter, written in Rust." +version = "0.3.4" +description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.235-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:50327fe28aa914c4b2e3d06c3e41f47bcfbd595843a26f5f7fda30ca5318755f"}, - {file = "ruff-0.0.235-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d29966029ff77a1c336004ff3e1effd33db8554ad9ec9f87ff339d0f3d44ae35"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50baf2635584b93c09d1e69bca51041eb4ff584b20b0a443124feb7019591a4e"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cc67f4e8095ad4af9bdd81f76db9cdc4e07533aeb91037dc3548d1384200de0f"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fa9d2ba750180e3d7c23ee0151c52f1900e601be54ab516283ada368b1bb1672"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fbd10fbb7643a8334e0f6ca1095a877e2f1fb240bbd0ee23f8196592e0c092d3"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8738cabb41216d467ac92d747380c6c943d74dd4d7d1bf8a3106787ecccbd36f"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64de46e14c30a6eb9c6a458c62048b1711c46e45ff0468f14118c4d24d2fa750"}, - {file = "ruff-0.0.235-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9b475d800a6f356a7e7afae89a8ce1297e06f365eaa23b9eb80e6cb16a0915f"}, - {file = "ruff-0.0.235-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ed8771ab7bbaa9b350eb64a3d6d6628e800800cb15c5c3cc6e3e3217ff67703d"}, - {file = "ruff-0.0.235-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e60c855babdc3d8df77ac044fb3f893c2084efebc606726ecb078edc9d3c5702"}, - {file = "ruff-0.0.235-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9efb9b87b92deeaeb707581a884e1764343165df0d37c3bdc4dc297edd837dce"}, - {file = "ruff-0.0.235-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:856ec6bfda0912f8010e15ffe04c33f2793971452379dfc8bd1f30b849483ede"}, - {file = "ruff-0.0.235-py3-none-win32.whl", hash = "sha256:82cf33ce2a998d1762517cc2e4ec0f79bbd985d005b312f31674411100c41899"}, - {file = "ruff-0.0.235-py3-none-win_amd64.whl", hash = "sha256:4a8b0284d52ea7b486894899cf5ba705c7b03a9d5fa780d55ac99ab64d3967ad"}, - {file = "ruff-0.0.235.tar.gz", hash = "sha256:270c0c83c01d00370851813edfd1502f2146a0a0b4e75b723e0c388252840f5a"}, + {file = "ruff-0.3.4-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:60c870a7d46efcbc8385d27ec07fe534ac32f3b251e4fc44b3cbfd9e09609ef4"}, + {file = "ruff-0.3.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6fc14fa742e1d8f24910e1fff0bd5e26d395b0e0e04cc1b15c7c5e5fe5b4af91"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3ee7880f653cc03749a3bfea720cf2a192e4f884925b0cf7eecce82f0ce5854"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf133dd744f2470b347f602452a88e70dadfbe0fcfb5fd46e093d55da65f82f7"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f3860057590e810c7ffea75669bdc6927bfd91e29b4baa9258fd48b540a4365"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:986f2377f7cf12efac1f515fc1a5b753c000ed1e0a6de96747cdf2da20a1b369"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fd98e85869603e65f554fdc5cddf0712e352fe6e61d29d5a6fe087ec82b76c"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64abeed785dad51801b423fa51840b1764b35d6c461ea8caef9cf9e5e5ab34d9"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df52972138318bc7546d92348a1ee58449bc3f9eaf0db278906eb511889c4b50"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:98e98300056445ba2cc27d0b325fd044dc17fcc38e4e4d2c7711585bd0a958ed"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:519cf6a0ebed244dce1dc8aecd3dc99add7a2ee15bb68cf19588bb5bf58e0488"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bb0acfb921030d00070539c038cd24bb1df73a2981e9f55942514af8b17be94e"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cf187a7e7098233d0d0c71175375c5162f880126c4c716fa28a8ac418dcf3378"}, + {file = "ruff-0.3.4-py3-none-win32.whl", hash = "sha256:af27ac187c0a331e8ef91d84bf1c3c6a5dea97e912a7560ac0cef25c526a4102"}, + {file = "ruff-0.3.4-py3-none-win_amd64.whl", hash = "sha256:de0d5069b165e5a32b3c6ffbb81c350b1e3d3483347196ffdf86dc0ef9e37dd6"}, + {file = "ruff-0.3.4-py3-none-win_arm64.whl", hash = "sha256:6810563cc08ad0096b57c717bd78aeac888a1bfd38654d9113cb3dc4d3f74232"}, + {file = "ruff-0.3.4.tar.gz", hash = "sha256:f0f4484c6541a99862b693e13a151435a279b271cff20e37101116a21e2a1ad1"}, ] [[package]] @@ -2943,4 +2873,4 @@ stats = ["pingouin"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0,<4.0" -content-hash = "12c3a78051fceafb4bde8d3d05f01089d74939ff52000d07c62e63e623c77c82" +content-hash = "2762d7255132c41c6cd8379db685e7d7766bd9980209e7b8091cd65681fb0543" diff --git a/pyproject.toml b/pyproject.toml index 6f9cabb3..25042502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ all = ["pomegranate", "pingouin"] poethepoet = "^0.18.1" pytest = "^7.2.1" pytest-cov = "^4.0.0" -black = "^22.12.0" coverage = "^7.0.5" ipykernel = "^6.20.2" IPython = "^8.8.0" @@ -62,151 +61,24 @@ sphinx-gallery = "^0.11.1" pydata-sphinx-theme = "^0.14.0" numpydoc = "^1.5.0" Sphinx = "^6.1.3" -ruff = "^0.0.235" +ruff = "^0.3.4" myst-parser = "^1.0.0" -[tool.black] -line-length = 120 -target-version = ['py38'] -exclude = ''' -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | _debug_test - | docs - | build - | dist - | \.virtual_documents - )/ -) -''' - -[tool.ruff] -line-length = 120 -update-check = true -target-version = "py38" - -select = [ - # pyflakes - "F", - # pycodestyle - "E", - "W", - # mccabe - "C90", - # isort - "I", - # pydocstyle - "D", - # pyupgrade - "UP", - # pep8-naming - "N", - # flake8-blind-except - "BLE", - # flake8-2020 - "YTT", - # flake8-builtins - "A", - # flake8-comprehensions - "C4", - # flake8-debugger - "T10", - # flake8-errmsg - "EM", - # flake8-implicit-str-concat - "ISC", - # flake8-pytest-style - "PT", - # flake8-return - "RET", - # flake8-simplify - "SIM", - # flake8-unused-arguments - "ARG", - # pandas-vet - "PD", - # pygrep-hooks - "PGH", - # flake8-bugbear - "B", - # flake8-quotes - "Q", - # pylint - "PL", - # flake8-pie - "PIE", - # flake8-type-checking - "TCH", - # tryceratops - "TRY", - # flake8-use-pathlib - "PTH", - "RUF" -] - -ignore = [ - # controversial - "B006", - # controversial - "B008", - "B010", - # Magic constants - "PLR2004", - # Strings in error messages - "EM101", - "EM102", - "EM103", - # Multiline docstring summary - "D213", - # Varaibles before return - "RET504", - # Abstract raise into inner function - "TRY301", - # Use type-checking block - "TCH001", - "TCH002", - "TCH003", - # df as varaible name - "PD901", - # melt over stack - "PD013" -] - - -exclude = [ - "doc/sphinxext/*.py", - "doc/build/*.py", - "doc/temp/*.py", - ".eggs/*.py", - "example_data", - "examples" -] - - -[tool.ruff.pydocstyle] -convention = "numpy" - -[tool.ruff.pyupgrade] -# Preserve types, even if a file imports `from __future__ import annotations`. -keep-runtime-typing = true - [tool.poe.tasks] -_format_black = "black ." -_format_ruff = "ruff . --fix-only" -format = { sequence = ["_format_black", "_format_ruff"], ignore_fail = true } -lint = { cmd = "ruff gaitmap gaitmap_mad --fix", help = "Lint all files with Prospector." } -_lint_ci = "ruff gaitmap gaitmap_mad --format=github" -_check_black = "black . --check" -ci_check = { sequence = ["_check_black", "_lint_ci"], help = "Check all potential format and linting issues." } -test = { cmd = "pytest --cov=gaitmap --cov=gaitmap_mad --cov-report=term-missing --cov-report=xml", help = "Run Pytest with coverage." } +_format = "ruff format ." +_auto_fix = "ruff check . --fix-only --show-fixes --exit-zero" +_auto_fix_unsafe = "ruff check . --fix-only --show-fixes --exit-zero --unsafe-fixes" +format = ["_auto_fix", "_format"] +format_unsafe = ["_auto_fix_unsafe", "_format"] +lint = { cmd = "ruff check gaitmap gaitmap_mad --fix", help = "Lint all files with ruff." } +_lint_ci = "ruff check gaitmap gaitmap_mad --output-format=github" +_check_format = "ruff format . --check" +ci_check = { sequence = ["_check_format", "_lint_ci"], help = "Check all potential format and linting issues." } +test = { cmd = "pytest --cov=gaitmap --cov-report=term-missing --cov-report=xml", help = "Run Pytest with coverage." } docs = { "script" = "_tasks:task_docs()", help = "Build the html docs using Sphinx." } +docs_clean = { "script" = "_tasks:task_docs(clean=True)", help = "Remove all old build files and build a clean version of the docs." } +docs_linkcheck = { "script" = "_tasks:task_docs(builder='linkcheck')", help = "Check all links in the built html docs." } +docs_preview = { cmd = "python -m http.server --directory docs/_build/html", help = "Preview the built html docs." } register_jupyter = { "script" = "_tasks:task_register_ipykernel()", help = "Register the gaitmap environment as a Jupyter kernel for testing." } version = { "script" = "_tasks:task_update_version()", help = "Bump version in all relevant places." } bump_dev = { script = "_tasks:task_bump_all_dev()", help= "Update all dev dependencies to their @latest version."} diff --git a/tests/_regression_utils.py b/tests/_regression_utils.py index 6a6bf355..a6ac5345 100644 --- a/tests/_regression_utils.py +++ b/tests/_regression_utils.py @@ -3,6 +3,7 @@ This is inspired by github.com/syrusakbary/snapshottest. Note that it can not be used in combination with this module! """ + import re from pathlib import Path @@ -11,7 +12,7 @@ from pandas._testing import assert_frame_equal -def pytest_addoption(parser): +def pytest_addoption(parser) -> None: group = parser.getgroup("snapshottest") group.addoption( "--snapshot-update", action="store_true", default=False, dest="snapshot_update", help="Update the snapshots." @@ -23,7 +24,7 @@ class SnapshotNotFound(Exception): class PyTestSnapshotTest: - def __init__(self, request=None): + def __init__(self, request=None) -> None: self.request = request self.curr_snapshot_number = 0 super().__init__() @@ -64,7 +65,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - def store(self, value): + def store(self, value) -> None: self.snapshot_folder.mkdir(parents=True, exist_ok=True) if isinstance(value, pd.DataFrame): value.to_json(self.file_name_json, indent=4, orient="table") @@ -97,7 +98,7 @@ def retrieve(self, dtype): else: raise ValueError(f"The dtype {dtype} is not supported for snapshot testing") - def assert_match(self, value, name="", **kwargs): + def assert_match(self, value, name="", **kwargs) -> None: self.curr_snapshot = name or self.curr_snapshot_number if self.update: self.store(value) diff --git a/tests/_test_gaitmap_mad_split.py b/tests/_test_gaitmap_mad_split.py index 4c4d8286..277275a3 100644 --- a/tests/_test_gaitmap_mad_split.py +++ b/tests/_test_gaitmap_mad_split.py @@ -6,6 +6,7 @@ Hence, these tests are excluded (Leading "_" in filename), but can be run manually. """ + import importlib import sys @@ -20,12 +21,12 @@ def _gaitmap_mad_sys_modifier(): # entry to None. # This import will force gaitmap_mad to be in sys.modules. - import gaitmap_mad + import gaitmap_mad # noqa: F401 sys.modules["gaitmap_mad"] = None yield sys.modules.pop("gaitmap_mad") - import gaitmap_mad # noqa: unused-import + import gaitmap_mad # noqa: F401 # We just go overboard to be save and reimport all gaitmap modules after the cleanup. modules_to_reload = [] @@ -48,12 +49,12 @@ def _gaitmap_mad_change_version(): importlib.reload(gaitmap_mad) -def test_raises_error_gaitmap_mad_not_installed(_gaitmap_mad_sys_modifier): +def test_raises_error_gaitmap_mad_not_installed(_gaitmap_mad_sys_modifier) -> None: # First we need to remove gaitmap_mad from sys.modules, so that it is not imported. # We need to make sure that this is a fresh import: sys.modules.pop("gaitmap.stride_segmentation", None) with pytest.raises(GaitmapMadImportError) as e: - from gaitmap.stride_segmentation import BarthDtw # noqa: unused-import + from gaitmap.stride_segmentation import BarthDtw # noqa: F401 assert e.value.object_name == "BarthDtw" assert e.value.module_name == "gaitmap.stride_segmentation" @@ -61,16 +62,16 @@ def test_raises_error_gaitmap_mad_not_installed(_gaitmap_mad_sys_modifier): assert "gaitmap.stride_segmentation" in str(e.value) -def test_gaitmap_mad_version_mismatch(_gaitmap_mad_change_version): +def test_gaitmap_mad_version_mismatch(_gaitmap_mad_change_version) -> None: # We need to make sure that this is a fresh import: sys.modules.pop("gaitmap.stride_segmentation", None) with pytest.raises(AssertionError): - from gaitmap.stride_segmentation import BarthDtw # noqa: unused-import + from gaitmap.stride_segmentation import BarthDtw # noqa: F401 -def test_raises_no_error_gaitmap_mad_installed(): +def test_raises_no_error_gaitmap_mad_installed() -> None: # We need to make sure that this is a fresh import: sys.modules.pop("gaitmap.stride_segmentation", None) - from gaitmap.stride_segmentation import BarthDtw # noqa: unused-import + from gaitmap.stride_segmentation import BarthDtw # noqa: F401 assert True diff --git a/tests/conftest.py b/tests/conftest.py index 2e423046..99947899 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ @pytest.fixture(autouse=True) -def reset_random_seed(): +def reset_random_seed() -> None: np.random.seed(10) random.seed(10) @@ -39,7 +39,7 @@ def snapshot(request): yield snapshot_test -def pytest_addoption(parser): +def pytest_addoption(parser) -> None: group = parser.getgroup("snapshottest") group.addoption( "--snapshot-update", action="store_true", default=False, dest="snapshot_update", help="Update the snapshots." @@ -59,7 +59,7 @@ def _get_params_without_nested_class(instance: BaseTpcpObject) -> Dict[str, Any] return {k: v for k, v in instance.get_params().items() if not hasattr(v, "get_params")} -def compare_algo_objects(a, b): +def compare_algo_objects(a, b) -> None: parameters = _get_params_without_nested_class(a) b_parameters = _get_params_without_nested_class(b) @@ -70,7 +70,7 @@ def compare_algo_objects(a, b): compare_val(value, json_val, p) -def compare_val(value, json_val, name): +def compare_val(value, json_val, name) -> None: if isinstance(value, BaseTpcpObject): compare_algo_objects(value, json_val) elif isinstance(value, np.ndarray): diff --git a/tests/mixins/test_algorithm_mixin.py b/tests/mixins/test_algorithm_mixin.py index f01d1469..017f9037 100644 --- a/tests/mixins/test_algorithm_mixin.py +++ b/tests/mixins/test_algorithm_mixin.py @@ -1,4 +1,5 @@ """A mixin for all common tests that should be run on all algorithm classes.""" + import inspect import joblib @@ -22,7 +23,7 @@ def valid_instance(self, after_action_instance): def after_action_instance(self) -> BaseType: pass - def test_init(self): + def test_init(self) -> None: """Test that all init paras are passed through untouched.""" field_names = get_param_names(self.algorithm_class) init_dict = {k: k for k in field_names} @@ -32,11 +33,11 @@ def test_init(self): for k, v in init_dict.items(): assert getattr(test_instance, k) == v, k - def test_empty_init(self): + def test_empty_init(self) -> None: """Test that the class has only optional kwargs.""" self.algorithm_class() - def test_all_parameters_documented(self): + def test_all_parameters_documented(self) -> None: docs = NumpyDocString(inspect.getdoc(self.algorithm_class)) documented_names = {p.name for p in docs["Parameters"]} @@ -44,7 +45,7 @@ def test_all_parameters_documented(self): assert documented_names == actual_names - def test_all_attributes_documented(self, after_action_instance): + def test_all_attributes_documented(self, after_action_instance) -> None: if not after_action_instance: pytest.skip("The testclass did not implement the correct `after_action_instance` fixture.") docs = NumpyDocString(inspect.getdoc(self.algorithm_class)) @@ -54,7 +55,7 @@ def test_all_attributes_documented(self, after_action_instance): assert documented_names == actual_names - def test_all_other_parameters_documented(self, after_action_instance): + def test_all_other_parameters_documented(self, after_action_instance) -> None: if not after_action_instance: pytest.skip("The testclass did not implement the correct `after_action_instance` fixture.") docs = NumpyDocString(inspect.getdoc(self.algorithm_class)) @@ -64,7 +65,7 @@ def test_all_other_parameters_documented(self, after_action_instance): assert documented_names == actual_names - def test_action_method_returns_self(self, after_action_instance): + def test_action_method_returns_self(self, after_action_instance) -> None: # call the action method a second time to test the output if not after_action_instance: pytest.skip("The testclass did not implement the correct `after_action_instance` fixture.") @@ -73,7 +74,7 @@ def test_action_method_returns_self(self, after_action_instance): assert id(results) == id(after_action_instance) - def test_set_params_valid(self, valid_instance): + def test_set_params_valid(self, valid_instance) -> None: instance = valid_instance.clone() valid_names = get_param_names(instance) values = list(range(len(valid_names))) @@ -82,7 +83,7 @@ def test_set_params_valid(self, valid_instance): for k, v in zip(valid_names, values): assert getattr(instance, k) == v, k - def test_set_params_invalid(self, valid_instance): + def test_set_params_invalid(self, valid_instance) -> None: instance = valid_instance.clone() with pytest.raises(ValueError) as e: @@ -91,7 +92,7 @@ def test_set_params_invalid(self, valid_instance): assert "an_invalid_name" in str(e) assert self.algorithm_class.__name__ in str(e) - def test_json_roundtrip(self, valid_instance): + def test_json_roundtrip(self, valid_instance) -> None: instance = valid_instance.clone() json_str = instance.to_json() @@ -100,13 +101,13 @@ def test_json_roundtrip(self, valid_instance): compare_algo_objects(instance, instance_from_json) - def test_hashing(self, valid_instance): + def test_hashing(self, valid_instance) -> None: """This checks if caching with joblib will work as expected.""" instance = valid_instance.clone() assert joblib.hash(instance) == joblib.hash(instance.clone()) - def test_nested_algo_marked_default(self): + def test_nested_algo_marked_default(self) -> None: init = self.algorithm_class.__init__ if init is object.__init__: # No explicit constructor to introspect diff --git a/tests/mixins/test_caching_mixin.py b/tests/mixins/test_caching_mixin.py index fb8c63ca..e4a32e42 100644 --- a/tests/mixins/test_caching_mixin.py +++ b/tests/mixins/test_caching_mixin.py @@ -1,5 +1,6 @@ import re from tempfile import TemporaryDirectory +from typing import NoReturn import joblib import pytest @@ -19,15 +20,15 @@ class TestCachingMixin: def after_action_instance(self) -> BaseType: pass - def assert_after_action_instance(self, instance): + def assert_after_action_instance(self, instance) -> NoReturn: """Test some aspects of the resulting instance to ensure that results retrieved from cache are correct.""" raise NotImplementedError() - def test_memory_as_params(self, after_action_instance): + def test_memory_as_params(self, after_action_instance) -> None: assert hasattr(after_action_instance, "memory") assert "memory" in after_action_instance.get_params() - def test_cached_call_works(self, after_action_instance, capsys): + def test_cached_call_works(self, after_action_instance, capsys) -> None: parameters = get_action_params(after_action_instance) algo = after_action_instance.clone() @@ -80,7 +81,7 @@ def test_cached_call_works(self, after_action_instance, capsys): assert algo_json == after_first_json == after_second_json assert algo_hash == after_first_hash == after_second_hash - def test_cached_json_export(self): + def test_cached_json_export(self) -> None: """Test that there is a warning on json export.""" instance = self.algorithm_class(memory=Memory(None)) with pytest.warns(UserWarning) as w: diff --git a/tests/test_base.py b/tests/test_base.py index 85691360..608ce59c 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,4 +1,5 @@ """This tests the BaseAlgorithm and fundamental functionality.""" + from inspect import Parameter, signature from typing import Any, Dict, Tuple @@ -11,7 +12,7 @@ def _init_getter(): - def _fake_init(self, **kwargs): + def _fake_init(self, **kwargs) -> None: for k, v in kwargs.items(): setattr(self, k, v) @@ -75,7 +76,7 @@ def example_test_class_after_action(example_test_class_initialised) -> Tuple[Bas return test_instance, params -def test_get_action_method(example_test_class_after_action): +def test_get_action_method(example_test_class_after_action) -> None: instance, test_parameters = example_test_class_after_action assert get_action_methods_names(instance)[0] == test_parameters["action_method_name"] @@ -86,25 +87,25 @@ def test_get_action_method(example_test_class_after_action): get_action_method(instance) -def test_get_attributes(example_test_class_after_action): +def test_get_attributes(example_test_class_after_action) -> None: instance, test_parameters = example_test_class_after_action assert get_results(instance) == test_parameters["attributes"] -def test_get_parameter(example_test_class_after_action): +def test_get_parameter(example_test_class_after_action) -> None: instance, test_parameters = example_test_class_after_action assert instance.get_params() == test_parameters["params"] -def test_get_other_parameter(example_test_class_after_action): +def test_get_other_parameter(example_test_class_after_action) -> None: instance, test_parameters = example_test_class_after_action assert get_action_params(instance) == test_parameters["other_params"] -def test_normal_wrong_attr_still_raises_attr_error(example_test_class_initialised): +def test_normal_wrong_attr_still_raises_attr_error(example_test_class_initialised) -> None: instance, test_parameters = example_test_class_initialised key = "not_existend_without_underscore" @@ -118,7 +119,7 @@ def test_normal_wrong_attr_still_raises_attr_error(example_test_class_initialise @pytest.mark.parametrize("key", ["wrong_with_", "wrong_without"]) -def test_attribute_helper_after_action_wrong(example_test_class_after_action, key): +def test_attribute_helper_after_action_wrong(example_test_class_after_action, key) -> None: instance, test_parameters = example_test_class_after_action if not test_parameters["attributes"]: @@ -132,13 +133,13 @@ def test_attribute_helper_after_action_wrong(example_test_class_after_action, ke assert get_action_methods_names(instance)[0] not in str(e.value) -def test_action_is_not_applied(example_test_class_initialised): +def test_action_is_not_applied(example_test_class_initialised) -> None: instance, _ = example_test_class_initialised assert is_action_applied(instance) is False -def test_action_is_applied(example_test_class_after_action): +def test_action_is_applied(example_test_class_after_action) -> None: instance, test_parameters = example_test_class_after_action if not test_parameters["attributes"]: @@ -147,7 +148,7 @@ def test_action_is_applied(example_test_class_after_action): assert is_action_applied(instance) is True -def test_nested_get_params(): +def test_nested_get_params() -> None: nested_instance = create_test_class("nested", params={"nested1": "n1", "nested2": "n2"}) top_level_params = {"test1": "t1"} test_instance = create_test_class("test", params={**top_level_params, "nested_class": nested_instance}) @@ -163,7 +164,7 @@ def test_nested_get_params(): assert params[k] == v -def test_nested_set_params(): +def test_nested_set_params() -> None: nested_instance = create_test_class("nested", params={"nested1": "n1", "nested2": "n2"}) top_level_params = {"test1": "t1"} test_instance = create_test_class("test", params={**top_level_params, "nested_class": nested_instance}) @@ -181,7 +182,7 @@ def test_nested_set_params(): assert params_nested[k] == v -def test_nested_clone(): +def test_nested_clone() -> None: nested_instance = create_test_class("nested", params={"nested1": "n1", "nested2": "n2"}) top_level_params = {"test1": "t1"} test_instance = create_test_class("test", params={**top_level_params, "nested_class": nested_instance}) @@ -199,7 +200,7 @@ def test_nested_clone(): assert cloned_params[k] == v -def test_clone_pomegranate(): +def test_clone_pomegranate() -> None: pytest.importorskip("pomegranate") from gaitmap.stride_segmentation.hmm import PreTrainedRothSegmentationModel diff --git a/tests/test_data_transforms/test_base.py b/tests/test_data_transforms/test_base.py index b21e427a..86c6fe88 100644 --- a/tests/test_data_transforms/test_base.py +++ b/tests/test_data_transforms/test_base.py @@ -32,10 +32,10 @@ class TestMetaFunctionality(TestAlgorithmMixin): __test__ = True @pytest.fixture(params=all_base_transformer, autouse=True) - def set_algo_class(self, request): + def set_algo_class(self, request) -> None: self.algorithm_class = request.param - def test_empty_init(self): + def test_empty_init(self) -> None: pytest.skip() @pytest.fixture() @@ -55,7 +55,7 @@ def after_action_instance(self, healthy_example_imu_data, healthy_example_stride class TestIdentityTransformer: - def test_transform(self): + def test_transform(self) -> None: t = IdentityTransformer() data = pd.DataFrame(np.random.rand(10, 3)) t.transform(data) @@ -67,7 +67,7 @@ def test_transform(self): class TestGroupedTransformer: @pytest.mark.parametrize("keep_all_cols", [True, False]) - def test_transform_no_opti(self, keep_all_cols): + def test_transform_no_opti(self, keep_all_cols) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) t = GroupedTransformer( transformer_mapping=[("b", FixedScaler(3)), ("a", FixedScaler(2))], keep_all_cols=keep_all_cols @@ -85,7 +85,7 @@ def test_transform_no_opti(self, keep_all_cols): # Test that the order of columns matches the data assert t.transformed_data_.columns.tolist() == ["a", "b", "c"] if keep_all_cols else ["a", "b"] - def test_multi_scale(self): + def test_multi_scale(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) t = GroupedTransformer(transformer_mapping=[(("a", "b", "c"), FixedScaler(3))]) t.transform(data) @@ -93,18 +93,18 @@ def test_multi_scale(self): assert id(t.data) == id(data) assert_frame_equal(t.transformed_data_, data / 3.0) - def test_error_when_transformer_not_unique(self): + def test_error_when_transformer_not_unique(self) -> None: scaler = FixedScaler(3) t = GroupedTransformer(transformer_mapping=[("b", scaler), ("a", scaler)]) with pytest.raises(ValueError): t.self_optimize([pd.DataFrame(np.ones((10, 3)))]) - def test_error_attempting_double_transform(self): + def test_error_attempting_double_transform(self) -> None: t = GroupedTransformer(transformer_mapping=[("b", FixedScaler()), ("b", FixedScaler())]) with pytest.raises(ValueError): t.transform(pd.DataFrame(np.ones((10, 3)))) - def test_optimization(self): + def test_optimization(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) scale_vals = [1, 2, 3] train_data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) * scale_vals @@ -123,7 +123,7 @@ def test_optimization(self): class TestChainedTransformer: - def test_simple_chaining(self): + def test_simple_chaining(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) * 2 t = ChainedTransformer(chain=[("first", FixedScaler(3, 1)), ("second", FixedScaler(2))]) make_action_safe(t.transform)(t, data) @@ -131,7 +131,7 @@ def test_simple_chaining(self): assert id(t.data) == id(data) assert_frame_equal(t.transformed_data_, (data - 1) / 3 / 2) - def test_chaining_with_training(self): + def test_chaining_with_training(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) train_data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) * 5 # The first scaler is expected to learn the original data scale (5), the second scaler is ecpected to learn @@ -147,13 +147,13 @@ def test_chaining_with_training(self): assert_frame_equal(t.transformed_data_, data / 5) - def test_error_when_transformer_not_unique(self): + def test_error_when_transformer_not_unique(self) -> None: scaler = FixedScaler(3) t = ChainedTransformer(chain=[("first", scaler), ("second", scaler)]) with pytest.raises(ValueError): t.self_optimize([pd.DataFrame(np.ones((10, 3)))]) - def test_composite_get_set(self): + def test_composite_get_set(self) -> None: t = ChainedTransformer(chain=[("x", FixedScaler()), ("y", FixedScaler(2))]) t.set_params(chain__y__offset=1) params = t.get_params() @@ -163,7 +163,7 @@ def test_composite_get_set(self): class TestParallelTransformer: - def test_simple_parallel(self): + def test_simple_parallel(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) t = ParallelTransformer([("x", FixedScaler(2)), ("y", FixedScaler(3))]) @@ -176,7 +176,7 @@ def test_simple_parallel(self): assert_equal(t.transformed_data_.filter(like="x__").to_numpy(), data.to_numpy() / 2) assert_equal(t.transformed_data_.filter(like="y__").to_numpy(), data.to_numpy() / 3) - def test_with_optimization(self): + def test_with_optimization(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) train_data = pd.DataFrame(np.ones((10, 3)), columns=list("abc")) * 5 # Both scaler are trained independently and should learn the same thing. @@ -192,19 +192,19 @@ def test_with_optimization(self): assert_equal(t.transformed_data_.filter(like="x__").to_numpy(), data.to_numpy() / 5) assert_equal(t.transformed_data_.filter(like="y__").to_numpy(), data.to_numpy() / 5) - def test_error_when_transformer_not_unique(self): + def test_error_when_transformer_not_unique(self) -> None: scaler = FixedScaler(3) t = ParallelTransformer([("b", scaler), ("a", scaler)]) with pytest.raises(ValueError): t.self_optimize([pd.DataFrame(np.ones((10, 3)))]) - @pytest.mark.parametrize("transformer", (["bla"], [("x", FixedScaler()), ("x", 3)])) - def test_invalid_transformer_mappings(self, transformer): + @pytest.mark.parametrize("transformer", [["bla"], [("x", FixedScaler()), ("x", 3)]]) + def test_invalid_transformer_mappings(self, transformer) -> None: t = ParallelTransformer(transformer) with pytest.raises(TypeError): t.self_optimize([pd.DataFrame(np.ones((10, 3)))]) - def test_composite_get_set(self): + def test_composite_get_set(self) -> None: t = ParallelTransformer(transformers=[("x", FixedScaler()), ("y", FixedScaler(2))]) t.set_params(transformers__y__offset=1) params = t.get_params() diff --git a/tests/test_data_transforms/test_feature_transformer.py b/tests/test_data_transforms/test_feature_transformer.py index 570a84e1..e59bfcd6 100644 --- a/tests/test_data_transforms/test_feature_transformer.py +++ b/tests/test_data_transforms/test_feature_transformer.py @@ -47,7 +47,7 @@ class TestMetaFunctionalityRollingTransforms(TestAlgorithmMixin): @pytest.fixture( params=[*all_rolling_transformer, (SlidingWindowGradient, {"window_size_s": 1}, None)], autouse=True ) - def set_algo_class(self, request): + def set_algo_class(self, request) -> None: self.algorithm_class, self.algo_params, _ = request.param @pytest.fixture() @@ -70,7 +70,7 @@ def df_or_series_imu_data(request, healthy_example_imu_data): class TestResample: - def test_resample(self, df_or_series_imu_data): + def test_resample(self, df_or_series_imu_data) -> None: instance = Resample(target_sampling_rate_hz=10) after_instance = instance.transform(df_or_series_imu_data, sampling_rate_hz=100) assert after_instance.transformed_data_.shape[0] == 100 @@ -79,7 +79,7 @@ def test_resample(self, df_or_series_imu_data): assert hasattr(after_instance, "roi_list") is False assert hasattr(after_instance, "transformed_roi_list_") is False - def test_new_sampling_rate_is_equal_to_old(self, df_or_series_imu_data, healthy_example_stride_borders): + def test_new_sampling_rate_is_equal_to_old(self, df_or_series_imu_data, healthy_example_stride_borders) -> None: compare_func = assert_series_equal if isinstance(df_or_series_imu_data, pd.Series) else assert_frame_equal instance = Resample(target_sampling_rate_hz=100) @@ -91,7 +91,7 @@ def test_new_sampling_rate_is_equal_to_old(self, df_or_series_imu_data, healthy_ assert_frame_equal(after_instance.transformed_roi_list_, healthy_example_stride_borders["left_sensor"]) assert after_instance.transformed_roi_list_ is not healthy_example_stride_borders["left_sensor"] - def test_resample_roi_list(self, healthy_example_stride_borders): + def test_resample_roi_list(self, healthy_example_stride_borders) -> None: in_stride_borders = healthy_example_stride_borders["left_sensor"].astype({"start": "Int64", "end": "Int64"}) instance = Resample(target_sampling_rate_hz=10) after_instance = instance.transform(roi_list=in_stride_borders, sampling_rate_hz=100) @@ -110,7 +110,7 @@ def test_resample_roi_list(self, healthy_example_stride_borders): assert hasattr(after_instance, "transformed_data_") is False assert hasattr(after_instance, "data") is False - def test_resample_data_and_roi_list(self, df_or_series_imu_data, healthy_example_stride_borders): + def test_resample_data_and_roi_list(self, df_or_series_imu_data, healthy_example_stride_borders) -> None: in_stride_borders = healthy_example_stride_borders["left_sensor"] instance = Resample(target_sampling_rate_hz=10) after_instance = instance.transform(df_or_series_imu_data, roi_list=in_stride_borders, sampling_rate_hz=100) @@ -119,14 +119,14 @@ def test_resample_data_and_roi_list(self, df_or_series_imu_data, healthy_example assert after_instance.roi_list is in_stride_borders assert after_instance.data is df_or_series_imu_data - def test_require_sampling_rate(self): + def test_require_sampling_rate(self) -> None: instance = Resample(target_sampling_rate_hz=10) with pytest.raises(ValueError) as e: instance.transform() assert "sampling_rate_hz" in str(e.value) - def test_require_target_sampling_rate(self): + def test_require_target_sampling_rate(self) -> None: instance = Resample() with pytest.raises(ValueError) as e: instance.transform(sampling_rate_hz=10) @@ -141,19 +141,19 @@ class _TestSlidingWindowTransformer: @pytest.mark.parametrize( ("window_size_s", "effective_win_size"), [(1, 101), (0.5, 51), (0.1, 11), (0.23, 23), (0.111, 11)] ) - def test_effective_window_size_samples(self, healthy_example_imu_data, window_size_s, effective_win_size): + def test_effective_window_size_samples(self, healthy_example_imu_data, window_size_s, effective_win_size) -> None: data_left = healthy_example_imu_data["left_sensor"].iloc[:100] data_left.columns = BF_COLS instance = self.algorithm_class(window_size_s=window_size_s) after_instance = instance.transform(data_left, sampling_rate_hz=100) assert after_instance.effective_window_size_samples_ == effective_win_size - def test_window_size_s_required(self): + def test_window_size_s_required(self) -> None: with pytest.raises(ValueError) as e: self.algorithm_class().transform([], sampling_rate_hz=100) assert "window_size_s" in str(e.value) - def test_sampling_rate_required(self): + def test_sampling_rate_required(self) -> None: with pytest.raises(ValueError) as e: self.algorithm_class(window_size_s=1).transform([]) assert "sampling_rate_hz" in str(e.value) @@ -167,11 +167,11 @@ class TestSlidingWindowTransformers(_TestSlidingWindowTransformer): equivalent_method: Callable @pytest.fixture(params=all_rolling_transformer, autouse=True) - def set_algo_class(self, request): + def set_algo_class(self, request) -> None: self.algorithm_class, self.algo_params, self.equivalent_method = request.param @pytest.mark.parametrize("win_size_s", [0.1, 0.2]) - def test_output(self, df_or_series_imu_data, win_size_s): + def test_output(self, df_or_series_imu_data, win_size_s) -> None: """Test the output shape and the values for the first couple of windows.""" instance = self.algorithm_class(window_size_s=win_size_s) after_instance = instance.transform(df_or_series_imu_data, sampling_rate_hz=100) @@ -219,7 +219,7 @@ class TestSlidingWindowGradient(_TestSlidingWindowTransformer): algorithm_class = SlidingWindowGradient @pytest.mark.parametrize("win_size_s", [0.1, 0.2]) - def test_output(self, df_or_series_imu_data, win_size_s): + def test_output(self, df_or_series_imu_data, win_size_s) -> None: """Test the output shape and the values for the first couple of windows.""" instance = SlidingWindowGradient(window_size_s=win_size_s) after_instance = instance.transform(df_or_series_imu_data, sampling_rate_hz=100) diff --git a/tests/test_data_transforms/test_filter.py b/tests/test_data_transforms/test_filter.py index a91ccd21..66b60b40 100644 --- a/tests/test_data_transforms/test_filter.py +++ b/tests/test_data_transforms/test_filter.py @@ -14,7 +14,7 @@ class TestButterworthMetaFunctionality(TestAlgorithmMixin): algorithm_class = ButterworthFilter - def test_empty_init(self): + def test_empty_init(self) -> None: pytest.skip() @pytest.fixture() @@ -27,8 +27,8 @@ def after_action_instance(self, healthy_example_imu_data, healthy_example_stride class TestButterworth: - @pytest.mark.parametrize("in_val", (pd.DataFrame(np.random.rand(50, 3)), pd.Series(np.random.rand(50)))) - def test_input_type_and_shape_conserved(self, in_val): + @pytest.mark.parametrize("in_val", [pd.DataFrame(np.random.rand(50, 3)), pd.Series(np.random.rand(50))]) + def test_input_type_and_shape_conserved(self, in_val) -> None: filter = ButterworthFilter(1, 5) before_dtype = type(in_val) @@ -49,7 +49,7 @@ def test_input_type_and_shape_conserved(self, in_val): assert filter.transformed_data_.shape == before_shape assert isinstance(filter.transformed_data_, before_dtype) - def test_filter_is_applied_correctly(self): + def test_filter_is_applied_correctly(self) -> None: filter = ButterworthFilter(1, 5, "highpass") data = pd.DataFrame(np.random.rand(50, 3)) sampling_rate_hz = 100 diff --git a/tests/test_data_transforms/test_scalers.py b/tests/test_data_transforms/test_scalers.py index 338c0cbb..68af7d4c 100644 --- a/tests/test_data_transforms/test_scalers.py +++ b/tests/test_data_transforms/test_scalers.py @@ -32,7 +32,7 @@ class TestMetaFunctionality(TestAlgorithmMixin): __test__ = True @pytest.fixture(params=all_scaler, autouse=True) - def set_algo_class(self, request): + def set_algo_class(self, request) -> None: self.algorithm_class = request.param @pytest.fixture() @@ -47,7 +47,7 @@ def after_action_instance(self, healthy_example_imu_data, healthy_example_stride class TestFixedScaler: - def test_transform_default(self): + def test_transform_default(self) -> None: t = FixedScaler() data = pd.DataFrame(np.random.rand(10, 10)) t.transform(data) @@ -56,7 +56,7 @@ def test_transform_default(self): assert t.transformed_data_.equals(data) @pytest.mark.parametrize(("scale", "offset"), [(1, 1), (2, 1), (3, 2)]) - def test_transform(self, scale, offset): + def test_transform(self, scale, offset) -> None: t = FixedScaler(scale=scale, offset=offset) data = pd.DataFrame(np.random.rand(10, 10)) t.transform(data) @@ -66,7 +66,7 @@ def test_transform(self, scale, offset): class TestStandardScaler: @pytest.mark.parametrize("ddof", [0, 1, 2]) - def test_transform(self, ddof): + def test_transform(self, ddof) -> None: t = StandardScaler(ddof=ddof) data = pd.DataFrame(np.random.rand(10, 10)) t.transform(data) @@ -77,7 +77,7 @@ def test_transform(self, ddof): class TestTrainableStandardScaler: @pytest.mark.parametrize("ddof", [0, 1, 2]) - def test_transform(self, ddof): + def test_transform(self, ddof) -> None: t = TrainableStandardScaler(ddof=ddof) train_data = pd.DataFrame(np.random.rand(10, 10)) test_data = pd.DataFrame(np.random.rand(10, 10)) @@ -100,7 +100,7 @@ def test_transform(self, ddof): (test_data - train_data.to_numpy().mean()) / train_data.to_numpy().std(ddof=ddof) ) - def test_iterative_std_calculation(self): + def test_iterative_std_calculation(self) -> None: test_data = [pd.DataFrame(np.random.rand(10, 10)) for _ in range(5)] test_data_concatenated = pd.concat(test_data) t = TrainableStandardScaler() @@ -109,7 +109,7 @@ def test_iterative_std_calculation(self): assert t.std == pytest.approx(test_data_concatenated.to_numpy().std(ddof=t.ddof), rel=1e-5) - def test_raise_error_before_optimization(self): + def test_raise_error_before_optimization(self) -> None: t = TrainableStandardScaler() with pytest.raises(ValueError): t.transform(pd.DataFrame(np.random.rand(10, 10))) @@ -118,7 +118,7 @@ def test_raise_error_before_optimization(self): class TestAbsMaxScaler: @pytest.mark.parametrize("out_max", [2, 3, 0.3]) @pytest.mark.parametrize("data_factor", [1, -2, 0.3]) - def test_transform(self, out_max, data_factor): + def test_transform(self, out_max, data_factor) -> None: t = AbsMaxScaler(out_max=out_max) data = pd.DataFrame([[0, 1, 1], [2, 1, 3]]) * data_factor t.transform(data) @@ -129,7 +129,7 @@ def test_transform(self, out_max, data_factor): class TestTrainableAbsMaxScaler: @pytest.mark.parametrize("out_max", [2, 3, 0.3]) - def test_transform(self, out_max): + def test_transform(self, out_max) -> None: t = TrainableAbsMaxScaler(out_max=out_max) train_data = pd.DataFrame(np.random.rand(10, 10)) test_data = pd.DataFrame(np.random.rand(10, 10)) @@ -148,7 +148,7 @@ def test_transform(self, out_max): assert id(t.data) == id(test_data) assert_frame_equal(t.transformed_data_, test_data / np.max(np.abs(train_data.to_numpy())) * out_max) - def test_raise_error_before_optimization(self): + def test_raise_error_before_optimization(self) -> None: t = TrainableAbsMaxScaler() with pytest.raises(ValueError): t.transform(pd.DataFrame(np.random.rand(10, 10))) @@ -156,7 +156,7 @@ def test_raise_error_before_optimization(self): class TestMinMaxScaler: @pytest.mark.parametrize("out_range", [(0, 1), (0, 2), (-1, 1), (-1, 2)]) - def test_transform(self, out_range): + def test_transform(self, out_range) -> None: t = MinMaxScaler(out_range=out_range) data = pd.DataFrame(np.random.rand(10, 10)) t.transform(data) @@ -165,7 +165,7 @@ def test_transform(self, out_range): assert t.transformed_data_.to_numpy().max() == pytest.approx(out_range[1], rel=1e-3) @pytest.mark.parametrize("out_range", [(0, 0), (1, -1), (2, 2)]) - def test_raise_error_for_invalid_out_range(self, out_range): + def test_raise_error_for_invalid_out_range(self, out_range) -> None: data = pd.DataFrame(np.random.rand(10, 10)) t = MinMaxScaler(out_range=out_range) with pytest.raises(ValueError): @@ -174,7 +174,7 @@ def test_raise_error_for_invalid_out_range(self, out_range): class TestTrainableMinMaxScaler: @pytest.mark.parametrize("out_range", [(0, 1), (0, 2), (-1, 1), (-1, 2)]) - def test_transform(self, out_range): + def test_transform(self, out_range) -> None: t = TrainableMinMaxScaler(out_range=out_range) train_data = pd.DataFrame(np.random.rand(10, 10)) test_data = pd.DataFrame(np.random.rand(10, 10)) @@ -200,7 +200,7 @@ def test_transform(self, out_range): + out_range[0], ) - def test_raise_error_before_optimization(self): + def test_raise_error_before_optimization(self) -> None: t = TrainableMinMaxScaler() with pytest.raises(ValueError): t.transform(pd.DataFrame(np.random.rand(10, 10))) diff --git a/tests/test_evaluation_utlis/test_event_detection.py b/tests/test_evaluation_utlis/test_event_detection.py index 5b943499..f879d58e 100644 --- a/tests/test_evaluation_utlis/test_event_detection.py +++ b/tests/test_evaluation_utlis/test_event_detection.py @@ -20,8 +20,8 @@ def _create_valid_list(self, labels, extra_columns=None): return df - @pytest.mark.parametrize("value", (["1", "2", "3"], "wrong_column", "")) - def test_invalid_column_values(self, value): + @pytest.mark.parametrize("value", [["1", "2", "3"], "wrong_column", ""]) + def test_invalid_column_values(self, value) -> None: sl = self._create_valid_list([[0, 1, 10], [1, 2, 20]], "ic") with pytest.raises(ValueError) as e: @@ -30,7 +30,7 @@ def test_invalid_column_values(self, value): assert "One or more selected columns" in str(e.value) assert str(value) in str(e.value) - def test_perfect_match(self): + def test_perfect_match(self) -> None: sl = self._create_valid_list([[0, 1, 10], [1, 2, 20], [2, 3, 30]], "ic") out = evaluate_stride_event_list(ground_truth=sl, stride_event_list=sl, match_cols="ic", tolerance=0) @@ -42,7 +42,7 @@ def test_perfect_match(self): assert len(out["fn"]) == 0 assert len(out) == (len(out["tp"]) + len(out["fn"])) - def test_match(self): + def test_match(self) -> None: sl1 = self._create_valid_list([[0, 1, 0], [1, 2, 20], [2, 3, 30]], "ic") sl2 = self._create_valid_list([[0, 1, 10], [1, 2, 20], [2, 3, 30]], "ic") diff --git a/tests/test_evaluation_utlis/test_parameter_errors.py b/tests/test_evaluation_utlis/test_parameter_errors.py index 4ab900e0..be396174 100644 --- a/tests/test_evaluation_utlis/test_parameter_errors.py +++ b/tests/test_evaluation_utlis/test_parameter_errors.py @@ -3,6 +3,7 @@ NOTE: I decided not the check every single error value and trust that the internal functions (as we are just calling pandas functions) handle calculation of the error correctly. """ + import doctest import numpy as np @@ -72,7 +73,7 @@ class TestCalculateAggregatedParameterErrors: ), ], ) - def test_invalid_input(self, input_param, ground_truth, expected_error): + def test_invalid_input(self, input_param, ground_truth, expected_error) -> None: with pytest.raises(ValidationError) as e: calculate_aggregated_parameter_errors(predicted_parameter=input_param, reference_parameter=ground_truth) @@ -110,7 +111,7 @@ def test_invalid_input(self, input_param, ground_truth, expected_error): ), ], ) - def test_valid_single_sensor_input(self, input_param, ground_truth, expectation): + def test_valid_single_sensor_input(self, input_param, ground_truth, expectation) -> None: output_normal = calculate_aggregated_parameter_errors( predicted_parameter=input_param, reference_parameter=ground_truth ) @@ -163,7 +164,7 @@ def test_valid_single_sensor_input(self, input_param, ground_truth, expectation) ), ], ) - def test_valid_multi_sensor_input(self, input_param, ground_truth, sensor_names, expectations): + def test_valid_multi_sensor_input(self, input_param, ground_truth, sensor_names, expectations) -> None: output_normal = calculate_aggregated_parameter_errors( predicted_parameter=input_param, reference_parameter=ground_truth ) @@ -207,7 +208,7 @@ def test_valid_multi_sensor_input(self, input_param, ground_truth, sensor_names, ), ], ) - def test_calculate_not_per_sensor_input(self, input_param, ground_truth, expectation): + def test_calculate_not_per_sensor_input(self, input_param, ground_truth, expectation) -> None: output_normal = calculate_aggregated_parameter_errors( predicted_parameter=input_param, reference_parameter=ground_truth, calculate_per_sensor=False ) @@ -216,7 +217,7 @@ def test_calculate_not_per_sensor_input(self, input_param, ground_truth, expecta assert_array_equal(np.round(output_normal.loc[error_type], 5), expectation[error_type]) @pytest.mark.parametrize("per_sensor", [True, False]) - def test_n_strides_missing(self, per_sensor): + def test_n_strides_missing(self, per_sensor) -> None: input_param = _create_valid_input(["param"], [[1, 2, 3], [4, 5, np.nan]], is_dict=True, sensors=["1", "2"]) ground_truth = _create_valid_input( ["param"], [[1, np.nan, np.nan], [4, 5, 6]], is_dict=True, sensors=["1", "2"] @@ -239,7 +240,7 @@ def test_n_strides_missing(self, per_sensor): assert output["param"].loc["n_additional_predicted"] == 2 @pytest.mark.parametrize("single_sensor", [True, False]) - def test_n_strides_missing_multi_param(self, single_sensor): + def test_n_strides_missing_multi_param(self, single_sensor) -> None: if single_sensor: input_param = _create_valid_input( ["param1", "param2"], [[[1, 2, 3], [4, 5, np.nan]]], is_dict=True, sensors=["1"] @@ -288,12 +289,12 @@ def test_n_strides_missing_multi_param(self, single_sensor): assert param2.loc["n_additional_reference"] == 1 assert param2.loc["n_additional_predicted"] == 0 - def test_doctest(self): + def test_doctest(self) -> None: pytest.importorskip("statsmodels") doctest_results = doctest.testmod(m=parameter_errors) assert doctest_results.failed == 0 - def test_calculate_per_sensor(self): + def test_calculate_per_sensor(self) -> None: input_param = _create_valid_input(["param"], [1, 2, 3], is_dict=False) ground_truth = _create_valid_input(["param"], [1, 2, 3], is_dict=False) with_per_sensor = calculate_aggregated_parameter_errors( @@ -310,7 +311,7 @@ def test_calculate_per_sensor(self): class TestCalculateParameterErrors: # We don't test a lot here, as this method is used internally by calculate_aggregated_parameter_errors # So all the tests are done there - def test_simple(self): + def test_simple(self) -> None: predicted_parameter = _create_valid_input( ["param1", "param2"], [[[1, 2, 3], [4, 5, 6]]], is_dict=True, sensors=["1"] ) @@ -342,7 +343,7 @@ def test_simple(self): "n_additional_predicted", } - def test_simple_non_dict(self): + def test_simple_non_dict(self) -> None: predicted_parameter = _create_valid_input( ["param1", "param2"], np.array([[1, 2, 3], [4, 5, 6]]).T, is_dict=False ) diff --git a/tests/test_evaluation_utlis/test_scores.py b/tests/test_evaluation_utlis/test_scores.py index b12ea990..1ca0101e 100644 --- a/tests/test_evaluation_utlis/test_scores.py +++ b/tests/test_evaluation_utlis/test_scores.py @@ -22,112 +22,112 @@ def _create_valid_matches_df(self, tp, fp, fn): return pd.concat([tp_df, fp_df, fn_df]) - def test_precision_single(self): + def test_precision_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) precision = precision_score(matches_df) assert_array_equal(precision, 0.6) - def test_perfect_precision_single(self): + def test_perfect_precision_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [], [10, 11, 12, 13]) precision = precision_score(matches_df) assert_array_equal(precision, 1.0) - def test_precision_multi(self): + def test_precision_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) precision = precision_score({"sensor": matches_df}) assert_array_equal(precision["sensor"], 0.6) - def test_perfect_precision_multi(self): + def test_perfect_precision_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [], [10, 11, 12, 13]) precision = precision_score({"sensor": matches_df}) assert_array_equal(precision["sensor"], 1.0) - def test_recall_single(self): + def test_recall_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) recall = recall_score(matches_df) assert_array_equal(recall, 0.6) - def test_perfect_recall_single(self): + def test_perfect_recall_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], []) recall = recall_score(matches_df) assert_array_equal(recall, 1.0) - def test_recall_multi(self): + def test_recall_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) recall = recall_score({"sensor": matches_df}) assert_array_equal(recall["sensor"], 0.6) - def test_perfect_recall_multi(self): + def test_perfect_recall_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], []) recall = recall_score({"sensor": matches_df}) assert_array_equal(recall["sensor"], 1.0) - def test_f1_score_single(self): + def test_f1_score_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) f1 = f1_score(matches_df) assert_array_equal(f1, 0.6) - def test_perfect_f1_score_single(self): + def test_perfect_f1_score_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [], []) f1 = f1_score(matches_df) assert_array_equal(f1, 1.0) - def test_f1_score_multi(self): + def test_f1_score_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) f1 = f1_score({"sensor": matches_df}) assert_array_equal(f1["sensor"], 0.6) - def test_perfect_f1_score_multi(self): + def test_perfect_f1_score_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [], []) f1 = f1_score({"sensor": matches_df}) assert_array_equal(f1["sensor"], 1.0) - def test_precision_recall_f1_single(self): + def test_precision_recall_f1_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) eval_metrics = precision_recall_f1_score(matches_df) assert_array_equal(list(eval_metrics.values()), [0.6, 0.6, 0.6]) - def test_perfect_precision_recall_f1_single(self): + def test_perfect_precision_recall_f1_single(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [], []) eval_metrics = precision_recall_f1_score(matches_df) assert_array_equal(list(eval_metrics.values()), [1.0, 1.0, 1.0]) - def test_precision_recall_f1_multi(self): + def test_precision_recall_f1_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]) eval_metrics = precision_recall_f1_score({"sensor": matches_df}) assert_array_equal(list(eval_metrics["sensor"].values()), [0.6, 0.6, 0.6]) - def test_perfect_precision_recall_f1_multi(self): + def test_perfect_precision_recall_f1_multi(self) -> None: matches_df = self._create_valid_matches_df([0, 1, 2, 3, 4, 5], [], []) eval_metrics = precision_recall_f1_score({"sensor": matches_df}) @@ -169,10 +169,10 @@ class TestDivisionByZeroReturn: [precision_recall_f1_score, [[], [], []], 1, 1], ), ) - def make_methods(self, request): + def make_methods(self, request) -> None: self.func, self.arguments, self.zero_division, self.expected_output = request.param - def test_division_by_zero_return(self): + def test_division_by_zero_return(self) -> None: matches_df = _create_valid_matches_df(*self.arguments) eval_metrics = self.func(matches_df, zero_division=self.zero_division) @@ -193,10 +193,10 @@ class TestDivisionByZeroWarnings: [precision_recall_f1_score, [[], [], []], "warn", "calculating the f1 score"], ), ) - def make_methods(self, request): + def make_methods(self, request) -> None: self.func, self.arguments, self.zero_division, self.warning_message = request.param - def test_division_by_zero_warnings(self): + def test_division_by_zero_warnings(self) -> None: with pytest.warns(UserWarning) as w: self.func(_create_valid_matches_df(*self.arguments), zero_division=self.zero_division) @@ -218,10 +218,10 @@ class TestDivisionByZeroError: [precision_recall_f1_score, [[], [], []], 2], ), ) - def make_methods(self, request): + def make_methods(self, request) -> None: self.func, self.arguments, self.zero_division = request.param - def test_division_by_zero_warnings(self): + def test_division_by_zero_warnings(self) -> None: with pytest.raises(ValueError) as e: self.func(_create_valid_matches_df(*self.arguments), zero_division=self.zero_division) diff --git a/tests/test_evaluation_utlis/test_stride_segmentation.py b/tests/test_evaluation_utlis/test_stride_segmentation.py index f8ec1230..788d3fb8 100644 --- a/tests/test_evaluation_utlis/test_stride_segmentation.py +++ b/tests/test_evaluation_utlis/test_stride_segmentation.py @@ -15,7 +15,7 @@ def _create_valid_list(self, labels): df.index.name = "s_id" return df - def test_invalid_stride_list(self): + def test_invalid_stride_list(self) -> None: sl = self._create_valid_list([[0, 1], [1, 2]]) with pytest.raises(ValidationError) as e: @@ -28,7 +28,7 @@ def test_invalid_stride_list(self): assert "SingleSensorStrideList" in str(e.value) - def test_invalid_postfix(self): + def test_invalid_postfix(self) -> None: sl = self._create_valid_list([[0, 1], [1, 2]]) with pytest.raises(ValueError) as e: @@ -36,7 +36,7 @@ def test_invalid_postfix(self): assert "The postfix" in str(e) - def test_invalid_tolerance(self): + def test_invalid_tolerance(self) -> None: sl = self._create_valid_list([[0, 1], [1, 2]]) with pytest.raises(ValueError) as e: @@ -44,15 +44,15 @@ def test_invalid_tolerance(self): assert "larger 0" in str(e) - def test_change_postfix(self): + def test_change_postfix(self) -> None: sl = self._create_valid_list([[0, 1], [1, 2]]) out = match_stride_lists(stride_list_a=sl, stride_list_b=sl, postfix_a="_a_different", postfix_b="_b_different") assert_array_equal(list(out.columns), ["s_id_a_different", "s_id_b_different"]) - @pytest.mark.parametrize("one_to_one", (True, False)) - def test_simple_one_to_one_match_tolerance(self, one_to_one): + @pytest.mark.parametrize("one_to_one", [True, False]) + def test_simple_one_to_one_match_tolerance(self, one_to_one) -> None: list_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) @@ -61,7 +61,7 @@ def test_simple_one_to_one_match_tolerance(self, one_to_one): assert_array_equal(out["s_id_a"].to_numpy(), [0, 1, 2, 3]) assert_array_equal(out["s_id_b"].to_numpy(), [0, 1, 2, 3]) - def test_simple_match_with_tolerance(self): + def test_simple_match_with_tolerance(self) -> None: list_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) list_left += 0.1 list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) @@ -74,7 +74,7 @@ def test_simple_match_with_tolerance(self): assert_array_equal(out["s_id_a"].to_numpy(), [0, 1, 2, 3]) assert_array_equal(out["s_id_b"].to_numpy(), [0, 1, 2, 3]) - def test_simple_missing_strides_no_tolerance(self): + def test_simple_missing_strides_no_tolerance(self) -> None: list_left = self._create_valid_list([[0, 1], [2, 3], [3, 4]]) list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3]]) @@ -83,7 +83,7 @@ def test_simple_missing_strides_no_tolerance(self): assert_array_equal(out["s_id_a"].to_numpy().astype(float), [0.0, 1, 2, np.nan]) assert_array_equal(out["s_id_b"].to_numpy().astype(float), [0.0, 2, np.nan, 1]) - def test_simple_double_match_no_tolerance(self): + def test_simple_double_match_no_tolerance(self) -> None: list_left = self._create_valid_list([[0, 1], [1, 2], [1, 2], [2, 3], [3, 4]]) list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4], [3, 4]]) @@ -92,7 +92,7 @@ def test_simple_double_match_no_tolerance(self): assert_array_equal(out["s_id_a"].to_numpy(), [0, 1, 2, 3, 4, 4]) assert_array_equal(out["s_id_b"].to_numpy(), [0, 1, 1, 2, 3, 4]) - def test_simple_double_match_no_tolerance_enforce_one_to_one(self): + def test_simple_double_match_no_tolerance_enforce_one_to_one(self) -> None: list_left = self._create_valid_list([[0, 1], [1, 2], [1, 2], [2, 3], [3, 4]]) list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4], [3, 4]]) @@ -101,7 +101,7 @@ def test_simple_double_match_no_tolerance_enforce_one_to_one(self): assert_array_equal(out["s_id_a"].to_numpy().astype(float), [0, 1, 2, 3, 4, np.nan]) assert_array_equal(out["s_id_b"].to_numpy().astype(float), [0, 1, np.nan, 2, 3, 4]) - def test_double_match_with_tolerance_enforce_one_to_one(self): + def test_double_match_with_tolerance_enforce_one_to_one(self) -> None: list_left = self._create_valid_list([[0, 1], [1.1, 2.1], [1, 2], [2, 3], [3, 4]]) list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3.1, 3.9], [3, 4]]) @@ -110,7 +110,7 @@ def test_double_match_with_tolerance_enforce_one_to_one(self): assert_array_equal(out["s_id_a"].to_numpy().astype(float), [0, 1, 2, 3, 4, np.nan]) assert_array_equal(out["s_id_b"].to_numpy().astype(float), [0, np.nan, 1, 2, 4, 3]) - def test_one_sided_double_match_no_tolerance_enforce_one_to_one(self): + def test_one_sided_double_match_no_tolerance_enforce_one_to_one(self) -> None: list_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4], [1, 2]]) list_right = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) @@ -119,9 +119,9 @@ def test_one_sided_double_match_no_tolerance_enforce_one_to_one(self): assert_array_equal(out["s_id_a"].to_numpy().astype(float), [0, 1, 2, 3, 4]) assert_array_equal(out["s_id_b"].to_numpy().astype(float), [0, 1, 2, 3, np.nan]) - @pytest.mark.parametrize("side", ("a", "b")) - def test_empty_stride_lists(self, side): - opposite = [s for s in ("a", "b") if s != side][0] + @pytest.mark.parametrize("side", ["a", "b"]) + def test_empty_stride_lists(self, side) -> None: + opposite = next(s for s in ("a", "b") if s != side) sl = self._create_valid_list([[0, 1], [1, 2], [2, 3]]) empty = self._create_valid_list([]) @@ -133,14 +133,14 @@ def test_empty_stride_lists(self, side): assert_array_equal(out["s_id_" + side].to_numpy().astype(float), [np.nan, np.nan, np.nan]) assert_array_equal(out["s_id_" + opposite].to_numpy().astype(float), [0.0, 1, 2]) - def test_empty_stride_lists_both(self): + def test_empty_stride_lists_both(self) -> None: empty = self._create_valid_list([]) out = match_stride_lists(stride_list_a=empty, stride_list_b=empty) assert len(out) == 0 - def test_multi_stride_lists_no_tolerance(self): + def test_multi_stride_lists_no_tolerance(self) -> None: stride_list_left_a = self._create_valid_list([[0, 1], [2, 3], [4, 5], [6, 7]]) stride_list_right_a = self._create_valid_list([[1, 2], [3, 4], [5, 6]]) multi_stride_list_a = {"left_sensor": stride_list_left_a, "right_sensor": stride_list_right_a} @@ -157,7 +157,7 @@ def test_multi_stride_lists_no_tolerance(self): assert_array_equal(out["right_sensor"]["s_id_a"].to_numpy().astype(float), [0, 1, 2]) assert_array_equal(out["right_sensor"]["s_id_b"].to_numpy().astype(float), [1, 0, 2]) - def test_multi_stride_lists_with_tolerance(self): + def test_multi_stride_lists_with_tolerance(self) -> None: stride_list_left_a = self._create_valid_list([[0, 1], [2, 3], [4, 5], [6, 7]]) stride_list_right_a = self._create_valid_list([[1, 2], [3, 4], [5, 6]]) multi_stride_list_a = {"left_sensor": stride_list_left_a, "right_sensor": stride_list_right_a} @@ -174,7 +174,7 @@ def test_multi_stride_lists_with_tolerance(self): assert_array_equal(out["right_sensor"]["s_id_a"].to_numpy().astype(float), [0, 1, 2, np.nan]) assert_array_equal(out["right_sensor"]["s_id_b"].to_numpy().astype(float), [0, 1, np.nan, 2]) - def test_empty_multi_stride_lists_both(self): + def test_empty_multi_stride_lists_both(self) -> None: empty = self._create_valid_list([]) out = match_stride_lists(stride_list_a={"left": empty}, stride_list_b={"left": empty}) @@ -182,7 +182,7 @@ def test_empty_multi_stride_lists_both(self): for dataframe in out.values(): assert dataframe.empty - def test_empty_multi_stride_lists(self): + def test_empty_multi_stride_lists(self) -> None: full = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) empty = self._create_valid_list([]) @@ -201,7 +201,7 @@ def test_empty_multi_stride_lists(self): assert "object does not contain any data/contains no sensors" in str(e.value) - def test_one_multi_one_single_list(self): + def test_one_multi_one_single_list(self) -> None: multi = {"sensor": self._create_valid_list([[0, 1], [2, 3], [4, 5], [6, 7]])} single = self._create_valid_list([[1, 2], [3, 4], [5, 6]]) @@ -215,7 +215,7 @@ def test_one_multi_one_single_list(self): assert "not of same type" in str(e) - def test_no_common_sensors_multi_stride_lists(self): + def test_no_common_sensors_multi_stride_lists(self) -> None: full = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) with pytest.raises(ValidationError) as e: @@ -223,7 +223,7 @@ def test_no_common_sensors_multi_stride_lists(self): assert "do not have any common sensors" in str(e) - def test_some_common_sensors_multi_stride_lists(self): + def test_some_common_sensors_multi_stride_lists(self) -> None: stride_list_left_a = self._create_valid_list([[0, 1], [2, 3], [4, 5], [6, 7]]) stride_list_right_a = self._create_valid_list([[1, 2], [3, 4], [5, 6]]) multi_stride_list_a = {"left_sensor": stride_list_left_a, "right_sensor": stride_list_right_a} @@ -258,8 +258,8 @@ def _create_valid_list(self, labels, extra_columns=None): df.index.name = "s_id" return df - @pytest.mark.parametrize("value", ("wrong_column", ["1", "2", "3"])) - def test_invalid_match_cols(self, value): + @pytest.mark.parametrize("value", ["wrong_column", ["1", "2", "3"]]) + def test_invalid_match_cols(self, value) -> None: sl = self._create_valid_list([[0, 1, 10], [1, 2, 20]], "ic") with pytest.raises(ValueError) as e: @@ -268,7 +268,7 @@ def test_invalid_match_cols(self, value): assert "One or more selected columns" in str(e.value) assert str(value) in str(e.value) - def test_perfect_match(self): + def test_perfect_match(self) -> None: sl = self._create_valid_list([[0, 1, 10], [1, 2, 20], [2, 3, 30]], "ic") out = match_stride_lists(stride_list_a=sl, stride_list_b=sl, match_cols="ic", tolerance=0) @@ -276,7 +276,7 @@ def test_perfect_match(self): assert_array_equal(out["s_id_a"].to_numpy(), [0, 1, 2]) assert_array_equal(out["s_id_b"].to_numpy(), [0, 1, 2]) - def test_match(self): + def test_match(self) -> None: sl1 = self._create_valid_list([[0, 1, 0], [1, 2, 20], [2, 3, 30]], "ic") sl2 = self._create_valid_list([[0, 1, 10], [1, 2, 20], [2, 3, 30]], "ic") @@ -346,7 +346,7 @@ def test_match(self): ), ], ) - def test_match_label_lists_edgecases(self, input_param, ground_truth, tolerance, one_to_one, expectation): + def test_match_label_lists_edgecases(self, input_param, ground_truth, tolerance, one_to_one, expectation) -> None: output1, output2 = _match_label_lists(input_param, ground_truth, tolerance, one_to_one) assert_array_equal(output1, expectation[0]) @@ -359,14 +359,14 @@ def _create_valid_list(self, labels): df.index.name = "s_id" return df - def test_segmented_stride_list_perfect_match(self): + def test_segmented_stride_list_perfect_match(self) -> None: list_ground_truth = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) list_predicted = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) matches = evaluate_segmented_stride_list(ground_truth=list_ground_truth, segmented_stride_list=list_predicted) assert np.all(matches["match_type"] == "tp") - def test_segmented_stride_list_empty_ground_truth(self): + def test_segmented_stride_list_empty_ground_truth(self) -> None: list_ground_truth = self._create_valid_list([]) list_predicted = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) matches = evaluate_segmented_stride_list(ground_truth=list_ground_truth, segmented_stride_list=list_predicted) @@ -382,7 +382,7 @@ def test_segmented_stride_list_empty_ground_truth(self): ) assert len(list_ground_truth) == (len(matches["tp"]) + len(matches["fn"])) - def test_segmented_stride_list_empty_prediction(self): + def test_segmented_stride_list_empty_prediction(self) -> None: list_ground_truth = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) list_predicted = self._create_valid_list([]) @@ -397,7 +397,7 @@ def test_segmented_stride_list_empty_prediction(self): assert_array_equal(matches["fn"]["s_id"].to_numpy().astype(float), [np.nan, np.nan, np.nan, np.nan]) assert len(list_ground_truth) == (len(matches["tp"]) + len(matches["fn"])) - def test_segmented_stride_list_match(self): + def test_segmented_stride_list_match(self) -> None: list_ground_truth = self._create_valid_list([[20, 30], [30, 40], [40, 50], [50, 60]]) list_predicted = self._create_valid_list([[0, 10], [11, 19], [19, 30], [30, 41], [70, 80], [80, 90]]) @@ -420,7 +420,7 @@ def test_segmented_stride_list_match(self): assert len(list_ground_truth) == (len(matches["tp"]) + len(matches["fn"])) - def test_segmented_stride_list_no_match(self): + def test_segmented_stride_list_no_match(self) -> None: list_ground_truth = self._create_valid_list([[20, 30], [30, 40], [40, 50]]) list_predicted = self._create_valid_list([[60, 70], [70, 80], [90, 100]]) @@ -442,7 +442,7 @@ def test_segmented_stride_list_no_match(self): assert len(list_ground_truth) == (len(matches["tp"]) + len(matches["fn"])) - def test_segmented_stride_list_double_match_predicted_many_to_one(self): + def test_segmented_stride_list_double_match_predicted_many_to_one(self) -> None: list_ground_truth = self._create_valid_list([[20, 30]]) list_predicted = self._create_valid_list([[18, 30], [20, 28]]) @@ -460,7 +460,7 @@ def test_segmented_stride_list_double_match_predicted_many_to_one(self): assert len(list_ground_truth) != (len(matches["tp"]) + len(matches["fn"])) - def test_segmented_stride_list_double_match_predicted_one_to_one(self): + def test_segmented_stride_list_double_match_predicted_one_to_one(self) -> None: list_ground_truth = self._create_valid_list([[20, 30]]) list_predicted = self._create_valid_list([[18, 30], [20, 28]]) @@ -480,7 +480,7 @@ def test_segmented_stride_list_double_match_predicted_one_to_one(self): assert len(list_ground_truth) == (len(matches["tp"]) + len(matches["fn"])) - def test_segmented_multi_stride_list_perfect_match(self): + def test_segmented_multi_stride_list_perfect_match(self) -> None: list_ground_truth_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) list_predicted_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) @@ -495,7 +495,7 @@ def test_segmented_multi_stride_list_perfect_match(self): assert np.all(matches["left_sensor"]["match_type"] == "tp") assert np.all(matches["right_sensor"]["match_type"] == "tp") - def test_segmented_multi_stride_list_empty_ground_truth(self): + def test_segmented_multi_stride_list_empty_ground_truth(self) -> None: list_ground_truth_left = self._create_valid_list([]) list_predicted_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) @@ -529,7 +529,7 @@ def test_segmented_multi_stride_list_empty_ground_truth(self): ) assert len(list_ground_truth_right) == (len(matches["right_sensor"]["tp"]) + len(matches["right_sensor"]["fn"])) - def test_segmented_multi_stride_list_empty_prediction(self): + def test_segmented_multi_stride_list_empty_prediction(self) -> None: list_predicted_left = self._create_valid_list([]) list_ground_truth_left = self._create_valid_list([[0, 1], [1, 2], [2, 3], [3, 4]]) @@ -561,7 +561,7 @@ def test_segmented_multi_stride_list_empty_prediction(self): ) assert len(list_ground_truth_right) == (len(matches["right_sensor"]["tp"]) + len(matches["right_sensor"]["fn"])) - def test_segmented_multi_stride_list_match(self): + def test_segmented_multi_stride_list_match(self) -> None: list_ground_truth = self._create_valid_list([[20, 30], [30, 40], [40, 50], [50, 60]]) list_predicted = self._create_valid_list([[0, 10], [11, 19], [19, 30], [30, 41], [70, 80], [80, 90]]) @@ -585,7 +585,7 @@ def test_segmented_multi_stride_list_match(self): assert len(list_ground_truth) == (len(matches["left"]["tp"]) + len(matches["left"]["fn"])) - def test_segmented_multi_stride_list_no_match(self): + def test_segmented_multi_stride_list_no_match(self) -> None: list_ground_truth = self._create_valid_list([[20, 30], [30, 40], [40, 50]]) list_predicted = self._create_valid_list([[60, 70], [70, 80], [90, 100]]) @@ -607,7 +607,7 @@ def test_segmented_multi_stride_list_no_match(self): assert len(list_ground_truth) == (len(matches["left"]["tp"]) + len(matches["left"]["fn"])) - def test_segmented_multi_stride_list_double_match_predicted_many_to_one(self): + def test_segmented_multi_stride_list_double_match_predicted_many_to_one(self) -> None: list_ground_truth = self._create_valid_list([[20, 30]]) list_predicted = self._create_valid_list([[18, 30], [20, 28]]) @@ -628,7 +628,7 @@ def test_segmented_multi_stride_list_double_match_predicted_many_to_one(self): assert len(list_ground_truth) != (len(matches["left"]["tp"]) + len(matches["left"]["fn"])) - def test_segmented_multi_stride_list_double_match_predicted_one_to_one(self): + def test_segmented_multi_stride_list_double_match_predicted_one_to_one(self) -> None: list_ground_truth = self._create_valid_list([[20, 30]]) list_predicted = self._create_valid_list([[18, 30], [20, 28]]) @@ -651,7 +651,7 @@ def test_segmented_multi_stride_list_double_match_predicted_one_to_one(self): assert len(list_ground_truth) == (len(matches["left"]["tp"]) + len(matches["left"]["fn"])) - def test_one_multi_one_single_list(self): + def test_one_multi_one_single_list(self) -> None: multi = {"sensor": self._create_valid_list([[0, 1], [2, 3], [4, 5], [6, 7]])} single = self._create_valid_list([[1, 2], [3, 4], [5, 6]]) diff --git a/tests/test_event_detection/test_event_detection_filtered_rampp.py b/tests/test_event_detection/test_event_detection_filtered_rampp.py index d8d2b5e7..ee25ead1 100644 --- a/tests/test_event_detection/test_event_detection_filtered_rampp.py +++ b/tests/test_event_detection/test_event_detection_filtered_rampp.py @@ -36,7 +36,9 @@ class TestCachingFunctionality(MetaTestConfig, TestCachingMixin): class TestEventDetectionRamppFiltered(TestEventDetectionRampp): algorithm_class = FilteredRamppEventDetection - def test_is_identical_to_normal_rampp(self, healthy_example_imu_data, healthy_example_stride_borders, snapshot): + def test_is_identical_to_normal_rampp( + self, healthy_example_imu_data, healthy_example_stride_borders, snapshot + ) -> None: """Test if the output is the same as normal Rampp for lax filter parameters.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -51,7 +53,9 @@ def test_is_identical_to_normal_rampp(self, healthy_example_imu_data, healthy_ex assert_frame_equal(ed.segmented_event_list_[sensor], rampp_ed.segmented_event_list_[sensor]) @pytest.mark.parametrize("filter_paras", [(3, 5), (2, 10)]) - def test_correct_arguments_are_passed(self, healthy_example_imu_data, healthy_example_stride_borders, filter_paras): + def test_correct_arguments_are_passed( + self, healthy_example_imu_data, healthy_example_stride_borders, filter_paras + ) -> None: data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] ) diff --git a/tests/test_event_detection/test_event_detection_herzer.py b/tests/test_event_detection/test_event_detection_herzer.py index 897aa811..b9c368c6 100644 --- a/tests/test_event_detection/test_event_detection_herzer.py +++ b/tests/test_event_detection/test_event_detection_herzer.py @@ -41,7 +41,7 @@ class TestCachingFunctionality(MetaTestConfig, TestCachingMixin): class TestEventDetectionHerzer: """Test the event detection by Herzer.""" - def test_multi_sensor_input(self, healthy_example_imu_data, healthy_example_stride_borders, snapshot): + def test_multi_sensor_input(self, healthy_example_imu_data, healthy_example_stride_borders, snapshot) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -55,8 +55,8 @@ def test_multi_sensor_input(self, healthy_example_imu_data, healthy_example_stri snapshot.assert_match(ed.segmented_event_list_["left_sensor"], "left_segmented", check_dtype=False) snapshot.assert_match(ed.segmented_event_list_["right_sensor"], "right_segmented", check_dtype=False) - @pytest.mark.parametrize(("var1", "output"), ((True, 2), (False, 0))) - def test_postprocessing(self, healthy_example_imu_data, healthy_example_stride_borders, var1, output): + @pytest.mark.parametrize(("var1", "output"), [(True, 2), (False, 0)]) + def test_postprocessing(self, healthy_example_imu_data, healthy_example_stride_borders, var1, output) -> None: data_left = healthy_example_imu_data["left_sensor"] data_left.columns = BF_COLS # only use the first entry of the stride list @@ -74,10 +74,10 @@ def mock_func(event_list, *args, **kwargs): assert mock.call_count == output - @pytest.mark.parametrize(("enforce_consistency", "output"), ((False, False), (True, True))) + @pytest.mark.parametrize(("enforce_consistency", "output"), [(False, False), (True, True)]) def test_disable_min_vel_event_list( self, healthy_example_imu_data, healthy_example_stride_borders, enforce_consistency, output - ): + ) -> None: data_left = healthy_example_imu_data["left_sensor"] data_left.columns = BF_COLS # only use the first entry of the stride list @@ -88,7 +88,7 @@ def test_disable_min_vel_event_list( assert hasattr(ed, "min_vel_event_list_") == output - def test_multi_sensor_input_dict(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_multi_sensor_input_dict(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test to see if the algorithm is generally working on the example data when provided as dict.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -107,7 +107,7 @@ def test_multi_sensor_input_dict(self, healthy_example_imu_data, healthy_example assert list(datatype_helper.get_multi_sensor_names(ed.min_vel_event_list_)) == dict_keys assert list(datatype_helper.get_multi_sensor_names(ed.segmented_event_list_)) == dict_keys - def test_equal_output_dict_df(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_equal_output_dict_df(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if output is similar for input dicts or regular multisensor data sets.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -129,7 +129,7 @@ def test_equal_output_dict_df(self, healthy_example_imu_data, healthy_example_st assert_frame_equal(ed_df.min_vel_event_list_["left_sensor"], ed_dict.min_vel_event_list_["l"]) assert_frame_equal(ed_df.min_vel_event_list_["right_sensor"], ed_dict.min_vel_event_list_["r"]) - def test_valid_input_data(self, healthy_example_stride_borders): + def test_valid_input_data(self, healthy_example_stride_borders) -> None: """Test if error is raised correctly on invalid input data type.""" data = pd.DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]}) ed = HerzerEventDetection() @@ -138,7 +138,7 @@ def test_valid_input_data(self, healthy_example_stride_borders): assert "The passed object appears to be neither single- or multi-sensor data" in str(e) - def test_valid_min_vel_search_win_size_ms(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_valid_min_vel_search_win_size_ms(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if error is raised correctly on too large min_vel_search_win_size_ms.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -147,7 +147,7 @@ def test_valid_min_vel_search_win_size_ms(self, healthy_example_imu_data, health with pytest.raises(ValueError, match=r"min_vel_search_win_size_ms is*"): ed.detect(data_left, stride_list_left, sampling_rate_hz=204.8) - def test_input_stride_list_size_one(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_input_stride_list_size_one(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if gait event detection also works with stride list of length 1.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -160,7 +160,7 @@ def test_input_stride_list_size_one(self, healthy_example_imu_data, healthy_exam # per default segmented_event_list_ has 5 columns assert_array_equal(np.array(ed.segmented_event_list_.shape[1]), 5) - def test_correct_s_id(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_correct_s_id(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if the s_id from the stride list is correctly transferred to the output of event detection.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -182,7 +182,7 @@ def test_correct_s_id(self, healthy_example_imu_data, healthy_example_stride_bor assert np.all(combined["start_x"] == combined["start_y"]) assert np.all(combined["end_x"] == combined["end_y"]) - def test_single_data_multi_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_single_data_multi_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test correct error for combination of single sensor data set and multi sensor stride list.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -191,7 +191,7 @@ def test_single_data_multi_stride_list(self, healthy_example_imu_data, healthy_e with pytest.raises(ValidationError): ed.detect(data_left, stride_list_left, sampling_rate_hz=204.8) - def test_multi_data_single_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_multi_data_single_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test correct error for combination of multi sensor data set and single sensor stride list.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -210,7 +210,7 @@ def test_multi_data_single_stride_list(self, healthy_example_imu_data, healthy_e ("ic", "tc"), ], ) - def test_detect_only(self, detect_only, healthy_example_imu_data, healthy_example_stride_borders): + def test_detect_only(self, detect_only, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if only the specified events are detected.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) diff --git a/tests/test_event_detection/test_event_detection_rampp.py b/tests/test_event_detection/test_event_detection_rampp.py index 037c1b14..e49b6012 100644 --- a/tests/test_event_detection/test_event_detection_rampp.py +++ b/tests/test_event_detection/test_event_detection_rampp.py @@ -45,7 +45,7 @@ class TestEventDetectionRampp: algorithm_class = RamppEventDetection - def test_multi_sensor_input(self, healthy_example_imu_data, healthy_example_stride_borders, snapshot): + def test_multi_sensor_input(self, healthy_example_imu_data, healthy_example_stride_borders, snapshot) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -59,8 +59,8 @@ def test_multi_sensor_input(self, healthy_example_imu_data, healthy_example_stri snapshot.assert_match(ed.segmented_event_list_["left_sensor"], "left_segmented", check_dtype=False) snapshot.assert_match(ed.segmented_event_list_["right_sensor"], "right_segmented", check_dtype=False) - @pytest.mark.parametrize(("var1", "output"), ((True, 2), (False, 0))) - def test_postprocessing(self, healthy_example_imu_data, healthy_example_stride_borders, var1, output): + @pytest.mark.parametrize(("var1", "output"), [(True, 2), (False, 0)]) + def test_postprocessing(self, healthy_example_imu_data, healthy_example_stride_borders, var1, output) -> None: data_left = healthy_example_imu_data["left_sensor"] data_left.columns = BF_COLS # only use the first entry of the stride list @@ -78,10 +78,10 @@ def mock_func(event_list, *args, **kwargs): assert mock.call_count == output - @pytest.mark.parametrize(("enforce_consistency", "output"), ((False, False), (True, True))) + @pytest.mark.parametrize(("enforce_consistency", "output"), [(False, False), (True, True)]) def test_disable_min_vel_event_list( self, healthy_example_imu_data, healthy_example_stride_borders, enforce_consistency, output - ): + ) -> None: data_left = healthy_example_imu_data["left_sensor"] data_left.columns = BF_COLS # only use the first entry of the stride list @@ -92,7 +92,7 @@ def test_disable_min_vel_event_list( assert hasattr(ed, "min_vel_event_list_") == output - def test_multi_sensor_input_dict(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_multi_sensor_input_dict(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test to see if the algorithm is generally working on the example data when provided as dict.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -111,7 +111,7 @@ def test_multi_sensor_input_dict(self, healthy_example_imu_data, healthy_example assert list(datatype_helper.get_multi_sensor_names(ed.min_vel_event_list_)) == dict_keys assert list(datatype_helper.get_multi_sensor_names(ed.segmented_event_list_)) == dict_keys - def test_equal_output_dict_df(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_equal_output_dict_df(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if output is similar for input dicts or regular multisensor data sets.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -133,7 +133,7 @@ def test_equal_output_dict_df(self, healthy_example_imu_data, healthy_example_st assert_frame_equal(ed_df.min_vel_event_list_["left_sensor"], ed_dict.min_vel_event_list_["l"]) assert_frame_equal(ed_df.min_vel_event_list_["right_sensor"], ed_dict.min_vel_event_list_["r"]) - def test_valid_input_data(self, healthy_example_stride_borders): + def test_valid_input_data(self, healthy_example_stride_borders) -> None: """Test if error is raised correctly on invalid input data type.""" data = pd.DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]}) ed = self.algorithm_class() @@ -142,13 +142,13 @@ def test_valid_input_data(self, healthy_example_stride_borders): assert "The passed object appears to be neither single- or multi-sensor data" in str(e) - def test_min_vel_search_win_size_ms_dummy_data(self): + def test_min_vel_search_win_size_ms_dummy_data(self) -> None: """Test if error is raised correctly if windows size matches the size of the input data.""" dummy_gyr = np.ones((100, 3)) with pytest.raises(ValueError, match=r"min_vel_search_win_size_ms is*"): _detect_min_vel_gyr_energy(dummy_gyr, dummy_gyr.size) - def test_valid_min_vel_search_win_size_ms(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_valid_min_vel_search_win_size_ms(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if error is raised correctly on too large min_vel_search_win_size_ms.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -157,7 +157,7 @@ def test_valid_min_vel_search_win_size_ms(self, healthy_example_imu_data, health with pytest.raises(ValueError, match=r"min_vel_search_win_size_ms is *"): ed.detect(data_left, stride_list_left, sampling_rate_hz=204.8) - def test_valid_ic_search_region_ms(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_valid_ic_search_region_ms(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if error is raised correctly on too small ic_search_region_ms.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -166,7 +166,7 @@ def test_valid_ic_search_region_ms(self, healthy_example_imu_data, healthy_examp with pytest.raises(ValueError): ed.detect(data_left, stride_list_left, sampling_rate_hz=204.8) - def test_input_stride_list_size_one(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_input_stride_list_size_one(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if gait event detection also works with stride list of length 1.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -179,7 +179,7 @@ def test_input_stride_list_size_one(self, healthy_example_imu_data, healthy_exam # per default segmented_event_list_ has 5 columns assert_array_equal(np.array(ed.segmented_event_list_.shape[1]), 5) - def test_correct_s_id(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_correct_s_id(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if the s_id from the stride list is correctly transferred to the output of event detection.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -201,7 +201,7 @@ def test_correct_s_id(self, healthy_example_imu_data, healthy_example_stride_bor assert np.all(combined["start_x"] == combined["start_y"]) assert np.all(combined["end_x"] == combined["end_y"]) - def test_single_data_multi_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_single_data_multi_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test correct error for combination of single sensor data set and multi sensor stride list.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -210,7 +210,7 @@ def test_single_data_multi_stride_list(self, healthy_example_imu_data, healthy_e with pytest.raises(ValidationError): ed.detect(data_left, stride_list_left, sampling_rate_hz=204.8) - def test_multi_data_single_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders): + def test_multi_data_single_stride_list(self, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test correct error for combination of multi sensor data set and single sensor stride list.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) @@ -219,7 +219,7 @@ def test_multi_data_single_stride_list(self, healthy_example_imu_data, healthy_e with pytest.raises(ValidationError): ed.detect(data_left, stride_list_left, sampling_rate_hz=204.8) - def test_sign_change_for_detect_tc(self): + def test_sign_change_for_detect_tc(self) -> None: """Test correct handling of signal that does or does not provide a change of the sign.""" # with sign change signal1 = np.concatenate([np.ones(10), np.ones(10) * -1]) @@ -239,7 +239,7 @@ def test_sign_change_for_detect_tc(self): ("ic", "tc"), ], ) - def test_detect_only(self, detect_only, healthy_example_imu_data, healthy_example_stride_borders): + def test_detect_only(self, detect_only, healthy_example_imu_data, healthy_example_stride_borders) -> None: """Test if only the specified events are detected.""" data_left = healthy_example_imu_data["left_sensor"] data_left = coordinate_conversion.convert_left_foot_to_fbf(data_left) diff --git a/tests/test_examples/test_all_examples.py b/tests/test_examples/test_all_examples.py index a98ac430..65dac807 100644 --- a/tests/test_examples/test_all_examples.py +++ b/tests/test_examples/test_all_examples.py @@ -11,21 +11,21 @@ matplotlib.use("Agg") -def test_base_dtw_generic(snapshot): +def test_base_dtw_generic(snapshot) -> None: from examples.generic_algorithms.base_dtw_generic import dtw assert len(dtw.matches_start_end_) == 5 snapshot.assert_match(dtw.matches_start_end_) -def test_barth_dtw_example(snapshot): +def test_barth_dtw_example(snapshot) -> None: from examples.stride_segmentation.barth_dtw_stride_segmentation import dtw assert len(dtw.matches_start_end_["left_sensor"]) == 28 snapshot.assert_match(dtw.matches_start_end_["left_sensor"]) -def test_barth_dtw_custom_template(snapshot): +def test_barth_dtw_custom_template(snapshot) -> None: from examples.stride_segmentation.barth_dtw_custom_template import dtw snapshot.assert_match(dtw.template.get_data()[:10]) @@ -33,7 +33,7 @@ def test_barth_dtw_custom_template(snapshot): assert dtw.template.scaling.data_max == 524.7659568483249 -def test_constrained_barth_dtw_example(snapshot): +def test_constrained_barth_dtw_example(snapshot) -> None: from examples.stride_segmentation.constrained_barth_dtw_stride_segmentation import cdtw, default_cdtw, dtw assert len(dtw.matches_start_end_["left_sensor"]) == 74 @@ -46,13 +46,13 @@ def test_constrained_barth_dtw_example(snapshot): snapshot.assert_match(default_cdtw.matches_start_end_["left_sensor"], "default_cdtw") -def test_roi(snapshot): +def test_roi(snapshot) -> None: from examples.stride_segmentation.barth_dtw_stride_segmentation_roi import roi_seg snapshot.assert_match(roi_seg.stride_list_["left_sensor"]) -def test_preprocessing_example(snapshot): +def test_preprocessing_example(snapshot) -> None: from examples.preprocessing.manual_sensor_alignment import dataset_sf_aligned_to_gravity desired_acc_vec = np.array([0.0, 0.0, 9.81]) @@ -69,7 +69,7 @@ def test_preprocessing_example(snapshot): snapshot.assert_match(dataset_sf_aligned_to_gravity.to_numpy()[:1000]) -def test_sensor_alignment_detailed_example(snapshot): +def test_sensor_alignment_detailed_example(snapshot) -> None: from examples.preprocessing.automatic_sensor_alignment_details import ( forward_aligned_data, gravity_aligned_data, @@ -93,7 +93,7 @@ def test_sensor_alignment_detailed_example(snapshot): snapshot.assert_match(forward_aligned_data[sensor].to_numpy()[:1000]) -def test_sensor_alignment_detailed_simple(snapshot): +def test_sensor_alignment_detailed_simple(snapshot) -> None: from examples.preprocessing.automatic_sensor_alignment_details import ( forward_aligned_data, gravity_aligned_data, @@ -107,19 +107,19 @@ def test_sensor_alignment_detailed_simple(snapshot): snapshot.assert_match(forward_aligned_data[sensor].to_numpy()[:1000]) -def test_temporal_parameters(snapshot): +def test_temporal_parameters(snapshot) -> None: from examples.parameters.temporal_parameters import p snapshot.assert_match(p.parameters_["left_sensor"]) -def test_spatial_parameters(snapshot): +def test_spatial_parameters(snapshot) -> None: from examples.parameters.spatial_parameters import p snapshot.assert_match(p.parameters_["left_sensor"]) -def test_rampp_event_detection(snapshot): +def test_rampp_event_detection(snapshot) -> None: from examples.event_detection.rampp_event_detection import ed, edfilt assert len(ed.min_vel_event_list_["left_sensor"]) == 26 @@ -132,7 +132,7 @@ def test_rampp_event_detection(snapshot): assert_frame_equal(ed.min_vel_event_list_["right_sensor"], edfilt.min_vel_event_list_["right_sensor"]) -def test_herzer_event_detection(snapshot): +def test_herzer_event_detection(snapshot) -> None: from examples.event_detection.herzer_event_detection import ed assert len(ed.min_vel_event_list_["left_sensor"]) == 26 @@ -141,7 +141,7 @@ def test_herzer_event_detection(snapshot): snapshot.assert_match(ed.min_vel_event_list_["right_sensor"]) -def test_json_example(snapshot): +def test_json_example(snapshot) -> None: from examples.advanced_features.algo_serialize import json_str, loaded_slt, slt snapshot.assert_match(json_str) @@ -149,7 +149,7 @@ def test_json_example(snapshot): compare_algo_objects(slt, loaded_slt) -def test_trajectory_reconstruction(snapshot): +def test_trajectory_reconstruction(snapshot) -> None: from examples.trajectory_reconstruction.trajectory_reconstruction import trajectory # just look at last values to see if final result is correct and save runtime @@ -157,7 +157,7 @@ def test_trajectory_reconstruction(snapshot): snapshot.assert_match(trajectory.orientation_["left_sensor"].tail(20)) -def test_region_trajectory_reconstruction(snapshot): +def test_region_trajectory_reconstruction(snapshot) -> None: from examples.trajectory_reconstruction.trajectory_reconstruction_region import ( trajectory_full, trajectory_per_stride, @@ -171,7 +171,7 @@ def test_region_trajectory_reconstruction(snapshot): snapshot.assert_match(trajectory_per_stride.orientation_["left_sensor"].loc[4].tail(20)) -def test_mad_pipeline(snapshot): +def test_mad_pipeline(snapshot) -> None: from examples.full_pipelines.mad_gait_pipeline import ed, spatial_paras, temporal_paras snapshot.assert_match(ed.min_vel_event_list_["left_sensor"], "strides_left") @@ -185,14 +185,14 @@ def test_mad_pipeline(snapshot): snapshot.assert_match(temporal_paras.parameters_pretty_["left_sensor"], "temporal_paras_left") -def test_ullrich_gait_sequence_detection(snapshot): +def test_ullrich_gait_sequence_detection(snapshot) -> None: from examples.gait_detection.ullrich_gait_sequence_detection import gsd assert len(gsd.gait_sequences_) == 2 snapshot.assert_match(gsd.gait_sequences_.astype(np.int64)) -def test_caching(snapshot): +def test_caching(snapshot) -> None: from examples.advanced_features.caching import first_call_results, second_call_results # We will not store the actual ouputs, but just check if they are actually idential @@ -200,20 +200,20 @@ def test_caching(snapshot): assert_frame_equal(s_list, second_call_results.stride_list_[sensor]) -def test_custom_dataset(): +def test_custom_dataset() -> None: # There is not really anything specific, we want to test here, so we just run everything and check that there are # no errors. import examples.datasets_and_pipelines.custom_dataset # noqa -def test_grid_search(snapshot): +def test_grid_search(snapshot) -> None: from examples.datasets_and_pipelines.gridsearch import results, segmented_stride_list snapshot.assert_match(segmented_stride_list, check_dtype=False) snapshot.assert_match(pd.DataFrame(results).drop("score_time", axis=1), check_dtype=False) -def test_optimizable_pipelines(snapshot): +def test_optimizable_pipelines(snapshot) -> None: from examples.datasets_and_pipelines.optimizable_pipelines import optimized_results, results snapshot.assert_match(results.segmented_stride_list_, check_dtype=False) @@ -221,14 +221,14 @@ def test_optimizable_pipelines(snapshot): snapshot.assert_match(optimized_results.template.get_data()) -def test_cross_validation(snapshot): +def test_cross_validation(snapshot) -> None: from examples.datasets_and_pipelines.cross_validation import result_df result_df = result_df.drop(["score_time", "optimize_time", "optimizer"], axis=1) snapshot.assert_match(result_df, check_dtype=False) -def test_gridsearch_cv(snapshot): +def test_gridsearch_cv(snapshot) -> None: from examples.datasets_and_pipelines.gridsearch_cv import cached_results, results_df ignore_cols = ["mean_score_time", "mean_optimize_time", "std_optimize_time", "std_score_time"] @@ -241,7 +241,7 @@ def test_gridsearch_cv(snapshot): snapshot.assert_match(results_df, check_dtype=False) -def test_multi_process(): +def test_multi_process() -> None: """Test the multiprocess example. We do not test the multi process example. @@ -249,7 +249,7 @@ def test_multi_process(): """ -def test_roth_hmm_stride_segmentation(snapshot): +def test_roth_hmm_stride_segmentation(snapshot) -> None: import pytest pytest.importorskip("pomegranate") @@ -259,7 +259,7 @@ def test_roth_hmm_stride_segmentation(snapshot): snapshot.assert_match(hmm_seg.stride_list_["right_sensor"], "right_sensor") -def test_segmentation_hmm_training(snapshot): +def test_segmentation_hmm_training(snapshot) -> None: import pytest pytest.importorskip("pomegranate") @@ -272,14 +272,14 @@ def test_segmentation_hmm_training(snapshot): snapshot.assert_match(hmm.stride_list_["right_sensor"], "right_sensor") -def test_zupt_dependency(): +def test_zupt_dependency() -> None: from examples.trajectory_reconstruction.zupt_dependency import gs assert_almost_equal(gs.best_score_, 0.10171051541126) assert_almost_equal(gs.best_params_["zupt_method__inactive_signal_threshold"], 2782559402.207125, decimal=3) -def test_advanced_kalman(snapshot): +def test_advanced_kalman(snapshot) -> None: from examples.trajectory_reconstruction.advanced_kalman_filter_usage import combo_zupt, madgwick_rts_no_zupt # We test some of the results of methods that don't really have regression tests anywhery else. diff --git a/tests/test_gait_detection/test_ullrich_gait_sequence_detection.py b/tests/test_gait_detection/test_ullrich_gait_sequence_detection.py index 1c188d8e..57f837e8 100644 --- a/tests/test_gait_detection/test_ullrich_gait_sequence_detection.py +++ b/tests/test_gait_detection/test_ullrich_gait_sequence_detection.py @@ -32,7 +32,7 @@ class TestMetaFunctionality(MetaTestConfig, TestAlgorithmMixin): class TestUllrichGaitSequenceDetection: """Test the gait sequence detection by Ullrich.""" - def test_single_sensor_input(self, healthy_example_imu_data, snapshot): + def test_single_sensor_input(self, healthy_example_imu_data, snapshot) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -51,7 +51,7 @@ def test_single_sensor_input(self, healthy_example_imu_data, snapshot): assert isinstance(gsd.start_, np.ndarray) assert isinstance(gsd.end_, np.ndarray) - def test_multi_sensor_input(self, healthy_example_imu_data, snapshot): + def test_multi_sensor_input(self, healthy_example_imu_data, snapshot) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -72,7 +72,7 @@ def test_multi_sensor_input(self, healthy_example_imu_data, snapshot): @pytest.mark.parametrize( ("sensor_channel_config", "peak_prominence", "merge_gait_sequences_from_sensors"), - ( + [ ("gyr_ml", 17, False), ("gyr_ml", 17, True), ("acc_si", 8, False), @@ -81,7 +81,7 @@ def test_multi_sensor_input(self, healthy_example_imu_data, snapshot): ("acc", 13, True), ("gyr", 11, False), ("gyr", 11, True), - ), + ], ) def test_different_activities_different_configs( self, @@ -90,7 +90,7 @@ def test_different_activities_different_configs( peak_prominence, merge_gait_sequences_from_sensors, snapshot, - ): + ) -> None: """Test if the algorithm is generally working with different sensor channel configs and their respective optimal peak prominence thresholds. """ @@ -126,7 +126,7 @@ def test_different_activities_different_configs( assert all(gsd.end_["left_sensor"] == gsd.gait_sequences_["left_sensor"]["end"]) snapshot.assert_match(gsd.gait_sequences_["left_sensor"], check_dtype=False) - def test_signal_length_one_window_size(self, healthy_example_imu_data, snapshot): + def test_signal_length_one_window_size(self, healthy_example_imu_data, snapshot) -> None: """Test to see if the algorithm is working if the signal length equals to one window size.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -145,7 +145,7 @@ def test_signal_length_one_window_size(self, healthy_example_imu_data, snapshot) assert len(gsd.start_) == 1 assert len(gsd.end_) == 1 - def test_on_signal_without_activity(self, snapshot): + def test_on_signal_without_activity(self, snapshot) -> None: """Test to see if the algorithm is working if the signal contains no activity at all.""" data_columns = BF_COLS @@ -160,7 +160,7 @@ def test_on_signal_without_activity(self, snapshot): assert len(gsd.start_) == 0 assert len(gsd.end_) == 0 - def test_on_signal_with_only_nongait(self, snapshot): + def test_on_signal_with_only_nongait(self, snapshot) -> None: """Test to see if the algorithm is working if the signal contains only non-gait activity.""" data_columns = BF_COLS @@ -182,7 +182,7 @@ def test_on_signal_with_only_nongait(self, snapshot): assert len(gsd.start_) == 0 assert len(gsd.end_) == 0 - def test_invalid_sensor_channel_config_type(self, healthy_example_imu_data): + def test_invalid_sensor_channel_config_type(self, healthy_example_imu_data) -> None: """Check if ValueError is raised for wrong sensor_channel_config data type.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -194,7 +194,7 @@ def test_invalid_sensor_channel_config_type(self, healthy_example_imu_data): gsd.detect(data, 204.8) @pytest.mark.parametrize("sensor_channel_config", "dummy") - def test_invalid_sensor_channel_config_value(self, healthy_example_imu_data, sensor_channel_config): + def test_invalid_sensor_channel_config_value(self, healthy_example_imu_data, sensor_channel_config) -> None: """Check if ValueError is raised for wrong sensor_channel_config data type.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -204,7 +204,7 @@ def test_invalid_sensor_channel_config_value(self, healthy_example_imu_data, sen gsd = UllrichGaitSequenceDetection(sensor_channel_config=sensor_channel_config) gsd.detect(data, 204.8) - def test_invalid_window_size(self, healthy_example_imu_data): + def test_invalid_window_size(self, healthy_example_imu_data) -> None: """Check if ValueError is raised for window size higher than len of signal.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -217,8 +217,8 @@ def test_invalid_window_size(self, healthy_example_imu_data): gsd = UllrichGaitSequenceDetection(window_size_s=window_size_s) gsd.detect(data, 204.8) - @pytest.mark.parametrize("locomotion_band", ([1], (0, 1, 2))) - def test_invalid_locomotion_band_size(self, healthy_example_imu_data, locomotion_band): + @pytest.mark.parametrize("locomotion_band", [[1], (0, 1, 2)]) + def test_invalid_locomotion_band_size(self, healthy_example_imu_data, locomotion_band) -> None: """Check if ValueError is raised for locomotion band with other than two values.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -228,8 +228,8 @@ def test_invalid_locomotion_band_size(self, healthy_example_imu_data, locomotion gsd1 = UllrichGaitSequenceDetection(locomotion_band=locomotion_band) gsd1.detect(data, 204.8) - @pytest.mark.parametrize("locomotion_band", ((3, 0.5), (0.5, 0.5))) - def test_invalid_locomotion_value_order(self, healthy_example_imu_data, locomotion_band): + @pytest.mark.parametrize("locomotion_band", [(3, 0.5), (0.5, 0.5)]) + def test_invalid_locomotion_value_order(self, healthy_example_imu_data, locomotion_band) -> None: """Check if ValueError is raised for locomotion band where second value is smaller or equal than first.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -239,7 +239,7 @@ def test_invalid_locomotion_value_order(self, healthy_example_imu_data, locomoti gsd = UllrichGaitSequenceDetection(locomotion_band=locomotion_band) gsd.detect(data, 204.8) - def test_invalid_locomotion_upper_value(self, healthy_example_imu_data): + def test_invalid_locomotion_upper_value(self, healthy_example_imu_data) -> None: """Check if ValueError is raised for locomotion band where the upper limit is too close to Nyquist freq.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -250,8 +250,8 @@ def test_invalid_locomotion_upper_value(self, healthy_example_imu_data): gsd = UllrichGaitSequenceDetection(locomotion_band=locomotion_band) gsd.detect(data, 204.8) - @pytest.mark.parametrize("harmonic_tolerance_hz", (-3, 0)) - def test_invalid_harmonic_tolerance(self, healthy_example_imu_data, harmonic_tolerance_hz): + @pytest.mark.parametrize("harmonic_tolerance_hz", [-3, 0]) + def test_invalid_harmonic_tolerance(self, healthy_example_imu_data, harmonic_tolerance_hz) -> None: """Check if ValueError is raised for harmonic tolerance of being too small Hz.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -261,7 +261,7 @@ def test_invalid_harmonic_tolerance(self, healthy_example_imu_data, harmonic_tol gsd_1 = UllrichGaitSequenceDetection(harmonic_tolerance_hz=harmonic_tolerance_hz) gsd_1.detect(data, 204.8) - def test_invalid_merging_gait_sequences(self, healthy_example_imu_data): + def test_invalid_merging_gait_sequences(self, healthy_example_imu_data) -> None: """Check if data and value for merge_gait_sequences_from_sensors fit to each other. Only gait sequences detected from synced data can be merge. """ @@ -276,7 +276,7 @@ def test_invalid_merging_gait_sequences(self, healthy_example_imu_data): gsd = UllrichGaitSequenceDetection(merge_gait_sequences_from_sensors=merge_gait_sequences_from_sensors) gsd.detect(data_dict, 204.8) - def test_merging_gait_sequences(self, healthy_example_imu_data): + def test_merging_gait_sequences(self, healthy_example_imu_data) -> None: """Check if merging of gait sequences works for synchronized data.""" data = coordinate_conversion.convert_to_fbf( healthy_example_imu_data, left=["left_sensor"], right=["right_sensor"] @@ -298,7 +298,7 @@ def test_merging_gait_sequences(self, healthy_example_imu_data): assert_frame_equal(gsd_merged.gait_sequences_["left_sensor"], gsd_merged.gait_sequences_["right_sensor"]) - def test_merging_for_no_activity(self): + def test_merging_for_no_activity(self) -> None: """Test to see if the merging is working if the signal contains no activity at all.""" data_columns = BF_COLS @@ -318,7 +318,7 @@ def test_merging_for_no_activity(self): for sensor in ["left_sensor", "right_sensor"]: assert gsd.gait_sequences_[sensor].empty - def test_merging_on_signal_with_only_nongait(self): + def test_merging_on_signal_with_only_nongait(self) -> None: """Test to see if the merging is working if the signal contains only non-gait activity.""" data_columns = BF_COLS @@ -345,7 +345,7 @@ def test_merging_on_signal_with_only_nongait(self): for sensor in ["left_sensor", "right_sensor"]: assert gsd.gait_sequences_[sensor].empty - def test_merge_gait_sequences_multi_sensor_data(self): + def test_merge_gait_sequences_multi_sensor_data(self) -> None: """Test the pure functionality of the merging with dummy data.""" data_columns = BF_COLS @@ -368,7 +368,7 @@ def test_merge_gait_sequences_multi_sensor_data(self): for sensor in ["left_sensor", "right_sensor"]: assert_frame_equal(out[sensor], expected_merged, check_dtype=False) - def test_gait_sequence_concat(self): + def test_gait_sequence_concat(self) -> None: """Test the concatenation of subsequent gait sequences.""" sig_length = 95 window_size = 10 @@ -383,13 +383,13 @@ def test_gait_sequence_concat(self): @pytest.mark.parametrize( ("margin_s", "output"), - ( + [ (10, np.array([[90, 510], [590, 810], [990, 1410], [1890, 2000]])), # simple case, no overlaps (200, np.array([[0, 1600], [1700, 2000]])), # simple overlaps, exceeding signal range (800, np.array([[0, 2000]])), # multiple overlaps - ), + ], ) - def test_adding_of_margin(self, margin_s, output): + def test_adding_of_margin(self, margin_s, output) -> None: """Test the addition of a symmetric margin to all gait sequences.""" # dummy array with start and end sample values of gait sequences gait_sequences_start_end = np.array([[100, 500], [600, 800], [1000, 1400], [1900, 2000]]) @@ -405,7 +405,7 @@ def test_adding_of_margin(self, margin_s, output): # margin np.testing.assert_array_equal(margin_added, output) - def test_adding_of_margin_empty_gsd(self): + def test_adding_of_margin_empty_gsd(self) -> None: """Test correct behavior in case no gait sequences are detected.""" gait_sequences_start_end = np.array([]) sig_length = 2000 diff --git a/tests/test_parameters/test_spatial_parameters.py b/tests/test_parameters/test_spatial_parameters.py index e2a20698..f0608b65 100644 --- a/tests/test_parameters/test_spatial_parameters.py +++ b/tests/test_parameters/test_spatial_parameters.py @@ -116,20 +116,20 @@ def after_action_instance( class TestIndividualParameter: - def test_stride_length(self, single_sensor_position_list_with_index, single_sensor_stride_length): + def test_stride_length(self, single_sensor_position_list_with_index, single_sensor_stride_length) -> None: assert_series_equal(_calc_stride_length(single_sensor_position_list_with_index), single_sensor_stride_length) - def test_arc_length(self, single_sensor_position_list_with_index, single_sensor_arc_length): + def test_arc_length(self, single_sensor_position_list_with_index, single_sensor_arc_length) -> None: assert_series_equal(_calc_arc_length(single_sensor_position_list_with_index), single_sensor_arc_length) - def test_turning_angle(self, single_sensor_orientation_list_with_index, single_sensor_turning_angle): + def test_turning_angle(self, single_sensor_orientation_list_with_index, single_sensor_turning_angle) -> None: assert_series_equal( _calc_turning_angle(single_sensor_orientation_list_with_index), single_sensor_turning_angle, check_exact=False, ) - def test_turning_angle_empty_orientation(self): + def test_turning_angle_empty_orientation(self) -> None: """Test the turning angle computation in case of empty orientation input. For scipy<=1.5.4 this produced an empty Series automatically. For scipy>1.6.0 we need to handle the case of @@ -147,12 +147,12 @@ def test_turning_angle_empty_orientation(self): check_exact=False, ) - def test_sole_angle(self, single_sensor_orientation_list_with_index, single_sensor_sole_angle_course): + def test_sole_angle(self, single_sensor_orientation_list_with_index, single_sensor_sole_angle_course) -> None: assert_series_equal( _compute_sole_angle_course(single_sensor_orientation_list_with_index), single_sensor_sole_angle_course ) - def test_sole_angle_empty_orientation(self): + def test_sole_angle_empty_orientation(self) -> None: """Test the sole angle computation in case of empty orientation input. For scipy<=1.5.4 this produced an empty Series automatically. For scipy>1.6.0 we need to handle the case of @@ -187,7 +187,7 @@ class TestSpatialParameterCalculation: def test_single_sensor( self, single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list - ): + ) -> None: """Test calculate spatial parameters for single sensor.""" t = SpatialParameterCalculation() t.calculate(single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list, 100) @@ -198,7 +198,7 @@ def test_single_sensor( def test_multiple_sensor( self, single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list - ): + ) -> None: """Test calculate spatial parameters for single sensor and single stride.""" stride_events_list = {"sensor1": single_sensor_stride_list, "sensor2": single_sensor_stride_list} position_list = {"sensor1": single_sensor_position_list, "sensor2": single_sensor_position_list} @@ -230,7 +230,7 @@ def test_partial_info( single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list, - ): + ) -> None: """Test that it is possible to calculate spatial parameters with partial information.""" stride_list = single_sensor_stride_list.drop(list(exclude), axis=1) stride_events_list = {"sensor1": stride_list, "sensor2": stride_list} @@ -242,7 +242,7 @@ def test_partial_info( assert set(sensor.columns) == set(self.parameters) - set(expected_missing) assert len(sensor) == len(single_sensor_stride_list) - def test_only_ori_provided(self, single_sensor_stride_list, single_sensor_orientation_list): + def test_only_ori_provided(self, single_sensor_stride_list, single_sensor_orientation_list) -> None: """Test calculate spatial parameters for single sensor.""" t = SpatialParameterCalculation() t.calculate( @@ -260,7 +260,7 @@ def test_only_pos_provided( self, single_sensor_stride_list, single_sensor_position_list, - ): + ) -> None: t = SpatialParameterCalculation() t.calculate( stride_event_list=single_sensor_stride_list, @@ -282,7 +282,7 @@ def test_only_pos_provided( @pytest.mark.parametrize("calculate_only", [["stride_length"], ["gait_velocity"], ["arc_length", "stride_length"]]) def test_calculate_only( self, single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list, calculate_only - ): + ) -> None: t = SpatialParameterCalculation(calculate_only=calculate_only) t.calculate(single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list, 100) @@ -290,7 +290,7 @@ def test_calculate_only( def test_stride_list_types( self, single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list - ): + ) -> None: # The default single_sensor_stride_list is a min_vel stride list. # If we set expected_stride_type to "ic", we should get an error. t = SpatialParameterCalculation(expected_stride_type="ic") @@ -327,7 +327,7 @@ def test_stride_list_types( def test_empty_stride_list_throws_no_error( self, single_sensor_stride_list, single_sensor_position_list, single_sensor_orientation_list - ): + ) -> None: """Test that an empty stride list does not throw an error.""" t = SpatialParameterCalculation(expected_stride_type="ic") single_sensor_stride_list = single_sensor_stride_list.iloc[0:0] @@ -341,7 +341,7 @@ def test_empty_stride_list_throws_no_error( class TestSpatialParameterRegression: def test_regression_on_example_data( self, healthy_example_orientation, healthy_example_position, healthy_example_stride_events, snapshot - ): + ) -> None: # Convert stride list back to mocap samples: healthy_example_stride_events["left_sensor"][["start", "end", "tc", "ic", "min_vel", "pre_ic"]] *= 100 / 204.8 healthy_example_stride_events["right_sensor"][["start", "end", "tc", "ic", "min_vel", "pre_ic"]] *= 100 / 204.8 diff --git a/tests/test_parameters/test_temporal_parameter.py b/tests/test_parameters/test_temporal_parameter.py index 051030a2..a4d2dd05 100644 --- a/tests/test_parameters/test_temporal_parameter.py +++ b/tests/test_parameters/test_temporal_parameter.py @@ -73,15 +73,16 @@ def stride_list(self, stride_list_type): return _min_vel_stride_list() elif stride_list_type == "ic": return ic_stride_list() + return None - def test_single_sensor_multiple_strides(self, stride_list, stride_list_type): + def test_single_sensor_multiple_strides(self, stride_list, stride_list_type) -> None: """Test calculate temporal parameters for single sensor.""" stride_events_list, temporal_parameters = stride_list t = TemporalParameterCalculation(expected_stride_type=stride_list_type) t.calculate(stride_events_list, 100) assert_frame_equal(t.parameters_, temporal_parameters) - def test_multiple_sensor(self, stride_list, stride_list_type): + def test_multiple_sensor(self, stride_list, stride_list_type) -> None: """Test calculate temporal parameters for multiple sensors , multiple strides for all sensors.""" stride_events_list1, temporal_parameters = stride_list stride_events_list = {"sensor1": stride_events_list1.iloc[:2], "sensor2": stride_events_list1} @@ -96,14 +97,14 @@ def test_multiple_sensor(self, stride_list, stride_list_type): class TestTemporalParametersIcStrideList: - def test_single_sensor_multiple_strides(self, min_vel_stride_list): + def test_single_sensor_multiple_strides(self, min_vel_stride_list) -> None: """Test calculate temporal parameters for single sensor.""" stride_events_list, temporal_parameters = min_vel_stride_list t = TemporalParameterCalculation() t.calculate(stride_events_list, 100) assert_frame_equal(t.parameters_, temporal_parameters) - def test_multiple_sensor(self, min_vel_stride_list): + def test_multiple_sensor(self, min_vel_stride_list) -> None: """Test calculate temporal parameters for multiple sensors , multiple strides for all sensors.""" stride_events_list1, temporal_parameters = min_vel_stride_list stride_events_list = {"sensor1": stride_events_list1.iloc[:2], "sensor2": stride_events_list1} @@ -118,7 +119,7 @@ def test_multiple_sensor(self, min_vel_stride_list): class TestTemporalParameterRegression: - def test_regression_on_example_data(self, healthy_example_stride_events, snapshot): + def test_regression_on_example_data(self, healthy_example_stride_events, snapshot) -> None: healthy_example_stride_events = healthy_example_stride_events["left_sensor"] t = TemporalParameterCalculation() t.calculate(healthy_example_stride_events, 204.8) diff --git a/tests/test_preprocessing/test_forward_direction_alignment.py b/tests/test_preprocessing/test_forward_direction_alignment.py index 6a1d91ee..0a368c37 100644 --- a/tests/test_preprocessing/test_forward_direction_alignment.py +++ b/tests/test_preprocessing/test_forward_direction_alignment.py @@ -28,7 +28,7 @@ class TestMetaFunctionality(MetaTestConfig, TestAlgorithmMixin): class TestForwardDirectionSignAlignment: """Test the forward direction sign alignment class `ForwardDirectionSignAlignment`.""" - def test_single_sensor_input(self, healthy_example_imu_data): + def test_single_sensor_input(self, healthy_example_imu_data) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = healthy_example_imu_data["left_sensor"] @@ -43,7 +43,7 @@ def test_single_sensor_input(self, healthy_example_imu_data): assert isinstance(fdsa.ori_method_, BaseOrientationMethod) assert isinstance(fdsa.pos_method_, BasePositionMethod) - def test_multi_sensor_input(self, healthy_example_imu_data): + def test_multi_sensor_input(self, healthy_example_imu_data) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = healthy_example_imu_data @@ -59,7 +59,7 @@ def test_multi_sensor_input(self, healthy_example_imu_data): assert isinstance(fdsa.ori_method_[sensor], BaseOrientationMethod) assert isinstance(fdsa.pos_method_[sensor], BasePositionMethod) - def test_invalid_axis_combination(self): + def test_invalid_axis_combination(self) -> None: """Test if value error is raised correctly if invalid axis are defined.""" with pytest.raises(ValueError, match=r".*Invalid rotation axis! *"): ForwardDirectionSignAlignment(forward_direction="x", rotation_axis="a").align(1, sampling_rate_hz=1) @@ -68,17 +68,17 @@ def test_invalid_axis_combination(self): with pytest.raises(ValueError, match=r".*Invalid combination of rotation and forward direction axis! *"): ForwardDirectionSignAlignment(forward_direction="x", rotation_axis="x").align(1, sampling_rate_hz=1) - def test_invalid_ori_method(self): + def test_invalid_ori_method(self) -> None: """Test if value error is raised correctly if invalid ori_method class is passed.""" with pytest.raises(ValueError, match=r".*The provided `ori_method` *"): ForwardDirectionSignAlignment(ori_method="abc").align(1, sampling_rate_hz=2) - def test_invalid_pos_method(self): + def test_invalid_pos_method(self) -> None: """Test if value error is raised correctly if invalid pos_method class is passed.""" with pytest.raises(ValueError, match=r".*The provided `pos_method` *"): ForwardDirectionSignAlignment(pos_method="abc").align(1, sampling_rate_hz=2) - def test_no_rotation(self, healthy_example_imu_data): + def test_no_rotation(self, healthy_example_imu_data) -> None: """Test that no rotation is applied if the data is not rotated.""" data = healthy_example_imu_data @@ -88,7 +88,7 @@ def test_no_rotation(self, healthy_example_imu_data): assert_almost_equal(data[sensor].to_numpy(), fwdsa.aligned_data_[sensor].to_numpy()) assert_almost_equal(np.rad2deg(fwdsa.rotation_[sensor].as_euler("zxy")), np.array([0.0, 0.0, 0.0])) - def test_flip_rotation(self, healthy_example_imu_data): + def test_flip_rotation(self, healthy_example_imu_data) -> None: """Test that a correct 180flip is applied if the data is upside-down.""" data = healthy_example_imu_data dataset_flipped = rotate_dataset(data, Rotation.from_euler("z", 180, degrees=True)) diff --git a/tests/test_preprocessing/test_pca_alignment.py b/tests/test_preprocessing/test_pca_alignment.py index c79dcf48..3c225e96 100644 --- a/tests/test_preprocessing/test_pca_alignment.py +++ b/tests/test_preprocessing/test_pca_alignment.py @@ -25,7 +25,7 @@ def after_action_instance(self, healthy_example_imu_data) -> PcaAlignment: class TestPcaAlignment: """Test the pca alignment class `PcaAlignment`.""" - def test_single_sensor_input(self, healthy_example_imu_data): + def test_single_sensor_input(self, healthy_example_imu_data) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = healthy_example_imu_data["left_sensor"] @@ -38,7 +38,7 @@ def test_single_sensor_input(self, healthy_example_imu_data): assert isinstance(pca_align.rotation_, Rotation) assert isinstance(pca_align.pca_, PCA) - def test_multi_sensor_input(self, healthy_example_imu_data): + def test_multi_sensor_input(self, healthy_example_imu_data) -> None: """Dummy test to see if the algorithm is generally working on the example data.""" data = healthy_example_imu_data @@ -52,7 +52,7 @@ def test_multi_sensor_input(self, healthy_example_imu_data): assert isinstance(pca_align.rotation_[sensor], Rotation) assert isinstance(pca_align.pca_[sensor], PCA) - def test_invalid_pca_plane_axis(self, healthy_example_imu_data): + def test_invalid_pca_plane_axis(self, healthy_example_imu_data) -> None: """Test if value error is raised correctly if invalid axis for the search plane are defined.""" data = healthy_example_imu_data @@ -65,7 +65,7 @@ def test_invalid_pca_plane_axis(self, healthy_example_imu_data): with pytest.raises(ValueError, match=r".*Invalid axis for pca plane *"): PcaAlignment(target_axis="y", pca_plane_axis=("acc_x")).align(data) - def test_invalid_target_axis(self, healthy_example_imu_data): + def test_invalid_target_axis(self, healthy_example_imu_data) -> None: """Test if value error is raised correctly if invalid axis for the search plane are defined.""" data = healthy_example_imu_data @@ -77,12 +77,12 @@ def test_invalid_target_axis(self, healthy_example_imu_data): @pytest.mark.parametrize( ("axis", "rot"), - ( + [ ("x", np.array([[0.28177506, 0.95948049, 0.0], [-0.95948049, 0.28177506, -0.0], [-0.0, 0.0, 1.0]])), ("y", np.array([[0.95948049, -0.28177506, 0.0], [0.28177506, 0.95948049, 0.0], [0.0, 0.0, 1.0]])), - ), + ], ) - def test_correct_rotation_regression(self, healthy_example_imu_data, snapshot, axis, rot): + def test_correct_rotation_regression(self, healthy_example_imu_data, snapshot, axis, rot) -> None: """Test if the alignment actually returns the expected rotation matrix on real imu data.""" data = healthy_example_imu_data["left_sensor"] @@ -92,7 +92,7 @@ def test_correct_rotation_regression(self, healthy_example_imu_data, snapshot, a assert_almost_equal(rot, pca_align.rotation_.as_matrix()) snapshot.assert_match(pca_align.aligned_data_, check_names=False) - def test_correct_rotation_complementary(self, healthy_example_imu_data): + def test_correct_rotation_complementary(self, healthy_example_imu_data) -> None: """Test if the alignment actually returns the expected rotation matrix on real imu data.""" data = healthy_example_imu_data["left_sensor"] @@ -113,8 +113,8 @@ def test_correct_rotation_complementary(self, healthy_example_imu_data): pca_align_y.aligned_data_["gyr_x"].to_numpy(), -pca_align_x.aligned_data_["gyr_y"].to_numpy() ) - @pytest.mark.parametrize("axis", ("x", "y")) - def test_is_righthanded_rotation(self, healthy_example_imu_data, axis): + @pytest.mark.parametrize("axis", ["x", "y"]) + def test_is_righthanded_rotation(self, healthy_example_imu_data, axis) -> None: """Test if the resulting rotation object is a valid righthanded rotation.""" data = healthy_example_imu_data["left_sensor"] diff --git a/tests/test_preprocessing/test_sensor_alignment.py b/tests/test_preprocessing/test_sensor_alignment.py index 6ab6a5ce..f9b9d993 100644 --- a/tests/test_preprocessing/test_sensor_alignment.py +++ b/tests/test_preprocessing/test_sensor_alignment.py @@ -17,7 +17,7 @@ class TestAlignToGravity: sample_sensor_dataset: MultiSensorData @pytest.fixture(autouse=True, params=("dict", "frame")) - def _sample_sensor_data(self, request): + def _sample_sensor_data(self, request) -> None: """Create some sample data. This data is recreated before each test (using pytest.fixture). @@ -32,7 +32,7 @@ def _sample_sensor_data(self, request): elif request.param == "frame": self.sample_sensor_dataset = pd.concat(dataset, axis=1) - def test_no_static_moments_in_dataset(self): + def test_no_static_moments_in_dataset(self) -> None: """Test if value error is raised correctly if no static window can be found on dataset with given user settings. """ @@ -45,7 +45,7 @@ def test_no_static_moments_in_dataset(self): metric="maximum", ) - def test_mulit_sensor_dataset_misaligned(self): + def test_mulit_sensor_dataset_misaligned(self) -> None: """Test basic alignment using different 180 deg rotations on each dataset.""" gravity = np.array([0.0, 0.0, 1.0]) @@ -62,7 +62,7 @@ def test_mulit_sensor_dataset_misaligned(self): assert_almost_equal(aligned_dataset["s1"][SF_ACC].to_numpy(), np.repeat(gravity[None, :], 5, axis=0)) assert_almost_equal(aligned_dataset["s2"][SF_ACC].to_numpy(), np.repeat(gravity[None, :], 5, axis=0)) - def test_single_sensor_dataset_misaligned(self): + def test_single_sensor_dataset_misaligned(self) -> None: """Test basic alignment using different 180 deg rotations on single sensor.""" gravity = np.array([0.0, 0.0, 1.0]) @@ -79,7 +79,7 @@ def test_single_sensor_dataset_misaligned(self): class TestXYAlignment: @pytest.mark.parametrize("angle", [90, 180.0, 22.0, 45.0, -90, -45]) - def test_xy_alignment_simple(self, angle): + def test_xy_alignment_simple(self, angle) -> None: signal = np.random.normal(scale=1000, size=(500, 3)) rot_signal = rotation_from_angle(np.array([0, 0, 1]), np.deg2rad(angle)).apply(signal) rot = align_heading_of_sensors(signal, rot_signal) @@ -92,7 +92,7 @@ def test_xy_alignment_simple(self, angle): def test_xy_alignment_dummy( self, - ): + ) -> None: signal = np.random.normal(scale=1000, size=(500, 3)) rot_signal = rotation_from_angle(np.array([0, 0, 1]), 0).apply(signal) rot = align_heading_of_sensors(signal, rot_signal) @@ -101,7 +101,7 @@ def test_xy_alignment_dummy( assert_almost_equal(rotvec, [0, 0, 0]) @pytest.mark.parametrize("angle", [90, 180.0, 22.0, 45.0, -90, -45]) - def test_xy_alignment_with_noise(self, angle): + def test_xy_alignment_with_noise(self, angle) -> None: signal = np.random.normal(scale=1000, size=(500, 3)) rot_signal = rotation_from_angle(np.array([0, 0, 1]), np.deg2rad(angle)).apply(signal) @@ -116,7 +116,7 @@ def test_xy_alignment_with_noise(self, angle): assert_almost_equal(np.abs(rotvec / norm(rotvec) @ [0, 0, 1]), 1, 3) @pytest.mark.parametrize("angle", [90, 180.0, 22.0, 45.0, -90, -45]) - def test_smoothing(self, angle): + def test_smoothing(self, angle) -> None: signal = np.random.normal(scale=1000, size=(500, 3)) rot_signal = rotation_from_angle(np.array([0, 0, 1]), np.deg2rad(angle)).apply(signal) diff --git a/tests/test_stride_segmentation/test_barth_dtw.py b/tests/test_stride_segmentation/test_barth_dtw.py index 0673ad33..84b367cc 100644 --- a/tests/test_stride_segmentation/test_barth_dtw.py +++ b/tests/test_stride_segmentation/test_barth_dtw.py @@ -39,7 +39,7 @@ class TestCachingFunctionality(MetaTestConfig, TestCachingMixin): class TestRegressionOnRealData: - def test_real_data_both_feed_regression(self, healthy_example_imu_data, snapshot): + def test_real_data_both_feed_regression(self, healthy_example_imu_data, snapshot) -> None: data = convert_to_fbf(healthy_example_imu_data, right=["right_sensor"], left=["left_sensor"]) dtw = BarthDtw() # Test with default paras dtw.segment(data, sampling_rate_hz=204.8) @@ -50,7 +50,7 @@ def test_real_data_both_feed_regression(self, healthy_example_imu_data, snapshot snapshot.assert_match(dtw.stride_list_["left_sensor"], "left") snapshot.assert_match(dtw.stride_list_["right_sensor"], "right") - def test_snapping_on_off(self, healthy_example_imu_data): + def test_snapping_on_off(self, healthy_example_imu_data) -> None: data = convert_to_fbf(healthy_example_imu_data, right=["right_sensor"], left=["left_sensor"]).iloc[:1000] # off dtw = BarthDtw(snap_to_min_win_ms=None) @@ -69,7 +69,7 @@ def test_snapping_on_off(self, healthy_example_imu_data): assert not np.array_equal(dtw.matches_start_end_["left_sensor"], dtw.matches_start_end_original_["left_sensor"]) assert_array_equal(dtw.matches_start_end_original_["left_sensor"], out_without_snapping) - def test_conflict_resolution_on_off(self, healthy_example_imu_data): + def test_conflict_resolution_on_off(self, healthy_example_imu_data) -> None: data = convert_to_fbf(healthy_example_imu_data, right=["right_sensor"], left=["left_sensor"]).iloc[:1000] # For both cases set the threshold so high that wrong matches will occure max_cost = 5 @@ -106,7 +106,7 @@ def init_dtw(self, template, **kwargs): class TestBarthDtwAdditions(DtwTestBaseBarth): # TODO: Add a test were a stride ends at the last sample before snapping - def test_stride_list(self): + def test_stride_list(self) -> None: """Test that the output of the stride list is correct.""" sequence = 2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2] template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) @@ -119,7 +119,7 @@ def test_stride_list(self): expected_stride_list = expected_stride_list.set_index("s_id") assert_frame_equal(dtw.stride_list_.astype(np.int64), expected_stride_list.astype(np.int64)) - def test_stride_list_multi_d(self): + def test_stride_list_multi_d(self) -> None: """Test that the output of the stride list is correct.""" sensor1 = np.array([*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]) sensor1 = pd.DataFrame(sensor1, columns=["col1"]) @@ -142,14 +142,14 @@ def test_stride_list_multi_d(self): pd.DataFrame([[0, 2, 5]], columns=["s_id", "start", "end"]).set_index("s_id").astype(np.int64), ) - def test_stride_list_passes_test_func(self): + def test_stride_list_passes_test_func(self) -> None: sequence = 2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2] template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) dtw = self.init_dtw(template).segment(np.array(sequence), sampling_rate_hz=100.0) assert is_single_sensor_stride_list(dtw.stride_list_) - def test_stride_list_passes_test_func_multiple(self): + def test_stride_list_passes_test_func_multiple(self) -> None: sensor1 = np.array([*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]) sensor1 = pd.DataFrame(sensor1, columns=["col1"]) sensor2 = np.array([*np.ones(2) * 2, 0, 1.0, 0, *np.ones(8) * 2]) @@ -171,7 +171,7 @@ class TestPostProcessing: Snapping is not tested here (fully) as it is well covered by the regression tests """ - def test_simple_stride_time(self): + def test_simple_stride_time(self) -> None: example_stride_list = np.array([np.arange(10), np.arange(10) + 1.0]).T bad_strides_short = [2, 5, 9] example_stride_list[bad_strides_short, 1] -= 0.5 @@ -203,7 +203,7 @@ def test_simple_stride_time(self): # Check that the correct strides were identified assert np.all(~to_keep[bad_strides]) - def test_simple_double_start(self): + def test_simple_double_start(self) -> None: example_stride_list = np.array([np.arange(10), np.arange(10) + 1.0]).T cost = np.ones(len(example_stride_list)) # Introduce errors @@ -241,7 +241,7 @@ def test_simple_double_start(self): # Check that the correct 3 strides were identified assert np.all(~to_keep[bad_strides]) - def test_previous_removal_double_start(self): + def test_previous_removal_double_start(self) -> None: example_stride_list = np.array([np.arange(10), np.arange(10) + 1.0]).T cost = np.ones(len(example_stride_list)) # Introduce errors @@ -287,7 +287,7 @@ def test_previous_removal_double_start(self): # Check that the correct 5 strides were identified assert np.all(~to_keep[bad_strides]) - def test_previous_removal_double_start_unsorted(self): + def test_previous_removal_double_start_unsorted(self) -> None: example_stride_list = np.array([np.arange(10), np.arange(10) + 1.0]).T cost = np.ones(len(example_stride_list)) # Introduce errors @@ -333,7 +333,7 @@ def test_previous_removal_double_start_unsorted(self): # Check that the correct 5 strides were identified assert np.all(~to_keep[bad_strides]) - def test_post_post_warning_is_raised(self, healthy_example_imu_data): + def test_post_post_warning_is_raised(self, healthy_example_imu_data) -> None: data = convert_to_fbf(healthy_example_imu_data, right=["right_sensor"], left=["left_sensor"])[:1000] # Disable all conflict resolutions to force a double match dtw = BarthDtw(max_cost=10000, min_match_length_s=None, conflict_resolution=False, snap_to_min_win_ms=None) @@ -351,7 +351,7 @@ def test_post_post_warning_is_raised(self, healthy_example_imu_data): # Check that no UserWarning was recorded w.pop(UserWarning) - def test_snapping_edge_case(self): + def test_snapping_edge_case(self) -> None: """Testing if snapping works, even if one of the strides ends inclusive the last sample. This is a special case, as we add 1 to the end of all matches, so that the end is exclusive. diff --git a/tests/test_stride_segmentation/test_base_dtw.py b/tests/test_stride_segmentation/test_base_dtw.py index 65051fb8..5c9909e1 100644 --- a/tests/test_stride_segmentation/test_base_dtw.py +++ b/tests/test_stride_segmentation/test_base_dtw.py @@ -8,6 +8,7 @@ - The same is True for the threshold/max_cost """ + from typing import Dict, Union from unittest.mock import patch @@ -54,15 +55,15 @@ def init_dtw(self, template, **kwargs): class TestIOErrors(DtwTestBase): """Test that the correct errors are raised if wrong parameter values are provided.""" - def test_no_template_provided(self): + def test_no_template_provided(self) -> None: with pytest.raises(ValueError) as e: dtw = self.init_dtw(template=None) dtw.segment(None, None) assert "`template` must be specified" in str(e) - @pytest.mark.parametrize("data", (pd.DataFrame, [], None)) - def test_unsuitable_datatype(self, data): + @pytest.mark.parametrize("data", [pd.DataFrame, [], None]) + def test_unsuitable_datatype(self, data) -> None: """No proper Sensordata provided.""" template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) with pytest.raises(ValidationError) as e: @@ -71,7 +72,7 @@ def test_unsuitable_datatype(self, data): assert "neither single- or multi-sensor data" in str(e) - def test_invalid_template_combination(self): + def test_invalid_template_combination(self) -> None: """Invalid combinations of template and dataset format.""" template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) with pytest.raises(ValueError) as e: @@ -80,7 +81,7 @@ def test_invalid_template_combination(self): assert "Invalid combination of data and template" in str(e) - def test_multi_sensor_dataset_without_proper_template(self): + def test_multi_sensor_dataset_without_proper_template(self) -> None: """Invalid combination of template and multisensor dataset.""" # This template can not be used with multi sensor dataframes. template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) @@ -96,7 +97,7 @@ def test_multi_sensor_dataset_without_proper_template(self): assert "template must either be of type `Dict[str, DtwTemplate]`" in str(e) - def test_invalid_find_matches_method(self): + def test_invalid_find_matches_method(self) -> None: template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) with pytest.raises(ValueError) as e: dtw = self.init_dtw(template=template, find_matches_method="invalid_name") @@ -104,9 +105,9 @@ def test_invalid_find_matches_method(self): assert "find_matches_method" in str(e) - @pytest.mark.parametrize("para", ("max_template_stretch_ms", "max_signal_stretch_ms")) - @pytest.mark.parametrize("value", (-2, 0)) - def test_invalid_stretch_larger_zero(self, para, value): + @pytest.mark.parametrize("para", ["max_template_stretch_ms", "max_signal_stretch_ms"]) + @pytest.mark.parametrize("value", [-2, 0]) + def test_invalid_stretch_larger_zero(self, para, value) -> None: template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) with pytest.raises(ValueError) as e: dtw = self.init_dtw(template=template, **{para: value}) @@ -115,8 +116,8 @@ def test_invalid_stretch_larger_zero(self, para, value): assert para in str(e) assert str(value) in str(e) - @pytest.mark.parametrize("paras", ((None, 100), (100, None))) - def test_none_constrain_is_inf(self, paras): + @pytest.mark.parametrize("paras", [(None, 100), (100, None)]) + def test_none_constrain_is_inf(self, paras) -> None: para_names = ("max_template_stretch_ms", "max_signal_stretch_ms") func_names = ("max_subseq_steps", "max_longseq_steps") none_index = list(paras).index(None) @@ -139,7 +140,7 @@ def test_none_constrain_is_inf(self, paras): assert kwargs[func_names[none_index]] == np.inf assert kwargs[func_names[not none_index]] == paras[para_names[not none_index]] / 1000 * sampling_rate - def test_constrains_without_template_sampling_rate(self): + def test_constrains_without_template_sampling_rate(self) -> None: template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=None) data = np.array([*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]) dtw = self.init_dtw(template=template, max_template_stretch_ms=3) @@ -149,7 +150,7 @@ def test_constrains_without_template_sampling_rate(self): assert "sampling_rate_hz" in str(e) - def test_constrains_correct_sampling_rate_used(self): + def test_constrains_correct_sampling_rate_used(self) -> None: template_sampling_rate = 100 signal_sampling_rate = 10 template = DtwTemplate(data=np.ones(30), sampling_rate_hz=template_sampling_rate) @@ -191,14 +192,14 @@ class TestSimpleSegment(DtwTestBase): template = DtwTemplate(data=np.array([0, 1.0, 0]), sampling_rate_hz=100.0) @pytest.fixture(params=list(BaseDtw._allowed_methods_map.keys()), autouse=True) - def _create_instance(self, request): + def _create_instance(self, request) -> None: dtw = self.init_dtw( template=self.template, find_matches_method=request.param, ) self.dtw = dtw - def test_sdtw_simple_match(self): + def test_sdtw_simple_match(self) -> None: """Test dtw with single match and hand calculated outcomes.""" sequence = [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2] @@ -221,7 +222,7 @@ def test_sdtw_simple_match(self): np.testing.assert_array_equal(dtw.data, sequence) - def test_sdtw_multi_match(self): + def test_sdtw_multi_match(self) -> None: """Test dtw with multiple matches and hand calculated outcomes.""" sequence = 2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2] @@ -240,13 +241,13 @@ def test_sdtw_multi_match(self): class TestMultiDimensionalArrayInputs(DtwTestBase): @pytest.mark.parametrize( ("template", "data"), - ( + [ (np.array([[0, 1.0, 0]]), np.array(2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2])), (np.array([0, 1.0, 0]), np.array([2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]])), (np.array([[0, 1.0, 0]]), np.array([2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]])), - ), + ], ) - def test_pseudo_2d_inputs(self, template, data): + def test_pseudo_2d_inputs(self, template, data) -> None: template = DtwTemplate(data=template, sampling_rate_hz=100.0) dtw = self.init_dtw(template=template) @@ -254,7 +255,7 @@ def test_pseudo_2d_inputs(self, template, data): np.testing.assert_array_equal(dtw.matches_start_end_, [[5, 8], [18, 21]]) - def test_1d_dataframe_inputs(self): + def test_1d_dataframe_inputs(self) -> None: template = pd.DataFrame(np.array([0, 1.0, 0]), columns=["col1"]) data = pd.DataFrame(np.array(2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]), columns=["col1"]) template = DtwTemplate(data=template, sampling_rate_hz=100.0) @@ -264,7 +265,7 @@ def test_1d_dataframe_inputs(self): np.testing.assert_array_equal(dtw.matches_start_end_, [[5, 8], [18, 21]]) - def test_no_matches_found(self): + def test_no_matches_found(self) -> None: """Test that no errors are raised when no matches are found.""" template = pd.DataFrame(np.array([0, 1.0, 0]), columns=["col1"]) data = pd.DataFrame(np.ones(10), columns=["col1"]) @@ -277,9 +278,9 @@ def test_no_matches_found(self): np.testing.assert_array_equal(dtw.paths_, []) np.testing.assert_array_equal(dtw.costs_, []) - @pytest.mark.parametrize("m_cols", (2, 3)) - @pytest.mark.parametrize("input_type", (np.array, pd.DataFrame)) - def test_valid_multi_d_input(self, m_cols, input_type): + @pytest.mark.parametrize("m_cols", [2, 3]) + @pytest.mark.parametrize("input_type", [np.array, pd.DataFrame]) + def test_valid_multi_d_input(self, m_cols, input_type) -> None: """Test if we get the same results with simple multi D inputs. Data and template are repeated to have the shape (n, m_cols), where n is the number of samples @@ -300,8 +301,8 @@ def test_valid_multi_d_input(self, m_cols, input_type): np.testing.assert_array_equal(dtw.matches_start_end_, [[5, 8], [18, 21]]) - @pytest.mark.parametrize("input_type", (np.array, pd.DataFrame)) - def test_data_has_more_cols_than_template(self, input_type): + @pytest.mark.parametrize("input_type", [np.array, pd.DataFrame]) + def test_data_has_more_cols_than_template(self, input_type) -> None: """In case the data has more cols, only the first m cols of the data is used. Note that this does not really tests if this work, but just that it doesn't throw an error. @@ -324,7 +325,7 @@ def test_data_has_more_cols_than_template(self, input_type): np.testing.assert_array_equal(dtw.matches_start_end_, [[5, 8], [18, 21]]) - def test_data_has_less_cols_than_template_array(self): + def test_data_has_less_cols_than_template_array(self) -> None: """In case the data has more cols, only the first m cols of the data is used.""" n_cols_template = 3 n_cols_data = 2 @@ -339,7 +340,7 @@ def test_data_has_less_cols_than_template_array(self): assert "less columns" in str(e) - def test_data_has_wrong_cols_than_template_df(self): + def test_data_has_wrong_cols_than_template_df(self) -> None: """An error should be raised, if the template has columns that are not in the df.""" n_cols_template = 3 n_cols_data = 2 @@ -356,7 +357,7 @@ def test_data_has_wrong_cols_than_template_df(self): assert str(["col3"]) in str(e) - def test_no_sampling_rate_for_resample(self): + def test_no_sampling_rate_for_resample(self) -> None: """Error is raised when resample is True, but no sampling rate provided.""" template = DtwTemplate(data=np.ndarray([])) @@ -366,7 +367,7 @@ def test_no_sampling_rate_for_resample(self): assert "sampling_rate_hz" in str(e) - def test_sampling_rate_mismatch_warning(self): + def test_sampling_rate_mismatch_warning(self) -> None: """Test if warning is raised, when template and data do not have the same sampling rate and resample is False.""" template = pd.DataFrame(np.array([0, 1.0, 0]), columns=["col1"]) data = pd.DataFrame(np.array(2 * [*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]), columns=["col1"]) @@ -384,7 +385,7 @@ class TestMultiSensorInputs(DtwTestBase): data: Union[pd.DataFrame, Dict[str, pd.DataFrame]] @pytest.fixture(params=("dict", "frame"), autouse=True) - def multi_sensor_dataset(self, request): + def multi_sensor_dataset(self, request) -> None: sensor1 = np.array([*np.ones(5) * 2, 0, 1.0, 0, *np.ones(5) * 2]) sensor1 = pd.DataFrame(sensor1, columns=["col1"]) sensor2 = np.array([*np.ones(2) * 2, 0, 1.0, 0, *np.ones(8) * 2]) @@ -395,7 +396,7 @@ def multi_sensor_dataset(self, request): elif request.param == "frame": self.data = pd.concat(data, axis=1) - def test_single_template_multi_sensors(self): + def test_single_template_multi_sensors(self) -> None: """In case a single template and multiple sensors are provided, the template is applied to all sensors.""" template = [0, 1.0, 0] template = pd.DataFrame(template, columns=["col1"]) @@ -438,7 +439,7 @@ def test_single_template_multi_sensors(self): ], ) - def test_no_matches_found_multiple(self): + def test_no_matches_found_multiple(self) -> None: """Test postprocessing still works, even when there are no matches.""" template = [0, 1.0, 0] template = pd.DataFrame(template, columns=["col1"]) @@ -457,7 +458,7 @@ def test_no_matches_found_multiple(self): np.testing.assert_array_equal(dtw.paths_[s], []) np.testing.assert_array_equal(dtw.costs_[s], []) - def test_multi_template_multi_sensors(self): + def test_multi_template_multi_sensors(self) -> None: """Test multiple templates with multiple sensors. In case a multiple template and multiple sensors are provided, each template is applied to data with the same @@ -509,13 +510,13 @@ def test_multi_template_multi_sensors(self): class TestTemplateResampling(DtwTestBase): - def test_resample_length(self): + def test_resample_length(self) -> None: """Unittest the resample func.""" test_template = np.ones(100) resampled_template = BaseDtw._resample_template(test_template, 100, 200) assert len(resampled_template) == 200 - def test_resample_real_example(self): + def test_resample_real_example(self) -> None: """Toy example with hand calculated outcomes.""" # Test that this works in general template = [0, 1, 2, 3, 4, 3, 2, 1, 0] @@ -548,8 +549,8 @@ def test_resample_real_example(self): assert len(dtw.matches_start_end_) == 0 - @pytest.mark.parametrize("sampling_rate", (102.4, 204.8, 100, 500, 409.6)) - def test_resample_decimal_values(self, sampling_rate): + @pytest.mark.parametrize("sampling_rate", [102.4, 204.8, 100, 500, 409.6]) + def test_resample_decimal_values(self, sampling_rate) -> None: """As a test, we gonna try to resample the Barth Template to various sampling rates.""" test_template = BarthOriginalTemplate().get_data() resampled_template = BaseDtw._resample_template( @@ -559,7 +560,7 @@ def test_resample_decimal_values(self, sampling_rate): # We just gonna check the length of the resampled template assert len(resampled_template) == int(200 * sampling_rate / 204.8) - def test_bug_missing_if(self): + def test_bug_missing_if(self) -> None: """Test that no error is thrown if sampling rates are different, but the template has no sampling rate.""" template = [0, 1, 2, 3, 4, 3, 2, 1, 0] test_data = np.array([0, 0, *template, 0, 0, *template, 0, 0]) @@ -572,7 +573,7 @@ def test_bug_missing_if(self): except Exception as e: pytest.fail(f"Raised unexpected error: {e}", pytrace=True) - def test_error_if_not_template_sampling_rate(self): + def test_error_if_not_template_sampling_rate(self) -> None: """If the template should be resampled, but the template has no sampling rate, raise an error.""" template = [0, 1, 2, 3, 4, 3, 2, 1, 0] test_data = np.array([0, 0, *template, 0, 0, *template, 0, 0]) @@ -585,7 +586,7 @@ def test_error_if_not_template_sampling_rate(self): assert "resample the template" in str(e) - def test_warning_when_no_resample_but_different_sampling_rate(self): + def test_warning_when_no_resample_but_different_sampling_rate(self) -> None: """In case the template and the data have different sampling rates, but resample is false, the user will be warned. """ @@ -613,7 +614,7 @@ def test_warning_when_no_resample_but_different_sampling_rate(self): class TestDtwConstrains(DtwTestBase): """Test that the local warping constrains work as expected.""" - def test_signal_constrained(self): + def test_signal_constrained(self) -> None: template = np.array([0, 1, 0]) sequence = [*np.ones(5) * 2, *np.repeat(template, 5), *np.ones(5) * 2, *np.repeat(template, 3), *np.ones(5) * 2] template = DtwTemplate(data=template, sampling_rate_hz=1) @@ -638,7 +639,7 @@ def test_signal_constrained(self): assert dtw.costs_[0] == 0.0 assert len(dtw.paths_[0]) == 5 - def test_template_constrained(self): + def test_template_constrained(self) -> None: template = np.array([0, 1, 0]) sequence = [*np.ones(5) * 2, *template, *np.ones(5) * 2, *np.repeat(template, 2), *np.ones(5) * 2] template = np.repeat(template, 6) diff --git a/tests/test_stride_segmentation/test_constrained_barth_dtw.py b/tests/test_stride_segmentation/test_constrained_barth_dtw.py index 0a7e80f1..07ba44d6 100644 --- a/tests/test_stride_segmentation/test_constrained_barth_dtw.py +++ b/tests/test_stride_segmentation/test_constrained_barth_dtw.py @@ -2,6 +2,7 @@ We only test the meta functionality and a regressions, as it is basically identical to BarthDtw, except some defaults. """ + import numpy as np import pytest @@ -35,7 +36,7 @@ class TestCachingFunctionality(MetaTestConfig, TestCachingMixin): class TestRegressionOnRealDataConstrainedDtw: """These regression tests run on a MS dataset, which produces a bunch of issues wiht the normal dtw.""" - def test_real_data_both_feed_regression(self, ms_example_imu_data, snapshot): + def test_real_data_both_feed_regression(self, ms_example_imu_data, snapshot) -> None: data = convert_to_fbf(ms_example_imu_data, right=["right_sensor"], left=["left_sensor"]) dtw = ConstrainedBarthDtw(template=BarthOriginalTemplate(use_cols=("gyr_ml", "gyr_si"))) # Test with default # paras diff --git a/tests/test_stride_segmentation/test_dtw_templates.py b/tests/test_stride_segmentation/test_dtw_templates.py index c9d5548e..74310fcc 100644 --- a/tests/test_stride_segmentation/test_dtw_templates.py +++ b/tests/test_stride_segmentation/test_dtw_templates.py @@ -16,8 +16,8 @@ class TestSerialize: """Test that templates can be serialized correctly.""" - @pytest.mark.parametrize("dtype", (list, np.array, pd.Series, pd.DataFrame)) - def test_different_dtypes(self, dtype): + @pytest.mark.parametrize("dtype", [list, np.array, pd.Series, pd.DataFrame]) + def test_different_dtypes(self, dtype) -> None: template = dtype(list(range(10))) instance = DtwTemplate(data=template) @@ -26,7 +26,7 @@ def test_different_dtypes(self, dtype): compare_algo_objects(instance, new_instance) - def test_index_order_long_dfs(self): + def test_index_order_long_dfs(self) -> None: """Loading df based templates might change their index.""" template = pd.DataFrame(list(range(20))) @@ -38,7 +38,7 @@ def test_index_order_long_dfs(self): class TestTemplateBaseClass: - def test_template_provided(self): + def test_template_provided(self) -> None: """Test very simple case where the template is directly stored in the class instance.""" template = np.arange(10) @@ -46,21 +46,21 @@ def test_template_provided(self): assert_array_equal(instance.get_data(), template) - def test_no_valid_info_provided(self): + def test_no_valid_info_provided(self) -> None: """Test that an error is raised, if neither a filename nor a array is provided.""" instance = DtwTemplate() with pytest.raises(ValueError): _ = instance.get_data() - def test_use_columns_array(self): + def test_use_columns_array(self) -> None: template = np.stack((np.arange(10), np.arange(10, 20))).T instance = DtwTemplate(data=template, use_cols=[1]) assert_array_equal(instance.get_data(), template[:, 1]) - def test_use_columns_dataframe(self): + def test_use_columns_dataframe(self) -> None: template = np.stack((np.arange(10), np.arange(10, 20))).T template = pd.DataFrame(template, columns=["col_1", "col_2"]) @@ -68,7 +68,7 @@ def test_use_columns_dataframe(self): assert_array_equal(instance.get_data(), template[["col_1"]]) - def test_use_columns_wrong_dim(self): + def test_use_columns_wrong_dim(self) -> None: template = np.arange(10) instance = DtwTemplate(data=template, use_cols=[1]) @@ -76,7 +76,7 @@ def test_use_columns_wrong_dim(self): with pytest.raises(ValueError): _ = instance.get_data() - def test_get_data_applies_scaling(self): + def test_get_data_applies_scaling(self) -> None: template = pd.DataFrame(np.arange(10)) instance = DtwTemplate(data=template, scaling=FixedScaler(scale=2)) @@ -85,7 +85,7 @@ def test_get_data_applies_scaling(self): class TestBartTemplate: - def test_load(self): + def test_load(self) -> None: with open_text( "gaitmap_mad.stride_segmentation.dtw._dtw_templates", "barth_original_template.csv" ) as test_data: @@ -97,7 +97,7 @@ def test_load(self): assert barth_instance.sampling_rate_hz == 204.8 assert barth_instance.scaling.get_params() == FixedScaler(500.0, 0).get_params() - def test_hashing(self): + def test_hashing(self) -> None: """Test that calling `get_data` does not modify the hash of the object.""" barth_instance = BarthOriginalTemplate() @@ -109,7 +109,7 @@ def test_hashing(self): class TestCreateTemplate: - def test_create_template_simple(self): + def test_create_template_simple(self) -> None: template = np.arange(10) sampling_rate_hz = 100 @@ -118,7 +118,7 @@ def test_create_template_simple(self): assert_array_equal(instance.get_data(), template) assert instance.sampling_rate_hz == sampling_rate_hz - def test_create_template_use_col(self): + def test_create_template_use_col(self) -> None: template = np.stack((np.arange(10), np.arange(10, 20))).T template = pd.DataFrame(template, columns=["col_1", "col_2"]) sampling_rate_hz = 100 @@ -133,10 +133,10 @@ def test_create_template_use_col(self): class TestCreateInterpolatedTemplate: @pytest.fixture(autouse=True, params=["linear", "nearest"]) - def select_kind(self, request): + def select_kind(self, request) -> None: self.kind = request.param - def test_create_interpolated_template_single_dataset(self): + def test_create_interpolated_template_single_dataset(self) -> None: """Test function can handle single dataset input.""" template_data = pd.DataFrame(np.array([0, 1, 2, 1, 0]), columns=["dummy_col"]) instance = InterpolatedDtwTemplate(interpolation_method=self.kind, n_samples=None).self_optimize( @@ -146,7 +146,7 @@ def test_create_interpolated_template_single_dataset(self): assert_array_almost_equal(instance.get_data(), template_data[["dummy_col"]]) assert isinstance(instance, DtwTemplate) - def test_create_interpolated_template_dataset_list(self): + def test_create_interpolated_template_dataset_list(self) -> None: """Test function can handle lists of dataset input.""" template_data1 = pd.DataFrame(np.array([0, 1, 2, 1, 0]), columns=["dummy_col"]) template_data2 = pd.DataFrame(np.array([0, 1, 2, 1, 0]), columns=["dummy_col"]) @@ -159,7 +159,7 @@ def test_create_interpolated_template_dataset_list(self): assert_array_almost_equal(instance.get_data(), template_data1[["dummy_col"]]) assert isinstance(instance, DtwTemplate) - def test_create_interpolated_template_different_indices(self): + def test_create_interpolated_template_different_indices(self) -> None: """Test that interpolation works even if strides have different indices.""" template_data1 = pd.DataFrame(np.array([0, 1, 2, 1, 0]), columns=["dummy_col"]) template_data2 = pd.DataFrame(np.array([0, 1, 2, 1, 0]), columns=["dummy_col"]) @@ -173,7 +173,7 @@ def test_create_interpolated_template_different_indices(self): assert_array_almost_equal(instance.get_data(), template_data1[["dummy_col"]]) assert isinstance(instance, DtwTemplate) - def test_create_interpolated_template_calculates_mean(self): + def test_create_interpolated_template_calculates_mean(self) -> None: """Test if result is actually mean over all inputs.""" template_data1 = pd.DataFrame(np.array([0, 1, 2, 1, 0]), columns=["dummy_col"]) template_data2 = pd.DataFrame(np.array([0, -1, -2, -1, 0]), columns=["dummy_col"]) @@ -187,7 +187,7 @@ def test_create_interpolated_template_calculates_mean(self): assert_array_almost_equal(instance.get_data(), result_template_df.to_numpy()) assert isinstance(instance, DtwTemplate) - def test_create_interpolated_mean_length_over_input_sequences_template(self): + def test_create_interpolated_mean_length_over_input_sequences_template(self) -> None: """Test template has mean length of all input sequences.""" template_data1 = pd.DataFrame(np.array([0, 1, 2, 3, 4]), columns=["dummy_col"]) template_data2 = pd.DataFrame(np.array([0, 1, 2]), columns=["dummy_col"]) @@ -201,7 +201,7 @@ def test_create_interpolated_mean_length_over_input_sequences_template(self): assert instance.sampling_rate_hz == 1 assert isinstance(instance, DtwTemplate) - def test_create_interpolated_fixed_length_template_upsample(self): + def test_create_interpolated_fixed_length_template_upsample(self) -> None: """Test template has specified length for upsampling.""" template_data1 = pd.DataFrame(np.array([0, 1, 2, 3, 4]), columns=["dummy_col"]) template_data2 = pd.DataFrame(np.array([0, 1, 2]), columns=["dummy_col"]) @@ -215,7 +215,7 @@ def test_create_interpolated_fixed_length_template_upsample(self): assert instance.sampling_rate_hz == 5 / 4 assert isinstance(instance, DtwTemplate) - def test_create_interpolated_fixed_length_template_downsample(self): + def test_create_interpolated_fixed_length_template_downsample(self) -> None: """Test template has specified length for downsampling.""" template_data1 = pd.DataFrame(np.array([0, 1, 2, 3, 5]), columns=["dummy_col"]) template_data2 = pd.DataFrame(np.array([0, 1, 2, 3, 4, 5, 6]), columns=["dummy_col"]) @@ -229,7 +229,7 @@ def test_create_interpolated_fixed_length_template_downsample(self): assert instance.sampling_rate_hz == 3 / 6 assert isinstance(instance, DtwTemplate) - def test_create_interpolated_template_check_multisensordataset_exception(self): + def test_create_interpolated_template_check_multisensordataset_exception(self) -> None: """Test only single sensor datasets are valid input.""" template_data1 = pd.DataFrame(np.array([[0, 1, 2], [0, 1, 2]]), columns=["col_1", "col_2", "col_3"]) template_data2 = pd.DataFrame(np.array([[0, 1, 2], [0, 1, 2]]), columns=["col_1", "col_2", "col_3"]) @@ -239,11 +239,11 @@ def test_create_interpolated_template_check_multisensordataset_exception(self): with pytest.raises(ValidationError, match=r".* SingleSensorData*"): InterpolatedDtwTemplate().self_optimize(dataset, kind=self.kind, n_samples=None) - def test_scaling_retraining(self): + def test_scaling_retraining(self) -> None: class CustomScaler(IdentityTransformer, TrainableTransformerMixin): """Dummy scaler that records the data it is trained with.""" - def __init__(self, opti_data=None): + def __init__(self, opti_data=None) -> None: self.opti_data = opti_data def self_optimize(self, data, **_): @@ -264,7 +264,7 @@ def self_optimize(self, data, **_): assert_array_equal(scaler_instance.opti_data[0], instance.data) - def test_column_selection(self): + def test_column_selection(self) -> None: data1 = pd.DataFrame(np.ones((5, 3)), columns=["col_1", "col_2", "col_3"]) data2 = pd.DataFrame(np.ones((5, 3)), columns=["col_2", "col_1", "col_3"]) diff --git a/tests/test_stride_segmentation/test_roi_stride_segmentation.py b/tests/test_stride_segmentation/test_roi_stride_segmentation.py index da15140b..9b55ae5e 100644 --- a/tests/test_stride_segmentation/test_roi_stride_segmentation.py +++ b/tests/test_stride_segmentation/test_roi_stride_segmentation.py @@ -48,11 +48,11 @@ def create_dummy_multi_sensor_roi(): class TestParameterValidation: @pytest.fixture(autouse=True) - def _create_instance(self): + def _create_instance(self) -> None: instance = RoiStrideSegmentation(BarthDtw()) self.instance = instance - def test_empty_algorithm_raises_error(self): + def test_empty_algorithm_raises_error(self) -> None: instance = RoiStrideSegmentation() with pytest.raises(ValueError): instance.segment( @@ -63,8 +63,8 @@ def test_empty_algorithm_raises_error(self): assert "`segmentation_algorithm` must be a valid instance of a StrideSegmentation algorithm" - @pytest.mark.parametrize("data", (pd.DataFrame, [], None)) - def test_unsuitable_datatype(self, data): + @pytest.mark.parametrize("data", [pd.DataFrame, [], None]) + def test_unsuitable_datatype(self, data) -> None: """No proper Sensordata provided.""" with pytest.raises(ValidationError) as e: self.instance.segment( @@ -75,8 +75,8 @@ def test_unsuitable_datatype(self, data): assert "neither single- or multi-sensor data" in str(e) - @pytest.mark.parametrize("roi", (pd.DataFrame(), None)) - def test_invalid_roi_single_dataset(self, roi): + @pytest.mark.parametrize("roi", [pd.DataFrame(), None]) + def test_invalid_roi_single_dataset(self, roi) -> None: """Test that an error is raised if an invalid roi is provided.""" with pytest.raises(ValidationError) as e: # call segment with invalid ROI @@ -84,7 +84,7 @@ def test_invalid_roi_single_dataset(self, roi): assert "neither a single- or a multi-sensor regions of interest list" in str(e) - def test_multi_roi_single_sensor(self): + def test_multi_roi_single_sensor(self) -> None: with pytest.raises(ValidationError) as e: # call segment with invalid ROI self.instance.segment( @@ -93,7 +93,7 @@ def test_multi_roi_single_sensor(self): assert "multi-sensor regions of interest list with a single sensor" in str(e) - def test_invalid_roi_multiple_dataset(self): + def test_invalid_roi_multiple_dataset(self) -> None: """Test that an error is raised if an invalid roi is provided.""" with pytest.raises(ValidationError) as e: # call segment with invalid ROI @@ -101,7 +101,7 @@ def test_invalid_roi_multiple_dataset(self): assert "neither a single- or a multi-sensor regions of interest list" in str(e) - def test_single_roi_unsync_multi(self): + def test_single_roi_unsync_multi(self) -> None: with pytest.raises(ValidationError) as e: # call segment with invalid ROI # Note, that the empty dataframe as data is actually valid data object and will not raise a validation @@ -112,7 +112,7 @@ def test_single_roi_unsync_multi(self): assert "single-sensor regions of interest list with an unsynchronised" in str(e) - def test_invalid_stride_id_naming(self): + def test_invalid_stride_id_naming(self) -> None: self.instance.set_params(s_id_naming="wrong") with pytest.raises(ValueError) as e: @@ -123,7 +123,7 @@ def test_invalid_stride_id_naming(self): ) assert "s_id_naming" in str(e) - def test_additional_sensors_in_roi(self): + def test_additional_sensors_in_roi(self) -> None: with pytest.raises(KeyError) as e: # Note, that the empty dataframe as data is actually valid data object and will not raise a validation # error. @@ -139,7 +139,7 @@ class MockStrideSegmentation(BaseStrideSegmentation): _action_methods = ("segment", "secondary_segment") - def __init__(self, n=3): + def __init__(self, n=3) -> None: self.n = 3 def segment(self: BaseType, data: SensorData, sampling_rate_hz: float, **kwargs) -> BaseType: @@ -176,10 +176,10 @@ class TestCombinedStridelist: """Test the actual ROI stuff.""" @pytest.fixture(autouse=True, params=["replace", "prefix"]) - def _s_id_naming(self, request): + def _s_id_naming(self, request) -> None: self.s_id_naming = request.param - def test_single_sensor(self): + def test_single_sensor(self) -> None: roi_seg = RoiStrideSegmentation(MockStrideSegmentation(), self.s_id_naming) data = pd.DataFrame(np.ones(27)) roi = pd.DataFrame(np.array([[0, 1, 3], [0, 9, 18], [8, 17, 26]]).T, columns=["roi_id", "start", "end"]) @@ -189,7 +189,7 @@ def test_single_sensor(self): assert is_single_sensor_stride_list(roi_seg.stride_list_) assert len(roi_seg.instances_per_roi_) == len(roi) - assert all([isinstance(o, MockStrideSegmentation) for o in roi_seg.instances_per_roi_.values()]) + assert all(isinstance(o, MockStrideSegmentation) for o in roi_seg.instances_per_roi_.values()) if self.s_id_naming == "replace": assert_array_equal(roi_seg.stride_list_.index, list(range(len(roi_seg.stride_list_)))) @@ -204,7 +204,7 @@ def test_single_sensor(self): if r[1]["roi_id"] == stride[1]["roi_id"]: assert stride[1]["start"] >= r[1]["start"] - def test_multi_sensor(self): + def test_multi_sensor(self) -> None: roi_seg = RoiStrideSegmentation(MockStrideSegmentation(), self.s_id_naming) data = {"s1": pd.DataFrame(np.ones(27)), "s2": pd.DataFrame(np.zeros(27))} roi = pd.DataFrame(np.array([[0, 1, 3], [0, 9, 18], [8, 17, 26]]).T, columns=["roi_id", "start", "end"]) @@ -213,10 +213,10 @@ def test_multi_sensor(self): roi_seg.segment(data, sampling_rate_hz=100, regions_of_interest=roi) assert is_multi_sensor_stride_list(roi_seg.stride_list_) assert len(roi_seg.instances_per_roi_) == len(roi) - assert all([isinstance(o, dict) for o in roi_seg.instances_per_roi_.values()]) + assert all(isinstance(o, dict) for o in roi_seg.instances_per_roi_.values()) for sensor in ["s1", "s2"]: - assert all([isinstance(o, MockStrideSegmentation) for o in roi_seg.instances_per_roi_[sensor].values()]) + assert all(isinstance(o, MockStrideSegmentation) for o in roi_seg.instances_per_roi_[sensor].values()) assert len(roi_seg.stride_list_[sensor]) == len(roi[sensor]) * roi_seg.segmentation_algorithm.n if self.s_id_naming == "replace": assert_array_equal(roi_seg.stride_list_[sensor].index, list(range(len(roi_seg.stride_list_[sensor])))) @@ -232,7 +232,7 @@ def test_multi_sensor(self): if r[1]["roi_id"] == stride[1]["roi_id"]: assert stride[1]["start"] >= r[1]["start"] - def test_multi_sensor_sync(self): + def test_multi_sensor_sync(self) -> None: roi_seg = RoiStrideSegmentation(MockStrideSegmentation(), self.s_id_naming) data = pd.concat({"s1": pd.DataFrame(np.ones(27)), "s2": pd.DataFrame(np.zeros(27))}, axis=1) roi = pd.DataFrame(np.array([[0, 1, 3], [0, 9, 18], [8, 17, 26]]).T, columns=["roi_id", "start", "end"]) @@ -240,7 +240,7 @@ def test_multi_sensor_sync(self): roi_seg.segment(data, sampling_rate_hz=100, regions_of_interest=roi) assert is_multi_sensor_stride_list(roi_seg.stride_list_) assert len(roi_seg.instances_per_roi_) == len(roi) - assert all([isinstance(o, MockStrideSegmentation) for o in roi_seg.instances_per_roi_.values()]) + assert all(isinstance(o, MockStrideSegmentation) for o in roi_seg.instances_per_roi_.values()) for sensor in ["s1", "s2"]: assert len(roi_seg.stride_list_[sensor]) == len(roi) * roi_seg.segmentation_algorithm.n @@ -259,8 +259,8 @@ def test_multi_sensor_sync(self): assert stride[1]["start"] >= r[1]["start"] -@pytest.mark.parametrize("action_method", (None, "segment", "secondary_segment")) -def test_alternative_action_method(action_method): +@pytest.mark.parametrize("action_method", [None, "segment", "secondary_segment"]) +def test_alternative_action_method(action_method) -> None: roi_seg = RoiStrideSegmentation(MockStrideSegmentation(), action_method=action_method) data = pd.concat({"s1": pd.DataFrame(np.ones(27)), "s2": pd.DataFrame(np.zeros(27))}, axis=1) roi = pd.DataFrame(np.array([[0, 1, 3], [0, 9, 18], [8, 17, 26]]).T, columns=["roi_id", "start", "end"]) diff --git a/tests/test_stride_segmentation/test_roth_hmm.py b/tests/test_stride_segmentation/test_roth_hmm.py index 04eaab88..13293d69 100644 --- a/tests/test_stride_segmentation/test_roth_hmm.py +++ b/tests/test_stride_segmentation/test_roth_hmm.py @@ -85,13 +85,13 @@ class TestMetaFunctionalitySimpleHMM(TestAlgorithmMixin): def valid_instance(self, after_action_instance): return SimpleHmm(n_states=5, n_gmm_components=3) - def test_empty_init(self): + def test_empty_init(self) -> None: pytest.skip() class TestRothHmmFeatureTransform: @pytest.mark.parametrize("target_sampling_rate", [50, 25]) - def test_inverse_transform_state_sequence(self, target_sampling_rate): + def test_inverse_transform_state_sequence(self, target_sampling_rate) -> None: transform = RothHmmFeatureTransformer(sampling_rate_feature_space_hz=target_sampling_rate) in_state_sequence = np.array([0, 1, 2, 3, 4, 5]) state_sequence = transform.inverse_transform_state_sequence( @@ -104,7 +104,7 @@ def test_inverse_transform_state_sequence(self, target_sampling_rate): @pytest.mark.parametrize("features", [["raw"], ["raw", "gradient"], ["raw", "gradient", "mean"]]) @pytest.mark.parametrize("axes", [["gyr_ml"], ["acc_pa"], ["gyr_ml", "acc_pa"]]) - def test_select_features(self, features, healthy_example_imu_data, axes): + def test_select_features(self, features, healthy_example_imu_data, axes) -> None: transform = RothHmmFeatureTransformer( features=features, axes=axes, @@ -120,7 +120,7 @@ def test_select_features(self, features, healthy_example_imu_data, axes): f"{feature}{feature_prefixes[feature]}__{axis}" for feature in features for axis in axes } - def test_actual_output(self, healthy_example_imu_data): + def test_actual_output(self, healthy_example_imu_data) -> None: # We disable downsampling, standardization, and filtering for this test transform = RothHmmFeatureTransformer( sampling_rate_feature_space_hz=100, @@ -150,20 +150,20 @@ def test_actual_output(self, healthy_example_imu_data): assert transform.data is data - def test_type_error_filter(self): + def test_type_error_filter(self) -> None: with pytest.raises(TypeError) as e: RothHmmFeatureTransformer(low_pass_filter="test").transform([], sampling_rate_hz=100) assert "low_pass_filter" in str(e.value) @pytest.mark.parametrize(("roi", "data"), [(None, []), ([], None)]) - def test_value_error_missing_sampling_rate(self, roi, data): + def test_value_error_missing_sampling_rate(self, roi, data) -> None: with pytest.raises(ValueError) as e: RothHmmFeatureTransformer().transform(data, roi_list=roi, sampling_rate_hz=None) assert "sampling_rate_hz" in str(e.value) - def test_resample_roi(self): + def test_resample_roi(self) -> None: transform = RothHmmFeatureTransformer(sampling_rate_feature_space_hz=50) roi = pd.DataFrame(np.array([[0, 100], [200, 300], [400, 500]]), columns=["start", "end"]) resampled_roi = transform.transform(roi_list=roi, sampling_rate_hz=100).transformed_roi_list_ @@ -173,7 +173,7 @@ def test_resample_roi(self): class TestSimpleModel: - def test_error_on_different_number_data_and_labels(self): + def test_error_on_different_number_data_and_labels(self) -> None: with pytest.raises(ValueError) as e: SimpleHmm(n_states=5, n_gmm_components=3).self_optimize( [np.random.rand(100, 3)], [np.random.rand(100), np.random.rand(100)] @@ -181,7 +181,7 @@ def test_error_on_different_number_data_and_labels(self): assert "The given training sequence and initial training labels" in str(e.value) - def test_error_if_datasequence_shorter_nstates(self): + def test_error_if_datasequence_shorter_nstates(self) -> None: with pytest.raises(ValueError) as e: SimpleHmm(n_states=5, n_gmm_components=3).self_optimize( [np.random.rand(100, 3), np.random.rand(3, 3)], [np.random.rand(100), np.random.rand(3)] @@ -189,7 +189,7 @@ def test_error_if_datasequence_shorter_nstates(self): assert "Invalid training sequence!" in str(e.value) - def test_error_on_different_length_data_and_labels(self): + def test_error_on_different_length_data_and_labels(self) -> None: with pytest.raises(ValueError) as e: SimpleHmm(n_states=5, n_gmm_components=3).self_optimize( [pd.DataFrame(np.random.rand(100, 3))], [pd.Series(np.random.rand(99))] @@ -197,7 +197,7 @@ def test_error_on_different_length_data_and_labels(self): assert "a different number of samples" in str(e.value) - def test_invalid_label_sequence(self): + def test_invalid_label_sequence(self) -> None: n_states = 5 with pytest.raises(ValueError) as e: SimpleHmm(n_states=n_states, n_gmm_components=3).self_optimize( @@ -217,7 +217,7 @@ def test_invalid_label_sequence(self): # We test one value with n_states > 10, as this should trigger a sorting bug in pomegranate that we are handling # explicitly @pytest.mark.parametrize("n_states", [5, 12]) - def test_optimize_with_single_sequence(self, data, n_gmm_components, n_states): + def test_optimize_with_single_sequence(self, data, n_gmm_components, n_states) -> None: model = SimpleHmm(n_states=n_states, n_gmm_components=n_gmm_components, max_iterations=1) model.self_optimize([data], [pd.Series(np.tile(np.arange(n_states), int(np.ceil(100 / n_states)))[:100])]) @@ -235,7 +235,7 @@ def test_optimize_with_single_sequence(self, data, n_gmm_components, n_states): dists = state.distribution.distributions assert {d.name for d in dists} == {"MultivariateGaussianDistribution"} - def test_model_exists_warning(self): + def test_model_exists_warning(self) -> None: model = SimpleHmm(n_states=5, n_gmm_components=3) model.self_optimize([pd.DataFrame(np.random.rand(100, 3))], [pd.Series(np.random.choice(5, 100))]) with pytest.warns(UserWarning) as e: @@ -243,7 +243,7 @@ def test_model_exists_warning(self): assert "Model already exists" in str(e[0].message) - def test_predict_rasies_error_without_optimize(self): + def test_predict_rasies_error_without_optimize(self) -> None: with pytest.raises(ValueError) as e: SimpleHmm(n_states=5, n_gmm_components=3).predict_hidden_state_sequence( pd.DataFrame(np.random.rand(100, 3)) @@ -251,7 +251,7 @@ def test_predict_rasies_error_without_optimize(self): assert "You need to train the HMM before calling `predict_hidden_state_sequence`" in str(e.value) - def test_predict_raises_error_on_invalid_columns(self): + def test_predict_raises_error_on_invalid_columns(self) -> None: model = SimpleHmm(n_states=5, n_gmm_components=3) col_names = ["feature1", "feature2", "feature3"] invalid_col_names = ["feature1", "feature2", "feature4"] @@ -265,7 +265,7 @@ def test_predict_raises_error_on_invalid_columns(self): assert str(tuple(col_names)) in str(e.value) @pytest.mark.parametrize("algorithm", ["viterbi", "map"]) - def test_predict(self, algorithm): + def test_predict(self, algorithm) -> None: model = SimpleHmm(n_states=5, n_gmm_components=3) model.self_optimize([pd.DataFrame(np.random.rand(100, 3))], [pd.Series(np.random.choice(5, 100))]) pred = model.predict_hidden_state_sequence(pd.DataFrame(np.random.rand(100, 3)), algorithm=algorithm) @@ -273,7 +273,7 @@ def test_predict(self, algorithm): assert set(pred) == set(range(5)) @pytest.mark.parametrize("architecture", ["left-right-strict", "left-right-loose", "fully-connected"]) - def test_different_architectures(self, architecture): + def test_different_architectures(self, architecture) -> None: # We test initialization directly, otherwise training will modify the transition matrizes model = initialize_hmm( [np.random.rand(100, 3)], @@ -307,7 +307,7 @@ def test_different_architectures(self, architecture): expected[:5, 6] = 1 / 2 assert_almost_equal(transition_matrix, expected) - def test_self_optimize_calls_self_optimize_with_info(self): + def test_self_optimize_calls_self_optimize_with_info(self) -> None: data, labels = [pd.DataFrame(np.random.rand(100, 3))], [pd.Series(np.random.choice(5, 100))] with patch.object(SimpleHmm, "self_optimize_with_info") as mock: @@ -317,14 +317,14 @@ def test_self_optimize_calls_self_optimize_with_info(self): mock.assert_called_once_with(data, labels) - def test_self_optimize_with_info_returns_history(self): + def test_self_optimize_with_info_returns_history(self) -> None: data, labels = [pd.DataFrame(np.random.rand(100, 3))], [pd.Series(np.random.choice(5, 100))] instance = SimpleHmm(n_states=5, n_gmm_components=3) trained_instance, history = instance.self_optimize_with_info(data, labels) assert instance is trained_instance assert isinstance(history, History) - def test_invalid_architecture_raises_error(self): + def test_invalid_architecture_raises_error(self) -> None: with pytest.raises(ValueError) as e: SimpleHmm(n_states=5, n_gmm_components=3, architecture="invalid").self_optimize( [pd.DataFrame(np.random.rand(100, 3))], [pd.Series(np.random.choice(5, 100))] @@ -334,13 +334,13 @@ def test_invalid_architecture_raises_error(self): class TestRothSegmentationHmm: - def test_predict_without_model_raises_error(self): + def test_predict_without_model_raises_error(self) -> None: with pytest.raises(ValueError) as e: RothSegmentationHmm().predict(pd.DataFrame(np.random.rand(100, 3)), sampling_rate_hz=100) assert "No trained model for prediction available!" in str(e.value) - def test_self_optimize_calls_self_optimize_with_info(self): + def test_self_optimize_calls_self_optimize_with_info(self) -> None: data, labels = [pd.DataFrame(np.random.rand(100, 3))], [pd.DataFrame({"start": [0], "end": [100]})] with patch.object(RothSegmentationHmm, "self_optimize_with_info") as mock: @@ -350,7 +350,7 @@ def test_self_optimize_calls_self_optimize_with_info(self): mock.assert_called_once_with(data, labels, sampling_rate_hz=100) - def test_self_optimize_with_info_returns_history(self): + def test_self_optimize_with_info_returns_history(self) -> None: data, labels = ( [pd.DataFrame(np.random.rand(120, 6), columns=BF_COLS)], [pd.DataFrame({"start": [0, 40, 70], "end": [30, 70, 100]})], @@ -366,7 +366,7 @@ def test_self_optimize_with_info_returns_history(self): assert isinstance(v, History) assert set(history.keys()) == {"stride_model", "transition_model", "self"} - def test_short_strides_raise_warning(self): + def test_short_strides_raise_warning(self) -> None: data, labels = ( [pd.DataFrame(np.random.rand(130, 6), columns=BF_COLS)], [pd.DataFrame({"start": [0, 40, 70, 110], "end": [30, 70, 100, 114]})], @@ -381,7 +381,7 @@ def test_short_strides_raise_warning(self): assert "1 strides (out of 4)" in str(w[0].message) - def test_short_transitions_raise_warning(self): + def test_short_transitions_raise_warning(self) -> None: data, labels = ( [pd.DataFrame(np.random.rand(250, 6), columns=BF_COLS)], [pd.DataFrame({"start": [0, 70, 102, 125, 170], "end": [30, 100, 125, 170, 200]})], @@ -398,7 +398,7 @@ def test_short_transitions_raise_warning(self): # The first warning is the warning about negative improvements during training assert "1 transitions (out of 3)" in str(w[1].message) - def test_strange_inputs_trigger_nan_error(self): + def test_strange_inputs_trigger_nan_error(self) -> None: # XXXX: We test the skip at the moment because it is not deteministic... pytest.skip() @@ -422,7 +422,7 @@ def test_strange_inputs_trigger_nan_error(self): assert "During training the improvement per epoch became NaN/infinite or negative!" in str(w[0].message) assert "the provided pomegranate model has non-finite/NaN parameters." in str(e.value) - def test_training_updates_all_models(self): + def test_training_updates_all_models(self) -> None: """Training should modify the stride, the transition model and the model itself.""" data, labels = ( [pd.DataFrame(np.random.rand(250, 6), columns=BF_COLS)], @@ -447,7 +447,7 @@ def test_training_updates_all_models(self): class TestHmmStrideSegmentation: - def test_segment_with_single_dataset(self, healthy_example_imu_data): + def test_segment_with_single_dataset(self, healthy_example_imu_data) -> None: data = convert_left_foot_to_fbf(healthy_example_imu_data["left_sensor"]) model = PreTrainedRothSegmentationModel() instance = HmmStrideSegmentation(model=model) @@ -462,7 +462,7 @@ def test_segment_with_single_dataset(self, healthy_example_imu_data): assert isinstance(result.hidden_state_sequence_, np.ndarray) assert result.hidden_state_sequence_ is result.result_model_.hidden_state_sequence_ - def test_segment_with_multi_dataset(self, healthy_example_imu_data): + def test_segment_with_multi_dataset(self, healthy_example_imu_data) -> None: data = convert_to_fbf(healthy_example_imu_data, left_like="left_", right_like="right_") model = PreTrainedRothSegmentationModel() instance = HmmStrideSegmentation(model=model) @@ -482,14 +482,14 @@ def test_segment_with_multi_dataset(self, healthy_example_imu_data): assert isinstance(result.hidden_state_sequence_[sensor], np.ndarray) assert result.hidden_state_sequence_[sensor] is result.result_model_[sensor].hidden_state_sequence_ - def test_matches_start_end_and_stride_list_identical(self, healthy_example_imu_data): + def test_matches_start_end_and_stride_list_identical(self, healthy_example_imu_data) -> None: data = convert_left_foot_to_fbf(healthy_example_imu_data["left_sensor"])[:3000] instance = HmmStrideSegmentation() result: HmmStrideSegmentation = instance.segment(data, 204.8) assert np.array_equal(result.matches_start_end_, result.stride_list_.to_numpy()) - def test_matches_start_end_original_identical_without_post(self, healthy_example_imu_data): + def test_matches_start_end_original_identical_without_post(self, healthy_example_imu_data) -> None: data = convert_left_foot_to_fbf(healthy_example_imu_data["left_sensor"])[:3000] # With post processing (default), they should be different @@ -551,7 +551,7 @@ def test_matches_start_end_original_identical_without_post(self, healthy_example ), ], ) - def test_hidden_state_sequence_start_end(self, starts, ends, correct): + def test_hidden_state_sequence_start_end(self, starts, ends, correct) -> None: """Test that the start end values are correctly extracted.""" hidden_state_sequence = np.zeros(50) hidden_state_sequence[starts] = 1 @@ -562,5 +562,5 @@ def test_hidden_state_sequence_start_end(self, starts, ends, correct): assert_array_equal(starts_ends, correct) -def test_pre_trained_model_returns_correctly(): +def test_pre_trained_model_returns_correctly() -> None: assert isinstance(PreTrainedRothSegmentationModel(), RothSegmentationHmm) diff --git a/tests/test_trajectory_reconstruction/test_orientation_methods/test_madgwick.py b/tests/test_trajectory_reconstruction/test_orientation_methods/test_madgwick.py index ea0e25b9..c2d4d3fb 100644 --- a/tests/test_trajectory_reconstruction/test_orientation_methods/test_madgwick.py +++ b/tests/test_trajectory_reconstruction/test_orientation_methods/test_madgwick.py @@ -35,7 +35,7 @@ class TestSimpleRotations(TestOrientationMethodMixin): def init_algo_class(self) -> BaseOrientationMethod: return MadgwickAHRS() - def test_correction_works(self): + def test_correction_works(self) -> None: """Madgwick should be able to resist small roations if acc does not change.""" ori = np.array([0, 0, 0, 1.0]) initial_ori = ori diff --git a/tests/test_trajectory_reconstruction/test_orientation_methods/test_ori_method_mixin.py b/tests/test_trajectory_reconstruction/test_orientation_methods/test_ori_method_mixin.py index 245bcf8e..33121243 100644 --- a/tests/test_trajectory_reconstruction/test_orientation_methods/test_ori_method_mixin.py +++ b/tests/test_trajectory_reconstruction/test_orientation_methods/test_ori_method_mixin.py @@ -16,9 +16,9 @@ def init_algo_class(self) -> BaseOrientationMethod: @pytest.mark.parametrize( ("axis_to_rotate", "vector_to_rotate", "expected_result"), - (([1, 0, 0], [0, 0, 1], [0, 0, -1]), ([0, 1, 0], [0, 0, 1], [0, 0, -1]), ([0, 0, 1], [1, 0, 0], [-1, 0, 0])), + [([1, 0, 0], [0, 0, 1], [0, 0, -1]), ([0, 1, 0], [0, 0, 1], [0, 0, -1]), ([0, 0, 1], [1, 0, 0], [-1, 0, 0])], ) - def test_180(self, axis_to_rotate: int, vector_to_rotate: list, expected_result: list): + def test_180(self, axis_to_rotate: int, vector_to_rotate: list, expected_result: list) -> None: """Rotate by 180 degree around one axis and check resulting rotation by transforming a 3D vector with start and final rotation. @@ -46,7 +46,7 @@ def test_180(self, axis_to_rotate: int, vector_to_rotate: list, expected_result: np.testing.assert_array_almost_equal(Rotation(rot_final).apply(vector_to_rotate), expected_result, decimal=1) assert len(test.orientation_) == fs + 1 - def test_idiot_update(self): + def test_idiot_update(self) -> None: test = self.init_algo_class() fs = 10 sensor_data = np.repeat(np.array([0, 0, 0, 0, 0, 0])[None, :], fs, axis=0) * np.rad2deg(np.pi) @@ -54,7 +54,7 @@ def test_idiot_update(self): test.estimate(sensor_data, sampling_rate_hz=fs) np.testing.assert_array_equal(test.orientation_.iloc[-1], test.initial_orientation) - def test_output_formats(self): + def test_output_formats(self) -> None: test = self.init_algo_class() fs = 10 sensor_data = np.repeat(np.array([0, 0, 0, 0, 0, 0])[None, :], fs, axis=0) * np.rad2deg(np.pi) @@ -66,7 +66,7 @@ def test_output_formats(self): assert is_single_sensor_orientation_list(test.orientation_, orientation_list_type=None) assert len(test.orientation_) == len(sensor_data) + 1 - def test_single_stride_regression(self, healthy_example_imu_data, healthy_example_stride_events, snapshot): + def test_single_stride_regression(self, healthy_example_imu_data, healthy_example_stride_events, snapshot) -> None: """Simple regression test with default parameters.""" test = self.init_algo_class() fs = 204.8 diff --git a/tests/test_trajectory_reconstruction/test_postition_methods/test_piece_wise_linear_dedrifted_integration.py b/tests/test_trajectory_reconstruction/test_postition_methods/test_piece_wise_linear_dedrifted_integration.py index 4eb64885..46eb1dc6 100644 --- a/tests/test_trajectory_reconstruction/test_postition_methods/test_piece_wise_linear_dedrifted_integration.py +++ b/tests/test_trajectory_reconstruction/test_postition_methods/test_piece_wise_linear_dedrifted_integration.py @@ -35,8 +35,8 @@ def init_algo_class(self) -> BasePositionMethod: # For basic integration tests, we do not remove gravity return PieceWiseLinearDedriftedIntegration(gravity=None).set_params(zupt_detector__window_length_s=0.1) - @pytest.mark.parametrize("acc", ([0, 0, 1], [1, 2, 3])) - def test_symetric_velocity_integrations(self, acc): + @pytest.mark.parametrize("acc", [[0, 0, 1], [1, 2, 3]]) + def test_symetric_velocity_integrations(self, acc) -> None: """All test data starts and ends at zero.""" # we had to overwrite this test as the PieceWiseLinearDedriftedIntegration function requires some valid # zupt updates within the test data @@ -55,8 +55,8 @@ def test_symetric_velocity_integrations(self, acc): assert_array_equal(test.velocity_.to_numpy()[0], expected) assert_array_equal(test.velocity_.to_numpy()[-1], expected) - @pytest.mark.parametrize("acc", ([0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 2, 0], [1, 2, 0], [1, 2, 3])) - def test_all_axis(self, acc): + @pytest.mark.parametrize("acc", [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 2, 0], [1, 2, 0], [1, 2, 3]]) + def test_all_axis(self, acc) -> None: """Test against the physics equation.""" # we had to overwrite this test as the PieceWiseLinearDedriftedIntegration function requires some valid # zupt updates within the test data @@ -103,7 +103,7 @@ def test_all_axis(self, acc): class TestPieceWiseLinearDedriftedIntegration: """Test the position estimation class `PieceWiseLinearDedriftedIntegration`.""" - def test_drift_model_simple(self): + def test_drift_model_simple(self) -> None: """Run a simple example and estimate its drift _model.""" data = np.array( [ @@ -291,7 +291,7 @@ def test_drift_model_simple(self): ) assert_almost_equal(estimated_drift_model, expected_output) - def test_drift_model_multidimensional(self): + def test_drift_model_multidimensional(self) -> None: data = np.column_stack([np.linspace(1, 10, 10), np.linspace(10, 20, 10), np.linspace(20, 10, 10)]) zupt = np.array([[5, 10]]) estimated_drift_model = PieceWiseLinearDedriftedIntegration()._estimate_piece_wise_linear_drift_model( @@ -300,7 +300,7 @@ def test_drift_model_multidimensional(self): assert_almost_equal(estimated_drift_model, data) - def test_drift_model(self): + def test_drift_model(self) -> None: """Test drift _model on simple slope with different zupt edge conditions.""" data = np.arange(20) zupt = np.repeat(False, 20) @@ -317,7 +317,7 @@ def test_drift_model(self): ) assert_almost_equal(data, estimated_drift_model) - def test_all_zupt_data(self): + def test_all_zupt_data(self) -> None: """Test drift _model all zupt.""" data = np.arange(20) zupt = np.repeat(True, 20) @@ -327,7 +327,7 @@ def test_all_zupt_data(self): ) assert_almost_equal(data, estimated_drift_model) - def test_no_zupt_data(self): + def test_no_zupt_data(self) -> None: """Test drift _model no zupts available.""" data = np.arange(20) zupt = np.repeat(False, 20) diff --git a/tests/test_trajectory_reconstruction/test_postition_methods/test_pos_method_mixin.py b/tests/test_trajectory_reconstruction/test_postition_methods/test_pos_method_mixin.py index a2f84350..29f262d4 100644 --- a/tests/test_trajectory_reconstruction/test_postition_methods/test_pos_method_mixin.py +++ b/tests/test_trajectory_reconstruction/test_postition_methods/test_pos_method_mixin.py @@ -19,7 +19,7 @@ class TestPositionMethodNoGravityMixin: def init_algo_class(self) -> BasePositionMethod: raise NotImplementedError("Should be implemented by ChildClass") - def test_idiot_update(self): + def test_idiot_update(self) -> None: """Integrate zeros.""" test = self.init_algo_class() idiot_data = pd.DataFrame(np.zeros((10, 6)), columns=SF_COLS) @@ -35,7 +35,7 @@ def test_idiot_update(self): assert_frame_equal(test.velocity_, expected_vel) assert_frame_equal(test.position_, expected_pos) - def test_output_formats(self): + def test_output_formats(self) -> None: test = self.init_algo_class() sensor_data = pd.DataFrame(np.zeros((10, 6)), columns=SF_COLS) @@ -45,8 +45,8 @@ def test_output_formats(self): assert len(test.position_) == len(sensor_data) + 1 assert len(test.velocity_) == len(sensor_data) + 1 - @pytest.mark.parametrize("acc", ([0, 0, 1], [1, 2, 3])) - def test_symetric_velocity_integrations(self, acc): + @pytest.mark.parametrize("acc", [[0, 0, 1], [1, 2, 3]]) + def test_symetric_velocity_integrations(self, acc) -> None: """All test data starts and ends at zero.""" test = self.init_algo_class() @@ -60,8 +60,8 @@ def test_symetric_velocity_integrations(self, acc): assert_array_equal(test.velocity_.to_numpy()[0], expected) assert_array_equal(test.velocity_.to_numpy()[-1], expected) - @pytest.mark.parametrize("acc", ([0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 2, 0], [1, 2, 0], [1, 2, 3])) - def test_all_axis(self, acc): + @pytest.mark.parametrize("acc", [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 2, 0], [1, 2, 0], [1, 2, 3]]) + def test_all_axis(self, acc) -> None: """Test against the physics equation.""" test = self.init_algo_class() @@ -87,7 +87,7 @@ def test_all_axis(self, acc): assert_array_almost_equal(test.velocity_.to_numpy()[n_steps], expected_vel) assert_array_almost_equal(test.position_.to_numpy()[n_steps], expected_pos) - def test_single_stride_regression(self, healthy_example_imu_data, healthy_example_stride_events, snapshot): + def test_single_stride_regression(self, healthy_example_imu_data, healthy_example_stride_events, snapshot) -> None: """Simple regression test with default parameters.""" test = self.init_algo_class() fs = 204.8 diff --git a/tests/test_trajectory_reconstruction/test_region_level_trajectory.py b/tests/test_trajectory_reconstruction/test_region_level_trajectory.py index 012f55ed..ccaa42f8 100644 --- a/tests/test_trajectory_reconstruction/test_region_level_trajectory.py +++ b/tests/test_trajectory_reconstruction/test_region_level_trajectory.py @@ -37,15 +37,15 @@ def after_action_instance(self, healthy_example_imu_data, healthy_example_stride ) return trajectory - def test_all_other_parameters_documented(self, after_action_instance): + def test_all_other_parameters_documented(self, after_action_instance) -> None: # As the class has multiple action methods with different parameters, this test can not pass in its current # state pytest.skip() class TestEstimateIntersect: - @pytest.mark.parametrize(("sl_type", "roi_type"), (("single", "multi"), ("multi", "single"))) - def test_datatypes_mismatch(self, sl_type, roi_type, healthy_example_imu_data): + @pytest.mark.parametrize(("sl_type", "roi_type"), [("single", "multi"), ("multi", "single")]) + def test_datatypes_mismatch(self, sl_type, roi_type, healthy_example_imu_data) -> None: roi_list = pd.DataFrame({"start": [0], "end": [8]}).rename_axis("roi_id") stride_list = pd.DataFrame({"start": [0], "end": [1]}).rename_axis("s_id") if roi_type == "multi": @@ -65,7 +65,7 @@ def test_datatypes_mismatch(self, sl_type, roi_type, healthy_example_imu_data): assert f"The stride list is {sl_type} sensor and the ROI list is {roi_type} sensor." in str(e) - def test_simple(self): + def test_simple(self) -> None: acc_xy = pd.Series([0, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0]) acc_z = pd.Series([9.81] * len(acc_xy)) gyr = pd.Series([0.0] * len(acc_xy)) @@ -93,13 +93,13 @@ def test_simple(self): class TestIntersect: @pytest.mark.parametrize( ("position", "starts", "ends"), - ( + [ ([-1, 0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 3, 6], [3, 6, 9]), ([-1, 0, 0, 0, -1, 1, 1, 1, 2, 2, 2], [0, 4, 7], [3, 7, 10]), ([-1, 0, 0, 0, 1, 1, 1, 2, 2, 2], [0], [9]), - ), + ], ) - def test_intersect_single_region(self, position, starts, ends): + def test_intersect_single_region(self, position, starts, ends) -> None: # Note that the first value of position simulates the initial orientation calculated by the method. # The stride and roi list indices are relative to the data, which has one sample less. test_position = pd.DataFrame( @@ -121,7 +121,7 @@ def test_intersect_single_region(self, position, starts, ends): # Output should pass stride pos list test assert is_single_sensor_position_list(intersected_pos, "stride") - def test_strides_outside_region(self): + def test_strides_outside_region(self) -> None: # Strides outside regions should simply be ignored position = [-1, 0, 0, 0, 1, 1, 1, 2, 2, 2] # 2 outside, 2 inside @@ -140,7 +140,7 @@ def test_strides_outside_region(self): assert len(intersected_pos.groupby("s_id")) == 2 - def test_multiple_roi(self): + def test_multiple_roi(self) -> None: # two gait sequences with some padding in between position = {"gs1": pd.Series([-1, 0, 0, 0, 1, 1, 1]), "gs2": pd.Series([-1, 2, 2, 2, 3, 3, 3])} test_roi_list = pd.DataFrame({"roi_id": ["gs1", "gs2"], "start": [0, 8], "end": [6, 14]}) @@ -165,7 +165,7 @@ def test_multiple_roi(self): np.testing.assert_array_equal(intersected_pos["pos_x"].loc[2].to_numpy(), [-1, 2, 2, 2]) np.testing.assert_array_equal(intersected_pos["pos_x"].loc[3].to_numpy(), [2, 3, 3, 3]) - def test_overlapping_roi(self): + def test_overlapping_roi(self) -> None: # Overlapping rois the stride information should be taken from the last roi. position = {"gs1": pd.Series([-1, 0, 0, 0, 1, 1, 1]), "gs2": pd.Series([-1, 2, 2, 2, 3, 3, 3])} # ROIs completely overlap @@ -187,13 +187,13 @@ def test_overlapping_roi(self): np.testing.assert_array_equal(intersected_pos["pos_x"].loc[0].to_numpy(), [-1, 2, 2, 2]) np.testing.assert_array_equal(intersected_pos["pos_x"].loc[1].to_numpy(), [2, 3, 3, 3]) - def test_estimate_must_be_called(self): + def test_estimate_must_be_called(self) -> None: with pytest.raises(ValidationError) as e: RegionLevelTrajectory().intersect({}) assert "`estimate`" in str(e) - def test_estimate_intersect_was_called(self): + def test_estimate_intersect_was_called(self) -> None: # Simulate calling `estimate_intersect` by setting a stride-level pos list: position = [1, 1, 1] position_list = pd.DataFrame( @@ -210,8 +210,8 @@ def test_estimate_intersect_was_called(self): assert "`estimate_intersect`" in str(e) - @pytest.mark.parametrize(("sl_type", "pos_type"), (("single", "multi"), ("multi", "single"))) - def test_datatypes_mismatch(self, sl_type, pos_type): + @pytest.mark.parametrize(("sl_type", "pos_type"), [("single", "multi"), ("multi", "single")]) + def test_datatypes_mismatch(self, sl_type, pos_type) -> None: position = [1, 1, 1] position_list = pd.DataFrame( np.array([[0] * len(position), position, position, position]).T, @@ -234,15 +234,15 @@ def test_datatypes_mismatch(self, sl_type, pos_type): assert f"{pos_type} sensor dataset with a {sl_type} sensor stride list" in str(e) - @pytest.mark.parametrize("value", ((), "invalid", ("invalid1", "orientation"))) - def test_invalid_return_data(self, value): + @pytest.mark.parametrize("value", [(), "invalid", ("invalid1", "orientation")]) + def test_invalid_return_data(self, value) -> None: rlt = RegionLevelTrajectory() rlt.position_ = [1] # Something other than None with pytest.raises(ValueError) as e: rlt.intersect({}, value) assert str(("orientation", "position", "velocity")) in str(e) - def test_data_has_been_modified(self): + def test_data_has_been_modified(self) -> None: rlt = RegionLevelTrajectory() # Simulate non-valid result properties rlt.position_ = "not a valid position list" @@ -252,7 +252,7 @@ def test_data_has_been_modified(self): assert "manipulated the outputs" in str(e) - def test_multiple_sensors(self): + def test_multiple_sensors(self) -> None: position = [-1, 0, 0, 0, 1, 1, 1] starts = [0, 3] ends = [3, 6] @@ -276,8 +276,8 @@ def test_multiple_sensors(self): class TestRegionLevelTrajectory: - @pytest.mark.parametrize("method", ("estimate", "estimate_intersect")) - def test_event_list_forwarded(self, method): + @pytest.mark.parametrize("method", ["estimate", "estimate_intersect"]) + def test_event_list_forwarded(self, method) -> None: with patch.object(MockTrajectory, "estimate") as mock_estimate: mock_estimate.return_value = MockTrajectory() test = RegionLevelTrajectory(ori_method=None, pos_method=None, trajectory_method=MockTrajectory()) @@ -303,8 +303,8 @@ def test_event_list_forwarded(self, method): pd.DataFrame({"start": [0, 3], "end": [3, 5], "min_vel": [0, 3]}, index=pd.Series([4, 5], name="s_id")), ) - @pytest.mark.parametrize("method", ("estimate", "estimate_intersect")) - def test_event_list_forwarded_multi(self, method): + @pytest.mark.parametrize("method", ["estimate", "estimate_intersect"]) + def test_event_list_forwarded_multi(self, method) -> None: with patch.object(MockTrajectory, "estimate") as mock_estimate: mock_estimate.return_value = MockTrajectory() test = RegionLevelTrajectory(ori_method=None, pos_method=None, trajectory_method=MockTrajectory()) diff --git a/tests/test_trajectory_reconstruction/test_stride_level_trajectory.py b/tests/test_trajectory_reconstruction/test_stride_level_trajectory.py index b25404fd..4ce5c7e5 100644 --- a/tests/test_trajectory_reconstruction/test_stride_level_trajectory.py +++ b/tests/test_trajectory_reconstruction/test_stride_level_trajectory.py @@ -28,7 +28,7 @@ def after_action_instance(self, healthy_example_imu_data, healthy_example_stride class TestStrideLevelTrajectory: - def test_event_list_forwarded(self): + def test_event_list_forwarded(self) -> None: with patch.object(MockTrajectory, "estimate") as mock_estimate: mock_estimate.return_value = MockTrajectory() test = StrideLevelTrajectory(ori_method=None, pos_method=None, trajectory_method=MockTrajectory()) @@ -55,7 +55,7 @@ def test_event_list_forwarded(self): pd.DataFrame({"start": [0], "end": [7], "min_vel": [0]}, index=pd.Series([3], name="s_id")), ) - def test_event_list_forwarded_multi(self): + def test_event_list_forwarded_multi(self) -> None: with patch.object(MockTrajectory, "estimate") as mock_estimate: mock_estimate.return_value = MockTrajectory() test = StrideLevelTrajectory(ori_method=None, pos_method=None, trajectory_method=MockTrajectory()) diff --git a/tests/test_trajectory_reconstruction/test_trajectory_methods/test_rts_kalman.py b/tests/test_trajectory_reconstruction/test_trajectory_methods/test_rts_kalman.py index 56206136..04efea06 100644 --- a/tests/test_trajectory_reconstruction/test_trajectory_methods/test_rts_kalman.py +++ b/tests/test_trajectory_reconstruction/test_trajectory_methods/test_rts_kalman.py @@ -46,7 +46,7 @@ def init_algo_class(self, **kwargs) -> RtsKalman: kwargs = {**self.default_kwargs, **kwargs} return RtsKalman().set_params(**kwargs) - def test_covariance_output_format(self): + def test_covariance_output_format(self) -> None: test = self.init_algo_class(zupt_detector__window_length_s=1) fs = 15 sensor_data = np.repeat(np.array([0.0, 0.0, 9.81, 0.0, 0.0, 0.0])[None, :], fs, axis=0) @@ -55,7 +55,7 @@ def test_covariance_output_format(self): assert test.covariance_.shape == (len(sensor_data) + 1, 9 * 9) - def test_zupt_output(self): + def test_zupt_output(self) -> None: test = self.init_algo_class( zupt_detector__inactive_signal_threshold=10, zupt_detector__window_length_s=0.3, @@ -71,7 +71,7 @@ def test_zupt_output(self): assert_array_almost_equal(expected_zupts, test.zupts_) - def test_corrects_velocity_drift(self): + def test_corrects_velocity_drift(self) -> None: """Check that ZUPTs correct a velocity drift and set velocity to zero.""" test = self.init_algo_class(zupt_detector__window_length_s=0.3, level_walking=False) acc = np.array([5.0, 5.0, 12.81]) @@ -82,7 +82,7 @@ def test_corrects_velocity_drift(self): test.estimate(sensor_data, sampling_rate_hz=10) assert_array_almost_equal(test.velocity_.to_numpy()[-1], [0.0, 0.0, 0.0], decimal=10) - def test_corrects_z_position(self): + def test_corrects_z_position(self) -> None: """Check that level walking reset position to zero during ZUPTs.""" test = self.init_algo_class(zupt_detector__window_length_s=1) accel_data = np.repeat(np.concatenate(([0.0, 0.0, 100], [0.0, 0.0, 40.0]))[None, :], 5, axis=0) @@ -93,7 +93,7 @@ def test_corrects_z_position(self): assert test.position_.to_numpy()[4][2] < -0.8 assert_array_almost_equal(test.position_.to_numpy()[-1], [0.0, 0.0, 0.0], decimal=10) - def test_stride_list_forwarded_to_zupt(self): + def test_stride_list_forwarded_to_zupt(self) -> None: """Test that the stride list passed to reconstruct is forwarded to the detect method of the ZUPT detector.""" class MockZUPTDetector(BaseZuptDetector): diff --git a/tests/test_trajectory_reconstruction/test_trajectory_methods/test_trajectory_method_mixin.py b/tests/test_trajectory_reconstruction/test_trajectory_methods/test_trajectory_method_mixin.py index d4546147..09600da3 100644 --- a/tests/test_trajectory_reconstruction/test_trajectory_methods/test_trajectory_method_mixin.py +++ b/tests/test_trajectory_reconstruction/test_trajectory_methods/test_trajectory_method_mixin.py @@ -21,7 +21,7 @@ class TestTrajectoryMethodMixin: def init_algo_class(self) -> BaseTrajectoryMethod: raise NotImplementedError("Should be implemented by ChildClass") - def test_idiot_update(self): + def test_idiot_update(self) -> None: """Integrate zeros except for gravity.""" test = self.init_algo_class() idiot_data = pd.DataFrame(np.zeros((15, 6)), columns=SF_COLS) @@ -41,7 +41,7 @@ def test_idiot_update(self): assert_frame_equal(test.velocity_, expected_vel) assert_frame_equal(test.orientation_, expected_ori) - def test_output_formats(self): + def test_output_formats(self) -> None: test = self.init_algo_class() fs = 100 sensor_data = np.repeat(np.array([0.0, 0.0, 9.81, 0.0, 0.0, 0.0])[None, :], fs, axis=0) @@ -59,9 +59,9 @@ def test_output_formats(self): @pytest.mark.parametrize( ("axis_to_rotate", "vector_to_rotate", "expected_result"), - (([1, 0, 0], [0, 0, 1], [0, 0, -1]), ([0, 1, 0], [0, 0, 1], [0, 0, -1]), ([0, 0, 1], [1, 0, 0], [-1, 0, 0])), + [([1, 0, 0], [0, 0, 1], [0, 0, -1]), ([0, 1, 0], [0, 0, 1], [0, 0, -1]), ([0, 0, 1], [1, 0, 0], [-1, 0, 0])], ) - def test_180(self, axis_to_rotate: int, vector_to_rotate: list, expected_result: list): + def test_180(self, axis_to_rotate: int, vector_to_rotate: list, expected_result: list) -> None: """Rotate by 180 degree around one axis and check resulting rotation by transforming a 3D vector with start and final rotation. @@ -89,7 +89,7 @@ def test_180(self, axis_to_rotate: int, vector_to_rotate: list, expected_result: np.testing.assert_array_almost_equal(Rotation(rot_final).apply(vector_to_rotate), expected_result, decimal=1) assert len(test.orientation_) == fs + 1 - def test_symmetric_velocity_integrations(self): + def test_symmetric_velocity_integrations(self) -> None: """Test data starts and ends at zero.""" test = self.init_algo_class() acc = np.array([0.0, 0.0, 10.0]) @@ -106,7 +106,7 @@ def test_symmetric_velocity_integrations(self): assert_array_almost_equal(test.velocity_.to_numpy()[0], expected, decimal=10) assert_array_almost_equal(test.velocity_.to_numpy()[-1], expected, decimal=10) - def test_full_trajectory_regression(self, healthy_example_imu_data, snapshot): + def test_full_trajectory_regression(self, healthy_example_imu_data, snapshot) -> None: """Simple regression test with default parameters.""" test = self.init_algo_class() fs = 204.8 diff --git a/tests/test_trajectory_reconstruction/test_trajectory_wrapper.py b/tests/test_trajectory_reconstruction/test_trajectory_wrapper.py index c5558dfb..f1b740f6 100644 --- a/tests/test_trajectory_reconstruction/test_trajectory_wrapper.py +++ b/tests/test_trajectory_reconstruction/test_trajectory_wrapper.py @@ -28,7 +28,7 @@ class TestIODataStructures: output_list_type: Literal["roi", "stride"] @pytest.fixture(params=(StrideLevelTrajectory, RegionLevelTrajectory), autouse=True) - def select_wrapper(self, healthy_example_stride_events, request): + def select_wrapper(self, healthy_example_stride_events, request) -> None: self.wrapper_class = request.param if self.wrapper_class == RegionLevelTrajectory: self.example_region = { @@ -41,7 +41,7 @@ def select_wrapper(self, healthy_example_stride_events, request): self.output_list_type = "stride" self.key = "s_id" - def test_invalid_input_data(self, healthy_example_imu_data): + def test_invalid_input_data(self, healthy_example_imu_data) -> None: """Test if error is raised correctly on invalid input data type.""" data = healthy_example_imu_data stride_list = self.example_region @@ -57,23 +57,23 @@ def test_invalid_input_data(self, healthy_example_imu_data): with pytest.raises(ValidationError, match=r".*neither a single- or a multi-sensor "): gyr_int.estimate(data, fake_stride_list, sampling_rate_hz=204.8) - @pytest.mark.parametrize("method", ("ori_method", "pos_method")) - def test_invalid_input_method(self, healthy_example_imu_data, method): + @pytest.mark.parametrize("method", ["ori_method", "pos_method"]) + def test_invalid_input_method(self, healthy_example_imu_data, method) -> None: """Test if correct errors are raised for invalid pos and ori methods.""" instance = self.wrapper_class(**{method: "wrong"}) with pytest.raises(ValueError) as e: instance.estimate(healthy_example_imu_data, self.example_region, sampling_rate_hz=204.8) assert method in str(e) - @pytest.mark.parametrize("method", ("ori_method", "pos_method")) - def test_only_pos_or_ori_provided(self, healthy_example_imu_data, method): + @pytest.mark.parametrize("method", ["ori_method", "pos_method"]) + def test_only_pos_or_ori_provided(self, healthy_example_imu_data, method) -> None: instance = self.wrapper_class(**{method: None}) with pytest.raises(ValueError) as e: instance.estimate(healthy_example_imu_data, self.example_region, sampling_rate_hz=204.8) assert "either a `ori` and a `pos` method" in str(e) - @pytest.mark.parametrize("method", ("ori_method", "pos_method")) - def test_trajectory_warning(self, healthy_example_imu_data, method): + @pytest.mark.parametrize("method", ["ori_method", "pos_method"]) + def test_trajectory_warning(self, healthy_example_imu_data, method) -> None: instance = self.wrapper_class(**{method: RtsKalman()}) with pytest.warns(UserWarning) as w: instance.estimate( @@ -83,7 +83,7 @@ def test_trajectory_warning(self, healthy_example_imu_data, method): ) assert "trajectory method as ori or pos method" in str(w[0]) - def test_passed_both_warning(self, healthy_example_imu_data): + def test_passed_both_warning(self, healthy_example_imu_data) -> None: """Test that a warning is raised when passing ori and pos and trajectory emthods all at one. This will happen by default, when leaving ori and pos as default values @@ -97,7 +97,7 @@ def test_passed_both_warning(self, healthy_example_imu_data): ) assert "You provided a trajectory method AND an ori or pos method." in str(w[0]) - def test_single_sensor_output(self, healthy_example_imu_data, snapshot): + def test_single_sensor_output(self, healthy_example_imu_data, snapshot) -> None: test_stride_events = self.example_region["left_sensor"].iloc[:3] test_data = healthy_example_imu_data["left_sensor"].iloc[: int(test_stride_events.iloc[-1]["end"])] @@ -121,7 +121,7 @@ def test_single_sensor_output(self, healthy_example_imu_data, snapshot): snapshot.assert_match(instance.orientation_.loc[first_last_stride], "ori") snapshot.assert_match(instance.position_.loc[first_last_stride], "pos") - def test_single_sensor_output_empty_stride_list(self, healthy_example_imu_data): + def test_single_sensor_output_empty_stride_list(self, healthy_example_imu_data) -> None: empty_stride_events = pd.DataFrame(columns=self.example_region["left_sensor"].columns) test_data = healthy_example_imu_data["left_sensor"] @@ -133,7 +133,7 @@ def test_single_sensor_output_empty_stride_list(self, healthy_example_imu_data): assert len(instance.orientation_) == 0 assert len(instance.position_) == 0 - def test_multi_sensor_output(self, healthy_example_imu_data, snapshot): + def test_multi_sensor_output(self, healthy_example_imu_data, snapshot) -> None: test_stride_events = self.example_region test_data = healthy_example_imu_data @@ -172,14 +172,14 @@ class TestInitCalculation: No complicated tests here, as this uses `get_gravity_rotation`, which is well tested """ - def test_calc_initial_dummy(self): + def test_calc_initial_dummy(self) -> None: """No rotation expected as already aligned.""" dummy_data = pd.DataFrame(np.repeat(np.array([0, 0, 1, 0, 0, 0])[None, :], 20, axis=0), columns=SF_COLS) start_ori = _initial_orientation_from_start(dummy_data, 10, 8) assert_array_equal(start_ori.as_quat(), Rotation.identity().as_quat()) @pytest.mark.parametrize("start", [0, 99]) - def test_start_of_stride_equals_start_or_end_of_data(self, start): + def test_start_of_stride_equals_start_or_end_of_data(self, start) -> None: """If start is to close to the start or the end of the data a warning is emitted.""" dummy_data = pd.DataFrame(np.repeat(np.array([0, 0, 1, 0, 0, 0])[None, :], 100, axis=0), columns=SF_COLS) with pytest.warns(UserWarning) as w: @@ -187,14 +187,14 @@ def test_start_of_stride_equals_start_or_end_of_data(self, start): assert "complete window length" in str(w[0]) - def test_only_single_value(self): + def test_only_single_value(self) -> None: dummy_data = pd.DataFrame(np.repeat(np.array([0, 0, 1, 0, 0, 0])[None, :], 20, axis=0), columns=SF_COLS) start_ori = _initial_orientation_from_start(dummy_data, 10, 0) assert_array_equal(start_ori.as_quat(), Rotation.identity().as_quat()) class MockTrajectory(BaseTrajectoryMethod): - def __init__(self, initial_orientation=None): + def __init__(self, initial_orientation=None) -> None: self.initial_orientation = initial_orientation super().__init__() diff --git a/tests/test_utils/test_array_handling.py b/tests/test_utils/test_array_handling.py index 66e08eb7..de314772 100644 --- a/tests/test_utils/test_array_handling.py +++ b/tests/test_utils/test_array_handling.py @@ -19,34 +19,34 @@ class TestSlidingWindow: """Test the function `sliding_window_view`.""" - def test_invalid_inputs_overlap(self): + def test_invalid_inputs_overlap(self) -> None: """Test if value error is raised correctly on invalid overlap input.""" with pytest.raises(ValueError, match=r".* overlap .*"): sliding_window_view(np.arange(0, 10), window_length=4, overlap=4) - def test_invalid_inputs_window_size(self): + def test_invalid_inputs_window_size(self) -> None: """Test if value error is raised correctly on invalid window length input.""" with pytest.raises(ValueError, match=r".* window_length .*"): sliding_window_view(np.arange(0, 10), window_length=1, overlap=0) - def test_invalid_inputs_window_size_2(self): + def test_invalid_inputs_window_size_2(self) -> None: """Test if value error is raised correctly for window length > signal length.""" with pytest.raises(ValueError, match=r"negative dimensions are not allowed"): sliding_window_view(np.arange(0, 10), window_length=15, overlap=0) - def test_view_of_array(selfs): + def test_view_of_array(self) -> None: """Test if output is actually just a different view onto the input data.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) window_view = sliding_window_view(input_array, window_length=4, overlap=2) assert np.may_share_memory(input_array, window_view) is True - def test_copy_of_array_with_padding(self): + def test_copy_of_array_with_padding(self) -> None: """Test if output a copy of input data.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) window_view = sliding_window_view(input_array, window_length=4, overlap=2, nan_padding=True) assert np.may_share_memory(input_array, window_view) is False - def test_nan_padding_of_type_nan(self): + def test_nan_padding_of_type_nan(self) -> None: """Test if output a copy of input data.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) window_view = sliding_window_view(input_array, window_length=4, overlap=2, nan_padding=True) @@ -55,7 +55,7 @@ def test_nan_padding_of_type_nan(self): assert math.isnan(window_view[-1][-1]) - def test_sliding_window_1D_without_without_padding(self): + def test_sliding_window_1D_without_without_padding(self) -> None: """Test windowed view is correct for 1D array without need for nan padding.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) expected_output = np.array([[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7], [6, 7, 8, 9]]) @@ -63,7 +63,7 @@ def test_sliding_window_1D_without_without_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_with_padding(self): + def test_sliding_window_1D_with_padding(self) -> None: """Test windowed view is correct for 1D array with need for nan padding.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) expected_output = np.array([[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7], [6, 7, 8, 9], [8, 9, 10, np.nan]]) @@ -71,7 +71,7 @@ def test_sliding_window_1D_with_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_without_padding(self): + def test_sliding_window_1D_without_padding(self) -> None: """Test windowed view is correct for 1D array with need for padding but padding disabled.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) expected_output = np.array([[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7], [6, 7, 8, 9]]) @@ -79,7 +79,7 @@ def test_sliding_window_1D_without_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_asym_with_padding(self): + def test_sliding_window_1D_asym_with_padding(self) -> None: """Test windowed view is correct for 1D array with need for nan padding and asymetrical overlap.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) expected_output = np.array([[0, 1, 2, 3, 4, 5, 6], [5, 6, 7, 8, 9, np.nan, np.nan]]) @@ -87,7 +87,7 @@ def test_sliding_window_1D_asym_with_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_asym_without_padding(self): + def test_sliding_window_1D_asym_without_padding(self) -> None: """Test windowed view is correct for 1D array with need for nan padding but padding disabled and asymetrical overlap. """ @@ -97,7 +97,7 @@ def test_sliding_window_1D_asym_without_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_no_overlap_without_padding(self): + def test_sliding_window_1D_no_overlap_without_padding(self) -> None: """Test windowed view is correct for 1D array with no overlap and no padding.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) expected_output = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) @@ -105,7 +105,7 @@ def test_sliding_window_1D_no_overlap_without_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_no_overlap_with_padding(self): + def test_sliding_window_1D_no_overlap_with_padding(self) -> None: """Test windowed view is correct for 1D array with no overlap and need for padding.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) expected_output = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, np.nan, np.nan]]) @@ -113,7 +113,7 @@ def test_sliding_window_1D_no_overlap_with_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_1D_asym_overlap_with_padding(self): + def test_sliding_window_1D_asym_overlap_with_padding(self) -> None: """Test windowed view is correct for 1D array with asym overlap need for padding.""" input_array = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) expected_output = np.array([[0, 1, 2, 3, 4], [2, 3, 4, 5, 6], [4, 5, 6, 7, 8], [6, 7, 8, 9, np.nan]]) @@ -121,7 +121,7 @@ def test_sliding_window_1D_asym_overlap_with_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_3D_without_edge_case(self): + def test_sliding_window_3D_without_edge_case(self) -> None: """Test windowed view is correct for 3D array with sym overlap and no need for padding.""" input_array = np.column_stack([np.arange(0, 10), np.arange(0, 10), np.arange(0, 10)]) expected_output = np.array( @@ -136,7 +136,7 @@ def test_sliding_window_3D_without_edge_case(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_3D_with_padding(self): + def test_sliding_window_3D_with_padding(self) -> None: """Test windowed view is correct for 3D array with sym overlap and need for padding.""" input_array = np.column_stack([np.arange(0, 11), np.arange(0, 11), np.arange(0, 11)]) expected_output = np.array( @@ -152,7 +152,7 @@ def test_sliding_window_3D_with_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_3D_asym_with_padding(self): + def test_sliding_window_3D_asym_with_padding(self) -> None: """Test windowed view is correct for 3D array with asym overlap and need for padding.""" input_array = np.column_stack([np.arange(0, 12), np.arange(0, 12), np.arange(0, 12)]) expected_output = np.array( @@ -174,7 +174,7 @@ def test_sliding_window_3D_asym_with_padding(self): assert_array_equal(expected_output, window_view) - def test_sliding_window_5D_without_padding(self): + def test_sliding_window_5D_without_padding(self) -> None: """Test windowed view is correct for high dimensional array with sym overlap and no need for padding.""" input_array = np.column_stack( [np.arange(0, 10), np.arange(0, 10), np.arange(0, 10), np.arange(0, 10), np.arange(0, 10)] @@ -195,53 +195,53 @@ def test_sliding_window_5D_without_padding(self): class TestBoolArrayToStartEndArray: """Test the function `bool_array_to_start_end_array`.""" - def test_simple_input(self): + def test_simple_input(self) -> None: input_array = np.array([0, 0, 1, 1, 0, 0, 1, 1, 1]) output_array = bool_array_to_start_end_array(input_array) expected_output = np.array([[2, 4], [6, 9]]) assert_array_equal(expected_output, output_array) - def test_invalid_inputs_overlap(self): + def test_invalid_inputs_overlap(self) -> None: """Test if value error is raised correctly on invalid input array.""" with pytest.raises(ValueError, match=r".* boolean .*"): input_array = np.array([0, 0, 2, 2, 0, 0, 2, 2, 2]) bool_array_to_start_end_array(input_array) - def test_zeros_array(self): + def test_zeros_array(self) -> None: """Test zeros only input.""" input_array = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) output_array = bool_array_to_start_end_array(input_array) assert output_array.size == 0 - def test_ones_array(self): + def test_ones_array(self) -> None: """Test ones only input.""" input_array = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]) output_array = bool_array_to_start_end_array(input_array) expected_output = np.array([[0, 9]]) assert_array_equal(expected_output, output_array) - def test_edges_array(self): + def test_edges_array(self) -> None: """Test correct handling of edges.""" input_array = np.array([1, 1, 1, 0, 0, 0, 1, 1, 1]) output_array = bool_array_to_start_end_array(input_array) expected_output = np.array([[0, 3], [6, 9]]) assert_array_equal(expected_output, output_array) - def test_bool_value_array(self): + def test_bool_value_array(self) -> None: """Test correct handling of boolean values.""" input_array = np.array([True, True, True, False, False, False, True, True, True]) output_array = bool_array_to_start_end_array(input_array) expected_output = np.array([[0, 3], [6, 9]]) assert_array_equal(expected_output, output_array) - def test_float_value_array(self): + def test_float_value_array(self) -> None: """Test correct handling of float values.""" input_array = np.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) output_array = bool_array_to_start_end_array(input_array) expected_output = np.array([[0, 3], [6, 9]]) assert_array_equal(expected_output, output_array) - def test_empty_array(self): + def test_empty_array(self) -> None: """Test zeros only input.""" input_array = np.array([]) output_array = bool_array_to_start_end_array(input_array) @@ -251,26 +251,26 @@ def test_empty_array(self): class TestStartEndArrayToBoolArray: """Test the function `start_end_array_to_bool_array`.""" - def test_simple_input_no_padding(self): + def test_simple_input_no_padding(self) -> None: input_array = np.array([[2, 3], [5, 9]]) output_array = start_end_array_to_bool_array(input_array) expected_output = np.array([0, 0, 1, 0, 0, 1, 1, 1, 1]).astype(bool) assert_array_equal(expected_output, output_array) - def test_simple_input_1d_no_padding(self): + def test_simple_input_1d_no_padding(self) -> None: input_array = np.array([2, 3]) output_array = start_end_array_to_bool_array(input_array, pad_to_length=5) expected_output = np.array([0, 0, 1, 0, 0]).astype(bool) assert_array_equal(expected_output, output_array) - def test_simple_input_with_padding(self): + def test_simple_input_with_padding(self) -> None: input_array = np.array([[2, 3], [5, 9]]) output_array = start_end_array_to_bool_array(input_array, pad_to_length=12) expected_output = np.array([0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) assert_array_equal(expected_output, output_array) @pytest.mark.parametrize("reverse", [True, False]) - def test_length_unsorted(self, reverse): + def test_length_unsorted(self, reverse) -> None: input_array = np.array([[5, 9], [2, 3]]) if reverse: input_array = input_array[::-1] @@ -278,31 +278,31 @@ def test_length_unsorted(self, reverse): expected_output = np.array([0, 0, 1, 0, 0, 1, 1, 1, 1]).astype(bool) assert_array_equal(expected_output, output_array) - def test_overlapping_input_with_padding(self): + def test_overlapping_input_with_padding(self) -> None: input_array = np.array([[2, 6], [5, 9]]) output_array = start_end_array_to_bool_array(input_array, pad_to_length=12) expected_output = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]).astype(bool) assert_array_equal(expected_output, output_array) - def test_invalid_padding(self): + def test_invalid_padding(self) -> None: input_array = np.array([[2, 3], [5, 9]]) with pytest.raises(ValueError) as e: start_end_array_to_bool_array(input_array, pad_to_length=-1) assert "pad_to_length must be positive" in str(e) - def test_short_padding(self): + def test_short_padding(self) -> None: input_array = np.array([[2, 3], [5, 9]]) output_array = start_end_array_to_bool_array(input_array, pad_to_length=7) expected_output = np.array([0, 0, 1, 0, 0, 1, 1]).astype(bool) assert_array_equal(expected_output, output_array) - def test_correct_output_dtype(self): + def test_correct_output_dtype(self) -> None: input_array = np.array([[2, 3], [5, 9]]) output_array = start_end_array_to_bool_array(input_array) assert output_array.dtype == "bool" - def test_bool_array_round_trip(self): + def test_bool_array_round_trip(self) -> None: input_array = np.array([[2, 3], [5, 9]]) output_array = start_end_array_to_bool_array(input_array) output_array = bool_array_to_start_end_array(output_array) @@ -323,7 +323,7 @@ class TestLocalMinimaBelowThreshold: ), ], ) - def test_split_array_simple(self, data, result): + def test_split_array_simple(self, data, result) -> None: out = split_array_at_nan(data) assert len(out) == len(result) @@ -338,14 +338,14 @@ def test_split_array_simple(self, data, result): ([*np.ones(10), -1, -2, -1, *np.ones(10)], -3, []), ], ) - def test_find_extrema(self, data, threshold, results): + def test_find_extrema(self, data, threshold, results) -> None: out = find_local_minima_below_threshold(np.array(data), threshold) np.testing.assert_array_equal(out, results) class TestFindMinRadius: - def test_invalid_method_name_error(self): + def test_invalid_method_name_error(self) -> None: data = np.array([0, 0, 0, -1, 0, 0, 0]) # min at 3 radius = 1 indices = np.array([2, 3, 4]) # All should find the minima @@ -353,16 +353,16 @@ def test_invalid_method_name_error(self): with pytest.raises(ValueError): find_extrema_in_radius(data, indices, radius, extrema_type="invalid_type") - @pytest.mark.parametrize(("method", "ex_value"), (("min", -1), ("max", 1))) - def test_simple(self, method, ex_value): + @pytest.mark.parametrize(("method", "ex_value"), [("min", -1), ("max", 1)]) + def test_simple(self, method, ex_value) -> None: data = np.array([0, 0, 0, ex_value, 0, 0, 0]) # min at 3 radius = 1 indices = np.array([2, 3, 4]) # All should find the minima out = find_extrema_in_radius(data, indices, radius, extrema_type=method) assert_array_equal(out, [3, 3, 3]) - @pytest.mark.parametrize(("method", "ex_value"), (("min", -1), ("max", 1))) - def test_multiple_matches(self, method, ex_value): + @pytest.mark.parametrize(("method", "ex_value"), [("min", -1), ("max", 1)]) + def test_multiple_matches(self, method, ex_value) -> None: data = np.array([0, 0, 0, ex_value, 0, 0, 0, 0, 0, 0, ex_value, 0, 0, 0]) # min at 3, 10 radius = 2 indices = np.arange(2, len(data) - 2) @@ -370,21 +370,21 @@ def test_multiple_matches(self, method, ex_value): # 2 - 5 should see the first minimum, 6, 7 see no minimum, 8-11 see second minimum assert_array_equal(out, [3, 3, 3, 3, 4, 5, 10, 10, 10, 10]) - def test_edge_case_end(self): + def test_edge_case_end(self) -> None: data = np.array([0, 0, 0, 0, 0, -1, 0]) # min at 5 radius = 2 indices = np.array([4, 5]) # 5 overlap with end out = find_extrema_in_radius(data, indices, radius) assert_array_equal(out, [5, 5]) - def test_edge_case_start(self): + def test_edge_case_start(self) -> None: data = np.array([0, -1, 0, 0, 0, 0, 0]) # min at 1 radius = 2 indices = np.array([1, 2]) # 1 overlap with start out = find_extrema_in_radius(data, indices, radius) assert_array_equal(out, [1, 1]) - def test_full_dummy(self): + def test_full_dummy(self) -> None: """As there is no minimum, every index should return the start of the window.""" data = np.zeros(10) indices = np.arange(10) @@ -393,7 +393,7 @@ def test_full_dummy(self): assert_array_equal(out[radius:], indices[radius:] - radius) assert_array_equal(out[:radius], np.zeros(radius)) - def test_radius_zero(self): + def test_radius_zero(self) -> None: """Test that if the radius is 0 the value itself gets returned.""" data = np.array([0, 0, 0, 1, 0, 0, 0]) # min at 3 radius = 0 @@ -401,8 +401,8 @@ def test_radius_zero(self): out = find_extrema_in_radius(data, indices, radius, extrema_type="max") assert_array_equal(out, indices) - @pytest.mark.parametrize("radius", (1, 2, 3)) - def test_tuple_input_identical_to_single_input(self, radius): + @pytest.mark.parametrize("radius", [1, 2, 3]) + def test_tuple_input_identical_to_single_input(self, radius) -> None: data = np.array([0, 0, 0, 1, 0, 0, 0]) indices = np.array([2, 3, 4]) @@ -411,7 +411,7 @@ def test_tuple_input_identical_to_single_input(self, radius): assert_array_equal(out_single, out_tuple) - def test_non_equal_left_right_radius(self): + def test_non_equal_left_right_radius(self) -> None: data = np.array([0, 2, 0, 0, 0, 1, 0, 0]) indices = np.array([2, 3, 4]) radius = (1, 2) @@ -419,7 +419,7 @@ def test_non_equal_left_right_radius(self): out = find_extrema_in_radius(data, indices, radius, extrema_type="max") assert_array_equal(out, [1, 5, 5]) - def test_non_equal_radius_edgecase(self): + def test_non_equal_radius_edgecase(self) -> None: data = np.array([0, 2, 0, 0, 0, 1, 0, 0]) indices = np.array([2, 3, 4]) radius = (3, 5) @@ -435,17 +435,17 @@ class TestMultiArrayInterpolate: """ @pytest.fixture(autouse=True, params=["linear", "nearest"]) - def select_kind(self, request): + def select_kind(self, request) -> None: self.kind = request.param - def test_upsample(self): + def test_upsample(self) -> None: """Test if data is upsampled to the correct number of samples.""" data = np.array([[0, 1, 2, 3], [0, 1, 2, 3]]).T data_interpolated = multi_array_interpolation([data, data, data], n_samples=10, kind=self.kind) assert data_interpolated.shape == (3, 2, 10) - def test_downsample(self): + def test_downsample(self) -> None: """Test if data is downsampled to the correct number of samples.""" data = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]).T data_interpolated = multi_array_interpolation([data, data, data], n_samples=4, kind=self.kind) @@ -479,12 +479,12 @@ class TestMergeIntervals: ), ], ) - def test_merge_intervals(self, input_array, output_array, gap_size): + def test_merge_intervals(self, input_array, output_array, gap_size) -> None: assert_array_equal(output_array, merge_intervals(input_array, gap_size)) class TestIterateRegionData: - def test_simple_case(self): + def test_simple_case(self) -> None: data = pd.DataFrame(np.ones((40, 3))) rois = pd.DataFrame( {"start": [0, 10, 20], "end": [10, 20, 30], "s_id": [0, 1, 2]}, @@ -495,7 +495,7 @@ def test_simple_case(self): assert region.shape == (10, 3) assert_array_equal(region, data.iloc[rois.start[i] : rois.end[i]]) - def test_multi_data(self): + def test_multi_data(self) -> None: data1 = pd.DataFrame(np.ones((40, 3))) data2 = pd.DataFrame(np.ones((40, 3))) * 2 @@ -511,7 +511,7 @@ def test_multi_data(self): assert region.shape == (10, 3) assert_array_equal(region, data2.iloc[rois.start[i] : rois.end[i]]) - def test_col_order(self): + def test_col_order(self) -> None: data = pd.DataFrame(np.ones((40, 3)), columns=["a", "b", "c"]) rois = pd.DataFrame( {"start": [0, 10, 20], "end": [10, 20, 30], "s_id": [0, 1, 2]}, @@ -519,7 +519,7 @@ def test_col_order(self): region_data = list(iterate_region_data([data], [rois], expected_col_order=["b", "a", "c"])) assert region_data[0].columns.tolist() == ["b", "a", "c"] - def test_data_different_col_order(self): + def test_data_different_col_order(self) -> None: data1 = pd.DataFrame(np.ones((40, 3)), columns=["a", "b", "c"]) data2 = pd.DataFrame(np.ones((40, 3)), columns=["a", "c", "b"]) diff --git a/tests/test_utils/test_coordinate_conversion.py b/tests/test_utils/test_coordinate_conversion.py index c3b852dc..8ccffbb7 100644 --- a/tests/test_utils/test_coordinate_conversion.py +++ b/tests/test_utils/test_coordinate_conversion.py @@ -10,7 +10,7 @@ class TestConvertAxes: """Test the functions for converting either left or right foot.""" @pytest.fixture(autouse=True) - def _sample_sensor_data(self): + def _sample_sensor_data(self) -> None: """Create some sample dummy data frames.""" self.data_left = pd.DataFrame([[1, 2, 3, 4, 5, 6]], columns=SF_COLS) self.data_right = pd.DataFrame([[7, 8, 9, 10, 11, 12]], columns=SF_COLS) @@ -24,59 +24,59 @@ def _sample_sensor_data(self): self.data_dict_expected = {"left_sensor": self.data_left_expected, "right_sensor": self.data_right_expected} self.data_df_expected = pd.concat(self.data_dict_expected, axis=1) - def test_convert_left_foot(self): + def test_convert_left_foot(self) -> None: data_converted = convert_left_foot_to_fbf(self.data_left) assert_frame_equal(data_converted, self.data_left_expected) - def test_convert_right_foot(self): + def test_convert_right_foot(self) -> None: data_converted = convert_right_foot_to_fbf(self.data_right) assert_frame_equal(data_converted, self.data_right_expected) - def test_no_position_arguments(self): + def test_no_position_arguments(self) -> None: with pytest.raises(ValueError): convert_to_fbf(self.data_df) - def test_wrong_key_arguments(self): + def test_wrong_key_arguments(self) -> None: with pytest.raises(KeyError): convert_to_fbf(self.data_df, left=["abc"]) with pytest.raises(KeyError): convert_to_fbf(self.data_df, right=["abc"]) - def test_rotate_multisensor(self): + def test_rotate_multisensor(self) -> None: data_converted = convert_to_fbf(self.data_df, left=["left_sensor"], right=["right_sensor"]) assert_frame_equal(data_converted, self.data_df_expected) - def test_rotate_multisensor_left(self): + def test_rotate_multisensor_left(self) -> None: data_converted = convert_to_fbf(self.data_df, left=["left_sensor"]) assert_frame_equal(data_converted["left_sensor"], self.data_left_expected) assert_frame_equal(data_converted["right_sensor"], self.data_df["right_sensor"]) - def test_rotate_multisensor_right(self): + def test_rotate_multisensor_right(self) -> None: data_converted = convert_to_fbf(self.data_df, right=["right_sensor"]) assert_frame_equal(data_converted["right_sensor"], self.data_right_expected) assert_frame_equal(data_converted["left_sensor"], self.data_df["left_sensor"]) - def test_rotate_multisensor_dict(self): + def test_rotate_multisensor_dict(self) -> None: data_converted = convert_to_fbf(self.data_dict, left=["left_sensor"], right=["right_sensor"]) for sensor in self.data_dict_expected: assert_frame_equal(data_converted[sensor], self.data_dict_expected[sensor]) - def test_like_argument(self): + def test_like_argument(self) -> None: data_converted = convert_to_fbf(self.data_df, right_like="right_", left_like="left_") assert_frame_equal(data_converted["right_sensor"], self.data_right_expected) - def test_like_argument_error(self): + def test_like_argument_error(self) -> None: with pytest.raises(ValueError): convert_to_fbf(self.data_df, right=["right_sensor"], right_like="right_") - def test_does_not_match_warning(self): + def test_does_not_match_warning(self) -> None: with pytest.warns(UserWarning) as w: convert_to_fbf(self.data_df, right_like="not_in_any_name_") assert "not_in_any_name_" in w[0].message.args[0] - def test_only_multisensor_dataset_supported(self): + def test_only_multisensor_dataset_supported(self) -> None: with pytest.raises(ValueError) as e: convert_to_fbf(self.data_df["left_sensor"]) diff --git a/tests/test_utils/test_datatype_helper.py b/tests/test_utils/test_datatype_helper.py index 59a3de3d..03aea0ce 100644 --- a/tests/test_utils/test_datatype_helper.py +++ b/tests/test_utils/test_datatype_helper.py @@ -1,4 +1,5 @@ """Test the dataset helpers.""" + import numpy as np import pandas as pd import pytest @@ -74,43 +75,43 @@ def as_index(request): class TestIsSingleSensorDataset: @pytest.mark.parametrize( "value", - ({"test": pd.DataFrame}, list(range(6)), "test", np.arange(6), pd.DataFrame(columns=_create_test_multiindex())), + [{"test": pd.DataFrame}, list(range(6)), "test", np.arange(6), pd.DataFrame(columns=_create_test_multiindex())], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not is_single_sensor_data(value, check_acc=False, check_gyr=False) - def test_correct_datatype(self): + def test_correct_datatype(self) -> None: assert is_single_sensor_data(pd.DataFrame(), check_acc=False, check_gyr=False) @pytest.mark.parametrize( ("cols", "frame_valid", "col_check_valid"), - ( + [ (SF_COLS, "sensor", "both"), (BF_COLS, "body", "both"), (BF_GYR, "body", "gyr"), (BF_ACC, "body", "acc"), (SF_GYR, "sensor", "gyr"), (SF_ACC, "sensor", "acc"), - ), + ], ) - def test_correct_columns(self, cols, frame_valid, col_check_valid, combinations, frame): + def test_correct_columns(self, cols, frame_valid, col_check_valid, combinations, frame) -> None: """Test all possible combinations of inputs.""" col_check, check_acc, check_gyro = combinations output = is_single_sensor_data( pd.DataFrame(columns=cols), check_acc=check_acc, check_gyr=check_gyro, frame=frame ) - valid_frame = (frame_valid == frame) or (frame == "any") - valid_cols = (col_check == col_check_valid) or (col_check_valid == "both") + valid_frame = frame in (frame_valid, "any") + valid_cols = col_check_valid in (col_check, "both") expected_outcome = valid_cols and valid_frame assert output == expected_outcome - def test_invalid_frame_argument(self): + def test_invalid_frame_argument(self) -> None: with pytest.raises(ValueError): is_single_sensor_data(pd.DataFrame(), frame="invalid_value") - def test_error_raising(self): + def test_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_single_sensor_data(pd.DataFrame(), frame="body", check_acc=True, check_gyr=False, raise_exception=True) @@ -121,28 +122,28 @@ def test_error_raising(self): class TestIsMultiSensorDataset: @pytest.mark.parametrize( "value", - (list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])), + [list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not is_multi_sensor_data(value, check_acc=False, check_gyr=False) - def test_correct_datatype(self): + def test_correct_datatype(self) -> None: assert is_multi_sensor_data( pd.DataFrame([[*range(9)]], columns=_create_test_multiindex()), check_acc=False, check_gyr=False ) @pytest.mark.parametrize( ("cols", "frame_valid", "col_check_valid"), - ( + [ (SF_COLS, "sensor", "both"), (BF_COLS, "body", "both"), (BF_GYR, "body", "gyr"), (BF_ACC, "body", "acc"), (SF_GYR, "sensor", "gyr"), (SF_ACC, "sensor", "acc"), - ), + ], ) - def test_correct_columns(self, cols, frame_valid, col_check_valid, combinations, frame): + def test_correct_columns(self, cols, frame_valid, col_check_valid, combinations, frame) -> None: """Test all possible combinations of inputs.""" col_check, check_acc, check_gyro = combinations output = is_multi_sensor_data( @@ -152,24 +153,24 @@ def test_correct_columns(self, cols, frame_valid, col_check_valid, combinations, frame=frame, ) - valid_frame = (frame_valid == frame) or (frame == "any") - valid_cols = (col_check == col_check_valid) or (col_check_valid == "both") + valid_frame = frame in (frame_valid, "any") + valid_cols = col_check_valid in (col_check, "both") expected_outcome = valid_cols and valid_frame assert output == expected_outcome - def test_invalid_frame_argument(self): + def test_invalid_frame_argument(self) -> None: with pytest.raises(ValueError): is_multi_sensor_data(pd.DataFrame([[*range(9)]], columns=_create_test_multiindex()), frame="invalid_value") - def test_error_raising(self): + def test_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_multi_sensor_data(pd.DataFrame(), raise_exception=True) assert "The passed object does not seem to be MultiSensorData." in str(e) assert "MultiIndex" in str(e) - def test_nested_error_raising(self): + def test_nested_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_multi_sensor_data( {"s1": pd.DataFrame()}, frame="body", check_acc=True, check_gyr=False, raise_exception=True @@ -181,7 +182,7 @@ def test_nested_error_raising(self): class TestIsDataset: - def test_raises_error_correctly(self): + def test_raises_error_correctly(self) -> None: with pytest.raises(ValidationError) as e: is_sensor_data(pd.DataFrame(), frame="body", check_acc=True, check_gyr=False) @@ -189,21 +190,21 @@ def test_raises_error_correctly(self): assert str(BF_ACC) in str(e.value) assert "MultiIndex" in str(e.value) - @pytest.mark.parametrize(("obj", "out"), ((pd.DataFrame(), "single"), ({"s1": pd.DataFrame()}, "multi"))) - def test_basic_function(self, obj, out): + @pytest.mark.parametrize(("obj", "out"), [(pd.DataFrame(), "single"), ({"s1": pd.DataFrame()}, "multi")]) + def test_basic_function(self, obj, out) -> None: assert is_sensor_data(obj, check_gyr=False, check_acc=False) == out class TestGetMultiSensorDatasetNames: - @pytest.mark.parametrize("obj", ({"a": [], "b": [], "c": []}, pd.DataFrame(columns=_create_test_multiindex()))) - def test_names_simple(self, obj): + @pytest.mark.parametrize("obj", [{"a": [], "b": [], "c": []}, pd.DataFrame(columns=_create_test_multiindex())]) + def test_names_simple(self, obj) -> None: assert set(get_multi_sensor_names(obj)) == {"a", "b", "c"} class TestIsSingleSensorStrideList: @pytest.mark.parametrize( "value", - ( + [ list(range(6)), "test", np.arange(6), @@ -211,14 +212,14 @@ class TestIsSingleSensorStrideList: pd.DataFrame(), pd.DataFrame(columns=[*range(3)]), pd.DataFrame([[*range(9)]], columns=_create_test_multiindex()), - ), + ], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not is_single_sensor_stride_list(value) @pytest.mark.parametrize( ("cols", "stride_types_valid"), - ( + [ (["s_id", "start", "end", "gsd_id"], ["any"]), (["s_id", "start", "end", "gsd_id", "something_extra"], ["any"]), (["s_id", "start", "end", "gsd_id", "pre_ic", "ic", "min_vel", "tc"], ["segmented", "min_vel", "ic"]), @@ -228,9 +229,9 @@ def test_wrong_datatype(self, value): ), (["s_id", "start", "end", "gsd_id", "ic", "min_vel", "tc"], ["ic", "segmented"]), (["s_id", "start", "end", "gsd_id", "ic", "min_vel", "tc", "something_extra"], ["ic", "segmented"]), - ), + ], ) - def test_valid_versions(self, cols, stride_types_valid, stride_types, as_index): + def test_valid_versions(self, cols, stride_types_valid, stride_types, as_index) -> None: expected_outcome = stride_types in stride_types_valid or stride_types == "any" df = pd.DataFrame(columns=cols) if as_index: @@ -240,8 +241,8 @@ def test_valid_versions(self, cols, stride_types_valid, stride_types, as_index): assert expected_outcome == out - @pytest.mark.parametrize("check_additional_cols", (True, False, ("ic",))) - def test_check_additional_columns(self, check_additional_cols): + @pytest.mark.parametrize("check_additional_cols", [True, False, ("ic",)]) + def test_check_additional_columns(self, check_additional_cols) -> None: # We construct a df that only has the minimal columns for min_vel df = pd.DataFrame(columns=["s_id", "start", "end", "min_vel"]) @@ -257,9 +258,9 @@ def test_check_additional_columns(self, check_additional_cols): @pytest.mark.parametrize( ("start", "min_vel", "expected"), - ((np.arange(10), np.arange(10), True), (np.arange(10), np.arange(10) + 1, False), ([], [], True)), + [(np.arange(10), np.arange(10), True), (np.arange(10), np.arange(10) + 1, False), ([], [], True)], ) - def test_columns_same_min_vel(self, start, min_vel, expected): + def test_columns_same_min_vel(self, start, min_vel, expected) -> None: """Test that the column equals check for min_vel_strides work.""" min_vel_cols = ["s_id", "start", "end", "gsd_id", "pre_ic", "ic", "min_vel", "tc"] stride_list = pd.DataFrame(columns=min_vel_cols) @@ -273,9 +274,9 @@ def test_columns_same_min_vel(self, start, min_vel, expected): @pytest.mark.parametrize( ("start", "ic", "expected"), - ((np.arange(10), np.arange(10), True), (np.arange(10), np.arange(10) + 1, False), ([], [], True)), + [(np.arange(10), np.arange(10), True), (np.arange(10), np.arange(10) + 1, False), ([], [], True)], ) - def test_columns_same_ic(self, start, ic, expected): + def test_columns_same_ic(self, start, ic, expected) -> None: """Test that the column equals check for ic_strides work.""" min_vel_cols = ["s_id", "start", "end", "gsd_id", "ic", "min_vel", "tc"] stride_list = pd.DataFrame(columns=min_vel_cols) @@ -287,14 +288,14 @@ def test_columns_same_ic(self, start, ic, expected): assert out == expected - def test_invalid_stride_type_argument(self): + def test_invalid_stride_type_argument(self) -> None: valid_cols = ["s_id", "start", "end", "gsd_id"] valid = pd.DataFrame(columns=valid_cols) with pytest.raises(ValueError): is_single_sensor_stride_list(valid, stride_type="invalid_value") - def test_identical_stride_ids(self): + def test_identical_stride_ids(self) -> None: """Test that the search for identical stride ids works.""" min_vel_cols = ["s_id", "start", "end"] stride_list = pd.DataFrame(columns=min_vel_cols) @@ -305,7 +306,7 @@ def test_identical_stride_ids(self): assert expected_outcome == out - def test_error_raising(self): + def test_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_single_sensor_stride_list(pd.DataFrame(), raise_exception=True) @@ -316,14 +317,14 @@ def test_error_raising(self): class TestIsMultiSensorStrideList: @pytest.mark.parametrize( "value", - (list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])), + [list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not is_multi_sensor_stride_list(value) @pytest.mark.parametrize( ("cols", "stride_types_valid"), - ( + [ (["s_id", "start", "end", "gsd_id"], ["any"]), (["s_id", "start", "end", "gsd_id", "something_extra"], ["any"]), (["s_id", "start", "end", "gsd_id", "pre_ic", "ic", "min_vel", "tc"], ["segmented", "min_vel", "ic"]), @@ -333,9 +334,9 @@ def test_wrong_datatype(self, value): ), (["s_id", "start", "end", "gsd_id", "ic", "min_vel", "tc"], ["ic", "segmented"]), (["s_id", "start", "end", "gsd_id", "ic", "min_vel", "tc", "something_extra"], ["ic", "segmented"]), - ), + ], ) - def test_valid_versions(self, cols, stride_types_valid, stride_types, as_index): + def test_valid_versions(self, cols, stride_types_valid, stride_types, as_index) -> None: expected_outcome = stride_types in stride_types_valid or stride_types == "any" df = pd.DataFrame(columns=cols) if as_index: @@ -345,7 +346,7 @@ def test_valid_versions(self, cols, stride_types_valid, stride_types, as_index): assert expected_outcome == out - def test_only_one_invalid(self): + def test_only_one_invalid(self) -> None: valid_cols = ["s_id", "start", "end", "gsd_id"] invalid_cols = ["start", "end", "gsd_id"] valid = {"s1": pd.DataFrame(columns=valid_cols)} @@ -354,14 +355,14 @@ def test_only_one_invalid(self): assert is_multi_sensor_stride_list(valid) assert not is_multi_sensor_stride_list(invalid) - def test_invalid_stride_type_argument(self): + def test_invalid_stride_type_argument(self) -> None: valid_cols = ["s_id", "start", "end", "gsd_id"] valid = {"s1": pd.DataFrame(columns=valid_cols)} with pytest.raises(ValueError): is_multi_sensor_stride_list(valid, stride_type="invalid_value") - def test_nested_error_raising(self): + def test_nested_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_multi_sensor_stride_list({"s1": pd.DataFrame()}, raise_exception=True) @@ -371,7 +372,7 @@ def test_nested_error_raising(self): class TestIsStrideList: - def test_raises_error_correctly(self): + def test_raises_error_correctly(self) -> None: with pytest.raises(ValidationError) as e: is_stride_list(pd.DataFrame()) @@ -381,12 +382,12 @@ def test_raises_error_correctly(self): @pytest.mark.parametrize( ("obj", "out"), - ( + [ (pd.DataFrame(columns=["s_id", "start", "end", "gsd_id"]), "single"), ({"s1": pd.DataFrame(columns=["s_id", "start", "end", "gsd_id"])}, "multi"), - ), + ], ) - def test_basic_function(self, obj, out): + def test_basic_function(self, obj, out) -> None: assert is_stride_list(obj) == out @@ -400,12 +401,12 @@ class TestIsSingleSensorTrajLikeList: ), ids=("pos", "vel", "ori"), ) - def traj_like_lists(self, request): + def traj_like_lists(self, request) -> None: self.func, self.dtype, self.valid_cols = request.param @pytest.mark.parametrize( "value", - ( + [ list(range(6)), "test", np.arange(6), @@ -413,22 +414,22 @@ def traj_like_lists(self, request): pd.DataFrame(), pd.DataFrame(columns=[*range(3)]), pd.DataFrame(columns=["s_id", "sample", "wrong1", "wrong2"]), - ), + ], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not self.func(value) @pytest.mark.parametrize( ("cols", "index"), - ( + [ (["s_id", "sample"], []), (["s_id", "sample", "something_else"], []), (["sample"], ["s_id"]), ([], ["s_id", "sample"]), (["something_else"], ["s_id", "sample"]), - ), + ], ) - def test_valid_versions(self, cols, index): + def test_valid_versions(self, cols, index) -> None: df = pd.DataFrame(columns=[*self.valid_cols, *cols, *index]) if index: df = df.set_index(index) @@ -437,15 +438,14 @@ def test_valid_versions(self, cols, index): @pytest.mark.parametrize( ("cols", "index", "both"), - ( + [ (["s_id", "sample"], [], True), (["sample"], [], False), ([], ["s_id", "sample"], True), ([], ["sample"], False), - ), + ], ) - def test_valid_versions_without_s_id(self, cols, index, both): - + def test_valid_versions_without_s_id(self, cols, index, both) -> None: df = pd.DataFrame(columns=[*self.valid_cols, *cols, *index]) if index: df = df.set_index(index) @@ -454,19 +454,19 @@ def test_valid_versions_without_s_id(self, cols, index, both): assert self.func(df) is True @pytest.mark.parametrize(("list_type", "index"), TRAJ_TYPE_COLS.items()) - def test_different_list_types(self, list_type, index): + def test_different_list_types(self, list_type, index) -> None: valid_cols = [index, "sample", *self.valid_cols] df = pd.DataFrame(columns=valid_cols) for k in TRAJ_TYPE_COLS: assert self.func(df, k) == (k == list_type) @pytest.mark.parametrize(("list_type", "index"), TRAJ_TYPE_COLS.items()) - def test_any_roi_list_type(self, list_type, index): + def test_any_roi_list_type(self, list_type, index) -> None: valid_cols = [index, "sample", *self.valid_cols] df = pd.DataFrame(columns=valid_cols) assert self.func(df, "any_roi") == (list_type in ["roi", "gs"]) - def test_error_raising(self): + def test_error_raising(self) -> None: with pytest.raises(ValidationError) as e: self.func(pd.DataFrame(), raise_exception=True) @@ -484,34 +484,34 @@ class TestIsMultiSensorTrajLikeList: ), ids=("pos", "vel", "ori"), ) - def traj_like_lists(self, request): + def traj_like_lists(self, request) -> None: self.func, self.dtype, self.valid_cols = request.param @pytest.mark.parametrize( "value", - (list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])), + [list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not self.func(value) @pytest.mark.parametrize( ("cols", "index"), - ( + [ (["s_id", "sample"], []), (["s_id", "sample", "something_else"], []), (["sample"], ["s_id"]), ([], ["s_id", "sample"]), (["something_else"], ["s_id", "sample"]), - ), + ], ) - def test_valid_versions(self, cols, index): + def test_valid_versions(self, cols, index) -> None: df = pd.DataFrame(columns=[*self.valid_cols, *cols, *index]) if index: df = df.set_index(index) assert self.func({"s1": df}, "stride") - def test_only_one_invalid(self): + def test_only_one_invalid(self) -> None: valid_cols = ["s_id", "sample", *self.valid_cols] invalid_cols = ["sample", *self.valid_cols] valid = {"s1": pd.DataFrame(columns=valid_cols)} @@ -520,7 +520,7 @@ def test_only_one_invalid(self): assert self.func(valid, "stride") assert not self.func(invalid, "stride") - def test_nested_error_raising(self): + def test_nested_error_raising(self) -> None: with pytest.raises(ValidationError) as e: self.func({"s1": pd.DataFrame()}, raise_exception=True) @@ -539,10 +539,10 @@ class TestIsTrajLikeList: ), ids=("pos", "vel", "ori"), ) - def traj_like_lists(self, request): + def traj_like_lists(self, request) -> None: self.func, self.dtype, self.valid_cols = request.param - def test_raises_error_correctly(self): + def test_raises_error_correctly(self) -> None: with pytest.raises(ValidationError) as e: self.func(pd.DataFrame()) @@ -550,7 +550,7 @@ def test_raises_error_correctly(self): assert "sample" in str(e.value) assert "'dict'" in str(e.value) - def test_basic_function(self): + def test_basic_function(self) -> None: valid_cols = ["s_id", "sample", *self.valid_cols] obj = pd.DataFrame(columns=valid_cols) assert self.func(obj) == "single" @@ -558,7 +558,7 @@ def test_basic_function(self): class TestSetCorrectIndex: - def test_no_change_needed(self): + def test_no_change_needed(self) -> None: index_names = ["t1", "t2"] test = _create_test_multiindex() test = test.rename(index_names) @@ -566,8 +566,8 @@ def test_no_change_needed(self): assert_frame_equal(df, set_correct_index(df, index_names)) - @pytest.mark.parametrize("level", (0, 1, [0, 1])) - def test_cols_to_index(self, level): + @pytest.mark.parametrize("level", [0, 1, [0, 1]]) + def test_cols_to_index(self, level) -> None: """Test what happens if one or multiple of the expected index cols are normal cols.""" index_names = ["t1", "t2"] test = _create_test_multiindex() @@ -582,7 +582,7 @@ def test_cols_to_index(self, level): # Nothing was changed besides setting the index assert_frame_equal(df, out) - def test_col_does_not_exist(self): + def test_col_does_not_exist(self) -> None: index_names = ["t1", "t2"] test = _create_test_multiindex() test = test.rename(index_names) @@ -591,8 +591,8 @@ def test_col_does_not_exist(self): with pytest.raises(ValidationError): set_correct_index(df, ["does_not_exist", *index_names]) - @pytest.mark.parametrize("drop_additional", (True, False)) - def test_additional_index_col(self, drop_additional): + @pytest.mark.parametrize("drop_additional", [True, False]) + def test_additional_index_col(self, drop_additional) -> None: index_names = ["t1", "t2"] test = _create_test_multiindex() test = test.rename(index_names) @@ -608,7 +608,7 @@ def test_additional_index_col(self, drop_additional): class TestIsSingleRegionsOfInterestList: @pytest.mark.parametrize( "value", - ( + [ list(range(6)), "test", np.arange(6), @@ -616,36 +616,36 @@ class TestIsSingleRegionsOfInterestList: pd.DataFrame(), pd.DataFrame(columns=[*range(3)]), pd.DataFrame([[*range(9)]], columns=_create_test_multiindex()), - ), + ], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not is_single_sensor_regions_of_interest_list(value) @pytest.mark.parametrize( ("cols", "roi_type_valid"), - ( + [ (["start", "end", "gs_id"], "gs"), (["start", "end", "gs_id", "something_extra"], "gs"), (["start", "end", "roi_id"], "roi"), (["start", "end", "roi_id", "something_extra"], "roi"), - ), + ], ) - def test_valid_versions(self, cols, roi_type_valid, roi_types): + def test_valid_versions(self, cols, roi_type_valid, roi_types) -> None: expected_outcome = roi_types in roi_type_valid or roi_types == "any" out = is_single_sensor_regions_of_interest_list(pd.DataFrame(columns=cols), region_type=roi_types) assert expected_outcome == out - def test_invalid_region_type_argument(self): + def test_invalid_region_type_argument(self) -> None: valid_cols = ["start", "end", "gs_id"] valid = pd.DataFrame(columns=valid_cols) with pytest.raises(ValueError): is_single_sensor_regions_of_interest_list(valid, region_type="invalid_value") - @pytest.mark.parametrize("col_name", ("gs_id", "roi_id")) - def test_identical_region_ids(self, col_name): + @pytest.mark.parametrize("col_name", ["gs_id", "roi_id"]) + def test_identical_region_ids(self, col_name) -> None: """Test that the search for identical region ids works.""" cols = [col_name, "start", "end"] roi_list = pd.DataFrame(columns=cols) @@ -656,8 +656,8 @@ def test_identical_region_ids(self, col_name): assert expected_outcome == out - @pytest.mark.parametrize("col_name", ("gs_id", "roi_id")) - def test_id_col_as_index(self, col_name): + @pytest.mark.parametrize("col_name", ["gs_id", "roi_id"]) + def test_id_col_as_index(self, col_name) -> None: """Test that the id col can either be the index or a column.""" cols = [col_name, "start", "end"] roi_list = pd.DataFrame(columns=cols) @@ -667,7 +667,7 @@ def test_id_col_as_index(self, col_name): assert out is True - def test_error_raising(self): + def test_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_single_sensor_regions_of_interest_list(pd.DataFrame(), raise_exception=True) @@ -678,28 +678,28 @@ def test_error_raising(self): class TestIsMultiSensorRegionsOfInterestList: @pytest.mark.parametrize( "value", - (list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])), + [list(range(6)), "test", np.arange(6), {}, pd.DataFrame(), pd.DataFrame(columns=[*range(3)])], ) - def test_wrong_datatype(self, value): + def test_wrong_datatype(self, value) -> None: assert not is_multi_sensor_regions_of_interest_list(value) @pytest.mark.parametrize( ("cols", "roi_type_valid"), - ( + [ (["start", "end", "gs_id"], "gs"), (["start", "end", "gs_id", "something_extra"], "gs"), (["start", "end", "roi_id"], "roi"), (["start", "end", "roi_id", "something_extra"], "roi"), - ), + ], ) - def test_valid_versions(self, cols, roi_type_valid, roi_types): + def test_valid_versions(self, cols, roi_type_valid, roi_types) -> None: expected_outcome = roi_types in roi_type_valid or roi_types == "any" out = is_multi_sensor_regions_of_interest_list({"s1": pd.DataFrame(columns=cols)}, region_type=roi_types) assert expected_outcome == out - def test_only_one_invalid(self): + def test_only_one_invalid(self) -> None: valid_cols = ["gs_id", "start", "end"] invalid_cols = ["start", "end"] valid = {"s1": pd.DataFrame(columns=valid_cols)} @@ -708,14 +708,14 @@ def test_only_one_invalid(self): assert is_multi_sensor_regions_of_interest_list(valid) assert not is_multi_sensor_regions_of_interest_list(invalid) - def test_invalid_region_type_argument(self): + def test_invalid_region_type_argument(self) -> None: valid_cols = ["start", "end", "gs_id"] valid = pd.DataFrame(columns=valid_cols) with pytest.raises(ValueError): is_multi_sensor_regions_of_interest_list({"si": valid}, region_type="invalid_value") - def test_nested_error_raising(self): + def test_nested_error_raising(self) -> None: with pytest.raises(ValidationError) as e: is_multi_sensor_regions_of_interest_list({"s1": pd.DataFrame()}, raise_exception=True) @@ -725,7 +725,7 @@ def test_nested_error_raising(self): class TestIsRegionsOfInterestList: - def test_raises_error_correctly(self): + def test_raises_error_correctly(self) -> None: with pytest.raises(ValidationError) as e: is_regions_of_interest_list(pd.DataFrame()) @@ -735,19 +735,19 @@ def test_raises_error_correctly(self): @pytest.mark.parametrize( ("obj", "out"), - ( + [ (pd.DataFrame(columns=["gs_id", "start", "end"]), "single"), (pd.DataFrame(columns=["roi_id", "start", "end"]), "single"), ({"s1": pd.DataFrame(columns=["roi_id", "start", "end"])}, "multi"), ({"s1": pd.DataFrame(columns=["gs_id", "start", "end"])}, "multi"), - ), + ], ) - def test_basic_function(self, obj, out): + def test_basic_function(self, obj, out) -> None: assert is_regions_of_interest_list(obj) == out class TestToDictMultiSensorData: - def test_convert_simple(self): + def test_convert_simple(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=BF_GYR) data = pd.concat([data, data], axis=1, keys=["s1", "s2"]) @@ -759,7 +759,7 @@ def test_convert_simple(self): assert out["s1"].shape == (10, 3) assert out["s2"].shape == (10, 3) - def test_dict_is_just_returned(self): + def test_dict_is_just_returned(self) -> None: data = pd.DataFrame(np.ones((10, 3)), columns=BF_GYR) data = {"s1": data, "s2": data} diff --git a/tests/test_utils/test_fast_quaternion_math.py b/tests/test_utils/test_fast_quaternion_math.py index f957865f..1fd66db9 100644 --- a/tests/test_utils/test_fast_quaternion_math.py +++ b/tests/test_utils/test_fast_quaternion_math.py @@ -24,7 +24,7 @@ class TestRateOfChangeFromGyro: ([0.5, 0.5, 0.5, 0.5], [1.0, 4.0, 0.4]), ], ) - def test_rate_of_change_from_gyro(self, q, g): + def test_rate_of_change_from_gyro(self, q, g) -> None: q = np.array(q) g = np.array(g) assert_array_almost_equal(rate_of_change_from_gyro(g, q), 0.5 * multiply(q, np.append(g, 0.0))) @@ -41,7 +41,7 @@ class TestMulitply: ([0.5, 0.5, 0.5, 0.5], [0.0, 0.0, 0.707107, 0.707107]), ], ) - def test_quaternion_multiplication(self, q1, q2): + def test_quaternion_multiplication(self, q1, q2) -> None: q1 = np.array(q1) q2 = np.array(q2) assert_array_almost_equal(multiply(q1, q2), (Rotation.from_quat(q1) * Rotation.from_quat(q2)).as_quat()) @@ -59,7 +59,7 @@ class TestRotateVector: ([0.5, 0.5, 0.5, 0.5], [1.0, 4.0, 0.4]), ], ) - def test_rotate_vector_by_quaternion(self, q, v): + def test_rotate_vector_by_quaternion(self, q, v) -> None: q = np.array(q) v = np.array(v) assert_array_almost_equal(rotate_vector(q, v), Rotation.from_quat(q).apply(v)) @@ -71,7 +71,7 @@ class TestQuatFromRotvec: @pytest.mark.parametrize( "v", [([1.0, 0.0, 0.0]), ([1.0, 1.0, 0.0]), ([1.0, 1.0, 1.0]), ([0.2, 0.1, 5.0]), ([10.0, 0.2, 0.0])] ) - def test_quat_from_rotation_vector(self, v): + def test_quat_from_rotation_vector(self, v) -> None: """Test quat_from_rotation_vector`.""" v = np.array(v) assert_array_almost_equal(quat_from_rotvec(v), Rotation.from_rotvec(v).as_quat()) @@ -81,6 +81,6 @@ class TestFindNormalize(TestNormalize): def func(self, x): return normalize(x) - def test_normalize_all_zeros(self): + def test_normalize_all_zeros(self) -> None: """Test vector [0, 0, 0].""" assert_array_almost_equal(self.func(np.array([0.0, 0, 0])), [0.0, 0, 0]) diff --git a/tests/test_utils/test_rotations.py b/tests/test_utils/test_rotations.py index d03ba7d7..61dcd564 100644 --- a/tests/test_utils/test_rotations.py +++ b/tests/test_utils/test_rotations.py @@ -34,25 +34,25 @@ def cyclic_rotation(): class TestRotationFromAngle: """Test the function `rotation_from_angle`.""" - def test_single_angle(self): + def test_single_angle(self) -> None: """Test single axis, single angle.""" assert_almost_equal(rotation_from_angle(np.array([1, 0, 0]), np.pi).as_quat(), [1.0, 0, 0, 0]) - def test_multiple_axis_and_angles(self): + def test_multiple_axis_and_angles(self) -> None: """Test multiple axes, multiple angles.""" start = np.repeat(np.array([1.0, 0, 0])[None, :], 5, axis=0) goal = np.repeat(np.array([1.0, 0, 0, 0])[None, :], 5, axis=0) angle = np.array([np.pi] * 5) assert_almost_equal(rotation_from_angle(start, angle).as_quat(), goal) - def test_multiple_axis_single_angle(self): + def test_multiple_axis_single_angle(self) -> None: """Test multiple axes, single angles.""" start = np.repeat(np.array([1.0, 0, 0])[None, :], 5, axis=0) goal = np.repeat(np.array([1.0, 0, 0, 0])[None, :], 5, axis=0) angle = np.array(np.pi) assert_almost_equal(rotation_from_angle(start, angle).as_quat(), goal) - def test_single_axis_multiple_angle(self): + def test_single_axis_multiple_angle(self) -> None: """Test single axis, multiple angles.""" start = np.array([1.0, 0, 0])[None, :] goal = np.repeat(np.array([1.0, 0, 0, 0])[None, :], 5, axis=0) @@ -60,7 +60,7 @@ def test_single_axis_multiple_angle(self): assert_almost_equal(rotation_from_angle(start, angle).as_quat(), goal) -def _compare_cyclic(data, rotated_data, cycles=1): +def _compare_cyclic(data, rotated_data, cycles=1) -> None: """Quickly check if rotated data was rotated by a cyclic axis rotation. This can be used in combination with :func:`cyclic_rotation fixture`, to test if this rotation was correctly @@ -83,7 +83,7 @@ class TestRotateDfDataset: multi_func = staticmethod(rotate_dataset) @pytest.fixture(autouse=True) - def _sample_sensor_data(self): + def _sample_sensor_data(self) -> None: """Create some sample data. This data is recreated before each test (using pytest.fixture). @@ -95,8 +95,8 @@ def _sample_sensor_data(self): dataset = {"s1": self.sample_sensor_data, "s2": self.sample_sensor_data + 0.5} self.sample_sensor_dataset = pd.concat(dataset, axis=1) - @pytest.mark.parametrize("inputs", ({"dataset": "single", "rotation": {}},)) - def test_invalid_inputs(self, inputs): + @pytest.mark.parametrize("inputs", [{"dataset": "single", "rotation": {}}]) + def test_invalid_inputs(self, inputs) -> None: """Test input combinations that should lead to ValueErrors.""" # Select the dataset for test using strings, as you can not use self-parameters in the decorator. if inputs["dataset"] == "single": @@ -105,8 +105,8 @@ def test_invalid_inputs(self, inputs): with pytest.raises(ValueError): self.multi_func(**inputs) - @pytest.mark.parametrize("ascending", (True, False)) - def test_order_is_preserved_multiple_datasets(self, cyclic_rotation, ascending): + @pytest.mark.parametrize("ascending", [True, False]) + def test_order_is_preserved_multiple_datasets(self, cyclic_rotation, ascending) -> None: """Test if the function preserves the order of columns, if they are not sorted in the beginning. Different orders are simulated by sorting the columns once in ascending and once in descending order. @@ -130,7 +130,7 @@ class TestFlipDfDataset(TestRotateDfDataset): single_func = staticmethod(_flip_sensor) multi_func = staticmethod(flip_dataset) - def test_non_orthogonal_matrix_raises(self): + def test_non_orthogonal_matrix_raises(self) -> None: # Create rot matrix that is not just 90 deg rot = rotation_from_angle(np.array([0, 1, 0]), np.pi / 4) with pytest.raises(ValueError) as e: @@ -138,7 +138,7 @@ def test_non_orthogonal_matrix_raises(self): assert "Only 90 deg rotations are allowed" in str(e.value) - def test_raises_when_multi_d_rotation_provided(self): + def test_raises_when_multi_d_rotation_provided(self) -> None: # create rotation object with multiple rotations rot = rotation_from_angle(np.array([0, 1, 0]), np.array([np.pi / 2, np.pi / 2])) with pytest.raises(ValueError) as e: @@ -157,7 +157,7 @@ class TestRotateDataset: multi_func = staticmethod(rotate_dataset) @pytest.fixture(autouse=True, params=("dict", "frame")) - def _sample_sensor_data(self, request): + def _sample_sensor_data(self, request) -> None: """Create some sample data. This data is recreated before each test (using pytest.fixture). @@ -172,13 +172,13 @@ def _sample_sensor_data(self, request): elif request.param == "frame": self.sample_sensor_dataset = pd.concat(dataset, axis=1) - def test_rotate_sensor(self, cyclic_rotation): + def test_rotate_sensor(self, cyclic_rotation) -> None: """Test if rotation is correctly applied to gyr and acc of single sensor data.""" rotated_data = self.single_func(self.sample_sensor_data, cyclic_rotation) _compare_cyclic(self.sample_sensor_data, rotated_data) - def test_rotate_dataset_single(self, cyclic_rotation): + def test_rotate_dataset_single(self, cyclic_rotation) -> None: """Rotate a single dataset with `rotate_dataset`. This tests the input option where no MultiIndex df is used. @@ -187,7 +187,7 @@ def test_rotate_dataset_single(self, cyclic_rotation): _compare_cyclic(self.sample_sensor_data, rotated_data) - def test_rotate_single_named_dataset(self, cyclic_rotation): + def test_rotate_single_named_dataset(self, cyclic_rotation) -> None: """Rotate a single dataset with a named sensor. This tests MultiIndex input with a single sensor. @@ -200,7 +200,7 @@ def test_rotate_single_named_dataset(self, cyclic_rotation): _compare_cyclic(test_data["s1"], rotated_data["s1"]) - def test_rotate_multiple_named_dataset(self, cyclic_rotation): + def test_rotate_multiple_named_dataset(self, cyclic_rotation) -> None: """Rotate multiple dataset with a named sensors. This tests MultiIndex input with multiple sensors. @@ -211,7 +211,7 @@ def test_rotate_multiple_named_dataset(self, cyclic_rotation): _compare_cyclic(test_data["s1"], rotated_data["s1"]) _compare_cyclic(test_data["s2"], rotated_data["s2"]) - def test_rotate_multiple_named_dataset_with_multiple_rotations(self, cyclic_rotation): + def test_rotate_multiple_named_dataset_with_multiple_rotations(self, cyclic_rotation) -> None: """Apply different rotations to each dataset.""" test_data = self.sample_sensor_dataset # Apply single cycle to "s1" and cycle twice to "s2" @@ -220,7 +220,7 @@ def test_rotate_multiple_named_dataset_with_multiple_rotations(self, cyclic_rota _compare_cyclic(test_data["s1"], rotated_data["s1"]) _compare_cyclic(test_data["s2"], rotated_data["s2"], cycles=2) - def test_only_rotate_some_sensors(self, cyclic_rotation): + def test_only_rotate_some_sensors(self, cyclic_rotation) -> None: """Only apply rotation to some sensors and not all. This uses the dict input to only provide a rotation for s1 and not s2. @@ -231,7 +231,7 @@ def test_only_rotate_some_sensors(self, cyclic_rotation): _compare_cyclic(test_data["s1"], rotated_data["s1"]) assert_frame_equal(test_data["s2"], rotated_data["s2"]) - def test_rotate_dataset_is_copy(self, cyclic_rotation): + def test_rotate_dataset_is_copy(self, cyclic_rotation) -> None: """Test if the output is indeed a copy and the original dataset was not modified.""" org_data = self.sample_sensor_dataset.copy() rotated_data = self.multi_func(self.sample_sensor_dataset, cyclic_rotation) @@ -241,8 +241,8 @@ def test_rotate_dataset_is_copy(self, cyclic_rotation): for k in get_multi_sensor_names(org_data): assert_frame_equal(org_data[k], self.sample_sensor_dataset[k]) - @pytest.mark.parametrize("ascending", (True, False)) - def test_order_is_preserved_single_sensor(self, cyclic_rotation, ascending): + @pytest.mark.parametrize("ascending", [True, False]) + def test_order_is_preserved_single_sensor(self, cyclic_rotation, ascending) -> None: """Test if the function preserves the order of columns, if they are not sorted in the beginning. Different orders are simulated by sorting the columns once in ascending and once in descending order. @@ -267,20 +267,20 @@ class TestFlipDataset(TestRotateDataset): class TestRotateDatasetSeries: - def test_invalid_input(self): + def test_invalid_input(self) -> None: with pytest.raises(ValidationError) as e: rotate_dataset_series("bla", Rotation.identity(3)) assert "SingleSensorData" in str(e) - def test_invalid_input_length(self): + def test_invalid_input_length(self) -> None: data = pd.DataFrame(np.zeros((10, 6)), columns=SF_COLS) with pytest.raises(ValueError) as e: rotate_dataset_series(data, Rotation.identity(11)) assert "number of rotations" in str(e) - def test_simple_series_rotation(self): + def test_simple_series_rotation(self) -> None: input_acc = [0, 0, 1] input_gyro = [0, 1, 0] data = np.array([[*input_acc, *input_gyro]] * 4) @@ -304,7 +304,7 @@ def test_simple_series_rotation(self): class TestFindShortestRotation: """Test the function `find_shortest_rotation`.""" - def test_find_shortest_rotation(self): + def test_find_shortest_rotation(self) -> None: """Test shortest rotation between two vectors.""" goal = np.array([0, 0, 1]) start = np.array([1, 0, 0]) @@ -312,7 +312,7 @@ def test_find_shortest_rotation(self): rotated = rot.apply(start) assert_almost_equal(rotated, goal) - def test_find_shortest_rotation_unnormalized_vector(self): + def test_find_shortest_rotation_unnormalized_vector(self) -> None: """Test shortest rotation for invalid input (one of the vectors is not normalized).""" with pytest.raises(ValueError): find_shortest_rotation([2, 0, 0], [0, 1, 0]) @@ -323,7 +323,7 @@ class TestGetGravityRotation: # TODO: Does this need more complex tests? - def test_gravity_rotation_simple(self): + def test_gravity_rotation_simple(self) -> None: """Test simple gravity rotation.""" rotation_quad = get_gravity_rotation(np.array([1, 0, 0])) rotated_vector = rotation_quad.apply(np.array([1, 0, 0])) @@ -335,7 +335,7 @@ class TestFindRotationAroundAxis: @pytest.mark.parametrize( ("rotation", "axis", "out"), - ( + [ (Rotation.from_rotvec([0, 0, np.pi / 2]), [0, 0, 1], [0, 0, np.pi / 2]), (Rotation.from_rotvec([0, 0, np.pi / 2]), [0, 1, 0], [0, 0, 0]), (Rotation.from_rotvec([0, 0, np.pi / 2]), [1, 0, 0], [0, 0, 0]), @@ -349,18 +349,18 @@ class TestFindRotationAroundAxis: [0, 0, 1], [0, 0, np.pi / 2], ), - ), + ], ) - def test_simple_cases(self, rotation, axis, out): + def test_simple_cases(self, rotation, axis, out) -> None: assert_array_almost_equal(find_rotation_around_axis(rotation, axis).as_rotvec(), out) - def test_multi_input_single_axis(self): + def test_multi_input_single_axis(self) -> None: rot = Rotation.from_rotvec(np.repeat([[0, 0, np.pi / 2]], 5, axis=0)) axis = [0, 0, 1] out = np.repeat([[0, 0, np.pi / 2]], 5, axis=0) assert_array_almost_equal(find_rotation_around_axis(rot, axis).as_rotvec(), out) - def test_multi_input_multi_axis(self): + def test_multi_input_multi_axis(self) -> None: rot = Rotation.from_rotvec(np.repeat([[0, 0, np.pi / 2]], 3, axis=0)) axis = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] out = [[0, 0, np.pi / 2], [0, 0, 0], [0, 0, 0]] @@ -372,7 +372,7 @@ class TestFindAngleBetweenOrientations: @pytest.mark.parametrize( ("ori1", "ori2", "axis", "out"), - ( + [ (Rotation.from_rotvec([0, 0, np.pi / 2]), Rotation.from_rotvec([0, 0, -np.pi / 2]), [0, 0, 1], np.pi), (Rotation.from_rotvec([0, 0, np.pi / 2]), Rotation.from_rotvec([0, 0, -np.pi / 2]), None, np.pi), (Rotation.from_rotvec([0, 0, np.pi / 2]), Rotation.from_rotvec([0, 0, -np.pi / 2]), [1, 0, 0], 0), @@ -403,21 +403,21 @@ class TestFindAngleBetweenOrientations: None, np.pi, # Yes, this really must be pi! ), - ), + ], ) - def test_simple_cases(self, ori1, ori2, axis, out): + def test_simple_cases(self, ori1, ori2, axis, out) -> None: result = find_angle_between_orientations(ori1, ori2, axis) assert_array_almost_equal(angle_diff(result, out), 0) assert isinstance(result, float) @pytest.mark.parametrize( ("ori1", "ori2", "axis", "out"), - ( + [ (Rotation.from_rotvec([0, 0, np.pi / 2]), Rotation.from_rotvec([0, 0, -np.pi / 2]), [0, 0, 0], "error"), (Rotation.from_rotvec([0, 0, np.pi / 2]), Rotation.from_rotvec([0, 0, np.pi / 2]), None, 0), - ), + ], ) - def test_zero_cases(self, ori1, ori2, axis, out): + def test_zero_cases(self, ori1, ori2, axis, out) -> None: if out == "error": with pytest.raises(ValueError): find_angle_between_orientations(ori1, ori2, axis) @@ -427,28 +427,28 @@ def test_zero_cases(self, ori1, ori2, axis, out): assert_array_almost_equal(angle_diff(result, out), 0) assert isinstance(result, float) - def test_multi_input_single_ref_single_axis(self): + def test_multi_input_single_ref_single_axis(self) -> None: rot = Rotation.from_rotvec(np.repeat([[0, 0, np.pi / 2]], 5, axis=0)) ref = Rotation.identity() axis = [0, 0, 1] out = [np.pi / 2] * 5 assert_array_almost_equal(find_angle_between_orientations(rot, ref, axis), out) - def test_single_input_multi_ref_single_axis(self): + def test_single_input_multi_ref_single_axis(self) -> None: ref = Rotation.from_rotvec(np.repeat([[0, 0, np.pi / 2]], 5, axis=0)) rot = Rotation.identity() axis = [0, 0, 1] out = [-np.pi / 2] * 5 assert_array_almost_equal(find_angle_between_orientations(rot, ref, axis), out) - def test_multi_input_multi_ref_single_axis(self): + def test_multi_input_multi_ref_single_axis(self) -> None: ref = Rotation.from_rotvec(np.repeat([[0, 0, np.pi / 2]], 5, axis=0)) rot = Rotation.identity(num=5) axis = [0, 0, 1] out = [-np.pi / 2] * 5 assert_array_almost_equal(find_angle_between_orientations(rot, ref, axis), out) - def test_multi_all(self): + def test_multi_all(self) -> None: ref = Rotation.from_rotvec(np.repeat([[0, 0, np.pi / 2]], 5, axis=0)) rot = Rotation.identity(num=5) axis = np.repeat([[0, 0, 1]], 5, axis=0) @@ -469,13 +469,13 @@ class TestFindUnsigned3dAngle: ([1, 0, 0], [-1, 0, 0], np.pi), ], ) - def test_find_unsigned_3d_angle(self, v1, v2, result): + def test_find_unsigned_3d_angle(self, v1, v2, result) -> None: """Test `find_unsigned_3d_angle` between two 1D vector.""" v1 = np.array(v1) v2 = np.array(v2) assert_almost_equal(find_unsigned_3d_angle(v1, v2), result) - def test_find_3d_angle_array(self): + def test_find_3d_angle_array(self) -> None: """Test `find_unsigned_3d_angle` between two 2D vector.""" v1 = np.array(4 * [[1, 0, 0]]) v2 = np.array(4 * [[0, 1, 0]]) @@ -487,22 +487,22 @@ def test_find_3d_angle_array(self): class TestAngleDiff: @pytest.mark.parametrize( ("a", "b", "out"), - ( + [ (-np.pi / 2, 0, -np.pi / 2), (0, -np.pi / 2, np.pi / 2), (-np.pi, np.pi, 0), (1.5 * np.pi, 0, -np.pi / 2), (np.array([-np.pi / 2, np.pi / 2]), np.array([0, 0]), np.array([-np.pi / 2, np.pi / 2])), - ), + ], ) - def test_various_inputs(self, a, b, out): + def test_various_inputs(self, a, b, out) -> None: assert_almost_equal(angle_diff(a, b), out) class TestSigned3DAngle: @pytest.mark.parametrize( ("v1", "v2", "n", "r"), - ( + [ ([0, 0, 1], [1, 0, 0], [0, 1, 0], 90), ([0, 0, 1], [1, 0, 0], [0, -1, 0], -90), ([0, 0, 1], [0, 0, 1], [0, 1, 0], 0), @@ -510,23 +510,23 @@ class TestSigned3DAngle: ([0, 1], [1, 0], [0, 0, 1], -90), ([0, 1], [1, 0], [0, 0, -1], 90), ([1, 0], [1, 0], [0, 0, 1], 0), - ), + ], ) - def test_simple_angle(self, v1, v2, n, r): + def test_simple_angle(self, v1, v2, n, r) -> None: result = find_signed_3d_angle(np.array(v1), np.array(v2), np.array(n)) assert result == np.deg2rad(r) @pytest.mark.parametrize( ("v1", "v2", "n", "r"), - ( + [ ([[0, 0, 1]], [1, 0, 0], [0, 1, 0], [90]), ([[0, 0, 1], [1, 0, 0]], [1, 0, 0], [0, 1, 0], [90, 0]), ([[0, 0, 1], [1, 0, 0]], [[1, 0, 0], [0, 0, 1]], [0, 1, 0], [90, -90]), ([[0, 0, 1], [1, 0, 0]], [[1, 0, 0], [0, 0, 1]], [[0, 1, 0], [0, -1, 0]], [90, 90]), - ), + ], ) - def test_angle_multi_d(self, v1, v2, n, r): + def test_angle_multi_d(self, v1, v2, n, r) -> None: result = find_signed_3d_angle(np.array(v1), np.array(v2), np.array(n)) assert_array_equal(np.deg2rad(r), result) diff --git a/tests/test_utils/test_signal_processing.py b/tests/test_utils/test_signal_processing.py index 57036c87..0d84a338 100644 --- a/tests/test_utils/test_signal_processing.py +++ b/tests/test_utils/test_signal_processing.py @@ -3,7 +3,7 @@ from gaitmap.utils.signal_processing import row_wise_autocorrelation -def test_row_wise_autocorrelation(): +def test_row_wise_autocorrelation() -> None: """Test if the manually implemented row wise autocorrelation function produces similar results than the numpy implementation . diff --git a/tests/test_utils/test_static_moment_detection.py b/tests/test_utils/test_static_moment_detection.py index 655e36bf..bf556e57 100644 --- a/tests/test_utils/test_static_moment_detection.py +++ b/tests/test_utils/test_static_moment_detection.py @@ -12,13 +12,13 @@ class TestFindStaticSamples: """Test the function `sliding_window_view`.""" - def test_invalid_input_dimension_default_overlap(self): + def test_invalid_input_dimension_default_overlap(self) -> None: """Test if value error is raised correctly on invalid input dimensions.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) with pytest.raises(ValueError, match=r".* dimensions.*"): find_static_samples(test_input, window_length=4, inactive_signal_th=0, metric="maximum") - def test_invalid_input_metric_default_overlap(self): + def test_invalid_input_metric_default_overlap(self) -> None: """Test if value error is raised correctly on invalid input dimensions.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -27,14 +27,14 @@ def test_invalid_input_metric_default_overlap(self): test_input, window_length=4, overlap=3, inactive_signal_th=0, metric="some_invalid_metric" ) - def test_invalid_window_length(self): + def test_invalid_window_length(self) -> None: """Test if value error is raised correctly on invalid input dimensions.""" test_input = np.array([0, 0, 0, 0, 0, 0]) test_input = np.column_stack([test_input, test_input, test_input]) with pytest.raises(ValueError, match=r".*Invalid window length*"): find_static_samples(test_input, window_length=10, overlap=3, inactive_signal_th=0, metric="maximum") - def test_single_window_fit(self): + def test_single_window_fit(self) -> None: """Test input where only a single window length fits within input signal.""" test_input = np.array([0, 0, 0, 0, 0, 1, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -47,7 +47,7 @@ def test_single_window_fit(self): assert min_vel_index == 2 assert miv_vel_value == 0 - def test_max_overlap_metric_max_w4_default_overlap(self): + def test_max_overlap_metric_max_w4_default_overlap(self) -> None: """Test binary input data on max metric with window size 4.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -61,7 +61,7 @@ def test_max_overlap_metric_max_w4_default_overlap(self): assert min_vel_index == 2 assert miv_vel_value == 0 - def test_max_overlap_metric_max_w3(self): + def test_max_overlap_metric_max_w3(self) -> None: """Test binary input data on max metric with window size 3.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -75,7 +75,7 @@ def test_max_overlap_metric_max_w3(self): assert min_vel_index == 1 assert miv_vel_value == 0 - def test_max_overlap_metric_max_w6_default_overlap(self): + def test_max_overlap_metric_max_w6_default_overlap(self) -> None: """Test binary input data on max metric with window size 6.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -89,7 +89,7 @@ def test_max_overlap_metric_max_w6_default_overlap(self): assert min_vel_index == 3 assert miv_vel_value == 0 - def test_max_overlap_mean_w3_with_noise_default_overlap(self): + def test_max_overlap_mean_w3_with_noise_default_overlap(self) -> None: """Test binary input data on mean metric with window size 4 after adding a bit of noise.""" test_input = np.array([0, 0.1, 0, 0, 0.1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0.1, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -104,7 +104,7 @@ def test_max_overlap_mean_w3_with_noise_default_overlap(self): assert min_vel_index == 11 assert miv_vel_value == 0 - def test_max_overlap_max_w3_with_noise(self): + def test_max_overlap_max_w3_with_noise(self) -> None: """Test binary input data on max metric with window size 4 after adding a bit of noise.""" test_input = np.array([0, 0.1, 0, 0, 0.1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0.1, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -122,7 +122,7 @@ def test_max_overlap_max_w3_with_noise(self): class TestFindStaticSequences: """Test the function `sliding_window_view`.""" - def test_max_overlap_metric_max_w4_default_overlap(self): + def test_max_overlap_metric_max_w4_default_overlap(self) -> None: """Test binary input data on max metric with window size 4.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -134,7 +134,7 @@ def test_max_overlap_metric_max_w4_default_overlap(self): ) assert_array_equal(test_output, expected_output) - def test_max_overlap_metric_mean_w4_default_overlap(self): + def test_max_overlap_metric_mean_w4_default_overlap(self) -> None: """Test binary input data on max metric with window size 4.""" test_input = np.array([0, 0, 0.1, 0, 0.1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0.1, 0, 0, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -148,7 +148,7 @@ def test_max_overlap_metric_mean_w4_default_overlap(self): class TestFirstStaticWindowsMultiSensor: - def test_basic_single_sensor(self): + def test_basic_single_sensor(self) -> None: test_input = np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]) test_input = np.column_stack([test_input, test_input, test_input])[:, None, :] @@ -158,7 +158,7 @@ def test_basic_single_sensor(self): ) assert_array_equal(test_output, (6, 10)) - def test_basic_multi_sensor(self): + def test_basic_multi_sensor(self) -> None: test_input = np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) test_input_2 = np.array([1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -170,7 +170,7 @@ def test_basic_multi_sensor(self): ) assert_array_equal(test_output, (8, 12)) - def test_invalid_shape_sub_array(self): + def test_invalid_shape_sub_array(self) -> None: test_input = np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) test_input_2 = np.array([1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]) @@ -181,7 +181,7 @@ def test_invalid_shape_sub_array(self): ) assert "2D" in str(e) - def test_invalid_shape_np_array(self): + def test_invalid_shape_np_array(self) -> None: test_input = np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) test_input = np.column_stack([test_input, test_input, test_input]) @@ -192,7 +192,7 @@ def test_invalid_shape_np_array(self): ) assert "3D" in str(e) - def test_invalid_metric(self): + def test_invalid_metric(self) -> None: test_input = np.array([1, 1, 1, 1, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input])[:, None, :] @@ -205,7 +205,7 @@ def test_invalid_metric(self): assert "metric" in str(e) - def test_no_static_window(self): + def test_no_static_window(self) -> None: test_input = np.array([1, 1, 1, 1, 1, 1]) test_input = np.column_stack([test_input, test_input, test_input])[:, None, :] diff --git a/tests/test_utils/test_stride_list_conversion.py b/tests/test_utils/test_stride_list_conversion.py index d36023d6..2e20bad7 100644 --- a/tests/test_utils/test_stride_list_conversion.py +++ b/tests/test_utils/test_stride_list_conversion.py @@ -28,9 +28,9 @@ def _create_example_stride_list(self, stride_type: str): @pytest.mark.parametrize( "stride_type", - ("segmented", "min_vel", "ic"), + ["segmented", "min_vel", "ic"], ) - def test_all_good(self, stride_type): + def test_all_good(self, stride_type) -> None: event_list = self._create_example_stride_list(stride_type) filtered_event_list, removed_strides = enforce_stride_list_consistency(event_list, stride_type) assert_frame_equal(event_list, filtered_event_list) @@ -38,9 +38,9 @@ def test_all_good(self, stride_type): @pytest.mark.parametrize( "stride_type", - ("segmented", "min_vel", "ic"), + ["segmented", "min_vel", "ic"], ) - def test_simple_error(self, stride_type): + def test_simple_error(self, stride_type) -> None: event_list = self._create_example_stride_list(stride_type) wrong_s_ids = [0, 3, 5, 19] modified = SL_EVENT_ORDER[stride_type][-1] @@ -56,9 +56,9 @@ def test_simple_error(self, stride_type): @pytest.mark.parametrize( "stride_type", - ("segmented", "min_vel", "ic"), + ["segmented", "min_vel", "ic"], ) - def test_nan_removal(self, stride_type): + def test_nan_removal(self, stride_type) -> None: """Test that strides that contain NaN in any column are removed.""" event_list = self._create_example_stride_list(stride_type) nan_s_ids = [0, 3, 5, 19] @@ -73,7 +73,7 @@ def test_nan_removal(self, stride_type): assert_frame_equal(event_list[event_list["s_id"].isin(nan_s_ids)], removed_strides) assert len(removed_strides) == len(nan_s_ids) - def test_check_stride_list(self): + def test_check_stride_list(self) -> None: stride_type = "segmented" # First use a stride list that works event_list = self._create_example_stride_list(stride_type) @@ -107,8 +107,8 @@ def _create_example_stride_list_with_pause(self): stride_list = stride_list.drop(5) return stride_list - @pytest.mark.parametrize("target", ("ic", "min_vel")) - def test_simple_conversion(self, target): + @pytest.mark.parametrize("target", ["ic", "min_vel"]) + def test_simple_conversion(self, target) -> None: stride_list = self._create_example_stride_list_with_pause() converted, dropped = _segmented_stride_list_to_min_vel_single_sensor(stride_list, target_stride_type=target) @@ -124,7 +124,7 @@ def test_simple_conversion(self, target): # Check that the length of all strides is still 1 assert np.all((converted["end"] - converted["start"]).round(2) == 1.0) - def test_second_to_last_stride_is_break(self): + def test_second_to_last_stride_is_break(self) -> None: """Test an edge case where there is a break right before the last stride.""" stride_list = self._create_example_stride_list_with_pause() # Drop the second to last stride to create a pause @@ -137,8 +137,8 @@ def test_second_to_last_stride_is_break(self): assert len(dropped) == 3 assert list(dropped.index) == [4, 7, 9] - @pytest.mark.parametrize("target", ("ic", "min_vel")) - def test_simple_conversion_multiple(self, target): + @pytest.mark.parametrize("target", ["ic", "min_vel"]) + def test_simple_conversion_multiple(self, target) -> None: stride_list = self._create_example_stride_list_with_pause() converted = convert_segmented_stride_list(stride_list, target_stride_type=target) @@ -155,8 +155,7 @@ def test_simple_conversion_multiple(self, target): class TestIntersectStrideList: - def test_simple_with_overlap(self): - + def test_simple_with_overlap(self) -> None: stride_list = pd.DataFrame( {"start": [10, 25, 30, 50], "end": [20, 30, 40, 55]}, index=pd.Series([0, 1, 2, 3], name="s_id") ) diff --git a/tests/test_utils/test_vector_math.py b/tests/test_utils/test_vector_math.py index 1856361f..8b78970f 100644 --- a/tests/test_utils/test_vector_math.py +++ b/tests/test_utils/test_vector_math.py @@ -27,11 +27,11 @@ class TestIsAlmostParallelOrAntiprallel: ([0, -1, 2], [0, -1, 2], True), ], ) - def test_is_almost_parallel_or_antiprallel_single_vector(self, v1, v2, result): + def test_is_almost_parallel_or_antiprallel_single_vector(self, v1, v2, result) -> None: """Test single vectors if they parallel or antiprallel.""" assert is_almost_parallel_or_antiparallel(np.array(v1), np.array(v2)) == result - def test_is_almost_parallel_or_antiprallel_multiple_vector(self): + def test_is_almost_parallel_or_antiprallel_multiple_vector(self) -> None: """Test array of vectors.""" v1 = np.repeat(np.array([1.0, 0, 0])[None, :], 4, axis=0) v2 = np.repeat(np.array([2.0, 0, 0])[None, :], 4, axis=0) @@ -45,7 +45,7 @@ class TestNormalize: def func(self, x): return normalize(x) - def test_normalize_1d_array(self): + def test_normalize_1d_array(self) -> None: """Test 1D array.""" assert_array_equal(self.func(np.array([2.0, 0, 0])), np.array([1.0, 0, 0])) @@ -53,11 +53,11 @@ def test_normalize_1d_array(self): ("v1", "v2"), [([0, 2.0, 0], [0, 1, 0]), ([2.0, 0, 0], [1.0, 0, 0]), ([0.5, 0.5, 0], [0.707107, 0.707107, 0])], ) - def test_normalize_2d_array(self, v1, v2): + def test_normalize_2d_array(self, v1, v2) -> None: """Test 2D array.""" assert_array_almost_equal(self.func(np.array(v1)), np.array(v2)) - def test_normalize_all_zeros(self): + def test_normalize_all_zeros(self) -> None: """Test vector [0, 0, 0].""" assert_array_almost_equal(self.func(np.array([0, 0, 0])), [np.nan, np.nan, np.nan]) @@ -65,7 +65,7 @@ def test_normalize_all_zeros(self): class TestFindRandomOrthogonal: """Test the function `find_random_orthogonal`.""" - def test_find_random_orthogonal_general(self): + def test_find_random_orthogonal_general(self) -> None: """Test find orthogonal for general vector`.""" v = np.array([0.5, 0.2, 1]) orthogonal = find_random_orthogonal(v) @@ -75,7 +75,7 @@ def test_find_random_orthogonal_general(self): @pytest.mark.parametrize( "vec", [[1, 0, 0], [2, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, -2, 0], [0, 0, 1], [0, 0, -1], [0, 0, -2]] ) - def test_find_random_orthogonal_special(self, vec): + def test_find_random_orthogonal_special(self, vec) -> None: """Test find_random_orthogonal for vectors parallel or antiparallel to [1,0,0],[0,1,0],[0,0,1]`.""" v = np.array(vec) orthogonal = find_random_orthogonal(v) @@ -96,7 +96,7 @@ class TestFindOrthogonal: ([1, 0.2, 1], [4, 1.2, 0]), ], ) - def test_find_orthogonal(self, v1, v2): + def test_find_orthogonal(self, v1, v2) -> None: """Test find_orthogonal for 1D vectors`.""" v1 = np.array(v1) v2 = np.array(v2) @@ -105,7 +105,7 @@ def test_find_orthogonal(self, v1, v2): assert_almost_equal(np.dot(orthogonal, v2), 0) assert_almost_equal(norm(orthogonal), 1) - def test_find_orthogonal_array(self): + def test_find_orthogonal_array(self) -> None: """Test find_orthogonal for multidimension vectors`.""" v1 = np.array(4 * [[1, 0, 0]]) v2 = np.array(4 * [[0, 1, 0]]) diff --git a/tests/test_zupt_detection/test_combo_zupt_detector.py b/tests/test_zupt_detection/test_combo_zupt_detector.py index 2aed6519..6d7c2424 100644 --- a/tests/test_zupt_detection/test_combo_zupt_detector.py +++ b/tests/test_zupt_detection/test_combo_zupt_detector.py @@ -24,7 +24,7 @@ def after_action_instance(self, healthy_example_imu_data): class DummyZuptDetector(BaseZuptDetector, RegionZuptDetectorMixin): - def __init__(self, zupts): + def __init__(self, zupts) -> None: self.zupts = zupts def detect(self, data, sampling_rate_hz, **kwargs): @@ -35,11 +35,11 @@ def detect(self, data, sampling_rate_hz, **kwargs): class TestComboZuptDetector: @pytest.mark.parametrize("detector_list", [None, []]) - def test_empty_detector_list(self, detector_list): + def test_empty_detector_list(self, detector_list) -> None: with pytest.raises(ValueError): ComboZuptDetector(detector_list).detect(pd.DataFrame(), sampling_rate_hz=1) - def test_kwargs_forwarded(self): + def test_kwargs_forwarded(self) -> None: class MockZUPTDetector(BaseZuptDetector): zupts_ = pd.DataFrame(columns=["start", "end"]) per_sample_zupts_ = np.zeros(10) @@ -53,7 +53,7 @@ class MockZUPTDetector(BaseZuptDetector): for call in mock_detect.call_args_list: assert call.kwargs["foo"] == "bar" - def test_dummy(self): + def test_dummy(self) -> None: zupts = pd.DataFrame( [[0, 10], [30, 55], [85, 90]], columns=["start", "end"], @@ -62,7 +62,7 @@ def test_dummy(self): test.detect(pd.DataFrame(np.zeros(100)), sampling_rate_hz=1) assert_array_equal(test.zupts_, zupts) - def test_empty_data_edge_case(self): + def test_empty_data_edge_case(self) -> None: zupts = pd.DataFrame( [[0, 10], [30, 55], [85, 90]], columns=["start", "end"], @@ -71,7 +71,7 @@ def test_empty_data_edge_case(self): test.detect(pd.DataFrame(), sampling_rate_hz=1) assert_array_equal(test.zupts_, pd.DataFrame(columns=["start", "end"])) - def test_combine_with_or(self): + def test_combine_with_or(self) -> None: zupts_a = pd.DataFrame( [[0, 10], [30, 55], [85, 90]], columns=["start", "end"], @@ -86,7 +86,7 @@ def test_combine_with_or(self): test.zupts_, pd.DataFrame([[0, 10], [30, 60], [85, 90], [95, 100]], columns=["start", "end"]) ) - def test_combine_with_and(self): + def test_combine_with_and(self) -> None: zupts_a = pd.DataFrame( [[0, 10], [30, 55], [85, 90]], columns=["start", "end"], diff --git a/tests/test_zupt_detection/test_moving_window_zupt_detector.py b/tests/test_zupt_detection/test_moving_window_zupt_detector.py index d2886b95..79d397a6 100644 --- a/tests/test_zupt_detection/test_moving_window_zupt_detector.py +++ b/tests/test_zupt_detection/test_moving_window_zupt_detector.py @@ -49,11 +49,11 @@ class TestNormZuptDetector: algorithm_class: Union[Type[NormZuptDetector], Type[AredZuptDetector]] @pytest.fixture(params=(NormZuptDetector, AredZuptDetector), autouse=True) - def get_algorithm_class(self, request): + def get_algorithm_class(self, request) -> None: self.algorithm_class = request.param - @pytest.mark.parametrize(("ws", "sr"), ((1, 1), (1, 2), (2, 1), (2.49, 1))) - def test_error_window_to_small(self, healthy_example_imu_data, ws, sr): + @pytest.mark.parametrize(("ws", "sr"), [(1, 1), (1, 2), (2, 1), (2.49, 1)]) + def test_error_window_to_small(self, healthy_example_imu_data, ws, sr) -> None: with pytest.raises(ValidationError, match=r".*The effective window size is smaller*"): self.algorithm_class(window_length_s=ws).detect( healthy_example_imu_data["left_sensor"], sampling_rate_hz=sr @@ -61,7 +61,7 @@ def test_error_window_to_small(self, healthy_example_imu_data, ws, sr): @pytest.mark.parametrize( ("overlap1", "overlap2", "valid"), - ( + [ (0.5, 0.5, False), (1, None, False), (-0.5, None, False), @@ -71,21 +71,20 @@ def test_error_window_to_small(self, healthy_example_imu_data, ws, sr): (None, 200, False), (None, 201, False), (None, 0.5, False), - (None, 200, False), (None, 1, True), (None, -1, True), (None, 199, True), - ), + ], ) - def test_wrong_window_overlap(self, overlap1, overlap2, valid, healthy_example_imu_data): + def test_wrong_window_overlap(self, overlap1, overlap2, valid, healthy_example_imu_data) -> None: context = pytest.raises(ValidationError) if not valid else nullcontext() with context: self.algorithm_class(window_length_s=1, window_overlap=overlap1, window_overlap_samples=overlap2).detect( healthy_example_imu_data["left_sensor"], sampling_rate_hz=200 ) - @pytest.mark.parametrize(("win_len", "overlap", "valid"), ((1, 0, True), (10 / 200, 0.99, False))) - def test_effective_overlap_error(self, win_len, overlap, valid, healthy_example_imu_data): + @pytest.mark.parametrize(("win_len", "overlap", "valid"), [(1, 0, True), (10 / 200, 0.99, False)]) + def test_effective_overlap_error(self, win_len, overlap, valid, healthy_example_imu_data) -> None: if not valid: with pytest.raises(ValidationError, match=r".*The effective window overlap after rounding is 1"): self.algorithm_class( @@ -98,16 +97,16 @@ def test_effective_overlap_error(self, win_len, overlap, valid, healthy_example_ @pytest.mark.parametrize( ("sensor", "axis", "valid"), - ( + [ ("acc", SF_ACC, True), ("acc", BF_ACC, True), ("acc", BF_GYR, False), ("gyr", BF_GYR, True), ("gyr", SF_GYR, True), ("gyr", SF_ACC, False), - ), + ], ) - def test_single_sensor_data(self, sensor, axis, valid): + def test_single_sensor_data(self, sensor, axis, valid) -> None: data = pd.DataFrame(np.empty((10, len(axis))), columns=axis) if valid is False: with pytest.raises(ValidationError): @@ -115,7 +114,7 @@ def test_single_sensor_data(self, sensor, axis, valid): else: self.algorithm_class(sensor=sensor, window_length_s=0.5).detect(data, sampling_rate_hz=10) - def test_debug_outputs(self): + def test_debug_outputs(self) -> None: data = pd.DataFrame(np.empty((10, 3)), columns=BF_GYR) zupt = self.algorithm_class( sensor="gyr", window_length_s=0.5, window_overlap=0.2, window_overlap_samples=None @@ -123,14 +122,14 @@ def test_debug_outputs(self): assert zupt.window_length_samples_ == 5 assert zupt.window_overlap_samples_ == 1 - def test_invalid_input_metric_default_overlap(self, healthy_example_imu_data): + def test_invalid_input_metric_default_overlap(self, healthy_example_imu_data) -> None: """Test if value error is raised correctly on invalid input dimensions.""" with pytest.raises(ValueError, match=r".*Invalid metric passed!.*"): self.algorithm_class(metric="invalid").detect( healthy_example_imu_data["left_sensor"], sampling_rate_hz=204.8 ) - def test_invalid_window_length(self, healthy_example_imu_data): + def test_invalid_window_length(self, healthy_example_imu_data) -> None: """Test if value error is raised when window longer than signal.""" with pytest.raises(ValueError, match=r".*Invalid window length*"): self.algorithm_class(window_length_s=1000).detect( @@ -139,23 +138,23 @@ def test_invalid_window_length(self, healthy_example_imu_data): @pytest.mark.parametrize( ("win_overlap_samples", "expected"), - ( + [ (50, 50), (-1, 99), (-10, 90), (0, 0), (-100, 0), (1, 1), - ), + ], ) - def test_sample_win_overlap(self, win_overlap_samples, expected, healthy_example_imu_data): + def test_sample_win_overlap(self, win_overlap_samples, expected, healthy_example_imu_data) -> None: """Test that window overlap is correctly calculated when provided in samples.""" out = self.algorithm_class( window_length_s=100, window_overlap_samples=win_overlap_samples, window_overlap=None ).detect(healthy_example_imu_data["left_sensor"].iloc[:500], sampling_rate_hz=1.0) assert out.window_overlap_samples_ == expected - def test_single_window_fit(self): + def test_single_window_fit(self) -> None: """Test input where only a single window length fits within input signal.""" test_input = np.array([0, 0, 0, 0, 0, 1, 1, 1]) test_input = pd.DataFrame(np.column_stack([test_input, test_input, test_input]), columns=SF_GYR) @@ -170,7 +169,7 @@ def test_single_window_fit(self): assert test_output.min_vel_value_ == 0.0 assert test_output.min_vel_index_ == 2 - def test_max_overlap_metric_max_w4_default_overlap(self): + def test_max_overlap_metric_max_w4_default_overlap(self) -> None: """Test binary input data on max metric with window size 4.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = pd.DataFrame(np.column_stack([test_input, test_input, test_input]), columns=SF_GYR) @@ -190,7 +189,7 @@ def test_max_overlap_metric_max_w4_default_overlap(self): assert test_output.min_vel_value_ == 0.0 assert test_output.min_vel_index_ == 2 - def test_max_overlap_metric_max_w3(self): + def test_max_overlap_metric_max_w3(self) -> None: """Test binary input data on max metric with window size 3.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = pd.DataFrame(np.column_stack([test_input, test_input, test_input]), columns=SF_GYR) @@ -210,7 +209,7 @@ def test_max_overlap_metric_max_w3(self): assert test_output.min_vel_value_ == 0.0 assert test_output.min_vel_index_ == 1 - def test_max_overlap_metric_max_w6_default_overlap(self): + def test_max_overlap_metric_max_w6_default_overlap(self) -> None: """Test binary input data on max metric with window size 6.""" test_input = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1]) test_input = pd.DataFrame(np.column_stack([test_input, test_input, test_input]), columns=SF_GYR) @@ -230,7 +229,7 @@ def test_max_overlap_metric_max_w6_default_overlap(self): assert test_output.min_vel_value_ == 0.0 assert test_output.min_vel_index_ == 3 - def test_max_overlap_max_w3_with_noise_default_overlap(self): + def test_max_overlap_max_w3_with_noise_default_overlap(self) -> None: """Test binary input data on mean metric with window size 4 after adding a bit of noise.""" test_input = np.array([0, 0.1, 0, 0, 0.1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0.1, 0, 0, 1, 1]) test_input = pd.DataFrame(np.column_stack([test_input, test_input, test_input]), columns=SF_GYR) @@ -248,7 +247,7 @@ def test_max_overlap_max_w3_with_noise_default_overlap(self): assert test_output.min_vel_value_ == 0.0 assert test_output.min_vel_index_ == 11 - def test_max_overlap_max_w3_with_noise(self): + def test_max_overlap_max_w3_with_noise(self) -> None: """Test binary input data on max metric with window size 4 after adding a bit of noise.""" test_input = np.array([0, 0.1, 0, 0, 0.1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0.1, 0, 0, 1, 1]) test_input = pd.DataFrame(np.column_stack([test_input, test_input, test_input]), columns=SF_GYR) @@ -266,7 +265,7 @@ def test_max_overlap_max_w3_with_noise(self): assert test_output.min_vel_value_ == 0.0 assert test_output.min_vel_index_ == 11 - def test_no_zupt_detected(self): + def test_no_zupt_detected(self) -> None: test_input = np.array( [0.2, 0.4, 0.2, 0.2, 0.4, 0.2, 1, 1, 1, 1, 0.2, 0.2, 0.2, 1, 1, 1, 0.2, 0.2, 0.4, 0.2, 0.2, 1, 1] ) @@ -285,7 +284,7 @@ def test_no_zupt_detected(self): assert np.isnan(test_output.min_vel_value_) assert test_output.min_vel_index_ == 11 - def test_real_data_regression(self, healthy_example_imu_data, snapshot): + def test_real_data_regression(self, healthy_example_imu_data, snapshot) -> None: """Test real data with default parameters.""" test_output = self.algorithm_class().detect(healthy_example_imu_data["left_sensor"], sampling_rate_hz=204.8) snapshot.assert_match(test_output.zupts_) @@ -294,12 +293,12 @@ def test_real_data_regression(self, healthy_example_imu_data, snapshot): class TestShoeZuptDetector: # Note: We don't retest all the validation here, as this is basically identical to the other ZUPT detectors. - def test_real_data_regression(self, healthy_example_imu_data, snapshot): + def test_real_data_regression(self, healthy_example_imu_data, snapshot) -> None: """Test real data with default parameters.""" test_output = ShoeZuptDetector().detect(healthy_example_imu_data["left_sensor"], sampling_rate_hz=204.8) snapshot.assert_match(test_output.zupts_) - def test_no_zupt_detected(self): + def test_no_zupt_detected(self) -> None: test_input = np.array( [0.2, 0.4, 0.2, 0.2, 0.4, 0.2, 1, 1, 1, 1, 0.2, 0.2, 0.2, 1, 1, 1, 0.2, 0.2, 0.4, 0.2, 0.2, 1, 1] ) diff --git a/tests/test_zupt_detection/test_stride_event_zupt_detector.py b/tests/test_zupt_detection/test_stride_event_zupt_detector.py index 2662a060..31d7c360 100644 --- a/tests/test_zupt_detection/test_stride_event_zupt_detector.py +++ b/tests/test_zupt_detection/test_stride_event_zupt_detector.py @@ -25,7 +25,7 @@ def after_action_instance(self, healthy_example_imu_data): class TestStrideEventZuptDetector: - def test_improper_stride_list(self): + def test_improper_stride_list(self) -> None: with pytest.raises(ValidationError): StrideEventZuptDetector().detect( pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 10, columns=SF_COLS), @@ -34,7 +34,7 @@ def test_improper_stride_list(self): stride_event_list=pd.DataFrame([[0, 5]], columns=["start", "end"]), ) - def test_region_0(self): + def test_region_0(self) -> None: stride_event_list = pd.DataFrame( [[0, 7, 0], [5, 10, 5]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") ) @@ -50,7 +50,7 @@ def test_region_0(self): zupts, pd.DataFrame([[0, 1], [5, 6], [7, 8], [10, 11]], columns=["start", "end"], dtype="Int64") ) - def test_edge_case(self): + def test_edge_case(self) -> None: """We test what happens if the zupt is exactly the first or last sample of the data or outside the range.""" stride_event_list = pd.DataFrame( [[0, 10, 0], [10, 15, 10]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") @@ -65,7 +65,7 @@ def test_edge_case(self): assert_frame_equal(zupts, pd.DataFrame([[0, 1]], columns=["start", "end"], dtype="Int64")) assert detector.half_region_size_samples_ == 0 - def test_with_overlap(self): + def test_with_overlap(self) -> None: stride_event_list = pd.DataFrame( [[0, 7, 0], [5, 10, 5]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") ) @@ -78,7 +78,7 @@ def test_with_overlap(self): assert_frame_equal(zupts, pd.DataFrame([[0, 11]], columns=["start", "end"], dtype="Int64")) assert detector.half_region_size_samples_ == 2 - def test_simple(self): + def test_simple(self) -> None: stride_event_list = pd.DataFrame( [[0, 5, 0], [10, 15, 10]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") )