數(shù)據(jù)集和代碼都在微信公眾號里面:一路向AI河胎,回復文本分類即可獲取关斜,后續(xù)會不定期更新文本數(shù)據(jù)和其它文本分類模型~
在上一篇文章中蛔外,描述了TextCNN用于文本分類內(nèi)在邏輯。今天應用這個模型來實踐一個文本多分類Demo也物。
一、數(shù)據(jù)集
先介紹下數(shù)據(jù)集列疗,數(shù)據(jù)集是從網(wǎng)上找到滑蚯,具體來源找不到了。數(shù)據(jù)集有女性抵栈、體育告材、文學、校園4個文件夾組成古劲,每個文件下有幾百個txt文件斥赋,每個txt文件包含一行文本。
首先讀取每個文件夾的所有數(shù)據(jù)作為我們的訓練數(shù)據(jù)产艾,而數(shù)據(jù)標簽則為每個txt文件所對應的文件夾名稱疤剑,即:女性、體育闷堡、文學隘膘、校園4個類別,這邊便于演示Demo缚窿,使用的數(shù)據(jù)量較屑摇:其中體育下299條數(shù)據(jù)、女性下992條數(shù)據(jù)倦零,文學下797條數(shù)據(jù)误续、校園下265條數(shù)據(jù),總共2353條數(shù)據(jù)扫茅。這顯示出文本數(shù)據(jù)類別不均衡蹋嵌,后續(xù)會對其進行一定的處理。
數(shù)據(jù)獲取完之后葫隙,對數(shù)據(jù)進行按以下步驟進行處理:
1. 數(shù)據(jù)分詞:使用jieba對文本進行分詞栽烂。
2. 文本過濾:首先過濾掉非中文字符,例如19 30 或者www url等。其次使用停用詞過濾掉一些無意義的中文字或詞腺办。
3. 數(shù)據(jù)填充:由于每行文本序列不一致焰手,為了便于建模,需要把所有序列填充到相同的長度怀喉,這里初略選取序列最大長度為25书妻,對長度小于25的序列后端補齊'0',對長度大于25的序列進行截斷處理躬拢。
def text_process(self, stopwords):
for i in range(len(self.text)):
# 使用正則表達式過濾非中文字符或數(shù)字
pattern = re.compile(r'[^\u4e00-\u9fa5]')
self.text[i] = re.sub(pattern, '', self.text[i])
# jieba 分詞
cut_result = list(jieba.cut(self.text[i]))
# 過濾停用詞
for j in range(len(cut_result)):
if cut_result[j] in stopwords:
cut_result[j] = ''
else:
# 把所有單詞存到集合里
if cut_result[j] not in self.words:
self.words.append(cut_result[j])
# 數(shù)據(jù)填充
tmp = self.data_padding([x.strip() for x in list(cut_result) if x != '' and x != ' '])
self.text[i] = ' '.join(tmp)
def data_padding(self, sequence):
# 序列小于最大長度填充'0'
if len(sequence) <= self.max_len:
sequence.extend(['0'] * (self.max_len - len(sequence)))
else:
# 序列大于最大長度進行截斷
sequence = sequence[:self.max_len]
return sequence
4. 數(shù)據(jù)編碼:對文本編碼:可以在上述過程中躲履,統(tǒng)計出分詞后所有單詞的個數(shù),并把其映射為單詞所對引的索引聊闯,然后把文本中的單詞轉(zhuǎn)換為其對應的索引工猜;對于標簽編碼,可以把標簽映射為{'體育':0菱蔬,'女性':1篷帅,'文學':2,'校園':3}處理,也可以直接進行onehot編碼: {'體育' : [1 0 0 0], '女性' :[0 1 0 0], '文學':[0 0 1 0], '校園' :[0 0 0 1]}汗销。
def data_encoding(self, texts, labels):
with open('../data/word2index.txt') as fp:
word2index = json.load(fp)
# 文本編碼 -- 找到每個詞對應的索引
data = []
for text in texts:
text = text.split(' ')
tmp = []
for i in range(len(text)):
text[i] = word2index.get(text[i], 0)
tmp.append(text[i])
data.extend(tmp)
# 標簽編碼
label2ind = {}
unique_label = list(set(labels))
for index, label in enumerate(unique_label):
label2ind[label] = index
for i in range(len(labels)):
labels[i] = label2ind[labels[i]]
# one hot 編碼
# labels = to_categorical(labels, len(set(labels)), dtype=int)
return np.array(data).reshape(-1, self.max_len), np.array(labels), word2index
5. 劃分數(shù)據(jù)集:把文本轉(zhuǎn)換成向量后犹褒,把數(shù)據(jù)集充分打亂之后,可以分為訓練集和測試集弛针。其中參數(shù)stratify = label 可以使劃分的訓練集和測試集各類比例與原始數(shù)據(jù)集分布一致叠骑,等同于各類等比例抽樣。
def split_data(self, data, label):
# shuffle data
data, label = shuffle(data, label, random_state=2020)
X_train, X_text, y_train, y_test = train_test_split(data, label, test_size=0.1, random_state=2020,
stratify=label)
return X_train, X_text, y_train, y_test
二削茁、TextCNN模型
TextCNN的核心思想是抓取文本的局部特征:通過不同的卷積核尺寸(確切的說是卷積核高度)來提取文本的N-gram信息宙枷,然后通過最大池化操作來突出各個卷積操作提取的最關(guān)鍵信息(頗有一番Attention的味道),拼接后通過全連接層對特征進行組合茧跋,最后通過多分類損失函數(shù)來訓練模型慰丛。
在本模型中TextCNN代碼如下:
def textcnn(wordsize, label, embedding_matrix=None):
input = Input(shape=(data_process.max_len,))
if embedding_matrix is None:
embedding = Embedding(input_dim=wordsize,
output_dim=32,
input_length=data_process.max_len,
trainable=True)(input)
else: # 使用預訓練矩陣初始化Embedding
embedding = Embedding(input_dim=wordsize,
output_dim=32,
weights=[embedding_matrix],
input_length=data_process.max_len,
trainable=False)(input)
convs = []
for kernel_size in [2, 3, 4]:
conv = Conv1D(64, kernel_size, activation='relu')(embedding)
pool = MaxPooling1D(pool_size=data_process.max_len - kernel_size + 1)(conv)
convs.append(pool)
print(pool)
concat = Concatenate()(convs)
flattern = Flatten()(concat)
dropout = Dropout(0.3)(flattern)
output = Dense(len(set(label)), activation='softmax')(dropout)
model = Model(input, output)
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print(model.summary())
return model
模型結(jié)構(gòu)如下:
開頭提到訓練數(shù)據(jù)不均衡,對待數(shù)據(jù)不均衡通常采用的方式為過采樣瘾杭、降采樣诅病、數(shù)據(jù)加權(quán)等。前兩種方式比較簡單粥烁,不作過多介紹贤笆,這里介紹下數(shù)據(jù)加權(quán),類別數(shù)量分布為{0: 893, 1: 238, 2: 269, 3: 717}讨阻,通過樣本總數(shù)除以每個類別總數(shù)來得到每個類別的樣本權(quán)重芥永,經(jīng)過處理后得到:{0: 2.37, 1: 8.89, 2: 7.87, 3: 2.95},可以看到樣本數(shù)目越多钝吮,樣本權(quán)重就越小埋涧。
def class_weight(self, y_train):
count_res = dict(Counter(y_train))
print(count_res)
for key in count_res.keys():
count_res[key] = round(len(y_train) / count_res[key], 2)
return count_res
樣本得到權(quán)重后板辽,怎么使用呢?可以在模型訓練的時候通過class_weight參數(shù)賦予給損失函數(shù)棘催。
history = model.fit(X_train, y_train, validation_split=0.05, batch_size=32, epochs=20, class_weight=class_weight,
verbose=2)
三劲弦、評估結(jié)果
模型訓練基本沒有調(diào)參,在測試集上的準確率達到93%左右,其它一些評估指標結(jié)果如下:混淆矩陣結(jié)果行代表真實標簽,列代表預測標簽坠非,可以看出把模型的第3類樣本預測為第2類樣本的數(shù)目最多為3個兼蕊,可以挑選出這些Badcase分析下是什么原因造成的。
混淆矩陣結(jié)果:
[[29 1 0 0]
[ 1 75 2 2]
[ 2 3 92 2]
[ 0 1 1 25]]
分類報告結(jié)果:
precision recall f1-score support
0 0.91 0.97 0.94 30
1 0.94 0.94 0.94 80
2 0.97 0.93 0.95 99
3 0.86 0.93 0.89 27
模型后續(xù)可改進的空間還有很多贸毕,比如說網(wǎng)格搜索+交叉驗證郑叠,模型不均衡數(shù)據(jù)集的處理,預訓練Embedding等等明棍,后續(xù)有時間會逐漸完善乡革。
由于時間比較倉促,文章寫的有點亂摊腋,數(shù)據(jù)集和代碼在公眾號回復文本分類即可獲取沸版,后續(xù)會不斷更新該系列文章,有興趣的可以關(guān)注一波兴蒸。