From 9aee8a3220054d8f535595cd58a9dee8b7d1b7b8 Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Sat, 25 Mar 2023 03:12:50 -0400 Subject: [PATCH] Better truncated --- rtgym/envs/real_time_env.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/rtgym/envs/real_time_env.py b/rtgym/envs/real_time_env.py index 15c7b3f..282311d 100644 --- a/rtgym/envs/real_time_env.py +++ b/rtgym/envs/real_time_env.py @@ -308,7 +308,7 @@ def __init__(self, config: dict=DEFAULT_CONFIG_DICT): self.last_act_on_reset = config["last_act_on_reset"] if "last_act_on_reset" in config else False self.act_prepro_func: callable = config["act_prepro_func"] if "act_prepro_func" in config else None self.obs_prepro_func = config["obs_prepro_func"] if "obs_prepro_func" in config else None - self.ep_max_length = config["ep_max_length"] - 1 + self.ep_max_length = config["ep_max_length"] self.time_step_duration = config["time_step_duration"] if "time_step_duration" in config else 0.0 self.time_step_timeout_factor = config["time_step_timeout_factor"] if "time_step_timeout_factor" in config else 1.0 @@ -330,7 +330,6 @@ def __init__(self, config: dict=DEFAULT_CONFIG_DICT): self.__obs = None self.__rew = None self.__terminated = None - self.__truncated = None self.__info = None self.__o_set_flag = False @@ -452,12 +451,11 @@ def __update_obs_rew_terminated_truncated(self): """ self.__o_lock.acquire() o, r, d, i = self.interface.get_obs_rew_terminated_info() - t = (self.current_step >= self.ep_max_length) if not d else False elt = o if self.obs_prepro_func: elt = self.obs_prepro_func(elt) elt = tuple(elt) - self.__obs, self.__rew, self.__terminated, self.__truncated, self.__info = elt, r, d, t, i + self.__obs, self.__rew, self.__terminated, self.__info = elt, r, d, i self.__o_set_flag = True self.__o_lock.release() @@ -468,11 +466,11 @@ def _retrieve_obs_rew_terminated_truncated_info(self): while c: self.__o_lock.acquire() if self.__o_set_flag: - elt, r, d, t, i = self.__obs, self.__rew, self.__terminated, self.__truncated, self.__info + elt, r, d, i = self.__obs, self.__rew, self.__terminated, self.__info self.__o_set_flag = False c = False self.__o_lock.release() - return elt, r, d, t, i + return elt, r, d, i def init_action_buffer(self): for _ in range(self.act_buf_len): @@ -551,7 +549,8 @@ def step(self, action): self._run_time_step(action) if not self.running: raise RuntimeError("The episode is terminated or truncated. Call reset before step.") - obs, rew, terminated, truncated, info = self._retrieve_obs_rew_terminated_truncated_info() + obs, rew, terminated, info = self._retrieve_obs_rew_terminated_truncated_info() + truncated = (self.current_step >= self.ep_max_length) if not terminated else False done = (terminated or truncated) if not done: # apply action only when not done self._run_time_step(action)