Skip to content

Commit

Permalink
Adding RuntimeMetricWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Tviskaron committed Jun 9, 2024
1 parent 745a937 commit 3925565
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions pogema/wrappers/metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import numpy as np
from gymnasium import Wrapper

Expand Down Expand Up @@ -150,3 +152,28 @@ def reset(self, **kwargs):
observations, info = self.env.reset(**kwargs)
self.count_agents(observations)
return observations, info


class RuntimeMetricWrapper(Wrapper):
def __init__(self, env):
super().__init__(env)
self._start_time = None
self._env_step_time = None

def step(self, actions):
env_step_start = time.monotonic()
observations, rewards, terminated, truncated, infos = self.env.step(actions)
env_step_end = time.monotonic()
self._env_step_time += env_step_end - env_step_start
if all(terminated) or all(truncated):
final_time = time.monotonic() - self._start_time - self._env_step_time
if 'metrics' not in infos[0]:
infos[0]['metrics'] = {}
infos[0]['metrics'].update(runtime=final_time)
return observations, rewards, terminated, truncated, infos

def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
self._start_time = time.monotonic()
self._env_step_time = 0.0
return obs

0 comments on commit 3925565

Please sign in to comment.