KD-Tree 算法的 C++ 實(shí)現(xiàn)

閱讀本文前帆阳,建議查閱相關(guān)資料哺壶,了解 KNN 算法與 KD 樹。

基礎(chǔ)知識(shí)

如圖所示蜒谤,假設(shè)一個(gè)點(diǎn) a 目前的最近鄰點(diǎn)為 b山宾,如果存在相對(duì)于 ba 更近的點(diǎn),那么這個(gè)點(diǎn)一定在以 a 為圓心芭逝,ab 為半徑的圓內(nèi)塌碌。
現(xiàn)右側(cè)的區(qū)域是未知的,如果 a 到分界線的距離 l 大于目前的最近距離 L(圓半徑)旬盯,則沒有必要在右側(cè)的未知區(qū)域繼續(xù)尋找最近鄰點(diǎn)(如圖一)台妆,反之,則要繼續(xù)尋找(如圖二)胖翰。
相應(yīng)的接剩,投射到多維空間,假如切分邊界為第 i 維萨咳,切分點(diǎn)的值為 v(標(biāo)量)懊缺,當(dāng)前最近鄰點(diǎn)為 y(向量),如果目標(biāo)點(diǎn) x(向量) 到切分邊界的距離 |x[i] - v| 滿足以下關(guān)系


時(shí)培他,需要在另一側(cè)繼續(xù)搜索鹃两。

圖1:不需要在右側(cè)未知區(qū)域繼續(xù)搜索的情況

圖2:需要在右側(cè)未知區(qū)域繼續(xù)搜索的情況

通常地,一個(gè)機(jī)器學(xué)習(xí)算法分為 fitpredict 兩個(gè)階段舀凛,基于線性搜索的 KNN 是一種惰性算法俊扳,它將全部的計(jì)算任務(wù)放到了 predict 階段,predict 的時(shí)間復(fù)雜度為 O(n)猛遍,KD 樹之所以比線性搜索快馋记,就是因?yàn)樗鼘⒁徊糠秩蝿?wù)放到了 fit(建立 KD 樹) 階段,從而在搜索時(shí)可以略去大量不必搜索的結(jié)點(diǎn)(最優(yōu)情況下時(shí)間復(fù)雜度為 O(1))懊烤。
上面說的比較簡(jiǎn)單梯醒,關(guān)于 KNN 算法和 KD 樹的詳細(xì)內(nèi)容,請(qǐng)參考李航博士的《統(tǒng)計(jì)學(xué)習(xí)方法》腌紧。

代碼

我們給出部分關(guān)鍵性的代碼茸习。

基本數(shù)據(jù)結(jié)構(gòu)

  • 訓(xùn)練集用一個(gè)一維數(shù)組 double *data 表示,它的長度為 n_samples * n_features壁肋,標(biāo)簽集也用一個(gè)一維數(shù)組 double *labels 表示号胚,它的長度為 n_samples代箭。
  • 樹的結(jié)點(diǎn)用以下數(shù)據(jù)結(jié)構(gòu)表示
     struct tree_node
     {
         size_t id;               // 表示訓(xùn)練集中的第 i 個(gè)數(shù)據(jù)
         size_t split;            // 切分的維度
         tree_node *left, *right; // 左、右子樹
     };
    
  • 一個(gè) KD 樹的模型可用以下結(jié)構(gòu)表示
     struct tree_model
     {
         tree_node *root;        // 根結(jié)點(diǎn)
         const double *datas;    // X
         const double *labels;   // y
         size_t n_samples;       // 樣例數(shù)
         size_t n_features;      // 每個(gè)樣例的特征數(shù)
         double p;               // 距離度量
     };
    
  • 求 K-近鄰時(shí)需要用到大頂堆涕刚,我們直接用 C++ 的優(yōu)先隊(duì)列來表示,堆內(nèi)現(xiàn)有的 n(n <= k) 個(gè)近鄰點(diǎn)中乙帮,距離測(cè)試點(diǎn)最遠(yuǎn)的在堆頂
    struct neighbor_heap_cmp {
        bool operator()(const std::tuple<size_t, double> &i, 
                        const std::tuple<size_t, double> &j) {
              return std::get<1>(i) < std::get<1>(j);
          }
      };
    
    typedef std::tuple<size_t, double> neighbor;
    typedef std::priority_queue<neighbor,
            std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_;
    
    neighbor_heap k_neighbor_heap_;
    

KD-Tree 類

我們用類 KDTree 表示一個(gè) KD 樹類杜漠,它應(yīng)該具有的功能有建樹搜索

//(簡(jiǎn)化的代碼察净,完整的代碼詳見最后)
class KDTree {
public:
    // 建樹
    KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
    // 返回樹
    tree_node *GetRoot() { return root; }
    // 求一個(gè)測(cè)試點(diǎn)的 k 鄰
    std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k);
