diff --git a/scripts/train.py b/scripts/train.py index a652f3b..18393a5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,6 +6,17 @@ import rl4caribou +# training +# +from rl4caribou.utils import sb3_train +model_save_id, train_options = sb3_train(args.file) -from rl4eco.utils import sb3_train -sb3_train(args.file) +# hugging face +# +if 'repo' in train_options: + from rl4caribou.utils import upload_to_hf + try: + upload_to_hf(args.file, "sb3/"+args.file, repo=train_options['repo']) + upload_to_hf(model_save_id, "sb3/"+model_save_id+".zip", repo=train_options['repo']) + except: + print("Couldn't upload to hf!") \ No newline at end of file diff --git a/src/rl4caribou/agents/__init__.py b/src/rl4caribou/agents/__init__.py new file mode 100644 index 0000000..d4b412a --- /dev/null +++ b/src/rl4caribou/agents/__init__.py @@ -0,0 +1,2 @@ +from rl4caribou.agents.const_escapement import constEsc +from rl4caribou.agents.const_action import constAction \ No newline at end of file diff --git a/src/rl4caribou/agents/const_action.py b/src/rl4caribou/agents/const_action.py new file mode 100644 index 0000000..464af2f --- /dev/null +++ b/src/rl4caribou/agents/const_action.py @@ -0,0 +1,15 @@ +import numpy as np + +class constAction: + def __init__(self, mortality_vec=np.zeros(2, dtype=np.float32), env = None, **kwargs): + # + # preprocess + if isinstance(mortality_vec, list): + mortality_vec = np.float32(mortality_vec) + # + self.mortality_vec = mortality_vec + self.action = 2 * self.mortality_vec - 1 + self.env = env + + def predict(self, observation, **kwargs): + return self.action, {} \ No newline at end of file diff --git a/src/rl4caribou/agents/const_escapement.py b/src/rl4caribou/agents/const_escapement.py new file mode 100644 index 0000000..5d05be1 --- /dev/null +++ b/src/rl4caribou/agents/const_escapement.py @@ -0,0 +1,45 @@ +import numpy as np + +class constEsc: + def __init__(self, escapement_vec, env = None): + # + # preprocess + if isinstance(escapement_vec, list): + escapement_vec = np.float32(escapement_vec) + escapement_vec = np.clip( + escapement_vec, a_min = 0, a_max = None + ) + # + self.escapement_vec = escapement_vec + self.env = env + self.bound = 1 + if self.env is not None: + self.bound = self.env.bound + + def predict(self, observation, **kwargs): + obs_nat_units = self.bound * self.to_01(observation) + m_mort = self.moose_mortality(obs_nat_units[0]) + w_mort = self.wolf_mortality(obs_nat_units[2]) + mortality = np.float32([m_mort, w_mort]) + return self.to_pm1(mortality), {} + + def moose_mortality(self, moose_pop): + if moose_pop <= self.escapement_vec[0]: + return 0 + else: + return (moose_pop - self.escapement_vec[0]) / moose_pop + + def wolf_mortality(self, wolf_pop): + if wolf_pop <= self.escapement_vec[1]: + return 0 + else: + return (wolf_pop - self.escapement_vec[1]) / wolf_pop + + def to_01(self, val): + return (val + 1 ) / 2 + + def to_pm1(self, val): + return 2 * val - 1 + + + diff --git a/src/rl4caribou/envs/__init__.py b/src/rl4caribou/envs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rl4caribou/utils/__init__.py b/src/rl4caribou/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rl4caribou/utils/hugging_face.py b/src/rl4caribou/utils/hugging_face.py new file mode 100644 index 0000000..88665c1 --- /dev/null +++ b/src/rl4caribou/utils/hugging_face.py @@ -0,0 +1,17 @@ +from huggingface_sb3 import load_from_hub, package_to_hub +from huggingface_hub import HfApi + +from os.path import basename +import pathlib + +def upload_to_hf(path, path_in_repo, repo, clean=False): + api = HfApi() + if path_in_repo is None: + path_in_repo = basename(path) + api.upload_file( + path_or_fileobj=path, + path_in_repo=path_in_repo, + repo_id=repo, + repo_type="model") + if clean: + pathlib.Path(path).unlink() \ No newline at end of file diff --git a/src/rl4caribou/utils/sb3.py b/src/rl4caribou/utils/sb3.py new file mode 100644 index 0000000..12a572a --- /dev/null +++ b/src/rl4caribou/utils/sb3.py @@ -0,0 +1,71 @@ +import yaml +import os + +import gymnasium as gym +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3 import PPO, A2C, DQN, SAC, TD3, HER, DDPG +from sb3_contrib import TQC, ARS, RecurrentPPO + +def algorithm(algo): + algos = { + 'PPO': PPO, + 'ppo': PPO, + 'RecurrentPPO': RecurrentPPO, + 'RPPO': RecurrentPPO, + 'recurrentppo': RecurrentPPO, + 'rppo': RecurrentPPO, + # + 'ARS': ARS, + 'ars': ARS, + 'A2C': A2C, + 'a2c':A2C , + # + 'DDPG': DDPG, + 'ddpg': DDPG, + # + 'HER': HER, + 'her': HER, + # + 'SAC': SAC, + 'sac': SAC, + # + 'TD3': TD3, + 'td3': TD3, + # + 'TQC': TQC, + 'tqc': TQC, + } + return algos[algo] + +def sb3_train(config_file, **kwargs): + with open(config_file, "r") as stream: + options = yaml.safe_load(stream) + options = {**options, **kwargs} + # updates / expands on yaml options with optional user-provided input + + if "n_envs" in options: + env = make_vec_env( + options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]} + ) + else: + env = gym.make(options["env_id"]) + ALGO = algorithm(options["algo"]) + model_id = options["algo"] + "-" + options["env_id"] + "-" + options["id"] + save_id = os.path.join(options["save_path"], model_id) + + model = ALGO( + "MlpPolicy", + env, + verbose=0, + tensorboard_log=options["tensorboard"], + use_sde=options["use_sde"], + ) + + progress_bar = options.get("progress_bar", False) + model.learn(total_timesteps=options["total_timesteps"], tb_log_name=model_id, progress_bar=progress_bar) + + os.makedirs(options["save_path"], exist_ok=True) + model.save(save_id) + print(f"Saved {options['algo']} model at {save_id}") + + return save_id, options \ No newline at end of file diff --git a/tests/test_caribou.py b/tests/test_caribou.py index 1c1a66b..6deb5da 100644 --- a/tests/test_caribou.py +++ b/tests/test_caribou.py @@ -1,7 +1,23 @@ # Confirm environment is correctly defined: from stable_baselines3.common.env_checker import check_env +import numpy as np from rl4caribou import Caribou +from rl4caribou.agents import constAction, constEsc def test_Caribou(): check_env(Caribou(), warn=True) +def test_constAction(): + ca1 = constAction(mortality_vec = [0,0]) + ca2 = constAction(mortality_vec = np.zeros(2)) + obs = np.zeros(3) + pr1, _ = ca1.predict(obs) + pr2, _ = ca2.predict(obs) + +def test_constEsc(): + ce1 = constEsc(escapement_vec = [0,0]) + ce2 = constEsc(escapement_vec = np.zeros(2)) + obs = np.zeros(3) + pr1, _ = ce1.predict(obs) + pr2, _ = ce2.predict(obs) +