diff --git a/pogema/__init__.py b/pogema/__init__.py
index 1a956e5..26ff7e5 100644
--- a/pogema/__init__.py
+++ b/pogema/__init__.py
@@ -1,6 +1,7 @@
from gymnasium import register
from pogema.grid_config import GridConfig
from pogema.integrations.make_pogema import pogema_v0
+from pogema.svg_animation.animation_wrapper import AnimationMonitor, AnimationConfig
from pogema.a_star_policy import AStarAgent, BatchAStarAgent
from pogema.grid_config import Easy8x8, Normal8x8, Hard8x8, ExtraHard8x8
@@ -8,12 +9,14 @@
from pogema.grid_config import Easy32x32, Normal32x32, Hard32x32, ExtraHard32x32
from pogema.grid_config import Easy64x64, Normal64x64, Hard64x64, ExtraHard64x64
-__version__ = '1.2.2'
+__version__ = '1.2.3a2'
__all__ = [
'GridConfig',
'pogema_v0',
'AStarAgent', 'BatchAStarAgent',
+ "AnimationMonitor", "AnimationConfig",
+
'Easy8x8', 'Normal8x8', 'Hard8x8', 'ExtraHard8x8',
'Easy16x16', 'Normal16x16', 'Hard16x16', 'ExtraHard16x16',
'Easy32x32', 'Normal32x32', 'Hard32x32', 'ExtraHard32x32',
diff --git a/pogema/animation.py b/pogema/animation.py
deleted file mode 100644
index c992881..0000000
--- a/pogema/animation.py
+++ /dev/null
@@ -1,720 +0,0 @@
-import os
-import typing
-from itertools import cycle
-from gymnasium import logger, Wrapper
-
-from pydantic import BaseModel
-
-from pogema import GridConfig, pogema_v0
-from pogema.grid import Grid
-from pogema.wrappers.persistence import PersistentWrapper, AgentState
-
-
-class AnimationSettings(BaseModel):
- """
- Settings for the animation.
- """
- r: int = 35
- stroke_width: int = 10
- scale_size: int = 100
- time_scale: float = 0.28
- draw_start: int = 100
- rx: int = 15
-
- obstacle_color: str = '#84A1AE'
- ego_color: str = '#c1433c'
- ego_other_color: str = '#72D5C8'
- shaded_opacity: float = 0.2
- egocentric_shaded: bool = True
- stroke_dasharray: int = 25
-
- colors: list = [
- '#c1433c',
- '#2e6f9e',
- '#6e81af',
- '#00b9c8',
- '#72D5C8',
- '#0ea08c',
- '#8F7B66',
- ]
-
-
-class AnimationConfig(BaseModel):
- """
- Configuration for the animation.
- """
- directory: str = 'renders/'
- static: bool = False
- show_agents: bool = True
- egocentric_idx: typing.Optional[int] = None
- uid: typing.Optional[str] = None
- save_every_idx_episode: typing.Optional[int] = 1
- show_border: bool = True
- show_lines: bool = False
-
-
-class GridHolder(BaseModel):
- """
- Holds the grid and the history.
- """
- obstacles: typing.Any = None
- episode_length: int = None
- height: int = None
- width: int = None
- colors: dict = None
- history: list = None
-
-
-class SvgObject:
- """
- Main class for the SVG.
- """
- tag = None
-
- def __init__(self, **kwargs):
- self.attributes = kwargs
- self.animations = []
-
- def add_animation(self, animation):
- self.animations.append(animation)
-
- @staticmethod
- def render_attributes(attributes):
- result = " ".join([f'{x.replace("_", "-")}="{y}"' for x, y in sorted(attributes.items())])
- return result
-
- def render(self):
- animations = '\n'.join([a.render() for a in self.animations]) if self.animations else None
- if animations:
- return f"<{self.tag} {self.render_attributes(self.attributes)}> {animations} {self.tag}>"
- return f"<{self.tag} {self.render_attributes(self.attributes)} />"
-
-
-class Rectangle(SvgObject):
- """
- Rectangle class for the SVG.
- """
- tag = 'rect'
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.attributes['y'] = -self.attributes['y'] - self.attributes['height']
-
-
-class Circle(SvgObject):
- """
- Circle class for the SVG.
- """
- tag = 'circle'
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.attributes['cy'] = -self.attributes['cy']
-
-
-class Line(SvgObject):
- """
- Line class for the SVG.
- """
- tag = 'line'
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.attributes['y1'] = -self.attributes['y1']
- self.attributes['y2'] = -self.attributes['y2']
-
-
-class Animation(SvgObject):
- """
- Animation class for the SVG.
- """
- tag = 'animate'
-
- def render(self):
- return f"<{self.tag} {self.render_attributes(self.attributes)}/>"
-
-
-class Drawing:
- """
- Drawing, analog of the DrawSvg class in the pogema package.
- """
-
- def __init__(self, height, width, display_inline=False, origin=(0, 0)):
- self.height = height
- self.width = width
- self.display_inline = display_inline
- self.origin = origin
- self.elements = []
-
- def add_element(self, element):
- self.elements.append(element)
-
- def render(self):
- view_box = (0, -self.height, self.width, self.height)
- results = [f'''
- ')
- return "\n".join(results)
-
-
-class AnimationMonitor(Wrapper):
- """
- Defines the animation, which saves the episode as SVG.
- """
-
- def __init__(self, env, animation_config=AnimationConfig()):
- # Wrapping env using PersistenceWrapper for saving the history.
- env = PersistentWrapper(env)
- super().__init__(env)
-
- self.history = self.env.get_history()
-
- self.svg_settings: AnimationSettings = AnimationSettings()
- self.animation_config: AnimationConfig = animation_config
-
- self._episode_idx = 0
-
- def step(self, action):
- """
- Saves information about the episode.
- :param action: current actions
- :return: obs, reward, done, info
- """
- obs, reward, terminated, truncated, info = self.env.step(action)
-
- multi_agent_terminated = isinstance(terminated, (list, tuple)) and all(terminated)
- single_agent_terminated = isinstance(terminated, (bool, int)) and terminated
- multi_agent_truncated = isinstance(truncated, (list, tuple)) and all(truncated)
- single_agent_truncated = isinstance(truncated, (bool, int)) and truncated
-
- if multi_agent_terminated or single_agent_terminated or multi_agent_truncated or single_agent_truncated:
- save_tau = self.animation_config.save_every_idx_episode
- if save_tau:
- if (self._episode_idx + 1) % save_tau or save_tau == 1:
- if not os.path.exists(self.animation_config.directory):
- logger.info(f"Creating pogema monitor directory {self.animation_config.directory}", )
- os.makedirs(self.animation_config.directory, exist_ok=True)
-
- path = os.path.join(self.animation_config.directory,
- self.pick_name(self.grid_config, self._episode_idx))
- self.save_animation(path)
-
- return obs, reward, terminated, truncated, info
-
- @staticmethod
- def pick_name(grid_config: GridConfig, episode_idx=None, zfill_ep=5):
- """
- Picks a name for the SVG file.
- :param grid_config: configuration of the grid
- :param episode_idx: idx of the episode
- :param zfill_ep: zfill for the episode number
- :return:
- """
- gc = grid_config
- name = 'pogema'
- if episode_idx is not None:
- name += f'-ep{str(episode_idx).zfill(zfill_ep)}'
- if gc:
- if gc.map_name:
- name += f'-{gc.map_name}'
- if gc.seed is not None:
- name += f'-seed{gc.seed}'
- else:
- name += '-render'
- return name + '.svg'
-
- def reset(self, **kwargs):
- """
- Resets the environment and resets the current positions of agents and targets
- :param kwargs:
- :return: obs: observation
- """
- obs = self.env.reset(**kwargs)
-
- self._episode_idx += 1
- self.history = self.env.get_history()
-
- return obs
-
- def create_animation(self, animation_config=None):
- """
- Creates the animation.
- :param animation_config: configuration of the animation
- :return: drawing: drawing object
- """
- anim_cfg = animation_config
- if anim_cfg is None:
- anim_cfg = self.animation_config
-
- grid: Grid = self.grid
- cfg = self.svg_settings
- colors = cycle(cfg.colors)
- agents_colors = {index: next(colors) for index in range(self.grid_config.num_agents)}
-
- if anim_cfg.egocentric_idx is not None:
- anim_cfg.egocentric_idx %= self.grid_config.num_agents
-
- decompressed_history: list[list[AgentState]] = self.env.decompress_history(self.history)
-
- # Change episode length for egocentric environment
- if anim_cfg.egocentric_idx is not None:
- episode_length = decompressed_history[anim_cfg.egocentric_idx][-1].step + 1
- for agent_idx in range(self.grid_config.num_agents):
- decompressed_history[agent_idx] = decompressed_history[agent_idx][:episode_length]
- else:
- episode_length = len(decompressed_history[0])
-
- # Add last observation one more time to highlight the final state
- for agent_idx in range(self.grid_config.num_agents):
- decompressed_history[agent_idx].append(decompressed_history[agent_idx][-1])
-
- # Change episode length for static environment
- if anim_cfg.static:
- episode_length = 1
- decompressed_history = [[decompressed_history[idx][-1]] for idx in range(len(decompressed_history))]
-
- gh = GridHolder(width=len(grid.obstacles), height=len(grid.obstacles[0]),
- obstacles=grid.obstacles,
- colors=agents_colors,
- episode_length=episode_length,
- history=decompressed_history, )
-
- render_width, render_height = gh.height * cfg.scale_size + cfg.scale_size, gh.width * cfg.scale_size + cfg.scale_size
-
- drawing = Drawing(width=render_width, height=render_height, display_inline=False, origin=(0, 0))
- obstacles = self.create_obstacles(gh, anim_cfg)
-
- agents = []
- targets = []
-
- if anim_cfg.show_agents:
- agents = self.create_agents(gh, anim_cfg)
- targets = self.create_targets(gh, anim_cfg)
-
- if not anim_cfg.static:
- self.animate_agents(agents, anim_cfg.egocentric_idx, gh)
- self.animate_targets(targets, gh, anim_cfg)
- if anim_cfg.show_lines:
- grid_lines = self.create_grid_lines(gh, anim_cfg, render_width, render_height)
- for line in grid_lines:
- drawing.add_element(line)
- for obj in [*obstacles, *agents, *targets, ]:
- drawing.add_element(obj)
-
- if anim_cfg.egocentric_idx is not None:
- field_of_view = self.create_field_of_view(grid_holder=gh, animation_config=anim_cfg)
- if not anim_cfg.static:
- self.animate_obstacles(obstacles=obstacles, grid_holder=gh, animation_config=anim_cfg)
- self.animate_field_of_view(field_of_view, anim_cfg.egocentric_idx, gh)
- drawing.add_element(field_of_view)
-
- return drawing
-
- def create_grid_lines(self, grid_holder: GridHolder, animation_config: AnimationConfig, render_width,
- render_height):
- """
- Creates the grid lines.
- :param grid_holder: grid holder
- :param animation_config: animation configuration
- :return: grid_lines: list of grid lines
- """
- cfg = self.svg_settings
- grid_lines = []
- for i in range(-1, grid_holder.height + 1):
- # vertical lines
- x0 = x1 = i * cfg.scale_size + cfg.scale_size / 2
- y0 = 0
- y1 = render_height
- grid_lines.append(
- Line(x1=x0, y1=y0, x2=x1, y2=y1, stroke=cfg.obstacle_color, stroke_width=cfg.stroke_width // 1.5))
- for i in range(-1, grid_holder.width + 1):
- # continue
- # horizontal lines
- x0 = 0
- y0 = y1 = i * cfg.scale_size + cfg.scale_size / 2
- x1 = render_width
- grid_lines.append(
- Line(x1=x0, y1=y0, x2=x1, y2=y1, stroke=cfg.obstacle_color, stroke_width=cfg.stroke_width // 1.5))
-
- # for i in range(grid_holder.width):
- # grid_lines.append(Line(start=(0, i * cfg.scale_size),
- # end=(grid_holder.height * cfg.scale_size, i * cfg.scale_size),
- # stroke=cfg.grid_color, stroke_width=cfg.grid_width))
- return grid_lines
-
- def save_animation(self, name='render.svg', animation_config: typing.Optional[AnimationConfig] = None):
- """
- Saves the animation.
- :param name: name of the file
- :param animation_config: animation configuration
- :return: None
- """
- animation = self.create_animation(animation_config)
- with open(name, "w") as f:
- f.write(animation.render())
-
- @staticmethod
- def fix_point(x, y, length):
- """
- Fixes the point to the grid.
- :param x: coordinate x
- :param y: coordinate y
- :param length: size of the grid
- :return: x, y: fixed coordinates
- """
- return length - y - 1, x
-
- @staticmethod
- def check_in_radius(x1, y1, x2, y2, r) -> bool:
- """
- Checks if the point is in the radius.
- :param x1: coordinate x1
- :param y1: coordinate y1
- :param x2: coordinate x2
- :param y2: coordinate y2
- :param r: radius
- :return:
- """
- return x2 - r <= x1 <= x2 + r and y2 - r <= y1 <= y2 + r
-
- def create_field_of_view(self, grid_holder, animation_config):
- """
- Creates the field of view for the egocentric agent.
- :param grid_holder:
- :param animation_config:
- :return:
- """
- cfg = self.svg_settings
- gh: GridHolder = grid_holder
- ego_idx = animation_config.egocentric_idx
- x, y = gh.history[ego_idx][0].get_xy()
- cx = cfg.draw_start + y * cfg.scale_size
- cy = cfg.draw_start + (gh.width - x - 1) * cfg.scale_size
-
- dr = (self.grid_config.obs_radius + 1) * cfg.scale_size - cfg.stroke_width * 2
- result = Rectangle(x=cx - dr + cfg.r, y=cy - dr + cfg.r,
- width=2 * dr - 2 * cfg.r, height=2 * dr - 2 * cfg.r,
- stroke=cfg.ego_color, stroke_width=cfg.stroke_width,
- fill='none',
- rx=cfg.rx, stroke_dasharray=cfg.stroke_dasharray,
- )
-
- return result
-
- def animate_field_of_view(self, view, agent_idx, grid_holder):
- """
- Animates the field of view.
- :param view:
- :param agent_idx:
- :param grid_holder:
- :return:
- """
- gh: GridHolder = grid_holder
- cfg = self.svg_settings
- x_path = []
- y_path = []
- for state in gh.history[agent_idx]:
- x, y = state.get_xy()
- dr = (self.grid_config.obs_radius + 1) * cfg.scale_size - cfg.stroke_width * 2
- cx = cfg.draw_start + y * cfg.scale_size
- cy = -cfg.draw_start + -(gh.width - x - 1) * cfg.scale_size
- x_path.append(str(cx - dr + cfg.r))
- y_path.append(str(cy - dr + cfg.r))
-
- visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
-
- view.add_animation(self.compressed_anim('x', x_path, cfg.time_scale))
- view.add_animation(self.compressed_anim('y', y_path, cfg.time_scale))
- view.add_animation(self.compressed_anim('visibility', visibility, cfg.time_scale))
-
- def animate_agents(self, agents, egocentric_idx, grid_holder):
- """
- Animates the agents.
- :param agents:
- :param egocentric_idx:
- :param grid_holder:
- :return:
- """
- gh: GridHolder = grid_holder
- cfg = self.svg_settings
- for agent_idx, agent in enumerate(agents):
- x_path = []
- y_path = []
- opacity = []
- for agent_state in gh.history[agent_idx]:
- x, y = agent_state.get_xy()
-
- x_path.append(str(cfg.draw_start + y * cfg.scale_size))
- y_path.append(str(-cfg.draw_start + -(gh.width - x - 1) * cfg.scale_size))
-
- if egocentric_idx is not None:
- ego_x, ego_y = agent_state.get_xy()
- if self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius):
- opacity.append('1.0')
- else:
- opacity.append(str(cfg.shaded_opacity))
-
- visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
-
- agent.add_animation(self.compressed_anim('cy', y_path, cfg.time_scale))
- agent.add_animation(self.compressed_anim('cx', x_path, cfg.time_scale))
- agent.add_animation(self.compressed_anim('visibility', visibility, cfg.time_scale))
- if opacity:
- agent.add_animation(self.compressed_anim('opacity', opacity, cfg.time_scale))
-
- @classmethod
- def compressed_anim(cls, attr_name, tokens, time_scale, rep_cnt='indefinite'):
- """
- Compresses the animation.
- :param attr_name:
- :param tokens:
- :param time_scale:
- :param rep_cnt:
- :return:
- """
- tokens, times = cls.compress_tokens(tokens)
- cumulative = [0, ]
- for t in times:
- cumulative.append(cumulative[-1] + t)
- times = [str(round(value / cumulative[-1], 10)) for value in cumulative]
- tokens = [tokens[0]] + tokens
-
- times = times
- tokens = tokens
- return Animation(attributeName=attr_name,
- dur=f'{time_scale * (-1 + cumulative[-1])}s',
- values=";".join(tokens),
- repeatCount=rep_cnt,
- keyTimes=";".join(times))
-
- @staticmethod
- def wisely_add(token, cnt, tokens, times):
- """
- Adds the token to the tokens and times.
- :param token:
- :param cnt:
- :param tokens:
- :param times:
- :return:
- """
- if cnt > 1:
- tokens += [token, token]
- times += [1, cnt - 1]
- else:
- tokens.append(token)
- times.append(cnt)
-
- @classmethod
- def compress_tokens(cls, input_tokens: list):
- """
- Compresses the tokens.
- :param input_tokens:
- :return:
- """
- tokens = []
- times = []
- if input_tokens:
- cur_idx = 0
- cnt = 1
- for idx in range(1, len(input_tokens)):
- if input_tokens[idx] == input_tokens[cur_idx]:
- cnt += 1
- else:
- cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times)
- cnt = 1
- cur_idx = idx
- cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times)
- return tokens, times
-
- def animate_targets(self, targets, grid_holder, animation_config):
- """
- Animates the targets.
- :param targets:
- :param grid_holder:
- :param animation_config:
- :return:
- """
- gh: GridHolder = grid_holder
- cfg = self.svg_settings
- ego_idx = animation_config.egocentric_idx
-
- for agent_idx, target in enumerate(targets):
- target_idx = ego_idx if ego_idx is not None else agent_idx
-
- x_path = []
- y_path = []
-
- for step_idx, state in enumerate(gh.history[target_idx]):
- x, y = state.get_target_xy()
- x_path.append(str(cfg.draw_start + y * cfg.scale_size))
- y_path.append(str(-cfg.draw_start + -(gh.width - x - 1) * cfg.scale_size))
-
- visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
-
- if self.grid_config.on_target == 'restart':
- target.add_animation(self.compressed_anim('cy', y_path, cfg.time_scale))
- target.add_animation(self.compressed_anim('cx', x_path, cfg.time_scale))
- target.add_animation(self.compressed_anim("visibility", visibility, cfg.time_scale))
-
- def create_obstacles(self, grid_holder, animation_config):
- """
- Creates the obstacles.
- :param grid_holder:
- :param animation_config:
- :return:
- """
- gh = grid_holder
- cfg = self.svg_settings
-
- result = []
- r = self.grid_config.obs_radius
- for i in range(gh.height):
- for j in range(gh.width):
- x, y = self.fix_point(i, j, gh.width)
- if not animation_config.show_border:
- if i == r - 1 or j == r - 1 or j == gh.width - r or i == gh.height - r:
- continue
- if gh.obstacles[x][y] != self.grid_config.FREE:
- obs_settings = {}
- obs_settings.update(x=cfg.draw_start + i * cfg.scale_size - cfg.r,
- y=cfg.draw_start + j * cfg.scale_size - cfg.r,
- width=cfg.r * 2,
- height=cfg.r * 2,
- rx=cfg.rx,
- fill=self.svg_settings.obstacle_color)
-
- if animation_config.egocentric_idx is not None and cfg.egocentric_shaded:
- initial_positions = [agent_states[0].get_xy() for agent_states in gh.history]
- ego_x, ego_y = initial_positions[animation_config.egocentric_idx]
- if not self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius):
- obs_settings.update(opacity=cfg.shaded_opacity)
-
- result.append(Rectangle(**obs_settings))
-
- return result
-
- def animate_obstacles(self, obstacles, grid_holder, animation_config):
- """
-
- :param obstacles:
- :param grid_holder:
- :param animation_config:
- :return:
- """
- gh: GridHolder = grid_holder
- obstacle_idx = 0
- cfg = self.svg_settings
-
- for i in range(gh.height):
- for j in range(gh.width):
- x, y = self.fix_point(i, j, gh.width)
- if gh.obstacles[x][y] == self.grid_config.FREE:
- continue
- opacity = []
- seen = set()
- for step_idx, agent_state in enumerate(gh.history[animation_config.egocentric_idx]):
- ego_x, ego_y = agent_state.get_xy()
- if self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius):
- seen.add((x, y))
- if (x, y) in seen:
- opacity.append(str(1.0))
- else:
- opacity.append(str(cfg.shaded_opacity))
-
- obstacle = obstacles[obstacle_idx]
- obstacle.add_animation(self.compressed_anim('opacity', opacity, cfg.time_scale))
-
- obstacle_idx += 1
-
- def create_agents(self, grid_holder, animation_config):
- """
- Creates the agents.
- :param grid_holder:
- :param animation_config:
- :return:
- """
- gh: GridHolder = grid_holder
- cfg = self.svg_settings
-
- agents = []
- initial_positions = [agent_states[0].get_xy() for agent_states in gh.history]
- for idx, (x, y) in enumerate(initial_positions):
-
- if not any([agent_state.is_active() for agent_state in gh.history[idx]]):
- continue
-
- circle_settings = {}
- circle_settings.update(cx=cfg.draw_start + y * cfg.scale_size,
- cy=cfg.draw_start + (gh.width - x - 1) * cfg.scale_size,
- r=cfg.r, fill=gh.colors[idx])
- ego_idx = animation_config.egocentric_idx
- if ego_idx is not None:
- ego_x, ego_y = initial_positions[ego_idx]
- if not self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius) and cfg.egocentric_shaded:
- circle_settings.update(opacity=cfg.shaded_opacity)
- if ego_idx == idx:
- circle_settings.update(fill=self.svg_settings.ego_color)
- else:
- circle_settings.update(fill=self.svg_settings.ego_other_color)
- agent = Circle(**circle_settings)
- agents.append(agent)
-
- return agents
-
- def create_targets(self, grid_holder, animation_config):
- """
- Creates the targets.
- :param grid_holder:
- :param animation_config:
- :return:
- """
- gh: GridHolder = grid_holder
- cfg = self.svg_settings
- targets = []
- for agent_idx, agent_states in enumerate(gh.history):
-
- tx, ty = agent_states[0].get_target_xy()
- x, y = ty, gh.width - tx - 1
-
- if not any([agent_state.is_active() for agent_state in gh.history[agent_idx]]):
- continue
-
- circle_settings = {}
- circle_settings.update(cx=cfg.draw_start + x * cfg.scale_size,
- cy=cfg.draw_start + y * cfg.scale_size,
- r=cfg.r,
- stroke=gh.colors[agent_idx], stroke_width=cfg.stroke_width, fill='none')
- if animation_config.egocentric_idx is not None:
- if animation_config.egocentric_idx != agent_idx:
- continue
-
- circle_settings.update(stroke=cfg.ego_color)
- target = Circle(**circle_settings)
- targets.append(target)
- return targets
-
-
-def main():
- grid_config = GridConfig(size=8, num_agents=5, obs_radius=2, seed=9, on_target='finish', max_episode_steps=128)
- env = pogema_v0(grid_config=grid_config)
- env = AnimationMonitor(env)
-
- env.reset()
- done = [False]
-
- while not all(done):
- _, _, done, _ = env.step(env.sample_actions())
-
- env.save_animation('out-static.svg', AnimationConfig(static=True, save_every_idx_episode=None))
- env.save_animation('out-static-ego.svg', AnimationConfig(egocentric_idx=0, static=True))
- env.save_animation('out-static-no-agents.svg', AnimationConfig(show_agents=False, static=True))
- env.save_animation("out.svg")
- env.save_animation("out-ego.svg", AnimationConfig(egocentric_idx=0))
-
-
-if __name__ == '__main__':
- main()
diff --git a/pogema/svg_animation/__init__.py b/pogema/svg_animation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/pogema/svg_animation/animation_drawer.py b/pogema/svg_animation/animation_drawer.py
new file mode 100644
index 0000000..46d0db5
--- /dev/null
+++ b/pogema/svg_animation/animation_drawer.py
@@ -0,0 +1,395 @@
+import math
+import typing
+from dataclasses import dataclass
+
+from pogema import GridConfig
+from pogema.svg_animation.svg_objects import Line, RectangleHref, Animation, Circle, Rectangle
+
+
+@dataclass
+class AnimationConfig:
+ directory: str = 'renders/'
+ static: bool = False
+ show_agents: bool = True
+ egocentric_idx: typing.Optional[int] = None
+ uid: typing.Optional[str] = None
+ save_every_idx_episode: typing.Optional[int] = 1
+ show_grid_lines: bool = True
+
+
+@dataclass
+class SvgSettings:
+ r: int = 35
+ stroke_width: int = 10
+ scale_size: int = 100
+ time_scale: float = 0.25
+ draw_start: int = 100
+ rx: int = 15
+
+ obstacle_color: str = '#84A1AE'
+ ego_color: str = '#c1433c'
+ ego_other_color: str = '#6e81af'
+ shaded_opacity: float = 0.2
+ egocentric_shaded: bool = True
+ stroke_dasharray: int = 25
+
+ colors: tuple = (
+ '#c1433c',
+ '#2e6f9e',
+ '#6e81af',
+ '#00b9c8',
+ '#72D5C8',
+ '#0ea08c',
+ '#8F7B66',
+ )
+
+
+@dataclass
+class GridHolder:
+ obstacles: typing.Any = None
+ episode_length: int = None
+ height: int = None
+ width: int = None
+ colors: dict = None
+ history: list = None
+ obs_radius: int = None
+ grid_config: GridConfig = None
+ on_target: str = None
+ config: AnimationConfig = None
+ svg_settings: SvgSettings = None
+
+
+class Drawing:
+
+ def __init__(self, height, width, svg_settings):
+ self.height = height
+ self.width = width
+ self.origin = (0, 0)
+ self.elements = []
+ self.svg_settings = svg_settings
+
+ def add_element(self, element):
+ self.elements.append(element)
+
+ def render(self):
+ scale = max(self.height, self.width) / 512
+ scaled_width = math.ceil(self.width / scale)
+ scaled_height = math.ceil(self.height / scale)
+
+ dx, dy = self.origin
+ view_box = (dx, dy - self.height, self.width, self.height)
+
+ svg_header = f'''
+ ')
+ return "\n".join(elements_svg)
+
+
+class AnimationDrawer:
+
+ def __init__(self):
+ pass
+
+ def create_animation(self, grid_holder: GridHolder):
+ gh = grid_holder
+ render_width = gh.height * gh.svg_settings.scale_size + gh.svg_settings.scale_size
+ render_height = gh.width * gh.svg_settings.scale_size + gh.svg_settings.scale_size
+ drawing = Drawing(width=render_width, height=render_height, svg_settings=SvgSettings())
+ obstacles = self.create_obstacles(gh)
+
+ agents = []
+ targets = []
+
+ if gh.config.show_agents:
+ agents = self.create_agents(gh)
+ targets = self.create_targets(gh)
+
+ if not gh.config.static:
+ self.animate_agents(agents, gh)
+ self.animate_targets(targets, gh)
+ if gh.config.show_grid_lines:
+ grid_lines = self.create_grid_lines(gh, render_width, render_height)
+ for line in grid_lines:
+ drawing.add_element(line)
+ for obj in [*obstacles, *agents, *targets]:
+ drawing.add_element(obj)
+
+ if gh.config.egocentric_idx is not None:
+ field_of_view = self.create_field_of_view(grid_holder=gh)
+ if not gh.config.static:
+ self.animate_obstacles(obstacles=obstacles, grid_holder=gh)
+ self.animate_field_of_view(field_of_view, gh)
+ drawing.add_element(field_of_view)
+
+ return drawing
+
+ @staticmethod
+ def fix_point(x, y, length):
+ return length - y - 1, x
+
+ @staticmethod
+ def check_in_radius(x1, y1, x2, y2, r) -> bool:
+ return x2 - r <= x1 <= x2 + r and y2 - r <= y1 <= y2 + r
+
+ @staticmethod
+ def create_grid_lines(grid_holder: GridHolder, render_width, render_height):
+ gh = grid_holder
+ offset = 0
+ stroke_settings = {'class': 'line'}
+ grid_lines = []
+ for i in range(-1, grid_holder.height + 1):
+ x = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2
+ grid_lines.append(Line(x1=x, y1=offset, x2=x, y2=render_height - offset, **stroke_settings))
+
+ for i in range(-1, grid_holder.width + 1):
+ y = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2
+ grid_lines.append(Line(x1=offset, y1=y, x2=render_width - offset, y2=y, **stroke_settings))
+
+ return grid_lines
+
+ @staticmethod
+ def create_field_of_view(grid_holder):
+ gh: GridHolder = grid_holder
+ ego_idx = gh.config.egocentric_idx
+ x, y = gh.history[ego_idx][0].get_xy()
+ cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size
+ cy = gh.svg_settings.draw_start + (gh.width - x - 1) * gh.svg_settings.scale_size
+
+ dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2
+ result = Rectangle(
+ x=cx - dr + gh.svg_settings.r, y=cy - dr + gh.svg_settings.r,
+ width=2 * dr - 2 * gh.svg_settings.r, height=2 * dr - 2 * gh.svg_settings.r,
+ stroke=gh.svg_settings.ego_color, stroke_width=gh.svg_settings.stroke_width,
+ fill='none', rx=gh.svg_settings.rx, stroke_dasharray=gh.svg_settings.stroke_dasharray
+ )
+
+ return result
+
+ def animate_field_of_view(self, view, grid_holder):
+ gh: GridHolder = grid_holder
+ x_path = []
+ y_path = []
+ ego_idx = grid_holder.config.egocentric_idx
+ for state in gh.history[ego_idx]:
+ x, y = state.get_xy()
+ dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2
+ cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size
+ cy = -gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size
+ x_path.append(str(cx - dr + gh.svg_settings.r))
+ y_path.append(str(cy - dr + gh.svg_settings.r))
+
+ visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[ego_idx]]
+
+ view.add_animation(self.compressed_anim('x', x_path, gh.svg_settings.time_scale))
+ view.add_animation(self.compressed_anim('y', y_path, gh.svg_settings.time_scale))
+ view.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale))
+
+ def animate_agents(self, agents, grid_holder):
+ gh: GridHolder = grid_holder
+ ego_idx = gh.config.egocentric_idx
+
+ for agent_idx, agent in enumerate(agents):
+ x_path = []
+ y_path = []
+ opacity = []
+ for idx, agent_state in enumerate(gh.history[agent_idx]):
+ x, y = agent_state.get_xy()
+
+ x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size))
+ y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size))
+
+ if ego_idx is not None:
+ ego_x, ego_y = gh.history[ego_idx][idx].get_xy()
+ if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius):
+ opacity.append('1.0')
+ else:
+ opacity.append(str(gh.svg_settings.shaded_opacity))
+
+ visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
+
+ agent.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale))
+ agent.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale))
+ agent.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale))
+ if opacity:
+ agent.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale))
+
+ @classmethod
+ def compressed_anim(cls, attr_name, tokens, time_scale, rep_cnt='indefinite'):
+ tokens, times = cls.compress_tokens(tokens)
+ cumulative = [0, ]
+ for t in times:
+ cumulative.append(cumulative[-1] + t)
+ times = [str(round(value / cumulative[-1], 10)) for value in cumulative]
+ tokens = [tokens[0]] + tokens
+
+ times = times
+ tokens = tokens
+ return Animation(
+ attributeName=attr_name, dur=f'{time_scale * (-1 + cumulative[-1])}s',
+ values=";".join(tokens), repeatCount=rep_cnt, keyTimes=";".join(times)
+ )
+
+ @staticmethod
+ def wisely_add(token, cnt, tokens, times):
+ if cnt > 1:
+ tokens += [token, token]
+ times += [1, cnt - 1]
+ else:
+ tokens.append(token)
+ times.append(cnt)
+
+ @classmethod
+ def compress_tokens(cls, input_tokens: list):
+ tokens = []
+ times = []
+ if input_tokens:
+ cur_idx = 0
+ cnt = 1
+ for idx in range(1, len(input_tokens)):
+ if input_tokens[idx] == input_tokens[cur_idx]:
+ cnt += 1
+ else:
+ cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times)
+ cnt = 1
+ cur_idx = idx
+ cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times)
+ return tokens, times
+
+ def animate_targets(self, targets, grid_holder):
+ gh: GridHolder = grid_holder
+ ego_idx = gh.config.egocentric_idx
+
+ for agent_idx, target in enumerate(targets):
+ target_idx = ego_idx if ego_idx is not None else agent_idx
+
+ x_path = []
+ y_path = []
+
+ for step_idx, state in enumerate(gh.history[target_idx]):
+ x, y = state.get_target_xy()
+ x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size))
+ y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size))
+
+ visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
+
+ if gh.on_target == 'restart' or gh.on_target == 'wait':
+ target.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale))
+ target.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale))
+ target.add_animation(self.compressed_anim("visibility", visibility, gh.svg_settings.time_scale))
+
+ def create_obstacles(self, grid_holder):
+ gh = grid_holder
+ result = []
+
+ for i in range(gh.height):
+ for j in range(gh.width):
+ x, y = self.fix_point(i, j, gh.width)
+
+ if gh.obstacles[x][y]:
+ obs_settings = {}
+ obs_settings.update(
+ x=gh.svg_settings.draw_start + i * gh.svg_settings.scale_size - gh.svg_settings.r,
+ y=gh.svg_settings.draw_start + j * gh.svg_settings.scale_size - gh.svg_settings.r,
+ height=gh.svg_settings.r * 2,
+ )
+
+ if gh.config.egocentric_idx is not None and gh.svg_settings.egocentric_shaded:
+ initial_positions = [agent_states[0].get_xy() for agent_states in gh.history]
+ ego_x, ego_y = initial_positions[gh.config.egocentric_idx]
+ if not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius):
+ obs_settings.update(opacity=gh.svg_settings.shaded_opacity)
+
+ result.append(RectangleHref(**obs_settings))
+
+ return result
+
+ def animate_obstacles(self, obstacles, grid_holder):
+ gh: GridHolder = grid_holder
+ obstacle_idx = 0
+
+ for i in range(gh.height):
+ for j in range(gh.width):
+ x, y = self.fix_point(i, j, gh.width)
+ if not gh.obstacles[x][y]:
+ continue
+ opacity = []
+ seen = set()
+ for step_idx, agent_state in enumerate(gh.history[gh.config.egocentric_idx]):
+ ego_x, ego_y = agent_state.get_xy()
+ if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius):
+ seen.add((x, y))
+ if (x, y) in seen:
+ opacity.append(str(1.0))
+ else:
+ opacity.append(str(gh.svg_settings.shaded_opacity))
+
+ obstacle = obstacles[obstacle_idx]
+ obstacle.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale))
+
+ obstacle_idx += 1
+
+ def create_agents(self, grid_holder):
+ initial_positions = [state[0].get_xy() for state in grid_holder.history if state[0].is_active()]
+ agents = []
+ gh: GridHolder = grid_holder
+ ego_idx = grid_holder.config.egocentric_idx
+
+ for idx, (x, y) in enumerate(initial_positions):
+ circle_settings = {
+ 'cx': gh.svg_settings.draw_start + y * gh.svg_settings.scale_size,
+ 'cy': gh.svg_settings.draw_start + (grid_holder.width - x - 1) * gh.svg_settings.scale_size,
+ 'r': gh.svg_settings.r, 'fill': grid_holder.colors[idx], 'class': 'agent',
+ }
+
+ if ego_idx is not None:
+ ego_x, ego_y = initial_positions[ego_idx]
+ is_out_of_radius = not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius)
+ circle_settings['fill'] = gh.svg_settings.ego_other_color
+ if idx == ego_idx:
+ circle_settings['fill'] = gh.svg_settings.ego_color
+ elif is_out_of_radius and gh.svg_settings.egocentric_shaded:
+ circle_settings['opacity'] = gh.svg_settings.shaded_opacity
+
+ agents.append(Circle(**circle_settings))
+
+ return agents
+
+ @staticmethod
+ def create_targets(grid_holder):
+ gh: GridHolder = grid_holder
+ targets = []
+ for agent_idx, agent_states in enumerate(gh.history):
+
+ tx, ty = agent_states[0].get_target_xy()
+ x, y = ty, gh.width - tx - 1
+
+ if not any([agent_state.is_active() for agent_state in gh.history[agent_idx]]):
+ continue
+
+ circle_settings = {"class": 'target'}
+ circle_settings.update(
+ cx=gh.svg_settings.draw_start + x * gh.svg_settings.scale_size, r=gh.svg_settings.r,
+ cy=gh.svg_settings.draw_start + y * gh.svg_settings.scale_size, stroke=gh.colors[agent_idx],
+ )
+
+ if gh.config.egocentric_idx is not None:
+ if gh.config.egocentric_idx != agent_idx:
+ continue
+
+ circle_settings.update(stroke=gh.svg_settings.ego_color)
+ target = Circle(**circle_settings)
+ targets.append(target)
+ return targets
diff --git a/pogema/svg_animation/animation_wrapper.py b/pogema/svg_animation/animation_wrapper.py
new file mode 100644
index 0000000..c87aff1
--- /dev/null
+++ b/pogema/svg_animation/animation_wrapper.py
@@ -0,0 +1,172 @@
+import os
+from itertools import cycle
+from gymnasium import logger, Wrapper
+
+from pogema import GridConfig
+from pogema.svg_animation.animation_drawer import AnimationConfig, SvgSettings, GridHolder, AnimationDrawer
+from pogema.wrappers.persistence import PersistentWrapper, AgentState
+
+
+class AnimationMonitor(Wrapper):
+ """
+ Defines the animation, which saves the episode as SVG.
+ """
+
+ def __init__(self, env, animation_config=AnimationConfig()):
+ self._working_radius = env.grid_config.obs_radius - 1
+ env = PersistentWrapper(env, xy_offset=-self._working_radius)
+
+ super().__init__(env)
+
+ self.history = env.get_history()
+
+ self.svg_settings: SvgSettings = SvgSettings()
+ self.animation_config: AnimationConfig = animation_config
+
+ self._episode_idx = 0
+
+ def step(self, action):
+ """
+ Saves information about the episode.
+ :param action: current actions
+ :return: obs, reward, done, info
+ """
+ obs, reward, terminated, truncated, info = self.env.step(action)
+
+ multi_agent_terminated = isinstance(terminated, (list, tuple)) and all(terminated)
+ single_agent_terminated = isinstance(terminated, (bool, int)) and terminated
+ multi_agent_truncated = isinstance(truncated, (list, tuple)) and all(truncated)
+ single_agent_truncated = isinstance(truncated, (bool, int)) and truncated
+
+ if multi_agent_terminated or single_agent_terminated or multi_agent_truncated or single_agent_truncated:
+ save_tau = self.animation_config.save_every_idx_episode
+ if save_tau:
+ if (self._episode_idx + 1) % save_tau or save_tau == 1:
+ if not os.path.exists(self.animation_config.directory):
+ logger.info(f"Creating pogema monitor directory {self.animation_config.directory}", )
+ os.makedirs(self.animation_config.directory, exist_ok=True)
+
+ path = os.path.join(self.animation_config.directory,
+ self.pick_name(self.grid_config, self._episode_idx))
+ self.save_animation(path)
+
+ return obs, reward, terminated, truncated, info
+
+ @staticmethod
+ def pick_name(grid_config: GridConfig, episode_idx=None, zfill_ep=5):
+ """
+ Picks a name for the SVG file.
+ :param grid_config: configuration of the grid
+ :param episode_idx: idx of the episode
+ :param zfill_ep: zfill for the episode number
+ :return:
+ """
+ gc = grid_config
+ name = 'pogema'
+ if episode_idx is not None:
+ name += f'-ep{str(episode_idx).zfill(zfill_ep)}'
+ if gc:
+ if gc.map_name:
+ name += f'-{gc.map_name}'
+ if gc.seed is not None:
+ name += f'-seed{gc.seed}'
+ else:
+ name += '-render'
+ return name + '.svg'
+
+ def reset(self, **kwargs):
+ """
+ Resets the environment and resets the current positions of agents and targets
+ :param kwargs:
+ :return: obs: observation
+ """
+ obs = self.env.reset(**kwargs)
+
+ self._episode_idx += 1
+ self.history = self.env.get_history()
+
+ return obs
+
+ def save_animation(self, name='render.svg', animation_config: AnimationConfig = AnimationConfig()):
+ """
+ Saves the animation.
+ :param name: name of the file
+ :param animation_config: animation configuration
+ :return: None
+ """
+ wr = self._working_radius
+ obstacles = self.env.get_obstacles(ignore_borders=False)[wr:-wr, wr:-wr]
+ history: list[list[AgentState]] = self.env.decompress_history(self.history)
+
+ svg_settings = SvgSettings()
+ colors_cycle = cycle(svg_settings.colors)
+ agents_colors = {index: next(colors_cycle) for index in range(self.grid_config.num_agents)}
+
+ for agent_idx in range(self.grid_config.num_agents):
+ history[agent_idx].append(history[agent_idx][-1])
+
+ episode_length = len(history[0])
+ # Change episode length for egocentric environment
+ if animation_config.egocentric_idx is not None and self.grid_config.on_target == 'finish':
+ episode_length = history[animation_config.egocentric_idx][-1].step + 1
+ for agent_idx in range(self.grid_config.num_agents):
+ history[agent_idx] = history[agent_idx][:episode_length]
+
+ grid_holder = GridHolder(
+ width=len(obstacles), height=len(obstacles[0]),
+ obstacles=obstacles,
+ episode_length=episode_length,
+ history=history,
+ obs_radius=self.grid_config.obs_radius,
+ on_target=self.grid_config.on_target,
+ colors=agents_colors,
+ config=animation_config,
+ svg_settings=svg_settings
+ )
+
+ animation = AnimationDrawer().create_animation(grid_holder)
+ with open(name, "w") as f:
+ f.write(animation.render())
+
+
+def main():
+ from pogema import GridConfig, pogema_v0, AnimationMonitor, BatchAStarAgent, AnimationConfig
+
+ for egocentric_idx in [0, 1]:
+ for on_target in ['nothing', 'restart', 'finish']:
+ grid = """
+ ....#..
+ ..#....
+ .......
+ .......
+ #.#.#..
+ #.#.#..
+ """
+ grid_config = GridConfig(size=32, num_agents=2, obs_radius=2, seed=8, on_target=on_target,
+ max_episode_steps=16,
+ density=0.1, map=grid, observation_type="POMAPF")
+ env = pogema_v0(grid_config=grid_config)
+ env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None))
+
+ obs, _ = env.reset()
+ truncated = terminated = [False]
+
+ agent = BatchAStarAgent()
+ while not all(terminated) and not all(truncated):
+ obs, _, terminated, truncated, _ = env.step(agent.act(obs))
+
+ anim_folder = 'renders'
+ if not os.path.exists(anim_folder):
+ os.makedirs(anim_folder)
+
+ env.save_animation(f'{anim_folder}/anim-{on_target}.svg')
+ env.save_animation(f'{anim_folder}/anim-{on_target}-ego-{egocentric_idx}.svg',
+ AnimationConfig(egocentric_idx=egocentric_idx))
+ env.save_animation(f'{anim_folder}/anim-static.svg', AnimationConfig(static=True))
+ env.save_animation(f'{anim_folder}/anim-static-ego.svg', AnimationConfig(egocentric_idx=0, static=True))
+ env.save_animation(f'{anim_folder}/anim-static-no-agents.svg',
+ AnimationConfig(show_agents=False, static=True))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pogema/svg_animation/svg_objects.py b/pogema/svg_animation/svg_objects.py
new file mode 100644
index 0000000..c18c13f
--- /dev/null
+++ b/pogema/svg_animation/svg_objects.py
@@ -0,0 +1,77 @@
+class SvgObject:
+ tag = None
+
+ def __init__(self, **kwargs):
+ self.attributes = kwargs
+ self.animations = []
+
+ def add_animation(self, animation):
+ self.animations.append(animation)
+
+ @staticmethod
+ def render_attributes(attributes):
+ result = " ".join([f'{x.replace("_", "-")}="{y}"' for x, y in sorted(attributes.items())])
+ return result
+
+ def render(self):
+ animations = '\n'.join([a.render() for a in self.animations]) if self.animations else None
+ if animations:
+ return f"<{self.tag} {self.render_attributes(self.attributes)}> {animations} {self.tag}>"
+ return f"<{self.tag} {self.render_attributes(self.attributes)} />"
+
+
+class Rectangle(SvgObject):
+ """
+ Rectangle class for the SVG.
+ """
+ tag = 'rect'
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.attributes['y'] = -self.attributes['y'] - self.attributes['height']
+
+
+class RectangleHref(SvgObject):
+ """
+ Rectangle class for the SVG.
+ """
+ tag = 'use'
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.attributes['y'] = -self.attributes['y'] - self.attributes['height']
+ self.attributes['href'] = "#obstacle"
+ del self.attributes['height']
+
+
+class Circle(SvgObject):
+ """
+ Circle class for the SVG.
+ """
+ tag = 'circle'
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.attributes['cy'] = -self.attributes['cy']
+
+
+class Line(SvgObject):
+ """
+ Line class for the SVG.
+ """
+ tag = 'line'
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.attributes['y1'] = -self.attributes['y1']
+ self.attributes['y2'] = -self.attributes['y2']
+
+
+class Animation(SvgObject):
+ """
+ Animation class for the SVG.
+ """
+ tag = 'animate'
+
+ def render(self):
+ return f"<{self.tag} {self.render_attributes(self.attributes)}/>"
diff --git a/pogema/wrappers/persistence.py b/pogema/wrappers/persistence.py
index 56634b2..3bd51f4 100644
--- a/pogema/wrappers/persistence.py
+++ b/pogema/wrappers/persistence.py
@@ -26,12 +26,16 @@ def __eq__(self, other):
o = other
return self.x == o.x and self.y == o.y and self.tx == o.tx and self.ty == o.ty and self.active == o.active
+ def __str__(self):
+ return str([self.x, self.y, self.tx, self.ty, self.step, self.active])
+
class PersistentWrapper(Wrapper):
- def __init__(self, env):
+ def __init__(self, env, xy_offset=None):
super().__init__(env)
self._step = None
self._agent_states = None
+ self._xy_offset = xy_offset
def step(self, action):
result = self.env.step(action)
@@ -67,6 +71,11 @@ def _get_agent_state(self, grid, agent_idx):
x, y = grid.positions_xy[agent_idx]
tx, ty = grid.finishes_xy[agent_idx]
active = grid.is_active[agent_idx]
+ if self._xy_offset:
+ x += self._xy_offset
+ y += self._xy_offset
+ tx += self._xy_offset
+ ty += self._xy_offset
return AgentState(x, y, tx, ty, self._step, active)
def reset(self, **kwargs):
diff --git a/tests/test_deterministic_policy.py b/tests/test_deterministic_policy.py
index f456c0d..4f79ddf 100644
--- a/tests/test_deterministic_policy.py
+++ b/tests/test_deterministic_policy.py
@@ -1,7 +1,8 @@
import numpy as np
from heapq import heappop, heappush
-from pogema import GridConfig, pogema_v0
-from pogema.animation import AnimationMonitor
+from pogema import GridConfig, pogema_v0, AnimationMonitor
+
+# from pogema.animation import AnimationMonitor
INF = 1000000007
diff --git a/tests/test_pogema_env.py b/tests/test_pogema_env.py
index 1ccad67..7b5c32f 100644
--- a/tests/test_pogema_env.py
+++ b/tests/test_pogema_env.py
@@ -4,13 +4,12 @@
import numpy as np
from tabulate import tabulate
-from pogema import pogema_v0
+from pogema import pogema_v0, AnimationMonitor
from pogema import Easy8x8, Normal8x8, Hard8x8, ExtraHard8x8
from pogema import Easy16x16, Normal16x16, Hard16x16, ExtraHard16x16
from pogema import Easy32x32, Normal32x32, Hard32x32, ExtraHard32x32
from pogema import Easy64x64, Normal64x64, Hard64x64, ExtraHard64x64
-from pogema.animation import AnimationMonitor
from pogema.envs import ActionsSampler
from pogema.grid import GridConfig
@@ -122,7 +121,8 @@ def test_standard_pogema_animation():
def test_gym_pogema_animation():
import gymnasium
env = gymnasium.make('Pogema-v0',
- grid_config=GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish'))
+ grid_config=GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42,
+ on_target='finish'))
env = AnimationMonitor(env)
env.reset()
done = False
@@ -131,6 +131,7 @@ def test_gym_pogema_animation():
if terminated or truncated:
break
+
def test_non_disappearing_pogema():
env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='nothing'))
env.reset()