根據(jù)統(tǒng)計學(xué)習(xí)方法寫的KdTree實現(xiàn)欢峰,###
參考了這個博客的主要思路,但是在關(guān)于如何搜索最近鄰上有些不同嚎尤。
1.我采取在發(fā)現(xiàn)可能的路徑后苦酱,采取擴展路徑到葉子節(jié)點,生成一個新路徑后重新計算最近路徑育特。而這個博客中只檢查了路徑上與超球體相交的點丙号。沒有遞歸搜索
2.他的博客用利用方差確定分割的方向。我則選用了簡單的依次更換策略缰冤。
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
using namespace std;
struct Node
{
double x;
double y;
};
struct KdTree
{
Node val;
int split; /描述根據(jù)X或Y進行劃分/
KdTree* left;
KdTree* right;
};
KdTree myKdTree{};
const int N = 6;
const int dim = 2;
Node dataSet[N] = {
{ 2,3 },
{ 5,4 },
{ 9,6 },
{ 4,7 },
{ 8,1 },
{ 7,2 }
};
int time = 0;/記錄尋找分割次數(shù)/
stack<KdTree> search_path;/記錄搜索過程的路經(jīng)*/
/*結(jié)果結(jié)構(gòu)*/
struct result {
Node resNode;
double dist;
};
/*X犬缨,Y維比較函數(shù)*/
bool compareX(Node a,Node b) {
return a.x > b.x;
}
bool compareY(Node a, Node b) {
return a.y > b.y;
}
void chooseSplit(Node unsortSet[],Node& splitData,int size) {
if (time % 2 == 0) {
/*根據(jù)x維分割*/
sort(unsortSet, unsortSet + size, compareX);
}
else {
/*根據(jù)y維分割*/
sort(unsortSet, unsortSet + size, compareY);
}
int mid;
if (size % 2 == 0) {
mid = size / 2 - 1;
}
else {
mid = size / 2;
}
splitData.x = unsortSet[mid].x;
splitData.y = unsortSet[mid].y;
time++;
}
/*構(gòu)造kdTree*/
KdTree* build(int size,Node unsortSet[], KdTree* tree) {
if (size == 0) {
return 0;
}
else {
int split;
Node splitData;
chooseSplit(unsortSet,splitData, size);
Node leftset[100]{};
Node rightset[100]{};
int leftnum = 0;
int rightnum = 0;
if (time % 2 == 1) {
/*根據(jù)x維分割,time加一后*/
split = 0;
for (int i = 0; i < size; i++) {
if (splitData.x > unsortSet[i].x) {
leftset[leftnum] = unsortSet[i];
leftnum++;
}
else if(splitData.x < unsortSet[i].x) {
rightset[rightnum] = unsortSet[i];
rightnum++;
}
}
}
else {
split = 1;
for (int i = 0; i < size; i++) {
if (splitData.y > unsortSet[i].y) {
leftset[leftnum] = unsortSet[i];
leftnum++;
}
else if (splitData.y < unsortSet[i].y) {
rightset[rightnum] = unsortSet[i];
rightnum++;
}
}
}
tree = new KdTree;
tree->val = splitData;
tree->split = split;
tree->left = build(leftnum, leftset, tree->left);
tree->right = build(rightnum, rightset, tree->right);
return tree;
}
}
/*計算距離 p=2*/
double distance(Node a, Node b) {
return (a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y);
}
/*建立搜索路徑*/
void buildpath(Node target, KdTree* tree) {
KdTree* pSearch = tree;
while (pSearch != NULL) {
search_path.push(pSearch);
if (pSearch->split == 0) {
if (target.x < pSearch->val.x) {
pSearch = pSearch->left;
}
else {
pSearch = pSearch->right;
}
}
else {
if (target.y < tree->val.y) {
pSearch = pSearch->left;
}
else {
pSearch = pSearch->right;
}
}
}
}
/*根據(jù)搜索路徑查找最近鄰*/
result findnearest (Node target,KdTree* tree){
/*初始化搜索路徑*/
buildpath(target, tree);
Node nearest = search_path.top()->val;
double dist = distance(nearest, target);
search_path.pop();
//搜索潛在的路徑上最近點。
KdTree* pBack;
while (search_path.size() != 0) {
pBack = search_path.top();
search_path.pop();
if (pBack->left == NULL && pBack->right == NULL) {
if (distance(pBack->val, target) < dist) {
dist = distance(pBack->val, target);
nearest = pBack->val;
}
}
else {
if (pBack->split == 0) {
if (abs(target.x - pBack->val.x) < dist) {//X方向相交棉浸。
KdTree* newTree{};
if ((target.x > pBack->val.x)&&(pBack->left !=NULL)) {//點在右側(cè)怀薛,向左搜索。
search_path.push(pBack->left);
newTree = pBack->left;
}
if ((target.x < pBack->val.x) && (pBack->right != NULL)) {
search_path.push(pBack->right);
newTree = pBack->right;
};
//搜索新發(fā)現(xiàn)的路徑
buildpath(target, newTree);
}
}
else {
if (abs(target.y - pBack->val.y) < dist) {//Y方向相交迷郑。
KdTree* newTree{};
if ((target.y > pBack->val.y) && (pBack->left != NULL)) {//點在右側(cè)枝恋,向左搜索。
search_path.push(pBack->left);
newTree = pBack->left;
}
if ((target.y < pBack->val.y) && (pBack->right != NULL)) {
search_path.push(pBack->right);
newTree = pBack->right;
};
//搜索新發(fā)現(xiàn)的路徑
buildpath(target, newTree);
}
}
}
}
return result{ nearest ,dist };
}
//打印樹結(jié)構(gòu)
void printNode(Node node) {
cout << "("<<node.x<<","<<node.y<<")"<<endl;
}
void printTree_rootfirst(KdTree* root) {
printNode(root->val);
if (root->left != NULL) {
printTree_rootfirst(root->left);
}
if (root->right != NULL) {
printTree_rootfirst(root->right);
}
}
void printTree_leftfirst(KdTree* root) {
if (root->left != NULL) {
printTree_leftfirst(root->left);
}
printNode(root->val);
if (root->right != NULL) {
printTree_leftfirst(root->right);
}
}
int main() {
KdTree * root = NULL;
root = build(N, dataSet, root);
Node target = {2,4.5};
result res = findnearest(target,root);
cout <<"最近距離:"<< res.dist << endl;
cout <<"X方向:"<< res.resNode.x << endl;
cout << "Y方向:" << res.resNode.y << endl;
system("pause");
}