Skip to content

Commit e3053c1

Browse files
newstzpzfacebook-github-bot
authored andcommitted
Support specifying end lr for WarmupCosineLR.
Summary: Support specifying end lr for WarmupCosineLR * Use `cfg.SOLVER.BASE_LR_END` to specify the lr value for the last iteration, only used by `WarmupCosineLR`. Reviewed By: zhanghang1989 Differential Revision: D33292501 fbshipit-source-id: 07942d734f8d445fe03f82a85bdc78b65a8aebd0
1 parent 085fda4 commit e3053c1

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

detectron2/config/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@
517517
_C.SOLVER.MAX_ITER = 40000
518518

519519
_C.SOLVER.BASE_LR = 0.001
520+
# The end lr, only used by WarmupCosineLR
521+
_C.SOLVER.BASE_LR_END = 0.0
520522

521523
_C.SOLVER.MOMENTUM = 0.9
522524

detectron2/solver/build.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ def build_lr_scheduler(
272272
num_updates=cfg.SOLVER.MAX_ITER,
273273
)
274274
elif name == "WarmupCosineLR":
275-
sched = CosineParamScheduler(1, 0)
275+
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
276+
assert end_value >= 0.0 and end_value <= 1.0, end_value
277+
sched = CosineParamScheduler(1, end_value)
276278
else:
277279
raise ValueError("Unknown LR scheduler: {}".format(name))
278280

tests/test_scheduler.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22

33
import math
4-
import numpy as np
54
from unittest import TestCase
5+
6+
import numpy as np
67
import torch
8+
from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler
79
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
810
from torch import nn
911

10-
from detectron2.solver import LRMultiplier, WarmupParamScheduler
11-
1212

1313
class TestScheduler(TestCase):
1414
def test_warmup_multistep(self):
@@ -66,3 +66,50 @@ def test_warmup_cosine(self):
6666
self.assertAlmostEqual(lr, expected_cosine)
6767
else:
6868
self.assertNotAlmostEqual(lr, expected_cosine)
69+
70+
def test_warmup_cosine_end_value(self):
71+
from detectron2.config import CfgNode, get_cfg
72+
73+
def _test_end_value(cfg_dict):
74+
cfg = get_cfg()
75+
cfg.merge_from_other_cfg(CfgNode(cfg_dict))
76+
77+
p = nn.Parameter(torch.zeros(0))
78+
opt = torch.optim.SGD([p], lr=cfg.SOLVER.BASE_LR)
79+
80+
scheduler = build_lr_scheduler(cfg, opt)
81+
82+
p.sum().backward()
83+
opt.step()
84+
self.assertEqual(
85+
opt.param_groups[0]["lr"], cfg.SOLVER.BASE_LR * cfg.SOLVER.WARMUP_FACTOR
86+
)
87+
88+
lrs = []
89+
for _ in range(cfg.SOLVER.MAX_ITER):
90+
scheduler.step()
91+
lrs.append(opt.param_groups[0]["lr"])
92+
93+
self.assertAlmostEqual(lrs[-1], cfg.SOLVER.BASE_LR_END)
94+
95+
_test_end_value({
96+
"SOLVER": {
97+
"LR_SCHEDULER_NAME": "WarmupCosineLR",
98+
"MAX_ITER": 100,
99+
"WARMUP_ITERS": 10,
100+
"WARMUP_FACTOR": 0.1,
101+
"BASE_LR": 5.0,
102+
"BASE_LR_END": 0.0,
103+
}
104+
})
105+
106+
_test_end_value({
107+
"SOLVER": {
108+
"LR_SCHEDULER_NAME": "WarmupCosineLR",
109+
"MAX_ITER": 100,
110+
"WARMUP_ITERS": 10,
111+
"WARMUP_FACTOR": 0.1,
112+
"BASE_LR": 5.0,
113+
"BASE_LR_END": 0.5,
114+
}
115+
})

0 commit comments

Comments
 (0)