diff --git a/jumanji/environments/packing/bin_pack/env_test.py b/jumanji/environments/packing/bin_pack/env_test.py index 921ce7025..0a6836858 100644 --- a/jumanji/environments/packing/bin_pack/env_test.py +++ b/jumanji/environments/packing/bin_pack/env_test.py @@ -17,7 +17,6 @@ import chex import jax import jax.numpy as jnp -import numpy as np import pytest from jumanji import tree_utils @@ -33,29 +32,11 @@ item_from_space, location_from_space, ) -from jumanji.testing.env_not_smoke import SelectActionFn, check_env_does_not_smoke +from jumanji.testing.env_not_smoke import check_env_does_not_smoke from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep -@pytest.fixture -def bin_pack_random_select_action(bin_pack: BinPack) -> SelectActionFn: - num_ems, num_items = np.asarray(bin_pack.action_spec().num_values) - - def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array: - """Randomly sample valid actions, as determined by `observation.action_mask`.""" - ems_item_id = jax.random.choice( - key=key, - a=num_ems * num_items, - p=observation.action_mask.flatten(), - ) - ems_id, item_id = jnp.divmod(ems_item_id, num_items) - action = jnp.array([ems_id, item_id], jnp.int32) - return action - - return jax.jit(select_action) # type: ignore - - @pytest.fixture(scope="function") def normalize_dimensions(request: pytest.mark.FixtureRequest) -> bool: return request.param # type: ignore @@ -160,17 +141,12 @@ def test_bin_pack__render_does_not_smoke(bin_pack: BinPack, dummy_state: State) bin_pack.close() -def test_bin_pack__does_not_smoke( - bin_pack: BinPack, - bin_pack_random_select_action: SelectActionFn, -) -> None: +def test_bin_pack__does_not_smoke(bin_pack: BinPack) -> None: """Test that we can run an episode without any errors.""" - check_env_does_not_smoke(bin_pack, bin_pack_random_select_action) + check_env_does_not_smoke(bin_pack) -def test_bin_pack__pack_all_items_dummy_instance( - bin_pack: BinPack, bin_pack_random_select_action: SelectActionFn -) -> None: +def test_bin_pack__pack_all_items_dummy_instance(bin_pack: BinPack) -> None: """Functional test to check that the dummy instance can be completed with a random agent.""" step_fn = jax.jit(bin_pack.step) key = jax.random.PRNGKey(0) @@ -178,7 +154,9 @@ def test_bin_pack__pack_all_items_dummy_instance( while not timestep.last(): action_key, key = jax.random.split(key) - action = bin_pack_random_select_action(action_key, timestep.observation) + action = bin_pack.action_spec().sample( + action_key, timestep.observation.action_mask + ) state, timestep = step_fn(state, action) assert jnp.array_equal(state.items_placed, state.items_mask) diff --git a/jumanji/environments/routing/cleaner/env_test.py b/jumanji/environments/routing/cleaner/env_test.py index d88ec6acd..dd494eed5 100644 --- a/jumanji/environments/routing/cleaner/env_test.py +++ b/jumanji/environments/routing/cleaner/env_test.py @@ -20,7 +20,7 @@ from jumanji.environments.routing.cleaner.constants import CLEAN, DIRTY, WALL from jumanji.environments.routing.cleaner.env import Cleaner from jumanji.environments.routing.cleaner.generator import Generator -from jumanji.environments.routing.cleaner.types import Observation, State +from jumanji.environments.routing.cleaner.types import State from jumanji.testing.env_not_smoke import check_env_does_not_smoke from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -176,19 +176,7 @@ def test_cleaner__action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None assert jnp.all(action_mask[2] == jnp.array([False, False, False, True])) def test_cleaner__does_not_smoke(self, cleaner: Cleaner) -> None: - def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array: - @jax.vmap # map over the keys and agents - def select_action( - key: chex.PRNGKey, agent_action_mask: chex.Array - ) -> chex.Array: - return jax.random.choice( - key, jnp.arange(4), p=agent_action_mask.flatten() - ) - - subkeys = jax.random.split(key, cleaner.num_agents) - return select_action(subkeys, observation.action_mask) - - check_env_does_not_smoke(cleaner, select_actions) + check_env_does_not_smoke(cleaner) def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: state, _ = cleaner.reset(key) diff --git a/jumanji/specs.py b/jumanji/specs.py index 6dc40237b..9e7519e51 100644 --- a/jumanji/specs.py +++ b/jumanji/specs.py @@ -16,13 +16,16 @@ import copy import functools import inspect +from math import prod from typing import ( Any, Callable, Dict, Generic, Iterable, + Literal, NamedTuple, + Optional, Sequence, Tuple, Type, @@ -116,6 +119,27 @@ def generate_value(self) -> T: ) return self._constructor(**constructor_kwargs) + def sample(self, key: chex.PRNGKey, mask: Optional[chex.Array] = None) -> T: + """Sample a random value which conforms to this spec. + + Args: + key: random number generator key. + mask: an optional mask array. + + Returns: + random sample conforming to spec. + """ + treedef = jax.tree_util.tree_structure(self._specs) + keys_flat = jax.random.split(key, treedef.num_leaves) + keys = jax.tree_util.tree_unflatten(treedef, keys_flat) + if mask is None: + mask = jax.tree_util.tree_unflatten(treedef, [None] * treedef.num_leaves) + return self._constructor( + **jax.tree_util.tree_map( + lambda spec, key, mask: spec.sample(key, mask), self._specs, keys, mask + ) + ) + def replace(self, **kwargs: Any) -> "Spec": """Returns a new copy of `self` with specified attributes replaced. @@ -205,6 +229,63 @@ def validate(self, value: chex.Numeric) -> chex.Array: ) return value + def _sample_scalar( + self, key: chex.PRNGKey, minimum: chex.Scalar, maximum: chex.Scalar + ) -> chex.Scalar: + if jnp.issubdtype(self.dtype, jnp.integer): + return jax.random.randint(key, (), minimum, maximum + 1, self.dtype) + elif jnp.issubdtype(self.dtype, jnp.floating): + return jax.lax.switch( + 2 * jnp.isneginf(minimum).astype(int) + jnp.isinf(maximum).astype(int), + ( + lambda: jax.random.uniform(key, (), self.dtype, minimum, maximum), + lambda: minimum + jax.random.exponential(key, (), self.dtype), + lambda: maximum - jax.random.exponential(key, (), self.dtype), + lambda: jax.random.normal(key, (), self.dtype), + ), + ) + elif jnp.issubdtype(self.dtype, jnp.bool_): + return jax.random.randint(key, (), minimum, maximum + 1).astype(self.dtype) + raise ValueError(f"Sampling Scalar of type {self.dtype} not supported.") + + def _sample_array( + self, key: chex.PRNGKey, minimum: chex.Array, maximum: chex.Array + ) -> chex.Array: + keys = jax.random.split(key, prod(self.shape)).reshape(*self.shape, 2) + sample_array = jnp.vectorize(self._sample_scalar, signature="(n),(),()->()") + return sample_array(keys, minimum, maximum) + + def sample( + self, key: chex.PRNGKey, mask: Optional[chex.Array] = None + ) -> chex.Array: + """Sample a random value which conforms to this spec. + + Args: + key: random number generator key. + mask: an optional mask array (must be None). + + Returns: + jax array containing random values depending on dtype. + ints are drawn from [MIN_INT, MAX_INT]. + floats are drawn from the normal distribution. + bools are drawn from the bernoulli distribution. + + Raises: + ValueError: if mask is not None. + """ + if mask is not None: + raise ValueError("Sampling Array from a mask not supported.") + if jnp.issubdtype(self.dtype, jnp.integer): + info = jnp.iinfo(self.dtype) + minimum, maximum = info.min, info.max - 1 + elif jnp.issubdtype(self.dtype, jnp.floating): + minimum, maximum = -jnp.inf, jnp.inf + elif jnp.issubdtype(self.dtype, jnp.bool_): + minimum, maximum = False, True + else: + raise ValueError(f"Sampling Array of type {self.dtype} not supported.") + return self._sample_array(key, minimum, maximum) + def _get_constructor_kwargs(self) -> Dict[str, Any]: """Returns constructor kwargs for instantiating a new copy of this spec.""" # Get the names and kinds of the constructor parameters. @@ -360,6 +441,35 @@ def validate(self, value: chex.Numeric) -> chex.Array: ) return value + def sample( + self, key: chex.PRNGKey, mask: Optional[chex.Array] = None + ) -> chex.Array: + """Sample a random value which conforms to this spec. + + Args: + key: random number generator key. + mask: an optional mask array (must be None). + + Returns: + jax array containing random values depending on dtype. + ints are drawn from [self.minimum, self.maximum]. + floats are drawn from: + normal distribution if unbounded, + shifted exponential distribution if bounded below, + shifted negative exponential distribution if bounded above, + [self.minimum, self.maximum) if bounded on both sides. + bools are drawn from: + True if self.minimum = True, + False if self.maximum = False, + bernoulli distribution otherwise. + + Raises: + ValueError: if mask is not None. + """ + if mask is not None: + raise ValueError("Sampling BoundedArray from a mask not supported.") + return self._sample_array(key, self.minimum, self.maximum) + def __eq__(self, other: "BoundedArray") -> bool: # type: ignore[override] if not isinstance(other, BoundedArray): return NotImplemented @@ -426,6 +536,31 @@ def num_values(self) -> int: """Returns the number of items.""" return self._num_values + def sample( + self, key: chex.PRNGKey, mask: Optional[chex.Array] = None + ) -> chex.Array: + """Sample a random value which conforms to this spec. + + Args: + key: random number generator key. + mask: an optional mask array. Must be of shape (num_values,). + Mask will be interpreted as weights to random.choice. + + Returns: + jax array containing random discrete value. + + Raises: + ValueError: if mask is not proper shape. + """ + if mask is None: + return super().sample(key) + elif mask.shape != (self.num_values,): + raise ValueError( + f"Expected mask of shape {(self.num_values,)}, " + + f"but recieved mask of shape {mask.shape}." + ) + return jax.random.choice(key, jnp.arange(self.num_values), p=mask) + def __eq__(self, other: "DiscreteArray") -> bool: # type: ignore[override] if not isinstance(other, DiscreteArray): return NotImplemented @@ -493,6 +628,116 @@ def num_values(self) -> chex.Array: """Returns the number of possible values for each element of the action vector.""" return self._num_values + @property + def _tuple_nv(self) -> Tuple[int, ...]: + return (*self.num_values.flatten().tolist(),) + + @property + def _first_nv(self) -> int: + return self._tuple_nv[0] + + @property + def _all_eq_nv(self) -> bool: + return all(nv == self._first_nv for nv in self._tuple_nv) + + def _validate_mode( + self, mask: Optional[chex.Array], mode: Literal["table", "index"] + ) -> None: + if mask is None: + raise ValueError("Sampling mode is provided without a mask.") + elif mode != "table" and mode != "index": + raise ValueError( + f"Sampling mode {mode} not supported. Use table or index mode." + ) + elif mode == "table" and not self._all_eq_nv: + raise ValueError( + f"Table mode not supported for num_vals of {self.num_values}. " + + "num_vals must all be equal." + ) + elif mode == "table" and mask.shape != (*self.shape, self._first_nv): + raise ValueError( + f"Expected mask of shape {(*self.shape, self._first_nv)} for table mode, " + + f"but recieved mask shape of {mask.shape}.", + ) + elif mode == "index" and mask.shape != self._tuple_nv: + raise ValueError( + f"Expected mask shape of {self._tuple_nv} for index mode, " + + f"but recieved mask shape of {mask.shape}." + ) + + def _get_valid_mode( + self, mask: Optional[chex.Array], mode: Optional[Literal["table", "index"]] + ) -> Optional[Literal["table", "index"]]: + if mode is not None: + self._validate_mode(mask, mode) + return mode + elif mask is None: + return None + + table_valid = self._all_eq_nv and mask.shape == (*self.shape, self._first_nv) + index_valid = mask.shape == self._tuple_nv + + if table_valid and index_valid: + raise ValueError( + "Sampling mode is ambiguous. Provide a mode of table or index." + ) + elif table_valid: + return "table" + elif index_valid: + return "index" + + raise ValueError( + f"Expected mask shape of {(*self.shape, self._first_nv)} for table mode or " + + f"{self._tuple_nv} for index mode, but recieved mask shape of {mask.shape}." + ) + + def sample( + self, + key: chex.PRNGKey, + mask: Optional[chex.Array] = None, + mode: Optional[Literal["table", "index"]] = None, + ) -> chex.Array: + """Sample a random value which conforms to this spec. + + Args: + key: random number generator key. + mask: an optional mask array. In table mode mask must be of shape + (*self.shape, num_values) and is interpreted as weights to random.choice for each + element. See JobShop for example. In index mode, mask must be of shape + (*self.num_values.flatten()) and is interpreted as weights to random.choice + for an index. See Minesweeper for example. + mode: sampling mode ('table', 'index', or None). If mode is None, will attempt to + determine mode from shape of mask. + + Returns: + jax array containing random discrete values. + + Raises: + ValueError: if mask is not proper shape or mode is incorrect. + """ + mode = self._get_valid_mode(mask, mode) + + if mask is None: + return super().sample(key) + + elif mode == "table": + keys = jax.random.split(key, prod(self.shape)).reshape(*self.shape, 2) + indices = jnp.arange(self._first_nv, dtype=self.dtype) + sample_array = jnp.vectorize( + lambda key, mask: jax.random.choice(key, indices, (), p=mask), + signature="(n),(m)->()", + ) + return sample_array(keys, mask) + + elif mode == "index": + indices = jnp.arange(prod(self._tuple_nv), dtype=self.dtype) + dividers = jnp.array( + [prod(self._tuple_nv[i + 1 :]) for i in range(len(self._tuple_nv))], + dtype=self.dtype, + ).reshape(self.shape) + index = jax.random.choice(key, indices, (), p=mask.flatten()) + return index // dividers % self.num_values + def __eq__(self, other: "MultiDiscreteArray") -> bool: # type: ignore[override] if not isinstance(other, MultiDiscreteArray): return NotImplemented diff --git a/jumanji/specs_test.py b/jumanji/specs_test.py index f3c94fd29..baca4322f 100644 --- a/jumanji/specs_test.py +++ b/jumanji/specs_test.py @@ -16,11 +16,12 @@ # ============================================================================ import pickle from collections import namedtuple -from typing import Any, NamedTuple, Sequence, Union +from typing import Any, Literal, NamedTuple, Optional, Sequence, Tuple, Union import chex import dm_env.specs import gym.spaces +import jax import jax.numpy as jnp import numpy as np import pytest @@ -121,24 +122,28 @@ def test_spec__generate_value(self, triply_nested_spec: specs.Spec) -> None: SinglyNested, ) + def test_spec__sample(self, triply_nested_spec: specs.Spec) -> None: + key = jax.random.PRNGKey(0) + assert isinstance(triply_nested_spec.sample(key), TriplyNested) + assert isinstance(triply_nested_spec["doubly_nested"].sample(key), DoublyNested) + assert isinstance( + triply_nested_spec["doubly_nested"]["singly_nested"].sample(key), + SinglyNested, + ) + def test_spec__validate(self, triply_nested_spec: specs.Spec) -> None: - singly_nested = triply_nested_spec["doubly_nested"][ - "singly_nested" - ].generate_value() + key = jax.random.PRNGKey(0) + singly_nested = triply_nested_spec["doubly_nested"]["singly_nested"].sample(key) + singly_nested = triply_nested_spec["doubly_nested"]["singly_nested"].validate( + singly_nested + ) assert isinstance(singly_nested, SinglyNested) - doubly_nested = DoublyNested( - singly_nested=singly_nested, - discrete_array=jnp.ones((), jnp.int32), - ) + doubly_nested = triply_nested_spec["doubly_nested"].sample(key) doubly_nested = triply_nested_spec["doubly_nested"].validate(doubly_nested) assert isinstance(doubly_nested, DoublyNested) - triply_nested = TriplyNested( - doubly_nested=doubly_nested, - bounded_array=jnp.ones((7, 9), jnp.int32), - discrete_array=jnp.ones((), jnp.int32), - ) + triply_nested = triply_nested_spec.sample(key) triply_nested = triply_nested_spec.validate(triply_nested) assert isinstance(triply_nested, TriplyNested) @@ -238,6 +243,18 @@ def test_generate_value(self) -> None: test_value = spec.generate_value() spec.validate(test_value) + @pytest.mark.parametrize("dtype", (float, int, bool)) + def test_sample(self, dtype: jnp.dtype) -> None: + key = jax.random.PRNGKey(0) + + spec = specs.Array((1, 2), dtype) + test_sample = spec.sample(key) + spec.validate(test_sample) + + mask = jnp.zeros((1, 2), jnp.bool_) + with pytest.raises(ValueError): + spec.sample(key, mask) + def test_serialization(self) -> None: spec = specs.Array([1, 5], jnp.float32, "pickle_test") loaded_spec = pickle.loads(pickle.dumps(spec)) @@ -398,6 +415,30 @@ def test_generate_value(self) -> None: test_value = spec.generate_value() spec.validate(test_value) + @pytest.mark.parametrize( + "dtype,minimum,maximum", + ( + (int, jnp.array(0), jnp.array((10, 20, 30))), + (float, jnp.array(0.0), jnp.array((3.14, 15.9, 265.4))), + (float, jnp.array(-jnp.inf), jnp.array((3.14, 15.9, 265.4))), + (float, jnp.array(0.0), jnp.array((jnp.inf, 15.9, jnp.inf))), + (float, jnp.array(-jnp.inf), jnp.array((jnp.inf, 15.9, jnp.inf))), + (bool, jnp.array(False), jnp.array((False, True, True))), + ), + ) + def test_sample( + self, dtype: jnp.dtype, minimum: chex.Array, maximum: chex.Array + ) -> None: + key = jax.random.PRNGKey(0) + + spec = specs.BoundedArray((1, 2, 3), dtype, minimum, maximum) + test_sample = spec.sample(key) + spec.validate(test_sample) + + mask = jnp.zeros((1, 2, 3), bool) + with pytest.raises(ValueError): + spec.sample(key, mask) + def test_scalar_bounds(self) -> None: spec = specs.BoundedArray((), float, minimum=0.0, maximum=1.0) @@ -488,6 +529,18 @@ def test_properties(self) -> None: assert spec.dtype == jnp.int32 assert spec.num_values == num_values + def test_sample(self) -> None: + key = jax.random.PRNGKey(0) + + spec = specs.DiscreteArray(5) + test_sample = spec.sample(key) + spec.validate(test_sample) + + mask = jnp.array((False, False, True, False, False), bool) + test_sample = spec.sample(key, mask) + spec.validate(test_sample) + assert test_sample == 2 + def test_serialization(self) -> None: spec = specs.DiscreteArray(2, jnp.int32, "pickle_test") loaded_spec = pickle.loads(pickle.dumps(spec)) @@ -544,6 +597,42 @@ def test_properties(self) -> None: assert spec.dtype == jnp.int32 assert (spec.num_values == num_values).all() + @pytest.mark.parametrize( + "num_values, mask, mode, is_valid, expected_sample", + [ + ((5, 6), None, None, True, None), + ((5, 6), None, "table", False, None), + ((5, 6), jnp.zeros((3, 4), bool), None, False, None), + ((2, 2), jnp.zeros((2, 2), bool), None, False, None), + ((5, 5), jnp.zeros((5, 5), bool), "wrong", False, None), + ((5, 6), jnp.zeros((2, 5), bool), "table", False, None), + ((5, 5), jnp.zeros((2, 6), bool), "table", False, None), + ((5, 5), jnp.zeros((2, 5), bool), "index", False, None), + ((5, 5), jnp.zeros((2, 5), bool).at[:, 3].set(True), None, True, (3, 3)), + ((5, 6), jnp.zeros((5, 6), bool).at[4, 0].set(True), None, True, (4, 0)), + ((2, 2), jnp.zeros((2, 2), bool).at[0, 1].set(True), "table", True, (1, 0)), + ((2, 2), jnp.zeros((2, 2), bool).at[0, 1].set(True), "index", True, (0, 1)), + ], + ) + def test_sample( + self, + num_values: Tuple[int], + mask: chex.Array, + mode: Optional[Literal["table", "index"]], + is_valid: bool, + expected_sample: Tuple[int], + ) -> None: + key = jax.random.PRNGKey(4) + spec = specs.MultiDiscreteArray(jnp.array(num_values)) + if is_valid: + test_sample = spec.sample(key, mask, mode) + spec.validate(test_sample) + if expected_sample is not None: + chex.assert_trees_all_close(test_sample, jnp.array(expected_sample)) + else: + with pytest.raises(ValueError): + spec.sample(key, mask, mode) + def test_serialization(self) -> None: spec = specs.MultiDiscreteArray( jnp.array([5, 6], dtype=int), jnp.int32, "pickle_test" diff --git a/jumanji/testing/env_not_smoke.py b/jumanji/testing/env_not_smoke.py index 8a3cb34b4..ba9133037 100644 --- a/jumanji/testing/env_not_smoke.py +++ b/jumanji/testing/env_not_smoke.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar import chex import jax -import jax.numpy as jnp from jumanji import specs from jumanji.env import Environment @@ -26,42 +25,13 @@ SelectActionFn = Callable[[chex.PRNGKey, Observation], Action] -def make_random_select_action_fn( - action_spec: Union[ - specs.BoundedArray, specs.DiscreteArray, specs.MultiDiscreteArray - ] -) -> SelectActionFn: +def make_random_select_action_fn(action_spec: specs.Spec) -> SelectActionFn: """Create select action function that chooses random actions.""" - def select_action(key: chex.PRNGKey, state: chex.ArrayTree) -> chex.ArrayTree: - del state - if ( - isinstance(action_spec, specs.DiscreteArray) - or isinstance(action_spec, specs.MultiDiscreteArray) - or jnp.issubdtype(action_spec.dtype, jnp.integer) - ): - action = jax.random.randint( - key=key, - shape=action_spec.shape, - minval=action_spec.minimum, - maxval=action_spec.maximum + 1, - dtype=action_spec.dtype, - ) - elif isinstance(action_spec, specs.BoundedArray): - assert jnp.issubdtype(action_spec.dtype, jnp.floating) - action = jax.random.uniform( - key=key, - shape=action_spec.shape, - dtype=action_spec.dtype, - minval=action_spec.minimum, - maxval=action_spec.maximum, - ) - else: - raise ValueError( - "Only supported for action specs of type `specs.BoundedArray, " - "specs.DiscreteArray or specs.MultiDiscreteArray`." - ) - return action + def select_action(key: chex.PRNGKey, observation: chex.ArrayTree) -> chex.ArrayTree: + if hasattr(observation, "action_mask"): + return action_spec.sample(key, observation.action_mask) + return action_spec.sample(key) return select_action @@ -74,17 +44,8 @@ def check_env_does_not_smoke( """Run an episode of the environment, with a jitted step function to check no errors occur.""" action_spec = env.action_spec() if select_action is None: - if isinstance(action_spec, specs.BoundedArray) or isinstance( - action_spec, specs.DiscreteArray - ): - select_action = make_random_select_action_fn(action_spec) - else: - raise NotImplementedError( - f"Currently the `make_random_select_action_fn` only works for environments with " - f"either discrete actions or bounded continuous actions. The input environment to " - f"this test has an action spec of type {action_spec}, and therefore requires " - f"a custom `SelectActionFn` to be provided to this test." - ) + select_action = make_random_select_action_fn(action_spec) + key = jax.random.PRNGKey(0) key, reset_key = jax.random.split(key) state, timestep = env.reset(reset_key)