Bert模型tensorflow源碼解析(詳解transformer encoder數(shù)據(jù)運(yùn)算)

Github地址:https://github.com/1234560o/Bert-model-code-interpretation.git

Contents

  • 前言
  • 模型輸入
  • Padding_Mask
  • attention_layer
  • transformer_model
  • Bert_model class
  • 后續(xù)

前言

關(guān)于Bert模型的基本內(nèi)容這里就不講述了撒璧,可參考其它文章掺喻,這里有一個(gè)收集了很多講解bert文章的網(wǎng)址:

http://www.52nlp.cn/bert-paper-論文-文章-代碼資源匯總

與大多數(shù)文章不同的是,本文主要是對(duì)Bert模型部分的源碼進(jìn)行詳細(xì)解讀拯辙,搞清楚數(shù)據(jù)從Bert模型輸入到輸出的每一步變化剿牺,這對(duì)于我們理解Bert模型企垦、特別是改造Bert是具有極大幫助的。需要注意的是晒来,閱讀本文之前钞诡,請(qǐng)先對(duì)Transformer、Bert有個(gè)大致的了解,本文直接講述源碼中的數(shù)據(jù)運(yùn)算細(xì)節(jié)荧降,并不會(huì)涉及一些基礎(chǔ)內(nèi)容接箫。當(dāng)然,我們還是先來(lái)回顧下Bert模型結(jié)構(gòu):

01.png

Bert模型采用的是transformer的encoder部分(見上圖)誊抛,不同的是輸入部分Bert增加了segment_embedding且模型細(xì)節(jié)方面有些微區(qū)別列牺。下面直接進(jìn)入Bert源碼解析。Bert模型部分源碼地址:

https://github.com/google-research/bert/blob/master/modeling.py拗窃。

模型輸入

Bert的輸入有三部分:token_embedding、segment_embedding泌辫、position_embedding随夸,它們分別指得是詞的向量表示、詞位于哪句話中震放、詞的位置信息:

02.png

Bert輸入部分由下面兩個(gè)函數(shù)得到:

03.png

embedding_lookup得到token_embedding宾毒,embedding_postprocessor得到將這三個(gè)輸入向量相加的結(jié)果,注意embedding_postprocessor函數(shù)return最后結(jié)果之前有一個(gè)layer normalize和droupout處理:

04.png

Padding_Mask

由于輸入句子長(zhǎng)度不一樣殿遂,Bert作了填充處理诈铛,將填充的部分標(biāo)記為0,其余標(biāo)記為1墨礁,這樣是為了在做attention時(shí)能將填充部分得到的attention權(quán)重很少幢竹,從而能盡可能忽略padding部分對(duì)模型的影響:


05.png
06.png

attention_layer

為了方便分析數(shù)據(jù)流,對(duì)張量的維度作如下簡(jiǎn)記:

07.png

做了該簡(jiǎn)記后恩静,經(jīng)過(guò)詞向量層輸入Bert的張量維度為[B, F, embedding_size]焕毫,attention_mask維度為[B, F, T]。由于在Bert中是self-attention驶乾,F(xiàn)和T是相等的邑飒。接下來(lái)我詳細(xì)解讀一下attention_layer函數(shù),該函數(shù)是Bert的Multi-Head Attention级乐,也是模型最為復(fù)雜的部分疙咸。更詳細(xì)的代碼可以結(jié)合源碼看。在進(jìn)入這部分之前风科,也建議先了解一下2017年谷歌提出的transformer模型撒轮,推薦Jay Alammar可視化地介紹Transformer的博客文章The Illustrated Transformer ,非常容易理解整個(gè)機(jī)制丐重。而Bert采用的是transformer的encoding部分腔召,attention只用到了self-attention,self-attention可以看成Q=K的特殊情況扮惦。所以attention_layer函數(shù)參數(shù)中才會(huì)有from_tensor臀蛛,to_tensor這兩個(gè)變量,一個(gè)代表Q,另一個(gè)代表K及V(這里的Q浊仆,K客峭,V含義不作介紹,可參考transformer模型講解相關(guān)文章)抡柿。

? atterntion_layer函數(shù)里面首先定義了函數(shù)transpose_for_scores:

08.png

