在網(wǎng)上沒找到比較直觀的多標(biāo)簽分類例子钞螟,于是自己寫了一個(gè)。起碼在入門這個(gè)領(lǐng)域的時(shí)候能有個(gè)直觀的認(rèn)識。scikit learn在多標(biāo)簽分類分類上一個(gè)很簡單的實(shí)現(xiàn)赖临。
數(shù)據(jù)準(zhǔn)備
這次用的是emotion數(shù)據(jù)集,格式是arff格式灾锯。一共是593個(gè)instance兢榨,共78列。前72列是features顺饮,最后的6列為labels吵聪。
import arff, numpy as np
dataset = arff.load(open('emotions.arff', 'rb'))
data = np.array(dataset['data'], dtype=np.float) # dtype=np.float保證數(shù)據(jù)格式,不然后面clf.fit的時(shí)候會出錯(cuò)
data.shape
# output
(593, 78)
# extract feature兼雄, 提取前72列作為features
data[:, :-6]
# output
array([[ 0.034741, 0.089665, 0.091225, ..., 0.245457, 0.105065,
0.405399],
[ 0.081374, 0.272747, 0.085733, ..., 0.343547, 0.276366,
0.710924],
[ 0.110545, 0.273567, 0.08441 , ..., 0.188693, 0.045941,
0.457372],
...,
[ 0.042903, 0.089283, 0.080263, ..., 0.366192, 0.289227,
0.66168 ],
[ 0.038987, 0.05957 , 0.082053, ..., 0.581526, 0.047156,
0.774458],
[ 0.084866, 0.192814, 0.084549, ..., 0.533746, 0.587807,
1.121553]])
# extract label吟逝,提取最后6列作為labels
data[:, -6:]
# output
array([[ 0., 1., 1., 0., 0., 0.],
[ 1., 0., 0., 0., 0., 1.],
[ 0., 1., 0., 0., 0., 1.],
...,
[ 0., 1., 1., 0., 0., 0.],
[ 0., 0., 0., 1., 1., 0.],
[ 0., 1., 1., 0., 0., 0.]])
# 把數(shù)據(jù)集劃為測試集和訓(xùn)練集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(data[:, :-6], data[:, -6:], test_size=0.33, random_state=42)
# 分類器使用1對多,SVM用linear kernel
clf1 = OneVsRestClassifier(SVC(kernel='linear'), n_jobs=-1)
# 訓(xùn)練
clf1.fit(X_train, y_train)
# output
OneVsRestClassifier(estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,decision_function_shape=None, degree=3, gamma='auto', kernel='linear', max_iter=-1, probability=False, random_state=None, shrinking=True,tol=0.001, verbose=False),n_jobs=-1)
# 輸出預(yù)測的標(biāo)簽結(jié)果
predict_class = clf1.predict(X_test)
predict_class
# output
array([[0, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
...,
[0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1],
[0, 0, 1, 1, 1, 0]])
#準(zhǔn)確率赦肋,預(yù)測的結(jié)果和實(shí)際的結(jié)果
clf1.score(X_test, y_test)
0.27040816326530615