Octave卷積學習筆記

本文首發(fā)于個人博客

Octave卷積

Octave卷積的主題思想來自于圖片的分頻思想,首先認為圖像可進行分頻:

  • 低頻部分:圖像低頻部分保存圖像的大體信息毙芜,信息數(shù)據(jù)量較少
  • 高頻部分:圖像高頻部分保留圖像的細節(jié)信息忽媒,信息數(shù)據(jù)量較大

由此,認為卷積神經(jīng)網(wǎng)絡中的feature map也可以進行分頻腋粥,可按channel分為高頻部分和低頻部分晦雨,如圖所示:

feature_map.png

對于一個feature map,將其按通道分為兩個部分隘冲,分別為低頻通道和高頻通道闹瞧。隨后將低頻通道的長寬各縮減一半,則將一個feature map分為了高頻和低頻兩個部分展辞,即為Octave卷積處理的基本feature map奥邮,使用X表示,該類型X可表示為X = [X^H,X^L]罗珍,其中X^H為高頻部分洽腺,X^L為低頻部分。

為了處理這種結構的feature map覆旱,其使用了如下所示的Octave卷積操作:

octave_conv.png

首先考慮低頻部分輸入X^L蘸朋,該部分進行兩個部分的操作:

  • X^L \to X^H:從低頻到高頻,首先使用指定卷積核W^{L \to H}進行卷積扣唱,隨后進行Upample操作生成與高頻部分長寬相同的Tensor藕坯,最終產(chǎn)生Y^{L\to H} = Upsample(Conv(X^L,W^{L \to H}),2)
  • X^L \to X^L:從低頻到低頻团南,這個部分為直接進行卷積操作Y^{L \to L} = Conv(X^L,W^{L \to L})

隨后考慮高頻部分,與低頻部分類似有兩個部分的操作:

  • X^H \to X^H:從高頻到高頻堕担,直接進行卷積操作Y^{H \to H} = Conv(X^H,W^{H \to H})
  • X^H \to X^L:從高頻到低頻已慢,首先進行stride和kernel均為2的平均值池化,再進行卷積操作霹购,生成與Y^L通道數(shù)相同的feature map佑惠,最終產(chǎn)生Y^{H \to L} = conv(avgpool(X^H,2),W^{H \to L}))

最終,有Y^L = Y^{H \to L} + Y^{L \to L}Y^H = Y^{H \to H} +Y^{L \to H}齐疙,因此可以總結如下公式:
Y^L = Y^{H \to L} + Y^{L \to L} = Y^{H \to L} = conv(avgpool(X^H,2),W^{H \to L})) + Conv(X^L,W^{L \to L}) \\ Y^H = Y^{H \to H} +Y^{L \to H} = Conv(X^H,W^{H \to H}) + Upsample(Conv(X^L,W^{L \to H}),2)
因此有四個部分的權值:

來源/去向 \to H \to L
H W^{H \to H} W^{H \to L}
L W^{L \to H} W^{L \to L}

另外進行使用時膜楷,在網(wǎng)絡的輸入和輸出需要將兩個頻率上的Tensor聚合,做法如下:

  • 輸入部分贞奋,取X = [X,0]赌厅,即有X^H = XX^L = 0轿塔,僅進行H \to LH \to H操作特愿,輸出輸出的低頻僅有X生成,即Y^H = Y^{H \to H}Y^L = Y^{H \to L}
  • 輸出部分勾缭,取X = [X^H,X^L]揍障,\alpha = 0。即僅進行L \to HH \to H的操作俩由,最終輸出為Y = Y^{L \to H} + Y^{H \to H}

性能分析

以下計算均取原Tensor尺寸為CI \times W \times H毒嫡,卷積尺寸為CO \times CI \times K \times K,輸出Tensor尺寸為CO \times W \times H(stride=1幻梯,padding設置使feature map尺寸不變)兜畸。

計算量分析

