Skip to content

Commit

Permalink
Add render_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Apr 8, 2024
1 parent ec72008 commit c636f05
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions gym_aloha/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@


class AlohaEnv(gym.Env):
metadata = {"render_modes": [], "render_fps": 50}
# TODO(aliberts): add "human" render_mode
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}

def __init__(
self,
task,
obs_type="pixels",
render_mode="rgb_array",
observation_width=640,
observation_height=480,
visualization_width=640,
Expand All @@ -33,6 +35,7 @@ def __init__(
super().__init__()
self.task = task
self.obs_type = obs_type
self.render_mode = render_mode
self.observation_width = observation_width
self.observation_height = observation_height
self.visualization_width = visualization_width
Expand Down Expand Up @@ -82,14 +85,23 @@ def __init__(

self.action_space = spaces.Box(low=-1, high=1, shape=(len(ACTIONS),), dtype=np.float32)

def render(self, mode="rgb_array"):
def render(self):
return self._render(visualize=True)

def _render(self, visualize=False):
assert self.render_mode == "rgb_array"
width, height = (
(self.visualization_width, self.visualization_height)
if visualize
else (self.observation_width, self.observation_height)
)
# if mode in ["visualize", "human"]:
# height, width = self.visualize_height, self.visualize_width
# elif mode == "rgb_array":
# height, width = self.observation_height, self.observation_width
# else:
# raise ValueError(mode)
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
if mode in ["visualize", "human"]:
height, width = self.visualize_height, self.visualize_width
elif mode == "rgb_array":
height, width = self.observation_height, self.observation_width
else:
raise ValueError(mode)
image = self._env.physics.render(height=height, width=width, camera_id="top")
return image

Expand Down

0 comments on commit c636f05

Please sign in to comment.