K-Means-1.png
K-Means-2.png
K-Means-3.png
K-Means-4.png
程序
# coding: utf-8
# # 第三次模式識(shí)別作業(yè)
# In[1]:
get_ipython().magic('matplotlib inline')
# In[2]:
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
# In[ ]:
K=3
iris = load_iris()
X = iris.data
Y = iris.target
# # 隨機(jī)洗牌數(shù)據(jù)
# In[13]:
shuffle_para=np.arange(Y.shape[0])
np.random.shuffle(shuffle_para)
X,Y=X[shuffle_para],Y[shuffle_para]
# # 每次隨機(jī)一樣
# In[ ]:
np.random.seed(980406)
# # 分類
# In[ ]:
cla=[]
for i in range(K):
cla.append(np.where(Y==i))
# # 初始點(diǎn)
# In[14]:
initial_point=X[np.random.randint(0,X.shape[0],(3,))]
initial_point
# In[15]:
mean_point=initial_point
# In[16]:
print(X.shape)
# # 開始迭代
# In[17]:
accu=[]
n=0
while True:
# 計(jì)算到k個(gè)中心的歐氏距離
distances=[]
for p in mean_point:
distances.append(np.linalg.norm((X-p),axis=1))
pass
distances=np.array(distances)
y=np.argmin(distances,0)
y=np.array(y,dtype=int)
# 保存上次點(diǎn)
last_point=mean_point
# 生成新點(diǎn)
mean_point=[]
for i in range(K):
mean_point.append(np.mean(X[(y==i),:],axis=0))
mean_point=np.array(mean_point)
J=np.linalg.norm(last_point-mean_point,axis=1)
# 每一個(gè)都是<0.01
if False not in list(J<0.001):
break
pass
if(n==20):
print('到達(dá)最大迭代次數(shù)')
break
# 看把原始數(shù)據(jù)的每一類還保留多少個(gè)為一類
corr=0
for c in cla:
corr+=(max(np.bincount(y[c])))
accu.append(corr/Y.shape[0])
print(accu[-1])
n+=1
pass
# # 畫圖
# In[18]:
plt.ylim([0.6,1])
plt.xticks(list(range(n)), rotation=20)
plt.xlabel('Interations')
plt.ylabel('Accuracy')
plt.plot(np.arange(n),accu)
# In[19]:
mean_point.shape
# In[20]:
label=(('Sepal length','Sepal width'),('Petal length','Petal width'))
def scat(i):
plt.scatter(X[:, i*2], X[:,2*(i+1)-1], c=y,marker='+')
plt.scatter(mean_point[:,i*2],mean_point[:,(i+1)*2-1],c=np.arange(K),marker='o')
plt.xlabel(label[i][0])
plt.ylabel(label[i][1])
i=0
scat(i)
# In[21]:
scat(1)