diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 9220048c7..6642bc1cf 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,29 +4,39 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, Union, runtime_checkable from typing_extensions import ParamSpec, TypeAlias import jax +from jax import Array from jax.typing import ArrayLike +from numpyro.distributions import MaskedDistribution + P = ParamSpec("P") ModelT: TypeAlias = Callable[P, Any] Message: TypeAlias = dict[str, Any] TraceT: TypeAlias = OrderedDict[str, Message] +PRNGKeyT: TypeAlias = Union[jax.dtypes.prng_key, ArrayLike] @runtime_checkable class ConstraintT(Protocol): - is_discrete: bool = ... - event_dim: int = ... + # is_discrete: bool = ... + # event_dim: int = ... - def __call__(self, x: ArrayLike) -> ArrayLike: ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __call__(self, x: Array) -> Array: ... def __repr__(self) -> str: ... - def check(self, value: ArrayLike) -> ArrayLike: ... - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: ... + def check(self, value: Array) -> Array: ... + def feasible_like(self, prototype: Array) -> Array: ... + + @property + def is_discrete(self) -> bool: ... + @property + def event_dim(self) -> int: ... @runtime_checkable @@ -38,27 +48,35 @@ class DistributionT(Protocol): """ arg_constraints: dict[str, ConstraintT] = ... - support: ConstraintT = ... - has_enumerate_support: bool = ... reparametrized_params: list[str] = ... - _validate_args: bool = ... pytree_data_fields: tuple = ... pytree_aux_fields: tuple = ... def __call__(self, *args: Any, **kwargs: Any) -> Any: ... def rsample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: ... + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: ... def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: ... - def log_prob(self, value: ArrayLike) -> ArrayLike: ... - def cdf(self, value: ArrayLike) -> ArrayLike: ... - def icdf(self, q: ArrayLike) -> ArrayLike: ... - def entropy(self) -> ArrayLike: ... - def enumerate_support(self, expand: bool = True) -> ArrayLike: ... + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: ... + def log_prob(self, value: Array) -> Array: ... + def cdf(self, value: Array) -> Array: ... + def icdf(self, q: Array) -> Array: ... + def entropy(self) -> Array: ... + def enumerate_support(self, expand: bool = True) -> Array: ... def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: ... + def to_event( + self, reinterpreted_batch_ndims: Optional[int] = None + ) -> "DistributionT": ... + def expand(self, batch_shape: tuple[int, ...]) -> "DistributionT": ... + def expand_by(self, sample_shape: tuple[int, ...]) -> "DistributionT": ... + def mask(self, mask: Array) -> MaskedDistribution: ... + @classmethod + def infer_shapes(cls, *args, **kwargs): ... + + @property + def support(self) -> ConstraintT: ... @property def batch_shape(self) -> tuple[int, ...]: ... @@ -76,6 +94,8 @@ def variance(self) -> ArrayLike: ... @property def is_discrete(self) -> bool: ... + @property + def has_enumerate_support(self) -> bool: ... # To avoid breaking changes for user code that uses `DistributionLike` @@ -84,20 +104,18 @@ def is_discrete(self) -> bool: ... @runtime_checkable class TransformT(Protocol): - domain = ConstraintT - codomain = ConstraintT - _inv: "TransformT" = None - - def __call__(self, x: ArrayLike) -> ArrayLike: ... - def _inverse(self, y: ArrayLike) -> ArrayLike: ... - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: ... - def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ... + domain: ConstraintT = ... + codomain: ConstraintT = ... + _inv: Optional["TransformT"] = None + + def __call__(self, x: Array) -> Array: ... + def _inverse(self, y: Array) -> Array: ... + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: ... + def call_with_intermediates(self, x: Array) -> tuple[Array, Optional[Array]]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... @property - def inv(self) -> "TransformT": ... + def inv(self) -> Optional["TransformT"]: ... @property - def sign(self) -> ArrayLike: ... + def sign(self) -> Array: ... diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index 1b71361a1..149ae4866 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -9,7 +9,6 @@ from jax import Array import jax.numpy as jnp -from jax.typing import ArrayLike import numpyro from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic @@ -62,7 +61,7 @@ def linear_approximation( def hsgp_squared_exponential( - x: ArrayLike, + x: Array, alpha: float, length: float, ell: float | int | list[float | int], @@ -84,7 +83,7 @@ def hsgp_squared_exponential( 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayLike x: input data + :param Array x: input data :param float alpha: amplitude of the squared exponential kernel :param float length: length scale of the squared exponential kernel :param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so @@ -110,7 +109,7 @@ def hsgp_squared_exponential( def hsgp_matern( - x: ArrayLike, + x: Array, nu: float, alpha: float, length: float, @@ -133,7 +132,7 @@ def hsgp_matern( 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayLike x: input data + :param Array x: input data :param float nu: smoothness parameter :param float alpha: amplitude of the squared exponential kernel :param float length: length scale of the squared exponential kernel @@ -160,7 +159,7 @@ def hsgp_matern( def hsgp_periodic_non_centered( - x: ArrayLike, alpha: float, length: float, w0: float, m: int + x: Array, alpha: float, length: float, w0: float, m: int ) -> Array: """ Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization. @@ -172,7 +171,7 @@ def hsgp_periodic_non_centered( 1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). - :param ArrayLike x: input data + :param Array x: input data :param float alpha: amplitude :param float length: length scale :param float w0: frequency of the periodic kernel diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index ae0cbbd6b..a1250ae00 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -11,7 +11,6 @@ from jax import Array import jax.numpy as jnp -from jax.typing import ArrayLike def eigenindices(m: list[int] | int, dim: int) -> Array: @@ -76,7 +75,7 @@ def eigenindices(m: list[int] | int, dim: int) -> Array: def sqrt_eigenvalues( - ell: ArrayLike | list[int | float], m: list[int] | int, dim: int + ell: Array | list[int | float], m: list[int] | int, dim: int ) -> Array: """ The first :math:`m^\\star \\times D` square root of eigenvalues of the laplacian operator in @@ -101,7 +100,7 @@ def sqrt_eigenvalues( return S * jnp.pi / 2 / ell_ # dim x prod(m) array of eigenvalues -def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -> Array: +def eigenfunctions(x: Array, ell: float | list[float], m: int | list[int]) -> Array: """ The first :math:`m^\\star` eigenfunctions of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]` @@ -137,7 +136,7 @@ def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) - 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020) - :param ArrayLike x: The points at which to evaluate the eigenfunctions. + :param Array x: The points at which to evaluate the eigenfunctions. If `x` is 1D the problem is assumed unidimensional. Otherwise, the dimension of the input space is inferred as the last dimension of `x`. Other dimensions are treated as batch dimensions. @@ -162,11 +161,11 @@ def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) - ) -def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Array]: +def eigenfunctions_periodic(x: Array, w0: float, m: int) -> tuple[Array, Array]: """ Basis functions for the approximation of the periodic kernel. - :param ArrayLike x: The points at which to evaluate the eigenfunctions. + :param Array x: The points at which to evaluate the eigenfunctions. :param float w0: The frequency of the periodic kernel. :param int m: The number of eigenfunctions to compute. @@ -188,13 +187,13 @@ def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Arr return cosines, sines -def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) -> Array: +def _convert_ell(ell: float | int | list[float | int] | Array, dim: int) -> Array: """ Process the half-length of the approximation interval and return a `D \\times 1` array. If `ell` is a scalar, it is converted to a list of length dim, then transformed into an Array. - :param float | int | list[float | int] | ArrayLike ell: The length of the interval in each dimension divided by 2. + :param float | int | list[float | int] | Array ell: The length of the interval in each dimension divided by 2. If a float or int, the same length is used in each dimension. :param int dim: The dimension of the space. diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index 383a57a61..1b5f89ee0 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -10,7 +10,6 @@ from jax import Array, vmap import jax.numpy as jnp from jax.scipy import special -from jax.typing import ArrayLike from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues @@ -20,7 +19,7 @@ def align_param(dim, param): def spectral_density_squared_exponential( - dim: int, w: ArrayLike, alpha: float, length: float | ArrayLike + dim: int, w: Array, alpha: float, length: float | Array ) -> Array: """ Spectral density of the squared exponential kernel. @@ -41,7 +40,7 @@ def spectral_density_squared_exponential( approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param int dim: dimension - :param ArrayLike w: frequency + :param Array w: frequency :param float alpha: amplitude :param float length: length scale :return: spectral density value @@ -54,7 +53,7 @@ def spectral_density_squared_exponential( def spectral_density_matern( - dim: int, nu: float, w: ArrayLike, alpha: float, length: float | ArrayLike + dim: int, nu: float, w: Array, alpha: float, length: float | Array ) -> float: """ Spectral density of the Matérn kernel. @@ -77,7 +76,7 @@ def spectral_density_matern( :param int dim: dimension :param float nu: smoothness - :param ArrayLike w: frequency + :param Array w: frequency :param float alpha: amplitude :param float length: length scale :return: spectral density value diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index de0ad382e..9a7d7e9c8 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -3,11 +3,11 @@ import copy from functools import singledispatch -from typing import Union import jax import jax.numpy as jnp +from numpyro._typing import DistributionT from numpyro.distributions import constraints from numpyro.distributions.conjugate import ( BetaBinomial, @@ -17,7 +17,6 @@ NegativeBinomialLogits, NegativeBinomialProbs, ) -from numpyro.distributions.constraints import Constraint from numpyro.distributions.continuous import ( CAR, LKJ, @@ -59,7 +58,6 @@ AffineTransform, CorrCholeskyTransform, PowerTransform, - Transform, ) from numpyro.distributions.truncated import ( LeftTruncatedDistribution, @@ -69,7 +67,7 @@ @singledispatch -def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs): +def vmap_over(d: DistributionT, **kwargs): raise NotImplementedError @@ -498,12 +496,12 @@ def _vmap_over_half_normal(dist: HalfNormal, scale=None): @singledispatch -def promote_batch_shape(d: Distribution): +def promote_batch_shape(d: DistributionT) -> DistributionT: raise NotImplementedError @promote_batch_shape.register -def _default_promote_batch_shape(d: Distribution): +def _default_promote_batch_shape(d: DistributionT) -> DistributionT: attr_batch_shapes = [d.batch_shape] for attr_name, constraint in d.arg_constraints.items(): try: @@ -515,12 +513,12 @@ def _default_promote_batch_shape(d: Distribution): attr_batch_shapes.append(jnp.shape(attr)[:attr_batch_ndim]) resolved_batch_shape = jnp.broadcast_shapes(*attr_batch_shapes) new_self = copy.deepcopy(d) - new_self._batch_shape = resolved_batch_shape + new_self._batch_shape = resolved_batch_shape # type: ignore return new_self @promote_batch_shape.register -def _promote_batch_shape_expanded(d: ExpandedDistribution): +def _promote_batch_shape_expanded(d: ExpandedDistribution) -> ExpandedDistribution: orig_delta_batch_shape = d.batch_shape[ : len(d.batch_shape) - len(d.base_dist.batch_shape) ] @@ -560,7 +558,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): @promote_batch_shape.register -def _promote_batch_shape_masked(d: MaskedDistribution): +def _promote_batch_shape_masked(d: MaskedDistribution) -> MaskedDistribution: new_self = copy.copy(d) new_base_dist = promote_batch_shape(d.base_dist) new_self._batch_shape = new_base_dist.batch_shape @@ -569,7 +567,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution): @promote_batch_shape.register -def _promote_batch_shape_independent(d: Independent): +def _promote_batch_shape_independent(d: Independent) -> DistributionT: new_self = copy.copy(d) new_base_dist = promote_batch_shape(d.base_dist) new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim] @@ -578,5 +576,5 @@ def _promote_batch_shape_independent(d: Independent): @promote_batch_shape.register -def _promote_batch_shape_unit(d: Unit): +def _promote_batch_shape_unit(d: Unit) -> Unit: return d diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index 19932ae72..6dfac54e7 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -4,13 +4,12 @@ from typing import Optional -import jax -from jax import lax, nn, random +from jax import Array, lax, nn, random import jax.numpy as jnp from jax.scipy.special import betainc, betaln, gammaln from jax.typing import ArrayLike -from numpyro._typing import ConstraintT +from numpyro._typing import ConstraintT, PRNGKeyT from numpyro.distributions import constraints from numpyro.distributions.continuous import Beta, Dirichlet, Gamma from numpyro.distributions.discrete import ( @@ -54,8 +53,8 @@ class BetaBinomial(Distribution): def __init__( self, - concentration1: ArrayLike, - concentration0: ArrayLike, + concentration1: Array, + concentration0: Array, total_count: int = 1, *, validate_args: Optional[bool] = None, @@ -72,8 +71,8 @@ def __init__( super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) @@ -82,7 +81,7 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return ( -_log_beta_1(self.total_count - value + 1, value) + betaln( @@ -132,7 +131,7 @@ class DirichletMultinomial(Distribution): def __init__( self, - concentration: ArrayLike, + concentration: Array, total_count: int = 1, *, total_count_max: Optional[int] = None, @@ -159,8 +158,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) key_dirichlet, key_multinom = random.split(key) probs = self._dirichlet.sample(key_dirichlet, sample_shape) @@ -171,7 +170,7 @@ def sample( ).sample(key_multinom) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: alpha = self.concentration return _log_beta_1(alpha.sum(-1), value.sum(-1)) - _log_beta_1( alpha, value @@ -195,7 +194,7 @@ def support(self) -> ConstraintT: @staticmethod def infer_shapes( - concentration: ArrayLike, total_count=() + concentration: Array, total_count=() ) -> tuple[tuple[int, ...], tuple[int, ...]]: batch_shape = lax.broadcast_shapes(concentration[:-1], total_count) event_shape = concentration[-1:] @@ -222,8 +221,8 @@ class GammaPoisson(Distribution): def __init__( self, - concentration: ArrayLike, - rate: ArrayLike = 1.0, + concentration: Array, + rate: Array = 1.0, *, validate_args: Optional[bool] = None, ): @@ -234,15 +233,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) key_gamma, key_poisson = random.split(key) rate = self._gamma.sample(key_gamma, sample_shape) return Poisson(rate).sample(key_poisson) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: post_value = self.concentration + value return ( -betaln(self.concentration, value + 1) @@ -259,15 +258,15 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return self.concentration / jnp.square(self.rate) * (1 + self.rate) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0)) return bt def NegativeBinomial( total_count: int, - probs: Optional[ArrayLike] = None, - logits: Optional[ArrayLike] = None, + probs: Optional[Array] = None, + logits: Optional[Array] = None, *, validate_args: Optional[bool] = None, ): @@ -289,7 +288,7 @@ class NegativeBinomialProbs(GammaPoisson): def __init__( self, total_count: int, - probs: ArrayLike, + probs: Array, *, validate_args: Optional[bool] = None, ): @@ -309,7 +308,7 @@ class NegativeBinomialLogits(GammaPoisson): def __init__( self, total_count: int, - logits: ArrayLike, + logits: Array, *, validate_args: Optional[bool] = None, ): @@ -319,7 +318,7 @@ def __init__( super().__init__(concentration, rate, validate_args=validate_args) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return -( self.total_count * nn.softplus(self.logits) + value * nn.softplus(-self.logits) @@ -341,8 +340,8 @@ class NegativeBinomial2(GammaPoisson): def __init__( self, - mean: ArrayLike, - concentration: ArrayLike, + mean: Array, + concentration: Array, *, validate_args: Optional[bool] = None, ): @@ -351,11 +350,11 @@ def __init__( def ZeroInflatedNegativeBinomial2( - mean: ArrayLike, - concentration: ArrayLike, + mean: Array, + concentration: Array, *, - gate: Optional[ArrayLike] = None, - gate_logits: Optional[ArrayLike] = None, + gate: Optional[Array] = None, + gate_logits: Optional[Array] = None, validate_args: Optional[bool] = None, ): return ZeroInflatedDistribution( diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7671b0685..1a15291fd 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -68,6 +68,7 @@ import numpy as np +from jax import Array import jax.numpy import jax.numpy as jnp from jax.tree_util import register_pytree_node @@ -91,20 +92,20 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: raise NotImplementedError def __repr__(self) -> str: return self.__class__.__name__[1:] + "()" - def check(self, value: ArrayLike) -> ArrayLike: + def check(self, value: Array) -> Array: """ Returns a byte tensor of `sample_shape + batch_shape` indicating whether each event in value satisfies this constraint. """ return self(value) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: """ Get a feasible value which has the same shape as dtype as `prototype`. """ @@ -143,17 +144,17 @@ def __new__(cls): class _Boolean(_SingletonConstraint): is_discrete = True - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x == 0) | (x == 1) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.zeros_like(prototype) class _CorrCholesky(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -165,7 +166,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: unit_norm_row = jnp.all(jnp.abs(x_norm - 1) <= tol, axis=-1) return lower_triangular & positive_diagonal & unit_norm_row - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -174,7 +175,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _CorrMatrix(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -186,7 +187,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return symmetric & positive & unit_variance - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -219,14 +220,14 @@ def is_discrete(self): return self._is_discrete @property - def event_dim(self) -> int: + def event_dim(self) -> int: # type: ignore[override] if self._event_dim is NotImplemented: raise NotImplementedError(".event_dim cannot be determined statically") return self._event_dim def __call__( self, - x: Optional[ArrayLike] = None, + x: Optional[Array] = None, *, is_discrete: bool = NotImplemented, event_dim: int = NotImplemented, @@ -242,12 +243,12 @@ def __call__( event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: return ( type(self) is type(other) - and self._is_discrete == other._is_discrete - and self._event_dim == other._event_dim - ) + and self._is_discrete == other._is_discrete # type: ignore[attr-defined] + and self._event_dim == other._event_dim # type: ignore[attr-defined] + ) # type: ignore def tree_flatten(self): return (), ( @@ -266,7 +267,7 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Optional[Array]) -> ConstraintT: # type: ignore[override] if not callable(x): return super().__call__(x) @@ -276,7 +277,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: # ... return dependent_property( x, is_discrete=self._is_discrete, event_dim=self._event_dim - ) + ) # type: ignore def is_dependent(constraint): @@ -287,7 +288,7 @@ class _GreaterThan(Constraint): def __init__(self, lower_bound): self.lower_bound = lower_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return x > self.lower_bound def __repr__(self) -> str: @@ -295,26 +296,26 @@ def __repr__(self) -> str: fmt_string += "(lower_bound={})".format(self.lower_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype)) def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore class _GreaterThanEq(_GreaterThan): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return x >= self.lower_bound - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThanEq): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore class _Positive(_SingletonConstraint, _GreaterThan): @@ -334,24 +335,26 @@ class _IndependentConstraint(Constraint): independent entries are valid. """ - def __init__(self, base_constraint, reinterpreted_batch_ndims): + def __init__( + self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int + ) -> None: assert isinstance(base_constraint, Constraint) assert isinstance(reinterpreted_batch_ndims, int) assert reinterpreted_batch_ndims >= 0 if isinstance(base_constraint, _IndependentConstraint): reinterpreted_batch_ndims = ( - reinterpreted_batch_ndims + base_constraint.reinterpreted_batch_ndims + reinterpreted_batch_ndims + base_constraint.reinterpreted_batch_ndims # type: ignore ) - base_constraint = base_constraint.base_constraint + base_constraint = base_constraint.base_constraint # type: ignore self.base_constraint = base_constraint self.reinterpreted_batch_ndims = reinterpreted_batch_ndims super().__init__() @property - def event_dim(self) -> int: + def event_dim(self) -> int: # type: ignore[override] return self.base_constraint.event_dim + self.reinterpreted_batch_ndims - def __call__(self, value: ArrayLike) -> ArrayLike: + def __call__(self, value: Array) -> Array: result = self.base_constraint(value) if self.reinterpreted_batch_ndims == 0: return result @@ -376,7 +379,7 @@ def __repr__(self) -> str: self.reinterpreted_batch_ndims, ) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return self.base_constraint.feasible_like(prototype) def tree_flatten(self): @@ -385,13 +388,13 @@ def tree_flatten(self): {"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims}, ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IndependentConstraint): return False return (self.base_constraint == other.base_constraint) & ( self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims - ) + ) # type: ignore class _RealVector(_IndependentConstraint, _SingletonConstraint): @@ -405,10 +408,10 @@ def __init__(self) -> None: class _LessThan(Constraint): - def __init__(self, upper_bound: ArrayLike) -> None: + def __init__(self, upper_bound: Array) -> None: self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return x < self.upper_bound def __repr__(self) -> str: @@ -416,36 +419,36 @@ def __repr__(self) -> str: fmt_string += "(upper_bound={})".format(self.upper_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype)) def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThan): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore class _LessThanEq(_LessThan): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return x <= self.upper_bound - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThanEq): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore class _IntegerInterval(Constraint): is_discrete = True - def __init__(self, lower_bound: ArrayLike, upper_bound: ArrayLike) -> None: + def __init__(self, lower_bound: Array, upper_bound: Array) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0) def __repr__(self) -> str: @@ -455,7 +458,7 @@ def __repr__(self) -> str: ) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -464,22 +467,22 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IntegerInterval): return False return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( self.upper_bound, other.upper_bound - ) + ) # type: ignore class _IntegerGreaterThan(Constraint): is_discrete = True - def __init__(self, lower_bound: ArrayLike) -> None: + def __init__(self, lower_bound: Array) -> None: self.lower_bound = lower_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x % 1 == 0) & (x >= self.lower_bound) def __repr__(self) -> str: @@ -487,26 +490,26 @@ def __repr__(self) -> str: fmt_string += "(lower_bound={})".format(self.lower_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IntegerGreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): def __init__(self) -> None: - super().__init__(1) + super().__init__(jnp.ones((), dtype=jnp.result_type(int))) class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan): def __init__(self) -> None: - super().__init__(0) + super().__init__(jnp.zeros((), dtype=jnp.result_type(int))) class _Interval(Constraint): @@ -514,7 +517,7 @@ def __init__(self, lower_bound: ArrayLike, upper_bound: ArrayLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x >= self.lower_bound) & (x <= self.upper_bound) def __repr__(self) -> str: @@ -524,17 +527,17 @@ def __repr__(self) -> str: ) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _Interval): return False return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( self.upper_bound, other.upper_bound - ) + ) # type: ignore def tree_flatten(self): return (self.lower_bound, self.upper_bound), ( @@ -545,16 +548,17 @@ def tree_flatten(self): class _Circular(_SingletonConstraint, _Interval): def __init__(self) -> None: - super().__init__(-math.pi, math.pi) + pi = jnp.asarray(math.pi) + super().__init__(-pi, pi) class _UnitInterval(_SingletonConstraint, _Interval): def __init__(self) -> None: - super().__init__(0.0, 1.0) + super().__init__(jnp.zeros(()), jnp.ones(())) class _OpenInterval(_Interval): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x > self.lower_bound) & (x < self.upper_bound) def __repr__(self) -> str: @@ -568,7 +572,7 @@ def __repr__(self) -> str: class _LowerCholesky(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -577,7 +581,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) return lower_triangular & positive_diagonal - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -587,13 +591,13 @@ class _Multinomial(Constraint): is_discrete = True event_dim = 1 - def __init__(self, upper_bound: ArrayLike) -> None: + def __init__(self, upper_bound: Array) -> None: self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: pad_width = ((0, 0),) * jax.numpy.ndim(self.upper_bound) + ( (0, prototype.shape[-1] - 1), ) @@ -603,10 +607,10 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _Multinomial): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore class _L1Ball(_SingletonConstraint): @@ -617,22 +621,22 @@ class _L1Ball(_SingletonConstraint): event_dim = 1 reltol = 10.0 # Relative to finfo.eps. - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy eps = jnp.finfo(x.dtype).eps return jnp.abs(x).sum(axis=-1) < 1 + self.reltol * eps - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.zeros_like(prototype) class _OrderedVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return (x[..., 1:] > x[..., :-1]).all(axis=-1) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.arange(float(prototype.shape[-1])), prototype.shape ) @@ -641,7 +645,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _PositiveDefinite(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -649,7 +653,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: positive = jnp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -658,20 +662,20 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _PositiveDefiniteCirculantVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = 10 * jnp.finfo(x.dtype).eps rfft = jnp.fft.rfft(x) return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jnp.zeros_like(prototype).at[..., 0].set(1.0) class _PositiveSemiDefinite(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -679,7 +683,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 return symmetric & nonnegative - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -693,41 +697,41 @@ class _PositiveOrderedVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return ordered_vector.check(x) & independent(positive, 1).check(x) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.exp(jax.numpy.arange(float(prototype.shape[-1]))), prototype.shape ) class _Complex(_SingletonConstraint): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.zeros_like(prototype) class _Real(_SingletonConstraint): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.zeros_like(prototype) class _Simplex(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: x_sum = x.sum(axis=-1) return (x >= 0).all(axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) @@ -735,12 +739,12 @@ class _SoftplusPositive(_SingletonConstraint, _GreaterThan): def __init__(self) -> None: super().__init__(lower_bound=0.0) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.full(jax.numpy.shape(prototype), np.log(2)) class _SoftplusLowerCholesky(_LowerCholesky): - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]) * np.log(2), prototype.shape ) @@ -758,14 +762,14 @@ class _Sphere(_SingletonConstraint): event_dim = 1 reltol = 10.0 # Relative to finfo.eps. - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy eps = jnp.finfo(x.dtype).eps norm = jnp.linalg.norm(x, axis=-1) error = jnp.abs(norm - 1) return error < self.reltol * eps * x.shape[-1] ** 0.5 - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) @@ -774,18 +778,18 @@ def __init__(self, event_dim: int = 1) -> None: self.event_dim = event_dim super().__init__() - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 - zerosum_true = True + zerosum_true = jnp.ones((), dtype=bool) for dim in range(-self.event_dim, 0): zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol) return zerosum_true - def __eq__(self, other: ConstraintT) -> bool: - return type(self) is type(other) and self.event_dim == other.event_dim + def __eq__(self, other: object) -> bool: + return type(self) is type(other) and self.event_dim == other.event_dim # type: ignore[attr-defined] - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: Array) -> Array: return jax.numpy.zeros_like(prototype) def tree_flatten(self): @@ -795,40 +799,40 @@ def tree_flatten(self): # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 -boolean: ConstraintT = _Boolean() -circular: ConstraintT = _Circular() -complex: ConstraintT = _Complex() -corr_cholesky: ConstraintT = _CorrCholesky() -corr_matrix: ConstraintT = _CorrMatrix() -dependent: ConstraintT = _Dependent() -greater_than: ConstraintT = _GreaterThan -greater_than_eq: ConstraintT = _GreaterThanEq -less_than: ConstraintT = _LessThan -less_than_eq: ConstraintT = _LessThanEq -independent: ConstraintT = _IndependentConstraint -integer_interval: ConstraintT = _IntegerInterval -integer_greater_than: ConstraintT = _IntegerGreaterThan -interval: ConstraintT = _Interval -l1_ball: ConstraintT = _L1Ball() -lower_cholesky: ConstraintT = _LowerCholesky() -scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() -multinomial: ConstraintT = _Multinomial -nonnegative: ConstraintT = _Nonnegative() -nonnegative_integer: ConstraintT = _IntegerNonnegative() -ordered_vector: ConstraintT = _OrderedVector() -positive: ConstraintT = _Positive() -positive_definite: ConstraintT = _PositiveDefinite() -positive_definite_circulant_vector: ConstraintT = _PositiveDefiniteCirculantVector() -positive_semidefinite: ConstraintT = _PositiveSemiDefinite() -positive_integer: ConstraintT = _IntegerPositive() -positive_ordered_vector: ConstraintT = _PositiveOrderedVector() -real: ConstraintT = _Real() -real_vector: ConstraintT = _RealVector() -real_matrix: ConstraintT = _RealMatrix() -simplex: ConstraintT = _Simplex() -softplus_lower_cholesky: ConstraintT = _SoftplusLowerCholesky() -softplus_positive: ConstraintT = _SoftplusPositive() -sphere: ConstraintT = _Sphere() -unit_interval: ConstraintT = _UnitInterval() -open_interval: ConstraintT = _OpenInterval -zero_sum: ConstraintT = _ZeroSum +boolean = _Boolean() +circular = _Circular() +complex = _Complex() +corr_cholesky = _CorrCholesky() +corr_matrix = _CorrMatrix() +dependent = _Dependent() +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +less_than_eq = _LessThanEq +independent = _IndependentConstraint +integer_interval = _IntegerInterval +integer_greater_than = _IntegerGreaterThan +interval = _Interval +l1_ball = _L1Ball() +lower_cholesky = _LowerCholesky() +scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky() +multinomial = _Multinomial +nonnegative = _Nonnegative() +nonnegative_integer = _IntegerNonnegative() +ordered_vector = _OrderedVector() +positive = _Positive() +positive_definite = _PositiveDefinite() +positive_definite_circulant_vector = _PositiveDefiniteCirculantVector() +positive_semidefinite = _PositiveSemiDefinite() +positive_integer = _IntegerPositive() +positive_ordered_vector = _PositiveOrderedVector() +real = _Real() +real_vector = _RealVector() +real_matrix = _RealMatrix() +simplex = _Simplex() +softplus_lower_cholesky = _SoftplusLowerCholesky() +softplus_positive = _SoftplusPositive() +sphere = _Sphere() +unit_interval = _UnitInterval() +open_interval = _OpenInterval +zero_sum = _ZeroSum diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 583c88d58..6b2ca8f3e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -30,7 +30,6 @@ import numpy as np -import jax from jax import Array, lax, vmap from jax.experimental.sparse import BCOO from jax.lax import scan @@ -55,13 +54,10 @@ from jax.scipy.stats import norm as jax_norm from jax.typing import ArrayLike +from numpyro._typing import ConstraintT, DistributionT, PRNGKeyT from numpyro.distributions import constraints from numpyro.distributions.discrete import _to_logits_bernoulli -from numpyro.distributions.distribution import ( - Distribution, - DistributionT, - TransformedDistribution, -) +from numpyro.distributions.distribution import Distribution, TransformedDistribution from numpyro.distributions.transforms import ( AffineTransform, CholeskyTransform, @@ -96,12 +92,12 @@ class AsymmetricLaplace(Distribution): arg_constraints = { - "loc": constraints.real, - "scale": constraints.positive, - "asymmetry": constraints.positive, + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + "asymmetry": constraints.positive, # type: ignore[has-type] } reparametrized_params = ["loc", "scale", "asymmetry"] - support = constraints.real + support = constraints.real # type: ignore[has-type] def __init__( self, @@ -129,7 +125,7 @@ def left_scale(self): def right_scale(self): return self.scale / self.asymmetry - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._validate_args: self._validate_sample(value) z = value - self.loc @@ -137,8 +133,8 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: return z - jnp.log(self.left_scale + self.right_scale) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) shape = (2,) + sample_shape + self.batch_shape + self.event_shape u, v = random.exponential(key, shape=shape) @@ -160,7 +156,7 @@ def variance(self) -> ArrayLike: variance = p * left**2 + q * right**2 + p * q * total**2 return jnp.broadcast_to(variance, self.batch_shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: z = value - self.loc k = self.asymmetry return jnp.where( @@ -169,7 +165,7 @@ def cdf(self, value: ArrayLike) -> ArrayLike: k**2 / (1 + k**2) * jnp.exp(-jnp.abs(z) / self.left_scale), ) - def icdf(self, value: ArrayLike) -> ArrayLike: + def icdf(self, value: Array) -> Array: k = self.asymmetry temp = k**2 / (1 + k**2) return jnp.where( @@ -181,17 +177,17 @@ def icdf(self, value: ArrayLike) -> ArrayLike: class Beta(Distribution): arg_constraints = { - "concentration1": constraints.positive, - "concentration0": constraints.positive, + "concentration1": constraints.positive, # type: ignore[has-type] + "concentration0": constraints.positive, # type: ignore[has-type] } reparametrized_params = ["concentration1", "concentration0"] - support = constraints.unit_interval + support = constraints.unit_interval # type: ignore[has-type] pytree_data_fields = ("concentration0", "concentration1", "_dirichlet") def __init__( self, - concentration1: ArrayLike, - concentration0: ArrayLike, + concentration1: Array, + concentration0: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -209,13 +205,13 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return self._dirichlet.sample(key, sample_shape)[..., 0] @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)) @property @@ -227,13 +223,13 @@ def variance(self) -> ArrayLike: total = self.concentration1 + self.concentration0 return self.concentration1 * self.concentration0 / (total**2 * (total + 1)) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return betainc(self.concentration1, self.concentration0, value) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return betaincinv(self.concentration1, self.concentration0, q) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: total = self.concentration0 + self.concentration1 return ( betaln(self.concentration0, self.concentration1) @@ -244,8 +240,11 @@ def entropy(self) -> ArrayLike: class Cauchy(Distribution): - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.real + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.real # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( @@ -262,14 +261,14 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) eps = random.cauchy(key, shape=sample_shape + self.batch_shape) return self.loc + eps * self.scale @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return ( -jnp.log(jnp.pi) - jnp.log(self.scale) @@ -284,23 +283,23 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.full(self.batch_shape, jnp.nan) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: scaled = (value - self.loc) / self.scale return jnp.arctan(scaled) / jnp.pi + 0.5 - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5)) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape) class Dirichlet(Distribution): arg_constraints = { - "concentration": constraints.independent(constraints.positive, 1) + "concentration": constraints.independent(constraints.positive, 1) # type: ignore[has-type] } reparametrized_params = ["concentration"] - support = constraints.simplex + support = constraints.simplex # type: ignore[has-type] def __init__( self, @@ -321,15 +320,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) shape = sample_shape + self.batch_shape samples = random.dirichlet(key, self.concentration, shape=shape) return jnp.clip(samples, jnp.finfo(samples).tiny, 1 - jnp.finfo(samples).eps) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: normalize_term = jnp.sum(gammaln(self.concentration), axis=-1) - gammaln( jnp.sum(self.concentration, axis=-1) ) @@ -353,7 +352,7 @@ def infer_shapes(concentration): event_shape = concentration[-1:] return batch_shape, event_shape - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: (n,) = self.event_shape total = self.concentration.sum(axis=-1) return ( @@ -378,7 +377,9 @@ class EulerMaruyama(Distribution): [1] https://en.wikipedia.org/wiki/Euler-Maruyama_method """ - arg_constraints = {"t": constraints.ordered_vector} + arg_constraints = { + "t": constraints.ordered_vector, # type: ignore[has-type] + } pytree_data_fields = ("t", "init_dist") pytree_aux_fields = ("sde_fn",) @@ -405,13 +406,13 @@ def __init__( batch_shape, event_shape, validate_args=validate_args ) - @constraints.dependent_property(is_discrete=False) - def support(self) -> constraints.Constraint: + @constraints.dependent_property(is_discrete=False) # type: ignore[arg-type] + def support(self) -> ConstraintT: # type: ignore[override] return constraints.independent(constraints.real, self.event_dim) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) batch_shape = sample_shape + self.batch_shape @@ -451,7 +452,7 @@ def scan_fn(init, noise, tm1, dt): return sde_out @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: sample_shape = lax.broadcast_shapes( value.shape[: -self.event_dim], self.batch_shape ) @@ -504,12 +505,14 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: class Exponential(Distribution): reparametrized_params = ["rate"] - arg_constraints = {"rate": constraints.positive} - support = constraints.positive + arg_constraints = { + "rate": constraints.positive, # type: ignore[has-type] + } + support = constraints.positive # type: ignore[has-type] def __init__( self, - rate: ArrayLike = 1.0, + rate: Array = 1.0, *, validate_args: Optional[bool] = None, ) -> None: @@ -519,15 +522,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return ( random.exponential(key, shape=sample_shape + self.batch_shape) / self.rate ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return jnp.log(self.rate) - self.rate * value @property @@ -538,28 +541,28 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.reciprocal(self.rate**2) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return -jnp.expm1(-self.rate * value) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return -jnp.log1p(-q) / self.rate - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return 1 - jnp.log(self.rate) class Gamma(Distribution): arg_constraints = { - "concentration": constraints.positive, - "rate": constraints.positive, + "concentration": constraints.positive, # type: ignore[has-type] + "rate": constraints.positive, # type: ignore[has-type] } - support = constraints.positive + support = constraints.positive # type: ignore[has-type] reparametrized_params = ["concentration", "rate"] def __init__( self, - concentration: ArrayLike, - rate: ArrayLike = 1.0, + concentration: Array, + rate: Array = 1.0, *, validate_args: Optional[bool] = None, ) -> None: @@ -570,14 +573,14 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) shape = sample_shape + self.batch_shape + self.event_shape return random.gamma(key, self.concentration, shape=shape) / self.rate @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: normalize_term = gammaln(self.concentration) - self.concentration * jnp.log( self.rate ) @@ -598,10 +601,10 @@ def variance(self) -> ArrayLike: def cdf(self, x): return gammainc(self.concentration, self.rate * x) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return gammaincinv(self.concentration, q) / self.rate - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return ( self.concentration - jnp.log(self.rate) @@ -611,10 +614,12 @@ def entropy(self) -> ArrayLike: class Chi2(Gamma): - arg_constraints = {"df": constraints.positive} + arg_constraints = { + "df": constraints.positive, # type: ignore[has-type] + } reparametrized_params = ["df"] - def __init__(self, df: ArrayLike, *, validate_args: Optional[bool] = None) -> None: + def __init__(self, df: Array, *, validate_args: Optional[bool] = None) -> None: self.df = df super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args) @@ -642,12 +647,12 @@ class GaussianStateSpace(TransformedDistribution): """ arg_constraints = { - "covariance_matrix": constraints.positive_definite, - "precision_matrix": constraints.positive_definite, - "scale_tril": constraints.lower_cholesky, - "transition_matrix": constraints.real_matrix, + "covariance_matrix": constraints.positive_definite, # type: ignore[has-type] + "precision_matrix": constraints.positive_definite, # type: ignore[has-type] + "scale_tril": constraints.lower_cholesky, # type: ignore[has-type] + "transition_matrix": constraints.real_matrix, # type: ignore[has-type] } - support = constraints.real_matrix + support = constraints.real_matrix # type: ignore[has-type] pytree_aux_fields = ("num_steps",) def __init__( @@ -723,8 +728,10 @@ def precision_matrix(self): class GaussianRandomWalk(Distribution): - arg_constraints = {"scale": constraints.positive} - support = constraints.real_vector + arg_constraints = { + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.real_vector # type: ignore[has-type] reparametrized_params = ["scale"] pytree_aux_fields = ("num_steps",) @@ -746,15 +753,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) shape = sample_shape + self.batch_shape + self.event_shape walks = random.normal(key, shape=shape) return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: init_prob = Normal(0.0, self.scale).log_prob(value[..., 0]) scale = jnp.expand_dims(self.scale, -1) step_probs = Normal(value[..., :-1], scale).log_prob(value[..., 1:]) @@ -774,8 +781,10 @@ def variance(self) -> ArrayLike: class HalfCauchy(Distribution): reparametrized_params = ["scale"] - support = constraints.positive - arg_constraints = {"scale": constraints.positive} + support = constraints.positive # type: ignore[has-type] + arg_constraints = { + "scale": constraints.positive, # type: ignore[has-type] + } pytree_data_fields = ("_cauchy", "scale") def __init__( @@ -791,19 +800,19 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return jnp.abs(self._cauchy.sample(key, sample_shape)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return self._cauchy.log_prob(value) + jnp.log(2) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return self._cauchy.cdf(value) * 2 - 1 - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return self._cauchy.icdf((q + 1) / 2) @property @@ -817,8 +826,10 @@ def variance(self) -> ArrayLike: class HalfNormal(Distribution): reparametrized_params = ["scale"] - support = constraints.positive - arg_constraints = {"scale": constraints.positive} + support = constraints.positive # type: ignore[has-type] + arg_constraints = { + "scale": constraints.positive, # type: ignore[has-type] + } pytree_data_fields = ("_normal", "scale") def __init__( @@ -834,19 +845,19 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return jnp.abs(self._normal.sample(key, sample_shape)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return self._normal.log_prob(value) + jnp.log(2) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return self._normal.cdf(value) * 2 - 1 - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return self._normal.icdf((q + 1) / 2) @property @@ -866,16 +877,16 @@ class InverseGamma(TransformedDistribution): """ arg_constraints = { - "concentration": constraints.positive, - "rate": constraints.positive, + "concentration": constraints.positive, # type: ignore[has-type] + "rate": constraints.positive, # type: ignore[has-type] } reparametrized_params = ["concentration", "rate"] - support = constraints.positive + support = constraints.positive # type: ignore[has-type] def __init__( self, - concentration: ArrayLike, - rate: ArrayLike = 1.0, + concentration: Array, + rate: Array = 1.0, *, validate_args: Optional[bool] = None, ) -> None: @@ -898,7 +909,7 @@ def variance(self) -> ArrayLike: a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2) return jnp.where(self.concentration <= 2, jnp.inf, a) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return ( self.concentration + jnp.log(self.rate) @@ -924,16 +935,16 @@ class Gompertz(Distribution): """ arg_constraints = { - "concentration": constraints.positive, - "rate": constraints.positive, + "concentration": constraints.positive, # type: ignore[has-type] + "rate": constraints.positive, # type: ignore[has-type] } - support = constraints.positive + support = constraints.positive # type: ignore[has-type] reparametrized_params = ["concentration", "rate"] def __init__( self, - concentration: ArrayLike, - rate: ArrayLike = 1.0, + concentration: Array, + rate: Array = 1.0, *, validate_args: Optional[bool] = None, ) -> None: @@ -944,15 +955,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) random_shape = sample_shape + self.batch_shape + self.event_shape unifs = random.uniform(key, shape=random_shape) return self.icdf(unifs) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: scaled_value = value * self.rate return ( jnp.log(self.concentration) @@ -961,10 +972,10 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: - self.concentration * jnp.expm1(scaled_value) ) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return -jnp.expm1(-self.concentration * jnp.expm1(value * self.rate)) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return jnp.log1p(-jnp.log1p(-q) / self.concentration) / self.rate @property @@ -973,8 +984,11 @@ def mean(self) -> ArrayLike: class Gumbel(Distribution): - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.real + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.real # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( @@ -992,8 +1006,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) standard_gumbel_sample = random.gumbel( key, shape=sample_shape + self.batch_shape + self.event_shape @@ -1001,7 +1015,7 @@ def sample( return self.loc + self.scale * standard_gumbel_sample @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: z = (value - self.loc) / self.scale return -(z + jnp.exp(-z)) - jnp.log(self.scale) @@ -1015,20 +1029,20 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.broadcast_to(jnp.pi**2 / 6.0 * self.scale**2, self.batch_shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return jnp.exp(-jnp.exp((self.loc - value) / self.scale)) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return self.loc - self.scale * jnp.log(-jnp.log(q)) class Kumaraswamy(Distribution): arg_constraints = { - "concentration1": constraints.positive, - "concentration0": constraints.positive, + "concentration1": constraints.positive, # type: ignore[has-type] + "concentration0": constraints.positive, # type: ignore[has-type] } reparametrized_params = ["concentration1", "concentration0"] - support = constraints.unit_interval + support = constraints.unit_interval # type: ignore[has-type] # XXX: This flag is used to approximate the Taylor expansion # of KL(Kumaraswamy||Beta) following # https://arxiv.org/abs/1605.06197 Formula (12) @@ -1038,8 +1052,8 @@ class Kumaraswamy(Distribution): def __init__( self, - concentration1: ArrayLike, - concentration0: ArrayLike, + concentration1: Array, + concentration0: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -1052,8 +1066,8 @@ def __init__( super().__init__(batch_shape=batch_shape, validate_args=validate_args) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) finfo = jnp.finfo(jnp.result_type(float)) u = random.uniform( @@ -1064,7 +1078,7 @@ def sample( return jnp.clip(jnp.exp(log_sample), finfo.tiny, 1 - finfo.eps) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: finfo = jnp.finfo(jnp.result_type(float)) normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1) value_con1 = jnp.clip(value**self.concentration1, None, 1 - finfo.eps) @@ -1086,8 +1100,11 @@ def variance(self) -> ArrayLike: class Laplace(Distribution): - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.real + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.real # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( @@ -1104,8 +1121,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) eps = random.laplace( key, shape=sample_shape + self.batch_shape + self.event_shape @@ -1113,7 +1130,7 @@ def sample( return self.loc + eps * self.scale @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: normalize_term = jnp.log(2 * self.scale) value_scaled = jnp.abs(value - self.loc) / self.scale return -value_scaled - normalize_term @@ -1126,15 +1143,15 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.broadcast_to(2 * self.scale**2, self.batch_shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: scaled = (value - self.loc) / self.scale return 0.5 - 0.5 * jnp.sign(scaled) * jnp.expm1(-jnp.abs(scaled)) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: a = q - 0.5 return self.loc - self.scale * jnp.sign(a) * jnp.log1p(-2 * jnp.abs(a)) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.log(2 * self.scale) + 1 @@ -1185,15 +1202,17 @@ def model(y): # y has dimension N x d Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ - arg_constraints = {"concentration": constraints.positive} + arg_constraints = { + "concentration": constraints.positive, # type: ignore[has-type] + } reparametrized_params = ["concentration"] - support = constraints.corr_matrix + support = constraints.corr_matrix # type: ignore[has-type] pytree_aux_fields = ("dimension", "sample_method") def __init__( self, dimension: int, - concentration: ArrayLike = 1.0, + concentration: Array = 1.0, sample_method: Literal["onion", "cvine"] = "onion", *, validate_args: Optional[bool] = None, @@ -1267,16 +1286,18 @@ def model(y): # y has dimension N x d Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ - arg_constraints = {"concentration": constraints.positive} + arg_constraints = { + "concentration": constraints.positive, # type: ignore[has-type] + } reparametrized_params = ["concentration"] - support = constraints.corr_cholesky + support = constraints.corr_cholesky # type: ignore[has-type] pytree_data_fields = ("_beta", "concentration") pytree_aux_fields = ("dimension", "sample_method") def __init__( self, dimension: int, - concentration: ArrayLike = 1.0, + concentration: Array = 1.0, sample_method: Literal["onion", "cvine"] = "onion", *, validate_args: Optional[bool] = None, @@ -1324,7 +1345,7 @@ def __init__( validate_args=validate_args, ) - def _cvine(self, key: jax.dtypes.prng_key, size): + def _cvine(self, key: Optional[PRNGKeyT], size): # C-vine method first uses beta_dist to generate partial correlations, # then apply signed stick breaking to transform to cholesky factor. # Here is an attempt to prove that using signed stick breaking to @@ -1345,7 +1366,7 @@ def _cvine(self, key: jax.dtypes.prng_key, size): partial_correlation = 2 * beta_sample - 1 # scale to domain to (-1, 1) return signed_stick_breaking_tril(partial_correlation) - def _onion(self, key: jax.dtypes.prng_key, size): + def _onion(self, key: Optional[PRNGKeyT], size): key_beta, key_normal = random.split(key) # Now we generate w term in Algorithm 3.2 of [1]. beta_sample = self._beta.sample(key_beta, size) @@ -1372,8 +1393,8 @@ def _onion(self, key: jax.dtypes.prng_key, size): return add_diag(cholesky, diag) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) if self.sample_method == "onion": return self._onion(key, sample_shape) @@ -1381,7 +1402,7 @@ def sample( return self._cvine(key, sample_shape) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: # Note about computing Jacobian of the transformation from Cholesky factor to # correlation matrix: # @@ -1427,8 +1448,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: class LogNormal(TransformedDistribution): - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.positive + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.positive # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( @@ -1452,13 +1476,16 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale) class Logistic(Distribution): - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.real + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, + } + support = constraints.real # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( @@ -1473,8 +1500,8 @@ def __init__( super(Logistic, self).__init__(batch_shape, validate_args=validate_args) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) z = random.logistic( key, shape=sample_shape + self.batch_shape + self.event_shape @@ -1482,7 +1509,7 @@ def sample( return self.loc + z * self.scale @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: log_exponent = (self.loc - value) / self.scale log_denominator = jnp.log(self.scale) + 2 * nn.softplus(log_exponent) return log_exponent - log_denominator @@ -1496,26 +1523,29 @@ def variance(self) -> ArrayLike: var = (self.scale**2) * (jnp.pi**2) / 3 return jnp.broadcast_to(var, self.batch_shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: scaled = (value - self.loc) / self.scale return expit(scaled) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return self.loc + self.scale * logit(q) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.broadcast_to(jnp.log(self.scale) + 2, self.batch_shape) class LogUniform(TransformedDistribution): - arg_constraints = {"low": constraints.positive, "high": constraints.positive} + arg_constraints = { + "low": constraints.positive, # type: ignore[has-type] + "high": constraints.positive, # type: ignore[has-type] + } reparametrized_params = ["low", "high"] pytree_data_fields = ("low", "high", "_support") def __init__( self, - low: ArrayLike, - high: ArrayLike, + low: Array, + high: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -1526,8 +1556,8 @@ def __init__( base_dist, ExpTransform(), validate_args=validate_args ) - @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> constraints.Constraint: + @constraints.dependent_property(is_discrete=False, event_dim=0) # type: ignore[arg-type] + def support(self) -> ConstraintT: # type: ignore[override] return self._support @property @@ -1541,7 +1571,7 @@ def variance(self) -> ArrayLike: - self.mean**2 ) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: log_low = jnp.log(self.low) log_high = jnp.log(self.high) return (log_low + log_high) / 2 + jnp.log(log_high - log_low) @@ -1613,11 +1643,11 @@ class MatrixNormal(Distribution): """ arg_constraints = { - "loc": constraints.real_vector, - "scale_tril_row": constraints.lower_cholesky, - "scale_tril_column": constraints.lower_cholesky, + "loc": constraints.real_vector, # type: ignore[has-type] + "scale_tril_row": constraints.lower_cholesky, # type: ignore[has-type] + "scale_tril_column": constraints.lower_cholesky, # type: ignore[has-type] } - support = constraints.real_matrix + support = constraints.real_matrix # type: ignore[has-type] reparametrized_params = [ "loc", "scale_tril_row", @@ -1656,8 +1686,8 @@ def mean(self) -> ArrayLike: return jnp.broadcast_to(self.loc, self.shape()) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: eps = random.normal( key, shape=sample_shape + self.batch_shape + self.event_shape ) @@ -1738,12 +1768,12 @@ def _batch_mahalanobis(bL, bx): class MultivariateNormal(Distribution): arg_constraints = { - "loc": constraints.real_vector, - "covariance_matrix": constraints.positive_definite, - "precision_matrix": constraints.positive_definite, - "scale_tril": constraints.lower_cholesky, + "loc": constraints.real_vector, # type: ignore[has-type] + "covariance_matrix": constraints.positive_definite, # type: ignore[has-type] + "precision_matrix": constraints.positive_definite, # type: ignore[has-type] + "scale_tril": constraints.lower_cholesky, # type: ignore[has-type] } - support = constraints.real_vector + support = constraints.real_vector # type: ignore[has-type] reparametrized_params = [ "loc", "covariance_matrix", @@ -1789,8 +1819,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) eps = random.normal( key, shape=sample_shape + self.batch_shape + self.event_shape @@ -1800,7 +1830,7 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: M = _batch_mahalanobis(self.scale_tril, value - self.loc) half_log_det = tri_logabsdet(self.scale_tril) normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log( @@ -1845,7 +1875,7 @@ def infer_shapes( event_shape = lax.broadcast_shapes(event_shape, matrix[-1:]) return batch_shape, event_shape - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: (n,) = self.event_shape half_log_det = tri_logabsdet(self.scale_tril) return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det @@ -1884,12 +1914,12 @@ class CAR(Distribution): """ arg_constraints = { - "loc": constraints.real_vector, - "correlation": constraints.open_interval(-1, 1), - "conditional_precision": constraints.positive, - "adj_matrix": constraints.dependent(is_discrete=False, event_dim=2), + "loc": constraints.real_vector, # type: ignore[has-type] + "correlation": constraints.open_interval(-1, 1), # type: ignore[has-type] + "conditional_precision": constraints.positive, # type: ignore[has-type] + "adj_matrix": constraints.dependent(is_discrete=False, event_dim=2), # type: ignore[has-type] } - support = constraints.real_vector + support = constraints.real_vector # type: ignore[has-type] reparametrized_params = [ "loc", "correlation", @@ -1900,7 +1930,7 @@ class CAR(Distribution): def __init__( self, - loc: ArrayLike, + loc: Array, correlation: Array, conditional_precision: Array, adj_matrix: Array, @@ -1969,14 +1999,14 @@ def __init__( ), "adjacency matrix must be symmetric" def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: # TODO: look into a sparse sampling method mvn = MultivariateNormal(self.mean, precision_matrix=self.precision_matrix) return mvn.sample(key, sample_shape=sample_shape) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: phi = value - self.loc adj_matrix = self.adj_matrix @@ -2075,17 +2105,17 @@ def tree_unflatten(cls, aux_data, params): class MultivariateStudentT(Distribution): arg_constraints = { - "df": constraints.positive, - "loc": constraints.real_vector, - "scale_tril": constraints.lower_cholesky, + "df": constraints.positive, # type: ignore[has-type] + "loc": constraints.real_vector, # type: ignore[has-type] + "scale_tril": constraints.lower_cholesky, # type: ignore[has-type] } - support = constraints.real_vector + support = constraints.real_vector # type: ignore[has-type] reparametrized_params = ["df", "loc", "scale_tril"] pytree_data_fields = ("df", "loc", "scale_tril", "_chi2") def __init__( self, - df: ArrayLike, + df: Array, loc: ArrayLike = 0.0, scale_tril: Optional[Array] = None, *, @@ -2110,8 +2140,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) key_normal, key_chi2 = random.split(key) std_normal = random.normal( @@ -2125,7 +2155,7 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: n = self.scale_tril.shape[-1] Z = ( tri_logabsdet(self.scale_tril) @@ -2224,11 +2254,11 @@ def _batch_lowrank_mahalanobis( class LowRankMultivariateNormal(Distribution): arg_constraints = { - "loc": constraints.real_vector, - "cov_factor": constraints.independent(constraints.real, 2), - "cov_diag": constraints.independent(constraints.positive, 1), + "loc": constraints.real_vector, # type: ignore[has-type] + "cov_factor": constraints.independent(constraints.real, 2), # type: ignore[has-type] + "cov_diag": constraints.independent(constraints.positive, 1), # type: ignore[has-type] } - support = constraints.real_vector + support = constraints.real_vector # type: ignore[has-type] reparametrized_params = ["loc", "cov_factor", "cov_diag"] pytree_data_fields = ("loc", "cov_factor", "cov_diag", "_capacitance_tril") @@ -2279,11 +2309,11 @@ def __init__( ) @property - def mean(self) -> Array: + def mean(self) -> ArrayLike: return self.loc @lazy_property - def variance(self) -> Array: + def variance(self) -> ArrayLike: # type: ignore[override] raw_variance = jnp.square(self.cov_factor).sum(-1) + self.cov_diag return jnp.broadcast_to(raw_variance, self.batch_shape + self.event_shape) @@ -2322,10 +2352,10 @@ def precision_matrix(self) -> Array: return add_diag(-jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) - key_W, key_D = random.split(key) + key_W, key_D = random.split(key) # type: ignore[arg-type] batch_shape = sample_shape + self.batch_shape W_shape = batch_shape + self.cov_factor.shape[-1:] D_shape = batch_shape + self.cov_diag.shape[-1:] @@ -2338,7 +2368,7 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: diff = value - self.loc M = _batch_lowrank_mahalanobis( self.cov_factor, self.cov_diag, diff, self._capacitance_tril @@ -2348,7 +2378,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: ) return -0.5 * (self.loc.shape[-1] * jnp.log(2 * jnp.pi) + log_det + M) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: log_det = _batch_lowrank_logdet( self.cov_factor, self.cov_diag, self._capacitance_tril ) @@ -2363,8 +2393,11 @@ def infer_shapes(loc, cov_factor, cov_diag): class Normal(Distribution): - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.real + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.real # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( @@ -2381,8 +2414,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) eps = random.normal( key, shape=sample_shape + self.batch_shape + self.event_shape @@ -2390,19 +2423,19 @@ def sample( return self.loc + eps * self.scale @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) value_scaled = (value - self.loc) / self.scale return -0.5 * value_scaled**2 - normalize_term - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: scaled = (value - self.loc) / self.scale return ndtr(scaled) - def log_cdf(self, value: ArrayLike) -> ArrayLike: + def log_cdf(self, value: Array) -> Array: return jax_norm.logcdf(value, loc=self.loc, scale=self.scale) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: return self.loc + self.scale * ndtri(q) @property @@ -2413,20 +2446,23 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.broadcast_to(self.scale**2, self.batch_shape) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.broadcast_to( (jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape ) class Pareto(TransformedDistribution): - arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive} + arg_constraints = { + "scale": constraints.positive, # type: ignore[has-type] + "alpha": constraints.positive, # type: ignore[has-type] + } reparametrized_params = ["scale", "alpha"] def __init__( self, - scale: ArrayLike, - alpha: ArrayLike, + scale: Array, + alpha: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -2455,22 +2491,25 @@ def variance(self) -> ArrayLike: return jnp.where(self.alpha <= 2, jnp.inf, a) # override the default behaviour to save computations - @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> constraints.Constraint: + @constraints.dependent_property(is_discrete=False, event_dim=0) # type: ignore[arg-type] + def support(self) -> ConstraintT: # type: ignore[override] return constraints.greater_than(self.scale) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha class RelaxedBernoulliLogits(TransformedDistribution): - arg_constraints = {"temperature": constraints.positive, "logits": constraints.real} - support = constraints.unit_interval + arg_constraints = { + "temperature": constraints.positive, # type: ignore[has-type] + "logits": constraints.real, # type: ignore[has-type] + } + support = constraints.unit_interval # type: ignore[has-type] def __init__( self, - temperature: ArrayLike, - logits: ArrayLike, + temperature: Array, + logits: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -2508,14 +2547,17 @@ class SoftLaplace(Distribution): :param scale: Scale parameter. """ - arg_constraints = {"loc": constraints.real, "scale": constraints.positive} - support = constraints.real + arg_constraints = { + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + } + support = constraints.real # type: ignore[has-type] reparametrized_params = ["loc", "scale"] def __init__( self, - loc: ArrayLike, - scale: ArrayLike, + loc: Array, + scale: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -2524,13 +2566,13 @@ def __init__( super().__init__(batch_shape=batch_shape, validate_args=validate_args) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: z = (value - self.loc) / self.scale return jnp.log(2 / jnp.pi) - jnp.log(self.scale) - jnp.logaddexp(z, -z) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) @@ -2539,11 +2581,11 @@ def sample( return self.icdf(u) # TODO: refactor validate_sample to only does validation check and use it here - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: z = (value - self.loc) / self.scale return jnp.arctan(jnp.exp(z)) * (2 / jnp.pi) - def icdf(self, value: ArrayLike) -> ArrayLike: + def icdf(self, value: Array) -> Array: return jnp.log(jnp.tan(value * (jnp.pi / 2))) * self.scale + self.loc @property @@ -2557,17 +2599,17 @@ def variance(self) -> ArrayLike: class StudentT(Distribution): arg_constraints = { - "df": constraints.positive, - "loc": constraints.real, - "scale": constraints.positive, + "df": constraints.positive, # type: ignore[has-type] + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] } - support = constraints.real + support = constraints.real # type: ignore[has-type] reparametrized_params = ["df", "loc", "scale"] pytree_data_fields = ("df", "loc", "scale", "_chi2") def __init__( self, - df: ArrayLike, + df: Array, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, @@ -2584,8 +2626,8 @@ def __init__( super(StudentT, self).__init__(batch_shape, validate_args=validate_args) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) key_normal, key_chi2 = random.split(key) std_normal = random.normal(key_normal, shape=sample_shape + self.batch_shape) @@ -2594,7 +2636,7 @@ def sample( return self.loc + self.scale * y @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: y = (value - self.loc) / self.scale z = ( jnp.log(self.scale) @@ -2620,7 +2662,7 @@ def variance(self) -> ArrayLike: var = jnp.where(self.df <= 1, jnp.nan, var) return jnp.broadcast_to(var, self.batch_shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: # Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions # X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5) scaled = (value - self.loc) / self.scale @@ -2635,13 +2677,13 @@ def cdf(self, value: ArrayLike) -> ArrayLike: - jnp.sign(scaled) * betainc(0.5 * self.df, 0.5, beta_value) ) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: beta_value = betaincinv(0.5 * self.df, 0.5, 1 - jnp.abs(1 - 2 * q)) scaled_squared = self.df * (1 / beta_value - 1) scaled = jnp.sign(q - 0.5) * jnp.sqrt(scaled_squared) return scaled * self.scale + self.loc - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.broadcast_to( (self.df + 1) / 2 * (digamma((self.df + 1) / 2) - digamma(self.df / 2)) + jnp.log(self.df) / 2 @@ -2653,16 +2695,16 @@ def entropy(self) -> ArrayLike: class Uniform(Distribution): arg_constraints = { - "low": constraints.dependent(is_discrete=False, event_dim=0), - "high": constraints.dependent(is_discrete=False, event_dim=0), + "low": constraints.dependent(is_discrete=False, event_dim=0), # type: ignore[has-type] + "high": constraints.dependent(is_discrete=False, event_dim=0), # type: ignore[has-type] } reparametrized_params = ["low", "high"] pytree_data_fields = ("low", "high", "_support") def __init__( self, - low: ArrayLike = 0.0, - high: ArrayLike = 1.0, + low: Array = 0.0, + high: Array = 1.0, *, validate_args: Optional[bool] = None, ) -> None: @@ -2671,26 +2713,26 @@ def __init__( self._support = constraints.interval(low, high) super().__init__(batch_shape, validate_args=validate_args) - @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> constraints.Constraint: + @constraints.dependent_property(is_discrete=False, event_dim=0) # type: ignore[arg-type] + def support(self) -> ConstraintT: # type: ignore[override] return self._support def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: shape = sample_shape + self.batch_shape return random.uniform(key, shape=shape, minval=self.low, maxval=self.high) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) return -jnp.broadcast_to(jnp.log(self.high - self.low), shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: cdf = (value - self.low) / (self.high - self.low) return jnp.clip(cdf, 0.0, 1.0) - def icdf(self, value: ArrayLike) -> ArrayLike: + def icdf(self, value: Array) -> Array: return self.low + value * (self.high - self.low) @property @@ -2709,22 +2751,22 @@ def infer_shapes( event_shape: tuple[int, ...] = () return batch_shape, event_shape - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.log(self.high - self.low) class Weibull(Distribution): arg_constraints = { - "scale": constraints.positive, - "concentration": constraints.positive, + "scale": constraints.positive, # type: ignore[has-type] + "concentration": constraints.positive, # type: ignore[has-type] } - support = constraints.positive + support = constraints.positive # type: ignore[has-type] reparametrized_params = ["scale", "concentration"] def __init__( self, - scale: ArrayLike, - concentration: ArrayLike, + scale: Array, + concentration: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -2733,8 +2775,8 @@ def __init__( super().__init__(batch_shape=batch_shape, validate_args=validate_args) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return random.weibull_min( key, @@ -2744,14 +2786,14 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: ll = -jnp.power(value / self.scale, self.concentration) ll += jnp.log(self.concentration) ll += (self.concentration - 1.0) * jnp.log(value) ll -= self.concentration * jnp.log(self.scale) return ll - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return 1 - jnp.exp(-((value / self.scale) ** self.concentration)) @property @@ -2765,7 +2807,7 @@ def variance(self) -> ArrayLike: - jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2 ) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return ( jnp.euler_gamma * (1 - 1 / self.concentration) + jnp.log(self.scale / self.concentration) @@ -2785,17 +2827,17 @@ class BetaProportion(Beta): """ arg_constraints = { - "mean": constraints.open_interval(0.0, 1.0), - "concentration": constraints.positive, + "mean": constraints.open_interval(0.0, 1.0), # type: ignore[has-type] + "concentration": constraints.positive, # type: ignore[has-type] } reparametrized_params = ["mean", "concentration"] - support = constraints.unit_interval - pytree_data_fields = ("concentration",) + support = constraints.unit_interval # type: ignore[has-type] + pytree_data_fields = ("concentration",) # type: ignore[assignment] def __init__( self, - mean: ArrayLike, - concentration: ArrayLike, + mean: Array, + concentration: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -2826,19 +2868,19 @@ class AsymmetricLaplaceQuantile(Distribution): """ arg_constraints = { - "loc": constraints.real, - "scale": constraints.positive, - "quantile": constraints.open_interval(0.0, 1.0), + "loc": constraints.real, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] + "quantile": constraints.open_interval(0.0, 1.0), # type: ignore[has-type] } reparametrized_params = ["loc", "scale", "quantile"] - support = constraints.real + support = constraints.real # type: ignore[has-type] pytree_data_fields = ("loc", "scale", "quantile", "_ald") def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, - quantile: ArrayLike = 0.5, + quantile: Array = 0.5, *, validate_args: Optional[bool] = None, ) -> None: @@ -2855,14 +2897,14 @@ def __init__( scale_classic = scale * asymmetry / quantile self._ald = AsymmetricLaplace(loc=loc, scale=scale_classic, asymmetry=asymmetry) - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._validate_args: self._validate_sample(value) return self._ald.log_prob(value) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self._ald.sample(key, sample_shape=sample_shape) @property @@ -2873,10 +2915,10 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return self._ald.variance - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return self._ald.cdf(value) - def icdf(self, value: ArrayLike) -> ArrayLike: + def icdf(self, value: Array) -> Array: return self._ald.icdf(value) @@ -2942,12 +2984,14 @@ class ZeroSumNormal(TransformedDistribution): [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ - arg_constraints = {"scale": constraints.positive} + arg_constraints = { + "scale": constraints.positive, # type: ignore[has-type] + } reparametrized_params = ["scale"] def __init__( self, - scale: ArrayLike, + scale: Array, event_shape: tuple[int, ...], *, validate_args: Optional[bool] = None, @@ -2961,8 +3005,8 @@ def __init__( validate_args=validate_args, ) - @constraints.dependent_property(is_discrete=False) - def support(self) -> constraints.Constraint: + @constraints.dependent_property(is_discrete=False) # type: ignore[arg-type] + def support(self) -> ConstraintT: # type: ignore[override] return constraints.zero_sum(len(self.event_shape)) @property @@ -2995,12 +3039,12 @@ class Wishart(TransformedDistribution): """ arg_constraints = { - "concentration": constraints.dependent(is_discrete=False, event_dim=0), - "scale_matrix": constraints.positive_definite, - "rate_matrix": constraints.positive_definite, - "scale_tril": constraints.lower_cholesky, + "concentration": constraints.dependent(is_discrete=False, event_dim=0), # type: ignore[has-type] + "scale_matrix": constraints.positive_definite, # type: ignore[has-type] + "rate_matrix": constraints.positive_definite, # type: ignore[has-type] + "scale_tril": constraints.lower_cholesky, # type: ignore[has-type] } - support = constraints.positive_definite + support = constraints.positive # type: ignore[has-type] reparametrized_params = [ "scale_matrix", "rate_matrix", @@ -3009,7 +3053,7 @@ class Wishart(TransformedDistribution): def __init__( self, - concentration: ArrayLike, + concentration: Array, scale_matrix: Optional[Array] = None, rate_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, @@ -3024,7 +3068,9 @@ def __init__( validate_args=validate_args, ) super().__init__( - base_dist, CholeskyTransform().inv, validate_args=validate_args + base_dist, + CholeskyTransform().inv, # type: ignore[arg-type] + validate_args=validate_args, ) @lazy_property @@ -3062,8 +3108,8 @@ def infer_shapes( concentration, scale_matrix, rate_matrix, scale_tril ) - def entropy(self) -> ArrayLike: - p = self.event_shape[-1] + def entropy(self) -> Array: + p = jnp.asarray(self.event_shape[-1]) return ( (p + 1) * tri_logabsdet(self.scale_tril) + p * (p + 1) / 2 * jnp.log(2) @@ -3088,12 +3134,12 @@ class WishartCholesky(Distribution): """ arg_constraints = { - "concentration": constraints.dependent(is_discrete=False, event_dim=0), - "scale_matrix": constraints.positive_definite, - "rate_matrix": constraints.positive_definite, - "scale_tril": constraints.lower_cholesky, + "concentration": constraints.dependent(is_discrete=False, event_dim=0), # type: ignore[has-type] + "scale_matrix": constraints.positive_definite, # type: ignore[has-type] + "rate_matrix": constraints.positive_definite, # type: ignore[has-type] + "scale_tril": constraints.lower_cholesky, # type: ignore[has-type] } - support = constraints.lower_cholesky + support = constraints.lower_cholesky # type: ignore[has-type] reparametrized_params = [ "scale_matrix", "rate_matrix", @@ -3102,7 +3148,7 @@ class WishartCholesky(Distribution): def __init__( self, - concentration: ArrayLike, + concentration: Array, scale_matrix: Optional[Array] = None, rate_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, @@ -3139,7 +3185,7 @@ def __init__( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: # The log density of the Wishart distribution includes a term # t = trace(rate_matrix @ cov). Here, value = cholesky(cov) such that # t = trace(value.T @ rate_matrix @ value) by the cyclical property of the @@ -3149,7 +3195,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: # rewrite as t = trace(x.T @ x) for x = inv(scale_tril) @ value which we can # obtain easily by solving a triangular system. x is again triangular such that # trace(x @ x.T) is equal to the sum of squares of elements. - x = solve_triangular(*jnp.broadcast_arrays(self.scale_tril, value), lower=True) + x = solve_triangular(*jnp.broadcast_arrays(self.scale_tril, value), lower=True) # type: ignore[arg-type] trace = jnp.square(x).sum(axis=(-1, -2)) p = value.shape[-1] return ( @@ -3177,8 +3223,8 @@ def rate_matrix(self): return cho_solve((self.scale_tril, True), identity) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) # Sample using the Bartlett decomposition # (https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition). @@ -3256,14 +3302,14 @@ class Levy(Distribution): """ arg_constraints = { - "loc": constraints.positive, - "scale": constraints.positive, + "loc": constraints.positive, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] } def __init__( self, - loc: ArrayLike, - scale: ArrayLike, + loc: Array, + scale: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -3272,12 +3318,12 @@ def __init__( self._support = constraints.greater_than(loc) super(Levy, self).__init__(batch_shape, validate_args=validate_args) - @constraints.dependent_property(is_discrete=False) - def support(self) -> constraints.Constraint: + @constraints.dependent_property(is_discrete=False) # type: ignore[arg-type] + def support(self) -> ConstraintT: # type: ignore[override] return self._support @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: r"""Compute the log probability density function of the Lévy distribution. .. math:: @@ -3293,12 +3339,14 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: jnp.log(2.0 * jnp.pi) - jnp.log(self.scale) + self.scale / shifted_value ) - 1.5 * jnp.log(shifted_value) - def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike: + def sample( + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) - u = random.uniform(key, shape=sample_shape + self.batch_shape) + u = random.uniform(key, shape=sample_shape + self.batch_shape) # type: ignore[arg-type] return self.icdf(u) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: r""" The inverse cumulative distribution function of Lévy distribution is given by, @@ -3312,7 +3360,7 @@ def icdf(self, q: ArrayLike) -> ArrayLike: """ return self.loc + self.scale * jnp.power(ndtri(1 - 0.5 * q), -2) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: r"""The cumulative distribution function of Lévy distribution is given by, .. math:: @@ -3334,7 +3382,7 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.broadcast_to(jnp.inf, self.batch_shape) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: r"""If :math:`X \sim \text{Levy}(\mu, c)`, then the entropy of :math:`X` is given by, .. math:: @@ -3384,17 +3432,17 @@ class CirculantNormal(TransformedDistribution): """ arg_constraints = { - "loc": constraints.real_vector, - "covariance_row": constraints.positive_definite_circulant_vector, - "covariance_rfft": constraints.independent(constraints.positive, 1), + "loc": constraints.real_vector, # type: ignore[has-type] + "covariance_row": constraints.positive_definite_circulant_vector, # type: ignore[has-type] + "covariance_rfft": constraints.independent(constraints.positive, 1), # type: ignore[has-type] } - support = constraints.real_vector + support = constraints.real_vector # type: ignore[has-type] def __init__( self, - loc: ArrayLike, - covariance_row: Optional[ArrayLike] = None, - covariance_rfft: Optional[ArrayLike] = None, + loc: Array, + covariance_row: Optional[Array] = None, + covariance_rfft: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: @@ -3439,9 +3487,9 @@ def __init__( super().__init__( base_distribution, [ - PackRealFastFourierCoefficientsTransform((n,)), - RealFastFourierTransform((n,)).inv, - AffineTransform(loc, scale=1.0), + PackRealFastFourierCoefficientsTransform((n,)), # type: ignore + RealFastFourierTransform((n,)).inv, # type: ignore + AffineTransform(loc, scale=1.0), # type: ignore ], validate_args=validate_args, ) @@ -3451,11 +3499,11 @@ def mean(self) -> ArrayLike: return jnp.broadcast_to(self.loc, self.shape()) @lazy_property - def covariance_row(self) -> ArrayLike: + def covariance_row(self) -> Array: return jnp.fft.irfft(self.covariance_rfft, n=self.event_shape[-1]) @lazy_property - def covariance_matrix(self) -> ArrayLike: + def covariance_matrix(self) -> Array: *leading_shape, n = self.covariance_row.shape if leading_shape: # `toeplitz` flattens the input, and we need to broadcast manually. @@ -3467,7 +3515,7 @@ def covariance_matrix(self) -> ArrayLike: return toeplitz(self.covariance_row) @lazy_property - def variance(self) -> ArrayLike: + def variance(self) -> ArrayLike: # type: ignore[override] return jnp.broadcast_to(self.covariance_row[..., 0, None], self.shape()) @staticmethod @@ -3483,7 +3531,7 @@ def infer_shapes( event_shape = loc[-1:] return batch_shape, event_shape - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: (n,) = self.event_shape log_abs_det_jacobian = 2 * jnp.log(2) * ((n - 1) // 2) - jnp.log(n) * n return self.base_dist.entropy() + log_abs_det_jacobian / 2 @@ -3491,18 +3539,18 @@ def entropy(self) -> ArrayLike: class Dagum(Distribution): arg_constraints = { - "concentration": constraints.positive, - "sharpness": constraints.positive, - "scale": constraints.positive, + "concentration": constraints.positive, # type: ignore[has-type] + "sharpness": constraints.positive, # type: ignore[has-type] + "scale": constraints.positive, # type: ignore[has-type] } - support = constraints.positive + support = constraints.positive # type: ignore[has-type] reparametrized_params = ["concentration", "sharpness", "scale"] def __init__( self, - concentration: ArrayLike, - sharpness: ArrayLike, - scale: ArrayLike, + concentration: Array, + sharpness: Array, + scale: Array, *, validate_args: Optional[bool] = None, ) -> None: @@ -3530,7 +3578,7 @@ def __init__( super().__init__(batch_shape=batch_shape, validate_args=validate_args) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: a_ln_x_m_ln_b = xlogy(self.sharpness, value) - xlogy(self.sharpness, self.scale) return ( jnp.log(self.sharpness) @@ -3540,7 +3588,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: - (self.concentration + 1.0) * nn.softplus(a_ln_x_m_ln_b) ) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: return jnp.exp( -self.concentration * nn.softplus( @@ -3548,15 +3596,15 @@ def cdf(self, value: ArrayLike) -> ArrayLike: ) ) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: q_root_p = jnp.power(q, -jnp.reciprocal(self.concentration)) return self.scale * jnp.power(q_root_p - 1.0, -jnp.reciprocal(self.sharpness)) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () ) -> jnp.ndarray: assert is_prng_key(key) - return self.icdf(random.uniform(key, shape=self.shape(sample_shape))) + return self.icdf(random.uniform(key, shape=self.shape(sample_shape))) # type: ignore[arg-type] @property def mean(self) -> ArrayLike: diff --git a/numpyro/distributions/copula.py b/numpyro/distributions/copula.py index 6e0e6d467..a18be5218 100644 --- a/numpyro/distributions/copula.py +++ b/numpyro/distributions/copula.py @@ -4,11 +4,10 @@ from typing import Optional -import jax from jax import Array, lax, numpy as jnp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, DistributionT +from numpyro._typing import ConstraintT, DistributionT, PRNGKeyT import numpyro.distributions.constraints as constraints from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal from numpyro.distributions.distribution import Distribution @@ -67,8 +66,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) shape = sample_shape + self.batch_shape @@ -77,7 +76,7 @@ def sample( return self.marginal_dist.icdf(cdf) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: # Ref: https://en.wikipedia.org/wiki/Copula_(probability_theory)#Gaussian_copula # see also https://github.com/pyro-ppl/numpyro/pull/1506#discussion_r1037525015 marginal_lps = self.marginal_dist.log_prob(value) @@ -124,8 +123,8 @@ class GaussianCopulaBeta(GaussianCopula): def __init__( self, - concentration1: ArrayLike, - concentration0: ArrayLike, + concentration1: Array, + concentration0: Array, correlation_matrix: Optional[Array] = None, correlation_cholesky: Optional[Array] = None, *, diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index e4f85ca94..5fd5ca086 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -9,7 +9,6 @@ import operator from typing import Optional -import jax from jax import Array, lax import jax.numpy as jnp import jax.random as random @@ -17,6 +16,7 @@ from jax.scipy.special import erf, i0e, i1e, logsumexp from jax.typing import ArrayLike +from numpyro._typing import PRNGKeyT from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( @@ -34,7 +34,7 @@ def _numel(shape: tuple[int, ...]) -> int: return functools.reduce(operator.mul, shape, 1) -def log_I1(orders: int, value: ArrayLike, terms: int = 250) -> Array: +def log_I1(orders: int, value: Array, terms: int = 250) -> Array: r"""Compute first n log modified bessel function of first kind .. math :: \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk) @@ -105,8 +105,8 @@ def model(): def __init__( self, - loc: ArrayLike, - concentration: ArrayLike, + loc: Array, + concentration: Array, *, validate_args: Optional[bool] = None, ): @@ -124,8 +124,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: """Generate sample from von Mises distribution :param key: random number generator key @@ -142,7 +142,7 @@ def sample( return samples @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return -( jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration)) ) + self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) @@ -233,7 +233,7 @@ def model(obs): def __init__( self, base_dist: Distribution, - skewness: ArrayLike, + skewness: Array, *, validate_args: Optional[bool] = None, ): @@ -268,8 +268,8 @@ def __repr__(self) -> str: ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: base_key, skew_key = random.split(key) bd = self.base_dist ys = bd.sample(base_key, sample_shape) @@ -285,7 +285,7 @@ def sample( ) - jnp.pi return samples - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._validate_args: self._validate_sample(value) if self.base_dist._validate_args: @@ -367,12 +367,12 @@ class SineBivariateVonMises(Distribution): def __init__( self, - phi_loc: ArrayLike, - psi_loc: ArrayLike, - phi_concentration: ArrayLike, - psi_concentration: ArrayLike, - correlation: Optional[ArrayLike] = None, - weighted_correlation: Optional[ArrayLike] = None, + phi_loc: Array, + psi_loc: Array, + phi_concentration: Array, + psi_concentration: Array, + correlation: Optional[Array] = None, + weighted_correlation: Optional[Array] = None, *, validate_args: Optional[bool] = None, ): @@ -436,7 +436,7 @@ def norm_const(self) -> Array: return norm_const.reshape(jnp.shape(self.phi_loc)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: indv = self.phi_concentration * jnp.cos( value[..., 0] - self.phi_loc ) + self.psi_concentration * jnp.cos(value[..., 1] - self.psi_loc) @@ -448,8 +448,8 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: return indv + corr - self.norm_const def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: """ ** References: ** 1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions @@ -513,14 +513,14 @@ def update_fn(curr: PhiMarginalState) -> PhiMarginalState: x, axis=1, keepdims=True ) # Angular Central Gaussian distribution - lf: ArrayLike = ( + lf: Array = ( conc[0] * (x[:, 0] - 1) + eigmin + log_I1(0, jnp.sqrt(conc[1] ** 2 + (corr * x[:, 1]) ** 2)).squeeze(0) - phi_den ) - lg_inv: ArrayLike = 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x**2).sum(1)) + lg_inv: Array = 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x**2).sum(1)) assert lg_inv.shape == lf.shape accepted = random.uniform(accept_key, lf.shape) < jnp.exp(lf + lg_inv) @@ -614,13 +614,13 @@ def mode(self): return safe_normalize(self.concentration) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: shape = sample_shape + self.batch_shape + self.event_shape eps = random.normal(key, shape=shape) return safe_normalize(self.concentration + eps) - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._validate_args: event_shape = value.shape[-1:] if event_shape != self.event_shape: @@ -669,7 +669,7 @@ def _dot(x, y): def _projected_normal_log_prob_3(concentration, value): - def _dot(x: Array, y: Array) -> ArrayLike: + def _dot(x: Array, y: Array) -> Array: return (x[..., None, :] @ y[..., None])[..., 0, 0] # We integrate along a ray, factorizing the integrand as a product of: diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index f6130e4f0..baef177f6 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -38,8 +38,9 @@ from jax.scipy.special import expit, gammaincc, gammaln, logsumexp, xlog1py, xlogy from jax.typing import ArrayLike +from numpyro._typing import PRNGKeyT from numpyro.distributions import constraints, transforms -from numpyro.distributions.distribution import Distribution, DistributionT +from numpyro.distributions.distribution import ConstraintT, Distribution, DistributionT from numpyro.distributions.util import ( assert_one_of, binary_cross_entropy_with_logits, @@ -54,38 +55,40 @@ from numpyro.util import is_prng_key, not_jax_tracer -def _to_probs_bernoulli(logits: ArrayLike) -> ArrayLike: +def _to_probs_bernoulli(logits: Array) -> Array: return expit(logits) -def _to_logits_bernoulli(probs: ArrayLike) -> ArrayLike: +def _to_logits_bernoulli(probs: Array) -> Array: ps_clamped = clamp_probs(probs) return jnp.log(ps_clamped) - jnp.log1p(-ps_clamped) -def _to_probs_multinom(logits: ArrayLike) -> ArrayLike: +def _to_probs_multinom(logits: Array) -> Array: return softmax(logits, axis=-1) -def _to_logits_multinom(probs: ArrayLike) -> ArrayLike: +def _to_logits_multinom(probs: Array) -> Array: minval = jnp.finfo(jnp.result_type(probs)).min return jnp.clip(jnp.log(probs), minval) class BernoulliProbs(Distribution): - arg_constraints = {"probs": constraints.unit_interval} - support = constraints.boolean + arg_constraints = { + "probs": constraints.unit_interval, # type: ignore[has-type] + } + support = constraints.boolean # type: ignore[has-type] has_enumerate_support = True - def __init__(self, probs: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, probs: Array, *, validate_args: Optional[bool] = None): self.probs = probs super(BernoulliProbs, self).__init__( batch_shape=jnp.shape(self.probs), validate_args=validate_args ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) samples = random.bernoulli( key, self.probs, shape=sample_shape + self.batch_shape @@ -93,13 +96,13 @@ def sample( return samples.astype(jnp.result_type(samples, int)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: ps_clamped = clamp_probs(self.probs) value = jnp.array(value, jnp.result_type(float)) return xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) @lazy_property - def logits(self) -> ArrayLike: + def logits(self) -> Array: return _to_logits_bernoulli(self.probs) @property @@ -110,32 +113,34 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return self.probs * (1 - self.probs) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return -self.probs * jnp.log(self.probs) - (1 - self.probs) * jnp.log1p( -self.probs ) class BernoulliLogits(Distribution): - arg_constraints = {"logits": constraints.real} - support = constraints.boolean + arg_constraints = { + "logits": constraints.real, # type: ignore[has-type] + } + support = constraints.boolean # type: ignore[has-type] has_enumerate_support = True - def __init__(self, logits: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, logits: Array, *, validate_args: Optional[bool] = None): self.logits = logits super(BernoulliLogits, self).__init__( batch_shape=jnp.shape(self.logits), validate_args=validate_args ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) samples = random.bernoulli( key, self.probs, shape=sample_shape + self.batch_shape @@ -143,11 +148,11 @@ def sample( return samples.astype(jnp.result_type(samples, int)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return -binary_cross_entropy_with_logits(self.logits, value) @lazy_property - def probs(self) -> ArrayLike: + def probs(self) -> Array: return _to_probs_bernoulli(self.logits) @property @@ -158,20 +163,20 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return self.probs * (1 - self.probs) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: nexp = jnp.exp(-self.logits) return ((1 + nexp) * jnp.log1p(nexp) + nexp * self.logits) / (1 + nexp) def Bernoulli( - probs: Optional[ArrayLike] = None, - logits: Optional[ArrayLike] = None, + probs: Optional[Array] = None, + logits: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> Union[BernoulliProbs, BernoulliLogits]: @@ -184,14 +189,14 @@ def Bernoulli( class BinomialProbs(Distribution): arg_constraints = { - "probs": constraints.unit_interval, - "total_count": constraints.nonnegative_integer, + "probs": constraints.unit_interval, # type: ignore[has-type] + "total_count": constraints.nonnegative_integer, # type: ignore[has-type] } has_enumerate_support = True def __init__( self, - probs: ArrayLike, + probs: Array, total_count: int = 1, *, validate_args: Optional[bool] = None, @@ -203,15 +208,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return binomial( key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: value = jnp.array(value, jnp.result_type(float)) log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) @@ -226,7 +231,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: ) @lazy_property - def logits(self) -> ArrayLike: + def logits(self) -> Array: return _to_logits_bernoulli(self.probs) @property @@ -240,10 +245,10 @@ def variance(self) -> ArrayLike: ) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.integer_interval(0, self.total_count) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: if not_jax_tracer(self.total_count): total_count = np.amax(self.total_count) # NB: the error can't be raised if inhomogeneous issue happens when tracing @@ -263,15 +268,15 @@ def enumerate_support(self, expand: bool = True) -> ArrayLike: class BinomialLogits(Distribution): arg_constraints = { - "logits": constraints.real, - "total_count": constraints.nonnegative_integer, + "logits": constraints.real, # type: ignore[has-type] + "total_count": constraints.nonnegative_integer, # type: ignore[has-type] } has_enumerate_support = True enumerate_support = BinomialProbs.enumerate_support def __init__( self, - logits: ArrayLike, + logits: Array, total_count: int = 1, *, validate_args: Optional[bool] = None, @@ -283,15 +288,15 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return binomial( key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: total_count = jnp.array(self.total_count, dtype=jnp.result_type(float)) log_factorial_n = gammaln(total_count + 1) log_factorial_k = gammaln(value + 1) @@ -306,7 +311,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: ) @lazy_property - def probs(self) -> ArrayLike: + def probs(self) -> Array: return _to_probs_bernoulli(self.logits) @property @@ -320,14 +325,14 @@ def variance(self) -> ArrayLike: ) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.integer_interval(0, self.total_count) def Binomial( total_count: int = 1, - probs: Optional[ArrayLike] = None, - logits: Optional[ArrayLike] = None, + probs: Optional[Array] = None, + logits: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> Union[BinomialProbs, BinomialLogits]: @@ -339,7 +344,9 @@ def Binomial( class CategoricalProbs(Distribution): - arg_constraints = {"probs": constraints.simplex} + arg_constraints = { + "probs": constraints.simplex, # type: ignore[has-type] + } has_enumerate_support = True def __init__(self, probs: Array, *, validate_args: Optional[bool] = None): @@ -351,13 +358,13 @@ def __init__(self, probs: Array, *, validate_args: Optional[bool] = None): ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return categorical(key, self.probs, shape=sample_shape + self.batch_shape) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) value = jnp.expand_dims(value, axis=-1) value = jnp.broadcast_to(value, batch_shape + (1,)) @@ -366,7 +373,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0] @lazy_property - def logits(self) -> ArrayLike: + def logits(self) -> Array: return _to_logits_multinom(self.probs) @property @@ -378,10 +385,10 @@ def variance(self) -> ArrayLike: return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs)) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.integer_interval(0, jnp.shape(self.probs)[-1] - 1) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: values = jnp.arange(self.probs.shape[-1]).reshape( (-1,) + (1,) * len(self.batch_shape) ) @@ -389,12 +396,14 @@ def enumerate_support(self, expand: bool = True) -> ArrayLike: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return -(self.probs * jnp.log(self.probs)).sum(axis=-1) class CategoricalLogits(Distribution): - arg_constraints = {"logits": constraints.real_vector} + arg_constraints = { + "logits": constraints.real_vector, # type: ignore[has-type] + } has_enumerate_support = True def __init__(self, logits: Array, *, validate_args: Optional[bool] = None): @@ -406,15 +415,15 @@ def __init__(self, logits: Array, *, validate_args: Optional[bool] = None): ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return random.categorical( key, self.logits, shape=sample_shape + self.batch_shape ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) value = jnp.expand_dims(value, -1) value = jnp.broadcast_to(value, batch_shape + (1,)) @@ -423,7 +432,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: return jnp.take_along_axis(log_pmf, value, -1)[..., 0] @lazy_property - def probs(self) -> ArrayLike: + def probs(self) -> Array: return _to_probs_multinom(self.logits) @property @@ -435,10 +444,10 @@ def variance(self) -> ArrayLike: return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits)) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.integer_interval(0, jnp.shape(self.logits)[-1] - 1) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: values = jnp.arange(self.logits.shape[-1]).reshape( (-1,) + (1,) * len(self.batch_shape) ) @@ -446,7 +455,7 @@ def enumerate_support(self, expand: bool = True) -> ArrayLike: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: probs = softmax(self.logits, axis=-1) return -(probs * self.logits).sum(axis=-1) + logsumexp(self.logits, axis=-1) @@ -461,16 +470,16 @@ def Categorical(probs=None, logits=None, *, validate_args: Optional[bool] = None class DiscreteUniform(Distribution): arg_constraints = { - "low": constraints.dependent(is_discrete=True, event_dim=0), - "high": constraints.dependent(is_discrete=True, event_dim=0), + "low": constraints.dependent(is_discrete=True, event_dim=0), # type: ignore[has-type] + "high": constraints.dependent(is_discrete=True, event_dim=0), # type: ignore[has-type] } has_enumerate_support = True pytree_data_fields = ("low", "high", "_support") def __init__( self, - low: ArrayLike = 0, - high: ArrayLike = 1, + low: Array = 0, + high: Array = 1, *, validate_args: Optional[bool] = None, ): @@ -480,25 +489,25 @@ def __init__( super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return self._support def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: shape = sample_shape + self.batch_shape return random.randint(key, shape=shape, minval=self.low, maxval=self.high + 1) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) return -jnp.broadcast_to(jnp.log(self.high + 1 - self.low), shape) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: cdf = (jnp.floor(value) + 1 - self.low) / (self.high - self.low + 1) return jnp.clip(cdf, 0.0, 1.0) - def icdf(self, value: ArrayLike) -> ArrayLike: + def icdf(self, value: Array) -> Array: return self.low + value * (self.high - self.low + 1) - 1 @property @@ -509,7 +518,7 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return ((self.high - self.low + 1) ** 2 - 1) / 12.0 - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: if not not_jax_tracer(self.high) or not not_jax_tracer(self.low): raise NotImplementedError("Both `low` and `high` must not be a JAX Tracer.") if np.any(np.amax(self.low) != self.low): @@ -529,7 +538,7 @@ def enumerate_support(self, expand: bool = True) -> ArrayLike: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.log(self.high - self.low + 1) @@ -548,14 +557,14 @@ class OrderedLogistic(CategoricalProbs): """ arg_constraints = { - "predictor": constraints.real, - "cutpoints": constraints.ordered_vector, + "predictor": constraints.real, # type: ignore[has-type] + "cutpoints": constraints.ordered_vector, # type: ignore[has-type] } def __init__( self, - predictor: ArrayLike, - cutpoints: ArrayLike, + predictor: Array, + cutpoints: Array, *, validate_args: Optional[bool] = None, ): @@ -574,14 +583,14 @@ def infer_shapes(predictor, cutpoints): event_shape = () return batch_shape, event_shape - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: raise NotImplementedError class MultinomialProbs(Distribution): arg_constraints = { - "probs": constraints.simplex, - "total_count": constraints.nonnegative_integer, + "probs": constraints.simplex, # type: ignore[has-type] + "total_count": constraints.nonnegative_integer, # type: ignore[has-type] } pytree_data_fields = ("probs",) pytree_aux_fields = ("total_count", "total_count_max") @@ -609,8 +618,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return multinomial( key, @@ -621,14 +630,14 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: value = jnp.array(value, jnp.result_type(float)) return gammaln(self.total_count + 1) + jnp.sum( xlogy(value, self.probs) - gammaln(value + 1), axis=-1 ) @lazy_property - def logits(self) -> ArrayLike: + def logits(self) -> Array: return _to_logits_multinom(self.probs) @property @@ -640,7 +649,7 @@ def variance(self) -> ArrayLike: return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) @constraints.dependent_property(is_discrete=True, event_dim=1) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.multinomial(self.total_count) @staticmethod @@ -654,8 +663,8 @@ def infer_shapes( class MultinomialLogits(Distribution): arg_constraints = { - "logits": constraints.real_vector, - "total_count": constraints.nonnegative_integer, + "logits": constraints.real_vector, # type: ignore[has-type] + "total_count": constraints.nonnegative_integer, # type: ignore[has-type] } pytree_data_fields = ("logits",) pytree_aux_fields = ("total_count", "total_count_max") @@ -685,8 +694,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return multinomial( key, @@ -697,7 +706,7 @@ def sample( ) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._validate_args: self._validate_sample(value) normalize_term = self.total_count * logsumexp(self.logits, axis=-1) - gammaln( @@ -708,7 +717,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: ) @lazy_property - def probs(self) -> ArrayLike: + def probs(self) -> Array: return _to_probs_multinom(self.logits) @property @@ -720,7 +729,7 @@ def variance(self) -> ArrayLike: return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) @constraints.dependent_property(is_discrete=True, event_dim=1) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.multinomial(self.total_count) @staticmethod @@ -780,13 +789,15 @@ class Poisson(Distribution): :meth:`log_prob`, which can speed up computation when data is sparse. """ - arg_constraints = {"rate": constraints.positive} - support = constraints.nonnegative_integer + arg_constraints = { + "rate": constraints.positive, # type: ignore[has-type] + } + support = constraints.nonnegative_integer # type: ignore[has-type] pytree_aux_fields = ("is_sparse",) def __init__( self, - rate: ArrayLike, + rate: Array, *, is_sparse: bool = False, validate_args: Optional[bool] = None, @@ -796,13 +807,13 @@ def __init__( super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._validate_args: self._validate_sample(value) if ( @@ -834,19 +845,21 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return self.rate - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: k = jnp.floor(value) + 1 return gammaincc(k, self.rate) class ZeroInflatedProbs(Distribution): - arg_constraints = {"gate": constraints.unit_interval} + arg_constraints = { + "gate": constraints.unit_interval, # type: ignore[has-type] + } pytree_data_fields = ("base_dist", "gate") def __init__( self, base_dist: DistributionT, - gate: ArrayLike, + gate: Array, *, validate_args: Optional[bool] = None, ): @@ -867,8 +880,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) key_bern, key_base = random.split(key) shape = sample_shape + self.batch_shape @@ -877,12 +890,12 @@ def sample( return jnp.where(mask, 0, samples) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: log_prob = jnp.log1p(-self.gate) + self.base_dist.log_prob(value) return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return self.base_dist.support @lazy_property @@ -896,20 +909,22 @@ def variance(self) -> ArrayLike: ) - self.mean**2 @property - def has_enumerate_support(self): + def has_enumerate_support(self) -> bool: return self.base_dist.has_enumerate_support - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: return self.base_dist.enumerate_support(expand=expand) class ZeroInflatedLogits(ZeroInflatedProbs): - arg_constraints = {"gate_logits": constraints.real} + arg_constraints = { + "gate_logits": constraints.real, # type: ignore[has-type] + } def __init__( self, base_dist: DistributionT, - gate_logits: ArrayLike, + gate_logits: Array, *, validate_args: Optional[bool] = None, ): @@ -919,7 +934,7 @@ def __init__( super().__init__(base_dist, gate, validate_args=validate_args) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: log_prob_minus_log_gate = -self.gate_logits + self.base_dist.log_prob(value) log_gate = -softplus(-self.gate_logits) log_prob = log_prob_minus_log_gate + log_gate @@ -930,8 +945,8 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: def ZeroInflatedDistribution( base_dist: DistributionT, *, - gate: Optional[ArrayLike] = None, - gate_logits: Optional[ArrayLike] = None, + gate: Optional[Array] = None, + gate_logits: Optional[Array] = None, validate_args: Optional[bool] = None, ) -> Union[ZeroInflatedProbs, ZeroInflatedLogits]: """ @@ -956,16 +971,19 @@ class ZeroInflatedPoisson(ZeroInflatedProbs): :param numpy.ndarray rate: rate of Poisson distribution. """ - arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive} - support = constraints.nonnegative_integer + arg_constraints = { + "gate": constraints.unit_interval, # type: ignore[has-type] + "rate": constraints.positive, # type: ignore[has-type] + } + support = constraints.nonnegative_integer # type: ignore[has-type] pytree_data_fields = ("rate",) # TODO: resolve inconsistent parameter order w.r.t. Pyro # and support `gate_logits` argument def __init__( self, - gate: ArrayLike, - rate: ArrayLike = 1.0, + gate: Array, + rate: Array = 1.0, *, validate_args: Optional[bool] = None, ) -> None: @@ -974,18 +992,20 @@ def __init__( class GeometricProbs(Distribution): - arg_constraints = {"probs": constraints.unit_interval} - support = constraints.nonnegative_integer + arg_constraints = { + "probs": constraints.unit_interval, # type: ignore[has-type] + } + support = constraints.nonnegative_integer # type: ignore[has-type] - def __init__(self, probs: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, probs: Array, *, validate_args: Optional[bool] = None): self.probs = probs super(GeometricProbs, self).__init__( batch_shape=jnp.shape(self.probs), validate_args=validate_args ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) probs = self.probs dtype = jnp.result_type(probs) @@ -994,12 +1014,12 @@ def sample( return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: probs = jnp.where((self.probs == 1) & (value == 0), 0, self.probs) return value * jnp.log1p(-probs) + jnp.log(probs) @lazy_property - def logits(self) -> ArrayLike: + def logits(self) -> Array: return _to_logits_bernoulli(self.probs) @property @@ -1010,29 +1030,31 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return (1.0 / self.probs - 1.0) / self.probs - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return -(1 - self.probs) * jnp.log1p(-self.probs) / self.probs - jnp.log( self.probs ) class GeometricLogits(Distribution): - arg_constraints = {"logits": constraints.real} - support = constraints.nonnegative_integer + arg_constraints = { + "logits": constraints.real, # type: ignore[has-type] + } + support = constraints.nonnegative_integer # type: ignore[has-type] - def __init__(self, logits: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, logits: Array, *, validate_args: Optional[bool] = None): self.logits = logits super(GeometricLogits, self).__init__( batch_shape=jnp.shape(self.logits), validate_args=validate_args ) @lazy_property - def probs(self) -> ArrayLike: + def probs(self) -> Array: return _to_probs_bernoulli(self.logits) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) logits = self.logits dtype = jnp.result_type(logits) @@ -1041,7 +1063,7 @@ def sample( return jnp.floor(jnp.log1p(-u) / -softplus(logits)) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return (-value - 1) * softplus(self.logits) + self.logits @property @@ -1052,7 +1074,7 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return (1.0 / self.probs - 1.0) / self.probs - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: logq = -jax.nn.softplus(self.logits) logp = -jax.nn.softplus(-self.logits) p = jax.scipy.special.expit(self.logits) @@ -1061,8 +1083,8 @@ def entropy(self) -> ArrayLike: def Geometric( - probs: Optional[ArrayLike] = None, - logits: Optional[ArrayLike] = None, + probs: Optional[Array] = None, + logits: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> Union[GeometricProbs, GeometricLogits]: diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7e67edad5..0efbf8d58 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -34,6 +34,7 @@ import warnings import numpy as np +from typing_extensions import ParamSpecKwargs import jax from jax import Array, lax, tree_util @@ -41,7 +42,7 @@ from jax.scipy.special import logsumexp from jax.typing import ArrayLike -from numpyro._typing import DistributionT, TransformT +from numpyro._typing import ConstraintT, DistributionT, PRNGKeyT, TransformT from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform from numpyro.distributions.util import ( lazy_property, @@ -134,7 +135,7 @@ class Distribution(metaclass=DistributionMeta): """ arg_constraints: dict[str, Any] = {} - support = None + support: Optional[ConstraintT] = None has_enumerate_support: bool = False reparametrized_params: list[str] = [] _validate_args: bool = False @@ -313,8 +314,8 @@ def has_rsample(self) -> bool: return set(self.reparametrized_params) == set(self.arg_constraints) def rsample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: if self.has_rsample: return self.sample(key, sample_shape=sample_shape) @@ -336,8 +337,8 @@ def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: return sample_shape + self.batch_shape + self.event_shape def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: """ Returns a sample from the distribution having shape given by `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty, @@ -352,8 +353,8 @@ def sample( raise NotImplementedError def sample_with_intermediates( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: """ Same as ``sample`` except that any intermediate computations are returned (useful for `TransformedDistribution`). @@ -365,14 +366,14 @@ def sample_with_intermediates( """ return self.sample(key, sample_shape=sample_shape), [] - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: """ Evaluates the log probability density for a batch of samples given by `value`. :param value: A batch of samples from the distribution. :return: an array with shape `value.shape[:-self.event_shape]` - :rtype: ArrayLike + :rtype: Array """ raise NotImplementedError @@ -390,7 +391,7 @@ def variance(self) -> ArrayLike: """ raise NotImplementedError - def _validate_sample(self, value: ArrayLike) -> ArrayLike: + def _validate_sample(self, value: Array) -> Array: mask = self.support(value) if not_jax_tracer(mask): if not np.all(mask): @@ -401,7 +402,7 @@ def _validate_sample(self, value: ArrayLike) -> ArrayLike: ) return mask - def __call__(self, *args, **kwargs) -> ArrayLike: + def __call__(self, *args, **kwargs) -> Array: key = kwargs.pop("rng_key") sample_intermediates = kwargs.pop("sample_intermediates", False) if sample_intermediates: @@ -410,7 +411,7 @@ def __call__(self, *args, **kwargs) -> ArrayLike: def to_event( self, reinterpreted_batch_ndims: Optional[int] = None - ) -> "Distribution": + ) -> DistributionT: """ Interpret the rightmost `reinterpreted_batch_ndims` batch dimensions as dependent event dimensions. @@ -426,20 +427,20 @@ def to_event( return self return Independent(self, reinterpreted_batch_ndims) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: """ Returns an array with shape `len(support) x batch_shape` containing all values in the support. """ raise NotImplementedError - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: """ Returns the entropy of the distribution. """ raise NotImplementedError - def expand(self, batch_shape: tuple[int, ...]) -> "Distribution": + def expand(self, batch_shape: tuple[int, ...]) -> DistributionT: """ Returns a new :class:`ExpandedDistribution` instance with batch dimensions expanded to `batch_shape`. @@ -453,7 +454,7 @@ def expand(self, batch_shape: tuple[int, ...]) -> "Distribution": return self return ExpandedDistribution(self, batch_shape) - def expand_by(self, sample_shape: tuple[int, ...]) -> "Distribution": + def expand_by(self, sample_shape: tuple[int, ...]) -> DistributionT: """ Expands a distribution by adding ``sample_shape`` to the left side of its :attr:`~numpyro.distributions.distribution.Distribution.batch_shape`. @@ -558,7 +559,7 @@ def infer_shapes(cls, *args, **kwargs): event_shape = () return batch_shape, event_shape - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: """ The cumulative distribution function of this distribution. @@ -567,7 +568,7 @@ def cdf(self, value: ArrayLike) -> ArrayLike: """ raise NotImplementedError - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: """ The inverse cumulative distribution function of this distribution. @@ -597,7 +598,7 @@ class ExpandedDistribution(Distribution): ) def __init__( - self, base_dist: Distribution, batch_shape: tuple[int, ...] = () + self, base_dist: DistributionT, batch_shape: tuple[int, ...] = () ) -> None: if isinstance(base_dist, ExpandedDistribution): batch_shape, _, _ = self._broadcast_shape( @@ -660,12 +661,12 @@ def has_rsample(self) -> bool: def _sample( self, sample_fn: Callable[ - [Optional[jax.dtypes.prng_key], tuple[int, ...]], - tuple[ArrayLike, list[ArrayLike]], + [Optional[PRNGKeyT], ParamSpecKwargs[tuple[int, ...]]], + tuple[Array, list[Array]], ], - key: Optional[jax.dtypes.prng_key], + key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = (), - ) -> tuple[ArrayLike, list[ArrayLike]]: + ) -> tuple[Array, list[Array]]: interstitial_sizes = tuple(self._interstitial_sizes.values()) expanded_sizes = tuple(self._expanded_sizes.values()) batch_shape = expanded_sizes + interstitial_sizes @@ -686,7 +687,7 @@ def _sample( for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims): permutation[dim1], permutation[dim2] = permutation[dim2], permutation[dim1] - def reshape_sample(x: ArrayLike) -> ArrayLike: + def reshape_sample(x: Array) -> Array: """ Reshapes samples and intermediates to ensure that the output shape is correct: This implicitly replaces the interstitial dims @@ -701,9 +702,7 @@ def reshape_sample(x: ArrayLike) -> ArrayLike: samples = reshape_sample(samples) return samples, intermediates - def rsample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ): + def rsample(self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = ()): return self._sample( lambda *args, **kwargs: (self.base_dist.rsample(*args, **kwargs), []), key, @@ -711,20 +710,20 @@ def rsample( )[0] @property - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return self.base_dist.support def sample_with_intermediates( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> tuple[ArrayLike, list[ArrayLike]]: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> tuple[Array, list[Array]]: return self._sample(self.base_dist.sample_with_intermediates, key, sample_shape) def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.sample_with_intermediates(key, sample_shape)[0] - def log_prob(self, value: ArrayLike, intermediates=None) -> ArrayLike: + def log_prob(self, value: Array, intermediates=None) -> Array: # TODO: utilize `intermediates` shape = lax.broadcast_shapes( self.batch_shape, @@ -733,7 +732,7 @@ def log_prob(self, value: ArrayLike, intermediates=None) -> ArrayLike: log_prob = self.base_dist.log_prob(value) return jnp.broadcast_to(log_prob, shape) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: samples = self.base_dist.enumerate_support(expand=False) enum_shape = samples.shape[:1] samples = samples.reshape(enum_shape + (1,) * len(self.batch_shape)) @@ -753,7 +752,7 @@ def variance(self) -> ArrayLike: self.base_dist.variance, self.batch_shape + self.event_shape ) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.broadcast_to(self.base_dist.entropy(), self.batch_shape) @@ -826,12 +825,12 @@ def __init__( super().__init__(batch_shape, event_shape, validate_args=validate_args) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: batch_shape = jnp.shape(value)[: jnp.ndim(value) - len(self.event_shape)] batch_shape = lax.broadcast_shapes(batch_shape, self.batch_shape) return jnp.zeros(batch_shape) - def _validate_sample(self, value: ArrayLike) -> ArrayLike: + def _validate_sample(self, value: Array) -> Array: mask = super(ImproperUniform, self)._validate_sample(value) batch_dim = jnp.ndim(value) - len(self.event_shape) if batch_dim < jnp.ndim(mask): @@ -891,7 +890,7 @@ def __init__( ) @property - def support(self) -> constraints.Constraint: + def support(self) -> ConstraintT: return constraints.independent( self.base_dist.support, self.reinterpreted_batch_ndims ) @@ -916,17 +915,15 @@ def variance(self) -> ArrayLike: def has_rsample(self) -> bool: return self.base_dist.has_rsample - def rsample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ): + def rsample(self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = ()): return self.base_dist.rsample(key, sample_shape=sample_shape) def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.base_dist(rng_key=key, sample_shape=sample_shape) - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: log_prob = self.base_dist.log_prob(value) return sum_rightmost(log_prob, self.reinterpreted_batch_ndims) @@ -938,7 +935,7 @@ def expand(self, batch_shape): self.reinterpreted_batch_ndims ) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: axes = range(-self.reinterpreted_batch_ndims, 0) return self.base_dist.entropy().sum(axes) @@ -982,20 +979,20 @@ def has_rsample(self) -> bool: return self.base_dist.has_rsample def rsample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ): + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.base_dist.rsample(key, sample_shape=sample_shape) @property - def support(self): + def support(self) -> ConstraintT: return self.base_dist.support def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.base_dist(rng_key=key, sample_shape=sample_shape) - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: if self._mask is False: shape = lax.broadcast_shapes( tuple(self.base_dist.batch_shape), @@ -1072,7 +1069,7 @@ class TransformedDistribution(Distribution): def __init__( self, base_distribution: DistributionT, - transforms: list[TransformT], + transforms: Union[TransformT, list[TransformT]], *, validate_args: Optional[bool] = None, ): @@ -1131,16 +1128,14 @@ def __init__( def has_rsample(self) -> bool: return self.base_dist.has_rsample - def rsample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ): + def rsample(self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = ()): x = self.base_dist.rsample(key, sample_shape=sample_shape) for transform in self.transforms: x = transform(x) return x @property - def support(self): + def support(self) -> ConstraintT: codomain = self.transforms[-1].codomain codomain_event_dim = codomain.event_dim assert self.event_dim >= codomain_event_dim @@ -1152,15 +1147,15 @@ def support(self): ) def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: x = self.base_dist(rng_key=key, sample_shape=sample_shape) for transform in self.transforms: x = transform(x) return x def sample_with_intermediates( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () ): x = self.base_dist(rng_key=key, sample_shape=sample_shape) intermediates = [] @@ -1171,7 +1166,7 @@ def sample_with_intermediates( return x, intermediates @validate_sample - def log_prob(self, value: ArrayLike, intermediates=None): + def log_prob(self, value: Array, intermediates=None): if intermediates is not None: if len(intermediates) != len(self.transforms): raise ValueError( @@ -1204,7 +1199,7 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: raise NotImplementedError - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: sign = 1 for transform in reversed(self.transforms): sign *= transform.sign @@ -1212,7 +1207,7 @@ def cdf(self, value: ArrayLike) -> ArrayLike: q = self.base_dist.cdf(value) return jnp.where(sign < 0, 1 - q, q) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: sign = 1 for transform in self.transforms: sign *= transform.sign @@ -1241,7 +1236,7 @@ def __init__( super().__init__(base_dist, AbsTransform(), validate_args=validate_args) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: dim = max(len(self.batch_shape), jnp.ndim(value)) plus_minus = jnp.array([1.0, -1.0]).reshape((2,) + (1,) * dim) return logsumexp(self.base_dist.log_prob(plus_minus * value), axis=0) @@ -1279,19 +1274,19 @@ def __init__( ) @constraints.dependent_property - def support(self): + def support(self) -> ConstraintT: return constraints.independent(constraints.real, self.event_dim) def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: if not sample_shape: return self.v shape = sample_shape + self.batch_shape + self.event_shape return jnp.broadcast_to(self.v, shape) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: log_prob = jnp.where(value == self.v, 0, -jnp.inf) log_prob = sum_rightmost(log_prob, len(self.event_shape)) return log_prob + self.log_density @@ -1304,7 +1299,7 @@ def mean(self) -> ArrayLike: def variance(self) -> ArrayLike: return jnp.zeros(self.batch_shape + self.event_shape) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return -jnp.broadcast_to(self.log_density, self.batch_shape) @@ -1320,7 +1315,7 @@ class Unit(Distribution): arg_constraints = {"log_factor": constraints.real} support = constraints.real - def __init__(self, log_factor: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, log_factor: Array, *, validate_args: Optional[bool] = None): batch_shape = jnp.shape(log_factor) event_shape = (0,) # This satisfies .size == 0. self.log_factor = log_factor @@ -1328,11 +1323,9 @@ def __init__(self, log_factor: ArrayLike, *, validate_args: Optional[bool] = Non batch_shape, event_shape, validate_args=validate_args ) - def sample( - self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = () - ): + def sample(self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = ()): return jnp.empty(sample_shape + self.batch_shape + self.event_shape) - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value)[:-1]) return jnp.broadcast_to(self.log_factor, shape) diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index 62eef556a..2d4438d7f 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -2,19 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 -from jax import lax +from typing import Optional + +from jax import Array, lax import jax.numpy as jnp -from jax.typing import ArrayLike -from numpyro._typing import TransformT from numpyro.distributions.constraints import real_vector from numpyro.distributions.transforms import Transform from numpyro.util import fori_loop -def _clamp_preserve_gradients( - x: ArrayLike, min: ArrayLike, max: ArrayLike -) -> ArrayLike: +def _clamp_preserve_gradients(x: Array, min: Array, max: Array) -> Array: return x + lax.stop_gradient(jnp.clip(x, min, max) - x) @@ -40,8 +38,8 @@ class InverseAutoregressiveTransform(Transform): def __init__( self, autoregressive_nn, - log_scale_min_clip: ArrayLike = -5.0, - log_scale_max_clip: ArrayLike = 3.0, + log_scale_min_clip: Array = -5.0, # type: ignore + log_scale_max_clip: Array = 3.0, # type: ignore ): """ :param autoregressive_nn: an autoregressive neural network whose forward call returns a real-valued @@ -51,13 +49,13 @@ def __init__( self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: """ :param numpy.ndarray x: the input into the transform """ return self.call_with_intermediates(x)[0] - def call_with_intermediates(self, x: ArrayLike) -> ArrayLike: + def call_with_intermediates(self, x: Array) -> tuple[Array, Optional[Array]]: mean, log_scale = self.arn(x) log_scale = _clamp_preserve_gradients( log_scale, self.log_scale_min_clip, self.log_scale_max_clip @@ -65,7 +63,7 @@ def call_with_intermediates(self, x: ArrayLike) -> ArrayLike: scale = jnp.exp(log_scale) return scale * x + mean, log_scale - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: """ :param numpy.ndarray y: the output of the transform to be inverted """ @@ -84,9 +82,7 @@ def _update_x(i, x): x = fori_loop(0, y.shape[-1], _update_x, jnp.zeros(y.shape)) return x - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: """ Calculates the elementwise determinant of the log jacobian. @@ -109,14 +105,14 @@ def tree_flatten(self): {"arn": self.arn}, ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, InverseAutoregressiveTransform): return False return ( (self.arn is other.arn) & jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip) & jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip) - ) + ) # type: ignore class BlockNeuralAutoregressiveTransform(Transform): @@ -135,25 +131,23 @@ class BlockNeuralAutoregressiveTransform(Transform): def __init__(self, bn_arn): self.bn_arn = bn_arn - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: """ :param numpy.ndarray x: the input into the transform """ return self.call_with_intermediates(x)[0] - def call_with_intermediates(self, x: ArrayLike) -> ArrayLike: + def call_with_intermediates(self, x: Array) -> tuple[Array, Optional[Array]]: y, logdet = self.bn_arn(x) return y, logdet - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: raise NotImplementedError( "Block neural autoregressive transform does not have an analytic" " inverse implemented." ) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: """ Calculates the elementwise determinant of the log jacobian. @@ -170,7 +164,7 @@ def log_abs_det_jacobian( def tree_flatten(self): return (), ((), {"bn_arn": self.bn_arn}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, BlockNeuralAutoregressiveTransform) and self.bn_arn is other.bn_arn diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index c6631f0d9..dcdc4dea5 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -9,7 +9,7 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, DistributionT +from numpyro._typing import ConstraintT, DistributionT, PRNGKeyT from numpyro.distributions import Distribution, constraints from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs from numpyro.distributions.util import validate_sample @@ -62,22 +62,22 @@ class _MixtureBase(Distribution): """ @property - def component_mean(self) -> ArrayLike: + def component_mean(self) -> Array: raise NotImplementedError @property - def component_variance(self) -> ArrayLike: + def component_variance(self) -> Array: raise NotImplementedError - def component_log_probs(self, value: ArrayLike) -> ArrayLike: + def component_log_probs(self, value: Array) -> Array: raise NotImplementedError def component_sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: raise NotImplementedError - def component_cdf(self, samples: ArrayLike) -> ArrayLike: + def component_cdf(self, samples: Array) -> Array: raise NotImplementedError @property @@ -112,7 +112,7 @@ def variance(self) -> ArrayLike: var_cond_mean = jnp.sum(probs * sq_deviation, axis=self.mixture_dim) return mean_cond_var + var_cond_mean - def cdf(self, samples: ArrayLike) -> ArrayLike: + def cdf(self, samples: Array) -> Array: """The cumulative distribution function :param value: samples from this distribution. @@ -125,8 +125,8 @@ def cdf(self, samples: ArrayLike) -> ArrayLike: return jnp.sum(cdf_components * self.mixing_distribution.probs, axis=-1) def sample_with_intermediates( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> tuple[ArrayLike, list[ArrayLike]]: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> tuple[Array, list[Array]]: """ A version of ``sample`` that also returns the sampled component indices @@ -142,7 +142,7 @@ def sample_with_intermediates( samples = self.component_sample(key_comp, sample_shape=sample_shape) # Sample selection indices from the categorical (shape will be sample_shape) - indices: ArrayLike = self.mixing_distribution.expand( + indices: Array = self.mixing_distribution.expand( sample_shape + self.batch_shape ).sample(key_ind) n_expand = self.event_dim + 1 @@ -157,12 +157,12 @@ def sample_with_intermediates( return jnp.squeeze(samples_selected, axis=self.mixture_dim), [indices] def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.sample_with_intermediates(key=key, sample_shape=sample_shape)[0] @validate_sample - def log_prob(self, value: ArrayLike, intermediates=None) -> ArrayLike: + def log_prob(self, value: Array, intermediates=None) -> Array: del intermediates sum_log_probs = self.component_log_probs(value) safe_sum_log_probs = jnp.where( @@ -261,26 +261,26 @@ def is_discrete(self) -> bool: return self.component_distribution.is_discrete @property - def component_mean(self) -> ArrayLike: + def component_mean(self) -> Array: return self.component_distribution.mean @property - def component_variance(self) -> ArrayLike: + def component_variance(self) -> Array: return self.component_distribution.variance - def component_cdf(self, samples: ArrayLike) -> ArrayLike: + def component_cdf(self, samples: Array) -> Array: return self.component_distribution.cdf( jnp.expand_dims(samples, axis=self.mixture_dim) ) def component_sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.component_distribution.expand( sample_shape + self.batch_shape + (self.mixture_size,) ).sample(key) - def component_log_probs(self, value: ArrayLike) -> ArrayLike: + def component_log_probs(self, value: Array) -> Array: value = jnp.expand_dims(value, self.mixture_dim) component_log_probs = self.component_distribution.log_prob(value) return jax.nn.log_softmax(self.mixing_distribution.logits) + component_log_probs @@ -429,33 +429,33 @@ def is_discrete(self) -> bool: return self.component_distributions[0].is_discrete @property - def component_mean(self) -> ArrayLike: + def component_mean(self) -> Array: return jnp.stack( [d.mean for d in self.component_distributions], axis=self.mixture_dim ) @property - def component_variance(self) -> ArrayLike: + def component_variance(self) -> Array: return jnp.stack( [d.variance for d in self.component_distributions], axis=self.mixture_dim ) - def component_cdf(self, samples: ArrayLike) -> Array: + def component_cdf(self, samples: Array) -> Array: return jnp.stack( [d.cdf(samples) for d in self.component_distributions], axis=self.mixture_dim, ) def component_sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: keys = jax.random.split(key, self.mixture_size) samples = [] for k, d in zip(keys, self.component_distributions): samples.append(d.expand(sample_shape + self.batch_shape).sample(k)) return jnp.stack(samples, axis=self.mixture_dim) - def component_log_probs(self, value: ArrayLike) -> ArrayLike: + def component_log_probs(self, value: Array) -> Array: component_log_probs = [] for d in self.component_distributions: log_prob = d.log_prob(value) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c1f7a6fa8..d764598ca 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Union import warnings import weakref @@ -17,7 +17,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import TransformT +from numpyro._typing import ConstraintT, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -59,7 +59,7 @@ ] -def _clipped_expit(x: ArrayLike) -> ArrayLike: +def _clipped_expit(x: Array) -> Array: finfo = jnp.finfo(jnp.result_type(x)) return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) @@ -74,7 +74,7 @@ def __init_subclass__(cls, **kwargs): register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) @property - def inv(self) -> TransformT: + def inv(self) -> Optional[TransformT]: inv = None if self._inv is not None: inv = self._inv() @@ -83,20 +83,16 @@ def inv(self) -> TransformT: self._inv = weakref.ref(inv) return inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: raise NotImplementedError - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: raise NotImplementedError - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: raise NotImplementedError - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: Array) -> Tuple[Array, Optional[Array]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -114,7 +110,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape @property - def sign(self) -> ArrayLike: + def sign(self) -> Array: """ Sign of the derivative of the transform if it is bijective. """ @@ -148,7 +144,7 @@ class ParameterFreeTransform(Transform): def tree_flatten(self): return (), ((), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> Union[bool, Array]: return isinstance(other, type(self)) @@ -158,27 +154,25 @@ def __init__(self, transform: TransformT) -> None: self._inv = transform @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: return self._inv.codomain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: return self._inv.domain @property - def sign(self) -> ArrayLike: + def sign(self) -> Array: return self._inv.sign @property - def inv(self) -> TransformT: + def inv(self) -> Optional[TransformT]: return self._inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return self._inv._inverse(x) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -191,7 +185,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self._inv,), (("_inv",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> Union[bool, Array]: if not isinstance(other, _InverseTransform): return False return self._inv == other._inv @@ -201,13 +195,13 @@ class AbsTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.positive - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> Union[bool, Array]: return isinstance(other, AbsTransform) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return jnp.abs(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", @@ -226,14 +220,14 @@ def __init__( self, loc: ArrayLike, scale: ArrayLike, - domain: constraints.Constraint = constraints.real, + domain: ConstraintT = constraints.real, ): self.loc = loc self.scale = scale self.domain = domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): @@ -261,18 +255,16 @@ def codomain(self) -> constraints.Constraint: raise NotImplementedError @property - def sign(self) -> ArrayLike: + def sign(self) -> Array: return jnp.sign(self.scale) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return self.loc + self.scale * x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return (y - self.loc) / self.scale - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -288,7 +280,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> Union[bool, Array]: if not isinstance(other, AffineTransform): return False return ( @@ -321,7 +313,7 @@ def __init__(self, parts: Sequence[TransformT]) -> None: self.parts = parts @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim @@ -333,7 +325,7 @@ def domain(self) -> constraints.Constraint: ) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim @@ -345,25 +337,23 @@ def codomain(self) -> constraints.Constraint: ) @property - def sign(self) -> ArrayLike: + def sign(self) -> Array: sign = 1 for transform in self.parts: sign *= transform.sign return sign - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: for part in self.parts: x = part(x) return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: for part in self.parts[::-1]: y = part.inv(y) return y - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: if intermediates is not None: if len(intermediates) != len(self.parts): raise ValueError( @@ -389,9 +379,7 @@ def log_abs_det_jacobian( result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: Array) -> Tuple[Array, Optional[Array]]: intermediates = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) @@ -414,7 +402,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.parts,), (("parts",), {}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) @@ -452,15 +440,13 @@ class CholeskyTransform(ParameterFreeTransform): domain = constraints.positive_definite codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return jnp.linalg.cholesky(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return jnp.matmul(y, jnp.swapaxes(y, -2, -1)) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) @@ -499,12 +485,12 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a domain = constraints.real_vector codomain = constraints.corr_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim @@ -518,9 +504,7 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: # inverse of tanh return jnp.arctanh(t) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. @@ -551,9 +535,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): domain = constraints.corr_matrix codomain = constraints.corr_cholesky - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) @@ -569,7 +551,7 @@ def __init__(self, domain=constraints.real): self.domain = domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: @@ -584,22 +566,20 @@ def codomain(self) -> constraints.Constraint: else: raise NotImplementedError - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: # XXX consider to clamp from below for stability if necessary return jnp.exp(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return jnp.log(y) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return x def tree_flatten(self): return (self.domain,), (("domain",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ExpTransform): return False return self.domain == other.domain @@ -608,15 +588,13 @@ def __eq__(self, other: TransformT) -> bool: class IdentityTransform(ParameterFreeTransform): sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return y - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.zeros_like(x) @@ -638,26 +616,24 @@ def __init__( super().__init__() @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: return constraints.independent( self.base_transform.domain, self.reinterpreted_batch_ndims ) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: return constraints.independent( self.base_transform.codomain, self.reinterpreted_batch_ndims ) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return self.base_transform(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return self.base_transform._inverse(y) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) @@ -666,9 +642,7 @@ def log_abs_det_jacobian( raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}") return sum_rightmost(result, self.reinterpreted_batch_ndims) - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: Array) -> Tuple[Array, Optional[Array]]: return self.base_transform.call_with_intermediates(x) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -683,7 +657,7 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, IndependentTransform): return False return (self.base_transform == other.base_transform) & ( @@ -699,7 +673,7 @@ class L1BallTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.l1_ball - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: # transform to (-1, 1) interval t = jnp.tanh(x) @@ -709,7 +683,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) return t * remainder - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] @@ -722,9 +696,7 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: t = jnp.clip(t, -1 + finfo.eps, 1 - finfo.eps) return jnp.arctanh(t) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) @@ -765,7 +737,7 @@ class LowerCholeskyAffine(Transform): domain = constraints.real_vector codomain = constraints.real_vector - def __init__(self, loc: ArrayLike, scale_tril: Array): + def __init__(self, loc: Array, scale_tril: Array): if jnp.ndim(scale_tril) != 2: raise ValueError( "Only support 2-dimensional scale_tril matrix. " @@ -775,21 +747,19 @@ def __init__(self, loc: ArrayLike, scale_tril: Array): self.loc = loc self.scale_tril = scale_tril - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1 ) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: y = y - self.loc original_shape = jnp.shape(y) yt = jnp.reshape(y, (-1, original_shape[-1])).T xt = solve_triangular(self.scale_tril, yt, lower=True) return jnp.reshape(xt.T, original_shape) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], @@ -808,7 +778,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, LowerCholeskyAffine): return False return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( @@ -827,21 +797,19 @@ class LowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) return add_diag(z, diag) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate( [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1 ) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1) @@ -869,20 +837,18 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return add_diag(z, 1) * diag[..., None] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -913,17 +879,15 @@ class OrderedTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.ordered_vector - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1) return jnp.cumsum(z, axis=-1) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: x = jnp.log(y[..., 1:] - y[..., :-1]) return jnp.concatenate([y[..., :1], x], axis=-1) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.sum(x[..., 1:], -1) @@ -934,10 +898,10 @@ class PermuteTransform(Transform): def __init__(self, permutation: Array) -> None: self.permutation = permutation - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return x[..., self.permutation] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: size = self.permutation.size permutation_inv = ( jnp.zeros(size, dtype=jnp.result_type(int)) @@ -946,15 +910,13 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: ) return y[..., permutation_inv] - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.full(jnp.shape(x)[:-1], 0.0) def tree_flatten(self): return (self.permutation,), (("permutation",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False return jnp.array_equal(self.permutation, other.permutation) @@ -964,18 +926,16 @@ class PowerTransform(Transform): domain = constraints.positive codomain = constraints.positive - def __init__(self, exponent: ArrayLike) -> None: + def __init__(self, exponent: Array) -> None: self.exponent = exponent - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return jnp.power(x, self.exponent) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return jnp.power(y, 1 / self.exponent) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.log(jnp.abs(self.exponent * y / x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -987,13 +947,13 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.exponent,), (("exponent",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False return jnp.array_equal(self.exponent, other.exponent) @property - def sign(self) -> ArrayLike: + def sign(self) -> Array: return jnp.sign(self.exponent) @@ -1001,15 +961,13 @@ class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return _clipped_expit(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return logit(y) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return -softplus(x) - softplus(-x) @@ -1042,15 +1000,15 @@ class SimplexToOrderedTransform(Transform): domain = constraints.simplex codomain = constraints.ordered_vector - def __init__(self, anchor_point: ArrayLike = 0.0) -> None: + def __init__(self, anchor_point: Array = 0.0) -> None: self.anchor_point = anchor_point - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] @@ -1060,9 +1018,7 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: x = s[..., 1:] - s[..., :-1] return x - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) @@ -1071,7 +1027,7 @@ def log_abs_det_jacobian( def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False return jnp.array_equal(self.anchor_point, other.anchor_point) @@ -1083,7 +1039,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] + 1,) -def _softplus_inv(y: ArrayLike) -> ArrayLike: +def _softplus_inv(y: Array) -> Array: return jnp.log(-jnp.expm1(-y)) + y @@ -1097,15 +1053,13 @@ class SoftplusTransform(ParameterFreeTransform): codomain = constraints.softplus_positive sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return softplus(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return _softplus_inv(y) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return -softplus(-x) @@ -1119,20 +1073,18 @@ class SoftplusLowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.softplus_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -1149,7 +1101,7 @@ class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick @@ -1165,16 +1117,14 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return z_padded * z1m_cumprod_shifted - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: y_crop = y[..., :-1] z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod x = jnp.log(y_crop / z1m_cumprod) return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) @@ -1207,7 +1157,7 @@ def __init__(self, unpack_fn, pack_fn=None): self.unpack_fn = unpack_fn self.pack_fn = pack_fn - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) @@ -1217,7 +1167,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: else: return self.unpack_fn(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: if self.pack_fn is None: raise NotImplementedError( "pack_fn needs to be provided to perform UnpackTransform.inv." @@ -1237,9 +1187,7 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: ) return self.pack_fn(y) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1252,7 +1200,7 @@ def tree_flatten(self): # XXX: what if unpack_fn is a parametrized callable pytree? return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, UnpackTransform) and (self.unpack_fn is other.unpack_fn) @@ -1293,11 +1241,11 @@ def __init__( self._inverse_shape = inverse_shape @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: return constraints.independent(constraints.real, len(self._inverse_shape)) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: return constraints.independent(constraints.real, len(self._forward_shape)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1306,15 +1254,13 @@ def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._inverse_shape, self._forward_shape) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: return jnp.reshape(x, self.forward_shape(jnp.shape(x))) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) def tree_flatten(self): @@ -1324,7 +1270,7 @@ def tree_flatten(self): } return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ReshapeTransform) and self._forward_shape == other._forward_shape @@ -1405,14 +1351,14 @@ def tree_flatten(self): return (), ((), aux_data) @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: return constraints.independent(constraints.real, self.transform_ndims) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: return constraints.independent(constraints.complex, self.transform_ndims) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, RealFastFourierTransform) and self.transform_ndims == other.transform_ndims @@ -1571,7 +1517,7 @@ def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None): def tree_flatten(self): return (self.transition_matrix,), (("transition_matrix",), {}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, RecursiveLinearTransform): return False return jnp.array_equal(self.transition_matrix, other.transition_matrix) @@ -1592,11 +1538,11 @@ def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: return constraints.independent(constraints.real, self.transform_ndims) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: return constraints.zero_sum(self.transform_ndims) def __call__(self, x: Array) -> Array: @@ -1654,7 +1600,7 @@ def tree_flatten(self): aux_data = {"transform_ndims": self.transform_ndims} return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ZeroSumTransform) and self.transform_ndims == other.transform_ndims @@ -1700,7 +1646,7 @@ def register(self, constraint, factory=None): if factory is None: return lambda factory: self.register(constraint, factory) - if isinstance(constraint, constraints.Constraint): + if isinstance(constraint, ConstraintT): constraint = type(constraint) self._registry[constraint] = factory diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index efda0d64e..f08951617 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -5,13 +5,13 @@ from typing import Optional, Union import jax -from jax import lax +from jax import Array, lax import jax.numpy as jnp import jax.random as random from jax.scipy.special import logsumexp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT +from numpyro._typing import ConstraintT, PRNGKeyT from numpyro.distributions import constraints from numpyro.distributions.continuous import ( Cauchy, @@ -73,8 +73,8 @@ def _tail_prob_at_high(self): return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) @@ -82,7 +82,7 @@ def sample( u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) ppf = (1 - sign) * loc + sign * self.base_dist.icdf( @@ -91,7 +91,7 @@ def icdf(self, q: ArrayLike) -> ArrayLike: return jnp.where(q < 0, jnp.nan, ppf) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0) return self.base_dist.log_prob(value) - jnp.log( sign * (self._tail_prob_at_high - self._tail_prob_at_low) @@ -108,7 +108,7 @@ def mean(self) -> ArrayLike: raise NotImplementedError("mean only available for Normal and Cauchy") @property - def var(self) -> ArrayLike: + def var(self) -> Array: if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return (self.base_dist.scale**2) * ( @@ -152,12 +152,12 @@ def support(self) -> ConstraintT: return self._support @lazy_property - def _cdf_at_high(self) -> ArrayLike: + def _cdf_at_high(self) -> Array: return self.base_dist.cdf(self.high) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) @@ -165,12 +165,12 @@ def sample( u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: ppf = self.base_dist.icdf(q * self._cdf_at_high) return jnp.where(q > 1, jnp.nan, ppf) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high) @property @@ -184,7 +184,7 @@ def mean(self) -> ArrayLike: raise NotImplementedError("mean only available for Normal and Cauchy") @property - def var(self) -> ArrayLike: + def var(self) -> Array: if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( @@ -235,21 +235,21 @@ def support(self) -> ConstraintT: return self._support @lazy_property - def _tail_prob_at_low(self) -> ArrayLike: + def _tail_prob_at_low(self) -> Array: # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property - def _tail_prob_at_high(self) -> ArrayLike: + def _tail_prob_at_high(self) -> Array: # if low < loc, returns cdf(high); otherwise returns 1 - cdf(high) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.high)) @lazy_property - def _log_diff_tail_probs(self) -> ArrayLike: + def _log_diff_tail_probs(self) -> Array: # use log_cdf method, if available, to avoid inf's in log_prob # fall back to cdf, if log_cdf not available log_cdf = getattr(self.base_dist, "log_cdf", None) @@ -266,8 +266,8 @@ def _log_diff_tail_probs(self) -> ArrayLike: return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low)) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) @@ -275,7 +275,7 @@ def sample( u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: # NB: we use a more numerically stable formula for a symmetric base distribution # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * q) = icdf[(1 - q) * cdf(low) + q * cdf(high)] # will suffer by precision issues when low is large; @@ -291,7 +291,7 @@ def icdf(self, q: ArrayLike) -> ArrayLike: return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: # NB: we use a more numerically stable formula for a symmetric base distribution # if low < loc # cdf(high) - cdf(low) = as-is @@ -311,7 +311,7 @@ def mean(self) -> ArrayLike: raise NotImplementedError("mean only available for Normal and Cauchy") @property - def var(self) -> ArrayLike: + def var(self) -> Array: if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) @@ -329,8 +329,8 @@ def var(self) -> ArrayLike: def TruncatedDistribution( base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT], - low: Optional[ArrayLike] = None, - high: Optional[ArrayLike] = None, + low: Optional[Array] = None, + high: Optional[Array] = None, *, validate_args: Optional[bool] = None, ): @@ -363,11 +363,11 @@ def TruncatedDistribution( def TruncatedCauchy( - loc: ArrayLike = 0.0, - scale: ArrayLike = 1.0, + loc: Array = 0.0, + scale: Array = 1.0, *, - low: Optional[ArrayLike] = None, - high: Optional[ArrayLike] = None, + low: Optional[Array] = None, + high: Optional[Array] = None, validate_args: Optional[bool] = None, ): return TruncatedDistribution( @@ -376,11 +376,11 @@ def TruncatedCauchy( def TruncatedNormal( - loc: ArrayLike = 0.0, - scale: ArrayLike = 1.0, + loc: Array = 0.0, + scale: Array = 1.0, *, - low: Optional[ArrayLike] = None, - high: Optional[ArrayLike] = None, + low: Optional[Array] = None, + high: Optional[Array] = None, validate_args: Optional[bool] = None, ): return TruncatedDistribution( @@ -405,8 +405,8 @@ def __init__( ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates)) x = random.gamma( @@ -416,7 +416,7 @@ def sample( return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point) @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: value = value[..., None] all_indices = jnp.arange(0, self.num_log_prob_terms) two_n_plus_one = 2.0 * all_indices + 1.0 @@ -464,9 +464,9 @@ class DoublyTruncatedPowerLaw(Distribution): def __init__( self, - alpha: ArrayLike, - low: ArrayLike, - high: ArrayLike, + alpha: Array, + low: Array, + high: Array, *, validate_args: Optional[bool] = None, ): @@ -484,7 +484,7 @@ def support(self) -> ConstraintT: return self._support @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: r"""Logarithmic probability distribution: Z inequal minus one: .. math:: @@ -497,14 +497,12 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: """ @jax.custom_jvp - def f( - x: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike - ) -> ArrayLike: + def f(x: Array, alpha: Array, low: Array, high: Array) -> Array: neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) # eq_neg1_alpha = jnp.where(~neq_neg1_mask, alpha, -1.0) - def neq_neg1_fn() -> ArrayLike: + def neq_neg1_fn() -> Array: one_more_alpha = 1.0 + neq_neg1_alpha return jnp.log( jnp.power(x, neq_neg1_alpha) @@ -512,16 +510,16 @@ def neq_neg1_fn() -> ArrayLike: / (jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha)) ) - def eq_neg1_fn() -> ArrayLike: + def eq_neg1_fn() -> Array: return -jnp.log(x) - jnp.log(jnp.log(high) - jnp.log(low)) return jnp.where(neq_neg1_mask, neq_neg1_fn(), eq_neg1_fn()) @f.defjvp def f_jvp( - primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], - tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], - ) -> tuple[ArrayLike, ArrayLike]: + primals: tuple[Array, Array, Array, Array], + tangents: tuple[Array, Array, Array, Array], + ) -> tuple[Array, Array]: x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents @@ -539,7 +537,7 @@ def f_jvp( # Alpha tangent with approximation # Variable part for all values alpha unequal -1 - def alpha_tangent_variable(alpha: ArrayLike) -> ArrayLike: + def alpha_tangent_variable(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -602,7 +600,7 @@ def alpha_tangent_variable(alpha: ArrayLike) -> ArrayLike: return f(value, self.alpha, self.low, self.high) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: r"""Cumulated probability distribution: Z inequal minus one: @@ -620,20 +618,18 @@ def cdf(self, value: ArrayLike) -> ArrayLike: """ @jax.custom_jvp - def f( - x: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike - ) -> ArrayLike: + def f(x: Array, alpha: Array, low: Array, high: Array) -> Array: neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) - def cdf_when_alpha_neq_neg1() -> ArrayLike: + def cdf_when_alpha_neq_neg1() -> Array: one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) return (jnp.power(x, one_more_alpha) - low_pow_one_more_alpha) / ( jnp.power(high, one_more_alpha) - low_pow_one_more_alpha ) - def cdf_when_alpha_eq_neg1() -> ArrayLike: + def cdf_when_alpha_eq_neg1() -> Array: return jnp.log(x / low) / jnp.log(high / low) cdf_val = jnp.where( @@ -645,9 +641,9 @@ def cdf_when_alpha_eq_neg1() -> ArrayLike: @f.defjvp def f_jvp( - primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], - tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], - ) -> tuple[ArrayLike, ArrayLike]: + primals: tuple[Array, Array, Array, Array], + tangents: tuple[Array, Array, Array, Array], + ) -> tuple[Array, Array]: x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents @@ -663,13 +659,13 @@ def f_jvp( primal_out = f(*primals) # Tangents for alpha not equals -1 - def x_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def x_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha return (one_more_alpha * jnp.power(x, alpha)) / ( jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha) ) - def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def alpha_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -686,7 +682,7 @@ def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike: ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) return term1 - term2 - def low_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def low_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -696,7 +692,7 @@ def low_neq_neg1(alpha: ArrayLike) -> ArrayLike: term1 = term2 * (x_pow_one_more_alpha - low_pow_one_more_alpha) / change return term1 - term2 - def high_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def high_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -708,15 +704,15 @@ def high_neq_neg1(alpha: ArrayLike) -> ArrayLike: ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) # Tangents for alpha equals -1 - def x_eq_neg1() -> ArrayLike: + def x_eq_neg1() -> Array: return jnp.reciprocal(x * (log_high - log_low)) - def low_eq_neg1() -> ArrayLike: + def low_eq_neg1() -> Array: return (log_x - log_low) / ( jnp.square(log_high - log_low) * low ) - jnp.reciprocal((log_high - log_low) * low) - def high_eq_neg1() -> ArrayLike: + def high_eq_neg1() -> Array: return (log_x - log_low) / (jnp.square(log_high - log_low) * high) # Including approximation for alpha = -1 @@ -744,7 +740,7 @@ def high_eq_neg1() -> ArrayLike: return f(value, self.alpha, self.low, self.high) - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: r"""Inverse cumulated probability distribution: Z inequal minus one: @@ -760,13 +756,11 @@ def icdf(self, q: ArrayLike) -> ArrayLike: """ @jax.custom_jvp - def f( - q: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike - ) -> ArrayLike: + def f(q: Array, alpha: Array, low: Array, high: Array) -> Array: neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) - def icdf_alpha_neq_neg1() -> ArrayLike: + def icdf_alpha_neq_neg1() -> Array: one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -776,7 +770,7 @@ def icdf_alpha_neq_neg1() -> ArrayLike: jnp.reciprocal(one_more_alpha), ) - def icdf_alpha_eq_neg1() -> ArrayLike: + def icdf_alpha_eq_neg1() -> Array: return jnp.power(high / low, q) * low icdf_val = jnp.where( @@ -788,9 +782,9 @@ def icdf_alpha_eq_neg1() -> ArrayLike: @f.defjvp def f_jvp( - primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], - tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], - ) -> tuple[ArrayLike, ArrayLike]: + primals: tuple[Array, Array, Array, Array], + tangents: tuple[Array, Array, Array, Array], + ) -> tuple[Array, Array]: x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents @@ -805,7 +799,7 @@ def f_jvp( primal_out = f(*primals) # Tangents for alpha not equal -1 - def x_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def x_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -818,7 +812,7 @@ def x_neq_neg1(alpha: ArrayLike) -> ArrayLike: ) ) / one_more_alpha - def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def alpha_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -837,7 +831,7 @@ def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike: term3 = jnp.log(factor0) / jnp.square(one_more_alpha) return term1 * (term2 - term3) - def low_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def low_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -851,7 +845,7 @@ def low_neq_neg1(alpha: ArrayLike) -> ArrayLike: ) ) - def high_neq_neg1(alpha: ArrayLike) -> ArrayLike: + def high_neq_neg1(alpha: Array) -> Array: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) @@ -866,16 +860,16 @@ def high_neq_neg1(alpha: ArrayLike) -> ArrayLike: ) # Tangents for alpha equals -1 - def dx_eq_neg1() -> ArrayLike: + def dx_eq_neg1() -> Array: return low * jnp.power(high_over_low, x) * (log_high - log_low) - def low_eq_neg1() -> ArrayLike: + def low_eq_neg1() -> Array: return ( jnp.power(high_over_low, x) - (high * x * jnp.power(high_over_low, x - 1)) / low ) - def high_eq_neg1() -> ArrayLike: + def high_eq_neg1() -> Array: return x * jnp.power(high_over_low, x - 1) # Including approximation for alpha = -1 \ @@ -904,8 +898,8 @@ def high_eq_neg1() -> ArrayLike: return f(q, self.alpha, self.low, self.high) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) @@ -945,7 +939,7 @@ class LowerTruncatedPowerLaw(Distribution): pytree_aux_fields = ("_support",) def __init__( - self, alpha: ArrayLike, low: ArrayLike, *, validate_args: Optional[bool] = None + self, alpha: Array, low: Array, *, validate_args: Optional[bool] = None ): self.alpha, self.low = promote_shapes(alpha, low) batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low)) @@ -959,7 +953,7 @@ def support(self) -> ConstraintT: return self._support @validate_sample - def log_prob(self, value: ArrayLike) -> ArrayLike: + def log_prob(self, value: Array) -> Array: one_more_alpha = 1.0 + self.alpha return ( self.alpha * jnp.log(value) @@ -967,7 +961,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: - one_more_alpha * jnp.log(self.low) ) - def cdf(self, value: ArrayLike) -> ArrayLike: + def cdf(self, value: Array) -> Array: cdf_val = jnp.where( jnp.less_equal(value, self.low), jnp.zeros_like(value), @@ -975,7 +969,7 @@ def cdf(self, value: ArrayLike) -> ArrayLike: ) return cdf_val - def icdf(self, q: ArrayLike) -> ArrayLike: + def icdf(self, q: Array) -> Array: nan_mask = jnp.logical_or(jnp.isnan(q), jnp.less(q, 0.0)) nan_mask = jnp.logical_or(nan_mask, jnp.greater(q, 1.0)) return jnp.where( @@ -985,8 +979,8 @@ def icdf(self, q: ArrayLike) -> ArrayLike: ) def sample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index f690bbecf..8b1d37b97 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -9,11 +9,10 @@ import numpy as np import jax -from jax import jit, lax, random, vmap +from jax import Array, jit, lax, random, vmap import jax.numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import digamma -from jax.typing import ArrayLike from numpyro.util import not_jax_tracer @@ -433,7 +432,7 @@ def logmatmulexp(x, y): @jax.custom_jvp -def log1mexp(x: ArrayLike) -> ArrayLike: +def log1mexp(x: Array) -> Array: """ Numerically stable calculation of the quantity :math:`\\log(1 - \\exp(x))`, following the algorithm @@ -463,7 +462,7 @@ def log1mexp(x: ArrayLike) -> ArrayLike: log1mexp.defjvps(lambda t, ans, x: -t / jnp.expm1(-x)) -def logdiffexp(a: ArrayLike, b: ArrayLike) -> ArrayLike: +def logdiffexp(a: Array, b: Array) -> Array: """ Numerically stable calculation of the quantity :math:`\\log(\\exp(a) - \\exp(b))`, @@ -489,7 +488,7 @@ def logdiffexp(a: ArrayLike, b: ArrayLike) -> ArrayLike: ) -def clamp_probs(probs: ArrayLike) -> ArrayLike: +def clamp_probs(probs: Array) -> Array: finfo = jnp.finfo(jnp.result_type(probs, float)) return jnp.clip(probs, finfo.tiny, 1.0 - finfo.eps) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index b57700b9c..c7a1a274c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -102,7 +102,6 @@ def seeded_model(data): from jax import Array, random import jax.numpy as jnp -from jax.typing import ArrayLike import numpyro from numpyro._typing import Message, TraceT @@ -444,7 +443,7 @@ class condition(Messenger): def __init__( self, fn: Optional[Callable] = None, - data: Optional[dict[str, ArrayLike]] = None, + data: Optional[dict[str, Array]] = None, condition_fn: Optional[Callable] = None, ) -> None: self.condition_fn = condition_fn @@ -583,7 +582,7 @@ class mask(Messenger): def __init__( self, fn: Optional[Callable] = None, - mask: Optional[ArrayLike] = True, + mask: Optional[Array] = True, ) -> None: if jnp.result_type(mask) != "bool": raise ValueError("`mask` should be a bool array.") @@ -677,7 +676,7 @@ class scale(Messenger): def __init__( self, fn: Optional[Callable] = None, - scale: ArrayLike = 1.0, + scale: Array = 1.0, ) -> None: if not_jax_tracer(scale): if np.any(np.less_equal(scale, 0)): @@ -975,7 +974,7 @@ class do(Messenger): def __init__( self, fn: Optional[Callable] = None, - data: Optional[dict[str, ArrayLike]] = None, + data: Optional[dict[str, Array]] = None, ) -> None: self.data = data self._intervener_id = str(id(self)) diff --git a/numpyro/optim.py b/numpyro/optim.py index 72decf79b..05b3b9610 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -12,13 +12,12 @@ from typing import Any, Literal, Optional, Protocol import jax -from jax import jacfwd, lax, value_and_grad +from jax import Array, jacfwd, lax, value_and_grad from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.optimize import minimize from jax.tree_util import register_pytree_node -from jax.typing import ArrayLike __all__ = [ "Adam", @@ -34,7 +33,7 @@ _Params = Any _OptState = Any -_IterOptState = tuple[ArrayLike, _OptState] +_IterOptState = tuple[Array, _OptState] def _value_and_grad(f, x, forward_mode_differentiation: bool = False) -> tuple: # noqa: ANN001 @@ -55,7 +54,7 @@ class UpdateExtraArgsFn(Protocol): def __call__( self, - arr: ArrayLike, + arr: Array, params: _Params, state: _OptState, **extra_args, @@ -85,7 +84,7 @@ def init(self, params: _Params) -> _IterOptState: return jnp.array(0), opt_state def update( - self, g: _Params, state: _IterOptState, value: Optional[ArrayLike] = None + self, g: _Params, state: _IterOptState, value: Optional[Array] = None ) -> _IterOptState: """ Gradient update for the optimizer. @@ -202,7 +201,7 @@ def __init__(self, *args, clip_norm: float = 10.0, **kwargs) -> None: super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs) def update( - self, g: _Params, state: _IterOptState, value: Optional[ArrayLike] = None + self, g: _Params, state: _IterOptState, value: Optional[Array] = None ) -> _IterOptState: i, opt_state = state # clip norm @@ -255,8 +254,8 @@ def __init__(self, *args, **kwargs) -> None: # When arbitrary pytree is supported in JAX, we can just simply use # identity functions for `init_fn` and `get_params`. class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])): - flat_params: ArrayLike - unravel_fn: Callable[[ArrayLike], _Params] + flat_params: Array + unravel_fn: Callable[[Array], _Params] register_pytree_node( @@ -276,7 +275,7 @@ def init_fn(params: _Params) -> _MinimizeState: return _MinimizeState(flat_params, unravel_fn) def update_fn( - i: ArrayLike, grad_tree: ArrayLike, opt_state: _MinimizeState + i: Array, grad_tree: Array, opt_state: _MinimizeState ) -> _MinimizeState: # we don't use update_fn in Minimize, so let it do nothing return opt_state @@ -378,10 +377,10 @@ def init_fn(params: _Params) -> tuple[_Params, Any]: return params, opt_state def update_fn( - step: ArrayLike, - grads: ArrayLike, + step: Array, + grads: Array, state: tuple[_Params, Any], - value: ArrayLike, + value: Array, ) -> tuple[_Params, Any]: params, opt_state = state updates, opt_state = optax.with_extra_args_support(transformation).update( diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 865c6a093..395cfc63c 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -11,7 +11,6 @@ import jax from jax import Array, lax, random import jax.numpy as jnp -from jax.typing import ArrayLike import numpyro from numpyro._typing import Message @@ -122,10 +121,10 @@ def __call__(self, *args, **kwargs): def _masked_observe( name: str, fn: DistributionT, - obs: Optional[ArrayLike], + obs: Optional[Array], obs_mask, # noqa: ANN001 **kwargs, -) -> ArrayLike: +) -> Array: # Split into two auxiliary sample sites. with numpyro.handlers.mask(mask=obs_mask): observed = sample(f"{name}_observed", fn, **kwargs, obs=obs) @@ -142,12 +141,12 @@ def _masked_observe( def sample( name: str, fn: DistributionT, - obs: Optional[ArrayLike] = None, - rng_key: Optional[ArrayLike] = None, + obs: Optional[Array] = None, + rng_key: Optional[Array] = None, sample_shape: tuple[int, ...] = (), infer: Optional[dict] = None, - obs_mask: Optional[ArrayLike] = None, -) -> ArrayLike: + obs_mask: Optional[Array] = None, +) -> Array: """ Returns a random sample from the stochastic function `fn`. This can have additional side effects when wrapped inside effect handlers like @@ -251,9 +250,9 @@ def sample( def param( name: str, - init_value: Optional[Union[ArrayLike, Callable]] = None, + init_value: Optional[Union[Array, Callable]] = None, **kwargs, -) -> Optional[ArrayLike]: +) -> Optional[Array]: """ Annotate the given site as an optimizable parameter for use with :mod:`jax.example_libraries.optimizers`. For an example of how `param` statements @@ -287,7 +286,7 @@ def param( if callable(init_value): - def fn(init_fn: Callable, *args, **kwargs) -> ArrayLike: + def fn(init_fn: Callable, *args, **kwargs) -> Array: return init_fn(prng_key()) else: @@ -310,7 +309,7 @@ def fn(init_fn: Callable, *args, **kwargs) -> ArrayLike: return msg["value"] -def deterministic(name: str, value: ArrayLike) -> ArrayLike: +def deterministic(name: str, value: Array) -> Array: """ Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except @@ -336,9 +335,7 @@ def deterministic(name: str, value: ArrayLike) -> ArrayLike: return msg["value"] -def mutable( - name: str, init_value: Optional[ArrayLike] = None -) -> Union[ArrayLike, None]: +def mutable(name: str, init_value: Optional[Array] = None) -> Union[Array, None]: """ This primitive is used to store a mutable value that can be changed during model execution:: @@ -394,7 +391,7 @@ def _inspect() -> dict: return msg -def get_mask() -> Union[ArrayLike, None]: +def get_mask() -> Union[Array, None]: """ Records the effects of enclosing ``handlers.mask`` handlers. This is useful for avoiding expensive ``numpyro.factor()`` computations during @@ -441,8 +438,8 @@ def module(name: str, nn: tuple, input_shape: Optional[tuple] = None) -> Callabl def _subsample_fn( - size: int, subsample_size: int, rng_key: Optional[ArrayLike] = None -) -> ArrayLike: + size: int, subsample_size: int, rng_key: Optional[Array] = None +) -> Array: if rng_key is None: raise ValueError( "Missing random key to generate subsample indices." @@ -651,7 +648,7 @@ def plate_stack( yield -def factor(name: str, log_factor: ArrayLike) -> None: +def factor(name: str, log_factor: Array) -> None: """ Factor statement to add arbitrary log probability factor to a probabilistic model. @@ -690,7 +687,7 @@ def prng_key() -> Union[Array, None]: return msg["value"] -def subsample(data: ArrayLike, event_dim: int) -> ArrayLike: +def subsample(data: Array, event_dim: int) -> Array: """ EXPERIMENTAL Subsampling statement to subsample data based on enclosing :class:`~numpyro.primitives.plate` s. diff --git a/pyproject.toml b/pyproject.toml index d79a3e532..3648ab50b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ extend-include = ["*.ipynb"] [tool.ruff.lint] select = ["ANN", "E", "F", "I", "W"] ignore = [ - "ANN002", # missing args type annotation - "ANN003", # missing kwargs type annotation - "ANN204", # missing type annotation for __call__ + "ANN002", # missing args type annotation + "ANN003", # missing kwargs type annotation + "ANN204", # missing type annotation for __call__ "E203", ] @@ -68,7 +68,9 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.ruff.lint.per-file-ignores] -"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = ["ANN"] # require type annotations in typed modules +"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = [ + "ANN", +] # require type annotations in typed modules [tool.ruff.lint.extend-per-file-ignores] "numpyro/contrib/tfp/distributions.py" = ["F811"] @@ -123,6 +125,7 @@ module = [ "numpyro.contrib.hsgp.*", "numpyro.contrib.stochastic_support.*", "numpyro.diagnostics.*", + "numpyro.distributions.*", "numpyro.handlers.*", "numpyro.infer.elbo.*", "numpyro.optim.*", diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index dea2c9383..244bc971a 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -13,7 +13,6 @@ from jax import Array, random import jax.numpy as jnp -from jax.typing import ArrayLike import numpyro from numpyro.contrib.hsgp.approximation import ( @@ -32,7 +31,7 @@ def generate_synthetic_one_dim_data( - rng_key: ArrayLike, start: float, stop: float, num: int, scale: float + rng_key: Array, start: float, stop: float, num: int, scale: float ) -> tuple[Array, Array]: x = jnp.linspace(start=start, stop=stop, num=num) y = jnp.sin(4 * jnp.pi * x) + jnp.sin(7 * jnp.pi * x) @@ -53,7 +52,7 @@ def synthetic_one_dim_data() -> tuple[Array, Array]: def generate_synthetic_two_dim_data( - rng_key: ArrayLike, start: float, stop: float, num: int, scale: float + rng_key: Array, start: float, stop: float, num: int, scale: float ) -> tuple[Array, Array]: x = random.uniform(rng_key, shape=(num, 2), minval=start, maxval=stop) y = jnp.sin(4 * jnp.pi * x[:, 0]) + jnp.sin(7 * jnp.pi * x[:, 1]) @@ -117,9 +116,9 @@ def synthetic_two_dim_data() -> tuple[Array, Array]: ], ) def test_kernel_approx_squared_exponential( - x1: ArrayLike, - x2: ArrayLike, - length: Union[float, ArrayLike], + x1: Array, + x2: Array, + length: Union[float, Array], ell: float, xfail: bool, ): @@ -201,7 +200,7 @@ def _exact_rbf(length): ], ) def test_kernel_approx_squared_matern( - x1: ArrayLike, x2: ArrayLike, nu: float, length: ArrayLike, ell: float + x1: Array, x2: Array, nu: float, length: Array, ell: float ): """ensure that the approximation of the matern kernel is accurate, matching the exact kernel implementation from sklearn. @@ -243,8 +242,8 @@ def _exact_matern(length): ], ) def test_kernel_approx_periodic( - x1: ArrayLike, - x2: ArrayLike, + x1: Array, + x2: Array, w0: float, length: float, ): @@ -281,7 +280,7 @@ def test_kernel_approx_periodic( ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"], ) def test_approximation_squared_exponential( - x: ArrayLike, + x: Array, alpha: float, length: float, ell: Union[int, float, list[Union[int, float]]], @@ -332,7 +331,7 @@ def model(x, alpha, length, ell, m, non_centered): ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"], ) def test_approximation_matern( - x: ArrayLike, + x: Array, nu: float, alpha: float, length: float, diff --git a/test/contrib/hsgp/test_laplacian.py b/test/contrib/hsgp/test_laplacian.py index 2dcc9b1b4..57fc898b3 100644 --- a/test/contrib/hsgp/test_laplacian.py +++ b/test/contrib/hsgp/test_laplacian.py @@ -9,8 +9,8 @@ import numpy as np import pytest +from jax import Array import jax.numpy as jnp -from jax.typing import ArrayLike from numpyro.contrib.hsgp.laplacian import ( _convert_ell, @@ -110,7 +110,7 @@ def test_sqrt_eigenvalues(ell: float | int, m: int | list[int], dim: int): ], ids=["x_pos", "x_contains_zero", "x_neg2", "x_pos2-large", "x_2d", "x_batch"], ) -def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]): +def test_eigenfunctions(x: Array, ell: float | int, m: int | list[int]): phi = eigenfunctions(x=x, ell=ell, m=m) if isinstance(m, int): m = [m] diff --git a/test/test_typing.py b/test/test_typing.py index 0ae066890..be2918d0f 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -2,14 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Optional -import jax +from jax import Array import jax.numpy as jnp from jax.random import PRNGKey -from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, DistributionT +from numpyro._typing import ConstraintT, DistributionT, PRNGKeyT import numpyro.distributions as dist @@ -59,14 +58,14 @@ def icdf(self, q): return jnp.array(0.0) def rsample( - self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () - ) -> ArrayLike: + self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () + ) -> Array: return self.sample(key, sample_shape) - def entropy(self) -> ArrayLike: + def entropy(self) -> Array: return jnp.array(0.0) - def enumerate_support(self, expand: bool = True) -> ArrayLike: + def enumerate_support(self, expand: bool = True) -> Array: return jnp.array([]) def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: