我大部分時間都在考慮如何讓神經(jīng)網(wǎng)絡(luò)的深度學(xué)習(xí)更快、更省電。實際上绒怨,這意味著關(guān)注一個名為GEMM的函數(shù)。它是BLAS(基本線性代數(shù)子程序)庫的一部分甩牺,該庫最早創(chuàng)建于1979年,在我開始嘗試優(yōu)化神經(jīng)網(wǎng)絡(luò)之前累奈,我從未聽說過它。為了解釋為什么它如此重要急但,下面是我朋友賈楊青的論文中的一個圖表:
這就是使用Alex Krizhevsky的Imagenet架構(gòu)進(jìn)行圖像識別的典型深度卷積神經(jīng)網(wǎng)絡(luò)的時間澎媒。所有以fc(即:全連接層)或conv(即:卷積層)開頭的層都是使用GEMM實現(xiàn)的,幾乎所有的時間(95%的GPU版本波桩,89%的CPU版本)都花在這些層上戒努。
那么什么是GEMM呢?GEMM代表 GEneral Matrix to Matrix Multiplication (通用矩陣到矩陣乘法),本質(zhì)上它完全按照tin上所說的做储玫,將兩個輸入矩陣相乘侍筛,得到一個輸出矩陣。其與3D圖形世界中使用的矩陣運算的不同之處在于撒穷,它處理的矩陣通常非常大匣椰。例如,典型網(wǎng)絡(luò)中的單個網(wǎng)絡(luò)層可能需要將256行1152列的矩陣乘以1152行192列的矩陣端礼,以產(chǎn)生256行192列的結(jié)果禽笑。天真地說,這需要5700萬(256x1152x192)次浮點運算蛤奥,而且在現(xiàn)代網(wǎng)絡(luò)結(jié)構(gòu)中可能有幾十個這樣的網(wǎng)絡(luò)層佳镜,所以我經(jīng)常看到一個往往需要幾十億次浮點運算來計算單個圖像幀凡桥。下面是我繪制的一張圖表蟀伸,幫助我直觀地了解它的工作原理:
全連接層
全連接層是已經(jīng)存在了幾十年的經(jīng)典神經(jīng)網(wǎng)絡(luò)層,從FC層開始說明如何使用GEMM可能是最容易的缅刽。FC層的每個輸出值都可以看到輸入的每個值啊掏,將輸入乘以該輸入對應(yīng)的權(quán)重,然后對結(jié)果求和以獲得其輸出拷恨。根據(jù)上圖脖律,它看起來是這樣的:
上圖中有“k”個輸入值,“n”個神經(jīng)元腕侄,每個神經(jīng)元都有自己的學(xué)習(xí)權(quán)重集小泉。對應(yīng)的圖中有“n”個輸出值,每個神經(jīng)元對應(yīng)其中一個冕杠,該輸出值利用對其權(quán)重和輸入值進(jìn)行點積運算計算得到微姊。
卷積層
在卷積層中使用GEMM不是一個顯而易見的選擇。conv層將其輸入視為二維圖像分预,每個像素具有多個通道兢交,非常類似于具有寬度、高度和深度的經(jīng)典圖像笼痹。不過配喳,與我以前處理的圖像不同,通道的數(shù)量可以達(dá)到數(shù)百個凳干,而不僅僅是RGB或RGBA晴裹!
卷積運算通過獲取若干“卷積核”的權(quán)重來產(chǎn)生其輸出。并將其應(yīng)用于整個圖像救赐。以下是輸入圖像和單個卷積核:
每個卷積核是另一個三維數(shù)字?jǐn)?shù)組涧团,深度與輸入圖像的深度相同,但寬度和高度要小得多,通常是7×7泌绣。為了得到結(jié)果钮追,卷積運算將卷積核應(yīng)用于輸入圖像上的點網(wǎng)格。在其應(yīng)用的每個點阿迈,所有相應(yīng)的輸入值和權(quán)重都相乘元媚,然后求和,在該點產(chǎn)生一個輸出值仿滔。以下是視覺效果:
你可以將這個運算看做邊緣檢測器惠毁。卷積核包含一個權(quán)重圖案,當(dāng)它所查看的輸入圖像部分具有類似的模式時崎页,它會輸出一個高值鞠绰。當(dāng)輸入與模式不匹配時,結(jié)果是該位置的數(shù)字較低飒焦。以下是一些通過神經(jīng)網(wǎng)絡(luò)第一層學(xué)習(xí)到的典型權(quán)重圖案[1]:
因為神經(jīng)網(wǎng)絡(luò)第一層的輸入是RGB圖像蜈膨,所以所有這些卷積核也可以可視化為RGB,并且它們顯示了網(wǎng)絡(luò)正在尋找的原始圖案牺荠。這96個卷積核中的每一個都以網(wǎng)格模式應(yīng)用于整個卷基層的輸入翁巍,結(jié)果是96個二維數(shù)組,它們被視為深度為96個通道的輸出圖像休雌。如果你習(xí)慣了像Sobel算子這樣的圖像處理操作灶壶,你可能可以想象其中的每一個都有點像一個邊緣檢測器,為圖像中不同的重要模式進(jìn)行了優(yōu)化杈曲,因此每個通道都是這些權(quán)重圖案在輸入中出現(xiàn)的位置的映射驰凛。
你可能已經(jīng)注意到,我對卷積核應(yīng)用于什么樣的網(wǎng)格一直很模糊担扑。關(guān)鍵的控制因素是一個名為“stride”的參數(shù)恰响,它定義了應(yīng)用卷積核之間的間距。例如涌献,當(dāng)stride為1時胚宦,256×256的輸入圖像將在每個像素上應(yīng)用卷積核,并且輸出的寬度和高度將與輸入的寬度和高度相同燕垃。如果stride為4枢劝,則同一輸入圖像每四個像素應(yīng)用一個卷積核,因此輸出將僅為64×64卜壕。一般來說stride會小于卷積核的大小呈野,這意味著在卷積核可以看到的并應(yīng)用的圖表中,很多stride實際上會在邊緣重疊印叁。
GEMM是如何應(yīng)用于卷積層的?
卷積似乎是一個相當(dāng)專業(yè)的運算。它包含多次乘法計算和最后的求和計算轮蜕,比如完全連接層昨悼,但我們不清楚應(yīng)該如何或為什么要將其轉(zhuǎn)化為GEMM矩陣乘法。我將在最后討論將GEMM應(yīng)用與卷積層的動機跃洛,但這里會討論如何用矩陣乘法來表示卷積運算的率触。
第一步是將輸入圖像(實際上是3D數(shù)組)轉(zhuǎn)換為2D數(shù)組,我們可以將其視為矩陣汇竭。應(yīng)用每個卷積核的地方是圖像中的一個小三維立方體葱蝗,因此我們將每個輸入值立方體作為一列復(fù)制到矩陣中。這被稱為im2col细燎,即:image-to-column(圖像到列)两曼,我相信它來自一個原始的Matlab函數(shù),以下是將im2col可視化:
現(xiàn)在玻驻,如果你像我一樣是對圖像處理感興趣的極客悼凑,你可能會對進(jìn)行這種轉(zhuǎn)換時,如果stride小于卷積核大小璧瞬,所需內(nèi)存的增加感到震驚户辫。這意味著包含在重疊卷積核中的像素將在矩陣中進(jìn)行復(fù)制,這似乎效率低下嗤锉。不過渔欢,你必須相信我,這種內(nèi)存使用的浪費會帶來計算上的優(yōu)勢瘟忱。
現(xiàn)在你有了矩陣形式的輸入圖像奥额,你對每個卷積核的權(quán)重做了同樣的操作,將3D立方體序列化成行酷誓,作為矩陣乘法的第二個矩陣披坏。以下是最終GEMM的樣子:
這里的“k”是每個patch和卷積核中的值的個數(shù),所以它是卷積核寬度高度深度盐数。得到的矩陣列高為“patch數(shù)”棒拂,行寬為“卷積數(shù)”。通過后續(xù)操作玫氢,該矩陣實際上被視為一個3D數(shù)組帚屉,方法是以核數(shù)維度作為深度,然后根據(jù)patch在輸入圖像中的原始位置將patch拆分回行和列漾峡。
為什么GEMM可以應(yīng)用于卷積層攻旦?
希望你現(xiàn)在能看到如何使用矩陣計算實現(xiàn)卷積層,但是你為什么要這么做還不清楚生逸。簡單的回答是牢屋,事實證明且预,F(xiàn)ortran世界的科學(xué)程序員花了幾十年時間優(yōu)化代碼,以執(zhí)行大型的矩陣乘法(large matrix to matrix multiplications)烙无,而且非常規(guī)則的內(nèi)存訪問模式帶來的好處超過了浪費的存儲成本锋谐。這篇來自Nvidia的論文[2]很好地介紹了您可以使用的一些不同方法,但它們也描述了為什么最終以修改版的GEMM作為他們最喜歡的方法截酷。能夠同時對同一個卷積核批處理大量輸入圖像也有很多優(yōu)點涮拗,Caffe con TROL[3] 使用了這些方法,取得了很好的效果迂苛。GEMM方法的主要競爭對手是使用傅里葉變換在頻率空間中進(jìn)行運算三热,但在卷積中使用stride使其難以達(dá)到同樣的效率。
好消息是三幻,有一個單一的就漾、易于理解的功能占用了我們大部分時間,這為優(yōu)化速度和電力使用提供了一條非常清晰的途徑赌髓,既可以通過更好的軟件實現(xiàn)从藤,也可以通過調(diào)整硬件來很好地運行操作。由于deep networks已被證明對語音锁蠕、NLP和計算機視覺的大量應(yīng)用非常有效夷野,我期待著在未來幾年看到大規(guī)模的改進(jìn),就像對3D游戲的廣泛需求推動了GPU的革命荣倾,使得vertex和pixel處理運算發(fā)生革命一樣悯搔。
-
Deep Learning for Computer Vision with Caffe and cuDNN https://developer.nvidia.com/blog/deep-learning-computer-vision-caffe-cudnn/ ?
-
cuDNN: Efficient Primitives for Deep Learning https://arxiv.org/pdf/1410.0759.pdf ?
-
Caffe con Troll: Shallow Ideas to Speed Up Deep Learning https://arxiv.org/pdf/1504.04343v1.pdf ?