diff --git a/vmas/scenarios/balance.py b/vmas/scenarios/balance.py index dc7fa410..8fa753d0 100644 --- a/vmas/scenarios/balance.py +++ b/vmas/scenarios/balance.py @@ -2,6 +2,8 @@ # ProrokLab (https://www.proroklab.org/) # All rights reserved. +import numpy as np +import random import torch from vmas import render_interactively @@ -12,27 +14,110 @@ class Scenario(BaseScenario): + def get_rng_state(self, device): + """ + Returns a tuple of the form + (numpy random state, python's random state, torch's random state, torch.cuda's random state) + """ + np_rng_state = np.random.get_state() + py_rng_state = random.getstate() + torch_rng_state = torch.get_rng_state() + torch_cuda_rng_state = torch.cuda.get_rng_state(device) + + return (np_rng_state, py_rng_state, torch_rng_state, torch_cuda_rng_state) + + def set_eval_seed(self, eval_seed): + """ + Set a new seed for numpy, python.random, torch.random, and torch.cuda.random. + + Intended to be used only with eval_seed + wrapped by get/set_rng_state(). + """ + torch.manual_seed(self.eval_seed) + torch.cuda.manual_seed_all(self.eval_seed) + random.seed(self.eval_seed) + np.random.seed(self.eval_seed) + + def set_rng_state(self, old_rng_state, device): + """ + Restore the prior RNG state (based on the return value of get_rng_state). + """ + assert old_rng_state is not None, "set_rng_state() must be called with the return value of get_rng_state()!" + + np_rng_state, py_rng_state, torch_rng_state, torch_cuda_rng_state = old_rng_state + + np.random.set_state(np_rng_state) + random.setstate(py_rng_state) + torch.set_rng_state(torch_rng_state) + torch.cuda.set_rng_state(torch_cuda_rng_state, device) + def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.n_agents = kwargs.get("n_agents", 3) self.package_mass = kwargs.get("package_mass", 5) self.random_package_pos_on_line = kwargs.get("random_package_pos_on_line", True) + self.world_semidim = kwargs.get("world_semidim", 1.0) + self.gravity = kwargs.get("gravity", -0.05) + self.eval_seed = kwargs.get("eval_seed", None) + + # capabilities + self.capability_mult_range = kwargs.get("capability_mult_range", [0.75, 1.25]) + self.multiple_ranges = kwargs.get("multiple_ranges", False) + if not self.multiple_ranges: + self.capability_mult_min = self.capability_mult_range[0] + self.capability_mult_max = self.capability_mult_range[1] + self.capability_representation = kwargs.get("capability_representation", "raw") + self.default_u_multiplier = kwargs.get("default_u_multiplier", 0.7) + self.default_agent_radius = kwargs.get("default_agent_radius", 0.03) + self.default_agent_mass = kwargs.get("default_agent_mass", 1) + + # metrics + self.success_rate = None + + # rng + rng_state = None + if self.eval_seed: + rng_state = self.get_rng_state(device) + self.set_eval_seed(self.eval_seed) assert self.n_agents > 1 self.line_length = 0.8 - self.agent_radius = 0.03 - self.shaping_factor = 100 - self.fall_reward = -10 + self.shaping_factor = 1 + self.fall_reward = -0.1 # Make world - world = World(batch_dim, device, gravity=(0.0, -0.05), y_semidim=1) + world = World(batch_dim, device, gravity=(0.0, self.gravity), y_semidim=self.world_semidim) # Add agents + capabilities = [] # save capabilities for relative capabilities later for i in range(self.n_agents): + if self.multiple_ranges: + cap_idx = int(random.choice(np.arange(len(self.capability_mult_range)))) + self.capability_mult_min = self.capability_mult_range[cap_idx][0] + self.capability_mult_max = self.capability_mult_range[cap_idx][1] + print("MADE IT HERE") + max_u = self.default_u_multiplier * random.uniform(self.capability_mult_min, self.capability_mult_max) + if self.multiple_ranges: + cap_idx = int(random.choice(np.arange(len(self.capability_mult_range)))) + self.capability_mult_min = self.capability_mult_range[cap_idx][0] + self.capability_mult_max = self.capability_mult_range[cap_idx][1] + radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max) + if self.multiple_ranges: + cap_idx = int(random.choice(np.arange(len(self.capability_mult_range)))) + self.capability_mult_min = self.capability_mult_range[cap_idx][0] + self.capability_mult_max = self.capability_mult_range[cap_idx][1] + mass = self.default_agent_mass * random.uniform(self.capability_mult_min, self.capability_mult_max) + agent = Agent( - name=f"agent_{i}", shape=Sphere(self.agent_radius), u_multiplier=0.7 + name=f"agent_{i}", + shape=Sphere(radius), + u_multiplier=max_u, + mass=mass, + render_action=True, ) + capabilities.append([max_u, agent.shape.radius, agent.mass]) world.add_agent(agent) + self.capabilities = torch.tensor(capabilities) + print(self.capabilities) goal = Landmark( name="goal", @@ -75,9 +160,46 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32) self.ground_rew = self.pos_rew.clone() + if self.eval_seed: + self.set_rng_state(rng_state, device) + return world def reset_world_at(self, env_index: int = None): + rng_state = None + if self.eval_seed: + rng_state = self.get_rng_state(self.world.device) + self.set_eval_seed(self.eval_seed) + + # reset capabilities, only do this during batched resets! + if not env_index: + capabilities = [] # save capabilities for relative capabilities later + for agent in self.world.agents: + if self.multiple_ranges: + cap_idx = int(random.choice(np.arange(len(self.capability_mult_range)))) + self.capability_mult_min = self.capability_mult_range[cap_idx][0] + self.capability_mult_max = self.capability_mult_range[cap_idx][1] + max_u = self.default_u_multiplier * random.uniform(self.capability_mult_min, self.capability_mult_max) + if self.multiple_ranges: + cap_idx = int(random.choice(np.arange(len(self.capability_mult_range)))) + self.capability_mult_min = self.capability_mult_range[cap_idx][0] + self.capability_mult_max = self.capability_mult_range[cap_idx][1] + radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max) + if self.multiple_ranges: + cap_idx = int(random.choice(np.arange(len(self.capability_mult_range)))) + self.capability_mult_min = self.capability_mult_range[cap_idx][0] + self.capability_mult_max = self.capability_mult_range[cap_idx][1] + mass = self.default_agent_mass * random.uniform(self.capability_mult_min, self.capability_mult_max) + + # capabilities.append([max_u, agent.shape.radius, agent.mass]) + capabilities.append([max_u, radius, mass]) + + agent.u_multiplier=max_u + agent.shape=Sphere(radius) + agent.mass=mass + + self.capabilities = torch.tensor(capabilities) + goal_pos = torch.cat( [ torch.zeros( @@ -111,7 +233,8 @@ def reset_world_at(self, env_index: int = None): ), torch.full( (1, 1) if env_index is not None else (self.world.batch_dim, 1), - -self.world.y_semidim + self.agent_radius * 2, + -self.world.y_semidim + self.default_agent_radius * self.capability_mult_max * 2 if not self.multiple_ranges else \ + -self.world.y_semidim + self.default_agent_radius * self.capability_mult_range[-1][1] * 2, device=self.world.device, dtype=torch.float32, ), @@ -151,7 +274,7 @@ def reset_world_at(self, env_index: int = None): + i * (self.line_length - agent.shape.radius) / (self.n_agents - 1), - -self.agent_radius * 2, + -agent.shape.radius * 2, ], device=self.world.device, dtype=torch.float32, @@ -182,7 +305,10 @@ def reset_world_at(self, env_index: int = None): 0, -self.world.y_semidim - self.floor.shape.width / 2 - - self.agent_radius, + - ( + self.default_agent_radius * self.capability_mult_max if not self.multiple_ranges else \ + self.default_agent_radius * self.capability_mult_range[-1][1] + ), ], device=self.world.device, ), @@ -205,6 +331,9 @@ def reset_world_at(self, env_index: int = None): * self.shaping_factor ) + if self.eval_seed: + self.set_rng_state(rng_state, self.world.device) + def compute_on_the_ground(self): self.on_the_ground = self.world.is_overlapping( self.line, self.floor @@ -230,8 +359,59 @@ def reward(self, agent: Agent): return self.ground_rew + self.pos_rew + def get_capability_repr(self, agent: Agent): + """ + Get capability representation: + raw = raw multiplier values + relative = zero-meaned (taking mean of team into account) + mixed = raw + relative (concatenated) + """ + # agent's normal capabilities + max_u = agent.u_multiplier + radius = agent.shape.radius + mass = agent.mass + + # compute the mean capabilities across the team's agents + # then compute "relative capability" of this agent by subtracting the mean + team_mean = list(torch.mean(self.capabilities, dim=0)) + rel_max_u = max_u - team_mean[0].item() + rel_radius = radius - team_mean[1].item() + rel_mass = mass - team_mean[2].item() + + raw_capability_repr = [ + torch.tensor( + max_u, device=self.world.device + ).repeat(self.world.batch_dim, 1), + torch.tensor( + radius, device=self.world.device + ).repeat(self.world.batch_dim, 1), + torch.tensor( + mass, device=self.world.device + ).repeat(self.world.batch_dim, 1), + ] + + rel_capability_repr = [ + torch.tensor( + rel_max_u, device=self.world.device + ).repeat(self.world.batch_dim, 1), + torch.tensor( + rel_radius, device=self.world.device + ).repeat(self.world.batch_dim, 1), + torch.tensor( + rel_mass, device=self.world.device + ).repeat(self.world.batch_dim, 1), + ] + + if self.capability_representation == "raw": + return raw_capability_repr + elif self.capability_representation == "relative": + return rel_capability_repr + elif self.capability_representation == "mixed": + return raw_capability_repr + rel_capability_repr + def observation(self, agent: Agent): # get positions of all entities in this agent's reference frame + capability_repr = self.get_capability_repr(agent) return torch.cat( [ agent.state.pos, @@ -243,7 +423,7 @@ def observation(self, agent: Agent): self.line.state.vel, self.line.state.ang_vel, self.line.state.rot % torch.pi, - ], + ] + capability_repr, dim=-1, ) @@ -253,7 +433,10 @@ def done(self): ) def info(self, agent: Agent): - info = {"pos_rew": self.pos_rew, "ground_rew": self.ground_rew} + self.success_rate = self.world.is_overlapping( + self.package, self.package.goal + ) + info = {"pos_rew": self.pos_rew, "ground_rew": self.ground_rew, "success_rate": self.success_rate} return info