diff --git a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py index ae40d42794..838558fc0f 100644 --- a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py +++ b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py @@ -21,6 +21,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import mvn_tril +from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.math import linalg @@ -625,11 +626,8 @@ def kalman_filter(transition_matrix, axis=0), added_cov=time_dep.observation_cov) - # TODO(srvasude): The JVP for this can be implemented more efficiently. - log_likelihoods = mvn_tril.MultivariateNormalTriL( - loc=observation_means, - scale_tril=tf.linalg.cholesky(observation_covs)).log_prob( - observation.y) + log_likelihoods = _mvn_log_prob( + observation_means, observation_covs, observation.y) if observation.mask is not None: log_likelihoods = tf.where(observation.mask, tf.zeros([], dtype=log_likelihoods.dtype), @@ -644,6 +642,17 @@ def kalman_filter(transition_matrix, observation_covs) +def _mvn_log_prob(mean, covariance, y): + cholesky_matrix = tf.linalg.cholesky(covariance) + log_prob = -0.5 * linalg.hpsd_quadratic_form_solvevec( + covariance, y - mean, cholesky_matrix=cholesky_matrix) + log_prob = log_prob - 0.5 * linalg.hpsd_logdet( + covariance, cholesky_matrix=cholesky_matrix) + event_dims = ps.shape(mean)[-1] + return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype( + mean.dtype)(np.log(2 * np.pi)) + + def _extract_batch_shape(x, sample_ndims, event_ndims): """Slice out the batch component of `x`'s shape.""" if x is None: