Skip to content

Commit

Permalink
Merge pull request #8 from boettiger-lab/nice-little-things
Browse files Browse the repository at this point in the history
Nice little things
  • Loading branch information
nicolalove committed Jul 5, 2024
2 parents 62f9f9e + fca8c5b commit 8b3c063
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion hyperpars/rppo-caribou.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ tensorboard: "/home/rstudio/logs"
total_timesteps: 500000
config: {}
use_sde: True
id: "1"
# id: "1"
repo: "boettiger-lab/rl4eco"
save_path: "/home/rstudio/rl4caribou/saved_agents"
2 changes: 1 addition & 1 deletion hyperpars/td3-caribou.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ env_id: "CaribouScipy"
config: {}
tensorboard: "/home/rstudio/logs"
total_timesteps: 500000
id: "1"
# id: "1"
repo: "boettiger-lab/rl4eco"
save_path: "/home/rstudio/rl4caribou/saved_agents"
4 changes: 3 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", help="Path config file", type=str)
parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool)
parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool, default=True)
parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=bool, default="0")
args = parser.parse_args()

import rl4caribou
Expand All @@ -21,6 +22,7 @@
#
from rl4caribou.utils import sb3_train
model_save_id, train_options = sb3_train(abs_filepath, progress_bar=args.progress_bar)
model_save_id = model_save_id + "_id_" + args.identifier

# hugging face
#
Expand Down
6 changes: 3 additions & 3 deletions src/rl4caribou/utils/sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def sb3_train(config_file, **kwargs):
options['algo_config']['policy_kwargs'] = eval(options['algo_config']['policy_kwargs'])

ALGO = algorithm(options["algo"])
if "id" in options:
options["id"] = "-" + options["id"]
model_id = options["algo"] + "-" + options["env_id"] + options.get("id", "")
# if "id" in options:
# options["id"] = "-" + options["id"]
model_id = options["algo"] + "-" + options["env_id"] # + options.get("id", "")
save_id = os.path.join(options["save_path"], model_id)

model = ALGO(
Expand Down

0 comments on commit 8b3c063

Please sign in to comment.