Scala實(shí)現(xiàn):KD-Tree(k-dimensional tree)
kd-tree是一種分割k維數(shù)據(jù)空間的數(shù)據(jù)結(jié)構(gòu)傻盟。主要應(yīng)用于多維空間數(shù)據(jù)的搜索吩坝,經(jīng)常使用在SIFT、KNN等多維數(shù)據(jù)搜索的場景中穗熬,以KNN(K近鄰)為例濒翻,使用線性搜索的方式效率低下,k-d樹本質(zhì)是對多維空間的劃分凫碌,其每個(gè)節(jié)點(diǎn)都為k維點(diǎn)的二叉樹kd-tree扑毡,因此可以大大提高搜索效率。
KD-Tree的構(gòu)建步驟:
上述文字引自李航博士的《統(tǒng)計(jì)學(xué)習(xí)方法》
以{(2,3)盛险,(5,4)瞄摊,(9,6),(4,7)苦掘,(8,1)换帜,(7,2)}數(shù)據(jù)集為例構(gòu)建KD-Tree。
KD-Tree空間劃分示意圖如下:
關(guān)于三維數(shù)據(jù)的空間劃分示意圖如下所示
更多維度的數(shù)據(jù)劃分只能靠腦補(bǔ)了······
KD-Tree最鄰近搜索:
從根節(jié)點(diǎn)開始鹤啡,遞歸的往下訪問kd樹惯驼,比較目標(biāo)點(diǎn)與切分點(diǎn)在當(dāng)前切分維度的大小,小于則移動到左子結(jié)點(diǎn)递瑰,大于則移動到右子結(jié)點(diǎn)祟牲,知道子結(jié)點(diǎn)為葉結(jié)點(diǎn)為止。
一旦移動到葉結(jié)點(diǎn)抖部,將該結(jié)點(diǎn)當(dāng)作"當(dāng)前最鄰近點(diǎn)"说贝。
遞歸回退,對每個(gè)經(jīng)過的葉結(jié)點(diǎn)遞歸地執(zhí)行下列操作:
- 如果當(dāng)前所在點(diǎn)比"當(dāng)前最鄰近點(diǎn)"更靠近輸入點(diǎn)慎颗,則將其變?yōu)楫?dāng)前最鄰近點(diǎn)乡恕。
-
- 當(dāng)前最近點(diǎn)一定存在于該節(jié)點(diǎn)一個(gè)子結(jié)點(diǎn)對應(yīng)的區(qū)域言询,檢查另一子結(jié)點(diǎn)對應(yīng)的區(qū)域是否與目標(biāo)點(diǎn)為球心,以目標(biāo)點(diǎn)與“當(dāng)前最鄰近點(diǎn)”之間的距離為半徑的超球體相交:
- 1.如果相交几颜,可能在另一結(jié)點(diǎn)對應(yīng)之區(qū)域內(nèi)存在距離目標(biāo)點(diǎn)更近的點(diǎn)倍试,移動到另一子結(jié)點(diǎn),接著遞歸地進(jìn)行最近鄰搜索蛋哭;
- 2.如果不相交县习,向上回退。
- 當(dāng)回退到根節(jié)點(diǎn)時(shí)谆趾,搜索結(jié)束躁愿。最后的“當(dāng)前最鄰近點(diǎn)"即為x的最近鄰點(diǎn)。
Scala代碼實(shí)現(xiàn)
定義樹節(jié)點(diǎn)
/**
*
* @param value 節(jié)點(diǎn)數(shù)據(jù)
* @param dim 當(dāng)前切分維度
* @param left 左子結(jié)點(diǎn)
* @param right 右子結(jié)點(diǎn)
*/
case class TreeNode(value: Seq[Double],
dim: Int,
var left: TreeNode,
var right: TreeNode) {
var parent: TreeNode = _ //父結(jié)點(diǎn)
var brotherNode: TreeNode = _ //兄弟結(jié)點(diǎn)
if (left != null) {
left.parent = this
left.brotherNode = right
}
if (right != null) {
right.parent = this
right.brotherNode = left
}
}
創(chuàng)建KD-Tree
/**
*
* @param value 數(shù)據(jù)序列
* @param dim 當(dāng)前劃分的維度
* @param shape 數(shù)據(jù)維數(shù)
* @return
*/
def creatKdTree(value: Seq[Seq[Double]], dim: Int, shape: Int): TreeNode = {
// 數(shù)據(jù)按照當(dāng)前劃分的維度排序
val sorted = value.sortBy(_ (dim))
//中間位置的索引
val midIndex: Int = value.length / 2
sorted match {
// 當(dāng)節(jié)點(diǎn)為空時(shí)沪蓬,返回null
case Nil => null
//節(jié)點(diǎn)不為空時(shí)彤钟,遞歸調(diào)用
case _ =>
val left = sorted.slice(0, midIndex)
val right = sorted.slice(midIndex + 1, value.length)
val leftNode = creatKdTree(left, (dim + 1) % shape, shape) //左子節(jié)點(diǎn)遞歸創(chuàng)建樹
val rightNode = creatKdTree(right, (dim + 1) % shape, shape) //右子節(jié)點(diǎn)遞歸創(chuàng)建樹
TreeNode(sorted(midIndex), dim, leftNode, rightNode)
}
}
最近鄰查找
// 歐式距離算法
def euclidean(p1: Seq[Double], p2: Seq[Double]) = {
require(p1.size == p2.size)
val d = p1
.zip(p2)
.map(tp => math.pow(tp._1 - tp._2, 2))
.sum
math.sqrt(d)
}
/**
*
* @param treeNode kdtree
* @param data 查詢點(diǎn)
* 最近鄰搜索
*/
def nearestSearch(treeNode: TreeNode, data: Seq[Double]): TreeNode = {
var nearestNode: TreeNode = null //當(dāng)前最近節(jié)點(diǎn)
var minDist: Double = Double.MaxValue //當(dāng)前最小距離
def finder(treeNode: TreeNode): TreeNode = {
treeNode match {
case null => nearestNode
case _ =>
val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)
val distc = euclidean(treeNode.value, data)
if (distc <= minDist) {
minDist = distc
nearestNode = treeNode
}
// 目標(biāo)點(diǎn)與當(dāng)前節(jié)點(diǎn)相交
if (math.abs(dimr) < minDist)
if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)
nearestNode
}
}
finder(treeNode)
}
結(jié)果查看
val nodes: Seq[Seq[Double]] =
Seq(Seq(2, 3), Seq(5, 4), Seq(9, 6), Seq(4, 7), Seq(8, 1), Seq(7, 2))
val treeNode: TreeNode = KdTree.creatKdTree(nodes, 0, 2)
println(treeNode)
println(KdTree.nearestSearch(treeNode, Seq(2.1, 4.5)).value)
println("==============")
nodes.map(x => {
val d = KdTree.euclidean(x, Seq(2.1, 4.5))
(d, x)
})
.sortBy(_._1)
.foreach(println)
TreeNode(List(7.0, 2.0),0,TreeNode(List(5.0, 4.0),1,TreeNode(List(2.0, 3.0),0,null,null),TreeNode(List(4.0, 7.0),0,null,null)),TreeNode(List(9.0, 6.0),1,TreeNode(List(8.0, 1.0),0,null,null),null))
List(2.0, 3.0)
==============
(1.503329637837291,List(2.0, 3.0))
(2.9427877939124323,List(5.0, 4.0))
(3.1400636936215163,List(4.0, 7.0))
(5.500909015790027,List(7.0, 2.0))
(6.860029154456998,List(8.0, 1.0))
(7.061161377563892,List(9.0, 6.0))
TODO K近鄰查找(KNN)
/**
* 從root節(jié)點(diǎn)開始,DFS搜索直到葉子節(jié)點(diǎn)跷叉,同時(shí)在stack中順序存儲已經(jīng)訪問的節(jié)點(diǎn)逸雹。
* 如果搜索到葉子節(jié)點(diǎn),當(dāng)前的葉子節(jié)點(diǎn)被設(shè)為最近鄰節(jié)點(diǎn)云挟。
* 然后通過stack回溯:
* 如果當(dāng)前點(diǎn)的距離比最近鄰點(diǎn)距離近梆砸,更新最近鄰節(jié)點(diǎn).
* 然后檢查以最近距離為半徑的圓是否和父節(jié)點(diǎn)的超平面相交.
* 如果相交,則必須到父節(jié)點(diǎn)的另外一側(cè)园欣,用同樣的DFS搜索法帖世,開始檢查最近鄰節(jié)點(diǎn)。
* 如果不相交沸枯,則繼續(xù)往上回溯日矫,而父節(jié)點(diǎn)的另一側(cè)子節(jié)點(diǎn)都被淘汰,不再考慮的范圍中.
* 當(dāng)搜索回到root節(jié)點(diǎn)時(shí)绑榴,搜索完成哪轿,得到最近鄰節(jié)點(diǎn)。
*
* @param treeNode
* @param data
* @param k
* @return
*/
def knn(treeNode: TreeNode, data: Seq[Double], k: Int) = {
var resArr = new Array[(Double, TreeNode)](k)
.map(_ => (Double.MaxValue, null))
.asInstanceOf[Array[(Double, TreeNode)]]
def finder(treeNode: TreeNode): TreeNode = {
if (treeNode != null) {
val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)
val distc: Double = distanceUtils.euclidean(treeNode.value, data)
if (distc < resArr.last._1 ) {
resArr.update(k - 1, (distc, treeNode))
resArr = resArr.sortBy(_._1)
}
if (math.abs(dimr) < resArr.last._1)
if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)
}
resArr.last._2
}
finder(treeNode)
resArr
}
KNN結(jié)果查看
KdTree
.knn(treeNode, Seq(2.1, 4.5), 3)
.map(x => (x._1, x._2.value))
.foreach(println)
(1.503329637837291,List(2.0, 3.0))
(2.9427877939124323,List(5.0, 4.0))
(3.1400636936215163,List(4.0, 7.0))
參考資料
https://baike.baidu.com/item/kd-tree/2302515?fr=aladdin#7_1
《統(tǒng)計(jì)學(xué)習(xí)方法》