Skip to content

Commit

Permalink
Agents utils (#2)
Browse files Browse the repository at this point in the history
* agents, agent tests

* sb3 and hugging face utils

* updated train script
  • Loading branch information
felimomo committed Feb 28, 2024
1 parent 573cf20 commit cbe7934
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 2 deletions.
15 changes: 13 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@

import rl4caribou

# training
#
from rl4caribou.utils import sb3_train
model_save_id, train_options = sb3_train(args.file)

from rl4eco.utils import sb3_train
sb3_train(args.file)
# hugging face
#
if 'repo' in train_options:
from rl4caribou.utils import upload_to_hf
try:
upload_to_hf(args.file, "sb3/"+args.file, repo=train_options['repo'])
upload_to_hf(model_save_id, "sb3/"+model_save_id+".zip", repo=train_options['repo'])
except:
print("Couldn't upload to hf!")
2 changes: 2 additions & 0 deletions src/rl4caribou/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from rl4caribou.agents.const_escapement import constEsc
from rl4caribou.agents.const_action import constAction
15 changes: 15 additions & 0 deletions src/rl4caribou/agents/const_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np

class constAction:
def __init__(self, mortality_vec=np.zeros(2, dtype=np.float32), env = None, **kwargs):
#
# preprocess
if isinstance(mortality_vec, list):
mortality_vec = np.float32(mortality_vec)
#
self.mortality_vec = mortality_vec
self.action = 2 * self.mortality_vec - 1
self.env = env

def predict(self, observation, **kwargs):
return self.action, {}
45 changes: 45 additions & 0 deletions src/rl4caribou/agents/const_escapement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

class constEsc:
def __init__(self, escapement_vec, env = None):
#
# preprocess
if isinstance(escapement_vec, list):
escapement_vec = np.float32(escapement_vec)
escapement_vec = np.clip(
escapement_vec, a_min = 0, a_max = None
)
#
self.escapement_vec = escapement_vec
self.env = env
self.bound = 1
if self.env is not None:
self.bound = self.env.bound

def predict(self, observation, **kwargs):
obs_nat_units = self.bound * self.to_01(observation)
m_mort = self.moose_mortality(obs_nat_units[0])
w_mort = self.wolf_mortality(obs_nat_units[2])
mortality = np.float32([m_mort, w_mort])
return self.to_pm1(mortality), {}

def moose_mortality(self, moose_pop):
if moose_pop <= self.escapement_vec[0]:
return 0
else:
return (moose_pop - self.escapement_vec[0]) / moose_pop

def wolf_mortality(self, wolf_pop):
if wolf_pop <= self.escapement_vec[1]:
return 0
else:
return (wolf_pop - self.escapement_vec[1]) / wolf_pop

def to_01(self, val):
return (val + 1 ) / 2

def to_pm1(self, val):
return 2 * val - 1



Empty file added src/rl4caribou/envs/__init__.py
Empty file.
Empty file.
17 changes: 17 additions & 0 deletions src/rl4caribou/utils/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from huggingface_sb3 import load_from_hub, package_to_hub
from huggingface_hub import HfApi

from os.path import basename
import pathlib

def upload_to_hf(path, path_in_repo, repo, clean=False):
api = HfApi()
if path_in_repo is None:
path_in_repo = basename(path)
api.upload_file(
path_or_fileobj=path,
path_in_repo=path_in_repo,
repo_id=repo,
repo_type="model")
if clean:
pathlib.Path(path).unlink()
71 changes: 71 additions & 0 deletions src/rl4caribou/utils/sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import yaml
import os

import gymnasium as gym
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO, A2C, DQN, SAC, TD3, HER, DDPG
from sb3_contrib import TQC, ARS, RecurrentPPO

def algorithm(algo):
algos = {
'PPO': PPO,
'ppo': PPO,
'RecurrentPPO': RecurrentPPO,
'RPPO': RecurrentPPO,
'recurrentppo': RecurrentPPO,
'rppo': RecurrentPPO,
#
'ARS': ARS,
'ars': ARS,
'A2C': A2C,
'a2c':A2C ,
#
'DDPG': DDPG,
'ddpg': DDPG,
#
'HER': HER,
'her': HER,
#
'SAC': SAC,
'sac': SAC,
#
'TD3': TD3,
'td3': TD3,
#
'TQC': TQC,
'tqc': TQC,
}
return algos[algo]

def sb3_train(config_file, **kwargs):
with open(config_file, "r") as stream:
options = yaml.safe_load(stream)
options = {**options, **kwargs}
# updates / expands on yaml options with optional user-provided input

if "n_envs" in options:
env = make_vec_env(
options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]}
)
else:
env = gym.make(options["env_id"])
ALGO = algorithm(options["algo"])
model_id = options["algo"] + "-" + options["env_id"] + "-" + options["id"]
save_id = os.path.join(options["save_path"], model_id)

model = ALGO(
"MlpPolicy",
env,
verbose=0,
tensorboard_log=options["tensorboard"],
use_sde=options["use_sde"],
)

progress_bar = options.get("progress_bar", False)
model.learn(total_timesteps=options["total_timesteps"], tb_log_name=model_id, progress_bar=progress_bar)

os.makedirs(options["save_path"], exist_ok=True)
model.save(save_id)
print(f"Saved {options['algo']} model at {save_id}")

return save_id, options
16 changes: 16 additions & 0 deletions tests/test_caribou.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
# Confirm environment is correctly defined:
from stable_baselines3.common.env_checker import check_env
import numpy as np
from rl4caribou import Caribou
from rl4caribou.agents import constAction, constEsc

def test_Caribou():
check_env(Caribou(), warn=True)

def test_constAction():
ca1 = constAction(mortality_vec = [0,0])
ca2 = constAction(mortality_vec = np.zeros(2))
obs = np.zeros(3)
pr1, _ = ca1.predict(obs)
pr2, _ = ca2.predict(obs)

def test_constEsc():
ce1 = constEsc(escapement_vec = [0,0])
ce2 = constEsc(escapement_vec = np.zeros(2))
obs = np.zeros(3)
pr1, _ = ce1.predict(obs)
pr2, _ = ce2.predict(obs)

0 comments on commit cbe7934

Please sign in to comment.