Octave卷積的最大優(yōu)勢在于減小計算量,取參數(shù)\alpha為低頻通道占總通道的比例碘梢。首先考慮直接卷積的計算量咬摇,對于輸出feature map中的每個數(shù)據(jù),需要進行CI \times K \times K次乘加計算痘系,因此總的計算量為:
C_{conv} = (CO \times W \times H) \times (CI \times K \times K)
現(xiàn)考慮Octave卷積菲嘴,有四個卷積操作:

  • L \to L卷積:C_{L \to L} = \alpha^2 \times (CO \times \frac{W}{2} \times \frac{H}{2}) \times (CI \times K \times K) = \frac{\alpha^2}{4} \times C_{conv}
  • L \to H卷積:C_{L \to H} = ((1 - \alpha) \times CO \times \frac{W}{2} \times \frac{H}{2}) \times ( \alpha \times CI \times K \times K) = \frac{\alpha(1-\alpha)}{4} \times C_{conv}
  • H \to L卷積:C_{H \to L} = (\alpha \times CO \times \frac{W}{2} \times \frac{H}{2}) \times ((1 - \alpha) \times CI \times K \times K) = \frac{\alpha(1-\alpha)}{4} \times C_{conv}
  • H \to H卷積:C_{H \to H} = ((1 - \alpha) \times CO \times W \times H) \times ((1 - \alpha) \times CI \times K \times K) = (1 - \alpha)^2 \times C_{conv}

總上,可以得出計算量有:
\frac{C_{octave}}{C_{conv}} = \frac{\alpha^2 + 2\alpha(1-\alpha) + 4 (1 - \alpha)^2}{4} = 1 - \frac{3}{4}\alpha(2- \alpha)
\alpha \in [0,1]中單調遞減汰翠,當取\alpha = 1時,有\frac{C_{octave}}{C_{conv}} = \frac{1}{4}昭雌。

參數(shù)量分析

原卷積的參數(shù)量為:
W_{conv} = CO \times CI \times K \times K
Octave卷積將該部分分為四個复唤,對于每個卷積有:

  • L \to L卷積:W_{L \to L} =(\alpha\times CO) \times (\alpha \times CI) \times K \times K = \alpha^2 \times W_{conv}
  • L \to H卷積:W_{L \to H} =((1-\alpha) \times CO) \times (\alpha \times CI) \times K \times K = \alpha(1 - \alpha) \times W_{conv}
  • H \to L卷積:W_{H \to L} =(\alpha \times CO) \times ((1-\alpha) \times CI) \times K \times K = \alpha(1 - \alpha) \times W_{conv}
  • H \to H卷積:W_{H \to L} =((1-\alpha) \times CO) \times ((1-\alpha) \times CI) \times K \times K = (1 - \alpha)^2 \times W_{conv}

因此共有參數(shù)量:
C_{octave} = (\alpha^2 + 2\alpha(1 - \alpha) + (1 - \alpha)^2) \times C_{conv} = C_{conv}
由此,參數(shù)量沒有發(fā)生變化烛卧,該方法無法減少參數(shù)量佛纫。

Octave卷積實現(xiàn)

Octave卷積模塊

以下實現(xiàn)了一個兼容普通卷積的Octave卷積模塊妓局,針對不同的高頻低頻feature map的通道數(shù),分為以下幾種情況:

  • Lout_channel != 0 and Lin_channel != 0:通用Octave卷積呈宇,需要四個卷積參數(shù)
  • Lout_channel == 0 and Lin_channel != 0:輸出Octave卷積好爬,輸入有低頻部分,輸出無低頻部分甥啄,僅需要兩個卷積參數(shù)
  • Lout_channel != 0 and Lin_channel == 0:輸入Octave卷積存炮,輸入無低頻部分,輸出有低頻部分蜈漓,僅需要兩個卷積參數(shù)
  • Lout_channel == 0 and Lin_channel == 0:退化為普通卷積穆桂,輸入輸出均無低頻部分,僅有一個卷積參數(shù)
