From a5304893c1afa696ca64f603eec9d4374a59c3cd Mon Sep 17 00:00:00 2001 From: Felipe Montealegre-Mora <34276401+felimomo@users.noreply.github.com> Date: Tue, 9 Jul 2024 15:47:31 -0700 Subject: [PATCH] Nice little things (#13) * np.inf discontinued (pt3) * typo * added zips to gitignore * fixed saving bug on train - sb3.sb3_train() interaction * bug --------- Co-authored-by: Felipe Montealegre-Mora --- .gitignore | 4 ++++ scripts/train.py | 8 ++++++-- src/rl4caribou/utils/sb3.py | 6 +++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 68bc17f..22d2d49 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# output +saved_agents/ +*.zip diff --git a/scripts/train.py b/scripts/train.py index 346fac5..02a61d4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -3,7 +3,7 @@ parser = argparse.ArgumentParser() parser.add_argument("-f", "--file", help="Path config file", type=str) parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool, default=True) -parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=bool, default="0") +parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=str, default="0") args = parser.parse_args() import rl4caribou @@ -21,7 +21,11 @@ # training # from rl4caribou.utils import sb3_train -model_save_id, train_options = sb3_train(abs_filepath, progress_bar=args.progress_bar) +model_save_id, train_options = sb3_train( + abs_filepath, + progress_bar=args.progress_bar, + identifier=args.identifier, +) model_save_id = model_save_id + "_id_" + args.identifier # hugging face diff --git a/src/rl4caribou/utils/sb3.py b/src/rl4caribou/utils/sb3.py index b472ffd..c640afc 100644 --- a/src/rl4caribou/utils/sb3.py +++ b/src/rl4caribou/utils/sb3.py @@ -36,7 +36,8 @@ def algorithm(algo): 'tqc': TQC, } return algos[algo] -def sb3_train(config_file, **kwargs): + +def sb3_train(config_file, progress_bar=True, identifier="0", **kwargs): with open(config_file, "r") as stream: options = yaml.safe_load(stream) options = {**options, **kwargs} @@ -65,7 +66,7 @@ def sb3_train(config_file, **kwargs): ALGO = algorithm(options["algo"]) # if "id" in options: # options["id"] = "-" + options["id"] - model_id = options["algo"] + "-" + options["env_id"] # + options.get("id", "") + model_id = options["algo"] + "-" + options["env_id"] + "_id_" + identifier save_id = os.path.join(options["save_path"], model_id) model = ALGO( @@ -73,7 +74,6 @@ def sb3_train(config_file, **kwargs): **options['algo_config'] ) - progress_bar = options.get("progress_bar", False) model.learn(total_timesteps=options["total_timesteps"], tb_log_name=model_id, progress_bar=progress_bar) os.makedirs(options["save_path"], exist_ok=True)