Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Montealegre-Mora committed Jul 9, 2024
2 parents 5446e3c + e23d679 commit 0572f61
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# output
saved_agents/
*.zip
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,21 @@ A DRL-based approach to Caribou conservation based on the methods of the

A three-species foodweb is considered, including interactions between Caribou, Elk and Wolf populations.

## Installation

```
git clone https://github.com/boettiger-lab/rl4caribou.git
cd rl4caribou
pip install .
```
## Train an agent

```
python scripts/train.py -f path/to/config/file.yml -id "string saving id" [-pb if you want a progress bar displayed]
```

An quick example:

```
python scripts/train.py -f hyperpars/example.yml -id "my_first_agent" -pb
```
8 changes: 6 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", help="Path config file", type=str)
parser.add_argument("-pb", "--progress_bar", help="Use progress bar for training", type=bool, default=True)
parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=bool, default="0")
parser.add_argument("-id", "--identifier", help="ID string for saving the agent", type=str, default="0")
args = parser.parse_args()

import rl4caribou
Expand All @@ -21,7 +21,11 @@
# training
#
from rl4caribou.utils import sb3_train
model_save_id, train_options = sb3_train(abs_filepath, progress_bar=args.progress_bar)
model_save_id, train_options = sb3_train(
abs_filepath,
progress_bar=args.progress_bar,
identifier=args.identifier,
)
model_save_id = model_save_id + "_id_" + args.identifier

# hugging face
Expand Down
6 changes: 3 additions & 3 deletions src/rl4caribou/utils/sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def algorithm(algo):
'tqc': TQC,
}
return algos[algo]
def sb3_train(config_file, **kwargs):

def sb3_train(config_file, progress_bar=True, identifier="0", **kwargs):
with open(config_file, "r") as stream:
options = yaml.safe_load(stream)
options = {**options, **kwargs}
Expand Down Expand Up @@ -65,15 +66,14 @@ def sb3_train(config_file, **kwargs):
ALGO = algorithm(options["algo"])
# if "id" in options:
# options["id"] = "-" + options["id"]
model_id = options["algo"] + "-" + options["env_id"] # + options.get("id", "")
model_id = options["algo"] + "-" + options["env_id"] + "_id_" + identifier
save_id = os.path.join(options["save_path"], model_id)

model = ALGO(
env=env,
**options['algo_config']
)

progress_bar = options.get("progress_bar", False)
model.learn(total_timesteps=options["total_timesteps"], tb_log_name=model_id, progress_bar=progress_bar)

os.makedirs(options["save_path"], exist_ok=True)
Expand Down

0 comments on commit 0572f61

Please sign in to comment.