Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PPO maskable type annotations #233

Merged
merged 7 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO

Bug Fixes:
^^^^^^^^^^
Expand All @@ -31,8 +32,10 @@ Deprecations:

Others:
^^^^^^^

- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations

Documentation:
^^^^^^^^^^^^^^
Expand Down
6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ exclude = """(?x)(
| sb3_contrib/ars/ars.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$
| sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
)"""

Expand Down
26 changes: 18 additions & 8 deletions sb3_contrib/common/maskable/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
action_masks: th.Tensor


class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
class MaskableDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
Expand All @@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""

action_masks: np.ndarray

def __init__(
self,
buffer_size: int,
Expand All @@ -53,14 +55,17 @@ def __init__(
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
self.action_masks = None

def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
mask_dims = self.action_space.n
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
Expand All @@ -79,7 +84,7 @@ def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> Non

super().add(*args, **kwargs)

def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
Expand All @@ -105,7 +110,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBuff
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
data = (
self.observations[batch_inds],
self.actions[batch_inds],
Expand Down Expand Up @@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
:param n_envs: Number of parallel environments
"""

action_masks: np.ndarray

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.action_masks = None
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)

def reset(self) -> None:
Expand All @@ -162,6 +168,10 @@ def reset(self) -> None:
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
Expand All @@ -180,7 +190,7 @@ def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> Non

super().add(*args, **kwargs)

def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
Expand All @@ -203,7 +213,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRollout
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
Expand Down
8 changes: 5 additions & 3 deletions sb3_contrib/common/maskable/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _on_step(self) -> bool:

# Note that evaluate_policy() has been patched to support masking
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.model, # type: ignore[arg-type]
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
Expand All @@ -67,6 +67,8 @@ def _on_step(self) -> bool:
)

if self.log_path is not None:
assert isinstance(episode_rewards, list)
assert isinstance(episode_lengths, list)
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
Expand All @@ -87,7 +89,7 @@ def _on_step(self) -> bool:

mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
self.last_mean_reward = float(mean_reward)

if self.verbose > 0:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
Expand All @@ -111,7 +113,7 @@ def _on_step(self) -> bool:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
self.best_mean_reward = float(mean_reward)
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()
Expand Down
37 changes: 23 additions & 14 deletions sb3_contrib/common/maskable/distributions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar, Union

import numpy as np
import torch as th
Expand All @@ -13,6 +13,7 @@
SelfMaskableMultiCategoricalDistribution = TypeVar(
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
)
MaybeMasks = Union[th.Tensor, np.ndarray, None]


class MaskableCategorical(Categorical):
Expand All @@ -36,14 +37,14 @@ def __init__(
probs: Optional[th.Tensor] = None,
logits: Optional[th.Tensor] = None,
validate_args: Optional[bool] = None,
masks: Optional[np.ndarray] = None,
masks: MaybeMasks = None,
):
self.masks: Optional[th.Tensor] = None
super().__init__(probs, logits, validate_args)
self._original_logits = self.logits
self.apply_masking(masks)

def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.

Expand Down Expand Up @@ -84,7 +85,7 @@ def entropy(self) -> th.Tensor:

class MaskableDistribution(Distribution, ABC):
@abstractmethod
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.

Expand All @@ -94,6 +95,13 @@ def apply_masking(self, masks: Optional[np.ndarray]) -> None:
previously applied masking is removed, and the original logits are restored.
"""

@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
"""Create the layers and parameters that represent the distribution.

Subclasses must define this, but the arguments and return type vary between
concrete classes."""


class MaskableCategoricalDistribution(MaskableDistribution):
"""
Expand Down Expand Up @@ -154,7 +162,7 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.
log_prob = self.log_prob(actions)
return actions, log_prob

def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
assert self.distribution is not None, "Must set distribution parameters"
self.distribution.apply_masking(masks)

Expand Down Expand Up @@ -192,7 +200,7 @@ def proba_distribution(
reshaped_logits = action_logits.view(-1, sum(self.action_dims))

self.distributions = [
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
]
return self

Expand Down Expand Up @@ -229,18 +237,16 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.
log_prob = self.log_prob(actions)
return actions, log_prob

def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
assert len(self.distributions) > 0, "Must set distribution parameters"

split_masks = [None] * len(self.distributions)
if masks is not None:
masks = th.as_tensor(masks)

masks_tensor = th.as_tensor(masks)
# Restructure shape to align with logits
masks = masks.view(-1, sum(self.action_dims))

masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
# Then split columnwise for each discrete action
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]

for distribution, mask in zip(self.distributions, split_masks):
distribution.apply_masking(mask)
Expand Down Expand Up @@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
"""

if isinstance(action_space, spaces.Discrete):
return MaskableCategoricalDistribution(action_space.n)
return MaskableCategoricalDistribution(int(action_space.n))
elif isinstance(action_space, spaces.MultiDiscrete):
return MaskableMultiCategoricalDistribution(action_space.nvec)
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return MaskableBernoulliDistribution(action_space.n)
else:
raise NotImplementedError(
Expand Down
Loading
Loading