Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(specs): implement sample method #116

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 7 additions & 29 deletions jumanji/environments/packing/bin_pack/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np
import pytest

from jumanji import tree_utils
Expand All @@ -33,29 +32,11 @@
item_from_space,
location_from_space,
)
from jumanji.testing.env_not_smoke import SelectActionFn, check_env_does_not_smoke
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
from jumanji.testing.pytrees import assert_is_jax_array_tree
from jumanji.types import TimeStep


@pytest.fixture
def bin_pack_random_select_action(bin_pack: BinPack) -> SelectActionFn:
num_ems, num_items = np.asarray(bin_pack.action_spec().num_values)

def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array:
"""Randomly sample valid actions, as determined by `observation.action_mask`."""
ems_item_id = jax.random.choice(
key=key,
a=num_ems * num_items,
p=observation.action_mask.flatten(),
)
ems_id, item_id = jnp.divmod(ems_item_id, num_items)
action = jnp.array([ems_id, item_id], jnp.int32)
return action

return jax.jit(select_action) # type: ignore


@pytest.fixture(scope="function")
def normalize_dimensions(request: pytest.mark.FixtureRequest) -> bool:
return request.param # type: ignore
Expand Down Expand Up @@ -160,25 +141,22 @@ def test_bin_pack__render_does_not_smoke(bin_pack: BinPack, dummy_state: State)
bin_pack.close()


def test_bin_pack__does_not_smoke(
bin_pack: BinPack,
bin_pack_random_select_action: SelectActionFn,
) -> None:
def test_bin_pack__does_not_smoke(bin_pack: BinPack) -> None:
"""Test that we can run an episode without any errors."""
check_env_does_not_smoke(bin_pack, bin_pack_random_select_action)
check_env_does_not_smoke(bin_pack)


def test_bin_pack__pack_all_items_dummy_instance(
bin_pack: BinPack, bin_pack_random_select_action: SelectActionFn
) -> None:
def test_bin_pack__pack_all_items_dummy_instance(bin_pack: BinPack) -> None:
"""Functional test to check that the dummy instance can be completed with a random agent."""
step_fn = jax.jit(bin_pack.step)
key = jax.random.PRNGKey(0)
state, timestep = bin_pack.reset(key)

while not timestep.last():
action_key, key = jax.random.split(key)
action = bin_pack_random_select_action(action_key, timestep.observation)
action = bin_pack.action_spec().sample(
action_key, timestep.observation.action_mask
)
state, timestep = step_fn(state, action)

assert jnp.array_equal(state.items_placed, state.items_mask)
Expand Down
16 changes: 2 additions & 14 deletions jumanji/environments/routing/cleaner/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jumanji.environments.routing.cleaner.constants import CLEAN, DIRTY, WALL
from jumanji.environments.routing.cleaner.env import Cleaner
from jumanji.environments.routing.cleaner.generator import Generator
from jumanji.environments.routing.cleaner.types import Observation, State
from jumanji.environments.routing.cleaner.types import State
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
from jumanji.testing.pytrees import assert_is_jax_array_tree
from jumanji.types import StepType, TimeStep
Expand Down Expand Up @@ -176,19 +176,7 @@ def test_cleaner__action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None
assert jnp.all(action_mask[2] == jnp.array([False, False, False, True]))

def test_cleaner__does_not_smoke(self, cleaner: Cleaner) -> None:
def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array:
@jax.vmap # map over the keys and agents
def select_action(
key: chex.PRNGKey, agent_action_mask: chex.Array
) -> chex.Array:
return jax.random.choice(
key, jnp.arange(4), p=agent_action_mask.flatten()
)

subkeys = jax.random.split(key, cleaner.num_agents)
return select_action(subkeys, observation.action_mask)

check_env_does_not_smoke(cleaner, select_actions)
check_env_does_not_smoke(cleaner)

def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
state, _ = cleaner.reset(key)
Expand Down
245 changes: 245 additions & 0 deletions jumanji/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import copy
import functools
import inspect
from math import prod
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
Expand Down Expand Up @@ -116,6 +119,27 @@ def generate_value(self) -> T:
)
return self._constructor(**constructor_kwargs)

def sample(self, key: chex.PRNGKey, mask: Optional[chex.Array] = None) -> T:
"""Sample a random value which conforms to this spec.

Args:
key: random number generator key.
mask: an optional mask array.

Returns:
random sample conforming to spec.
"""
treedef = jax.tree_util.tree_structure(self._specs)
keys_flat = jax.random.split(key, treedef.num_leaves)
keys = jax.tree_util.tree_unflatten(treedef, keys_flat)
if mask is None:
mask = jax.tree_util.tree_unflatten(treedef, [None] * treedef.num_leaves)
return self._constructor(
**jax.tree_util.tree_map(
lambda spec, key, mask: spec.sample(key, mask), self._specs, keys, mask
)
)

