Skip to content

Commit cd10c87

Browse files
[docs] update config file info (#147)
* [docs] update min/max epochs description * [docs] add info about starting model training from a checkpoint
1 parent 9a5a765 commit cd10c87

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

docs/source/user_guide/config_file.rst

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,31 @@ The config file contains several sections:
2424
Data parameters
2525
===============
2626

27-
* ``data.image_orig_dims.height/width``: the current version of Lightning Pose requires all training images to be the same size. We are working on an updated version without this requirement. However, if you plan to use the PCA losses (Pose PCA or multiview PCA) then all training images **must** be the same size, otherwise the PCA subspace will erroneously contain variance related to image size.
28-
29-
* ``data.image_resize_dims.height/width``: images (and videos) will be resized to the specified height and width before being processed by the network. Supported values are {64, 128, 256, 384, 512}. The height and width need not be identical. Some points to keep in mind when selecting these values: if the resized images are too small, you will lose resolution/details; if they are too large, the model takes longer to train and might not train as well.
27+
* ``data.image_orig_dims.height/width``: the current version of Lightning Pose requires all
28+
training images to be the same size.
29+
We are working on an updated version without this requirement.
30+
However, if you plan to use the PCA losses (Pose PCA or multiview PCA) then all training images
31+
**must** be the same size, otherwise the PCA subspace will erroneously contain variance related
32+
to image size.
33+
34+
* ``data.image_resize_dims.height/width``: images (and videos) will be resized to the specified
35+
height and width before being processed by the network.
36+
Supported values are {64, 128, 256, 384, 512}.
37+
The height and width need not be identical.
38+
Some points to keep in mind when selecting these values:
39+
if the resized images are too small, you will lose resolution/details;
40+
if they are too large, the model takes longer to train and might not train as well.
3041

3142
* ``data.data_dir/video_dir``: update these to reflect your local paths
3243

33-
* ``data.num_keypoints``: the number of body parts. If using a mirrored setup, this should be the number of body parts summed across all views. If using a multiview setup, this number should indicate the number of keyponts per view (must be the same across all views).
44+
* ``data.num_keypoints``: the number of body parts.
45+
If using a mirrored setup, this should be the number of body parts summed across all views.
46+
If using a multiview setup, this number should indicate the number of keyponts per view
47+
(must be the same across all views).
3448

35-
* ``data.keypoint_names``: keypoint names should reflect the actual names/order in the csv file. This field is necessary if, for example, you are running inference on a machine that does not have the training data saved on it.
49+
* ``data.keypoint_names``: keypoint names should reflect the actual names/order in the csv file.
50+
This field is necessary if, for example, you are running inference on a machine that does not
51+
have the training data saved on it.
3652

3753
* ``data.columns_for_singleview_pca``: see the :ref:`Pose PCA documentation <unsup_loss_pcasv>`
3854

@@ -45,19 +61,36 @@ Model/training parameters
4561
Below is a list of some commonly modified arguments related to model architecture/training.
4662

4763
* ``training.train_batch_size``: batch size for labeled data
48-
* ``training.min_epochs`` / ``training.max_epochs``: length of training
64+
65+
* ``training.min_epochs`` / ``training.max_epochs``: length of training.
66+
An epoch is one full pass through the dataset.
67+
As an example, if you have 400 labeled frames, and ``training.train_batch_size=10``, then your
68+
dataset is divided into 400/10 = 40 batches.
69+
One "batch" in this case is equivalent to one "iteration" for DeepLabCut.
70+
Therefore, 300 epochs, at 40 batches per epoch, is equal to 300*40=12k total batches
71+
(or iterations).
72+
4973
* ``model.model_type``:
5074

5175
* regression: model directly outputs an (x, y) prediction for each keypoint; not recommended
5276
* heatmap: model outputs a 2D heatmap for each keypoint
53-
* heatmap_mhcrnn: the "multi-head convolutional RNN", this model takes a temporal window of frames as input, and outputs two heatmaps: one "context-aware" and one "static". The prediction with the highest confidence is automatically chosen.
77+
* heatmap_mhcrnn: the "multi-head convolutional RNN", this model takes a temporal window of
78+
frames as input, and outputs two heatmaps: one "context-aware" and one "static".
79+
The prediction with the highest confidence is automatically chosen.
5480

55-
* ``model.losses_to_use``: defines the unsupervised losses. An empty list indicates a fully supervised model. Each element of the list corresponds to an unsupervised loss. For example, ``model.losses_to_use=[pca_multiview,temporal]`` will fit both a pca_multiview loss and a temporal loss. Options include:
81+
* ``model.losses_to_use``: defines the unsupervised losses.
82+
An empty list indicates a fully supervised model.
83+
Each element of the list corresponds to an unsupervised loss.
84+
For example, ``model.losses_to_use=[pca_multiview,temporal]`` will fit both a pca_multiview loss
85+
and a temporal loss. Options include:
5686

5787
* pca_multiview: penalize inconsistencies between multiple camera views
5888
* pca_singleview: penalize implausible body configurations
5989
* temporal: penalize large temporal jumps
6090

91+
* ``model.checkpoint``: to continue training from an existing checkpoint, update this parameter
92+
to the absolute path of a pytorch .ckpt file
93+
6194
See the :ref:`Unsupervised losses <unsupervised_losses>` section for more details on the various
6295
losses and their associated hyperparameters.
6396

scripts/configs/config_default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ model:
9595
heatmap_loss_type: mse
9696
# directory name for model saving
9797
model_name: test
98+
# load model from checkpoint
99+
checkpoint: null
98100

99101
dali:
100102
general:

0 commit comments

Comments
 (0)