class OctaveConv(pt.nn.Module):

    def __init__(self,Lin_channel,Hin_channel,Lout_channel,Hout_channel,
            kernel,stride,padding):
        super(OctaveConv, self).__init__()
        if Lout_channel != 0 and Lin_channel != 0:
            self.convL2L = pt.nn.Conv2d(Lin_channel,Lout_channel, kernel,stride,padding)
            self.convH2L = pt.nn.Conv2d(Hin_channel,Lout_channel, kernel,stride,padding)
            self.convL2H = pt.nn.Conv2d(Lin_channel,Hout_channel, kernel,stride,padding)
            self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
        elif Lout_channel == 0 and Lin_channel != 0:
            self.convL2L = None
            self.convH2L = None
            self.convL2H = pt.nn.Conv2d(Lin_channel,Hout_channel, kernel,stride,padding)
            self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
        elif Lout_channel != 0 and Lin_channel == 0:
            self.convL2L = None
            self.convH2L = pt.nn.Conv2d(Hin_channel,Lout_channel, kernel,stride,padding)
            self.convL2H = None
            self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
        else:
            self.convL2L = None
            self.convH2L = None
            self.convL2H = None
            self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
        self.upsample = pt.nn.Upsample(scale_factor=2)
        self.pool = pt.nn.AvgPool2d(2)

    def forward(self,Lx,Hx):
        if self.convL2L is not None:
            L2Ly = self.convL2L(Lx)
        else:
            L2Ly = 0
        if self.convL2H is not None:
            L2Hy = self.upsample(self.convL2H(Lx))
        else:
            L2Hy = 0
        if self.convH2L is not None:
            H2Ly = self.convH2L(self.pool(Hx))
        else:
            H2Ly = 0
        if self.convH2H is not None:
            H2Hy = self.convH2H(Hx)
        else:
            H2Hy = 0
        return L2Ly+H2Ly,L2Hy+H2Hy

在前項傳播的過程中融虽,根據(jù)是否有對應的卷積操作參數(shù)判斷是否進行卷積享完,若不進行卷積,將輸出置為0有额。前向傳播時般又,輸入為低頻和高頻兩個feature map,輸出為低頻和高頻兩個feature map巍佑,輸入情況和參數(shù)配置應與通道數(shù)的配置匹配茴迁。

其他部分

使用MNIST數(shù)據(jù)集,構建了一個三層卷積+兩層全連接層的神經(jīng)網(wǎng)絡句狼,使用Adam優(yōu)化器訓練笋熬,代價函數(shù)使用交叉熵函數(shù),訓練3輪腻菇,最后在測試集上進行測試胳螟。

import torch as pt
import torchvision as ptv
# download dataset
train_dataset = ptv.datasets.MNIST("./",train=True,download=True,transform=ptv.transforms.ToTensor())
test_dataset = ptv.datasets.MNIST("./",train=False,download=True,transform=ptv.transforms.ToTensor())
train_loader = pt.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader = pt.utils.data.DataLoader(test_dataset,batch_size=64,shuffle=True)

# build network
class mnist_model(pt.nn.Module):

    def __init__(self):
        super(mnist_model, self).__init__()
        self.conv1 = OctaveConv(0,1,8,8,kernel=3,stride=1,padding=1)        
        self.conv2 = OctaveConv(8,8,16,16,kernel=3,stride=1,padding=1)      
        self.conv3 = OctaveConv(16,16,0,64,kernel=3,stride=1,padding=1)
        self.pool =  pt.nn.MaxPool2d(2)
        self.relu = pt.nn.ReLU()
        self.fc1 = pt.nn.Linear(64*7*7,256)
        self.fc2 = pt.nn.Linear(256,10)

    def forward(self,x):
        out = [self.pool(self.relu(i)) for i in self.conv1(0,x)]
        out = self.conv2(*out)
        _,out = self.conv3(*out)
        out = self.fc1(self.pool(self.relu(out)).view(-1,64*7*7))
        return self.fc2(out)


net = mnist_model().cuda()
# print(net)
# prepare training
def acc(outputs,label):
    _,data = pt.max(outputs,dim=1)
    return pt.mean((data.float()==label.float()).float()).item()

