在尋找輸入樣本的k個近鄰的時候,若進行線性掃描,對于大數(shù)據(jù)集來說耗時太久暖眼,為了加快搜索速度,提出了用kd樹實現(xiàn)k個近鄰的搜索纺裁,此時復(fù)雜度為O(logN)诫肠。
首先是建樹
這里假設(shè)輸入數(shù)據(jù)一個N×K的矩陣,N代表實例點的個數(shù)欺缘,K代表樣本空間的維度栋豫。每一行代表一個實例點。
每個節(jié)點包含六個屬性:
- SamplePoints:實例點的行號谚殊,表示該節(jié)點對應(yīng)區(qū)域包含的所有實例點
- SplitDim:切割對應(yīng)的區(qū)域時選擇的特征(維度)
- MidPoint:是一個元組丧鸯,(切分點的行號,切分特征的中位數(shù))
- left:指向左子節(jié)點
- right:指向右子節(jié)點
- father:指向父節(jié)點
- visited:該節(jié)點是否已被訪問的標志
包含兩個方法:
- get_median():獲取切割特征的中位數(shù)
- get_dim():獲取方差最大的特征作為切割特征
過程如下:
- 構(gòu)造根節(jié)點嫩絮,使根節(jié)點對應(yīng)于k維空間中包含所有實例點的超矩形區(qū)域丛肢;
- 在超矩形區(qū)域上選擇一個坐標軸和在一個切分點,確定一個超平面剿干,這個超平面通過選定的切分點并垂直于選定的坐標軸蜂怎,將當前超矩形區(qū)域切分為左右兩個子區(qū)域(子節(jié)點);這時置尔,實例被分到兩個子區(qū)域杠步。
- 將切分點保存在根節(jié)點上。
- 重復(fù)步驟2榜轿、3篮愉,直到子區(qū)域內(nèi)只含有不含實例時終止。
import numpy as np
from collections import Counter
class KdTreeNode(object):
def __init__(self, SamplePoints):
self.SamplePoints = SamplePoints
self.SplitDim = self.get_dim()
self.MidPoint = self.get_MidPoint()
self.left = None
self.right = None
self.father = None
self.visited = False
def get_dim(self):
variance = np.var(X[self.SamplePoints, :], axis = 0) #計算該節(jié)點包含的實例點每個特征的方差
#print(variance)
return np.argmax(variance) #選擇方差最大的特征
def get_MidPoint(self):
tmp = X[self.SamplePoints, self.SplitDim]
length = len(tmp)
index = np.argsort(tmp) #該函數(shù)返回的是數(shù)組值從小到大的索引值
return (self.SamplePoints[index[int(length/2)]], tmp[index[int(length/2)]]) #(中位數(shù)所在的行號差导,中位數(shù)的值)
def build_tree(SamplePoints, father = None): #構(gòu)建kd樹
if len(SamplePoints) == 0: #子區(qū)域不含實例點時停止
return None
root = KdTreeNode(SamplePoints)
LeftPoints = [] #分割區(qū)域依據(jù)的特征小于或等于median的實例點
RightPoints = [] #分割區(qū)域依據(jù)的特征大于median的實例點
for x in SamplePoints:
if x == root.MidPoint[0]:
continue
if X[x, root.SplitDim] <= root.MidPoint[1]:
LeftPoints.append(x)
else:
RightPoints.append(x)
root.father = father
if len(SamplePoints) > 1: #子區(qū)域只含一個點時停止
root.left = build_tree(LeftPoints, root) #構(gòu)建左子樹
root.right = build_tree(RightPoints, root) #構(gòu)建右子樹
return root
最近鄰搜索
- 從根節(jié)點出發(fā)试躏,遞歸地向下訪問kd樹。若目標點x當前維(即切割根節(jié)點對應(yīng)區(qū)域時選擇的維度)的坐標小于或等于切分點的坐標设褐,則移動到左子節(jié)點颠蕴,否則移動到右子節(jié)點。直到子節(jié)點為葉節(jié)點為止助析,記此葉節(jié)點為L犀被。
- 以此葉節(jié)點L上的切分點為“當前最近點Ncur”,記錄Ncur與目標點的距離為Dcur外冀。
- 判斷L的父節(jié)點是否已被訪問寡键。
3.1. 若未被訪問,檢查L的父節(jié)點的另一子節(jié)點(即L的兄弟節(jié)點)對應(yīng)的區(qū)域是否與以目標點為球心以Dcur為半徑的超球體相交雪隧。具體做法是在分割L的父節(jié)點區(qū)域時選擇的維度上計算目標點與切分點的坐標差值的絕對值西轩,然后將其與Dcur比較员舵。
a) 若大于Dur,說明不相交藕畔。則標記L的父節(jié)點已被訪問马僻,回到此步驟的開頭。
b) 若小于或等于Dcur注服,說明相交韭邓。先計算L的父節(jié)點上的切分點與目標點的距離,檢查是否要更新Pcur與Dcur溶弟,完成后標記L的父節(jié)點已被訪問女淑。從L的兄弟節(jié)點出發(fā),按照步驟1找到一個新的葉節(jié)點L辜御。計算L上的切分點與目標點的距離诗力,檢查是否要更新Pcur與Dcur,完成后回到此步驟的開頭我抠。
3.2 若已被訪問,判斷L的父節(jié)點是否為根節(jié)點袜茧。
a) 若是,則停止整個程序笛厦。Pcur即為目標點的最近鄰纳鼎。
b) 若不是,則回退到L的父節(jié)點,具做法為令L=L的父節(jié)點捌议,然后回到此步驟的開頭瓣颅。
def approx_nearest_neighbor(root, TargetPoint): #尋找樹中與目標點的近似最近鄰點倦逐,該最似最近鄰僅僅是與目標點在同一分區(qū)中,不一定是最近鄰
if root.left == None and root.right == None:
return root
if TargetPoint[root.SplitDim] <= root.MidPoint[1]:
if root.left == None: #若應(yīng)往左子樹走時發(fā)現(xiàn)左子樹為空宫补,轉(zhuǎn)向右子樹搜尋檬姥,保證最后返回的是一個葉節(jié)點
return approx_nearest_neighbor(root.right, TargetPoint)
return approx_nearest_neighbor(root.left, TargetPoint)
else:
if root.right == None: #若應(yīng)往右子樹走時發(fā)現(xiàn)左子樹為空曾我,轉(zhuǎn)向左子樹搜尋
return approx_nearest_neighbor(root.left, TargetPoint)
return approx_nearest_neighbor(root.right, TargetPoint)
def nearest_neighbor_search(root, TargetPoint): #搜索與目標點的歐氏距離最小的樣本點
Vis = approx_nearest_neighbor(root, TargetPoint) #表示以該節(jié)點為根節(jié)點的子樹已被搜索完成
Ncur = X[Vis.MidPoint[0], :]#開始時直接用近似最近鄰點作為當前最近鄰點
Dcur = np.sqrt(np.sum(np.square(Ncur - TargetPoint))) #目標點與當前最近鄰的歐式距離
if Vis == root: #當樣本空間中只有一個點則直接輸出該點,注意Vis是一個節(jié)點穿铆,Ncur是一個點向量
return (Ncur, Dcur)
while True:
if not Vis.father.visited: #若Vis的父節(jié)點未被訪問
VerticalDis = abs(TargetPoint[Vis.father.SplitDim] - Vis.father.MidPoint[1]) #目標點到以Vis父節(jié)點為切分點的分割超平面的垂直距離
#若Vis的兄弟節(jié)點代表的區(qū)域與以目標點為圓心Dcur為半徑的圓相交
if VerticalDis <= Dcur:
EuclideanDis = np.sqrt(np.sum(np.square(X[Vis.father.MidPoint[0], :] - TargetPoint))) #Vis的父節(jié)點與目標點的距離
if EuclideanDis < Dcur: #若比Dcur小您单,則將其作為當前最近鄰
Dcur = EuclideanDis
Ncur = X[Vis.father.MidPoint[0], :]
Vis.father.visited = True #此節(jié)點已被訪問
#尋找Vis的兄弟節(jié)點
if Vis.father.left == Vis:
brother = Vis.father.right
else:
brother = Vis.father.left
#若無兄弟節(jié)點,直接爬升到Vis的父節(jié)點
if brother == None:
continue
#若有兄弟節(jié)點
Vis = approx_nearest_neighbor(brother, TargetPoint)
EuclideanDis = np.sqrt(np.sum(np.square(X[Vis.MidPoint[0], :] - TargetPoint)))
if EuclideanDis < Dcur:
Dcur = EuclideanDis
Ncur = X[Vis.MidPoint[0], :]
continue
#若不相交
else:
Vis.father.visited = True
else: #若Vis的父節(jié)點已被訪問
if Vis.father == root: #若根節(jié)點已被訪問荞雏,則結(jié)束搜索
break
else:
Vis = Vis.father #向上爬升到Vis的父節(jié)點
return (Ncur, Dcur)
K近鄰搜索
k近鄰的搜索與最近鄰搜索類似虐秦,不過程序中的“當前最近鄰Ncur”要改為“當前K近鄰Kcur”,它是一個二維列表凤优,里面的每一行代表了K個近鄰點中的一個悦陋。在每次比較一個新的節(jié)點時,都需判斷是否要對它進行更新筑辨,用離目標點更近的點代替更遠的點俺驶。
def compare_dis(CurrentPoint, TargetPoint, Ncur, K): #計算樣本點與目標點的距離,若有必要的話對Ncur進行更新
EuclideanDis = np.sqrt(np.sum(np.square(CurrentPoint - TargetPoint))) #計算歐式距離
Ncur = sorted(Ncur, key = lambda x : -x[1]) #對Ncur中的K個點按照到目標點的距離從遠到近排序
if EuclideanDis < Ncur[0][1]: #如果當前目標點到目標點的距離比Ncur中最遠的點要近棍辕,則對Ncur進行更新
Ncur = Ncur[1:K]
Ncur.append((CurrentPoint, EuclideanDis))
return Ncur
def k_neighbor_search(root, TargetPoint, K): #搜索與目標點的歐氏距離最小的K個樣本點
Vis = approx_nearest_neighbor(root, TargetPoint) #Vis表示以該節(jié)點為根節(jié)點的子樹已被搜索完成
Ncur = [] #存儲當前K個近鄰點
for i in range(K): #用K個離目標點無窮遠的點作為Ncur的初始值
Ncur.append((X[i,:], float('inf')))
Ncur = compare_dis(X[Vis.MidPoint[0], :], TargetPoint, Ncur, K)
if Vis == root: #當K=1時暮现, 若樣本空間中只有一個點,則直接輸出該點
return Ncur
while True:
if not Vis.father.visited: #若Vis的父節(jié)點未被訪問
VerticalDis = abs(TargetPoint[Vis.father.SplitDim] - Vis.father.MidPoint[1]) #目標點到以Vis父節(jié)點為切分點的分割超平面的垂直距離
#若Vis的兄弟節(jié)點代表的區(qū)域與以目標點為圓心Dcur為半徑的圓相交
if VerticalDis <= sorted(Ncur, key = lambda x : -x[1])[0][1]:
Ncur = compare_dis(X[Vis.father.MidPoint[0], :], TargetPoint, Ncur, K) #判斷Vis的父節(jié)點是否要加入到Ncur中
Vis.father.visited = True #此節(jié)點已被訪問
brother = Vis.father.right if Vis.father.left == Vis else Vis.father.left #尋找Vis的兄弟節(jié)點
#若無兄弟節(jié)點楚昭,直接爬升到Vis的父節(jié)點
if brother == None:
continue
#若有兄弟節(jié)點
Vis = approx_nearest_neighbor(brother, TargetPoint)
Ncur = compare_dis(X[Vis.MidPoint[0], :], TargetPoint, Ncur, K)
continue
#若不相交
else:
Vis.father.visited = True
else: #若Vis的父節(jié)點已被訪問
if Vis.father == root: #若根節(jié)點已被訪問栖袋,則結(jié)束搜索
break
else:
Vis = Vis.father #向上爬升到Vis的父節(jié)點
return Ncur
測試程序
下圖中的紅色叉叉代表目標點。
#主程序
X = np.array([[2,3],
[5,4],
[9,6],
[4,7],
[8,1],
[7,2]]) #存儲樣本向量
TargetPoint = np.array([8, 0]) #輸入目標點
root = build_tree(range(len(X))) #建樹
while True:
K = int(input('Input K:').strip()) #若樣本點的個數(shù)沒有K個抚太,需重新設(shè)定K
if len(X) < K:
print('Retry')
continue
break
Ncur = k_neighbor_search(root, TargetPoint, K)
for point in Ncur:
print(point[0]) #輸出K個近鄰點的坐標
特征空間劃分