機(jī)器學(xué)習(xí)基礎(chǔ)(11)條件隨機(jī)場(chǎng)的理解及BI-LSTM+CRF實(shí)戰(zhàn)

在NLP領(lǐng)域,在神經(jīng)網(wǎng)絡(luò)興起之前嚎于,條件隨機(jī)場(chǎng)(CRF)一直是作為主力模型的存在沟启,就算是在RNN系(包括BERT系)的模型興起之后,也通常會(huì)在模型的最后添加一個(gè)CRF層庵芭,以提高準(zhǔn)確率。因此雀监,CRF是所有NLPer必須要精通且掌握的一個(gè)模型双吆,本文將優(yōu)先闡述清楚與CRF有關(guān)的全部基本概念,并詳細(xì)對(duì)比HMM会前,最后獻(xiàn)上BI-LSTM+CRF的實(shí)戰(zhàn)代碼及理解好乐。相信讀完本文,將對(duì)CRF的認(rèn)識(shí)有一個(gè)新的高度瓦宜。

在閱讀本文之前蔚万,務(wù)必對(duì)概率圖模型基礎(chǔ)有一個(gè)全盤的掌握,若對(duì)此沒有信心的临庇,可以先參考我之前的一篇總結(jié)文:概率圖模型基礎(chǔ)

本文的基本目錄如下:

  1. 基礎(chǔ)知識(shí)
    1.1 CRF到底是什么反璃?
    1.2 如何用CRF建模?
    1.3 CRF與HMM的區(qū)別是什么假夺?

  2. BILSTM+CRF實(shí)戰(zhàn)
    2.1 為什么需要添加CRF層淮蜈?
    2.2 如何計(jì)算損失函數(shù)?
    2.3 實(shí)戰(zhàn)環(huán)節(jié)

------------------第一菇 - 基礎(chǔ)知識(shí)------------------

1.1 CRF到底是什么已卷?

本段主要用于講述與CRF有關(guān)的基礎(chǔ)概念礁芦。

大部分人理解CRF都會(huì)被帶到一個(gè)奇怪的誤區(qū)里面(包括我之前),因?yàn)榭偸抢斫馔炅薍MM以后悼尾,就會(huì)立馬投入到CRF的學(xué)習(xí)里面柿扣,所以就會(huì)理所當(dāng)然的認(rèn)為CRF就是HMM的升級(jí)版(確實(shí)從模型效果上可以這么理解),然后一直把HMM的各自概念往CRF上套闺魏,之后兩廂一對(duì)比未状,就會(huì)有點(diǎn)犯迷糊了,好多東西也對(duì)不上啊??~然后析桥,再一看書里的結(jié)論司草,啥?CRF竟然是判別式模型泡仗,HMM是生成式模型埋虹!這是什么鬼啦?CRF不就是HMM解除各自限制(有向圖變無向圖娩怎,箭頭指的方向更多了嗎I巍?=匾唷爬泥?)柬讨,怎么突然就變成判別式模型啦!袍啡?廢話不多說踩官,如果有此疑問的同學(xué),就說明看對(duì)文章了境输,不要心急慢慢看蔗牡,我將一一解釋清楚;而對(duì)此毫無疑問的同學(xué)嗅剖,那可以直接跳到實(shí)戰(zhàn)環(huán)節(jié)了哈辩越。

CRF真的是判別式模型!準(zhǔn)確說是窗悯,判別式無向圖模型区匣!

大家要牢記偷拔,區(qū)別判別式模型與生成式模型最基本的就是去判斷模型是對(duì)聯(lián)合分布進(jìn)行建模蒋院,還是對(duì)條件分布進(jìn)行建模。HMM中很顯然莲绰,模型是對(duì)x,y的聯(lián)合分布進(jìn)行建模(不清楚的同學(xué)欺旧,還請(qǐng)移步HMM專區(qū)),而CRF則不然蛤签,其試圖對(duì)多個(gè)變量在給定觀測(cè)值后的條件概率進(jìn)行建模辞友,因此屬于判別式模型。(各位抱著學(xué)新東西的心態(tài)來學(xué)CRF震肮,把HMM拋在腦后把)

具體展開來看称龙,若令x = \{x_1, x_2, x_3, ..., x_n\}為觀測(cè)序列,y = \{y_1, y_2, ..., y_n\}為與之對(duì)應(yīng)的標(biāo)記序列戳晌,則條件隨機(jī)場(chǎng)的目標(biāo)是構(gòu)建條件概率模型P(y | x)鲫尊。值得注意的是,標(biāo)記變量y可以是結(jié)構(gòu)型變量沦偎,即其分量之間具有某種相關(guān)性疫向。就比如在NLP領(lǐng)域中的詞性標(biāo)注任務(wù),觀測(cè)數(shù)據(jù)為單詞序列(即為x)豪嚎,標(biāo)記為相應(yīng)的詞性序列(即為y)搔驼,且其具有線性序列結(jié)構(gòu)。

1.2 如何用CRF建模侈询?

