介紹
CNN模型為了追求精度提高層數(shù)已經(jīng)是愈來愈多萄金,可更多的層次帶來的精度邊際提升卻不斷減小备畦。或者對某些輸入圖片而言,真正所需的layers并非那么多畔勤,只有一些真正模糊、特征不明顯扒磁、即使人看上去也較難分辨的圖片才需要較多的layers處理最終得到能分別其類別的表達特征庆揪。
SkipNet主要是以此假設出發(fā),通過在傳統(tǒng)CNN的每個layer(或module)上設置判斷其是否需要執(zhí)行的Gate module來決定是否需要真的執(zhí)行此層計算妨托,若判斷為否則直接將activation feature maps傳入到下一層缸榛,越過當下層的運算不做吝羞。無益這樣做可以有效地節(jié)省傳統(tǒng)CNN模型在部署時進行推理工作所需的時間。
就這樣一旦訓練好内颗,SkipNet在做圖片推理時可根據(jù)輸入的feature maps不同靈活地決定是否執(zhí)行某一網(wǎng)絡中的層钧排。下圖可反映SkipNet這一根本特點。
SkiptNet
對于每一層操作而言均澳,SkipNet可表示為:xi+1 = GiFi(xi)+(1-Gi)xi恨溜。其中xi和Fi(xi)分別表示第ith layer的輸入與輸出feature maps;Gi ∈{0,1} 為第ith layer的Gate函數(shù)找前。
對于此處的Gate函數(shù)糟袁,作者實驗了兩種不同的表示方法。Paper中SkipNet基于的CNN網(wǎng)絡為Resnet纸厉,其中Gate即可以被獨立地添加在各個Residual block上面作為單獨的個體系吭,有著不同的參數(shù)即Feed-forward Gate;還可以所有的Residual blocks復用一個Gate module即Recurrent Gate颗品。其不同之處可從下圖中看出肯尺。
Gate module設計
作者在論文中共嘗試了三種不同的Gate module設計,它們對計算與accuracy的考量略有不同躯枢。
FFGate-I: MaxPool(2x2) -> Conv(3x3, 1) -> Conv(3x3, 2) -> AvgPool -> FC则吟,整體計算量約為Residual block的19%,在論文中主要用于較淺的一些網(wǎng)絡(層數(shù)小于100)锄蹂;
FFGate-II: Conv(3x3, 2) -> AvgPool -> FC氓仲,整體計算量約為Residual block的12.5%,主要用于較深的一些網(wǎng)絡(層數(shù)大于100)得糜;
RNNGate: AvgPool -> Conv(1x1) -> LSTM(10 hidden units) -> FC敬扛,整體計算量約為Residual block的0.04%,是論文中首選的Gate函數(shù)朝抖。在深層次網(wǎng)絡中它相對于Feed-forward Gate有較大的性能與分類精度優(yōu)勢啥箭,只是在較淺的層次上它精度略低,但計算開銷仍有較大優(yōu)勢治宣。
下圖為以上三種Gate module的概況描述急侥。
使用Hybrid RL的Skipping policy學習
對于上節(jié)所介紹的Gate函數(shù)可理解為是這么一種決策:Π(xi,i) = P(Gi(xi) = gi),(其中gi∈{0,1}侮邀,分別表示執(zhí)行還是略過第ith層執(zhí)行的兩種離散決策)坏怪。
這樣對于有N層的CNN來說,我們在forward時需要決定下如此一個輸入為x的決策序列:g = [g1,....,gN] ? Π(F<sub>θ</sub>)绊茧。在這里Fθ = [Fθ1,....,FθN]表示CNN網(wǎng)絡中N個layers的計算铝宵。
而整體的目標函數(shù)則可表示如下:
其中Ri = (1-gi)Ci表示的是每個Gate module所節(jié)省的計算,亦為它的激勵函數(shù)华畏。因為paper中用的是Resnet捉超,故假定所有的Ci相同胧卤,設為1。然后α 則為CNN分類準確率與計算節(jié)省之間的平衡系數(shù)拼岳≈μ埽可以看出這里的目標函數(shù)設計同時考慮了模型分類精度與計算效率并力圖在其中尋找平衡。
下式為具體計算時的梯度計算公式惜纸∫度觯可以看出它主要由兩部分組成,第一部分表示的是學習分類精度的supervised loss耐版,第二部分則是要接合RL最終學習出來的反映計算節(jié)省的Skip learning policy祠够。
下圖為使用Hybrid RL的具體算法概述。
實驗結果
下圖為SkipNet在各大數(shù)據(jù)集上得到的分類精度結果粪牲。
下表中反映了不同SkipNet配置與訓練方法在達到與原生ResNet相似精度的情況下?lián)Q來的計算節(jié)省古瓤。
代碼分析
如下為FFGate-I的設計實現(xiàn),其它Gate module的寫法并無太多不同腺阳。
# Feedforward-Gate (FFGate-I)
class FeedforwardGateI(nn.Module):
""" Use Max Pooling First and then apply to multiple 2 conv layers.
The first conv has stride = 1 and second has stride = 2"""
def __init__(self, pool_size=5, channel=10):
super(FeedforwardGateI, self).__init__()
self.pool_size = pool_size
self.channel = channel
self.maxpool = nn.MaxPool2d(2)
self.conv1 = conv3x3(channel, channel)
self.bn1 = nn.BatchNorm2d(channel)
self.relu1 = nn.ReLU(inplace=True)
# adding another conv layer
self.conv2 = conv3x3(channel, channel, stride=2)
self.bn2 = nn.BatchNorm2d(channel)
self.relu2 = nn.ReLU(inplace=True)
pool_size = math.floor(pool_size/2) # for max pooling
pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2
self.avg_layer = nn.AvgPool2d(pool_size)
self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2,
kernel_size=1, stride=1)
self.prob_layer = nn.Softmax()
self.logprob = nn.LogSoftmax()
def forward(self, x):
x = self.maxpool(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.avg_layer(x)
x = self.linear_layer(x).squeeze()
softmax = self.prob_layer(x)
logprob = self.logprob(x)
# discretize output in forward pass.
# use softmax gradients in backward pass
x = (softmax[:, 1] > 0.5).float().detach() - \
softmax[:, 1].detach() + softmax[:, 1]
x = x.view(x.size(0), 1, 1, 1)
return x, logprob
下面這個class里面則具體實現(xiàn)了如何將Gate module與某一CNN網(wǎng)絡結合起來從而實現(xiàn)相關的SkipNet落君。
class ResNetFeedForwardRL(nn.Module):
"""Adding gating module on every basic block"""
def __init__(self, block, layers, num_classes=10,
gate_type='ffgate1', **kwargs):
self.inplanes = 16
super(ResNetFeedForwardRL, self).__init__()
self.num_layers = layers
self.conv1 = conv3x3(3, 16)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.gate_instances = []
self.gate_type = gate_type
self._make_group(block, 16, layers[0], group_id=1,
gate_type=gate_type, pool_size=32)
self._make_group(block, 32, layers[1], group_id=2,
gate_type=gate_type, pool_size=16)
self._make_group(block, 64, layers[2], group_id=3,
gate_type=gate_type, pool_size=8)
# remove the last gate instance, (not optimized)
del self.gate_instances[-1]
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64 * block.expansion, num_classes)
self.softmax = nn.Softmax()
self.saved_actions = []
self.rewards = []
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(0) * m.weight.size(1)
m.weight.data.normal_(0, math.sqrt(2. / n))
def _make_group(self, block, planes, layers, group_id=1,
gate_type='fisher', pool_size=16):
""" Create the whole group"""
for i in range(layers):
if group_id > 1 and i == 0:
stride = 2
else:
stride = 1
meta = self._make_layer_v2(block, planes, stride=stride,
gate_type=gate_type,
pool_size=pool_size)
setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0])
setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1])
setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2])
# add into gate instance collection
self.gate_instances.append(meta[2])
def _make_layer_v2(self, block, planes, stride=1,
gate_type='fisher', pool_size=16):
""" create one block and optional a gate module """
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layer = block(self.inplanes, planes, stride, downsample)
self.inplanes = planes * block.expansion
if gate_type == 'ffgate1':
gate_layer = RLFeedforwardGateI(pool_size=pool_size,
channel=planes*block.expansion)
elif gate_type == 'ffgate2':
gate_layer = RLFeedforwardGateII(pool_size=pool_size,
channel=planes*block.expansion)
else:
gate_layer = None
if downsample:
return downsample, layer, gate_layer
else:
return None, layer, gate_layer
def repackage_vars(self):
self.saved_actions = repackage_hidden(self.saved_actions)
def forward(self, x, reinforce=False):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
masks = []
gprobs = []
# must pass through the first layer in first group
x = getattr(self, 'group1_layer0')(x)
# gate takes the output of the current layer
mask, gprob = getattr(self, 'group1_gate0')(x)
gprobs.append(gprob)
masks.append(mask.squeeze())
prev = x # input of next layer
for g in range(3):
for i in range(0 + int(g == 0), self.num_layers[g]):
if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None:
prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev)
x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x)
# new mask is taking the current output
prev = x = mask.expand_as(x) * x \
+ (1 - mask).expand_as(prev) * prev
mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x)
gprobs.append(gprob)
masks.append(mask.squeeze())
del masks[-1]
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
# collect all actions
for inst in self.gate_instances:
self.saved_actions.append(inst.saved_action)
if reinforce: # for pure RL
softmax = self.softmax(x)
action = softmax.multinomial()
self.saved_actions.append(action)
return x, masks, gprobs
參考文獻
- SkipNet: Learning Dynamic Routing in Convolutional Networks, Xin-Wang, 2018
- https://github.com/ucbdrive/skipnet