模型壓縮和加速——網(wǎng)絡(luò)剪枝(Network Pruning)及其pytorch實(shí)現(xiàn)

一背镇、網(wǎng)絡(luò)剪枝的步驟

神經(jīng)網(wǎng)絡(luò)中的一些權(quán)重和神經(jīng)元是可以被剪枝的,這是因?yàn)檫@些權(quán)重可能為零或者神經(jīng)元的輸出大多數(shù)時(shí)候?yàn)榱悖砻鬟@些權(quán)重或神經(jīng)元是冗余的贿条。
網(wǎng)絡(luò)剪枝的過程主要分以下幾步:
訓(xùn)練網(wǎng)絡(luò)处窥。
評(píng)估權(quán)重和神經(jīng)元的重要性嘱吗。例如可以用L1、L2來評(píng)估權(quán)重的重要性,用不是0的次數(shù)來衡量神經(jīng)元的重要性谒麦。
③對(duì)權(quán)重或者神經(jīng)元的重要性進(jìn)行排序俄讹,然后移除不重要的權(quán)重或神經(jīng)元。
恢復(fù)微調(diào)绕德。移除部分權(quán)重或者神經(jīng)元后網(wǎng)絡(luò)的準(zhǔn)確率會(huì)受到一些損傷患膛,因此我們要進(jìn)行微調(diào),也就是使用原來的訓(xùn)練數(shù)據(jù)更新一下參數(shù)耻蛇,往往就可以復(fù)原回來踪蹬。
迭代。為了不會(huì)使剪枝造成模型效果的過大損傷臣咖,我們每次都不會(huì)一次性剪掉太多的權(quán)重或神經(jīng)元跃捣,因此這個(gè)過程需要迭代,也就是說剪枝且微調(diào)一次后如果剪枝后的模型大小還不令人滿意就回到步驟后迭代上述過程直到滿意為止夺蛇。

二枝缔、剪枝的方法和對(duì)象

非結(jié)構(gòu)化剪枝(Unstructured Puning)
非結(jié)構(gòu)化剪枝(Unstructured Puning)是指修剪參數(shù)的單個(gè)元素,比如全連接層中的單個(gè)權(quán)重蚊惯、卷積層中的單個(gè)卷積核參數(shù)元素或者自定義層中的浮點(diǎn)數(shù)(scaling floats)愿卸。其重點(diǎn)在于,剪枝權(quán)重對(duì)象是隨機(jī)的截型,沒有特定結(jié)構(gòu)趴荸,因此被稱為非結(jié)構(gòu)化剪枝。由于非結(jié)構(gòu)化的剪枝在硬件方面需要有專用的庫支持宦焦,沒 有結(jié)構(gòu)化剪枝易于實(shí) 現(xiàn) 发钝,正在逐漸淡出人們的焦點(diǎn)。
結(jié)構(gòu)化剪枝
與非結(jié)構(gòu)化剪枝相反波闹,結(jié)構(gòu)化剪枝會(huì)剪枝整個(gè)參數(shù)結(jié)構(gòu)酝豪。比如,丟棄整行或整列的權(quán)重精堕,或者在卷積層中丟棄整個(gè)過濾器(Filter)孵淘。

具體來看,有以下幾種常見的剪枝對(duì)象:
1.weights剪枝
2.神經(jīng)元剪枝
3.Filters剪枝
4.通道剪枝

有人可能會(huì)問歹篓,就是將參數(shù)矩陣?yán)镆徊糠值闹导舻簦ú⒉皇钦娴囊瞥敝ぃ菍⑵滟x為0)。這真的能起到壓縮模型大小的作用嗎庄撮,Parameters并沒有減少氨嘲啤?回答是洞斯,Parameter的確沒有減少毡庆,但是FLOPsFLOPs(浮點(diǎn)運(yùn)算數(shù))減少了,所以仍然起到了壓縮模型大小。

三么抗、代碼實(shí)現(xiàn)

  1. PyTorch中的torch.nn.utils.prune模塊是一個(gè)專門用于神經(jīng)網(wǎng)絡(luò)模型剪枝的工具集毅否。
import torch
import torch.nn as nn
from torch.nn.utils import prune
torch.manual_seed(888)
# 創(chuàng)建一個(gè)簡單的卷積層
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()     
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
model = SimpleModel()

