Skip to content

Commit 96c4955

Browse files
committed
polish(nyz): polish rl_utils api docs
1 parent 15ff277 commit 96c4955

File tree

6 files changed

+78
-53
lines changed

6 files changed

+78
-53
lines changed

ding/rl_utils/adder.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def get_gae(cls, data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: fl
2323
Overview:
2424
Get GAE advantage for stacked transitions(T timestep, 1 batch). Call ``gae`` for calculation.
2525
Arguments:
26-
- data (:obj:`list`): Transitions list, each element is a transition dict with at least ['value', 'reward']
26+
- data (:obj:`list`): Transitions list, each element is a transition dict with at least \
27+
``['value', 'reward']``.
2728
- last_value (:obj:`torch.Tensor`): The last value(i.e.: the T+1 timestep)
2829
- gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
2930
- gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
@@ -63,7 +64,7 @@ def get_gae_with_default_last_value(cls, data: deque, done: bool, gamma: float,
6364
Overview:
6465
Like ``get_gae`` above to get GAE advantage for stacked transitions. However, this function is designed in
6566
case ``last_value`` is not passed. If transition is not done yet, it wouold assign last value in ``data``
66-
as ``last_value``, discard the last element in ``data``(i.e. len(data) would decrease by 1), and then call
67+
as ``last_value``, discard the last element in ``data`` (i.e. len(data) would decrease by 1), and then call
6768
``get_gae``. Otherwise it would make ``last_value`` equal to 0.
6869
Arguments:
6970
- data (:obj:`deque`): Transitions list, each element is a transition dict with \
@@ -103,7 +104,7 @@ def get_nstep_return_data(
103104
) -> deque:
104105
"""
105106
Overview:
106-
Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
107+
Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element.
107108
Arguments:
108109
- data (:obj:`deque`): Transitions list, each element is a transition dict
109110
- nstep (:obj:`int`): Number of steps. If equals to 1, return ``data`` directly; \
@@ -159,7 +160,7 @@ def get_train_sample(
159160
) -> List[Dict[str, Any]]:
160161
"""
161162
Overview:
162-
Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
163+
Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element.
163164
If ``unroll_len`` equals to 1, which means no process is needed, can directly return ``data``.
164165
Otherwise, ``data`` will be splitted according to ``unroll_len``, process residual part according to
165166
``last_fn_type`` and call ``lists_to_dicts`` to form sampled training data.

ding/rl_utils/beta_function.py

+27
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
# For CPW, eta = 0.71 most closely match human subjects
1515
# this function is locally concave for small values of τ and becomes locally convex for larger values of τ
1616
def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
17+
"""
18+
Overview:
19+
The implementation of CPW function.
20+
Arguments:
21+
- x (:obj:`Union[torch.Tensor, float]`): The input value.
22+
- eta (:obj:`float`): The hyperparameter of CPW function.
23+
Returns:
24+
- output (:obj:`Union[torch.Tensor, float]`): The output value.
25+
"""
1726
return (x ** eta) / ((x ** eta + (1 - x) ** eta) ** (1 / eta))
1827

1928

@@ -22,6 +31,15 @@ def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor,
2231

2332
# CVaR is risk-averse
2433
def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
34+
"""
35+
Overview:
36+
The implementation of CVaR function, which is a risk-averse function.
37+
Arguments:
38+
- x (:obj:`Union[torch.Tensor, float]`): The input value.
39+
- eta (:obj:`float`): The hyperparameter of CVaR function.
40+
Returns:
41+
- output (:obj:`Union[torch.Tensor, float]`): The output value.
42+
"""
2543
assert eta <= 1.0
2644
return x * eta
2745

@@ -31,6 +49,15 @@ def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor
3149

3250
# risk-averse (eta < 0) or risk-seeking (eta > 0)
3351
def Pow(x: Union[torch.Tensor, float], eta: float = 0.0) -> Union[torch.Tensor, float]:
52+
"""
53+
Overview:
54+
The implementation of Pow function, which is risk-averse when eta < 0 and risk-seeking when eta > 0.
55+
Arguments:
56+
- x (:obj:`Union[torch.Tensor, float]`): The input value.
57+
- eta (:obj:`float`): The hyperparameter of Pow function.
58+
Returns:
59+
- output (:obj:`Union[torch.Tensor, float]`): The output value.
60+
"""
3461
if eta >= 0:
3562
return x ** (1 / (1 + eta))
3663
else:

ding/rl_utils/exploration.py

+30-29
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ def get_epsilon_greedy_fn(start: float, end: float, decay: int, type_: str = 'ex
1212
Overview:
1313
Generate an epsilon_greedy function with decay, which inputs current timestep and outputs current epsilon.
1414
Arguments:
15-
- start (:obj:`float`): Epsilon start value. For 'linear', it should be 1.0.
15+
- start (:obj:`float`): Epsilon start value. For ``linear`` , it should be 1.0.
1616
- end (:obj:`float`): Epsilon end value.
1717
- decay (:obj:`int`): Controls the speed that epsilon decreases from ``start`` to ``end``. \
1818
We recommend epsilon decays according to env step rather than iteration.
19-
- type (:obj:`str`): How epsilon decays, now supports ['linear', 'exp'(exponential)]
19+
- type (:obj:`str`): How epsilon decays, now supports ``['linear', 'exp'(exponential)]`` .
2020
Returns:
21-
- eps_fn (:obj:`function`): The epsilon greedy function with decay
21+
- eps_fn (:obj:`function`): The epsilon greedy function with decay.
2222
"""
2323
assert type_ in ['linear', 'exp'], type_
2424
if type_ == 'exp':
@@ -48,27 +48,27 @@ class BaseNoise(ABC):
4848
def __init__(self) -> None:
4949
"""
5050
Overview:
51-
Initialization method
51+
Initialization method.
5252
"""
5353
super().__init__()
5454

5555
@abstractmethod
5656
def __call__(self, shape: tuple, device: str) -> torch.Tensor:
5757
"""
5858
Overview:
59-
Generate noise according to action tensor's shape, device
59+
Generate noise according to action tensor's shape, device.
6060
Arguments:
61-
- shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
62-
- device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
61+
- shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same.
62+
- device (:obj:`str`): device of the action tensor, output noise's device should be the same as it.
6363
Returns:
6464
- noise (:obj:`torch.Tensor`): generated action noise, \
65-
have the same shape and device with the input action tensor
65+
have the same shape and device with the input action tensor.
6666
"""
6767
raise NotImplementedError
6868

6969

7070
class GaussianNoise(BaseNoise):
71-
r"""
71+
"""
7272
Overview:
7373
Derived class for generating gaussian noise, which satisfies :math:`X \sim N(\mu, \sigma^2)`
7474
Interface:
@@ -78,10 +78,10 @@ class GaussianNoise(BaseNoise):
7878
def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
7979
"""
8080
Overview:
81-
Initialize :math:`\mu` and :math:`\sigma` in Gaussian Distribution
81+
Initialize :math:`\mu` and :math:`\sigma` in Gaussian Distribution.
8282
Arguments:
83-
- mu (:obj:`float`): :math:`\mu` , mean value
84-
- sigma (:obj:`float`): :math:`\sigma` , standard deviation, should be positive
83+
- mu (:obj:`float`): :math:`\mu` , mean value.
84+
- sigma (:obj:`float`): :math:`\sigma` , standard deviation, should be positive.
8585
"""
8686
super(GaussianNoise, self).__init__()
8787
self._mu = mu
@@ -125,14 +125,15 @@ def __init__(
125125
"""
126126
Overview:
127127
Initialize ``_alpha`` :math:`=\theta * dt\`,
128-
``beta`` :math:`= \sigma * \sqrt{dt}`, in Ornstein-Uhlenbeck process
128+
``beta`` :math:`= \sigma * \sqrt{dt}`, in Ornstein-Uhlenbeck process.
129129
Arguments:
130-
- mu (:obj:`float`): :math:`\mu` , mean value
131-
- sigma (:obj:`float`): :math:`\sigma` , standard deviation of the perturbation noise
132-
- theta (:obj:`float`): how strongly the noise reacts to perturbations, \
133-
greater value means stronger reaction
134-
- dt (:obj:`float`): derivative of time t
135-
- x0 (:obj:`float` or :obj:`torch.Tensor`): initial action
130+
- mu (:obj:`float`): :math:`\mu` , mean value.
131+
- sigma (:obj:`float`): :math:`\sigma` , standard deviation of the perturbation noise.
132+
- theta (:obj:`float`): How strongly the noise reacts to perturbations, \
133+
greater value means stronger reaction.
134+
- dt (:obj:`float`): The derivative of time t.
135+
- x0 (:obj:`Union[float, torch.Tensor]`): The initial state of the noise, \
136+
should be a scalar or tensor with the same shape as the action tensor.
136137
"""
137138
super().__init__()
138139
self._mu = mu
@@ -144,21 +145,21 @@ def __init__(
144145
def reset(self) -> None:
145146
"""
146147
Overview:
147-
Reset ``_x`` to the initial state ``_x0``
148+
Reset ``_x`` to the initial state ``_x0``.
148149
"""
149150
self._x = deepcopy(self._x0)
150151

151152
def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> torch.Tensor:
152153
"""
153154
Overview:
154-
Generate gaussian noise according to action tensor's shape, device
155+
Generate gaussian noise according to action tensor's shape, device.
155156
Arguments:
156-
- shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
157-
- device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
158-
- mu (:obj:`float`): new mean value :math:`\mu`, you can set it to `None` if don't need it
157+
- shape (:obj:`tuple`): The size of the action tensor, output noise's size should be the same.
158+
- device (:obj:`str`): The device of the action tensor, output noise's device should be the same as it.
159+
- mu (:obj:`float`): The new mean value :math:`\mu`, you can set it to `None` if don't need it.
159160
Returns:
160161
- noise (:obj:`torch.Tensor`): generated action noise, \
161-
have the same shape and device with the input action tensor
162+
have the same shape and device with the input action tensor.
162163
"""
163164
if self._x is None or \
164165
(isinstance(self._x, torch.Tensor) and self._x.shape != shape):
@@ -174,15 +175,15 @@ def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> tor
174175
def x0(self) -> Union[float, torch.Tensor]:
175176
"""
176177
Overview:
177-
Get ``self._x0``
178+
Get ``self._x0``.
178179
"""
179180
return self._x0
180181

181182
@x0.setter
182183
def x0(self, _x0: Union[float, torch.Tensor]) -> None:
183184
"""
184185
Overview:
185-
Set ``self._x0`` and reset ``self.x`` to ``self._x0`` as well
186+
Set ``self._x0`` and reset ``self.x`` to ``self._x0`` as well.
186187
"""
187188
self._x0 = _x0
188189
self.reset()
@@ -198,10 +199,10 @@ def create_noise_generator(noise_type: str, noise_kwargs: dict) -> BaseNoise:
198199
or raise an KeyError. In other words, a derived noise generator must first register,
199200
then call ``create_noise generator`` to get the instance object.
200201
Arguments:
201-
- noise_type (:obj:`str`): the type of noise generator to be created
202+
- noise_type (:obj:`str`): the type of noise generator to be created.
202203
Returns:
203204
- noise (:obj:`BaseNoise`): the created new noise generator, should be an instance of one of \
204-
noise_mapping's values
205+
noise_mapping's values.
205206
"""
206207
if noise_type not in noise_mapping.keys():
207208
raise KeyError("not support noise type: {}".format(noise_type))

ding/rl_utils/td.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def v_nstep_td_error(
578578
nstep: int = 1,
579579
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
580580
) -> torch.Tensor:
581-
r"""
581+
"""
582582
Overview:
583583
Multistep (n step) td_error for distributed value based algorithm
584584
Arguments:
@@ -588,14 +588,14 @@ def v_nstep_td_error(
588588
Returns:
589589
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
590590
Shapes:
591-
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\
591+
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing \
592592
['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']
593593
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
594594
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
595595
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
596596
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
597597
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
598-
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\
598+
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step \
599599
we use value_gamma as the gamma discount value for next_v rather than gamma**n_step
600600
Examples:
601601
>>> v = torch.randn(5).requires_grad_(True)
@@ -1098,7 +1098,7 @@ def qrdqn_nstep_td_error(
10981098
Overview:
10991099
Multistep (1 step or n step) td_error with in QRDQN
11001100
Arguments:
1101-
- data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss
1101+
- data (:obj:`qrdqn_nstep_td_data`): The input data, qrdqn_nstep_td_data to calculate loss
11021102
- gamma (:obj:`float`): Discount factor
11031103
- nstep (:obj:`int`): nstep num, default set to 1
11041104
Returns:
@@ -1605,18 +1605,16 @@ def multistep_forward_view(
16051605
lambda_: float,
16061606
done: Optional[torch.Tensor] = None
16071607
) -> torch.Tensor:
1608-
r"""
1608+
"""
16091609
Overview:
1610-
Same as trfl.sequence_ops.multistep_forward_view
1611-
Implementing (12.18) in Sutton & Barto
1610+
Same as trfl.sequence_ops.multistep_forward_view, which implements (12.18) in Sutton & Barto.
1611+
Assuming the first dim of input tensors correspond to the index in batch.
16121612
1613-
```
1613+
.. note::
16141614
result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T]
16151615
for t in 0...T-2 :
16161616
result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1])
1617-
```
16181617
1619-
Assuming the first dim of input tensors correspond to the index in batch
16201618
Arguments:
16211619
- bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize]
16221620
- rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]

ding/rl_utils/upgo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def tb_cross_entropy(logit, label, mask=None):
4444

4545

4646
def upgo_returns(rewards: torch.Tensor, bootstrap_values: torch.Tensor) -> torch.Tensor:
47-
r"""
47+
"""
4848
Overview:
4949
Computing UPGO return targets. Also notice there is no special handling for the terminal state.
5050
Arguments:
@@ -82,7 +82,7 @@ def upgo_loss(
8282
bootstrap_values: torch.Tensor,
8383
mask=None
8484
) -> torch.Tensor:
85-
r"""
85+
"""
8686
Overview:
8787
Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value,
8888
if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value.

ding/rl_utils/value_rescale.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
5-
r"""
5+
"""
66
Overview:
77
A function to reduce the scale of the action-value function.
88
:math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \eps * x` .
@@ -14,14 +14,13 @@ def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
1414
- (:obj:`torch.Tensor`) Normalized tensor.
1515
1616
.. note::
17-
Observe and Look Further: Achieving Consistent Performance on Atari
18-
(https://arxiv.org/abs/1805.11593)
17+
Observe and Look Further: Achieving Consistent Performance on Atari (https://arxiv.org/abs/1805.11593).
1918
"""
2019
return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
2120

2221

2322
def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
24-
r"""
23+
"""
2524
Overview:
2625
The inverse form of value rescale.
2726
:math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\eps(|x|+1+\eps)}-1}{2\eps})}^2-1)` .
@@ -36,7 +35,7 @@ def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
3635

3736

3837
def symlog(x: torch.Tensor) -> torch.Tensor:
39-
r"""
38+
"""
4039
Overview:
4140
A function to normalize the targets.
4241
:math: `symlog(x) = sign(x)(\ln{|x|+1})` .
@@ -46,14 +45,13 @@ def symlog(x: torch.Tensor) -> torch.Tensor:
4645
- (:obj:`torch.Tensor`) Normalized tensor.
4746
4847
.. note::
49-
Mastering Diverse Domains through World Models
50-
(https://arxiv.org/abs/2301.04104)
48+
Mastering Diverse Domains through World Models (https://arxiv.org/abs/2301.04104)
5149
"""
5250
return torch.sign(x) * (torch.log(torch.abs(x) + 1))
5351

5452

5553
def inv_symlog(x: torch.Tensor) -> torch.Tensor:
56-
r"""
54+
"""
5755
Overview:
5856
The inverse form of symlog.
5957
:math: `symexp(x) = sign(x)(\exp{|x|}-1)` .

0 commit comments

Comments
 (0)