-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* agents, agent tests * sb3 and hugging face utils * updated train script
- Loading branch information
Showing
9 changed files
with
179 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from rl4caribou.agents.const_escapement import constEsc | ||
from rl4caribou.agents.const_action import constAction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import numpy as np | ||
|
||
class constAction: | ||
def __init__(self, mortality_vec=np.zeros(2, dtype=np.float32), env = None, **kwargs): | ||
# | ||
# preprocess | ||
if isinstance(mortality_vec, list): | ||
mortality_vec = np.float32(mortality_vec) | ||
# | ||
self.mortality_vec = mortality_vec | ||
self.action = 2 * self.mortality_vec - 1 | ||
self.env = env | ||
|
||
def predict(self, observation, **kwargs): | ||
return self.action, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
|
||
class constEsc: | ||
def __init__(self, escapement_vec, env = None): | ||
# | ||
# preprocess | ||
if isinstance(escapement_vec, list): | ||
escapement_vec = np.float32(escapement_vec) | ||
escapement_vec = np.clip( | ||
escapement_vec, a_min = 0, a_max = None | ||
) | ||
# | ||
self.escapement_vec = escapement_vec | ||
self.env = env | ||
self.bound = 1 | ||
if self.env is not None: | ||
self.bound = self.env.bound | ||
|
||
def predict(self, observation, **kwargs): | ||
obs_nat_units = self.bound * self.to_01(observation) | ||
m_mort = self.moose_mortality(obs_nat_units[0]) | ||
w_mort = self.wolf_mortality(obs_nat_units[2]) | ||
mortality = np.float32([m_mort, w_mort]) | ||
return self.to_pm1(mortality), {} | ||
|
||
def moose_mortality(self, moose_pop): | ||
if moose_pop <= self.escapement_vec[0]: | ||
return 0 | ||
else: | ||
return (moose_pop - self.escapement_vec[0]) / moose_pop | ||
|
||
def wolf_mortality(self, wolf_pop): | ||
if wolf_pop <= self.escapement_vec[1]: | ||
return 0 | ||
else: | ||
return (wolf_pop - self.escapement_vec[1]) / wolf_pop | ||
|
||
def to_01(self, val): | ||
return (val + 1 ) / 2 | ||
|
||
def to_pm1(self, val): | ||
return 2 * val - 1 | ||
|
||
|
||
|
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from huggingface_sb3 import load_from_hub, package_to_hub | ||
from huggingface_hub import HfApi | ||
|
||
from os.path import basename | ||
import pathlib | ||
|
||
def upload_to_hf(path, path_in_repo, repo, clean=False): | ||
api = HfApi() | ||
if path_in_repo is None: | ||
path_in_repo = basename(path) | ||
api.upload_file( | ||
path_or_fileobj=path, | ||
path_in_repo=path_in_repo, | ||
repo_id=repo, | ||
repo_type="model") | ||
if clean: | ||
pathlib.Path(path).unlink() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import yaml | ||
import os | ||
|
||
import gymnasium as gym | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from stable_baselines3 import PPO, A2C, DQN, SAC, TD3, HER, DDPG | ||
from sb3_contrib import TQC, ARS, RecurrentPPO | ||
|
||
def algorithm(algo): | ||
algos = { | ||
'PPO': PPO, | ||
'ppo': PPO, | ||
'RecurrentPPO': RecurrentPPO, | ||
'RPPO': RecurrentPPO, | ||
'recurrentppo': RecurrentPPO, | ||
'rppo': RecurrentPPO, | ||
# | ||
'ARS': ARS, | ||
'ars': ARS, | ||
'A2C': A2C, | ||
'a2c':A2C , | ||
# | ||
'DDPG': DDPG, | ||
'ddpg': DDPG, | ||
# | ||
'HER': HER, | ||
'her': HER, | ||
# | ||
'SAC': SAC, | ||
'sac': SAC, | ||
# | ||
'TD3': TD3, | ||
'td3': TD3, | ||
# | ||
'TQC': TQC, | ||
'tqc': TQC, | ||
} | ||
return algos[algo] | ||
|
||
def sb3_train(config_file, **kwargs): | ||
with open(config_file, "r") as stream: | ||
options = yaml.safe_load(stream) | ||
options = {**options, **kwargs} | ||
# updates / expands on yaml options with optional user-provided input | ||
|
||
if "n_envs" in options: | ||
env = make_vec_env( | ||
options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]} | ||
) | ||
else: | ||
env = gym.make(options["env_id"]) | ||
ALGO = algorithm(options["algo"]) | ||
model_id = options["algo"] + "-" + options["env_id"] + "-" + options["id"] | ||
save_id = os.path.join(options["save_path"], model_id) | ||
|
||
model = ALGO( | ||
"MlpPolicy", | ||
env, | ||
verbose=0, | ||
tensorboard_log=options["tensorboard"], | ||
use_sde=options["use_sde"], | ||
) | ||
|
||
progress_bar = options.get("progress_bar", False) | ||
model.learn(total_timesteps=options["total_timesteps"], tb_log_name=model_id, progress_bar=progress_bar) | ||
|
||
os.makedirs(options["save_path"], exist_ok=True) | ||
model.save(save_id) | ||
print(f"Saved {options['algo']} model at {save_id}") | ||
|
||
return save_id, options |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,23 @@ | ||
# Confirm environment is correctly defined: | ||
from stable_baselines3.common.env_checker import check_env | ||
import numpy as np | ||
from rl4caribou import Caribou | ||
from rl4caribou.agents import constAction, constEsc | ||
|
||
def test_Caribou(): | ||
check_env(Caribou(), warn=True) | ||
|
||
def test_constAction(): | ||
ca1 = constAction(mortality_vec = [0,0]) | ||
ca2 = constAction(mortality_vec = np.zeros(2)) | ||
obs = np.zeros(3) | ||
pr1, _ = ca1.predict(obs) | ||
pr2, _ = ca2.predict(obs) | ||
|
||
def test_constEsc(): | ||
ce1 = constEsc(escapement_vec = [0,0]) | ||
ce2 = constEsc(escapement_vec = np.zeros(2)) | ||
obs = np.zeros(3) | ||
pr1, _ = ce1.predict(obs) | ||
pr2, _ = ce2.predict(obs) | ||
|