Skip to content

Commit

Permalink
Fix Cache Key Mismatch for Function Input Signature
Browse files Browse the repository at this point in the history
- A TensorSpec with no name is now equivalent to a Tensor with the same dtype and shape
- Fixed the shape mismatch issue for Tensors: (11,2) was equal to (1, 12)

PiperOrigin-RevId: 390241221
  • Loading branch information
faizan-m authored and tensorflower-gardener committed Aug 11, 2021
1 parent 956d09f commit 0867ae3
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow_probability/python/distributions/poisson_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ def testSampleGPU(self):
def testSampleXLA(self):
self.skip_if_no_xla()
if not tf.executing_eagerly(): return # jit_compile is eager-only.

# TODO(b/195975508): Reloading the function to reset the cache.
if not test_util.JAX_MODE:
poisson_lib.random_poisson = tf.function(
poisson_lib.random_poisson._python_function)

log_rates = np.random.rand(4, 3).astype(np.float32)
dist = tfd.Poisson(log_rate=log_rates, validate_args=True)
# Verify the compile succeeds going all the way through the distribution.
Expand Down

0 comments on commit 0867ae3

Please sign in to comment.