Skip to content

Commit 1f8144b

Browse files
committed
Fix errors from lint
1 parent eb7ee38 commit 1f8144b

File tree

4 files changed

+37
-20
lines changed

4 files changed

+37
-20
lines changed

Diff for: d3rlpy/algos/qlearning/prdc.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
from typing import Callable, Generator, Optional
2+
from typing import Callable, Optional
33

44
import numpy as np
55
import torch
@@ -11,7 +11,10 @@
1111
from ...dataset import ReplayBufferBase
1212
from ...logging import FileAdapterFactory, LoggerAdapterFactory
1313
from ...metrics import EvaluatorProtocol
14-
from ...models.builders import create_continuous_q_function, create_deterministic_policy
14+
from ...models.builders import (
15+
create_continuous_q_function,
16+
create_deterministic_policy,
17+
)
1518
from ...models.encoders import EncoderFactory, make_encoder_field
1619
from ...models.q_functions import QFunctionFactory, make_q_func_field
1720
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
@@ -188,43 +191,47 @@ def fit(
188191
dataset: ReplayBufferBase,
189192
n_steps: int,
190193
n_steps_per_epoch: int = 10000,
191-
logging_steps: int = 500,
192-
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
193194
experiment_name: Optional[str] = None,
194195
with_timestamp: bool = True,
196+
logging_steps: int = 500,
197+
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
195198
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
196199
show_progress: bool = True,
197200
save_interval: int = 1,
198201
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
199202
callback: Optional[Callable[[Self, int, int], None]] = None,
200203
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
201-
) -> Generator[tuple[int, dict[str, float]], None, None]:
204+
) -> list[tuple[int, dict[str, float]]]:
202205
observations = []
203206
actions = []
204207
for episode in dataset.buffer.episodes:
205208
for i in range(episode.transition_count):
206209
transition = dataset.transition_picker(episode, i)
207-
observations.append(transition.observation.reshape(1, -1))
208-
actions.append(transition.action.reshape(1, -1))
210+
observations.append(np.reshape(transition.observation, (1, -1)))
211+
actions.append(np.reshape(transition.action, (1, -1)))
209212
observations = np.concatenate(observations, axis=0)
210213
actions = np.concatenate(actions, axis=0)
211214

212215
build_scalers_with_transition_picker(self, dataset)
213216
if self.observation_scaler and self.observation_scaler.built:
214-
observations = self.observation_scaler.transform(
215-
torch.tensor(observations, device=self._device)
217+
observations = (
218+
self.observation_scaler.transform(
219+
torch.tensor(observations, device=self._device)
220+
)
221+
.cpu()
222+
.numpy()
216223
)
217-
observations = observations.cpu().numpy()
218224

219225
if self.action_scaler and self.action_scaler.built:
220-
actions = self.action_scaler.transform(
221-
torch.tensor(actions, device=self._device)
226+
actions = (
227+
self.action_scaler.transform(torch.tensor(actions, device=self._device))
228+
.cpu()
229+
.numpy()
222230
)
223-
actions = actions.cpu().numpy()
224231

225232
self._nbsr.fit(
226233
np.concatenate(
227-
[self._config.beta * observations, actions],
234+
[np.multiply(observations, self._config.beta), actions],
228235
axis=1,
229236
)
230237
)

Diff for: d3rlpy/algos/qlearning/torch/prdc_impl.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ....torch_utility import TorchMiniBatch
99
from ....types import Shape
1010
from .ddpg_impl import DDPGBaseActorLoss, DDPGModules
11-
from .td3_plus_bc_impl import TD3PlusBCImpl
11+
from .td3_impl import TD3Impl
1212

1313
__all__ = ["PRDCImpl"]
1414

@@ -18,8 +18,9 @@ class PRDCActorLoss(DDPGBaseActorLoss):
1818
dc_loss: torch.Tensor
1919

2020

21-
class PRDCImpl(TD3PlusBCImpl):
22-
_beta: float = 2.0
21+
class PRDCImpl(TD3Impl):
22+
_alpha: float
23+
_beta: float
2324
_nbsr: NearestNeighbors
2425

2526
def __init__(
@@ -50,11 +51,11 @@ def __init__(
5051
tau=tau,
5152
target_smoothing_sigma=target_smoothing_sigma,
5253
target_smoothing_clip=target_smoothing_clip,
53-
alpha=alpha,
5454
update_actor_interval=update_actor_interval,
5555
compiled=compiled,
5656
device=device,
5757
)
58+
self._alpha = alpha
5859
self._beta = beta
5960
self._nbsr = nbsr
6061

@@ -66,7 +67,9 @@ def compute_actor_loss(
6667
)[0]
6768
lam = self._alpha / (q_t.abs().mean()).detach()
6869
key = (
69-
torch.cat([self._beta * batch.observations, action.squashed_mu], dim=-1)
70+
torch.cat(
71+
[torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1
72+
)
7073
.detach()
7174
.cpu()
7275
.numpy()

Diff for: mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,6 @@ follow_imports_for_stubs = True
7171
ignore_missing_imports = True
7272
follow_imports = skip
7373
follow_imports_for_stubs = True
74+
75+
[mypy-sklearn.*]
76+
ignore_missing_imports = True

Diff for: tests/algos/qlearning/test_prdc.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import pytest
44

55
from d3rlpy.algos.qlearning.prdc import PRDCConfig
6-
from d3rlpy.models import MeanQFunctionFactory, QFunctionFactory, QRQFunctionFactory
6+
from d3rlpy.models import (
7+
MeanQFunctionFactory,
8+
QFunctionFactory,
9+
QRQFunctionFactory,
10+
)
711
from d3rlpy.types import Shape
812

913
from ...models.torch.model_test import DummyEncoderFactory

0 commit comments

Comments
 (0)