From b7724fd1ee5f924c0f9cf5a6fcb4860e855d5808 Mon Sep 17 00:00:00 2001 From: colcarroll Date: Tue, 21 Nov 2023 08:44:16 -0800 Subject: [PATCH] Allow `independent_joint_distribution_from_structure` to support (more) nested objects. If there were no items at a level of nesting which were distributions, `get_traverse_shallow` would terminate there, which would end up calling `independent_joint_distribution_from_structure` on the same structure, leading to a recursion error. This change guarantees that that in that case, the top layer at least will get mapped up to. PiperOrigin-RevId: 584336978 --- .../python/distributions/joint_distribution_util.py | 4 ++++ .../python/distributions/joint_distribution_util_test.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/tensorflow_probability/python/distributions/joint_distribution_util.py b/tensorflow_probability/python/distributions/joint_distribution_util.py index 2a60ceb696..a6c3e427e8 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_util.py +++ b/tensorflow_probability/python/distributions/joint_distribution_util.py @@ -79,6 +79,10 @@ def independent_joint_distribution_from_structure(structure_of_distributions, next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) + if not nest.is_nested(next_level_shallow_structure): # is a boolean + next_level_shallow_structure = nest.get_traverse_shallow_structure( + traverse_fn=lambda x: x is element_depths, + structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, functools.partial(independent_joint_distribution_from_structure, diff --git a/tensorflow_probability/python/distributions/joint_distribution_util_test.py b/tensorflow_probability/python/distributions/joint_distribution_util_test.py index 0eb32ec3cf..4917eafc1d 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_util_test.py +++ b/tensorflow_probability/python/distributions/joint_distribution_util_test.py @@ -83,6 +83,12 @@ def test_independent_jd_from_nested_input(self): 'c': (dirichlet.Dirichlet([1., 1.]),)}], expect_isinstance=jds.JointDistributionSequential) + def test_independent_jd_from_nested_input_one_empty(self): + self._test_independent_joint_distribution_from_structure_helper( + structure={'a': {'b': normal.Normal(0., 1.)}, + 'c': {'d': normal.Normal(0., 1.)}}, + expect_isinstance=jdn.JointDistributionNamed) + def test_batch_ndims_nested_input(self): dist = jdu.independent_joint_distribution_from_structure( [normal.Normal(0., tf.ones([5, 4])),