diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index ac2be35..5114b6e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -55,12 +55,12 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): - Y = cls.Y.eval() - else: - Y = cls.Y - if not cls.all_trees: + if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): + Y = cls.Y.eval() + else: + Y = cls.Y + if size is not None: return np.full((size[0], Y.shape[0]), Y.mean()) else: