-
Notifications
You must be signed in to change notification settings - Fork 0
/
entrance.asv
58 lines (57 loc) · 2.25 KB
/
entrance.asv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
function entrance()
%age,creatinine_phosphokinase,ejection_fraction,platelets,serum_creatinine,serum_sodium,time,DEATH_EVENT
%75 582 20 265000 1.9 130 4 1
filename = '.\heart_failure_clinical_records_dataset.csv';
[age,creatinine_phosphokinase,ejection_fraction,platelets,serum_creatinine,serum_sodium,time,DEATH_EVENT]=textread(filename,'', 'delimiter', ',');
theta=[4;4;4;4;4;4;4;4];
Compute2(zscore(age),zscore(creatinine_phosphokinase),zscore(ejection_fraction),zscore(platelets),zscore(serum_creatinine),zscore(serum_sodium),zscore(time),zscore(DEATH_EVENT),theta,0.0015,100000);
%1000000
end
function Compute2(age,creatinine_phosphokinase,ejection_fraction,platelets,serum_creatinine,serum_sodium,time,DEATH_EVENT,Theta,a,iterationTimes)
X=age;
Y=DEATH_EVENT;
m=size(age,1);
X2=[ones(m,1),age,creatinine_phosphokinase,ejection_fraction,platelets,serum_creatinine,serum_sodium,time];
data=[m*iterationTimes:11];
theta=Theta;
timess=1;
for s=1:iterationTimes
for i=1:m
for n=1:size(Theta,1)
theta(n)=Theta(n)-(a*(1/m)*sum((X2(i,:)*Theta-DEATH_EVENT(i))*X2(i,n)));
end
Theta=theta;
data(timess,1)=timess;
data(timess,2)=((1/m)*sum((X2*Theta-DEATH_EVENT).^2));
data(timess,3)=Theta(1,1);
data(timess,4)=Theta(2,1);
data(timess,5)=Theta(3,1);
data(timess,6)=Theta(4,1);
data(timess,7)=Theta(5,1);
data(timess,8)=Theta(6,1);
data(timess,9)=Theta(7,1);
data(timess,10)=Theta(8,1);
data(timess,11)=Theta(9,1);
fprintf('第%d次迭代\t 损失函数结果 %.5f\t Theta1 %.2f\t Theta2 %.2f\t Theta3 %.2f\t Theta4 %.2f\t Theta5 %.2f\t Theta6 %.2f\t Theta7 %.2f\t Theta8 %.2f\t Theta9 %.2f\t\n',data(timess,1),data(timess,2),data(timess,3),data(timess,4),data(timess,5),data(timess,6),data(timess,7),data(timess,8),data(timess,9),data(timess,10),data(timess,11))
timess=timess+1;
end
end
subplot(2,2,1)
scatter(X,Y)
hold on;
plot(X,Theta(1,1)+Theta(2,1)*X)
title('数据样本散点图&数据假想回归直线')
xlabel('x轴')
ylabel('y轴')
ax = gca;
ax.FontSize = 13;
Theta
subplot(2,2,2)
plot(data(:,1),data(:,4))
title('梯度下降损失函数值')
xlabel('迭代次数')
ylabel('损失函数值')
ax = gca;
ax.FontSize = 13;
% subplot(2,2,3)
end