jittor簡述
由清華大學(xué)研制并開源了第一個(gè)我國高校自主的深度學(xué)習(xí)框架——計(jì)圖(Jittor)慧脱。計(jì)圖是一個(gè)完全動(dòng)態(tài)編譯(Just-in-time)训唱,基于元算子融合和統(tǒng)一計(jì)算圖的深度學(xué)習(xí)框架练链。計(jì)圖支持30多種的骨干網(wǎng)絡(luò)沃但,并且開源了多個(gè)模型庫:對(duì)抗生成網(wǎng)絡(luò)可款、圖像語義分割炉奴、檢測與實(shí)例分割逼庞、點(diǎn)云分類、可微渲染等一個(gè)完全即時(shí)(JIT)編譯的深度學(xué)習(xí)框架瞻赶。通過JIT編譯赛糟,我們可以實(shí)現(xiàn)更高的性能派任,同時(shí)使系統(tǒng)高度可定制。
2.jittor中的算子融合
Listing 1 Python implementation of convolution using three operators: reindex, broadcast, and sum
1: def conv(x, p):
2: N,C,H,W = x.shape
3: o,i,h,w = p.shape
4: xx = x.reindex(
5: shape=(N,o,H,W,i,h,w),
6: indices=("i0", "i4", "i2-i5", "i3-i6")
7: )
8: pp = p.broadcast(xx.shape, dims=(0,2,3))
9: yy = xx*pp
10: y = yy.sum(dims=(4,5,6))
11: return y
jittor中的算子融合屬于元算子融合璧南,這個(gè)例子展示了元算子運(yùn)算符實(shí)現(xiàn)卷積運(yùn)算掌逛。
代碼解讀:
第1行顯示conv有兩個(gè)參數(shù):x是圖像張量,p是參數(shù)張量司倚。
第2行和第3行解包了關(guān)于參數(shù)形狀的信息豆混。圖像張量x的布局為:批次數(shù)(N)、通道數(shù)(C)动知、圖像高度(H)和圖像寬度(W)皿伺。
參數(shù)張量p的布局為:輸出通道數(shù)(o)、輸入通道數(shù)(I)盒粮、內(nèi)核高度(h)和內(nèi)核寬度(w)鸵鸥。
第4–7行使用輸入張量x和輸出張量xx調(diào)用reindex運(yùn)算符。結(jié)果是:
xx(i0,i1,i2,i3,i4,i5,i6) = x(i0,i4,i2–i5,i3–i6)
第8行廣播參數(shù)張量p輸出張量pp丹皱,形狀與xx相同妒穴。廣播操作符是reindex操作符的專門化,相當(dāng)于pp = p.reindex(x.shape摊崭,indexs =(i1宰翅,i4,i5爽室,i6))。這個(gè)廣播運(yùn)營商的結(jié)果是: pp(i0,i1,i2,i3,i4,i5,i6) = p(i1,i4,i5,i6)
第9行對(duì)結(jié)果yy(i0淆攻,i1阔墩,i2,i3瓶珊,i4啸箫,i5,i6) = xx(i0伞芹,i1忘苛,i2,i3唱较,i4扎唾,i5,i6)pp(i0南缓,i1胸遇,i2,i3汉形,i4纸镊,i5倍阐,i6)執(zhí)行逐元素乘法。
第10行使用sum運(yùn)算符(reindex-reduce的專門化)來計(jì)算y(i0逗威,i1峰搪,i2,i3) = X i4凯旭,i5概耻,i6 yy(i0,i1尽纽,i2咐蚯,i3,i4弄贿,i5春锋,i6)。
上面的例子展示了如何通過對(duì)元運(yùn)算符的4次調(diào)用來實(shí)現(xiàn)卷積差凹。Jittor能夠?qū)⑦@4個(gè)元操作符融合成一個(gè)操作符期奔,這樣中間變量xx、pp危尿、yy就不需要實(shí)際計(jì)算了呐萌。融合所有4個(gè)元算子得到最終表達(dá)式:y(i0,i1谊娇,i2肺孤,i3) = X i4,i5济欢,i6 x(i0赠堵,i4,I2–i5法褥,i3–i6)p(i1茫叭,i4,i5半等,i6)揍愁。以類似的方式,元算子也可以用于實(shí)現(xiàn)各種卷積變體杀饵,例如膨脹卷積和群卷積莽囤。
3. Operator fuser
算子融合是Jittor后端的重要組成部分。它負(fù)責(zé)任意計(jì)算圖中的算子融合優(yōu)化切距。在上面烁登,我們展示了一個(gè)使用卷積計(jì)算的算子融合的例子。在實(shí)際應(yīng)用中,前端產(chǎn)生的計(jì)算圖要復(fù)雜得多饵沧。為了優(yōu)化任意情況锨络,我們將計(jì)算圖視為頂點(diǎn)和邊的有向無環(huán)圖,G = (V狼牺,E)羡儿,其中每個(gè)節(jié)點(diǎn)V代表一個(gè)算子,而每個(gè)邊E代表一個(gè)變量是钥。我們希望將G劃分為多個(gè)子圖Gi'?G掠归,其中每個(gè)子圖Gi' =(Vi',Ei')代表一個(gè)融合算子悄泥,每個(gè)節(jié)點(diǎn)恰好屬于一個(gè)子圖虏冻,每個(gè)邊可以屬于一個(gè)子圖或鏈接兩個(gè)子圖。目標(biāo)是選擇一個(gè)執(zhí)行所有子圖的成本最小的分區(qū)弹囚。然而厨相,準(zhǔn)確預(yù)測實(shí)際執(zhí)行成本是不可行的:它們?nèi)Q于硬件和其他因素的許多方面。因此鸥鹉,我們使用一種簡化的方法蛮穿,通過將成本定義為來確定子圖
其中We, 簡單地是由邊 e. Eq表示的變量的大小。這個(gè)式子對(duì)鏈接兩個(gè)不同子圖的每條邊e的權(quán)重求和毁渗,因此不屬于任何子圖Gi'践磅,這個(gè)代價(jià)相當(dāng)于讀寫指令的總數(shù)。這種方法是合理的灸异,因?yàn)榇蠖鄶?shù)深度學(xué)習(xí)模型都受到內(nèi)存帶寬的限制府适。融合可以通過減少內(nèi)存操作來提高性能。在最小化成本的同時(shí)肺樟,需要滿足以下規(guī)則:
規(guī)則1檐春。重新索引操作符不能與前面的元操作符融合,因?yàn)檫@種融合通常會(huì)導(dǎo)致性能下降儡嘶。
規(guī)則2。Reindex-reduce運(yùn)算符不能與以下元運(yùn)算符融合恍风。這種融合不會(huì)提高性能蹦狂。
規(guī)則3。融合不應(yīng)該在子圖之間創(chuàng)建有向循環(huán)朋贬。例如凯楔,給定一個(gè)有三個(gè)節(jié)點(diǎn)和三條邊的圖:(1 → 2),(2 → 3)锦募,(1 → 3)摆屯,如果第三條邊被融合,它將在結(jié)果中的子圖(1,3)和(平凡的)子圖2之間產(chǎn)生一個(gè)循環(huán):(1虐骑,3) ? 2准验。
4.最小化搜索成本
使用貪婪算法來最小化成本:在每次迭代中,我們選擇滿足規(guī)則1–3的邊e = (vstart廷没,vend)糊饱,并將vstart,vend融合到vstart所屬的子圖G’中颠黎,重復(fù)直到找不到滿足規(guī)則1–3的邊另锋。在實(shí)踐中,使用動(dòng)態(tài)規(guī)劃標(biāo)記算法來避免重復(fù)搜索滿足規(guī)則的邊狭归。該算法運(yùn)行良好夭坪,在大多數(shù)神經(jīng)網(wǎng)絡(luò)中取得了競爭性能。
圖顯示了經(jīng)典網(wǎng)絡(luò)組合的操作融合潭流,卷積-歸一化-激活竞惋。卷積層由兩個(gè)重新索引操作符組成,一個(gè)逐元素操作符和一個(gè)重新索引縮減操作符灰嫉。規(guī)范化層由一個(gè)重新索引操作符拆宛、一個(gè)逐元素操作符和一個(gè)重新索引-縮減操作符組成。激活層由多個(gè)元素操作符組成讼撒。在這種情況下浑厚,操作符可以跨卷積、歸一化和激活層進(jìn)行融合根盒。