Skip to content

Commit 1f2a9dc

Browse files
committed
Include lr_scheduler as state_dict
1 parent bb98403 commit 1f2a9dc

File tree

5 files changed

+116
-12
lines changed

5 files changed

+116
-12
lines changed

Diff for: d3rlpy/optimizers/optimizers.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
from typing import Iterable, Optional, Sequence, Tuple
2+
from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple
33

44
from torch import nn
55
from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop
@@ -93,6 +93,19 @@ def step(self) -> None:
9393
def optim(self) -> Optimizer:
9494
return self._optim
9595

96+
def state_dict(self) -> Mapping[str, Any]:
97+
return {
98+
"optim": self._optim.state_dict(),
99+
"lr_scheduler": (
100+
self._lr_scheduler.state_dict() if self._lr_scheduler else None
101+
),
102+
}
103+
104+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
105+
self._optim.load_state_dict(state_dict["optim"])
106+
if self._lr_scheduler:
107+
self._lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
108+
96109

97110
@dataclasses.dataclass()
98111
class OptimizerFactory(DynamicConfig):

Diff for: d3rlpy/torch_utility.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def __init__(
384384
def save(self, f: BinaryIO) -> None:
385385
# unwrap DDP
386386
modules = {
387-
k: unwrap_ddp_model(v) if isinstance(v, nn.Module) else v.optim
387+
k: unwrap_ddp_model(v) if isinstance(v, nn.Module) else v
388388
for k, v in self._modules.items()
389389
}
390390
states = {k: v.state_dict() for k, v in modules.items()}
@@ -393,10 +393,7 @@ def save(self, f: BinaryIO) -> None:
393393
def load(self, f: BinaryIO) -> None:
394394
chkpt = torch.load(f, map_location=map_location(self._device))
395395
for k, v in self._modules.items():
396-
if isinstance(v, nn.Module):
397-
v.load_state_dict(chkpt[k])
398-
else:
399-
v.optim.load_state_dict(chkpt[k])
396+
v.load_state_dict(chkpt[k])
400397

401398
@property
402399
def modules(self) -> Dict[str, Union[nn.Module, OptimizerWrapperProto]]:

Diff for: d3rlpy/types.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Sequence, Union
1+
from typing import Any, Mapping, Sequence, Union
22

33
import gym
44
import gymnasium
@@ -42,3 +42,9 @@ class OptimizerWrapperProto(Protocol):
4242
@property
4343
def optim(self) -> Optimizer:
4444
raise NotImplementedError
45+
46+
def state_dict(self) -> Mapping[str, Any]:
47+
raise NotImplementedError
48+
49+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
50+
raise NotImplementedError

Diff for: tests/optimizers/test_optimizers.py

+65
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,82 @@
1+
from typing import Optional
2+
13
import pytest
24
import torch
35
from torch import nn
46
from torch.optim import SGD, Adam, AdamW, RMSprop
57

8+
from d3rlpy.optimizers.lr_schedulers import (
9+
CosineAnnealingLRFactory,
10+
LRSchedulerFactory,
11+
)
612
from d3rlpy.optimizers.optimizers import (
713
AdamFactory,
814
AdamWFactory,
915
GPTAdamWFactory,
16+
OptimizerWrapper,
1017
RMSpropFactory,
1118
SGDFactory,
1219
)
1320

1421

22+
@pytest.mark.parametrize(
23+
"lr_scheduler_factory", [None, CosineAnnealingLRFactory(100)]
24+
)
25+
@pytest.mark.parametrize("compiled", [False, True])
26+
@pytest.mark.parametrize("clip_grad_norm", [None, 1e-4])
27+
def test_optimizer_wrapper(
28+
lr_scheduler_factory: Optional[LRSchedulerFactory],
29+
compiled: bool,
30+
clip_grad_norm: Optional[float],
31+
) -> None:
32+
model = nn.Linear(100, 200)
33+
optim = SGD(model.parameters(), lr=1)
34+
lr_scheduler = (
35+
lr_scheduler_factory.create(optim) if lr_scheduler_factory else None
36+
)
37+
wrapper = OptimizerWrapper(
38+
params=list(model.parameters()),
39+
optim=optim,
40+
compiled=compiled,
41+
clip_grad_norm=clip_grad_norm,
42+
lr_scheduler=lr_scheduler,
43+
)
44+
45+
loss = model(torch.rand(1, 100)).mean()
46+
loss.backward()
47+
48+
# check zero grad
49+
wrapper.zero_grad()
50+
if compiled:
51+
assert model.weight.grad is None
52+
assert model.bias.grad is None
53+
else:
54+
assert torch.all(model.weight.grad == 0)
55+
assert torch.all(model.bias.grad == 0)
56+
57+
# check step
58+
before_weight = torch.zeros_like(model.weight)
59+
before_weight.copy_(model.weight)
60+
before_bias = torch.zeros_like(model.bias)
61+
before_bias.copy_(model.bias)
62+
loss = model(torch.rand(1, 100)).mean()
63+
loss.backward()
64+
model.weight.grad.add_(1)
65+
model.weight.grad.mul_(10000)
66+
model.bias.grad.add_(1)
67+
model.bias.grad.mul_(10000)
68+
69+
wrapper.step()
70+
assert torch.all(model.weight != before_weight)
71+
assert torch.all(model.bias != before_bias)
72+
73+
# check clip_grad_norm
74+
if clip_grad_norm:
75+
assert torch.norm(model.weight.grad) < 1e-4
76+
else:
77+
assert torch.norm(model.weight.grad) > 1e-4
78+
79+
1580
@pytest.mark.parametrize("lr", [1e-4])
1681
@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
1782
def test_sgd_factory(lr: float, module: torch.nn.Module) -> None:

Diff for: tests/test_torch_utility.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import dataclasses
33
from io import BytesIO
4-
from typing import Any, Dict, Sequence
4+
from typing import Any, Dict, Optional, Sequence
55
from unittest.mock import Mock
66

77
import numpy as np
@@ -10,6 +10,10 @@
1010

1111
from d3rlpy.dataset import TrajectoryMiniBatch, Transition, TransitionMiniBatch
1212
from d3rlpy.optimizers import OptimizerWrapper
13+
from d3rlpy.optimizers.lr_schedulers import (
14+
CosineAnnealingLRFactory,
15+
LRSchedulerFactory,
16+
)
1317
from d3rlpy.torch_utility import (
1418
GEGLU,
1519
Checkpointer,
@@ -454,11 +458,22 @@ def test_torch_trajectory_mini_batch(
454458
assert torch.all(torch_batch2.masks == torch_batch.masks)
455459

456460

457-
def test_checkpointer() -> None:
461+
@pytest.mark.parametrize(
462+
"lr_scheduler_factory", [None, CosineAnnealingLRFactory(100)]
463+
)
464+
def test_checkpointer(
465+
lr_scheduler_factory: Optional[LRSchedulerFactory],
466+
) -> None:
458467
fc1 = torch.nn.Linear(100, 100)
459468
fc2 = torch.nn.Linear(100, 100)
460469
params = list(fc1.parameters())
461-
optim = OptimizerWrapper(params, torch.optim.Adam(params), False)
470+
raw_optim = torch.optim.Adam(params)
471+
lr_scheduler = (
472+
lr_scheduler_factory.create(raw_optim) if lr_scheduler_factory else None
473+
)
474+
optim = OptimizerWrapper(
475+
params, raw_optim, lr_scheduler=lr_scheduler, compiled=False
476+
)
462477
checkpointer = Checkpointer(
463478
modules={"fc1": fc1, "fc2": fc2, "optim": optim}, device="cpu:0"
464479
)
@@ -468,7 +483,7 @@ def test_checkpointer() -> None:
468483
states = {
469484
"fc1": fc1.state_dict(),
470485
"fc2": fc2.state_dict(),
471-
"optim": optim.optim.state_dict(),
486+
"optim": optim.state_dict(),
472487
}
473488
torch.save(states, ref_bytes)
474489

@@ -480,7 +495,15 @@ def test_checkpointer() -> None:
480495
fc1_2 = torch.nn.Linear(100, 100)
481496
fc2_2 = torch.nn.Linear(100, 100)
482497
params_2 = list(fc1_2.parameters())
483-
optim_2 = OptimizerWrapper(params_2, torch.optim.Adam(params_2), False)
498+
raw_optim_2 = torch.optim.Adam(params_2)
499+
lr_scheduler_2 = (
500+
lr_scheduler_factory.create(raw_optim_2)
501+
if lr_scheduler_factory
502+
else None
503+
)
504+
optim_2 = OptimizerWrapper(
505+
params_2, raw_optim_2, lr_scheduler=lr_scheduler_2, compiled=False
506+
)
484507
checkpointer = Checkpointer(
485508
modules={"fc1": fc1_2, "fc2": fc2_2, "optim": optim_2}, device="cpu:0"
486509
)

0 commit comments

Comments
 (0)