Skip to content

Commit

Permalink
Debugging infinite-length episodes
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed May 15, 2024
1 parent 6a3808a commit 0af2469
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
36 changes: 20 additions & 16 deletions tmrl/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import shutil
import tempfile
import itertools
from os.path import exists

# third-party imports
Expand Down Expand Up @@ -641,10 +642,12 @@ def collect_train_episode(self, max_samples=None):
if max_samples is None:
max_samples = self.max_samples_per_episode

iterator = range(max_samples) if max_samples != np.inf else itertools.count()

ret = 0.0
steps = 0
obs, info = self.reset(collect_samples=True)
for i in range(max_samples):
for i in iterator:
obs, rew, terminated, truncated, info = self.step(obs=obs, test=False, collect_samples=True, last_step=i == max_samples - 1)
ret += rew
steps += 1
Expand All @@ -665,10 +668,10 @@ def run_episodes(self, max_samples_per_episode=None, nb_episodes=np.inf, train=F
if max_samples_per_episode is None:
max_samples_per_episode = self.max_samples_per_episode

counter = 0
while counter < nb_episodes:
iterator = range(nb_episodes) if nb_episodes != np.inf else itertools.count()

for _ in iterator:
self.run_episode(max_samples_per_episode, train=train)
counter += 1

def run_episode(self, max_samples=None, train=False):
"""
Expand All @@ -683,10 +686,12 @@ def run_episode(self, max_samples=None, train=False):
if max_samples is None:
max_samples = self.max_samples_per_episode

iterator = range(max_samples) if max_samples != np.inf else itertools.count()

ret = 0.0
steps = 0
obs, info = self.reset(collect_samples=False)
for _ in range(max_samples):
for _ in iterator:
obs, rew, terminated, truncated, info = self.step(obs=obs, test=not train, collect_samples=False)
ret += rew
steps += 1
Expand All @@ -711,39 +716,36 @@ def run(self, test_episode_interval=0, nb_episodes=np.inf, verbose=True, expert=
expert (bool): experts send training samples without updating their model nor running test episodes.
"""

episode = 0
iterator = range(nb_episodes) if nb_episodes != np.inf else itertools.count()

if expert:
if not verbose:
while episode < nb_episodes:
for _ in iterator:
self.collect_train_episode(self.max_samples_per_episode)
self.send_and_clear_buffer()
self.ignore_actor_weights()
episode += 1
else:
while episode < nb_episodes:
for _ in iterator:
print_with_timestamp("collecting expert episode")
self.collect_train_episode(self.max_samples_per_episode)
print_with_timestamp("copying buffer for sending")
self.send_and_clear_buffer()
self.ignore_actor_weights()
episode += 1
elif not verbose:
if not test_episode_interval:
while episode < nb_episodes:
for _ in iterator:
self.collect_train_episode(self.max_samples_per_episode)
self.send_and_clear_buffer()
self.update_actor_weights(verbose=False)
episode += 1
else:
while episode < nb_episodes:
for episode in iterator:
if episode % test_episode_interval == 0 and not self.crc_debug:
self.run_episode(self.max_samples_per_episode, train=False)
self.collect_train_episode(self.max_samples_per_episode)
self.send_and_clear_buffer()
self.update_actor_weights(verbose=False)
episode += 1
else:
while episode < nb_episodes:
for episode in iterator:
if test_episode_interval and episode % test_episode_interval == 0 and not self.crc_debug:
print_with_timestamp("running test episode")
self.run_episode(self.max_samples_per_episode, train=False)
Expand All @@ -753,7 +755,6 @@ def run(self, test_episode_interval=0, nb_episodes=np.inf, verbose=True, expert=
self.send_and_clear_buffer()
print_with_timestamp("checking for new weights")
self.update_actor_weights(verbose=True)
episode += 1

def run_synchronous(self,
test_episode_interval=0,
Expand Down Expand Up @@ -908,6 +909,9 @@ def run_env_benchmark(self, nb_steps, test=False, verbose=True):
test (int): whether the actor is called in test or train mode
verbose (bool): whether to log INFO messages
"""
if nb_steps == np.inf or nb_steps < 0:
raise RuntimeError(f"Invalid number of steps: {nb_steps}")

obs, info = self.reset(collect_samples=False)
for _ in range(nb_steps):
obs, rew, terminated, truncated, info = self.step(obs=obs, test=test, collect_samples=False)
Expand Down
4 changes: 2 additions & 2 deletions tmrl/tuto/competition/competition_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tmrl.config.config_constants as cfg
import tmrl.config.config_objects as cfg_obj

from tmrl.tuto.competition.custom_actor_module import MyActorModule # change this to match your ActorModule name
from custom_actor_module import MyActorModule # change this to match your ActorModule name


# rtgym environment class (full TrackMania2020 Gymnasium environment with replays enabled):
Expand All @@ -23,7 +23,7 @@
device_worker = 'cpu'

try:
from tmrl.tuto.competition.custom_actor_module import obs_preprocessor
from custom_actor_module import obs_preprocessor
except Exception as e:
obs_preprocessor = cfg_obj.OBS_PREPROCESSOR

Expand Down
10 changes: 7 additions & 3 deletions tmrl/tuto/competition/custom_actor_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,13 @@ def train(self, batch):
server_port=server_port,
password=password,
security=security)
my_trainer.run_with_wandb(entity=wandb_entity,
project=wandb_project,
run_id=wandb_run_id)
my_trainer.run()

# Note: if you want to log training metrics to wandb, replace my_trainer.run() with:
# my_trainer.run_with_wandb(entity=wandb_entity,
# project=wandb_project,
# run_id=wandb_run_id)

elif args.worker or args.test:
rw = RolloutWorker(env_cls=env_cls,
actor_module_cls=MyActorModule,
Expand Down

0 comments on commit 0af2469

Please sign in to comment.