SVM原理與C++的Eigen庫實(shí)現(xiàn)

具體的SVM詳解參考https://blog.csdn.net/c406495762/article/details/78072313, 講的特別詳細(xì), 如下代碼也是基于該鏈接中的講解而實(shí)現(xiàn)的

//model.h
#include <iostream>
#include "Eigen/Eigen"
#include<vector>
#include<string>
#include<fstream>
#include<sstream>
#include<iterator>
#include<algorithm>
#include<regex>
#include<set>
#include<unordered_map>
#include <assert.h>
#include <random>
#include <python2.7/Python.h>
#include <stdlib.h>
using namespace Eigen;
using std::pair;
using std::vector;
using std::cout;
using std::endl;
using std::ios;
using std::ifstream;
using std::string;
using std::regex;
using std::iterator;
using std::stringstream;
using std::sregex_token_iterator;
using std::set;
using std::istringstream;
using std::istream_iterator;
using std::unordered_map;
using std::make_pair;
using std::begin;
using std::end;
using std::min;
using std::max;
using std::abs;

using _MAT_VEC=pair<MatrixXf,VectorXf>;
using _PARAM=pair<VectorXf,float>;

#define filename "data/sample"
#define C 0.6f
#define Threshold 0.001f
#define Max_iter 40
#define Alpha_threshold 0.00001f
#define RBF_var 1.3




#ifndef DATA_HANDLE_
#define DATA_HANDLE_
inline _MAT_VEC load_data();
inline pair<_MAT_VEC,_MAT_VEC> train_test_split(const _MAT_VEC,float);
_PARAM SMO(const MatrixXf&,const VectorXf&);
inline VectorXf cal_weight(const MatrixXf&,const VectorXf&,const VectorXf&);
VectorXf kernel_RBF(MatrixXf,VectorXf,float);
inline void python_plot(const VectorXf&,float b);
#endif
//model.cpp
#pragma once
#include "model.h"

inline _MAT_VEC load_data(){
    ifstream ifile(filename,ios::in);
    if(!ifile.is_open()){
        cout<<"failed to open: "<<filename<<endl;
    }
    string line;
    vector<vector<float> > tempX;
    vector<float> tempY;
    while(getline(ifile,line)){
        istringstream iss( line );
        vector<float> nums{istream_iterator<float>( iss ), std::istream_iterator<float>()};
        tempY.push_back(nums[nums.size()-1]);
        nums.pop_back();
        tempX.push_back(nums);
    }
    assert(tempX.size()==tempY.size());
    int matC=tempX[0].size(),matR=tempX.size();
    MatrixXf X(matR,matC);
    VectorXf Y(matR);
    for(auto row=0;row<matR;++row){
        Y[row]=tempY[row];
        float *arr=tempX[row].data();
        X.row(row)=Map<VectorXf>(arr,matC);
    }
    return make_pair(X,Y);
}

inline pair<_MAT_VEC,_MAT_VEC> train_test_split(const _MAT_VEC &raw_data,float percent=0.8){
    int new_row=raw_data.first.rows()*percent;
    return make_pair(
        make_pair(raw_data.first.topRows(new_row),raw_data.second.topRows(new_row)),
        make_pair(raw_data.first.topRows(raw_data.first.rows()-new_row),raw_data.second.topRows(raw_data.first.rows()-new_row))
        );

}

int random_choice(int min,int max,int current){
    std::random_device seeder;
    std::mt19937 engine(seeder());
    std::uniform_int_distribution<int> dist(min, max);
    int rand=dist(engine);
    while(rand==current){
        rand=dist(engine);
    }
    return rand;
}

