From d61f15c0d05a7a86ca464248e5c6f8814b56e457 Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Fri, 9 Aug 2024 18:42:07 -0400 Subject: [PATCH] avoids setting jax tracer as lazy property attribute (#1843) * remove tracer as attribute of truncated dist * streamline test * fix CI test run failure * Update test name Co-authored-by: Dylan H. Morris * Move test from test_distributions.py to test_distributions_util.py * Direct tests of tracer leaks * pre-commit changes --------- Co-authored-by: Dylan H. Morris Co-authored-by: Dylan H. Morris --- numpyro/distributions/util.py | 5 +++- test/test_distributions_util.py | 44 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c83efb701..7b0e26325 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -14,6 +14,8 @@ from jax.scipy.linalg import solve_triangular from jax.scipy.special import digamma +from numpyro.util import not_jax_tracer + # Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3. _tr_params = namedtuple( "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"] @@ -692,7 +694,8 @@ def __get__(self, instance, obj_type=None): if instance is None: return self value = self.wrapped(instance) - setattr(instance, self.wrapped.__name__, value) + if not_jax_tracer(value): + setattr(instance, self.wrapped.__name__, value) return value diff --git a/test/test_distributions_util.py b/test/test_distributions_util.py index ab15966f8..84af13fca 100644 --- a/test/test_distributions_util.py +++ b/test/test_distributions_util.py @@ -8,10 +8,12 @@ import pytest import scipy +import jax from jax import lax, random, vmap import jax.numpy as jnp from jax.scipy.special import expit, xlog1py, xlogy +import numpyro.distributions as dist from numpyro.distributions.util import ( add_diag, binary_cross_entropy_with_logits, @@ -182,3 +184,45 @@ def test_add_diag(matrix_shape: tuple, diag_shape: tuple) -> None: expected = matrix + diag[..., None] * jnp.eye(matrix.shape[-1]) actual = add_diag(matrix, diag) np.testing.assert_allclose(actual, expected) + + +@pytest.mark.parametrize( + "my_dist", + [ + dist.TruncatedNormal(low=-1.0, high=2.0), + dist.TruncatedCauchy(low=-5, high=10), + dist.TruncatedDistribution(dist.StudentT(3), low=1.5), + ], +) +def test_no_tracer_leak_at_lazy_property_log_prob(my_dist): + """ + Tests that truncated distributions, which use @lazy_property + values in their log_prob() methods, do not + have tracer leakage when log_prob() is called. + Reference: https://github.com/pyro-ppl/numpyro/issues/1836, and + https://github.com/CDCgov/multisignal-epi-inference/issues/282 + """ + jit_lp = jax.jit(my_dist.log_prob) + with jax.check_tracer_leaks(): + jit_lp(1.0) + + +@pytest.mark.parametrize( + "my_dist", + [ + dist.TruncatedNormal(low=-1.0, high=2.0), + dist.TruncatedCauchy(low=-5, high=10), + dist.TruncatedDistribution(dist.StudentT(3), low=1.5), + ], +) +def test_no_tracer_leak_at_lazy_property_sample(my_dist): + """ + Tests that truncated distributions, which use @lazy_property + values in their sample() methods, do not + have tracer leakage when sample() is called. + Reference: https://github.com/pyro-ppl/numpyro/issues/1836, and + https://github.com/CDCgov/multisignal-epi-inference/issues/282 + """ + jit_sample = jax.jit(my_dist.sample) + with jax.check_tracer_leaks(): + jit_sample(jax.random.key(5))