一背镇、網(wǎng)絡(luò)剪枝的步驟
神經(jīng)網(wǎng)絡(luò)中的一些權(quán)重和神經(jīng)元是可以被剪枝的,這是因?yàn)檫@些權(quán)重可能為零或者神經(jīng)元的輸出大多數(shù)時(shí)候?yàn)榱悖砻鬟@些權(quán)重或神經(jīng)元是冗余的贿条。
網(wǎng)絡(luò)剪枝的過程主要分以下幾步:
①訓(xùn)練網(wǎng)絡(luò)处窥。
②評(píng)估權(quán)重和神經(jīng)元的重要性嘱吗。例如可以用L1、L2來評(píng)估權(quán)重的重要性,用不是0的次數(shù)來衡量神經(jīng)元的重要性谒麦。
③對(duì)權(quán)重或者神經(jīng)元的重要性進(jìn)行排序俄讹,然后移除不重要的權(quán)重或神經(jīng)元。
④恢復(fù)微調(diào)绕德。移除部分權(quán)重或者神經(jīng)元后網(wǎng)絡(luò)的準(zhǔn)確率會(huì)受到一些損傷患膛,因此我們要進(jìn)行微調(diào),也就是使用原來的訓(xùn)練數(shù)據(jù)更新一下參數(shù)耻蛇,往往就可以復(fù)原回來踪蹬。
⑤迭代。為了不會(huì)使剪枝造成模型效果的過大損傷臣咖,我們每次都不會(huì)一次性剪掉太多的權(quán)重或神經(jīng)元跃捣,因此這個(gè)過程需要迭代,也就是說剪枝且微調(diào)一次后如果剪枝后的模型大小還不令人滿意就回到步驟后迭代上述過程直到滿意為止夺蛇。
二枝缔、剪枝的方法和對(duì)象
非結(jié)構(gòu)化剪枝(Unstructured Puning)
非結(jié)構(gòu)化剪枝(Unstructured Puning)是指修剪參數(shù)的單個(gè)元素,比如全連接層中的單個(gè)權(quán)重蚊惯、卷積層中的單個(gè)卷積核參數(shù)元素或者自定義層中的浮點(diǎn)數(shù)(scaling floats)愿卸。其重點(diǎn)在于,剪枝權(quán)重對(duì)象是隨機(jī)的截型,沒有特定結(jié)構(gòu)趴荸,因此被稱為非結(jié)構(gòu)化剪枝。由于非結(jié)構(gòu)化的剪枝在硬件方面需要有專用的庫支持宦焦,沒 有結(jié)構(gòu)化剪枝易于實(shí) 現(xiàn) 发钝,正在逐漸淡出人們的焦點(diǎn)。
結(jié)構(gòu)化剪枝
與非結(jié)構(gòu)化剪枝相反波闹,結(jié)構(gòu)化剪枝會(huì)剪枝整個(gè)參數(shù)結(jié)構(gòu)酝豪。比如,丟棄整行或整列的權(quán)重精堕,或者在卷積層中丟棄整個(gè)過濾器(Filter)孵淘。
具體來看,有以下幾種常見的剪枝對(duì)象:
1.weights剪枝
2.神經(jīng)元剪枝
3.Filters剪枝
4.通道剪枝
有人可能會(huì)問歹篓,就是將參數(shù)矩陣?yán)镆徊糠值闹导舻簦ú⒉皇钦娴囊瞥敝ぃ菍⑵滟x為0)。這真的能起到壓縮模型大小的作用嗎庄撮,Parameters并沒有減少氨嘲啤?回答是洞斯,Parameter的確沒有減少毡庆,但是FLOPsFLOPs(浮點(diǎn)運(yùn)算數(shù))減少了,所以仍然起到了壓縮模型大小。
三么抗、代碼實(shí)現(xiàn)
- PyTorch中的torch.nn.utils.prune模塊是一個(gè)專門用于神經(jīng)網(wǎng)絡(luò)模型剪枝的工具集毅否。
import torch
import torch.nn as nn
from torch.nn.utils import prune
torch.manual_seed(888)
# 創(chuàng)建一個(gè)簡單的卷積層
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
model = SimpleModel()
#剪枝1:隨機(jī)選擇50%的權(quán)重進(jìn)行剪枝
prune.random_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after RandomUnstructured pruning (50%):", model.conv.weight)
#除了對(duì) weight 剪枝, 還可以對(duì) bias 剪枝, name="bias"
#amount還可以是非負(fù)整數(shù),表示要修剪的連接的絕對(duì)數(shù)量
#剪枝2:剪掉L1范數(shù)(絕對(duì)值)最小的50%
prune.l1_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after L1Unstructured pruning (50%):", model.conv.weight)
#剪枝3:結(jié)構(gòu)化剪枝
prune.ln_structured(model.conv, name="weight", amount=0.5, n=2, dim=0)
print(model.conv.weight)
print(model.state_dict().keys())
#剪枝4:設(shè)置多個(gè)參數(shù)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
new_model = LeNet().to(device=device)
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
#global prune
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
))
#Global sparsity: 20.00%
#移除修建過程中保存的參數(shù)乖坠,如修剪前的參數(shù)和偏置等搀突,使用 torch.nn.utils.prune 中的 remove 功能刀闷。
prune.remove(model , 'weight')
print(list(model .named_parameters()))
- 重要參數(shù)說明
n 表示剪枝的范數(shù)熊泵,dim 表示剪枝的維度。
- 對(duì)于 torch.nn.Linear:
dim = 0:移除一個(gè)神經(jīng)元甸昏。
dim = 1:移除與一個(gè)輸入的所有連接顽分。 - 對(duì)于 torch.nn.Conv2d:
dim = 0 (Channels): 通道 channels 剪枝/過濾器 filters 剪枝
dim = 1(Neurons): 二維卷積核 kernel 剪枝,即與輸入通道相連接的 kernel