diff --git a/vmas/scenarios/joint_passage_size.py b/vmas/scenarios/joint_passage_size.py index 4fa6780e..4fc72d35 100644 --- a/vmas/scenarios/joint_passage_size.py +++ b/vmas/scenarios/joint_passage_size.py @@ -70,6 +70,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): assert self.n_passages == 3 or self.n_passages == 4 + # NOTE: I changed how our VMAS viewer works s.t. it requires a world_semidim, my bad + self.world_semidim = 1.0 + self.plot_grid = False # Make world @@ -100,7 +103,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.n_boxes = int(self.scenario_length // self.passage_length) self.min_collision_distance = 0.005 - cotnroller_params = [2.0, 10, 0.00001] + controller_params = [1.0, 0.0, 0.0] # Add agents agent = Agent( @@ -112,7 +115,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): f_range=10, ) agent.controller = VelocityController( - agent, world, cotnroller_params, "standard" + agent, world, controller_params, "standard" ) world.add_agent(agent) agent = Agent( @@ -126,7 +129,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): f_range=10, ) agent.controller = VelocityController( - agent, world, cotnroller_params, "standard" + agent, world, controller_params, "standard" ) world.add_agent(agent) @@ -510,6 +513,20 @@ def observation(self, agent: Agent): angle_to_vector(self.goal.state.rot), ] + ([angle_to_vector(joint_angle)] if self.observe_joint_angle else []) + # NOTE: I add the size capability in "mixed" form manually here: + radius = agent.shape.radius + mean_radius = (self.agent_radius + self.agent_radius_2) / 2 + relative_radius = agent.shape.radius - mean_radius + capability_repr = [ + torch.tensor( + radius, device=self.world.device + ).repeat(self.world.batch_dim, 1), + torch.tensor( + relative_radius, device=self.world.device + ).repeat(self.world.batch_dim, 1), + ] + observations += capability_repr + if self.obs_noise > 0: for i, obs in enumerate(observations): noise = torch.zeros( @@ -526,25 +543,35 @@ def observation(self, agent: Agent): ) def done(self): - return torch.all( - ( - torch.linalg.vector_norm( - self.joint.landmark.state.pos - self.goal.state.pos, dim=1 - ) - <= 0.01 - ) - * ( - get_line_angle_dist_0_180( - self.joint.landmark.state.rot, self.goal.state.rot - ).unsqueeze(-1) - <= 0.01 - ), - dim=1, - ) + # reimplementation + pos_done = torch.linalg.vector_norm(self.joint.landmark.state.pos - self.goal.state.pos, dim=1) <= 0.01 + rot_done = get_line_angle_dist_0_180(self.joint.landmark.state.rot, self.goal.state.rot) <= 0.01 + done = torch.logical_and(pos_done, rot_done) + return done + + # original + # return torch.all( + # ( + # torch.linalg.vector_norm( + # self.joint.landmark.state.pos - self.goal.state.pos, dim=1 + # ) + # <= 0.01 + # ) + # * ( + # get_line_angle_dist_0_180( + # self.joint.landmark.state.rot, self.goal.state.rot + # ) # .unsqueeze(-1) + # <= 0.01 + # ), + # dim=1, + # ) def info(self, agent: Agent) -> Dict[str, Tensor]: is_first = self.world.agents[0] == agent if is_first: + dist_to_goal = torch.linalg.vector_norm(self.joint.landmark.state.pos - self.goal.state.pos, dim=1) + rot_to_goal = get_line_angle_dist_0_180(self.joint.landmark.state.rot, self.goal.state.rot) + just_passed = self.all_passed * (self.passed == 0) self.passed[just_passed] = 100 self.info_stored = { @@ -553,6 +580,9 @@ def info(self, agent: Agent) -> Dict[str, Tensor]: "collision_rew": self.collision_rew, "energy_rew": self.energy_rew, "passed": just_passed.to(torch.int), + "success_rate": self.done(), # wandb averages over n_envs + "dist_to_goal": dist_to_goal, + "rot_to_goal": rot_to_goal, } return self.info_stored