|
10 | 10 | from lightning_pose.data.datatypes import MultiviewPredictionResult, PredictionResult
|
11 | 11 | from lightning_pose.models import ALLOWED_MODELS
|
12 | 12 | from lightning_pose.utils import io as io_utils
|
13 |
| -from lightning_pose.utils.predictions import ( |
14 |
| - generate_labeled_video as generate_labeled_video_fn, |
15 |
| -) |
| 13 | +from lightning_pose.utils.predictions import generate_labeled_video as generate_labeled_video_fn |
16 | 14 | from lightning_pose.utils.predictions import (
|
17 | 15 | load_model_from_checkpoint,
|
18 | 16 | predict_dataset,
|
@@ -131,15 +129,17 @@ def predict_on_label_csv(
|
131 | 129 |
|
132 | 130 | Args:
|
133 | 131 | csv_file (str | Path): Path to the CSV file of images, keypoint locations.
|
134 |
| - data_dir (str | Path, optional): Root path for relative paths in the CSV file. Defaults to the |
135 |
| - data_dir originally used when training. |
| 132 | + data_dir (str | Path, optional): Root path for relative paths in the CSV file. |
| 133 | + Defaults to the data_dir originally used when training. |
136 | 134 | compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
|
137 | 135 | predictions.
|
138 |
| - generate_labeled_images (bool, optional): Whether to save labeled images. Defaults to False. |
| 136 | + generate_labeled_images (bool, optional): Whether to save labeled images. |
| 137 | + Defaults to False. |
139 | 138 | output_dir (str | Path, optional): The directory to save outputs to.
|
140 |
| - Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. |
141 |
| - add_train_val_test_set (bool): When predicting on training dataset, set to true to add the `set` |
142 |
| - column to the prediction output. |
| 139 | + Defaults to `{model_dir}/image_preds/{csv_file_name}`. |
| 140 | + If set to None, outputs are not saved. |
| 141 | + add_train_val_test_set (bool): When predicting on training dataset, set to true to add |
| 142 | + the `set` column to the prediction output. |
143 | 143 | Returns:
|
144 | 144 | PredictionResult: A PredictionResult object containing the predictions and metrics.
|
145 | 145 | """
|
@@ -218,13 +218,17 @@ def predict_on_video_file(
|
218 | 218 |
|
219 | 219 | Args:
|
220 | 220 | video_file (str | Path): Path to the video file.
|
| 221 | + output_dir (str | Path, optional): The directory to save outputs to. |
| 222 | + Defaults to `{model_dir}/image_preds/{csv_file_name}`. |
| 223 | + If set to None, outputs are not saved. |
221 | 224 | compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on
|
222 | 225 | predictions.
|
223 |
| - generate_labeled_video (bool, optional): Whether to save a labeled video. Defaults to False. |
224 |
| - output_dir (str | Path, optional): The directory to save outputs to. |
225 |
| - Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. |
| 226 | + generate_labeled_video (bool, optional): Whether to save a labeled video. |
| 227 | + Defaults to False. |
| 228 | +
|
226 | 229 | Returns:
|
227 | 230 | PredictionResult: A PredictionResult object containing the predictions and metrics.
|
| 231 | +
|
228 | 232 | """
|
229 | 233 | self._load()
|
230 | 234 | video_file = Path(video_file)
|
@@ -278,15 +282,26 @@ def predict_on_video_file_multiview(
|
278 | 282 | compute_metrics: bool = True,
|
279 | 283 | generate_labeled_video: bool = False,
|
280 | 284 | ) -> MultiviewPredictionResult:
|
281 |
| - """Version of `predict_on_video_file` that gives models access to multiple camera views of each frame. |
| 285 | + """Version of `predict_on_video_file` that accesses to multiple camera views of each frame. |
282 | 286 |
|
283 | 287 | Arguments:
|
284 |
| - video_file_per_view (list[str] | list[Path]): A list of video files each from a different view of the |
285 |
| - same session. Number of video files must match the `view_names` in the config file. Order of the list |
286 |
| - does not matter: video files are intelligently matched to views by their filename using |
287 |
| - `utils.io.collect_video_files_by_view`. |
| 288 | + video_file_per_view (list[str] | list[Path]): A list of video files each from a |
| 289 | + different view of the same session. |
| 290 | + Number of video files must match the `view_names` in the config file. |
| 291 | + Order of the list does not matter: video files are intelligently matched to views |
| 292 | + by their filename using `utils.io.collect_video_files_by_view`. |
| 293 | + output_dir (str | Path, optional): The directory to save outputs to. |
| 294 | + Defaults to `{model_dir}/image_preds/{csv_file_name}`. |
| 295 | + If set to None, outputs are not saved. |
| 296 | + compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on |
| 297 | + predictions. |
| 298 | + generate_labeled_video (bool, optional): Whether to save a labeled video. |
| 299 | + Defaults to False. |
288 | 300 |
|
289 |
| - See `predict_on_video_file` docstring for other arguments.""" |
| 301 | + Returns: |
| 302 | + MultiviewPredictionResult: object containing the predictions and metrics for each view. |
| 303 | +
|
| 304 | + """ |
290 | 305 | assert self.config.is_multi_view()
|
291 | 306 | self._load()
|
292 | 307 |
|
|
0 commit comments