Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions vmas/scenarios/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.package_length = kwargs.get("package_length", 0.4)
self.package_rotatable = kwargs.get("package_rotatable", True)
self.package_mass = kwargs.get("package_mass", 3)
# how far away the packages can spawn from the goal
self.min_pkg_goal_spawn_dist = kwargs.get("min_pkg_goal_spawn_dist", 0.01)

# partial obs
self.partial_observations = kwargs.get("partial_observations", False)
Expand All @@ -52,15 +54,16 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
# TODO: implement automated domain randomization here?

# rewards
self.agent_package_dist_reward_factor = kwargs.get("agent_package_dist_reward_factor", 0.1)
self.package_goal_dist_reward_factor = kwargs.get("package_goal_dist_reward_factor", 100)
self.agent_package_dist_reward_factor = kwargs.get("agent_package_dist_reward_factor", 0)
self.package_goal_dist_reward_factor = kwargs.get("package_goal_dist_reward_factor", 0)
self.agent_near_pkg_rew_factor = kwargs.get("agent_near_pkg_rew_factor", 0)

self.min_collision_distance = 0.05 * self.default_agent_radius # default navigation collision dist is 5% of the agent radius
self.interagent_collision_penalty = kwargs.get("interagent_collision_penalty", -1)
self.interagent_collision_penalty = kwargs.get("interagent_collision_penalty", 0)
assert self.interagent_collision_penalty <= 0, f"self.interagent_collision_penalty must be <= 0, current value is {self.interagent_collision_penalty}!"

self.add_dense_reward = kwargs.get("add_dense_reward", True)
self.package_on_goal_reward_factor = kwargs.get("package_on_goal_reward_factor", 1.0)
self.package_on_goal_reward_factor = kwargs.get("package_on_goal_reward_factor", 0.0)
self.agent_touching_package_reward_factor = kwargs.get("agent_touching_package_reward_factor", 0.0)
self.time_penalty = kwargs.get("time_penalty", 0.0)

Expand Down Expand Up @@ -168,7 +171,7 @@ def reset_world_at(self, env_index: int = None):
self.world,
env_index,
min_dist_between_entities=max(
package.shape.circumscribed_radius() + goal.shape.radius + 0.01
package.shape.circumscribed_radius() + goal.shape.radius + self.min_pkg_goal_spawn_dist
for package in self.packages
),
x_bounds=(
Expand Down Expand Up @@ -253,17 +256,18 @@ def reward(self, agent: Agent):
Color.GREEN.value, device=self.world.device, dtype=torch.float32
)

# dense reward
if self.add_dense_reward:
# reward for pushing the package closer to goal than previous step
package_shaping = package.dist_to_goal * self.package_goal_dist_reward_factor
self.rew[~package.on_goal] += (
package.global_shaping[~package.on_goal]
- package_shaping[~package.on_goal]
)
# "global shaping" = the last package dist * goal_dist_rew_factor
package.global_shaping = package_shaping

# positive reward when the agent achieves the goal
self.rew[package.on_goal] += 1.0 * self.package_on_goal_reward_factor
# self.rew[package.on_goal] += 1.0 * self.package_on_goal_reward_factor

_time_penalty += self.time_penalty
# penalty (negative rew) for agent-agent collisions
Expand All @@ -283,14 +287,19 @@ def reward(self, agent: Agent):
distance <= self.min_collision_distance
] += self.interagent_collision_penalty

# reward for how close agents are to all packages
# reward agents for being near a package
if self.add_dense_reward:
for i, package in enumerate(self.packages):
dist_to_pkg = torch.linalg.vector_norm(agent.state.pos - package.state.pos, dim=-1)
agent_touching_package=self.world.is_overlapping(package, agent)
self.rew += (-dist_to_pkg * self.agent_package_dist_reward_factor) + self.agent_touching_package_reward_factor * agent_touching_package
agent_diameter = torch.ones(dist_to_pkg.shape, device=self.world.device) * (agent.shape.radius * 2)
near_pkg = dist_to_pkg < 1.5 * agent_diameter
self.rew[near_pkg] += 1.0 * self.agent_near_pkg_rew_factor
# self.rew[~near_pkg] -= 1.0 * self.agent_near_pkg_rew_factor

return self.rew + agent.agent_collision_rew + _time_penalty
# agent_touching_package=self.world.is_overlapping(package, agent)
# self.rew += (-dist_to_pkg * self.agent_package_dist_reward_factor) + self.agent_touching_package_reward_factor * agent_touching_package

return self.rew + agent.agent_collision_rew # + _time_penalty

def info(self, agent: Agent):
"""
Expand Down