Faiss 源碼解析

image.png

<meta charset="utf-8">

Faiss 源碼解析

faissfacebook 開源的一個專門用于做高維向量的相似性搜索的庫鸭廷,有 c++python 的接口胳蛮;目前項目地址在 https://github.com/facebookresearch/faiss。本文主要結(jié)合 faiss 的官方示例,介紹如何使用 faiss 以及 暴力/IVF/IVFPQ 檢索算法在 faiss 的具體實現(xiàn)。

檢索算法介紹

檢索算法的介紹可以參考 科普个盆,本文主要關(guān)注3種檢索算法:

  1. 暴力搜索:顧名思義,querybase 一一比對朵栖,選擇最近的
  2. IVF:首先在具有代表性的數(shù)據(jù)上訓(xùn)練聚類中心颊亮,然后將 base 加入到最近的聚類中心的桶里,在 search 的時候陨溅,query 先和聚類中心比對终惑,再在一定數(shù)目的桶里做暴力搜索
  3. IVFPQ:在 IVF 的基礎(chǔ)上,將 basePQ 量化门扇,加速比對

faiss 的編譯與安裝

可以參考官方給出的編譯方法狠鸳,這里我沒有安裝 cuda,所以采用的命令是

./configure --without-cuda && make

在編譯完 faiss 之后悯嗓,我們對官方提供的示例也進行編譯,路徑在 ./tutorial/cpp 下卸察,cd到目錄下直接 make 就可以了

如何使用 faiss

官方總共提供了五個示例脯厨,其中有兩個是 gpu 版本的,三個是 cpu 版本的坑质,我們這里主要關(guān)注 cpu 的合武,分別是 1-Flat.cpp2-IVFFLAT.cpp涡扼,3-IVFPQ.cpp稼跳,分別對應(yīng)著暴力算法檢索,IVF 算法檢索吃沪,IVFPQ 算法檢索汤善。不同的算法在用戶側(cè)代碼基本一致,我們選取 IVFPQ 做簡單介紹。

#include <cstdio>
#include <cstdlib>

#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>


int main() {
    int d = 64;                            // 特征維度
    int nb = 100000;                       // base 樣本數(shù)量
    int nq = 10000;                        // query 樣本數(shù)量

    float *xb = new float[d * nb];
    float *xq = new float[d * nq];

    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++)
            xb[d * i + j] = drand48();
        xb[d * i] += i / 1000.;
    } // 隨機初始化 base 數(shù)據(jù)

    for(int i = 0; i < nq; i++) {
        for(int j = 0; j < d; j++)
            xq[d * i + j] = drand48();
        xq[d * i] += i / 1000.;
    }    // 隨機初始化 query 數(shù)據(jù)


    int nlist = 100;  // 聚類中心個數(shù)
    int k = 4;
    int m = 8;                             // bytes per vector
    faiss::IndexFlatL2 quantizer(d);       // 初始化用 L2 暴力 search 的 index
    faiss::IndexIVFPQ index(&quantizer, d, nlist, m, 8); // 初始化 ivfpq 的 index红淡,用 L2 暴力 search 的 index 初始化
    index.train(nb, xb); // 訓(xùn)練 index
    index.add(nb, xb); // 將 base 數(shù)據(jù)加入到 index 中不狮,用于之后的搜索

    {       // search xq
        long *I = new long[k * nq];
        float *D = new float[k * nq];

        index.nprobe = 10; // 搜索 10 個中心點
        index.search(nq, xq, k, D, I);

        printf("I=\n");
        for(int i = nq - 5; i < nq; i++) {
            for(int j = 0; j < k; j++)
                printf("%5ld ", I[i * k + j]);
            printf("\n");
        }

        delete [] I;
        delete [] D;
    }



    delete [] xb;
    delete [] xq;

    return 0;
}

這段代碼主要包括了四個部分,分別是

  1. 初始化 base/query 數(shù)據(jù)和 index
  2. 訓(xùn)練 index
  3. 加入baseindex
  4. querysearch

