作者使用了batch大型吆簟:8192编矾,使用了256 GPUs,在一個(gè)小時(shí)內(nèi)訓(xùn)練了ResNet-50氧映,并且得到了和256大小的batch同樣的訓(xùn)練精度。
2 Large Minibatch SGD
通常來(lái)說(shuō)脱货,我們?cè)谟?xùn)練有監(jiān)督任務(wù)的時(shí)候岛都,會(huì)最小化loss:
是網(wǎng)絡(luò)的參數(shù)律姨,
是訓(xùn)練集,
就是損失函數(shù)臼疫。
minibatch SGD就是在一個(gè)batch的訓(xùn)練集上择份,進(jìn)行參數(shù)的更新:
2.1 Learning Rates for Large Minibatches
論文的目的是在使用非常大的batch的時(shí)候能夠維持訓(xùn)練的準(zhǔn)確性和泛化性能。具體來(lái)說(shuō)烫堤,就是在使用多個(gè)worker來(lái)進(jìn)行數(shù)據(jù)并行訓(xùn)練的時(shí)候荣赶,不會(huì)犧牲模型的accuracy。
作者發(fā)現(xiàn)鸽斟,下面的learning rate scaling rule能夠適合于很大范圍的batch size拔创。
Linear Scaling Rule:當(dāng)minibatch size乘以一個(gè)數(shù)
,同樣learning rate也乘以這個(gè)數(shù)
富蓄。
所有其他超參數(shù)保持不變剩燥,
-
interpretation解釋:為什么上面的方法會(huì)有效呢?首先考慮一個(gè)網(wǎng)絡(luò)在某一個(gè)時(shí)刻的參數(shù)
立倍,和一組
個(gè)minibatches
灭红,每一個(gè)minibatch的大小為
。我們比較一下每個(gè)minibatch單獨(dú)訓(xùn)練和這
個(gè)batch一起訓(xùn)練的效果口注。
第一種情況:在進(jìn)行了
次更新后
-
第二種情況:訓(xùn)練是在這
個(gè)batch的合集上進(jìn)行变擒,batch size大小為
update rule
顯然兩個(gè)結(jié)果不太可能一樣,但是假如 并且
那么我們就可以得到
寝志。
2.2 Warmup熱身
當(dāng)網(wǎng)絡(luò)變化很劇烈的時(shí)候娇斑,上面提出的假設(shè)就不會(huì)成立,那么Linear Scaling Rule就不會(huì)有效果澈段。但是作者發(fā)現(xiàn)悠菜,這樣的情況可以通過(guò)一種熱身的方式來(lái)緩解舰攒,具體來(lái)說(shuō)就是败富,在訓(xùn)練的開(kāi)始,使用一個(gè)更小的learning rate摩窃。
-
Constant warmup:一種熱身的策略是使用一個(gè)小的定值作為初始的學(xué)習(xí)率兽叮,訓(xùn)練幾個(gè)回合。這種策略對(duì)于物體檢測(cè)猾愿,分割鹦聪,fine-tune等問(wèn)題在有些時(shí)候效果較好,但是當(dāng)
較大也就是batch較大的時(shí)候蒂秘,就不是那么有效了泽本,尤其在熱身結(jié)束的時(shí)候會(huì)出現(xiàn)error的峰值。
- gradual warmup:為了克服constant warmup的不足姻僧,作者使用了gradual warmup悼泌,就是一點(diǎn)一點(diǎn)地將學(xué)習(xí)率從小称近,增大惦费。并且在增大后,回復(fù)到原始的learning rate schedule冰抢。
2.3 Batch Normalization with Large Minibatches
BN在提高訓(xùn)練效率和精度有很大的效果。但是一個(gè)minibatch在計(jì)算一些統(tǒng)計(jì)量的時(shí)候艘狭,需要整個(gè)minibatch的數(shù)據(jù)挎扰,當(dāng)分布式或者多卡訓(xùn)練的時(shí)候,就會(huì)導(dǎo)致非常多的數(shù)據(jù)需要傳輸巢音。
當(dāng)使用BN的時(shí)候遵倦,每個(gè)sample的loss就會(huì)和整個(gè)batch的統(tǒng)計(jì)量相關(guān),我們用表示單個(gè)sample的loss官撼。用
表示整個(gè)batch的loss骇吭。那么整個(gè)訓(xùn)練集的loss表示為
。
表示一個(gè)大小為
的batch歧寺。
當(dāng)我們改變的大小的時(shí)候燥狰,就相當(dāng)于改變loss function。More specifically the mean/variance statics computed by BN with different
exhibit different levels of random variation斜筐。
在分布式和多卡訓(xùn)練的情況下龙致,如果每個(gè)worker的batch size大小為,那么總共的batch大小就是
顷链,相當(dāng)于從許多batch中選擇了
個(gè)samples目代,每個(gè)sample就是一個(gè)batch。那么之前的公式就變?yōu)?br>
We also note that the BN statics should not be computed across all workers, not only for the sake of reducing communication, but also for maintaining the same underlying loss function being optimized.
Subtleties and Pitfalls of Distributed SGD
Weight decay:weight decay是參數(shù)的L2-正則項(xiàng)嗤练。加入正則后的更新公式變?yōu)?br>
最后一項(xiàng)
Remark 1: Scaling the cross-entropy loss is not equivalent to scaling the learning rate.
Momentum correction:帶動(dòng)量的SGD被廣泛應(yīng)用于神經(jīng)網(wǎng)絡(luò)的更新中。一種常見(jiàn)的形式如下:
用
需要注意的是革答,
remark 2: Apply momentum correction after changing learning rate if using (10)
- Gradient aggregation:對(duì)于每個(gè)worker的訓(xùn)練結(jié)果,需要將梯度匯聚起來(lái),求平均用于更新參數(shù)栅组。
remark 3: Normalize the per-worker loss by total minibatch size
, not per-worker size
袱衷。
- Data shuffling:
remark 4: Use a single random shuffling of the training data (per epoch) that is divided amongst all
workers.
Communication
對(duì)于每一個(gè)參數(shù)的梯度,都是通過(guò)allreduce操作來(lái)進(jìn)行匯聚的笑窜。在進(jìn)行allreduce之前致燥,每個(gè)GPU都會(huì)計(jì)算自己的梯度,在allreduce*之后排截,每個(gè)GPU得到梯度的和嫌蚤。
論文中還討論了軟件和硬件的實(shí)現(xiàn)相關(guān),詳情可參考論文断傲。