Skip to content

Commit

Permalink
Allow independent_joint_distribution_from_structure to support (mor…
Browse files Browse the repository at this point in the history
…e) nested objects.

If there were no items at a level of nesting which were distributions, `get_traverse_shallow` would terminate there, which would end up calling `independent_joint_distribution_from_structure` on the same structure, leading to a recursion error. This change guarantees that that in that case, the top layer at least will get mapped up to.

PiperOrigin-RevId: 584336978
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Nov 21, 2023
1 parent 300bfe5 commit b7724fd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def independent_joint_distribution_from_structure(structure_of_distributions,
next_level_shallow_structure = nest.get_traverse_shallow_structure(
traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
structure=element_depths)
if not nest.is_nested(next_level_shallow_structure): # is a boolean
next_level_shallow_structure = nest.get_traverse_shallow_structure(
traverse_fn=lambda x: x is element_depths,
structure=element_depths)
structure_of_distributions = nest.map_structure_up_to(
next_level_shallow_structure,
functools.partial(independent_joint_distribution_from_structure,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def test_independent_jd_from_nested_input(self):
'c': (dirichlet.Dirichlet([1., 1.]),)}],
expect_isinstance=jds.JointDistributionSequential)

def test_independent_jd_from_nested_input_one_empty(self):
self._test_independent_joint_distribution_from_structure_helper(
structure={'a': {'b': normal.Normal(0., 1.)},
'c': {'d': normal.Normal(0., 1.)}},
expect_isinstance=jdn.JointDistributionNamed)

def test_batch_ndims_nested_input(self):
dist = jdu.independent_joint_distribution_from_structure(
[normal.Normal(0., tf.ones([5, 4])),
Expand Down

0 comments on commit b7724fd

Please sign in to comment.