private:
    tree_node *root_;
}

尋找切分維和切分點(diǎn)

在建樹之前驾茴,我們還要考慮如何選擇切分維度和切分點(diǎn)。切分維度的選擇有許多氢卡,一般的锈至,可以取 dim = floor % n_features,即當(dāng)前樹的層數(shù)對(duì)特征數(shù)取余译秦,我們?cè)谶@里使用 dim = argmax(nmax - nmin)峡捡,即選取當(dāng)前結(jié)點(diǎn)集合中極差最大的維度。

(這里是不完整的代碼筑悴,有些工具函數(shù)的定義請(qǐng)?jiān)斠娡暾创a)
size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
    if (points.size() == 1)
        return 0;
    size_t cur_best_dim = 0;
    double cur_largest_spread = -1;
    double cur_min_val;
    double cur_max_val;
    for (size_t dim = 0; dim < n_features; ++dim) {
        cur_min_val = GetDimVal(points[0], dim);
        cur_max_val = GetDimVal(points[0], dim);
        for (const auto &id : points) {
            if (GetDimVal(id, dim) > cur_max_val)
                cur_max_val = GetDimVal(id, dim);
            else if (GetDimVal(id, dim) < cur_min_val)
                cur_min_val = GetDimVal(id, dim);
        }

        if (cur_max_val - cur_min_val > cur_largest_spread) {
            cur_largest_spread = cur_max_val - cur_min_val;
            cur_best_dim = dim;
        }
    }
    return cur_best_dim;
}

選擇完切分維 k 之后们拙,我們需選取當(dāng)前結(jié)點(diǎn)集合中的結(jié)點(diǎn)在第 k 維的值的中位數(shù) x 作為切分點(diǎn)的值,除去該點(diǎn)之外的點(diǎn)阁吝,第 k 維的值小于等于 x 的砚婆,放入左子樹,反之放入右子樹突勇。
在求中位數(shù)時(shí)装盯,不要全排序,然后取中間的點(diǎn)甲馋,可以采用類似快排的方法埂奈,找到中位數(shù)時(shí)就停止排序,這里我們就不寫算法了摔刁,直接用 C++ 的函數(shù)挥转。

std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
    size_t len = points.size();
    for (size_t i = 0; i < points.size(); ++i)
        get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
    std::nth_element(get_mid_buf_,
                     get_mid_buf_ + len / 2,
                     get_mid_buf_ + len,
                     [](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) {
                         return std::get<1>(i) < std::get<1>(j);
                     });
    return get_mid_buf_[len / 2];
}

建樹

建樹直接按照建立二叉樹的方法即可

tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
    size_t dim = FindSplitDim(points);
    std::tuple<size_t, double> t = MidElement(points, dim);
    size_t arg_mid_val = std::get<0>(t);
    double mid_val = std::get<1>(t);

    tree_node *node = Malloc(tree_node, 1);
    node->left = nullptr;
    node->right = nullptr;
    node->id = arg_mid_val;
    node->split = dim;
    std::vector<size_t> left, right;
    for (auto &i : points) {
        if (i == arg_mid_val)
            continue;
        if (GetDimVal(i, dim) <= mid_val)
            left.emplace_back(i);
        else
            right.emplace_back(i);
    }
    if (!left.empty())
        node->left = BuildTree(left);
    if (!right.empty())
        node->right = BuildTree(right);
    return node;
}

搜索 K-近鄰的規(guī)則

一般書上所講的都是搜索最近鄰,但是我們這里是搜索 K-近鄰共屈,需要對(duì)書上的算法做少許的擴(kuò)充绑谣。
搜索最近鄰時(shí),我們一般設(shè)置兩個(gè)變量 cur_min_idcur_min_dist拗引,如果當(dāng)前搜索到的點(diǎn)到測(cè)試點(diǎn)的距離 l < cur_min_dist 時(shí)借宵,我們將上述兩個(gè)變量更新為新點(diǎn)的 iddist
相應(yīng)的矾削,在搜索 K-近鄰時(shí)壤玫,我們可以設(shè)置一個(gè)最多有 k 個(gè)元素的大頂堆豁护,這樣,在搜索時(shí)欲间,當(dāng)堆滿時(shí)楚里,只需比較當(dāng)前搜索點(diǎn)的 dist 是否小于堆頂點(diǎn)的 dist,如果小于猎贴,堆頂出堆班缎,并將當(dāng)前搜索點(diǎn)壓入,反之她渴,則不變达址;當(dāng)堆未滿時(shí),直接將該搜索點(diǎn)壓入趁耗。

