Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated gamma #1187

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

Truncated gamma #1187

wants to merge 8 commits into from

Conversation

quattro
Copy link
Contributor

@quattro quattro commented Oct 12, 2021

PR for issue #969 . Contains initial implementation that performs uniform sampling + inverse CDF of Left/Right/Doubly truncated Gamma. Relies on tensorflow functionality for igammainv function, which is not yet implemented at the lax/jax level (see jax-ml/jax#5350).

There is a test that fails, but it is not clear to me if this is purely a numerical issue with the uniform + iCDF sampling, or a larger issue that I missed at the time I implemented things.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @quattro! The implementation looks great. Could you expose those distributions to sphinx? For numerical issues, could you increase the thresholds a bit to make the tests pass. I guess gammaincinv does not have good precision (especially under float32).


def icdf(self, q):
# https://github.com/pyro-ppl/numpyro/issues/969
from numpyro.distributions.util import gammaincinv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can move this import to the top.

@@ -411,3 +412,327 @@ def tree_flatten(self):
@classmethod
def tree_unflatten(cls, aux_data, params):
return cls(batch_shape=aux_data)


def TruncatedGamma(base_gamma, low=None, high=None, validate_args=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to expose the parameters of Gamma here (TruncatedGamma(concentration, rate, low=..., high=...), rather than using a nested pattern. There are a couple of benefits with that:

  • parameters of the distribution is defined probably in args_constraints
  • it is easier to test
  • no need to have flatten/unflatten logic

base_gamma = Gamma.tree_unflatten(base_aux, base_flatten)
return cls(base_gamma, low=low)

@validate_sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, currently validate_sample logic does not work with cdf :(

# until jax/lax has direct implementation we'll need to rely on tfp
# https://github.com/pyro-ppl/numpyro/issues/969
try:
import tensorflow_probability as tfpm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can import tensorflow_probability.substrates.jax directly, to make sure that jax backend is installed.

return lprob - jnp.log(1.0 - lscale)

def _scale_moment(self, t):
assert t > -self.base_gamma.concentration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work for jax arrays (which might have abstract values under jit compiling). You can use jnp.where to mask out the invalid cases like this.

def log_prob(self, value):
lprob = self.base_gamma.log_prob(value)
lscale = self.base_gamma.cdf(self.low)
return lprob - jnp.log(1.0 - lscale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use log1p(-lscale) for a better numerical result

@fehiepsi
Copy link
Member

fehiepsi commented Nov 3, 2022

@quattro Looking the the PR is is the good shape - just have small comments above. Any chance we can have this in the next numpyro release?

@quattro
Copy link
Contributor Author

quattro commented Nov 3, 2022

Will try my best. Should have some time closer to Thanksgiving holidays, does that fall before next release schedule?

@fehiepsi
Copy link
Member

fehiepsi commented Nov 3, 2022

Absolutely, there is no plan for the release date yet. Thank you!

@disadone
Copy link

Will we have this feature in the future?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants