我們?cè)谠O(shè)計(jì)一個(gè)CNN網(wǎng)絡(luò)時(shí)深纲,一定要考慮的兩個(gè)事情澄步,一個(gè)是這個(gè)網(wǎng)絡(luò)需要的計(jì)算量有多大南缓,一個(gè)是這個(gè)模型的參數(shù)量有多少蛀骇。計(jì)算量決定網(wǎng)絡(luò)訓(xùn)練的快慢(硬件設(shè)備確定的情況),參數(shù)量決定計(jì)算設(shè)備需要多大的內(nèi)存或顯存厌秒。CNN的計(jì)算量以計(jì)算機(jī)做乘加次數(shù)為單位,即完成某個(gè)操作擅憔,需要執(zhí)行多少次乘法和加法操作鸵闪。參數(shù)量以參數(shù)個(gè)數(shù)為單位,要計(jì)算內(nèi)存或顯存的暑诸,用參數(shù)量乘以每個(gè)參數(shù)所占的字節(jié)數(shù)即可蚌讼。
1.計(jì)算量
對(duì)于輸入特征圖f=(B,H,W,C)辟灰,卷積核張量kernel=(K,S,C,O),其中篡石,B是batch size大小芥喇,H,W夏志,C分別是輸入特征圖的高乃坤,寬和通道數(shù)。K,S,C,O分別是卷積操作時(shí)沟蔑,卷積核的大小湿诊,移動(dòng)步長(stride),特征圖輸入通道及輸出通道數(shù)瘦材。
- 首先一次卷積的計(jì)算量
一個(gè)kk的卷積厅须,執(zhí)行一次卷積操作,需要kk次乘法操作(卷積核中每個(gè)參數(shù)都要和特征圖上的元素相乘一次)食棕,kk-1次加法操作(將卷積結(jié)果朗和,kk個(gè)數(shù)加起來)。所以簿晓,一次卷積操作需要的乘加次數(shù):(KK)+(KK-1)=2KK-1 - 在一個(gè)特征圖上執(zhí)行卷積需要進(jìn)行卷積的次數(shù):
在一個(gè)特征圖上需要執(zhí)行的卷積次數(shù):(((H-K+Ph)/S)+1)*(((W-K+Pw)/S)+1),Pw,Ph表示在高和寬方向填充的像素眶拉。 - C個(gè)特征圖上進(jìn)行卷積運(yùn)算的次數(shù)
C個(gè)輸入特征圖上進(jìn)行卷積運(yùn)算的次數(shù)為C - 輸出一個(gè)特征圖通道需要的加法次數(shù)
在C個(gè)輸入特征圖上進(jìn)行卷積之后需要將卷積的結(jié)果相加,得到一個(gè)輸出特征圖上的卷積結(jié)果憔儿,C個(gè)相加需要C-1次加法忆植,所以輸出一個(gè)特征圖的計(jì)算量是:
(C-1)C((H-K+Ph)/S+1)((W-K+Pw)/S+1)(2KK-1) - 輸出O個(gè)特征圖需要計(jì)算的次數(shù)
上面的結(jié)果要乘到的是一個(gè)通道需要的計(jì)算量,需要輸出O個(gè)通道谒臼,計(jì)算量還要乘以O(shè)朝刊,所以O(shè)個(gè)特征圖需要的計(jì)算量為:
O(C-1)C((H-K+Ph)/S+1)((W-K+Pw)/S+1)(2K*K-1) - 一個(gè)batch的樣本需要的計(jì)算量
上面是一個(gè)樣本的數(shù)據(jù)需要的計(jì)算量,每個(gè)樣本都需要進(jìn)行卷積運(yùn)算蜈缤,所以一個(gè)batch的樣本需要的計(jì)算量為:
BO(C-1)C((H-K+Ph)/S+1)((W-K+Pw)/S+1)(2KK-1)
上面的算式得到的是CNN一層所需要的計(jì)算量拾氓,將每層的計(jì)算量相加就可以得到整個(gè)網(wǎng)絡(luò)的計(jì)算量。通常包含乘加的操作有Pool,Relu,BN(含有除法),卷積等底哥。一般都是卷積操作占主要咙鞍。
2.參數(shù)量
CNN網(wǎng)絡(luò)的參數(shù)量和特征圖的尺寸無關(guān),僅和卷積核的大小趾徽,偏置及BN有關(guān),對(duì)于卷積張量kernel=(K,S,C,O)奶陈,權(quán)重參數(shù)量為KKCO,偏置參數(shù)量為O,如果使用BN附较,那么還有兩個(gè)可學(xué)習(xí)參數(shù) a,b吃粒,參數(shù)量都是O,總共2O拒课,終上所述徐勃,該卷積層所有的參數(shù)量為:
KKCO+3O
需要注意的是事示,上面計(jì)算的僅僅是模型的參數(shù)量,若要計(jì)算模型實(shí)際需要多少顯存僻肖,還要考慮特征圖的大小肖爵,因?yàn)槊恳粚泳矸e的輸出都需要緩存,還要BN計(jì)算出來的均值和偏差也需要緩存臀脏,權(quán)重的梯度也需要緩存劝堪。通常模型參數(shù)所占用的顯存比例很小。