Skip to content

Commit

Permalink
Adding Pyro
Browse files Browse the repository at this point in the history
  • Loading branch information
James Bristow committed Feb 7, 2024
1 parent fe32913 commit 4c320ee
Show file tree
Hide file tree
Showing 9 changed files with 1,570 additions and 4 deletions.
3 changes: 3 additions & 0 deletions pyro/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Pyro

See https://pyro.ai/examples/index.html
774 changes: 774 additions & 0 deletions pyro/hmm.py

Large diffs are not rendered by default.

218 changes: 218 additions & 0 deletions pyro/kalman-filter.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2e5c7973",
"metadata": {},
"source": [
"# https://pyro.ai/examples/ekf.html"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9ebb6829",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.8.6'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import math\n",
"\n",
"import torch\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"from pyro.infer.autoguide import AutoDelta\n",
"from pyro.optim import Adam\n",
"from pyro.infer import SVI, Trace_ELBO, config_enumerate\n",
"from pyro.contrib.tracking.extended_kalman_filter import EKFState\n",
"from pyro.contrib.tracking.distributions import EKFDistribution\n",
"from pyro.contrib.tracking.dynamic_models import NcvContinuous\n",
"from pyro.contrib.tracking.measurements import PositionMeasurement\n",
"\n",
"pyro.__version__"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "87dcc0f7",
"metadata": {},
"outputs": [],
"source": [
"dt = 1e-2\n",
"num_frames = 10\n",
"dim = 4\n",
"\n",
"# Continuous model\n",
"ncv = NcvContinuous(dim, 2.0)\n",
"\n",
"# Truth trajectory\n",
"xs_truth = torch.zeros(num_frames, dim)\n",
"# initial direction\n",
"theta0_truth = 0.0\n",
"# initial state\n",
"with torch.no_grad():\n",
" xs_truth[0, :] = torch.tensor([0.0, 0.0, math.cos(theta0_truth), math.sin(theta0_truth)])\n",
" for frame_num in range(1, num_frames):\n",
" # sample independent process noise\n",
" dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))\n",
" xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx\n",
" \n",
"# Measurements\n",
"measurements = []\n",
"mean = torch.zeros(2)\n",
"# no correlations\n",
"cov = 1e-5 * torch.eye(2)\n",
"with torch.no_grad():\n",
" # sample independent measurement noise\n",
" dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))\n",
" # compute measurement means\n",
" zs = xs_truth[:, :2] + dzs"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "046c16cb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jbris/miniconda3/envs/data_assim/lib/python3.10/site-packages/torch/autograd/__init__.py:251: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11070). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n",
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: -15.20763874053955\n",
"loss: -15.339868545532227\n",
"loss: -15.413694381713867\n",
"loss: -15.473196029663086\n",
"loss: -15.507671356201172\n",
"loss: -15.523503303527832\n",
"loss: -15.5301513671875\n",
"loss: -15.532791137695312\n",
"loss: -15.533793449401855\n",
"loss: -15.534193992614746\n",
"loss: -15.534348487854004\n",
"loss: -15.534411430358887\n",
"loss: -15.534439086914062\n",
"loss: -15.534448623657227\n",
"loss: -15.534452438354492\n",
"loss: -15.534453392028809\n",
"loss: -15.534455299377441\n",
"loss: -15.534454345703125\n",
"loss: -15.534454345703125\n",
"loss: -15.534454345703125\n",
"loss: -15.534455299377441\n",
"loss: -15.534455299377441\n",
"loss: -15.534454345703125\n",
"loss: -15.534453392028809\n",
"loss: -15.534456253051758\n"
]
}
],
"source": [
"def model(data):\n",
" # a HalfNormal can be used here as well\n",
" R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)\n",
" Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)\n",
" # observe the measurements\n",
" pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,\n",
" Q, time_steps=num_frames),\n",
" obs=data)\n",
"\n",
"guide = AutoDelta(model) # MAP estimation\n",
"\n",
"optim = pyro.optim.Adam({'lr': 2e-2})\n",
"svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))\n",
"\n",
"pyro.set_rng_seed(0)\n",
"pyro.clear_param_store()\n",
"\n",
"for i in range(250):\n",
" loss = svi.step(zs)\n",
" if not i % 10:\n",
" print('loss: ', loss)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bb429939",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35346980>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35347c10>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35347d60>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35345d20>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35344730>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ffe80>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ff0d0>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354fded0>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ff9d0>,\n",
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354feb90>]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"R = guide()['pv_cov'] * torch.eye(4)\n",
"Q = guide()['measurement_cov'] * torch.eye(2)\n",
"ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)\n",
"states= ekf_dist.filter_states(zs)\n",
"states"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88a05fba",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
44 changes: 44 additions & 0 deletions pyro/mixed_hmm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Hierarchical mixed-effect hidden Markov models