G = <V,E>表示結(jié)點(diǎn)與標(biāo)記變量y中元素一一對(duì)應(yīng)的無向圖舌涨,y_v表示與結(jié)點(diǎn)v對(duì)應(yīng)的標(biāo)記變量,n(v)表示結(jié)點(diǎn)v的鄰接結(jié)點(diǎn)扔字,若圖G的每個(gè)變量y_v都滿足馬爾可夫性(即只與其相鄰的結(jié)點(diǎn)有關(guān))泼菌,即谍肤,

P(y_v | x, y_{V\setminus \{v\}}) = P(y_v | x, y_{n(v)})

(y,x)構(gòu)成一個(gè)條件隨機(jī)場(chǎng)。而理論上來說哗伯,圖G可具有任意結(jié)構(gòu)荒揣,只要能表示標(biāo)記變量之間的條件獨(dú)立性關(guān)系即可,但在現(xiàn)實(shí)應(yīng)用中焊刹,尤其是對(duì)標(biāo)記序列建模時(shí)候系任,最常用的仍然是鏈?zhǔn)浇Y(jié)構(gòu),即“鏈?zhǔn)綏l件隨機(jī)場(chǎng)(chain-structured CRF)”虐块,也是我們接下來主要要討論的一種條件隨機(jī)場(chǎng)俩滥。

鏈?zhǔn)綏l件隨機(jī)場(chǎng)

那我們?cè)撊绾味xP(y|x)呢?

參考《機(jī)器學(xué)習(xí)》第14章的原文贺奠,其定義的方式類似馬爾可夫隨機(jī)場(chǎng)模型定義的聯(lián)合概率霜旧。條件隨機(jī)場(chǎng)使用勢(shì)函數(shù)和圖結(jié)構(gòu)上的團(tuán)來定義條件概率P(y|x)!如上圖所示儡率,該鏈?zhǔn)綏l件隨機(jī)場(chǎng)主要包含兩種關(guān)于標(biāo)記變量的團(tuán)挂据,即單個(gè)標(biāo)記變量\{y_i\}以及相鄰的標(biāo)記變量\{y_{i-1}, y_i\}。選擇合適的勢(shì)函數(shù)儿普,即可得到形如馬爾可夫隨機(jī)場(chǎng)中聯(lián)合概率的定義崎逃。

在條件隨機(jī)場(chǎng)中,通過選用指數(shù)勢(shì)函數(shù)并引入特征函數(shù)眉孩,條件概率被定義如下个绍,

P(Y|X) = \frac{1}{Z}exp\left (\sum_{j}\sum_{i=1}^{n-1}\lambda_jt_j(y_{i+1}, y_i,X,i) + \sum_k\sum_{i=1}^{n}\mu_ks_k(y_i,X,i)\right )

注意哈,這里的X指的是整一個(gè)觀測(cè)序列浪汪!而且這里定義的條件概率計(jì)算方式巴柿,就只是將觀測(cè)序列X作為條件,并不對(duì)其作任何獨(dú)立性假設(shè)K涝狻9慊帧!(這點(diǎn)很重要殃姓!也是其是判別式模型的重要依據(jù))

其中袁波,t_j(y_{i+1}, y_i,X,i)是定義在觀測(cè)序列的倆個(gè)相鄰標(biāo)記位置上的轉(zhuǎn)移特征函數(shù),用于刻畫相鄰標(biāo)記變量之間的相關(guān)關(guān)系以及觀測(cè)序列對(duì)他們的影響蜗侈。即給定觀測(cè)序列X篷牌,其標(biāo)注序列在ii-1位置上標(biāo)記的轉(zhuǎn)移概率!而特征函數(shù)的定義往往不止一種踏幻,因此會(huì)有一個(gè)下標(biāo)j代表要遍歷計(jì)算每一種特征函數(shù)的取值枷颊。

另外,s_k(y_i,X,i)是定義在觀測(cè)序列的標(biāo)記位置i上的狀態(tài)特征函數(shù),用于刻畫觀測(cè)序列對(duì)標(biāo)記變量的影響夭苗。即表示對(duì)于觀察序列Xi位置的標(biāo)記概率信卡。同理,也有多種特征函數(shù)题造,所以會(huì)有下標(biāo)k傍菇。

剩下的就比較簡單理解,\lambda_j和\mu_k都是參數(shù)界赔,Z為規(guī)范化因子丢习,用于確保上式是被正確定義的概率(可以理解為類似softmax的操作)。

總結(jié)一下上式淮悼,可以理解為如下圖咐低,

概率定義理解圖

至此,整個(gè)概率的定義想必大家已經(jīng)爛熟于心了~顯然袜腥,要運(yùn)用好條件隨機(jī)場(chǎng)见擦,最重要的就是要去定義合適的特征函數(shù)了。特征函數(shù)通常是實(shí)值函數(shù)羹令,以刻畫數(shù)據(jù)的一些很可能成立或期望成立的經(jīng)驗(yàn)特性鲤屡。因此定義特征函數(shù)的時(shí)候,一般都可以定義一組關(guān)于觀察序列的\{0,1\}二值特征b(X,i)來表示訓(xùn)練樣本中某些分布特性特恬,比如詞性標(biāo)注任務(wù)执俩,

