Skip to content

Commit

Permalink
Creating a Dirichlet Process Mixture class
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Aug 31, 2022
1 parent 18fe6d0 commit faeed12
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
_log.addHandler(handler)


from pymc_experimental import distributions, gp, utils
from pymc_experimental import distributions, dp, gp, utils
from pymc_experimental.bart import *
from pymc_experimental.dp import *
18 changes: 18 additions & 0 deletions pymc_experimental/dp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pymc_experimental.dp.dp import DirichletProcessMixture

__all__ = ["DirichletProcessMixture"]
53 changes: 53 additions & 0 deletions pymc_experimental/dp/dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc.model import modelcontext
from pymc.distributions import Mixture
import pymc as pm

__all__ = ["DirichletProcessMixture"]


class DirichletProcessMixture(Mixture):
r"""
Truncated Dirichlet Process Mixture for Bayesian Nonparametric Density Modelling
Parameters
----------
alpha: tensor_like of float
Scale concentration parameter (alpha > 0) specifying the size of "sticks", or generated
weights, from the stick-breaking process. Ideally, alpha should have a prior and not be
a fixed constant.
G0: single batched distribution
The base distribution for a Dirichlet Process Mixture should be created via the
`.dist()` API as this class inherits from `pm.Mixture`. Be sure that the last size
of G0 is K+1.
K: int
The truncation parameter for the number of components of the Dirichlet Process Mixture.
The Goldilocks Principle should be used in selecting an appropriate value of K: not too
low to capture all possible clusters and not too high to induce a heavy computational
burden for sampling.
"""
def __new__(cls, name, alpha, G0, K, **kwargs):
if "sbw_name" in kwargs:
sbw_name = kwargs["sbw_name"]
else:
sbw_name = f"sbw_{name}"

model = modelcontext(None)
model.register_rv(
pm.StickBreakingWeights.dist(alpha, K, **kwargs),
sbw_name,
)
return super().__new__(cls, name, w=model[sbw_name], comp_dists=G0, **kwargs)

0 comments on commit faeed12

Please sign in to comment.