多分類(lèi)
背景:多分類(lèi)是指具有兩類(lèi)以上的分類(lèi)任務(wù); 例如涡匀,分類(lèi)一組可能是橘子盯腌,蘋(píng)果或梨的水果圖像。本文旨在為大家提供一段即寫(xiě)即用的代碼陨瘩,跳過(guò)對(duì)原理的解說(shuō)腕够,直接上手跑一版baseline。當(dāng)然舌劳,后續(xù)的優(yōu)化任務(wù)還是需要一定的算法基礎(chǔ)帚湘,比如模型參數(shù)以及性能參數(shù)優(yōu)化。
初步結(jié)論
本數(shù)據(jù)集上甚淡, 在迭代次數(shù)量級(jí)基本一致的情況下大诸,lightgbm表現(xiàn)更優(yōu):樹(shù)的固有多分類(lèi)特性使得不需要OVR或者OVO式的開(kāi)銷(xiāo),而且lightgbm本身就對(duì)決策樹(shù)進(jìn)行了優(yōu)化,因此性能和分類(lèi)能力都較好资柔。
模型 | AUC | 精確率 | 耗時(shí)(s) |
---|---|---|---|
linearSVC | 0.9169 | 0.6708 | 883 |
LR | 0.9226 | 0.6571 | 944 |
lightgbm | 0.9332 | 0.6947 | 600 |
數(shù)據(jù)定義
一個(gè)樣本僅對(duì)應(yīng)一個(gè)標(biāo)簽
數(shù)據(jù)量: 800M(32w樣本量 * 929 特征)
數(shù)據(jù)格式
特征1|特征2|...|特征N|label
評(píng)測(cè)算法
- LR
- linearSVC
- lightgbm
notice: 樹(shù)模型是天生的多分類(lèi)模型焙贷,LR、linearSVC則是基于“One-Vs-The-Rest”贿堰,即為N類(lèi)訓(xùn)練N個(gè)模型辙芍,為樣本選擇一個(gè)最佳類(lèi)別。
參考:多分類(lèi)和多標(biāo)簽算法
版本
系統(tǒng) 64bit centOS
sklearn 0.19.1
代碼走讀
import pandas as pd
import numpy as np
import time
import logging
import os, sys
import psutil
import lightgbm as lgb
from datetime import datetime
from itertools import cycle
from sklearn import svm
from sklearn.metrics import *
from sklearn.cross_validation import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.externals import joblib
from scipy import interp
# 循環(huán)讀取多個(gè)文件
path = "./data/d20190416/"
os.chdir(path)
files = os.listdir(path)
files_csv = list(filter(lambda x: x[:4]=='part' , files))[:200]
data_list = []
for file in files_csv:
tmp = pd.read_csv(path + file, sep = '|', header=None)
data_list.append(tmp)
data_set = pd.concat(data_list, axis = 0)
del data_list
#配置列名
sample_cnt, col_cnt = data_set.shape
cols = ["x_%d"%(i) for i in range(col_cnt - 1)]
cols.append("y")
data_set.columns = cols
# 數(shù)據(jù)預(yù)覽,事先準(zhǔn)備好one-hot特征羹与,最后一列為label={0,1,2,3}
# >> 0|1|1|1|0|0....|2
- 模型 OneVsRestClassifier
元分類(lèi)器 svm.LinearSVC
說(shuō)明:OneVsRestClassifier模塊, 是通過(guò)將分類(lèi)問(wèn)題分解為二進(jìn)制分類(lèi)問(wèn)題來(lái)解決沸手,因此構(gòu)建樣本時(shí)需要將label列轉(zhuǎn)為二進(jìn)制格式 e.g. 2 -> [0, 0, 1, 0] 0 ->[1, 0, 0, 0]
性能:?jiǎn)魏?83s, 迭代1000次
AUC : 0.9169
精確率:0.6708
#########################################################################################
# 模型 OneVsRestClassifier
# 元分類(lèi)器 svm.LinearSVC
#########################################################################################
y = label_binarize(data_set["y"], classes=[0,1,2,3])
X = data_set.iloc[:, :-1]
# 隨機(jī)化數(shù)據(jù),并劃分訓(xùn)練數(shù)據(jù)和測(cè)試數(shù)據(jù)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,random_state=0)
#訓(xùn)練
model = OneVsRestClassifier(svm.LinearSVC(random_state = 0, verbose = 1))
btime = datetime.now()
model.fit(X_train, y_train)
print 'all tasks done. total time used:%s s.\n\n'%((datetime.now() - btime).total_seconds())
# 評(píng)價(jià)
y_score = model.decision_function(X_test) # 計(jì)算屬于各個(gè)類(lèi)別的概率注簿,返回值的shape = [n_samples, n_classes]
# 1契吉、調(diào)用函數(shù)計(jì)算micro類(lèi)型的AUC
print '調(diào)用函數(shù)auc:', roc_auc_score(y_test, y_score, average='micro')
# 2、混淆矩陣
y_pred = model.predict(X_test) # 預(yù)測(cè)屬于哪個(gè)類(lèi)別
confusion_matrix(y_test.argmax(axis=1), y_pred1.argmax(axis=1)) # 需要0诡渴、1捐晶、2、3而不是OH編碼格式
# 3妄辩、經(jīng)典-精確率惑灵、召回率、F1分?jǐn)?shù)
precision_score(y_test, y_pred,average='micro')
recall_score(y_test, y_pred,average='micro')
f1_score(y_test, y_pred,average='micro')
# 4眼耀、模型報(bào)告
classification_report(y_test, y_pred, digits=4)
''' precision recall f1-score support
0 0.78 0.85 0.81 42276
1 0.83 0.66 0.74 18960
2 0.59 0.34 0.44 13591
3 0.59 0.35 0.44 13170
4 0.00 0.00 0.00 8151
avg / total 0.67 0.60 0.62 96148
'''
# 保存模型
joblib.dump(model, './model/LinearSVC.pkl')
- 模型(分類(lèi)器) LR
說(shuō)明:LogisticRegression模塊英支,設(shè)置multi_class='ovr',會(huì)訓(xùn)練出“類(lèi)別數(shù)”個(gè)分類(lèi)器哮伟,構(gòu)建樣本時(shí)需要原始label即可
性能:?jiǎn)魏?44s , 迭代1000次
AUC : 0.9226
精確率:0.6571
#########################################################################################
# 模型 LogisticRegression(random_state=0, solver='sag',multi_class='ovr', verbose = 1)
#########################################################################################
from sklearn.linear_model import LogisticRegression
# 準(zhǔn)備數(shù)據(jù)
X = data_set.iloc[:, :-1]
X_train, X_test, y_train, y_test = train_test_split(X, data_set["y"], test_size=0.3,random_state=0)
# 訓(xùn)練
btime = datetime.now()
lr_clf = LogisticRegression(random_state=0, solver='sag',multi_class='ovr', verbose = 1)
lr_clf.fit(X_train, y_train)
print 'all tasks done. total time used:%s s.\n\n'%((datetime.now() - btime).total_seconds())
# 1干花、AUC
y_pred_pa = lr_clf.predict_proba(X_test)
y_test_oh = label_binarize(y_test, classes=[0,1,2,3])
print '調(diào)用函數(shù)auc:', roc_auc_score(y_test_oh, y_pred_pa, average='micro')
# 2、混淆矩陣
y_pred = lr_clf.predict(X_test)
confusion_matrix(y_test, y_pred_1)
# 3楞黄、經(jīng)典-精確率池凄、召回率、F1分?jǐn)?shù)
precision_score(y_test, y_pred_1,average='micro')
recall_score(y_test, y_pred_1,average='micro')
f1_score(y_test, y_pred_1,average='micro')
# 4鬼廓、模型報(bào)告
print(classification_report(y_test, y_pred , digits=4))
# 保存模型
joblib.dump(lr_clf, './model/lr_clf.pkl')
- 模型(分類(lèi)器) lightgbm
說(shuō)明:樹(shù)的輸出本身就可以是多分類(lèi)肿仑,應(yīng)該是操作最簡(jiǎn)單的,構(gòu)建樣本時(shí)需要原始label即可
性能:?jiǎn)魏?00s, 迭代200次
AUC : 0.9332
精確率:0.6947
#########################################################################################
# 模型 lightgbm
#########################################################################################
import lightgbm as lgb
# 準(zhǔn)備數(shù)據(jù)
X = data_set.iloc[:, :-1]
X_train, X_test, y_train, y_test = train_test_split(X, data_set["y"], test_size=0.3,random_state=0)
# 訓(xùn)練
btime = datetime.now()
train_data=lgb.Dataset(X_train,label=y_train)
validation_data=lgb.Dataset(X_test,label=y_test)
params={
'learning_rate':0.1,
'lambda_l1':0.1,
'lambda_l2':0.2,
'max_depth':6,
'objective':'multiclass',
'num_class':4,
}
clf=lgb.train(params,train_data,valid_sets=[validation_data])
print 'all tasks done. total time used:%s s.\n\n'%((datetime.now() - btime).total_seconds())
# 1碎税、AUC
y_pred_pa = clf.predict(X_test) # !!!注意lgm預(yù)測(cè)的是分?jǐn)?shù)尤慰,類(lèi)似 sklearn的predict_proba
y_test_oh = label_binarize(y_test, classes= [0,1,2,3])
print '調(diào)用函數(shù)auc:', roc_auc_score(y_test_oh, y_pred_pa, average='micro')
# 2、混淆矩陣
y_pred = y_pred_pa .argmax(axis=1)
confusion_matrix(y_test, y_pred )
# 3雷蹂、經(jīng)典-精確率伟端、召回率、F1分?jǐn)?shù)
precision_score(y_test, y_pred,average='micro')
recall_score(y_test, y_pred,average='micro')
f1_score(y_test, y_pred,average='micro')
# 4萎河、模型報(bào)告
print(classification_report(y_test, y_pred))
# 保存模型
joblib.dump(clf, './model/lgb.pkl')