Paper: Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
重點關(guān)注論文中的相對位置編碼及提高融合了相對位置信息的attention score的計算效率的部分患亿。
Abstract
Transformer具有學(xué)習(xí)長依賴的能力,但受限于語言模型固定長度上下文的限定押逼。本文提出的Transformer-XL神經(jīng)網(wǎng)絡(luò)架構(gòu)可以在不打破時序關(guān)系的前提下突破固定長度上下文的限制步藕,學(xué)習(xí)文本間的依賴關(guān)系。模型具體包括一個片段級別的循環(huán)機制和一個全新的位置編碼方式挑格。該架構(gòu)不僅可以學(xué)習(xí)文本中的長依賴關(guān)系咙冗,還解決了上下文碎片問題(context fragmentation problem)。最終漂彤,Transformer-XL可以習(xí)得相較RNN長80%雾消、相較原始Transformer長450%的依賴關(guān)系,并且在評估時的速度最多比原始Transformer快1800多倍挫望。作者還提供了Transformer-XL的Tensorflow和PyTorch的實現(xiàn)版本立润。
Introduction
本文關(guān)注的是基于神經(jīng)網(wǎng)絡(luò)的架構(gòu)使得模型具備為序列數(shù)據(jù)的長依賴關(guān)系進行建模的能力的問題。RNN由于梯度消失和梯度爆炸的問題難以優(yōu)化媳板,即使是引入了門機制的LSTM和梯度裁剪技術(shù)桑腮,以上問題仍舊未能得到完全解決,同時一般而言蛉幸,LSTM語言模型平均使用長度為200的上下文單詞破讨,因此尚有一定的提升空間。
另一方面奕纫,可以直接捕捉兩個距離較遠的單詞之間關(guān)系的attention機制或許有助于實現(xiàn)長依賴關(guān)系的學(xué)習(xí)提陶。相關(guān)的研究有很多,但受限于固定長度的上下文匹层,模型無法捕捉那些長度超過預(yù)定義的上文長度的文本依賴關(guān)系隙笆。還有方法在不考慮句子或其它語義邊界的情況下選擇連續(xù)字符構(gòu)成長度固定的片段(fixed-length segment)進行建模,但這樣的模型在前幾步的預(yù)測中缺乏必要的上下文信息,繼而帶來優(yōu)化及性能方面的問題仲器,本文將該方面的問題稱為上下文碎片問題(context fragmentation)煤率。
為了解決上述問題,本文提出了名為Transformer-XL的架構(gòu)乏冀,其中XL意為extra long,該機制將循環(huán)的概念引入了深度自注意力網(wǎng)絡(luò)中洋只。具體而言辆沦,對于每一個新片段(segment),在計算其隱藏狀態(tài)時會復(fù)用之前片段的隱藏狀態(tài)识虚,而非從頭開始計算肢扯。復(fù)用的隱藏狀態(tài)作為當(dāng)前片段的記憶單元,從而建立起了片段之間的循環(huán)連接担锤。這樣的做法由于信息可以在片段的循環(huán)連接之間得以傳播使得為特別長的依賴關(guān)系建模成為可能蔚晨,同時也解決了上下文碎片的問題。此外肛循。為了在復(fù)用狀態(tài)時不會引發(fā)時許混淆的問題铭腕,本文展現(xiàn)了使用相對位置編碼的必要性。一個簡單但更高效的相對位置編碼公式也有利于那些長度超過訓(xùn)練時的注意力長度的內(nèi)容學(xué)習(xí)上的泛化多糠。
Transformer-XL是首個同時在字符級別(character-level)和詞級別(word-level)語言模型上超越RNN模型的自注意力模型累舷。
Related Work
語言模型領(lǐng)域近年來的發(fā)展有很多,如設(shè)計更好的編碼上下文的新架構(gòu)夹孔、改進的正則化或優(yōu)化算法被盈、softmax計算的加速以及對輸出分布的優(yōu)化等等。
為了捕捉語言模型中的長范圍的上下文搭伤,部分工作直接將更長的上下文表示作為附加輸入送入神經(jīng)網(wǎng)絡(luò)中≈辉酰現(xiàn)有的工作包括人為定義上下文表示以及從數(shù)據(jù)中學(xué)習(xí)篇章級別的主題等等。
更廣泛而言怜俐,在一般的序列建模問題中身堡,如何捕捉長依賴關(guān)系一直是一個長期存在的研究問題。由于LSTM的普適性佑菩,大量工作關(guān)注解決其梯度消失的問題盾沫,包括更好的參數(shù)初始化、附加的損失計算殿漠、增強的記憶單元結(jié)構(gòu)以及一些修改RNN結(jié)構(gòu)以便于優(yōu)化的方法等赴精。與這些做法不同的是,本文的工作基于Transformer架構(gòu)绞幌,同時證明了學(xué)習(xí)長依賴關(guān)系的能力對現(xiàn)實任務(wù)中的語言建模的優(yōu)勢蕾哟。
Model
給定token的語料庫:,語言模型的任務(wù)是估計聯(lián)合概率√啡罚基于因式分解帘营,該問題簡化為估計各條件因子。本工作采用標(biāo)準(zhǔn)的神經(jīng)網(wǎng)絡(luò)方法為各條件概率建模逐哈。具體而言芬迄,以一個可訓(xùn)練的神經(jīng)網(wǎng)絡(luò)將上下文編碼為一個固定大小的隱藏狀態(tài),繼而乘上詞嵌入以獲得其邏輯表示昂秃,該表示將送入softmax方程產(chǎn)生下一個token的概率分布禀梳。
Vanilla Transformer Language Models
將Transformer或自注意力機制用于語言模型的一個可行方案是,將整個語料劃分為若干較短的可管理的片段肠骆,同時忽略之前片段的上下文信息算途,僅在各片段內(nèi)部訓(xùn)練模型。本文將該模型稱為Vanilla Model蚀腿,其過程如Figure 1所示嘴瓤。
在該模型下的訓(xùn)練過程中,信息無法在片段間流動莉钙。使用固定長度的上下文存在兩點關(guān)鍵限制:①可能獲取的依賴長度上限由片段長度決定廓脆。而在字符級別的語言模型中,片段長度需要有好幾百胆胰,即使自注意力機制能在一定程度上緩解RNN梯度消失的問題狞贱,但該模型認為充分利用自注意力機制的這一優(yōu)化優(yōu)勢。②盡管可以通過padding延續(xù)文本的句子或其他語義邊界特性蜀涨,但事實上為了提高效率瞎嬉,簡單將長文本劃分成固定長度的片段已然成為標(biāo)準(zhǔn)做法,繼而引發(fā)了上文提及的上下文碎片問題厚柳。
在評估階段的每一步中氧枣,the vanilla model依舊采用訓(xùn)練階段相同的片段長度,但僅對最后一個位置進行預(yù)測别垮。而在下一步中便监,片段將向右平移一個位置,再重新從頭開始處理整個片段進行當(dāng)前片段最后位置的預(yù)測碳想。如圖所示烧董,該過程確保每一次預(yù)測用到了訓(xùn)練階段能看到的最長的上下文,同時緩解了訓(xùn)練階段的上下文碎片問題胧奔。但相應(yīng)的評估階段的計算成本也有所提高逊移。這一點在本文提出的框架中得以解決。
Segment-Level Recurrence with State Reuse
為了解決使用固定長度上下文帶來的限制龙填,本文提出在Transformer架構(gòu)中引入循環(huán)機制胳泉,其過程如Figure 2所示拐叉。
在訓(xùn)練階段,前一個片段計算得到的隱藏狀態(tài)序列將被固定(fixed)并緩存起來(cached)扇商,在模型處理接下來的一個新片段時凤瘦,剛剛緩存的隱藏層序列將作為一個擴展上下文進行復(fù)用。盡管梯度仍保留于每個片段內(nèi)部案铺,但這個附加的輸入使得網(wǎng)絡(luò)可以處理歷史信息蔬芥,繼而使得模型可以對長依賴建模,同時避免了上下文碎片問題控汉。該過程以公式化形式將表述如下坝茎,將兩個長度為的連續(xù)片段分別表示如下:和;將第個片段的第層的隱藏狀態(tài)序列表示為暇番,其中是隱藏層狀態(tài)維度。然后思喊,將第個片段的第層的隱藏狀態(tài)序列的計算過程如下:
其中函數(shù)表示停止梯度計算馒过,表示兩個隱藏層序列沿length維度的拼接莱衩,表示模型參數(shù)。與標(biāo)準(zhǔn)Transformer相比,關(guān)鍵的不同點在于汁雷,key 和value 基于擴展后上下文得來,因此可從之前的片段中獲取信息起便。Figure 2(a)中由綠色路徑標(biāo)注了本文的特殊設(shè)計典挑。
通過在每兩個連續(xù)片段之間應(yīng)用循環(huán)機制,建立起了隱藏層片段級別的循環(huán)纲辽。因此有效的上下文信息將不僅僅在兩個片段內(nèi)被利用颜武。然而需要注意的是,和之間的循環(huán)依賴每一片段將向下移動一層拖吼,這與傳統(tǒng)的基于RNN的語言模型中的同層循環(huán)是有所不同的鳞上。最終,最長依賴長度隨層數(shù)和片段長度呈線性增長吊档,即篙议,如Figure 2(b)的陰影部分所示。這與一訓(xùn)練基于RNN的語言模型采用的方法truncated BPTT類似怠硼。但與其不同的是鬼贱,本文提出的方法緩存的是隱藏狀態(tài)序列而非上一序列,同時還應(yīng)該結(jié)合后文將介紹的相對位置編碼技術(shù)一起使用香璃。
該架構(gòu)除了能利用更長的上下文以及解決上下文碎片問題外这难,循環(huán)機制還使得評估時的效率顯著提高。具體而言增显,在評估階段雁佳,之前片段的表示可以同the vanilla model一樣進行重用脐帝。
最后需要注意的是,循環(huán)機制不必僅限于鄰接的前一個片段糖权。理論上堵腹,在GPU內(nèi)存允許的情況下,可以緩存盡可能多的之前的片段星澳,并在處理當(dāng)前片段時復(fù)用所有的這些片段作為額外的上下文疚顷。因此,可以緩存一個預(yù)定義的長度——個舊的隱藏狀態(tài)禁偎,并將它們表示為記憶單元腿堤。實驗中,本文將設(shè)為等同于片段長度的大小如暖,并在評估中笆檀,將其值加倍增長。
Relative Positional Encoding
上述方案存在的問題是重用隱藏層狀態(tài)的順序問題盒至,即在重時是如何保證位置信息的連貫問題(the positional information coherent)酗洒。在標(biāo)準(zhǔn)Transformer中,序列順序信息是通過一個位置編碼集合提供的枷遂,其中第行表示某一片段中第個絕對位置樱衷,表示建模的最大長度;隨后輸入將有文本的詞嵌入表示和位置編碼相加得來酒唉。倘若將這樣的位置編碼方式直接運用到本文的循環(huán)機制中矩桂,隱藏狀態(tài)序列的計算如下:
其中表示序列的詞嵌入,表示一個轉(zhuǎn)換方程痪伦。需要注意的是侄榴,和用到了同樣的位置編碼。因此流妻,對于任意的牲蜀,模型沒有用于分辨和位置區(qū)別的信息,繼而造成嚴(yán)重的性能損失绅这。
為了避免上述問題涣达,最基礎(chǔ)的想法是在隱藏狀態(tài)中僅編碼相對位置信息。從概念上來說证薇,位置編碼給予了模型如何匯聚信息的時序線索度苔。出于同樣的目的,可以在每一層中將類似的信息映射到attention分值上浑度。更重要的是寇窑,以相對位置定義時序偏差是更直觀且有利于泛化的。例如箩张,當(dāng)一個query向量在key向量上計算注意力時甩骏,無需了解每一個key向量的絕對位置窗市,了解每一個key向量和自身的相對位置即可反映片段內(nèi)的時序關(guān)系。在實踐上饮笛,可以創(chuàng)建一個相對位置編碼集合咨察,其中第行表示i和其它位置的相對距離。通過將相對位置動態(tài)地映射到注意力分值中福青,query向量可以輕松地根據(jù)不同的距離區(qū)分和的表示摄狱,繼而使得狀態(tài)重用機制可行。與此同時无午,由于絕對位置信息可以遞歸地從相對距離中獲取媒役,時序信息并未丟失。
過去宪迟,相對位置編碼的思想已被用于機器翻譯和音樂生成任務(wù)中酣衷。這里,本文提出一種不同的相對位置編碼新形式的推導(dǎo)次泽,不僅與其絕對位置有一對一的對應(yīng)關(guān)系鸥诽,而且具有更好的泛化能力。首先箕憾,在標(biāo)準(zhǔn)Transformer中,同一片段內(nèi)的query向量和key向量之間的注意力得分計算可做如下分解:
根據(jù)僅依賴相對位置信息的思想拳昌,將上式中的四項重新參數(shù)化如下:
- 將所有在公式項和中出現(xiàn)的計算key向量用到的絕對位置編碼替換為對應(yīng)的絕對位置編碼袭异。這從本質(zhì)上反映了只考慮相對位置的前提。需要注意的是炬藤,是沒有可訓(xùn)練參數(shù)的正弦編碼矩陣御铃。
- 引入了可訓(xùn)練參數(shù)來替代公式項中的query向量。在采用相對位置編碼的情況下沈矿,無論是哪個查詢位置上真,此處的query向量應(yīng)是一致的,其位置信息由相對位置編碼反映羹膳,因此此處采用一個可訓(xùn)練的參數(shù)表示睡互。出于同樣的原因,將公式項中的替換為可訓(xùn)練參數(shù)陵像。
- 將兩個權(quán)重矩陣和區(qū)別開來就珠,以分別表示基于內(nèi)容的key向量和基于位置的key向量。
在這樣全新的參數(shù)化表示下醒颖,每一項都具備一個直觀的含義:表示內(nèi)容上的關(guān)聯(lián)(content-based addressing)妻怎;捕捉了依賴內(nèi)容的位置偏差;控制著全局內(nèi)容偏差泞歉;編碼了全局位置偏差逼侦。
綜上匿辩,帶有單個注意力頭的層的Transformer-XL的計算過程如下:對于:
初始化為詞嵌入序列。此外榛丢,計算的效率隨序列長度呈二次方變化铲球。下面將介紹一個對的高效計算方式,其效率隨序列長度呈線性變化涕滋。
Efficient Computation of the Attention with Relative Positional Embedding
倘若以基本方法計算考慮相對位置的attention score睬辐,其中對于所有對的計算呈二次方的消耗。因此宾肺,本文提出一種線性消耗的計算方法溯饵。已知,相對距離的只能是到的整數(shù)值锨用,M是記憶長度丰刊,L是片段長度。令增拥,則:
令啄巧,則
接下來,對于attention score中的(b)項掌栅,收集所有的對秩仆,形成如下一個的矩陣:
接下來,定義一個新矩陣:
將和比較可以發(fā)現(xiàn)將的第\mathbf{B}i$行猾封。
類似的澄耍,對于attention score中的項,收集所有的對晌缘,形成如下一個的矩陣:
同樣齐莲,可以定義:
此時的每一行可由向左平移得來。
上述方法中磷箕,平移的消耗較少选酗,主要的計算量在于和的矩陣乘法上,從而效率得以提升岳枷。