精簡CNN模型系列之三:SkipNet

介紹

CNN模型為了追求精度提高層數(shù)已經(jīng)是愈來愈多萄金,可更多的層次帶來的精度邊際提升卻不斷減小备畦。或者對某些輸入圖片而言,真正所需的layers并非那么多畔勤,只有一些真正模糊、特征不明顯扒磁、即使人看上去也較難分辨的圖片才需要較多的layers處理最終得到能分別其類別的表達特征庆揪。

SkipNet主要是以此假設出發(fā),通過在傳統(tǒng)CNN的每個layer(或module)上設置判斷其是否需要執(zhí)行的Gate module來決定是否需要真的執(zhí)行此層計算妨托,若判斷為否則直接將activation feature maps傳入到下一層缸榛,越過當下層的運算不做吝羞。無益這樣做可以有效地節(jié)省傳統(tǒng)CNN模型在部署時進行推理工作所需的時間。

就這樣一旦訓練好内颗,SkipNet在做圖片推理時可根據(jù)輸入的feature maps不同靈活地決定是否執(zhí)行某一網(wǎng)絡中的層钧排。下圖可反映SkipNet這一根本特點。

SkipNet根本思想

SkiptNet

對于每一層操作而言均澳,SkipNet可表示為:xi+1 = GiFi(xi)+(1-Gi)xi恨溜。其中xi和Fi(xi)分別表示第ith layer的輸入與輸出feature maps;Gi ∈{0,1} 為第ith layer的Gate函數(shù)找前。

對于此處的Gate函數(shù)糟袁,作者實驗了兩種不同的表示方法。Paper中SkipNet基于的CNN網(wǎng)絡為Resnet纸厉,其中Gate即可以被獨立地添加在各個Residual block上面作為單獨的個體系吭,有著不同的參數(shù)即Feed-forward Gate;還可以所有的Residual blocks復用一個Gate module即Recurrent Gate颗品。其不同之處可從下圖中看出肯尺。

SkipNet中兩種不同的Gate函數(shù)選擇

Gate module設計

作者在論文中共嘗試了三種不同的Gate module設計,它們對計算與accuracy的考量略有不同躯枢。

FFGate-I: MaxPool(2x2) -> Conv(3x3, 1) -> Conv(3x3, 2) -> AvgPool -> FC则吟,整體計算量約為Residual block的19%,在論文中主要用于較淺的一些網(wǎng)絡(層數(shù)小于100)锄蹂;
FFGate-II: Conv(3x3, 2) -> AvgPool -> FC氓仲,整體計算量約為Residual block的12.5%,主要用于較深的一些網(wǎng)絡(層數(shù)大于100)得糜;
RNNGate: AvgPool -> Conv(1x1) -> LSTM(10 hidden units) -> FC敬扛,整體計算量約為Residual block的0.04%,是論文中首選的Gate函數(shù)朝抖。在深層次網(wǎng)絡中它相對于Feed-forward Gate有較大的性能與分類精度優(yōu)勢啥箭,只是在較淺的層次上它精度略低,但計算開銷仍有較大優(yōu)勢治宣。

下圖為以上三種Gate module的概況描述急侥。

三種具體的Gate_module設計

使用Hybrid RL的Skipping policy學習

對于上節(jié)所介紹的Gate函數(shù)可理解為是這么一種決策:Π(xi,i) = P(Gi(xi) = gi),(其中gi∈{0,1}侮邀,分別表示執(zhí)行還是略過第ith層執(zhí)行的兩種離散決策)坏怪。

這樣對于有N層的CNN來說,我們在forward時需要決定下如此一個輸入為x的決策序列:g = [g1,....,gN] ? Π(F<sub>&theta;</sub>)绊茧。在這里F&theta; = [F&theta;1,....,F&theta;N]表示CNN網(wǎng)絡中N個layers的計算铝宵。

而整體的目標函數(shù)則可表示如下:

Skip_learning中使用Hybrid_RL時的整體目標函數(shù)

其中Ri = (1-gi)Ci表示的是每個Gate module所節(jié)省的計算,亦為它的激勵函數(shù)华畏。因為paper中用的是Resnet捉超,故假定所有的Ci相同胧卤,設為1。然后α 則為CNN分類準確率與計算節(jié)省之間的平衡系數(shù)拼岳≈μ埽可以看出這里的目標函數(shù)設計同時考慮了模型分類精度與計算效率并力圖在其中尋找平衡。

下式為具體計算時的梯度計算公式惜纸∫度觯可以看出它主要由兩部分組成,第一部分表示的是學習分類精度的supervised loss耐版,第二部分則是要接合RL最終學習出來的反映計算節(jié)省的Skip learning policy祠够。

Skip_learning中使用Hybrid_RL時的梯度計算

下圖為使用Hybrid RL的具體算法概述。

Hybrid_RL_learning算法

實驗結果

