diff --git a/examples/example_configs/hopper_npg.txt b/examples/example_configs/hopper_npg.txt index bd98381..bf16679 100644 --- a/examples/example_configs/hopper_npg.txt +++ b/examples/example_configs/hopper_npg.txt @@ -29,5 +29,11 @@ 'alg_hyper_params' : dict(), +'wandb_params': { + 'use_wandb' : True, + 'wandb_user' : 'vikashplus', + 'wandb_project' : 'mjrl_demo', + 'wandb_exp' : 'demo_exp', + } } diff --git a/examples/example_configs/swimmer_npg.txt b/examples/example_configs/swimmer_npg.txt index f8af3a8..8c09008 100644 --- a/examples/example_configs/swimmer_npg.txt +++ b/examples/example_configs/swimmer_npg.txt @@ -29,4 +29,10 @@ 'alg_hyper_params' : dict(), +'wandb_params': { + 'use_wandb' : True, + 'wandb_user' : 'vikashplus', + 'wandb_project' : 'mjrl_demo', + 'wandb_exp' : 'demo_exp', + } } \ No newline at end of file diff --git a/examples/example_configs/swimmer_ppo.txt b/examples/example_configs/swimmer_ppo.txt index e20d592..6c561e0 100644 --- a/examples/example_configs/swimmer_ppo.txt +++ b/examples/example_configs/swimmer_ppo.txt @@ -7,7 +7,7 @@ 'seed' : 123, 'sample_mode' : 'trajectories', 'rl_num_traj' : 10, -'rl_num_iter' : 50, +'rl_num_iter' : 10, 'num_cpu' : 2, 'save_freq' : 25, 'eval_rollouts' : None, @@ -29,4 +29,10 @@ 'alg_hyper_params' : dict(clip_coef=0.2, epochs=10, mb_size=64, learn_rate=5e-4), +'wandb_params': { + 'use_wandb' : True, + 'wandb_user' : 'vikashplus', + 'wandb_project' : 'mjrl_demo', + 'wandb_exp' : 'demo_exp', + } } \ No newline at end of file diff --git a/examples/policy_opt_job_script.py b/examples/policy_opt_job_script.py index 0ee68df..e0ae249 100644 --- a/examples/policy_opt_job_script.py +++ b/examples/policy_opt_job_script.py @@ -13,6 +13,7 @@ from mjrl.algos.batch_reinforce import BatchREINFORCE from mjrl.algos.ppo_clip import PPO from mjrl.utils.train_agent import train_agent +from mjrl.utils.logger import DataLog import os import json import gym @@ -82,6 +83,16 @@ # or defaults in the PPO algorithm will be used agent = PPO(e, policy, baseline, save_logs=True, **job_data['alg_hyper_params']) + +# Update logger if WandB in Config +if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True: + if 'wandb_logdir' in job_data['wandb_params']: + job_data['wandb_params']['wandb_logdir'] = os.path.join(JOB_DIR, job_data['wandb_params']['wandb_logdir']) + else: + job_data['wandb_params']['wandb_logdir'] = JOB_DIR + agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data) + + print("========================================") print("Starting policy learning") print("========================================") diff --git a/mjrl/__init__.py b/mjrl/__init__.py index 008b133..00e188e 100644 --- a/mjrl/__init__.py +++ b/mjrl/__init__.py @@ -1 +1,2 @@ -import mjrl.envs \ No newline at end of file +# Users should explicitly import these envs if need be. They have mujoco_py dependency that not all setups have +# import mjrl.envs \ No newline at end of file diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py index 6494eb2..f96a074 100644 --- a/mjrl/utils/logger.py +++ b/mjrl/utils/logger.py @@ -13,14 +13,20 @@ class DataLog: - def __init__(self, use_wandb:bool = True, + def __init__(self, + use_wandb:bool = False, wandb_user:str = USERNAME, wandb_project:str = WANDB_PROJECT, + wandb_exp:str = None, + wandb_logdir:str = None, wandb_config:dict = dict()) -> None: self.use_wandb = use_wandb if use_wandb: import wandb - self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config) + self.run = wandb.init(project=wandb_project, entity=wandb_user, dir=wandb_logdir, config=wandb_config) + # Update exp name if explicitely specified + if wandb_exp is not None: wandb.run.name = wandb_exp + self.log = {} self.max_len = 0 self.global_step = 0