diff --git a/dreamerv3/nets.py b/dreamerv3/nets.py index d11addab..8cb8b59c 100644 --- a/dreamerv3/nets.py +++ b/dreamerv3/nets.py @@ -92,6 +92,8 @@ def obs_step(self, prev_state, prev_action, embed, is_first): lambda x, y: x + self._mask(y, is_first), prev_state, self.initial(len(is_first))) prior = self.img_step(prev_state, prev_action) + if len(embed.shape) > len(prior['deter'].shape): + embed = embed.reshape(embed.shape[0], -1) x = jnp.concatenate([prior['deter'], embed], -1) x = self.get('obs_out', Linear, **self._kw)(x) stats = self._stats('obs_stats', x) @@ -251,7 +253,7 @@ def __init__( if re.match(cnn_keys, k) and len(v) == 3} self.mlp_shapes = { k: v for k, v in shapes.items() - if re.match(mlp_keys, k) and len(v) == 1} + if re.match(mlp_keys, k) and len(v) in [1, 2]} self.shapes = {**self.cnn_shapes, **self.mlp_shapes} print('Decoder CNN shapes:', self.cnn_shapes) print('Decoder MLP shapes:', self.mlp_shapes)