FB等提出全新卷積操作OctConv,速度接近理論極限

引言

論文地址
這篇論文是周一時帶我的大佬(現(xiàn)在瑞士讀博士护蝶,據(jù)說還在nips上面發(fā)過文章??华烟,瑟瑟發(fā)抖)發(fā)給我一個一篇鏈接文章,博客是計劃周五就要寫出來的持灰,但是由于要將maxnet的代碼遷移到pytorch的resnet上面花費了一些時間盔夜。至今沒見過這位大佬,我這位本科大白只是每周一閱讀他發(fā)的論文和相關(guān)demo代碼,改寫或者遷移到現(xiàn)在的工業(yè)圖像分類上喂链。有想一起學(xué)習(xí)的可以加qq:1678354579進行討論返十。
下面的內(nèi)容由于時間有限,主要以代碼實現(xiàn)為主椭微。才疏學(xué)淺洞坑,如果那些錯誤還請大佬多多指正!

摘要

在自然圖像中蝇率,信息總是在不同頻率中表達的迟杂,其中高頻信號一般包含豐富的細節(jié)而低頻信號一般包含整體的結(jié)構(gòu)。類似地瓢剿,卷積層的輸出特征圖同樣可以被看作是混合了不同頻域的信息逢慌。在這項工作中,我們提出了如何根據(jù)頻域去分解信息混合的特征圖间狂,并設(shè)計了一個新穎的八度卷積(Octave Convolution攻泼,OctConv)操作來保存和處理那些在較低空間分辨率下變化“較慢”(Slower)的特征圖,從而減少存儲和計算開銷鉴象。與現(xiàn)有多尺度(multi-scale)方法不同的是忙菠,八度卷積被制定為一種單個通用的即插即用卷積單元,可以直接替換普通(vanilla)卷積而不需要對現(xiàn)有網(wǎng)絡(luò)有任何調(diào)整纺弊。它同時也是對一些表明有著更好拓撲(topologies)或者減少通道冗余的方法的補充牛欢,并且與這些方法正交(orthogonal)。通過簡單地用八度卷積替換普通卷積淆游,我們在實驗中發(fā)現(xiàn)我們在減少存儲和計算開銷的同時傍睹,還能持續(xù)提高圖像和視頻識別任務(wù)的準確率。一個使用八度卷積的ResNet-152網(wǎng)絡(luò)能夠在ImageNet上達到82.9%的Top-1分類準確率犹菱,而其浮點計算量僅僅只有22.2G(Giga)拾稳。

  • 總結(jié)下來就是:自然界的圖像中高頻的信息表示細膩而豐富的細節(jié),低頻表示整體的輪廓和布局腊脱。八度卷積最大的優(yōu)點就是節(jié)省存儲空間的運算力访得,而且有怎么如此強的功能只需要改動網(wǎng)絡(luò)中卷積部分即可實現(xiàn)即插即用的功能!我的代碼能力一般陕凹,大概花了一天左右的時間改寫了octconv版的resnet悍抑,后期經(jīng)過改動能夠適應(yīng)三種卷積的增強版
  • 加一句,關(guān)于低頻和高頻個人覺得可能搞美術(shù)的人更能理解杜耙。比如像畫人物一樣搜骡,大致的輪廓是差不多的,不經(jīng)常改變?yōu)榈皖l佑女。具體的細節(jié)浆兰,一顰一動每個人都不一樣為高頻磕仅。本人為工科宅男一枚,獻丑了??

原理淺談

關(guān)于詳細的原理簸呈,大家可以參考論文和一片中文博客榕订。我這里更深的理解也是來源這篇博客,推薦大家去看看蜕便。
這里我主要從個人代碼理解和實現(xiàn)的角度來聊一聊原理劫恒,說白了就是數(shù)學(xué)公式看的有點蒙逼。代碼和公式相結(jié)合能夠理解更深入轿腺。
傳統(tǒng)的圖像卷積是每一個卷積核為[kernel_size,kernel_size,in_channels]两嘴,通過一系列相乘相加操作后得出特征圖的一個像素點。如果是BP網(wǎng)絡(luò)這一步就已經(jīng)結(jié)束了族壳,但是卷積網(wǎng)絡(luò)會利用stride進行移動相同的卷積核得出下一個像素點憔辫。就這樣按照步長在圖像的寬高進行移動,得出一個通道的特征圖仿荆,那如果我想要out_channels個通道的特征圖贰您。我只需要out_channels個卷積和就可以了,所以卷積的參數(shù)維度就是[kernel_size,kernel_size,in_channels,out_channels]拢操。后期人們在消除特征圖的冗余锦亦,人們又提出了grop_conv和depth_wise的卷積,對應(yīng)的網(wǎng)絡(luò)就是現(xiàn)在的resenxt和mobilenet令境。關(guān)于冗余的理解之前看過一本書上講解是過多的輸出通道杠园,卷積核很大概率存在相似性,那么輸出的特征圖就會存在線性相關(guān)(簡單說就是特征圖的一個向量可以由另一個向量線性表示)舔庶。這部分如果大家有感到不太懂的抛蚁,自動google關(guān)鍵字√璩龋或者加我私聊篮绿,歡迎騷擾!