其中在旱,使用 faiss 主要包含了三步摇零。初始化數(shù)據(jù)準(zhǔn)備不用多說,faiss 中要求的數(shù)據(jù)格式都是 n * d 的矩陣格式桶蝎,然后被展平到一維 float 數(shù)組中驻仅。剩下的兩步,都是對 index 進行操作登渣。

源碼解析

檢索流程

參考官方給的例子噪服,檢索分為三步:trainadd绍豁,search芯咧,不同的檢索算法,體現(xiàn)在使用不同的 index 進行這三步上

  1. train:選取有代表性的數(shù)據(jù)竹揍,訓(xùn)練 index
  2. add:將 base 數(shù)據(jù)加入到 index
  3. search:對于給定的 query敬飒,返回其對應(yīng)的在底庫中的 topk

重要類

Index

index 的基類,后續(xù)各種各樣的檢索算法芬位,都會繼承這個基類或者這個類的派生類无拗,然后實現(xiàn)具體的方法,在這個類中昧碉,有如下的數(shù)據(jù)成員:

  • d:維度英染,每個向量的維度
  • ntotal:索引的向量的數(shù)目,可以理解成檢索時的 base 數(shù)目
  • metric_type:檢索時使用的 metric 類型被饿,比如 L2四康,內(nèi)積等

IndexFlat

用于做暴力搜索的 index 類,直接繼承 index狭握。暴力搜索思路很簡單孽江,無需 train滑频,add 的所有 base 都被存儲起來,然后在 search 的時候把 query 和所有 base 進行比對,選取最近的似袁。我們看下具體實現(xiàn)凳宙。

  • add

add 就是把所有的 base 都存儲起來

void IndexFlat::add (idx_t n, const float *x) {
    xb.insert(xb.end(), x, x + n * d);
    ntotal += n;
}

  • Search

Search 的時候殴胧,根據(jù) metric type 的不同荚坞,返回 querytopk。具體計算時采用了 openmpsse/avx 優(yōu)化

void IndexFlat::search (idx_t n, const float *x, idx_t k,
                               float *distances, idx_t *labels) const
{
    // we see the distances and labels as heaps

    if (metric_type == METRIC_INNER_PRODUCT) {
        float_minheap_array_t res = {
            size_t(n), size_t(k), labels, distances};
        knn_inner_product (x, xb.data(), d, n, ntotal, &res); //函數(shù)內(nèi)部有并行優(yōu)化
    } else if (metric_type == METRIC_L2) {
        float_maxheap_array_t res = {
            size_t(n), size_t(k), labels, distances};
        knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
    } else {
        float_maxheap_array_t res = {
            size_t(n), size_t(k), labels, distances};
        knn_extra_metrics (x, xb.data(), d, n, ntotal,
                           metric_type, metric_arg,
                           &res);
    }
}

Clustering

實現(xiàn) K-means 聚類的類今妄,提供train 郑口,需要訓(xùn)練數(shù)據(jù)和 index(用于 search 最近的向量)鸳碧,結(jié)果得到訓(xùn)練數(shù)據(jù)的類中心向量,如果是量化的向量潘酗,那么還需要提供量化使用的 index codec杆兵,我們?nèi)コ炕牟糠郑豢?float 數(shù)據(jù)

核心代碼如下仔夺,包括如下部分:

  • search過程琐脏,將聚類中心作為底庫加入到 index 中,并對訓(xùn)練數(shù)據(jù)做 search缸兔,得到 assign
  • 計算新的聚類中心日裙,計算新的聚類中心的代碼在 compute_centroids中,具體就是對于相同的類別的向量惰蜜,將向量的均值作為新的中心昂拂,在實現(xiàn)上,利用 openmp 進行了并行優(yōu)化

重復(fù)以上兩步抛猖,就可以得到最優(yōu)的聚類中心

