Megatron-LM:Transformer模型專用分布式張量模型并行方法

論文標(biāo)題:Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
論文鏈接:https://arxiv.org/abs/1909.08053
論文來源:NVIDIA

一份蝴、概述

隨著自然語(yǔ)言處理領(lǐng)域預(yù)訓(xùn)練語(yǔ)言模型的規(guī)模變得越來越大,它們超過了現(xiàn)代處理器的內(nèi)存限制,需要額外的內(nèi)存管理技術(shù),如激活檢查點(diǎn)(activation checkpointing)缸托。一些廣泛使用的優(yōu)化算法如Adam需要額外的內(nèi)存來存儲(chǔ)其中的動(dòng)量和其他優(yōu)化器狀態(tài)场斑,這降低了可以有效訓(xùn)練的模型大小。幾種模型并行方法通過劃分模型來克服這個(gè)限制赖草,這樣權(quán)重及其相關(guān)的優(yōu)化器狀態(tài)就不需要同時(shí)駐留在處理器上杆逗。例如乡翅,GPipe和Mesh-Tensorflow提供了不同種類的模型并行框架。但是髓迎,它們需要重寫模型峦朗,并依賴于仍在開發(fā)中的自定義編譯器和框架。

在這項(xiàng)工作中排龄,我們使用簡(jiǎn)單高效的層內(nèi)模型并行(intra-layer model-parallelism)來實(shí)現(xiàn)模型并行。我們利用transformer基礎(chǔ)語(yǔ)言模型中的固有結(jié)構(gòu)來實(shí)現(xiàn)一個(gè)簡(jiǎn)單的模型并行翎朱,它可以在PyTorch中高效訓(xùn)練橄维,而無需自定義C++代碼或編譯器。這種方法與GPipe等基于流水線的模型并行方法是正交的(可以同時(shí)使用拴曲,相互獨(dú)立而不沖突)争舞。

為了證明我們方法的可擴(kuò)展性,我們建立了一個(gè)baseline澈灼,在單個(gè)NVIDIA V100 32GB GPU上訓(xùn)練了一個(gè)12億參數(shù)的模型竞川,保持39 TeraFLOPs的計(jì)算速度。這是DGX-2H服務(wù)器中單GPU配置的理論峰值浮點(diǎn)運(yùn)算能力的30%叁熔,因此是一個(gè)很強(qiáng)的基線委乌。通過在512個(gè)GPU上以8路模型并行將模型擴(kuò)展到83億參數(shù),我們實(shí)現(xiàn)了高達(dá)每秒15.1PetaFLOPs的持續(xù)計(jì)算速度荣回。與單GPU情況相比遭贸,這代表了76%的擴(kuò)展效率。下圖顯示了更詳細(xì)的擴(kuò)展結(jié)果心软。

拓展效率

為了分析模型大小擴(kuò)展對(duì)準(zhǔn)確率的影響壕吹,我們訓(xùn)練了自回歸的GPT-2語(yǔ)言模型以及自編碼的BERT雙向transformer著蛙,并在幾個(gè)下游任務(wù)上對(duì)其進(jìn)行評(píng)估。我們發(fā)現(xiàn)耳贬,隨著模型大小的增加踏堡,現(xiàn)有的BERT架構(gòu)會(huì)導(dǎo)致模型退化。我們通過重新排列transformer層中的層標(biāo)準(zhǔn)化和殘差連接來克服這一挑戰(zhàn)咒劲,結(jié)果表明暂吉,進(jìn)行這一改變之后,下游任務(wù)的開發(fā)集結(jié)果隨著模型大小的增加單調(diào)提升缎患。此外慕的,我們的模型在WikiText103、LAMBADA上的閉包式預(yù)測(cè)準(zhǔn)確率以及RACE閱讀理解數(shù)據(jù)集上都取得了SOTA的測(cè)試集結(jié)果挤渔。

二肮街、背景

有兩種主要的范式可以將深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練擴(kuò)展到多個(gè)硬件加速器中:(1)數(shù)據(jù)并行,其中minibatch的訓(xùn)練被劃分到多個(gè)worker中;(2)模型并行判导,其將模型的內(nèi)存使用和計(jì)算分布到多個(gè)worker中嫉父。通過按可用worker的數(shù)量成比例增加minibatch大小(即weak scaling),可以觀察到訓(xùn)練數(shù)據(jù)吞吐量的近乎線性的擴(kuò)展眼刃。但是绕辖,大批量訓(xùn)練會(huì)使優(yōu)化過程更復(fù)雜,可能導(dǎo)致準(zhǔn)確率降低或收斂時(shí)間延長(zhǎng)擂红,反而抵消了增加訓(xùn)練吞吐量帶來的好處仪际。進(jìn)一步的研究開發(fā)了各種技術(shù)來緩解這些影響,降低大型神經(jīng)網(wǎng)絡(luò)的訓(xùn)練時(shí)間昵骤。為了進(jìn)一步擴(kuò)展訓(xùn)練規(guī)模树碱,一些并行工作將數(shù)據(jù)并行和激活檢查點(diǎn)結(jié)合起來:在前向傳播中重新計(jì)算而不是存儲(chǔ)激活,以減少內(nèi)存需求变秦。

