diff --git a/tensorflow_probability/python/distributions/linear_gaussian_ssm.py b/tensorflow_probability/python/distributions/linear_gaussian_ssm.py index 0938fbd242..facb6e6b34 100644 --- a/tensorflow_probability/python/distributions/linear_gaussian_ssm.py +++ b/tensorflow_probability/python/distributions/linear_gaussian_ssm.py @@ -61,7 +61,7 @@ def _safe_concat(values): for x in values: try: full_values.append(ps.reshape(x, reference_shape)) - except (TypeError, ValueError): + except (TypeError, ValueError, ZeroDivisionError): # JAX/numpy don't like `-1`'s in size-zero shapes. full_values.append(ps.reshape(x, trivial_shape)) return ps.concat(full_values, axis=0)