void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
                                const Index * codec, Index & index,
                                const float *weights) {
  // 前處理省略  
  for (int redo = 0; redo < nredo; redo++) {

        if (verbose && nredo > 1) {
            printf("Outer iteration %d / %d\n", redo, nredo);
        }

        // initialize (remaining) centroids with random points from the dataset
        centroids.resize (d * k);
        std::vector<int> perm (nx);

        rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);

        for (int i = n_input_centroids; i < k ; i++) {
          memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
        }

        post_process_centroids ();

        // prepare the index

        if (index.ntotal != 0) {
            index.reset();
        }

        index.add (k, centroids.data());

        // k-means iterations

        float err = 0;
        for (int i = 0; i < niter; i++) {
            double t0s = getmillisecs();
                        index.search (nx, reinterpret_cast<const float *>(x), 1,
                          dis.get(), assign.get());

            InterruptCallback::check();
            t_search_tot += getmillisecs() - t0s;

            // accumulate error
            err = 0;
            for (int j = 0; j < nx; j++) {
                err += dis[j];
            }

            // update the centroids
            std::vector<float> hassign (k);

            size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
            compute_centroids (
                  d, k, nx, k_frozen,
                  x, codec, assign.get(), weights,
                  hassign.data(), centroids.data()
            );

            index.reset ();
            if (update_index) {
                index.train (k, centroids.data());
            }

            index.add (k, centroids.data());
            InterruptCallback::check ();
        }

    }
    //保存最優(yōu)聚類中心
    if (nredo > 1) {
        centroids = best_centroids;
        iteration_stats = best_obj;
        index.reset();
        index.add(k, best_centroids.data());
    }

}

void compute_centroids (size_t d, size_t k, size_t n,
                       size_t k_frozen,
                       const uint8_t * x, const Index *codec,
                       const int64_t * assign,
                       const float * weights,
                       float * hassign,
                       float * centroids)
{
    k -= k_frozen;
    centroids += k_frozen * d;

    memset (centroids, 0, sizeof(*centroids) * d * k);

    size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);

#pragma omp parallel
    {
        int nt = omp_get_num_threads();
        int rank = omp_get_thread_num();

        // this thread is taking care of centroids c0:c1
        size_t c0 = (k * rank) / nt;
        size_t c1 = (k * (rank + 1)) / nt;
        std::vector<float> decode_buffer (d);

        for (size_t i = 0; i < n; i++) {
            int64_t ci = assign[i];
            assert (ci >= 0 && ci < k + k_frozen);
            ci -= k_frozen;
            if (ci >= c0 && ci < c1)  {
                float * c = centroids + ci * d;
                const float * xi;
                if (!codec) {
                    xi = reinterpret_cast<const float*>(x + i * line_size);
                } else {
                    float *xif = decode_buffer.data();
                    codec->sa_decode (1, x + i * line_size, xif);
                    xi = xif;
                }
                if (weights) {
                    float w = weights[i];
                    hassign[ci] += w;
                    for (size_t j = 0; j < d; j++) {
                        c[j] += xi[j] * w;
                    }
                } else {
                    hassign[ci] += 1.0;
                    for (size_t j = 0; j < d; j++) {
                        c[j] += xi[j];
                    }
                }
            }
        }

    }

#pragma omp parallel for
    for (size_t ci = 0; ci < k; ci++) {
        if (hassign[ci] == 0) {
            continue;
        }
        float norm = 1 / hassign[ci];
        float * c = centroids + ci * d;
        for (size_t j = 0; j < d; j++) {
            c[j] *= norm;
        }
    }

}

IndexIVF

用于做 IVF 搜索的 index 類格侯。

  • train

ivf 算法會把給定的數(shù)據(jù)進行聚類,得到固定數(shù)目的聚類中心财著。具體的联四,就是 train_q1? 的過程,train_residual 在 ivf 中是一個空函數(shù)

void IndexIVF::train (idx_t n, const float *x)
{
    train_q1 (n, x, verbose, metric_type);
    train_residual (n, x);
    is_trained = true;
}

void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
  if (verbose)
    printf("IndexIVF: no residual training\n");
  // does nothing by default
}

train_q1用的是 Level1Quantizer 的具體實現(xiàn)撑教,如下朝墩,對訓(xùn)練數(shù)據(jù)進行聚類,得到聚類中心并保存下來