lossfunc = pt.nn.CrossEntropyLoss().cuda()
optimizer = pt.optim.Adam(net.parameters())

# train
for _ in range(3):
    for i,(data,label) in enumerate(train_loader) :
        optimizer.zero_grad()
        # print(i,data,label)
        data,label = data.cuda(),label.cuda()
        outputs = net(data)
        loss = lossfunc(outputs,label)
        loss.backward()

        optimizer.step()
        if i % 100 == 0:
            print(i,loss.cpu().data.item(),acc(outputs,label))

# test
acc_list = []
for i,(data,label) in enumerate(test_loader):
    data,label = data.cuda(),label.cuda()
    outputs = net(data)
    acc_list.append(acc(outputs,label))
print("Test:",sum(acc_list)/len(acc_list))

# save
pt.save(net,"./model.pth")

最終獲得模型的準確率為0.988

?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市筹吐,隨后出現(xiàn)的幾起案子糖耸,更是在濱河造成了極大的恐慌,老刑警劉巖丘薛,帶你破解...
    沈念sama閱讀 218,858評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件嘉竟,死亡現(xiàn)場離奇詭異,居然都是意外死亡洋侨,警方通過查閱死者的電腦和手機舍扰,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,372評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來希坚,“玉大人边苹,你說我怎么就攤上這事〔蒙” “怎么了个束?”我有些...
    開封第一講書人閱讀 165,282評論 0 356
  • 文/不壞的土叔 我叫張陵慕购,是天一觀的道長。 經(jīng)常有香客問我茬底,道長沪悲,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,842評論 1 295
  • 正文 為了忘掉前任阱表,我火速辦了婚禮殿如,結果婚禮上,老公的妹妹穿的比我還像新娘捶枢。我一直安慰自己握截,他們只是感情好,可當我...
    茶點故事閱讀 67,857評論 6 392
  • 文/花漫 我一把揭開白布烂叔。 她就那樣靜靜地躺著谨胞,像睡著了一般。 火紅的嫁衣襯著肌膚如雪蒜鸡。 梳的紋絲不亂的頭發(fā)上胯努,一...
    開封第一講書人閱讀 51,679評論 1 305
  • 那天,我揣著相機與錄音逢防,去河邊找鬼叶沛。 笑死,一個胖子當著我的面吹牛忘朝,可吹牛的內(nèi)容都是我干的灰署。 我是一名探鬼主播,決...
    沈念sama閱讀 40,406評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼局嘁,長吁一口氣:“原來是場噩夢啊……” “哼溉箕!你這毒婦竟也來了?” 一聲冷哼從身側響起悦昵,我...
    開封第一講書人閱讀 39,311評論 0 276
  • 序言:老撾萬榮一對情侶失蹤肴茄,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后但指,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體寡痰,經(jīng)...
    沈念sama閱讀 45,767評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,945評論 3 336
  • 正文 我和宋清朗相戀三年棋凳,在試婚紗的時候發(fā)現(xiàn)自己被綠了拦坠。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,090評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡剩岳,死狀恐怖贪婉,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情卢肃,我是刑警寧澤疲迂,帶...
    沈念sama閱讀 35,785評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站莫湘,受9級特大地震影響尤蒿,放射性物質發(fā)生泄漏。R本人自食惡果不足惜幅垮,卻給世界環(huán)境...
    茶點故事閱讀 41,420評論 3 331
  • 文/蒙蒙 一腰池、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧忙芒,春花似錦示弓、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,988評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至潮峦,卻和暖如春囱皿,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背忱嘹。 一陣腳步聲響...
    開封第一講書人閱讀 33,101評論 1 271
  • 我被黑心中介騙來泰國打工嘱腥, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人拘悦。 一個月前我還...
    沈念sama閱讀 48,298評論 3 372
  • 正文 我出身青樓齿兔,卻偏偏與公主長得像,于是被迫代替她去往敵國和親础米。 傳聞我的和親對象是個殘疾皇子分苇,可洞房花燭夜當晚...
    茶點故事閱讀 45,033評論 2 355

推薦閱讀更多精彩內(nèi)容