Skip to content

Commit

Permalink
Merge pull request #118 from dattalab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
calebweinreb authored Dec 15, 2023
2 parents ae1b360 + 5d98ae7 commit d406906
Show file tree
Hide file tree
Showing 11 changed files with 622 additions and 40 deletions.
3 changes: 1 addition & 2 deletions docs/keypoint_moseq_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,14 @@
"source": [
"# load the most recent model checkpoint and pca object\n",
"# model = kpms.load_checkpoint(project_dir, model_name)[0]\n",
"# pca = kpms.load_pca(project_dir)\n",
"\n",
"# # load new data (e.g. from deeplabcut)\n",
"# new_data = 'path/to/new/data/' # can be a file, a directory, or a list of files\n",
"# coordinates, confidences, bodyparts = kpms.load_keypoints(new_data, 'deeplabcut')\n",
"# data, metadata = kpms.format_data(coordinates, confidences, **config())\n",
"\n",
"# # apply saved model to new data\n",
"# results = kpms.apply_model(model, pca, data, metadata, project_dir, model_name, **config())\n",
"# results = kpms.apply_model(model, data, metadata, project_dir, model_name, **config())\n",
"\n",
"# optionally rerun `save_results_as_csv` to export the new results\n",
"# kpms.save_results_as_csv(results, project_dir, model_name)"
Expand Down
Binary file added docs/source/_static/EML_scores.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/confusion_matrix.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/kappa_scan.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
203 changes: 197 additions & 6 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Exporting pose estimates
~~~~~~~~~~~~~~~~~~~~~~~~
------------------------

During fitting, keypoint-MoSeq tries to estimate the "true" pose trajectory of the animal, discounting anomolous or low-confidence keypoints. The pose trajectory is stored in the model as a variable "x" that encodes a low-dimensional representation of the keypoints (similar to PCA). The code below shows how to project the pose trajectory back into the original coordinate space. This is useful for visualizing the estimated pose trajectory.::

Expand Down Expand Up @@ -51,27 +51,218 @@ The following code generates a video showing frames 0-3600 from one recording wi


Automatic kappa scan
~~~~~~~~~~~~~~~~~~~~
--------------------

Keypoint-MoSeq includes a hyperparameter called ``kappa`` that determines the rate of transitions between syllables. Higher values of kappa lead to longer syllables and smaller values lead to shorter syllables. Users should choose a value of kappa based their desired distribution of syllable durations. The code below shows how to automatically scan over a range of kappa values and choose te optimal value.::

import numpy as np

kappas = np.logspace(3,7,5)
decrease_kappa_factor = 10
num_ar_iters = 50
num_full_iters = 200

prefix = 'my_kappa_scan'

for kappa in kappas:
print(f"Fitting model with kappa={kappa}")
model = kpms.update_hypparams(model, kappa=kappa)
model_name = f'{prefix}-{kappa}'
model = kpms.init_model(data, pca=pca, **config())
# stage 1: fit the model with AR only
model = kpms.update_hypparams(model, kappa=kappa)
model = kpms.fit_model(
model,
data,
metadata,
project_dir,
model_name,
ar_only=True,
num_iters=num_ar_iters,
save_every_n_iters=25
)[0];

# stage 2: fit the full model
model = kpms.update_hypparams(model, kappa=kappa/decrease_kappa_factor)
kpms.fit_model(
model, data, metadata, project_dir,
model_name, ar_only=True, num_iters=100,
save_every_n_iters=25);
model,
data,
metadata,
project_dir,
model_name,
ar_only=False,
start_iter=num_ar_iters,
num_iters=num_full_iters,
save_every_n_iters=25
);

kpms.plot_kappa_scan(kappas, project_dir, prefix)


.. image:: _static/kappa_scan.jpg
:align: center




Model selection and comparison
------------------------------

Keypoint-MoSeq uses a stochastic fitting procedure, and thus produces slightly different syllable segmentations when run multiple times with different random seeds. Below, we show how to fit multiple models, compare the resulting syllables, and then select an optimal model for further analysis. It may also be useful in some cases to show that downstream analyses are robust to the choice of model.


.. _fitting-multiple-models:

Fitting multiple models
~~~~~~~~~~~~~~~~~~~~~~~

The code below shows how to fit multiple models with different random seeds.::

import jax

num_model_fits = 20
prefix = 'my_models'

ar_only_kappa = 1e6
num_ar_iters = 50

full_model_kappa = 1e4
num_full_iters = 500

for restart in range(num_model_fits):
print(f"Fitting model {restart}")
model_name = f'{prefix}-{restart}'
model = kpms.init_model(
data, pca=pca, **config(), seed=jax.random.PRNGKey(restart)
)

# stage 1: fit the model with AR only
model = kpms.update_hypparams(model, kappa=ar_only_kappa)
model = kpms.fit_model(
model,
data,
metadata,
project_dir,
model_name,
ar_only=True,
num_iters=num_ar_iters
)[0]

# stage 2: fit the full model
model = kpms.update_hypparams(model, kappa=full_model_kappa)
kpms.fit_model(
model,
data,
metadata,
project_dir,
model_name,
ar_only=False,
start_iter=num_ar_iters,
num_iters=num_full_iters
);

kpms.reindex_syllables_in_checkpoint(project_dir, model_name);
model, data, metadata, current_iter = kpms.load_checkpoint(project_dir, model_name)
results = kpms.extract_results(model, metadata, project_dir, model_name)

Comparing syllables
~~~~~~~~~~~~~~~~~~~

To get a sense of the variability across model runs, it may be useful to compare syllables produced by each model. The code below shows how to load results from two models runs (e.g., produced by the code above) and plot a confusion matrix showing the overlap between syllable labels.::

model_name_1 = 'my_models-0'
model_name_2 = 'my_models-1'

results_1 = kpms.load_results(project_dir, model_name_1)
results_2 = kpms.load_results(project_dir, model_name_2)

kpms.plot_confusion_matrix(results_1, results_2);


.. image:: _static/confusion_matrix.jpg


Selecting a model
~~~~~~~~~~~~~~~~~

We developed a matric called the expected marginal likelihood (EML) score that can be used to rank models. To calculate EML scores, you must first fit an ensemble of models to a given dataset, as shown in :ref:`Fitting multiple models <fitting-multiple-models>`. The code below loads this ensemble and then calculates the EML score for each model. The model with the highest EML score can then be selected for further analysis.::


# change the following line as needed
model_names = ['my_models-{}'.format(i) for i in range(20)]

eml_scores, eml_std_errs = kpms.expected_marginal_likelihoods(project_dir, model_names)
best_model = model_names[np.argmax(eml_scores)]
print(f"Best model: {best_model_name}")

kpms.plot_eml_scores(eml_scores, eml_std_errs, model_names)


.. image:: _static/EML_scores.jpg


Model averaging
~~~~~~~~~~~~~~~

Keypoint-MoSeq is probabilistic. So even once fitting is complete and the syllable parameters are fixed, there is still a distribution of possible syllable sequences given the observed data. In the default pipeline, one such sequence is sampled from this distribution and used for downstream analyses. Alternatively, one can estimate the marginal probability distribution over syllable labels at each timepoint. The code below shows how to do this. It can be applied to new data or the same data that was used for fitting (or a combination of the two).::

burnin_iters = 50
num_samples = 100
steps_per_sample = 5

# load the model (change `project_dir` and `model_name` as needed)
model = kpms.load_checkpoint(project_dir, model_name)[0]

# load data (e.g. from deeplabcut)
data_path = 'path/to/data/' # can be a file, a directory, or a list of files
coordinates, confidences, bodyparts = kpms.load_keypoints(data_path, 'deeplabcut')
data, metadata = kpms.format_data(coordinates, confidences, **config())

# compute the marginal probabilities of syllable labels
marginal_probs = kpms.estimate_syllable_marginals(
model, data, metadata, burnin_iters, num_samples, steps_per_sample, **config()
)


Location-aware modeling
-----------------------

