Skip to content

Commit 312cdc0

Browse files
authored
Switch to ruff (#429)
* Add ruff * Add typing-extensions back * Move Protocol to typing from typing_extensions
1 parent 2c67af3 commit 312cdc0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+384
-1011
lines changed

Diff for: .github/workflows/format_check.yml

+1-4
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ jobs:
2525
pip install Cython numpy
2626
pip install -e .
2727
pip install -r dev.requirements.txt
28-
- name: Check format
29-
run: |
30-
./scripts/format
31-
- name: Linter
28+
- name: Static analysis
3229
run: |
3330
./scripts/lint
3431

Diff for: CONTRIBUTING.md

+2-8
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,10 @@ $ ./scripts/test
2929
```
3030

3131
### Coding style check
32-
This repository is styled with [black](https://github.com/psf/black) formatter.
33-
Also, [isort](https://github.com/PyCQA/isort) is used to format package imports.
32+
This repository is styled and analyzed with [Ruff](https://docs.astral.sh/ruff/).
3433
[docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings.
35-
```
36-
$ ./scripts/format
37-
```
38-
39-
### Linter
4034
This repository is fully type-annotated and checked by [mypy](https://github.com/python/mypy).
41-
Also, [pylint](https://github.com/PyCQA/pylint) checks code consistency.
35+
Before you submit your PR, please execute this command:
4236
```
4337
$ ./scripts/lint
4438
```

Diff for: d3rlpy/algos/qlearning/base.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
from collections import defaultdict
33
from typing import (
44
Callable,
5-
Dict,
65
Generator,
76
Generic,
8-
List,
97
Optional,
108
Sequence,
11-
Tuple,
129
TypeVar,
1310
)
1411

@@ -67,13 +64,13 @@
6764

6865
class QLearningAlgoImplBase(ImplBase):
6966
@train_api
70-
def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]:
67+
def update(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]:
7168
return self.inner_update(batch, grad_step)
7269

7370
@abstractmethod
7471
def inner_update(
7572
self, batch: TorchMiniBatch, grad_step: int
76-
) -> Dict[str, float]:
73+
) -> dict[str, float]:
7774
pass
7875

7976
@eval_api
@@ -382,10 +379,10 @@ def fit(
382379
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
383380
show_progress: bool = True,
384381
save_interval: int = 1,
385-
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
382+
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
386383
callback: Optional[Callable[[Self, int, int], None]] = None,
387384
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
388-
) -> List[Tuple[int, Dict[str, float]]]:
385+
) -> list[tuple[int, dict[str, float]]]:
389386
"""Trains with given dataset.
390387
391388
.. code-block:: python
@@ -448,10 +445,10 @@ def fitter(
448445
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
449446
show_progress: bool = True,
450447
save_interval: int = 1,
451-
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
448+
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
452449
callback: Optional[Callable[[Self, int, int], None]] = None,
453450
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
454-
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
451+
) -> Generator[tuple[int, dict[str, float]], None, None]:
455452
"""Iterate over epochs steps to train with the given dataset. At each
456453
iteration algo methods and properties can be changed or queried.
457454
@@ -859,7 +856,7 @@ def collect(
859856

860857
return buffer
861858

862-
def update(self, batch: TransitionMiniBatch) -> Dict[str, float]:
859+
def update(self, batch: TransitionMiniBatch) -> dict[str, float]:
863860
"""Update parameters with mini-batch of data.
864861
865862
Args:

Diff for: d3rlpy/algos/qlearning/random_policy.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import dataclasses
2-
from typing import Dict
32

43
import numpy as np
54

@@ -35,7 +34,9 @@ class RandomPolicyConfig(LearnableConfig):
3534
distribution: str = "uniform"
3635
normal_std: float = 1.0
3736

38-
def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "RandomPolicy": # type: ignore
37+
def create( # type: ignore
38+
self, device: DeviceArg = False, enable_ddp: bool = False
39+
) -> "RandomPolicy":
3940
return RandomPolicy(self)
4041

4142
@staticmethod
@@ -83,7 +84,7 @@ def sample_action(self, x: Observation) -> NDArray:
8384
def predict_value(self, x: Observation, action: NDArray) -> NDArray:
8485
raise NotImplementedError
8586

86-
def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
87+
def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]:
8788
raise NotImplementedError
8889

8990
def get_action_type(self) -> ActionSpace:
@@ -98,7 +99,9 @@ class DiscreteRandomPolicyConfig(LearnableConfig):
9899
``fit`` and ``fit_online`` methods will raise exceptions.
99100
"""
100101

101-
def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "DiscreteRandomPolicy": # type: ignore
102+
def create( # type: ignore
103+
self, device: DeviceArg = False, enable_ddp: bool = False
104+
) -> "DiscreteRandomPolicy":
102105
return DiscreteRandomPolicy(self)
103106

104107
@staticmethod
@@ -128,7 +131,7 @@ def sample_action(self, x: Observation) -> NDArray:
128131
def predict_value(self, x: Observation, action: NDArray) -> NDArray:
129132
raise NotImplementedError
130133

131-
def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
134+
def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]:
132135
raise NotImplementedError
133136

134137
def get_action_type(self) -> ActionSpace:

Diff for: d3rlpy/algos/qlearning/torch/bc_impl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
from abc import ABCMeta, abstractmethod
3-
from typing import Callable, Dict, Union
3+
from typing import Callable, Union
44

55
import torch
66
from torch.optim import Optimizer
@@ -60,7 +60,7 @@ def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
6060
loss.loss.backward()
6161
return loss
6262

63-
def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
63+
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
6464
loss = self._compute_imitator_grad(batch)
6565
self._modules.optim.step()
6666
return asdict_as_float(loss)
@@ -81,7 +81,7 @@ def inner_predict_value(
8181

8282
def inner_update(
8383
self, batch: TorchMiniBatch, grad_step: int
84-
) -> Dict[str, float]:
84+
) -> dict[str, float]:
8585
return self.update_imitator(batch)
8686

8787

Diff for: d3rlpy/algos/qlearning/torch/bcq_impl.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
import math
3-
from typing import Callable, Dict, cast
3+
from typing import Callable, cast
44

55
import torch
66
import torch.nn.functional as F
@@ -50,7 +50,7 @@ class BCQModules(DDPGBaseModules):
5050

5151
class BCQImpl(DDPGBaseImpl):
5252
_modules: BCQModules
53-
_compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]]
53+
_compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]]
5454
_lam: float
5555
_n_action_samples: int
5656
_action_flexibility: float
@@ -124,7 +124,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:
124124

125125
def compute_imitator_grad(
126126
self, batch: TorchMiniBatch
127-
) -> Dict[str, torch.Tensor]:
127+
) -> dict[str, torch.Tensor]:
128128
self._modules.vae_optim.zero_grad()
129129
loss = compute_vae_error(
130130
vae_encoder=self._modules.vae_encoder,
@@ -136,7 +136,7 @@ def compute_imitator_grad(
136136
loss.backward()
137137
return {"loss": loss}
138138

139-
def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
139+
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
140140
loss = self._compute_imitator_grad(batch)
141141
self._modules.vae_optim.step()
142142
return {"vae_loss": float(loss["loss"].cpu().detach().numpy())}
@@ -214,7 +214,7 @@ def update_actor_target(self) -> None:
214214

215215
def inner_update(
216216
self, batch: TorchMiniBatch, grad_step: int
217-
) -> Dict[str, float]:
217+
) -> dict[str, float]:
218218
metrics = {}
219219

220220
metrics.update(self.update_imitator(batch))

Diff for: d3rlpy/algos/qlearning/torch/bear_impl.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
from typing import Callable, Dict, Optional
2+
from typing import Callable, Optional
33

44
import torch
55

@@ -61,9 +61,9 @@ class BEARActorLoss(SACActorLoss):
6161
class BEARImpl(SACImpl):
6262
_modules: BEARModules
6363
_compute_warmup_actor_grad: Callable[
64-
[TorchMiniBatch], Dict[str, torch.Tensor]
64+
[TorchMiniBatch], dict[str, torch.Tensor]
6565
]
66-
_compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]]
66+
_compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]]
6767
_alpha_threshold: float
6868
_lam: float
6969
_n_action_samples: int
@@ -143,13 +143,13 @@ def compute_actor_loss(
143143

144144
def compute_warmup_actor_grad(
145145
self, batch: TorchMiniBatch
146-
) -> Dict[str, torch.Tensor]:
146+
) -> dict[str, torch.Tensor]:
147147
self._modules.actor_optim.zero_grad()
148148
loss = self._compute_mmd_loss(batch.observations)
149149
loss.backward()
150150
return {"loss": loss}
151151

152-
def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
152+
def warmup_actor(self, batch: TorchMiniBatch) -> dict[str, float]:
153153
loss = self._compute_warmup_actor_grad(batch)
154154
self._modules.actor_optim.step()
155155
return {"actor_loss": float(loss["loss"].cpu().detach().numpy())}
@@ -161,13 +161,13 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:
161161

162162
def compute_imitator_grad(
163163
self, batch: TorchMiniBatch
164-
) -> Dict[str, torch.Tensor]:
164+
) -> dict[str, torch.Tensor]:
165165
self._modules.vae_optim.zero_grad()
166166
loss = self.compute_imitator_loss(batch)
167167
loss.backward()
168168
return {"loss": loss}
169169

170-
def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
170+
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
171171
loss = self._compute_imitator_grad(batch)
172172
self._modules.vae_optim.step()
173173
return {"imitator_loss": float(loss["loss"].cpu().detach().numpy())}
@@ -301,7 +301,7 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
301301

302302
def inner_update(
303303
self, batch: TorchMiniBatch, grad_step: int
304-
) -> Dict[str, float]:
304+
) -> dict[str, float]:
305305
metrics = {}
306306
metrics.update(self.update_imitator(batch))
307307
metrics.update(self.update_critic(batch))

Diff for: d3rlpy/algos/qlearning/torch/cal_ql_impl.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import Tuple
21

32
import torch
43

@@ -14,7 +13,7 @@ def _compute_policy_is_values(
1413
policy_obs: TorchObservation,
1514
value_obs: TorchObservation,
1615
returns_to_go: torch.Tensor,
17-
) -> Tuple[torch.Tensor, torch.Tensor]:
16+
) -> tuple[torch.Tensor, torch.Tensor]:
1817
values, log_probs = super()._compute_policy_is_values(
1918
policy_obs=policy_obs,
2019
value_obs=value_obs,

Diff for: d3rlpy/algos/qlearning/torch/cql_impl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
import math
3-
from typing import Optional, Tuple
3+
from typing import Optional
44

55
import torch
66
import torch.nn.functional as F
@@ -119,7 +119,7 @@ def _compute_policy_is_values(
119119
policy_obs: TorchObservation,
120120
value_obs: TorchObservation,
121121
returns_to_go: torch.Tensor,
122-
) -> Tuple[torch.Tensor, torch.Tensor]:
122+
) -> tuple[torch.Tensor, torch.Tensor]:
123123
return sample_q_values_with_policy(
124124
policy=self._modules.policy,
125125
q_func_forwarder=self._q_func_forwarder,
@@ -131,7 +131,7 @@ def _compute_policy_is_values(
131131

132132
def _compute_random_is_values(
133133
self, obs: TorchObservation
134-
) -> Tuple[torch.Tensor, float]:
134+
) -> tuple[torch.Tensor, float]:
135135
# (batch, observation) -> (batch, n, observation)
136136
repeated_obs = expand_and_repeat_recursively(
137137
obs, self._n_action_samples

Diff for: d3rlpy/algos/qlearning/torch/crr_impl.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import dataclasses
2-
from typing import Dict
32

43
import torch
54
import torch.nn.functional as F
@@ -186,7 +185,7 @@ def update_actor_target(self) -> None:
186185

187186
def inner_update(
188187
self, batch: TorchMiniBatch, grad_step: int
189-
) -> Dict[str, float]:
188+
) -> dict[str, float]:
190189
metrics = {}
191190
metrics.update(self.update_critic(batch))
192191
metrics.update(self.update_actor(batch))

Diff for: d3rlpy/algos/qlearning/torch/ddpg_impl.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
from abc import ABCMeta, abstractmethod
3-
from typing import Callable, Dict
3+
from typing import Callable
44

55
import torch
66
from torch import nn
@@ -105,7 +105,7 @@ def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss:
105105
loss.critic_loss.backward()
106106
return loss
107107

108-
def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:
108+
def update_critic(self, batch: TorchMiniBatch) -> dict[str, float]:
109109
loss = self._compute_critic_grad(batch)
110110
self._modules.critic_optim.step()
111111
return asdict_as_float(loss)
@@ -130,7 +130,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:
130130
loss.actor_loss.backward()
131131
return loss
132132

133-
def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
133+
def update_actor(self, batch: TorchMiniBatch) -> dict[str, float]:
134134
# Q function should be inference mode for stability
135135
self._modules.q_funcs.eval()
136136
loss = self._compute_actor_grad(batch)
@@ -139,7 +139,7 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
139139

140140
def inner_update(
141141
self, batch: TorchMiniBatch, grad_step: int
142-
) -> Dict[str, float]:
142+
) -> dict[str, float]:
143143
metrics = {}
144144
metrics.update(self.update_critic(batch))
145145
metrics.update(self.update_actor(batch))
@@ -241,7 +241,7 @@ def update_actor_target(self) -> None:
241241

242242
def inner_update(
243243
self, batch: TorchMiniBatch, grad_step: int
244-
) -> Dict[str, float]:
244+
) -> dict[str, float]:
245245
metrics = super().inner_update(batch, grad_step)
246246
self.update_actor_target()
247247
return metrics

Diff for: d3rlpy/algos/qlearning/torch/dqn_impl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
from typing import Callable, Dict
2+
from typing import Callable
33

44
import torch
55
from torch import nn
@@ -79,7 +79,7 @@ def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss:
7979

8080
def inner_update(
8181
self, batch: TorchMiniBatch, grad_step: int
82-
) -> Dict[str, float]:
82+
) -> dict[str, float]:
8383
loss = self._compute_grad(batch)
8484
self._modules.optim.step()
8585
if grad_step % self._target_update_interval == 0:

0 commit comments

Comments
 (0)