-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
84aa791
commit d738d5c
Showing
6 changed files
with
646 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "ddde367f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pymc as pm\n", | ||
"import numpy as np\n", | ||
"import pymc_experimental as pmx" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "45731819", | ||
"metadata": {}, | ||
"source": [ | ||
"I follow **Example 9** from [here](https://projecteuclid.org/ebooks/nsf-cbms-regional-conference-series-in-probability-and-statistics/Nonparametric-Bayesian-Inference/Chapter/Chapter-3-Dirichlet-Process/10.1214/cbms/1362163748) (displayed below) and attempt to also replicate the results shown in Figure 3.2. They assume that data are drawn i.i.d. from $\\mathcal{N}(2, 4)$ but assume a base distribution $G_0 = \\mathcal{N}(0, 1)$.\n", | ||
"\n", | ||
"<img src=\"dp-example-9.png\" width=500>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "93713949", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"alpha = 5.0 # concentration parameter\n", | ||
"K = 19 # truncation parameter\n", | ||
"\n", | ||
"rng = np.random.default_rng(seed=34)\n", | ||
"obs = rng.normal(2.0, 2.0, size=50)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "7fab0297", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Multiprocess sampling (4 chains in 4 jobs)\n", | ||
"CompoundStep\n", | ||
">NUTS: [base_dist, sbw]\n", | ||
">BinaryGibbsMetropolis: [idx]\n", | ||
">CategoricalGibbsMetropolis: [atom_selection]\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"\n", | ||
"<style>\n", | ||
" /* Turns off some styling */\n", | ||
" progress {\n", | ||
" /* gets rid of default border in Firefox and Opera. */\n", | ||
" border: none;\n", | ||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | ||
" background-size: auto;\n", | ||
" }\n", | ||
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n", | ||
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n", | ||
" }\n", | ||
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | ||
" background: #F44336;\n", | ||
" }\n", | ||
"</style>\n" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"\n", | ||
" <div>\n", | ||
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | ||
" 100.00% [8000/8000 00:07<00:00 Sampling 4 chains, 0 divergences]\n", | ||
" </div>\n", | ||
" " | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds.\n", | ||
"Chain <xarray.DataArray 'chain' ()>\n", | ||
"array(0)\n", | ||
"Coordinates:\n", | ||
" chain int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n", | ||
"Chain <xarray.DataArray 'chain' ()>\n", | ||
"array(1)\n", | ||
"Coordinates:\n", | ||
" chain int64 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n", | ||
"Chain <xarray.DataArray 'chain' ()>\n", | ||
"array(2)\n", | ||
"Coordinates:\n", | ||
" chain int64 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n", | ||
"Chain <xarray.DataArray 'chain' ()>\n", | ||
"array(3)\n", | ||
"Coordinates:\n", | ||
" chain int64 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"with pm.Model() as model:\n", | ||
" base_dist = pm.Normal(\"base_dist\", 0.0, 1.0, shape=(K + 1,))\n", | ||
" sbw, atoms = pmx.dp.DirichletProcess(\"dp\", alpha, base_dist, K, observed=obs)\n", | ||
"\n", | ||
" trace = pm.sample()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "9ca7b453", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x_plot = np.linspace(-4, 8, num=1001)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"id": "19b4ca98", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dirac = np.less.outer(x_plot, trace.posterior[\"atoms\"].values[0, 0])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"id": "9875702b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(1000, 20)" | ||
] | ||
}, | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"trace.posterior[\"sbw\"].values[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"id": "ee2fb857", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(1001, 20)" | ||
] | ||
}, | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"dirac" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d00e46c2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "pymc-dev", | ||
"language": "python", | ||
"name": "pymc-dev" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2020 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from pymc_experimental.dp.dp import DirichletProcess, DirichletProcessMixture | ||
|
||
__all__ = ["DirichletProcess", "DirichletProcessMixture"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright 2020 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import pymc as pm | ||
import pytensor.tensor as pt | ||
from pymc.distributions import Mixture | ||
from pymc.model import modelcontext | ||
|
||
__all__ = ["DirichletProcess", "DirichletProcessMixture"] | ||
|
||
|
||
def DirichletProcess(name, alpha, base_dist, K, observed=None, sbw_name=None, atoms_name=None): | ||
r""" | ||
Truncated Dirichlet Process for Bayesian Nonparametric Density Modelling | ||
Parameters | ||
---------- | ||
alpha: tensor_like of float | ||
Scale concentration parameter (alpha > 0) specifying the size of "sticks", or generated | ||
weights, from the stick-breaking process. Ideally, alpha should have a prior and not be | ||
a fixed constant. | ||
base_dist: single batched distribution | ||
The base distribution for a Dirichlet Process. `base_dist` must have shape (K + 1,). | ||
K: int | ||
The truncation parameter for the number of components of the Dirichlet Process Mixture. | ||
The Goldilocks Principle should be used in selecting an appropriate value of K: not too | ||
low to capture all possible clusters and not too high to induce a heavy computational | ||
burden for sampling. | ||
""" | ||
if sbw_name is None: | ||
sbw_name = "sbw" | ||
|
||
if atoms_name is None: | ||
atoms_name = "atoms" | ||
|
||
if observed is not None: | ||
observed = np.asarray(observed) | ||
|
||
if observed.ndim > 1: | ||
raise ValueError("Multi-dimensional Dirichlet Processes are not " "yet supported.") | ||
|
||
N = observed.shape[0] | ||
|
||
try: | ||
modelcontext(None) | ||
except TypeError: | ||
raise ValueError( | ||
"PyMC Dirichlet Processes are only available under a pm.Model() context manager." | ||
) | ||
|
||
sbw = pm.StickBreakingWeights(sbw_name, alpha, K) | ||
|
||
if observed is None: | ||
return sbw, pm.Deterministic(atoms_name, base_dist) | ||
|
||
""" | ||
idx samples a new atom from `base_dist` with probability alpha/(alpha + N) | ||
and an existing atom from `observed` with probability N/(alpha + N). | ||
If a new atom is not sampled, an atom from `observed` is sampled uniformly. | ||
""" | ||
idx = pm.Bernoulli("idx", p=alpha / (alpha + N), shape=(K + 1,)) | ||
atom_selection = pm.Categorical("atom_selection", p=[1 / N] * N, shape=(K + 1,)) | ||
|
||
atoms = pm.Deterministic( | ||
atoms_name, | ||
var=pt.stack([pt.constant(observed)[atom_selection], base_dist], axis=-1)[ | ||
pt.arange(K + 1), idx | ||
], | ||
) | ||
|
||
return sbw, atoms | ||
|
||
|
||
class DirichletProcessMixture(Mixture): | ||
r""" | ||
Truncated Dirichlet Process Mixture | ||
Parameters | ||
---------- | ||
alpha: tensor_like of float | ||
Scale concentration parameter (alpha > 0) specifying the size of "sticks", or generated | ||
weights, from the stick-breaking process. Ideally, alpha should have a prior and not be | ||
a fixed constant. | ||
G0: single batched distribution | ||
The base distribution for a Dirichlet Process Mixture should be created via the | ||
`.dist()` API as this class inherits from `pm.Mixture`. Be sure that the last size | ||
of G0 is K+1. | ||
K: int | ||
The truncation parameter for the number of components of the Dirichlet Process Mixture. | ||
The Goldilocks Principle should be used in selecting an appropriate value of K: not too | ||
low to capture all possible clusters and not too high to induce a heavy computational | ||
burden for sampling. | ||
""" | ||
|
||
def __new__(cls, name, alpha, G0, K, **kwargs): | ||
if "sbw_name" in kwargs: | ||
sbw_name = kwargs["sbw_name"] | ||
else: | ||
sbw_name = f"sbw_{name}" | ||
|
||
model = modelcontext(None) | ||
model.register_rv( | ||
pm.StickBreakingWeights.dist(alpha, K, **kwargs), | ||
sbw_name, | ||
) | ||
return super().__new__(cls, name, w=model[sbw_name], comp_dists=G0, **kwargs) |