From 300bfe564df5099a48a38931c641425b1b79580f Mon Sep 17 00:00:00 2001 From: emilyaf Date: Mon, 20 Nov 2023 15:39:30 -0800 Subject: [PATCH] Fix dtype in parallel Kalman filter likelihood. PiperOrigin-RevId: 584146012 --- .../parallel_filter/parallel_kalman_filter_lib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py index 838558fc0f..b375741071 100644 --- a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py +++ b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py @@ -649,8 +649,8 @@ def _mvn_log_prob(mean, covariance, y): log_prob = log_prob - 0.5 * linalg.hpsd_logdet( covariance, cholesky_matrix=cholesky_matrix) event_dims = ps.shape(mean)[-1] - return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype( - mean.dtype)(np.log(2 * np.pi)) + return log_prob - dtype_util.as_numpy_dtype(mean.dtype)( + 0.5 * event_dims * np.log(2 * np.pi)) def _extract_batch_shape(x, sample_ndims, event_ndims):