From b42c4bdede939d1936c124a2eaee0e6778024e6b Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Wed, 17 May 2023 18:34:02 -0400 Subject: [PATCH 1/4] Users can now specify and pass wandb details from the config files --- examples/example_configs/hopper_npg.txt | 6 ++++++ examples/example_configs/swimmer_npg.txt | 6 ++++++ examples/example_configs/swimmer_ppo.txt | 8 +++++++- examples/policy_opt_job_script.py | 7 +++++++ mjrl/utils/logger.py | 7 ++++++- 5 files changed, 32 insertions(+), 2 deletions(-) diff --git a/examples/example_configs/hopper_npg.txt b/examples/example_configs/hopper_npg.txt index bd98381..bf16679 100644 --- a/examples/example_configs/hopper_npg.txt +++ b/examples/example_configs/hopper_npg.txt @@ -29,5 +29,11 @@ 'alg_hyper_params' : dict(), +'wandb_params': { + 'use_wandb' : True, + 'wandb_user' : 'vikashplus', + 'wandb_project' : 'mjrl_demo', + 'wandb_exp' : 'demo_exp', + } } diff --git a/examples/example_configs/swimmer_npg.txt b/examples/example_configs/swimmer_npg.txt index f8af3a8..8c09008 100644 --- a/examples/example_configs/swimmer_npg.txt +++ b/examples/example_configs/swimmer_npg.txt @@ -29,4 +29,10 @@ 'alg_hyper_params' : dict(), +'wandb_params': { + 'use_wandb' : True, + 'wandb_user' : 'vikashplus', + 'wandb_project' : 'mjrl_demo', + 'wandb_exp' : 'demo_exp', + } } \ No newline at end of file diff --git a/examples/example_configs/swimmer_ppo.txt b/examples/example_configs/swimmer_ppo.txt index e20d592..6c561e0 100644 --- a/examples/example_configs/swimmer_ppo.txt +++ b/examples/example_configs/swimmer_ppo.txt @@ -7,7 +7,7 @@ 'seed' : 123, 'sample_mode' : 'trajectories', 'rl_num_traj' : 10, -'rl_num_iter' : 50, +'rl_num_iter' : 10, 'num_cpu' : 2, 'save_freq' : 25, 'eval_rollouts' : None, @@ -29,4 +29,10 @@ 'alg_hyper_params' : dict(clip_coef=0.2, epochs=10, mb_size=64, learn_rate=5e-4), +'wandb_params': { + 'use_wandb' : True, + 'wandb_user' : 'vikashplus', + 'wandb_project' : 'mjrl_demo', + 'wandb_exp' : 'demo_exp', + } } \ No newline at end of file diff --git a/examples/policy_opt_job_script.py b/examples/policy_opt_job_script.py index 0ee68df..7253ed0 100644 --- a/examples/policy_opt_job_script.py +++ b/examples/policy_opt_job_script.py @@ -13,6 +13,7 @@ from mjrl.algos.batch_reinforce import BatchREINFORCE from mjrl.algos.ppo_clip import PPO from mjrl.utils.train_agent import train_agent +from mjrl.utils.logger import DataLog import os import json import gym @@ -82,6 +83,12 @@ # or defaults in the PPO algorithm will be used agent = PPO(e, policy, baseline, save_logs=True, **job_data['alg_hyper_params']) + +# Update logger if WandB in Config +if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True: + agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data) + + print("========================================") print("Starting policy learning") print("========================================") diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py index 6494eb2..9fd4845 100644 --- a/mjrl/utils/logger.py +++ b/mjrl/utils/logger.py @@ -13,14 +13,19 @@ class DataLog: - def __init__(self, use_wandb:bool = True, + def __init__(self, + use_wandb:bool = False, wandb_user:str = USERNAME, wandb_project:str = WANDB_PROJECT, + wandb_exp:str = None, wandb_config:dict = dict()) -> None: self.use_wandb = use_wandb if use_wandb: import wandb self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config) + # Update exp name if explicitely specified + if wandb_exp is not None: wandb.run.name = wandb_exp + self.log = {} self.max_len = 0 self.global_step = 0 From cc6af69aaef97d3db6e5f748bf7e4e3f6e4bb723 Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Sun, 21 May 2023 00:17:54 -0400 Subject: [PATCH 2/4] FEATURE: directory can be specified for wandb logs --- examples/policy_opt_job_script.py | 4 ++++ mjrl/utils/logger.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/policy_opt_job_script.py b/examples/policy_opt_job_script.py index 7253ed0..e0ae249 100644 --- a/examples/policy_opt_job_script.py +++ b/examples/policy_opt_job_script.py @@ -86,6 +86,10 @@ # Update logger if WandB in Config if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True: + if 'wandb_logdir' in job_data['wandb_params']: + job_data['wandb_params']['wandb_logdir'] = os.path.join(JOB_DIR, job_data['wandb_params']['wandb_logdir']) + else: + job_data['wandb_params']['wandb_logdir'] = JOB_DIR agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data) diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py index 9fd4845..f96a074 100644 --- a/mjrl/utils/logger.py +++ b/mjrl/utils/logger.py @@ -18,11 +18,12 @@ def __init__(self, wandb_user:str = USERNAME, wandb_project:str = WANDB_PROJECT, wandb_exp:str = None, + wandb_logdir:str = None, wandb_config:dict = dict()) -> None: self.use_wandb = use_wandb if use_wandb: import wandb - self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config) + self.run = wandb.init(project=wandb_project, entity=wandb_user, dir=wandb_logdir, config=wandb_config) # Update exp name if explicitely specified if wandb_exp is not None: wandb.run.name = wandb_exp From 104732fd6f46e754e7d652d72ab4b60178d9e821 Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Sat, 24 Jun 2023 14:28:57 -0400 Subject: [PATCH 3/4] Users should explicitly import these envs if need be. They have mujoco_py dependency that all setups have --- mjrl/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mjrl/__init__.py b/mjrl/__init__.py index 008b133..affb942 100644 --- a/mjrl/__init__.py +++ b/mjrl/__init__.py @@ -1 +1,2 @@ -import mjrl.envs \ No newline at end of file +# Users should explicitly import these envs if need be. They have mujoco_py dependency that all setups have +# import mjrl.envs \ No newline at end of file From 6e058535b8f0fa3368ce33e7090abbf275f9307e Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Sat, 24 Jun 2023 14:35:20 -0400 Subject: [PATCH 4/4] Users should explicitly import these envs if need be. They have mujoco_py dependency that all setups have --- mjrl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mjrl/__init__.py b/mjrl/__init__.py index affb942..00e188e 100644 --- a/mjrl/__init__.py +++ b/mjrl/__init__.py @@ -1,2 +1,2 @@ -# Users should explicitly import these envs if need be. They have mujoco_py dependency that all setups have +# Users should explicitly import these envs if need be. They have mujoco_py dependency that not all setups have # import mjrl.envs \ No newline at end of file