b(X,i) = \left\{\begin{matrix} 1, & X的\ i \ 位置為某個(gè)特定的詞\\ 0, & 否則 \end{matrix}\right.

等等類似的特征函數(shù)徐钠,能定義好多出來的癌刽。因此,小小總結(jié)一下尝丐,CRF與馬爾可夫隨機(jī)場(chǎng)均使用團(tuán)上的勢(shì)函數(shù)定義概率显拜,兩者在形式上并沒有顯著區(qū)別,只不過CRF處理的是條件概率爹袁,而馬爾可夫隨機(jī)場(chǎng)處理的是聯(lián)合概率远荠。至此,整個(gè)CRF的建模已經(jīng)講明白了失息。

1.3 CRF與HMM的區(qū)別是什么譬淳?

CRF與HMM的一些基本定義的概念區(qū)別這邊在講概率圖模型和上面的基礎(chǔ)定義時(shí)已經(jīng)表述的很清楚了,本段就不繼續(xù)展開了~這里主要講一下HMM的標(biāo)注偏置問題盹兢,以及CRF為何能解決這個(gè)問題邻梆。

其實(shí)要想解釋清楚標(biāo)注偏置問題,大家只要看如下貼的一張圖即可绎秒,

標(biāo)注偏置問題

大家可以發(fā)現(xiàn)浦妄,狀態(tài)1傾向于轉(zhuǎn)移到狀態(tài)2,狀態(tài)2傾向于轉(zhuǎn)移到狀態(tài)2本身,但是實(shí)際計(jì)算得到的最大概率路徑是1>1>1>1剂娄,狀態(tài)1并沒有轉(zhuǎn)移到狀態(tài)2蠢涝!這其實(shí)是與我們的直覺相悖的~究其本質(zhì)原因,從狀態(tài)2轉(zhuǎn)移出去可能的狀態(tài)包括1阅懦,2和二,3,4耳胎,5儿咱,概率在可能的狀態(tài)上分散了,而狀態(tài)1轉(zhuǎn)移出去的可能狀態(tài)僅僅為狀態(tài)1和2场晶,概率更加集中;觳骸(大家可以拿筆算一下,是不是這么個(gè)理~加深理解)由于局部歸一化的影響诗轻,隱狀態(tài)會(huì)傾向于轉(zhuǎn)移到那些后續(xù)狀態(tài)可能更少的狀態(tài)上钳宪,以提高整體的后驗(yàn)概率!這就是標(biāo)注偏置問題扳炬!

而CRF如上所述吏颖,因?yàn)橛袣w一化因子(Z)的存在,其在全局范圍內(nèi)進(jìn)行了歸一化恨樟,枚舉了整個(gè)隱狀態(tài)序列的全部可能半醉,從而解決了局部歸一化帶來的標(biāo)注偏置問題。而這也是CRF在很多問題上劝术,表現(xiàn)比HMM優(yōu)秀的原因~

------------------第二菇 - BILSTM+CRF實(shí)戰(zhàn)------------------

介紹了這么多CRF有關(guān)的東西缩多,想必各位也是躍躍欲試,我這邊也獻(xiàn)上一份BILSTM+CRF的實(shí)戰(zhàn)解析养晋,包括對(duì)此模型架構(gòu)的理解以及源碼的解讀~

這里貼一個(gè)鏈接衬吆,是一個(gè)外國小哥寫的博客,當(dāng)初就是看這個(gè)博客明白其原理的绳泉,所以也特地在這邊貼出來逊抡,英文好的同學(xué)也可以直接看這個(gè)鏈接,我就不單獨(dú)放在參考文獻(xiàn)里了零酪。

為了便于理解冒嫡,代碼都是用Pytorch寫的,且還是以命名實(shí)體識(shí)別任務(wù)為具體例子四苇。

如果看到這篇文章的是初學(xué)者孝凌,也不用慌,就簡單理解BILSTM和CRF為一個(gè)命名實(shí)體識(shí)別模型中的兩個(gè)層蛔琅。

為了便于理解下面的圖示胎许,這邊假設(shè)我們的數(shù)據(jù)集有兩大類峻呛,人名地名,與之相對(duì)應(yīng)在我們的訓(xùn)練數(shù)據(jù)集中辜窑,有五類標(biāo)簽:

* B-Person
* I-Person
* B-Organization
* I-Organization
* O

假設(shè)句子x由5個(gè)字符組成钩述,即x = (w_0, w_1, w_2, w_3, w_4, w_5),其中[w_0, w_1]人名實(shí)體穆碎,[w_3]組織實(shí)體牙勘,其他字符的標(biāo)簽為"O"。

2.1 為什么需要添加CRF層所禀?

這里先直接貼一張BILSTM-CRF的模型結(jié)構(gòu)圖方面,方便大家理解。

BiLSTM+CRF模型結(jié)構(gòu)圖

