-
Notifications
You must be signed in to change notification settings - Fork 0
/
learning_rate_range_test.py
144 lines (119 loc) · 5.24 KB
/
learning_rate_range_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
from config import LRRT_STOP_FACTOR, RANDOM_SEED
from callbacks import LearningRateVsLossCallback
import numpy as np
import wandb
import tensorflow as tf
import pandas as pd
from model import build_densenet121_model, build_efficientnet_model, build_mobilenetv2_model
from optimizer import build_sgd_optimizer
from split_dataset import SplitDataset
# Load data
train_dataframe = pd.read_csv("wlasl100_skeletons_train.csv", index_col=0)
validation_dataframe = pd.read_csv("wlasl100_skeletons_val.csv", index_col=0)
dataset = SplitDataset(train_dataframe, validation_dataframe, num_splits=5)
del train_dataframe, validation_dataframe
def run_experiment(config=None, log_to_wandb=True, verbose=0):
tf.keras.backend.clear_session()
tf.keras.utils.set_random_seed(RANDOM_SEED)
# check if config was provided
if config is None:
print("Not config provided.")
return
print("[INFO] Configuration:", config, "\n")
# generate train dataset
train_dataset = dataset.get_training_set(
split=config['split'],
batch_size=config['batch_size'],
buffer_size=dataset.num_train_examples,
repeat=True,
deterministic=True,
augmentation=config['augmentation'],
pipeline=config['pipeline'])
# generate val dataset
validation_dataset = dataset.get_validation_set(
split=config['split'],
batch_size=config['batch_size'],
pipeline=config['pipeline'])
print("[INFO] Dataset Total examples:", dataset.num_total_examples)
print("[INFO] Dataset Training examples:", dataset.num_train_examples)
print("[INFO] Dataset Validation examples:", dataset.num_val_examples)
# setup optimizer
optimizer = build_sgd_optimizer(initial_learning_rate=config['initial_learning_rate'],
maximal_learning_rate=config['maximal_learning_rate'],
momentum=config['momentum'],
nesterov=config['nesterov'],
step_size=config['step_size'],
weight_decay=config['weight_decay'])
# setup model
input_shape = [None, dataset.input_width, 3]
if config['backbone'] == "densenet":
model = build_densenet121_model(input_shape=input_shape,
dropout=config['dropout'],
optimizer=optimizer,
pretraining=config['pretraining'])
elif config['backbone'] == "mobilenet":
model = build_mobilenetv2_model(input_shape=input_shape,
dropout=config['dropout'],
optimizer=optimizer,
pretraining=config['pretraining'])
elif config['backbone'] == "efficientnet":
model = build_efficientnet_model(input_shape=input_shape,
dropout=config['dropout'],
optimizer=optimizer,
pretraining=config['pretraining'])
else:
raise Exception("Unknown model name")
# setup callback
lrc = LearningRateVsLossCallback(
validation_data=validation_dataset,
log_each_steps=config['log_each_steps'],
stop_factor=LRRT_STOP_FACTOR,
stop_patience=config['stop_patience'],
loss_min_delta=config['loss_min_delta'],
log_to_wandb=log_to_wandb)
# train model
model.fit(train_dataset,
epochs=1,
steps_per_epoch=int(config['step_size']),
verbose=verbose,
callbacks=[lrc])
# get the logs of the callback
logs = lrc.get_logs()
return logs
def agent_fn(config=None):
wandb.init(config=config, reinit=True)
maximal_learning_rate = wandb.config.maximal_learning_rate
initial_learning_rate = wandb.config.initial_learning_rate
learning_rate_delta = wandb.config.learning_rate_delta
batch_size = wandb.config.batch_size
learning_rate_distance = maximal_learning_rate - initial_learning_rate
step_size = learning_rate_distance / learning_rate_delta
log_each_steps = np.ceil(dataset.num_train_examples / batch_size)
update = {"step_size": step_size, "log_each_steps": log_each_steps}
wandb.config.update(update)
_ = run_experiment(config=wandb.config, log_to_wandb=True, verbose=0)
wandb.finish()
def main(args):
entity = args.entity
project = args.project
sweep_id = args.sweep_id
wandb.agent(sweep_id, project=project, entity=entity, function=agent_fn)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Learning rate range test cv.')
parser.add_argument('--entity',
type=str,
help='Entity',
default='cv_inside')
parser.add_argument('--project',
type=str,
help='Project name',
default='lrrt-wlasl100-tssi')
parser.add_argument('--sweep_id',
type=str,
help='Sweep id',
required=True)
args = parser.parse_args()
print(args)
main(args)