Skip to content

Commit

Permalink
Nice little things (#13)
Browse files Browse the repository at this point in the history
* np.inf discontinued (pt3)

* typo

* added zips to gitignore

* fixed saving bug on train - sb3.sb3_train() interaction

* bug

---------

Co-authored-by: Felipe Montealegre-Mora <[email protected]>
  • Loading branch information
felimomo and Felipe Montealegre-Mora committed Jul 9, 2024
1 parent fb03276 commit a530489
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# output
saved_agents/
*.zip
8 changes: 6 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", help="Path config file", type=str)
parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool, default=True)
parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=bool, default="0")
parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=str, default="0")
args = parser.parse_args()

import rl4caribou
Expand All @@ -21,7 +21,11 @@
# training
#
from rl4caribou.utils import sb3_train
model_save_id, train_options = sb3_train(abs_filepath, progress_bar=args.progress_bar)
model_save_id, train_options = sb3_train(
abs_filepath,
progress_bar=args.progress_bar,
identifier=args.identifier,
)
model_save_id = model_save_id + "_id_" + args.identifier

# hugging face
Expand Down
6 changes: 3 additions & 3 deletions src/rl4caribou/utils/sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def algorithm(algo):
'tqc': TQC,
}
return algos[algo]
def sb3_train(config_file, **kwargs):

def sb3_train(config_file, progress_bar=True, identifier="0", **kwargs):
with open(config_file, "r") as stream:
options = yaml.safe_load(stream)
options = {**options, **kwargs}
Expand Down Expand Up @@ -65,15 +66,14 @@ def sb3_train(config_file, **kwargs):
ALGO = algorithm(options["algo"])
# if "id" in options:
# options["id"] = "-" + options["id"]
model_id = options["algo"] + "-" + options["env_id"] # + options.get("id", "")
model_id = options["algo"] + "-" + options["env_id"] + "_id_" + identifier
save_id = os.path.join(options["save_path"], model_id)

model = ALGO(
env=env,
**options['algo_config']
)

progress_bar = options.get("progress_bar", False)
model.learn(total_timesteps=options["total_timesteps"], tb_log_name=model_id, progress_bar=progress_bar)

os.makedirs(options["save_path"], exist_ok=True)
Expand Down

0 comments on commit a530489

Please sign in to comment.