混合精度訓(xùn)練

原文來自于機(jī)器學(xué)習(xí)算法與自然語言處理公眾號

混合精度訓(xùn)練

  • 《MIXED PRECISION TRAINING》這篇論文是百度&Nvidia研究院一起發(fā)表的,結(jié)合N卡底層計算優(yōu)化集侯,提出了一種灰常有效的神經(jīng)網(wǎng)絡(luò)訓(xùn)練加速方法件蚕,不僅是預(yù)訓(xùn)練,在全民finetune BERT的今天變得異常有用秋泳。而不僅百度的paddle框架支持混合精度訓(xùn)練,在Tensorflow和Pytorch中也有相應(yīng)的實現(xiàn)攒菠。下面我們先來講講理論迫皱,后面再分析混合精度訓(xùn)練在三大深度學(xué)習(xí)框架中的打開方式。

理論原理

  • 訓(xùn)練過神經(jīng)網(wǎng)絡(luò)的小伙伴都知道辖众,神經(jīng)網(wǎng)絡(luò)的參數(shù)和中間結(jié)果絕大部分都是單精度浮點數(shù)(即float32)存儲和計算的卓起,當(dāng)網(wǎng)絡(luò)變得超級大時,降低浮點數(shù)精度凹炸,比如使用半精度浮點數(shù)戏阅,顯然是提高計算速度,降低存儲開銷的一個很直接的辦法啤它。然而副作用也很顯然饲握,如果我們直接降低浮點數(shù)的精度直觀上必然導(dǎo)致模型訓(xùn)練精度的損失私杜。但是呢,天外有天救欧,這篇文章用了三種機(jī)制有效地防止了模型的精度損失衰粹。

權(quán)重備份(master weights)

  • 我們知道半精度浮點數(shù)(float16)在計算機(jī)中的表示分為1bit的符號位,5bits的指數(shù)位和10bits的尾數(shù)位笆怠,所以它能表示的最小的正數(shù)即2^-24(也就是精度到此為止了)铝耻。當(dāng)神經(jīng)網(wǎng)絡(luò)中的梯度灰常小的時候,網(wǎng)絡(luò)訓(xùn)練過程中每一步的迭代(灰常小的梯度 ? 也黑小的learning rate)會變得更小蹬刷,小到float16精度無法表示的時候瓢捉,相應(yīng)的梯度就無法得到更新。
  • 論文統(tǒng)計了一下在Mandarin數(shù)據(jù)集上訓(xùn)練DeepSpeech 2模型時產(chǎn)生過的梯度办成,發(fā)現(xiàn)在未乘以learning rate之前泡态,就有接近5%的梯度直接悲劇的變成0(精度比2^-24還要高的梯度會直接變成0),造成重大的損失迂卢。


  • 還有更難的某弦,假設(shè)迭代量逃過一劫準(zhǔn)備奉獻(xiàn)自己的時候。而克。靶壮。由于網(wǎng)絡(luò)中的權(quán)重往往遠(yuǎn)大于我們要更新的量,當(dāng)?shù)啃∮贔loat16當(dāng)前區(qū)間內(nèi)能表示的最小間隔的時候员萍,更新也會失敗腾降。
  • 作者這里提出了一個非常simple but effective的方法,就是前向傳播和梯度計算都用float16碎绎,但是存儲網(wǎng)絡(luò)參數(shù)的梯度時要用float32螃壤!這樣就可以一定程度上的解決上面說的兩個問題啦。
  • 我們來看一下訓(xùn)練曲線筋帖,藍(lán)色的線是正常的float32精度訓(xùn)練曲線奸晴,橙色的線是使用float32存儲網(wǎng)絡(luò)參數(shù)的learning curve,綠色滴是不使用float32存儲參數(shù)的曲線幕随,兩者一比就相形見絀啦蚁滋。

損失放縮(loss scaling)

  • 雖然使用float32來存儲梯度宿接,確實不會丟失精度了赘淮,但是計算過程中出現(xiàn)的指數(shù)位小于 -24 的梯度不還是會丟失。于是loss scaling方法來了睦霎。首先作者統(tǒng)計了一下訓(xùn)練過程中激活函數(shù)梯度的分布情況梢卸,由于網(wǎng)絡(luò)中的梯度往往都非常小,導(dǎo)致在使用FP16的時候右邊有大量的范圍是沒有使用的副女。這種情況下蛤高, 我們可以通過放大loss來把整個梯度右移,減少因為精度隨時變?yōu)?的梯度。
  • 那么問題來了戴陡,怎么合理的放大loss呢塞绿?一個最簡單的方法是常數(shù)縮放,把loss一股腦統(tǒng)一放大S倍恤批。float16能表示的最大正數(shù)是215*(1+1-2-10)=65504异吻,我們可以統(tǒng)計網(wǎng)絡(luò)中的梯度,計算出一個常數(shù)S喜庞,使得最大的梯度不超過float16能表示的最大整數(shù)即可诀浪。
  • 當(dāng)然啦,還有更加智能的動態(tài)調(diào)整(automatic scaling) 我們先初始化一個很大的S延都,如果梯度溢出雷猪,我們就把S縮小為原來的二分之一;如果在很多次迭代中梯度都沒有溢出晰房,我們也可以嘗試把S放大兩倍求摇。以此類推,實現(xiàn)動態(tài)的loss scaling嫉你。

