From c2ca7c39d68d8f0a6893b463d1edd44439d42e84 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Wed, 13 Aug 2025 01:16:03 +0500 Subject: [PATCH 01/16] chore: incomplete changes for type hint in numpyro.distribution.transforms --- numpyro/_typing.py | 25 +- numpyro/distributions/constraints.py | 2 +- numpyro/distributions/transforms.py | 384 ++++++++++++++++----------- pyproject.toml | 11 +- 4 files changed, 257 insertions(+), 165 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index c9ad8819a..6a379bb7a 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -7,9 +7,11 @@ from typing import Any, Protocol, runtime_checkable try: - from typing import ParamSpec, TypeAlias + from typing import ParamSpec, TypeAlias, Union except ImportError: - from typing_extensions import ParamSpec, TypeAlias + from typing_extensions import ParamSpec, TypeAlias, Union + +import numpy as np import jax from jax.typing import ArrayLike @@ -20,6 +22,9 @@ Message: TypeAlias = dict[str, Any] TraceT: TypeAlias = OrderedDict[str, Message] +# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars +StrictArrayT: TypeAlias = Union[np.ndarray, jax.Array] + @runtime_checkable class ConstraintT(Protocol): @@ -87,20 +92,20 @@ def is_discrete(self) -> bool: ... @runtime_checkable class TransformT(Protocol): - domain = ConstraintT - codomain = ConstraintT + domain: ConstraintT = ... + codomain: ConstraintT = ... _inv: "TransformT" = None - def __call__(self, x: ArrayLike) -> ArrayLike: ... - def _inverse(self, y: ArrayLike) -> ArrayLike: ... + def __call__(self, x: jax.Array) -> jax.Array: ... + def _inverse(self, y: jax.Array) -> jax.Array: ... def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: ... - def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ... + self, x: jax.Array, y: jax.Array, intermediates=None + ) -> jax.Array: ... + def call_with_intermediates(self, x: jax.Array) -> tuple[jax.Array, None]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... @property def inv(self) -> "TransformT": ... @property - def sign(self) -> ArrayLike: ... + def sign(self) -> jax.Array: ... diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7671b0685..9e665bded 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -805,7 +805,7 @@ def tree_flatten(self): greater_than_eq: ConstraintT = _GreaterThanEq less_than: ConstraintT = _LessThan less_than_eq: ConstraintT = _LessThanEq -independent: ConstraintT = _IndependentConstraint +independent = _IndependentConstraint integer_interval: ConstraintT = _IntegerInterval integer_greater_than: ConstraintT = _IntegerGreaterThan interval: ConstraintT = _Interval diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 082b2df3a..4fb970772 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -17,7 +17,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import TransformT +from numpyro._typing import ConstraintT, StrictArrayT, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -67,36 +67,34 @@ def _clipped_expit(x: ArrayLike) -> ArrayLike: class Transform(object): domain = constraints.real codomain = constraints.real - _inv = None + _inv: Optional[TransformT] = None def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) @property - def inv(self) -> TransformT: + def inv(self: TransformT) -> TransformT: inv = None if self._inv is not None: inv = self._inv() if inv is None: - inv = _InverseTransform(self) + inv: TransformT = _InverseTransform(self) self._inv = weakref.ref(inv) return inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Array) -> Array: raise NotImplementedError - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Array) -> Array: raise NotImplementedError def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: Array, y: Array, intermediates: Optional[Array] = None + ) -> Array: raise NotImplementedError - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: Array) -> Tuple[Array, Optional[Array]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -114,7 +112,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape @property - def sign(self) -> ArrayLike: + def sign(self) -> Array: """ Sign of the derivative of the transform if it is bijective. """ @@ -148,21 +146,21 @@ class ParameterFreeTransform(Transform): def tree_flatten(self): return (), ((), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) class _InverseTransform(Transform): def __init__(self, transform: TransformT) -> None: super().__init__() - self._inv = transform + self._inv: TransformT = transform @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return self._inv.codomain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return self._inv.domain @property @@ -173,12 +171,15 @@ def sign(self) -> ArrayLike: def inv(self) -> TransformT: return self._inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return self._inv._inverse(x) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -191,7 +192,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self._inv,), (("_inv",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _InverseTransform): return False return self._inv == other._inv @@ -201,13 +202,13 @@ class AbsTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.positive - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, AbsTransform) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return jnp.abs(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", @@ -226,14 +227,14 @@ def __init__( self, loc: ArrayLike, scale: ArrayLike, - domain: constraints.Constraint = constraints.real, + domain: ConstraintT = constraints.real, ): self.loc = loc self.scale = scale self.domain = domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): @@ -264,15 +265,18 @@ def codomain(self) -> constraints.Constraint: def sign(self) -> ArrayLike: return jnp.sign(self.scale) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return self.loc + self.scale * x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return (y - self.loc) / self.scale def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -288,14 +292,14 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False return ( jnp.array_equal(self.loc, other.loc) & jnp.array_equal(self.scale, other.scale) & (self.domain == other.domain) - ) + ) # type: ignore[return-value] def _get_compose_transform_input_event_dim(parts): @@ -321,7 +325,7 @@ def __init__(self, parts: Sequence[TransformT]) -> None: self.parts = parts @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim @@ -333,7 +337,7 @@ def domain(self) -> constraints.Constraint: ) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim @@ -351,19 +355,22 @@ def sign(self) -> ArrayLike: sign *= transform.sign return sign - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: for part in self.parts: x = part(x) return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: for part in self.parts[::-1]: y = part.inv(y) return y def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: if intermediates is not None: if len(intermediates) != len(self.parts): raise ValueError( @@ -414,7 +421,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.parts,), (("parts",), {}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) @@ -452,15 +459,18 @@ class CholeskyTransform(ParameterFreeTransform): domain = constraints.positive_definite codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return jnp.linalg.cholesky(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return jnp.matmul(y, jnp.swapaxes(y, -2, -1)) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) @@ -499,12 +509,12 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a domain = constraints.real_vector codomain = constraints.corr_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim @@ -519,8 +529,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.arctanh(t) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. @@ -552,8 +565,11 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): codomain = constraints.corr_cholesky def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) @@ -569,7 +585,7 @@ def __init__(self, domain=constraints.real): self.domain = domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: @@ -584,22 +600,25 @@ def codomain(self) -> constraints.Constraint: else: raise NotImplementedError - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: # XXX consider to clamp from below for stability if necessary return jnp.exp(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return jnp.log(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return x def tree_flatten(self): return (self.domain,), (("domain",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ExpTransform): return False return self.domain == other.domain @@ -608,15 +627,18 @@ def __eq__(self, other: TransformT) -> bool: class IdentityTransform(ParameterFreeTransform): sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return y def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.zeros_like(x) @@ -638,26 +660,29 @@ def __init__( super().__init__() @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return constraints.independent( self.base_transform.domain, self.reinterpreted_batch_ndims ) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.independent( self.base_transform.codomain, self.reinterpreted_batch_ndims ) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return self.base_transform(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return self.base_transform._inverse(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) @@ -683,7 +708,7 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, IndependentTransform): return False return (self.base_transform == other.base_transform) & ( @@ -699,7 +724,7 @@ class L1BallTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.l1_ball - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: # transform to (-1, 1) interval t = jnp.tanh(x) @@ -709,7 +734,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) return t * remainder - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] @@ -723,8 +748,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.arctanh(t) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) @@ -775,12 +803,12 @@ def __init__(self, loc: ArrayLike, scale_tril: Array): self.loc = loc self.scale_tril = scale_tril - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1 ) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: y = y - self.loc original_shape = jnp.shape(y) yt = jnp.reshape(y, (-1, original_shape[-1])).T @@ -788,8 +816,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.reshape(xt.T, original_shape) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], @@ -808,7 +839,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, LowerCholeskyAffine): return False return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( @@ -827,21 +858,21 @@ class LowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) return add_diag(z, diag) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate( [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1 ) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, x: StrictArrayT, y: StrictArrayT, intermediates=None + ) -> StrictArrayT: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1) @@ -869,20 +900,23 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return add_diag(z, 1) * diag[..., None] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -913,17 +947,20 @@ class OrderedTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.ordered_vector - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1) return jnp.cumsum(z, axis=-1) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: x = jnp.log(y[..., 1:] - y[..., :-1]) return jnp.concatenate([y[..., :1], x], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.sum(x[..., 1:], -1) @@ -934,12 +971,12 @@ class PermuteTransform(Transform): def __init__(self, permutation: Array) -> None: self.permutation = permutation - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return x[..., self.permutation] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: size = self.permutation.size - permutation_inv = ( + permutation_inv: StrictArrayT = ( jnp.zeros(size, dtype=jnp.result_type(int)) .at[self.permutation] .set(jnp.arange(size)) @@ -947,14 +984,17 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return y[..., permutation_inv] def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.full(jnp.shape(x)[:-1], 0.0) def tree_flatten(self): return (self.permutation,), (("permutation",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False return jnp.array_equal(self.permutation, other.permutation) @@ -974,7 +1014,10 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.power(y, 1 / self.exponent) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None + self, + x: ArrayLike, + y: ArrayLike, + intermediates: Optional[ArrayLike] = None, ) -> ArrayLike: return jnp.log(jnp.abs(self.exponent * y / x)) @@ -987,7 +1030,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.exponent,), (("exponent",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False return jnp.array_equal(self.exponent, other.exponent) @@ -1001,15 +1044,18 @@ class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return _clipped_expit(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return logit(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return -softplus(x) - softplus(-x) @@ -1045,12 +1091,12 @@ class SimplexToOrderedTransform(Transform): def __init__(self, anchor_point: ArrayLike = 0.0) -> None: self.anchor_point = anchor_point - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] @@ -1061,8 +1107,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return x def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) @@ -1071,7 +1120,7 @@ def log_abs_det_jacobian( def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False return jnp.array_equal(self.anchor_point, other.anchor_point) @@ -1097,15 +1146,18 @@ class SoftplusTransform(ParameterFreeTransform): codomain = constraints.softplus_positive sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return softplus(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return _softplus_inv(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return -softplus(-x) @@ -1119,20 +1171,23 @@ class SoftplusLowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.softplus_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -1149,7 +1204,7 @@ class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick @@ -1165,7 +1220,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return z_padded * z1m_cumprod_shifted - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: y_crop = y[..., :-1] z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod @@ -1173,8 +1228,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) @@ -1207,7 +1265,7 @@ def __init__(self, unpack_fn, pack_fn=None): self.unpack_fn = unpack_fn self.pack_fn = pack_fn - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) @@ -1217,7 +1275,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: else: return self.unpack_fn(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: if self.pack_fn is None: raise NotImplementedError( "pack_fn needs to be provided to perform UnpackTransform.inv." @@ -1238,8 +1296,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return self.pack_fn(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1252,7 +1313,7 @@ def tree_flatten(self): # XXX: what if unpack_fn is a parametrized callable pytree? return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, UnpackTransform) and (self.unpack_fn is other.unpack_fn) @@ -1293,11 +1354,11 @@ def __init__( self._inverse_shape = inverse_shape @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return constraints.independent(constraints.real, len(self._inverse_shape)) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.independent(constraints.real, len(self._forward_shape)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1306,15 +1367,18 @@ def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._inverse_shape, self._forward_shape) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: StrictArrayT) -> StrictArrayT: return jnp.reshape(x, self.forward_shape(jnp.shape(x))) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) def tree_flatten(self): @@ -1324,7 +1388,7 @@ def tree_flatten(self): } return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ReshapeTransform) and self._forward_shape == other._forward_shape @@ -1385,8 +1449,11 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (size,) def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None - ) -> jnp.ndarray: + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] ) @@ -1405,14 +1472,14 @@ def tree_flatten(self): return (), ((), aux_data) @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return constraints.independent(constraints.real, self.transform_ndims) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.independent(constraints.complex, self.transform_ndims) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, RealFastFourierTransform) and self.transform_ndims == other.transform_ndims @@ -1457,12 +1524,15 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return (*batch_shape, n) def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, ) -> Array: shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1]) return jnp.zeros_like(x, shape=shape) - def __call__(self, x: Array) -> Array: + def __call__(self, x: StrictArrayT) -> StrictArrayT: assert self.shape is None or self.shape == x.shape[-1:] n = x.shape[-1] n_real = n // 2 + 1 @@ -1475,7 +1545,7 @@ def __call__(self, x: Array) -> Array: .add(1j * x[..., n_real:]) ) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: (n,) = self.shape n_real = n // 2 + 1 n_imag = n - n_real @@ -1540,7 +1610,11 @@ class RecursiveLinearTransform(Transform): domain = constraints.real_matrix codomain = constraints.real_matrix - def __init__(self, transition_matrix: Array, initial_value: Array = None) -> None: + def __init__( + self, + transition_matrix: StrictArrayT, + initial_value: Optional[StrictArrayT] = None, + ) -> None: event_shape = transition_matrix.shape[-1:] if initial_value is None: @@ -1567,7 +1641,7 @@ def _get_initial_value(self, sample_shape) -> Array: return jnp.broadcast_to(self.initial_value, batch_shape + event_shape) - def __call__(self, x: Array) -> Array: + def __call__(self, x: StrictArrayT) -> StrictArrayT: # Move the time axis to the first position so we can scan over it. sample_shape = x.shape[:-2] x = jnp.moveaxis(x, -2, 0) @@ -1581,7 +1655,7 @@ def f(y, x): _, y = lax.scan(f, initial_value, x) return jnp.moveaxis(y, 0, -2) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: # Move the time axis to the first position so we can scan over it in reverse. sample_shape = y.shape[:-2] y = jnp.moveaxis(y, -2, 0) @@ -1597,7 +1671,12 @@ def f(y, prev): ) return jnp.moveaxis(x, 0, -2) - def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None): + def log_abs_det_jacobian( + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, + ) -> StrictArrayT: return jnp.zeros_like(x, shape=x.shape[:-2]) def tree_flatten(self): @@ -1606,10 +1685,10 @@ def tree_flatten(self): {}, ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, RecursiveLinearTransform): return False - return jnp.array_equal(self.transition_matrix, other.transition_matrix) + return jnp.array_equal(self.transition_matrix, other.transition_matrix) # type: ignore[return-value] class ZeroSumTransform(Transform): @@ -1627,20 +1706,20 @@ def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return constraints.independent(constraints.real, self.transform_ndims) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.zero_sum(self.transform_ndims) - def __call__(self, x: Array) -> Array: + def __call__(self, x: StrictArrayT) -> StrictArrayT: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: x = self.extend_axis(x, axis=axis) return x - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: StrictArrayT) -> StrictArrayT: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: y = self.extend_axis_rev(y, axis=axis) @@ -1668,7 +1747,10 @@ def extend_axis(self, array: Array, axis: int) -> Array: return out - norm def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Optional[ArrayLike] = None, ) -> jnp.ndarray: shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1689,7 +1771,7 @@ def tree_flatten(self): aux_data = {"transform_ndims": self.transform_ndims} return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ZeroSumTransform) and self.transform_ndims == other.transform_ndims @@ -1704,14 +1786,16 @@ class ComplexTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.complex - def __call__(self, x: Array) -> Array: + def __call__(self, x: StrictArrayT) -> StrictArrayT: assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." return lax.complex(x[..., 0], x[..., 1]) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.stack([y.real, y.imag], axis=-1) - def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: + def log_abs_det_jacobian( + self, x: ArrayLike, y: ArrayLike, intermediates: Optional[ArrayLike] = None + ) -> ArrayLike: return jnp.zeros_like(y) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: diff --git a/pyproject.toml b/pyproject.toml index d79a3e532..104c746ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ extend-include = ["*.ipynb"] [tool.ruff.lint] select = ["ANN", "E", "F", "I", "W"] ignore = [ - "ANN002", # missing args type annotation - "ANN003", # missing kwargs type annotation - "ANN204", # missing type annotation for __call__ + "ANN002", # missing args type annotation + "ANN003", # missing kwargs type annotation + "ANN204", # missing type annotation for __call__ "E203", ] @@ -68,7 +68,9 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.ruff.lint.per-file-ignores] -"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = ["ANN"] # require type annotations in typed modules +"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = [ + "ANN", +] # require type annotations in typed modules [tool.ruff.lint.extend-per-file-ignores] "numpyro/contrib/tfp/distributions.py" = ["F811"] @@ -129,5 +131,6 @@ module = [ "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*", + "numpyro.distributions.transforms", ] ignore_errors = false From 3b018369e9d3e2f18275edb14989157d0ec603df Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 25 Aug 2025 03:54:10 +0500 Subject: [PATCH 02/16] refactor: update type hints and improve type safety across transforms and constraints --- numpyro/_typing.py | 32 ++-- numpyro/distributions/constraints.py | 20 +-- numpyro/distributions/transforms.py | 215 ++++++++++++++------------- 3 files changed, 144 insertions(+), 123 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 6a379bb7a..38aea907d 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,12 +4,12 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, Union, runtime_checkable try: - from typing import ParamSpec, TypeAlias, Union + from typing import ParamSpec, TypeAlias except ImportError: - from typing_extensions import ParamSpec, TypeAlias, Union + from typing_extensions import ParamSpec, TypeAlias import numpy as np @@ -23,7 +23,7 @@ TraceT: TypeAlias = OrderedDict[str, Message] # ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars -StrictArrayT: TypeAlias = Union[np.ndarray, jax.Array] +StrictArrayT = Union[np.ndarray, jax.Array] @runtime_checkable @@ -94,18 +94,28 @@ def is_discrete(self) -> bool: ... class TransformT(Protocol): domain: ConstraintT = ... codomain: ConstraintT = ... - _inv: "TransformT" = None + _inv: Optional["TransformT"] = ... - def __call__(self, x: jax.Array) -> jax.Array: ... - def _inverse(self, y: jax.Array) -> jax.Array: ... + def __call__(self, x: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ... + def _inverse(self, y: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ... def log_abs_det_jacobian( - self, x: jax.Array, y: jax.Array, intermediates=None - ) -> jax.Array: ... - def call_with_intermediates(self, x: jax.Array) -> tuple[jax.Array, None]: ... + self, + x: Union[jax.Array, Any], + y: Union[jax.Array, Any], + intermediates: Optional[Any] = None, + ) -> Union[jax.Array, Any]: ... + def call_with_intermediates( + self, x: Union[jax.Array, Optional[Any]] + ) -> tuple[Union[jax.Array, Any], Any]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... @property def inv(self) -> "TransformT": ... @property - def sign(self) -> jax.Array: ... + def sign(self) -> Union[ArrayLike, Any]: ... + + +class UnusedParam(object): + def __repr__(self): + return "UnusedParam" diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 9e665bded..317b56b2e 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -801,18 +801,18 @@ def tree_flatten(self): corr_cholesky: ConstraintT = _CorrCholesky() corr_matrix: ConstraintT = _CorrMatrix() dependent: ConstraintT = _Dependent() -greater_than: ConstraintT = _GreaterThan -greater_than_eq: ConstraintT = _GreaterThanEq -less_than: ConstraintT = _LessThan -less_than_eq: ConstraintT = _LessThanEq +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +less_than_eq = _LessThanEq independent = _IndependentConstraint -integer_interval: ConstraintT = _IntegerInterval -integer_greater_than: ConstraintT = _IntegerGreaterThan -interval: ConstraintT = _Interval +integer_interval = _IntegerInterval +integer_greater_than = _IntegerGreaterThan +interval = _Interval l1_ball: ConstraintT = _L1Ball() lower_cholesky: ConstraintT = _LowerCholesky() scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() -multinomial: ConstraintT = _Multinomial +multinomial = _Multinomial nonnegative: ConstraintT = _Nonnegative() nonnegative_integer: ConstraintT = _IntegerNonnegative() ordered_vector: ConstraintT = _OrderedVector() @@ -830,5 +830,5 @@ def tree_flatten(self): softplus_positive: ConstraintT = _SoftplusPositive() sphere: ConstraintT = _Sphere() unit_interval: ConstraintT = _UnitInterval() -open_interval: ConstraintT = _OpenInterval -zero_sum: ConstraintT = _ZeroSum +open_interval = _OpenInterval +zero_sum = _ZeroSum diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4fb970772..78a2af11d 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1,8 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + import math -from typing import Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union import warnings import weakref @@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, StrictArrayT, TransformT +from numpyro._typing import ConstraintT, StrictArrayT, TransformT, UnusedParam from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -75,26 +76,32 @@ def __init_subclass__(cls, **kwargs): @property def inv(self: TransformT) -> TransformT: + # TODO: can not understand the implementation (type wise) inv = None if self._inv is not None: inv = self._inv() if inv is None: inv: TransformT = _InverseTransform(self) - self._inv = weakref.ref(inv) + self._inv: TransformT = weakref.ref(inv) return inv - def __call__(self, x: Array) -> Array: + def __call__(self, x: Union[Array, Any]) -> Union[Array, Any]: raise NotImplementedError - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: Union[Array, Any]) -> Union[Array, Any]: raise NotImplementedError def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: Optional[Array] = None - ) -> Array: + self, + x: Union[Array, Any], + y: Union[Array, Any], + intermediates: Optional[Any] = None, + ) -> Union[Array, Any]: raise NotImplementedError - def call_with_intermediates(self, x: Array) -> Tuple[Array, Optional[Array]]: + def call_with_intermediates( + self, x: Union[Array, Any] + ) -> Tuple[Union[Array, Any], Optional[Union[Array, Any]]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -112,7 +119,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape @property - def sign(self) -> Array: + def sign(self) -> Union[Array, Any]: """ Sign of the derivative of the transform if it is bijective. """ @@ -178,7 +185,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -265,18 +272,18 @@ def codomain(self) -> ConstraintT: # type: ignore[override] def sign(self) -> ArrayLike: return jnp.sign(self.scale) - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: ArrayLike) -> ArrayLike: return self.loc + self.scale * x - def _inverse(self, y: StrictArrayT) -> StrictArrayT: - return (y - self.loc) / self.scale + def _inverse(self, y: ArrayLike) -> ArrayLike: + return (y - self.loc) / self.scale # type: ignore[call-overload,operator] def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, - ) -> StrictArrayT: + x: ArrayLike, + y: Union[ArrayLike, UnusedParam], + intermediates: Union[Any, None, UnusedParam] = None, + ) -> ArrayLike: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -296,10 +303,10 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False return ( - jnp.array_equal(self.loc, other.loc) - & jnp.array_equal(self.scale, other.scale) - & (self.domain == other.domain) - ) # type: ignore[return-value] + jnp.array_equal(self.loc, other.loc) # type: ignore[return-value] + & jnp.array_equal(self.scale, other.scale) # type: ignore[return-value] + & (self.domain == other.domain) # type: ignore[return-value] + ) def _get_compose_transform_input_event_dim(parts): @@ -332,7 +339,7 @@ def domain(self) -> ConstraintT: # type: ignore[override] if input_event_dim == first_input_event_dim: return self.parts[0].domain else: - return constraints.independent( + return constraints.independent( # type: ignore[return-value] self.parts[0].domain, input_event_dim - first_input_event_dim ) @@ -346,11 +353,11 @@ def codomain(self) -> ConstraintT: # type: ignore[override] else: return constraints.independent( self.parts[-1].codomain, output_event_dim - last_output_event_dim - ) + ) # type: ignore[return-value] @property def sign(self) -> ArrayLike: - sign = 1 + sign: ArrayLike = 1 for transform in self.parts: sign *= transform.sign return sign @@ -369,7 +376,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Optional[Any] = None, ) -> StrictArrayT: if intermediates is not None: if len(intermediates) != len(self.parts): @@ -379,7 +386,7 @@ def log_abs_det_jacobian( ) ) - result = 0.0 + result = jnp.zeros(()) input_event_dim = self.domain.event_dim for i, part in enumerate(self.parts[:-1]): y_tmp = part(x) if intermediates is None else intermediates[i][0] @@ -396,9 +403,7 @@ def log_abs_det_jacobian( result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: ArrayLike) -> Tuple[ArrayLike, Sequence]: intermediates = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) @@ -424,7 +429,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False - return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) + return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) # type: ignore[return-value] def _matrix_forward_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]: @@ -469,7 +474,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] @@ -532,7 +537,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the @@ -568,7 +573,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] @@ -581,7 +586,7 @@ class ExpTransform(Transform): # TODO: refine domain/codomain logic through setters, especially when # transforms for inverses are supported - def __init__(self, domain=constraints.real): + def __init__(self, domain: ConstraintT = constraints.real): self.domain = domain @property @@ -600,19 +605,19 @@ def codomain(self) -> ConstraintT: # type: ignore[override] else: raise NotImplementedError - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: ArrayLike) -> ArrayLike: # XXX consider to clamp from below for stability if necessary return jnp.exp(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.log(y) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, - ) -> StrictArrayT: + x: ArrayLike, + y: Union[ArrayLike, UnusedParam], + intermediates: Union[Any, None, UnusedParam] = None, + ) -> ArrayLike: return x def tree_flatten(self): @@ -636,8 +641,8 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + y: Union[StrictArrayT, UnusedParam] = UnusedParam(), + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.zeros_like(x) @@ -650,7 +655,7 @@ class IndependentTransform(Transform): """ def __init__( - self, base_transform: TransformT, reinterpreted_batch_ndims: int + self, base_transform: Transform, reinterpreted_batch_ndims: int ) -> None: assert isinstance(base_transform, Transform) assert isinstance(reinterpreted_batch_ndims, int) @@ -663,13 +668,13 @@ def __init__( def domain(self) -> ConstraintT: # type: ignore[override] return constraints.independent( self.base_transform.domain, self.reinterpreted_batch_ndims - ) + ) # type: ignore[return-value] @property def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.independent( self.base_transform.codomain, self.reinterpreted_batch_ndims - ) + ) # type: ignore[return-value] def __call__(self, x: StrictArrayT) -> StrictArrayT: return self.base_transform(x) @@ -681,7 +686,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Optional[Any] = None, ) -> StrictArrayT: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates @@ -713,7 +718,7 @@ def __eq__(self, other: object) -> bool: return False return (self.base_transform == other.base_transform) & ( self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims - ) + ) # type: ignore[return-value] class L1BallTransform(ParameterFreeTransform): @@ -751,7 +756,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # compute stick-breaking logdet # t1 -> t1 @@ -793,7 +798,7 @@ class LowerCholeskyAffine(Transform): domain = constraints.real_vector codomain = constraints.real_vector - def __init__(self, loc: ArrayLike, scale_tril: Array): + def __init__(self, loc: StrictArrayT, scale_tril: StrictArrayT): if jnp.ndim(scale_tril) != 2: raise ValueError( "Only support 2-dimensional scale_tril matrix. " @@ -818,8 +823,8 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + y: Union[StrictArrayT, UnusedParam], + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), @@ -844,7 +849,7 @@ def __eq__(self, other: object) -> bool: return False return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( self.scale_tril, other.scale_tril - ) + ) # type: ignore[return-value] class LowerCholeskyTransform(ParameterFreeTransform): @@ -871,7 +876,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: ) def log_abs_det_jacobian( - self, x: StrictArrayT, y: StrictArrayT, intermediates=None + self, + x: StrictArrayT, + y: StrictArrayT, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -904,7 +912,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) - return add_diag(z, 1) * diag[..., None] + return add_diag(z, 1) * diag[..., None] # type: ignore[arg-type] def _inverse(self, y: StrictArrayT) -> StrictArrayT: diag = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -915,7 +923,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] @@ -959,7 +967,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.sum(x[..., 1:], -1) @@ -987,7 +995,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.full(jnp.shape(x)[:-1], 0.0) @@ -997,7 +1005,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False - return jnp.array_equal(self.permutation, other.permutation) + return jnp.array_equal(self.permutation, other.permutation) # type: ignore[return-value] class PowerTransform(Transform): @@ -1017,7 +1025,7 @@ def log_abs_det_jacobian( self, x: ArrayLike, y: ArrayLike, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> ArrayLike: return jnp.log(jnp.abs(self.exponent * y / x)) @@ -1033,7 +1041,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False - return jnp.array_equal(self.exponent, other.exponent) + return jnp.array_equal(self.exponent, other.exponent) # type: ignore[return-value] @property def sign(self) -> ArrayLike: @@ -1044,19 +1052,19 @@ class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval sign = 1 - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: ArrayLike) -> ArrayLike: return _clipped_expit(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: ArrayLike) -> ArrayLike: return logit(y) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, - ) -> StrictArrayT: - return -softplus(x) - softplus(-x) + x: ArrayLike, + y: Union[ArrayLike, UnusedParam], + intermediates: Union[Any, None, UnusedParam] = None, + ) -> ArrayLike: + return -softplus(x) - softplus(-x) # type: ignore[operator] class SimplexToOrderedTransform(Transform): @@ -1108,9 +1116,9 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, + x: Union[StrictArrayT, UnusedParam], y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` @@ -1123,7 +1131,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False - return jnp.array_equal(self.anchor_point, other.anchor_point) + return jnp.array_equal(self.anchor_point, other.anchor_point) # type: ignore[return-value] def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) @@ -1133,7 +1141,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def _softplus_inv(y: ArrayLike) -> ArrayLike: - return jnp.log(-jnp.expm1(-y)) + y + return jnp.log(-jnp.expm1(-y)) + y # type: ignore[operator] class SoftplusTransform(ParameterFreeTransform): @@ -1146,19 +1154,19 @@ class SoftplusTransform(ParameterFreeTransform): codomain = constraints.softplus_positive sign = 1 - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: ArrayLike) -> ArrayLike: return softplus(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: ArrayLike) -> ArrayLike: return _softplus_inv(y) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, - ) -> StrictArrayT: - return -softplus(-x) + x: ArrayLike, + y: Union[ArrayLike, UnusedParam], + intermediates: Union[Any, None, UnusedParam] = None, + ) -> ArrayLike: + return -softplus(-x) # type: ignore[operator] class SoftplusLowerCholeskyTransform(ParameterFreeTransform): @@ -1185,8 +1193,8 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + y: Union[StrictArrayT, UnusedParam] = UnusedParam(), + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform @@ -1231,7 +1239,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) @@ -1299,7 +1307,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.zeros(jnp.shape(x)[:-1]) @@ -1355,11 +1363,11 @@ def __init__( @property def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, len(self._inverse_shape)) + return constraints.independent(constraints.real, len(self._inverse_shape)) # type: ignore[return-value] @property def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, len(self._forward_shape)) + return constraints.independent(constraints.real, len(self._forward_shape)) # type: ignore[return-value] def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._forward_shape, self._inverse_shape) @@ -1377,7 +1385,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) @@ -1398,7 +1406,7 @@ def __eq__(self, other: object) -> bool: def _normalize_rfft_shape( input_shape: tuple[int, ...], - shape: tuple[int, ...], + shape: Optional[tuple[int, ...]], ) -> tuple[int, ...]: if shape is None: return input_shape @@ -1452,7 +1460,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1473,11 +1481,11 @@ def tree_flatten(self): @property def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, self.transform_ndims) + return constraints.independent(constraints.real, self.transform_ndims) # type: ignore[return-value] @property def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.complex, self.transform_ndims) + return constraints.independent(constraints.complex, self.transform_ndims) # type: ignore[return-value] def __eq__(self, other: object) -> bool: return ( @@ -1495,13 +1503,13 @@ class PackRealFastFourierCoefficientsTransform(Transform): """ domain = constraints.real_vector - codomain = constraints.independent(constraints.complex, 1) + codomain = constraints.independent(constraints.complex, 1) # type: ignore[assignment] def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None: assert transform_shape is None or len(transform_shape) == 1, ( "Packing Fourier coefficients is only implemented for vectors." ) - self.shape = transform_shape + self.shape: tuple[int, ...] = transform_shape # type: ignore[assignment] def tree_flatten(self): return (), ((), {"shape": self.shape}) @@ -1527,7 +1535,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> Array: shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1]) return jnp.zeros_like(x, shape=shape) @@ -1674,8 +1682,8 @@ def f(y, prev): def log_abs_det_jacobian( self, x: StrictArrayT, - y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + y: Union[StrictArrayT, UnusedParam] = UnusedParam(), + intermediates: Union[Any, None, UnusedParam] = None, ) -> StrictArrayT: return jnp.zeros_like(x, shape=x.shape[:-2]) @@ -1707,19 +1715,19 @@ def __init__(self, transform_ndims: int = 1) -> None: @property def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, self.transform_ndims) + return constraints.independent(constraints.real, self.transform_ndims) # type: ignore[return-value] @property def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.zero_sum(self.transform_ndims) + return constraints.zero_sum(self.transform_ndims) # type: ignore[return-value] - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: Array) -> Array: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: x = self.extend_axis(x, axis=axis) return x - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: Array) -> Array: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: y = self.extend_axis_rev(y, axis=axis) @@ -1750,7 +1758,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[ArrayLike] = None, + intermediates: Union[Any, None, UnusedParam] = None, ) -> jnp.ndarray: shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1790,12 +1798,15 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." return lax.complex(x[..., 0], x[..., 1]) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: ArrayLike) -> Array: return jnp.stack([y.real, y.imag], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates: Optional[ArrayLike] = None - ) -> ArrayLike: + self, + x: Union[ArrayLike, UnusedParam], + y: ArrayLike, + intermediates: Union[Any, None, UnusedParam] = None, + ) -> Array: return jnp.zeros_like(y) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: From 35e6d1b3dcac02af7d334be8e2ffd278efd8741a Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sat, 30 Aug 2025 23:22:41 +0500 Subject: [PATCH 03/16] fix: improve type casting in Transform class and ensure proper handling of inverse transforms Co-authored-by: Juan Orduz --- numpyro/distributions/transforms.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 78a2af11d..98c002401 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -3,7 +3,7 @@ import math -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union, cast import warnings import weakref @@ -76,13 +76,12 @@ def __init_subclass__(cls, **kwargs): @property def inv(self: TransformT) -> TransformT: - # TODO: can not understand the implementation (type wise) inv = None if self._inv is not None: - inv = self._inv() + inv = self._inv if inv is None: - inv: TransformT = _InverseTransform(self) - self._inv: TransformT = weakref.ref(inv) + inv = cast(TransformT, _InverseTransform(self)) + self._inv = cast(TransformT, weakref.ref(inv)) return inv def __call__(self, x: Union[Array, Any]) -> Union[Array, Any]: @@ -1547,7 +1546,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: n_imag = n - n_real complex_dtype = jnp.result_type(x.dtype, jnp.complex64) return ( - x[..., :n_real] + jnp.asarray(x)[..., :n_real] .astype(complex_dtype) .at[..., 1 : 1 + n_imag] .add(1j * x[..., n_real:]) From 23407d324357269e7aec31c18db05c501a130a36 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sat, 30 Aug 2025 23:24:05 +0500 Subject: [PATCH 04/16] chore: remove `numpy.typing.mypy_plugin` from mypy configuration Co-authored-by: Juan Orduz --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 104c746ac..47925bf7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,6 @@ doctest_optionflags = [ [tool.mypy] ignore_errors = true ignore_missing_imports = true -plugins = ["numpy.typing.mypy_plugin"] [[tool.mypy.overrides]] module = [ From 300fead9c44f45495d6b07aab5c2aef136ef7316 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Thu, 4 Sep 2025 16:55:28 +0500 Subject: [PATCH 05/16] fix: update `_inv` type to support weak references in `Transform` class Co-authored-by: Juan Orduz --- numpyro/distributions/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 98c002401..0f209b622 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -68,7 +68,7 @@ def _clipped_expit(x: ArrayLike) -> ArrayLike: class Transform(object): domain = constraints.real codomain = constraints.real - _inv: Optional[TransformT] = None + _inv: Optional[Union[TransformT, weakref.ReferenceType]] = None def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -77,8 +77,8 @@ def __init_subclass__(cls, **kwargs): @property def inv(self: TransformT) -> TransformT: inv = None - if self._inv is not None: - inv = self._inv + if (self._inv is not None) and isinstance(self._inv, weakref.ReferenceType): + inv = self._inv() if inv is None: inv = cast(TransformT, _InverseTransform(self)) self._inv = cast(TransformT, weakref.ref(inv)) From 8aa2592d656a2eeaf47765c3c5c7d0a763725dbe Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sat, 6 Sep 2025 12:36:02 +0500 Subject: [PATCH 06/16] fix: update `intermediates` type to `Optional[PyTree]` in Transform classes and add PyTree type alias --- numpyro/_typing.py | 11 ++-- numpyro/distributions/transforms.py | 80 ++++++++++++++--------------- 2 files changed, 45 insertions(+), 46 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 38aea907d..9c6b2d9a4 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -26,6 +26,10 @@ StrictArrayT = Union[np.ndarray, jax.Array] +PyTree: TypeAlias = Any +"""A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" + + @runtime_checkable class ConstraintT(Protocol): is_discrete: bool = ... @@ -102,7 +106,7 @@ def log_abs_det_jacobian( self, x: Union[jax.Array, Any], y: Union[jax.Array, Any], - intermediates: Optional[Any] = None, + intermediates: Optional[PyTree] = None, ) -> Union[jax.Array, Any]: ... def call_with_intermediates( self, x: Union[jax.Array, Optional[Any]] @@ -114,8 +118,3 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inv(self) -> "TransformT": ... @property def sign(self) -> Union[ArrayLike, Any]: ... - - -class UnusedParam(object): - def __repr__(self): - return "UnusedParam" diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 0f209b622..7b4f359ff 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, StrictArrayT, TransformT, UnusedParam +from numpyro._typing import ConstraintT, PyTree, StrictArrayT, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -94,7 +94,7 @@ def log_abs_det_jacobian( self, x: Union[Array, Any], y: Union[Array, Any], - intermediates: Optional[Any] = None, + intermediates: Optional[PyTree] = None, ) -> Union[Array, Any]: raise NotImplementedError @@ -184,7 +184,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -280,8 +280,8 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: def log_abs_det_jacobian( self, x: ArrayLike, - y: Union[ArrayLike, UnusedParam], - intermediates: Union[Any, None, UnusedParam] = None, + y: ArrayLike, + intermediates: Optional[PyTree] = None, ) -> ArrayLike: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) @@ -375,7 +375,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[Any] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: if intermediates is not None: if len(intermediates) != len(self.parts): @@ -473,7 +473,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] @@ -536,7 +536,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the @@ -572,7 +572,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] @@ -614,8 +614,8 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: def log_abs_det_jacobian( self, x: ArrayLike, - y: Union[ArrayLike, UnusedParam], - intermediates: Union[Any, None, UnusedParam] = None, + y: ArrayLike, + intermediates: Optional[PyTree] = None, ) -> ArrayLike: return x @@ -640,8 +640,8 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, x: StrictArrayT, - y: Union[StrictArrayT, UnusedParam] = UnusedParam(), - intermediates: Union[Any, None, UnusedParam] = None, + y: StrictArrayT, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.zeros_like(x) @@ -685,7 +685,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Optional[Any] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates @@ -755,7 +755,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # compute stick-breaking logdet # t1 -> t1 @@ -822,8 +822,8 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, x: StrictArrayT, - y: Union[StrictArrayT, UnusedParam], - intermediates: Union[Any, None, UnusedParam] = None, + y: StrictArrayT, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), @@ -878,7 +878,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -922,7 +922,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] @@ -966,7 +966,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.sum(x[..., 1:], -1) @@ -994,7 +994,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.full(jnp.shape(x)[:-1], 0.0) @@ -1024,7 +1024,7 @@ def log_abs_det_jacobian( self, x: ArrayLike, y: ArrayLike, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> ArrayLike: return jnp.log(jnp.abs(self.exponent * y / x)) @@ -1060,8 +1060,8 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: def log_abs_det_jacobian( self, x: ArrayLike, - y: Union[ArrayLike, UnusedParam], - intermediates: Union[Any, None, UnusedParam] = None, + y: ArrayLike, + intermediates: Optional[PyTree] = None, ) -> ArrayLike: return -softplus(x) - softplus(-x) # type: ignore[operator] @@ -1115,9 +1115,9 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: Union[StrictArrayT, UnusedParam], + x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` @@ -1162,8 +1162,8 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: def log_abs_det_jacobian( self, x: ArrayLike, - y: Union[ArrayLike, UnusedParam], - intermediates: Union[Any, None, UnusedParam] = None, + y: ArrayLike, + intermediates: Optional[PyTree] = None, ) -> ArrayLike: return -softplus(-x) # type: ignore[operator] @@ -1192,8 +1192,8 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, x: StrictArrayT, - y: Union[StrictArrayT, UnusedParam] = UnusedParam(), - intermediates: Union[Any, None, UnusedParam] = None, + y: StrictArrayT, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform @@ -1238,7 +1238,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) @@ -1306,7 +1306,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.zeros(jnp.shape(x)[:-1]) @@ -1384,7 +1384,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) @@ -1459,7 +1459,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1534,7 +1534,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> Array: shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1]) return jnp.zeros_like(x, shape=shape) @@ -1681,8 +1681,8 @@ def f(y, prev): def log_abs_det_jacobian( self, x: StrictArrayT, - y: Union[StrictArrayT, UnusedParam] = UnusedParam(), - intermediates: Union[Any, None, UnusedParam] = None, + y: StrictArrayT, + intermediates: Optional[PyTree] = None, ) -> StrictArrayT: return jnp.zeros_like(x, shape=x.shape[:-2]) @@ -1757,7 +1757,7 @@ def log_abs_det_jacobian( self, x: StrictArrayT, y: StrictArrayT, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> jnp.ndarray: shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1802,9 +1802,9 @@ def _inverse(self, y: ArrayLike) -> Array: def log_abs_det_jacobian( self, - x: Union[ArrayLike, UnusedParam], + x: ArrayLike, y: ArrayLike, - intermediates: Union[Any, None, UnusedParam] = None, + intermediates: Optional[PyTree] = None, ) -> Array: return jnp.zeros_like(y) From 1e2e67090dd80d4cf7c3b69d4fefcdbc9e2d3f08 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 22 Sep 2025 02:25:58 +0500 Subject: [PATCH 07/16] fix: update type hints in `Transform` classes to use `NonScalarArray` and `NumLike` --- numpyro/_typing.py | 25 +- numpyro/distributions/transforms.py | 340 ++++++++++++++-------------- 2 files changed, 185 insertions(+), 180 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 9c6b2d9a4..6ba07a039 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -22,8 +22,13 @@ Message: TypeAlias = dict[str, Any] TraceT: TypeAlias = OrderedDict[str, Message] -# ArrayLike type has StaticScalar, StrictArrayT has everything except StaticScalars -StrictArrayT = Union[np.ndarray, jax.Array] + +NonScalarArray = Union[np.ndarray, jax.Array] +"""An alias for array-like types excluding scalars.""" + + +NumLike = Union[NonScalarArray, np.number, int, float, complex] +"""An alias for array-like types excluding `np.bool_` and `bool`.""" PyTree: TypeAlias = Any @@ -100,21 +105,19 @@ class TransformT(Protocol): codomain: ConstraintT = ... _inv: Optional["TransformT"] = ... - def __call__(self, x: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ... - def _inverse(self, y: Union[jax.Array, Any]) -> Union[jax.Array, Any]: ... + def __call__(self, x: NumLike) -> NumLike: ... + def _inverse(self, y: NumLike) -> NumLike: ... def log_abs_det_jacobian( self, - x: Union[jax.Array, Any], - y: Union[jax.Array, Any], + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> Union[jax.Array, Any]: ... - def call_with_intermediates( - self, x: Union[jax.Array, Optional[Any]] - ) -> tuple[Union[jax.Array, Any], Any]: ... + ) -> NumLike: ... + def call_with_intermediates(self, x: NumLike) -> tuple[NumLike, PyTree]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... @property def inv(self) -> "TransformT": ... @property - def sign(self) -> Union[ArrayLike, Any]: ... + def sign(self) -> NumLike: ... diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 7b4f359ff..a94702f5d 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, PyTree, StrictArrayT, TransformT +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -60,7 +60,7 @@ ] -def _clipped_expit(x: ArrayLike) -> ArrayLike: +def _clipped_expit(x: NumLike) -> NumLike: finfo = jnp.finfo(jnp.result_type(x)) return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) @@ -84,23 +84,23 @@ def inv(self: TransformT) -> TransformT: self._inv = cast(TransformT, weakref.ref(inv)) return inv - def __call__(self, x: Union[Array, Any]) -> Union[Array, Any]: + def __call__(self, x: Union[NonScalarArray, Any]) -> Union[NonScalarArray, Any]: raise NotImplementedError - def _inverse(self, y: Union[Array, Any]) -> Union[Array, Any]: + def _inverse(self, y: Union[NonScalarArray, Any]) -> Union[NonScalarArray, Any]: raise NotImplementedError def log_abs_det_jacobian( self, - x: Union[Array, Any], - y: Union[Array, Any], + x: Union[NonScalarArray, Any], + y: Union[NonScalarArray, Any], intermediates: Optional[PyTree] = None, - ) -> Union[Array, Any]: + ) -> Union[NonScalarArray, Any]: raise NotImplementedError def call_with_intermediates( - self, x: Union[Array, Any] - ) -> Tuple[Union[Array, Any], Optional[Union[Array, Any]]]: + self, x: Union[NonScalarArray, Any] + ) -> Tuple[Union[NonScalarArray, Any], Optional[PyTree]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -118,7 +118,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape @property - def sign(self) -> Union[Array, Any]: + def sign(self) -> NumLike: """ Sign of the derivative of the transform if it is bijective. """ @@ -170,22 +170,22 @@ def codomain(self) -> ConstraintT: # type: ignore[override] return self._inv.domain @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return self._inv.sign @property def inv(self) -> TransformT: return self._inv - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NumLike) -> NumLike: return self._inv._inverse(x) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NumLike: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -211,10 +211,10 @@ class AbsTransform(ParameterFreeTransform): def __eq__(self, other: object) -> bool: return isinstance(other, AbsTransform) - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.abs(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", @@ -251,38 +251,40 @@ def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.greater_than(self(self.domain.lower_bound)) elif isinstance(self.domain, constraints.less_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): - return constraints.greater_than(self(self.domain.upper_bound)) + return constraints.greater_than(self(self.domain.upper_bound)) # type: ignore[arg-type] # we suppose scale > 0 for any tracer else: - return constraints.less_than(self(self.domain.upper_bound)) + return constraints.less_than(self(self.domain.upper_bound)) # type: ignore[arg-type] elif isinstance(self.domain, constraints.interval): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): - return constraints.interval( - self(self.domain.upper_bound), self(self.domain.lower_bound) + return constraints.interval( # type: ignore[arg-type] + self(self.domain.upper_bound), # type: ignore[arg-type] + self(self.domain.lower_bound), # type: ignore[arg-type] ) else: - return constraints.interval( - self(self.domain.lower_bound), self(self.domain.upper_bound) + return constraints.interval( # type: ignore[arg-type] + self(self.domain.lower_bound), # type: ignore[arg-type] + self(self.domain.upper_bound), # type: ignore[arg-type] ) else: raise NotImplementedError @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return jnp.sign(self.scale) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return self.loc + self.scale * x - def _inverse(self, y: ArrayLike) -> ArrayLike: - return (y - self.loc) / self.scale # type: ignore[call-overload,operator] + def _inverse(self, y: NumLike) -> NumLike: + return (y - self.loc) / self.scale def log_abs_det_jacobian( self, - x: ArrayLike, - y: ArrayLike, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> ArrayLike: + ) -> NumLike: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -355,28 +357,28 @@ def codomain(self) -> ConstraintT: # type: ignore[override] ) # type: ignore[return-value] @property - def sign(self) -> ArrayLike: - sign: ArrayLike = 1 + def sign(self) -> NumLike: + sign: NumLike = 1 for transform in self.parts: sign *= transform.sign return sign - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NumLike) -> NumLike: for part in self.parts: x = part(x) return x - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NumLike) -> NumLike: for part in self.parts[::-1]: y = part.inv(y) return y def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NumLike: if intermediates is not None: if len(intermediates) != len(self.parts): raise ValueError( @@ -385,7 +387,7 @@ def log_abs_det_jacobian( ) ) - result = jnp.zeros(()) + result = 0.0 input_event_dim = self.domain.event_dim for i, part in enumerate(self.parts[:-1]): y_tmp = part(x) if intermediates is None else intermediates[i][0] @@ -402,7 +404,7 @@ def log_abs_det_jacobian( result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result - def call_with_intermediates(self, x: ArrayLike) -> Tuple[ArrayLike, Sequence]: + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, PyTree]: intermediates = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) @@ -463,18 +465,18 @@ class CholeskyTransform(ParameterFreeTransform): domain = constraints.positive_definite codomain = constraints.lower_cholesky - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.linalg.cholesky(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.matmul(y, jnp.swapaxes(y, -2, -1)) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) @@ -513,12 +515,12 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a domain = constraints.real_vector codomain = constraints.corr_cholesky - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim @@ -534,10 +536,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. @@ -570,10 +572,10 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) @@ -597,26 +599,26 @@ def codomain(self) -> ConstraintT: # type: ignore[override] elif isinstance(self.domain, constraints.greater_than): return constraints.greater_than(self.__call__(self.domain.lower_bound)) elif isinstance(self.domain, constraints.interval): - return constraints.interval( - self.__call__(self.domain.lower_bound), - self.__call__(self.domain.upper_bound), + return constraints.interval( # type: ignore[arg-type] + self.__call__(self.domain.lower_bound), # type: ignore[arg-type] + self.__call__(self.domain.upper_bound), # type: ignore[arg-type] ) else: raise NotImplementedError - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: # XXX consider to clamp from below for stability if necessary return jnp.exp(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return jnp.log(y) def log_abs_det_jacobian( self, - x: ArrayLike, - y: ArrayLike, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> ArrayLike: + ) -> NumLike: return x def tree_flatten(self): @@ -631,18 +633,18 @@ def __eq__(self, other: object) -> bool: class IdentityTransform(ParameterFreeTransform): sign = 1 - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return x - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return y def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.zeros_like(x) @@ -675,18 +677,18 @@ def codomain(self) -> ConstraintT: # type: ignore[override] self.base_transform.codomain, self.reinterpreted_batch_ndims ) # type: ignore[return-value] - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return self.base_transform(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return self.base_transform._inverse(y) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) @@ -728,7 +730,7 @@ class L1BallTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.l1_ball - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # transform to (-1, 1) interval t = jnp.tanh(x) @@ -738,7 +740,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) return t * remainder - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] @@ -753,10 +755,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) @@ -797,7 +799,7 @@ class LowerCholeskyAffine(Transform): domain = constraints.real_vector codomain = constraints.real_vector - def __init__(self, loc: StrictArrayT, scale_tril: StrictArrayT): + def __init__(self, loc: NonScalarArray, scale_tril: NonScalarArray): if jnp.ndim(scale_tril) != 2: raise ValueError( "Only support 2-dimensional scale_tril matrix. " @@ -807,12 +809,12 @@ def __init__(self, loc: StrictArrayT, scale_tril: StrictArrayT): self.loc = loc self.scale_tril = scale_tril - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1 ) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y = y - self.loc original_shape = jnp.shape(y) yt = jnp.reshape(y, (-1, original_shape[-1])).T @@ -821,10 +823,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], @@ -862,13 +864,13 @@ class LowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.lower_cholesky - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) return add_diag(z, diag) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate( [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1 @@ -876,10 +878,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1) @@ -907,23 +909,23 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return add_diag(z, 1) * diag[..., None] # type: ignore[arg-type] - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -954,20 +956,20 @@ class OrderedTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.ordered_vector - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1) return jnp.cumsum(z, axis=-1) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: x = jnp.log(y[..., 1:] - y[..., :-1]) return jnp.concatenate([y[..., :1], x], axis=-1) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.sum(x[..., 1:], -1) @@ -978,12 +980,12 @@ class PermuteTransform(Transform): def __init__(self, permutation: Array) -> None: self.permutation = permutation - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return x[..., self.permutation] - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: size = self.permutation.size - permutation_inv: StrictArrayT = ( + permutation_inv: NonScalarArray = ( jnp.zeros(size, dtype=jnp.result_type(int)) .at[self.permutation] .set(jnp.arange(size)) @@ -992,10 +994,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.full(jnp.shape(x)[:-1], 0.0) def tree_flatten(self): @@ -1014,18 +1016,18 @@ class PowerTransform(Transform): def __init__(self, exponent: ArrayLike) -> None: self.exponent = exponent - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return jnp.power(x, self.exponent) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return jnp.power(y, 1 / self.exponent) def log_abs_det_jacobian( self, - x: ArrayLike, - y: ArrayLike, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> ArrayLike: + ) -> NumLike: return jnp.log(jnp.abs(self.exponent * y / x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1043,7 +1045,7 @@ def __eq__(self, other: object) -> bool: return jnp.array_equal(self.exponent, other.exponent) # type: ignore[return-value] @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return jnp.sign(self.exponent) @@ -1051,18 +1053,18 @@ class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return _clipped_expit(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return logit(y) def log_abs_det_jacobian( self, - x: ArrayLike, - y: ArrayLike, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> ArrayLike: + ) -> NumLike: return -softplus(x) - softplus(-x) # type: ignore[operator] @@ -1098,12 +1100,12 @@ class SimplexToOrderedTransform(Transform): def __init__(self, anchor_point: ArrayLike = 0.0) -> None: self.anchor_point = anchor_point - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] @@ -1115,10 +1117,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) @@ -1139,7 +1141,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] + 1,) -def _softplus_inv(y: ArrayLike) -> ArrayLike: +def _softplus_inv(y: ArrayLike) -> NumLike: return jnp.log(-jnp.expm1(-y)) + y # type: ignore[operator] @@ -1153,18 +1155,18 @@ class SoftplusTransform(ParameterFreeTransform): codomain = constraints.softplus_positive sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return softplus(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return _softplus_inv(y) def log_abs_det_jacobian( self, - x: ArrayLike, - y: ArrayLike, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, - ) -> ArrayLike: + ) -> NumLike: return -softplus(-x) # type: ignore[operator] @@ -1178,23 +1180,23 @@ class SoftplusLowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.softplus_lower_cholesky - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -1211,7 +1213,7 @@ class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick @@ -1227,7 +1229,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: ) return z_padded * z1m_cumprod_shifted - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y_crop = y[..., :-1] z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod @@ -1236,10 +1238,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) @@ -1272,7 +1274,7 @@ def __init__(self, unpack_fn, pack_fn=None): self.unpack_fn = unpack_fn self.pack_fn = pack_fn - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) @@ -1282,7 +1284,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: else: return self.unpack_fn(x) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: if self.pack_fn is None: raise NotImplementedError( "pack_fn needs to be provided to perform UnpackTransform.inv." @@ -1304,10 +1306,10 @@ def _inverse(self, y: StrictArrayT) -> StrictArrayT: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1374,18 +1376,18 @@ def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._inverse_shape, self._forward_shape) - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.reshape(x, self.forward_shape(jnp.shape(x))) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) def tree_flatten(self): @@ -1436,11 +1438,11 @@ def __init__( self.transform_shape = transform_shape self.transform_ndims = transform_ndims - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: NonScalarArray) -> NonScalarArray: axes = tuple(range(-self.transform_ndims, 0)) return jnp.fft.rfftn(x, self.transform_shape, axes) - def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: axes = tuple(range(-self.transform_ndims, 0)) return jnp.fft.irfftn(y, self.transform_shape, axes) @@ -1457,10 +1459,10 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] ) @@ -1532,14 +1534,14 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, ) -> Array: shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1]) return jnp.zeros_like(x, shape=shape) - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: assert self.shape is None or self.shape == x.shape[-1:] n = x.shape[-1] n_real = n // 2 + 1 @@ -1552,7 +1554,7 @@ def __call__(self, x: StrictArrayT) -> StrictArrayT: .add(1j * x[..., n_real:]) ) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: (n,) = self.shape n_real = n // 2 + 1 n_imag = n - n_real @@ -1619,8 +1621,8 @@ class RecursiveLinearTransform(Transform): def __init__( self, - transition_matrix: StrictArrayT, - initial_value: Optional[StrictArrayT] = None, + transition_matrix: NonScalarArray, + initial_value: Optional[NonScalarArray] = None, ) -> None: event_shape = transition_matrix.shape[-1:] @@ -1648,7 +1650,7 @@ def _get_initial_value(self, sample_shape) -> Array: return jnp.broadcast_to(self.initial_value, batch_shape + event_shape) - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # Move the time axis to the first position so we can scan over it. sample_shape = x.shape[:-2] x = jnp.moveaxis(x, -2, 0) @@ -1662,7 +1664,7 @@ def f(y, x): _, y = lax.scan(f, initial_value, x) return jnp.moveaxis(y, 0, -2) - def _inverse(self, y: StrictArrayT) -> StrictArrayT: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # Move the time axis to the first position so we can scan over it in reverse. sample_shape = y.shape[:-2] y = jnp.moveaxis(y, -2, 0) @@ -1680,10 +1682,10 @@ def f(y, prev): def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> StrictArrayT: + ) -> NonScalarArray: return jnp.zeros_like(x, shape=x.shape[:-2]) def tree_flatten(self): @@ -1720,19 +1722,19 @@ def domain(self) -> ConstraintT: # type: ignore[override] def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.zero_sum(self.transform_ndims) # type: ignore[return-value] - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: x = self.extend_axis(x, axis=axis) return x - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: y = self.extend_axis_rev(y, axis=axis) return y - def extend_axis_rev(self, array: Array, axis: int) -> Array: + def extend_axis_rev(self, array: NonScalarArray, axis: int) -> NonScalarArray: normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis n = array.shape[normalized_axis] @@ -1743,7 +1745,7 @@ def extend_axis_rev(self, array: Array, axis: int) -> Array: slice_before = (slice(None, None),) * normalized_axis return array[(*slice_before, slice(None, -1))] + norm - def extend_axis(self, array: Array, axis: int) -> Array: + def extend_axis(self, array: NonScalarArray, axis: int) -> NonScalarArray: n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) @@ -1755,8 +1757,8 @@ def extend_axis(self, array: Array, axis: int) -> Array: def log_abs_det_jacobian( self, - x: StrictArrayT, - y: StrictArrayT, + x: NonScalarArray, + y: NonScalarArray, intermediates: Optional[PyTree] = None, ) -> jnp.ndarray: shape = jnp.broadcast_shapes( @@ -1793,7 +1795,7 @@ class ComplexTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.complex - def __call__(self, x: StrictArrayT) -> StrictArrayT: + def __call__(self, x: NonScalarArray) -> NonScalarArray: assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." return lax.complex(x[..., 0], x[..., 1]) @@ -1802,8 +1804,8 @@ def _inverse(self, y: ArrayLike) -> Array: def log_abs_det_jacobian( self, - x: ArrayLike, - y: ArrayLike, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, ) -> Array: return jnp.zeros_like(y) From 2dbbce21626da93eb8eb4a27f5b7a42ae5ed8260 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 22 Sep 2025 03:27:53 +0500 Subject: [PATCH 08/16] fix: return type of `log_abs_det_jacobian` changed to `NumLike` --- numpyro/distributions/transforms.py | 36 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index a94702f5d..4ede31106 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -476,7 +476,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) @@ -539,7 +539,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. @@ -575,7 +575,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) @@ -644,7 +644,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.zeros_like(x) @@ -688,7 +688,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) @@ -758,7 +758,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) @@ -826,7 +826,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], @@ -881,7 +881,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1) @@ -925,7 +925,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -969,7 +969,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.sum(x[..., 1:], -1) @@ -997,7 +997,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.full(jnp.shape(x)[:-1], 0.0) def tree_flatten(self): @@ -1120,7 +1120,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) @@ -1196,7 +1196,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -1241,7 +1241,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) @@ -1309,7 +1309,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1387,7 +1387,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) def tree_flatten(self): @@ -1462,7 +1462,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] ) @@ -1685,7 +1685,7 @@ def log_abs_det_jacobian( x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None, - ) -> NonScalarArray: + ) -> NumLike: return jnp.zeros_like(x, shape=x.shape[:-2]) def tree_flatten(self): From c3d8bb25743a96e523348c1797f0b674450be3c7 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 22 Sep 2025 15:14:33 +0500 Subject: [PATCH 09/16] fix: update `call_with_intermediates` method to return `Optional[PyTree]` --- numpyro/_typing.py | 4 +++- numpyro/distributions/transforms.py | 8 +++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 6ba07a039..3be15dcb8 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -113,7 +113,9 @@ def log_abs_det_jacobian( y: NumLike, intermediates: Optional[PyTree] = None, ) -> NumLike: ... - def call_with_intermediates(self, x: NumLike) -> tuple[NumLike, PyTree]: ... + def call_with_intermediates( + self, x: NumLike + ) -> tuple[NumLike, Optional[PyTree]]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4ede31106..6f187efd1 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -404,8 +404,8 @@ def log_abs_det_jacobian( result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result - def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, PyTree]: - intermediates = [] + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]: + intermediates: list[Optional[PyTree]] = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) intermediates.append([x, inter]) @@ -697,9 +697,7 @@ def log_abs_det_jacobian( raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}") return sum_rightmost(result, self.reinterpreted_batch_ndims) - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]: return self.base_transform.call_with_intermediates(x) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: From 69c1ed5eeeb9ddc0841f723d7089cc51a0cd5a16 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 18 Oct 2025 12:22:24 -0400 Subject: [PATCH 10/16] using Generic type for the base transform with NumLike bound --- numpyro/_typing.py | 27 ++- numpyro/distributions/constraints.py | 126 +++++------ numpyro/distributions/transforms.py | 299 +++++++++++++++------------ 3 files changed, 252 insertions(+), 200 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 3be15dcb8..1a9695a83 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -5,6 +5,7 @@ from collections import OrderedDict from collections.abc import Callable from typing import Any, Optional, Protocol, Union, runtime_checkable +import weakref try: from typing import ParamSpec, TypeAlias @@ -37,8 +38,12 @@ @runtime_checkable class ConstraintT(Protocol): - is_discrete: bool = ... - event_dim: int = ... + """A protocol for typing constraints.""" + + @property + def is_discrete(self) -> bool: ... + @property + def event_dim(self) -> int: ... def __call__(self, x: ArrayLike) -> ArrayLike: ... def __repr__(self) -> str: ... @@ -101,9 +106,16 @@ def is_discrete(self) -> bool: ... @runtime_checkable class TransformT(Protocol): - domain: ConstraintT = ... - codomain: ConstraintT = ... - _inv: Optional["TransformT"] = ... + _inv: Optional[Union["TransformT", weakref.ref]] = ... + + @property + def domain(self) -> ConstraintT: ... + @property + def codomain(self) -> ConstraintT: ... + @property + def inv(self) -> "TransformT": ... + @property + def sign(self) -> NumLike: ... def __call__(self, x: NumLike) -> NumLike: ... def _inverse(self, y: NumLike) -> NumLike: ... @@ -118,8 +130,3 @@ def call_with_intermediates( ) -> tuple[NumLike, Optional[PyTree]]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... - - @property - def inv(self) -> "TransformT": ... - @property - def sign(self) -> NumLike: ... diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 317b56b2e..400ebb2f0 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -73,7 +73,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT +from numpyro._typing import ConstraintT, NonScalarArray, NumLike class Constraint(object): @@ -153,7 +153,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _CorrCholesky(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -165,7 +165,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: unit_norm_row = jnp.all(jnp.abs(x_norm - 1) <= tol, axis=-1) return lower_triangular & positive_diagonal & unit_norm_row - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -174,7 +174,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _CorrMatrix(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -186,7 +186,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return symmetric & positive & unit_variance - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -284,10 +284,10 @@ def is_dependent(constraint): class _GreaterThan(Constraint): - def __init__(self, lower_bound): + def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x > self.lower_bound def __repr__(self) -> str: @@ -295,7 +295,7 @@ def __repr__(self) -> str: fmt_string += "(lower_bound={})".format(self.lower_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -304,17 +304,17 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _GreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return self.lower_bound is other.lower_bound class _GreaterThanEq(_GreaterThan): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x >= self.lower_bound def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _GreaterThanEq): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return self.lower_bound is other.lower_bound class _Positive(_SingletonConstraint, _GreaterThan): @@ -347,6 +347,10 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims): self.reinterpreted_batch_ndims = reinterpreted_batch_ndims super().__init__() + @property + def is_discrete(self) -> bool: + return self.base_constraint.is_discrete + @property def event_dim(self) -> int: return self.base_constraint.event_dim + self.reinterpreted_batch_ndims @@ -405,10 +409,10 @@ def __init__(self) -> None: class _LessThan(Constraint): - def __init__(self, upper_bound: ArrayLike) -> None: + def __init__(self, upper_bound: NumLike) -> None: self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x < self.upper_bound def __repr__(self) -> str: @@ -416,7 +420,7 @@ def __repr__(self) -> str: fmt_string += "(upper_bound={})".format(self.upper_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -425,27 +429,27 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _LessThan): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return self.upper_bound is other.upper_bound class _LessThanEq(_LessThan): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return x <= self.upper_bound def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _LessThanEq): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return self.upper_bound is other.upper_bound class _IntegerInterval(Constraint): is_discrete = True - def __init__(self, lower_bound: ArrayLike, upper_bound: ArrayLike) -> None: + def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0) def __repr__(self) -> str: @@ -455,7 +459,7 @@ def __repr__(self) -> str: ) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -468,18 +472,18 @@ def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _IntegerInterval): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( - self.upper_bound, other.upper_bound + return (self.lower_bound is other.lower_bound) and ( + self.upper_bound is other.upper_bound ) class _IntegerGreaterThan(Constraint): is_discrete = True - def __init__(self, lower_bound: ArrayLike) -> None: + def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x % 1 == 0) & (x >= self.lower_bound) def __repr__(self) -> str: @@ -487,7 +491,7 @@ def __repr__(self) -> str: fmt_string += "(lower_bound={})".format(self.lower_bound) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype)) def tree_flatten(self): @@ -496,7 +500,7 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _IntegerGreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return self.lower_bound is other.lower_bound class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): @@ -510,11 +514,11 @@ def __init__(self) -> None: class _Interval(Constraint): - def __init__(self, lower_bound: ArrayLike, upper_bound: ArrayLike) -> None: + def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x >= self.lower_bound) & (x <= self.upper_bound) def __repr__(self) -> str: @@ -524,7 +528,7 @@ def __repr__(self) -> str: ) return fmt_string - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.broadcast_to( (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) ) @@ -532,8 +536,8 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _Interval): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( - self.upper_bound, other.upper_bound + return (self.lower_bound is other.lower_bound) and ( + self.upper_bound is other.upper_bound ) def tree_flatten(self): @@ -554,7 +558,7 @@ def __init__(self) -> None: class _OpenInterval(_Interval): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: return (x > self.lower_bound) & (x < self.upper_bound) def __repr__(self) -> str: @@ -568,7 +572,7 @@ def __repr__(self) -> str: class _LowerCholesky(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -577,7 +581,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) return lower_triangular & positive_diagonal - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -590,10 +594,10 @@ class _Multinomial(Constraint): def __init__(self, upper_bound: ArrayLike) -> None: self.upper_bound = upper_bound - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: pad_width = ((0, 0),) * jax.numpy.ndim(self.upper_bound) + ( (0, prototype.shape[-1] - 1), ) @@ -606,7 +610,7 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _Multinomial): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return self.upper_bound is other.upper_bound class _L1Ball(_SingletonConstraint): @@ -617,22 +621,22 @@ class _L1Ball(_SingletonConstraint): event_dim = 1 reltol = 10.0 # Relative to finfo.eps. - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy eps = jnp.finfo(x.dtype).eps return jnp.abs(x).sum(axis=-1) < 1 + self.reltol * eps - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) class _OrderedVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: return (x[..., 1:] > x[..., :-1]).all(axis=-1) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.arange(float(prototype.shape[-1])), prototype.shape ) @@ -641,7 +645,7 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _PositiveDefinite(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -649,7 +653,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: positive = jnp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -658,20 +662,20 @@ def feasible_like(self, prototype: ArrayLike) -> ArrayLike: class _PositiveDefiniteCirculantVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = 10 * jnp.finfo(x.dtype).eps rfft = jnp.fft.rfft(x) return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jnp.zeros_like(prototype).at[..., 0].set(1.0) class _PositiveSemiDefinite(_SingletonConstraint): event_dim = 2 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -679,7 +683,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 return symmetric & nonnegative - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]), prototype.shape ) @@ -693,41 +697,41 @@ class _PositiveOrderedVector(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: return ordered_vector.check(x) & independent(positive, 1).check(x) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.exp(jax.numpy.arange(float(prototype.shape[-1]))), prototype.shape ) class _Complex(_SingletonConstraint): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) class _Real(_SingletonConstraint): - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) class _Simplex(_SingletonConstraint): event_dim = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: x_sum = x.sum(axis=-1) return (x >= 0).all(axis=-1) & (x_sum < 1 + 1e-6) & (x_sum > 1 - 1e-6) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) @@ -735,12 +739,12 @@ class _SoftplusPositive(_SingletonConstraint, _GreaterThan): def __init__(self) -> None: super().__init__(lower_bound=0.0) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.full(jax.numpy.shape(prototype), np.log(2)) class _SoftplusLowerCholesky(_LowerCholesky): - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( jax.numpy.eye(prototype.shape[-1]) * np.log(2), prototype.shape ) @@ -758,14 +762,14 @@ class _Sphere(_SingletonConstraint): event_dim = 1 reltol = 10.0 # Relative to finfo.eps. - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy eps = jnp.finfo(x.dtype).eps norm = jnp.linalg.norm(x, axis=-1) error = jnp.abs(norm - 1) return error < self.reltol * eps * x.shape[-1] ** 0.5 - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) @@ -774,7 +778,7 @@ def __init__(self, event_dim: int = 1) -> None: self.event_dim = event_dim super().__init__() - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 zerosum_true = True @@ -785,7 +789,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: def __eq__(self, other: ConstraintT) -> bool: return type(self) is type(other) and self.event_dim == other.event_dim - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.zeros_like(prototype) def tree_flatten(self): diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 6f187efd1..c2aabda21 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -3,7 +3,7 @@ import math -from typing import Any, Optional, Sequence, Tuple, Union, cast +from typing import Generic, Optional, Sequence, Tuple, TypeVar, Union, cast import warnings import weakref @@ -65,42 +65,52 @@ def _clipped_expit(x: NumLike) -> NumLike: return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) -class Transform(object): - domain = constraints.real - codomain = constraints.real +NumLikeT = TypeVar("NumLikeT", bound=NumLike) + + +class Transform(Generic[NumLikeT]): _inv: Optional[Union[TransformT, weakref.ReferenceType]] = None + @property + def domain(self) -> ConstraintT: + if hasattr(self, "_domain"): + return self._domain + return constraints.real + + @property + def codomain(self) -> ConstraintT: + if hasattr(self, "_codomain"): + return self._codomain + return constraints.real + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) @property def inv(self: TransformT) -> TransformT: - inv = None - if (self._inv is not None) and isinstance(self._inv, weakref.ReferenceType): + if (self._inv is not None) and isinstance(self._inv, weakref.ref): inv = self._inv() - if inv is None: - inv = cast(TransformT, _InverseTransform(self)) - self._inv = cast(TransformT, weakref.ref(inv)) - return inv + else: + inv = _InverseTransform(self) + self._inv = weakref.ref(inv) + return cast(TransformT, inv) - def __call__(self, x: Union[NonScalarArray, Any]) -> Union[NonScalarArray, Any]: - raise NotImplementedError + def __call__(self, x: NumLikeT) -> NumLike: + raise NotImplementedError() - def _inverse(self, y: Union[NonScalarArray, Any]) -> Union[NonScalarArray, Any]: - raise NotImplementedError + def _inverse(self, y: NumLikeT) -> NumLike: + raise NotImplementedError() def log_abs_det_jacobian( self, - x: Union[NonScalarArray, Any], - y: Union[NonScalarArray, Any], + x: NumLikeT, + y: NumLikeT, intermediates: Optional[PyTree] = None, - ) -> Union[NonScalarArray, Any]: - raise NotImplementedError + ) -> NumLike: + raise NotImplementedError() - def call_with_intermediates( - self, x: Union[NonScalarArray, Any] - ) -> Tuple[Union[NonScalarArray, Any], Optional[PyTree]]: + def call_with_intermediates(self, x: NumLikeT) -> Tuple[NumLike, Optional[PyTree]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -148,7 +158,7 @@ def tree_unflatten(cls, aux_data, params): return self -class ParameterFreeTransform(Transform): +class ParameterFreeTransform(Transform[NumLikeT]): def tree_flatten(self): return (), ((), dict()) @@ -156,17 +166,17 @@ def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) -class _InverseTransform(Transform): +class _InverseTransform(Transform[NumLike]): def __init__(self, transform: TransformT) -> None: super().__init__() self._inv: TransformT = transform @property - def domain(self) -> ConstraintT: # type: ignore[override] + def domain(self) -> ConstraintT: return self._inv.codomain @property - def codomain(self) -> ConstraintT: # type: ignore[override] + def codomain(self) -> ConstraintT: return self._inv.domain @property @@ -204,17 +214,17 @@ def __eq__(self, other: object) -> bool: return self._inv == other._inv -class AbsTransform(ParameterFreeTransform): +class AbsTransform(ParameterFreeTransform[NumLike]): domain = constraints.real codomain = constraints.positive def __eq__(self, other: object) -> bool: return isinstance(other, AbsTransform) - def __call__(self, x: NonScalarArray) -> NonScalarArray: + def __call__(self, x: NumLike) -> NumLike: return jnp.abs(x) - def _inverse(self, y: NonScalarArray) -> NonScalarArray: + def _inverse(self, y: NumLike) -> NumLike: warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", @@ -223,7 +233,7 @@ def _inverse(self, y: NonScalarArray) -> NonScalarArray: return y -class AffineTransform(Transform): +class AffineTransform(Transform[NumLike]): """ .. note:: When `scale` is a JAX tracer, we always assume that `scale > 0` when calculating `codomain`. @@ -231,16 +241,16 @@ class AffineTransform(Transform): def __init__( self, - loc: ArrayLike, - scale: ArrayLike, + loc: NumLike, + scale: NumLike, domain: ConstraintT = constraints.real, ): self.loc = loc self.scale = scale - self.domain = domain + self._domain = domain @property - def codomain(self) -> ConstraintT: # type: ignore[override] + def codomain(self) -> ConstraintT: if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): @@ -251,20 +261,20 @@ def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.greater_than(self(self.domain.lower_bound)) elif isinstance(self.domain, constraints.less_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): - return constraints.greater_than(self(self.domain.upper_bound)) # type: ignore[arg-type] + return constraints.greater_than(self(self.domain.upper_bound)) # we suppose scale > 0 for any tracer else: - return constraints.less_than(self(self.domain.upper_bound)) # type: ignore[arg-type] + return constraints.less_than(self(self.domain.upper_bound)) elif isinstance(self.domain, constraints.interval): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): - return constraints.interval( # type: ignore[arg-type] - self(self.domain.upper_bound), # type: ignore[arg-type] - self(self.domain.lower_bound), # type: ignore[arg-type] + return constraints.interval( + self(self.domain.upper_bound), + self(self.domain.lower_bound), ) else: - return constraints.interval( # type: ignore[arg-type] - self(self.domain.lower_bound), # type: ignore[arg-type] - self(self.domain.upper_bound), # type: ignore[arg-type] + return constraints.interval( + self(self.domain.lower_bound), + self(self.domain.upper_bound), ) else: raise NotImplementedError @@ -274,7 +284,7 @@ def sign(self) -> NumLike: return jnp.sign(self.scale) def __call__(self, x: NumLike) -> NumLike: - return self.loc + self.scale * x + return self.loc + jnp.multiply(self.scale, x) def _inverse(self, y: NumLike) -> NumLike: return (y - self.loc) / self.scale @@ -298,15 +308,18 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ) def tree_flatten(self): - return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) + return (self.loc, self.scale, self.domain), ( + ("loc", "scale", "_domain"), + dict(), + ) def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False return ( - jnp.array_equal(self.loc, other.loc) # type: ignore[return-value] - & jnp.array_equal(self.scale, other.scale) # type: ignore[return-value] - & (self.domain == other.domain) # type: ignore[return-value] + (self.loc is other.loc) + and (self.scale is other.scale) + and (self.domain == other.domain) ) @@ -328,33 +341,39 @@ def _get_compose_transform_output_event_dim(parts): return output_event_dim -class ComposeTransform(Transform): +class ComposeTransform(Transform[NumLike]): def __init__(self, parts: Sequence[TransformT]) -> None: self.parts = parts @property - def domain(self) -> ConstraintT: # type: ignore[override] + def domain(self) -> ConstraintT: input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim if input_event_dim == first_input_event_dim: return self.parts[0].domain else: - return constraints.independent( # type: ignore[return-value] - self.parts[0].domain, input_event_dim - first_input_event_dim + return cast( + ConstraintT, + constraints.independent( + self.parts[0].domain, input_event_dim - first_input_event_dim + ), ) @property - def codomain(self) -> ConstraintT: # type: ignore[override] + def codomain(self) -> ConstraintT: output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim if output_event_dim == last_output_event_dim: return self.parts[-1].codomain else: - return constraints.independent( - self.parts[-1].codomain, output_event_dim - last_output_event_dim - ) # type: ignore[return-value] + return cast( + ConstraintT, + constraints.independent( + self.parts[-1].codomain, output_event_dim - last_output_event_dim + ), + ) @property def sign(self) -> NumLike: @@ -430,7 +449,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False - return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) # type: ignore[return-value] + return all(p1 == p2 for p1, p2 in zip(self.parts, other.parts)) def _matrix_forward_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]: @@ -456,7 +475,7 @@ def _matrix_inverse_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, return shape[:-2] + (N,) -class CholeskyTransform(ParameterFreeTransform): +class CholeskyTransform(ParameterFreeTransform[NonScalarArray]): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a positive definite matrix. @@ -485,7 +504,7 @@ def log_abs_det_jacobian( ) -class CorrCholeskyTransform(ParameterFreeTransform): +class CorrCholeskyTransform(ParameterFreeTransform[NonScalarArray]): r""" Transforms a unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower @@ -582,16 +601,16 @@ def log_abs_det_jacobian( return jnp.sum(order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1) -class ExpTransform(Transform): +class ExpTransform(Transform[NumLike]): sign = 1 # TODO: refine domain/codomain logic through setters, especially when # transforms for inverses are supported def __init__(self, domain: ConstraintT = constraints.real): - self.domain = domain + self._domain = domain @property - def codomain(self) -> ConstraintT: # type: ignore[override] + def codomain(self) -> ConstraintT: if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: @@ -599,9 +618,9 @@ def codomain(self) -> ConstraintT: # type: ignore[override] elif isinstance(self.domain, constraints.greater_than): return constraints.greater_than(self.__call__(self.domain.lower_bound)) elif isinstance(self.domain, constraints.interval): - return constraints.interval( # type: ignore[arg-type] - self.__call__(self.domain.lower_bound), # type: ignore[arg-type] - self.__call__(self.domain.upper_bound), # type: ignore[arg-type] + return constraints.interval( + self.__call__(self.domain.lower_bound), + self.__call__(self.domain.upper_bound), ) else: raise NotImplementedError @@ -630,25 +649,25 @@ def __eq__(self, other: object) -> bool: return self.domain == other.domain -class IdentityTransform(ParameterFreeTransform): +class IdentityTransform(ParameterFreeTransform[NumLike]): sign = 1 - def __call__(self, x: NonScalarArray) -> NonScalarArray: + def __call__(self, x: NumLike) -> NumLike: return x - def _inverse(self, y: NonScalarArray) -> NonScalarArray: + def _inverse(self, y: NumLike) -> NumLike: return y def log_abs_det_jacobian( self, - x: NonScalarArray, - y: NonScalarArray, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, ) -> NumLike: return jnp.zeros_like(x) -class IndependentTransform(Transform): +class IndependentTransform(Transform[NumLike]): """ Wraps a transform by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its @@ -656,7 +675,7 @@ class IndependentTransform(Transform): """ def __init__( - self, base_transform: Transform, reinterpreted_batch_ndims: int + self, base_transform: TransformT, reinterpreted_batch_ndims: int ) -> None: assert isinstance(base_transform, Transform) assert isinstance(reinterpreted_batch_ndims, int) @@ -666,27 +685,33 @@ def __init__( super().__init__() @property - def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent( - self.base_transform.domain, self.reinterpreted_batch_ndims - ) # type: ignore[return-value] + def domain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent( + self.base_transform.domain, self.reinterpreted_batch_ndims + ), + ) @property - def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent( - self.base_transform.codomain, self.reinterpreted_batch_ndims - ) # type: ignore[return-value] + def codomain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent( + self.base_transform.codomain, self.reinterpreted_batch_ndims + ), + ) - def __call__(self, x: NonScalarArray) -> NonScalarArray: + def __call__(self, x: NumLike) -> NumLike: return self.base_transform(x) - def _inverse(self, y: NonScalarArray) -> NonScalarArray: + def _inverse(self, y: NumLike) -> NumLike: return self.base_transform._inverse(y) def log_abs_det_jacobian( self, - x: NonScalarArray, - y: NonScalarArray, + x: NumLike, + y: NumLike, intermediates: Optional[PyTree] = None, ) -> NumLike: result = self.base_transform.log_abs_det_jacobian( @@ -717,10 +742,10 @@ def __eq__(self, other: object) -> bool: return False return (self.base_transform == other.base_transform) & ( self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims - ) # type: ignore[return-value] + ) -class L1BallTransform(ParameterFreeTransform): +class L1BallTransform(ParameterFreeTransform[NonScalarArray]): r""" Transforms a unconstrained real vector :math:`x` into the unit L1 ball. """ @@ -773,7 +798,7 @@ def log_abs_det_jacobian( return stick_breaking_logdet + tanh_logdet -class LowerCholeskyAffine(Transform): +class LowerCholeskyAffine(Transform[NonScalarArray]): r""" Transform via the mapping :math:`y = loc + scale\_tril\ @\ x`. @@ -846,12 +871,10 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, LowerCholeskyAffine): return False - return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( - self.scale_tril, other.scale_tril - ) # type: ignore[return-value] + return (self.loc is other.loc) and (self.scale_tril is other.scale_tril) -class LowerCholeskyTransform(ParameterFreeTransform): +class LowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]): """ Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is @@ -911,7 +934,7 @@ def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) - return add_diag(z, 1) * diag[..., None] # type: ignore[arg-type] + return add_diag(z, jnp.array(1)) * diag[..., None] def _inverse(self, y: NonScalarArray) -> NonScalarArray: diag = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -930,7 +953,7 @@ def log_abs_det_jacobian( return (jnp.log(diag_softplus) * jnp.arange(n) - softplus(-diag)).sum(-1) -class OrderedTransform(ParameterFreeTransform): +class OrderedTransform(ParameterFreeTransform[NonScalarArray]): """ Transform a real vector to an ordered vector. @@ -971,7 +994,7 @@ def log_abs_det_jacobian( return jnp.sum(x[..., 1:], -1) -class PermuteTransform(Transform): +class PermuteTransform(Transform[NonScalarArray]): domain = constraints.real_vector codomain = constraints.real_vector @@ -1004,14 +1027,14 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False - return jnp.array_equal(self.permutation, other.permutation) # type: ignore[return-value] + return self.permutation is other.permutation -class PowerTransform(Transform): +class PowerTransform(Transform[NumLike]): domain = constraints.positive codomain = constraints.positive - def __init__(self, exponent: ArrayLike) -> None: + def __init__(self, exponent: NumLike) -> None: self.exponent = exponent def __call__(self, x: NumLike) -> NumLike: @@ -1026,7 +1049,7 @@ def log_abs_det_jacobian( y: NumLike, intermediates: Optional[PyTree] = None, ) -> NumLike: - return jnp.log(jnp.abs(self.exponent * y / x)) + return jnp.log(jnp.abs(jnp.multiply(self.exponent, y) / x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) @@ -1040,14 +1063,14 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False - return jnp.array_equal(self.exponent, other.exponent) # type: ignore[return-value] + return self.exponent is other.exponent @property def sign(self) -> NumLike: return jnp.sign(self.exponent) -class SigmoidTransform(ParameterFreeTransform): +class SigmoidTransform(ParameterFreeTransform[NumLike]): codomain = constraints.unit_interval sign = 1 @@ -1063,10 +1086,10 @@ def log_abs_det_jacobian( y: NumLike, intermediates: Optional[PyTree] = None, ) -> NumLike: - return -softplus(x) - softplus(-x) # type: ignore[operator] + return -softplus(x) - softplus(-x) -class SimplexToOrderedTransform(Transform): +class SimplexToOrderedTransform(Transform[NonScalarArray]): """ Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities. @@ -1130,7 +1153,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False - return jnp.array_equal(self.anchor_point, other.anchor_point) # type: ignore[return-value] + return self.anchor_point is other.anchor_point def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) @@ -1139,11 +1162,11 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] + 1,) -def _softplus_inv(y: ArrayLike) -> NumLike: - return jnp.log(-jnp.expm1(-y)) + y # type: ignore[operator] +def _softplus_inv(y: NumLike) -> NumLike: + return jnp.log(-jnp.expm1(-y)) + y -class SoftplusTransform(ParameterFreeTransform): +class SoftplusTransform(ParameterFreeTransform[NumLike]): r""" Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y) - 1)`. @@ -1165,10 +1188,10 @@ def log_abs_det_jacobian( y: NumLike, intermediates: Optional[PyTree] = None, ) -> NumLike: - return -softplus(-x) # type: ignore[operator] + return -softplus(-x) -class SoftplusLowerCholeskyTransform(ParameterFreeTransform): +class SoftplusLowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]): """ Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive @@ -1207,7 +1230,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _matrix_inverse_shape(shape) -class StickBreakingTransform(ParameterFreeTransform): +class StickBreakingTransform(ParameterFreeTransform[NonScalarArray]): domain = constraints.real_vector codomain = constraints.simplex @@ -1257,7 +1280,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) -class UnpackTransform(Transform): +class UnpackTransform(Transform[NonScalarArray]): """ Transforms a contiguous array to a pytree of subarrays. @@ -1337,7 +1360,7 @@ def _get_target_shape( return shape[:batch_ndims] + forward_shape -class ReshapeTransform(Transform): +class ReshapeTransform(Transform[NonScalarArray]): """ Reshape a sample, leaving batch dimensions unchanged. @@ -1361,12 +1384,18 @@ def __init__( self._inverse_shape = inverse_shape @property - def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, len(self._inverse_shape)) # type: ignore[return-value] + def domain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent(constraints.real, len(self._inverse_shape)), + ) @property - def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, len(self._forward_shape)) # type: ignore[return-value] + def codomain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent(constraints.real, len(self._forward_shape)), + ) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._forward_shape, self._inverse_shape) @@ -1412,7 +1441,7 @@ def _normalize_rfft_shape( return input_shape[: len(input_shape) - len(shape)] + shape -class RealFastFourierTransform(Transform): +class RealFastFourierTransform(Transform[NonScalarArray]): """ N-dimensional discrete fast Fourier transform for real input. @@ -1479,12 +1508,17 @@ def tree_flatten(self): return (), ((), aux_data) @property - def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, self.transform_ndims) # type: ignore[return-value] + def domain(self) -> ConstraintT: + return cast( + ConstraintT, constraints.independent(constraints.real, self.transform_ndims) + ) @property - def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.complex, self.transform_ndims) # type: ignore[return-value] + def codomain(self) -> ConstraintT: + return cast( + ConstraintT, + constraints.independent(constraints.complex, self.transform_ndims), + ) def __eq__(self, other: object) -> bool: return ( @@ -1494,7 +1528,7 @@ def __eq__(self, other: object) -> bool: ) -class PackRealFastFourierCoefficientsTransform(Transform): +class PackRealFastFourierCoefficientsTransform(Transform[NonScalarArray]): """ Transform a real vector to complex coefficients of a real fast Fourier transform. @@ -1502,13 +1536,13 @@ class PackRealFastFourierCoefficientsTransform(Transform): """ domain = constraints.real_vector - codomain = constraints.independent(constraints.complex, 1) # type: ignore[assignment] + codomain = cast(ConstraintT, constraints.independent(constraints.complex, 1)) def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None: assert transform_shape is None or len(transform_shape) == 1, ( "Packing Fourier coefficients is only implemented for vectors." ) - self.shape: tuple[int, ...] = transform_shape # type: ignore[assignment] + self.shape: Optional[tuple[int, ...]] = transform_shape def tree_flatten(self): return (), ((), {"shape": self.shape}) @@ -1553,6 +1587,9 @@ def __call__(self, x: NonScalarArray) -> NonScalarArray: ) def _inverse(self, y: NonScalarArray) -> NonScalarArray: + assert self.shape is not None, ( + "Shape must be specified in `__init__` for inverse transform." + ) (n,) = self.shape n_real = n // 2 + 1 n_imag = n - n_real @@ -1565,7 +1602,7 @@ def __eq__(self, other) -> bool: ) -class RecursiveLinearTransform(Transform): +class RecursiveLinearTransform(Transform[NonScalarArray]): """ Apply a linear transformation recursively such that :math:`y_t = A y_{t - 1} + x_t` for :math:`t > 0`, where :math:`x_t` and :math:`y_t` @@ -1695,10 +1732,12 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, RecursiveLinearTransform): return False - return jnp.array_equal(self.transition_matrix, other.transition_matrix) # type: ignore[return-value] + return (self.transition_matrix is other.transition_matrix) and ( + self.initial_value is other.initial_value + ) -class ZeroSumTransform(Transform): +class ZeroSumTransform(Transform[NonScalarArray]): """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] :param transform_ndims: Number of trailing dimensions to transform. @@ -1713,12 +1752,14 @@ def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property - def domain(self) -> ConstraintT: # type: ignore[override] - return constraints.independent(constraints.real, self.transform_ndims) # type: ignore[return-value] + def domain(self) -> ConstraintT: + return cast( + ConstraintT, constraints.independent(constraints.real, self.transform_ndims) + ) @property - def codomain(self) -> ConstraintT: # type: ignore[override] - return constraints.zero_sum(self.transform_ndims) # type: ignore[return-value] + def codomain(self) -> ConstraintT: + return cast(ConstraintT, constraints.zero_sum(self.transform_ndims)) def __call__(self, x: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) @@ -1785,7 +1826,7 @@ def __eq__(self, other: object) -> bool: ) -class ComplexTransform(ParameterFreeTransform): +class ComplexTransform(ParameterFreeTransform[NonScalarArray]): """ Transforms a pair of real numbers to a complex number. """ From 37d821108982210d4261d191700a1bab855449ba Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 18 Oct 2025 12:46:51 -0400 Subject: [PATCH 11/16] formatting transform methods --- numpyro/_typing.py | 5 +- numpyro/distributions/transforms.py | 80 +++++++++-------------------- 2 files changed, 24 insertions(+), 61 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 1a9695a83..10cc1c9e8 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -120,10 +120,7 @@ def sign(self) -> NumLike: ... def __call__(self, x: NumLike) -> NumLike: ... def _inverse(self, y: NumLike) -> NumLike: ... def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: ... def call_with_intermediates( self, x: NumLike diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c2aabda21..acd09859e 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -73,14 +73,10 @@ class Transform(Generic[NumLikeT]): @property def domain(self) -> ConstraintT: - if hasattr(self, "_domain"): - return self._domain return constraints.real @property def codomain(self) -> ConstraintT: - if hasattr(self, "_codomain"): - return self._codomain return constraints.real def __init_subclass__(cls, **kwargs): @@ -103,10 +99,7 @@ def _inverse(self, y: NumLikeT) -> NumLike: raise NotImplementedError() def log_abs_det_jacobian( - self, - x: NumLikeT, - y: NumLikeT, - intermediates: Optional[PyTree] = None, + self, x: NumLikeT, y: NumLikeT, intermediates: Optional[PyTree] = None ) -> NumLike: raise NotImplementedError() @@ -191,10 +184,7 @@ def __call__(self, x: NumLike) -> NumLike: return self._inv._inverse(x) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -240,15 +230,16 @@ class AffineTransform(Transform[NumLike]): """ def __init__( - self, - loc: NumLike, - scale: NumLike, - domain: ConstraintT = constraints.real, + self, loc: NumLike, scale: NumLike, domain: ConstraintT = constraints.real ): self.loc = loc self.scale = scale self._domain = domain + @property + def domain(self) -> ConstraintT: + return self._domain + @property def codomain(self) -> ConstraintT: if self.domain is constraints.real: @@ -268,13 +259,11 @@ def codomain(self) -> ConstraintT: elif isinstance(self.domain, constraints.interval): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.interval( - self(self.domain.upper_bound), - self(self.domain.lower_bound), + self(self.domain.upper_bound), self(self.domain.lower_bound) ) else: return constraints.interval( - self(self.domain.lower_bound), - self(self.domain.upper_bound), + self(self.domain.lower_bound), self(self.domain.upper_bound) ) else: raise NotImplementedError @@ -290,10 +279,7 @@ def _inverse(self, y: NumLike) -> NumLike: return (y - self.loc) / self.scale def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) @@ -393,10 +379,7 @@ def _inverse(self, y: NumLike) -> NumLike: return y def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: if intermediates is not None: if len(intermediates) != len(self.parts): @@ -609,6 +592,10 @@ class ExpTransform(Transform[NumLike]): def __init__(self, domain: ConstraintT = constraints.real): self._domain = domain + @property + def domain(self) -> ConstraintT: + return self._domain + @property def codomain(self) -> ConstraintT: if self.domain is constraints.ordered_vector: @@ -633,15 +620,12 @@ def _inverse(self, y: NumLike) -> NumLike: return jnp.log(y) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: return x def tree_flatten(self): - return (self.domain,), (("domain",), dict()) + return (self.domain,), (("_domain",), dict()) def __eq__(self, other: object) -> bool: if not isinstance(other, ExpTransform): @@ -659,10 +643,7 @@ def _inverse(self, y: NumLike) -> NumLike: return y def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: return jnp.zeros_like(x) @@ -709,10 +690,7 @@ def _inverse(self, y: NumLike) -> NumLike: return self.base_transform._inverse(y) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates @@ -1044,10 +1022,7 @@ def _inverse(self, y: NumLike) -> NumLike: return jnp.power(y, 1 / self.exponent) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: return jnp.log(jnp.abs(jnp.multiply(self.exponent, y) / x)) @@ -1081,10 +1056,7 @@ def _inverse(self, y: NumLike) -> NumLike: return logit(y) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: return -softplus(x) - softplus(-x) @@ -1183,10 +1155,7 @@ def _inverse(self, y: NumLike) -> NumLike: return _softplus_inv(y) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> NumLike: return -softplus(-x) @@ -1842,10 +1811,7 @@ def _inverse(self, y: ArrayLike) -> Array: return jnp.stack([y.real, y.imag], axis=-1) def log_abs_det_jacobian( - self, - x: NumLike, - y: NumLike, - intermediates: Optional[PyTree] = None, + self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None ) -> Array: return jnp.zeros_like(y) From e1942a0eb468a9d64e48649c2c4932af88213672 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 18 Oct 2025 12:50:12 -0400 Subject: [PATCH 12/16] fix typing of ComplexTransform --- numpyro/distributions/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index acd09859e..b70c91bbc 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1807,12 +1807,12 @@ def __call__(self, x: NonScalarArray) -> NonScalarArray: assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." return lax.complex(x[..., 0], x[..., 1]) - def _inverse(self, y: ArrayLike) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.stack([y.real, y.imag], axis=-1) def log_abs_det_jacobian( - self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None - ) -> Array: + self, x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None + ) -> NonScalarArray: return jnp.zeros_like(y) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: From e43d6317b95e76a957e1686f57d4266c07fff8f6 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 18 Oct 2025 12:58:58 -0400 Subject: [PATCH 13/16] format transforms module again --- numpyro/distributions/transforms.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index b70c91bbc..159e36316 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1811,7 +1811,10 @@ def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.stack([y.real, y.imag], axis=-1) def log_abs_det_jacobian( - self, x: NonScalarArray, y: NonScalarArray, intermediates: Optional[PyTree] = None + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, ) -> NonScalarArray: return jnp.zeros_like(y) From 34770a3359a941792a4f279d419e7618bd94dffe Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 18 Oct 2025 22:47:22 -0400 Subject: [PATCH 14/16] fix for transform.inv might get value None --- numpyro/distributions/transforms.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 159e36316..b87ecec07 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -69,7 +69,7 @@ def _clipped_expit(x: NumLike) -> NumLike: class Transform(Generic[NumLikeT]): - _inv: Optional[Union[TransformT, weakref.ReferenceType]] = None + _inv: Optional[Union[TransformT, weakref.ref]] = None @property def domain(self) -> ConstraintT: @@ -85,9 +85,10 @@ def __init_subclass__(cls, **kwargs): @property def inv(self: TransformT) -> TransformT: + inv = None if (self._inv is not None) and isinstance(self._inv, weakref.ref): inv = self._inv() - else: + if inv is None: inv = _InverseTransform(self) self._inv = weakref.ref(inv) return cast(TransformT, inv) @@ -160,7 +161,7 @@ def __eq__(self, other: object) -> bool: class _InverseTransform(Transform[NumLike]): - def __init__(self, transform: TransformT) -> None: + def __init__(self, transform: TransformT): super().__init__() self._inv: TransformT = transform From cd1ab43f37fd0872b34d12abb54b5f27241b7a05 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 19 Oct 2025 19:29:54 -0400 Subject: [PATCH 15/16] revert to use jnp.array_equal to compare the domains --- numpyro/distributions/batch_util.py | 2 +- numpyro/distributions/constraints.py | 20 ++++++++++---------- numpyro/distributions/transforms.py | 26 ++++++++++++++------------ test/test_distributions.py | 25 +++++++++++++++++-------- 4 files changed, 42 insertions(+), 31 deletions(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index de0ad382e..b25d050d5 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -80,7 +80,7 @@ def _vmap_over_affine_transform( dist_axes = copy.copy(dist) dist_axes.loc = loc dist_axes.scale = scale - dist_axes.domain = domain + dist_axes._domain = domain return dist_axes diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 400ebb2f0..417b5f260 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -304,7 +304,7 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _GreaterThan): return False - return self.lower_bound is other.lower_bound + return jnp.array_equal(self.lower_bound, other.lower_bound) class _GreaterThanEq(_GreaterThan): @@ -314,7 +314,7 @@ def __call__(self, x: NumLike) -> ArrayLike: def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _GreaterThanEq): return False - return self.lower_bound is other.lower_bound + return jnp.array_equal(self.lower_bound, other.lower_bound) class _Positive(_SingletonConstraint, _GreaterThan): @@ -429,7 +429,7 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _LessThan): return False - return self.upper_bound is other.upper_bound + return jnp.array_equal(self.upper_bound, other.upper_bound) class _LessThanEq(_LessThan): @@ -439,7 +439,7 @@ def __call__(self, x: NumLike) -> ArrayLike: def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _LessThanEq): return False - return self.upper_bound is other.upper_bound + return jnp.array_equal(self.upper_bound, other.upper_bound) class _IntegerInterval(Constraint): @@ -472,8 +472,8 @@ def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _IntegerInterval): return False - return (self.lower_bound is other.lower_bound) and ( - self.upper_bound is other.upper_bound + return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( + self.upper_bound, other.upper_bound ) @@ -500,7 +500,7 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _IntegerGreaterThan): return False - return self.lower_bound is other.lower_bound + return jnp.array_equal(self.lower_bound, other.lower_bound) class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): @@ -536,8 +536,8 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _Interval): return False - return (self.lower_bound is other.lower_bound) and ( - self.upper_bound is other.upper_bound + return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( + self.upper_bound, other.upper_bound ) def tree_flatten(self): @@ -610,7 +610,7 @@ def tree_flatten(self): def __eq__(self, other: ConstraintT) -> bool: if not isinstance(other, _Multinomial): return False - return self.upper_bound is other.upper_bound + return jnp.array_equal(self.upper_bound, other.upper_bound) class _L1Ball(_SingletonConstraint): diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index b87ecec07..467da9f3c 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -304,10 +304,10 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False return ( - (self.loc is other.loc) - and (self.scale is other.scale) - and (self.domain == other.domain) - ) + jnp.array_equal(self.loc, other.loc) + & jnp.array_equal(self.scale, other.scale) + & (self.domain == other.domain) + ) # type: ignore[return-value] def _get_compose_transform_input_event_dim(parts): @@ -433,7 +433,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False - return all(p1 == p2 for p1, p2 in zip(self.parts, other.parts)) + return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) # type: ignore[return-value] def _matrix_forward_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]: @@ -850,7 +850,9 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, LowerCholeskyAffine): return False - return (self.loc is other.loc) and (self.scale_tril is other.scale_tril) + return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( + self.scale_tril, other.scale_tril + ) # type: ignore[return-value] class LowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]): @@ -1006,7 +1008,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False - return self.permutation is other.permutation + return jnp.array_equal(self.permutation, other.permutation) # type: ignore[return-value] class PowerTransform(Transform[NumLike]): @@ -1039,7 +1041,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False - return self.exponent is other.exponent + return jnp.array_equal(self.exponent, other.exponent) # type: ignore[return-value] @property def sign(self) -> NumLike: @@ -1126,7 +1128,7 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False - return self.anchor_point is other.anchor_point + return jnp.array_equal(self.anchor_point, other.anchor_point) # type: ignore[return-value] def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) @@ -1702,9 +1704,9 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, RecursiveLinearTransform): return False - return (self.transition_matrix is other.transition_matrix) and ( - self.initial_value is other.initial_value - ) + return jnp.array_equal( + self.transition_matrix, other.transition_matrix + ) & jnp.array_equal(self.initial_value, other.initial_value) # type: ignore[return-value] class ZeroSumTransform(Transform[NonScalarArray]): diff --git a/test/test_distributions.py b/test/test_distributions.py index fa5c31f78..3eeb09498 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -556,14 +556,23 @@ def get_sp_dist(jax_dist): np.array([[0.8, 0.2], [-0.1, 1.1]]), np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]), ), - T( - dist.GaussianCopulaBeta, - np.array([7.0, 2.0]), - np.array([4.0, 10.0]), - np.array([[1.0, 0.75], [0.75, 1.0]]), + pytest.param( + *T( + dist.GaussianCopulaBeta, + np.array([7.0, 2.0]), + np.array([4.0, 10.0]), + np.array([[1.0, 0.75], [0.75, 1.0]]), + ), + marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"), + ), + pytest.param( + *T(dist.GaussianCopulaBeta, 2.0, 1.5, np.eye(3)), + marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"), + ), + pytest.param( + *T(dist.GaussianCopulaBeta, 2.0, 1.5, np.full((5, 3, 3), np.eye(3))), + marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"), ), - T(dist.GaussianCopulaBeta, 2.0, 1.5, np.eye(3)), - T(dist.GaussianCopulaBeta, 2.0, 1.5, np.full((5, 3, 3), np.eye(3))), T(dist.Gompertz, np.array([1.7]), np.array([[2.0], [3.0]])), T(dist.Gompertz, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])), T(dist.Gumbel, 0.0, 1.0), @@ -2630,7 +2639,7 @@ def test_composed_transform(batch_shape): expected_log_det = ( jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9 ) - assert_allclose(log_det, expected_log_det) + assert_allclose(log_det, expected_log_det, rtol=1e-6) @pytest.mark.parametrize("batch_shape", [(), (5,)]) From 1d6b24df2a340efa6118910043a821a7dbdafcfb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 19 Oct 2025 19:33:27 -0400 Subject: [PATCH 16/16] fix lint --- numpyro/distributions/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 467da9f3c..4512225c8 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -303,11 +303,12 @@ def tree_flatten(self): def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False - return ( + is_equal = ( jnp.array_equal(self.loc, other.loc) & jnp.array_equal(self.scale, other.scale) & (self.domain == other.domain) - ) # type: ignore[return-value] + ) + return is_equal # type: ignore[return-value] def _get_compose_transform_input_event_dim(parts):