diff --git a/scripts/train.py b/scripts/train.py index d0d7122..18393a5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,13 +6,17 @@ import rl4caribou - +# training +# from rl4caribou.utils import sb3_train model_save_id, train_options = sb3_train(args.file) -from rl4caribou.utils import upload_to_hf -try: - upload_to_hf(args.file, "sb3/"+args.file, repo=train_options['repo']) - upload_to_hf(model_save_id, "sb3/"+model_save_id+".zip", repo=train_options['repo']) -except: - print("Couldn't upload to hf!") \ No newline at end of file +# hugging face +# +if 'repo' in train_options: + from rl4caribou.utils import upload_to_hf + try: + upload_to_hf(args.file, "sb3/"+args.file, repo=train_options['repo']) + upload_to_hf(model_save_id, "sb3/"+model_save_id+".zip", repo=train_options['repo']) + except: + print("Couldn't upload to hf!") \ No newline at end of file