該函數(shù)的作用是將attention層的輸入(Q舔琅,K,V)切割成維度為[B, N, F 或T, H]洲劣。了解transformer可以知道备蚓,Q、K囱稽、V是輸入的詞向量分別經(jīng)過(guò)一個(gè)線性變換得到的郊尝。在做線性變換即MLP層時(shí)先將input_tensor(維度為[B, F, embedding_size])reshape成二維的(其實(shí)源碼在下一個(gè)函數(shù)transformer_model中使用這個(gè)函數(shù)傳進(jìn)去的參數(shù)已經(jīng)變成二維的了,這一點(diǎn)看下一個(gè)函數(shù)transformer_model可以看到):

09.png

接下來(lái)就是MLP層战惊,即對(duì)輸入的詞向量input_tensor作三個(gè)不同的線性變換去得到Q流昏、K、V吞获,當(dāng)然這一步后維度還需要轉(zhuǎn)換一下才能得到最終的Q况凉、K、V:

10.png

MLP層將[B * F, embedding_size]變成[B * F, N * H]各拷。但從后面的代碼(transformer_model函數(shù))可以看到embedding_size等于hidden_size等于N * H刁绒,相當(dāng)于這個(gè)MLP層沒有改變維度大小,這一點(diǎn)也是比較難理解的:

11.png

之后撤逢,代碼通過(guò)先前介紹的transpose_for_scores函數(shù)得到Q膛锭、K、V蚊荣,維度分別為[B, N, F, H]初狰、[B, N, T, H]、[B, N, T, H]互例。不解得是奢入,后面的求V代碼并不是通過(guò)transpose_for_scores函數(shù)得到,而是又把transpose_for_scores函數(shù)體再寫了一遍媳叨。

到目前為止Q腥光、K、V我們都已經(jīng)得到了糊秆,我們?cè)賮?lái)回顧一下論文“Attention is all you need”中的attention公式:


equation_1.png

下面這部分得到的attention_scores得到的是softmax里面的部分武福。這里簡(jiǎn)單解釋下tf.matmul。這個(gè)函數(shù)實(shí)質(zhì)上是對(duì)最后兩維進(jìn)行普通的矩陣乘法痘番,前面的維度都當(dāng)做batch捉片,因此這要求相乘的兩個(gè)張量前面的維度是一樣的平痰,后面兩個(gè)維度滿足普通矩陣的乘法規(guī)則即可。細(xì)想一下attention的運(yùn)算過(guò)程伍纫,這剛好是可以用這個(gè)矩陣乘法來(lái)得到結(jié)果的宗雇。得到的attention_scores的維度為[B, N, F, T]。只看后面兩個(gè)維度(即只考慮一個(gè)數(shù)據(jù)莹规、一個(gè)attention)赔蒲,attention_scores其實(shí)就是一個(gè)attention中Q和K作用得到的權(quán)重系數(shù)(還未經(jīng)過(guò)softmax),而Q和K長(zhǎng)度分別是F和T良漱,因此共有F * T個(gè)這樣的系數(shù):

12.png

那么比較關(guān)鍵的一步來(lái)了——Mask舞虱,即將padding部分“mask”掉(這和Bert預(yù)測(cè)詞向量任務(wù)時(shí)的mask是完全不同的,詳情參考相關(guān)文章母市,這里只討論模型的詳細(xì)架構(gòu)):

13.png

我們?cè)谇懊娌襟E中得到的attention_mask的維度為[B, F, T]砾嫉,為了能實(shí)現(xiàn)矩陣加法,所以先在維度1上(指第二個(gè)維度窒篱,第一個(gè)維度axis=0)擴(kuò)充一維,得到維度為[B, 1, F, T]舶沿。然后利用python里面的廣播機(jī)制就可以相加了墙杯,要mask的部分加上-10000.0,不mask的部分加上0括荡。這個(gè)模型的mask是在softmax之前做的高镐,至于具體原因我也不太清楚,還是繼續(xù)跟著數(shù)據(jù)流走吧畸冲。加上mask之后就是softmax嫉髓,softmax之后又加了dropout:

14.png