void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
{
    // 省略無關(guān)代碼
    Clustering clus (d, nlist, cp);
    quantizer->reset();
    if (clustering_index) {
      clus.train (n, x, *clustering_index);
      quantizer->add (nlist, clus.centroids.data());
    } else {
      clus.train (n, x, *quantizer);
    }
    quantizer->is_trained = true;
}

  • add

  • 分片伟姐。根據(jù)輸入的大小收苏,按照固定的大小依次進行 add

  • 建立 invlists。根據(jù) train得到的聚類中心(保存在 quantizer 中)愤兵,每一個類中心對應(yīng) invlists 中的一個桶鹿霸。

  • invlists 的桶里加入 base。利用了 openmp 進行了并行加速

void IndexIVF::add (idx_t n, const float * x)
{
    add_with_ids (n, x, nullptr);
}

void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
{
    // do some blocking to avoid excessive allocs
    idx_t bs = 65536;
    if (n > bs) {
        for (idx_t i0 = 0; i0 < n; i0 += bs) {
            idx_t i1 = std::min (n, i0 + bs);
            if (verbose) {
                printf("   IndexIVF::add_with_ids %ld:%ld\n", i0, i1);
            }
            add_with_ids (i1 - i0, x + i0 * d,
                          xids ? xids + i0 : nullptr);
        }
        return;
    }

    std::unique_ptr<idx_t []> idx(new idx_t[n]);
    quantizer->assign (n, x, idx.get());
    size_t nadd = 0, nminus1 = 0;

#pragma omp parallel reduction(+: nadd)
    {
        int nt = omp_get_num_threads();
        int rank = omp_get_thread_num();

        // each thread takes care of a subset of lists
        for (size_t i = 0; i < n; i++) {
            idx_t list_no = idx [i];
            if (list_no >= 0 && list_no % nt == rank) {
                idx_t id = xids ? xids[i] : ntotal + i;
                size_t ofs = invlists->add_entry (
                     list_no, id,
                     flat_codes.get() + i * code_size
                );

                dm_adder.add (i, list_no, ofs);

                nadd++;
            } else if (rank == 0 && list_no == -1) {
                dm_adder.add (i, -1, 0);
            }
        }
    }

    ntotal += n;
}

  • search

  • Search corse_dis秆乳。搜索離 query 最近的聚類中心

  • Search invlists懦鼠。在最近的 nprobe 個聚類中心對應(yīng)的 invlists 中進行暴力 heap 搜索,得到 topk

void IndexIVF::search (idx_t n, const float *x, idx_t k,
                         float *distances, idx_t *labels) const
{
    std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
    std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

    double t0 = getmillisecs();
    quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
    indexIVF_stats.quantization_time += getmillisecs() - t0;

    t0 = getmillisecs();
    invlists->prefetch_lists (idx.get(), n * nprobe);

    search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
                        distances, labels, false);
    indexIVF_stats.search_time += getmillisecs() - t0;
}

ProductQuantizer

用來做 PQ 量化算法的類矫夷,關(guān)于 PQ 量化算法,可以參考 pq算法憋槐。簡單來說双藕,我們需要得到用來量化的碼本,然后我們可以對輸入的向量進行解碼和編碼阳仔。得到碼本的過程在 ProductQuantizer::train 中忧陪,包含

  • 將輸入向量按照維度切分成 PQ 段扣泊,每段的維度是 dsub
  • 得到每段的聚類中心,這就是碼本

編碼和解碼的過程就是將輸入向量轉(zhuǎn)化為碼本里的 idx嘶摊,可以看出延蟹,量化是存在一定的誤差,其中叶堆,PQ 越大阱飘,誤差越小

