Skip to content

Commit

Permalink
Fix tf.where-induced nan grads in NormalInverseGaussian
Browse files Browse the repository at this point in the history
Fixes #1778

PiperOrigin-RevId: 596612661
  • Loading branch information
csuter authored and tensorflower-gardener committed Jan 8, 2024
1 parent 526c975 commit 07c9704
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,35 @@


def _log1px2(x):
"""Safely compute log(1 + x ** 2).
For small x, use log1p(x ** 2). For large x(x >> 1), use 2 * log(x). Also
avoid nan grad using double-where for x ~= 0.
Args:
x: float `Tensor`.
Returns:
y: log(1 + x ** 2).
"""
# The idea with this is to use 2 log(x) when x** 2 >> 1, so that adding 1
# doesn't matter. This happens when x >> 1 / sqrt(eps). But this causes
# grad problems for zero input:
#
# If x is zero, the log(1 + x**2) is log(1) = 0. But then 2 * log(x) is
# 2 * log(0) = 2 * -Inf, which causes problems. So for 0 input, we need a safe
# value for the negative case and use the double-where trick
# (see, eg, https://github.com/google/jax/issues/1052)
finfo = np.finfo(dtype_util.as_numpy_dtype(x.dtype))
is_basically_zero = tf.abs(x) < finfo.tiny
safe_x = tf.where(is_basically_zero, tf.ones_like(x), x)
return tf.where(
tf.abs(x) * np.sqrt(np.finfo(
dtype_util.as_numpy_dtype(x.dtype)).eps) <= 1.,
tf.math.log1p(x**2.),
2 * tf.math.log(tf.math.abs(x)))
is_basically_zero,
tf.abs(x),
tf.where(
tf.abs(x) * np.sqrt(finfo.eps) <= 1.,
tf.math.log1p(safe_x**2.),
2 * tf.math.log(tf.math.abs(safe_x))))


class NormalInverseGaussian(distribution.AutoCompositeTensorDistribution):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import normal_inverse_gaussian as nig
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.math import gradient


@test_util.test_all_tf_execution_regimes
Expand Down Expand Up @@ -218,6 +219,17 @@ def testModifiedVariableAssertion(self):
with tf.control_dependencies([skewness.assign(-2.)]):
self.evaluate(normal_inverse_gaussian.mean())

@test_util.numpy_disable_gradient_test
def testDoubleWhere(self):
loc = 0.

def f(x):
return nig.NormalInverseGaussian(
loc=x, scale=2., tailweight=1., skewness=2.).log_prob(loc)

_, g = gradient.value_and_gradient(f, loc)
self.assertAllNotNan(g)


class NormalInverseGaussianTestFloat32(
test_util.TestCase, _NormalInverseGaussianTest):
Expand Down

0 comments on commit 07c9704

Please sign in to comment.