有一段時(shí)間沒(méi)干正事了,早前一直沒(méi)找到合適的pytorch的剪枝壓縮代方法,現(xiàn)在看來(lái),主要是姿勢(shì)不對(duì)丰榴。這幾天集中突擊了一下網(wǎng)絡(luò)壓縮的pytorch剪枝,做一個(gè)記錄严望。
簡(jiǎn)述剪枝
以前早前弄了一下caffe模型的壓縮看這里多艇,懵懵懂懂只知道這是深度學(xué)習(xí)模型做好之后的后處理工作,實(shí)際場(chǎng)景進(jìn)行效率提升時(shí)候會(huì)用的到像吻,是深度學(xué)習(xí)的一個(gè)大方向峻黍。
先從這個(gè)小網(wǎng)絡(luò)剪枝的demo開(kāi)始說(shuō)起。
剪枝的主要流程
1. 網(wǎng)絡(luò)的重要程度評(píng)估
對(duì)網(wǎng)絡(luò)進(jìn)行剪枝首先要知道該刪除那些參數(shù)拨匆,對(duì)卷積層和全連接層的剪枝方法是不一樣的姆涩,jacobgil提到到L1和L2方法,Oracle pruning等惭每。
- L1: w = |w|
- L2: w = (w)^2
具體可以看看他提到的幾篇論文骨饿,我就不贅述了。使用這些方法得到網(wǎng)絡(luò)權(quán)重的排名台腥,按大小派宏赘,全連接就直接排序了,卷積就需要看了因?yàn)橛行┌赐ǖ烙行┦前磳舆M(jìn)行剪枝黎侈。具體看看這里有對(duì)比[Quantizing deep convolutional networks for efficient inference: A whitepaper]察署。demo就是安通道和L1范數(shù)進(jìn)行剪枝的。評(píng)估之后可以得到網(wǎng)絡(luò)同一層之間峻汉,通道的排序贴汪,統(tǒng)計(jì)模型所有conv layer的通道數(shù),同時(shí)記錄并返回它們的卷積層號(hào)和通道編號(hào)休吠。
主要由這兩個(gè)函數(shù)進(jìn)行操作扳埂。
compute_rank()
normalize_ranks_per_layer()
2. 移除不重要的網(wǎng)絡(luò)層
移除不重要的網(wǎng)絡(luò)層是剪枝的理所當(dāng)然的事,但是也是最復(fù)雜的一個(gè)操作瘤礁。受限于當(dāng)前的深度學(xué)習(xí)框架的限制阳懂,小白們對(duì)這種取參數(shù)和新建網(wǎng)路層這種復(fù)雜手藝表示看不懂。caffe剪枝可以看看這里L1剪枝,pytorch和caffe剪枝不一樣希太,因?yàn)閏affe的結(jié)構(gòu)克饶,是過(guò)文件進(jìn)行編寫的酝蜒,容易查看誊辉。pytorch的模型結(jié)構(gòu),就看個(gè)人的手藝了亡脑,不是自己寫的估計(jì)夠嗆啊堕澄。
和caffe的流程類似,通過(guò)建立一個(gè)新層霉咨,新層的輸入和輸出的數(shù)量通過(guò)移除的比例確定蛙紫,參數(shù)隨機(jī)初始化,把老層的參數(shù)進(jìn)行copy途戒。copy的手法就是坑傅,通過(guò)確認(rèn)遍歷整個(gè)模型參數(shù),copy重要的參數(shù)索引喷斋,不copy重要的參數(shù)索引唁毒。之后把新的層和新的參數(shù),重新裝入net中星爪。
- 類似:
net.feature = conv
- demo中的主要參數(shù)函數(shù)為
prune()
prune_conv_layer()
3.微調(diào)網(wǎng)絡(luò)
和普通的訓(xùn)練沒(méi)什么差別浆西。
4. 其他
已經(jīng)有大神把resnet18的剪枝方法,發(fā)出來(lái)了顽腾,提到BatchNorm層通道數(shù)修改,當(dāng)所有卷積層剪枝結(jié)束近零,依據(jù)鄰近上一個(gè)卷積層輸出通道數(shù),通過(guò)BatchNorm層繼承方式抄肖,它需修改成同樣的通道數(shù)久信。但是通用的方法是把BatchNorm和上一層conv進(jìn)行融合,其中會(huì)直接刪除一些結(jié)構(gòu)漓摩,需要一點(diǎn)時(shí)間裙士,好好看看才能實(shí)現(xiàn)。
參考:
大佬 jacobgil
對(duì) resnet18進(jìn)行剪枝
小網(wǎng)絡(luò)剪枝的demo
pytorch基于卷積層通道剪枝的方法
基于Pytorch的卷積神經(jīng)網(wǎng)絡(luò)剪枝