From 0867ae32a66b3c63c85a03a14b25e917bf903ce8 Mon Sep 17 00:00:00 2001 From: fmuham Date: Wed, 11 Aug 2021 16:09:07 -0700 Subject: [PATCH] Fix Cache Key Mismatch for Function Input Signature - A TensorSpec with no name is now equivalent to a Tensor with the same dtype and shape - Fixed the shape mismatch issue for Tensors: (11,2) was equal to (1, 12) PiperOrigin-RevId: 390241221 --- tensorflow_probability/python/distributions/poisson_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow_probability/python/distributions/poisson_test.py b/tensorflow_probability/python/distributions/poisson_test.py index ff3a7e4459..a9b9413ee6 100644 --- a/tensorflow_probability/python/distributions/poisson_test.py +++ b/tensorflow_probability/python/distributions/poisson_test.py @@ -450,6 +450,12 @@ def testSampleGPU(self): def testSampleXLA(self): self.skip_if_no_xla() if not tf.executing_eagerly(): return # jit_compile is eager-only. + + # TODO(b/195975508): Reloading the function to reset the cache. + if not test_util.JAX_MODE: + poisson_lib.random_poisson = tf.function( + poisson_lib.random_poisson._python_function) + log_rates = np.random.rand(4, 3).astype(np.float32) dist = tfd.Poisson(log_rate=log_rates, validate_args=True) # Verify the compile succeeds going all the way through the distribution.