|
| 1 | +from libsvm.svmutil import * |
| 2 | +from libsvm.svm import * |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import matplotlib as mpl |
| 7 | +import seaborn as sns |
| 8 | + |
| 9 | +mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 指定默认字体 |
| 10 | +mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题 |
| 11 | + |
| 12 | + |
| 13 | +def list_of_dict_to_list_of_list(list_of_dict: list) -> list: |
| 14 | + """ |
| 15 | + 将元素为字典的列表转换成元素为列表的列表,仅取value |
| 16 | +
|
| 17 | + Args: |
| 18 | + list_of_dict (list): 元素为字典的列表 |
| 19 | +
|
| 20 | + Returns: |
| 21 | + list: 元素为列表的列表 |
| 22 | + """ |
| 23 | + list_of_list = [[v for v in d.values()] for d in list_of_dict] |
| 24 | + return list_of_list |
| 25 | + |
| 26 | + |
| 27 | +def get_dividing_point(y: list): |
| 28 | + """ |
| 29 | + 找出不同样例的分界点 |
| 30 | +
|
| 31 | + Args: |
| 32 | + y (list): 数据标签 |
| 33 | +
|
| 34 | + Returns: |
| 35 | + int: -1表示全部相同,否则表示分界点 |
| 36 | +
|
| 37 | + """ |
| 38 | + last = y[0] |
| 39 | + for i, yi in enumerate(y): |
| 40 | + if yi != last: |
| 41 | + return i |
| 42 | + else: |
| 43 | + last = yi |
| 44 | + return -1 |
| 45 | + |
| 46 | + |
| 47 | +def scatter_training_set(x: list, y: list, axes): |
| 48 | + """ |
| 49 | + 绘制训练集散点图 |
| 50 | +
|
| 51 | + Args: |
| 52 | + x (list): 数据特征 |
| 53 | + y (list): 数据标签 |
| 54 | + axes (matplotlib.axes._base._AxesBase): 要绘图的Axes实例 |
| 55 | +
|
| 56 | + Returns: |
| 57 | + None |
| 58 | + """ |
| 59 | + x_array = np.array(list_of_dict_to_list_of_list(x)) |
| 60 | + x1 = x_array[:, 0] |
| 61 | + x2 = x_array[:, 1] |
| 62 | + dividing_point = get_dividing_point(y) |
| 63 | + axes.scatter(x1[:dividing_point], x2[:dividing_point]) |
| 64 | + axes.scatter(x1[dividing_point:], x2[dividing_point:]) |
| 65 | + |
| 66 | + |
| 67 | +def leave_one_out(x: list, y: list, param_str: str): |
| 68 | + """ |
| 69 | + 进行留一交叉验证 |
| 70 | +
|
| 71 | + Args: |
| 72 | + x (list): 数据特征 |
| 73 | + y (list): 数据标签 |
| 74 | + param_str (str): SVM参数指令 |
| 75 | +
|
| 76 | + Returns: |
| 77 | + 留一交叉验证精度 |
| 78 | + """ |
| 79 | + param_str += " -v " + str(len(y)) |
| 80 | + accuracy = svm_train(y, x, param_str) |
| 81 | + return accuracy |
| 82 | + |
| 83 | + |
| 84 | +def solve_predict(x: list, y: list, param_str: str): |
| 85 | + """ |
| 86 | + 训练模型SVM并用于分类 |
| 87 | +
|
| 88 | + Args: |
| 89 | + x (list): 数据特征 |
| 90 | + y (list): 数据标签 |
| 91 | + param_str (str): SVM参数指令 |
| 92 | +
|
| 93 | + Returns: |
| 94 | + p_label, p_acc, p_val, model |
| 95 | + """ |
| 96 | + prob = svm_problem(y, x) |
| 97 | + param = svm_parameter(param_str) |
| 98 | + model = svm_train(prob, param) |
| 99 | + p_label, p_acc, p_val = svm_predict(y, x, model) |
| 100 | + return p_label, p_acc, p_val, model |
| 101 | + |
| 102 | + |
| 103 | +def tuning_gauss(x: list, y: list, c_range: np.ndarray, g_range: np.ndarray): |
| 104 | + """ |
| 105 | + SVM高斯核调参 |
| 106 | +
|
| 107 | + Args: |
| 108 | + x (list): 数据特征 |
| 109 | + y (list): 数据标签 |
| 110 | + c_range (np.ndarray): c参数所有取值 |
| 111 | + g_range (np.ndarray): g参数所有取值 |
| 112 | +
|
| 113 | + Returns: |
| 114 | + best_result (dict): 调参的最优结果,包含精度和c、g取值 |
| 115 | + result_frame (pd.DataFrame): 调参过程中所有c、g和对应精度 |
| 116 | + """ |
| 117 | + best_result = {"Accuracy": -1, "c": -1, "g": -1} |
| 118 | + result_file_name = "best_result.txt" |
| 119 | + result_array = [] |
| 120 | + clear_file(result_file_name) |
| 121 | + for c in c_range: |
| 122 | + for g in g_range: |
| 123 | + param_str = '-q -t 2 -c ' + str(c) + ' -g ' + str(g) |
| 124 | + accuracy = leave_one_out(x, y, param_str) |
| 125 | + result_array.append([float(format(c, '.6f')), float(format(g, '.6f')), accuracy]) |
| 126 | + if accuracy >= best_result["Accuracy"]: |
| 127 | + best_result["Accuracy"] = accuracy |
| 128 | + best_result["c"] = c |
| 129 | + best_result["g"] = g |
| 130 | + append_dict_to_file(result_file_name, best_result) |
| 131 | + result_frame = pd.DataFrame(result_array, columns=['c', 'g', 'Accuracy']) |
| 132 | + return best_result, result_frame |
| 133 | + |
| 134 | + |
| 135 | +def clear_file(filename: str): |
| 136 | + """ |
| 137 | + 清空文件 |
| 138 | +
|
| 139 | + Args: |
| 140 | + filename (str): 文件名 |
| 141 | +
|
| 142 | + Returns: |
| 143 | + None |
| 144 | + """ |
| 145 | + with open(filename, mode='r+', encoding='UTF-8') as file_object: |
| 146 | + file_object.truncate() |
| 147 | + |
| 148 | + |
| 149 | +def append_dict_to_file(filename: str, content: dict): |
| 150 | + """ |
| 151 | + 将字典内容写入文件 |
| 152 | +
|
| 153 | + Args: |
| 154 | + filename (str): 文件名 |
| 155 | + content (dict): 要写入的字典 |
| 156 | +
|
| 157 | + Returns: |
| 158 | + None |
| 159 | + """ |
| 160 | + newline = '' # 要写入的内容 |
| 161 | + for k, v in content.items(): |
| 162 | + newline += str(k) + ': ' + str(v) + '\t' |
| 163 | + newline += '\n' |
| 164 | + append_to_file(filename, newline) |
| 165 | + |
| 166 | + |
| 167 | +def append_to_file(filename: str, content: str): |
| 168 | + """ |
| 169 | + 将字符串写入文件 |
| 170 | +
|
| 171 | + Args: |
| 172 | + filename (str): 文件名 |
| 173 | + content (str): 要写入的字符串 |
| 174 | +
|
| 175 | + Returns: |
| 176 | + None |
| 177 | + """ |
| 178 | + with open(filename, mode='r+', encoding='UTF-8') as file_object: |
| 179 | + file_object.seek(0, 2) |
| 180 | + file_object.writelines(content) |
| 181 | + |
| 182 | + |
| 183 | +def plot_tuning_result(result_frame: pd.DataFrame): |
| 184 | + """ |
| 185 | + 绘制调参结果的热力图 |
| 186 | +
|
| 187 | + Args: |
| 188 | + result_frame (pd.DataFrame): 调参结果 |
| 189 | +
|
| 190 | + Returns: |
| 191 | + None |
| 192 | + """ |
| 193 | + fig, ax = plt.subplots(figsize=(10, 10)) |
| 194 | + # sns.set() |
| 195 | + result_frame = result_frame.pivot("c", "g", "Accuracy") |
| 196 | + hm = sns.heatmap(result_frame, ax=ax, cmap="YlGnBu") |
| 197 | + hm.set_xlabel(hm.get_xlabel(), labelpad=0, rotation=0) |
| 198 | + plt.yticks(rotation=0) |
| 199 | + plt.savefig('parameter heat map.png', dpi=260) |
| 200 | + |
| 201 | + |
| 202 | +def calculate_laplace_kernel(x: list, y: list, gamma: float, result_file_name: str): |
| 203 | + """ |
| 204 | + 计算拉普拉斯核并写入文件 |
| 205 | +
|
| 206 | + Args: |
| 207 | + x (list): 数据特征 |
| 208 | + y (list): 数据标签 |
| 209 | + gamma (float): gamma参数 |
| 210 | + result_file_name (str): 要写入的文件名 |
| 211 | +
|
| 212 | + Returns: |
| 213 | + None |
| 214 | + """ |
| 215 | + x_array = np.array(list_of_dict_to_list_of_list(x)) |
| 216 | + clear_file(result_file_name) |
| 217 | + for i in range(len(y)): |
| 218 | + kernels = [] |
| 219 | + for j in range(len(y)): |
| 220 | + x_dif = x_array[i, :] - x_array[j, :] # 可以利用传播性质优化计算 |
| 221 | + x_distance = np.power(np.sum(np.power(x_dif, 2)), 0.5) |
| 222 | + kernel = np.exp(-gamma * x_distance) |
| 223 | + kernels.append(kernel) |
| 224 | + content = str(y[i]) + " 0:" + str(i + 1) |
| 225 | + for k in range(len(y)): |
| 226 | + content += ' ' + str(k + 1) + ':' + str(kernels[k]) |
| 227 | + content += '\n' |
| 228 | + append_to_file(result_file_name, content) |
| 229 | + |
| 230 | + |
| 231 | +def use_laplace(x, y, c: float): |
| 232 | + """ |
| 233 | + 使用拉普拉斯核进行留一交叉验证和分类 |
| 234 | +
|
| 235 | + Args: |
| 236 | + x (list): 数据特征 |
| 237 | + y (list): 数据标签 |
| 238 | + c (float): c参数 |
| 239 | +
|
| 240 | + Returns: |
| 241 | + accuracy, p_label, p_acc, p_val, model |
| 242 | + """ |
| 243 | + param_str = '-q -t 4 -c ' + str(c) |
| 244 | + accuracy = leave_one_out(x, y, param_str) |
| 245 | + p_label, p_acc, p_val, model = solve_predict(x, y, param_str) |
| 246 | + return accuracy, p_label, p_acc, p_val, model |
| 247 | + |
| 248 | + |
| 249 | +def tuning_laplace(x: list, y: list, kernel_file_name: str, c_range: np.ndarray, g_range: np.ndarray): |
| 250 | + """ |
| 251 | + SVM拉普拉斯核调参 |
| 252 | +
|
| 253 | + Args: |
| 254 | + x (list): 数据特征 |
| 255 | + y (list): 数据标签 |
| 256 | + kernel_file_name (str): 要写入拉普拉斯核的文件名 |
| 257 | + c_range (np.ndarray): c参数所有取值 |
| 258 | + g_range (np.ndarray): g参数所有取值 |
| 259 | +
|
| 260 | + Returns: |
| 261 | + best_result (dict): 调参的最优结果,包含精度和c、g取值 |
| 262 | + result_frame (pd.DataFrame): 调参过程中所有c、g和对应精度 |
| 263 | + """ |
| 264 | + best_result = {"Accuracy": -1, "c": -1, "g": -1} |
| 265 | + result_file_name = "best_laplace_result.txt" |
| 266 | + result_array = [] |
| 267 | + clear_file(result_file_name) |
| 268 | + for g in g_range: |
| 269 | + calculate_laplace_kernel(x, y, g, kernel_file_name) |
| 270 | + ly, lx = svm_read_problem(kernel_file_name) |
| 271 | + for c in c_range: |
| 272 | + param_str = '-q -t 4 -c ' + str(c) |
| 273 | + accuracy = leave_one_out(lx, ly, param_str) |
| 274 | + result_array.append([float(format(c, '.2f')), float(format(g, '.2f')), accuracy]) |
| 275 | + # result_array.append([c, g, accuracy]) |
| 276 | + if accuracy >= best_result["Accuracy"]: |
| 277 | + best_result["Accuracy"] = accuracy |
| 278 | + best_result["c"] = c |
| 279 | + best_result["g"] = g |
| 280 | + append_dict_to_file(result_file_name, best_result) |
| 281 | + result_frame = pd.DataFrame(result_array, columns=['c', 'g', 'Accuracy']) |
| 282 | + return best_result, result_frame |
| 283 | + |
| 284 | + |
| 285 | +def plot_sv(model, customed_model: bool, axes, x: np.ndarray = np.array([])): |
| 286 | + """ |
| 287 | + 在图中标注支持向量 |
| 288 | +
|
| 289 | + Args: |
| 290 | + model (): SVM模型 |
| 291 | + customed_model (bool): 是否使用自定义核(拉普拉斯核) |
| 292 | + axes (matplotlib.axes._base._AxesBase): 要绘图的Axes实例 |
| 293 | + x (np.ndarray): 使用自定义核时的原始数据 |
| 294 | +
|
| 295 | + Returns: |
| 296 | + None |
| 297 | + """ |
| 298 | + if not customed_model: |
| 299 | + sv_dict = model.get_SV() |
| 300 | + sv = np.array(list_of_dict_to_list_of_list(sv_dict)) |
| 301 | + else: |
| 302 | + if x.size == 0: |
| 303 | + raise Exception("x数据缺失") |
| 304 | + sv_indices = np.array(model.get_sv_indices(), dtype=np.int32) - 1 |
| 305 | + sv = x[sv_indices] |
| 306 | + x1 = sv[:, 0] |
| 307 | + x2 = sv[:, 1] |
| 308 | + plt.scatter(x1, x2, marker='o', facecolor='none', edgecolors='black', s=200) |
| 309 | + |
| 310 | + |
| 311 | +def plot_data_and_sv(x, y, model, customed_model: bool, title: str, fig_file_name: str = "data and SV"): |
| 312 | + """ |
| 313 | + 绘制原始数据并标注支持向量 |
| 314 | +
|
| 315 | + Args: |
| 316 | + x (list): 数据特征 |
| 317 | + y (list): 数据标签 |
| 318 | + model (): SVM模型 |
| 319 | + customed_model (bool): 是否使用自定义核(拉普拉斯核) |
| 320 | + title (str): 绘图标题 |
| 321 | + fig_file_name (str): 保存图片的文件名 |
| 322 | +
|
| 323 | + Returns: |
| 324 | + None |
| 325 | + """ |
| 326 | + fig, axes = plt.subplots(1, 1) |
| 327 | + scatter_training_set(x, y, axes) |
| 328 | + x = np.array(list_of_dict_to_list_of_list(x)) |
| 329 | + plot_sv(model=model, customed_model=customed_model, axes=axes, x=x) |
| 330 | + props = {'xlabel': '密度', 'ylabel': '含糖率', 'title': title} |
| 331 | + axes.set(**props) |
| 332 | + axes.set_ylabel(axes.get_ylabel(), labelpad=20, rotation=0) |
| 333 | + plt.savefig(fig_file_name + ".png", dpi=260) |
| 334 | + |
| 335 | + |
| 336 | +if __name__ == '__main__': |
| 337 | + param_str = '-q -t 2 -c 1.4 -g 110' |
| 338 | + y, x = svm_read_problem('training set.txt') |
| 339 | + '''高斯''' |
| 340 | + accuracy = leave_one_out(x, y, param_str) |
| 341 | + p_label, p_acc, p_val, gauss_model = solve_predict(x, y, param_str) |
| 342 | + plot_data_and_sv(x=x, y=y, model=gauss_model, customed_model=False, title="SVM-高斯核, C=1.4, γ=110", |
| 343 | + fig_file_name="gauss data and SV") |
| 344 | + # best_gauss_result, gauss_result_frame = tuning_gauss(x, y, np.linspace(1, 10, int((10 - 1) * 1) + 1), |
| 345 | + # np.linspace(0, 128, int((128 - 0) * 1) + 1)) |
| 346 | + # best_gauss_result, gauss_result_frame = tuning_gauss(x, y, np.logspace(-4, 4, num=513, base=10), |
| 347 | + # np.logspace(-4, 4, num=513, base=10)) |
| 348 | + # plot_tuning_result(gauss_result_frame) |
| 349 | + '''拉普拉斯''' |
| 350 | + calculate_laplace_kernel(x, y, gamma=9, result_file_name="laplace_kernel.txt") |
| 351 | + ly, lx = svm_read_problem("laplace_kernel.txt") |
| 352 | + l_accuracy, l_p_label, l_p_acc, l_p_val, laplace_model = use_laplace(lx, ly, c=0.8) |
| 353 | + plot_data_and_sv(x=x, y=y, model=laplace_model, customed_model=True, |
| 354 | + title="SVM-拉普拉斯核, C=0.8, γ=9", |
| 355 | + fig_file_name="laplace data and SV") |
| 356 | + # best_laplace_result, laplace_result_frame = tuning_laplace(x, y, "laplace_kernel.txt", |
| 357 | + # np.logspace(-4, 4, num=129, base=10), |
| 358 | + # np.logspace(-4, 4, num=129, base=10)) |
| 359 | + # best_laplace_result, laplace_result_frame = tuning_laplace(x, y, "laplace_kernel.txt", |
| 360 | + # np.linspace(0.2, 1.4, int((1.4 - 0.2) * 10) + 1), |
| 361 | + # np.linspace(0, 40, int((40 - 0) * 10) + 1)) |
| 362 | + # plot_tuning_result(laplace_result_frame) |
| 363 | + plt.show() |
0 commit comments