Skip to content

Commit c6e6cd4

Browse files
committed
Support BC
1 parent b41be5a commit c6e6cd4

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

Diff for: d3rlpy/algos/qlearning/bc.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class BCConfig(LearnableConfig):
4949
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
5050
Observation preprocessor.
5151
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
52+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
5253
"""
5354

5455
batch_size: int = 100
@@ -95,7 +96,7 @@ def inner_create_impl(
9596
optim = self._config.optim_factory.create(
9697
imitator.named_modules(),
9798
lr=self._config.learning_rate,
98-
compiled=False,
99+
compiled=self.compiled,
99100
)
100101

101102
modules = BCModules(optim=optim, imitator=imitator)
@@ -105,6 +106,7 @@ def inner_create_impl(
105106
action_size=action_size,
106107
modules=modules,
107108
policy_type=self._config.policy_type,
109+
compiled=self.compiled,
108110
device=self._device,
109111
)
110112

@@ -139,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig):
139141
beta (float): Reguralization factor.
140142
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
141143
Observation preprocessor.
144+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
142145
"""
143146

144147
batch_size: int = 100
@@ -172,7 +175,7 @@ def inner_create_impl(
172175
optim = self._config.optim_factory.create(
173176
imitator.named_modules(),
174177
lr=self._config.learning_rate,
175-
compiled=False,
178+
compiled=self.compiled,
176179
)
177180

178181
modules = DiscreteBCModules(optim=optim, imitator=imitator)
@@ -182,6 +185,7 @@ def inner_create_impl(
182185
action_size=action_size,
183186
modules=modules,
184187
beta=self._config.beta,
188+
compiled=self.compiled,
185189
device=self._device,
186190
)
187191

Diff for: d3rlpy/algos/qlearning/torch/bc_impl.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
from abc import ABCMeta, abstractmethod
3-
from typing import Dict, Union
3+
from typing import Callable, Dict, Union
44

55
import torch
66
from torch.optim import Optimizer
@@ -18,7 +18,7 @@
1818
compute_stochastic_imitation_loss,
1919
)
2020
from ....optimizers import OptimizerWrapper
21-
from ....torch_utility import Modules, TorchMiniBatch
21+
from ....torch_utility import CudaGraphWrapper, Modules, TorchMiniBatch
2222
from ....types import Shape, TorchObservation
2323
from ..base import QLearningAlgoImplBase
2424

@@ -32,12 +32,14 @@ class BCBaseModules(Modules):
3232

3333
class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta):
3434
_modules: BCBaseModules
35+
_compute_imitator_grad: Callable[[TorchMiniBatch], ImitationLoss]
3536

3637
def __init__(
3738
self,
3839
observation_shape: Shape,
3940
action_size: int,
4041
modules: BCBaseModules,
42+
compiled: bool,
4143
device: str,
4244
):
4345
super().__init__(
@@ -46,15 +48,21 @@ def __init__(
4648
modules=modules,
4749
device=device,
4850
)
51+
self._compute_imitator_grad = (
52+
CudaGraphWrapper(self.compute_imitator_grad)
53+
if compiled
54+
else self.compute_imitator_grad
55+
)
4956

50-
def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
57+
def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
5158
self._modules.optim.zero_grad()
52-
5359
loss = self.compute_loss(batch.observations, batch.actions)
54-
5560
loss.loss.backward()
56-
self._modules.optim.step()
61+
return loss
5762

63+
def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
64+
loss = self._compute_imitator_grad(batch)
65+
self._modules.optim.step()
5866
return asdict_as_float(loss)
5967

6068
@abstractmethod
@@ -92,12 +100,14 @@ def __init__(
92100
action_size: int,
93101
modules: BCModules,
94102
policy_type: str,
103+
compiled: bool,
95104
device: str,
96105
):
97106
super().__init__(
98107
observation_shape=observation_shape,
99108
action_size=action_size,
100109
modules=modules,
110+
compiled=compiled,
101111
device=device,
102112
)
103113
self._policy_type = policy_type
@@ -145,12 +155,14 @@ def __init__(
145155
action_size: int,
146156
modules: DiscreteBCModules,
147157
beta: float,
158+
compiled: bool,
148159
device: str,
149160
):
150161
super().__init__(
151162
observation_shape=observation_shape,
152163
action_size=action_size,
153164
modules=modules,
165+
compiled=compiled,
154166
device=device,
155167
)
156168
self._beta = beta

0 commit comments

Comments
 (0)