diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py index 30b5640c9e..2037d40e7e 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py @@ -17,6 +17,7 @@ import collections import functools import numpy as np +import numpy as onp # Disable JAX rewrite. # pylint: disable=reimported from tensorflow_probability.python.internal.backend.numpy import _utils as utils from tensorflow_probability.python.internal.backend.numpy.numpy_array import _reverse @@ -169,7 +170,7 @@ def _astuple(x): # In version 1.25 this was deprecated, causing a warning to be issued in the # below try/except. To avoid that, we just fall through in the case of an # np.ndarray. - if not isinstance(x, np.ndarray): + if not isinstance(x, onp.ndarray): try: return (int(x),) except TypeError: