所有代碼請(qǐng)移步GitHub——kNNbyPython
很多人在第一次聽到機(jī)器學(xué)習(xí)的時(shí)候都不知所措椿疗,無從下手漏峰。起初我也是這樣的,各種看別人的博客届榄,吳恩達(dá)的課程也死磕浅乔,但效果不佳。后來發(fā)現(xiàn)一個(gè)神奇的網(wǎng)站k-近鄰算法實(shí)現(xiàn)手寫數(shù)字識(shí)別系統(tǒng)--《機(jī)器學(xué)習(xí)實(shí)戰(zhàn) 》,跟著過了一遍之后感覺還不錯(cuò)铝条,也順便買了《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》這本書靖苇,接著就正式入坑機(jī)器學(xué)習(xí)。
KNN算法應(yīng)該是機(jī)器學(xué)習(xí)中最簡(jiǎn)單的算法之一班缰,作為機(jī)器學(xué)習(xí)的入門是個(gè)非常不錯(cuò)的選擇贤壁。
KNN算法思路
KNN(K-Nearest Neighbor)算法的理論基礎(chǔ)網(wǎng)上一查一大把,我這里就不贅述埠忘,這里我講自己的理解脾拆。
KNN算法屬于機(jī)器學(xué)習(xí)中的監(jiān)督算法,主要用于分類莹妒。
首先名船,在二維坐標(biāo)軸中,有四個(gè)點(diǎn)旨怠,分別是a1(1,1)渠驼,a2(1,2),b1(3,3)鉴腻,b2(3,4)迷扇。其中百揭,a1,a2為A類蜓席,b1信峻,b2為B類
這里用matplotlib實(shí)現(xiàn)一下這四個(gè)點(diǎn),更加直觀點(diǎn)瓮床。
實(shí)現(xiàn)這張圖的代碼盹舞,感興趣的可以看一下。
# -*- coding: utf-8 -*-
# @Date : 2017-04-28 16:52:44
# @Author : Alan Lau (rlalan@outlook.com)
# @Language : Python3.5
from matplotlib import pyplot as plt
import numpy as np
# 定義四個(gè)點(diǎn)的坐標(biāo)
a1 = np.array([1, 1])
a2 = np.array([1, 2])
b1 = np.array([3, 3])
b2 = np.array([3, 4])
# 四個(gè)點(diǎn)坐標(biāo)分別賦值給X,Y
X1, Y1 = a1
X2, Y2 = a2
X3, Y3 = b1
X4, Y4 = b2
plt.title('show data')
plt.scatter(X1, Y1, color="blue", label="a1")
plt.scatter(X2, Y2, color="blue", label="a2")
plt.scatter(X3, Y3, color="red", label="b1")
plt.scatter(X4, Y4, color="red", label="b2")
plt.legend(loc='upper left')
plt.annotate(r'a1(1,1)', xy=(X1, Y1), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'a2(1,2)', xy=(X2, Y2), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'b1(3,3)', xy=(X3, Y3), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'b2(3,4)', xy=(X4, Y4), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.show()
然后隘庄,問題出現(xiàn)了踢步,現(xiàn)在突然冒出個(gè)c(2,1)
我現(xiàn)在想知道的是,c(2,1)這個(gè)點(diǎn)丑掺,在AB兩個(gè)類中是屬于A類获印,還是數(shù)據(jù)B類。
怎么做街州?
1.計(jì)算c和其余所有點(diǎn)的距離兼丰。
2.將計(jì)算出的距離集合進(jìn)行升序排序(即距離最短的排列在前面)。
3.獲得距離集合降序排序的前k個(gè)距離唆缴。
4.統(tǒng)計(jì)出在前k個(gè)距離中鳍征,出現(xiàn)頻次最多的類別。
然后我們把已經(jīng)知道分類的四個(gè)點(diǎn)a1面徽,a2艳丛,b1,b3稱為訓(xùn)練數(shù)據(jù)趟紊,把未知類別的c稱為測(cè)試數(shù)據(jù)氮双。
這里的k取值一般為小于等于20的常數(shù),具體的取值霎匈,看不同的樣本戴差。同樣,如何確定k的值铛嘱,獲得最佳的計(jì)算結(jié)果暖释,也是kNN算法的一個(gè)難點(diǎn)。
現(xiàn)在跟著上面的例子走一遍弄痹,這里k取3(訓(xùn)練數(shù)據(jù)才4個(gè)饭入,最大只能取3)。
1.計(jì)算c和其余所有點(diǎn)的距離
計(jì)算距離的方法我這里使用歐式距離肛真,具體python代碼可以參考我的另一篇博文 python實(shí)現(xiàn)各種距離,同樣爽航,在眾多計(jì)算距離的方法中蚓让,確定使用kNN算法時(shí)用哪個(gè)距離算法也是該算法的難點(diǎn)之一乾忱。
此圖代碼:
# 如想運(yùn)行,請(qǐng)拼接上一段代碼
import math
def Euclidean(vec1, vec2):
npvec1, npvec2 = np.array(vec1), np.array(vec2)
return math.sqrt(((npvec1-npvec2)**2).sum())
# 顯示距離
def show_distance(exit_point, c):
line_point = np.array([exit_point, c])
x = (line_point.T)[0]
y = (line_point.T)[1]
o_dis = round(Euclidean(exit_point, c), 2) # 計(jì)算距離
mi_x, mi_y = (exit_point+c)/2 # 計(jì)算中點(diǎn)位置历极,來顯示“distance=xx”這個(gè)標(biāo)簽
plt.annotate('distance=%s' % str(o_dis), xy=(mi_x, mi_y), xycoords='data', xytext=(+10, 0), textcoords='offset points', fontsize=10, arrowprops=dict(arrowstyle="-", connectionstyle="arc3,rad=.2"))
return plt.plot(x, y, linestyle="--", color='black', lw=1)
show_distance(a1, c)
show_distance(a2, c)
show_distance(b1, c)
show_distance(b2, c)
plt.show()
代碼的注釋中怎么引用自己寫的包和.py窄瘟,看一參考我的博客python中import自己寫的.py
歐式距離計(jì)算方法
def Euclidean(vec1, vec2):
npvec1, npvec2 = np.array(vec1), np.array(vec2)
return math.sqrt(((npvec1-npvec2)**2).sum())
2.將計(jì)算出的距離集合進(jìn)行升序排序(即距離最短的排列在前面)
|升序序號(hào)|點(diǎn)標(biāo)簽|標(biāo)簽所屬類別|點(diǎn)坐標(biāo)|與c點(diǎn)距離|
| ------------- |:-------------: |:-------------:| -----:|
| 1 | a1 |A | (1,1) |1.0|
| 2 | a2 |A | (1,2) |1.41|
| 3 | b1 |B | (3,3) |2.24|
| 4 | b2 |B | (3,4) |3.16|
3.獲得距離集合升序排序的前k個(gè)距離
k取值為3,因此保留升序排序前三的距離
|升序序號(hào)|點(diǎn)標(biāo)簽|標(biāo)簽所屬類別|點(diǎn)坐標(biāo)|與c點(diǎn)距離|
| ------------- |:-------------: |:-------------:| -----:|
| 1 | a1 |A | (1,1) |1.0|
| 2 | a2 |A | (1,2) |1.41|
| 3 | b1 |B | (3,3) |2.24|
4.統(tǒng)計(jì)出在前k個(gè)距離中趟卸,出現(xiàn)頻次最多的類別
肉眼直接看出蹄葱,頻次最多的類別是A。因此锄列,c點(diǎn)屬于A類图云。
5.總結(jié)
在上面這個(gè)例子中我用了四個(gè)點(diǎn),即四個(gè)向量邻邮,同時(shí)為了方便理解竣况,我使用的是二維坐標(biāo)平面。但是在真正的kNN實(shí)戰(zhàn)中筒严,則涉及的訓(xùn)練數(shù)量是非常龐大的丹泉,同樣,也不會(huì)單單局限于二維鸭蛙,而是多維向量摹恨。但是,其實(shí)現(xiàn)方法都是相同的娶视。當(dāng)然睬塌,我上面舉的例子是不能用來實(shí)際使用的,因?yàn)橛?xùn)練數(shù)據(jù)太少歇万。
上述例子的所有代碼揩晴,感興趣可以自己過一遍:
# -*- coding: utf-8 -*-
# @Date : 2017-04-28 16:52:44
# @Author : Alan Lau (rlalan@outlook.com)
# @Language : Python3.5
from matplotlib import pyplot as plt
import numpy as np
import math
# 定義四個(gè)點(diǎn)的坐標(biāo)
a1 = np.array([1, 1])
a2 = np.array([1, 2])
b1 = np.array([3, 3])
b2 = np.array([3, 4])
c = np.array([2, 1])
# 四個(gè)點(diǎn)坐標(biāo)分別賦值給X,Y
X1, Y1 = a1
X2, Y2 = a2
X3, Y3 = b1
X4, Y4 = b2
X5, Y5 = c
plt.title('show data')
plt.scatter(X1, Y1, color="blue", label="a1")
plt.scatter(X2, Y2, color="blue", label="a2")
plt.scatter(X3, Y3, color="red", label="b1")
plt.scatter(X4, Y4, color="red", label="b2")
plt.scatter(X5, Y5, color="yellow", label="c")
plt.annotate(r'a1(1,1)', xy=(X1, Y1), xycoords='data', xytext=(+10, +20), textcoords='offset points',fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'a2(1,2)', xy=(X2, Y2), xycoords='data', xytext=(+10, +20), textcoords='offset points',fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'b1(3,3)', xy=(X3, Y3), xycoords='data', xytext=(+10, +20), textcoords='offset points',fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'b2(3,4)', xy=(X4, Y4), xycoords='data', xytext=(+10, +20), textcoords='offset points', fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.annotate(r'c(2,1)', xy=(X5, Y5), xycoords='data', xytext=(+30, 0), textcoords='offset points', fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
def Euclidean(vec1, vec2):
npvec1, npvec2 = np.array(vec1), np.array(vec2)
return math.sqrt(((npvec1-npvec2)**2).sum())
# 顯示距離
def show_distance(exit_point, c):
line_point = np.array([exit_point, c])
x = (line_point.T)[0]
y = (line_point.T)[1]
o_dis = round(Euclidean(exit_point, c), 2) # 計(jì)算距離
mi_x, mi_y = (exit_point+c)/2 # 計(jì)算中點(diǎn)位置,來顯示“distance=xx”這個(gè)標(biāo)簽
plt.annotate('distance=%s' % str(o_dis), xy=(mi_x, mi_y), xycoords='data', xytext=(+10, 0), textcoords='offset points', fontsize=10, arrowprops=dict(arrowstyle="-", connectionstyle="arc3,rad=.2"))
return plt.plot(x, y, linestyle="--", color='black', lw=1)
show_distance(a1, c)
show_distance(a2, c)
show_distance(b1, c)
show_distance(b2, c)
plt.show()
實(shí)戰(zhàn)
實(shí)戰(zhàn)這里使用k-近鄰算法實(shí)現(xiàn)手寫數(shù)字識(shí)別系統(tǒng)--《機(jī)器學(xué)習(xí)實(shí)戰(zhàn) 》中的數(shù)據(jù)進(jìn)行贪磺,但是本人的代碼與網(wǎng)站提供的代碼有差異硫兰。
準(zhǔn)備數(shù)據(jù)
在使用數(shù)據(jù)之前,我先對(duì)網(wǎng)站提供的數(shù)據(jù)進(jìn)行預(yù)處理寒锚,方便使用numpy讀取劫映。
網(wǎng)站提供數(shù)據(jù):
處理后的數(shù)據(jù):
實(shí)際上就是在數(shù)字之間加上空格,方便numpy識(shí)別并分割數(shù)據(jù)刹前。
數(shù)據(jù)預(yù)處理的代碼:
# -*- coding: utf-8 -*-
# @Date : 2017-04-03 16:04:19
# @Author : Alan Lau (rlalan@outlook.com)
def fwalker(path):
fileArray = []
for root, dirs, files in os.walk(path):
for fn in files:
eachpath = str(root+'\\'+fn)
fileArray.append(eachpath)
return fileArray
def writetxt(path, content, code):
with open(path, 'a', encoding=code)as f:
f.write(content)
return path+' is ok!'
def readtxt(path, encoding):
with open(path, 'r', encoding=encoding) as f:
lines = f.readlines()
return lines
def buildfile(echkeyfile):
if os.path.exists(echkeyfile):
#創(chuàng)建前先判斷是否存在文件夾泳赋,if存在則刪除
shutil.rmtree(echkeyfile)
os.makedirs(echkeyfile)
else:
os.makedirs(echkeyfile)#else則創(chuàng)建語句
return echkeyfile
def change_data(files, inputpath):
trainpath = buildfile(inputpath+'\\'+'trainingDigits')
testpath = buildfile(inputpath+'\\'+'testDigits')
for file in files:
ech_name = (file.split('\\'))[-2:]
new_path = inputpath+'\\'+'\\'.join(ech_name)
ech_content = readtxt(file, 'utf8')
new_content = []
for ech_line in ech_content:
line_ary = list(ech_line.replace('\n', '').replace('\r', ''))
new_content.append(' '.join(line_ary))
print(writetxt(new_path, '\n'.join(new_content), 'utf8'))
def main():
datapath =r'..\lab3_0930\digits'
inputpath = buildfile(r'..\lab3_0930\input_digits')
files = fwalker(datapath)
change_data(files, inputpath)
if __name__ == '__main__':
main()
實(shí)現(xiàn)代碼
教程網(wǎng)站中利用list下標(biāo)索引將標(biāo)簽和向量進(jìn)行對(duì)應(yīng),而我使用將每一個(gè)標(biāo)簽和向量放到分別一個(gè)list中喇喉,再將這些list放到一個(gè)list內(nèi)祖今,類似于實(shí)現(xiàn)字典。如[[label1,vector1],[label2,vector2],[label3,vector3],...]
。
# -*- coding: utf-8 -*-
# @Date : 2017-04-03 15:47:04
# @Author : Alan Lau (rlalan@outlook.com)
import os
import math
import collections
import numpy as np
def Euclidean(vec1, vec2):
npvec1, npvec2 = np.array(vec1), np.array(vec2)
return math.sqrt(((npvec1-npvec2)**2).sum())
def fwalker(path):
fileArray = []
for root, dirs, files in os.walk(path):
for fn in files:
eachpath = str(root+'\\'+fn)
fileArray.append(eachpath)
return fileArray
def orderdic(dic, reverse):
ordered_list = sorted(
dic.items(), key=lambda item: item[1], reverse=reverse)
return ordered_list
def get_data(data_path):
label_vec = []
files = fwalker(data_path)
for file in files:
ech_label_vec = []
ech_label = int((file.split('\\'))[-1][0])# 獲取每個(gè)向量的標(biāo)簽
ech_vec = ((np.loadtxt(file)).ravel())# 獲取每個(gè)文件的向量
ech_label_vec.append(ech_label) # 將一個(gè)文件夾的標(biāo)簽和向量放到同一個(gè)list內(nèi)
ech_label_vec.append(ech_vec) # 將一個(gè)文件夾的標(biāo)簽和向量放到同一個(gè)list內(nèi)千诬,目的是將標(biāo)簽和向量對(duì)應(yīng)起來耍目,類似于字典,這里不直接用字典因?yàn)樽值涞逆I(key)不可重復(fù)徐绑。
label_vec.append(ech_label_vec) # 再將所有的標(biāo)簽和向量存入一個(gè)list內(nèi)邪驮,構(gòu)成二維數(shù)組
return label_vec
def find_label(train_vec_list, vec, k):
get_label_list = []
for ech_trainlabel_vec in train_vec_list:
ech_label_distance = []
train_label, train_vec = ech_trainlabel_vec[0], ech_trainlabel_vec[1]
vec_distance = Euclidean(train_vec, vec)# 計(jì)算距離
ech_label_distance.append(train_label)
ech_label_distance.append(vec_distance)# 將距離和標(biāo)簽對(duì)應(yīng)存入list
get_label_list.append(ech_label_distance)
result_k = np.array(get_label_list)
order_distance = (result_k.T)[1].argsort()# 對(duì)距離進(jìn)行排序
order = np.array((result_k[order_distance].T)[0])
top_k = np.array(order[:k], dtype=int) # 獲取前k距離和標(biāo)簽
find_label = orderdic(collections.Counter(top_k), True)[0][0]# 統(tǒng)計(jì)在前k排名中標(biāo)簽出現(xiàn)頻次
return find_label
def classify(train_vec_list, test_vec_list, k):
error_counter = 0 #計(jì)數(shù)器,計(jì)算錯(cuò)誤率
for ech_label_vec in test_vec_list:
label, vec = ech_label_vec[0], ech_label_vec[1]
get_label = find_label(train_vec_list, vec, k) # 獲得學(xué)習(xí)得到的標(biāo)簽
print('Original label is:'+str(label) +
', kNN label is:'+str(get_label))
if str(label) != str(get_label):
error_counter += 1
else:
continue
true_probability = str(round((1-error_counter/len(test_vec_list))*100, 2))+'%'
print('Correct probability:'+true_probability)
def main():
k = 3
train_data_path =r'..\lab3_0930\input_digits\trainingDigits'
test_data_path =r'..\lab3_0930\input_digits\testDigits'
train_vec_list = get_data(train_data_path)
test_vec_list = get_data(test_data_path)
classify(train_vec_list, test_vec_list, k)
if __name__ == '__main__':
main()