Skip to content

Commit

Permalink
Update notebook Dirichlet mixture of multinomials to PyMC 5 (#581)
Browse files Browse the repository at this point in the history
* Update notebook to pymc v5

Also remove unused import

* Make randomness deterministic

* Update notebook to match new simulated data

Since the simulated data has changed, the numerical issues we faced are slightly different.
The notebook has been modified to reflect this.

- No need to do metropolis sampling for the multinomial model
- Worse numerical issues for the explicit DM model. Addressed by increasing
target_accept and modify the discussion about poor rhat values
- Minor changes to match new data
- Also refer to species by name instead of index.

* Minor changes to notebook text

* Change target_accept from 0.95 to 0.9

This gets rid of most divergences, and gives much better r_hat statistics

* Add myself to author list

* Minor change

* Fix minor text mismatches after updates

* Update unrelated files
  • Loading branch information
erik-werner authored Oct 5, 2023
1 parent 5177871 commit 0e74cef
Show file tree
Hide file tree
Showing 4 changed files with 684 additions and 525 deletions.
1 change: 0 additions & 1 deletion examples/case_studies/nyc_bym.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@
"outputs": [],
"source": [
"with pm.Model(coords=coords) as BYM_model:\n",
"\n",
" # intercept\n",
" beta0 = pm.Normal(\"beta0\", 0, 1)\n",
"\n",
Expand Down
1 change: 0 additions & 1 deletion examples/case_studies/nyc_bym.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ Lastly, we'll use a Poisson outcome distribution. The number of traffic accident

```{code-cell} ipython3
with pm.Model(coords=coords) as BYM_model:
# intercept
beta0 = pm.Normal("beta0", 0, 1)
Expand Down
1,100 changes: 648 additions & 452 deletions examples/mixture_models/dirichlet_mixture_of_multinomials.ipynb

Large diffs are not rendered by default.

107 changes: 36 additions & 71 deletions examples/mixture_models/dirichlet_mixture_of_multinomials.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ This notebook will demonstrate the performance benefits that come from taking th
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import pymc as pm
import scipy as sp
import scipy.stats
import seaborn as sns
print(f"Running on PyMC3 v{pm.__version__}")
print(f"Running on PyMC v{pm.__version__}")
```

```{code-cell} ipython3
Expand Down Expand Up @@ -121,8 +119,10 @@ k = len(trees)
n = len(forests)
total_count = 50
true_p = sp.stats.dirichlet(true_conc * true_frac).rvs(size=n)
observed_counts = np.vstack([sp.stats.multinomial(n=total_count, p=p_i).rvs() for p_i in true_p])
true_p = sp.stats.dirichlet(true_conc * true_frac).rvs(size=n, random_state=rng)
observed_counts = np.vstack(
[sp.stats.multinomial(n=total_count, p=p_i).rvs(random_state=rng) for p_i in true_p]
)
observed_counts
```
Expand All @@ -149,46 +149,16 @@ with pm.Model(coords=coords) as model_multinomial:
pm.model_to_graphviz(model_multinomial)
```

Interestingly, NUTS frequently runs into numerical problems on this model, perhaps an example of the
["Folk Theorem of Statistical Computing"](https://statmodeling.stat.columbia.edu/2008/05/13/the_folk_theore/).

Because of a couple of identities of the multinomial distribution,
we could reparameterize this model in a number of ways—we
would obtain equivalent models by exploding our $n$ observations
of $\mathrm{total\_count}$ items into $(n \times \mathrm{total\_count})$
independent categorical trials, or collapsing them down into
one Multinomial draw with $(n \times \mathrm{total\_count})$ items.
(Importantly, this is _not_ true for the DM distribution.)

Rather than _actually_ fixing our problem through reparameterization,
here we'll instead switch to the Metropolis step method,
which ignores some of the geometric pathologies of our naïve model.

**Important**: switching to Metropolis does not not _fix_ our model's issues, rather it _sweeps them under the rug_.
In fact, if you try running this model with NUTS (PyMC3's default step method), it will break loudly during sampling.
When that happens, this should be a **red alert** that there is something wrong in our model.

You'll also notice below that we have to increase considerably the number of draws we take from the posterior;
this is because Metropolis is much less efficient at
exploring the posterior than NUTS.

```{code-cell} ipython3
with model_multinomial:
trace_multinomial = pm.sample(
draws=5000, chains=4, step=pm.Metropolis(), return_inferencedata=True
)
trace_multinomial = pm.sample(chains=4)
```

Let's ignore the warning about inefficient sampling for now.

```{code-cell} ipython3
az.plot_trace(data=trace_multinomial, var_names=["frac"]);
```

The trace plots look fairly good;
visually, each parameter appears to be moving around the posterior well,
although some sharp parts of the KDE plot suggests that
sampling sometimes gets stuck in one place for a few steps.
The trace plots look fairly good; visually, each parameter appears to be moving around the posterior well.

```{code-cell} ipython3
summary_multinomial = az.summary(trace_multinomial, var_names=["frac"])
Expand All @@ -201,12 +171,9 @@ summary_multinomial
```

Likewise, diagnostics in the parameter summary table all look fine.
Here I've added a column estimating the effective sample size per
Here we've added a column estimating the effective sample size per
second of sampling.

Nonetheless, the fact that we were unable to use NUTS is still a red flag, and we should be
very cautious in using these results.

```{code-cell} ipython3
az.plot_forest(trace_multinomial, var_names=["frac"])
for j, (y_tick, frac_j) in enumerate(zip(plt.gca().get_yticks(), reversed(true_frac))):
Expand All @@ -219,7 +186,7 @@ frequencies are for each species (dashed lines), we can comment on the accuracy
of our inferences.
And now the issues with our model become apparent;
notice that the 94% HDIs _don't include the true values_ for
tree species 0, 2, 3.
tree species 0, 1, 3.
We might have seen _one_ HDI miss, but _three_???

...what's going on?
Expand All @@ -228,9 +195,7 @@ Let's troubleshoot this model using a posterior-predictive check, comparing our

```{code-cell} ipython3
with model_multinomial:
pp_samples = az.from_pymc3(
posterior_predictive=pm.fast_sample_posterior_predictive(trace=trace_multinomial)
)
pp_samples = pm.sample_posterior_predictive(trace=trace_multinomial)
# Concatenate with InferenceData object
trace_multinomial.extend(pp_samples)
Expand Down Expand Up @@ -283,12 +248,12 @@ axs[-1].set_ylim(0, 0.6);
Here we're plotting histograms of the predicted counts
against the observed counts for each species.

_(Notice that the y-axis isn't full height and clips the distributions for species-4 in purple.)_
_(Notice that the y-axis isn't full height and clips the distributions for species ``mahogany`` in purple.)_

And now we can start to see why our posterior HDI deviates from the _true_ parameters for three of five species (vertical lines).
See that for all of the species the observed counts are frequently quite far from the predictions
conditioned on the posterior distribution.
This is particularly obvious for (e.g.) species-2 where we have one observation of more than 20
This is particularly obvious for (e.g.) ``oak`` where we have one observation of more than 30
trees of this species, despite the posterior predicitive mass being concentrated far below that.

This is overdispersion at work, and a clear sign that we need to adjust our model to accommodate it.
Expand Down Expand Up @@ -330,13 +295,14 @@ accounting for overdispersion of counts relative to the simple multinomial model

```{code-cell} ipython3
with model_dm_explicit:
trace_dm_explicit = pm.sample(chains=4, return_inferencedata=True)
trace_dm_explicit = pm.sample(chains=4, target_accept=0.9)
```

We got a warning, although we'll ignore it for now.
More interesting is how much longer it took to sample this model than the
first.
This may be because our model has an additional ~$(n \times k)$ parameters,
Here we had to increase ``target_accept`` from 0.8 to 0.9 to not get drowned in divergences.

We also got a warning about the ``rhat`` statistic, although we'll ignore it for now.
More interesting is how much longer it took to sample this model than the first.
This is partly because our model has an additional ~$(n \times k)$ parameters,
but it seems like there are other geometric challenges for NUTS as well.

We'll see if we can fix these in the next model, but for now let's take a look at the traces.
Expand All @@ -345,7 +311,7 @@ We'll see if we can fix these in the next model, but for now let's take a look a
az.plot_trace(data=trace_dm_explicit, var_names=["frac", "conc"]);
```

Obviously some sampling issues, but it's hard to see where divergences are occurring.
The divergences seem to occur when the estimated fraction of the rare species (``mahogany``) is very close to zero.

```{code-cell} ipython3
az.plot_forest(trace_dm_explicit, var_names=["frac"])
Expand All @@ -372,8 +338,8 @@ summary_dm_explicit
```

This is great, but _we can do better_.
The larger $\hat{R}$ value for $\mathrm{frac}_4$ is mildly concerning, and it's surprising
that our $\mathrm{ESS} \; \mathrm{sec}^{-1}$ is relatively small.
The slightly too large $\hat{R}$ value for ``frac[mahogany]`` is a bit concerning, and it's surprising
that our $\mathrm{ESS} \; \mathrm{sec}^{-1}$ is quite small.

+++

Expand All @@ -383,7 +349,7 @@ that our $\mathrm{ESS} \; \mathrm{sec}^{-1}$ is relatively small.

Happily, the Dirichlet distribution is conjugate to the multinomial
and therefore there's a convenient, closed-form for the marginalized
distribution, i.e. the Dirichlet-multinomial distribution, which was added to PyMC3 in [3.11.0](https://github.com/pymc-devs/pymc3/releases/tag/v3.11.0).
distribution, i.e. the Dirichlet-multinomial distribution, which was added to PyMC in [3.11.0](https://github.com/pymc-devs/pymc3/releases/tag/v3.11.0).

Let's take advantage of this, marginalizing out the explicit latent parameter, $p_i$,
replacing the combination of this node and the multinomial
Expand All @@ -405,7 +371,7 @@ nodes together into a single DM node.

```{code-cell} ipython3
with model_dm_marginalized:
trace_dm_marginalized = pm.sample(chains=4, return_inferencedata=True)
trace_dm_marginalized = pm.sample(chains=4)
```

It samples much more quickly and without any of the warnings from before!
Expand Down Expand Up @@ -437,9 +403,7 @@ Posterior predictive checks to the rescue (again)!

```{code-cell} ipython3
with model_dm_marginalized:
pp_samples = az.from_pymc3(
posterior_predictive=pm.fast_sample_posterior_predictive(trace_dm_marginalized)
)
pp_samples = pm.sample_posterior_predictive(trace_dm_marginalized)
# Concatenate with InferenceData object
trace_dm_marginalized.extend(pp_samples)
Expand Down Expand Up @@ -494,7 +458,7 @@ axs[-1, 0].set_ylim(0, 0.6)
ax.set_ylim(0, 0.6);
```

_(Notice, again, that the y-axis isn't full height, and clips the distributions for species-4 in purple.)_
_(Notice, again, that the y-axis isn't full height, and clips the distributions for ``mahogany`` in purple.)_

Compared to the multinomial (plots on the right), PPCs for the DM (left) show that the observed data is
an entirely reasonable realization of our model.
Expand All @@ -512,15 +476,19 @@ We'll use leave-one-out cross validation to compare the
out-of-sample predictive ability of the two.

```{code-cell} ipython3
with model_multinomial:
pm.compute_log_likelihood(trace_multinomial)
with model_dm_marginalized:
pm.compute_log_likelihood(trace_dm_marginalized)
az.compare(
{"multinomial": trace_multinomial, "dirichlet_multinomial": trace_dm_marginalized}, ic="loo"
)
```

Unsurprisingly, the DM outclasses the multinomial by a mile, assigning a weight of nearly
100% to the over-dispersed model.
We can conclude that between the two, the DM should be greatly favored for prediction,
parameter inference, etc.
Unsurprisingly, the DM outclasses the multinomial by a mile, assigning a weight of 100% to the over-dispersed model.
While the ``warning=True`` flag for the multinomial distribution indicates that the numerical value cannot be fully trusted, the large difference in ``elpd_loo`` is further confirmation that between the two, the DM should be greatly favored for prediction, parameter inference, etc.

+++

Expand Down Expand Up @@ -549,19 +517,16 @@ In that case, swapping the vanilla Dirichlet distribution for something fancier
## Authors
* Authored by [Byron J. Smith](https://github.com/bsmith89) on Jan, 2021 ([pymc-examples#18](https://github.com/pymc-devs/pymc-examples/pull/18))
* Updated by Abhipsha Das and Oriol Abril-Pla on August, 2021 ([pymc-examples#212](https://github.com/pymc-devs/pymc-examples/pull/212))
* Updated to PyMC v5 by Erik Werner on Oct, 2023 ([pymc-examples#581](https://github.com/pymc-devs/pymc-examples/pull/581))

+++

## Watermark

```{code-cell} ipython3
%load_ext watermark
%watermark -n -u -v -iv -w -p theano,xarray
%watermark -n -u -v -iv -w -p pytensor,xarray
```

:::{include} page_footer.md
:::

```{code-cell} ipython3
```

0 comments on commit 0e74cef

Please sign in to comment.