From 50475294912eccc95fc4b79a3742e3786a90a835 Mon Sep 17 00:00:00 2001 From: colcarroll Date: Tue, 5 Nov 2024 14:19:36 -0800 Subject: [PATCH] Disable test_gradient_with_additional_parameters for JAX backend. PiperOrigin-RevId: 693475540 --- .../distributions/batch_broadcast_test.py | 25 ++++++++++++++++--- ...lar_function_with_inferred_inverse_test.py | 2 ++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tensorflow_probability/python/distributions/batch_broadcast_test.py b/tensorflow_probability/python/distributions/batch_broadcast_test.py index 7f0ef2fe45..9bc8541133 100644 --- a/tensorflow_probability/python/distributions/batch_broadcast_test.py +++ b/tensorflow_probability/python/distributions/batch_broadcast_test.py @@ -40,12 +40,23 @@ from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.random import random_ops +_DIFFERENT_HYPOTHESIS_KWARGS = {} + +# This check is done on recent versions of hypothesis, but not all, +# as of November 2024. +if hasattr(hp.HealthCheck, 'differing_executors'): + _DIFFERENT_HYPOTHESIS_KWARGS['suppress_health_check'] = [ + hp.HealthCheck.differing_executors + ] + @test_util.test_all_tf_execution_regimes class _BatchBroadcastTest(object): @hp.given(hps.data()) - @tfp_hps.tfp_hp_settings(default_max_examples=5) + @tfp_hps.tfp_hp_settings( + default_max_examples=5, + **_DIFFERENT_HYPOTHESIS_KWARGS) def test_shapes(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( @@ -63,7 +74,9 @@ def test_shapes(self, data): dist.event_shape_tensor()) @hp.given(hps.data()) - @tfp_hps.tfp_hp_settings(default_max_examples=5) + @tfp_hps.tfp_hp_settings( + default_max_examples=5, + **_DIFFERENT_HYPOTHESIS_KWARGS) def test_sample(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( @@ -109,7 +122,9 @@ def test_sample(self, data): self.assertAllClose(lp, dist.log_prob(sample2)) @hp.given(hps.data()) - @tfp_hps.tfp_hp_settings(default_max_examples=5) + @tfp_hps.tfp_hp_settings( + default_max_examples=5, + **_DIFFERENT_HYPOTHESIS_KWARGS) def test_log_prob(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( @@ -235,7 +250,9 @@ def test_docstring_example(self): self.evaluate(lp) @hp.given(hps.data()) - @tfp_hps.tfp_hp_settings(default_max_examples=5) + @tfp_hps.tfp_hp_settings( + default_max_examples=5, + **_DIFFERENT_HYPOTHESIS_KWARGS) def test_default_bijector(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( diff --git a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py index ecec2b671a..0f9d27c34e 100644 --- a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py +++ b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py @@ -102,6 +102,8 @@ def ildj_fn(y): self.assertAllClose(ildj, ildj_true, atol=1e-4) self.assertAllClose(ildj_grad, ildj_grad_true, rtol=1e-4) + @test_util.disable_test_for_backend( + disable_jax=True, reason='Tracer leak from additional parameters.') @test_util.numpy_disable_gradient_test @parameterized.named_parameters( {