-
Notifications
You must be signed in to change notification settings - Fork 22
/
optimizer.cpp
110 lines (72 loc) · 1.96 KB
/
optimizer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
/*
* Optimizer.c
*
*/
#include <thread>
#include "optimizer.h"
#include "variable.h"
Optimizer::Optimizer(Model *model, float learning_rate) {
this->model = model;
lr = learning_rate;
}
Optimizer::Optimizer(Model *model, float learning_rate, float clip_grad_threshold) {
this->model = model;
lr = learning_rate;
this->clip_grad_threshold = clip_grad_threshold;
}
Optimizer::~Optimizer() {
delOpts();
}
void Optimizer::delOpts(){
for (int i = 0; i < opts.size(); i++) {
OptimizerParams *p = opts.at(i);
delete p;
}
opts.clear();
}
OptimizerParams *Optimizer::createOptimizerParams(Variable *v){
cout << "createOptimizerParams base" << endl;
}
void Optimizer::init() {
epoch = 1;
updateParams = model->getUpdateParams();
delOpts();
for (int i = 0; i < updateParams.size(); i++) {
UpdateParams *up = updateParams.at(i);
for (int j = 0; j < up->params.size(); j++) {
Variable *v = up->params.at(j);
opts.push_back(createOptimizerParams(v));
}
}
}
void Optimizer::update_param(Variable *w, OptimizerParams &opp) {
cout << "update_param" << endl;
}
void Optimizer::zero_grads() {
for (int i = 0; i < updateParams.size(); i++) {
UpdateParams *up = updateParams.at(i);
for(int j=0; j < up->params.size(); j++){
Variable *v = up->params.at(j);
v->grad *= 0.0;
}
}
}
void Optimizer::clip_grad(Variable *v){
if (clip_grad_threshold > 0.0){
v->grad.element_wise_clip(v->grad, clip_grad_threshold);
}
}
void Optimizer::update() {
int k = 0;
for (int i = 0; i < updateParams.size(); i++) {
UpdateParams *up = updateParams.at(i);
for(int j=0; j < up->params.size(); j++){
Variable *v = up->params.at(j);
clip_grad(v);
update_param(v, *opts.at(k));
k++;
}
}
epoch++;
zero_grads();
}