diff --git a/tmrl/networking.py b/tmrl/networking.py index 8060107..beba8b1 100644 --- a/tmrl/networking.py +++ b/tmrl/networking.py @@ -7,6 +7,7 @@ import json import shutil import tempfile +import itertools from os.path import exists # third-party imports @@ -641,10 +642,12 @@ def collect_train_episode(self, max_samples=None): if max_samples is None: max_samples = self.max_samples_per_episode + iterator = range(max_samples) if max_samples != np.inf else itertools.count() + ret = 0.0 steps = 0 obs, info = self.reset(collect_samples=True) - for i in range(max_samples): + for i in iterator: obs, rew, terminated, truncated, info = self.step(obs=obs, test=False, collect_samples=True, last_step=i == max_samples - 1) ret += rew steps += 1 @@ -665,10 +668,10 @@ def run_episodes(self, max_samples_per_episode=None, nb_episodes=np.inf, train=F if max_samples_per_episode is None: max_samples_per_episode = self.max_samples_per_episode - counter = 0 - while counter < nb_episodes: + iterator = range(nb_episodes) if nb_episodes != np.inf else itertools.count() + + for _ in iterator: self.run_episode(max_samples_per_episode, train=train) - counter += 1 def run_episode(self, max_samples=None, train=False): """ @@ -683,10 +686,12 @@ def run_episode(self, max_samples=None, train=False): if max_samples is None: max_samples = self.max_samples_per_episode + iterator = range(max_samples) if max_samples != np.inf else itertools.count() + ret = 0.0 steps = 0 obs, info = self.reset(collect_samples=False) - for _ in range(max_samples): + for _ in iterator: obs, rew, terminated, truncated, info = self.step(obs=obs, test=not train, collect_samples=False) ret += rew steps += 1 @@ -711,39 +716,36 @@ def run(self, test_episode_interval=0, nb_episodes=np.inf, verbose=True, expert= expert (bool): experts send training samples without updating their model nor running test episodes. """ - episode = 0 + iterator = range(nb_episodes) if nb_episodes != np.inf else itertools.count() + if expert: if not verbose: - while episode < nb_episodes: + for _ in iterator: self.collect_train_episode(self.max_samples_per_episode) self.send_and_clear_buffer() self.ignore_actor_weights() - episode += 1 else: - while episode < nb_episodes: + for _ in iterator: print_with_timestamp("collecting expert episode") self.collect_train_episode(self.max_samples_per_episode) print_with_timestamp("copying buffer for sending") self.send_and_clear_buffer() self.ignore_actor_weights() - episode += 1 elif not verbose: if not test_episode_interval: - while episode < nb_episodes: + for _ in iterator: self.collect_train_episode(self.max_samples_per_episode) self.send_and_clear_buffer() self.update_actor_weights(verbose=False) - episode += 1 else: - while episode < nb_episodes: + for episode in iterator: if episode % test_episode_interval == 0 and not self.crc_debug: self.run_episode(self.max_samples_per_episode, train=False) self.collect_train_episode(self.max_samples_per_episode) self.send_and_clear_buffer() self.update_actor_weights(verbose=False) - episode += 1 else: - while episode < nb_episodes: + for episode in iterator: if test_episode_interval and episode % test_episode_interval == 0 and not self.crc_debug: print_with_timestamp("running test episode") self.run_episode(self.max_samples_per_episode, train=False) @@ -753,7 +755,6 @@ def run(self, test_episode_interval=0, nb_episodes=np.inf, verbose=True, expert= self.send_and_clear_buffer() print_with_timestamp("checking for new weights") self.update_actor_weights(verbose=True) - episode += 1 def run_synchronous(self, test_episode_interval=0, @@ -908,6 +909,9 @@ def run_env_benchmark(self, nb_steps, test=False, verbose=True): test (int): whether the actor is called in test or train mode verbose (bool): whether to log INFO messages """ + if nb_steps == np.inf or nb_steps < 0: + raise RuntimeError(f"Invalid number of steps: {nb_steps}") + obs, info = self.reset(collect_samples=False) for _ in range(nb_steps): obs, rew, terminated, truncated, info = self.step(obs=obs, test=test, collect_samples=False) diff --git a/tmrl/tuto/competition/competition_eval.py b/tmrl/tuto/competition/competition_eval.py index f5c9dca..c08cbcb 100644 --- a/tmrl/tuto/competition/competition_eval.py +++ b/tmrl/tuto/competition/competition_eval.py @@ -11,7 +11,7 @@ import tmrl.config.config_constants as cfg import tmrl.config.config_objects as cfg_obj -from tmrl.tuto.competition.custom_actor_module import MyActorModule # change this to match your ActorModule name +from custom_actor_module import MyActorModule # change this to match your ActorModule name # rtgym environment class (full TrackMania2020 Gymnasium environment with replays enabled): @@ -23,7 +23,7 @@ device_worker = 'cpu' try: - from tmrl.tuto.competition.custom_actor_module import obs_preprocessor + from custom_actor_module import obs_preprocessor except Exception as e: obs_preprocessor = cfg_obj.OBS_PREPROCESSOR diff --git a/tmrl/tuto/competition/custom_actor_module.py b/tmrl/tuto/competition/custom_actor_module.py index cd73421..710182d 100644 --- a/tmrl/tuto/competition/custom_actor_module.py +++ b/tmrl/tuto/competition/custom_actor_module.py @@ -851,9 +851,13 @@ def train(self, batch): server_port=server_port, password=password, security=security) - my_trainer.run_with_wandb(entity=wandb_entity, - project=wandb_project, - run_id=wandb_run_id) + my_trainer.run() + + # Note: if you want to log training metrics to wandb, replace my_trainer.run() with: + # my_trainer.run_with_wandb(entity=wandb_entity, + # project=wandb_project, + # run_id=wandb_run_id) + elif args.worker or args.test: rw = RolloutWorker(env_cls=env_cls, actor_module_cls=MyActorModule,