再之后就是softmax之后的權(quán)重系數(shù)乘上后面的V,得到維度為[B, N, F, H]邑闲,在維度為1和維度為2的位置轉(zhuǎn)置一下變成[B, F, N, H]算行,該函數(shù)可以返回兩種維度的張量:

  1. [B * F, N * H](源碼中注釋H變成了V,這一點(diǎn)是錯(cuò)誤嗎苫耸?還是我理解錯(cuò)了州邢?
  2. [B, F, N * H]
15.png

至此,我將bert模型中最為復(fù)雜的Multi-Head Attention數(shù)據(jù)變化形式講解完了褪子。下一個(gè)函數(shù)transformer_model搭建Bert整體模型。

transformer_model

下面我對(duì)transformer_model這個(gè)函數(shù)進(jìn)行解析嫌褪,該函數(shù)是將Transformer Encoded所有的組件結(jié)合在一起。 很多時(shí)候裙秋,結(jié)合圖形理解是非常有幫助的琅拌。下面我們先看一下下面這個(gè)盜的圖吧(我們把這個(gè)圖的結(jié)構(gòu)叫做transformer block吧):


16.png

整個(gè)Bert模型其實(shí)就是num_hidden_layers個(gè)這樣的結(jié)構(gòu)串連,相當(dāng)于有num_hidden_layers個(gè)transformer_block财忽。而self-attention部分在上個(gè)函數(shù)已經(jīng)梳理得很清楚了泣侮,剩下的其實(shí)都是一些熟悉的組件(殘差即彪、MLP、LN)活尊。transformer_model先處理好輸入的詞向量隶校,然后進(jìn)入一個(gè)循壞深胳,每個(gè)循壞就是一個(gè)block:

17.png

上面的截圖并未包括所有的循環(huán)代碼铜犬,我們一步步來(lái)走下去。顯然敛劝,代碼是將上一個(gè)transformer block的輸出作為下一個(gè)transformer block的輸入纷宇。那么第一個(gè)transformer block的輸入是什么呢?當(dāng)然是我們前面所說(shuō)的三個(gè)輸入向量相加得到的input_tensor上陕。至于每個(gè)block維度是否對(duì)得上拓春,計(jì)算是否準(zhǔn)確,繼續(xù)看后面的代碼就知道了痘儡。該代碼中還用了變量all_layer_outputs來(lái)保存每一個(gè)block的輸出結(jié)果沉删,設(shè)置參數(shù)do_return_all_layers可以選擇輸出每個(gè)block的結(jié)果或者最后一個(gè)block的結(jié)果。transformer_model中使用attention_layer函數(shù)的輸入數(shù)據(jù)維度為二維的([B * F或B * T, hidden_size])砖茸。詳細(xì)看attention_layer函數(shù)時(shí)是可以輸入二維張量數(shù)據(jù)的:

18.png

至于下面這部分為什么會(huì)有attention_heads這個(gè)變量殴穴,原因我也不知道凉夯,仿佛在這里是多此一舉,源碼中的解釋如下:

19.png

我們?cè)倩仡櫼幌律弦粋€(gè)函數(shù)attention_layer劲够,return的結(jié)果維度為[B * F, N * H]或[B, F, N * H]。注意這里面使用的attention_layer函數(shù)do_return_2d_tensor參數(shù)設(shè)置為True征绎,所以attention_output的維度為[B * F, N * H]人柿。然后再做一層MLP(該層并沒改變維度,因?yàn)閔idden_size=N * H)江咳、dropout哥放、layer_norm:

20.png

此時(shí)attention_output的維度還是[B * F, N * H或hidden_size]甥雕。由上面的圖可以接下來(lái)是繼續(xù)MLP層加dropout加layer_norm,只不過(guò)該層MLP的神經(jīng)元數(shù)intermediate_size是一個(gè)超參數(shù),可以人工指定:

21.png

由上面截圖的代碼可知接下來(lái)做了兩層MLP呵哨,維度變化[B * F, hidden_size]到[B * F, intermediate_size]再到[B * F, hidden_size]轨奄,再經(jīng)過(guò)dropout和layer_norm維度大小不變。至此挨务,一個(gè)transformer block已經(jīng)走完了玉组。而此時(shí)得到的layer_out將作為下一個(gè)block的輸入,這個(gè)維度與該模型第一個(gè)block的的輸入是一樣的朝巫,然后就是這樣num_hidden_layers次循環(huán)下去得到最后一個(gè)block的輸出結(jié)果layer_output石景,維度依舊為[B * F, hidden_size]。

return的時(shí)候通過(guò)reshape_from_matrix函數(shù)把block的輸出變成維度和input_shape一樣的維度揪荣,即一開始詞向量輸入input_tensor的維度([batch_size, seq_length, hidden_size])

22.png

Bert_model class

