Skip to content

Commit

Permalink
Merge branch 'opendilab:main' into q_transformner
Browse files Browse the repository at this point in the history
  • Loading branch information
rongkunxue authored Apr 23, 2024
2 parents 1839ded + 8392206 commit be60d5c
Show file tree
Hide file tree
Showing 32 changed files with 800 additions and 88 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
<details open>
<summary>(Click to Collapse)</summary>


| No | Environment | Label | Visualization | Code and Doc Links |
| :-: | :--------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| 1 | [Atari](https://github.com/openai/gym/tree/master/gym/envs/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/atari/atari.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/atari/envs) <br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/atari_zh.html) |
Expand Down Expand Up @@ -316,8 +315,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 35 | [metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
| 36 | [cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)<br> env tutorial <br> 环境指南 |
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |

| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
25 changes: 15 additions & 10 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ class DQN(nn.Module):
"""

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None,
init_bias: Optional[float] = None,
) -> None:
"""
Overview:
Expand All @@ -55,6 +56,7 @@ def __init__(
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
if ``None`` then default disable dropout layer.
- init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \
"""
super(DQN, self).__init__()
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
Expand Down Expand Up @@ -99,6 +101,9 @@ def __init__(
norm_type=norm_type,
dropout=dropout
)
if init_bias is not None and head_cls == DuelingHead:
# Zero the last layer bias of advantage head
self.head.A[-1][0].bias.data.fill_(init_bias)

def forward(self, x: torch.Tensor) -> Dict:
"""
Expand Down
13 changes: 11 additions & 2 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,17 @@ def default_preprocess_learn(
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
# reward: (batch_size, nstep) -> (nstep, batch_size)
data['reward'] = reward.permute(1, 0).contiguous()
# single agent reward: (batch_size, nstep) -> (nstep, batch_size)
# multi-agent reward: (batch_size, agent_dim, nstep) -> (nstep, batch_size, agent_dim)
# Assuming 'reward' is a PyTorch tensor with shape (batch_size, nstep) or (batch_size, agent_dim, nstep)
if reward.ndim == 2:
# For a 2D tensor, simply transpose it to get (nstep, batch_size)
data['reward'] = reward.transpose(0, 1).contiguous()
elif reward.ndim == 3:
# For a 3D tensor, move the last dimension to the front to get (nstep, batch_size, agent_dim)
data['reward'] = reward.permute(2, 0, 1).contiguous()
else:
raise ValueError("The 'reward' tensor must be either 2D or 3D. Got shape: {}".format(reward.shape))
else:
if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
data['reward'] = data['reward'].squeeze(-1)
Expand Down
19 changes: 12 additions & 7 deletions ding/rl_utils/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def get_gae(cls, data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: fl
Overview:
Get GAE advantage for stacked transitions(T timestep, 1 batch). Call ``gae`` for calculation.
Arguments:
- data (:obj:`list`): Transitions list, each element is a transition dict with at least ['value', 'reward']
- data (:obj:`list`): Transitions list, each element is a transition dict with at least \
``['value', 'reward']``.
- last_value (:obj:`torch.Tensor`): The last value(i.e.: the T+1 timestep)
- gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
- gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_gae_with_default_last_value(cls, data: deque, done: bool, gamma: float,
Overview:
Like ``get_gae`` above to get GAE advantage for stacked transitions. However, this function is designed in
case ``last_value`` is not passed. If transition is not done yet, it wouold assign last value in ``data``
as ``last_value``, discard the last element in ``data``(i.e. len(data) would decrease by 1), and then call
as ``last_value``, discard the last element in ``data`` (i.e. len(data) would decrease by 1), and then call
``get_gae``. Otherwise it would make ``last_value`` equal to 0.
Arguments:
- data (:obj:`deque`): Transitions list, each element is a transition dict with \
Expand Down Expand Up @@ -103,7 +104,7 @@ def get_nstep_return_data(
) -> deque:
"""
Overview:
Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element.
Arguments:
- data (:obj:`deque`): Transitions list, each element is a transition dict
- nstep (:obj:`int`): Number of steps. If equals to 1, return ``data`` directly; \
Expand All @@ -121,7 +122,7 @@ def get_nstep_return_data(
"""
if nstep == 1:
return data
fake_reward = torch.zeros(1)
fake_reward = torch.zeros_like(data[0]['reward'])
next_obs_flag = 'next_obs' in data[0]
for i in range(len(data) - nstep):
# update keys ['next_obs', 'reward', 'done'] with their n-step value
Expand All @@ -130,7 +131,10 @@ def get_nstep_return_data(
if cum_reward:
data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)])
else:
data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)])
# data[i]['reward'].shape = (1) or (agent_num, 1)
# single agent env: shape (1) -> (n_step)
# multi-agent env: shape (agent_num, 1) -> (agent_num, n_step)
data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)], dim=-1)
data[i]['done'] = data[i + nstep - 1]['done']
if correct_terminate_gamma:
data[i]['value_gamma'] = gamma ** nstep
Expand All @@ -142,7 +146,8 @@ def get_nstep_return_data(
else:
data[i]['reward'] = torch.cat(
[data[i + j]['reward']
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))]
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))],
dim=-1
)
data[i]['done'] = data[-1]['done']
if correct_terminate_gamma:
Expand All @@ -159,7 +164,7 @@ def get_train_sample(
) -> List[Dict[str, Any]]:
"""
Overview:
Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element.
If ``unroll_len`` equals to 1, which means no process is needed, can directly return ``data``.
Otherwise, ``data`` will be splitted according to ``unroll_len``, process residual part according to
``last_fn_type`` and call ``lists_to_dicts`` to form sampled training data.
Expand Down
27 changes: 27 additions & 0 deletions ding/rl_utils/beta_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
# For CPW, eta = 0.71 most closely match human subjects
# this function is locally concave for small values of τ and becomes locally convex for larger values of τ
def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
"""
Overview:
The implementation of CPW function.
Arguments:
- x (:obj:`Union[torch.Tensor, float]`): The input value.
- eta (:obj:`float`): The hyperparameter of CPW function.
Returns:
- output (:obj:`Union[torch.Tensor, float]`): The output value.
"""
return (x ** eta) / ((x ** eta + (1 - x) ** eta) ** (1 / eta))


Expand All @@ -22,6 +31,15 @@ def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor,

# CVaR is risk-averse
def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
"""
Overview:
The implementation of CVaR function, which is a risk-averse function.
Arguments:
- x (:obj:`Union[torch.Tensor, float]`): The input value.
- eta (:obj:`float`): The hyperparameter of CVaR function.
Returns:
- output (:obj:`Union[torch.Tensor, float]`): The output value.
"""
assert eta <= 1.0
return x * eta

Expand All @@ -31,6 +49,15 @@ def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor

# risk-averse (eta < 0) or risk-seeking (eta > 0)
def Pow(x: Union[torch.Tensor, float], eta: float = 0.0) -> Union[torch.Tensor, float]:
"""
Overview:
The implementation of Pow function, which is risk-averse when eta < 0 and risk-seeking when eta > 0.
Arguments:
- x (:obj:`Union[torch.Tensor, float]`): The input value.
- eta (:obj:`float`): The hyperparameter of Pow function.
Returns:
- output (:obj:`Union[torch.Tensor, float]`): The output value.
"""
if eta >= 0:
return x ** (1 / (1 + eta))
else:
Expand Down
Loading

0 comments on commit be60d5c

Please sign in to comment.