diff --git a/numpyro/_typing.py b/numpyro/_typing.py index c9ad8819a..10cc1c9e8 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,13 +4,16 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, Union, runtime_checkable +import weakref try: from typing import ParamSpec, TypeAlias except ImportError: from typing_extensions import ParamSpec, TypeAlias +import numpy as np + import jax from jax.typing import ArrayLike @@ -21,10 +24,26 @@ TraceT: TypeAlias = OrderedDict[str, Message] +NonScalarArray = Union[np.ndarray, jax.Array] +"""An alias for array-like types excluding scalars.""" + + +NumLike = Union[NonScalarArray, np.number, int, float, complex] +"""An alias for array-like types excluding `np.bool_` and `bool`.""" + + +PyTree: TypeAlias = Any +"""A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" + + @runtime_checkable class ConstraintT(Protocol): - is_discrete: bool = ... - event_dim: int = ... + """A protocol for typing constraints.""" + + @property + def is_discrete(self) -> bool: ... + @property + def event_dim(self) -> int: ... def __call__(self, x: ArrayLike) -> ArrayLike: ... def __repr__(self) -> str: ... @@ -87,20 +106,24 @@ def is_discrete(self) -> bool: ... @runtime_checkable class TransformT(Protocol): - domain = ConstraintT - codomain = ConstraintT - _inv: "TransformT" = None - - def __call__(self, x: ArrayLike) -> ArrayLike: ... - def _inverse(self, y: ArrayLike) -> ArrayLike: ... - def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: ... - def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ... - def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... - def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... + _inv: Optional[Union["TransformT", weakref.ref]] = ... + @property + def domain(self) -> ConstraintT: ... + @property + def codomain(self) -> ConstraintT: ... @property def inv(self) -> "TransformT": ... @property - def sign(self) -> ArrayLike: ... + def sign(self) -> NumLike: ... + + def __call__(self, x: NumLike) -> NumLike: ... + def _inverse(self, y: NumLike) -> NumLike: ... + def log_abs_det_jacobian( + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: ... + def call_with_intermediates( + self, x: NumLike + ) -> tuple[NumLike, Optional[PyTree]]: ... + def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... + def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index de0ad382e..b25d050d5 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -80,7 +80,7 @@ def _vmap_over_affine_transform( dist_axes = copy.copy(dist) dist_axes.loc = loc dist_axes.scale = scale - dist_axes.domain = domain + dist_axes._domain = domain return dist_axes diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7671b0685..417b5f260 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -73,7 +73,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT +from numpyro._typing import ConstraintT, NonScalarArray, NumLike class Constraint(object): @@ -153,7 +153,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _CorrCholesky(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -165,7 +165,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: unit_norm_row = jnp.all(jnp.abs(x_norm - 1) <= tol, axis=-1) return lower_triangular & positive_diagonal & unit_norm_row - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -174,7 +174,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _CorrMatrix(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -186,7 +186,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return symmetric & positive & unit_variance - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -284,10 +284,10 @@ def is_dependent(constraint): class _GreaterThan(Constraint): - def __init__(self, lower_bound): + def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x > self.lower_bound def __repr__(self) -> str: @@ -295,7 +295,7 @@ def __repr__(self) -> str: fmt_string += "(lower_bound={})".format(self.lower_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -308,7 +308,7 @@ def __eq__(self, other: ConstraintT) -> bool: class _GreaterThanEq(_GreaterThan): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x >= self.lower_bound def __eq__(self, other: ConstraintT) -> bool: @@ -347,6 +347,10 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims): self.reinterpreted_batch_ndims = reinterpreted_batch_ndims super().__init__() + @property + def is_discrete(self) -> bool: + return self.base_constraint.is_discrete + @property def event_dim(self) -> int: return self.base_constraint.event_dim + self.reinterpreted_batch_ndims @@ -405,10 +409,10 @@ def __init__(self) -> None: class _LessThan(Constraint): - def __init__(self, upper_bound: ArrayLike) -> None: + def __init__(self, upper_bound: NumLike) -> None: self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x < self.upper_bound def __repr__(self) -> str: @@ -416,7 +420,7 @@ def __repr__(self) -> str: fmt_string += "(upper_bound={})".format(self.upper_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -429,7 +433,7 @@ def __eq__(self, other: ConstraintT) -> bool: class _LessThanEq(_LessThan): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x <= self.upper_bound def __eq__(self, other: ConstraintT) -> bool: @@ -441,11 +445,11 @@ def __eq__(self, other: ConstraintT) -> bool: class _IntegerInterval(Constraint): is_discrete = True - def __init__(self, lower_bound: ArrayLike, upper_bound: ArrayLike) -> None: + def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0) def __repr__(self) -> str: @@ -455,7 +459,7 @@ def __repr__(self) -> str: ) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -476,10 +480,10 @@ def __eq__(self, other: ConstraintT) -> bool: class _IntegerGreaterThan(Constraint): is_discrete = True - def __init__(self, lower_bound: ArrayLike) -> None: + def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x % 1 == 0) & (x >= self.lower_bound) def __repr__(self) -> str: @@ -487,7 +491,7 @@ def __repr__(self) -> str: fmt_string += "(lower_bound={})".format(self.lower_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -510,11 +514,11 @@ def __init__(self) -> None: class _Interval(Constraint): - def __init__(self, lower_bound: ArrayLike, upper_bound: ArrayLike) -> None: + def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x >= self.lower_bound) & (x <= self.upper_bound) def __repr__(self) -> str: @@ -524,7 +528,7 @@ def __repr__(self) -> str: ) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to( (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) ) @@ -554,7 +558,7 @@ def __init__(self) -> None: class _OpenInterval(_Interval): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x > self.lower_bound) & (x < self.upper_bound) def __repr__(self) -> str: @@ -568,7 +572,7 @@ def __repr__(self) -> str: class _LowerCholesky(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -577,7 +581,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) return lower_triangular & positive_diagonal - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -590,10 +594,10 @@ class _Multinomial(Constraint): def __init__(self, upper_bound: ArrayLike) -> None: self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: pad_width = ((0, 0),) * jax.numpy.ndim(self.upper_bound) + ( (0, prototype.shape[-1] - 1), ) @@ -617,22 +621,22 @@ class _L1Ball(_SingletonConstraint): event_dim = 1 reltol = 10.0 # Relative to finfo.eps. - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy eps = jnp.finfo(x.dtype).eps return jnp.abs(x).sum(axis=-1) < 1 + self.reltol * eps - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) class _OrderedVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: return (x[..., 1:] > x[..., :-1]).all(axis=-1) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.arange(float(prototype.shape[-1])), prototype.shape ) @@ -641,7 +645,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _PositiveDefinite(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -649,7 +653,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: positive = jnp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -658,20 +662,20 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _PositiveDefiniteCirculantVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = 10 * jnp.finfo(x.dtype).eps rfft = jnp.fft.rfft(x) return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jnp.zeros_like(prototype).at[..., 0].set(1.0) class _PositiveSemiDefinite(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -679,7 +683,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 return symmetric & nonnegative - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -693,41 +697,41 @@ class _PositiveOrderedVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: return ordered_vector.check(x) & independent(positive, 1).check(x) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.exp(jax.numpy.arange(float(prototype.shape[-1]))), prototype.shape ) class _Complex(_SingletonConstraint): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) class _Real(_SingletonConstraint): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) class _Simplex(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: x_sum = x.sum(axis=-1) return (x >= 0).all(axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) @@ -735,12 +739,12 @@ class _SoftplusPositive(_SingletonConstraint, _GreaterThan): def __init__(self) -> None: super().__init__(lower_bound=0.0) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.full(jax.numpy.shape(prototype), np.log(2)) class _SoftplusLowerCholesky(_LowerCholesky): - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]) * np.log(2), prototype.shape ) @@ -758,14 +762,14 @@ class _Sphere(_SingletonConstraint): event_dim = 1 reltol = 10.0 # Relative to finfo.eps. - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy eps = jnp.finfo(x.dtype).eps norm = jnp.linalg.norm(x, axis=-1) error = jnp.abs(norm - 1) return error < self.reltol * eps * x.shape[-1] ** 0.5 - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) @@ -774,7 +778,7 @@ def __init__(self, event_dim: int = 1) -> None: self.event_dim = event_dim super().__init__() - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 zerosum_true = True @@ -785,7 +789,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: def __eq__(self, other: ConstraintT) -> bool: return type(self) is type(other) and self.event_dim == other.event_dim - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.zeros_like(prototype) def tree_flatten(self): @@ -801,18 +805,18 @@ def tree_flatten(self): corr_cholesky: ConstraintT = _CorrCholesky() corr_matrix: ConstraintT = _CorrMatrix() dependent: ConstraintT = _Dependent() -greater_than: ConstraintT = _GreaterThan -greater_than_eq: ConstraintT = _GreaterThanEq -less_than: ConstraintT = _LessThan -less_than_eq: ConstraintT = _LessThanEq -independent: ConstraintT = _IndependentConstraint -integer_interval: ConstraintT = _IntegerInterval -integer_greater_than: ConstraintT = _IntegerGreaterThan -interval: ConstraintT = _Interval +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +less_than_eq = _LessThanEq +independent = _IndependentConstraint +integer_interval = _IntegerInterval +integer_greater_than = _IntegerGreaterThan +interval = _Interval l1_ball: ConstraintT = _L1Ball() lower_cholesky: ConstraintT = _LowerCholesky() scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() -multinomial: ConstraintT = _Multinomial +multinomial = _Multinomial nonnegative: ConstraintT = _Nonnegative() nonnegative_integer: ConstraintT = _IntegerNonnegative() ordered_vector: ConstraintT = _OrderedVector() @@ -830,5 +834,5 @@ def tree_flatten(self): softplus_positive: ConstraintT = _SoftplusPositive() sphere: ConstraintT = _Sphere() unit_interval: ConstraintT = _UnitInterval() -open_interval: ConstraintT = _OpenInterval -zero_sum: ConstraintT = _ZeroSum +open_interval = _OpenInterval +zero_sum = _ZeroSum diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 082b2df3a..4512225c8 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1,8 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + import math -from typing import Optional, Sequence, Tuple +from typing import Generic, Optional, Sequence, Tuple, TypeVar, Union, cast import warnings import weakref @@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import TransformT +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -59,44 +60,51 @@ ] -def _clipped_expit(x: ArrayLike) -> ArrayLike: +def _clipped_expit(x: NumLike) -> NumLike: finfo = jnp.finfo(jnp.result_type(x)) return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) -class Transform(object): - domain = constraints.real - codomain = constraints.real - _inv = None +NumLikeT = TypeVar("NumLikeT", bound=NumLike) + + +class Transform(Generic[NumLikeT]): + _inv: Optional[Union[TransformT, weakref.ref]] = None + + @property + def domain(self) -> ConstraintT: + return constraints.real + + @property + def codomain(self) -> ConstraintT: + return constraints.real def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) @property - def inv(self) -> TransformT: + def inv(self: TransformT) -> TransformT: inv = None - if self._inv is not None: + if (self._inv is not None) and isinstance(self._inv, weakref.ref): inv = self._inv() if inv is None: inv = _InverseTransform(self) self._inv = weakref.ref(inv) - return inv + return cast(TransformT, inv) - def __call__(self, x: ArrayLike) -> ArrayLike: - raise NotImplementedError + def __call__(self, x: NumLikeT) -> NumLike: + raise NotImplementedError() - def _inverse(self, y: ArrayLike) -> ArrayLike: - raise NotImplementedError + def _inverse(self, y: NumLikeT) -> NumLike: + raise NotImplementedError() def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: - raise NotImplementedError + self, x: NumLikeT, y: NumLikeT, intermediates: Optional[PyTree] = None + ) -> NumLike: + raise NotImplementedError() - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: NumLikeT) -> Tuple[NumLike, Optional[PyTree]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -114,7 +122,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: """ Sign of the derivative of the transform if it is bijective. """ @@ -144,41 +152,41 @@ def tree_unflatten(cls, aux_data, params): return self -class ParameterFreeTransform(Transform): +class ParameterFreeTransform(Transform[NumLikeT]): def tree_flatten(self): return (), ((), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) -class _InverseTransform(Transform): - def __init__(self, transform: TransformT) -> None: +class _InverseTransform(Transform[NumLike]): + def __init__(self, transform: TransformT): super().__init__() - self._inv = transform + self._inv: TransformT = transform @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: return self._inv.codomain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: return self._inv.domain @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return self._inv.sign @property def inv(self) -> TransformT: return self._inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return self._inv._inverse(x) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -191,23 +199,23 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self._inv,), (("_inv",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _InverseTransform): return False return self._inv == other._inv -class AbsTransform(ParameterFreeTransform): +class AbsTransform(ParameterFreeTransform[NumLike]): domain = constraints.real codomain = constraints.positive - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, AbsTransform) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return jnp.abs(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", @@ -216,24 +224,25 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return y -class AffineTransform(Transform): +class AffineTransform(Transform[NumLike]): """ .. note:: When `scale` is a JAX tracer, we always assume that `scale > 0` when calculating `codomain`. """ def __init__( - self, - loc: ArrayLike, - scale: ArrayLike, - domain: constraints.Constraint = constraints.real, + self, loc: NumLike, scale: NumLike, domain: ConstraintT = constraints.real ): self.loc = loc self.scale = scale - self.domain = domain + self._domain = domain + + @property + def domain(self) -> ConstraintT: + return self._domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): @@ -261,18 +270,18 @@ def codomain(self) -> constraints.Constraint: raise NotImplementedError @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return jnp.sign(self.scale) - def __call__(self, x: ArrayLike) -> ArrayLike: - return self.loc + self.scale * x + def __call__(self, x: NumLike) -> NumLike: + return self.loc + jnp.multiply(self.scale, x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return (y - self.loc) / self.scale def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -286,16 +295,20 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ) def tree_flatten(self): - return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) + return (self.loc, self.scale, self.domain), ( + ("loc", "scale", "_domain"), + dict(), + ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False - return ( + is_equal = ( jnp.array_equal(self.loc, other.loc) & jnp.array_equal(self.scale, other.scale) & (self.domain == other.domain) ) + return is_equal # type: ignore[return-value] def _get_compose_transform_input_event_dim(parts): @@ -316,54 +329,60 @@ def _get_compose_transform_output_event_dim(parts): return output_event_dim -class ComposeTransform(Transform): +class ComposeTransform(Transform[NumLike]): def __init__(self, parts: Sequence[TransformT]) -> None: self.parts = parts @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim if input_event_dim == first_input_event_dim: return self.parts[0].domain else: - return constraints.independent( - self.parts[0].domain, input_event_dim - first_input_event_dim + return cast( + ConstraintT, + constraints.independent( + self.parts[0].domain, input_event_dim - first_input_event_dim + ), ) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim if output_event_dim == last_output_event_dim: return self.parts[-1].codomain else: - return constraints.independent( - self.parts[-1].codomain, output_event_dim - last_output_event_dim + return cast( + ConstraintT, + constraints.independent( + self.parts[-1].codomain, output_event_dim - last_output_event_dim + ), ) @property - def sign(self) -> ArrayLike: - sign = 1 + def sign(self) -> NumLike: + sign: NumLike = 1 for transform in self.parts: sign *= transform.sign return sign - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: for part in self.parts: x = part(x) return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: for part in self.parts[::-1]: y = part.inv(y) return y def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: if intermediates is not None: if len(intermediates) != len(self.parts): raise ValueError( @@ -389,10 +408,8 @@ def log_abs_det_jacobian( result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: - intermediates = [] + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]: + intermediates: list[Optional[PyTree]] = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) intermediates.append([x, inter]) @@ -414,10 +431,10 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.parts,), (("parts",), {}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False - return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) + return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) # type: ignore[return-value] def _matrix_forward_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]: @@ -443,7 +460,7 @@ def _matrix_inverse_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, return shape[:-2] + (N,) -class CholeskyTransform(ParameterFreeTransform): +class CholeskyTransform(ParameterFreeTransform[NonScalarArray]): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a positive definite matrix. @@ -452,15 +469,18 @@ class CholeskyTransform(ParameterFreeTransform): domain = constraints.positive_definite codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.linalg.cholesky(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.matmul(y, jnp.swapaxes(y, -2, -1)) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) @@ -469,7 +489,7 @@ def log_abs_det_jacobian( ) -class CorrCholeskyTransform(ParameterFreeTransform): +class CorrCholeskyTransform(ParameterFreeTransform[NonScalarArray]): r""" Transforms a unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower @@ -499,12 +519,12 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a domain = constraints.real_vector codomain = constraints.corr_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim @@ -519,8 +539,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.arctanh(t) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. @@ -552,24 +575,31 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): codomain = constraints.corr_cholesky def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) return jnp.sum(order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1) -class ExpTransform(Transform): +class ExpTransform(Transform[NumLike]): sign = 1 # TODO: refine domain/codomain logic through setters, especially when # transforms for inverses are supported - def __init__(self, domain=constraints.real): - self.domain = domain + def __init__(self, domain: ConstraintT = constraints.real): + self._domain = domain + + @property + def domain(self) -> ConstraintT: + return self._domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: @@ -584,43 +614,43 @@ def codomain(self) -> constraints.Constraint: else: raise NotImplementedError - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: # XXX consider to clamp from below for stability if necessary return jnp.exp(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return jnp.log(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: return x def tree_flatten(self): - return (self.domain,), (("domain",), dict()) + return (self.domain,), (("_domain",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ExpTransform): return False return self.domain == other.domain -class IdentityTransform(ParameterFreeTransform): +class IdentityTransform(ParameterFreeTransform[NumLike]): sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return y def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: return jnp.zeros_like(x) -class IndependentTransform(Transform): +class IndependentTransform(Transform[NumLike]): """ Wraps a transform by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its @@ -638,26 +668,32 @@ def __init__( super().__init__() @property - def domain(self) -> constraints.Constraint: - return constraints.independent( - self.base_transform.domain, self.reinterpreted_batch_ndims + def domain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent( + self.base_transform.domain, self.reinterpreted_batch_ndims + ), ) @property - def codomain(self) -> constraints.Constraint: - return constraints.independent( - self.base_transform.codomain, self.reinterpreted_batch_ndims + def codomain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent( + self.base_transform.codomain, self.reinterpreted_batch_ndims + ), ) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return self.base_transform(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return self.base_transform._inverse(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) @@ -666,9 +702,7 @@ def log_abs_det_jacobian( raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}") return sum_rightmost(result, self.reinterpreted_batch_ndims) - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]: return self.base_transform.call_with_intermediates(x) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -683,7 +717,7 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, IndependentTransform): return False return (self.base_transform == other.base_transform) & ( @@ -691,7 +725,7 @@ def __eq__(self, other: TransformT) -> bool: ) -class L1BallTransform(ParameterFreeTransform): +class L1BallTransform(ParameterFreeTransform[NonScalarArray]): r""" Transforms a unconstrained real vector :math:`x` into the unit L1 ball. """ @@ -699,7 +733,7 @@ class L1BallTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.l1_ball - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # transform to (-1, 1) interval t = jnp.tanh(x) @@ -709,7 +743,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) return t * remainder - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] @@ -723,8 +757,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.arctanh(t) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) @@ -741,7 +778,7 @@ def log_abs_det_jacobian( return stick_breaking_logdet + tanh_logdet -class LowerCholeskyAffine(Transform): +class LowerCholeskyAffine(Transform[NonScalarArray]): r""" Transform via the mapping :math:`y = loc + scale\_tril\ @\ x`. @@ -765,7 +802,7 @@ class LowerCholeskyAffine(Transform): domain = constraints.real_vector codomain = constraints.real_vector - def __init__(self, loc: ArrayLike, scale_tril: Array): + def __init__(self, loc: NonScalarArray, scale_tril: NonScalarArray): if jnp.ndim(scale_tril) != 2: raise ValueError( "Only support 2-dimensional scale_tril matrix. " @@ -775,12 +812,12 @@ def __init__(self, loc: ArrayLike, scale_tril: Array): self.loc = loc self.scale_tril = scale_tril - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1 ) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y = y - self.loc original_shape = jnp.shape(y) yt = jnp.reshape(y, (-1, original_shape[-1])).T @@ -788,8 +825,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.reshape(xt.T, original_shape) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], @@ -808,15 +848,15 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, LowerCholeskyAffine): return False return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( self.scale_tril, other.scale_tril - ) + ) # type: ignore[return-value] -class LowerCholeskyTransform(ParameterFreeTransform): +class LowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]): """ Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is @@ -827,21 +867,24 @@ class LowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) return add_diag(z, diag) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate( [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1 ) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1) @@ -869,27 +912,30 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) - return add_diag(z, 1) * diag[..., None] + return add_diag(z, jnp.array(1)) * diag[..., None] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) return (jnp.log(diag_softplus) * jnp.arange(n) - softplus(-diag)).sum(-1) -class OrderedTransform(ParameterFreeTransform): +class OrderedTransform(ParameterFreeTransform[NonScalarArray]): """ Transform a real vector to an ordered vector. @@ -913,33 +959,36 @@ class OrderedTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.ordered_vector - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1) return jnp.cumsum(z, axis=-1) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: x = jnp.log(y[..., 1:] - y[..., :-1]) return jnp.concatenate([y[..., :1], x], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.sum(x[..., 1:], -1) -class PermuteTransform(Transform): +class PermuteTransform(Transform[NonScalarArray]): domain = constraints.real_vector codomain = constraints.real_vector def __init__(self, permutation: Array) -> None: self.permutation = permutation - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return x[..., self.permutation] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: size = self.permutation.size - permutation_inv = ( + permutation_inv: NonScalarArray = ( jnp.zeros(size, dtype=jnp.result_type(int)) .at[self.permutation] .set(jnp.arange(size)) @@ -947,36 +996,39 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return y[..., permutation_inv] def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.full(jnp.shape(x)[:-1], 0.0) def tree_flatten(self): return (self.permutation,), (("permutation",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False - return jnp.array_equal(self.permutation, other.permutation) + return jnp.array_equal(self.permutation, other.permutation) # type: ignore[return-value] -class PowerTransform(Transform): +class PowerTransform(Transform[NumLike]): domain = constraints.positive codomain = constraints.positive - def __init__(self, exponent: ArrayLike) -> None: + def __init__(self, exponent: NumLike) -> None: self.exponent = exponent - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return jnp.power(x, self.exponent) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return jnp.power(y, 1 / self.exponent) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: - return jnp.log(jnp.abs(self.exponent * y / x)) + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: + return jnp.log(jnp.abs(jnp.multiply(self.exponent, y) / x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) @@ -987,33 +1039,33 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.exponent,), (("exponent",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False - return jnp.array_equal(self.exponent, other.exponent) + return jnp.array_equal(self.exponent, other.exponent) # type: ignore[return-value] @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return jnp.sign(self.exponent) -class SigmoidTransform(ParameterFreeTransform): +class SigmoidTransform(ParameterFreeTransform[NumLike]): codomain = constraints.unit_interval sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return _clipped_expit(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return logit(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: return -softplus(x) - softplus(-x) -class SimplexToOrderedTransform(Transform): +class SimplexToOrderedTransform(Transform[NonScalarArray]): """ Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities. @@ -1045,12 +1097,12 @@ class SimplexToOrderedTransform(Transform): def __init__(self, anchor_point: ArrayLike = 0.0) -> None: self.anchor_point = anchor_point - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] @@ -1061,8 +1113,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return x def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) @@ -1071,10 +1126,10 @@ def log_abs_det_jacobian( def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False - return jnp.array_equal(self.anchor_point, other.anchor_point) + return jnp.array_equal(self.anchor_point, other.anchor_point) # type: ignore[return-value] def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) @@ -1083,11 +1138,11 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] + 1,) -def _softplus_inv(y: ArrayLike) -> ArrayLike: +def _softplus_inv(y: NumLike) -> NumLike: return jnp.log(-jnp.expm1(-y)) + y -class SoftplusTransform(ParameterFreeTransform): +class SoftplusTransform(ParameterFreeTransform[NumLike]): r""" Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y) - 1)`. @@ -1097,19 +1152,19 @@ class SoftplusTransform(ParameterFreeTransform): codomain = constraints.softplus_positive sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return softplus(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return _softplus_inv(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None + ) -> NumLike: return -softplus(-x) -class SoftplusLowerCholeskyTransform(ParameterFreeTransform): +class SoftplusLowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]): """ Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive @@ -1119,20 +1174,23 @@ class SoftplusLowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.softplus_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -1145,11 +1203,11 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _matrix_inverse_shape(shape) -class StickBreakingTransform(ParameterFreeTransform): +class StickBreakingTransform(ParameterFreeTransform[NonScalarArray]): domain = constraints.real_vector codomain = constraints.simplex - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick @@ -1165,7 +1223,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return z_padded * z1m_cumprod_shifted - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y_crop = y[..., :-1] z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod @@ -1173,8 +1231,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) @@ -1192,7 +1253,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) -class UnpackTransform(Transform): +class UnpackTransform(Transform[NonScalarArray]): """ Transforms a contiguous array to a pytree of subarrays. @@ -1207,7 +1268,7 @@ def __init__(self, unpack_fn, pack_fn=None): self.unpack_fn = unpack_fn self.pack_fn = pack_fn - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) @@ -1217,7 +1278,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: else: return self.unpack_fn(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: if self.pack_fn is None: raise NotImplementedError( "pack_fn needs to be provided to perform UnpackTransform.inv." @@ -1238,8 +1299,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return self.pack_fn(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1252,7 +1316,7 @@ def tree_flatten(self): # XXX: what if unpack_fn is a parametrized callable pytree? return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, UnpackTransform) and (self.unpack_fn is other.unpack_fn) @@ -1269,7 +1333,7 @@ def _get_target_shape( return shape[:batch_ndims] + forward_shape -class ReshapeTransform(Transform): +class ReshapeTransform(Transform[NonScalarArray]): """ Reshape a sample, leaving batch dimensions unchanged. @@ -1293,12 +1357,18 @@ def __init__( self._inverse_shape = inverse_shape @property - def domain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, len(self._inverse_shape)) + def domain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent(constraints.real, len(self._inverse_shape)), + ) @property - def codomain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, len(self._forward_shape)) + def codomain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent(constraints.real, len(self._forward_shape)), + ) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._forward_shape, self._inverse_shape) @@ -1306,15 +1376,18 @@ def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._inverse_shape, self._forward_shape) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.reshape(x, self.forward_shape(jnp.shape(x))) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) def tree_flatten(self): @@ -1324,7 +1397,7 @@ def tree_flatten(self): } return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ReshapeTransform) and self._forward_shape == other._forward_shape @@ -1334,14 +1407,14 @@ def __eq__(self, other: TransformT) -> bool: def _normalize_rfft_shape( input_shape: tuple[int, ...], - shape: tuple[int, ...], + shape: Optional[tuple[int, ...]], ) -> tuple[int, ...]: if shape is None: return input_shape return input_shape[: len(input_shape) - len(shape)] + shape -class RealFastFourierTransform(Transform): +class RealFastFourierTransform(Transform[NonScalarArray]): """ N-dimensional discrete fast Fourier transform for real input. @@ -1365,11 +1438,11 @@ def __init__( self.transform_shape = transform_shape self.transform_ndims = transform_ndims - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: NonScalarArray) -> NonScalarArray: axes = tuple(range(-self.transform_ndims, 0)) return jnp.fft.rfftn(x, self.transform_shape, axes) - def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: axes = tuple(range(-self.transform_ndims, 0)) return jnp.fft.irfftn(y, self.transform_shape, axes) @@ -1385,8 +1458,11 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (size,) def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None - ) -> jnp.ndarray: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] ) @@ -1405,14 +1481,19 @@ def tree_flatten(self): return (), ((), aux_data) @property - def domain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, self.transform_ndims) + def domain(self) -> ConstraintT: + return cast( + ConstraintT, constraints.independent(constraints.real, self.transform_ndims) + ) @property - def codomain(self) -> constraints.Constraint: - return constraints.independent(constraints.complex, self.transform_ndims) + def codomain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent(constraints.complex, self.transform_ndims), + ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, RealFastFourierTransform) and self.transform_ndims == other.transform_ndims @@ -1420,7 +1501,7 @@ def __eq__(self, other: TransformT) -> bool: ) -class PackRealFastFourierCoefficientsTransform(Transform): +class PackRealFastFourierCoefficientsTransform(Transform[NonScalarArray]): """ Transform a real vector to complex coefficients of a real fast Fourier transform. @@ -1428,13 +1509,13 @@ class PackRealFastFourierCoefficientsTransform(Transform): """ domain = constraints.real_vector - codomain = constraints.independent(constraints.complex, 1) + codomain = cast(ConstraintT, constraints.independent(constraints.complex, 1)) def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None: assert transform_shape is None or len(transform_shape) == 1, ( "Packing Fourier coefficients is only implemented for vectors." ) - self.shape = transform_shape + self.shape: Optional[tuple[int, ...]] = transform_shape def tree_flatten(self): return (), ((), {"shape": self.shape}) @@ -1457,25 +1538,31 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return (*batch_shape, n) def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, ) -> Array: shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1]) return jnp.zeros_like(x, shape=shape) - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: assert self.shape is None or self.shape == x.shape[-1:] n = x.shape[-1] n_real = n // 2 + 1 n_imag = n - n_real complex_dtype = jnp.result_type(x.dtype, jnp.complex64) return ( - x[..., :n_real] + jnp.asarray(x)[..., :n_real] .astype(complex_dtype) .at[..., 1 : 1 + n_imag] .add(1j * x[..., n_real:]) ) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: + assert self.shape is not None, ( + "Shape must be specified in `__init__` for inverse transform." + ) (n,) = self.shape n_real = n // 2 + 1 n_imag = n - n_real @@ -1488,7 +1575,7 @@ def __eq__(self, other) -> bool: ) -class RecursiveLinearTransform(Transform): +class RecursiveLinearTransform(Transform[NonScalarArray]): """ Apply a linear transformation recursively such that :math:`y_t = A y_{t - 1} + x_t` for :math:`t > 0`, where :math:`x_t` and :math:`y_t` @@ -1540,7 +1627,11 @@ class RecursiveLinearTransform(Transform): domain = constraints.real_matrix codomain = constraints.real_matrix - def __init__(self, transition_matrix: Array, initial_value: Array = None) -> None: + def __init__( + self, + transition_matrix: NonScalarArray, + initial_value: Optional[NonScalarArray] = None, + ) -> None: event_shape = transition_matrix.shape[-1:] if initial_value is None: @@ -1567,7 +1658,7 @@ def _get_initial_value(self, sample_shape) -> Array: return jnp.broadcast_to(self.initial_value, batch_shape + event_shape) - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # Move the time axis to the first position so we can scan over it. sample_shape = x.shape[:-2] x = jnp.moveaxis(x, -2, 0) @@ -1581,7 +1672,7 @@ def f(y, x): _, y = lax.scan(f, initial_value, x) return jnp.moveaxis(y, 0, -2) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # Move the time axis to the first position so we can scan over it in reverse. sample_shape = y.shape[:-2] y = jnp.moveaxis(y, -2, 0) @@ -1597,7 +1688,12 @@ def f(y, prev): ) return jnp.moveaxis(x, 0, -2) - def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None): + def log_abs_det_jacobian( + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros_like(x, shape=x.shape[:-2]) def tree_flatten(self): @@ -1606,13 +1702,15 @@ def tree_flatten(self): {}, ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, RecursiveLinearTransform): return False - return jnp.array_equal(self.transition_matrix, other.transition_matrix) + return jnp.array_equal( + self.transition_matrix, other.transition_matrix + ) & jnp.array_equal(self.initial_value, other.initial_value) # type: ignore[return-value] -class ZeroSumTransform(Transform): +class ZeroSumTransform(Transform[NonScalarArray]): """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] :param transform_ndims: Number of trailing dimensions to transform. @@ -1627,26 +1725,28 @@ def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property - def domain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, self.transform_ndims) + def domain(self) -> ConstraintT: + return cast( + ConstraintT, constraints.independent(constraints.real, self.transform_ndims) + ) @property - def codomain(self) -> constraints.Constraint: - return constraints.zero_sum(self.transform_ndims) + def codomain(self) -> ConstraintT: + return cast(ConstraintT, constraints.zero_sum(self.transform_ndims)) - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: x = self.extend_axis(x, axis=axis) return x - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: y = self.extend_axis_rev(y, axis=axis) return y - def extend_axis_rev(self, array: Array, axis: int) -> Array: + def extend_axis_rev(self, array: NonScalarArray, axis: int) -> NonScalarArray: normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis n = array.shape[normalized_axis] @@ -1657,7 +1757,7 @@ def extend_axis_rev(self, array: Array, axis: int) -> Array: slice_before = (slice(None, None),) * normalized_axis return array[(*slice_before, slice(None, -1))] + norm - def extend_axis(self, array: Array, axis: int) -> Array: + def extend_axis(self, array: NonScalarArray, axis: int) -> NonScalarArray: n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) @@ -1668,7 +1768,10 @@ def extend_axis(self, array: Array, axis: int) -> Array: return out - norm def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, ) -> jnp.ndarray: shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1689,14 +1792,14 @@ def tree_flatten(self): aux_data = {"transform_ndims": self.transform_ndims} return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ZeroSumTransform) and self.transform_ndims == other.transform_ndims ) -class ComplexTransform(ParameterFreeTransform): +class ComplexTransform(ParameterFreeTransform[NonScalarArray]): """ Transforms a pair of real numbers to a complex number. """ @@ -1704,14 +1807,19 @@ class ComplexTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.complex - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." return lax.complex(x[..., 0], x[..., 1]) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.stack([y.real, y.imag], axis=-1) - def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: + def log_abs_det_jacobian( + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NonScalarArray: return jnp.zeros_like(y) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: diff --git a/pyproject.toml b/pyproject.toml index d79a3e532..47925bf7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ extend-include = ["*.ipynb"] [tool.ruff.lint] select = ["ANN", "E", "F", "I", "W"] ignore = [ - "ANN002", # missing args type annotation - "ANN003", # missing kwargs type annotation - "ANN204", # missing type annotation for __call__ + "ANN002", # missing args type annotation + "ANN003", # missing kwargs type annotation + "ANN204", # missing type annotation for __call__ "E203", ] @@ -68,7 +68,9 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.ruff.lint.per-file-ignores] -"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = ["ANN"] # require type annotations in typed modules +"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = [ + "ANN", +] # require type annotations in typed modules [tool.ruff.lint.extend-per-file-ignores] "numpyro/contrib/tfp/distributions.py" = ["F811"] @@ -114,7 +116,6 @@ doctest_optionflags = [ [tool.mypy] ignore_errors = true ignore_missing_imports = true -plugins = ["numpy.typing.mypy_plugin"] [[tool.mypy.overrides]] module = [ @@ -129,5 +130,6 @@ module = [ "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*", + "numpyro.distributions.transforms", ] ignore_errors = false diff --git a/test/test_distributions.py b/test/test_distributions.py index fa5c31f78..3eeb09498 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -556,14 +556,23 @@ def get_sp_dist(jax_dist): np.array([[0.8, 0.2], [-0.1, 1.1]]), np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]), ), - T( - dist.GaussianCopulaBeta, - np.array([7.0, 2.0]), - np.array([4.0, 10.0]), - np.array([[1.0, 0.75], [0.75, 1.0]]), + pytest.param( + *T( + dist.GaussianCopulaBeta, + np.array([7.0, 2.0]), + np.array([4.0, 10.0]), + np.array([[1.0, 0.75], [0.75, 1.0]]), + ), + marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"), + ), + pytest.param( + *T(dist.GaussianCopulaBeta, 2.0, 1.5, np.eye(3)), + marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"), + ), + pytest.param( + *T(dist.GaussianCopulaBeta, 2.0, 1.5, np.full((5, 3, 3), np.eye(3))), + marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"), ), - T(dist.GaussianCopulaBeta, 2.0, 1.5, np.eye(3)), - T(dist.GaussianCopulaBeta, 2.0, 1.5, np.full((5, 3, 3), np.eye(3))), T(dist.Gompertz, np.array([1.7]), np.array([[2.0], [3.0]])), T(dist.Gompertz, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])), T(dist.Gumbel, 0.0, 1.0), @@ -2630,7 +2639,7 @@ def test_composed_transform(batch_shape): expected_log_det = ( jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 ) - assert_allclose(log_det, expected_log_det) + assert_allclose(log_det, expected_log_det, rtol=1e-6) @pytest.mark.parametrize("batch_shape", [(), (5,)])