Skip to content

Commit

Permalink
Avoid deprecated casting of size-1 np.ndarrays.
Browse files Browse the repository at this point in the history
This used to be allowed but is now deprecated. Some logic that lies downstream
of many of our distributions' log_prob methods would invoke this behavior (in a
try/except, so it would not fail even post-deprecation, but we get an annoyting
warning all the time). This change avoids that deprecated behavior.

PiperOrigin-RevId: 574008432
  • Loading branch information
csuter authored and tensorflower-gardener committed Oct 17, 2023
1 parent e6907a1 commit c1728d3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.util.seed_stream import SeedStream

Expand Down Expand Up @@ -1494,7 +1495,7 @@ def _random_unit_hypersphere(sample_shape, event_shape, dtype, seed):
target_shape = tf.concat([sample_shape, event_shape], axis=0)
return tf.math.l2_normalize(
tf.random.normal(target_shape, seed=seed, dtype=dtype),
axis=-1 - tf.range(tf.size(event_shape)))
axis=-1 - ps.range(ps.size(event_shape)))


def assert_multivariate_true_cdf_equal_on_projections_two_sample(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,15 @@

def _astuple(x):
"""Attempt to convert the given argument to be a Python tuple."""
try:
return (int(x),)
except TypeError:
pass
# Numpy used to allow casting a size-1 ndarray to python scalar literal types.
# In version 1.25 this was deprecated, causing a warning to be issued in the
# below try/except. To avoid that, we just fall through in the case of an
# np.ndarray.
if not isinstance(x, np.ndarray):
try:
return (int(x),)
except TypeError:
pass

try:
return tuple(x)
Expand Down

0 comments on commit c1728d3

Please sign in to comment.