從下往上看色徘,最下面就是輸入(字或詞向量)恭金,由于是序列模型,因此褂策,在“時(shí)間”緯度上進(jìn)行展開横腿,就可以得到如圖所示的模型表示,對(duì)應(yīng)于一個(gè)時(shí)刻斤寂,就是輸入一個(gè)字/詞向量(一般都是預(yù)先訓(xùn)練得出的)耿焊。

首先,是經(jīng)過BiLSTM的結(jié)構(gòu)單元遍搞。這個(gè)比較好理解罗侯,本質(zhì)上就是倆個(gè)LSTM層,只不過一次是正序輸入溪猿,一次是倒序輸入钩杰,然后把倆個(gè)結(jié)果進(jìn)行concact(拼接),并輸入到CRF層再愈,最后由CRF層輸出每一個(gè)詞的標(biāo)簽~如果沒有CRF層的話榜苫,傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)都會(huì)加一層softmax層用于歸一化并輸出每個(gè)標(biāo)簽概率护戳。

為了更容易理解CRF層的作用翎冲,我們還是先要理清Bi-LSTM的輸出。這里再貼一張圖媳荒,方便大家理解抗悍,

BiLSTM層的輸出.png

大家可以看到,其輸出十分簡單清晰钳枕,就是對(duì)于每一個(gè)單詞缴渊,其對(duì)應(yīng)每一個(gè)標(biāo)簽的分值(score)。因此鱼炒,就算沒有CRF層衔沼,該模型依舊有效,我們只需要挑選每一個(gè)標(biāo)簽對(duì)應(yīng)最大的分值就可以,比如指蚁,w_0就是B-Person菩佑。因此,在原有模型本身就有效的情況下凝化,我們?cè)偬砑右粚覥RF的目的肯定只有一個(gè)稍坯,即提高模型的準(zhǔn)確率

接下來搓劫,我們就要重點(diǎn)分析一下瞧哟,CRF層的作用。先上結(jié)論枪向,CRF層的主要作用是為最后預(yù)測(cè)的標(biāo)簽添加一些約束來保證預(yù)測(cè)標(biāo)簽的合理性勤揩!比如,在命名實(shí)體識(shí)別任務(wù)中秘蛔,我們可以想到的約束可以是雄可,

1)開頭的標(biāo)簽只能是B, O,而不可能是I-
2)B-Person開頭的標(biāo)簽缠犀,后面不可能接一個(gè)I-Organization
数苫。。辨液。

有了如上約束虐急,我們就能保證,最終預(yù)測(cè)生成的標(biāo)簽序列的不合理性就會(huì)大大降低滔迈,而單憑BiLSTM的輸出來預(yù)測(cè)是無法保證標(biāo)簽序列的合理性的~

2.2 損失函數(shù)的定義及計(jì)算

弄清楚了CRF層的作用以后止吁,我們就要來仔細(xì)研究研究CRF層的運(yùn)行原理了,主要從其損失函數(shù)的角度來理解燎悍。

2.2.1 CRF中的兩種分?jǐn)?shù)

在CRF層的損失函數(shù)中敬惦,有兩種形式的score(分?jǐn)?shù)),第一個(gè)就是emission score(發(fā)射分?jǐn)?shù))谈山,主要就是來自于BiLSTM層的輸出(如上圖所示)俄删,假設(shè)我們給每一個(gè)標(biāo)簽一個(gè)索引,那么第一個(gè)單詞的emission score就是奏路,[1.5, 0.9, 0.1, 0.08, 0.05]

:大家可千萬注意了畴椰,不要把這里的score和HMM里面的發(fā)射概率矩陣相混淆!兩者可是完全不一樣的鸽粉,在HMM中斜脂,每一個(gè)單詞的發(fā)射概率,僅與當(dāng)前的隱狀態(tài)層有關(guān)触机!是由隱狀態(tài)決定了當(dāng)前單詞的發(fā)射概率帚戳!而這里的發(fā)射分?jǐn)?shù)玷或,是由當(dāng)前輸入的序列,決定的當(dāng)前狀態(tài)的概率片任!這里是整一個(gè)序列哦庐椒!若大家還能記得CRF層的定義和概率模型圖(往上翻一翻),想必對(duì)此并不會(huì)驚訝蚂踊!而這约谈,也是CRF層能接在神經(jīng)網(wǎng)絡(luò)最后一層的主要原因!大家對(duì)此一定要有深刻的理解和認(rèn)識(shí)犁钟。

第二個(gè)就是transition score(轉(zhuǎn)移分?jǐn)?shù))棱诱,這個(gè)倒是跟HMM中的狀態(tài)概率轉(zhuǎn)移矩陣相類似,也很好理解涝动,也是模型中主要學(xué)習(xí)的參數(shù)霸旗!而且為了使模型更具有魯棒性荐操,我們額外增加了倆個(gè)標(biāo)簽,STARTENDSTART代表句子的開始位置佳晶,而非第一個(gè)詞拔创,同理END代表句子的結(jié)束位置威始,這里也貼一張transition score矩陣的圖短绸,方便大家理解,

