多任務(wù)學(xué)習(xí)(Multi-task learning)簡介
多任務(wù)學(xué)習(xí)(Multi-task learning)是遷移學(xué)習(xí)(Transfer Learning)的一種捆探,而遷移學(xué)習(xí)指的是將從源領(lǐng)域的知識(shí)(source domin)學(xué)到的知識(shí)用于目標(biāo)領(lǐng)域(target domin),提升目標(biāo)領(lǐng)域的學(xué)習(xí)效果酸钦。 而多任務(wù)學(xué)習(xí)也是希望模型同時(shí)做多個(gè)任務(wù)時(shí),能將其他任務(wù)學(xué)到的知識(shí),用于目標(biāo)任務(wù)中度液,從而提升目標(biāo)任務(wù)效果。
如果我們換個(gè)角度理解画舌,其實(shí)多任務(wù)學(xué)習(xí)堕担,其實(shí)是對目標(biāo)任務(wù)做了一定的約束,或者叫做regularization曲聂。我們不希望模型只局限于目標(biāo)任務(wù)的學(xué)習(xí)霹购,而是能夠適應(yīng)多個(gè)任務(wù)場景,這樣可以大大的增加模型的泛函能力(generalization)朋腋。
舉個(gè)形象的例子齐疙,單人多任務(wù)學(xué)習(xí)模型就像一個(gè)一門心思只做一樣事情的匠人,在他自己的領(lǐng)域旭咽,他可能可以做一百分贞奋,如果換個(gè)任務(wù)也許他就會(huì)做的不是特別好,而多任務(wù)學(xué)習(xí)模型就像一個(gè)什么任務(wù)都做得還算優(yōu)秀但是不完美的人轻专∫涿可是在實(shí)際深度學(xué)習(xí)任務(wù)中,測試集和訓(xùn)練集的分布還是會(huì)有一定的偏差,那測試集可能就意味給讓模型做一個(gè)目標(biāo)微調(diào)后任務(wù)催训。所以在測試集上洽议,多任務(wù)模型大概率是表現(xiàn)優(yōu)異那一個(gè)。
這里需要強(qiáng)調(diào)一點(diǎn)漫拭,這里的多任務(wù)的各個(gè)任務(wù)之間一定要有強(qiáng)相關(guān)性亚兄,如果任務(wù)之間本身的關(guān)聯(lián)性就不大,多任務(wù)學(xué)習(xí)并不會(huì)對模型的提升并不一定會(huì)有用采驻。
多任務(wù)學(xué)習(xí)(Multi-task learning)的兩種模式
深度學(xué)習(xí)中兩種多任務(wù)學(xué)習(xí)模式:隱層參數(shù)的硬共享與軟共享审胚。
- 隱層參數(shù)硬共享,指的是多個(gè)任務(wù)之間共享網(wǎng)絡(luò)的同幾層隱藏層礼旅,只不過在網(wǎng)絡(luò)的靠近輸出部分開始分叉去做不同的任務(wù)膳叨。
- 隱層參數(shù)軟共享,不同的任務(wù)使用不同的網(wǎng)絡(luò)痘系,但是不同任務(wù)的網(wǎng)絡(luò)參數(shù)菲嘴,采用距離(L1,L2)等作為約束,鼓勵(lì)參數(shù)相似化汰翠。
而本次的代碼實(shí)現(xiàn)采用的是隱層參數(shù)硬共享龄坪,也就是兩個(gè)任務(wù)共享網(wǎng)絡(luò)淺層的參數(shù)。
多任務(wù)學(xué)習(xí)keras實(shí)現(xiàn)
這里筆者簡單的介紹一下如何通過keras簡單的搭建一個(gè)多任務(wù)學(xué)習(xí)網(wǎng)絡(luò)复唤。
這里筆者的目標(biāo)任務(wù)是一個(gè)10分類的關(guān)系分類任務(wù)健田,對關(guān)系分類任務(wù)不是很了解的同學(xué)可以移步到筆者之前的文章中去了解一下,而我將訓(xùn)練文本中兩個(gè)存在關(guān)系的實(shí)體(entity)標(biāo)了出來佛纫,在模型中加了一個(gè)命名體識(shí)別(NER)任務(wù)構(gòu)成了多任務(wù)學(xué)習(xí)模型妓局。
筆者的網(wǎng)絡(luò)架構(gòu)如下圖所示:
- 句子向量和位置向量拼接構(gòu)成模型的輸入,
- 經(jīng)過一層共享的LSTM編碼層雳旅,后模型開始分叉跟磨,
- 其中一條路徑是經(jīng)過一層MaxPooling和以及全連接層后輸出文本分類的預(yù)測輸出间聊,
-
另外一條路徑是經(jīng)過一層CRF層后輸出命名實(shí)體識(shí)別的預(yù)測輸出攒盈。
模型代碼部分
這里模型構(gòu)建不需要注意,筆者這里強(qiáng)調(diào)的是:
各個(gè)任務(wù)的輸出層一定要命名哎榴,比如筆者這個(gè)模型的文本分類任務(wù)的輸出層Dense(10, activation='softmax',name = "out1")(out1)中的name ="out1"型豁,以及NER的輸出層crf = CRF(2, sparse_target=True,name ="crf_output")中的name ="crf_output"不能省略。
第二個(gè)就是model.compile中的loss和loss的權(quán)重需要和任務(wù)輸出層的name進(jìn)行對應(yīng)尚蝌,如下:
loss={'out1': 'categorical_crossentropy','crf_output': crf.loss_function}
loss_weights={'out1':1, 'crf_output': 1}
下面是實(shí)現(xiàn)代碼迎变,發(fā)現(xiàn)沒有,Keras搭建多任務(wù)學(xué)習(xí)模型是不是So easy飘言。
from keras.layers import Input,LSTM,Bidirectional,Dense,Dropout,Concatenate,Embedding,GlobalMaxPool1D
from keras.models import Model
from keras_contrib.layers import CRF
import keras.backend as K
from keras.utils import plot_model
K.clear_session()
maxlen = 40
###輸入
inputs = Input(shape=(maxlen,768),name="sen_emb")
pos1_en = Input(shape=(maxlen,),name="pos_en1_id")
pos2_en = Input(shape=(maxlen,),name="pos_en2_id")
pos1_emb = Embedding(maxlen,8,input_length=maxlen,name = "pos_en1_emb")(pos1_en)
pos2_emb = Embedding(maxlen,8,input_length=maxlen,name = "pos_en2_emb")(pos2_en)
x = Concatenate(axis=2)([inputs,pos1_emb,pos2_emb])
###參數(shù)共享部分
x = Bidirectional(LSTM(128,return_sequences=True))(x)
###任務(wù)一衣形,10分類的文本分類任務(wù)
out1 = GlobalMaxPool1D()(x)
out1 = Dense(64, activation='relu')(out1)
out1 = Dropout(0.5)(out1)
out1 = Dense(10, activation='softmax',name = "out1")(out1)
###任務(wù)二,實(shí)體識(shí)別任務(wù)
crf = CRF(2, sparse_target=True,name ="crf_output")
crf_output = crf(x)
###模型有兩個(gè)輸出out1,crf_output
model = Model(inputs=[inputs,pos1_en,pos2_en], outputs=[out1,crf_output])
model.summary()
###模型有兩個(gè)loss,categorical_crossentropy和crf.loss_function
model.compile(optimizer='adam',
loss={'out1': 'categorical_crossentropy','crf_output': crf.loss_function},
loss_weights={'out1':1, 'crf_output': 1},
metrics=["acc"])
plot_model(model,to_file="model.png")
結(jié)語
筆者利用這個(gè)多任務(wù)學(xué)習(xí)的模型和去掉CRF實(shí)體識(shí)別分支的單任務(wù)模型做了對比實(shí)驗(yàn),確實(shí)多任務(wù)學(xué)習(xí)模型比單任務(wù)模型在測試集上的F1得分要好2個(gè)百分點(diǎn)左右谆吴。多任務(wù)訓(xùn)練模型的泛化能力確實(shí)很強(qiáng)倒源。如果你的深度學(xué)習(xí)模型遇到瓶頸了,可以嘗試一下多任務(wù)學(xué)習(xí)模型哦句狼。
參考文獻(xiàn)
https://blog.csdn.net/xuluohongshang/article/details/79044325