Skip to content

Commit

Permalink
Refactor likelihood, add small example
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 13, 2024
1 parent ab3b4ed commit 3df6534
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 198 deletions.
9 changes: 9 additions & 0 deletions pymc_experimental/model/modular/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pymc_experimental.model.modular.components import Intercept, Regression, Spline
from pymc_experimental.model.modular.likelihood import NormalLikelihood

__all__ = [
"Intercept",
"Regression",
"Spline",
"NormalLikelihood",
]
9 changes: 5 additions & 4 deletions pymc_experimental/model/modular/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import pymc as pm
import pytensor.tensor as pt

from model.modular.utilities import (
from patsy import dmatrix
from pytensor.graph import Apply, Op

from pymc_experimental.model.modular.utilities import (
PRIOR_DEFAULT_KWARGS,
ColumnType,
PoolingType,
Expand All @@ -13,14 +16,12 @@
make_hierarchical_prior,
select_data_columns,
)
from patsy import dmatrix
from pytensor.graph import Apply, Op


class GLMModel(ABC):
"""Base class for GLM components. Subclasses should implement the build method to construct the component."""

def __init__(self, name):
def __init__(self, name=None):
self.model = None
self.compiled = False
self.name = name
Expand Down
210 changes: 29 additions & 181 deletions pymc_experimental/model/modular/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from io import StringIO
from typing import Literal, get_args

import arviz as az
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import rich

from pymc.backends.arviz import apply_function_over_dataset
from pymc.model.fgraph import clone_model
Expand All @@ -14,6 +16,7 @@

from pymc_experimental.model.marginal.marginal_model import MarginalModel
from pymc_experimental.model.modular.utilities import ColumnType, encode_categoricals
from pymc_experimental.printing import model_table

LIKELIHOOD_TYPES = Literal["lognormal", "logt", "mixture", "unmarginalized-mixture"]
valid_likelihoods = get_args(LIKELIHOOD_TYPES)
Expand Down Expand Up @@ -43,7 +46,7 @@ def __init__(self, target_col: ColumnType, data: pd.DataFrame):

X_df = data.drop(columns=[target_col])

self.obs_dim = data.index.name
self.obs_dim = data.index.name if data.index.name is not None else "obs_idx"
self.coords = {
self.obs_dim: data.index.values,
}
Expand All @@ -70,6 +73,10 @@ def sample(self, **sample_kwargs):
with self.model:
return pm.sample(**sample_kwargs)

def sample_prior_predictive(self, **sample_kwargs):
with self.model:
return pm.sample_prior_predictive(**sample_kwargs)

def predict(
self,
idata: az.InferenceData,
Expand Down Expand Up @@ -137,212 +144,53 @@ def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalMo
"""Return the type on model used by the likelihood function"""
raise NotImplementedError

def register_mu(
self,
*,
df: pd.DataFrame,
mu=None,
):
def register_mu(self, mu=None):
with self.model:
if mu is not None:
return pm.Deterministic("mu", mu.build(df=df), dims=[self.obs_dim])
return pm.Deterministic("mu", mu.build(self.model), dims=[self.obs_dim])
return pm.Normal("mu", 0, 100)

def register_sigma(
self,
*,
df: pd.DataFrame,
sigma=None,
):
def register_sigma(self, sigma=None):
with self.model:
if sigma is not None:
return pm.Deterministic("sigma", pt.exp(sigma.build(df=df)), dims=[self.obs_dim])
return pm.Exponential("sigma", lam=1)


class LogNormalLikelihood(Likelihood):
"""Class to represent a log-normal likelihood function for a GLM component."""

def __init__(
self,
mu,
sigma,
target_col: ColumnType,
data: pd.DataFrame,
):
super().__init__(target_col=target_col, data=data)

with self.model:
self.register_data(data[target_col])
mu = self.register_mu(mu)
sigma = self.register_sigma(sigma)

pm.LogNormal(
target_col,
mu=mu,
sigma=sigma,
observed=self.model[f"{target_col}_observed"],
dims=[self.obs_dim],
)

def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
return pm.Model(coords=coords)


class LogTLikelihood(Likelihood):
"""
Class to represent a log-t likelihood function for a GLM component.
"""

def __init__(
self,
mu,
*,
sigma=None,
nu=None,
target_col: ColumnType,
data: pd.DataFrame,
):
def log_student_t(nu, mu, sigma, shape=None):
return pm.math.exp(pm.StudentT.dist(mu=mu, sigma=sigma, nu=nu, shape=shape))

super().__init__(target_col=target_col, data=data)

with self.model:
mu = self.register_mu(mu=mu, df=data)
sigma = self.register_sigma(sigma=sigma, df=data)
nu = self.register_nu(nu=nu, df=data)

pm.CustomDist(
target_col,
nu,
mu,
sigma,
observed=self.model[f"{target_col}_observed"],
shape=mu.shape,
dims=[self.obs_dim],
dist=log_student_t,
class_name="LogStudentT",
)

def register_nu(self, *, df, nu=None):
with self.model:
if nu is not None:
return pm.Deterministic("nu", pt.exp(nu.build(df=df)), dims=[self.obs_dim])
return pm.Uniform("nu", 2, 30)

def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
return pm.Model(coords=coords)


class BaseMixtureLikelihood(Likelihood):
"""
Base class for mixture likelihood functions to hold common methods for registering parameters.
"""

def register_sigma(self, *, df, sigma=None):
with self.model:
if sigma is None:
sigma_not_outlier = pm.Exponential("sigma_not_outlier", lam=1)
else:
sigma_not_outlier = pm.Deterministic(
"sigma_not_outlier", pt.exp(sigma.build(df=df)), dims=[self.obs_dim]
)
sigma_outlier_offset = pm.Gamma("sigma_outlier_offset", mu=0.2, sigma=0.5)
sigma = pm.Deterministic(
"sigma",
pt.as_tensor([sigma_not_outlier, sigma_not_outlier * (1 + sigma_outlier_offset)]),
dims=["outlier"],
)

return sigma

def register_p_outlier(self, *, df, p_outlier=None, **param_kwargs):
mean_p = param_kwargs.get("mean_p", 0.1)
concentration = param_kwargs.get("concentration", 50)

with self.model:
if p_outlier is not None:
return pm.Deterministic(
"p_outlier", pt.sigmoid(p_outlier.build(df=df)), dims=[self.obs_dim]
"sigma", pt.exp(sigma.build(self.model)), dims=[self.obs_dim]
)
return pm.Beta("p_outlier", mean_p * concentration, (1 - mean_p) * concentration)

def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
coords["outlier"] = [False, True]
return MarginalModel(coords=coords)

return pm.Exponential("sigma", lam=1)

class MixtureLikelihood(BaseMixtureLikelihood):
"""
Class to represent a mixture likelihood function for a GLM component. The mixture is implemented using pm.Mixture,
and does not allow for automatic marginalization of components.
"""
def __repr__(self):
table = model_table(self.model)
buffer = StringIO()
rich.print(table, file=buffer)

def __init__(
self,
mu,
sigma,
p_outlier,
target_col: ColumnType,
data: pd.DataFrame,
):
super().__init__(target_col=target_col, data=data)
return buffer.getvalue()

with self.model:
mu = self.register_mu(mu)
sigma = self.register_sigma(sigma)
p_outlier = self.register_p_outlier(p_outlier)
def to_graphviz(self):
return self.model.to_graphviz()

pm.Mixture(
target_col,
w=[1 - p_outlier, p_outlier],
comp_dists=pm.LogNormal.dist(mu[..., None], sigma=sigma.T),
shape=mu.shape,
observed=self.model[f"{target_col}_observed"],
dims=[self.obs_dim],
)
# def _repr_html_(self):
# return model_table(self.model)


class UnmarginalizedMixtureLikelihood(BaseMixtureLikelihood):
class NormalLikelihood(Likelihood):
"""
Class to represent an unmarginalized mixture likelihood function for a GLM component. The mixture is implemented using
a MarginalModel, and allows for automatic marginalization of components.
A model with normally distributed errors
"""

def __init__(
self,
mu,
sigma,
p_outlier,
target_col: ColumnType,
data: pd.DataFrame,
):
def __init__(self, mu, sigma, target_col: ColumnType, data: pd.DataFrame):
super().__init__(target_col=target_col, data=data)

with self.model:
mu = self.register_mu(mu)
sigma = self.register_sigma(sigma)
p_outlier = self.register_p_outlier(p_outlier)

is_outlier = pm.Bernoulli(
"is_outlier",
p_outlier,
dims=["cusip"],
# shape=X_pt.shape[0], # Uncomment after https://github.com/pymc-devs/pymc-experimental/pull/304
)

pm.LogNormal(
pm.Normal(
target_col,
mu=mu,
sigma=pm.math.switch(is_outlier, sigma[1], sigma[0]),
sigma=sigma,
observed=self.model[f"{target_col}_observed"],
shape=mu.shape,
dims=[data.index.name],
dims=[self.obs_dim],
)

self.model.marginalize(["is_outlier"])

def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
coords["outlier"] = [False, True]
return MarginalModel(coords=coords)
return pm.Model(coords=coords)
13 changes: 0 additions & 13 deletions pymc_experimental/model/modular/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def select_data_columns(
Returns
-------
X: TensorVariable
A tensor variable representing the selected columns of the independent data
"""
model = pm.modelcontext(model)
Expand Down Expand Up @@ -350,9 +349,6 @@ def make_unpooled_hierarchy(
):
coords = model.coords

sigma_dist = hierarchy_kwargs.pop("sigma_dist", "Gamma")
sigma_kwargs = hierarchy_kwargs.pop("sigma_kwargs", {"alpha": 2, "beta": 1})

if X.ndim == 1:
X = X[:, None]

Expand All @@ -367,17 +363,8 @@ def make_unpooled_hierarchy(
beta = Prior(f"{name}_mu", **prior_kwargs, dims=dims)

for i, (last_level, level) in enumerate(itertools.pairwise([None, *levels])):
if i == 0:
sigma_dims = dims
else:
sigma_dims = [*dims, last_level] if dims is not None else [last_level]
beta_dims = [*dims, level] if dims is not None else [level]

sigma = make_sigma(f"{name}_{level}_effect", sigma_dist, sigma_kwargs, sigma_dims)

prior_kwargs["mu"] = beta[..., idx_maps[i]]
scale_name = "b" if prior == "Laplace" else "sigma"
prior_kwargs[scale_name] = sigma[..., idx_maps[i]]

beta = Prior(f"{name}_{level}_effect", **prior_kwargs, dims=beta_dims)

Expand Down
31 changes: 31 additions & 0 deletions tests/model/modular/test_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import pandas as pd
import pytest

from pymc_experimental.model.modular.likelihood import NormalLikelihood


@pytest.fixture(scope="session")
def rng():
return np.random.default_rng()


@pytest.fixture(scope="session")
def data(rng):
city = ["A", "B", "C"]
race = ["white", "black", "hispanic"]

df = pd.DataFrame(
{
"city": np.random.choice(city, 1000),
"age": rng.normal(size=1000),
"race": rng.choice(race, size=1000),
"income": rng.normal(size=1000),
}
)
return df


def test_normal_likelihood(data):
model = NormalLikelihood(mu=None, sigma=None, target_col="income", data=data)
idata = model.sample_prior_predictive()

0 comments on commit 3df6534

Please sign in to comment.