-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PC prior distribution for Student T dof #252
base: main
Are you sure you want to change the base?
Changes from all commits
9dd5573
8b6e391
f8dc8d0
310f7e1
a2026f5
62fda67
14ea01e
60077e8
67549a2
8f61efe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,18 +19,28 @@ | |||||
The imports from pymc are not fully replicated here: add imports as necessary. | ||||||
""" | ||||||
|
||||||
from typing import List, Tuple, Union | ||||||
from typing import List, Optional, Tuple, Union | ||||||
|
||||||
import numpy as np | ||||||
import pytensor.tensor as pt | ||||||
from pymc.distributions.dist_math import check_parameters | ||||||
from pymc.distributions.continuous import ( | ||||||
DIST_PARAMETER_TYPES, | ||||||
PositiveContinuous, | ||||||
check_parameters, | ||||||
) | ||||||
from pymc.distributions.distribution import Continuous | ||||||
from pymc.distributions.shape_utils import rv_size_is_none | ||||||
from pymc.pytensorf import floatX | ||||||
from pytensor.tensor import TensorVariable | ||||||
from pytensor.tensor.random.op import RandomVariable | ||||||
from pytensor.tensor.variable import TensorVariable | ||||||
from scipy import stats | ||||||
|
||||||
from pymc_experimental.distributions.dist_math import ( | ||||||
pc_prior_studentt_kld_dist_inv_op, | ||||||
pc_prior_studentt_logp, | ||||||
studentt_kld_distance, | ||||||
) | ||||||
|
||||||
|
||||||
class GenExtremeRV(RandomVariable): | ||||||
name: str = "Generalized Extreme Value" | ||||||
|
@@ -216,3 +226,62 @@ def moment(rv, size, mu, sigma, xi): | |||||
if not rv_size_is_none(size): | ||||||
mode = pt.full(size, mode) | ||||||
return mode | ||||||
|
||||||
|
||||||
class PCPriorStudentT_dof_RV(RandomVariable): | ||||||
name = "pc_prior_studentt_dof" | ||||||
ndim_supp = 0 | ||||||
ndims_params = [0] | ||||||
dtype = "floatX" | ||||||
_print_name = ("PCTDoF", "\\operatorname{PCPriorStudentT_dof}") | ||||||
|
||||||
@classmethod | ||||||
def rng_fn(cls, rng, lam, size=None) -> np.ndarray: | ||||||
return pc_prior_studentt_kld_dist_inv_op.spline(rng.exponential(scale=1.0 / lam, size=size)) | ||||||
|
||||||
|
||||||
pc_prior_studentt_dof = PCPriorStudentT_dof_RV() | ||||||
|
||||||
|
||||||
class PCPriorStudentT_dof(PositiveContinuous): | ||||||
|
||||||
rv_op = pc_prior_studentt_dof | ||||||
|
||||||
@classmethod | ||||||
def dist( | ||||||
cls, | ||||||
alpha: Optional[DIST_PARAMETER_TYPES] = None, | ||||||
U: Optional[DIST_PARAMETER_TYPES] = None, | ||||||
lam: Optional[DIST_PARAMETER_TYPES] = None, | ||||||
*args, | ||||||
**kwargs | ||||||
): | ||||||
lam = cls.get_lam(alpha, U, lam) | ||||||
return super().dist([lam], *args, **kwargs) | ||||||
|
||||||
def moment(rv, size, lam): | ||||||
mean = pc_prior_studentt_kld_dist_inv_op(1.0 / lam) | ||||||
if not rv_size_is_none(size): | ||||||
mean = pt.full(size, mean) | ||||||
return mean | ||||||
|
||||||
@classmethod | ||||||
def get_lam(cls, alpha=None, U=None, lam=None): | ||||||
if (alpha is not None) and (U is not None): | ||||||
return -np.log(alpha) / studentt_kld_distance(U) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
elif lam is not None: | ||||||
return lam | ||||||
else: | ||||||
raise ValueError( | ||||||
"Incompatible parameterization. Either use alpha and U, or lam to specify the " | ||||||
"distribution." | ||||||
) | ||||||
|
||||||
def logp(value, lam): | ||||||
res = pc_prior_studentt_logp(value, lam) | ||||||
res = pt.switch( | ||||||
pt.lt(value, 2 + 1e-6), # 2 + 1e-6 smallest value for nu | ||||||
-np.inf, | ||||||
res, | ||||||
) | ||||||
return check_parameters(res, lam > 0, msg="lam > 0") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2023 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# coding: utf-8 | ||
|
||
import numpy as np | ||
import pytensor.tensor as pt | ||
from pymc.distributions.dist_math import SplineWrapper | ||
from scipy.interpolate import UnivariateSpline | ||
|
||
|
||
def studentt_kld_distance(nu): | ||
""" | ||
2 * sqrt(KL divergence divergence) between a student t and a normal random variable. Derived | ||
by Tang in https://arxiv.org/abs/1811.08042. | ||
""" | ||
return pt.sqrt( | ||
1 | ||
+ pt.log(2 * pt.reciprocal(nu - 2)) | ||
+ 2 * pt.gammaln((nu + 1) / 2) | ||
- 2 * pt.gammaln(nu / 2) | ||
- (nu + 1) * (pt.digamma((nu + 1) / 2) - pt.digamma(nu / 2)) | ||
) | ||
|
||
|
||
def tri_gamma_approx(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is already implemented There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This approximation will be much more performant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw you added trigamma recently, I'll give that a try. I used this approx because at the time the gradient wasn't implement yet, where the gradient for the approx is easy. Wasn't concerned with performance at the time, but will take another look |
||
"""Derivative of the digamma function, or second derivative of the gamma function. This is a | ||
series expansion taken from wikipedia: https://en.wikipedia.org/wiki/Trigamma_function. When | ||
the trigamma function in pytensor implements a gradient this function can be removed and | ||
replaced. | ||
""" | ||
return ( | ||
1 / x | ||
+ (1 / (2 * x**2)) | ||
+ (1 / (6 * x**3)) | ||
- (1 / (30 * x**5)) | ||
+ (1 / (42 * x**7)) | ||
- (1 / (30 * x**9)) | ||
+ (5 / (66 * x**11)) | ||
- (691 / (2730 * x**13)) | ||
+ (7 / (6 * x**15)) | ||
) | ||
|
||
|
||
def pc_prior_studentt_logp(nu, lam): | ||
"""The log probability density function for the PC prior for the degrees of freedom in a | ||
student t likelihood. Derived by Tang in https://arxiv.org/abs/1811.08042. | ||
""" | ||
return ( | ||
pt.log(lam) | ||
+ pt.log( | ||
(1 / (nu - 2)) | ||
+ ((nu + 1) / 2) * (tri_gamma_approx((nu + 1) / 2) - tri_gamma_approx(nu / 2)) | ||
) | ||
- pt.log(4 * studentt_kld_distance(nu)) | ||
- lam * studentt_kld_distance(nu) | ||
+ pt.log(2) | ||
) | ||
|
||
|
||
def _make_pct_inv_func(): | ||
"""This function constructs a numerical approximation to the inverse of the KLD distance | ||
function, `studentt_kld_distance`. It does a spline fit for degrees of freedom values | ||
from 2 + 1e-6 to 4000. 2 is the smallest valid value for the student t degrees of freedom, and | ||
values above 4000 don't seem to change much (nearly Gaussian past 30). It's then wrapped by | ||
`SplineWrapper` so it can be used as a PyTensor op. | ||
""" | ||
NU_MIN = 2.0 + 1e-6 | ||
nu = np.concatenate((np.linspace(NU_MIN, 2.4, 2000), np.linspace(2.4 + 1e-4, 4000, 10000))) | ||
return UnivariateSpline( | ||
studentt_kld_distance(nu).eval()[::-1], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having an eval is a bit dangerous. If it comes up from an RV you're going to get a random value. The safe thing to do is to Or create a PyTensor Op that wraps UnivariateSpline There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't we have such an op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ricardoV94 It only comes from Thanks @ferrine, will look into that. I remember needing to use
as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I missed the inputs were constant, nvm on my end There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's always a known constant could you use |
||
nu[::-1], | ||
ext=3, | ||
k=3, | ||
s=0, | ||
) | ||
|
||
|
||
pc_prior_studentt_kld_dist_inv_op = SplineWrapper(_make_pct_inv_func()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs a docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually we don't document the RV, but the Distribution class, which doesn't have a docstring either