然而成榜,這些技術(shù)在其可以處理的問題大小方面存在一個(gè)根本的局限:模型必須完全能夠在一個(gè)worker上進(jìn)行處理。隨著BERT和GPT-2等語(yǔ)言模型大小和復(fù)雜度的增加蹦玫,神經(jīng)網(wǎng)絡(luò)已經(jīng)接近了現(xiàn)代硬件加速器的內(nèi)存容量赎婚。這個(gè)問題的一個(gè)解決方案是采用參數(shù)共享來減少模型的內(nèi)存占用,但這限制了模型的總體容量樱溉。我們的方法是利用模型并行將模型劃分到多個(gè)加速器上挣输。這不僅減輕了內(nèi)存壓力,而且獨(dú)立于微批量大小增加了并行度饺窿。

在模型并行中歧焦,還有兩種進(jìn)一步的范式:逐層流水線并行(layer-wise pipeline parallelism)和更通用的分布式張量計(jì)算(distributed tensor computation)。在流水線模型并行中,一組操作首先在一個(gè)設(shè)備上執(zhí)行绢馍,然后將輸出傳遞到流水線中的下一個(gè)設(shè)備向瓷,在下一個(gè)設(shè)備上執(zhí)行不同的另一組操作。一些方法與流水線并行結(jié)合使用參數(shù)服務(wù)器舰涌。然而猖任,這些方法存在一致性問題。TensorFlow中的GPipe框架通過使用同步梯度下降來解決這種一致性問題瓷耙。這種方法需要額外的邏輯來處理這些通信和計(jì)算操作的高效流水線朱躺,并受到減少效率的流水線bubble的影響,或者對(duì)優(yōu)化器本身的更改會(huì)影響準(zhǔn)確性搁痛。

分布式張量計(jì)算是一種正交的长搀、更通用的方法,它將張量操作劃分到多個(gè)設(shè)備上以加速計(jì)算或增加模型大小鸡典。編排這種并行計(jì)算的深度學(xué)習(xí)框架FlexFlow提供了一種選擇最佳并行策略的方法源请。最近,Mesh-TensorFlow在TensorFlow中引入了一種指定分布式張量計(jì)算的通用類的語(yǔ)言彻况。并行維度由終端用戶在這一語(yǔ)言中指定谁尸,生成的圖由適當(dāng)?shù)募w原語(yǔ)編譯。我們利用類似于Mesh-TensorFlow中的見解纽甘,并利用transformer注意力頭的并行計(jì)算來并行化我們的transformer模型良蛮。但是,我們沒有實(shí)現(xiàn)一個(gè)用于模型并行的框架和編譯器悍赢,而是僅對(duì)現(xiàn)有的PyTorch transformer實(shí)現(xiàn)進(jìn)行了一些有針對(duì)性的修改决瞳。我們的方法很簡(jiǎn)單,不需要任何新的編譯器或代碼重寫泽裳,可以通過插入幾個(gè)簡(jiǎn)單的基元來完全實(shí)現(xiàn)瞒斩,如下一節(jié)所述。

三涮总、方法

我們利用transformer網(wǎng)絡(luò)的結(jié)構(gòu)來創(chuàng)建一個(gè)簡(jiǎn)單的模型并行實(shí)現(xiàn),只需要添加幾個(gè)同步原語(yǔ)祷舀。一個(gè)transformer層由一個(gè)自注意力塊和一個(gè)兩層多層感知器(MLP)組成瀑梗,如下圖所示。我們?cè)谶@兩個(gè)塊中分別引入模型并行裳扯。

模型架構(gòu)

首先詳細(xì)說明MLP塊抛丽。該塊的第一部分是一個(gè)GEMM(General Matrix Multiplication),后面是一個(gè)GELU非線性層:

Y = \mathrm{GeLU}(XA)

并行化GEMM的一種方法是按行切分權(quán)重矩陣A饰豺,按列切分輸入X:

X = [X_1, X_2]亿鲜, \quad A = \begin{bmatrix} A_1 \\ A_2 \end{bmatrix}

