diff --git a/tensorflow_probability/python/distributions/joint_distribution_util.py b/tensorflow_probability/python/distributions/joint_distribution_util.py index 2a60ceb696..a6c3e427e8 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_util.py +++ b/tensorflow_probability/python/distributions/joint_distribution_util.py @@ -79,6 +79,10 @@ def independent_joint_distribution_from_structure(structure_of_distributions, next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) + if not nest.is_nested(next_level_shallow_structure): # is a boolean + next_level_shallow_structure = nest.get_traverse_shallow_structure( + traverse_fn=lambda x: x is element_depths, + structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, functools.partial(independent_joint_distribution_from_structure, diff --git a/tensorflow_probability/python/distributions/joint_distribution_util_test.py b/tensorflow_probability/python/distributions/joint_distribution_util_test.py index 0eb32ec3cf..4917eafc1d 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_util_test.py +++ b/tensorflow_probability/python/distributions/joint_distribution_util_test.py @@ -83,6 +83,12 @@ def test_independent_jd_from_nested_input(self): 'c': (dirichlet.Dirichlet([1., 1.]),)}], expect_isinstance=jds.JointDistributionSequential) + def test_independent_jd_from_nested_input_one_empty(self): + self._test_independent_joint_distribution_from_structure_helper( + structure={'a': {'b': normal.Normal(0., 1.)}, + 'c': {'d': normal.Normal(0., 1.)}}, + expect_isinstance=jdn.JointDistributionNamed) + def test_batch_ndims_nested_input(self): dist = jdu.independent_joint_distribution_from_structure( [normal.Normal(0., tf.ones([5, 4])),