From 806394f5749d41ef416dc108676d35404d9afd65 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Fri, 3 Nov 2023 08:02:37 -0700 Subject: [PATCH] Improve backprop performance through experimental kalman filter, by changing out MVN log_prob calculation. PiperOrigin-RevId: 579184615 --- .../parallel_kalman_filter_lib.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py index ae40d42794..838558fc0f 100644 --- a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py +++ b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py @@ -21,6 +21,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import mvn_tril +from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.math import linalg @@ -625,11 +626,8 @@ def kalman_filter(transition_matrix, axis=0), added_cov=time_dep.observation_cov) - # TODO(srvasude): The JVP for this can be implemented more efficiently. - log_likelihoods = mvn_tril.MultivariateNormalTriL( - loc=observation_means, - scale_tril=tf.linalg.cholesky(observation_covs)).log_prob( - observation.y) + log_likelihoods = _mvn_log_prob( + observation_means, observation_covs, observation.y) if observation.mask is not None: log_likelihoods = tf.where(observation.mask, tf.zeros([], dtype=log_likelihoods.dtype), @@ -644,6 +642,17 @@ def kalman_filter(transition_matrix, observation_covs) +def _mvn_log_prob(mean, covariance, y): + cholesky_matrix = tf.linalg.cholesky(covariance) + log_prob = -0.5 * linalg.hpsd_quadratic_form_solvevec( + covariance, y - mean, cholesky_matrix=cholesky_matrix) + log_prob = log_prob - 0.5 * linalg.hpsd_logdet( + covariance, cholesky_matrix=cholesky_matrix) + event_dims = ps.shape(mean)[-1] + return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype( + mean.dtype)(np.log(2 * np.pi)) + + def _extract_batch_shape(x, sample_ndims, event_ndims): """Slice out the batch component of `x`'s shape.""" if x is None: