-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy path聚类 kmeans.py
67 lines (38 loc) · 1.16 KB
/
聚类 kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# coding: utf-8
# In[1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# In[2]:
data = pd.read_table('kmeans_data.txt',header=None,names=['x','y'])
x = data['x']
y = data['y']
plt.scatter(x,y)
# In[3]:
data.head()
# In[5]:
def distance(data,centers):
# data: 80x2, centers: 4x2
dist = np.zeros((data.shape[0],centers.shape[0]))
for i in range(len(data)):
for j in range(len(centers)):
dist[i,j] = np.sqrt(np.sum((data.iloc[i,:]-centers[j])**2))
return dist
def near_center(data,centers):
dist = distance(data,centers)
near_cen = np.argmin(dist,1)
return near_cen
def kmeans(data,k):
# step 1: init. centers
centers = np.random.choice(np.arange(-5,5,0.1),(k,2))
print(centers)
for _ in range(10):
# step 2: 点归属
near_cen = near_center(data,centers)
# step 3:簇重心更新
for ci in range(k):
centers[ci] = data[near_cen==ci].mean()
return centers,near_cen
centers,near_cen = kmeans(data,4)
plt.scatter(x,y,c=near_cen)
plt.scatter(centers[:,0],centers[:,1],marker='*',s=500,c='r')