Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Users can now specify and pass wandb details from the config files #48

Open
wants to merge 4 commits into
base: pvr_beta_1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/example_configs/hopper_npg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,11 @@

'alg_hyper_params' : dict(),

'wandb_params': {
'use_wandb' : True,
'wandb_user' : 'vikashplus',
'wandb_project' : 'mjrl_demo',
'wandb_exp' : 'demo_exp',
}
}

6 changes: 6 additions & 0 deletions examples/example_configs/swimmer_npg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@

'alg_hyper_params' : dict(),

'wandb_params': {
'use_wandb' : True,
'wandb_user' : 'vikashplus',
'wandb_project' : 'mjrl_demo',
'wandb_exp' : 'demo_exp',
}
}
8 changes: 7 additions & 1 deletion examples/example_configs/swimmer_ppo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'seed' : 123,
'sample_mode' : 'trajectories',
'rl_num_traj' : 10,
'rl_num_iter' : 50,
'rl_num_iter' : 10,
'num_cpu' : 2,
'save_freq' : 25,
'eval_rollouts' : None,
Expand All @@ -29,4 +29,10 @@

'alg_hyper_params' : dict(clip_coef=0.2, epochs=10, mb_size=64, learn_rate=5e-4),

'wandb_params': {
'use_wandb' : True,
'wandb_user' : 'vikashplus',
'wandb_project' : 'mjrl_demo',
'wandb_exp' : 'demo_exp',
}
}
11 changes: 11 additions & 0 deletions examples/policy_opt_job_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mjrl.algos.batch_reinforce import BatchREINFORCE
from mjrl.algos.ppo_clip import PPO
from mjrl.utils.train_agent import train_agent
from mjrl.utils.logger import DataLog
import os
import json
import gym
Expand Down Expand Up @@ -82,6 +83,16 @@
# or defaults in the PPO algorithm will be used
agent = PPO(e, policy, baseline, save_logs=True, **job_data['alg_hyper_params'])


# Update logger if WandB in Config
if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True:
if 'wandb_logdir' in job_data['wandb_params']:
job_data['wandb_params']['wandb_logdir'] = os.path.join(JOB_DIR, job_data['wandb_params']['wandb_logdir'])
else:
job_data['wandb_params']['wandb_logdir'] = JOB_DIR
agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data)


print("========================================")
print("Starting policy learning")
print("========================================")
Expand Down
3 changes: 2 additions & 1 deletion mjrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
import mjrl.envs
# Users should explicitly import these envs if need be. They have mujoco_py dependency that not all setups have
# import mjrl.envs
10 changes: 8 additions & 2 deletions mjrl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@

class DataLog:

def __init__(self, use_wandb:bool = True,
def __init__(self,
use_wandb:bool = False,
wandb_user:str = USERNAME,
wandb_project:str = WANDB_PROJECT,
wandb_exp:str = None,
wandb_logdir:str = None,
wandb_config:dict = dict()) -> None:
self.use_wandb = use_wandb
if use_wandb:
import wandb
self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config)
self.run = wandb.init(project=wandb_project, entity=wandb_user, dir=wandb_logdir, config=wandb_config)
# Update exp name if explicitely specified
if wandb_exp is not None: wandb.run.name = wandb_exp

self.log = {}
self.max_len = 0
self.global_step = 0
Expand Down