這將得到Y = \mathrm{GeLU}(X_1A_1 + X_2A_2)。由于GeLU是非線性函數(shù),\mathrm{GeLU}(X_1A_1 + X_2A_2) \neq \mathrm{GeLU}(X_1A_1) + \mathrm{GeLU}(X_2A_2)蒿柳,這種切分方式需要在GeLU函數(shù)之前進(jìn)行同步饶套。

另一種切分方法是按列切分A = [A_1, A_2]。這種切分方式允許獨(dú)立地對(duì)每個(gè)切分后的GEMM的輸出應(yīng)用GeLU非線性函數(shù):

[Y_1, Y_2] = [\mathrm{GeLU}(XA_1), \mathrm{GeLU}(XA_2)]

這種方法的優(yōu)點(diǎn)是去除了一個(gè)同步點(diǎn)垒探。因此妓蛮,我們采用這種列并行方式切分第一個(gè)GEMM,直接將GeLU層的輸出作為第二個(gè)GEMM(這個(gè)GEMM以按行切分的方式并行)的輸入圾叼,而不需要任何通信蛤克,如下圖(a)所示。

并行方法

第二個(gè)GEMM的輸出通過dropout層之前先跨GPU進(jìn)行reduce夷蚊。這種方法將MLP塊中的兩個(gè)GEMM切分到不同的GPU上构挤,前向傳播中只需要一個(gè)all-reduce操作(g操作符),反向傳播中也只需要一個(gè)all-reduce(f操作符)惕鼓。這兩個(gè)操作符互為共軛筋现,可以通過PyTorch中的幾行代碼實(shí)現(xiàn)。例如呜笑,f操作符的實(shí)現(xiàn)如下:

class f(torch.autograd.Function):
    def forward(ctx, x):
        return x
    
    def backward(ctx, gradient):
        all_reduce(gradient) 
        return gradient

如上圖3(b)所示夫否,對(duì)于自注意力塊,我們利用了多頭注意力操作中固有的并行性叫胁,以列并行的方式切分key(K)凰慈、query(Q)和value(V)對(duì)應(yīng)的GEMM,這樣每個(gè)注意力頭對(duì)應(yīng)的矩陣乘法在一個(gè)GPU上本地計(jì)算驼鹅。這允許我們?cè)贕PU之間切分每個(gè)注意力頭的參數(shù)和工作量微谓,并且不需要任何直接的通信就可以完成自注意力。在自注意力之后的輸出線性層的GEMM(自注意力之后)沿其行進(jìn)行并行化输钩,直接獲取并行注意力層的輸出豺型,而不需要GPU之間的通信。MLP和自注意力這兩種塊的并行方法都融合了兩組GEMM买乃,消除了中間的同步點(diǎn)姻氨,從而獲得了更好的擴(kuò)展性。這使我們能夠使用僅兩個(gè)all-reduce在前向路徑中完成transformer層中的所有GEMM的計(jì)算剪验,并在反向路徑中也使用兩個(gè)all-reduce(如下圖所示)肴焊。

通信操作

Transformer語(yǔ)言模型的輸出嵌入維度為隱層維度H乘以詞匯表大小v。由于詞匯表的大小是萬級(jí)的(例如功戚,GPT-2使用的詞匯表大小為50娶眷,257),將輸出嵌入GEMM并行化是有益的啸臀。但是届宠,在transformer語(yǔ)言模型中,輸出嵌入層與輸入嵌入共享權(quán)重,這需要同時(shí)修改這兩者豌注。我們沿詞典維度v切分輸入嵌入權(quán)重矩陣E_{H\times v}E = [E_1, E_2] (列向)伤塌。由于每個(gè)切分部分現(xiàn)在只包含嵌入表的一部分,在輸入嵌入之后需要一個(gè)all-reduce(g操作符)幌羞。對(duì)于輸出嵌入寸谜,一種方法是執(zhí)行并行GEMM [Y_1, Y_2] = [XE_1, XE_2]以獲得logits,添加一個(gè)all-gather Y = \text{all-gather}([Y_1, Y_2])属桦,并將結(jié)果發(fā)送到交叉熵?fù)p失函數(shù)熊痴。但是,在這種情況下聂宾,all-gather將通信b \times s \times v個(gè)元素(其中b是batch大小果善,s是序列長(zhǎng)度),由于詞匯表大小很大系谐,這會(huì)產(chǎn)生巨大的通信量巾陕。為了減小通信量,我們將并行GEMM [Y_1, Y_2]的輸出與交叉熵?fù)p失函數(shù)融合纪他,這將維度降低到b \times s鄙煤。通信標(biāo)量損失而不是logits極大地減少了通信量,這極大地提高了我們的模型并行方法的效率茶袒。

