Skip to content

Commit 29b2062

Browse files
flake (#316)
* flake * update version
1 parent cef3f92 commit 29b2062

29 files changed

+124
-96
lines changed

lightning_pose/api/model.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from lightning_pose.data.datatypes import MultiviewPredictionResult, PredictionResult
1111
from lightning_pose.models import ALLOWED_MODELS
1212
from lightning_pose.utils import io as io_utils
13-
from lightning_pose.utils.predictions import (
14-
generate_labeled_video as generate_labeled_video_fn,
15-
)
13+
from lightning_pose.utils.predictions import generate_labeled_video as generate_labeled_video_fn
1614
from lightning_pose.utils.predictions import (
1715
load_model_from_checkpoint,
1816
predict_dataset,
@@ -131,15 +129,17 @@ def predict_on_label_csv(
131129
132130
Args:
133131
csv_file (str | Path): Path to the CSV file of images, keypoint locations.
134-
data_dir (str | Path, optional): Root path for relative paths in the CSV file. Defaults to the
135-
data_dir originally used when training.
132+
data_dir (str | Path, optional): Root path for relative paths in the CSV file.
133+
Defaults to the data_dir originally used when training.
136134
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
137135
predictions.
138-
generate_labeled_images (bool, optional): Whether to save labeled images. Defaults to False.
136+
generate_labeled_images (bool, optional): Whether to save labeled images.
137+
Defaults to False.
139138
output_dir (str | Path, optional): The directory to save outputs to.
140-
Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved.
141-
add_train_val_test_set (bool): When predicting on training dataset, set to true to add the `set`
142-
column to the prediction output.
139+
Defaults to `{model_dir}/image_preds/{csv_file_name}`.
140+
If set to None, outputs are not saved.
141+
add_train_val_test_set (bool): When predicting on training dataset, set to true to add
142+
the `set` column to the prediction output.
143143
Returns:
144144
PredictionResult: A PredictionResult object containing the predictions and metrics.
145145
"""
@@ -218,13 +218,17 @@ def predict_on_video_file(
218218
219219
Args:
220220
video_file (str | Path): Path to the video file.
221+
output_dir (str | Path, optional): The directory to save outputs to.
222+
Defaults to `{model_dir}/image_preds/{csv_file_name}`.
223+
If set to None, outputs are not saved.
221224
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
222225
predictions.
223-
generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False.
224-
output_dir (str | Path, optional): The directory to save outputs to.
225-
Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved.
226+
generate_labeled_video (bool, optional): Whether to save a labeled video.
227+
Defaults to False.
228+
226229
Returns:
227230
PredictionResult: A PredictionResult object containing the predictions and metrics.
231+
228232
"""
229233
self._load()
230234
video_file = Path(video_file)
@@ -278,15 +282,26 @@ def predict_on_video_file_multiview(
278282
compute_metrics: bool = True,
279283
generate_labeled_video: bool = False,
280284
) -> MultiviewPredictionResult:
281-
"""Version of `predict_on_video_file` that gives models access to multiple camera views of each frame.
285+
"""Version of `predict_on_video_file` that accesses to multiple camera views of each frame.
282286
283287
Arguments:
284-
video_file_per_view (list[str] | list[Path]): A list of video files each from a different view of the
285-
same session. Number of video files must match the `view_names` in the config file. Order of the list
286-
does not matter: video files are intelligently matched to views by their filename using
287-
`utils.io.collect_video_files_by_view`.
288+
video_file_per_view (list[str] | list[Path]): A list of video files each from a
289+
different view of the same session.
290+
Number of video files must match the `view_names` in the config file.
291+
Order of the list does not matter: video files are intelligently matched to views
292+
by their filename using `utils.io.collect_video_files_by_view`.
293+
output_dir (str | Path, optional): The directory to save outputs to.
294+
Defaults to `{model_dir}/image_preds/{csv_file_name}`.
295+
If set to None, outputs are not saved.
296+
compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
297+
predictions.
298+
generate_labeled_video (bool, optional): Whether to save a labeled video.
299+
Defaults to False.
288300
289-
See `predict_on_video_file` docstring for other arguments."""
301+
Returns:
302+
MultiviewPredictionResult: object containing the predictions and metrics for each view.
303+
304+
"""
290305
assert self.config.is_multi_view()
291306
self._load()
292307

lightning_pose/api/model_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def is_multi_view(self):
3535
)
3636
return True
3737

38-
## Eval ##
39-
4038
def test_video_files(self) -> list[Path]:
4139
files = check_video_paths(
4240
return_absolute_path(self.cfg.eval.test_videos_directory)

lightning_pose/cli/commands/crop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def register_parser(subparsers):
2424
"""\
2525
Crops a video or labeled frames based on model predictions.
2626
Requires model predictions to already have been generated using `litpose predict`.
27-
27+
2828
Cropped videos are saved to:
2929
<model_dir>/
3030
└── video_preds/
@@ -45,7 +45,7 @@ def register_parser(subparsers):
4545
└── a/b/c/<image_name>.png (cropped images)\
4646
"""
4747
),
48-
usage="litpose crop <model_dir> <input_path:video|csv>... --crop_ratio=CROP_RATIO --anchor_keypoints=x,y,z",
48+
usage="litpose crop <model_dir> <input_path:video|csv>... --crop_ratio=CROP_RATIO --anchor_keypoints=x,y,z", # noqa
4949
)
5050
crop_parser.add_argument(
5151
"model_dir", type=types.existing_model_dir, help="path to a model directory"

lightning_pose/cli/commands/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def register_parser(subparsers):
4242
"input_path",
4343
type=Path,
4444
nargs="+",
45-
help="one or more paths. They can be video files, image files, CSV files, or directories.\n"
45+
help="one or more paths; can be video files, image files, CSV files, or directories.\n"
4646
" directory: predicts over videos or images in the directory.\n"
4747
" saves image outputs to `image_preds/<directory_name>`\n"
4848
" video file: predicts on the video\n"

lightning_pose/cli/commands/remap.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from pathlib import Path
44
from textwrap import dedent
55

6-
from .. import types
7-
86

97
def register_parser(subparsers):
108
"""Register the remap command parser."""
@@ -15,7 +13,8 @@ def register_parser(subparsers):
1513
Remaps predictions from cropped to original coordinate space.
1614
Requires model predictions to already have been generated using `litpose predict`.
1715
18-
Remapped predictions are saved as "remapped_{preds_file}" in the same folder as preds_file.
16+
Remapped predictions are saved as "remapped_{preds_file}" in the same folder as
17+
preds_file.
1918
"""
2019
),
2120
usage="litpose remap <preds_file> <bbox_file>",

lightning_pose/cli/commands/run_app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def handle(args):
2727
import importlib.util
2828
if not importlib.util.find_spec('litpose_app'):
2929
import sys
30-
print("❌ App not installed. To install the app:\n\n pip install lightning-pose-app\n", file=sys.stderr)
30+
print(
31+
"❌ App not installed. To install the app:\n\n pip install lightning-pose-app\n",
32+
file=sys.stderr,
33+
)
3134
sys.exit(1)
3235

3336
# Import lightning_pose modules only when needed

lightning_pose/cli/commands/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def register_parser(subparsers):
2929
type=types.config_file,
3030
help="path a config file.\n"
3131
"Download and modify the config template from: \n"
32-
"https://github.com/paninski-lab/lightning-pose/blob/main/scripts/configs/config_default.yaml",
32+
"https://github.com/paninski-lab/lightning-pose/blob/main/scripts/configs/config_default.yaml", # noqa
3333
)
3434
train_parser.add_argument(
3535
"--output_dir",

lightning_pose/data/augmentations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def imgaug_transform(params_dict: dict | DictConfig) -> iaa.Sequential:
5050
>>> angle: [-90, 90]
5151
5252
Create a pipeline with
53-
- Rot90 transformation that is applied 100% of the time with rotations of 0, 90, 180, or 270 degrees.
53+
- Rot90 transformation applied 100% of the time with rotations of 0, 90, 180, 270 degrees.
5454
5555
>>> params_dict = {
56-
>>> 'Rot90': {'p': 1.0, 'kwargs': {'k': [[0, 1, 2, 3]]}}, # note the (required) nested list
56+
>>> 'Rot90': {'p': 1.0, 'kwargs': {'k': [[0, 1, 2, 3]]}}, # note required nested list
5757
>>> }
5858
5959
In a config file, this will look like:
@@ -64,8 +64,8 @@ def imgaug_transform(params_dict: dict | DictConfig) -> iaa.Sequential:
6464
>>> kwargs:
6565
>>> k: [0, 1, 2, 3]
6666
67-
NOTE: if you pass a list of exactly 2 values to Rot90 it will be parsed as a tuple and all (discrete) rotations
68-
between the two values will be sampled uniformly.
67+
NOTE: if you pass a list of exactly 2 values to Rot90 it will be parsed as a tuple and all
68+
(discrete) rotations between the two values will be sampled uniformly.
6969
For example, `k: [0, 2]` is equivalent to `k: [0, 1, 2]`.
7070
If you need to _only_ sample two non-contiguous integers please raise an issue.
7171

lightning_pose/data/datamodules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ def _setup(self) -> None:
126126
self.val_dataset.dataset.imgaug_transform = resize_transform
127127
if hasattr(self.val_dataset.dataset, "dataset"):
128128
# this will get triggered for multiview datasets
129-
print(f"val: updating children datasets with resize imgaug pipeline")
129+
print("val: updating children datasets with resize imgaug pipeline")
130130
for view_name, dset in self.val_dataset.dataset.dataset.items():
131131
dset.imgaug_transform = resize_transform
132132

133133
self.test_dataset.dataset.imgaug_transform = resize_transform
134134
if hasattr(self.test_dataset.dataset, "dataset"):
135135
# this will get triggered for multiview datasets
136-
print(f"test: updating children datasets with resize imgaug pipeline")
136+
print("test: updating children datasets with resize imgaug pipeline")
137137
for view_name, dset in self.test_dataset.dataset.dataset.items():
138138
dset.imgaug_transform = resize_transform
139139

lightning_pose/data/datasets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def __init__(
9595

9696
csv_data = pd.read_csv(csv_file, header=header_rows, index_col=0)
9797
csv_data = io_utils.fix_empty_first_row(csv_data)
98-
self.keypoint_names = io_utils.get_keypoint_names(csv_file=csv_file, header_rows=header_rows)
98+
self.keypoint_names = io_utils.get_keypoint_names(
99+
csv_file=csv_file, header_rows=header_rows,
100+
)
99101
self.image_names = list(csv_data.index)
100102
self.keypoints = torch.tensor(csv_data.to_numpy(), dtype=torch.float32)
101103
# convert to x,y coordinates
@@ -452,7 +454,7 @@ def check_data_images_names(self):
452454
img_file_names.add(Path(heatmaps.image_names[idx]).name)
453455
if len(img_file_names) > 1:
454456
raise ImportError(
455-
f"Discrepancy in image file names across CSV files! "
457+
"Discrepancy in image file names across CSV files! "
456458
"index:{idx}, image file names:{img_file_names}"
457459
)
458460

0 commit comments

Comments
 (0)