Skip to content

Commit

Permalink
Increase test tolerance for upcoming JAX PRNG change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580023819
  • Loading branch information
jburnim authored and tensorflower-gardener committed Nov 7, 2023
1 parent fb642b0 commit d31cd56
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,7 @@ def testMVN(self, event_shape, shift, tril, dynamic_shape):
num_samples = 7e3
y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed())
x = y[0:5, ...]
self.assertAllMeansClose(y, expected_mean, axis=0,
atol=0.1, rtol=0.1)
self.assertAllMeansClose(y, expected_mean, axis=0, atol=0.25)
self.assertAllClose(expected_cov, sample_stats.covariance(y, sample_axis=0),
atol=0., rtol=0.1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def get_abs_sample_mean(skewness):

err = self.compute_max_gradient_error(
get_abs_sample_mean, [tf.constant(skewness, self.dtype)], delta=1e-1)
maxerr = 0.05 if self.dtype == np.float64 else 0.09
maxerr = 0.2
self.assertLess(err, maxerr)

@test_util.numpy_disable_gradient_test
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/math/special_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ def _test_betaincinv_value(self, a_high, b_high, dtype, atol, rtol):
"rtol": 2e-3},
{"testcase_name": "float64",
"dtype": np.float64,
"atol": 1e-12,
"rtol": 1e-11})
"atol": 3e-12,
"rtol": 3e-11})
def testBetaincinvSmall(self, dtype, atol, rtol):
self._test_betaincinv_value(
a_high=1., b_high=1., dtype=dtype, atol=atol, rtol=rtol)
Expand Down

0 comments on commit d31cd56

Please sign in to comment.