diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 9a4575252..54e7616ee 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -27,6 +27,7 @@ class Discrete(Space[np.int64]): def __init__( self, n: int | np.integer[Any], + dtype: str | type[np.integer[Any]] = np.int64, seed: int | np.random.Generator | None = None, start: int | np.integer[Any] = 0, ): @@ -36,6 +37,7 @@ def __init__( Args: n (int): The number of elements of this space. + dtype: The space type, for example, ``int``, ``np.int64``, ``np.int32``, or ``np.uint8``. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space. start (int): The smallest element of this space. """ @@ -47,16 +49,27 @@ def __init__( type(start), np.integer ), f"Expects `start` to be an integer, actual type: {type(start)}" - self.n = np.int64(n) - self.start = np.int64(start) - super().__init__((), np.int64, seed) + # determine dtype + if dtype is None: + raise ValueError("Invalid Discrete dtype ({self.dtype}), cannot be None.") + self.dtype = np.dtype(dtype) + + # * check that dtype is an accepted dtype + if not (np.issubdtype(self.dtype, np.integer)): + raise ValueError( + f"Invalid Discrete dtype ({self.dtype}), must be an integer dtype" + ) + + self.n = self.dtype.type(n) + self.start = self.dtype.type(start) + super().__init__((), self.dtype, seed) @property def is_np_flattenable(self): """Checks whether this space can be flattened to a :class:`spaces.Box`.""" return True - def sample(self, mask: MaskNDArray | None = None) -> np.int64: + def sample(self, mask: MaskNDArray | None = None) -> np.integer[Any]: """Generates a single random sample from this space. A sample will be chosen uniformly at random with the mask if provided @@ -84,13 +97,13 @@ def sample(self, mask: MaskNDArray | None = None) -> np.int64: np.logical_or(mask == 0, valid_action_mask) ), f"All values of a mask should be 0 or 1, actual values: {mask}" if np.any(valid_action_mask): - return self.start + self.np_random.choice( - np.where(valid_action_mask)[0] + return self.start + self.dtype.type( + self.np_random.choice(np.where(valid_action_mask)[0]) ) else: return self.start - return self.start + self.np_random.integers(self.n) + return self.start + self.np_random.integers(self.n).astype(self.dtype) def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" @@ -137,7 +150,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]): super().__setstate__(state) - def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]: + def to_jsonable(self, sample_n: Sequence[np.integer[Any]]) -> list[int]: """Converts a list of samples to a list of ints.""" return [int(x) for x in sample_n] diff --git a/tests/spaces/test_discrete.py b/tests/spaces/test_discrete.py index 71c4fcf51..c94905795 100644 --- a/tests/spaces/test_discrete.py +++ b/tests/spaces/test_discrete.py @@ -1,6 +1,7 @@ from copy import deepcopy import numpy as np +import pytest from gymnasium.spaces import Discrete @@ -32,3 +33,37 @@ def test_sample_mask(): assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3 assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2 assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5] + + +@pytest.mark.parametrize( + "dtype, sample_dtype", + [ + (int, np.int64), + (np.int64, np.int64), + (np.int32, np.int32), + (np.uint8, np.uint8), + ], +) +def test_dtype(dtype, sample_dtype): + space = Discrete(n=5, dtype=dtype, start=2) + + sample = space.sample() + sample_mask = space.sample(mask=np.array([0, 1, 0, 0, 0], dtype=np.int8)) + print(f"{sample=}, {sample_mask=}") + print(f"{type(sample)=}, {type(sample_mask)=}") + assert isinstance(sample, sample_dtype), type(sample) + assert isinstance(sample_mask, sample_dtype), type(sample_mask) + + +@pytest.mark.parametrize( + "dtype", + [ + None, + str, + np.float32, + np.complex64, + ], +) +def test_dtype_error(dtype): + with pytest.raises(ValueError, match="Invalid Discrete dtype"): + Discrete(4, dtype=dtype)