Skip to content

Commit

Permalink
Better truncated
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed Mar 25, 2023
1 parent 6b8a0cb commit 9aee8a3
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions rtgym/envs/real_time_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(self, config: dict=DEFAULT_CONFIG_DICT):
self.last_act_on_reset = config["last_act_on_reset"] if "last_act_on_reset" in config else False
self.act_prepro_func: callable = config["act_prepro_func"] if "act_prepro_func" in config else None
self.obs_prepro_func = config["obs_prepro_func"] if "obs_prepro_func" in config else None
self.ep_max_length = config["ep_max_length"] - 1
self.ep_max_length = config["ep_max_length"]

self.time_step_duration = config["time_step_duration"] if "time_step_duration" in config else 0.0
self.time_step_timeout_factor = config["time_step_timeout_factor"] if "time_step_timeout_factor" in config else 1.0
Expand All @@ -330,7 +330,6 @@ def __init__(self, config: dict=DEFAULT_CONFIG_DICT):
self.__obs = None
self.__rew = None
self.__terminated = None
self.__truncated = None
self.__info = None
self.__o_set_flag = False

Expand Down Expand Up @@ -452,12 +451,11 @@ def __update_obs_rew_terminated_truncated(self):
"""
self.__o_lock.acquire()
o, r, d, i = self.interface.get_obs_rew_terminated_info()
t = (self.current_step >= self.ep_max_length) if not d else False
elt = o
if self.obs_prepro_func:
elt = self.obs_prepro_func(elt)
elt = tuple(elt)
self.__obs, self.__rew, self.__terminated, self.__truncated, self.__info = elt, r, d, t, i
self.__obs, self.__rew, self.__terminated, self.__info = elt, r, d, i
self.__o_set_flag = True
self.__o_lock.release()

Expand All @@ -468,11 +466,11 @@ def _retrieve_obs_rew_terminated_truncated_info(self):
while c:
self.__o_lock.acquire()
if self.__o_set_flag:
elt, r, d, t, i = self.__obs, self.__rew, self.__terminated, self.__truncated, self.__info
elt, r, d, i = self.__obs, self.__rew, self.__terminated, self.__info
self.__o_set_flag = False
c = False
self.__o_lock.release()
return elt, r, d, t, i
return elt, r, d, i

def init_action_buffer(self):
for _ in range(self.act_buf_len):
Expand Down Expand Up @@ -551,7 +549,8 @@ def step(self, action):
self._run_time_step(action)
if not self.running:
raise RuntimeError("The episode is terminated or truncated. Call reset before step.")
obs, rew, terminated, truncated, info = self._retrieve_obs_rew_terminated_truncated_info()
obs, rew, terminated, info = self._retrieve_obs_rew_terminated_truncated_info()
truncated = (self.current_step >= self.ep_max_length) if not terminated else False
done = (terminated or truncated)
if not done: # apply action only when not done
self._run_time_step(action)
Expand Down

0 comments on commit 9aee8a3

Please sign in to comment.