_PARAM SMO(const MatrixXf &features,const VectorXf& labels){
    int rows=features.rows(),cols=features.cols();
    int b=0,iter_count=0;
    VectorXf alphas(rows);
    alphas.setZero();
    while(iter_count<=Max_iter){
        int pair_alpha_changed_count=0;
        for(int i=0;i<rows;++i){

            //計(jì)算拉格朗日表示fX_i和損失E_i
            float fX_i=(alphas.cwiseProduct(labels)).transpose()*(features*(features.row(i).transpose()))+b;
            //float fX_i=(alphas.cwiseProduct(labels)).transpose()*(kernel_RBF(features,VectorXf(features.row(i)),1.3f))+b;
            float E_i=fX_i-labels(i);
            if((labels(i)*E_i<-Threshold && alphas(i)<C) || (labels(i)*E_i>Threshold && alphas(i)>0)){
                
                //隨機(jī)挑選j并計(jì)算拉格朗日fX_j和E_j
                int j=random_choice(0,rows-1,i);
                float fX_j=(alphas.cwiseProduct(labels)).transpose()*(features*(features.row(j).transpose()))+b;
                //float fX_j=(alphas.cwiseProduct(labels)).transpose()*(kernel_RBF(features,VectorXf(features.row(j)),1.3f))+b;
                float E_j=fX_j-labels(j);

                //保留i和j所對應(yīng)的舊的alphas
                float alphas_old_i=alphas(i);
                float alphas_old_j=alphas(j);

                //計(jì)算上下界, 如果上下界相同則重新選擇
                float zero=0;
                float L=(labels(i)!=labels(j))?max(zero,alphas(j)-alphas(i)):max(zero,alphas(i)+alphas(j)-C);
                float H=(labels(i)!=labels(j))?min(C,C+alphas(j)-alphas(i)):min(C,alphas(i)+alphas(j));
                if(L==H) continue;

                //計(jì)算步長eta, 大于等于0說明不是支持向量
                float eta=static_cast<float>(2.0f*features.row(i)*(features.row(j).transpose()))-static_cast<float>(features.row(i)*(features.row(i).transpose()))-static_cast<float>(features.row(j)*(features.row(j).transpose()));
                if(eta>=0) continue;

                //更新alphas_j并對alphas_j加窗
                alphas(j)-=labels(j)*(E_i-E_j)/eta;
                alphas(j)=alphas(j)>H?H:(alphas(j)<L?L:alphas(j));

                //如果alphas_j變化太小則不更新  
                if(abs(alphas(j)-alphas_old_j)<Alpha_threshold) continue;
                
                //更新alphas_i
                alphas(i)+=static_cast<float>(labels.row(j)*labels.row(i))*(alphas_old_j-alphas(j));
                
                //更新b_1和b_2
                float b_1 = b - E_i- labels(i)*(alphas(i)-alphas_old_i)*features.row(i)*features.row(i).transpose() - labels(j)*(alphas(j)-alphas_old_j)*features.row(i)*features.row(j).transpose();
                float b_2 = b - E_j- labels(i)*(alphas(i)-alphas_old_i)*features.row(i)*features.row(j).transpose() - labels(j)*(alphas(j)-alphas_old_j)*features.row(j)*features.row(j).transpose();

                //更新b
                b=(0<alphas(i)&&C>alphas(i))?b_1:
                    ((0<alphas(j)&&C>alphas(j))?b_2:
                        (b_1+b_2)/2);
                
                ++pair_alpha_changed_count;
            }
        }
        cout<<"第 "<<iter_count<<" 次迭代. 在這次迭代中, 共有 "<<pair_alpha_changed_count<<" 個(gè)SMO對被改變"<<endl;
        if(!pair_alpha_changed_count) ++iter_count;
        else iter_count=0;
    }
    return make_pair(alphas,b);
}

inline VectorXf cal_weight(const MatrixXf &features,const VectorXf &labels,const VectorXf &alphas){
    int rows=features.rows(),cols=features.cols();
    VectorXf weight=labels.cwiseProduct(alphas).replicate(1,cols).cwiseProduct(features).colwise().sum().transpose();
    return weight;
}

//徑向基核函數(shù)
VectorXf kernel_RBF(MatrixXf features,VectorXf line,float var=RBF_var){
    features.rowwise()-=line.transpose();
    ArrayXf kV=ArrayXf((features*features.transpose()).diagonal()/(pow(var,2))*(-1));
    return kV.exp();
}


//調(diào)用python代碼作圖
void python_plot(VectorXf &weight,float b){
    setenv("PYTHONPATH",".",1); //將python路徑設(shè)為當(dāng)前工作路徑
    Py_Initialize();

    PyObject* myModuleString = PyString_FromString((char*)"svm");
    PyObject* myModule = PyImport_Import(myModuleString);

    PyObject* myFunction = PyObject_GetAttrString(myModule,(char*)"plot_points");

    //通過元組傳入?yún)?shù)
    PyObject *pArgs = PyTuple_New(3);
    PyTuple_SetItem(pArgs,0, PyFloat_FromDouble(static_cast<double>(weight(0))));
    PyTuple_SetItem(pArgs,1, PyFloat_FromDouble(static_cast<double>(weight(1))));
    PyTuple_SetItem(pArgs,2, PyFloat_FromDouble(static_cast<double>(b)));

    //調(diào)用函數(shù)
    PyObject_CallObject(myFunction, pArgs);
    Py_Finalize();
}
//main.cpp
#include "Eigen/Dense"
#include "model.h"
#include "model.cpp"
int main(){
    _MAT_VEC Train;
    _MAT_VEC Test;
    _MAT_VEC total_data=load_data();
    {
        auto whole_data=train_test_split(total_data);
        Train=move(whole_data.first);
        Test=move(whole_data.second);
    }
    assert(Train.first.rows()==Train.second.rows());
    _PARAM alpha_b=SMO(Train.first,Train.second);
    
    VectorXf weight=cal_weight(Train.first,Train.second,alpha_b.first);

    ArrayXf arr((Test.first*weight+alpha_b.second*VectorXf::Ones(Test.first.rows())).cwiseProduct(Test.second));

    cout<<"權(quán)重大小w為: "<<weight.transpose()<<"偏置項(xiàng)b為: "<<alpha_b.second<<endl;

    cout<<"訓(xùn)練樣本大小: "<<Train.first.rows()<<endl<<"測試樣本大小: "<<Test.first.rows()<<endl;
    cout<<"測試樣本正確的數(shù)量: "<<(arr>=0).count()<<endl;

    python_plot(weight,alpha_b.second);
}
#svm.py
# -*- coding:UTF-8 -*-


