k近鄰算法
給定一個訓練數(shù)據(jù)集,對新的輸入實例周霉,在訓練數(shù)據(jù)集中找到跟它最近的k個實例雨席,根據(jù)這k個實例的類判斷它自己的類(一般采用多數(shù)表決的方法)。
k近鄰模型
模型有3個要素——距離度量方法捌显、k值的選擇和分類決策規(guī)則茁彭。
模型
當3要素確定的時候,對任何實例(訓練或輸入)扶歪,它所屬的類都是確定的理肺,相當于將特征空間分為一些子空間。
距離度量
對n維實數(shù)向量空間Rn击罪,經(jīng)常用Lp距離或曼哈頓Minkowski距離哲嘲。
Lp距離定義如下:
當p=2時,稱為歐氏距離:
當p=1時媳禁,稱為曼哈頓距離:
當p=∞眠副,它是各個坐標距離的最大值,即:
用圖表示如下:
k值的選擇
k較小竣稽,容易被噪聲影響囱怕,發(fā)生過擬合。
k較大毫别,較遠的訓練實例也會對預(yù)測起作用娃弓,容易發(fā)生錯誤。
分類決策規(guī)則
使用0-1損失函數(shù)衡量岛宦,那么誤分類率是:
Nk是近鄰集合台丛,要使左邊最小,右邊的
必須最大,所以多數(shù)表決=經(jīng)驗最小化挽霉。
k近鄰法的實現(xiàn):kd樹
算法核心在于怎么快速搜索k個近鄰出來防嗡,樸素做法是線性掃描,不可取侠坎,這里介紹的方法是kd樹蚁趁。
構(gòu)造kd樹
對數(shù)據(jù)集T中的子集S初始化S=T,取當前節(jié)點node=root取維數(shù)的序數(shù)i=0实胸,對S遞歸執(zhí)行:
找出S的第i維的中位數(shù)對應(yīng)的點他嫡,通過該點,且垂直于第i維坐標軸做一個超平面庐完。該點加入node的子節(jié)點钢属。該超平面將空間分為兩個部分,對這兩個部分分別重復此操作(S=S'假褪,++i署咽,node=current),直到不可再分生音。
T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
class node:
def __init__(self, point):
self.left = None
self.right = None
self.point = point
pass
def median(lst):
m = len(lst) / 2
return lst[m], m
def build_kdtree(data, d):
data = sorted(data, key=lambda x: x[d])
p, m = median(data)
tree = node(p)
del data[m]
print data, p
if m > 0: tree.left = build_kdtree(data[:m], not d)
if len(data) > 1: tree.right = build_kdtree(data[m:], not d)
return tree
kd_tree = build_kdtree(T, 0)
print kd_tree
可視化
可視化的話則要費點功夫保存中間結(jié)果宁否,并恰當?shù)卣故境鰜?/p>
# -*- coding:utf-8 -*-
# Filename: kdtree.py
# Author:hankcs
# Date: 2015/2/4 15:01
import copy
import itertools
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib import animation
T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
def draw_point(data):
X, Y = [], []
for p in data:
X.append(p[0])
Y.append(p[1])
plt.plot(X, Y, 'bo')
def draw_line(xy_list):
for xy in xy_list:
x, y = xy
plt.plot(x, y, 'g', lw=2)
def draw_square(square_list):
currentAxis = plt.gca()
colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
for square in square_list:
currentAxis.add_patch(
Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
color=next(colors)))
def median(lst):
m = len(lst) / 2
return lst[m], m
history_quare = []
def build_kdtree(data, d, square):
history_quare.append(square)
data = sorted(data, key=lambda x: x[d])
p, m = median(data)
del data[m]
print data, p
if m >= 0:
sub_square = copy.deepcopy(square)
if d == 0:
sub_square[1][0] = p[0]
else:
sub_square[1][1] = p[1]
history_quare.append(sub_square)
if m > 0: build_kdtree(data[:m], not d, sub_square)
if len(data) > 1:
sub_square = copy.deepcopy(square)
if d == 0:
sub_square[0][0] = p[0]
else:
sub_square[0][1] = p[1]
build_kdtree(data[m:], not d, sub_square)
build_kdtree(T, 0, [[0, 0], [10, 10]])
print history_quare
# draw an animation to show how it works, the data comes from history
# first set up the figure, the axis, and the plot element we want to animate
fig = plt.figure()
ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
line, = ax.plot([], [], 'g', lw=2)
label = ax.text([], [], '')
# initialization function: plot the background of each frame
def init():
plt.axis([0, 10, 0, 10])
plt.grid(True)
plt.xlabel('x_1')
plt.ylabel('x_2')
plt.title('build kd tree (www.hankcs.com)')
draw_point(T)
currentAxis = plt.gca()
colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
# animation function. this is called sequentially
def animate(i):
square = history_quare[i]
currentAxis.add_patch(
Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
color=next(colors)))
return
# call the animator. blit=true means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
blit=False)
plt.show()
anim.save('kdtree_build.gif', fps=2, writer='imagemagick')
搜索kd樹
上面的代碼其實并沒有搜索kd樹,現(xiàn)在來實現(xiàn)搜索缀遍。
搜索跟二叉樹一樣來慕匠,是一個遞歸的過程。先找到目標點的插入位置域醇,然后往上走台谊,逐步用自己到目標點的距離畫個超球體,用超球體圈住的點來更新最近鄰(或k最近鄰)譬挚。以最近鄰為例锅铅,實現(xiàn)如下(本實現(xiàn)由于測試數(shù)據(jù)簡單,沒有做超球體與超立體相交的邏輯):
# -*- coding:utf-8 -*-
# Filename: search_kdtree.py
# Author:hankcs
# Date: 2015/2/4 15:01
T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
class node:
def __init__(self, point):
self.left = None
self.right = None
self.point = point
self.parent = None
pass
def set_left(self, left):
if left == None: pass
left.parent = self
self.left = left
def set_right(self, right):
if right == None: pass
right.parent = self
self.right = right
def median(lst):
m = len(lst) / 2
return lst[m], m
def build_kdtree(data, d):
data = sorted(data, key=lambda x: x[d])
p, m = median(data)
tree = node(p)
del data[m]
if m > 0: tree.set_left(build_kdtree(data[:m], not d))
if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
return tree
def distance(a, b):
print a, b
return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
def search_kdtree(tree, d, target):
if target[d] < tree.point[d]:
if tree.left != None:
return search_kdtree(tree.left, not d, target)
else:
if tree.right != None:
return search_kdtree(tree.right, not d, target)
def update_best(t, best):
if t == None: return
t = t.point
d = distance(t, target)
if d < best[1]:
best[1] = d
best[0] = t
best = [tree.point, 100000.0]
while (tree.parent != None):
update_best(tree.parent.left, best)
update_best(tree.parent.right, best)
tree = tree.parent
return best[0]
kd_tree = build_kdtree(T, 0)
print search_kdtree(kd_tree, 0, [9, 4])
輸出
[8, 1] [9, 4]
[5, 4] [9, 4]
[9, 6] [9, 4]
[9, 6]