Skip to content

Commit

Permalink
env can be given as direct input to agent.train_step
Browse files Browse the repository at this point in the history
  • Loading branch information
aravindr93 committed Sep 13, 2016
1 parent 45ed7c1 commit 10fbc7d
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions robustRL/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,21 @@ def train_step(self, N, T, gamma, env_mode='train',
idx=None,
mujoco_env=True,
normalized_env=False,
sub_sample=None):
sub_sample=None,
train_env=None):
""" N = number of trajectories
T = horizon
env_mode = can be 'train', 'test' or something else.
You need to write the appropriate function in MDP_funcs
"""

paths = sample_paths_parallel(N, self.policy, self.baseline, env_mode,
T, gamma, num_cpu=num_cpu, mujoco_env=mujoco_env, normalized_env=normalized_env)

if train_env == None:
paths = sample_paths_parallel(N, self.policy, self.baseline, env_mode,
T, gamma, num_cpu=num_cpu, mujoco_env=mujoco_env, normalized_env=normalized_env)
else:
paths = sample_paths(N, self.policy, self.baseline, env=train_env, T=T, gamma=gamma,
mujoco_env=mujoco_env, normalized_env=normalized_env)

# save the paths used to make the policy update
if save_paths == True and idx != None:
Expand Down Expand Up @@ -259,15 +265,17 @@ def train(self, N, T, gamma, niter, env_mode='train'):
return eval_statistics


def train_step(self, N, T, gamma, env_mode='train'):
def train_step(self, N, T, gamma, env_mode='train', train_env=None):
""" N = number of trajectories
T = horizon
env_mode = can be 'train', 'test' or something else.
You need to write the appropriate function in MDP_funcs
"""

paths = sample_paths_parallel(N, T, gamma, self.policy, self.baseline, env, num_cpu='max')

if train_env == None:
paths = sample_paths_parallel(N, T, gamma, self.policy, self.baseline, env, num_cpu='max')
else:
paths = sample_paths(N, T, gamma, self.policy, self.baseline, train_env)
eval_statistics = self.train_from_paths(paths)

return eval_statistics
Expand Down Expand Up @@ -311,4 +319,4 @@ def train_from_paths(self, paths):
std_return = np.std(path_returns)
min_return = np.amin(path_returns)
max_return = np.amax(path_returns)
return (mean_return, std_return, min_return, max_return)
return (mean_return, std_return, min_return, max_return)

0 comments on commit 10fbc7d

Please sign in to comment.