|
1 | 1 | import dataclasses
|
2 |
| -from typing import Callable, Generator, Optional |
| 2 | +from typing import Callable, Optional |
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import torch
|
|
11 | 11 | from ...dataset import ReplayBufferBase
|
12 | 12 | from ...logging import FileAdapterFactory, LoggerAdapterFactory
|
13 | 13 | from ...metrics import EvaluatorProtocol
|
14 |
| -from ...models.builders import create_continuous_q_function, create_deterministic_policy |
| 14 | +from ...models.builders import ( |
| 15 | + create_continuous_q_function, |
| 16 | + create_deterministic_policy, |
| 17 | +) |
15 | 18 | from ...models.encoders import EncoderFactory, make_encoder_field
|
16 | 19 | from ...models.q_functions import QFunctionFactory, make_q_func_field
|
17 | 20 | from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
|
@@ -188,43 +191,47 @@ def fit(
|
188 | 191 | dataset: ReplayBufferBase,
|
189 | 192 | n_steps: int,
|
190 | 193 | n_steps_per_epoch: int = 10000,
|
191 |
| - logging_steps: int = 500, |
192 |
| - logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, |
193 | 194 | experiment_name: Optional[str] = None,
|
194 | 195 | with_timestamp: bool = True,
|
| 196 | + logging_steps: int = 500, |
| 197 | + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, |
195 | 198 | logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
|
196 | 199 | show_progress: bool = True,
|
197 | 200 | save_interval: int = 1,
|
198 | 201 | evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
|
199 | 202 | callback: Optional[Callable[[Self, int, int], None]] = None,
|
200 | 203 | epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
|
201 |
| - ) -> Generator[tuple[int, dict[str, float]], None, None]: |
| 204 | + ) -> list[tuple[int, dict[str, float]]]: |
202 | 205 | observations = []
|
203 | 206 | actions = []
|
204 | 207 | for episode in dataset.buffer.episodes:
|
205 | 208 | for i in range(episode.transition_count):
|
206 | 209 | transition = dataset.transition_picker(episode, i)
|
207 |
| - observations.append(transition.observation.reshape(1, -1)) |
208 |
| - actions.append(transition.action.reshape(1, -1)) |
| 210 | + observations.append(np.reshape(transition.observation, (1, -1))) |
| 211 | + actions.append(np.reshape(transition.action, (1, -1))) |
209 | 212 | observations = np.concatenate(observations, axis=0)
|
210 | 213 | actions = np.concatenate(actions, axis=0)
|
211 | 214 |
|
212 | 215 | build_scalers_with_transition_picker(self, dataset)
|
213 | 216 | if self.observation_scaler and self.observation_scaler.built:
|
214 |
| - observations = self.observation_scaler.transform( |
215 |
| - torch.tensor(observations, device=self._device) |
| 217 | + observations = ( |
| 218 | + self.observation_scaler.transform( |
| 219 | + torch.tensor(observations, device=self._device) |
| 220 | + ) |
| 221 | + .cpu() |
| 222 | + .numpy() |
216 | 223 | )
|
217 |
| - observations = observations.cpu().numpy() |
218 | 224 |
|
219 | 225 | if self.action_scaler and self.action_scaler.built:
|
220 |
| - actions = self.action_scaler.transform( |
221 |
| - torch.tensor(actions, device=self._device) |
| 226 | + actions = ( |
| 227 | + self.action_scaler.transform(torch.tensor(actions, device=self._device)) |
| 228 | + .cpu() |
| 229 | + .numpy() |
222 | 230 | )
|
223 |
| - actions = actions.cpu().numpy() |
224 | 231 |
|
225 | 232 | self._nbsr.fit(
|
226 | 233 | np.concatenate(
|
227 |
| - [self._config.beta * observations, actions], |
| 234 | + [np.multiply(observations, self._config.beta), actions], |
228 | 235 | axis=1,
|
229 | 236 | )
|
230 | 237 | )
|
|
0 commit comments