引言
論文地址
這篇論文是周一時帶我的大佬(現(xiàn)在瑞士讀博士护蝶,據(jù)說還在nips上面發(fā)過文章??华烟,瑟瑟發(fā)抖)發(fā)給我一個一篇鏈接文章,博客是計劃周五就要寫出來的持灰,但是由于要將maxnet的代碼遷移到pytorch的resnet上面花費了一些時間盔夜。至今沒見過這位大佬,我這位本科大白只是每周一閱讀他發(fā)的論文和相關(guān)demo代碼,改寫或者遷移到現(xiàn)在的工業(yè)圖像分類上喂链。有想一起學(xué)習(xí)的可以加qq:1678354579進行討論返十。
下面的內(nèi)容由于時間有限,主要以代碼實現(xiàn)為主椭微。才疏學(xué)淺洞坑,如果那些錯誤還請大佬多多指正!
摘要
在自然圖像中蝇率,信息總是在不同頻率中表達的迟杂,其中高頻信號一般包含豐富的細節(jié)而低頻信號一般包含整體的結(jié)構(gòu)。類似地瓢剿,卷積層的輸出特征圖同樣可以被看作是混合了不同頻域的信息逢慌。在這項工作中,我們提出了如何根據(jù)頻域去分解信息混合的特征圖间狂,并設(shè)計了一個新穎的八度卷積(Octave Convolution攻泼,OctConv)操作來保存和處理那些在較低空間分辨率下變化“較慢”(Slower)的特征圖,從而減少存儲和計算開銷鉴象。與現(xiàn)有多尺度(multi-scale)方法不同的是忙菠,八度卷積被制定為一種單個通用的即插即用卷積單元,可以直接替換普通(vanilla)卷積而不需要對現(xiàn)有網(wǎng)絡(luò)有任何調(diào)整纺弊。它同時也是對一些表明有著更好拓撲(topologies)或者減少通道冗余的方法的補充牛欢,并且與這些方法正交(orthogonal)。通過簡單地用八度卷積替換普通卷積淆游,我們在實驗中發(fā)現(xiàn)我們在減少存儲和計算開銷的同時傍睹,還能持續(xù)提高圖像和視頻識別任務(wù)的準確率。一個使用八度卷積的ResNet-152網(wǎng)絡(luò)能夠在ImageNet上達到82.9%的Top-1分類準確率犹菱,而其浮點計算量僅僅只有22.2G(Giga)拾稳。
- 總結(jié)下來就是:自然界的圖像中高頻的信息表示細膩而豐富的細節(jié),低頻表示整體的輪廓和布局腊脱。八度卷積最大的優(yōu)點就是節(jié)省存儲空間的運算力访得,而且有怎么如此強的功能只需要改動網(wǎng)絡(luò)中卷積部分即可實現(xiàn)即插即用的功能!我的代碼能力一般陕凹,大概花了一天左右的時間改寫了octconv版的resnet悍抑,后期經(jīng)過改動能夠適應(yīng)三種卷積的增強版
- 加一句,關(guān)于低頻和高頻個人覺得可能搞美術(shù)的人更能理解杜耙。比如像畫人物一樣搜骡,大致的輪廓是差不多的,不經(jīng)常改變?yōu)榈皖l佑女。具體的細節(jié)浆兰,一顰一動每個人都不一樣為高頻磕仅。本人為工科宅男一枚,獻丑了??
原理淺談
關(guān)于詳細的原理簸呈,大家可以參考論文和一片中文博客榕订。我這里更深的理解也是來源這篇博客,推薦大家去看看蜕便。
這里我主要從個人代碼理解和實現(xiàn)的角度來聊一聊原理劫恒,說白了就是數(shù)學(xué)公式看的有點蒙逼。代碼和公式相結(jié)合能夠理解更深入轿腺。
傳統(tǒng)的圖像卷積是每一個卷積核為[kernel_size,kernel_size,in_channels]两嘴,通過一系列相乘相加操作后得出特征圖的一個像素點。如果是BP網(wǎng)絡(luò)這一步就已經(jīng)結(jié)束了族壳,但是卷積網(wǎng)絡(luò)會利用stride進行移動相同的卷積核得出下一個像素點憔辫。就這樣按照步長在圖像的寬高進行移動,得出一個通道的特征圖仿荆,那如果我想要out_channels個通道的特征圖贰您。我只需要out_channels個卷積和就可以了,所以卷積的參數(shù)維度就是[kernel_size,kernel_size,in_channels,out_channels]拢操。后期人們在消除特征圖的冗余锦亦,人們又提出了grop_conv和depth_wise的卷積,對應(yīng)的網(wǎng)絡(luò)就是現(xiàn)在的resenxt和mobilenet令境。關(guān)于冗余的理解之前看過一本書上講解是過多的輸出通道杠园,卷積核很大概率存在相似性,那么輸出的特征圖就會存在線性相關(guān)(簡單說就是特征圖的一個向量可以由另一個向量線性表示)舔庶。這部分如果大家有感到不太懂的抛蚁,自動google關(guān)鍵字√璩龋或者加我私聊篮绿,歡迎騷擾!
好像有點扯遠了吕漂,,尘应,惶凝,現(xiàn)在開始進入重點啦!犬钢!八度卷積是在分辨率的維度提出低頻的信息在傳統(tǒng)的卷積中也存在冗余苍鲜,通過將特征圖分離成低頻信息(低分辨率),高頻信息(高分辨率)的達到節(jié)省存儲和算力玷犹。大概估算一下混滔,如果每一個特征圖的一半為低頻信息,那么他的分辨率降低為原始特征圖的1/2,存儲會卷積運算會減少1/4坯屿。
下采樣剛才我們降低冗余是通過降低低頻信息的分辨率,那么現(xiàn)在的問題是如何進行分辨率的降低呢?卷積網(wǎng)絡(luò)中有兩種下采樣的方式鬓长,一種是池化(pool)叔收,一種是步長為2的卷積。論文的實驗是說池化的方式會更有效
將八度卷積嵌入到resnet中發(fā)現(xiàn)stride=2的卷積下采樣并沒有降低可訓(xùn)練的參數(shù)吠昭,而pool的下采樣方式則數(shù)十倍的降低了參數(shù)量喊括。具體的數(shù)值當(dāng)時沒有保存,應(yīng)該會降低的更過矢棚。pool我們好理解郑什,因為pool本來并沒有可訓(xùn)練卷積,而stride=2的卷積下采樣本質(zhì)是將原始的卷積核分解成四份(中間卷積)或者兩份(開始和結(jié)尾卷積)蒲肋,所以他的可訓(xùn)練參數(shù)是不會減少的蘑拯。
八度卷積路線圖
第一層卷積:輸入圖像默認全部為高頻信息,故alpha_int=0肉津,alpha_out=
中間層卷積强胰,特征圖包含低頻和高頻信息,一般設(shè)置為alpha_int=alpha_out=
最后一層卷積妹沙,回復(fù)正常特征圖偶洋,故alpha_int=,alpha_out=0
這里的參數(shù)設(shè)置一般為0.5距糖,0.2玄窝。具體的參數(shù)設(shè)置會根據(jù)圖像的特征豐富程度調(diào)整。
簡單總結(jié):特征圖由第一層進入分為兩路(低頻信息和高頻信息)悍引,中間層一直是兩路信息恩脂,并且兩路信息之間有交互,最終匯聚為一路信息輸出趣斤。
具體實現(xiàn)代碼
版本一 pool池化
# -*- coding: utf-8 -*-
# @Time : 2019/4/22 13:29
# @Author : ljf
import torch
import torch.nn.functional as F
from torch import nn
class OctConv2d_v1(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
alpha_in=0.5,
alpha_out=0.5
):
"""adapt first octconv , octconv and last octconv
"""
assert alpha_in >= 0 and alpha_in <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
alpha_in)
assert alpha_out >= 0 and alpha_out <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
alpha_out)
super(OctConv2d_v1, self).__init__(in_channels,
out_channels,
dilation,
groups,
bias,)
self.alpha_in = alpha_in
self.alpha_out = alpha_out
self.kernel_size = (1,1)
self.stride = (1,1)
self.avgPool = nn.AvgPool2d(kernel_size, stride, padding)
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
self.inChannelSplitIndex = int(
self.alpha_in * self.in_channels)
self.outChannelSplitIndex = int(
self.alpha_out * self.out_channels)
# split bias
if bias:
self.hh_bias = self.bias[self.outChannelSplitIndex:]
self.hl_bias = self.bias[:self.outChannelSplitIndex]
self.ll_bias = self.bias[ :self.outChannelSplitIndex]
self.lh_bias = self.bias[ self.outChannelSplitIndex:]
else:
self.hh_bias = None
self.hl_bias = None
self.ll_bias = None
self.ll_bias = None
# conv and upsample
self.upsample = F.interpolate
def forward(self, x):
if not isinstance(x, tuple):
# first octconv
input_h = x if self.alpha_in == 0 else None
input_l = x if self.alpha_in == 1 else None
else:
input_l = x[0]
input_h = x[1]
output = [0, 0]
# H->H
if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
output_hh = F.conv2d(self.avgPool(input_h),
self.weight[
self.outChannelSplitIndex:,
self.inChannelSplitIndex:,
:, :],
self.bias[self.outChannelSplitIndex:],
self.kernel_size
)
output[1] += output_hh
# H->L
if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
output_hl = F.conv2d(self.avgpool(self.avgPool(input_h)),
self.weight[
:self.outChannelSplitIndex,
self.inChannelSplitIndex:,
:, :],
self.bias[:self.outChannelSplitIndex],
self.kernel_size
)
output[0] += output_hl
# L->L
if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
output_ll = F.conv2d((self.avgPool(input_l)),
self.weight[
:self.outChannelSplitIndex,
:self.inChannelSplitIndex,
:, :],
self.bias[:self.outChannelSplitIndex],
self.kernel_size
)
output[0] += output_ll
# L->H
if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
output_lh = F.conv2d(self.avgPool(input_l),
self.weight[
self.outChannelSplitIndex:,
:self.inChannelSplitIndex,
:, :],
self.bias[self.outChannelSplitIndex:],
self.kernel_size
)
output_lh = self.upsample(output_lh, scale_factor=2)
output[1] += output_lh
if isinstance(output[0], int):
out = output[1]
else:
out = tuple(output)
return out
if __name__ == "__main__":
input = torch.randn(1, 3, 32, 32)
octconv1 = OctConv2d(
in_channels=3,
out_channels=6,
kernel_size=3,
padding=1,
stride=2,
alpha_in=0,
alpha_out=0.3)
octconv2 = OctConv2d(
in_channels=6,
out_channels=16,
kernel_size=2,
padding=0,
stride=2,
alpha_in=0.3,
alpha_out=0.5)
lastconv = OctConv2d(
in_channels=16,
out_channels=32,
kernel_size=2,
padding=0,
stride=2,
alpha_in=0.5,
alpha_out=0)
# bn1 = OctBN(3,3)
# ac1 = OctAc(name="relu")
out = octconv1(input)
print(len(out))
print(out[0].size())
print(out[1].size())
out = octconv2(out)
print(len(out))
print(out[0].size())
print(out[1].size())
out = lastconv(out)
print(len(out))
print(out[0].size())
print(out[1])
版本二 stride=2的卷積
# -*- coding: utf-8 -*-
# @Time : 2019/4/22 10:35
# @Author : ljf
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class OctConv2d_v2(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
alpha_in=0.5,
alpha_out=0.5,):
assert alpha_in >= 0 and alpha_in <= 1
assert alpha_out >= 0 and alpha_out <= 1
super(OctConv2d_v2, self).__init__(in_channels, out_channels,
kernel_size, stride, padding,
dilation, groups, bias)
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
self.alpha_in = alpha_in
self.alpha_out = alpha_out
self.inChannelSplitIndex = math.floor(
self.alpha_in * self.in_channels)
self.outChannelSplitIndex = math.floor(
self.alpha_out * self.out_channels)
if bias:
self.hh_bias = self.bias[self.outChannelSplitIndex:]
self.hl_bias = self.bias[:self.outChannelSplitIndex]
self.ll_bias = self.bias[ :self.outChannelSplitIndex]
self.lh_bias = self.bias[ self.outChannelSplitIndex:]
else:
self.hh_bias = None
self.hl_bias = None
self.ll_bias = None
self.lh_bias = None
def forward(self, input):
if not isinstance(input, tuple):
assert self.alpha_in == 0 or self.alpha_in == 1
inputLow = input if self.alpha_in == 1 else None
inputHigh = input if self.alpha_in == 0 else None
else:
inputLow = input[0]
inputHigh = input[1]
output = [0, 0]
# H->H
if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
outputH2H = F.conv2d(
inputHigh,
self.weight[
self.outChannelSplitIndex:,
self.inChannelSplitIndex:,
:,
:],
self.hh_bias,
self.stride,
self.padding,
self.dilation,
self.groups)
output[1] += outputH2H
# H->L
if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
outputH2L = F.conv2d(
self.avgpool(inputHigh),
self.weight[
:self.outChannelSplitIndex,
self.inChannelSplitIndex:,
:,
:],
self.hl_bias,
self.stride,
self.padding,
self.dilation,
self.groups)
output[0] += outputH2L
# L->L
if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
outputL2L = F.conv2d(
inputLow,
self.weight[
:self.outChannelSplitIndex,
:self.inChannelSplitIndex,
:,
:],
self.ll_bias,
self.stride,
self.padding,
self.dilation,
self.groups)
output[0] += outputL2L
# L->H
if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
outputL2H = F.conv2d(
F.interpolate(inputLow, scale_factor=2),
self.weight[
self.outChannelSplitIndex:,
:self.inChannelSplitIndex,
:,
:],
self.lh_bias,
self.stride,
self.padding,
self.dilation,
self.groups)
output[1] += outputL2H
if isinstance(output[0],int):
out = output[1]
else:
out = tuple(output)
return out
if __name__ == "__main__":
input = torch.randn(1, 3, 32, 32)
octconv1 = OctConv2d(in_channels=3,
out_channels=6,
kernel_size=3,
stride=2,
padding=1,
dilation=1,
groups=1,
bias=True,
alpha_in=0.,
alpha_out=0.25)
octconv2 = OctConv2d(in_channels=6,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
alpha_in=0.25,
alpha_out=0.5)
out = octconv1(input)
print(len(out))
print(out[0].shape)
print(out[1].size())
out = octconv2(out)
print(len(out))
print(out[0].size())
print(out[1].size())
github地址
功力有限俩块,還請各位多多包涵,多多指證浓领。
參考文章:https://mp.weixin.qq.com/s?__biz=MzUyMjE2MTE0Mw==&mid=2247487810&idx=1&sn=1428510ec154a24a9e779d82f693930d&chksm=f9d14fdacea6c6cc42a630e57726c1789a54dc8e31616bd747fb2c35f41dbbd86f2c2a0b8998&mpshare=1&scene=23&srcid=#rd