我們的模型并行方法的很大一部分可以歸納為針對(duì)減少通信并保持GPU計(jì)算限度的技術(shù)梯刚。對(duì)于dropout、層標(biāo)準(zhǔn)化薪寓、殘差連接等計(jì)算亡资,我們選擇跨GPU重復(fù)計(jì)算,而不是讓一個(gè)GPU計(jì)算部分然后廣播結(jié)果到其他GPU向叉。具體來說锥腻,我們?cè)诿總€(gè)GPU上維護(hù)層標(biāo)準(zhǔn)化參數(shù)的副本,并在模型并行區(qū)域的輸出上運(yùn)行dropout和殘差連接母谎,然后將其饋送作為下一個(gè)模型并行區(qū)域的輸入瘦黑。對(duì)于模型的優(yōu)化,我們?cè)试S每個(gè)模型并行worker優(yōu)化自己的參數(shù)集合奇唤。由于所有的值要么是本地的(比如GEMM的參數(shù))供璧,要么被復(fù)制在多個(gè)GPU上(比如層標(biāo)準(zhǔn)化的參數(shù)),所以在這種形式中不需要通信更新的參數(shù)值冻记。

總而言之,我們上述的方法實(shí)現(xiàn)起來很監(jiān)督来惧,只需要在前向和反向傳播中添加幾個(gè)額外的all-reduce操作冗栗。它不需要編譯器,與GPipe等方法提倡的流水線模型并行是正交互補(bǔ)的。

四隅居、實(shí)驗(yàn)

  1. 并行效率
實(shí)驗(yàn)
  1. GPT-2實(shí)驗(yàn)
實(shí)驗(yàn)
實(shí)驗(yàn)
實(shí)驗(yàn)
  1. BERT實(shí)驗(yàn)

在進(jìn)行BERT相關(guān)的實(shí)驗(yàn)時(shí)钠至,我們發(fā)現(xiàn)到模型參數(shù)規(guī)模超過BERT-large(336M)時(shí)會(huì)出現(xiàn)模型性能退化。我們發(fā)現(xiàn)按照下圖7的方式重排層標(biāo)準(zhǔn)化與殘差連接的順序后可以解決這個(gè)問題胎源。

實(shí)驗(yàn)
實(shí)驗(yàn)
實(shí)驗(yàn)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末棉钧,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子涕蚤,更是在濱河造成了極大的恐慌宪卿,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,682評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件万栅,死亡現(xiàn)場(chǎng)離奇詭異佑钾,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)烦粒,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,277評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門休溶,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人扰她,你說我怎么就攤上這事兽掰。” “怎么了徒役?”我有些...
    開封第一講書人閱讀 165,083評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵孽尽,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我廉涕,道長(zhǎng)泻云,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,763評(píng)論 1 295
  • 正文 為了忘掉前任狐蜕,我火速辦了婚禮宠纯,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘层释。我一直安慰自己婆瓜,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,785評(píng)論 6 392
  • 文/花漫 我一把揭開白布贡羔。 她就那樣靜靜地躺著廉白,像睡著了一般。 火紅的嫁衣襯著肌膚如雪乖寒。 梳的紋絲不亂的頭發(fā)上猴蹂,一...
    開封第一講書人閱讀 51,624評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音楣嘁,去河邊找鬼磅轻。 笑死珍逸,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的聋溜。 我是一名探鬼主播谆膳,決...
    沈念sama閱讀 40,358評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼撮躁!你這毒婦竟也來了漱病?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,261評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤把曼,失蹤者是張志新(化名)和其女友劉穎杨帽,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體祝迂,經(jīng)...
    沈念sama閱讀 45,722評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡睦尽,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,900評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了型雳。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片当凡。...
    茶點(diǎn)故事閱讀 40,030評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖纠俭,靈堂內(nèi)的尸體忽然破棺而出沿量,到底是詐尸還是另有隱情,我是刑警寧澤冤荆,帶...
    沈念sama閱讀 35,737評(píng)論 5 346
  • 正文 年R本政府宣布朴则,位于F島的核電站,受9級(jí)特大地震影響钓简,放射性物質(zhì)發(fā)生泄漏乌妒。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,360評(píng)論 3 330
  • 文/蒙蒙 一外邓、第九天 我趴在偏房一處隱蔽的房頂上張望撤蚊。 院中可真熱鬧,春花似錦损话、人聲如沸侦啸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,941評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)光涂。三九已至,卻和暖如春拧烦,著一層夾襖步出監(jiān)牢的瞬間忘闻,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,057評(píng)論 1 270
  • 我被黑心中介騙來泰國(guó)打工恋博, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,237評(píng)論 3 371
  • 正文 我出身青樓郭宝,卻偏偏與公主長(zhǎng)得像蝇裤,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子秦士,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,976評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容