運算精度(precison of ops)

  • 精益求精再進(jìn)一步月帝,神經(jīng)網(wǎng)絡(luò)中的運算主要可以分為四大類,混合精度訓(xùn)練把一些有更高精度要求的運算幽污,在計算過程中使用float32嚷辅,存儲的時候再轉(zhuǎn)換為float16。
  • 像矩陣乘法和絕大多數(shù)pointwise的計算可以直接使用float16來計算并存儲距误,而reductions簸搞、loss function和一些pointwise(如exp,log准潭,pow等函數(shù)值遠(yuǎn)大于變量的函數(shù))需要更加精細(xì)的處理趁俊,所以在計算中使用用float32,再將結(jié)果轉(zhuǎn)換為float16來存儲刑然。

Pytorch

  • 導(dǎo)入Automatic Mixed Precision (AMP)
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 這里是“歐一”寺擂,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

import torch
from apex import amp
model = ... 
optimizer = ...

#包裝model和optimizer
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for data, label in data_iter: 
    out = model(data) 
    loss = criterion(out, label) 
    optimizer.zero_grad() 
    
    #loss scaling,代替loss.backward()
    with amp.scaled_loss(loss, optimizer) as scaled_loss:   
        scaled_loss.backward() 
optimizer.step()

Tensorflow

  • 一句話實現(xiàn)混合精度訓(xùn)練之修改環(huán)境變量泼掠,在python腳本中設(shè)置環(huán)境變量

os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
  • Graph-based示例
opt = tf.train.AdamOptimizer()

#add a line
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          opt,
          loss_scale='dynamic')
          
train_op = opt.miminize(loss)
  • Keras-based示例

opt = tf.keras.optimizers.Adam()

#add a line
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            opt,
            loss_scale='dynamic')
            
model.compile(loss=loss, optimizer=opt)
model.fit(...)

PaddlePaddle

  • 一句話實現(xiàn)混合精度訓(xùn)練之添加config --use_fp16=true
  • 舉個栗子怔软,基于BERT finetune XNLI任務(wù)時,只需在執(zhí)行時設(shè)置use_fp16為true即可择镇。

export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

BERT_BASE_PATH="chinese_L-12_H-768_A-12"
TASK_NAME='XNLI'
DATA_PATH=/path/to/xnli/data/
CKPT_PATH=/path/to/save/checkpoints/

python -u run_classifier.py --task_name ${TASK_NAME} \
                   --use_fp16=true \  #!!!!!!add a line
                   --use_cuda true \
                   --do_train true \
                   --do_val true \
                   --do_test true \
                   --batch_size 32 \
                   --in_tokens false \
                   --init_pretraining_params ${BERT_BASE_PATH}/params \
                   --data_dir ${DATA_PATH} \
                   --vocab_path ${BERT_BASE_PATH}/vocab.txt \
                   --checkpoints ${CKPT_PATH} \
                   --save_steps 1000 \
                   --weight_decay  0.01 \
                   --warmup_proportion 0.1 \
                   --validation_steps 100 \
                   --epoch 3 \
                   --max_seq_len 128 \
                   --bert_config_path ${BERT_BASE_PATH}/bert_config.json \
                   --learning_rate 5e-5 \
                   --skip_steps 10 \
                   --num_iteration_per_drop_scope 10 \
                   --verbose true
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末挡逼,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子腻豌,更是在濱河造成了極大的恐慌家坎,老刑警劉巖嘱能,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異虱疏,居然都是意外死亡惹骂,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進(jìn)店門做瞪,熙熙樓的掌柜王于貴愁眉苦臉地迎上來析苫,“玉大人,你說我怎么就攤上這事穿扳●媒模” “怎么了?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵矛物,是天一觀的道長茫死。 經(jīng)常有香客問我,道長履羞,這世上最難降的妖魔是什么峦萎? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮忆首,結(jié)果婚禮上爱榔,老公的妹妹穿的比我還像新娘。我一直安慰自己糙及,他們只是感情好详幽,可當(dāng)我...
    茶點故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著浸锨,像睡著了一般唇聘。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上柱搜,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天迟郎,我揣著相機(jī)與錄音,去河邊找鬼聪蘸。 笑死宪肖,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的健爬。 我是一名探鬼主播控乾,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼浑劳!你這毒婦竟也來了阱持?” 一聲冷哼從身側(cè)響起夭拌,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤魔熏,失蹤者是張志新(化名)和其女友劉穎衷咽,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體蒜绽,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨居荒郊野嶺守林人離奇死亡镶骗,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了躲雅。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片鼎姊。...
    茶點故事閱讀 39,690評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖相赁,靈堂內(nèi)的尸體忽然破棺而出相寇,到底是詐尸還是另有隱情,我是刑警寧澤钮科,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布唤衫,位于F島的核電站,受9級特大地震影響绵脯,放射性物質(zhì)發(fā)生泄漏佳励。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一蛆挫、第九天 我趴在偏房一處隱蔽的房頂上張望赃承。 院中可真熱鬧,春花似錦悴侵、人聲如沸瞧剖。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽筒繁。三九已至,卻和暖如春巴元,著一層夾襖步出監(jiān)牢的瞬間毡咏,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工逮刨, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留呕缭,地道東北人。 一個月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓修己,卻偏偏與公主長得像恢总,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子睬愤,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,577評論 2 353

推薦閱讀更多精彩內(nèi)容