From 1b4a4807583c3dd50f53368657ebf281fd457423 Mon Sep 17 00:00:00 2001 From: Felipe Montealegre-Mora Date: Fri, 5 Jul 2024 17:40:36 +0000 Subject: [PATCH 1/2] id defined on terminal call to train.py --- scripts/train.py | 4 +++- src/rl4caribou/utils/sb3.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index ad49ef4..346fac5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -2,7 +2,8 @@ import argparse parser = argparse.ArgumentParser() parser.add_argument("-f", "--file", help="Path config file", type=str) -parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool) +parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool, default=True) +parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=bool, default="0") args = parser.parse_args() import rl4caribou @@ -21,6 +22,7 @@ # from rl4caribou.utils import sb3_train model_save_id, train_options = sb3_train(abs_filepath, progress_bar=args.progress_bar) +model_save_id = model_save_id + "_id_" + args.identifier # hugging face # diff --git a/src/rl4caribou/utils/sb3.py b/src/rl4caribou/utils/sb3.py index e581eb9..b472ffd 100644 --- a/src/rl4caribou/utils/sb3.py +++ b/src/rl4caribou/utils/sb3.py @@ -63,9 +63,9 @@ def sb3_train(config_file, **kwargs): options['algo_config']['policy_kwargs'] = eval(options['algo_config']['policy_kwargs']) ALGO = algorithm(options["algo"]) - if "id" in options: - options["id"] = "-" + options["id"] - model_id = options["algo"] + "-" + options["env_id"] + options.get("id", "") + # if "id" in options: + # options["id"] = "-" + options["id"] + model_id = options["algo"] + "-" + options["env_id"] # + options.get("id", "") save_id = os.path.join(options["save_path"], model_id) model = ALGO( From fca8c5b69bbc9ae3f558758418372a960dbc644e Mon Sep 17 00:00:00 2001 From: Felipe Montealegre-Mora Date: Fri, 5 Jul 2024 17:42:36 +0000 Subject: [PATCH 2/2] no more id inside hyperpars/... yamls --- hyperpars/rppo-caribou.yml | 2 +- hyperpars/td3-caribou.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hyperpars/rppo-caribou.yml b/hyperpars/rppo-caribou.yml index 5478c47..4fc2b56 100644 --- a/hyperpars/rppo-caribou.yml +++ b/hyperpars/rppo-caribou.yml @@ -8,6 +8,6 @@ tensorboard: "/home/rstudio/logs" total_timesteps: 500000 config: {} use_sde: True -id: "1" +# id: "1" repo: "boettiger-lab/rl4eco" save_path: "/home/rstudio/rl4caribou/saved_agents" diff --git a/hyperpars/td3-caribou.yml b/hyperpars/td3-caribou.yml index 1e673f5..b77f76c 100644 --- a/hyperpars/td3-caribou.yml +++ b/hyperpars/td3-caribou.yml @@ -5,6 +5,6 @@ env_id: "CaribouScipy" config: {} tensorboard: "/home/rstudio/logs" total_timesteps: 500000 -id: "1" +# id: "1" repo: "boettiger-lab/rl4eco" save_path: "/home/rstudio/rl4caribou/saved_agents" \ No newline at end of file