transition score 圖.png

大家從圖中應(yīng)該很清楚可以看到育苟,其能學(xué)到很多約束規(guī)則较鼓!因此,該矩陣也是模型主要訓(xùn)練的一個(gè)參數(shù)违柏,一般一開始都會(huì)初始化一個(gè)概率轉(zhuǎn)移矩陣博烂,隨著訓(xùn)練的迭代,逐漸合理~因此漱竖,接下來禽篱,我們就要來看看,其損失函數(shù)是如何設(shè)計(jì)的馍惹,才能學(xué)到合理的參數(shù)~

2.2.2 損失函數(shù)的設(shè)計(jì)

先明確一點(diǎn)躺率,損失函數(shù)就是我們要優(yōu)化的目標(biāo),那對(duì)于這樣一個(gè)序列標(biāo)注問題讼积,我們肯定是希望肥照,正確的序列,是所有的可能序列中勤众,得分最高的!就如同我們作HMM解碼的時(shí)候鲤脏,利用維特比算法解碼们颜,我們返回的肯定是概率最大的那條路徑一般吕朵,那反過來,我們訓(xùn)練的時(shí)候窥突,自然希望得到的參數(shù)努溃,能使得正確路徑的概率最大~

有了上述的核心思想,我們?cè)賮硐胍幌胱栉剩绾吻蠼馕嗨啊<僭O(shè)一共有N種可能的標(biāo)簽序列組合,記第i個(gè)標(biāo)簽序列的得分為P_i称近,那么所有可能標(biāo)簽序列組合的總得分為第队,

P_{total} = P_1 + P_2 + ... + P_N = e^{S_1} +e^{S_2} +... + e^{S_N}

因此,我們可以設(shè)想出一個(gè)損失函數(shù)刨秆,就是真實(shí)序列的分?jǐn)?shù)在所有可能的序列中占比最高凳谦,即,

L = \frac{P_{real}}{P_{total}}

由這個(gè)損失函數(shù)衡未,引出2個(gè)問題尸执,

1)如何定義計(jì)算每一個(gè)序列的得分?
2)如何計(jì)算所有標(biāo)簽序列的總得分缓醋?

2.2.3 求解一個(gè)序列的得分

先來看第一個(gè)問題如失,上述曾提過倆個(gè)分?jǐn)?shù)概念,emission score和transition score送粱,因此岖常,一個(gè)序列的得分也有這倆個(gè)構(gòu)成。

S_i = EmissionScore + TransitionScore

看到這個(gè)公式葫督,大家再與CRF定義相聯(lián)系竭鞍,有木有看出點(diǎn)什么花頭?沒錯(cuò)橄镜,這個(gè)跟CRF定義的條件概率幾乎是類似的哦偎快,基本上可以理解為,BiLSTM的輸出(也就是EmissionScore)取代來狀態(tài)特征函數(shù)的位置洽胶,而我們要學(xué)習(xí)的參數(shù)也就是轉(zhuǎn)移特征函數(shù)及其權(quán)重晒夹。所有,CRF與神經(jīng)網(wǎng)絡(luò)的配套組合并不是強(qiáng)行加上或者巧合姊氓,而是有理論作強(qiáng)支撐的哈哈~

我們逐一來理解每一個(gè)分?jǐn)?shù)的計(jì)算過程丐怯,假設(shè)我們有一個(gè)正確的序列標(biāo)注為,[START, B-Person, I-Person, O, B-Organization, O, END]

那么翔横,
EmissionScore = x_{0, START} + x_{1, B-Person} + ... x_{6, END}

其中读跷,x_{index,label}就表示第index個(gè)詞被標(biāo)記為label的得分(直接是從神經(jīng)網(wǎng)絡(luò)的輸出能拿到的)。

而另一個(gè)轉(zhuǎn)移分?jǐn)?shù)即為禾唁,
TransitionScore = t_{START, B-Person} + t_{B-Person, I-Person} + ... t_{O, END}

其中效览,t_{label1, label2}就表示label1 到 label2的轉(zhuǎn)移概率无切,也就是模型要學(xué)習(xí)的參數(shù)~

2.2.4 計(jì)算所有序列總分?jǐn)?shù)的方法

至此,每一條路徑的總得分丐枉,就可以根據(jù)上面的式子很輕松的計(jì)算得出~但顯然哆键,如果真實(shí)計(jì)算也是如此遍歷操作的話,時(shí)間復(fù)雜度會(huì)吃不消的瘦锹,因此我們需要一個(gè)高效的算法來計(jì)算~

我們先簡化一下?lián)p失函數(shù)籍嘹,

損失函數(shù)的簡化.png

簡化之后,可以很輕松的看出弯院,前半部分的計(jì)算是固定的辱士,我們只需要高效的計(jì)算出后半部分即可,

e^{S_1} + e^{S_2} + ... + e^{S_N}

