-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathMLP2_Data.m
113 lines (91 loc) · 1.95 KB
/
MLP2_Data.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
clc;
clear all;
close all;
Q1 = 7;
Q2 = 3;
m = 3;
mu = 0.75;
a = 0.0001;
T = 1000;
MSETarget = 1e-20;
data = xlsread('data4.xlsx');
for j=1:7
data(:,j)=data(:,j)/max(abs(data(:,j)));
end
X = data(:,(1:7));
Y = data(:,8);
for i = 1:length(Y)
if(Y(i)==1)
Zy(i,:) = [0 0 1];
else if(Y(i)==2)
Zy(i,:) = [0 1 0];
else if(Y(i)==3)
Zy(i,:) = [1 0 0];
end
end
end
end
C = cvpartition(Y,'HoldOut',0.3);
tr = C.training;
te = C.test;
Xtr = X(tr,:);
Xte = X(te,:);
Ztr = Zy(tr,:);
Zte = Zy(te,:);
[p,N] = size(Xtr);
[p2,N2] = size(Xte);
bias = -1;
Xtr = [bias*ones(p,1) Xtr];
Xte = [bias*ones(p2,1) Xte];
W1 = rand(Q1,N+1);
W2 = rand(Q2,Q1+1);
W3 = rand(m,Q2+1);
MSETemp = zeros(1,T);
for i = 1:T
V1 = W1*Xtr';
Z1 = 1./(1+exp(-V1));
S1 = [bias*ones(1,p);Z1];
V2 = W2*S1;
Z2 = 1./(1+exp(-V2));
S2 = [bias*ones(1,p);Z2];
G = W3*S2;
Y = 1./(1+exp(-G));
E = Ztr - Y';
mse = (mean(mean(E.^2)));
MSETemp(i) = mse;
if(mse<MSETarget)
MSE = MSETemp(1:i);
return
end
df = Y.*(1-Y);
dG3 = df.*E';
DW3 = mu/N * dG3*S2';
W3 = (1+a)*W3 + DW3;
df = S2.*(1-S2);
dG2 = df.*(W3' * dG3);
dG2 = dG2(2:end,:);
DW2 = mu/N * dG2*S1';
W2 = (1+a)*W2 + DW2;
df = S1.*(1-S1);
dG1 = df.*(W2' * dG2);
dG1 = dG1(2:end,:);
DW1 = mu/N * dG1*Xtr;
W1 = (1+a)*W1 + DW1;
end
V1 = W1*Xte';
Z1 = 1./(1+exp(-V1));
S1 = [bias*ones(1,p2);Z1];
V2 = W2*S1;
Z2 = 1./(1+exp(-V2));
S2 = [bias*ones(1,p2);Z2];
G = W3*S2;
Yp = 1./(1+exp(-G));
Ypp = Yp';
for k = 1:size(Ypp,1)
[~, pl(k)] = max(Ypp(k,:));
[~, pa(k)] = max(Zte(k,:));
end
[cm,order] = confusionmat(pa,pl);
Accuracy = ((trace(cm))/sum(cm(:)))*100;
display(cm);
display(Accuracy);