#剪枝1:隨機(jī)選擇50%的權(quán)重進(jìn)行剪枝
prune.random_unstructured(model.conv, name="weight", amount=0.5) 
print("Weight after RandomUnstructured pruning (50%):", model.conv.weight)
#除了對(duì) weight 剪枝, 還可以對(duì) bias 剪枝, name="bias"
#amount還可以是非負(fù)整數(shù),表示要修剪的連接的絕對(duì)數(shù)量

#剪枝2:剪掉L1范數(shù)(絕對(duì)值)最小的50%
prune.l1_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after L1Unstructured pruning (50%):", model.conv.weight)

#剪枝3:結(jié)構(gòu)化剪枝
prune.ln_structured(model.conv, name="weight", amount=0.5, n=2, dim=0)
print(model.conv.weight)
print(model.state_dict().keys())

#剪枝4:設(shè)置多個(gè)參數(shù)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

new_model = LeNet().to(device=device)
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

#global prune
model = LeNet()
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2
)

print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        ))
#Global sparsity: 20.00%

#移除修建過程中保存的參數(shù)乖坠,如修剪前的參數(shù)和偏置等搀突,使用 torch.nn.utils.prune 中的 remove 功能刀闷。
prune.remove(model , 'weight')
print(list(model .named_parameters()))
  1. 重要參數(shù)說明
    n 表示剪枝的范數(shù)熊泵,dim 表示剪枝的維度。
  • 對(duì)于 torch.nn.Linear:
    dim = 0:移除一個(gè)神經(jīng)元甸昏。
    dim = 1:移除與一個(gè)輸入的所有連接顽分。
  • 對(duì)于 torch.nn.Conv2d:
    dim = 0 (Channels): 通道 channels 剪枝/過濾器 filters 剪枝
    dim = 1(Neurons): 二維卷積核 kernel 剪枝,即與輸入通道相連接的 kernel
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末施蜜,一起剝皮案震驚了整個(gè)濱河市卒蘸,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌翻默,老刑警劉巖缸沃,帶你破解...
    沈念sama閱讀 219,110評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異修械,居然都是意外死亡趾牧,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門肯污,熙熙樓的掌柜王于貴愁眉苦臉地迎上來翘单,“玉大人,你說我怎么就攤上這事蹦渣『逦撸” “怎么了?”我有些...
    開封第一講書人閱讀 165,474評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵柬唯,是天一觀的道長认臊。 經(jīng)常有香客問我,道長锄奢,這世上最難降的妖魔是什么美尸? 我笑而不...
    開封第一講書人閱讀 58,881評(píng)論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮斟薇,結(jié)果婚禮上师坎,老公的妹妹穿的比我還像新娘。我一直安慰自己堪滨,他們只是感情好胯陋,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,902評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般遏乔。 火紅的嫁衣襯著肌膚如雪义矛。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,698評(píng)論 1 305
  • 那天盟萨,我揣著相機(jī)與錄音凉翻,去河邊找鬼。 笑死捻激,一個(gè)胖子當(dāng)著我的面吹牛制轰,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播胞谭,決...
    沈念sama閱讀 40,418評(píng)論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼垃杖,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了丈屹?” 一聲冷哼從身側(cè)響起调俘,我...
    開封第一講書人閱讀 39,332評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎旺垒,沒想到半個(gè)月后彩库,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,796評(píng)論 1 316
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡先蒋,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,968評(píng)論 3 337
  • 正文 我和宋清朗相戀三年骇钦,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片鞭达。...
    茶點(diǎn)故事閱讀 40,110評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡司忱,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出畴蹭,到底是詐尸還是另有隱情坦仍,我是刑警寧澤,帶...
    沈念sama閱讀 35,792評(píng)論 5 346
  • 正文 年R本政府宣布叨襟,位于F島的核電站繁扎,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏糊闽。R本人自食惡果不足惜梳玫,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,455評(píng)論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望右犹。 院中可真熱鬧提澎,春花似錦、人聲如沸念链。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至谦纱,卻和暖如春看成,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背跨嘉。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評(píng)論 1 272
  • 我被黑心中介騙來泰國打工川慌, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人祠乃。 一個(gè)月前我還...
    沈念sama閱讀 48,348評(píng)論 3 373
  • 正文 我出身青樓梦重,卻偏偏與公主長得像,于是被迫代替她去往敵國和親跳纳。 傳聞我的和親對(duì)象是個(gè)殘疾皇子忍饰,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,047評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容