Skip to content

Commit

Permalink
Adding Dirichlet Process submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Mar 12, 2023
1 parent 84aa791 commit d738d5c
Show file tree
Hide file tree
Showing 6 changed files with 646 additions and 1 deletion.
224 changes: 224 additions & 0 deletions notebooks/dirichlet-process.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ddde367f",
"metadata": {},
"outputs": [],
"source": [
"import pymc as pm\n",
"import numpy as np\n",
"import pymc_experimental as pmx"
]
},
{
"cell_type": "markdown",
"id": "45731819",
"metadata": {},
"source": [
"I follow **Example 9** from [here](https://projecteuclid.org/ebooks/nsf-cbms-regional-conference-series-in-probability-and-statistics/Nonparametric-Bayesian-Inference/Chapter/Chapter-3-Dirichlet-Process/10.1214/cbms/1362163748) (displayed below) and attempt to also replicate the results shown in Figure 3.2. They assume that data are drawn i.i.d. from $\\mathcal{N}(2, 4)$ but assume a base distribution $G_0 = \\mathcal{N}(0, 1)$.\n",
"\n",
"<img src=\"dp-example-9.png\" width=500>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "93713949",
"metadata": {},
"outputs": [],
"source": [
"alpha = 5.0 # concentration parameter\n",
"K = 19 # truncation parameter\n",
"\n",
"rng = np.random.default_rng(seed=34)\n",
"obs = rng.normal(2.0, 2.0, size=50)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7fab0297",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Multiprocess sampling (4 chains in 4 jobs)\n",
"CompoundStep\n",
">NUTS: [base_dist, sbw]\n",
">BinaryGibbsMetropolis: [idx]\n",
">CategoricalGibbsMetropolis: [atom_selection]\n"
]
},
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [8000/8000 00:07&lt;00:00 Sampling 4 chains, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds.\n",
"Chain <xarray.DataArray 'chain' ()>\n",
"array(0)\n",
"Coordinates:\n",
" chain int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n",
"Chain <xarray.DataArray 'chain' ()>\n",
"array(1)\n",
"Coordinates:\n",
" chain int64 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n",
"Chain <xarray.DataArray 'chain' ()>\n",
"array(2)\n",
"Coordinates:\n",
" chain int64 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n",
"Chain <xarray.DataArray 'chain' ()>\n",
"array(3)\n",
"Coordinates:\n",
" chain int64 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.\n"
]
}
],
"source": [
"with pm.Model() as model:\n",
" base_dist = pm.Normal(\"base_dist\", 0.0, 1.0, shape=(K + 1,))\n",
" sbw, atoms = pmx.dp.DirichletProcess(\"dp\", alpha, base_dist, K, observed=obs)\n",
"\n",
" trace = pm.sample()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9ca7b453",
"metadata": {},
"outputs": [],
"source": [
"x_plot = np.linspace(-4, 8, num=1001)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "19b4ca98",
"metadata": {},
"outputs": [],
"source": [
"dirac = np.less.outer(x_plot, trace.posterior[\"atoms\"].values[0, 0])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "9875702b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 20)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trace.posterior[\"sbw\"].values[0]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "ee2fb857",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1001, 20)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dirac"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d00e46c2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc-dev",
"language": "python",
"name": "pymc-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Binary file added notebooks/dp-example-9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
284 changes: 284 additions & 0 deletions notebooks/dp-posterior-numpy.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@
_log.addHandler(handler)


from pymc_experimental import distributions, gp, utils
from pymc_experimental import distributions, dp, gp, utils
from pymc_experimental.inference.fit import fit
from pymc_experimental.marginal_model import MarginalModel
18 changes: 18 additions & 0 deletions pymc_experimental/dp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pymc_experimental.dp.dp import DirichletProcess, DirichletProcessMixture

__all__ = ["DirichletProcess", "DirichletProcessMixture"]
119 changes: 119 additions & 0 deletions pymc_experimental/dp/dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pymc as pm
import pytensor.tensor as pt
from pymc.distributions import Mixture
from pymc.model import modelcontext

__all__ = ["DirichletProcess", "DirichletProcessMixture"]


def DirichletProcess(name, alpha, base_dist, K, observed=None, sbw_name=None, atoms_name=None):
r"""
Truncated Dirichlet Process for Bayesian Nonparametric Density Modelling
Parameters
----------
alpha: tensor_like of float
Scale concentration parameter (alpha > 0) specifying the size of "sticks", or generated
weights, from the stick-breaking process. Ideally, alpha should have a prior and not be
a fixed constant.
base_dist: single batched distribution
The base distribution for a Dirichlet Process. `base_dist` must have shape (K + 1,).
K: int
The truncation parameter for the number of components of the Dirichlet Process Mixture.
The Goldilocks Principle should be used in selecting an appropriate value of K: not too
low to capture all possible clusters and not too high to induce a heavy computational
burden for sampling.
"""
if sbw_name is None:
sbw_name = "sbw"

if atoms_name is None:
atoms_name = "atoms"

if observed is not None:
observed = np.asarray(observed)

if observed.ndim > 1:
raise ValueError("Multi-dimensional Dirichlet Processes are not " "yet supported.")

N = observed.shape[0]

try:
modelcontext(None)
except TypeError:
raise ValueError(
"PyMC Dirichlet Processes are only available under a pm.Model() context manager."
)

sbw = pm.StickBreakingWeights(sbw_name, alpha, K)

if observed is None:
return sbw, pm.Deterministic(atoms_name, base_dist)

"""
idx samples a new atom from `base_dist` with probability alpha/(alpha + N)
and an existing atom from `observed` with probability N/(alpha + N).
If a new atom is not sampled, an atom from `observed` is sampled uniformly.
"""
idx = pm.Bernoulli("idx", p=alpha / (alpha + N), shape=(K + 1,))
atom_selection = pm.Categorical("atom_selection", p=[1 / N] * N, shape=(K + 1,))

atoms = pm.Deterministic(
atoms_name,
var=pt.stack([pt.constant(observed)[atom_selection], base_dist], axis=-1)[
pt.arange(K + 1), idx
],
)

return sbw, atoms


class DirichletProcessMixture(Mixture):
r"""
Truncated Dirichlet Process Mixture
Parameters
----------
alpha: tensor_like of float
Scale concentration parameter (alpha > 0) specifying the size of "sticks", or generated
weights, from the stick-breaking process. Ideally, alpha should have a prior and not be
a fixed constant.
G0: single batched distribution
The base distribution for a Dirichlet Process Mixture should be created via the
`.dist()` API as this class inherits from `pm.Mixture`. Be sure that the last size
of G0 is K+1.
K: int
The truncation parameter for the number of components of the Dirichlet Process Mixture.
The Goldilocks Principle should be used in selecting an appropriate value of K: not too
low to capture all possible clusters and not too high to induce a heavy computational
burden for sampling.
"""

def __new__(cls, name, alpha, G0, K, **kwargs):
if "sbw_name" in kwargs:
sbw_name = kwargs["sbw_name"]
else:
sbw_name = f"sbw_{name}"

model = modelcontext(None)
model.register_rv(
pm.StickBreakingWeights.dist(alpha, K, **kwargs),
sbw_name,
)
return super().__new__(cls, name, w=model[sbw_name], comp_dists=G0, **kwargs)

0 comments on commit d738d5c

Please sign in to comment.