下圖為SkipNet在各大數(shù)據(jù)集上得到的分類精度結果粪牲。

在各大數(shù)據(jù)集上SkipNet得到的分類精度

下表中反映了不同SkipNet配置與訓練方法在達到與原生ResNet相似精度的情況下?lián)Q來的計算節(jié)省古瓤。

不同SkipNet配置在達到相似精度情況下得到的計算節(jié)省

代碼分析

如下為FFGate-I的設計實現(xiàn),其它Gate module的寫法并無太多不同腺阳。

# Feedforward-Gate (FFGate-I)
class FeedforwardGateI(nn.Module):
    """ Use Max Pooling First and then apply to multiple 2 conv layers.
    The first conv has stride = 1 and second has stride = 2"""
    def __init__(self, pool_size=5, channel=10):
        super(FeedforwardGateI, self).__init__()
        self.pool_size = pool_size
        self.channel = channel

        self.maxpool = nn.MaxPool2d(2)
        self.conv1 = conv3x3(channel, channel)
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu1 = nn.ReLU(inplace=True)

        # adding another conv layer
        self.conv2 = conv3x3(channel, channel, stride=2)
        self.bn2 = nn.BatchNorm2d(channel)
        self.relu2 = nn.ReLU(inplace=True)

        pool_size = math.floor(pool_size/2)  # for max pooling
        pool_size = math.floor(pool_size/2 + 0.5)  # for conv stride = 2

        self.avg_layer = nn.AvgPool2d(pool_size)
        self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2,
                                      kernel_size=1, stride=1)
        self.prob_layer = nn.Softmax()
        self.logprob = nn.LogSoftmax()

    def forward(self, x):
        x = self.maxpool(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.avg_layer(x)
        x = self.linear_layer(x).squeeze()
        softmax = self.prob_layer(x)
        logprob = self.logprob(x)

        # discretize output in forward pass.
        # use softmax gradients in backward pass
        x = (softmax[:, 1] > 0.5).float().detach() - \
            softmax[:, 1].detach() + softmax[:, 1]

        x = x.view(x.size(0), 1, 1, 1)
        return x, logprob

下面這個class里面則具體實現(xiàn)了如何將Gate module與某一CNN網(wǎng)絡結合起來從而實現(xiàn)相關的SkipNet落君。

class ResNetFeedForwardRL(nn.Module):
    """Adding gating module on every basic block"""

    def __init__(self, block, layers, num_classes=10,
                 gate_type='ffgate1', **kwargs):
        self.inplanes = 16
        super(ResNetFeedForwardRL, self).__init__()

        self.num_layers = layers
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.gate_instances = []
        self.gate_type = gate_type
        self._make_group(block, 16, layers[0], group_id=1,
                         gate_type=gate_type, pool_size=32)
        self._make_group(block, 32, layers[1], group_id=2,
                         gate_type=gate_type, pool_size=16)
        self._make_group(block, 64, layers[2], group_id=3,
                         gate_type=gate_type, pool_size=8)

        # remove the last gate instance, (not optimized)
        del self.gate_instances[-1]

        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        self.softmax = nn.Softmax()
        self.saved_actions = []
        self.rewards = []

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(0) * m.weight.size(1)
                m.weight.data.normal_(0, math.sqrt(2. / n))

    def _make_group(self, block, planes, layers, group_id=1,
                    gate_type='fisher', pool_size=16):
        """ Create the whole group"""
        for i in range(layers):
            if group_id > 1 and i == 0:
                stride = 2
            else:
                stride = 1

            meta = self._make_layer_v2(block, planes, stride=stride,
                                       gate_type=gate_type,
                                       pool_size=pool_size)

            setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0])
            setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1])
            setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2])

            # add into gate instance collection
            self.gate_instances.append(meta[2])

    def _make_layer_v2(self, block, planes, stride=1,
                       gate_type='fisher', pool_size=16):
        """ create one block and optional a gate module """
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),

            )
        layer = block(self.inplanes, planes, stride, downsample)
        self.inplanes = planes * block.expansion

        if gate_type == 'ffgate1':
            gate_layer = RLFeedforwardGateI(pool_size=pool_size,
                                            channel=planes*block.expansion)
        elif gate_type == 'ffgate2':
            gate_layer = RLFeedforwardGateII(pool_size=pool_size,
                                             channel=planes*block.expansion)
        else:
            gate_layer = None

        if downsample:
            return downsample, layer, gate_layer
        else:
            return None, layer, gate_layer

    def repackage_vars(self):
        self.saved_actions = repackage_hidden(self.saved_actions)

    def forward(self, x, reinforce=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        masks = []
        gprobs = []
        # must pass through the first layer in first group
        x = getattr(self, 'group1_layer0')(x)
        # gate takes the output of the current layer
        mask, gprob = getattr(self, 'group1_gate0')(x)
        gprobs.append(gprob)
        masks.append(mask.squeeze())
        prev = x  # input of next layer

        for g in range(3):
            for i in range(0 + int(g == 0), self.num_layers[g]):
                if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None:
                    prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev)
                x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x)
                # new mask is taking the current output
                prev = x = mask.expand_as(x) * x \
                           + (1 - mask).expand_as(prev) * prev
                mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x)
                gprobs.append(gprob)
                masks.append(mask.squeeze())

        del masks[-1]

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # collect all actions
        for inst in self.gate_instances:
            self.saved_actions.append(inst.saved_action)

        if reinforce:  # for pure RL
            softmax = self.softmax(x)
            action = softmax.multinomial()
            self.saved_actions.append(action)

        return x, masks, gprobs

