Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pvr beta 1vk #51

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/example_configs/hopper_npg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,11 @@

'alg_hyper_params' : dict(),

'wandb_params': {
'use_wandb' : True,
'wandb_user' : 'vikashplus',
'wandb_project' : 'mjrl_demo',
'wandb_exp' : 'demo_exp',
}
}

6 changes: 6 additions & 0 deletions examples/example_configs/swimmer_npg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@

'alg_hyper_params' : dict(),

'wandb_params': {
'use_wandb' : True,
'wandb_user' : 'vikashplus',
'wandb_project' : 'mjrl_demo',
'wandb_exp' : 'demo_exp',
}
}
8 changes: 7 additions & 1 deletion examples/example_configs/swimmer_ppo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'seed' : 123,
'sample_mode' : 'trajectories',
'rl_num_traj' : 10,
'rl_num_iter' : 50,
'rl_num_iter' : 10,
'num_cpu' : 2,
'save_freq' : 25,
'eval_rollouts' : None,
Expand All @@ -29,4 +29,10 @@

'alg_hyper_params' : dict(clip_coef=0.2, epochs=10, mb_size=64, learn_rate=5e-4),

'wandb_params': {
'use_wandb' : True,
'wandb_user' : 'vikashplus',
'wandb_project' : 'mjrl_demo',
'wandb_exp' : 'demo_exp',
}
}
11 changes: 11 additions & 0 deletions examples/policy_opt_job_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mjrl.algos.batch_reinforce import BatchREINFORCE
from mjrl.algos.ppo_clip import PPO
from mjrl.utils.train_agent import train_agent
from mjrl.utils.logger import DataLog
import os
import json
import gym
Expand Down Expand Up @@ -82,6 +83,16 @@
# or defaults in the PPO algorithm will be used
agent = PPO(e, policy, baseline, save_logs=True, **job_data['alg_hyper_params'])


# Update logger if WandB in Config
if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True:
if 'wandb_logdir' in job_data['wandb_params']:
job_data['wandb_params']['wandb_logdir'] = os.path.join(JOB_DIR, job_data['wandb_params']['wandb_logdir'])
else:
job_data['wandb_params']['wandb_logdir'] = JOB_DIR
agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data)


print("========================================")
print("Starting policy learning")
print("========================================")
Expand Down
3 changes: 2 additions & 1 deletion mjrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
import mjrl.envs
# Users should explicitly import these envs if need be. They have mujoco_py dependency that not all setups have
# import mjrl.envs
2 changes: 2 additions & 0 deletions mjrl/algos/behavior_cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ def fit(self, data, suppress_fit_tqdm=False, **kwargs):
self.logger.log_kv('loss_before', loss_val)

# train loop
self.policy.model.train()
for ep in config_tqdm(range(self.epochs), suppress_fit_tqdm):
for mb in range(int(num_samples / self.mb_size)):
rand_idx = np.random.choice(num_samples, size=self.mb_size)
self.optimizer.zero_grad()
loss = self.loss(data, idx=rand_idx)
loss.backward()
self.optimizer.step()
self.policy.model.eval()
params_after_opt = self.policy.get_param_values()
self.policy.set_param_values(params_after_opt, set_new=True, set_old=True)

Expand Down
1 change: 0 additions & 1 deletion mjrl/algos/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# samplers
import mjrl.samplers.core as trajectory_sampler
import mjrl.samplers.batch_sampler as batch_sampler

# utility functions
import mjrl.utils.process_samples as process_samples
Expand Down
96 changes: 69 additions & 27 deletions mjrl/policies/gaussian_mlp.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,44 @@
import numpy as np
from mjrl.utils.fc_network import FCNetwork
from mjrl.utils.fc_network import FCNetwork, FCNetworkWithBatchNorm
import torch
from torch.autograd import Variable


class MLP:
def __init__(self, env_spec,
hidden_sizes=(64,64),
min_log_std=-3,
init_log_std=0,
seed=None):
"""
:param env_spec: specifications of the env (see utils/gym_env.py)
:param hidden_sizes: network hidden layer sizes (currently 2 layers only)
:param min_log_std: log_std is clamped at this value and can't go below
:param init_log_std: initial log standard deviation
:param seed: random seed
"""
self.n = env_spec.observation_dim # number of states
self.m = env_spec.action_dim # number of actions
self.min_log_std = min_log_std

# Set seed
# ------------------------
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)

# Policy network
# ------------------------
self.model = FCNetwork(self.n, self.m, hidden_sizes)
# make weights small
for param in list(self.model.parameters())[-2:]: # only last layer
param.data = 1e-2 * param.data
self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
self.trainable_params = list(self.model.parameters()) + [self.log_std]

# Old Policy network
# ------------------------
self.old_model = FCNetwork(self.n, self.m, hidden_sizes)
self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
for idx, param in enumerate(self.old_params):
param.data = self.trainable_params[idx].data.clone()

# Easy access variables
# -------------------------
self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
self.d = np.sum(self.param_sizes) # total number of params

