閱讀本文前帆阳,建議查閱相關(guān)資料哺壶,了解 KNN 算法與 KD 樹。
基礎(chǔ)知識(shí)
如圖所示蜒谤,假設(shè)一個(gè)點(diǎn) a
目前的最近鄰點(diǎn)為 b
山宾,如果存在相對(duì)于 b
離 a
更近的點(diǎn),那么這個(gè)點(diǎn)一定在以 a
為圓心芭逝,ab
為半徑的圓內(nèi)塌碌。
現(xiàn)右側(cè)的區(qū)域是未知的,如果 a
到分界線的距離 l
大于目前的最近距離 L
(圓半徑)旬盯,則沒有必要在右側(cè)的未知區(qū)域繼續(xù)尋找最近鄰點(diǎn)(如圖一)台妆,反之,則要繼續(xù)尋找(如圖二)胖翰。
相應(yīng)的接剩,投射到多維空間,假如切分邊界為第 i
維萨咳,切分點(diǎn)的值為 v
(標(biāo)量)懊缺,當(dāng)前最近鄰點(diǎn)為 y
(向量),如果目標(biāo)點(diǎn) x
(向量) 到切分邊界的距離 |x[i] - v| 滿足以下關(guān)系
時(shí)培他,需要在另一側(cè)繼續(xù)搜索鹃两。
通常地,一個(gè)機(jī)器學(xué)習(xí)算法分為
fit
和 predict
兩個(gè)階段舀凛,基于線性搜索的 KNN
是一種惰性算法俊扳,它將全部的計(jì)算任務(wù)放到了 predict
階段,predict
的時(shí)間復(fù)雜度為 O(n)
猛遍,KD 樹之所以比線性搜索快馋记,就是因?yàn)樗鼘⒁徊糠秩蝿?wù)放到了 fit
(建立 KD 樹) 階段,從而在搜索時(shí)可以略去大量不必搜索的結(jié)點(diǎn)(最優(yōu)情況下時(shí)間復(fù)雜度為 O(1)
)懊烤。上面說的比較簡(jiǎn)單梯醒,關(guān)于 KNN 算法和 KD 樹的詳細(xì)內(nèi)容,請(qǐng)參考李航博士的《統(tǒng)計(jì)學(xué)習(xí)方法》腌紧。
代碼
我們給出部分關(guān)鍵性的代碼茸习。
基本數(shù)據(jù)結(jié)構(gòu)
- 訓(xùn)練集用一個(gè)一維數(shù)組
double *data
表示,它的長度為n_samples * n_features
壁肋,標(biāo)簽集也用一個(gè)一維數(shù)組double *labels
表示号胚,它的長度為n_samples
代箭。 - 樹的結(jié)點(diǎn)用以下數(shù)據(jù)結(jié)構(gòu)表示
struct tree_node { size_t id; // 表示訓(xùn)練集中的第 i 個(gè)數(shù)據(jù) size_t split; // 切分的維度 tree_node *left, *right; // 左、右子樹 };
- 一個(gè) KD 樹的模型可用以下結(jié)構(gòu)表示
struct tree_model { tree_node *root; // 根結(jié)點(diǎn) const double *datas; // X const double *labels; // y size_t n_samples; // 樣例數(shù) size_t n_features; // 每個(gè)樣例的特征數(shù) double p; // 距離度量 };
- 求 K-近鄰時(shí)需要用到大頂堆涕刚,我們直接用 C++ 的優(yōu)先隊(duì)列來表示,堆內(nèi)現(xiàn)有的
n(n <= k)
個(gè)近鄰點(diǎn)中乙帮,距離測(cè)試點(diǎn)最遠(yuǎn)的在堆頂struct neighbor_heap_cmp { bool operator()(const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) { return std::get<1>(i) < std::get<1>(j); } }; typedef std::tuple<size_t, double> neighbor; typedef std::priority_queue<neighbor, std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_; neighbor_heap k_neighbor_heap_;
KD-Tree 類
我們用類 KDTree
表示一個(gè) KD 樹類杜漠,它應(yīng)該具有的功能有建樹
和搜索
。
//(簡(jiǎn)化的代碼察净,完整的代碼詳見最后)
class KDTree {
public:
// 建樹
KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
// 返回樹
tree_node *GetRoot() { return root; }
// 求一個(gè)測(cè)試點(diǎn)的 k 鄰
std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k);
private:
tree_node *root_;
}
尋找切分維和切分點(diǎn)
在建樹之前驾茴,我們還要考慮如何選擇切分維度和切分點(diǎn)。切分維度的選擇有許多氢卡,一般的锈至,可以取 dim = floor % n_features
,即當(dāng)前樹的層數(shù)對(duì)特征數(shù)取余译秦,我們?cè)谶@里使用 dim = argmax(nmax - nmin)
峡捡,即選取當(dāng)前結(jié)點(diǎn)集合中極差最大的維度。
(這里是不完整的代碼筑悴,有些工具函數(shù)的定義請(qǐng)?jiān)斠娡暾创a)
size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
if (points.size() == 1)
return 0;
size_t cur_best_dim = 0;
double cur_largest_spread = -1;
double cur_min_val;
double cur_max_val;
for (size_t dim = 0; dim < n_features; ++dim) {
cur_min_val = GetDimVal(points[0], dim);
cur_max_val = GetDimVal(points[0], dim);
for (const auto &id : points) {
if (GetDimVal(id, dim) > cur_max_val)
cur_max_val = GetDimVal(id, dim);
else if (GetDimVal(id, dim) < cur_min_val)
cur_min_val = GetDimVal(id, dim);
}
if (cur_max_val - cur_min_val > cur_largest_spread) {
cur_largest_spread = cur_max_val - cur_min_val;
cur_best_dim = dim;
}
}
return cur_best_dim;
}
選擇完切分維 k
之后们拙,我們需選取當(dāng)前結(jié)點(diǎn)集合中的結(jié)點(diǎn)在第 k
維的值的中位數(shù) x
作為切分點(diǎn)的值,除去該點(diǎn)之外的點(diǎn)阁吝,第 k
維的值小于等于 x
的砚婆,放入左子樹,反之放入右子樹突勇。
在求中位數(shù)時(shí)装盯,不要全排序,然后取中間的點(diǎn)甲馋,可以采用類似快排的方法埂奈,找到中位數(shù)時(shí)就停止排序,這里我們就不寫算法了摔刁,直接用 C++ 的函數(shù)挥转。
std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
size_t len = points.size();
for (size_t i = 0; i < points.size(); ++i)
get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
std::nth_element(get_mid_buf_,
get_mid_buf_ + len / 2,
get_mid_buf_ + len,
[](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) {
return std::get<1>(i) < std::get<1>(j);
});
return get_mid_buf_[len / 2];
}
建樹
建樹直接按照建立二叉樹的方法即可
tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
size_t dim = FindSplitDim(points);
std::tuple<size_t, double> t = MidElement(points, dim);
size_t arg_mid_val = std::get<0>(t);
double mid_val = std::get<1>(t);
tree_node *node = Malloc(tree_node, 1);
node->left = nullptr;
node->right = nullptr;
node->id = arg_mid_val;
node->split = dim;
std::vector<size_t> left, right;
for (auto &i : points) {
if (i == arg_mid_val)
continue;
if (GetDimVal(i, dim) <= mid_val)
left.emplace_back(i);
else
right.emplace_back(i);
}
if (!left.empty())
node->left = BuildTree(left);
if (!right.empty())
node->right = BuildTree(right);
return node;
}
搜索 K-近鄰的規(guī)則
一般書上所講的都是搜索最近鄰,但是我們這里是搜索 K-近鄰共屈,需要對(duì)書上的算法做少許的擴(kuò)充绑谣。
搜索最近鄰時(shí),我們一般設(shè)置兩個(gè)變量 cur_min_id
和 cur_min_dist
拗引,如果當(dāng)前搜索到的點(diǎn)到測(cè)試點(diǎn)的距離 l < cur_min_dist
時(shí)借宵,我們將上述兩個(gè)變量更新為新點(diǎn)的 id
和 dist
。
相應(yīng)的矾削,在搜索 K-近鄰時(shí)壤玫,我們可以設(shè)置一個(gè)最多有 k
個(gè)元素的大頂堆豁护,這樣,在搜索時(shí)欲间,當(dāng)堆滿時(shí)楚里,只需比較當(dāng)前搜索點(diǎn)的 dist
是否小于堆頂點(diǎn)的 dist
,如果小于猎贴,堆頂出堆班缎,并將當(dāng)前搜索點(diǎn)壓入,反之她渴,則不變达址;當(dāng)堆未滿時(shí),直接將該搜索點(diǎn)壓入趁耗。
搜索 K-近鄰的算法
我們直接使用二叉樹深度優(yōu)先遍歷的非遞歸算法(具體的描述詳見《統(tǒng)計(jì)學(xué)習(xí)方法》第 43 頁算法 3.3)沉唠。
std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) {
std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
std::stack<tree_node *> paths;
tree_node *p = root;
while (p) {
HeapStackPush(paths, p, coor, k);
p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
}
while (!paths.empty()) {
p = paths.top();
paths.pop();
if (!p->left && !p->right)
continue;
if (k_neighbor_heap_.size() < k) {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if (p->right)
HeapStackPush(paths, p->right, coor, k);
} else {
double node_split_val = GetDimVal(p->id, p->split);
double coor_split_val = coor[p->split];
double heap_top_val = std::get<1>(k_neighbor_heap_.top());
if (coor_split_val > node_split_val) {
if (p->right)
HeapStackPush(paths, p->right, coor, k);
if ((coor_split_val - node_split_val) < heap_top_val && p->left)
HeapStackPush(paths, p->left, coor, k);
} else {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if ((node_split_val - coor_split_val) < heap_top_val && p->right)
HeapStackPush(paths, p->right, coor, k);
}
}
}
std::vector<std::tuple<size_t, double>> res;
while (!k_neighbor_heap_.empty()) {
res.emplace_back(k_neighbor_heap_.top());
k_neighbor_heap_.pop();
}
return res;
}
完整代碼
詳見 https://github.com/WiseDoge/libkdtree
完整代碼中除了 KD-Tree 的代碼外,還給出了測(cè)試代碼和 Python 接口代碼苛败,以及一些調(diào)用第三方庫來加速的手段满葛。