好像有點扯遠了吕漂,,尘应,惶凝,現(xiàn)在開始進入重點啦!犬钢!八度卷積是在分辨率的維度提出低頻的信息在傳統(tǒng)的卷積中也存在冗余苍鲜,通過將特征圖分離成低頻信息(低分辨率),高頻信息(高分辨率)的達到節(jié)省存儲和算力玷犹。大概估算一下混滔,如果每一個特征圖的一半為低頻信息,那么他的分辨率降低為原始特征圖的1/2,存儲會卷積運算會減少1/4坯屿。
下采樣剛才我們降低冗余是通過降低低頻信息的分辨率,那么現(xiàn)在的問題是如何進行分辨率的降低呢?卷積網(wǎng)絡(luò)中有兩種下采樣的方式鬓长,一種是池化(pool)叔收,一種是步長為2的卷積。論文的實驗是說池化的方式會更有效

消融實驗

將八度卷積嵌入到resnet中發(fā)現(xiàn)stride=2的卷積下采樣并沒有降低可訓(xùn)練的參數(shù)吠昭,而pool的下采樣方式則數(shù)十倍的降低了參數(shù)量喊括。具體的數(shù)值當(dāng)時沒有保存,應(yīng)該會降低的更過矢棚。pool我們好理解郑什,因為pool本來并沒有可訓(xùn)練卷積,而stride=2的卷積下采樣本質(zhì)是將原始的卷積核分解成四份(中間卷積)或者兩份(開始和結(jié)尾卷積)蒲肋,所以他的可訓(xùn)練參數(shù)是不會減少的蘑拯。
八度卷積路線圖
第一層卷積:輸入圖像默認全部為高頻信息,故alpha_int=0肉津,alpha_out=
在這里插入圖片描述

中間層卷積强胰,特征圖包含低頻和高頻信息,一般設(shè)置為alpha_int=alpha_out=
在這里插入圖片描述

最后一層卷積妹沙,回復(fù)正常特征圖偶洋,故alpha_int=,alpha_out=0
在這里插入圖片描述

這里的參數(shù)設(shè)置一般為0.5距糖,0.2玄窝。具體的參數(shù)設(shè)置會根據(jù)圖像的特征豐富程度調(diào)整。
簡單總結(jié):特征圖由第一層進入分為兩路(低頻信息和高頻信息)悍引,中間層一直是兩路信息恩脂,并且兩路信息之間有交互,最終匯聚為一路信息輸出趣斤。

具體實現(xiàn)代碼

版本一 pool池化

# -*- coding: utf-8 -*-
# @Time    : 2019/4/22 13:29
# @Author  : ljf
import torch
import torch.nn.functional as F
from torch import nn


class OctConv2d_v1(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 alpha_in=0.5,
                 alpha_out=0.5
                 ):
        """adapt first octconv , octconv and last octconv

        """
        assert alpha_in >= 0 and alpha_in <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
            alpha_in)
        assert alpha_out >= 0 and alpha_out <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
            alpha_out)
        super(OctConv2d_v1, self).__init__(in_channels,
                                        out_channels,
                                        dilation,
                                        groups,
                                        bias,)
        self.alpha_in = alpha_in
        self.alpha_out = alpha_out
        self.kernel_size = (1,1)
        self.stride = (1,1)
        self.avgPool = nn.AvgPool2d(kernel_size, stride, padding)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.inChannelSplitIndex = int(
            self.alpha_in * self.in_channels)
        self.outChannelSplitIndex = int(
            self.alpha_out * self.out_channels)
        # split bias
        if bias:
            self.hh_bias = self.bias[self.outChannelSplitIndex:]
            self.hl_bias = self.bias[:self.outChannelSplitIndex]
            self.ll_bias = self.bias[ :self.outChannelSplitIndex]
            self.lh_bias = self.bias[ self.outChannelSplitIndex:]
        else:
            self.hh_bias = None
            self.hl_bias = None
            self.ll_bias = None
            self.ll_bias = None

        # conv and upsample
        self.upsample = F.interpolate

    def forward(self, x):
        if not isinstance(x, tuple):
            # first octconv
            input_h = x if self.alpha_in == 0 else None
            input_l = x if self.alpha_in == 1 else None
        else:
            input_l = x[0]
            input_h = x[1]

        output = [0, 0]
        # H->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
            output_hh = F.conv2d(self.avgPool(input_h),
                                 self.weight[
                                 self.outChannelSplitIndex:,
                                 self.inChannelSplitIndex:,
                                 :, :],
                                 self.bias[self.outChannelSplitIndex:],
                                 self.kernel_size
                                 )

            output[1] += output_hh

        # H->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
            output_hl = F.conv2d(self.avgpool(self.avgPool(input_h)),
                                 self.weight[
                :self.outChannelSplitIndex,
                self.inChannelSplitIndex:,
                                     :, :],
                                 self.bias[:self.outChannelSplitIndex],
                                 self.kernel_size
                                 )

            output[0] += output_hl

        # L->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
            output_ll = F.conv2d((self.avgPool(input_l)),
                                 self.weight[
                                 :self.outChannelSplitIndex,
                                 :self.inChannelSplitIndex,
                                 :, :],
                                 self.bias[:self.outChannelSplitIndex],
                                 self.kernel_size
                                 )

            output[0] += output_ll

        # L->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
            output_lh = F.conv2d(self.avgPool(input_l),
                                 self.weight[
                                 self.outChannelSplitIndex:,
                                 :self.inChannelSplitIndex,
                                 :, :],
                                 self.bias[self.outChannelSplitIndex:],
                                 self.kernel_size
                                 )
            output_lh = self.upsample(output_lh, scale_factor=2)

            output[1] += output_lh

        if isinstance(output[0], int):
            out = output[1]
        else:
            out = tuple(output)
        return out