為了方便訓(xùn)練,模型的整個(gè)過(guò)程都封裝在Bert_model類中佛舱,通過(guò)該類的實(shí)例可以訪問(wèn)模型中的結(jié)果揽乱。詳細(xì)的過(guò)程見代碼。上述幾個(gè)函數(shù)梳理之后便沒什么復(fù)雜的了损拢,只是把內(nèi)容整合在一起了撒犀。self.all_encoder_layers是經(jīng)過(guò)transformer_model函數(shù)返回每個(gè)block的結(jié)果或舞,self.sequence_output得到最后一個(gè)維度的結(jié)果,由上面的分析知維度為[Batch_szie, seq_length, hidden_size]胆筒,這和一開始詞向量的維度是一樣的诈豌,只不過(guò)這個(gè)結(jié)果是經(jīng)過(guò)Transformer Encoded提取特征之后的,包含重要的信息彤蔽,也是Bert想得到的結(jié)果:

23.png

在這一步之后顿痪,該類用成員變量self.pooled_output保存第一個(gè)位置再經(jīng)過(guò)一個(gè)MLP層的輸出結(jié)果油够。熟悉數(shù)據(jù)輸入形式的可以知道,這個(gè)位置是[CLS]撕阎,該位置的輸出在Bert預(yù)訓(xùn)練中是用來(lái)判斷句子上下文關(guān)系的:

24.png

這里保存該結(jié)果除了可以用于Bert預(yù)訓(xùn)練碌补,還可以微調(diào)Bert用于分類任務(wù),詳細(xì)可參考:

http://www.reibang.com/p/22e462f01d8c

后續(xù)

文中可能存在不少筆誤或者理解不正確的表達(dá)不清晰地方敬請(qǐng)諒解镇匀,非常歡迎能提出來(lái)共同學(xué)習(xí)汗侵。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市发乔,隨后出現(xiàn)的幾起案子雪猪,更是在濱河造成了極大的恐慌只恨,老刑警劉巖,帶你破解...
    沈念sama閱讀 206,968評(píng)論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件纵菌,死亡現(xiàn)場(chǎng)離奇詭異休涤,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)闷堡,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,601評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)纵势,“玉大人管钳,你說(shuō)我怎么就攤上這事∨2埽” “怎么了醇滥?”我有些...
    開封第一講書人閱讀 153,220評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)阅虫。 經(jīng)常有香客問(wèn)我颓帝,道長(zhǎng),這世上最難降的妖魔是什么吕座? 我笑而不...
    開封第一講書人閱讀 55,416評(píng)論 1 279
  • 正文 為了忘掉前任瘪板,我火速辦了婚禮,結(jié)果婚禮上史侣,老公的妹妹穿的比我還像新娘魏身。我一直安慰自己箭昵,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,425評(píng)論 5 374
  • 文/花漫 我一把揭開白布正林。 她就那樣靜靜地躺著颤殴,像睡著了一般。 火紅的嫁衣襯著肌膚如雪佛呻。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,144評(píng)論 1 285
  • 那天景用,我揣著相機(jī)與錄音,去河邊找鬼。 笑死写妥,一個(gè)胖子當(dāng)著我的面吹牛劲弦,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播次坡,決...
    沈念sama閱讀 38,432評(píng)論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼砸琅,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼轴踱!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起诱篷,我...
    開封第一講書人閱讀 37,088評(píng)論 0 261
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤棕所,失蹤者是張志新(化名)和其女友劉穎悯辙,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體针贬,經(jīng)...
    沈念sama閱讀 43,586評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡桦他,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,028評(píng)論 2 325
  • 正文 我和宋清朗相戀三年谆棱,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了础锐。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片荧缘。...
    茶點(diǎn)故事閱讀 38,137評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖信姓,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情豆瘫,我是刑警寧澤菊值,帶...
    沈念sama閱讀 33,783評(píng)論 4 324
  • 正文 年R本政府宣布腻窒,位于F島的核電站,受9級(jí)特大地震影響瓦哎,放射性物質(zhì)發(fā)生泄漏柔逼。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,343評(píng)論 3 307
  • 文/蒙蒙 一犯助、第九天 我趴在偏房一處隱蔽的房頂上張望也切。 院中可真熱鬧腰湾,春花似錦、人聲如沸倒槐。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,333評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)把跨。三九已至沼死,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間耸别,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,559評(píng)論 1 262
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留省有,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,595評(píng)論 2 355
  • 正文 我出身青樓狭瞎,卻偏偏與公主長(zhǎng)得像熊锭,于是被迫代替她去往敵國(guó)和親雪侥。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,901評(píng)論 2 345

推薦閱讀更多精彩內(nèi)容