From 10fbc7d11ac522ab76d056f91c98fcc254716dff Mon Sep 17 00:00:00 2001 From: aravindr93 Date: Mon, 12 Sep 2016 22:22:42 -0700 Subject: [PATCH] env can be given as direct input to agent.train_step --- robustRL/algos.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/robustRL/algos.py b/robustRL/algos.py index 80f930f..e5d2e93 100644 --- a/robustRL/algos.py +++ b/robustRL/algos.py @@ -117,15 +117,21 @@ def train_step(self, N, T, gamma, env_mode='train', idx=None, mujoco_env=True, normalized_env=False, - sub_sample=None): + sub_sample=None, + train_env=None): """ N = number of trajectories T = horizon env_mode = can be 'train', 'test' or something else. You need to write the appropriate function in MDP_funcs """ - paths = sample_paths_parallel(N, self.policy, self.baseline, env_mode, - T, gamma, num_cpu=num_cpu, mujoco_env=mujoco_env, normalized_env=normalized_env) + + if train_env == None: + paths = sample_paths_parallel(N, self.policy, self.baseline, env_mode, + T, gamma, num_cpu=num_cpu, mujoco_env=mujoco_env, normalized_env=normalized_env) + else: + paths = sample_paths(N, self.policy, self.baseline, env=train_env, T=T, gamma=gamma, + mujoco_env=mujoco_env, normalized_env=normalized_env) # save the paths used to make the policy update if save_paths == True and idx != None: @@ -259,15 +265,17 @@ def train(self, N, T, gamma, niter, env_mode='train'): return eval_statistics - def train_step(self, N, T, gamma, env_mode='train'): + def train_step(self, N, T, gamma, env_mode='train', train_env=None): """ N = number of trajectories T = horizon env_mode = can be 'train', 'test' or something else. You need to write the appropriate function in MDP_funcs """ - paths = sample_paths_parallel(N, T, gamma, self.policy, self.baseline, env, num_cpu='max') - + if train_env == None: + paths = sample_paths_parallel(N, T, gamma, self.policy, self.baseline, env, num_cpu='max') + else: + paths = sample_paths(N, T, gamma, self.policy, self.baseline, train_env) eval_statistics = self.train_from_paths(paths) return eval_statistics @@ -311,4 +319,4 @@ def train_from_paths(self, paths): std_return = np.std(path_returns) min_return = np.amin(path_returns) max_return = np.amax(path_returns) - return (mean_return, std_return, min_return, max_return) \ No newline at end of file + return (mean_return, std_return, min_return, max_return)