Skip to content

Commit

Permalink
Updated competition script
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed May 15, 2024
1 parent 0af2469 commit 0557136
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tmrl/tuto/competition/competition_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
device=device_worker,
obs_preprocessor=obs_preprocessor,
standalone=True)
rw.run()
rw.run_episodes()
12 changes: 7 additions & 5 deletions tmrl/tuto/competition/custom_actor_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"PORT": <port of the server (usually requires port forwarding if accessed via the Internet)>,
If you are training over the Internet, please read the security instructions on the TMRL GitHub page.
IMPORTANT: Set a custom 'RUN_NAME' in config.json, otherwise this script will not work.
"""

# Let us start our tutorial by importing some useful stuff.
Expand Down Expand Up @@ -795,11 +797,11 @@ def train(self, batch):

training_agent_cls = partial(SACTrainingAgent,
model_cls=VanillaCNNActorCritic,
gamma=0.99,
gamma=0.995,
polyak=0.995,
alpha=0.02,
lr_actor=0.000005,
lr_critic=0.00003)
alpha=0.01,
lr_actor=0.00001,
lr_critic=0.00005)


# =====================================================================
Expand Down Expand Up @@ -870,7 +872,7 @@ def train(self, batch):
max_samples_per_episode=max_samples_per_episode,
obs_preprocessor=obs_preprocessor,
standalone=args.test)
rw.run()
rw.run(test_episode_interval=10)
elif args.server:
import time
serv = Server(port=server_port,
Expand Down

0 comments on commit 0557136

Please sign in to comment.