diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index b457f1e947e0..b3b6f9f9f2f9 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -241,7 +241,23 @@ def _get_q_value(self, model_out, actions, net): if self.concat_obs_and_actions: input_dict = {"obs": tf.concat([model_out, actions], axis=-1)} else: - input_dict = {"obs": force_list(model_out) + [actions]} + # For the discrete case, action is always None. + shapes = [] + for space in net.obs_space: + if isinstance(space, Discrete): + shapes.append(space.n) + elif isinstance(space, Box): + shapes.append(space.shape) + else: + raise RuntimeError("The space type is not supported.") + + input_dict = { + "obs": tf.split( + tf.concat(force_list(model_out) + [actions], axis=-1), + num_or_size_splits=shapes, + axis=-1, + ) + } # Discrete case -> return q-vals for all actions. else: input_dict = {"obs": model_out}