From 31ce2f2ae792f8da9706b2daaf54a6c8cd4d7d17 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 6 Dec 2024 20:14:00 +0800 Subject: [PATCH] initial commit --- pymc_experimental/model/modular/__init__.py | 0 pymc_experimental/model/modular/components.py | 590 ++++++++++++++++++ pymc_experimental/model/modular/likelihood.py | 359 +++++++++++ pymc_experimental/model/modular/utilities.py | 276 ++++++++ 4 files changed, 1225 insertions(+) create mode 100644 pymc_experimental/model/modular/__init__.py create mode 100644 pymc_experimental/model/modular/components.py create mode 100644 pymc_experimental/model/modular/likelihood.py create mode 100644 pymc_experimental/model/modular/utilities.py diff --git a/pymc_experimental/model/modular/__init__.py b/pymc_experimental/model/modular/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/model/modular/components.py b/pymc_experimental/model/modular/components.py new file mode 100644 index 00000000..16315285 --- /dev/null +++ b/pymc_experimental/model/modular/components.py @@ -0,0 +1,590 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Literal, get_args + +import pandas as pd +import pymc as pm +import pytensor.tensor as pt + +from model.modular.utilities import ColumnType, hierarchical_prior_to_requested_depth +from patsy import dmatrix + +POOLING_TYPES = Literal["none", "complete", "partial"] +valid_pooling = get_args(POOLING_TYPES) + +CURVE_TYPES = Literal["log", "abc", "ns", "nss", "box-cox"] +valid_curves = get_args(CURVE_TYPES) + + +FEATURE_DICT = { + "log": ["slope"], + "box-cox": ["lambda", "slope", "intercept"], + "nss": ["tau", "beta0", "beta1", "beta2"], + "abc": ["a", "b", "c"], +} + + +def _validate_pooling_params(pooling_columns: ColumnType, pooling: POOLING_TYPES): + """ + Helper function to validate inputs to a GLM component. + + Parameters + ---------- + index_data: Series or DataFrame + Index data used to build hierarchical priors + + pooling: str + Type of pooling to use in the component + + Returns + ------- + None + """ + if pooling_columns is not None and pooling == "complete": + raise ValueError("Index data provided but complete pooling was requested") + if pooling_columns is None and pooling != "complete": + raise ValueError( + "Index data must be provided for partial pooling (pooling = 'partial') or no pooling " + "(pooling = 'none')" + ) + + +def _get_x_cols( + cols: str | Sequence[str], + model: pm.Model | None = None, +) -> pt.TensorVariable: + model = pm.modelcontext(model) + # Don't upcast a single column to a colum matrix + if isinstance(cols, str): + [cols_idx] = [i for i, col in enumerate(model.coords["feature"]) if col == cols] + else: + cols_idx = [i for i, col in enumerate(model.coords["feature"]) if col is cols] + return model["X_data"][:, cols_idx] + + +class GLMModel(ABC): + """Base class for GLM components. Subclasses should implement the build method to construct the component.""" + + def __init__(self): + self.model = None + self.compiled = False + + @abstractmethod + def build(self, model=None): + pass + + def __add__(self, other): + return AdditiveGLMComponent(self, other) + + def __mul__(self, other): + return MultiplicativeGLMComponent(self, other) + + +class AdditiveGLMComponent(GLMModel): + """Class to represent an additive combination of GLM components""" + + def __init__(self, left, right): + self.left = left + self.right = right + super().__init__() + + def build(self, *args, **kwargs): + return self.left.build(*args, **kwargs) + self.right.build(*args, **kwargs) + + +class MultiplicativeGLMComponent(GLMModel): + """Class to represent a multiplicative combination of GLM components""" + + def __init__(self, left, right): + self.left = left + self.right = right + super().__init__() + + def build(self, *args, **kwargs): + return self.left.build(*args, **kwargs) * self.right.build(*args, **kwargs) + + +class Intercept(GLMModel): + def __init__( + self, + name: str | None = None, + *, + pooling_cols: ColumnType = None, + pooling: POOLING_TYPES = "complete", + hierarchical_params: dict | None = None, + prior: str = "Normal", + prior_params: dict | None = None, + ): + """ + TODO: Update signature docs + Class to represent an intercept term in a GLM model. + + By intercept, it is meant any constant term in the model that is not a function of any input data. This can be + a simple constant term, or a hierarchical prior that creates fixed effects across level of one or more + categorical variables. + + Parameters + ---------- + name: str, optional + Name of the intercept term. If None, a default name is generated based on the index_data. + index_data: Series or DataFrame, optional + Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as + levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy, + and depth increasing to the right. + + The index of the index_data must match the index of the observed data. + prior: str, optional + Name of the PyMC distribution to use for the intercept term. Default is "Normal". + pooling: str, one of ["none", "complete", "partial"], default "complete" + Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the + index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated + as coming from the same group. If "partial", a hierarchical prior is constructed that shares information + across groups in the index_data. + prior_params: dict, optional + Additional keyword arguments to pass to the PyMC distribution specified by the prior argument. + hierarchical_params: dict, optional + Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function. + Options include: + sigma_dist: str + Name of the distribution to use for the standard deviation of the hierarchy. Default is "Gamma" + sigma_kwargs: dict + Additional keyword arguments to pass to the sigma distribution specified by the sigma_dist argument. + Default is {"alpha": 2, "beta": 1} + offset_dist: str, one of ["zerosum", "normal", "laplace"] + Name of the distribution to use for the offset distribution. Default is "zerosum" + """ + _validate_pooling_params(pooling_cols, pooling) + + self.pooling_cols = pooling_cols + self.hierarchical_params = hierarchical_params if hierarchical_params is not None else {} + self.pooling = pooling if pooling_cols is not None else "complete" + + self.prior = prior + self.prior_params = prior_params if prior_params is not None else {} + + if pooling_cols is None: + pooling_cols = [] + elif isinstance(pooling_cols, str): + pooling_cols = [pooling_cols] + + data_name = ", ".join(pooling_cols) + self.name = name or f"Constant(pooling_cols={data_name})" + super().__init__() + + def build(self, model=None): + model = pm.modelcontext(model) + with model: + if self.pooling == "complete": + intercept = getattr(pm, self.prior)(f"{self.name}", **self.prior_params) + return intercept + + [i for i, col in enumerate(model.coords["feature"]) if col in self.pooling_cols] + + intercept = hierarchical_prior_to_requested_depth( + self.name, + model.X_df[self.pooling_cols], # TODO: Reconsider this + model=model, + dims=None, + no_pooling=self.pooling == "none", + **self.hierarchical_params, + ) + return intercept + + +def build_curve( + time_pt: pt.TensorVariable, + beta: pt.TensorVariable, + curve_type: Literal["log", "abc", "ns", "nss", "box-cox"], +): + """ + Build a curve based on the time data and parameters beta. + + In this context, a "curve" is a deterministic function that maps time to a value. The curve should (in general) be + strictly increasing with time (df(t)/dt > 0), and should (in general) exhibit diminishing marginal growth with time + (d^2f(t)/dt^2 < 0). These properties are not strictly necessary; some curve functions (such as nss) allow for + local reversals. + + Parameters + ---------- + time_pt: TensorVariable + A pytensor variable representing the time data to build the curve from. + beta: TensorVariable + A pytensor variable representing the parameters of the curve. The number of parameters and their meaning depend + on the curve_type. + + .. warning:: + Currently no checks are in place to ensure that the number of parameters in beta matches the expected number + for the curve_type. + + curve_type: str, one of ["log", "abc", "ns", "nss", "box-cox"] + Type of curve to build. Options are: + + - "log": + A simple log-linear curve. The curve is defined as: + + .. math:: + + \beta \\log(t) + + - "abc": + A curve parameterized by "a", "b", and "c", such that the minimum value of the curve is "a", the + maximum value is "a + b", and the inflection point is "a + b / c". "C" thus controls the speed of change + from the minimum to the maximum value. The curve is defined as: + + .. math:: + + \frac{a + bc t}{1 + ct} + + - "ns": + The Nelson-Siegel yield curve model. The curve is parameterized by three parameters: :math:`\tau`, + :math:`\beta_1`, and :math:`\beta_2`. :math:`\tau` is the decay rate of the exponential term, and + :math:`\beta_1` and :math:`\beta_2` control the slope and curvature of the curve. The curve is defined as: + + .. math:: + + \begin{align} + x_t &= \beta_1 \\phi(t) + \beta_2 \\left (\\phi(t) - \\exp(-t/\tau) \right ) \\ + \\phi(t) &= \frac{1 - \\exp(-t/\tau)}{t/\tau} + \\end{align} + + - "nss": + The Nelson-Siegel-Svensson yield curve model. The curve is parameterized by four parameters: + :math:`\tau_1`, :math:`\tau_2`, :math:`\beta_1`, and :math:`\beta_2`. :math:`\beta_3` + + Where :math:`\tau_1` and :math:`\tau_2` are the decay rates of the two exponential terms, :math:`\beta_1` + controls the slope of the curve, and :math:`\beta_2` and :math:`\beta_3` control the curvature of the curve. + To ensure that short-term rates are strictly postitive, one typically restrices :math:`\beta_1 + \beta_2 > 0`. + + The curve is defined as: + + .. math:: + \begin{align} + x_t & = \beta_1 \\phi_1(t) + \beta_2 \\left (\\phi_1(t) - \\exp(-t/\tau_1) \right) + \beta_3 \\left (\\phi_2(t) - \\exp(-t/\tau_2) \right) \\ + \\phi_1(t) &= \frac{1 - \\exp(-t/\tau_1)}{t/\tau_1} \\ + \\phi_2(t) &= \frac{1 - \\exp(-t/\tau_2)}{t/\tau_2} + \\end{align} + + Note that this definition omits the constant term that is typically included in the Nelson-Siegel-Svensson; + you are assumed to have already accounted for this with another component in the model. + + - "box-cox": + A curve that applies a box-cox transformation to the time data. The curve is parameterized by two + parameters: :math:`\\lambda` and :math:`\beta`, where :math:`\\lambda` is the box-cox parameter that + interpolates between the log and linear transformations, and :math:`\beta` is the slope of the curve. + + The curve is defined as: + + .. math:: + + \beta \\left ( \frac{t^{\\lambda} - 1}{\\lambda} \right ) + + Returns + ------- + TensorVariable + A pytensor variable representing the curve. + """ + if curve_type == "box-cox": + lam = beta[0] + 1e-12 + time_scaled = (time_pt**lam - 1) / lam + curve = beta[1] * time_scaled + + elif curve_type == "log": + time_scaled = pt.log(time_pt) + curve = beta[0] * time_scaled + + elif curve_type == "ns": + tau = pt.exp(beta[0]) + t_over_tau = time_pt / tau + time_scaled = (1 - pt.exp(-t_over_tau)) / t_over_tau + curve = beta[1] * time_scaled + beta[2] * (time_scaled - pt.exp(-t_over_tau)) + + elif curve_type == "nss": + tau = pt.exp(beta[:2]) + beta = beta[2:] + t_over_tau_1 = time_pt / tau[0] + t_over_tau_2 = time_pt / tau[1] + time_scaled_1 = (1 - pt.exp(t_over_tau_1)) / t_over_tau_1 + time_scaled_2 = (1 - pt.exp(t_over_tau_2)) / t_over_tau_2 + curve = ( + beta[0] * time_scaled_1 + + beta[1] * (time_scaled_1 - pt.exp(-t_over_tau_1)) + + beta[2] * (time_scaled_2 - pt.exp(-t_over_tau_2)) + ) + + elif curve_type == "abc": + curve = (beta[0] + beta[1] * beta[2] * time_pt) / (1 + beta[2] * time_pt) + + else: + raise ValueError(f"Unknown curve type: {curve_type}") + + return curve + + +class Curve(GLMModel): + def __init__( + self, + name: str, + t: pd.Series | pd.DataFrame, + prior: str = "Normal", + index_data: pd.Series | pd.DataFrame | None = None, + pooling: POOLING_TYPES = "complete", + curve_type: CURVE_TYPES = "log", + prior_params: dict | None = None, + hierarchical_params: dict | None = None, + ): + """ + Class to represent a curve in a GLM model. + + A curve is a deterministic function that transforms time data via a non-linear function. Currently, the following + curve types are supported: + - "log": A simple log-linear curve. + - "abc": A curve defined by a minimum value (a), maximum value (b), and inflection point ((a + b) / c). + - "ns": The Nelson-Siegel yield curve model. + - "nss": The Nelson-Siegel-Svensson yield curve model. + - "box-cox": A curve that applies a box-cox transformation to the time data. + + Parameters + ---------- + name: str, optional + Name of the intercept term. If None, a default name is generated based on the index_data. + t: Series + Time data used to build the curve. If Series, must have a name attribute. If dataframe, must have exactly + one column. + index_data: Series or DataFrame, optional + Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as + levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy, + and depth increasing to the right. + + The index of the index_data must match the index of the observed data. + prior: str, optional + Name of the PyMC distribution to use for the intercept term. Default is "Normal". + pooling: str, one of ["none", "complete", "partial"], default "complete" + Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the + index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated + as coming from the same group. If "partial", a hierarchical prior is constructed that shares information + across groups in the index_data. + curve_type: str, one of ["log", "abc", "ns", "nss", "box-cox"] + Type of curve to build. For details, see the build_curve function. + prior_params: dict, optional + Additional keyword arguments to pass to the PyMC distribution specified by the prior argument. + hierarchical_params: dict, optional + Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function. + Options include: + sigma_dist: str + Name of the distribution to use for the standard deviation of the hierarchy. Default is "Gamma" + sigma_kwargs: dict + Additional keyword arguments to pass to the sigma distribution specified by the sigma_dist argument. + Default is {"alpha": 2, "beta": 1} + offset_dist: str, one of ["zerosum", "normal", "laplace"] + Name of the distribution to use for the offset distribution. Default is "zerosum" + """ + + _validate_pooling_params(index_data, pooling) + + self.name = name + self.t = t if isinstance(t, pd.Series) else t.iloc[:, 0] + self.curve_type = curve_type + + self.index_data = index_data + self.pooling = pooling + + self.prior = prior + self.prior_params = prior_params if prior_params is not None else {} + self.hierarchical_params = hierarchical_params if hierarchical_params is not None else {} + + super().__init__() + + def build(self, model=None): + model = pm.modelcontext(model) + obs_dim = self.t.index.name + feature_dim = f"{self.name}_features" + if feature_dim not in model.coords: + model.add_coord(feature_dim, FEATURE_DICT[self.curve_type]) + + with model: + t_pt = pm.Data("t", self.t.values, dims=[obs_dim]) + if self.pooling == "complete": + beta = getattr(pm, self.prior)( + f"{self.name}_beta", **self.prior_params, dims=[feature_dim] + ) + curve = build_curve(t_pt, beta, self.curve_type) + return pm.Deterministic(f"{self.name}", curve, dims=[obs_dim]) + + beta = hierarchical_prior_to_requested_depth( + self.name, + self.index_data, + model=model, + dims=[feature_dim], + no_pooling=self.pooling == "none", + ) + + curve = build_curve(t_pt, beta, self.curve_type) + return pm.Deterministic(f"{self.name}", curve, dims=[obs_dim]) + + +class Regression(GLMModel): + def __init__( + self, + name: str, + X: pd.DataFrame, + prior: str = "Normal", + index_data: pd.Series = None, + pooling: POOLING_TYPES = "complete", + **prior_params, + ): + """ + Class to represent a regression component in a GLM model. + + A regression component is a linear combination of input data and a set of parameters. The input data should be + a DataFrame with the same index as the observed data. Parameteres can be made hierarchical by providing + an index_data Series or DataFrame (which should have the same index as the observed data). + + Parameters + ---------- + name: str, optional + Name of the intercept term. If None, a default name is generated based on the index_data. + X: DataFrame + Exogenous data used to build the regression component. Each column of the DataFrame represents a feature + used in the regression. Index of the DataFrame should match the index of the observed data. + index_data: Series or DataFrame, optional + Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as + levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy, + and depth increasing to the right. + + The index of the index_data must match the index of the observed data. + prior: str, optional + Name of the PyMC distribution to use for the intercept term. Default is "Normal". + pooling: str, one of ["none", "complete", "partial"], default "complete" + Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the + index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated + as coming from the same group. If "partial", a hierarchical prior is constructed that shares information + across groups in the index_data. + curve_type: str, one of ["log", "abc", "ns", "nss", "box-cox"] + Type of curve to build. For details, see the build_curve function. + prior_params: dict, optional + Additional keyword arguments to pass to the PyMC distribution specified by the prior argument. + hierarchical_params: dict, optional + Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function. + Options include: + sigma_dist: str + Name of the distribution to use for the standard deviation of the hierarchy. Default is "Gamma" + sigma_kwargs: dict + Additional keyword arguments to pass to the sigma distribution specified by the sigma_dist argument. + Default is {"alpha": 2, "beta": 1} + offset_dist: str, one of ["zerosum", "normal", "laplace"] + Name of the distribution to use for the offset distribution. Default is "zerosum" + """ + _validate_pooling_params(index_data, pooling) + + self.name = name + self.X = X + self.index_data = index_data + self.pooling = pooling + + self.prior = prior + self.prior_params = prior_params + + super().__init__() + + def build(self, model=None): + model = pm.modelcontext(model) + feature_dim = f"{self.name}_features" + obs_dim = self.X.index.name + + if feature_dim not in model.coords: + model.add_coord(feature_dim, self.X.columns) + + with model: + X_pt = pm.Data(f"{self.name}_data", self.X.values, dims=[obs_dim, feature_dim]) + if self.pooling == "complete": + beta = getattr(pm, self.prior)( + f"{self.name}", **self.prior_params, dims=[feature_dim] + ) + return X_pt @ beta + + beta = hierarchical_prior_to_requested_depth( + self.name, + self.index_data, + model=model, + dims=[feature_dim], + no_pooling=self.pooling == "none", + ) + + regression_effect = (X_pt * beta.T).sum(axis=-1) + return regression_effect + + +class Spline(Regression): + def __init__( + self, + name: str, + n_knots: int = 10, + spline_data: pd.Series | pd.DataFrame | None = None, + prior: str = "Normal", + index_data: pd.Series | None = None, + pooling: POOLING_TYPES = "complete", + **prior_params, + ): + """ + Class to represent a spline component in a GLM model. + + A spline component is a linear combination of basis functions that are piecewise polynomial. The basis functions + are constructed using the `bs` function from the patsy library. The spline_data should be a Series with the same + index as the observed data. + + The weights of the spline components can be made hierarchical by providing an index_data Series or DataFrame + (which should have the same index as the observed data). + + Parameters + ---------- + name: str, optional + Name of the intercept term. If None, a default name is generated based on the index_data. + n_knots: int, default 10 + Number of knots to use in the spline basis. + spline_data: Series or DataFrame + Exogenous data to be interpolated using basis splines. If Series, must have a name attribute. If dataframe, + must have exactly one column. In either case, the index of the data should match the index of the observed + data. + index_data: Series or DataFrame, optional + Index data used to build hierarchical priors. If there are multiple columns, the columns are treated as + levels of a "telescoping" hierarchy, with the leftmost column representing the top level of the hierarchy, + and depth increasing to the right. + + The index of the index_data must match the index of the observed data. + prior: str, optional + Name of the PyMC distribution to use for the intercept term. Default is "Normal". + pooling: str, one of ["none", "complete", "partial"], default "complete" + Type of pooling to use for the intercept term. If "none", no pooling is applied, and each group in the + index_data is treated as independent. If "complete", complete pooling is applied, and all data are treated + as coming from the same group. If "partial", a hierarchical prior is constructed that shares information + across groups in the index_data. + curve_type: str, one of ["log", "abc", "ns", "nss", "box-cox"] + Type of curve to build. For details, see the build_curve function. + prior_params: dict, optional + Additional keyword arguments to pass to the PyMC distribution specified by the prior argument. + hierarchical_params: dict, optional + Additional keyword arguments to configure priors in the hierarchical_prior_to_requested_depth function. + Options include: + sigma_dist: str + Name of the distribution to use for the standard deviation of the hierarchy. Default is "Gamma" + sigma_kwargs: dict + Additional keyword arguments to pass to the sigma distribution specified by the sigma_dist argument. + Default is {"alpha": 2, "beta": 1} + offset_dist: str, one of ["zerosum", "normal", "laplace"] + Name of the distribution to use for the offset distribution. Default is "zerosum" + """ + _validate_pooling_params(index_data, pooling) + + spline_features = dmatrix( + f"bs(maturity_years, df={n_knots}, degree=3) - 1", + {"maturity_years": spline_data}, + ) + X = pd.DataFrame( + spline_features, + index=spline_data.index, + columns=[f"Spline_{i}" for i in range(n_knots)], + ) + + super().__init__( + name=name, X=X, prior=prior, index_data=index_data, pooling=pooling, **prior_params + ) diff --git a/pymc_experimental/model/modular/likelihood.py b/pymc_experimental/model/modular/likelihood.py new file mode 100644 index 00000000..32b4f432 --- /dev/null +++ b/pymc_experimental/model/modular/likelihood.py @@ -0,0 +1,359 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Literal, get_args + +import arviz as az +import pandas as pd +import pymc as pm +import pytensor.tensor as pt + +from pymc.backends.arviz import apply_function_over_dataset +from pymc.model.fgraph import clone_model +from pymc.pytensorf import reseed_rngs +from pytensor.tensor.random.type import RandomType + +from pymc_experimental.model.marginal.marginal_model import MarginalModel +from pymc_experimental.model.modular.utilities import ColumnType + +LIKELIHOOD_TYPES = Literal["lognormal", "logt", "mixture", "unmarginalized-mixture"] +valid_likelihoods = get_args(LIKELIHOOD_TYPES) + + +class Likelihood(ABC): + """Class to represent a likelihood function for a GLM component. Subclasses should implement the _get_model_class + method to return the type of model used by the likelihood function, and should implement a `register_xx` method for + each parameter unique to that likelihood function.""" + + def __init__(self, target_col: ColumnType, data: pd.DataFrame): + """ + Initialization logic common to all likelihoods. + + All subclasses should call super().__init__(y) to register data and create a model object. The subclass __init__ + method should then create a PyMC model inside the self.model context. + + Parameters + ---------- + y: Series or DataFrame, optional + Observed data. Must have a name attribute (if a Series), and an index with a name attribute. + """ + + if not isinstance(target_col, str): + [target_col] = target_col + self.target_col = target_col + + # TODO: Reconsider this (two sources of nearly the same info not good) + X_df = data.drop(columns=[target_col]) + X_data = X_df.copy() + self.column_labels = {} + for col, dtype in X_data.dtypes.to_dict().items(): + if dtype.name.startswith("float"): + pass + elif dtype.name == "object": + # TODO: We definitely need to save these if we want to factorize predict data + col_array, labels = pd.factorize(X_data[col], sort=True) + X_data[col] = col_array.astype("float64") + self.column_labels[col] = {label: i for i, label in enumerate(labels.values)} + elif dtype.name.startswith("int"): + X_data[col] = X_data[col].astype("float64") + else: + raise NotImplementedError( + f"Haven't decided how to handle the following type: {dtype.name}" + ) + + self.obs_dim = data.index.name + coords = { + self.obs_dim: data.index.values, + "feature": list(X_data.columns), + } + with self._get_model_class(coords) as self.model: + self.model.X_df = X_df # FIXME: Definitely not a solution + pm.Data(f"{target_col}_observed", data[target_col], dims=self.obs_dim) + pm.Data( + "X_data", + X_data, + dims=(self.obs_dim, "feature"), + shape=(None, len(coords["feature"])), + ) + + self._predict_fn = None # We are caching function for faster predictions + + def sample(self, **sample_kwargs): + with self.model: + return pm.sample(**sample_kwargs) + + def predict( + self, + idata: az.InferenceData, + predict_df: pd.DataFrame, + random_seed=None, + compile_kwargs=None, + ): + # Makes sure only features present during fitting are used and sorted during prediction + X_data = predict_df[list(self.model.coords["feature"])].copy() + for col, dtype in X_data.dtypes.to_dict().items(): + if dtype.name.startswith("float"): + pass + elif dtype.name == "object": + X_data[col] = X_data[col].map(self.column_labels[col]).astype("float64") + elif dtype.name.startswith("int"): + X_data[col] = X_data[col].astype("float64") + else: + raise NotImplementedError( + f"Haven't decided how to handle the following type: {dtype.name}" + ) + + coords = {self.obs_dim: X_data.index.values} + + predict_fn = self._predict_fn + + if predict_fn is None: + model_copy = clone_model(self.model) + # TODO: Freeze everything that is not supposed to change, when PyMC allows it + # dims = [dim for dim self.model.coords.keys() if dim != self.obs_dim] + # model_copy = freeze_dims_and_data(model_copy, dims=dims, data=[]) + observed_RVs = model_copy.observed_RVs + if compile_kwargs is None: + compile_kwargs = {} + predict_fn = model_copy.compile_fn( + observed_RVs, + inputs=model_copy.free_RVs, + mode=compile_kwargs.pop("mode", None), + on_unused_input="ignore", + **compile_kwargs, + ) + predict_fn.trust_input = True + self._predict_fn = predict_fn + + [X_var] = [shared for shared in predict_fn.f.get_shared() if shared.name == "X_data"] + if random_seed is not None: + rngs = [ + shared + for shared in predict_fn.f.get_shared() + if isinstance(shared.type, RandomType) + ] + reseed_rngs(rngs, random_seed) + X_var.set_value(X_data.values, borrow=True) + + return apply_function_over_dataset( + fn=predict_fn, + dataset=idata.posterior[[rv.name for rv in self.model.free_RVs]], + output_var_names=[rv.name for rv in self.model.observed_RVs], + dims={rv.name: [self.obs_dim] for rv in self.model.observed_RVs}, + coords=coords, + progressbar=False, + ) + + @abstractmethod + def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel: + """Return the type on model used by the likelihood function""" + raise NotImplementedError + + def register_mu( + self, + *, + df: pd.DataFrame, + mu=None, + ): + with self.model: + if mu is not None: + return pm.Deterministic("mu", mu.build(df=df), dims=[self.obs_dim]) + return pm.Normal("mu", 0, 100) + + def register_sigma( + self, + *, + df: pd.DataFrame, + sigma=None, + ): + with self.model: + if sigma is not None: + return pm.Deterministic("sigma", pt.exp(sigma.build(df=df)), dims=[self.obs_dim]) + return pm.Exponential("sigma", lam=1) + + +class LogNormalLikelihood(Likelihood): + """Class to represent a log-normal likelihood function for a GLM component.""" + + def __init__( + self, + mu, + sigma, + target_col: ColumnType, + data: pd.DataFrame, + ): + super().__init__(target_col=target_col, data=data) + + with self.model: + self.register_data(data[target_col]) + mu = self.register_mu(mu) + sigma = self.register_sigma(sigma) + + pm.LogNormal( + target_col, + mu=mu, + sigma=sigma, + observed=self.model[f"{target_col}_observed"], + dims=[self.obs_dim], + ) + + def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel: + return pm.Model(coords=coords) + + +class LogTLikelihood(Likelihood): + """ + Class to represent a log-t likelihood function for a GLM component. + """ + + def __init__( + self, + mu, + *, + sigma=None, + nu=None, + target_col: ColumnType, + data: pd.DataFrame, + ): + def log_student_t(nu, mu, sigma, shape=None): + return pm.math.exp(pm.StudentT.dist(mu=mu, sigma=sigma, nu=nu, shape=shape)) + + super().__init__(target_col=target_col, data=data) + + with self.model: + mu = self.register_mu(mu=mu, df=data) + sigma = self.register_sigma(sigma=sigma, df=data) + nu = self.register_nu(nu=nu, df=data) + + pm.CustomDist( + target_col, + nu, + mu, + sigma, + observed=self.model[f"{target_col}_observed"], + shape=mu.shape, + dims=[self.obs_dim], + dist=log_student_t, + class_name="LogStudentT", + ) + + def register_nu(self, *, df, nu=None): + with self.model: + if nu is not None: + return pm.Deterministic("nu", pt.exp(nu.build(df=df)), dims=[self.obs_dim]) + return pm.Uniform("nu", 2, 30) + + def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel: + return pm.Model(coords=coords) + + +class BaseMixtureLikelihood(Likelihood): + """ + Base class for mixture likelihood functions to hold common methods for registering parameters. + """ + + def register_sigma(self, *, df, sigma=None): + with self.model: + if sigma is None: + sigma_not_outlier = pm.Exponential("sigma_not_outlier", lam=1) + else: + sigma_not_outlier = pm.Deterministic( + "sigma_not_outlier", pt.exp(sigma.build(df=df)), dims=[self.obs_dim] + ) + sigma_outlier_offset = pm.Gamma("sigma_outlier_offset", mu=0.2, sigma=0.5) + sigma = pm.Deterministic( + "sigma", + pt.as_tensor([sigma_not_outlier, sigma_not_outlier * (1 + sigma_outlier_offset)]), + dims=["outlier"], + ) + + return sigma + + def register_p_outlier(self, *, df, p_outlier=None, **param_kwargs): + mean_p = param_kwargs.get("mean_p", 0.1) + concentration = param_kwargs.get("concentration", 50) + + with self.model: + if p_outlier is not None: + return pm.Deterministic( + "p_outlier", pt.sigmoid(p_outlier.build(df=df)), dims=[self.obs_dim] + ) + return pm.Beta("p_outlier", mean_p * concentration, (1 - mean_p) * concentration) + + def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel: + coords["outlier"] = [False, True] + return MarginalModel(coords=coords) + + +class MixtureLikelihood(BaseMixtureLikelihood): + """ + Class to represent a mixture likelihood function for a GLM component. The mixture is implemented using pm.Mixture, + and does not allow for automatic marginalization of components. + """ + + def __init__( + self, + mu, + sigma, + p_outlier, + target_col: ColumnType, + data: pd.DataFrame, + ): + super().__init__(target_col=target_col, data=data) + + with self.model: + mu = self.register_mu(mu) + sigma = self.register_sigma(sigma) + p_outlier = self.register_p_outlier(p_outlier) + + pm.Mixture( + target_col, + w=[1 - p_outlier, p_outlier], + comp_dists=pm.LogNormal.dist(mu[..., None], sigma=sigma.T), + shape=mu.shape, + observed=self.model[f"{target_col}_observed"], + dims=[self.obs_dim], + ) + + +class UnmarginalizedMixtureLikelihood(BaseMixtureLikelihood): + """ + Class to represent an unmarginalized mixture likelihood function for a GLM component. The mixture is implemented using + a MarginalModel, and allows for automatic marginalization of components. + """ + + def __init__( + self, + mu, + sigma, + p_outlier, + target_col: ColumnType, + data: pd.DataFrame, + ): + super().__init__(target_col=target_col, data=data) + + with self.model: + mu = self.register_mu(mu) + sigma = self.register_sigma(sigma) + p_outlier = self.register_p_outlier(p_outlier) + + is_outlier = pm.Bernoulli( + "is_outlier", + p_outlier, + dims=["cusip"], + # shape=X_pt.shape[0], # Uncomment after https://github.com/pymc-devs/pymc-experimental/pull/304 + ) + + pm.LogNormal( + target_col, + mu=mu, + sigma=pm.math.switch(is_outlier, sigma[1], sigma[0]), + observed=self.model[f"{target_col}_observed"], + shape=mu.shape, + dims=[data.index.name], + ) + + self.model.marginalize(["is_outlier"]) + + def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel: + coords["outlier"] = [False, True] + return MarginalModel(coords=coords) diff --git a/pymc_experimental/model/modular/utilities.py b/pymc_experimental/model/modular/utilities.py new file mode 100644 index 00000000..8e9c27a5 --- /dev/null +++ b/pymc_experimental/model/modular/utilities.py @@ -0,0 +1,276 @@ +import itertools + +from collections.abc import Sequence + +import pandas as pd +import pymc as pm +import pytensor.tensor as pt + +ColumnType = str | Sequence[str] | None + +# Dictionary to define offset distributions for hierarchical models +OFFSET_DIST_FACTORY = { + "zerosum": lambda name, offset_dims: pm.ZeroSumNormal(f"{name}_offset", dims=offset_dims), + "normal": lambda name, offset_dims: pm.Normal(f"{name}_offset", dims=offset_dims), + "laplace": lambda name, offset_dims: pm.Laplace(f"{name}_offset", mu=0, b=1, dims=offset_dims), +} + +# Default kwargs for sigma distributions +SIGMA_DEFAULT_KWARGS = { + "Gamma": {"alpha": 2, "beta": 1}, + "Exponential": {"lam": 1}, + "HalfNormal": {"sigma": 1}, + "HalfCauchy": {"beta": 1}, +} + + +def _get_x_cols( + cols: str | Sequence[str], + model: pm.Model | None = None, +) -> pt.TensorVariable: + model = pm.modelcontext(model) + # Don't upcast a single column to a colum matrix + if isinstance(cols, str): + [cols_idx] = [i for i, col in enumerate(model.coords["feature"]) if col == cols] + else: + cols_idx = [i for i, col in enumerate(model.coords["feature"]) if col is cols] + return model["X_data"][:, cols_idx] + + +def make_level_maps(df: pd.DataFrame, ordered_levels: list[str]): + """ + For each row of data, create a mapping between levels of a arbitrary set of levels defined by `ordered_levels`. + + Consider a set of levels (A, B, C) with members A: [A], B: [B1, B2], C: [C1, C2, C3, C4] arraged in a tree, like: + A + / \ + B1 B2 + / \\ / \ + C1 C2 C3 C4 + + A "deep hierarchy" will have the following priors: + A ~ F(...) + B1, B2 ~ F(A, ...) + C1, C2 ~ F(B1, ...) + C3, C4 ~ F(B2, ...) + + Noting that there could be multiple such trees in a dataset, to create these priors in a memory efficient way we need 2 mappings: B to A, and C to B. These + need to be generated at inference time, and also re-generated for out of sample prediction. + + Parameters + ---------- + df: pd.DataFrame + It's data OK? + + ordered_levels: list[str] + Sequence of level names, ordered from highest to lowest. In the above example, ordered_levels = ['A', 'B', 'C'] + + Returns + ------- + labels: list[pd.Index] + Unique labels generated for each level, sorted alphabetically. Ordering corresponds to the integers in the corresponding mapping, à la pd.factorize + + mappings: list[np.ndarray] + `len(ordered_levels) - 1` list of arrays indexing each previous level to the next level. The i-th array in the list has shape len(df[ordered_levels[i+1]].unique()) + """ + # TODO: Raise an error if there are one-to-many mappings between levels? + if not all([level in df for level in ordered_levels]): + missing = set(ordered_levels) - set(df.columns) + raise ValueError(f'Requested levels were not in provided dataframe: {", ".join(missing)}') + + level_pairs = itertools.pairwise(ordered_levels) + mappings = [] + labels = [] + for pair in level_pairs: + _, level_labels = pd.factorize(df[pair[0]], sort=True) + edges = df[list(pair)].drop_duplicates().set_index(pair[1])[pair[0]].sort_index() + idx = edges.map({k: i for i, k in enumerate(level_labels)}).values + labels.append(level_labels) + mappings.append(idx) + + last_map, last_labels = pd.factorize(df[ordered_levels[-1]], sort=True) + labels.append(last_labels) + mappings.append(last_map) + + return labels, mappings + + +def make_next_level_hierarchy_variable( + name: str, + mu, + sigma_dist: str = "Gamma", + sigma_kwargs: dict | None = None, + mapping=None, + sigma_dims=None, + offset_dims=None, + offset_dist="Normal", + no_pooling=False, +): + if no_pooling: + if mapping is None: + return pm.Deterministic(f"{name}", mu[..., None], dims=offset_dims) + else: + return pm.Deterministic(f"{name}", mu[..., mapping], dims=offset_dims) + + d_sigma = getattr(pm, sigma_dist) + + if sigma_kwargs is None: + if sigma_dist not in SIGMA_DEFAULT_KWARGS: + raise NotImplementedError( + f"No defaults implemented for {sigma_dist}. Pass sigma_kwargs explictly." + ) + sigma_kwargs = SIGMA_DEFAULT_KWARGS[sigma_dist] + + sigma_ = d_sigma(f"{name}_sigma", **sigma_kwargs, dims=sigma_dims) + + offset_dist = offset_dist.lower() + if offset_dist not in OFFSET_DIST_FACTORY: + raise NotImplementedError() + + offset = OFFSET_DIST_FACTORY[offset_dist](name, offset_dims) + + if mapping is None: + return pm.Deterministic( + f"{name}", mu[..., None] + sigma_[..., None] * offset, dims=offset_dims + ) + else: + return pm.Deterministic( + f"{name}", mu[..., mapping] + sigma_[..., mapping] * offset, dims=offset_dims + ) + + +def hierarchical_prior_to_requested_depth( + name: str, + df: pd.DataFrame, + model: pm.Model = None, + dims: list[str] | None = None, + no_pooling: bool = False, + **hierarchy_kwargs, +): + """ + Given a dataframe of categorical data, construct a hierarchical prior that pools data telescopically, moving from + left to right across the columns of the dataframe. + + At its simplest, this function can be used to construct a simple hierarchical prior for a single categorical + variable. Consider the following example: + + .. code-block:: python + + df = pd.DataFrame(['Apple', 'Apple', 'Banana', 'Banana'], columns=['fruit']) + coords = {'fruit': ['Apple', 'Banana']} + with pm.Model(coords=coords) as model: + fruit_effect = hierarchical_prior_to_requested_depth('fruit_effect', df) + + This will construct a simple, non-centered hierarchical intercept corresponding to the 'fruit' feature of the data. + The power of the function comes from its ability to handle multiple categorical variables, and to construct a + hierarchical prior that pools data across multiple levels of a hierarchy. Consider the following example: + + .. code-block:: python + df = pd.DataFrame({'fruit': ['Apple', 'Apple', 'Banana', 'Banana'], + 'color': ['Red', 'Green', 'Yellow', 'Brown']}) + coords = {'fruit': ['Apple', 'Banana'], 'color': ['Red', 'Green', 'Yellow', 'Brown']} + with pm.Model(coords=coords) as model: + fruit_effect = hierarchical_prior_to_requested_depth('fruit_effect', df[['fruit', 'color']]) + + This will construct a two-level hierarchy. The first level will pool all rows of data with the same 'fruit' value, + and the second level will pool color values within each fruit. The structure of the hierarchy will be: + + Apple Banana + / \\ / \ + Red Green Yellow Brown + + That is, estimates for each of "red" and "green" will be centered on the estimate of "apple", and estimates for + "yellow" and "brown" will be centered on the estimate of "banana". + + .. warning:: + Currently, the structure of the data **must** be bijective with respect to the levels of the hierarchy. That is, + each child must map to exactly one parent. In the above example, we could not consider green bananas, for example, + because the "green" level would not uniquely map to "apple". This is a limitation of the current implementation. + + + Parameters + ---------- + name: str + Name of the variable to construct + df: DataFrame + DataFrame of categorical data. Each column represents a level of the hierarchy, with the leftmost column + representing the top level of the hierarchy, with depth increasing to the right. + model: pm.Model, optional + PyMC model to add the variable to. If None, the model on the current context stack is used. + dims: list of str, optional + Additional dimensions to add to the variable. These are treated as batch dimensions, and are added to the + left of the hierarchy dimensions. For example, if dims=['feature'], and df has one column named "country", + the returned variables will have dimensions ['feature', 'country'] + no_pooling: bool, optional + If True, no pooling is applied to the variable. Each level of the hierarchy is treated as independent, with no + informaton shared across level members of a given level. + hierarchy_kwargs: dict + Additional keyword arguments to pass to the underlying PyMC distribution. Options include: + sigma_dist: str + Name of the distribution to use for the standard deviation of the hierarchy. Default is "Gamma" + sigma_kwargs: dict + Additional keyword arguments to pass to the sigma distribution specified by the sigma_dist argument. + Default is {"alpha": 2, "beta": 1} + offset_dist: str, one of ["zerosum", "normal", "laplace"] + Name of the distribution to use for the offset distribution. Default is "zerosum" + + Returns + ------- + pm.Distribution + PyMC distribution representing the hierarchical prior. The shape of the distribution will be + (n_obs, *dims, df.loc[:, -1].nunique()) + """ + + if isinstance(df, pd.Series): + df = df.to_frame() + + model = pm.modelcontext(model) + sigma_dist = hierarchy_kwargs.pop("sigma_dist", "Gamma") + sigma_kwargs = hierarchy_kwargs.pop("sigma_kwargs", {"alpha": 2, "beta": 1}) + offset_dist = hierarchy_kwargs.pop("offset_dist", "zerosum") + + levels = [None, *df.columns.tolist()] + n_levels = len(levels) - 1 + idx_maps = None + if n_levels > 1: + labels, idx_maps = make_level_maps(df, levels[1:]) + + if idx_maps: + idx_maps = [None, *idx_maps] + else: + idx_maps = [None] + + for level_dim in levels[1:]: + _, labels = pd.factorize(df[level_dim], sort=True) + if level_dim not in model.coords: + model.add_coord(level_dim, labels) + + # Danger zone, this assumes we factorized the same way here and in X_data + deepest_map = _get_x_cols(df.columns[-1]).astype("int") + + with model: + beta = pm.Normal(f"{name}_effect", 0, 1, dims=dims) + for i, (level, last_level) in enumerate(zip(levels[1:], levels[:-1])): + if i == 0: + sigma_dims = dims + else: + sigma_dims = [*dims, last_level] if dims is not None else [last_level] + offset_dims = [*dims, level] if dims is not None else [level] + + # TODO: Need a better way to handle different priors at each level. + if "beta" in sigma_kwargs: + sigma_kwargs["beta"] = sigma_kwargs["beta"] ** (i + 1) + + beta = make_next_level_hierarchy_variable( + f"{name}_{level}_effect", + mu=beta, + sigma_dist=sigma_dist, + sigma_kwargs=sigma_kwargs, + mapping=idx_maps[i], + sigma_dims=sigma_dims, + offset_dims=offset_dims, + offset_dist=offset_dist, + no_pooling=no_pooling, + ) + + return beta[..., deepest_map]