搜索 K-近鄰的算法

我們直接使用二叉樹深度優(yōu)先遍歷的非遞歸算法(具體的描述詳見《統(tǒng)計(jì)學(xué)習(xí)方法》第 43 頁算法 3.3)沉唠。

std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) {
    std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
    std::stack<tree_node *> paths;
    tree_node *p = root;

    while (p) {
        HeapStackPush(paths, p, coor, k);
        p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
    }
    while (!paths.empty()) {
        p = paths.top();
        paths.pop();

        if (!p->left && !p->right)
            continue;

        if (k_neighbor_heap_.size() < k) {
            if (p->left)
                HeapStackPush(paths, p->left, coor, k);
            if (p->right)
                HeapStackPush(paths, p->right, coor, k);
        } else {
            double node_split_val = GetDimVal(p->id, p->split);
            double coor_split_val = coor[p->split];
            double heap_top_val = std::get<1>(k_neighbor_heap_.top());
            if (coor_split_val > node_split_val) {
                if (p->right)
                    HeapStackPush(paths, p->right, coor, k);
                if ((coor_split_val - node_split_val) < heap_top_val && p->left)
                    HeapStackPush(paths, p->left, coor, k);
            } else {
                if (p->left)
                    HeapStackPush(paths, p->left, coor, k);
                if ((node_split_val - coor_split_val) < heap_top_val && p->right)
                    HeapStackPush(paths, p->right, coor, k);
            }
        }
    }
    std::vector<std::tuple<size_t, double>> res;

    while (!k_neighbor_heap_.empty()) {
        res.emplace_back(k_neighbor_heap_.top());
        k_neighbor_heap_.pop();
    }
    return res;
}

完整代碼

詳見 https://github.com/WiseDoge/libkdtree
完整代碼中除了 KD-Tree 的代碼外,還給出了測(cè)試代碼和 Python 接口代碼苛败,以及一些調(diào)用第三方庫來加速的手段满葛。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市罢屈,隨后出現(xiàn)的幾起案子纱扭,更是在濱河造成了極大的恐慌,老刑警劉巖儡遮,帶你破解...
    沈念sama閱讀 218,525評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件乳蛾,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡鄙币,警方通過查閱死者的電腦和手機(jī)肃叶,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,203評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來十嘿,“玉大人因惭,你說我怎么就攤上這事〖ㄖ裕” “怎么了蹦魔?”我有些...
    開封第一講書人閱讀 164,862評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長咳燕。 經(jīng)常有香客問我勿决,道長,這世上最難降的妖魔是什么招盲? 我笑而不...
    開封第一講書人閱讀 58,728評(píng)論 1 294
  • 正文 為了忘掉前任低缩,我火速辦了婚禮,結(jié)果婚禮上曹货,老公的妹妹穿的比我還像新娘咆繁。我一直安慰自己讳推,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,743評(píng)論 6 392
  • 文/花漫 我一把揭開白布玩般。 她就那樣靜靜地躺著银觅,像睡著了一般。 火紅的嫁衣襯著肌膚如雪坏为。 梳的紋絲不亂的頭發(fā)上设拟,一...
    開封第一講書人閱讀 51,590評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音久脯,去河邊找鬼。 笑死镰吆,一個(gè)胖子當(dāng)著我的面吹牛帘撰,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播万皿,決...
    沈念sama閱讀 40,330評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼摧找,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了牢硅?” 一聲冷哼從身側(cè)響起蹬耘,我...
    開封第一講書人閱讀 39,244評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎减余,沒想到半個(gè)月后综苔,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,693評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡位岔,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,885評(píng)論 3 336
  • 正文 我和宋清朗相戀三年如筛,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片抒抬。...
    茶點(diǎn)故事閱讀 40,001評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡杨刨,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出擦剑,到底是詐尸還是另有隱情妖胀,我是刑警寧澤,帶...
    沈念sama閱讀 35,723評(píng)論 5 346
  • 正文 年R本政府宣布惠勒,位于F島的核電站赚抡,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏纠屋。R本人自食惡果不足惜怕品,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,343評(píng)論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望巾遭。 院中可真熱鬧肉康,春花似錦闯估、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,919評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至炫乓,卻和暖如春刚夺,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背末捣。 一陣腳步聲響...
    開封第一講書人閱讀 33,042評(píng)論 1 270
  • 我被黑心中介騙來泰國打工侠姑, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人箩做。 一個(gè)月前我還...
    沈念sama閱讀 48,191評(píng)論 3 370
  • 正文 我出身青樓莽红,卻偏偏與公主長得像,于是被迫代替她去往敵國和親邦邦。 傳聞我的和親對(duì)象是個(gè)殘疾皇子安吁,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,955評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容