-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Predator-Prey Flock Environment #259
base: main
Are you sure you want to change the base?
Conversation
* Initial prototype * feat: Add environment tests * fix: Update esquilax version to fix type issues * docs: Add docstrings * docs: Add docstrings * test: Test multiple reward types * test: Add smoke tests and add max-steps check * feat: Implement pred-prey environment viewer * refactor: Pull out common viewer functionality * test: Add reward and view tests * test: Add rendering tests and add test docstrings * docs: Add predator-prey environment documentation page * docs: Cleanup docstrings * docs: Cleanup docstrings
Here you go @sash-a this is correct now. Will grab a look at the contributor license and Ci failure now. |
I think CI issue is I've Esquilax set to Python |
Python version PR is merged now so hopefully it will pass 😄 Should have time during the week to review this, really appreciate the contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An initial review with some high level comments about jumanji conventions. Will go through it more in depth once these are addressed. In general it's looking really nice and well documented!
Not quite sure on the new swarms package, but also not sure where else we would put it. Not sure on it especially if we only have 1 env and no news ones planned.
One thing I don't quite understand is the benefit of amap
over vmap
specifically in the case of this env?
Please @ me when it's ready for another review or if you have any questions.
import chex | ||
|
||
|
||
@dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume these are fixed attributes for each agent? If so can we be explicit that it is frozen
@dataclass | |
@dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct yeah these remain fixed after creation, will add.
return new_heading, new_speeds | ||
|
||
|
||
@esquilax.transforms.amap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this function why not just vmap
? Since you don't use the params or key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes in this case this is overkill, I'll use vmap inside the update function.
from . import types | ||
from .types import AgentParams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By convention we don't do relative imports in jumanji, also I see you're using types.AgentParams
and also just AgentParams
I think I prefer just using agent params, so:
from . import types | |
from .types import AgentParams | |
from jumanji.environments.swarms.common.types import AgentParams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll switch out relative imports (there's a few around). In this case also have types.Agentstate
so will import the types module.
def init_state( | ||
n: int, params: types.AgentParams, key: chex.PRNGKey | ||
) -> types.AgentState: | ||
""" | ||
Randomly initialise state of a group of agents | ||
Args: | ||
n: Number of agents to initialise. | ||
params: Agent parameters. | ||
key: JAX random key. | ||
Returns: | ||
AgentState: Random agent states (i.e. position, headings, and speeds) | ||
""" | ||
k1, k2, k3 = jax.random.split(key, 3) | ||
|
||
positions = jax.random.uniform(k1, (n, 2)) | ||
speeds = jax.random.uniform( | ||
k2, (n,), minval=params.min_speed, maxval=params.max_speed | ||
) | ||
headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jax.numpy.pi) | ||
|
||
return types.AgentState( | ||
pos=positions, | ||
speed=speeds, | ||
heading=headings, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to turn this into a generator, it's a convention for making it easy to switch the initial state distribution. See cleaner for a good example of how we do generators
AgentState: Updated state of the agents after applying steering | ||
actions and updating positions. | ||
""" | ||
actions = jax.numpy.clip(actions, min=-1.0, max=1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convention in jumanji is to import jax.numpy as jnp
and then use jnp
everywhere instead of jax.numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this to reward.py
to keep with jumanji convention and can you follow this convention for how we write our reward functions 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice yeah, I meant to ask about generic reward functions, reward tuning can be a large part of these multi-agent environments
@dataclass | ||
class Observation: | ||
""" | ||
predators: Local view of predator agents. | ||
prey: Local view of prey agents. | ||
""" | ||
|
||
predators: chex.Array | ||
prey: chex.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By convention our observations are NamedTuple
s and also need to be very well documented, see here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for this (for my own knowledge)?
@dataclass | ||
class Actions: | ||
""" | ||
predators: Array of actions for predator agents. | ||
prey: Array of actions for prey agents. | ||
""" | ||
|
||
predators: chex.Array | ||
prey: chex.Array | ||
|
||
|
||
@dataclass | ||
class Rewards: | ||
""" | ||
predators: Array of individual rewards for predator agents. | ||
prey: Array of individual rewards for prey agents. | ||
""" | ||
|
||
predators: chex.Array | ||
prey: chex.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this is repeated multipe times, maybe a PredatorPrey
type would be best? Although not sure about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I did this in the prototype to indicate something that just had the two fields. I guess you could say for readability and in the strict typing sense these should be different things (was my thinking here)? But also appreciate the repetition is a bit ugly.
pos: chex.Array | ||
heading: chex.Array | ||
speed: chex.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all types we add shape comments so it's easy to understand what we're expecting when debugging. e.g here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes I'll add in
if self.sparse_rewards: | ||
rewards = self._state_to_sparse_rewards(state) | ||
else: | ||
rewards = self._state_to_distance_rewards(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change this to how we set up different reward functions in jumanji, see here
As for your questions in the description:
Nope just the environment is fine
Please do add animation it's a great help.
We do want defaults, I think we can discuss what makes sense.
It's generated with mkdocs, we need an entry in One big thing I've realized that this is missing after my review is training code. We like to validate that the env works. I'm not 100% sure if this is possible because the env has two teams, so which reward do you optimize, maybe training with simple heuristic, eg you are the predator and the prey moves randomly? For examples see the |
Add a predator-prey flock environment where two sets of agents attempt to catch/evade each other.
Changes
swarm
environment group/type (was not sure the new environment fit into an existing group, but happy to move if you think it would better fit somewhere else)Todo
Questions
jumanji.environments
do types also need forwarding somewhere?animate
method to the environment, but saw that some other do? Easy enough to add.