Skip to content

Commit f3c5540

Browse files
authored
Support CudaGraph and torch.compile (#428)
* Add CudaGraphWrapper * Fix lint errors * Fix TD3 * Fix DiscreteSAC * Update torch dependency * Fix lint error * Add compiled flag to OptimizerWrapper * Workaround DiscreteSAC test * Add compiled property * Add tests * Update python version in readthedocs * Remove unnecessary change * Rename compile_graph to compiled * Support BC * Add compile option to reproduction scripts
1 parent 3b01da3 commit f3c5540

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+824
-276
lines changed

Diff for: .readthedocs.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ version: 2
22
build:
33
os: ubuntu-22.04
44
tools:
5-
python: "3.8"
5+
python: "3.10"
66
sphinx:
77
builder: html
88
configuration: docs/conf.py

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ d3rlpy supports Linux, macOS and Windows.
5454

5555
### Dependencies
5656
Installing d3rlpy package will install or upgrade the following packages to satisfy requirements:
57-
- torch>=2.0.0
57+
- torch>=2.5.0
5858
- tqdm>=4.66.3
5959
- gym>=0.26.0
6060
- gymnasium>=1.0.0

Diff for: d3rlpy/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=protected-access
12
import random
23

34
import gymnasium
@@ -68,6 +69,10 @@ def seed(n: int) -> None:
6869
# run healthcheck
6970
run_healthcheck()
7071

72+
if torch.cuda.is_available():
73+
# enable autograd compilation
74+
torch._dynamo.config.compiled_autograd = True
75+
torch.set_float32_matmul_precision("high")
7176

7277
# register Shimmy if available
7378
try:

Diff for: d3rlpy/algos/qlearning/awac.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class AWACConfig(LearnableConfig):
7070
n_action_samples (int): Number of sampled actions to calculate
7171
:math:`A^\pi(s_t, a_t)`.
7272
n_critics (int): Number of Q functions for ensemble.
73+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
7374
"""
7475

7576
actor_learning_rate: float = 3e-4
@@ -130,10 +131,14 @@ def inner_create_impl(
130131
)
131132

132133
actor_optim = self._config.actor_optim_factory.create(
133-
policy.named_modules(), lr=self._config.actor_learning_rate
134+
policy.named_modules(),
135+
lr=self._config.actor_learning_rate,
136+
compiled=self.compiled,
134137
)
135138
critic_optim = self._config.critic_optim_factory.create(
136-
q_funcs.named_modules(), lr=self._config.critic_learning_rate
139+
q_funcs.named_modules(),
140+
lr=self._config.critic_learning_rate,
141+
compiled=self.compiled,
137142
)
138143

139144
dummy_log_temp = Parameter(torch.zeros(1, 1))
@@ -158,6 +163,7 @@ def inner_create_impl(
158163
tau=self._config.tau,
159164
lam=self._config.lam,
160165
n_action_samples=self._config.n_action_samples,
166+
compiled=self.compiled,
161167
device=self._device,
162168
)
163169

Diff for: d3rlpy/algos/qlearning/bc.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class BCConfig(LearnableConfig):
4949
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
5050
Observation preprocessor.
5151
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
52+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
5253
"""
5354

5455
batch_size: int = 100
@@ -93,7 +94,9 @@ def inner_create_impl(
9394
raise ValueError(f"invalid policy_type: {self._config.policy_type}")
9495

9596
optim = self._config.optim_factory.create(
96-
imitator.named_modules(), lr=self._config.learning_rate
97+
imitator.named_modules(),
98+
lr=self._config.learning_rate,
99+
compiled=self.compiled,
97100
)
98101

99102
modules = BCModules(optim=optim, imitator=imitator)
@@ -103,6 +106,7 @@ def inner_create_impl(
103106
action_size=action_size,
104107
modules=modules,
105108
policy_type=self._config.policy_type,
109+
compiled=self.compiled,
106110
device=self._device,
107111
)
108112

@@ -137,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig):
137141
beta (float): Reguralization factor.
138142
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
139143
Observation preprocessor.
144+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
140145
"""
141146

142147
batch_size: int = 100
@@ -168,7 +173,9 @@ def inner_create_impl(
168173
)
169174

170175
optim = self._config.optim_factory.create(
171-
imitator.named_modules(), lr=self._config.learning_rate
176+
imitator.named_modules(),
177+
lr=self._config.learning_rate,
178+
compiled=self.compiled,
172179
)
173180

174181
modules = DiscreteBCModules(optim=optim, imitator=imitator)
@@ -178,6 +185,7 @@ def inner_create_impl(
178185
action_size=action_size,
179186
modules=modules,
180187
beta=self._config.beta,
188+
compiled=self.compiled,
181189
device=self._device,
182190
)
183191

Diff for: d3rlpy/algos/qlearning/bcq.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class BCQConfig(LearnableConfig):
137137
rl_start_step (int): Steps to start to update policy function and Q
138138
functions. If this is large, RL training would be more stabilized.
139139
beta (float): KL reguralization term for Conditional VAE.
140+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
140141
"""
141142

142143
actor_learning_rate: float = 1e-3
@@ -228,15 +229,20 @@ def inner_create_impl(
228229
)
229230

230231
actor_optim = self._config.actor_optim_factory.create(
231-
policy.named_modules(), lr=self._config.actor_learning_rate
232+
policy.named_modules(),
233+
lr=self._config.actor_learning_rate,
234+
compiled=self.compiled,
232235
)
233236
critic_optim = self._config.critic_optim_factory.create(
234-
q_funcs.named_modules(), lr=self._config.critic_learning_rate
237+
q_funcs.named_modules(),
238+
lr=self._config.critic_learning_rate,
239+
compiled=self.compiled,
235240
)
236241
vae_optim = self._config.imitator_optim_factory.create(
237242
list(vae_encoder.named_modules())
238243
+ list(vae_decoder.named_modules()),
239244
lr=self._config.imitator_learning_rate,
245+
compiled=self.compiled,
240246
)
241247

242248
modules = BCQModules(
@@ -264,6 +270,7 @@ def inner_create_impl(
264270
action_flexibility=self._config.action_flexibility,
265271
beta=self._config.beta,
266272
rl_start_step=self._config.rl_start_step,
273+
compiled=self.compiled,
267274
device=self._device,
268275
)
269276

@@ -331,6 +338,7 @@ class DiscreteBCQConfig(LearnableConfig):
331338
target_update_interval (int): Interval to update the target network.
332339
share_encoder (bool): Flag to share encoder between Q-function and
333340
imitation models.
341+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
334342
"""
335343

336344
learning_rate: float = 6.25e-5
@@ -402,7 +410,9 @@ def inner_create_impl(
402410
q_func_params = list(q_funcs.named_modules())
403411
imitator_params = list(imitator.named_modules())
404412
optim = self._config.optim_factory.create(
405-
q_func_params + imitator_params, lr=self._config.learning_rate
413+
q_func_params + imitator_params,
414+
lr=self._config.learning_rate,
415+
compiled=self.compiled,
406416
)
407417

408418
modules = DiscreteBCQModules(
@@ -422,6 +432,7 @@ def inner_create_impl(
422432
gamma=self._config.gamma,
423433
action_flexibility=self._config.action_flexibility,
424434
beta=self._config.beta,
435+
compiled=self.compiled,
425436
device=self._device,
426437
)
427438

Diff for: d3rlpy/algos/qlearning/bear.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class BEARConfig(LearnableConfig):
114114
policy training.
115115
warmup_steps (int): Number of steps to warmup the policy
116116
function.
117+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
117118
"""
118119

119120
actor_learning_rate: float = 1e-4
@@ -217,21 +218,30 @@ def inner_create_impl(
217218
)
218219

219220
actor_optim = self._config.actor_optim_factory.create(
220-
policy.named_modules(), lr=self._config.actor_learning_rate
221+
policy.named_modules(),
222+
lr=self._config.actor_learning_rate,
223+
compiled=self.compiled,
221224
)
222225
critic_optim = self._config.critic_optim_factory.create(
223-
q_funcs.named_modules(), lr=self._config.critic_learning_rate
226+
q_funcs.named_modules(),
227+
lr=self._config.critic_learning_rate,
228+
compiled=self.compiled,
224229
)
225230
vae_optim = self._config.imitator_optim_factory.create(
226231
list(vae_encoder.named_modules())
227232
+ list(vae_decoder.named_modules()),
228233
lr=self._config.imitator_learning_rate,
234+
compiled=self.compiled,
229235
)
230236
temp_optim = self._config.temp_optim_factory.create(
231-
log_temp.named_modules(), lr=self._config.temp_learning_rate
237+
log_temp.named_modules(),
238+
lr=self._config.temp_learning_rate,
239+
compiled=self.compiled,
232240
)
233241
alpha_optim = self._config.alpha_optim_factory.create(
234-
log_alpha.named_modules(), lr=self._config.actor_learning_rate
242+
log_alpha.named_modules(),
243+
lr=self._config.actor_learning_rate,
244+
compiled=self.compiled,
235245
)
236246

237247
modules = BEARModules(
@@ -266,6 +276,7 @@ def inner_create_impl(
266276
mmd_sigma=self._config.mmd_sigma,
267277
vae_kl_weight=self._config.vae_kl_weight,
268278
warmup_steps=self._config.warmup_steps,
279+
compiled=self.compiled,
269280
device=self._device,
270281
)
271282

Diff for: d3rlpy/algos/qlearning/cal_ql.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class CalQLConfig(CQLConfig):
6969
:math:`\log{\sum_a \exp{Q(s, a)}}`.
7070
soft_q_backup (bool): Flag to use SAC-style backup.
7171
max_q_backup (bool): Flag to sample max Q-values for target.
72+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
7273
"""
7374

7475
def create(
@@ -88,7 +89,6 @@ def inner_create_impl(
8889
assert not (
8990
self._config.soft_q_backup and self._config.max_q_backup
9091
), "soft_q_backup and max_q_backup are mutually exclusive."
91-
9292
policy = create_normal_policy(
9393
observation_shape,
9494
action_size,
@@ -128,20 +128,28 @@ def inner_create_impl(
128128
)
129129

130130
actor_optim = self._config.actor_optim_factory.create(
131-
policy.named_modules(), lr=self._config.actor_learning_rate
131+
policy.named_modules(),
132+
lr=self._config.actor_learning_rate,
133+
compiled=self.compiled,
132134
)
133135
critic_optim = self._config.critic_optim_factory.create(
134-
q_funcs.named_modules(), lr=self._config.critic_learning_rate
136+
q_funcs.named_modules(),
137+
lr=self._config.critic_learning_rate,
138+
compiled=self.compiled,
135139
)
136140
if self._config.temp_learning_rate > 0:
137141
temp_optim = self._config.temp_optim_factory.create(
138-
log_temp.named_modules(), lr=self._config.temp_learning_rate
142+
log_temp.named_modules(),
143+
lr=self._config.temp_learning_rate,
144+
compiled=self.compiled,
139145
)
140146
else:
141147
temp_optim = None
142148
if self._config.alpha_learning_rate > 0:
143149
alpha_optim = self._config.alpha_optim_factory.create(
144-
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
150+
log_alpha.named_modules(),
151+
lr=self._config.alpha_learning_rate,
152+
compiled=self.compiled,
145153
)
146154
else:
147155
alpha_optim = None
@@ -171,6 +179,7 @@ def inner_create_impl(
171179
n_action_samples=self._config.n_action_samples,
172180
soft_q_backup=self._config.soft_q_backup,
173181
max_q_backup=self._config.max_q_backup,
182+
compiled=self.compiled,
174183
device=self._device,
175184
)
176185

Diff for: d3rlpy/algos/qlearning/cql.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig):
100100
:math:`\log{\sum_a \exp{Q(s, a)}}`.
101101
soft_q_backup (bool): Flag to use SAC-style backup.
102102
max_q_backup (bool): Flag to sample max Q-values for target.
103+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
103104
"""
104105

105106
actor_learning_rate: float = 1e-4
@@ -142,7 +143,6 @@ def inner_create_impl(
142143
assert not (
143144
self._config.soft_q_backup and self._config.max_q_backup
144145
), "soft_q_backup and max_q_backup are mutually exclusive."
145-
146146
policy = create_normal_policy(
147147
observation_shape,
148148
action_size,
@@ -182,20 +182,28 @@ def inner_create_impl(
182182
)
183183

184184
actor_optim = self._config.actor_optim_factory.create(
185-
policy.named_modules(), lr=self._config.actor_learning_rate
185+
policy.named_modules(),
186+
lr=self._config.actor_learning_rate,
187+
compiled=self.compiled,
186188
)
187189
critic_optim = self._config.critic_optim_factory.create(
188-
q_funcs.named_modules(), lr=self._config.critic_learning_rate
190+
q_funcs.named_modules(),
191+
lr=self._config.critic_learning_rate,
192+
compiled=self.compiled,
189193
)
190194
if self._config.temp_learning_rate > 0:
191195
temp_optim = self._config.temp_optim_factory.create(
192-
log_temp.named_modules(), lr=self._config.temp_learning_rate
196+
log_temp.named_modules(),
197+
lr=self._config.temp_learning_rate,
198+
compiled=self.compiled,
193199
)
194200
else:
195201
temp_optim = None
196202
if self._config.alpha_learning_rate > 0:
197203
alpha_optim = self._config.alpha_optim_factory.create(
198-
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
204+
log_alpha.named_modules(),
205+
lr=self._config.alpha_learning_rate,
206+
compiled=self.compiled,
199207
)
200208
else:
201209
alpha_optim = None
@@ -225,6 +233,7 @@ def inner_create_impl(
225233
n_action_samples=self._config.n_action_samples,
226234
soft_q_backup=self._config.soft_q_backup,
227235
max_q_backup=self._config.max_q_backup,
236+
compiled=self.compiled,
228237
device=self._device,
229238
)
230239

@@ -272,6 +281,7 @@ class DiscreteCQLConfig(LearnableConfig):
272281
target_update_interval (int): Interval to synchronize the target
273282
network.
274283
alpha (float): math:`\alpha` value above.
284+
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
275285
"""
276286

277287
learning_rate: float = 6.25e-5
@@ -318,7 +328,9 @@ def inner_create_impl(
318328
)
319329

320330
optim = self._config.optim_factory.create(
321-
q_funcs.named_modules(), lr=self._config.learning_rate
331+
q_funcs.named_modules(),
332+
lr=self._config.learning_rate,
333+
compiled=self.compiled,
322334
)
323335

324336
modules = DQNModules(
@@ -336,6 +348,7 @@ def inner_create_impl(
336348
target_update_interval=self._config.target_update_interval,
337349
gamma=self._config.gamma,
338350
alpha=self._config.alpha,
351+
compiled=self.compiled,
339352
device=self._device,
340353
)
341354

0 commit comments

Comments
 (0)