-
Notifications
You must be signed in to change notification settings - Fork 0
/
lr_scheduler.py
65 lines (53 loc) · 2.18 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
def MultiStepLR(initial_learning_rate, lr_steps, lr_rate, name='MultiStepLR'):
"""Multi-steps learning rate scheduler."""
lr_steps_value = [initial_learning_rate]
for _ in range(len(lr_steps)):
lr_steps_value.append(lr_steps_value[-1] * lr_rate)
return tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=lr_steps, values=lr_steps_value)
def CosineAnnealingLR_Restart(initial_learning_rate, t_period, lr_min):
"""Cosine annealing learning rate scheduler with restart."""
return tf.keras.experimental.CosineDecayRestarts(
initial_learning_rate=initial_learning_rate,
first_decay_steps=t_period, t_mul=1.0, m_mul=1.0,
alpha=lr_min / initial_learning_rate)
if __name__ == "__main__":
# pretrain PSNR lr scheduler
lr_scheduler = MultiStepLR(2e-4, [200000, 400000, 600000, 800000], 0.5)
# ESRGAN lr scheduler
# lr_scheduler = MultiStepLR(1e-4, [50000, 100000, 200000, 300000], 0.5)
# Cosine Annealing lr scheduler
# lr_scheduler = CosineAnnealingLR_Restart(2e-4, 250000, 1e-7)
##############################
# Draw figure
##############################
N_iter = 1000000
step_list = list(range(0, N_iter, 1000))
lr_list = []
for i in step_list:
current_lr = lr_scheduler(i).numpy()
lr_list.append(current_lr)
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as mtick
mpl.style.use('default')
import seaborn
seaborn.set(style='whitegrid')
seaborn.set_context('paper')
plt.figure(1)
plt.subplot(111)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Title', fontsize=16, color='k')
plt.plot(step_list, lr_list, linewidth=1.5, label='learning rate scheme')
legend = plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
labels = ax.get_xticks().tolist()
for k, v in enumerate(labels):
labels[k] = str(int(v / 1000)) + 'K'
ax.set_xticklabels(labels)
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
ax.set_ylabel('Learning rate')
ax.set_xlabel('Iteration')
fig = plt.gcf()
plt.show()