Skip to content

Commit

Permalink
[docs] install + training + inference
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Dec 3, 2023
1 parent e92fa62 commit b10c6d4
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 44 deletions.
47 changes: 3 additions & 44 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
[![Documentation Status](https://readthedocs.org/projects/daart/badge/?version=latest)](https://daart.readthedocs.io/en/latest/?badge=latest)
[![DOI](https://zenodo.org/badge/334987729.svg)](https://zenodo.org/badge/latestdoi/334987729)

A collection of tools for the discrete classification of animal behaviors using low-dimensional representations of videos (such as skeletons provided by tracking algorithms). Our approach combines strong supervision, weak supervision, and self-supervision to improve model performance. See the preprint [here](https://www.biorxiv.org/content/10.1101/2021.06.16.448685v1) for more details. This repo currently supports fitting the following types of base models on behavioral time series data:
A collection of tools for the discrete classification of animal behaviors using low-dimensional representations of videos (such as skeletons provided by tracking algorithms). Our approach combines strong supervision, weak supervision, and self-supervision to improve model performance. See the preprint [here](https://www.biorxiv.org/content/10.1101/2021.06.16.448685v1) for more details.

This repo currently supports fitting the following types of base models on behavioral time series data:
* Dense MLP network with initial 1D convolutional layer
* RNNs - both LSTMs and GRUs
* Temporal Convolutional Networks (TCNs)
Expand All @@ -19,46 +21,3 @@ If you use daart in your analysis of behavioral data, please cite our preprint!
year={2021},
publisher={Cold Spring Harbor Laboratory}
}

## Installation



## Getting started

To fit models from the command line using [test-tube](https://williamfalcon.github.io/test-tube/)
for hyperparameter searching and model fitting, see `fit_models.py` in the `examples` directory.
This script fits one or more models based on three yaml configuration files: one describing the
data, one describing the model, and one describing the training procedure. Example configuration
files can be found in the `configs` directory.

**_Note:_** Test-tube will automatically perform a hyperparameter search over any field that is
provided as a list; for example, in the `model.yaml` file, change `n_hid_layers: 1` to
`n_hid_layers: [1, 2, 3]` to search over the number of hidden layers in the model.

Once you have set the desired parameters in these files (see comment on data paths below), you can
then fit models like so:

```
(daart) $: python fit_models.py --data_config /path/to/data.yaml
--model_config /path/to/model.yaml --train_config /path/to/train.yaml
```

#### Data paths

The `data.yaml` file has a field for listing experiment/session/video ids (`expt_ids`), as well as
a `data_dir` field. The `fit_models.py` script assumes data is stored in the following way, though
this can easily be adapted by changing the appropriate lines in a copy of the `fit_example.py`
script:

* markers: `data_dir/markers/[expt_id]_labeled.csv` or `data_dir/markers/[expt_id]_labeled.h5`;
the standard file formats used by DLC/DGP are currently supported.

* hand labels: `data_dir/labels-hand/[expt_id]_labels.csv`; a binary matrix of shape
`(T, n_classes + 1)`, where the first column represents the `background` class; the gradients
contributed by these time points are zeroed out during training.

* heuristic labels: `data_dir/labels-heuristic/[expt_id]_labels.csv`; same format as the hand
labels

See the directory `daart/data` for example fly data used in the preprint.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ data:
:caption: Contents:

source/installation
source/user_guide
source/api

Indices and tables
Expand Down
17 changes: 17 additions & 0 deletions docs/source/user_guide.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. _user_guide:

##########
User guide
##########

This guide walks you through the steps required for using the daart package for semi-supervised
discrete behavior classification.

.. toctree::
:maxdepth: 2
:caption: Contents:

user_guide/organizing_your_data
user_guide/config_files
user_guide/training
user_guide/inference
60 changes: 60 additions & 0 deletions docs/source/user_guide/config_files.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
.. _user_guide_configs:

#######################
The configuration files
#######################

Users interact with daart through a set of configuration (yaml) files.
These files point to the data directories, define the type of model to fit, and specify a wide
range of hyperparameters.

An example set of configuration files can be found
`here <https://github.com/themattinthehatt/daart/tree/main/data/configs>`_.
When training a model on a new dataset, you must copy/paste these templates onto your local
machine and update the arguments to match your data.

There are three configuration files:

* :ref:`data <config_data>`: where data is stored and model input type
* :ref:`model <config_model>`: model class and various network hyperparameters
* :ref:`train <config_train>`: training epochs, batch size, etc.

The sections below describe the most important parameters in each file;
see the example configs for all possible options.

.. _config_data:

Data
====

* **input_type**: name of directory containing input data: 'markers' | 'features' | ...
* **output_size**: number of classes (including background)
* **expt_ids**: list of experiment ids used for training the model
* **data_dir**: absolute path to directory that contains the data
* **results_dir**: absolute path to directory that stores model fitting results

.. _config_model:

Model
=====

* **labmda_weak**: weight on heuristic/pseudo label classification loss
* **lambda_strong**: weight on hand label classification loss (can always leave this as 1)
* **lambda_recon**: weight on input reconstruction loss
* **lambda_pred**: weight on next-step-ahead prediction loss

So, for example, to fit a fully supervised classification model, set ``lambda_strong: 1`` and
all other "lambda" options to 0.

To fit a model that uses heuristic labels, set ``lambda_strong: 1``, ``lambda_weak: 1``, and
all other "lambda" options to 0. You can try several values of ``lambda_weak`` to see what works
best for your data.

.. _config_train:

Train
=====

* **min/max_epochs**: control length of training
* **enable_early_stop**: exit training early if validation loss begins to increase
* **trial_splits**: fraction of data to use for train;val;test;gap; you can always set "gap" to 0 as long as you validate your model on completely held-out videos
82 changes: 82 additions & 0 deletions docs/source/user_guide/inference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
.. _user_guide_inference:

#########
Inference
#########

Once you have trained a model you'll likely want to run inference on new videos.

Similar to training, there are a set of high-level functions used to perform inference and evaluate
performance; this page details some of the main steps.


Load model
==========

Using a provided model directory, construct a model and load the weights.

.. code-block:: python
import os
import torch
import yaml
from daart.models import Segmenter
model_dir = /path/to/model_dir
model_file = os.path.join(model_dir, 'best_val_model.pt')
hparams_file = os.path.join(model_dir, 'hparams.yaml')
hparams = yaml.safe_load(open(hparams_file, 'rb'))
model = Segmenter(hparams)
model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
model.to(hparams['device'])
model.eval()
Build data generator
====================

To run inference on a new session, you must provide a csv file that contains markers or features
from a new session (you must use the same type of inputs the model was trained on).

.. code-block:: python
from daart.data import DataGenerator
from daart.transforms import ZScore
sess_id = <name_of_session>
input_file = /path/to/markers_or_features_csv
# define data generator signals
signals = ['markers'] # same for markers or features
transforms = [ZScore()]
paths = [input_file]
# build data generator
data_gen_test = DataGenerator(
[sess_id], [signals], [transforms], [paths], device=hparams['device'],
sequence_length=hparams['sequence_length'], batch_size=hparams['batch_size'],
trial_splits=hparams['trial_splits'],
sequence_pad=hparams['sequence_pad'], input_type=hparams['input_type'],
)
Run inference
=============

Inference can be performed by passing the newly constructed data generator to the model's
``predict_labels`` method:

.. code-block:: python
import numpy as np
# predict probabilities from model
print('computing states for %s...' % sess_id, end='')
tmp = model.predict_labels(data_gen_test, return_scores=True)
probs = np.vstack(tmp['labels'][0])
print('done')
# get discrete state by taking argmax over probabilities at each time point
states = np.argmax(probs, axis=1)
Loading

0 comments on commit b10c6d4

Please sign in to comment.