Skip to content

Commit

Permalink
Update of GLM model selection notebook to v5
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Dec 24, 2024
1 parent 9857594 commit 3ea370e
Show file tree
Hide file tree
Showing 4 changed files with 683 additions and 798 deletions.
1,409 changes: 625 additions & 784 deletions examples/generalized_linear_models/GLM-model-selection.ipynb

Large diffs are not rendered by default.

36 changes: 22 additions & 14 deletions examples/generalized_linear_models/GLM-model-selection.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc as pm
import seaborn as sns
import xarray as xr
from ipywidgets import fixed, interactive
print(f"Running on PyMC3 v{pm.__version__}")
print(f"Running on PyMC v{pm.__version__}")
```

```{code-cell} ipython3
Expand All @@ -44,7 +44,7 @@ plt.rcParams["figure.constrained_layout.use"] = False
```

## Introduction
A fairly minimal reproducible example of Model Selection using WAIC, and LOO as currently implemented in PyMC3.
A fairly minimal reproducible example of Model Selection using WAIC, and LOO as currently implemented in PyMC.

This example creates two toy datasets under linear and quadratic models, and then tests the fit of a range of polynomial linear models upon those datasets by using Widely Applicable Information Criterion (WAIC), and leave-one-out (LOO) cross-validation using Pareto-smoothed importance sampling (PSIS).

Expand Down Expand Up @@ -198,12 +198,18 @@ def plot_posterior_cr(models, idatas, rawdata, xlims, datamodelnm="linear", mode
# Get traces and calc posterior prediction for npoints in x
npoints = 100
mdl = models[modelnm]
trc = idatas[modelnm].posterior.copy().drop_vars("y_sigma")
da = xr.concat([var for var in trc.values()], dim="order")
trc = idatas[modelnm].posterior.copy()
ordr = int(modelnm[-1:])
# Extract variables and stack them in correct order
vars_to_concat = []
for var in ["Intercept", "x"] + [f"np.power(x, {i})" for i in range(2, int(modelnm[-1:]) + 1)]:
if var in trc:
vars_to_concat.append(trc[var])
da = xr.concat(vars_to_concat, dim="order")
ordr = len(vars_to_concat)
x = xr.DataArray(np.linspace(xlims[0], xlims[1], npoints), dims=["x_plot"])
pwrs = xr.DataArray(np.arange(ordr + 1), dims=["order"])
pwrs = xr.DataArray(np.arange(ordr), dims=["order"])
X = x**pwrs
cr = xr.dot(X, da, dims="order")
Expand Down Expand Up @@ -337,7 +343,7 @@ $$y = a + bx + \epsilon$$

+++

### Define model using explicit PyMC3 method
### Define model using explicit PyMC method

```{code-cell} ipython3
with pm.Model() as mdl_ols:
Expand Down Expand Up @@ -417,7 +423,7 @@ def create_poly_modelspec(k=1):
def run_models(df, upper_order=5):
"""
Convenience function:
Fit a range of pymc3 models of increasing polynomial complexity.
Fit a range of pymc models of increasing polynomial complexity.
Suggest limit to max order 5 since calculation time is exponential.
"""
Expand All @@ -432,7 +438,9 @@ def run_models(df, upper_order=5):
models[nm] = bmb.Model(
fml, df, priors={"intercept": bmb.Prior("Normal", mu=0, sigma=100)}, family="gaussian"
)
results[nm] = models[nm].fit(draws=2000, tune=1000, init="advi+adapt_diag")
results[nm] = models[nm].fit(
draws=2000, tune=1000, init="advi+adapt_diag", idata_kwargs={"log_likelihood": True}
)
return models, results
```
Expand Down Expand Up @@ -499,11 +507,11 @@ dfwaic_quad
_, axs = plt.subplots(1, 2)
ax = axs[0]
az.plot_compare(dfwaic_lin, ax=ax)
az.plot_compare(dfwaic_lin, ax=ax, legend=False)
ax.set_title("Linear data")
ax = axs[1]
az.plot_compare(dfwaic_quad, ax=ax)
az.plot_compare(dfwaic_quad, ax=ax, legend=False)
ax.set_title("Quadratic data");
```

Expand Down Expand Up @@ -545,11 +553,11 @@ dfloo_quad
_, axs = plt.subplots(1, 2)
ax = axs[0]
az.plot_compare(dfloo_lin, ax=ax)
az.plot_compare(dfloo_lin, ax=ax, legend=False)
ax.set_title("Linear data")
ax = axs[1]
az.plot_compare(dfloo_quad, ax=ax)
az.plot_compare(dfloo_quad, ax=ax, legend=False)
ax.set_title("Quadratic data");
```

Expand Down
35 changes: 35 additions & 0 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nutpie = ">=0.13.2,<0.14"
numba = ">=0.60.0,<0.61"
scikit-learn = ">=1.5.2,<2"
blackjax = ">=1.2.3,<2"
bambi = ">=0.15.0,<0.16"

[pypi-dependencies]
pymc-experimental = ">=0.1.2, <0.2"
Expand Down

0 comments on commit 3ea370e

Please sign in to comment.