1. bert模型架構(gòu)
基礎(chǔ)架構(gòu)——transformer的encoder部分(如下圖)
transformer 是多層encoder-多層decoder結(jié)構(gòu)嗡髓。input = word_embedding + positional_encoding(word_embedding?詞向量,可以是隨機初始化裆甩,也可以使用 word2vec场绿;positional_embedding用正余弦表示位置特征)
而bert是多層encoder 結(jié)構(gòu)。
1. 輸入部分
input = token_embedding + segment_embedding + positional_embedding
token_embedding?——詞向量,可以是隨機初始化浦箱,也可以使用 word2vec
segment_embedding? ——用于對兩個句子進行區(qū)分大渤,CLS到到中間SEP制妄,全部都是(整個句子都是一樣的),也可以全部用0泵三;SEP到SEP耕捞,全部用,全部用1.
positional_embedding ——區(qū)別于transformer中的?positional_encoding烫幕,使用隨機初始化(但為什么用隨機初始化砸脊,而不用正余弦函數(shù),沒有找到很好地解釋)
input?中的特殊字符:CLS纬霞,SEP
(bert有兩個訓(xùn)練任務(wù)凌埂,一個是NSP,next sentence prediction诗芜,用來判斷兩個句子之間的關(guān)系瞳抓,因此,需要有一個字符告訴計算機句子與句子之間的分割伏恐,所以有了SEP字符孩哑。同時,NSP也是句子之間的二分類任務(wù)翠桦,因此横蜒,作者在句子前接了一個CLS字符,訓(xùn)練的時候销凑,將CLS輸出后面接一個二分類器丛晌,這就是CLS的作用。但是斗幼,有一個問題澎蛛,很多人認(rèn)為CLS代表整個句子的語義信息,但是原文并沒有這種說法蜕窿,所以可以作為一個思考點谋逻。B站UP主做過一個測試,他也提供了一些其他證據(jù)桐经,bert預(yù)訓(xùn)練模型直接拿來做sentence embedding毁兆,效果甚至不如word embedding,CLS效果最差阴挣。)
2. 多頭注意力部分
3. 前饋神經(jīng)網(wǎng)絡(luò)
區(qū)分:encoder與 decoder
2. 如何做預(yù)訓(xùn)練
2.1 MLM 掩碼語言模型
bert使用大量無監(jiān)督預(yù)料進行預(yù)訓(xùn)練气堕。無監(jiān)督模型有兩種目標(biāo)函數(shù),比較受重視:
舉例:原始預(yù)料“我愛吃飯”
1. AR(auto-regressive 自回歸模型):只考慮單側(cè)信息,典型的GPT
P(我愛吃飯) =P(我)P(愛|我)P(吃|我愛)P(飯|我愛吃)
AR的優(yōu)化目標(biāo)是:“我愛吃飯”的概率 = “我”出現(xiàn)的概率 “我”出現(xiàn)的條件下送巡,“愛”?出現(xiàn)的概率...
這個優(yōu)化目標(biāo)摹菠,是有一個前后依賴關(guān)系的。所以說骗爆,AR只用到了單側(cè)信息
2. AE(auto-encoding 自編碼模型):從損壞的輸入數(shù)據(jù)中預(yù)測重建原始數(shù)據(jù)次氨,可以使用上下文信息
對句子進行 mask,原句編程 “我愛mask飯”
P(我愛吃飯|我愛mask飯)=P(mask=吃|我愛飯)
AE的優(yōu)化目標(biāo):“我愛吃飯”的概率 = “我愛飯”出現(xiàn)的條件下摘投,mask=吃的概率
本質(zhì):打破文本原有的信息煮寡,讓模型訓(xùn)練的時候,進行文本重建
模型缺點:
P(我愛吃飯|我愛mask mask)=P(吃|我愛)P(飯|我愛)
mask之后犀呼,吃 和 飯 之間被看做是獨立的幸撕,但是本身是有關(guān)系的。
mask策略:
隨機mask 15%的單詞外臂,這 15%中坐儿,0.1被替換成其他單詞(有可能選到這個單詞本身),0.1原封不動宋光,0.8替換成其他(這個比例問題沒有知道到解釋)貌矿,如下圖:
2.2 NSP
NSP樣本如下:
1. 從訓(xùn)練語料庫匯總?cè)〕?b>兩個連續(xù)的段落作為正樣本(說明兩個段落來自于同一個文檔,同一個主題罪佳,且順序沒有顛倒)
2. 從不同文檔中隨機創(chuàng)建一對段落作為負(fù)樣本(不同的主題)
缺點:主題預(yù)測 和 連貫性預(yù)測 合并為一個單項任務(wù)
3. 如何微調(diào)bert逛漫,提升下游任務(wù)中的效果
四個常見的下游任務(wù):
a. 句子對分類任務(wù)——本質(zhì)是文本匹配,把兩個句子拼接起來赘艳,判斷是否相似酌毡。CLS接二分類器,輸出 0-相似蕾管,1-不相似
b. 單個句子分類任務(wù)——CLS輸出枷踏,接一個分類器,進行分類
c. 問答任務(wù)
d. 序列標(biāo)注任務(wù)——把所有的token 輸出娇掏,然后接softmax呕寝,進行標(biāo)注(比如詞性標(biāo)注,命名實體識別)
如何提升bert下游任務(wù)表現(xiàn)婴梧?即微調(diào)策略。
基本步驟:1-獲取一個訓(xùn)練好的bert客蹋,比如谷歌中文BERT塞蹭;2-基于任務(wù)數(shù)據(jù)進行微調(diào)。
比如讶坯,做微博情感分析:
1-獲取通用的預(yù)訓(xùn)練模型番电,比如谷歌中文BERT
2-在相同領(lǐng)域上繼續(xù)做模型訓(xùn)練,比如在微博文本上進行訓(xùn)練(Domain tansfer 領(lǐng)域自適應(yīng),或 領(lǐng)域遷移)
3-在任務(wù)相關(guān)的小數(shù)據(jù)上繼續(xù)訓(xùn)練漱办,在微博情感任務(wù)文本上進行訓(xùn)練(task transfer 任務(wù)自適應(yīng)这刷,或 任務(wù)遷移)
4-在任務(wù)相關(guān)數(shù)據(jù)上做具體的任務(wù) (fine-tune 微調(diào))
在上面 領(lǐng)域遷移 步驟中,還可以進行 further pre-traning娩井,比如:
1-動態(tài)mask:每次epoch去訓(xùn)練的時候mask
2-n-gram mask:ERNIE 和 SpanBERT類似于做了實體詞 的 mask
可以對參數(shù)進行優(yōu)化暇屋,從而提升模型效果:
batch size:16,32---128,影響不大洞辣,主要看及其效果
learning rate(Adam):5e-5, 3e-5, 2e-5 盡可能小一點咐刨,避免災(zāi)難性遺忘
number of epochs: 3,4
weighted decay:修改后的Adam,使用warmup, 搭配線性衰減
其他:
數(shù)據(jù)增強扬霜、自蒸餾定鸟、外部知識融入
比如 ERNIE ,融入了知識圖譜著瓶,加入了實體信息
4. 如何在脫敏數(shù)據(jù)中心使用BERT等預(yù)訓(xùn)練模型
如果本身語料很大联予,可以從0開始訓(xùn)練一個bert
否則,按照詞頻材原,把脫敏數(shù)字對找到中文(假如是中文語料)沸久,使用中文做bert初始化,然后基于新的中文語料訓(xùn)練bert
5. 代碼解讀
代碼最核心的一點华糖,MLM損失函數(shù)的計算:
15%的詞匯被mask麦向,8:1:1的比例進行了不同的處理,那么損失函數(shù)究竟計算的是哪一部分客叉?
最下面一行 是原始字符對應(yīng)的索引诵竭。
mask:把第三個字符 從索引 13 替換成了 4 對應(yīng)的字符。
接下來兼搏,經(jīng)過三個embedding卵慰,然后拼接,組合成input
然后佛呻,經(jīng)過 encoder 層(N層)
得到每個token的embedding:
第一個對應(yīng)的是 CLS裳朋,可以接linear層,做二分類任務(wù)
被mask的位置吓著,接linear層鲤嫡,在 詞表大小的維度(bert的此表大小是22128)接 softmax,挑選最有可能的詞匯绑莺,做損失函數(shù)
1. 參數(shù)部分
maxlen- 如果是document類很長的文本暖眼,一般會使用一些策略,使最大長度控制在256以內(nèi)纺裁。
max_pred -為了限制一個句子最大可以預(yù)測多少個token
n_layers - 選擇多少個 encoder诫肠,base一般選12司澎,large一般選24
n_head - 多少個頭(多頭注意力機制)
d_ff -前饋神經(jīng)網(wǎng)絡(luò)的維度,不用特意寫成768*4栋豫,沒什么意義挤安,直接寫成3072
d_q = d_v Q-K-V向量的維度,一般K和V要一致
n_segments NSP做二分類任務(wù)丧鸯,有兩個SEP(區(qū)分出兩個segments)
text 是 輸入樣本蛤铜,正常任務(wù)中肯定是從文本文件中讀取
sentences 是數(shù)據(jù)預(yù)處理的代碼,去除掉原始文本中一些沒用的字符
加下來骡送,創(chuàng)建一個詞表昂羡,其中 【PAD】編碼為0、【CLS】-1摔踱、【SEP】-2虐先、【MASK】-3等特殊字符進行編碼,而文本中的字符派敷,從3以后取蛹批。
token_list 是把文本轉(zhuǎn)化成數(shù)字
為什么補0?