From 5950b65fd8ae91bc3025cad21d35d91dd185de01 Mon Sep 17 00:00:00 2001 From: Alexey Skrynnik Date: Sun, 2 Jun 2024 18:02:32 +0300 Subject: [PATCH] Improving animations (#34) * Improvements for svg animations: - common attributes of elements via defs - removed borders - grid lines by default --- pogema/__init__.py | 5 +- pogema/animation.py | 720 ---------------------- pogema/svg_animation/__init__.py | 0 pogema/svg_animation/animation_drawer.py | 395 ++++++++++++ pogema/svg_animation/animation_wrapper.py | 172 ++++++ pogema/svg_animation/svg_objects.py | 77 +++ pogema/wrappers/persistence.py | 11 +- tests/test_deterministic_policy.py | 5 +- tests/test_pogema_env.py | 7 +- 9 files changed, 665 insertions(+), 727 deletions(-) delete mode 100644 pogema/animation.py create mode 100644 pogema/svg_animation/__init__.py create mode 100644 pogema/svg_animation/animation_drawer.py create mode 100644 pogema/svg_animation/animation_wrapper.py create mode 100644 pogema/svg_animation/svg_objects.py diff --git a/pogema/__init__.py b/pogema/__init__.py index 1a956e5..26ff7e5 100644 --- a/pogema/__init__.py +++ b/pogema/__init__.py @@ -1,6 +1,7 @@ from gymnasium import register from pogema.grid_config import GridConfig from pogema.integrations.make_pogema import pogema_v0 +from pogema.svg_animation.animation_wrapper import AnimationMonitor, AnimationConfig from pogema.a_star_policy import AStarAgent, BatchAStarAgent from pogema.grid_config import Easy8x8, Normal8x8, Hard8x8, ExtraHard8x8 @@ -8,12 +9,14 @@ from pogema.grid_config import Easy32x32, Normal32x32, Hard32x32, ExtraHard32x32 from pogema.grid_config import Easy64x64, Normal64x64, Hard64x64, ExtraHard64x64 -__version__ = '1.2.2' +__version__ = '1.2.3a2' __all__ = [ 'GridConfig', 'pogema_v0', 'AStarAgent', 'BatchAStarAgent', + "AnimationMonitor", "AnimationConfig", + 'Easy8x8', 'Normal8x8', 'Hard8x8', 'ExtraHard8x8', 'Easy16x16', 'Normal16x16', 'Hard16x16', 'ExtraHard16x16', 'Easy32x32', 'Normal32x32', 'Hard32x32', 'ExtraHard32x32', diff --git a/pogema/animation.py b/pogema/animation.py deleted file mode 100644 index c992881..0000000 --- a/pogema/animation.py +++ /dev/null @@ -1,720 +0,0 @@ -import os -import typing -from itertools import cycle -from gymnasium import logger, Wrapper - -from pydantic import BaseModel - -from pogema import GridConfig, pogema_v0 -from pogema.grid import Grid -from pogema.wrappers.persistence import PersistentWrapper, AgentState - - -class AnimationSettings(BaseModel): - """ - Settings for the animation. - """ - r: int = 35 - stroke_width: int = 10 - scale_size: int = 100 - time_scale: float = 0.28 - draw_start: int = 100 - rx: int = 15 - - obstacle_color: str = '#84A1AE' - ego_color: str = '#c1433c' - ego_other_color: str = '#72D5C8' - shaded_opacity: float = 0.2 - egocentric_shaded: bool = True - stroke_dasharray: int = 25 - - colors: list = [ - '#c1433c', - '#2e6f9e', - '#6e81af', - '#00b9c8', - '#72D5C8', - '#0ea08c', - '#8F7B66', - ] - - -class AnimationConfig(BaseModel): - """ - Configuration for the animation. - """ - directory: str = 'renders/' - static: bool = False - show_agents: bool = True - egocentric_idx: typing.Optional[int] = None - uid: typing.Optional[str] = None - save_every_idx_episode: typing.Optional[int] = 1 - show_border: bool = True - show_lines: bool = False - - -class GridHolder(BaseModel): - """ - Holds the grid and the history. - """ - obstacles: typing.Any = None - episode_length: int = None - height: int = None - width: int = None - colors: dict = None - history: list = None - - -class SvgObject: - """ - Main class for the SVG. - """ - tag = None - - def __init__(self, **kwargs): - self.attributes = kwargs - self.animations = [] - - def add_animation(self, animation): - self.animations.append(animation) - - @staticmethod - def render_attributes(attributes): - result = " ".join([f'{x.replace("_", "-")}="{y}"' for x, y in sorted(attributes.items())]) - return result - - def render(self): - animations = '\n'.join([a.render() for a in self.animations]) if self.animations else None - if animations: - return f"<{self.tag} {self.render_attributes(self.attributes)}> {animations} " - return f"<{self.tag} {self.render_attributes(self.attributes)} />" - - -class Rectangle(SvgObject): - """ - Rectangle class for the SVG. - """ - tag = 'rect' - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.attributes['y'] = -self.attributes['y'] - self.attributes['height'] - - -class Circle(SvgObject): - """ - Circle class for the SVG. - """ - tag = 'circle' - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.attributes['cy'] = -self.attributes['cy'] - - -class Line(SvgObject): - """ - Line class for the SVG. - """ - tag = 'line' - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.attributes['y1'] = -self.attributes['y1'] - self.attributes['y2'] = -self.attributes['y2'] - - -class Animation(SvgObject): - """ - Animation class for the SVG. - """ - tag = 'animate' - - def render(self): - return f"<{self.tag} {self.render_attributes(self.attributes)}/>" - - -class Drawing: - """ - Drawing, analog of the DrawSvg class in the pogema package. - """ - - def __init__(self, height, width, display_inline=False, origin=(0, 0)): - self.height = height - self.width = width - self.display_inline = display_inline - self.origin = origin - self.elements = [] - - def add_element(self, element): - self.elements.append(element) - - def render(self): - view_box = (0, -self.height, self.width, self.height) - results = [f''' - ''', - '\n\n', '\n'] - for element in self.elements: - results.append(element.render()) - results.append('') - return "\n".join(results) - - -class AnimationMonitor(Wrapper): - """ - Defines the animation, which saves the episode as SVG. - """ - - def __init__(self, env, animation_config=AnimationConfig()): - # Wrapping env using PersistenceWrapper for saving the history. - env = PersistentWrapper(env) - super().__init__(env) - - self.history = self.env.get_history() - - self.svg_settings: AnimationSettings = AnimationSettings() - self.animation_config: AnimationConfig = animation_config - - self._episode_idx = 0 - - def step(self, action): - """ - Saves information about the episode. - :param action: current actions - :return: obs, reward, done, info - """ - obs, reward, terminated, truncated, info = self.env.step(action) - - multi_agent_terminated = isinstance(terminated, (list, tuple)) and all(terminated) - single_agent_terminated = isinstance(terminated, (bool, int)) and terminated - multi_agent_truncated = isinstance(truncated, (list, tuple)) and all(truncated) - single_agent_truncated = isinstance(truncated, (bool, int)) and truncated - - if multi_agent_terminated or single_agent_terminated or multi_agent_truncated or single_agent_truncated: - save_tau = self.animation_config.save_every_idx_episode - if save_tau: - if (self._episode_idx + 1) % save_tau or save_tau == 1: - if not os.path.exists(self.animation_config.directory): - logger.info(f"Creating pogema monitor directory {self.animation_config.directory}", ) - os.makedirs(self.animation_config.directory, exist_ok=True) - - path = os.path.join(self.animation_config.directory, - self.pick_name(self.grid_config, self._episode_idx)) - self.save_animation(path) - - return obs, reward, terminated, truncated, info - - @staticmethod - def pick_name(grid_config: GridConfig, episode_idx=None, zfill_ep=5): - """ - Picks a name for the SVG file. - :param grid_config: configuration of the grid - :param episode_idx: idx of the episode - :param zfill_ep: zfill for the episode number - :return: - """ - gc = grid_config - name = 'pogema' - if episode_idx is not None: - name += f'-ep{str(episode_idx).zfill(zfill_ep)}' - if gc: - if gc.map_name: - name += f'-{gc.map_name}' - if gc.seed is not None: - name += f'-seed{gc.seed}' - else: - name += '-render' - return name + '.svg' - - def reset(self, **kwargs): - """ - Resets the environment and resets the current positions of agents and targets - :param kwargs: - :return: obs: observation - """ - obs = self.env.reset(**kwargs) - - self._episode_idx += 1 - self.history = self.env.get_history() - - return obs - - def create_animation(self, animation_config=None): - """ - Creates the animation. - :param animation_config: configuration of the animation - :return: drawing: drawing object - """ - anim_cfg = animation_config - if anim_cfg is None: - anim_cfg = self.animation_config - - grid: Grid = self.grid - cfg = self.svg_settings - colors = cycle(cfg.colors) - agents_colors = {index: next(colors) for index in range(self.grid_config.num_agents)} - - if anim_cfg.egocentric_idx is not None: - anim_cfg.egocentric_idx %= self.grid_config.num_agents - - decompressed_history: list[list[AgentState]] = self.env.decompress_history(self.history) - - # Change episode length for egocentric environment - if anim_cfg.egocentric_idx is not None: - episode_length = decompressed_history[anim_cfg.egocentric_idx][-1].step + 1 - for agent_idx in range(self.grid_config.num_agents): - decompressed_history[agent_idx] = decompressed_history[agent_idx][:episode_length] - else: - episode_length = len(decompressed_history[0]) - - # Add last observation one more time to highlight the final state - for agent_idx in range(self.grid_config.num_agents): - decompressed_history[agent_idx].append(decompressed_history[agent_idx][-1]) - - # Change episode length for static environment - if anim_cfg.static: - episode_length = 1 - decompressed_history = [[decompressed_history[idx][-1]] for idx in range(len(decompressed_history))] - - gh = GridHolder(width=len(grid.obstacles), height=len(grid.obstacles[0]), - obstacles=grid.obstacles, - colors=agents_colors, - episode_length=episode_length, - history=decompressed_history, ) - - render_width, render_height = gh.height * cfg.scale_size + cfg.scale_size, gh.width * cfg.scale_size + cfg.scale_size - - drawing = Drawing(width=render_width, height=render_height, display_inline=False, origin=(0, 0)) - obstacles = self.create_obstacles(gh, anim_cfg) - - agents = [] - targets = [] - - if anim_cfg.show_agents: - agents = self.create_agents(gh, anim_cfg) - targets = self.create_targets(gh, anim_cfg) - - if not anim_cfg.static: - self.animate_agents(agents, anim_cfg.egocentric_idx, gh) - self.animate_targets(targets, gh, anim_cfg) - if anim_cfg.show_lines: - grid_lines = self.create_grid_lines(gh, anim_cfg, render_width, render_height) - for line in grid_lines: - drawing.add_element(line) - for obj in [*obstacles, *agents, *targets, ]: - drawing.add_element(obj) - - if anim_cfg.egocentric_idx is not None: - field_of_view = self.create_field_of_view(grid_holder=gh, animation_config=anim_cfg) - if not anim_cfg.static: - self.animate_obstacles(obstacles=obstacles, grid_holder=gh, animation_config=anim_cfg) - self.animate_field_of_view(field_of_view, anim_cfg.egocentric_idx, gh) - drawing.add_element(field_of_view) - - return drawing - - def create_grid_lines(self, grid_holder: GridHolder, animation_config: AnimationConfig, render_width, - render_height): - """ - Creates the grid lines. - :param grid_holder: grid holder - :param animation_config: animation configuration - :return: grid_lines: list of grid lines - """ - cfg = self.svg_settings - grid_lines = [] - for i in range(-1, grid_holder.height + 1): - # vertical lines - x0 = x1 = i * cfg.scale_size + cfg.scale_size / 2 - y0 = 0 - y1 = render_height - grid_lines.append( - Line(x1=x0, y1=y0, x2=x1, y2=y1, stroke=cfg.obstacle_color, stroke_width=cfg.stroke_width // 1.5)) - for i in range(-1, grid_holder.width + 1): - # continue - # horizontal lines - x0 = 0 - y0 = y1 = i * cfg.scale_size + cfg.scale_size / 2 - x1 = render_width - grid_lines.append( - Line(x1=x0, y1=y0, x2=x1, y2=y1, stroke=cfg.obstacle_color, stroke_width=cfg.stroke_width // 1.5)) - - # for i in range(grid_holder.width): - # grid_lines.append(Line(start=(0, i * cfg.scale_size), - # end=(grid_holder.height * cfg.scale_size, i * cfg.scale_size), - # stroke=cfg.grid_color, stroke_width=cfg.grid_width)) - return grid_lines - - def save_animation(self, name='render.svg', animation_config: typing.Optional[AnimationConfig] = None): - """ - Saves the animation. - :param name: name of the file - :param animation_config: animation configuration - :return: None - """ - animation = self.create_animation(animation_config) - with open(name, "w") as f: - f.write(animation.render()) - - @staticmethod - def fix_point(x, y, length): - """ - Fixes the point to the grid. - :param x: coordinate x - :param y: coordinate y - :param length: size of the grid - :return: x, y: fixed coordinates - """ - return length - y - 1, x - - @staticmethod - def check_in_radius(x1, y1, x2, y2, r) -> bool: - """ - Checks if the point is in the radius. - :param x1: coordinate x1 - :param y1: coordinate y1 - :param x2: coordinate x2 - :param y2: coordinate y2 - :param r: radius - :return: - """ - return x2 - r <= x1 <= x2 + r and y2 - r <= y1 <= y2 + r - - def create_field_of_view(self, grid_holder, animation_config): - """ - Creates the field of view for the egocentric agent. - :param grid_holder: - :param animation_config: - :return: - """ - cfg = self.svg_settings - gh: GridHolder = grid_holder - ego_idx = animation_config.egocentric_idx - x, y = gh.history[ego_idx][0].get_xy() - cx = cfg.draw_start + y * cfg.scale_size - cy = cfg.draw_start + (gh.width - x - 1) * cfg.scale_size - - dr = (self.grid_config.obs_radius + 1) * cfg.scale_size - cfg.stroke_width * 2 - result = Rectangle(x=cx - dr + cfg.r, y=cy - dr + cfg.r, - width=2 * dr - 2 * cfg.r, height=2 * dr - 2 * cfg.r, - stroke=cfg.ego_color, stroke_width=cfg.stroke_width, - fill='none', - rx=cfg.rx, stroke_dasharray=cfg.stroke_dasharray, - ) - - return result - - def animate_field_of_view(self, view, agent_idx, grid_holder): - """ - Animates the field of view. - :param view: - :param agent_idx: - :param grid_holder: - :return: - """ - gh: GridHolder = grid_holder - cfg = self.svg_settings - x_path = [] - y_path = [] - for state in gh.history[agent_idx]: - x, y = state.get_xy() - dr = (self.grid_config.obs_radius + 1) * cfg.scale_size - cfg.stroke_width * 2 - cx = cfg.draw_start + y * cfg.scale_size - cy = -cfg.draw_start + -(gh.width - x - 1) * cfg.scale_size - x_path.append(str(cx - dr + cfg.r)) - y_path.append(str(cy - dr + cfg.r)) - - visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] - - view.add_animation(self.compressed_anim('x', x_path, cfg.time_scale)) - view.add_animation(self.compressed_anim('y', y_path, cfg.time_scale)) - view.add_animation(self.compressed_anim('visibility', visibility, cfg.time_scale)) - - def animate_agents(self, agents, egocentric_idx, grid_holder): - """ - Animates the agents. - :param agents: - :param egocentric_idx: - :param grid_holder: - :return: - """ - gh: GridHolder = grid_holder - cfg = self.svg_settings - for agent_idx, agent in enumerate(agents): - x_path = [] - y_path = [] - opacity = [] - for agent_state in gh.history[agent_idx]: - x, y = agent_state.get_xy() - - x_path.append(str(cfg.draw_start + y * cfg.scale_size)) - y_path.append(str(-cfg.draw_start + -(gh.width - x - 1) * cfg.scale_size)) - - if egocentric_idx is not None: - ego_x, ego_y = agent_state.get_xy() - if self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius): - opacity.append('1.0') - else: - opacity.append(str(cfg.shaded_opacity)) - - visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] - - agent.add_animation(self.compressed_anim('cy', y_path, cfg.time_scale)) - agent.add_animation(self.compressed_anim('cx', x_path, cfg.time_scale)) - agent.add_animation(self.compressed_anim('visibility', visibility, cfg.time_scale)) - if opacity: - agent.add_animation(self.compressed_anim('opacity', opacity, cfg.time_scale)) - - @classmethod - def compressed_anim(cls, attr_name, tokens, time_scale, rep_cnt='indefinite'): - """ - Compresses the animation. - :param attr_name: - :param tokens: - :param time_scale: - :param rep_cnt: - :return: - """ - tokens, times = cls.compress_tokens(tokens) - cumulative = [0, ] - for t in times: - cumulative.append(cumulative[-1] + t) - times = [str(round(value / cumulative[-1], 10)) for value in cumulative] - tokens = [tokens[0]] + tokens - - times = times - tokens = tokens - return Animation(attributeName=attr_name, - dur=f'{time_scale * (-1 + cumulative[-1])}s', - values=";".join(tokens), - repeatCount=rep_cnt, - keyTimes=";".join(times)) - - @staticmethod - def wisely_add(token, cnt, tokens, times): - """ - Adds the token to the tokens and times. - :param token: - :param cnt: - :param tokens: - :param times: - :return: - """ - if cnt > 1: - tokens += [token, token] - times += [1, cnt - 1] - else: - tokens.append(token) - times.append(cnt) - - @classmethod - def compress_tokens(cls, input_tokens: list): - """ - Compresses the tokens. - :param input_tokens: - :return: - """ - tokens = [] - times = [] - if input_tokens: - cur_idx = 0 - cnt = 1 - for idx in range(1, len(input_tokens)): - if input_tokens[idx] == input_tokens[cur_idx]: - cnt += 1 - else: - cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times) - cnt = 1 - cur_idx = idx - cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times) - return tokens, times - - def animate_targets(self, targets, grid_holder, animation_config): - """ - Animates the targets. - :param targets: - :param grid_holder: - :param animation_config: - :return: - """ - gh: GridHolder = grid_holder - cfg = self.svg_settings - ego_idx = animation_config.egocentric_idx - - for agent_idx, target in enumerate(targets): - target_idx = ego_idx if ego_idx is not None else agent_idx - - x_path = [] - y_path = [] - - for step_idx, state in enumerate(gh.history[target_idx]): - x, y = state.get_target_xy() - x_path.append(str(cfg.draw_start + y * cfg.scale_size)) - y_path.append(str(-cfg.draw_start + -(gh.width - x - 1) * cfg.scale_size)) - - visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] - - if self.grid_config.on_target == 'restart': - target.add_animation(self.compressed_anim('cy', y_path, cfg.time_scale)) - target.add_animation(self.compressed_anim('cx', x_path, cfg.time_scale)) - target.add_animation(self.compressed_anim("visibility", visibility, cfg.time_scale)) - - def create_obstacles(self, grid_holder, animation_config): - """ - Creates the obstacles. - :param grid_holder: - :param animation_config: - :return: - """ - gh = grid_holder - cfg = self.svg_settings - - result = [] - r = self.grid_config.obs_radius - for i in range(gh.height): - for j in range(gh.width): - x, y = self.fix_point(i, j, gh.width) - if not animation_config.show_border: - if i == r - 1 or j == r - 1 or j == gh.width - r or i == gh.height - r: - continue - if gh.obstacles[x][y] != self.grid_config.FREE: - obs_settings = {} - obs_settings.update(x=cfg.draw_start + i * cfg.scale_size - cfg.r, - y=cfg.draw_start + j * cfg.scale_size - cfg.r, - width=cfg.r * 2, - height=cfg.r * 2, - rx=cfg.rx, - fill=self.svg_settings.obstacle_color) - - if animation_config.egocentric_idx is not None and cfg.egocentric_shaded: - initial_positions = [agent_states[0].get_xy() for agent_states in gh.history] - ego_x, ego_y = initial_positions[animation_config.egocentric_idx] - if not self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius): - obs_settings.update(opacity=cfg.shaded_opacity) - - result.append(Rectangle(**obs_settings)) - - return result - - def animate_obstacles(self, obstacles, grid_holder, animation_config): - """ - - :param obstacles: - :param grid_holder: - :param animation_config: - :return: - """ - gh: GridHolder = grid_holder - obstacle_idx = 0 - cfg = self.svg_settings - - for i in range(gh.height): - for j in range(gh.width): - x, y = self.fix_point(i, j, gh.width) - if gh.obstacles[x][y] == self.grid_config.FREE: - continue - opacity = [] - seen = set() - for step_idx, agent_state in enumerate(gh.history[animation_config.egocentric_idx]): - ego_x, ego_y = agent_state.get_xy() - if self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius): - seen.add((x, y)) - if (x, y) in seen: - opacity.append(str(1.0)) - else: - opacity.append(str(cfg.shaded_opacity)) - - obstacle = obstacles[obstacle_idx] - obstacle.add_animation(self.compressed_anim('opacity', opacity, cfg.time_scale)) - - obstacle_idx += 1 - - def create_agents(self, grid_holder, animation_config): - """ - Creates the agents. - :param grid_holder: - :param animation_config: - :return: - """ - gh: GridHolder = grid_holder - cfg = self.svg_settings - - agents = [] - initial_positions = [agent_states[0].get_xy() for agent_states in gh.history] - for idx, (x, y) in enumerate(initial_positions): - - if not any([agent_state.is_active() for agent_state in gh.history[idx]]): - continue - - circle_settings = {} - circle_settings.update(cx=cfg.draw_start + y * cfg.scale_size, - cy=cfg.draw_start + (gh.width - x - 1) * cfg.scale_size, - r=cfg.r, fill=gh.colors[idx]) - ego_idx = animation_config.egocentric_idx - if ego_idx is not None: - ego_x, ego_y = initial_positions[ego_idx] - if not self.check_in_radius(x, y, ego_x, ego_y, self.grid_config.obs_radius) and cfg.egocentric_shaded: - circle_settings.update(opacity=cfg.shaded_opacity) - if ego_idx == idx: - circle_settings.update(fill=self.svg_settings.ego_color) - else: - circle_settings.update(fill=self.svg_settings.ego_other_color) - agent = Circle(**circle_settings) - agents.append(agent) - - return agents - - def create_targets(self, grid_holder, animation_config): - """ - Creates the targets. - :param grid_holder: - :param animation_config: - :return: - """ - gh: GridHolder = grid_holder - cfg = self.svg_settings - targets = [] - for agent_idx, agent_states in enumerate(gh.history): - - tx, ty = agent_states[0].get_target_xy() - x, y = ty, gh.width - tx - 1 - - if not any([agent_state.is_active() for agent_state in gh.history[agent_idx]]): - continue - - circle_settings = {} - circle_settings.update(cx=cfg.draw_start + x * cfg.scale_size, - cy=cfg.draw_start + y * cfg.scale_size, - r=cfg.r, - stroke=gh.colors[agent_idx], stroke_width=cfg.stroke_width, fill='none') - if animation_config.egocentric_idx is not None: - if animation_config.egocentric_idx != agent_idx: - continue - - circle_settings.update(stroke=cfg.ego_color) - target = Circle(**circle_settings) - targets.append(target) - return targets - - -def main(): - grid_config = GridConfig(size=8, num_agents=5, obs_radius=2, seed=9, on_target='finish', max_episode_steps=128) - env = pogema_v0(grid_config=grid_config) - env = AnimationMonitor(env) - - env.reset() - done = [False] - - while not all(done): - _, _, done, _ = env.step(env.sample_actions()) - - env.save_animation('out-static.svg', AnimationConfig(static=True, save_every_idx_episode=None)) - env.save_animation('out-static-ego.svg', AnimationConfig(egocentric_idx=0, static=True)) - env.save_animation('out-static-no-agents.svg', AnimationConfig(show_agents=False, static=True)) - env.save_animation("out.svg") - env.save_animation("out-ego.svg", AnimationConfig(egocentric_idx=0)) - - -if __name__ == '__main__': - main() diff --git a/pogema/svg_animation/__init__.py b/pogema/svg_animation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pogema/svg_animation/animation_drawer.py b/pogema/svg_animation/animation_drawer.py new file mode 100644 index 0000000..46d0db5 --- /dev/null +++ b/pogema/svg_animation/animation_drawer.py @@ -0,0 +1,395 @@ +import math +import typing +from dataclasses import dataclass + +from pogema import GridConfig +from pogema.svg_animation.svg_objects import Line, RectangleHref, Animation, Circle, Rectangle + + +@dataclass +class AnimationConfig: + directory: str = 'renders/' + static: bool = False + show_agents: bool = True + egocentric_idx: typing.Optional[int] = None + uid: typing.Optional[str] = None + save_every_idx_episode: typing.Optional[int] = 1 + show_grid_lines: bool = True + + +@dataclass +class SvgSettings: + r: int = 35 + stroke_width: int = 10 + scale_size: int = 100 + time_scale: float = 0.25 + draw_start: int = 100 + rx: int = 15 + + obstacle_color: str = '#84A1AE' + ego_color: str = '#c1433c' + ego_other_color: str = '#6e81af' + shaded_opacity: float = 0.2 + egocentric_shaded: bool = True + stroke_dasharray: int = 25 + + colors: tuple = ( + '#c1433c', + '#2e6f9e', + '#6e81af', + '#00b9c8', + '#72D5C8', + '#0ea08c', + '#8F7B66', + ) + + +@dataclass +class GridHolder: + obstacles: typing.Any = None + episode_length: int = None + height: int = None + width: int = None + colors: dict = None + history: list = None + obs_radius: int = None + grid_config: GridConfig = None + on_target: str = None + config: AnimationConfig = None + svg_settings: SvgSettings = None + + +class Drawing: + + def __init__(self, height, width, svg_settings): + self.height = height + self.width = width + self.origin = (0, 0) + self.elements = [] + self.svg_settings = svg_settings + + def add_element(self, element): + self.elements.append(element) + + def render(self): + scale = max(self.height, self.width) / 512 + scaled_width = math.ceil(self.width / scale) + scaled_height = math.ceil(self.height / scale) + + dx, dy = self.origin + view_box = (dx, dy - self.height, self.width, self.height) + + svg_header = f''' + ''' + + definitions = f''' + + + ''' + + elements_svg = [svg_header, '', definitions, '\n'] + elements_svg.extend(element.render() for element in self.elements) + elements_svg.append('') + return "\n".join(elements_svg) + + +class AnimationDrawer: + + def __init__(self): + pass + + def create_animation(self, grid_holder: GridHolder): + gh = grid_holder + render_width = gh.height * gh.svg_settings.scale_size + gh.svg_settings.scale_size + render_height = gh.width * gh.svg_settings.scale_size + gh.svg_settings.scale_size + drawing = Drawing(width=render_width, height=render_height, svg_settings=SvgSettings()) + obstacles = self.create_obstacles(gh) + + agents = [] + targets = [] + + if gh.config.show_agents: + agents = self.create_agents(gh) + targets = self.create_targets(gh) + + if not gh.config.static: + self.animate_agents(agents, gh) + self.animate_targets(targets, gh) + if gh.config.show_grid_lines: + grid_lines = self.create_grid_lines(gh, render_width, render_height) + for line in grid_lines: + drawing.add_element(line) + for obj in [*obstacles, *agents, *targets]: + drawing.add_element(obj) + + if gh.config.egocentric_idx is not None: + field_of_view = self.create_field_of_view(grid_holder=gh) + if not gh.config.static: + self.animate_obstacles(obstacles=obstacles, grid_holder=gh) + self.animate_field_of_view(field_of_view, gh) + drawing.add_element(field_of_view) + + return drawing + + @staticmethod + def fix_point(x, y, length): + return length - y - 1, x + + @staticmethod + def check_in_radius(x1, y1, x2, y2, r) -> bool: + return x2 - r <= x1 <= x2 + r and y2 - r <= y1 <= y2 + r + + @staticmethod + def create_grid_lines(grid_holder: GridHolder, render_width, render_height): + gh = grid_holder + offset = 0 + stroke_settings = {'class': 'line'} + grid_lines = [] + for i in range(-1, grid_holder.height + 1): + x = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2 + grid_lines.append(Line(x1=x, y1=offset, x2=x, y2=render_height - offset, **stroke_settings)) + + for i in range(-1, grid_holder.width + 1): + y = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2 + grid_lines.append(Line(x1=offset, y1=y, x2=render_width - offset, y2=y, **stroke_settings)) + + return grid_lines + + @staticmethod + def create_field_of_view(grid_holder): + gh: GridHolder = grid_holder + ego_idx = gh.config.egocentric_idx + x, y = gh.history[ego_idx][0].get_xy() + cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size + cy = gh.svg_settings.draw_start + (gh.width - x - 1) * gh.svg_settings.scale_size + + dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2 + result = Rectangle( + x=cx - dr + gh.svg_settings.r, y=cy - dr + gh.svg_settings.r, + width=2 * dr - 2 * gh.svg_settings.r, height=2 * dr - 2 * gh.svg_settings.r, + stroke=gh.svg_settings.ego_color, stroke_width=gh.svg_settings.stroke_width, + fill='none', rx=gh.svg_settings.rx, stroke_dasharray=gh.svg_settings.stroke_dasharray + ) + + return result + + def animate_field_of_view(self, view, grid_holder): + gh: GridHolder = grid_holder + x_path = [] + y_path = [] + ego_idx = grid_holder.config.egocentric_idx + for state in gh.history[ego_idx]: + x, y = state.get_xy() + dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2 + cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size + cy = -gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size + x_path.append(str(cx - dr + gh.svg_settings.r)) + y_path.append(str(cy - dr + gh.svg_settings.r)) + + visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[ego_idx]] + + view.add_animation(self.compressed_anim('x', x_path, gh.svg_settings.time_scale)) + view.add_animation(self.compressed_anim('y', y_path, gh.svg_settings.time_scale)) + view.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale)) + + def animate_agents(self, agents, grid_holder): + gh: GridHolder = grid_holder + ego_idx = gh.config.egocentric_idx + + for agent_idx, agent in enumerate(agents): + x_path = [] + y_path = [] + opacity = [] + for idx, agent_state in enumerate(gh.history[agent_idx]): + x, y = agent_state.get_xy() + + x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size)) + y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size)) + + if ego_idx is not None: + ego_x, ego_y = gh.history[ego_idx][idx].get_xy() + if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius): + opacity.append('1.0') + else: + opacity.append(str(gh.svg_settings.shaded_opacity)) + + visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] + + agent.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale)) + agent.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale)) + agent.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale)) + if opacity: + agent.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale)) + + @classmethod + def compressed_anim(cls, attr_name, tokens, time_scale, rep_cnt='indefinite'): + tokens, times = cls.compress_tokens(tokens) + cumulative = [0, ] + for t in times: + cumulative.append(cumulative[-1] + t) + times = [str(round(value / cumulative[-1], 10)) for value in cumulative] + tokens = [tokens[0]] + tokens + + times = times + tokens = tokens + return Animation( + attributeName=attr_name, dur=f'{time_scale * (-1 + cumulative[-1])}s', + values=";".join(tokens), repeatCount=rep_cnt, keyTimes=";".join(times) + ) + + @staticmethod + def wisely_add(token, cnt, tokens, times): + if cnt > 1: + tokens += [token, token] + times += [1, cnt - 1] + else: + tokens.append(token) + times.append(cnt) + + @classmethod + def compress_tokens(cls, input_tokens: list): + tokens = [] + times = [] + if input_tokens: + cur_idx = 0 + cnt = 1 + for idx in range(1, len(input_tokens)): + if input_tokens[idx] == input_tokens[cur_idx]: + cnt += 1 + else: + cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times) + cnt = 1 + cur_idx = idx + cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times) + return tokens, times + + def animate_targets(self, targets, grid_holder): + gh: GridHolder = grid_holder + ego_idx = gh.config.egocentric_idx + + for agent_idx, target in enumerate(targets): + target_idx = ego_idx if ego_idx is not None else agent_idx + + x_path = [] + y_path = [] + + for step_idx, state in enumerate(gh.history[target_idx]): + x, y = state.get_target_xy() + x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size)) + y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size)) + + visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] + + if gh.on_target == 'restart' or gh.on_target == 'wait': + target.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale)) + target.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale)) + target.add_animation(self.compressed_anim("visibility", visibility, gh.svg_settings.time_scale)) + + def create_obstacles(self, grid_holder): + gh = grid_holder + result = [] + + for i in range(gh.height): + for j in range(gh.width): + x, y = self.fix_point(i, j, gh.width) + + if gh.obstacles[x][y]: + obs_settings = {} + obs_settings.update( + x=gh.svg_settings.draw_start + i * gh.svg_settings.scale_size - gh.svg_settings.r, + y=gh.svg_settings.draw_start + j * gh.svg_settings.scale_size - gh.svg_settings.r, + height=gh.svg_settings.r * 2, + ) + + if gh.config.egocentric_idx is not None and gh.svg_settings.egocentric_shaded: + initial_positions = [agent_states[0].get_xy() for agent_states in gh.history] + ego_x, ego_y = initial_positions[gh.config.egocentric_idx] + if not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius): + obs_settings.update(opacity=gh.svg_settings.shaded_opacity) + + result.append(RectangleHref(**obs_settings)) + + return result + + def animate_obstacles(self, obstacles, grid_holder): + gh: GridHolder = grid_holder + obstacle_idx = 0 + + for i in range(gh.height): + for j in range(gh.width): + x, y = self.fix_point(i, j, gh.width) + if not gh.obstacles[x][y]: + continue + opacity = [] + seen = set() + for step_idx, agent_state in enumerate(gh.history[gh.config.egocentric_idx]): + ego_x, ego_y = agent_state.get_xy() + if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius): + seen.add((x, y)) + if (x, y) in seen: + opacity.append(str(1.0)) + else: + opacity.append(str(gh.svg_settings.shaded_opacity)) + + obstacle = obstacles[obstacle_idx] + obstacle.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale)) + + obstacle_idx += 1 + + def create_agents(self, grid_holder): + initial_positions = [state[0].get_xy() for state in grid_holder.history if state[0].is_active()] + agents = [] + gh: GridHolder = grid_holder + ego_idx = grid_holder.config.egocentric_idx + + for idx, (x, y) in enumerate(initial_positions): + circle_settings = { + 'cx': gh.svg_settings.draw_start + y * gh.svg_settings.scale_size, + 'cy': gh.svg_settings.draw_start + (grid_holder.width - x - 1) * gh.svg_settings.scale_size, + 'r': gh.svg_settings.r, 'fill': grid_holder.colors[idx], 'class': 'agent', + } + + if ego_idx is not None: + ego_x, ego_y = initial_positions[ego_idx] + is_out_of_radius = not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius) + circle_settings['fill'] = gh.svg_settings.ego_other_color + if idx == ego_idx: + circle_settings['fill'] = gh.svg_settings.ego_color + elif is_out_of_radius and gh.svg_settings.egocentric_shaded: + circle_settings['opacity'] = gh.svg_settings.shaded_opacity + + agents.append(Circle(**circle_settings)) + + return agents + + @staticmethod + def create_targets(grid_holder): + gh: GridHolder = grid_holder + targets = [] + for agent_idx, agent_states in enumerate(gh.history): + + tx, ty = agent_states[0].get_target_xy() + x, y = ty, gh.width - tx - 1 + + if not any([agent_state.is_active() for agent_state in gh.history[agent_idx]]): + continue + + circle_settings = {"class": 'target'} + circle_settings.update( + cx=gh.svg_settings.draw_start + x * gh.svg_settings.scale_size, r=gh.svg_settings.r, + cy=gh.svg_settings.draw_start + y * gh.svg_settings.scale_size, stroke=gh.colors[agent_idx], + ) + + if gh.config.egocentric_idx is not None: + if gh.config.egocentric_idx != agent_idx: + continue + + circle_settings.update(stroke=gh.svg_settings.ego_color) + target = Circle(**circle_settings) + targets.append(target) + return targets diff --git a/pogema/svg_animation/animation_wrapper.py b/pogema/svg_animation/animation_wrapper.py new file mode 100644 index 0000000..c87aff1 --- /dev/null +++ b/pogema/svg_animation/animation_wrapper.py @@ -0,0 +1,172 @@ +import os +from itertools import cycle +from gymnasium import logger, Wrapper + +from pogema import GridConfig +from pogema.svg_animation.animation_drawer import AnimationConfig, SvgSettings, GridHolder, AnimationDrawer +from pogema.wrappers.persistence import PersistentWrapper, AgentState + + +class AnimationMonitor(Wrapper): + """ + Defines the animation, which saves the episode as SVG. + """ + + def __init__(self, env, animation_config=AnimationConfig()): + self._working_radius = env.grid_config.obs_radius - 1 + env = PersistentWrapper(env, xy_offset=-self._working_radius) + + super().__init__(env) + + self.history = env.get_history() + + self.svg_settings: SvgSettings = SvgSettings() + self.animation_config: AnimationConfig = animation_config + + self._episode_idx = 0 + + def step(self, action): + """ + Saves information about the episode. + :param action: current actions + :return: obs, reward, done, info + """ + obs, reward, terminated, truncated, info = self.env.step(action) + + multi_agent_terminated = isinstance(terminated, (list, tuple)) and all(terminated) + single_agent_terminated = isinstance(terminated, (bool, int)) and terminated + multi_agent_truncated = isinstance(truncated, (list, tuple)) and all(truncated) + single_agent_truncated = isinstance(truncated, (bool, int)) and truncated + + if multi_agent_terminated or single_agent_terminated or multi_agent_truncated or single_agent_truncated: + save_tau = self.animation_config.save_every_idx_episode + if save_tau: + if (self._episode_idx + 1) % save_tau or save_tau == 1: + if not os.path.exists(self.animation_config.directory): + logger.info(f"Creating pogema monitor directory {self.animation_config.directory}", ) + os.makedirs(self.animation_config.directory, exist_ok=True) + + path = os.path.join(self.animation_config.directory, + self.pick_name(self.grid_config, self._episode_idx)) + self.save_animation(path) + + return obs, reward, terminated, truncated, info + + @staticmethod + def pick_name(grid_config: GridConfig, episode_idx=None, zfill_ep=5): + """ + Picks a name for the SVG file. + :param grid_config: configuration of the grid + :param episode_idx: idx of the episode + :param zfill_ep: zfill for the episode number + :return: + """ + gc = grid_config + name = 'pogema' + if episode_idx is not None: + name += f'-ep{str(episode_idx).zfill(zfill_ep)}' + if gc: + if gc.map_name: + name += f'-{gc.map_name}' + if gc.seed is not None: + name += f'-seed{gc.seed}' + else: + name += '-render' + return name + '.svg' + + def reset(self, **kwargs): + """ + Resets the environment and resets the current positions of agents and targets + :param kwargs: + :return: obs: observation + """ + obs = self.env.reset(**kwargs) + + self._episode_idx += 1 + self.history = self.env.get_history() + + return obs + + def save_animation(self, name='render.svg', animation_config: AnimationConfig = AnimationConfig()): + """ + Saves the animation. + :param name: name of the file + :param animation_config: animation configuration + :return: None + """ + wr = self._working_radius + obstacles = self.env.get_obstacles(ignore_borders=False)[wr:-wr, wr:-wr] + history: list[list[AgentState]] = self.env.decompress_history(self.history) + + svg_settings = SvgSettings() + colors_cycle = cycle(svg_settings.colors) + agents_colors = {index: next(colors_cycle) for index in range(self.grid_config.num_agents)} + + for agent_idx in range(self.grid_config.num_agents): + history[agent_idx].append(history[agent_idx][-1]) + + episode_length = len(history[0]) + # Change episode length for egocentric environment + if animation_config.egocentric_idx is not None and self.grid_config.on_target == 'finish': + episode_length = history[animation_config.egocentric_idx][-1].step + 1 + for agent_idx in range(self.grid_config.num_agents): + history[agent_idx] = history[agent_idx][:episode_length] + + grid_holder = GridHolder( + width=len(obstacles), height=len(obstacles[0]), + obstacles=obstacles, + episode_length=episode_length, + history=history, + obs_radius=self.grid_config.obs_radius, + on_target=self.grid_config.on_target, + colors=agents_colors, + config=animation_config, + svg_settings=svg_settings + ) + + animation = AnimationDrawer().create_animation(grid_holder) + with open(name, "w") as f: + f.write(animation.render()) + + +def main(): + from pogema import GridConfig, pogema_v0, AnimationMonitor, BatchAStarAgent, AnimationConfig + + for egocentric_idx in [0, 1]: + for on_target in ['nothing', 'restart', 'finish']: + grid = """ + ....#.. + ..#.... + ....... + ....... + #.#.#.. + #.#.#.. + """ + grid_config = GridConfig(size=32, num_agents=2, obs_radius=2, seed=8, on_target=on_target, + max_episode_steps=16, + density=0.1, map=grid, observation_type="POMAPF") + env = pogema_v0(grid_config=grid_config) + env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None)) + + obs, _ = env.reset() + truncated = terminated = [False] + + agent = BatchAStarAgent() + while not all(terminated) and not all(truncated): + obs, _, terminated, truncated, _ = env.step(agent.act(obs)) + + anim_folder = 'renders' + if not os.path.exists(anim_folder): + os.makedirs(anim_folder) + + env.save_animation(f'{anim_folder}/anim-{on_target}.svg') + env.save_animation(f'{anim_folder}/anim-{on_target}-ego-{egocentric_idx}.svg', + AnimationConfig(egocentric_idx=egocentric_idx)) + env.save_animation(f'{anim_folder}/anim-static.svg', AnimationConfig(static=True)) + env.save_animation(f'{anim_folder}/anim-static-ego.svg', AnimationConfig(egocentric_idx=0, static=True)) + env.save_animation(f'{anim_folder}/anim-static-no-agents.svg', + AnimationConfig(show_agents=False, static=True)) + + +if __name__ == '__main__': + main() diff --git a/pogema/svg_animation/svg_objects.py b/pogema/svg_animation/svg_objects.py new file mode 100644 index 0000000..c18c13f --- /dev/null +++ b/pogema/svg_animation/svg_objects.py @@ -0,0 +1,77 @@ +class SvgObject: + tag = None + + def __init__(self, **kwargs): + self.attributes = kwargs + self.animations = [] + + def add_animation(self, animation): + self.animations.append(animation) + + @staticmethod + def render_attributes(attributes): + result = " ".join([f'{x.replace("_", "-")}="{y}"' for x, y in sorted(attributes.items())]) + return result + + def render(self): + animations = '\n'.join([a.render() for a in self.animations]) if self.animations else None + if animations: + return f"<{self.tag} {self.render_attributes(self.attributes)}> {animations} " + return f"<{self.tag} {self.render_attributes(self.attributes)} />" + + +class Rectangle(SvgObject): + """ + Rectangle class for the SVG. + """ + tag = 'rect' + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.attributes['y'] = -self.attributes['y'] - self.attributes['height'] + + +class RectangleHref(SvgObject): + """ + Rectangle class for the SVG. + """ + tag = 'use' + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.attributes['y'] = -self.attributes['y'] - self.attributes['height'] + self.attributes['href'] = "#obstacle" + del self.attributes['height'] + + +class Circle(SvgObject): + """ + Circle class for the SVG. + """ + tag = 'circle' + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.attributes['cy'] = -self.attributes['cy'] + + +class Line(SvgObject): + """ + Line class for the SVG. + """ + tag = 'line' + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.attributes['y1'] = -self.attributes['y1'] + self.attributes['y2'] = -self.attributes['y2'] + + +class Animation(SvgObject): + """ + Animation class for the SVG. + """ + tag = 'animate' + + def render(self): + return f"<{self.tag} {self.render_attributes(self.attributes)}/>" diff --git a/pogema/wrappers/persistence.py b/pogema/wrappers/persistence.py index 56634b2..3bd51f4 100644 --- a/pogema/wrappers/persistence.py +++ b/pogema/wrappers/persistence.py @@ -26,12 +26,16 @@ def __eq__(self, other): o = other return self.x == o.x and self.y == o.y and self.tx == o.tx and self.ty == o.ty and self.active == o.active + def __str__(self): + return str([self.x, self.y, self.tx, self.ty, self.step, self.active]) + class PersistentWrapper(Wrapper): - def __init__(self, env): + def __init__(self, env, xy_offset=None): super().__init__(env) self._step = None self._agent_states = None + self._xy_offset = xy_offset def step(self, action): result = self.env.step(action) @@ -67,6 +71,11 @@ def _get_agent_state(self, grid, agent_idx): x, y = grid.positions_xy[agent_idx] tx, ty = grid.finishes_xy[agent_idx] active = grid.is_active[agent_idx] + if self._xy_offset: + x += self._xy_offset + y += self._xy_offset + tx += self._xy_offset + ty += self._xy_offset return AgentState(x, y, tx, ty, self._step, active) def reset(self, **kwargs): diff --git a/tests/test_deterministic_policy.py b/tests/test_deterministic_policy.py index f456c0d..4f79ddf 100644 --- a/tests/test_deterministic_policy.py +++ b/tests/test_deterministic_policy.py @@ -1,7 +1,8 @@ import numpy as np from heapq import heappop, heappush -from pogema import GridConfig, pogema_v0 -from pogema.animation import AnimationMonitor +from pogema import GridConfig, pogema_v0, AnimationMonitor + +# from pogema.animation import AnimationMonitor INF = 1000000007 diff --git a/tests/test_pogema_env.py b/tests/test_pogema_env.py index 1ccad67..7b5c32f 100644 --- a/tests/test_pogema_env.py +++ b/tests/test_pogema_env.py @@ -4,13 +4,12 @@ import numpy as np from tabulate import tabulate -from pogema import pogema_v0 +from pogema import pogema_v0, AnimationMonitor from pogema import Easy8x8, Normal8x8, Hard8x8, ExtraHard8x8 from pogema import Easy16x16, Normal16x16, Hard16x16, ExtraHard16x16 from pogema import Easy32x32, Normal32x32, Hard32x32, ExtraHard32x32 from pogema import Easy64x64, Normal64x64, Hard64x64, ExtraHard64x64 -from pogema.animation import AnimationMonitor from pogema.envs import ActionsSampler from pogema.grid import GridConfig @@ -122,7 +121,8 @@ def test_standard_pogema_animation(): def test_gym_pogema_animation(): import gymnasium env = gymnasium.make('Pogema-v0', - grid_config=GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish')) + grid_config=GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, + on_target='finish')) env = AnimationMonitor(env) env.reset() done = False @@ -131,6 +131,7 @@ def test_gym_pogema_animation(): if terminated or truncated: break + def test_non_disappearing_pogema(): env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='nothing')) env.reset()