def replace(self, **kwargs: Any) -> "Spec":
"""Returns a new copy of `self` with specified attributes replaced.

Expand Down Expand Up @@ -205,6 +229,63 @@ def validate(self, value: chex.Numeric) -> chex.Array:
)
return value

def _sample_scalar(
self, key: chex.PRNGKey, minimum: chex.Scalar, maximum: chex.Scalar
) -> chex.Scalar:
if jnp.issubdtype(self.dtype, jnp.integer):
return jax.random.randint(key, (), minimum, maximum + 1, self.dtype)
elif jnp.issubdtype(self.dtype, jnp.floating):
return jax.lax.switch(
2 * jnp.isneginf(minimum).astype(int) + jnp.isinf(maximum).astype(int),
(
lambda: jax.random.uniform(key, (), self.dtype, minimum, maximum),
lambda: minimum + jax.random.exponential(key, (), self.dtype),
lambda: maximum - jax.random.exponential(key, (), self.dtype),
lambda: jax.random.normal(key, (), self.dtype),
),
)
elif jnp.issubdtype(self.dtype, jnp.bool_):
return jax.random.randint(key, (), minimum, maximum + 1).astype(self.dtype)
raise ValueError(f"Sampling Scalar of type {self.dtype} not supported.")

def _sample_array(
self, key: chex.PRNGKey, minimum: chex.Array, maximum: chex.Array
) -> chex.Array:
keys = jax.random.split(key, prod(self.shape)).reshape(*self.shape, 2)
sample_array = jnp.vectorize(self._sample_scalar, signature="(n),(),()->()")
return sample_array(keys, minimum, maximum)

def sample(
self, key: chex.PRNGKey, mask: Optional[chex.Array] = None
) -> chex.Array:
"""Sample a random value which conforms to this spec.

Args:
key: random number generator key.
mask: an optional mask array (must be None).

Returns:
jax array containing random values depending on dtype.
ints are drawn from [MIN_INT, MAX_INT].
floats are drawn from the normal distribution.
bools are drawn from the bernoulli distribution.

Raises:
ValueError: if mask is not None.
"""
if mask is not None:
raise ValueError("Sampling Array from a mask not supported.")
if jnp.issubdtype(self.dtype, jnp.integer):
info = jnp.iinfo(self.dtype)
minimum, maximum = info.min, info.max - 1
elif jnp.issubdtype(self.dtype, jnp.floating):
minimum, maximum = -jnp.inf, jnp.inf
elif jnp.issubdtype(self.dtype, jnp.bool_):
minimum, maximum = False, True
else:
raise ValueError(f"Sampling Array of type {self.dtype} not supported.")
return self._sample_array(key, minimum, maximum)

def _get_constructor_kwargs(self) -> Dict[str, Any]:
"""Returns constructor kwargs for instantiating a new copy of this spec."""
# Get the names and kinds of the constructor parameters.
Expand Down Expand Up @@ -360,6 +441,35 @@ def validate(self, value: chex.Numeric) -> chex.Array:
)
return value

def sample(
self, key: chex.PRNGKey, mask: Optional[chex.Array] = None
) -> chex.Array:
"""Sample a random value which conforms to this spec.

Args:
key: random number generator key.
mask: an optional mask array (must be None).

Returns:
jax array containing random values depending on dtype.
ints are drawn from [self.minimum, self.maximum].
floats are drawn from:
normal distribution if unbounded,
shifted exponential distribution if bounded below,
shifted negative exponential distribution if bounded above,
[self.minimum, self.maximum) if bounded on both sides.
bools are drawn from:
True if self.minimum = True,
False if self.maximum = False,
bernoulli distribution otherwise.

Raises:
ValueError: if mask is not None.
"""
if mask is not None:
raise ValueError("Sampling BoundedArray from a mask not supported.")
return self._sample_array(key, self.minimum, self.maximum)

def __eq__(self, other: "BoundedArray") -> bool: # type: ignore[override]
if not isinstance(other, BoundedArray):
return NotImplemented
Expand Down Expand Up @@ -426,6 +536,31 @@ def num_values(self) -> int:
"""Returns the number of items."""
return self._num_values

def sample(
self, key: chex.PRNGKey, mask: Optional[chex.Array] = None
) -> chex.Array:
"""Sample a random value which conforms to this spec.

Args:
key: random number generator key.
mask: an optional mask array. Must be of shape (num_values,).
Mask will be interpreted as weights to random.choice.

