Skip to content

Commit 398749a

Browse files
authored
Add source code and data set.
1 parent 308c8ae commit 398749a

File tree

2 files changed

+380
-0
lines changed

2 files changed

+380
-0
lines changed

main.py

+363
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
from libsvm.svmutil import *
2+
from libsvm.svm import *
3+
import numpy as np
4+
import pandas as pd
5+
import matplotlib.pyplot as plt
6+
import matplotlib as mpl
7+
import seaborn as sns
8+
9+
mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 指定默认字体
10+
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题
11+
12+
13+
def list_of_dict_to_list_of_list(list_of_dict: list) -> list:
14+
"""
15+
将元素为字典的列表转换成元素为列表的列表,仅取value
16+
17+
Args:
18+
list_of_dict (list): 元素为字典的列表
19+
20+
Returns:
21+
list: 元素为列表的列表
22+
"""
23+
list_of_list = [[v for v in d.values()] for d in list_of_dict]
24+
return list_of_list
25+
26+
27+
def get_dividing_point(y: list):
28+
"""
29+
找出不同样例的分界点
30+
31+
Args:
32+
y (list): 数据标签
33+
34+
Returns:
35+
int: -1表示全部相同,否则表示分界点
36+
37+
"""
38+
last = y[0]
39+
for i, yi in enumerate(y):
40+
if yi != last:
41+
return i
42+
else:
43+
last = yi
44+
return -1
45+
46+
47+
def scatter_training_set(x: list, y: list, axes):
48+
"""
49+
绘制训练集散点图
50+
51+
Args:
52+
x (list): 数据特征
53+
y (list): 数据标签
54+
axes (matplotlib.axes._base._AxesBase): 要绘图的Axes实例
55+
56+
Returns:
57+
None
58+
"""
59+
x_array = np.array(list_of_dict_to_list_of_list(x))
60+
x1 = x_array[:, 0]
61+
x2 = x_array[:, 1]
62+
dividing_point = get_dividing_point(y)
63+
axes.scatter(x1[:dividing_point], x2[:dividing_point])
64+
axes.scatter(x1[dividing_point:], x2[dividing_point:])
65+
66+
67+
def leave_one_out(x: list, y: list, param_str: str):
68+
"""
69+
进行留一交叉验证
70+
71+
Args:
72+
x (list): 数据特征
73+
y (list): 数据标签
74+
param_str (str): SVM参数指令
75+
76+
Returns:
77+
留一交叉验证精度
78+
"""
79+
param_str += " -v " + str(len(y))
80+
accuracy = svm_train(y, x, param_str)
81+
return accuracy
82+
83+
84+
def solve_predict(x: list, y: list, param_str: str):
85+
"""
86+
训练模型SVM并用于分类
87+
88+
Args:
89+
x (list): 数据特征
90+
y (list): 数据标签
91+
param_str (str): SVM参数指令
92+
93+
Returns:
94+
p_label, p_acc, p_val, model
95+
"""
96+
prob = svm_problem(y, x)
97+
param = svm_parameter(param_str)
98+
model = svm_train(prob, param)
99+
p_label, p_acc, p_val = svm_predict(y, x, model)
100+
return p_label, p_acc, p_val, model
101+
102+
103+
def tuning_gauss(x: list, y: list, c_range: np.ndarray, g_range: np.ndarray):
104+
"""
105+
SVM高斯核调参
106+
107+
Args:
108+
x (list): 数据特征
109+
y (list): 数据标签
110+
c_range (np.ndarray): c参数所有取值
111+
g_range (np.ndarray): g参数所有取值
112+
113+
Returns:
114+
best_result (dict): 调参的最优结果,包含精度和c、g取值
115+
result_frame (pd.DataFrame): 调参过程中所有c、g和对应精度
116+
"""
117+
best_result = {"Accuracy": -1, "c": -1, "g": -1}
118+
result_file_name = "best_result.txt"
119+
result_array = []
120+
clear_file(result_file_name)
121+
for c in c_range:
122+
for g in g_range:
123+
param_str = '-q -t 2 -c ' + str(c) + ' -g ' + str(g)
124+
accuracy = leave_one_out(x, y, param_str)
125+
result_array.append([float(format(c, '.6f')), float(format(g, '.6f')), accuracy])
126+
if accuracy >= best_result["Accuracy"]:
127+
best_result["Accuracy"] = accuracy
128+
best_result["c"] = c
129+
best_result["g"] = g
130+
append_dict_to_file(result_file_name, best_result)
131+
result_frame = pd.DataFrame(result_array, columns=['c', 'g', 'Accuracy'])
132+
return best_result, result_frame
133+
134+
135+
def clear_file(filename: str):
136+
"""
137+
清空文件
138+
139+
Args:
140+
filename (str): 文件名
141+
142+
Returns:
143+
None
144+
"""
145+
with open(filename, mode='r+', encoding='UTF-8') as file_object:
146+
file_object.truncate()
147+
148+
149+
def append_dict_to_file(filename: str, content: dict):
150+
"""
151+
将字典内容写入文件
152+
153+
Args:
154+
filename (str): 文件名
155+
content (dict): 要写入的字典
156+
157+
Returns:
158+
None
159+
"""
160+
newline = '' # 要写入的内容
161+
for k, v in content.items():
162+
newline += str(k) + ': ' + str(v) + '\t'
163+
newline += '\n'
164+
append_to_file(filename, newline)
165+
166+
167+
def append_to_file(filename: str, content: str):
168+
"""
169+
将字符串写入文件
170+
171+
Args:
172+
filename (str): 文件名
173+
content (str): 要写入的字符串
174+
175+
Returns:
176+
None
177+
"""
178+
with open(filename, mode='r+', encoding='UTF-8') as file_object:
179+
file_object.seek(0, 2)
180+
file_object.writelines(content)
181+
182+
183+
def plot_tuning_result(result_frame: pd.DataFrame):
184+
"""
185+
绘制调参结果的热力图
186+
187+
Args:
188+
result_frame (pd.DataFrame): 调参结果
189+
190+
Returns:
191+
None
192+
"""
193+
fig, ax = plt.subplots(figsize=(10, 10))
194+
# sns.set()
195+
result_frame = result_frame.pivot("c", "g", "Accuracy")
196+
hm = sns.heatmap(result_frame, ax=ax, cmap="YlGnBu")
197+
hm.set_xlabel(hm.get_xlabel(), labelpad=0, rotation=0)
198+
plt.yticks(rotation=0)
199+
plt.savefig('parameter heat map.png', dpi=260)
200+
201+
202+
def calculate_laplace_kernel(x: list, y: list, gamma: float, result_file_name: str):
203+
"""
204+
计算拉普拉斯核并写入文件
205+
206+
Args:
207+
x (list): 数据特征
208+
y (list): 数据标签
209+
gamma (float): gamma参数
210+
result_file_name (str): 要写入的文件名
211+
212+
Returns:
213+
None
214+
"""
215+
x_array = np.array(list_of_dict_to_list_of_list(x))
216+
clear_file(result_file_name)
217+
for i in range(len(y)):
218+
kernels = []
219+
for j in range(len(y)):
220+
x_dif = x_array[i, :] - x_array[j, :] # 可以利用传播性质优化计算
221+
x_distance = np.power(np.sum(np.power(x_dif, 2)), 0.5)
222+
kernel = np.exp(-gamma * x_distance)
223+
kernels.append(kernel)
224+
content = str(y[i]) + " 0:" + str(i + 1)
225+
for k in range(len(y)):
226+
content += ' ' + str(k + 1) + ':' + str(kernels[k])
227+
content += '\n'
228+
append_to_file(result_file_name, content)
229+
230+
231+
def use_laplace(x, y, c: float):
232+
"""
233+
使用拉普拉斯核进行留一交叉验证和分类
234+
235+
Args:
236+
x (list): 数据特征
237+
y (list): 数据标签
238+
c (float): c参数
239+
240+
Returns:
241+
accuracy, p_label, p_acc, p_val, model
242+
"""
243+
param_str = '-q -t 4 -c ' + str(c)
244+
accuracy = leave_one_out(x, y, param_str)
245+
p_label, p_acc, p_val, model = solve_predict(x, y, param_str)
246+
return accuracy, p_label, p_acc, p_val, model
247+
248+
249+
def tuning_laplace(x: list, y: list, kernel_file_name: str, c_range: np.ndarray, g_range: np.ndarray):
250+
"""
251+
SVM拉普拉斯核调参
252+
253+
Args:
254+
x (list): 数据特征
255+
y (list): 数据标签
256+
kernel_file_name (str): 要写入拉普拉斯核的文件名
257+
c_range (np.ndarray): c参数所有取值
258+
g_range (np.ndarray): g参数所有取值
259+
260+
Returns:
261+
best_result (dict): 调参的最优结果,包含精度和c、g取值
262+
result_frame (pd.DataFrame): 调参过程中所有c、g和对应精度
263+
"""
264+
best_result = {"Accuracy": -1, "c": -1, "g": -1}
265+
result_file_name = "best_laplace_result.txt"
266+
result_array = []
267+
clear_file(result_file_name)
268+
for g in g_range:
269+
calculate_laplace_kernel(x, y, g, kernel_file_name)
270+
ly, lx = svm_read_problem(kernel_file_name)
271+
for c in c_range:
272+
param_str = '-q -t 4 -c ' + str(c)
273+
accuracy = leave_one_out(lx, ly, param_str)
274+
result_array.append([float(format(c, '.2f')), float(format(g, '.2f')), accuracy])
275+
# result_array.append([c, g, accuracy])
276+
if accuracy >= best_result["Accuracy"]:
277+
best_result["Accuracy"] = accuracy
278+
best_result["c"] = c
279+
best_result["g"] = g
280+
append_dict_to_file(result_file_name, best_result)
281+
result_frame = pd.DataFrame(result_array, columns=['c', 'g', 'Accuracy'])
282+
return best_result, result_frame
283+
284+
285+
def plot_sv(model, customed_model: bool, axes, x: np.ndarray = np.array([])):
286+
"""
287+
在图中标注支持向量
288+
289+
Args:
290+
model (): SVM模型
291+
customed_model (bool): 是否使用自定义核(拉普拉斯核)
292+
axes (matplotlib.axes._base._AxesBase): 要绘图的Axes实例
293+
x (np.ndarray): 使用自定义核时的原始数据
294+
295+
Returns:
296+
None
297+
"""
298+
if not customed_model:
299+
sv_dict = model.get_SV()
300+
sv = np.array(list_of_dict_to_list_of_list(sv_dict))
301+
else:
302+
if x.size == 0:
303+
raise Exception("x数据缺失")
304+
sv_indices = np.array(model.get_sv_indices(), dtype=np.int32) - 1
305+
sv = x[sv_indices]
306+
x1 = sv[:, 0]
307+
x2 = sv[:, 1]
308+
plt.scatter(x1, x2, marker='o', facecolor='none', edgecolors='black', s=200)
309+
310+
311+
def plot_data_and_sv(x, y, model, customed_model: bool, title: str, fig_file_name: str = "data and SV"):
312+
"""
313+
绘制原始数据并标注支持向量
314+
315+
Args:
316+
x (list): 数据特征
317+
y (list): 数据标签
318+
model (): SVM模型
319+
customed_model (bool): 是否使用自定义核(拉普拉斯核)
320+
title (str): 绘图标题
321+
fig_file_name (str): 保存图片的文件名
322+
323+
Returns:
324+
None
325+
"""
326+
fig, axes = plt.subplots(1, 1)
327+
scatter_training_set(x, y, axes)
328+
x = np.array(list_of_dict_to_list_of_list(x))
329+
plot_sv(model=model, customed_model=customed_model, axes=axes, x=x)
330+
props = {'xlabel': '密度', 'ylabel': '含糖率', 'title': title}
331+
axes.set(**props)
332+
axes.set_ylabel(axes.get_ylabel(), labelpad=20, rotation=0)
333+
plt.savefig(fig_file_name + ".png", dpi=260)
334+
335+
336+
if __name__ == '__main__':
337+
param_str = '-q -t 2 -c 1.4 -g 110'
338+
y, x = svm_read_problem('training set.txt')
339+
'''高斯'''
340+
accuracy = leave_one_out(x, y, param_str)
341+
p_label, p_acc, p_val, gauss_model = solve_predict(x, y, param_str)
342+
plot_data_and_sv(x=x, y=y, model=gauss_model, customed_model=False, title="SVM-高斯核, C=1.4, γ=110",
343+
fig_file_name="gauss data and SV")
344+
# best_gauss_result, gauss_result_frame = tuning_gauss(x, y, np.linspace(1, 10, int((10 - 1) * 1) + 1),
345+
# np.linspace(0, 128, int((128 - 0) * 1) + 1))
346+
# best_gauss_result, gauss_result_frame = tuning_gauss(x, y, np.logspace(-4, 4, num=513, base=10),
347+
# np.logspace(-4, 4, num=513, base=10))
348+
# plot_tuning_result(gauss_result_frame)
349+
'''拉普拉斯'''
350+
calculate_laplace_kernel(x, y, gamma=9, result_file_name="laplace_kernel.txt")
351+
ly, lx = svm_read_problem("laplace_kernel.txt")
352+
l_accuracy, l_p_label, l_p_acc, l_p_val, laplace_model = use_laplace(lx, ly, c=0.8)
353+
plot_data_and_sv(x=x, y=y, model=laplace_model, customed_model=True,
354+
title="SVM-拉普拉斯核, C=0.8, γ=9",
355+
fig_file_name="laplace data and SV")
356+
# best_laplace_result, laplace_result_frame = tuning_laplace(x, y, "laplace_kernel.txt",
357+
# np.logspace(-4, 4, num=129, base=10),
358+
# np.logspace(-4, 4, num=129, base=10))
359+
# best_laplace_result, laplace_result_frame = tuning_laplace(x, y, "laplace_kernel.txt",
360+
# np.linspace(0.2, 1.4, int((1.4 - 0.2) * 10) + 1),
361+
# np.linspace(0, 40, int((40 - 0) * 10) + 1))
362+
# plot_tuning_result(laplace_result_frame)
363+
plt.show()

training set.txt

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
1 1:0.697 2:0.460
2+
1 1:0.774 2:0.376
3+
1 1:0.634 2:0.264
4+
1 1:0.608 2:0.318
5+
1 1:0.556 2:0.215
6+
1 1:0.403 2:0.237
7+
1 1:0.481 2:0.149
8+
1 1:0.437 2:0.211
9+
0 1:0.666 2:0.091
10+
0 1:0.243 2:0.267
11+
0 1:0.245 2:0.057
12+
0 1:0.343 2:0.099
13+
0 1:0.639 2:0.161
14+
0 1:0.657 2:0.198
15+
0 1:0.360 2:0.370
16+
0 1:0.593 2:0.042
17+
0 1:0.719 2:0.103

0 commit comments

Comments
 (0)