-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNNLayer.cpp
80 lines (69 loc) · 1.32 KB
/
NNLayer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "Mat.h"
#include "Vect.h"
#include "NNLayer.h"
#include "NNUtils.h"
#include "Utils.h"
#include <stdio.h>
#include <math.h>
NNLayer::NNLayer(int num_input, int num_output) {
this->num_input = num_input;
this->num_output = num_output;
w = new Mat(num_input, num_output);
fillRandomNorm(w, -1.0f, 1.0f);
w->multiply(sqrt(2.0f/num_input));
b = new Vect(num_output);
fill(b, 0.1f);
_x = 0;
_z = 0;
}
int NNLayer::getNumInput() const {
return num_input;
}
int NNLayer::getNumOutput() const {
return num_output;
}
Vect* NNLayer::forwardPropagate(const Vect *x) {
if (_x != 0) {
delete _x;
delete _z;
}
_x = new Vect(x);
_z = x->multiplyC(w);
_z->add(b);
Vect *result = new Vect(_z);
relu(result);
return result;
}
Vect* NNLayer::backwardPropagate(const Vect *e, float lr) {
if (_z == 0) {
print();
throwError("calling backProp without haveing called forwardProp.");
}
Vect *e_zrp = new Vect(_z);
relu_prime(e_zrp);
e_zrp->multiply(e);
Mat *dw = e_zrp->multiplyCR(_x);
dw->multiply(lr);
w->sub(dw);
Vect* e_in = e->multiplyR(w);
delete e_zrp;
delete dw;
return e_in;
}
void NNLayer::print() const {
printf("NNLayer(%d, %d)", num_input, num_output);
w->print();
b->print();
}
NNLayer::~NNLayer() {
if (_x != 0) {
delete _x;
delete _z;
_x=0;
_z=0;
}
delete w;
delete b;
w = 0;
b = 0;
}