Returns:
jax array containing random discrete value.

Raises:
ValueError: if mask is not proper shape.
"""
if mask is None:
return super().sample(key)
elif mask.shape != (self.num_values,):
raise ValueError(
f"Expected mask of shape {(self.num_values,)}, "
+ f"but recieved mask of shape {mask.shape}."
)
return jax.random.choice(key, jnp.arange(self.num_values), p=mask)

def __eq__(self, other: "DiscreteArray") -> bool: # type: ignore[override]
if not isinstance(other, DiscreteArray):
return NotImplemented
Expand Down Expand Up @@ -493,6 +628,116 @@ def num_values(self) -> chex.Array:
"""Returns the number of possible values for each element of the action vector."""
return self._num_values

@property
def _tuple_nv(self) -> Tuple[int, ...]:
return (*self.num_values.flatten().tolist(),)

@property
def _first_nv(self) -> int:
return self._tuple_nv[0]

@property
def _all_eq_nv(self) -> bool:
return all(nv == self._first_nv for nv in self._tuple_nv)

def _validate_mode(
self, mask: Optional[chex.Array], mode: Literal["table", "index"]
) -> None:
if mask is None:
raise ValueError("Sampling mode is provided without a mask.")
elif mode != "table" and mode != "index":
raise ValueError(
f"Sampling mode {mode} not supported. Use table or index mode."
)
elif mode == "table" and not self._all_eq_nv:
raise ValueError(
f"Table mode not supported for num_vals of {self.num_values}. "
+ "num_vals must all be equal."
)
elif mode == "table" and mask.shape != (*self.shape, self._first_nv):
raise ValueError(
f"Expected mask of shape {(*self.shape, self._first_nv)} for table mode, "
+ f"but recieved mask shape of {mask.shape}.",
)
elif mode == "index" and mask.shape != self._tuple_nv:
raise ValueError(
f"Expected mask shape of {self._tuple_nv} for index mode, "
+ f"but recieved mask shape of {mask.shape}."
)

def _get_valid_mode(
self, mask: Optional[chex.Array], mode: Optional[Literal["table", "index"]]
) -> Optional[Literal["table", "index"]]:
if mode is not None:
self._validate_mode(mask, mode)
return mode
elif mask is None:
return None

table_valid = self._all_eq_nv and mask.shape == (*self.shape, self._first_nv)
index_valid = mask.shape == self._tuple_nv

if table_valid and index_valid:
raise ValueError(
"Sampling mode is ambiguous. Provide a mode of table or index."
)
elif table_valid:
return "table"
elif index_valid:
return "index"

raise ValueError(
f"Expected mask shape of {(*self.shape, self._first_nv)} for table mode or "
+ f"{self._tuple_nv} for index mode, but recieved mask shape of {mask.shape}."
)

def sample(
self,
key: chex.PRNGKey,
mask: Optional[chex.Array] = None,
mode: Optional[Literal["table", "index"]] = None,
) -> chex.Array:
"""Sample a random value which conforms to this spec.

Args:
key: random number generator key.
mask: an optional mask array. In table mode mask must be of shape
(*self.shape, num_values) and is interpreted as weights to random.choice for each
element. See JobShop for example. In index mode, mask must be of shape
(*self.num_values.flatten()) and is interpreted as weights to random.choice
for an index. See Minesweeper for example.
mode: sampling mode ('table', 'index', or None). If mode is None, will attempt to
determine mode from shape of mask.

Returns:
jax array containing random discrete values.

Raises:
ValueError: if mask is not proper shape or mode is incorrect.
"""
mode = self._get_valid_mode(mask, mode)

if mask is None:
return super().sample(key)

elif mode == "table":
keys = jax.random.split(key, prod(self.shape)).reshape(*self.shape, 2)
indices = jnp.arange(self._first_nv, dtype=self.dtype)
sample_array = jnp.vectorize(
lambda key, mask: jax.random.choice(key, indices, (), p=mask),
signature="(n),(m)->()",
)
return sample_array(keys, mask)

elif mode == "index":
indices = jnp.arange(prod(self._tuple_nv), dtype=self.dtype)
dividers = jnp.array(
[prod(self._tuple_nv[i + 1 :]) for i in range(len(self._tuple_nv))],
dtype=self.dtype,
).reshape(self.shape)
index = jax.random.choice(key, indices, (), p=mask.flatten())
return index // dividers % self.num_values

def __eq__(self, other: "MultiDiscreteArray") -> bool: # type: ignore[override]
if not isinstance(other, MultiDiscreteArray):
return NotImplemented
Expand Down
Loading