Note: This is a cleaned-up version of the seal experiments in [Bingham et al 2019] that is a simplified variant of some of the analysis in the [momentuHMM harbour seal example](https://github.com/bmcclintock/momentuHMM/blob/master/vignettes/harbourSealExample.R) [McClintock et al 2018].

Recent advances in sensor technology have made it possible to capture the movements of multiple wild animals within a single population at high spatiotemporal resolution over long periods of time [McClintock et al 2013, Towner et al 2016]. Discrete state-space models, where the latent state is thought of as corresponding to a behavior state such as "foraging" or "resting", have become popular computational tools for analyzing these new datasets thanks to their interpretability and tractability.

This example applies several different hierarchical discrete state-space models to location data recorded from a colony of harbour seals on foraging excursions in the North Sea [McClintock et al 2013].
The raw data are irregularly sampled time series (roughly 5-15 minutes between samples) of GPS coordinates and diving activity for each individual in the colony (10 male and 7 female) over the course of a single day recorded by lightweight tracking devices physically attached to each animal by researchers. They have been preprocessed using the momentuHMM example code into smoothed, temporally regular series of step sizes, turn angles, and diving activity for each individual.

The models are special cases of a time-inhomogeneous discrete state space model
whose state transition distribution is specified by a hierarchical generalized linear mixed model (GLMM).
At each timestep `t`, for each individual trajectory `b` in each group `a`, we have

```
logit(p(x[t,a,b] = state i | x[t-1,a,b] = state j)) =
(epsilon_G[a] + epsilon_I[a,b] + Z_I[a,b].T @ beta1 + Z_G[a].T @ beta2 + Z_T[t,a,b].T @ beta3)[i,j]
```

where `a,b` correspond to plate indices, `epsilon_G` and `epsilon_I` are independent random variables for each group and individual within each group respectively, `Z`s are covariates, and `beta`s are parameter vectors.

The random variables `epsilon` may be either discrete or continuous.
If continuous, they are normally distributed.
If discrete, they are sampled from a set of three possible values shared across the innermost plate of a particular variable.
That is, for each individual trajectory `b` in each group `a`, we sample single random effect values for an entire trajectory:

```
iota_G[a] ~ Categorical(pi_G)
epsilon_G[a] = Theta_G[iota_G[a]]
iota_I[a,b] ~ Categorical(pi_I[a])
epsilon_I[a,b] = Theta_I[a][iota_I[a,b]]
```

Here `pi_G`, `Theta_G`, `pi_I`, and `Theta_I` are all learnable real-valued parameter vectors and `epsilon` values are batches of vectors the size of state transition matrices.

Observations `y[t,a,b]` are represented as sequences of real-valued step lengths and turn angles, modelled by zero-inflated Gamma and von Mises likelihoods respectively.
The seal models also include a third observed variable indicating the amount of diving activity between successive locations, which we model with a zero-inflated Beta distribution following [McClintock et al 2018].

We grouped animals by sex and implemented versions of this model with (i) no random effects (as a baseline), and with random effects present at the (ii) group, (iii) individual, or (iv) group+individual levels. Unlike the models in [Towner et al 2016], we do not consider fixed effects on any of the parameters.

# References
* [Obermeyer et al 2019] Obermeyer, F.\*, Bingham, E.\*, Jankowiak, M.\*, Chiu, J., Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for Plated Factor Graphs, 2019
* [McClintock et al 2013] McClintock, B. T., Russell, D. J., Matthiopoulos, J., and King, R. Combining individual animal movement and ancillary biotelemetry data to investigate population-level activity budgets. Ecology, 94(4):838–849, 2013
* [McClintock et al 2018] McClintock, B. T. and Michelot,T. momentuhmm: R package for generalized hidden markov models of animal movement. Methods in Ecology and Evolution, 9(6): 1518–1530, 2018. doi: 10.1111/2041-210X.12995
* [Towner et al 2016] Towner, A. V., Leos-Barajas, V., Langrock, R., Schick, R. S., Smale, M. J., Kaschke, T., Jewell, O. J., and Papastamatiou, Y. P. Sex-specific and individual preferences for hunting strategies in white sharks. Functional Ecology, 30(8):1397–1407, 2016.
Empty file added pyro/mixed_hmm/__init__.py
Empty file.
Loading

0 comments on commit 4c320ee

Please sign in to comment.