if __name__ == "__main__":
    input = torch.randn(1, 3, 32, 32)
    octconv1 = OctConv2d(
        in_channels=3,
        out_channels=6,
        kernel_size=3,
        padding=1,
        stride=2,
        alpha_in=0,
        alpha_out=0.3)
    octconv2 = OctConv2d(
        in_channels=6,
        out_channels=16,
        kernel_size=2,
        padding=0,
        stride=2,
        alpha_in=0.3,
        alpha_out=0.5)
    lastconv = OctConv2d(
        in_channels=16,
        out_channels=32,
        kernel_size=2,
        padding=0,
        stride=2,
        alpha_in=0.5,
        alpha_out=0)
    # bn1 = OctBN(3,3)
    # ac1 = OctAc(name="relu")
    out = octconv1(input)
    print(len(out))
    print(out[0].size())
    print(out[1].size())
    out = octconv2(out)
    print(len(out))
    print(out[0].size())
    print(out[1].size())

    out = lastconv(out)
    print(len(out))
    print(out[0].size())
    print(out[1])

版本二 stride=2的卷積

# -*- coding: utf-8 -*-
# @Time    : 2019/4/22 10:35
# @Author  : ljf
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class OctConv2d_v2(nn.Conv2d):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            alpha_in=0.5,
            alpha_out=0.5,):
        assert alpha_in >= 0 and alpha_in <= 1
        assert alpha_out >= 0 and alpha_out <= 1
        super(OctConv2d_v2, self).__init__(in_channels, out_channels,
                                           kernel_size, stride, padding,
                                           dilation, groups, bias)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.alpha_in = alpha_in
        self.alpha_out = alpha_out
        self.inChannelSplitIndex = math.floor(
            self.alpha_in * self.in_channels)
        self.outChannelSplitIndex = math.floor(
            self.alpha_out * self.out_channels)
        if bias:
            self.hh_bias = self.bias[self.outChannelSplitIndex:]
            self.hl_bias = self.bias[:self.outChannelSplitIndex]
            self.ll_bias = self.bias[ :self.outChannelSplitIndex]
            self.lh_bias = self.bias[ self.outChannelSplitIndex:]
        else:
            self.hh_bias = None
            self.hl_bias = None
            self.ll_bias = None
            self.lh_bias = None
    def forward(self, input):
        if not isinstance(input, tuple):
            assert self.alpha_in == 0 or self.alpha_in == 1
            inputLow = input if self.alpha_in == 1 else None
            inputHigh = input if self.alpha_in == 0 else None
        else:
            inputLow = input[0]
            inputHigh = input[1]

        output = [0, 0]
        # H->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
            outputH2H = F.conv2d(
                inputHigh,
                self.weight[
                    self.outChannelSplitIndex:,
                    self.inChannelSplitIndex:,
                    :,
                    :],
                self.hh_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[1] += outputH2H

        # H->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
            outputH2L = F.conv2d(
                self.avgpool(inputHigh),
                self.weight[
                    :self.outChannelSplitIndex,
                    self.inChannelSplitIndex:,
                    :,
                    :],
                self.hl_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[0] += outputH2L

        # L->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
            outputL2L = F.conv2d(
                inputLow,
                self.weight[
                    :self.outChannelSplitIndex,
                    :self.inChannelSplitIndex,
                    :,
                    :],
                self.ll_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[0] += outputL2L

        # L->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
            outputL2H = F.conv2d(
                F.interpolate(inputLow, scale_factor=2),
                self.weight[
                    self.outChannelSplitIndex:,
                    :self.inChannelSplitIndex,
                    :,
                    :],
                self.lh_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[1] += outputL2H
        if isinstance(output[0],int):
            out = output[1]
        else:
            out = tuple(output)
        return out


if __name__ == "__main__":
    input = torch.randn(1, 3, 32, 32)
    octconv1 = OctConv2d(in_channels=3,
                         out_channels=6,
                         kernel_size=3,
                         stride=2,
                         padding=1,
                         dilation=1,
                         groups=1,
                         bias=True,
                         alpha_in=0.,
                         alpha_out=0.25)
    octconv2 = OctConv2d(in_channels=6,
                         out_channels=16,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         dilation=1,
                         groups=1,
                         bias=True,
                         alpha_in=0.25,
                         alpha_out=0.5)
    out = octconv1(input)
    print(len(out))
    print(out[0].shape)
    print(out[1].size())

    out = octconv2(out)
    print(len(out))
    print(out[0].size())
    print(out[1].size())

github地址

功力有限俩块,還請各位多多包涵,多多指證浓领。
參考文章:https://mp.weixin.qq.com/s?__biz=MzUyMjE2MTE0Mw==&mid=2247487810&idx=1&sn=1428510ec154a24a9e779d82f693930d&chksm=f9d14fdacea6c6cc42a630e57726c1789a54dc8e31616bd747fb2c35f41dbbd86f2c2a0b8998&mpshare=1&scene=23&srcid=#rd

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末玉凯,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子联贩,更是在濱河造成了極大的恐慌漫仆,老刑警劉巖,帶你破解...
    沈念sama閱讀 210,914評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件泪幌,死亡現(xiàn)場離奇詭異盲厌,居然都是意外死亡署照,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,935評論 2 383
  • 文/潘曉璐 我一進店門吗浩,熙熙樓的掌柜王于貴愁眉苦臉地迎上來建芙,“玉大人,你說我怎么就攤上這事拓萌∷甑觯” “怎么了?”我有些...
    開封第一講書人閱讀 156,531評論 0 345
  • 文/不壞的土叔 我叫張陵微王,是天一觀的道長屡限。 經(jīng)常有香客問我,道長炕倘,這世上最難降的妖魔是什么钧大? 我笑而不...
    開封第一講書人閱讀 56,309評論 1 282
  • 正文 為了忘掉前任,我火速辦了婚禮罩旋,結(jié)果婚禮上啊央,老公的妹妹穿的比我還像新娘。我一直安慰自己涨醋,他們只是感情好瓜饥,可當(dāng)我...
    茶點故事閱讀 65,381評論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著浴骂,像睡著了一般乓土。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上溯警,一...
    開封第一講書人閱讀 49,730評論 1 289
  • 那天趣苏,我揣著相機與錄音,去河邊找鬼梯轻。 笑死食磕,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的喳挑。 我是一名探鬼主播彬伦,決...
    沈念sama閱讀 38,882評論 3 404
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼伊诵!你這毒婦竟也來了单绑?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,643評論 0 266
  • 序言:老撾萬榮一對情侶失蹤日戈,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后孙乖,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體浙炼,經(jīng)...
    沈念sama閱讀 44,095評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡份氧,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,448評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了弯屈。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片蜗帜。...
    茶點故事閱讀 38,566評論 1 339
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖资厉,靈堂內(nèi)的尸體忽然破棺而出厅缺,到底是詐尸還是另有隱情,我是刑警寧澤宴偿,帶...
    沈念sama閱讀 34,253評論 4 328
  • 正文 年R本政府宣布湘捎,位于F島的核電站,受9級特大地震影響窄刘,放射性物質(zhì)發(fā)生泄漏窥妇。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,829評論 3 312
  • 文/蒙蒙 一娩践、第九天 我趴在偏房一處隱蔽的房頂上張望活翩。 院中可真熱鬧,春花似錦翻伺、人聲如沸材泄。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,715評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽拉宗。三九已至,卻和暖如春未妹,著一層夾襖步出監(jiān)牢的瞬間簿废,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,945評論 1 264
  • 我被黑心中介騙來泰國打工络它, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留族檬,地道東北人。 一個月前我還...
    沈念sama閱讀 46,248評論 2 360
  • 正文 我出身青樓化戳,卻偏偏與公主長得像单料,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子点楼,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,440評論 2 348

推薦閱讀更多精彩內(nèi)容