很明顯抽兆,這里會(huì)運(yùn)用到動(dòng)態(tài)規(guī)劃的思想(不懂動(dòng)態(tài)規(guī)劃的识补,直接去看一下維特比算法,加深理解)來進(jìn)行求解辫红,即利用w_0的總得分來推出w_1的總得分凭涂,最后以此類推,每一次計(jì)算都需要利用到上一步計(jì)算得到的結(jié)果贴妻。

這里舉一個(gè)簡單的示例切油,假設(shè)句子長度為3([w_0, w_1, w_2]),標(biāo)簽有2個(gè)([l_1, l_2])名惩,我們學(xué)到的Emission Score 矩陣如下(BiLSTM輸出)澎胡,

l_1 l_2
w_0 x_{01} x_{02}
w_1 x_{11} x_{12}
w_2 x_{21} x_{22}

學(xué)習(xí)到的Transition Score矩陣如下,

l_1 l_2
l_1 t_{11} t_{12}
l_2 t_{21} t_{22}

接下來娩鹉,將演示如何計(jì)算總得分攻谁,因?yàn)槭莿?dòng)態(tài)規(guī)劃的思想,只需演繹出其中一步即可~

針對(duì)w_0弯予,很輕松戚宦,因?yàn)闆]有轉(zhuǎn)移分?jǐn)?shù),僅有發(fā)射分?jǐn)?shù)锈嫩,因此在第一個(gè)位置的總得分即為兩種路徑的分?jǐn)?shù)總和受楼,而現(xiàn)在兩種路徑就是兩種可能性,要么就是標(biāo)簽1要么就是標(biāo)簽2呼寸,而這也是整個(gè)動(dòng)態(tài)規(guī)劃開始的初始條件~)

S_{w0} = log(e^{x_{01}} + e^{x_{02}})

接下來艳汽,我們要求在第二個(gè)位置w_1的總得分,注意我們的推導(dǎo)式子就是由w_0推出的w_1对雪,因此我們直接利用w_0計(jì)算得出的分?jǐn)?shù)即可~如下圖所演示的~

路徑求和示意圖.png

上述的動(dòng)態(tài)核心式子即為河狐,

S_{ij} = S_{i-1j} + t_{ij} + x_{ij}

最終,將在w1位置的所有狀態(tài)求和得分相加即是總路徑的得分~

有人可能會(huì)問,那你這不是只有w_1一個(gè)位置的得分了嗎甚牲?我們不是要求得總路徑得分嗎义郑?有這個(gè)疑惑的同學(xué)應(yīng)該就是還沒有領(lǐng)悟到動(dòng)態(tài)規(guī)劃的精髓蝶柿,建議自己手動(dòng)推導(dǎo)一遍丈钙,便可迎刃而解~至此,整一個(gè)所有序列路徑的求和方法交汤,我們已經(jīng)大致了解清楚了~(多提一句雏赦,預(yù)測(cè)階段的解碼思路,其實(shí)就是維特比算法芙扎,也是動(dòng)態(tài)規(guī)劃的思路星岗,十分簡單,這里就不多說了~)

2.3 實(shí)戰(zhàn)環(huán)節(jié)

上面兩節(jié)已經(jīng)把BiLSTM+CRF講的清清楚楚了~光看理論還不夠戒洼,我們要深入代碼實(shí)戰(zhàn)環(huán)節(jié)(注:此乃網(wǎng)上找的一個(gè)Pytorch的版本俏橘,個(gè)人覺得是寫的比較好的,只限用于理解理論圈浇,并非商業(yè)應(yīng)用)

我們首先導(dǎo)入相應(yīng)的包和定義一些后面要用到的輔助函數(shù)寥掐,如下,

import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)

# some helper functions
def argmax(vec):
    # return the argmax as a python int
    # 第1維度上最大值的下標(biāo)
    # input: tensor([[2,3,4]])
    # output: 2
    _, idx = torch.max(vec,1)
    return idx.item()

def prepare_sequence(seq,to_ix):
    # 文本序列轉(zhuǎn)化為index的序列形式
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def log_sum_exp(vec):
    #compute log sum exp in a numerically stable way for the forward algorithm
    # 用數(shù)值穩(wěn)定的方法計(jì)算正演算法的對(duì)數(shù)和exp
    # input: tensor([[2,3,4]])
    # max_score_broadcast: tensor([[4,4,4]])
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1,-1).expand(1,vec.size()[1])
    return max_score+torch.log(torch.sum(torch.exp(vec-max_score_broadcast)))

START_TAG = "<s>"
END_TAG = "<e>"

這里定義的幾個(gè)輔助函數(shù)都比較直觀磷蜀,唯獨(dú)log_sum_exp可能會(huì)對(duì)大家造成一點(diǎn)困擾召耘,但其實(shí)這是一種考慮數(shù)值穩(wěn)定性的求解辦法,具體大家參考這篇博文即可褐隆,深究一下也是好事情污它,不深究的就明白這個(gè)函數(shù)是為了求即可~

log(e^{S_1} + e^{S_2} ... e^{S_n})

我們接著看模型的定義,

