diff --git a/pogema/wrappers/metrics.py b/pogema/wrappers/metrics.py index 409d1a1..984235a 100644 --- a/pogema/wrappers/metrics.py +++ b/pogema/wrappers/metrics.py @@ -1,3 +1,5 @@ +import time + import numpy as np from gymnasium import Wrapper @@ -150,3 +152,28 @@ def reset(self, **kwargs): observations, info = self.env.reset(**kwargs) self.count_agents(observations) return observations, info + + +class RuntimeMetricWrapper(Wrapper): + def __init__(self, env): + super().__init__(env) + self._start_time = None + self._env_step_time = None + + def step(self, actions): + env_step_start = time.monotonic() + observations, rewards, terminated, truncated, infos = self.env.step(actions) + env_step_end = time.monotonic() + self._env_step_time += env_step_end - env_step_start + if all(terminated) or all(truncated): + final_time = time.monotonic() - self._start_time - self._env_step_time + if 'metrics' not in infos[0]: + infos[0]['metrics'] = {} + infos[0]['metrics'].update(runtime=final_time) + return observations, rewards, terminated, truncated, infos + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + self._start_time = time.monotonic() + self._env_step_time = 0.0 + return obs