參考文獻

?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市亭引,隨后出現(xiàn)的幾起案子绎速,更是在濱河造成了極大的恐慌,老刑警劉巖焙蚓,帶你破解...
    沈念sama閱讀 217,826評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件纹冤,死亡現(xiàn)場離奇詭異,居然都是意外死亡购公,警方通過查閱死者的電腦和手機萌京,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,968評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來宏浩,“玉大人枫夺,你說我怎么就攤上這事』婷疲” “怎么了?”我有些...
    開封第一講書人閱讀 164,234評論 0 354
  • 文/不壞的土叔 我叫張陵较坛,是天一觀的道長印蔗。 經(jīng)常有香客問我,道長丑勤,這世上最難降的妖魔是什么华嘹? 我笑而不...
    開封第一講書人閱讀 58,562評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮法竞,結果婚禮上耙厚,老公的妹妹穿的比我還像新娘强挫。我一直安慰自己,他們只是感情好薛躬,可當我...
    茶點故事閱讀 67,611評論 6 392
  • 文/花漫 我一把揭開白布俯渤。 她就那樣靜靜地躺著,像睡著了一般型宝。 火紅的嫁衣襯著肌膚如雪八匠。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,482評論 1 302
  • 那天趴酣,我揣著相機與錄音梨树,去河邊找鬼。 笑死岖寞,一個胖子當著我的面吹牛抡四,可吹牛的內容都是我干的。 我是一名探鬼主播仗谆,決...
    沈念sama閱讀 40,271評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼指巡,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了胸私?” 一聲冷哼從身側響起厌处,我...
    開封第一講書人閱讀 39,166評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎岁疼,沒想到半個月后阔涉,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,608評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡捷绒,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,814評論 3 336
  • 正文 我和宋清朗相戀三年瑰排,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片暖侨。...
    茶點故事閱讀 39,926評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡椭住,死狀恐怖,靈堂內的尸體忽然破棺而出字逗,到底是詐尸還是另有隱情京郑,我是刑警寧澤,帶...
    沈念sama閱讀 35,644評論 5 346
  • 正文 年R本政府宣布葫掉,位于F島的核電站些举,受9級特大地震影響,放射性物質發(fā)生泄漏俭厚。R本人自食惡果不足惜户魏,卻給世界環(huán)境...
    茶點故事閱讀 41,249評論 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧叼丑,春花似錦关翎、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,866評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至症副,卻和暖如春店雅,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背贞铣。 一陣腳步聲響...
    開封第一講書人閱讀 32,991評論 1 269
  • 我被黑心中介騙來泰國打工闹啦, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人辕坝。 一個月前我還...
    沈念sama閱讀 48,063評論 3 370
  • 正文 我出身青樓窍奋,卻偏偏與公主長得像,于是被迫代替她去往敵國和親酱畅。 傳聞我的和親對象是個殘疾皇子琳袄,可洞房花燭夜當晚...
    茶點故事閱讀 44,871評論 2 354

推薦閱讀更多精彩內容

  • 介紹 SqueezeNet同這個系列要介紹的其它任一CNN模型一樣不只關心模型分類精度,同樣也重視其計算速度與模型...
    manofmountain閱讀 3,777評論 0 4
  • 最近發(fā)現(xiàn)自己的一個缺點纺酸,很多原理雖然從理論上或著數(shù)學上理解了窖逗,但是難以用一種簡潔的偏于溝通的方式表達出來。所以合上...
    給力桃閱讀 1,708評論 0 0
  • 文章作者:Tyan博客:noahsnail.com | CSDN | 簡書 聲明:作者翻譯論文僅為學習餐蔬,如有侵權請...
    SnailTyan閱讀 9,096評論 0 16
  • 介紹 終于可以說一下Resnet分類網(wǎng)絡了碎紊,它差不多是當前應用最為廣泛的CNN特征提取網(wǎng)絡。它的提出始于2015年...
    manofmountain閱讀 295,307評論 3 79
  • 這是一道再平常不過的門 被歲月侵蝕的藍幾處剝落 任夕陽下固執(zhí)的風 拂過不動的環(huán) 空自搖落幾度黃昏 門后的兩院青草 ...
    張秉初閱讀 291評論 0 6