# create model
class BiLSTM_CRF(nn.Module):
    def __init__(self,vocab_size, tag2ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF,self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.tag2ix = tag2ix
        self.tagset_size = len(tag2ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim//2, num_layers=1, bidirectional=True)

        # maps output of lstm to tog space
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # matrix of transition parameters
        # entry i, j is the score of transitioning to i from j
        # tag間的轉(zhuǎn)移矩陣庶弃,是CRF層的參數(shù)
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))

        # these two statements enforce the constraint that we never transfer to the start tag
        # and we never transfer from the stop tag
        self.transitions.data[tag2ix[START_TAG], :] = -10000
        self.transitions.data[:, tag2ix[END_TAG]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1,self.hidden_dim//2),
                torch.randn(2, 1,self.hidden_dim//2))

    def _forward_alg(self, feats):
        # to compute partition function
        # 求歸一化項(xiàng)的值衫贬,應(yīng)用動(dòng)態(tài)歸化算法
        init_alphas = torch.full((1,self.tagset_size), -10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
        # START_TAG has all of the score
        init_alphas[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])

        forward_var = init_alphas

        for feat in feats:
            #feat指Bi-LSTM模型每一步的輸出,大小為tagset_size
            alphas_t = []
            for next_tag in range(self.tagset_size):
                # 取其中的某個(gè)tag對(duì)應(yīng)的值進(jìn)行擴(kuò)張至(1歇攻,tagset_size)大小
                # 如tensor([3]) -> tensor([[3,3,3,3,3]])
                emit_score = feat[next_tag].view(1,-1).expand(1,self.tagset_size)
                # 增維操作
                trans_score = self.transitions[next_tag].view(1,-1)
                # 上一步的路徑和+轉(zhuǎn)移分?jǐn)?shù)+發(fā)射分?jǐn)?shù)
                next_tag_var = forward_var + trans_score + emit_score
                # log_sum_exp求和
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            # 增維
            forward_var = torch.cat(alphas_t).view(1,-1)
        terminal_var = forward_var+self.transitions[self.tag2ix[END_TAG]]
        alpha = log_sum_exp(terminal_var)
        #歸一項(xiàng)的值
        return alpha

    def _get_lstm_features(self,sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence),1,-1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self,feats,tags):
        # gives the score of a provides tag sequence
        # 求某一路徑的值
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag2ix[START_TAG]], dtype=torch.long), tags])
        for i , feat in enumerate(feats):
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag2ix[END_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        # 當(dāng)參數(shù)確定的時(shí)候固惯,求解最佳路徑
        backpointers = []

        init_vars = torch.full((1,self.tagset_size),-10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
        init_vars[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])

        forward_var = init_vars
        for feat in feats:
            bptrs_t = [] # holds the back pointers for this step
            viterbivars_t = [] # holds the viterbi variables for this step

            for next_tag in range(self.tagset_size):
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag2ix[END_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag2ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        # 由lstm層計(jì)算得的每一時(shí)刻屬于某一tag的值
        feats = self._get_lstm_features(sentence)
        # 歸一項(xiàng)的值
        forward_score = self._forward_alg(feats)
        # 正確路徑的值
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score# -(正確路徑的分值  -  歸一項(xiàng)的值)

    def forward(self, sentence):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

上面的注釋應(yīng)該說是很詳細(xì)了,一開始的初始化定義也都是Pytorch的常規(guī)寫法(簡書的代碼顯示的略詭異掉伏,大家將就看看吧)~LSTM層也是直接掉的nn里面的缝呕,只有CRF層是自己手?jǐn)]上來的~所以,大家重點(diǎn)關(guān)注一下_forward_alg這個(gè)函數(shù)斧散,就是我們上面講的求解路徑總得分的函數(shù)~其中feats就是序列步長供常,自然是要順序遍歷每一個(gè)feat,其中每一個(gè)feat又要遍歷每一種tag的情況鸡捐,利用forward_var記錄每一個(gè)路徑的總得分(實(shí)時(shí)更新)栈暇,最后在求和即可!應(yīng)該說看懂了上面的解釋的同學(xué)箍镜,在看這個(gè)代碼源祈,簡直是太簡單了哈哈~其他的函數(shù)也沒啥好特地強(qiáng)調(diào)的煎源,大家掃一眼明白即可,對(duì)解碼不清楚的香缺,直接看代碼也難手销,手動(dòng)推演一遍,理解的更快~

最后图张,我們?cè)賮砜匆幌轮骱瘮?shù)锋拖,

