From ea843cd19a8a89214267072f6d4b91a529f98fe4 Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Sun, 16 Apr 2023 01:09:57 +0800 Subject: [PATCH] quick fix --- dreamerv3/nets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dreamerv3/nets.py b/dreamerv3/nets.py index d11addab..8cb8b59c 100644 --- a/dreamerv3/nets.py +++ b/dreamerv3/nets.py @@ -92,6 +92,8 @@ def obs_step(self, prev_state, prev_action, embed, is_first): lambda x, y: x + self._mask(y, is_first), prev_state, self.initial(len(is_first))) prior = self.img_step(prev_state, prev_action) + if len(embed.shape) > len(prior['deter'].shape): + embed = embed.reshape(embed.shape[0], -1) x = jnp.concatenate([prior['deter'], embed], -1) x = self.get('obs_out', Linear, **self._kw)(x) stats = self._stats('obs_stats', x) @@ -251,7 +253,7 @@ def __init__( if re.match(cnn_keys, k) and len(v) == 3} self.mlp_shapes = { k: v for k, v in shapes.items() - if re.match(mlp_keys, k) and len(v) == 1} + if re.match(mlp_keys, k) and len(v) in [1, 2]} self.shapes = {**self.cnn_shapes, **self.mlp_shapes} print('Decoder CNN shapes:', self.cnn_shapes) print('Decoder MLP shapes:', self.mlp_shapes)