From c1728d3e5e323daac732a264a6b0a5063fa1425f Mon Sep 17 00:00:00 2001 From: Christopher Suter Date: Mon, 16 Oct 2023 20:39:25 -0700 Subject: [PATCH] Avoid deprecated casting of size-1 np.ndarrays. This used to be allowed but is now deprecated. Some logic that lies downstream of many of our distributions' log_prob methods would invoke this behavior (in a try/except, so it would not fail even post-deprecation, but we get an annoyting warning all the time). This change avoids that deprecated behavior. PiperOrigin-RevId: 574008432 --- .../distributions/internal/statistical_testing.py | 3 ++- .../python/internal/backend/numpy/numpy_math.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow_probability/python/distributions/internal/statistical_testing.py b/tensorflow_probability/python/distributions/internal/statistical_testing.py index 75fe286711..2cf7189a3f 100644 --- a/tensorflow_probability/python/distributions/internal/statistical_testing.py +++ b/tensorflow_probability/python/distributions/internal/statistical_testing.py @@ -127,6 +127,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.util.seed_stream import SeedStream @@ -1494,7 +1495,7 @@ def _random_unit_hypersphere(sample_shape, event_shape, dtype, seed): target_shape = tf.concat([sample_shape, event_shape], axis=0) return tf.math.l2_normalize( tf.random.normal(target_shape, seed=seed, dtype=dtype), - axis=-1 - tf.range(tf.size(event_shape))) + axis=-1 - ps.range(ps.size(event_shape))) def assert_multivariate_true_cdf_equal_on_projections_two_sample( diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py index 40b3a2526b..30b5640c9e 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py @@ -165,10 +165,15 @@ def _astuple(x): """Attempt to convert the given argument to be a Python tuple.""" - try: - return (int(x),) - except TypeError: - pass + # Numpy used to allow casting a size-1 ndarray to python scalar literal types. + # In version 1.25 this was deprecated, causing a warning to be issued in the + # below try/except. To avoid that, we just fall through in the case of an + # np.ndarray. + if not isinstance(x, np.ndarray): + try: + return (int(x),) + except TypeError: + pass try: return tuple(x)