if __name__ == "__main__":
    EMBEDDING_DIM = 5
    HIDDEN_DIM = 4

    # Make up some training data
    training_data = [(
        "the wall street journal reported today that apple corporation made money".split(),
        "B I I I O O O B I O O".split()
    ), (
        "georgia tech is a university in georgia".split(),
        "B I O O O O B".split()
    )]

    word2ix = {}
    for sentence, tags in training_data:
        for word in sentence:
            if word not in word2ix:
                word2ix[word] = len(word2ix)

    tag2ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, END_TAG: 4}

    model = BiLSTM_CRF(len(word2ix), tag2ix, EMBEDDING_DIM, HIDDEN_DIM)
    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

    # Check predictions before training
    # 輸出訓(xùn)練前的預(yù)測(cè)序列
    with torch.no_grad():
        precheck_sent = prepare_sequence(training_data[0][0], word2ix)
        precheck_tags = torch.tensor([tag2ix[t] for t in training_data[0][1]], dtype=torch.long)
        print(model(precheck_sent))

    # Make sure prepare_sequence from earlier in the LSTM section is loaded
    for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
        for sentence, tags in training_data:
            # Step 1. Remember that Pytorch accumulates gradients.
            # We need to clear them out before each instance
            model.zero_grad()

            # Step 2. Get our inputs ready for the network, that is,
            # turn them into Tensors of word indices.
            sentence_in = prepare_sequence(sentence, word2ix)
            targets = torch.tensor([tag2ix[t] for t in tags], dtype=torch.long)

            # Step 3. Run our forward pass.
            loss = model.neg_log_likelihood(sentence_in, targets)

            # Step 4. Compute the loss, gradients, and update the parameters by
            # calling optimizer.step()
            loss.backward()
            optimizer.step()

    # Check predictions after training
    with torch.no_grad():
        precheck_sent = prepare_sequence(training_data[0][0], word2ix)
        print(model(precheck_sent))

    # 輸出結(jié)果
    # (tensor(-9996.9365), [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
    # (tensor(-9973.2725), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])

也是比較常規(guī)的寫法,還帶了示例~大家應(yīng)該很容易理解的祸轮!

至此兽埃,整一套跟CRF有關(guān)的知識(shí)點(diǎn)和代碼解釋已經(jīng)全部弄清楚了。簡單總結(jié)一下本文适袜,先是詳細(xì)解釋了一下與CRF有關(guān)的一些誤區(qū)和知識(shí)點(diǎn)柄错,接著展示了與CRF有關(guān)的用法和計(jì)算損失函數(shù)的方法,最后獻(xiàn)上了詳細(xì)的代碼解讀~希望大家讀完本文后對(duì)CRF的一些概念會(huì)有一個(gè)全新的認(rèn)識(shí)苦酱。有說的不對(duì)的地方也請(qǐng)大家指出售貌,多多交流,大家一起進(jìn)步~??

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末躏啰,一起剝皮案震驚了整個(gè)濱河市趁矾,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌给僵,老刑警劉巖毫捣,帶你破解...
    沈念sama閱讀 217,277評(píng)論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異帝际,居然都是意外死亡蔓同,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評(píng)論 3 393
  • 文/潘曉璐 我一進(jìn)店門蹲诀,熙熙樓的掌柜王于貴愁眉苦臉地迎上來斑粱,“玉大人,你說我怎么就攤上這事脯爪≡虮保” “怎么了?”我有些...
    開封第一講書人閱讀 163,624評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵痕慢,是天一觀的道長尚揣。 經(jīng)常有香客問我,道長掖举,這世上最難降的妖魔是什么快骗? 我笑而不...
    開封第一講書人閱讀 58,356評(píng)論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上方篮,老公的妹妹穿的比我還像新娘名秀。我一直安慰自己,他們只是感情好藕溅,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,402評(píng)論 6 392
  • 文/花漫 我一把揭開白布匕得。 她就那樣靜靜地躺著,像睡著了一般蜈垮。 火紅的嫁衣襯著肌膚如雪耗跛。 梳的紋絲不亂的頭發(fā)上裕照,一...
    開封第一講書人閱讀 51,292評(píng)論 1 301
  • 那天攒发,我揣著相機(jī)與錄音,去河邊找鬼晋南。 笑死惠猿,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的负间。 我是一名探鬼主播偶妖,決...
    沈念sama閱讀 40,135評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼政溃!你這毒婦竟也來了趾访?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,992評(píng)論 0 275
  • 序言:老撾萬榮一對(duì)情侶失蹤董虱,失蹤者是張志新(化名)和其女友劉穎扼鞋,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體愤诱,經(jīng)...
    沈念sama閱讀 45,429評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡云头,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,636評(píng)論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了淫半。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片溃槐。...
    茶點(diǎn)故事閱讀 39,785評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖科吭,靈堂內(nèi)的尸體忽然破棺而出昏滴,到底是詐尸還是另有隱情,我是刑警寧澤对人,帶...
    沈念sama閱讀 35,492評(píng)論 5 345
  • 正文 年R本政府宣布谣殊,位于F島的核電站,受9級(jí)特大地震影響规伐,放射性物質(zhì)發(fā)生泄漏蟹倾。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,092評(píng)論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望鲜棠。 院中可真熱鬧肌厨,春花似錦、人聲如沸豁陆。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽盒音。三九已至表鳍,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間祥诽,已是汗流浹背譬圣。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評(píng)論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留雄坪,地道東北人厘熟。 一個(gè)月前我還...
    沈念sama閱讀 47,891評(píng)論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像维哈,于是被迫代替她去往敵國和親绳姨。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,713評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容