def plot_points(a1,a2,b):
    import matplotlib.pyplot as plt
    import numpy as np
    import types

    fileName=''
    features = []; labels = []
    fr = open(fileName)
    for line in fr.readlines():
        lineArr = line.strip().split('\t')
        features.append([float(lineArr[0]), float(lineArr[1])])
        labels.append(float(lineArr[2]))

    data_plus = []
    data_minus = []
    for i in range(len(features)):
        if labels[i] > 0:
            data_plus.append(features[i])
        else:
            data_minus.append(features[i])
    data_plus_np = np.array(data_plus)
    data_minus_np = np.array(data_minus)
    plt.scatter(np.transpose(data_plus_np)[0], np.transpose(data_plus_np)[1], s=30, alpha=0.7,color='red')
    plt.scatter(np.transpose(data_minus_np)[0], np.transpose(data_minus_np)[1], s=30, alpha=0.7,color='blue')
    x1 = max(features)[0]
    x2 = min(features)[0]
    y1, y2 = (-b- a1*x1)/a2, (-b - a1*x2)/a2
    plt.plot([x1, x2], [y1, y2])
    plt.title("Sample data and the svm linear discriminant")
    plt.show()

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市射富,隨后出現(xiàn)的幾起案子懂拾,更是在濱河造成了極大的恐慌,老刑警劉巖墨坚,帶你破解...
    沈念sama閱讀 207,113評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異映挂,居然都是意外死亡泽篮,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,644評論 2 381
  • 文/潘曉璐 我一進(jìn)店門柑船,熙熙樓的掌柜王于貴愁眉苦臉地迎上來帽撑,“玉大人,你說我怎么就攤上這事鞍时】骼” “怎么了?”我有些...
    開封第一講書人閱讀 153,340評論 0 344
  • 文/不壞的土叔 我叫張陵逆巍,是天一觀的道長及塘。 經(jīng)常有香客問我,道長锐极,這世上最難降的妖魔是什么笙僚? 我笑而不...
    開封第一講書人閱讀 55,449評論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮灵再,結(jié)果婚禮上肋层,老公的妹妹穿的比我還像新娘。我一直安慰自己翎迁,他們只是感情好栋猖,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,445評論 5 374
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著鸳兽,像睡著了一般掂铐。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,166評論 1 284
  • 那天全陨,我揣著相機(jī)與錄音爆班,去河邊找鬼。 笑死辱姨,一個(gè)胖子當(dāng)著我的面吹牛柿菩,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播雨涛,決...
    沈念sama閱讀 38,442評論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼枢舶,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了替久?” 一聲冷哼從身側(cè)響起凉泄,我...
    開封第一講書人閱讀 37,105評論 0 261
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎蚯根,沒想到半個(gè)月后后众,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,601評論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡颅拦,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,066評論 2 325
  • 正文 我和宋清朗相戀三年蒂誉,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片距帅。...
    茶點(diǎn)故事閱讀 38,161評論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡右锨,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出碌秸,到底是詐尸還是另有隱情绍移,我是刑警寧澤,帶...
    沈念sama閱讀 33,792評論 4 323
  • 正文 年R本政府宣布哮肚,位于F島的核電站登夫,受9級特大地震影響广匙,放射性物質(zhì)發(fā)生泄漏允趟。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,351評論 3 307
  • 文/蒙蒙 一鸦致、第九天 我趴在偏房一處隱蔽的房頂上張望潮剪。 院中可真熱鬧,春花似錦分唾、人聲如沸抗碰。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,352評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽弧蝇。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間看疗,已是汗流浹背沙峻。 一陣腳步聲響...
    開封第一講書人閱讀 31,584評論 1 261
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留两芳,地道東北人摔寨。 一個(gè)月前我還...
    沈念sama閱讀 45,618評論 2 355
  • 正文 我出身青樓,卻偏偏與公主長得像怖辆,于是被迫代替她去往敵國和親是复。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,916評論 2 344