Skip to content

Commit

Permalink
Fix bug in windowed_mcmc related to constrained distributions having …
Browse files Browse the repository at this point in the history
…different event_shapes from unconstrained.

PiperOrigin-RevId: 390202596
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Aug 11, 2021
1 parent b71b6af commit 956d09f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def _get_flat_unconstraining_bijector(jd_model):
event_space_bij = jd_model.experimental_default_event_space_bijector()
flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())

unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
jd_model.event_shape_tensor())
unconstrained_shapes = event_space_bij(
flat_bijector).inverse_event_shape_tensor(jd_model.event_shape_tensor())

# this reshaping is required as as split can produce a tensor of shape [1]
# when the distribution event shape is []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ def mk_y(x):
self.assertEqual((2, 64, 10, 3), states['x'].shape)
self.assertEqual((2, 10, 1), trace['step_size'].shape)

def test_bijector(self):
dist = tfd.JointDistributionSequential([tfd.Dirichlet(tf.ones(2))])
bij, _ = windowed_sampling._get_flat_unconstraining_bijector(dist)
draw = dist.sample(seed=test_util.test_seed())
self.assertAllCloseNested(bij.inverse(bij(draw)), draw)


@test_util.test_graph_and_eager_modes
class WindowedSamplingStepSizeTest(test_util.TestCase):
Expand Down

0 comments on commit 956d09f

Please sign in to comment.