Date: 2020/08/03
Author: CW
Foreword:
各位煉丹者應(yīng)該都會(huì)有自己常用的一種或幾種深度學(xué)習(xí)框架鹤啡,如 MxNet、Caffe清蚀、Tensorflow、Pytorch爹谭、PaddlePaddle(百度)枷邪,甚至是國(guó)產(chǎn)新興框架 MegEngine(曠視)、MindSpore(華為)等诺凡,在涉及介紹這些框架的時(shí)候东揣,都會(huì)提及動(dòng)態(tài)圖和靜態(tài)圖這樣的概念,那么它們究竟是什么意思呢腹泌?在框架中又是如何體現(xiàn)與使用的呢嘶卧?本文會(huì)結(jié)合 Tensorflow、Pytorch 以及小鮮肉 MegEngine 的例子來(lái)為諸位揭開這神秘的面紗凉袱。
計(jì)算圖
不論是動(dòng)態(tài)圖還是靜態(tài)圖芥吟,它們都屬于計(jì)算圖侦铜。計(jì)算圖是用來(lái)描述運(yùn)算的有向無(wú)環(huán)圖,它有兩個(gè)主要元素:結(jié)點(diǎn)(Node)和邊(Edge)钟鸵。結(jié)點(diǎn)表示數(shù)據(jù)钉稍,如向量、矩陣棺耍、張量贡未,而邊表示運(yùn)算,如加減乘除卷積等蒙袍。
采用計(jì)算圖來(lái)描述運(yùn)算的好處不僅是讓運(yùn)算流的表達(dá)更加簡(jiǎn)潔清晰俊卤,還有一個(gè)更重要的原因是方便求導(dǎo)計(jì)算梯度。
上圖表示的是 y = (w + x) * (w + 1) 代表的計(jì)算圖害幅,若要計(jì)算y對(duì)w的導(dǎo)數(shù)消恍,那么結(jié)合鏈?zhǔn)角髮?dǎo)法則,就在計(jì)算圖中反向從y找到所有到w的路徑矫限,每條路徑上各段的導(dǎo)數(shù)相乘就是該路徑的偏導(dǎo)哺哼,最后再將所有路徑獲得的偏導(dǎo)求和即可。
葉子節(jié)點(diǎn)是用戶創(chuàng)建的變量叼风,如上圖的x與w取董,在Pytorch的實(shí)現(xiàn)中,為了節(jié)省內(nèi)存无宿,在梯度反向傳播結(jié)束后茵汰,非葉子節(jié)點(diǎn)的梯度都會(huì)被釋放掉。
動(dòng)態(tài)圖
動(dòng)態(tài)圖意味著計(jì)算圖的構(gòu)建和計(jì)算同時(shí)發(fā)生(define by run)孽鸡。這種機(jī)制由于能夠?qū)崟r(shí)得到中間結(jié)果的值蹂午,使得調(diào)試更加容易,同時(shí)我們將大腦中的想法轉(zhuǎn)化為代碼方案也變得更加容易彬碱,對(duì)于編程實(shí)現(xiàn)來(lái)說更友好豆胸。Pytorch使用的就是動(dòng)態(tài)圖機(jī)制,因此它更易上手巷疼,風(fēng)格更加pythonic晚胡,大受科研人員的喜愛。
靜態(tài)圖
靜態(tài)圖則意味著計(jì)算圖的構(gòu)建和實(shí)際計(jì)算是分開(define and run)的嚼沿。在靜態(tài)圖中估盘,會(huì)事先了解和定義好整個(gè)運(yùn)算流,這樣之后再次運(yùn)行的時(shí)候就不再需要重新構(gòu)建計(jì)算圖了(可理解為編譯)骡尽,因此速度會(huì)比動(dòng)態(tài)圖更快遣妥,從性能上來(lái)說更加高效,但這也意味著你所期望的程序與編譯器實(shí)際執(zhí)行之間存在著更多的代溝攀细,代碼中的錯(cuò)誤將難以發(fā)現(xiàn)箫踩,無(wú)法像動(dòng)態(tài)圖一樣隨時(shí)拿到中間計(jì)算結(jié)果爱态。Tensorflow默認(rèn)使用的是靜態(tài)圖機(jī)制,這也是其名稱的由來(lái),先定義好整個(gè)計(jì)算流(flow),然后再對(duì)數(shù)據(jù)(tensor)進(jìn)行計(jì)算哆键。
動(dòng)態(tài)圖 vs 靜態(tài)圖
通過一個(gè)例子來(lái)對(duì)比下動(dòng)態(tài)圖和靜態(tài)圖機(jī)制在編程實(shí)現(xiàn)上的差異凡傅,分別基于Pytorch和Tensorflow實(shí)現(xiàn),先來(lái)看看Pytorch的動(dòng)態(tài)圖機(jī)制:
import torch
first_counter = torch.Tensor([0])
second_counter = torch.Tensor([10])
while (first_counter < second_counter)[0]:
? ? first_counter += 2
? ? second_counter += 1
print(first_counter)
print(second_counter)
可以看到,這與普通的Python編程無(wú)異。
再來(lái)看看在基于Tensorflow的靜態(tài)圖機(jī)制下是如何實(shí)現(xiàn)上述程序的:
import tensorflow as tf
first_counter = tf.constant(0)
second_counter = tf.constant(10)
# tensorflow
import tensorflow as tf
first_counter = tf.constant(0)
second_counter = tf.constant(10)
def cond(first_counter, second_counter, *args):
? ? return first_counter < second_counter
def body(first_counter, second_counter):
? ? first_counter = tf.add(first_counter, 2)
? ? second_counter = tf.add(second_counter, 1)
? ? return first_counter, second_counter
c1, c2 = tf.while_loop(cond, body, [first_counter, second_counter])
with tf.Session() as sess:
? ? counter_1_res, counter_2_res = sess.run([c1, c2])
print(counter_1_res)
print(counter_2_res)
(⊙o⊙)… 對(duì)Tensorflow不熟悉的童鞋來(lái)說,第一反應(yīng)可能會(huì)是:這什么鬼6幻骸?確實(shí)猿规,看上去會(huì)有點(diǎn)難受..
Tensorflow在靜態(tài)圖的模式下衷快,每次運(yùn)算使用的計(jì)算圖都是同一個(gè),因此不能直接使用 Python 的 while 循環(huán)語(yǔ)句姨俩,而是要使用其內(nèi)置的輔助函數(shù) tf.while_loop蘸拔,而且還要tf.Session().run()之類的亂七八糟..
而Pytorch是動(dòng)態(tài)圖的模式,每次運(yùn)算會(huì)構(gòu)建新的計(jì)算圖环葵,在編程實(shí)現(xiàn)上不需要額外的學(xué)習(xí)成本(當(dāng)然首先你得會(huì)Python)调窍。
動(dòng)靜結(jié)合
在最近開源的框架MegEngine中,集成了兩種圖模式张遭,并且可以進(jìn)行相互切換邓萨,下面舉例說明將動(dòng)態(tài)圖轉(zhuǎn)換為靜態(tài)圖編譯過程中進(jìn)行的內(nèi)存和計(jì)算優(yōu)化:
y = w*x + b 的動(dòng)態(tài)計(jì)算圖如下:
可以看到,中間的運(yùn)算結(jié)果是被保留下來(lái)的菊卷,如p=w*x缔恳,這樣就一共需要5個(gè)變量的存儲(chǔ)空間。若切換為靜態(tài)圖洁闰,由于事先了解了整個(gè)計(jì)算流歉甚,因此可以讓y復(fù)用p的內(nèi)存空間,這樣一共就只需要4個(gè)變量的存儲(chǔ)空間扑眉。
另外铃芦,MegEngine 還使用了 算子融合 (Operator Fuse)的機(jī)制,用于減少計(jì)算開銷襟雷。對(duì)于上面的動(dòng)態(tài)計(jì)算圖,切換為靜態(tài)圖后可以將乘法和加法融合為一個(gè)三元操作(假設(shè)硬件支持):乘加(如下圖所示)仁烹,從而降低計(jì)算量耸弄。