Because keypoint-MoSeq uses centered and aligned pose estimates to define syllables, it is effectively blind to absolute movements of the animal in space. The only thing that keypoint-MoSeq normally cares about is change in pose -- defined here as the relative location of each keypoint. For example, if an animal were capable of simply sliding forward without otherwise moving, this would fail to show up in the syllable segmentation. To address this gap, we developed an experimental version of keypoint-MoSeq that leverages location and heading dynamics (in addition to pose) when defining syllables. To use this "location-aware" model, simply pass ``location_aware=True`` as an additional argument when calling the following functions.

- :py:func:`keypoint_moseq.init_model`
- :py:func:`keypoint_moseq.fit_model`
- :py:func:`keypoint_moseq.apply_model`
- :py:func:`keypoint_moseq.estimate_syllable_marginals`

Note that the location-aware model was not tested in the keypoint-MoSeq paper remains experimental. We welcome feedback and suggestions for improvement.


Mathematical details
~~~~~~~~~~~~~~~~~~~~

In the published version of keypoint-MoSeq, the animal's location :math:`v_t` and heading :math:`h_t` at each timepoint are conditionally independent of the current syllable :math:`z_t`. In particular, we assume

.. math::
v_{t+1} & \sim \mathcal{N}(v_t, \sigma^2_\text{loc} I_2) \\
h_{t+1} & \sim \text{Uniform}(-\pi, \pi)
In the location-aware model, we relax this assumption and allow the animal's location and heading to depend on the current syllable. Specifically, each syllable is associated with a pair of normal distributions that specify the animal's expected rotation and translation at each timestep. This can be expressed formally as follows:

.. math::
h_{t+1} = h_t + \Delta h_{z_t} + \epsilon_h,
& \ \text{ where } \
\epsilon_h \mid z_t \sim \mathcal{N}(0, \sigma^2_{h,z_t}) \\
v_{t+1} = v_t + R(h_t)^\top \Delta v_{z_t} + \epsilon_v,
& \ \text{ where } \
\epsilon_v \mid z_t \sim \mathcal{N}(0, \sigma^2_{v, z_t} I_2)
where :math:`R(h)` is a rotation matrix that rotates a vector by angle :math:`h`. The parameters :math:`\Delta h_i`, :math:`\Delta v_i`, :math:`\sigma^2_{h,i}`, and :math:`\sigma^2_{v,i}` for each syllable :math:`i` have a normal-inverse-gamma prior:

.. math::
\sigma^2_{v,i} & \sim \text{InverseGamma}(\alpha_v, \beta_v), \ \ \ \ \Delta v_i \sim \mathcal{N}(0, \sigma^2_{v,i} I_2 / \lambda_v) \\
\sigma^2_{h,i} & \sim \text{InverseGamma}(\alpha_h, \beta_h), \ \ \ \ \Delta h_i \sim \mathcal{N}(0, \sigma^2_{h,i} / \lambda_h)
3 changes: 1 addition & 2 deletions docs/source/modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1136,15 +1136,14 @@
"source": [
"# load the most recent model checkpoint and pca object\n",
"model = kpms.load_checkpoint(project_dir, model_name)[0]\n",
"pca = kpms.load_pca(project_dir)\n",
"\n",
"# load new data (e.g. from deeplabcut)\n",
"new_data = 'path/to/new/data/' # can be a file, a directory, or a list of files\n",
"coordinates, confidences, bodyparts = kpms.load_keypoints(new_data, 'deeplabcut')\n",
"data, metadata = kpms.format_data(coordinates, confidences, **config())\n",
"\n",
"# apply saved model to new data\n",
"results = kpms.apply_model(model, pca, data, metadata, project_dir, model_name, **config())\n",
"results = kpms.apply_model(model, data, metadata, project_dir, model_name, **config())\n",
"\n",
"# optionally rerun `save_results_as_csv` to export the new results\n",
"# kpms.save_results_as_csv(results, project_dir, model_name)"
Expand Down
2 changes: 1 addition & 1 deletion keypoint_moseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .analysis import *
from .calibration import noise_calibration

from jax_moseq.models.keypoint_slds import fit_pca, init_model
from jax_moseq.models.keypoint_slds import fit_pca
from jax_moseq.utils import get_frequencies, get_durations

from . import _version
Expand Down
Loading

0 comments on commit d406906

Please sign in to comment.