# Placeholders
# ------------------------
self.obs_var = Variable(torch.randn(self.n), requires_grad=False)

# Utility functions
# ============================================
def show_activations(self):
return self.model.activations

def get_param_values(self):
params = np.concatenate([p.contiguous().view(-1).data.numpy()
for p in self.trainable_params])
Expand All @@ -70,10 +52,8 @@ def set_param_values(self, new_params, set_new=True, set_old=True):
vals = vals.reshape(self.param_shapes[idx])
param.data = torch.from_numpy(vals).float()
current_idx += self.param_sizes[idx]
# clip std at minimum value
self.trainable_params[-1].data = \
torch.clamp(self.trainable_params[-1], self.min_log_std).data
# update log_std_val for sampling
self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
if set_old:
current_idx = 0
Expand All @@ -82,12 +62,9 @@ def set_param_values(self, new_params, set_new=True, set_old=True):
vals = vals.reshape(self.param_shapes[idx])
param.data = torch.from_numpy(vals).float()
current_idx += self.param_sizes[idx]
# clip std at minimum value
self.old_params[-1].data = \
torch.clamp(self.old_params[-1], self.min_log_std).data

# Main functions
# ============================================
def get_action(self, observation):
o = np.float32(observation.reshape(1, -1))
self.obs_var.data = torch.from_numpy(o)
Expand Down Expand Up @@ -143,3 +120,68 @@ def mean_kl(self, new_dist_info, old_dist_info):
Dr = 2 * new_std ** 2 + 1e-8
sample_kl = torch.sum(Nr / Dr + new_log_std - old_log_std, dim=1)
return torch.mean(sample_kl)

# Ensure to close the writer when done
def close_writer(self):
self.model.close_writer()
self.old_model.close_writer()



class BatchNormMLP(MLP):
def __init__(self, env_spec,
hidden_sizes=(64,64),
min_log_std=-3,
init_log_std=0,
seed=None,
nonlinearity='relu',
dropout=0,
log_dir='runs/activations_with_batchnorm',
*args, **kwargs):
self.n = env_spec.observation_dim # number of states
self.m = env_spec.action_dim # number of actions
self.min_log_std = min_log_std

if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)

self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout, log_dir=log_dir)

for param in list(self.model.parameters())[-2:]: # only last layer
param.data = 1e-2 * param.data
self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
self.trainable_params = list(self.model.parameters()) + [self.log_std]
self.model.eval()

self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout, log_dir=log_dir)
self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
for idx, param in enumerate(self.old_params):
param.data = self.trainable_params[idx].data.clone()
self.old_model.eval()

self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
self.d = np.sum(self.param_sizes) # total number of params

self.obs_var = Variable(torch.randn(self.n), requires_grad=False)

# Register hooks to log activations
self.model.register_hooks()
self.old_model.register_hooks()
self.close_writer()

def get_action(self, observation):
o = np.float32(observation.reshape(1, -1))
self.obs_var.data = torch.from_numpy(o)
mean = self.model(self.obs_var).data.numpy().ravel()
noise = np.exp(self.log_std_val) * np.random.randn(self.m)
action = mean + noise
return [action, {'mean': mean, 'log_std': self.log_std_val, 'evaluation': mean}]

# Ensure to close the writer when done
def close_writer(self):
self.model.close_writer()
self.old_model.close_writer()
29 changes: 26 additions & 3 deletions mjrl/samplers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import multiprocessing as mp
import time as timer
logging.disable(logging.CRITICAL)
import gc


# Single core rollout to sample trajectories
Expand Down Expand Up @@ -93,6 +94,7 @@ def do_rollout(
paths.append(path)

del(env)
gc.collect()
return paths


Expand Down Expand Up @@ -134,7 +136,7 @@ def sample_paths(
start_time = timer.time()
print("####### Gathering Samples #######")

results = _try_multiprocess(do_rollout, input_dict_list,
results = _try_multiprocess_cf(do_rollout, input_dict_list,
num_cpu, max_process_time, max_timeouts)
paths = []
# result is a paths type and results is list of paths
Expand Down Expand Up @@ -186,7 +188,7 @@ def sample_data_batch(
return paths


def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts):

# Base case
if max_timeouts == 0:
Expand All @@ -202,9 +204,30 @@ def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_time
pool.close()
pool.terminate()
pool.join()
return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)
return _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)

pool.close()
pool.terminate()
pool.join()
return results


def _try_multiprocess_cf(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
import concurrent.futures
results = None
if max_timeouts != 0:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpu) as executor:
submit_futures = [executor.submit(func, **input_dict) for input_dict in input_dict_list]
try:
results = [f.result() for f in submit_futures]
except TimeoutError as e:
print(str(e))
print("Timeout Error raised...")
except concurrent.futures.CancelledError as e:
print(str(e))
print("Future Cancelled Error raised...")
except Exception as e:
print(str(e))
print("Error raised...")
raise e
return results
Loading