Add event_ndims to Transform class to fix shape errors in non-autobatching mode #321
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #297
Description of the problem:
When not using auto_batching, the event_shape of the distributions and potentials must match. The reason is, when the log-likelihood is calculated, it is not possible to reduce the sum over all dimensions, as the
batch_shape
has to be preserved. This works well for distributions, however, transformations add a potential aslog_det_jacobian
with too high dimensionality. It is assumed inSamplingState.collect_unreduced_log_prob
that each potential and distribution consist only ofbatch_dim
as shape, however thebackward_log_det_jacobian
also includes the event_dim.Solution
An
event_ndims
attribute is added to the transform class that saves the number of event dimensions of the distributions it belongs to. This number is passed to thelog_det_jacobian
function implemented in tensorflow_probability, which then reduces over these axes. At the end of the initialization of a pymc4 distribution, theevent_ndims
attribute of the transform is set by reading out theevent_shape
attribute of the just initializedtfp.distributions.distribution
.There is some additional code to take care of the case when the transformation changes the number of dimensions. This change of dimensions is defined in the
inverse_event_ndims
method of thetfp.bijectors.bijector
, which is used to calculate the number of dimensions of the transformed variable.Comparison to tensorflow_probability
In tensorflow_probability, in
TransformedTransitionKernel
, a similar problem has been already been solved. Their solution was to calculate the rank of the untransformed likelihood function, which is equal to the length of thebatch_shape
and subtract it from the rank of the input to each bijector to getevent_ndims
link. This is however done at the sampling stage. I suppose that theoretically, one could also implement something similar in PyMC4 in the SamplingExecutor, however, I think that it is cleaner to do it already at the initialization stage because the logic for the event_shapes of the distributions is also implemented there.Implementation details
event_ndims
is read out from thetfp.distribution
with prefer_static, which is copied from link.Other Changes