-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.cpp
54 lines (39 loc) · 1.37 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include "buildTensorflow.h"
// Example of training a network on the buildTensorflow framework.
int main() {
// Load Dataset
Celsius2Fahrenheit<float,float> dataset;
dataset.create(5);
// Create Model
Dense<float> fc1(1,1,NO_ACTIVATION);
// Initialise Optimiser
SGD<float> sgd(0.01);
// Train
cout<<"Training started"<<endl;
for(int j = 0;j<2000;j++) {
for(auto i: dataset.data) {
// Get data
auto inp = new Tensor<float>({i.first}, {1,1});
auto tar = new Tensor<float>({i.second}, {1,1});
// Forward Prop
auto out = fc1.forward(inp);
// Get Loss
auto l = new Tensor<float>({-1}, {1,1});
auto k = tensorOps::multiply(l,tar);
auto loss = tensorOps::add(out,k); // error in loss
auto finalLoss = tensorOps::power(loss,(float)2);
// Compute backProp
finalLoss->backward();
// Perform Gradient Descent
sgd.minimise(finalLoss);
}
}
cout<<"Training completed"<<endl;
// Inference
float cel = 4;
auto test = new Tensor<float>({cel}, {1,1});
auto out1 = fc1.forward(test);
cout<<"The conversion of "<<cel<<" degrees celcius to fahrenheit is "<<out1->val<<endl; // For 4 Celcius: it's ~39.2
// Clean up
delete out1;
}