1. KNN (k-Nearest Neighbor)
k近鄰算法是一種基本分類與回歸方法。k近鄰法假設(shè)給定一個(gè)訓(xùn)練數(shù)據(jù)集,其中的實(shí)例類別一定。分類時(shí)坡脐,對(duì)新的實(shí)例,根據(jù)其k個(gè)最近鄰的訓(xùn)練實(shí)例的類別房揭,通過多數(shù)表決等方法進(jìn)行預(yù)測(cè)备闲。因此k近鄰算法不具有顯式的學(xué)習(xí)過程。k近鄰實(shí)際上是利用訓(xùn)練數(shù)據(jù)集對(duì)特征向量空間進(jìn)行劃分捅暴,并作為其分類的模型恬砂。
k近鄰的三個(gè)基本要素是:k值的選擇,距離的度量以及分類決策規(guī)則蓬痒。
1.1 距離的度量
特征空間中兩個(gè)實(shí)例點(diǎn)的距離是兩個(gè)實(shí)例點(diǎn)相似程度的反映泻骤,常見的距離度量有:歐式距離,Lp距離等等(距離度量可以參考這篇博文: 從K近鄰算法梧奢、距離度量談到KD樹狱掂、SIFT+BBF算法 - July_ - 博客園)。不同的距離度量得到的結(jié)果可能是不一樣的亲轨。
1.2 k值的選擇
如果選擇較小的k,就相當(dāng)于用較小的領(lǐng)域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè)趋惨,只有與輸入實(shí)例較近的訓(xùn)練實(shí)例才會(huì)對(duì)預(yù)測(cè)結(jié)果起作用,但是這樣會(huì)導(dǎo)致預(yù)測(cè)結(jié)果對(duì)近鄰點(diǎn)非常敏感惦蚊。如果近鄰的實(shí)例點(diǎn)恰巧是噪聲器虾,預(yù)測(cè)就會(huì)出錯(cuò)。也就是說(shuō)蹦锋,k值的減少就意味著整體模型變得復(fù)雜兆沙,容易過擬合。
如果選擇較大的k值莉掂,與輸入實(shí)例較遠(yuǎn)的(不相似的)訓(xùn)練實(shí)例也會(huì)對(duì)預(yù)測(cè)起作用葛圃,使得預(yù)測(cè)發(fā)生錯(cuò)誤。k值的增加意味著整體模型變得簡(jiǎn)單巫湘。
1.3分類決策規(guī)則
可以選擇多數(shù)表決規(guī)則装悲,甚至加上距離的遠(yuǎn)近(即把距離當(dāng)做權(quán)重),決定輸入實(shí)例是哪個(gè)類別尚氛。
2.kd樹
實(shí)現(xiàn)k近鄰算法是诀诊,主要考慮的問題是如何對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行快速k近鄰搜索。這點(diǎn)在特征空間的維數(shù)大及訓(xùn)練數(shù)據(jù)容量大時(shí)尤其必要阅嘶。為了提高k近鄰搜索的效率属瓣,可以考慮使用特殊的結(jié)構(gòu)存儲(chǔ)訓(xùn)練數(shù)據(jù)载迄,以減少距離計(jì)算次數(shù)÷胀埽可以采用kd-tree护昧。
k近鄰搜索算法思路如下:
輸入:已構(gòu)造的kd樹:目標(biāo)點(diǎn)x;(輔助結(jié)構(gòu)粗截,數(shù)組)
輸出:x的k近鄰
公共操作P:在訪問每個(gè)結(jié)點(diǎn)時(shí)惋耙,若數(shù)組容量不足k,則將該結(jié)點(diǎn)加入數(shù)組熊昌,若堆容量以達(dá)到k绽榛,則比較當(dāng)前節(jié)點(diǎn)是否比數(shù)組尾元素與x的距離更近,若更近則以當(dāng)前節(jié)點(diǎn)代替數(shù)組尾結(jié)點(diǎn)婿屹,并調(diào)整數(shù)組灭美。
(1)從根節(jié)點(diǎn)出發(fā),遞歸地向下訪問kd樹昂利,若目標(biāo)x當(dāng)前維的坐標(biāo)小于切分點(diǎn)的坐標(biāo)届腐,則移動(dòng)到左子結(jié)點(diǎn),否則移動(dòng)到右子結(jié)點(diǎn)蜂奸,知道結(jié)點(diǎn)為葉節(jié)點(diǎn)為止犁苏。執(zhí)行公共操作P。
(2)遞歸的向上回退窝撵,在每個(gè)節(jié)點(diǎn)進(jìn)行以下操作:
(a)執(zhí)行公共操作P傀顾。
(b)檢查該子結(jié)點(diǎn)的兄弟結(jié)點(diǎn)區(qū)域是否有比堆頂元素更近的點(diǎn)或堆容量未滿襟铭。具體的碌奉,檢查另一子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否與以目標(biāo)點(diǎn)為求心,以目標(biāo)點(diǎn)與堆頂元素距離為半徑的球體相交寒砖。
如果相交或容量未滿赐劣,以另一子結(jié)點(diǎn)為根節(jié)點(diǎn)執(zhí)行(1)。
(4)當(dāng)回退到根節(jié)點(diǎn)時(shí)哩都,搜索結(jié)束魁兼,堆中實(shí)例即為所求實(shí)例。
注:前幾天剛做完機(jī)器學(xué)習(xí)的大作業(yè)漠嵌,實(shí)現(xiàn)了KNN算法咐汞,是針對(duì)iris數(shù)據(jù)集的。特此總結(jié)
代碼實(shí)現(xiàn):
代碼不友好H迓埂;骸!约炎!
kd_tree.h
#include<stdlib.h>
#include<vector>
#include<math.h>
#include<algorithm>
#include<iostream>
using namespace std;
#define K 4 ////輸入數(shù)據(jù)的維度
class kd_tree_node{
//成員對(duì)象
public:
vector<float> node_data; //存儲(chǔ)該節(jié)點(diǎn)樣本數(shù)據(jù)
string node_type; //是葉節(jié)點(diǎn)還是樹干(樹枝)
int numpoints; //訓(xùn)練數(shù)據(jù)的個(gè)數(shù)植阴,或者說(shuō)這個(gè)二叉樹有多少個(gè)節(jié)點(diǎn)
int index; //節(jié)點(diǎn)數(shù)據(jù)在原數(shù)據(jù)中的索引位置
int splitdim; //該節(jié)點(diǎn)進(jìn)行分裂是的蟹瘾,選擇的分裂維度
double splitval; //該節(jié)點(diǎn)選擇的分裂值
kd_tree_node* left_node, *right_node,*parents;
};
vector<int> median_data(vector<vector<float>>data, vector<int> index, int splitdim_num);//排 序函數(shù),返回排好序的索引序列
/ /遞歸實(shí)現(xiàn)創(chuàng)建kd_tree
kd_tree_node* create_kd_tree(vector<vector<float>>data,int split_dim_num,vector<int>index,kd_tree_node *parent){
//初始化掠手,構(gòu)造根節(jié)點(diǎn)憾朴。創(chuàng)建一個(gè)節(jié)點(diǎn)kd_tree_node;
kd_tree_node * root = new kd_tree_node;
root->numpoints = data.size();
//判斷結(jié)束條件
if (index.size() == 1){
//設(shè)置成員變量
root->left_node = NULL;
root->right_node = NULL;
root->node_type = "leaf";
root->splitdim = -1;
root->splitval = 0;
root->parents = parent;
root->node_data = data[index[0]];
root->index = index[0];
}
else{
//排序,分裂
index = median_data(data, index, split_dim_num);
int length = index.size();
vector<int>left, right;
for (int i = 0; i < index.size(); i++){
if (i < length / 2)
left.push_back(index[i]);
else{
if (i>length/2)
right.push_back(index[i]);
}
}
//設(shè)置類成員變量
if (left.size() >= 1){
root->left_node = create_kd_tree(data, split_dim_num % K + 1, left, root);
}
else
root->left_node = NULL;
if (right.size() >= 1){
root->right_node = create_kd_tree(data, split_dim_num % K + 1, right, root);
}
else
root->right_node = NULL;
root->node_type = "body";
root->splitdim = split_dim_num;
root->splitval = data[index[length/2]][split_dim_num - 1]; //(<)
root->parents = parent;
root->node_data = data[index[length/2]];
root->index = index[length / 2];
}
return root;
}
//排序函數(shù)喷鸽,返回排好序的索引序列
vector<int> median_data(vector<vector<float>>data, vector<int> index, int splitdim_num){
vector<float>temp;
int length = index.size();
for (int i = 0; i < length; i++){
temp.push_back(data[index[i]][splitdim_num - 1]);
}
//升序排序众雷,冒泡法
int index_temp = 0;
float a = 0;
for (int i = 0; i < length - 1; i++){
for (int j = 0; j < length -i- 1; j++){
if (temp[j]>temp[j + 1]){
a = temp[j + 1];
temp[j + 1] = temp[j];
temp[j] = a;
index_temp = index[j + 1];
index[j + 1] = index[j];
index[j] = index_temp;
}
}
}
return index;
}
//k-近鄰搜索算法
/*****公共操作P:在訪問每個(gè)結(jié)點(diǎn)時(shí),若最大堆容量不足k做祝,則將該結(jié)點(diǎn)加入最大堆报腔,若堆容量以達(dá)到k,則 比較當(dāng)前節(jié)點(diǎn)是否比堆頂元素與x的距離更近剖淀,若更近則以當(dāng)前節(jié)點(diǎn)代替堆頂結(jié)點(diǎn)纯蛾,并調(diào)整堆。
(1)從根節(jié)點(diǎn)出發(fā)纵隔,遞歸地向下訪問kd樹翻诉,若目標(biāo)x當(dāng)前維的坐標(biāo)小于切分點(diǎn)的坐標(biāo),則移動(dòng)到左子結(jié)點(diǎn)捌刮,否則移動(dòng)到右子結(jié)點(diǎn)碰煌,知道結(jié)點(diǎn)為葉節(jié)點(diǎn)為止。執(zhí)行公共操作P绅作。
(2)遞歸的向上回退芦圾,在每個(gè)節(jié)點(diǎn)進(jìn)行以下操作:
(a)執(zhí)行公共操作P。
(b)檢查該子結(jié)點(diǎn)的兄弟結(jié)點(diǎn)區(qū)域是否有比堆頂元素更近的點(diǎn)或堆容量未滿俄认。具體的个少,檢查另一子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否與以目標(biāo)點(diǎn)為求心,以目標(biāo)點(diǎn)與堆頂元素距離為半徑的球體相交眯杏。
如果相交或容量未滿夜焦,以另一子結(jié)點(diǎn)為根節(jié)點(diǎn)執(zhí)行(1)。
(4)當(dāng)回退到根節(jié)點(diǎn)時(shí)岂贩,搜索結(jié)束茫经,堆中實(shí)例即為所求實(shí)例。
****/
/*************
function: knn_k_search()
input:
test_data:測(cè)試數(shù)據(jù)
near_num:需要尋找?guī)讉€(gè)近鄰元素萎津,near_num
root:kd樹的根節(jié)點(diǎn)
output: 返回找到原數(shù)據(jù)中near_num個(gè)近鄰點(diǎn)在原數(shù)據(jù)中的index(索引)數(shù)組卸伞。
*************/
vector<int> knn_k_search(vector<float>test_data, int near_num, kd_tree_node *root){
vector<int> near_k_node_index(0); //記錄下k個(gè)近鄰點(diǎn)的索引
vector<double>near_k_nodedist(0); //記錄下k個(gè)緊鄰點(diǎn)的距離
vector<kd_tree_node*> near_k_nodepoint; //記錄下k個(gè)近鄰點(diǎn)的kd_tree指針
if (root->numpoints < near_num){
cout << "do not have enough points" << endl;
return near_k_node_index;
}
//首先找到葉節(jié)點(diǎn),并記錄下搜索的路徑
kd_tree_node * leaf_node = NULL;
int split_dim = 1;
leaf_node = root;
vector<kd_tree_node*>path;
path.push_back(leaf_node);
while (leaf_node->node_type != "leaf"){
split_dim = leaf_node->splitdim;
if (test_data[split_dim - 1] <= leaf_node->splitval)//分裂
leaf_node = leaf_node->left_node;
else{
if (leaf_node->right_node == NULL)
leaf_node = leaf_node->left_node;//如果只有左子樹锉屈,那么葉節(jié)點(diǎn)就選是左子樹
else
leaf_node = leaf_node->right_node;
}
path.push_back(leaf_node);
}
path.pop_back();
//copy一份路徑
vector<kd_tree_node*>path_copy = path;
//k近鄰搜索荤傲,回溯,部念,找到K個(gè)最接近給定測(cè)試數(shù)據(jù)的樣本弃酌,統(tǒng)計(jì)出現(xiàn)頻率
//計(jì)算兩點(diǎn)之間的距離,從葉子節(jié)點(diǎn)開始;
//test_data所在的葉節(jié)點(diǎn)指針一直存儲(chǔ)在leaf_node中
double dist1 = 0, max_dist = 0;
//計(jì)算距離
for (int i = 0; i < test_data.size(); i++){
dist1 += (leaf_node->node_data[i] - test_data[i])*(leaf_node->node_data[i] - test_data[i]);
}
dist1 = sqrt(dist1);
max_dist = dist1;
//壓入數(shù)據(jù)
near_k_nodepoint.push_back(leaf_node);
near_k_nodedist.push_back(max_dist);
//定義一個(gè)指針氨菇,該值針,指向上一個(gè)分支妓湘。
kd_tree_node * rl_node = leaf_node; //也就是表示該分支已經(jīng)被訪問過了
while (path.size() != 0){
//回溯到父節(jié)點(diǎn)(不一定是父節(jié)點(diǎn)查蓉,是搜索隊(duì)列中,棧頂元素)
kd_tree_node *back_point = path[path.size() - 1];
path.pop_back();
int split_s = back_point->splitdim - 1;
double dist2 = 0;
for (int i = 0; i < test_data.size(); i++){
dist2 += (back_point->node_data[i] - test_data[i])*(back_point->node_data[i] - test_data[i]);
}
dist2 = sqrt(dist2);
//判斷是否加入隊(duì)列榜贴,兩個(gè):隊(duì)列是否已滿豌研?未滿直接加入,更新最大距離唬党,已滿的話判斷是否大于最大距離
if (near_k_nodepoint.size() == near_num && dist2 < max_dist)//隊(duì)列已滿,且小于最大距離
{
near_k_nodepoint.pop_back();
near_k_nodedist.pop_back();
//此時(shí)隊(duì)列是不滿的
}
if (near_k_nodepoint.size() < near_num)//如果隊(duì)列未滿的話鹃共,壓入隊(duì)列
{
if (near_k_nodepoint.size() == 0){ // 當(dāng)隊(duì)列為空時(shí)
near_k_nodepoint.push_back(back_point);
near_k_nodedist.push_back(dist2);
max_dist = dist2;
}
else{
int i = 0;
while (dist2>near_k_nodedist[i]){
i++;
if (i == near_k_nodepoint.size())
break;
}
//更新最大距離
max_dist = near_k_nodedist[near_k_nodedist.size() - 1];
if (i == near_k_nodepoint.size())
max_dist = dist2;
//插入對(duì)i之前,對(duì)near_k_nodepoint和near_k_nodepoint;
near_k_nodepoint.insert(near_k_nodepoint.begin() + i, back_point);
near_k_nodedist.insert(near_k_nodedist.begin() + i, dist2);
}
}
if (back_point->node_type == "leaf"){
continue;//到達(dá)葉節(jié)點(diǎn)就繼續(xù)下一輪
}
double dist3 = abs(test_data[split_s] - back_point->node_data[split_s]);
//判斷是否需要進(jìn)入另一個(gè)分支
if (near_k_nodepoint.size() < near_num || (dist3<max_dist)){
//判斷back_point 是否是test_data搜索路徑中某個(gè)節(jié)點(diǎn)
bool flag = false;
for (int i = path_copy.size()-1; i >=0; i--){
if (back_point == path_copy[i]){
flag = true;
}
}
if (flag){
double flag = test_data[split_s];
double flag2 = back_point->node_data[split_s];
if (flag <= flag2){
if (back_point->right_node != NULL)
back_point = back_point->right_node;//可能只有左子樹,//如果只有左子樹,那么葉節(jié)點(diǎn)就選是左子樹
else
back_point = back_point->left_node;
}
else{
back_point = back_point->left_node;
}
path.push_back(back_point);
}
else{
if (back_point->right_node != NULL) //右節(jié)點(diǎn)壓入棧中
path.push_back(back_point->right_node);
if (back_point->left_node != NULL) //左節(jié)點(diǎn)壓入棧中
path.push_back(back_point->left_node);
}
}
}
//返回索引向量
for (int i = 0; i < near_k_nodepoint.size(); i++){
near_k_node_index.push_back(near_k_nodepoint[i]->index);
}
return near_k_node_index;
}
knn.cpp
#include"kd_tree.h"
#include<fstream>
#include<string>
#define label_type 3 //有三種樣本
using namespace std;
string iris_name[label_type] = {"Iris-setosa","Iris-versicolor","Iris-virginica"}; //三種iris花的名字
void main(){
//讀取數(shù)據(jù)階段
/**數(shù)據(jù)分為train.txt和test.txt
每個(gè)數(shù)據(jù)有五個(gè)分量驶拱,最后一個(gè)分量是樣本所屬的類型
讀得數(shù)據(jù)分別存儲(chǔ)在data和label里霜浴,分為train_data,train_label.
分隔符是空格符
**/
//訓(xùn)練數(shù)據(jù)
string train_file = "train2.txt";
ifstream ist(train_file.c_str());
vector<vector<float>>train_data;
vector<int>train_label;
while (!ist.eof()){
vector<float> single_data;
for (int i = 0; i < K; i++){
float temp = 0;
ist >> temp;
single_data.push_back(temp);
}
int label = 0;
ist >> label;
train_label.push_back(label);
train_data.push_back(single_data);
single_data.resize(0);
}
ist.close();
//測(cè)試數(shù)據(jù)
string test_file = "test2.txt";
ifstream ist2(test_file.c_str());
vector<vector<float>>test_data;
vector<int>test_label;
while (!ist2.eof()){
vector<float> single_data;
for (int i = 0; i < K; i++){
float temp = 0;
ist2 >> temp;
single_data.push_back(temp);
}
int label = 0;
ist2 >> label;
test_label.push_back(label);
test_data.push_back(single_data);
single_data.resize(0);
}
ist2.close();
int NUM = 0; //NUM是K近鄰的所選取的近鄰點(diǎn)的數(shù)目
for (NUM = 1; NUM < 121; NUM++){
//創(chuàng)建kd樹
kd_tree_node *iris_kd_tree = NULL;
int numpoints = train_label.size();
vector<int>index;
for (int i = 0; i < numpoints; i++){
index.push_back(i);
}
iris_kd_tree = create_kd_tree(train_data, 1, index, NULL); //根據(jù)訓(xùn)練數(shù)據(jù)創(chuàng)建kd_tree
//測(cè)試樣本的準(zhǔn)確率
int sum_num[label_type]; //各類樣本的總數(shù)
int right_num[label_type]; //各類樣本的正確判斷數(shù)目
int error_num[label_type]; //各類樣本的錯(cuò)誤識(shí)別率
//初始化
for (int i = 0; i < label_type; i++){
sum_num[i] = 0;
right_num[i] = 0;
error_num[i] = 0;
}
//k近鄰搜索,判斷樣本類型
for (int i = 0; i < test_label.size(); i++){
vector<int>k_index;
vector<int>count_label;
for (int j = 0; j < label_type; j++){
count_label.push_back(0);
}
k_index = knn_k_search(test_data[i], NUM, iris_kd_tree);//k近鄰搜索
for (int j = 0; j < k_index.size(); j++){
int flag = train_label[k_index[j]];
count_label[flag]++; //統(tǒng)計(jì)k近鄰各類樣本出現(xiàn)的次數(shù)
}
int max = count_label[0];
int label_flag = 0;
for (int j = 1; j < label_type; j++){
if (max < count_label[j]){
max = count_label[j];
label_flag = j;
}
}
if (label_flag == test_label[i]){
right_num[test_label[i]]++;
}
else{
error_num[label_flag]++;
}
sum_num[test_label[i]]++;
}
//統(tǒng)計(jì)結(jié)果,并打印出結(jié)果
int sum = 0;
int error = 0;
for (int i = 0; i < label_type; i++){
sum += sum_num[i];
error += error_num[i];
}
cout << NUM << ":" << endl;
for (int i = 0; i < label_type; i++){
cout << iris_name[i] << "測(cè)試樣本總數(shù)為:" << sum_num[i] << ",正確率為:" << right_num[i] / (sum_num[i] * 1.0) << ",錯(cuò)誤識(shí)別為該樣本的數(shù)目為:" << error_num[i] << endl;
}
cout << "總的正確率為:" << 1-error*1.0/sum<<endl;
cout << endl;
}
//畫出kd_樹(選做)
system("pause");
}