Skip to content

Commit

Permalink
Increase test tolerances for upcoming JAX PRNG change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579940356
  • Loading branch information
jburnim authored and tensorflower-gardener committed Nov 6, 2023
1 parent be84b38 commit fb642b0
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def testSampleLarge(self):
self.assertAllClose(true_mean, sample_mean_, atol=0., rtol=0.03)
self.assertAllClose(true_mean, analytical_mean_, atol=0., rtol=1e-6)

self.assertAllClose(true_covariance, sample_covariance_, atol=0., rtol=0.03)
self.assertAllClose(true_covariance, sample_covariance_, atol=0., rtol=0.04)
self.assertAllClose(
true_covariance, analytical_covariance_, atol=0., rtol=1e-6)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def model():
self.strategy_run(
run, (self.key,), in_axes=None))
for i in range(test_lib.NUM_DEVICES):
self.assertAllClose(sharded_log_prob[i], true_log_prob, atol=2e-2)
self.assertAllClose(sharded_log_prob[i], true_log_prob, atol=0.025)
self.assertAllClose(sharded_log_prob_grad[i], true_log_prob_grad,
atol=2e-2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def testMatrixValuesAreCorrect(
amplitudes, length_scale, dim, x, y, method='matrix')

self.assertAllClose(
self.evaluate(actual), self.evaluate(expected), rtol=1e-5)
self.evaluate(actual), self.evaluate(expected), rtol=3e-5)

@test_util.disable_test_for_backend(
disable_numpy=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def test_posterior_on_nonzero_subset_matches_bayesian_regression(
self.assertAllClose(
nonzero_subvector(self.evaluate(
initial_state.conditional_weights_mean)),
restricted_weights_posterior_mean)
restricted_weights_posterior_mean,
atol=5e-5)
self.assertAllClose(
nonzero_submatrix(initial_state.conditional_posterior_precision_chol),
tf.linalg.cholesky(restricted_weights_posterior_prec.to_dense()))
Expand Down Expand Up @@ -346,7 +347,7 @@ def loop_body(var_weights_seed, _):
tf.float32)
self.assertAllClose(nonzero_prior_prob,
tf.reduce_mean(nonzero_weight_samples),
atol=0.03)
atol=0.04)

@parameterized.named_parameters(('', False), ('_xla', True))
def test_deterministic_given_seed(self, use_xla):
Expand Down

0 comments on commit fb642b0

Please sign in to comment.