論文標(biāo)題:Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
論文鏈接:https://arxiv.org/abs/1909.08053
論文來源:NVIDIA
一份蝴、概述
隨著自然語(yǔ)言處理領(lǐng)域預(yù)訓(xùn)練語(yǔ)言模型的規(guī)模變得越來越大,它們超過了現(xiàn)代處理器的內(nèi)存限制,需要額外的內(nèi)存管理技術(shù),如激活檢查點(diǎn)(activation checkpointing)缸托。一些廣泛使用的優(yōu)化算法如Adam需要額外的內(nèi)存來存儲(chǔ)其中的動(dòng)量和其他優(yōu)化器狀態(tài)场斑,這降低了可以有效訓(xùn)練的模型大小。幾種模型并行方法通過劃分模型來克服這個(gè)限制赖草,這樣權(quán)重及其相關(guān)的優(yōu)化器狀態(tài)就不需要同時(shí)駐留在處理器上杆逗。例如乡翅,GPipe和Mesh-Tensorflow提供了不同種類的模型并行框架。但是髓迎,它們需要重寫模型峦朗,并依賴于仍在開發(fā)中的自定義編譯器和框架。
在這項(xiàng)工作中排龄,我們使用簡(jiǎn)單高效的層內(nèi)模型并行(intra-layer model-parallelism)來實(shí)現(xiàn)模型并行。我們利用transformer基礎(chǔ)語(yǔ)言模型中的固有結(jié)構(gòu)來實(shí)現(xiàn)一個(gè)簡(jiǎn)單的模型并行翎朱,它可以在PyTorch中高效訓(xùn)練橄维,而無需自定義C++代碼或編譯器。這種方法與GPipe等基于流水線的模型并行方法是正交的(可以同時(shí)使用拴曲,相互獨(dú)立而不沖突)争舞。
為了證明我們方法的可擴(kuò)展性,我們建立了一個(gè)baseline澈灼,在單個(gè)NVIDIA V100 32GB GPU上訓(xùn)練了一個(gè)12億參數(shù)的模型竞川,保持39 TeraFLOPs的計(jì)算速度。這是DGX-2H服務(wù)器中單GPU配置的理論峰值浮點(diǎn)運(yùn)算能力的30%叁熔,因此是一個(gè)很強(qiáng)的基線委乌。通過在512個(gè)GPU上以8路模型并行將模型擴(kuò)展到83億參數(shù),我們實(shí)現(xiàn)了高達(dá)每秒15.1PetaFLOPs的持續(xù)計(jì)算速度荣回。與單GPU情況相比遭贸,這代表了76%的擴(kuò)展效率。下圖顯示了更詳細(xì)的擴(kuò)展結(jié)果心软。
為了分析模型大小擴(kuò)展對(duì)準(zhǔn)確率的影響壕吹,我們訓(xùn)練了自回歸的GPT-2語(yǔ)言模型以及自編碼的BERT雙向transformer著蛙,并在幾個(gè)下游任務(wù)上對(duì)其進(jìn)行評(píng)估。我們發(fā)現(xiàn)耳贬,隨著模型大小的增加踏堡,現(xiàn)有的BERT架構(gòu)會(huì)導(dǎo)致模型退化。我們通過重新排列transformer層中的層標(biāo)準(zhǔn)化和殘差連接來克服這一挑戰(zhàn)咒劲,結(jié)果表明暂吉,進(jìn)行這一改變之后,下游任務(wù)的開發(fā)集結(jié)果隨著模型大小的增加單調(diào)提升缎患。此外慕的,我們的模型在WikiText103、LAMBADA上的閉包式預(yù)測(cè)準(zhǔn)確率以及RACE閱讀理解數(shù)據(jù)集上都取得了SOTA的測(cè)試集結(jié)果挤渔。
二肮街、背景
有兩種主要的范式可以將深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練擴(kuò)展到多個(gè)硬件加速器中:(1)數(shù)據(jù)并行,其中minibatch的訓(xùn)練被劃分到多個(gè)worker中;(2)模型并行判导,其將模型的內(nèi)存使用和計(jì)算分布到多個(gè)worker中嫉父。通過按可用worker的數(shù)量成比例增加minibatch大小(即weak scaling),可以觀察到訓(xùn)練數(shù)據(jù)吞吐量的近乎線性的擴(kuò)展眼刃。但是绕辖,大批量訓(xùn)練會(huì)使優(yōu)化過程更復(fù)雜,可能導(dǎo)致準(zhǔn)確率降低或收斂時(shí)間延長(zhǎng)擂红,反而抵消了增加訓(xùn)練吞吐量帶來的好處仪际。進(jìn)一步的研究開發(fā)了各種技術(shù)來緩解這些影響,降低大型神經(jīng)網(wǎng)絡(luò)的訓(xùn)練時(shí)間昵骤。為了進(jìn)一步擴(kuò)展訓(xùn)練規(guī)模树碱,一些并行工作將數(shù)據(jù)并行和激活檢查點(diǎn)結(jié)合起來:在前向傳播中重新計(jì)算而不是存儲(chǔ)激活,以減少內(nèi)存需求变秦。
然而成榜,這些技術(shù)在其可以處理的問題大小方面存在一個(gè)根本的局限:模型必須完全能夠在一個(gè)worker上進(jìn)行處理。隨著BERT和GPT-2等語(yǔ)言模型大小和復(fù)雜度的增加蹦玫,神經(jīng)網(wǎng)絡(luò)已經(jīng)接近了現(xiàn)代硬件加速器的內(nèi)存容量赎婚。這個(gè)問題的一個(gè)解決方案是采用參數(shù)共享來減少模型的內(nèi)存占用,但這限制了模型的總體容量樱溉。我們的方法是利用模型并行將模型劃分到多個(gè)加速器上挣输。這不僅減輕了內(nèi)存壓力,而且獨(dú)立于微批量大小增加了并行度饺窿。
在模型并行中歧焦,還有兩種進(jìn)一步的范式:逐層流水線并行(layer-wise pipeline parallelism)和更通用的分布式張量計(jì)算(distributed tensor computation)。在流水線模型并行中,一組操作首先在一個(gè)設(shè)備上執(zhí)行绢馍,然后將輸出傳遞到流水線中的下一個(gè)設(shè)備向瓷,在下一個(gè)設(shè)備上執(zhí)行不同的另一組操作。一些方法與流水線并行結(jié)合使用參數(shù)服務(wù)器舰涌。然而猖任,這些方法存在一致性問題。TensorFlow中的GPipe框架通過使用同步梯度下降來解決這種一致性問題瓷耙。這種方法需要額外的邏輯來處理這些通信和計(jì)算操作的高效流水線朱躺,并受到減少效率的流水線bubble的影響,或者對(duì)優(yōu)化器本身的更改會(huì)影響準(zhǔn)確性搁痛。
分布式張量計(jì)算是一種正交的长搀、更通用的方法,它將張量操作劃分到多個(gè)設(shè)備上以加速計(jì)算或增加模型大小鸡典。編排這種并行計(jì)算的深度學(xué)習(xí)框架FlexFlow提供了一種選擇最佳并行策略的方法源请。最近,Mesh-TensorFlow在TensorFlow中引入了一種指定分布式張量計(jì)算的通用類的語(yǔ)言彻况。并行維度由終端用戶在這一語(yǔ)言中指定谁尸,生成的圖由適當(dāng)?shù)募w原語(yǔ)編譯。我們利用類似于Mesh-TensorFlow中的見解纽甘,并利用transformer注意力頭的并行計(jì)算來并行化我們的transformer模型良蛮。但是,我們沒有實(shí)現(xiàn)一個(gè)用于模型并行的框架和編譯器悍赢,而是僅對(duì)現(xiàn)有的PyTorch transformer實(shí)現(xiàn)進(jìn)行了一些有針對(duì)性的修改决瞳。我們的方法很簡(jiǎn)單,不需要任何新的編譯器或代碼重寫泽裳,可以通過插入幾個(gè)簡(jiǎn)單的基元來完全實(shí)現(xiàn)瞒斩,如下一節(jié)所述。
三涮总、方法
我們利用transformer網(wǎng)絡(luò)的結(jié)構(gòu)來創(chuàng)建一個(gè)簡(jiǎn)單的模型并行實(shí)現(xiàn),只需要添加幾個(gè)同步原語(yǔ)祷舀。一個(gè)transformer層由一個(gè)自注意力塊和一個(gè)兩層多層感知器(MLP)組成瀑梗,如下圖所示。我們?cè)谶@兩個(gè)塊中分別引入模型并行裳扯。
首先詳細(xì)說明MLP塊抛丽。該塊的第一部分是一個(gè)GEMM(General Matrix Multiplication),后面是一個(gè)GELU非線性層:
并行化GEMM的一種方法是按行切分權(quán)重矩陣饰豺,按列切分輸入
:
這將得到。由于GeLU是非線性函數(shù),
蒿柳,這種切分方式需要在GeLU函數(shù)之前進(jìn)行同步饶套。
另一種切分方法是按列切分。這種切分方式允許獨(dú)立地對(duì)每個(gè)切分后的GEMM的輸出應(yīng)用GeLU非線性函數(shù):
這種方法的優(yōu)點(diǎn)是去除了一個(gè)同步點(diǎn)垒探。因此妓蛮,我們采用這種列并行方式切分第一個(gè)GEMM,直接將GeLU層的輸出作為第二個(gè)GEMM(這個(gè)GEMM以按行切分的方式并行)的輸入圾叼,而不需要任何通信蛤克,如下圖(a)所示。
第二個(gè)GEMM的輸出通過dropout層之前先跨GPU進(jìn)行reduce夷蚊。這種方法將MLP塊中的兩個(gè)GEMM切分到不同的GPU上构挤,前向傳播中只需要一個(gè)all-reduce操作(g操作符),反向傳播中也只需要一個(gè)all-reduce(f操作符)惕鼓。這兩個(gè)操作符互為共軛筋现,可以通過PyTorch中的幾行代碼實(shí)現(xiàn)。例如呜笑,f操作符的實(shí)現(xiàn)如下:
class f(torch.autograd.Function):
def forward(ctx, x):
return x
def backward(ctx, gradient):
all_reduce(gradient)
return gradient
如上圖3(b)所示夫否,對(duì)于自注意力塊,我們利用了多頭注意力操作中固有的并行性叫胁,以列并行的方式切分key(K)凰慈、query(Q)和value(V)對(duì)應(yīng)的GEMM,這樣每個(gè)注意力頭對(duì)應(yīng)的矩陣乘法在一個(gè)GPU上本地計(jì)算驼鹅。這允許我們?cè)贕PU之間切分每個(gè)注意力頭的參數(shù)和工作量微谓,并且不需要任何直接的通信就可以完成自注意力。在自注意力之后的輸出線性層的GEMM(自注意力之后)沿其行進(jìn)行并行化输钩,直接獲取并行注意力層的輸出豺型,而不需要GPU之間的通信。MLP和自注意力這兩種塊的并行方法都融合了兩組GEMM买乃,消除了中間的同步點(diǎn)姻氨,從而獲得了更好的擴(kuò)展性。這使我們能夠使用僅兩個(gè)all-reduce在前向路徑中完成transformer層中的所有GEMM的計(jì)算剪验,并在反向路徑中也使用兩個(gè)all-reduce(如下圖所示)肴焊。
Transformer語(yǔ)言模型的輸出嵌入維度為隱層維度乘以詞匯表大小
。由于詞匯表的大小是萬級(jí)的(例如功戚,GPT-2使用的詞匯表大小為50娶眷,257),將輸出嵌入GEMM并行化是有益的啸臀。但是届宠,在transformer語(yǔ)言模型中,輸出嵌入層與輸入嵌入共享權(quán)重,這需要同時(shí)修改這兩者豌注。我們沿詞典維度
切分輸入嵌入權(quán)重矩陣
為
(列向)伤塌。由于每個(gè)切分部分現(xiàn)在只包含嵌入表的一部分,在輸入嵌入之后需要一個(gè)all-reduce(g操作符)幌羞。對(duì)于輸出嵌入寸谜,一種方法是執(zhí)行并行GEMM
以獲得logits,添加一個(gè)all-gather
属桦,并將結(jié)果發(fā)送到交叉熵?fù)p失函數(shù)熊痴。但是,在這種情況下聂宾,all-gather將通信
個(gè)元素(其中
是batch大小果善,
是序列長(zhǎng)度),由于詞匯表大小很大系谐,這會(huì)產(chǎn)生巨大的通信量巾陕。為了減小通信量,我們將并行GEMM
的輸出與交叉熵?fù)p失函數(shù)融合纪他,這將維度降低到
鄙煤。通信標(biāo)量損失而不是logits極大地減少了通信量,這極大地提高了我們的模型并行方法的效率茶袒。
我們的模型并行方法的很大一部分可以歸納為針對(duì)減少通信并保持GPU計(jì)算限度的技術(shù)梯刚。對(duì)于dropout、層標(biāo)準(zhǔn)化薪寓、殘差連接等計(jì)算亡资,我們選擇跨GPU重復(fù)計(jì)算,而不是讓一個(gè)GPU計(jì)算部分然后廣播結(jié)果到其他GPU向叉。具體來說锥腻,我們?cè)诿總€(gè)GPU上維護(hù)層標(biāo)準(zhǔn)化參數(shù)的副本,并在模型并行區(qū)域的輸出上運(yùn)行dropout和殘差連接母谎,然后將其饋送作為下一個(gè)模型并行區(qū)域的輸入瘦黑。對(duì)于模型的優(yōu)化,我們?cè)试S每個(gè)模型并行worker優(yōu)化自己的參數(shù)集合奇唤。由于所有的值要么是本地的(比如GEMM的參數(shù))供璧,要么被復(fù)制在多個(gè)GPU上(比如層標(biāo)準(zhǔn)化的參數(shù)),所以在這種形式中不需要通信更新的參數(shù)值冻记。
總而言之,我們上述的方法實(shí)現(xiàn)起來很監(jiān)督来惧,只需要在前向和反向傳播中添加幾個(gè)額外的all-reduce操作冗栗。它不需要編譯器,與GPipe等方法提倡的流水線模型并行是正交互補(bǔ)的。
四隅居、實(shí)驗(yàn)
- 并行效率
- GPT-2實(shí)驗(yàn)
- BERT實(shí)驗(yàn)
在進(jìn)行BERT相關(guān)的實(shí)驗(yàn)時(shí)钠至,我們發(fā)現(xiàn)到模型參數(shù)規(guī)模超過BERT-large(336M)時(shí)會(huì)出現(xiàn)模型性能退化。我們發(fā)現(xiàn)按照下圖7的方式重排層標(biāo)準(zhǔn)化與殘差連接的順序后可以解決這個(gè)問題胎源。