void ProductQuantizer::train (int n, const float * x)
{
    if (train_type != Train_shared) {
        train_type_t final_train_type;
        final_train_type = train_type;
        if (train_type == Train_hypercube ||
            train_type == Train_hypercube_pca) {
            if (dsub < nbits) {
                final_train_type = Train_default;
                printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
                        nbits, dsub);
            }
        }

        float * xslice = new float[n * dsub];
        ScopeDeleter<float> del (xslice);
        for (int m = 0; m < M; m++) {
            for (int j = 0; j < n; j++)
                memcpy (xslice + j * dsub,
                        x + j * d + m * dsub,
                        dsub * sizeof(float));

            Clustering clus (dsub, ksub, cp);

            // we have some initialization for the centroids
            if (final_train_type != Train_default) {
                clus.centroids.resize (dsub * ksub);
            }

            switch (final_train_type) {
            case Train_hypercube:
                init_hypercube (dsub, nbits, n, xslice,
                                clus.centroids.data ());
                break;
            case  Train_hypercube_pca:
                init_hypercube_pca (dsub, nbits, n, xslice,
                                    clus.centroids.data ());
                break;
            case  Train_hot_start:
                memcpy (clus.centroids.data(),
                        get_centroids (m, 0),
                        dsub * ksub * sizeof (float));
                break;
            default: ;
            }

            if(verbose) {
                clus.verbose = true;
                printf ("Training PQ slice %d/%zd\n", m, M);
            }
            IndexFlatL2 index (dsub);
            clus.train (n, xslice, assign_index ? *assign_index : index);
            set_params (clus.centroids.data(), m);
        }

    } else {

        Clustering clus (dsub, ksub, cp);

        if(verbose) {
            clus.verbose = true;
            printf ("Training all PQ slices at once\n");
        }

        IndexFlatL2 index (dsub);

        clus.train (n * M, x, assign_index ? *assign_index : index);
        for (int m = 0; m < M; m++) {
            set_params (clus.centroids.data(), m);
        }

    }
}

IndexIVFPQ

ivfpq 算法在 ivf 的基礎(chǔ)上,對 basepq虱颗。大家可以自行參考代碼

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末沥匈,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子忘渔,更是在濱河造成了極大的恐慌高帖,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,607評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件畦粮,死亡現(xiàn)場離奇詭異散址,居然都是意外死亡,警方通過查閱死者的電腦和手機宣赔,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,239評論 3 395
  • 文/潘曉璐 我一進店門预麸,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人拉背,你說我怎么就攤上這事师崎。” “怎么了椅棺?”我有些...
    開封第一講書人閱讀 164,960評論 0 355
  • 文/不壞的土叔 我叫張陵犁罩,是天一觀的道長。 經(jīng)常有香客問我两疚,道長床估,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,750評論 1 294
  • 正文 為了忘掉前任诱渤,我火速辦了婚禮丐巫,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘勺美。我一直安慰自己递胧,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,764評論 6 392
  • 文/花漫 我一把揭開白布赡茸。 她就那樣靜靜地躺著缎脾,像睡著了一般。 火紅的嫁衣襯著肌膚如雪占卧。 梳的紋絲不亂的頭發(fā)上遗菠,一...
    開封第一講書人閱讀 51,604評論 1 305
  • 那天联喘,我揣著相機與錄音,去河邊找鬼辙纬。 笑死豁遭,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的贺拣。 我是一名探鬼主播蓖谢,決...
    沈念sama閱讀 40,347評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼纵柿!你這毒婦竟也來了蜈抓?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,253評論 0 276
  • 序言:老撾萬榮一對情侶失蹤昂儒,失蹤者是張志新(化名)和其女友劉穎沟使,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體渊跋,經(jīng)...
    沈念sama閱讀 45,702評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡腊嗡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,893評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了拾酝。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片燕少。...
    茶點故事閱讀 40,015評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖蒿囤,靈堂內(nèi)的尸體忽然破棺而出客们,到底是詐尸還是另有隱情,我是刑警寧澤材诽,帶...
    沈念sama閱讀 35,734評論 5 346
  • 正文 年R本政府宣布底挫,位于F島的核電站,受9級特大地震影響脸侥,放射性物質(zhì)發(fā)生泄漏建邓。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,352評論 3 330
  • 文/蒙蒙 一睁枕、第九天 我趴在偏房一處隱蔽的房頂上張望官边。 院中可真熱鬧,春花似錦外遇、人聲如沸注簿。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,934評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽诡渴。三九已至,卻和暖如春塔嬉,著一層夾襖步出監(jiān)牢的瞬間玩徊,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,052評論 1 270
  • 我被黑心中介騙來泰國打工谨究, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留恩袱,地道東北人。 一個月前我還...
    沈念sama閱讀 48,216評論 3 371
  • 正文 我出身青樓胶哲,卻偏偏與公主長得像畔塔,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子鸯屿,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,969評論 2 355