本文首發(fā)于個人博客
Octave卷積
Octave卷積的主題思想來自于圖片的分頻思想,首先認為圖像可進行分頻:
- 低頻部分:圖像低頻部分保存圖像的大體信息毙芜,信息數(shù)據(jù)量較少
- 高頻部分:圖像高頻部分保留圖像的細節(jié)信息忽媒,信息數(shù)據(jù)量較大
由此,認為卷積神經(jīng)網(wǎng)絡中的feature map也可以進行分頻腋粥,可按channel分為高頻部分和低頻部分晦雨,如圖所示:
對于一個feature map,將其按通道分為兩個部分隘冲,分別為低頻通道和高頻通道闹瞧。隨后將低頻通道的長寬各縮減一半,則將一個feature map分為了高頻和低頻兩個部分展辞,即為Octave卷積處理的基本feature map奥邮,使用X表示,該類型X可表示為罗珍,其中
為高頻部分洽腺,
為低頻部分。
為了處理這種結構的feature map覆旱,其使用了如下所示的Octave卷積操作:
首先考慮低頻部分輸入蘸朋,該部分進行兩個部分的操作:
-
:從低頻到高頻,首先使用指定卷積核
進行卷積扣唱,隨后進行Upample操作生成與高頻部分長寬相同的Tensor藕坯,最終產(chǎn)生
-
:從低頻到低頻团南,這個部分為直接進行卷積操作
隨后考慮高頻部分,與低頻部分類似有兩個部分的操作:
-
:從高頻到高頻堕担,直接進行卷積操作
-
:從高頻到低頻已慢,首先進行stride和kernel均為2的平均值池化,再進行卷積操作霹购,生成與
通道數(shù)相同的feature map佑惠,最終產(chǎn)生
最終,有和
齐疙,因此可以總結如下公式:
因此有四個部分的權值:
來源/去向 | ||
---|---|---|
H | ||
L |
另外進行使用時膜楷,在網(wǎng)絡的輸入和輸出需要將兩個頻率上的Tensor聚合,做法如下:
- 輸入部分贞奋,取
赌厅,即有
,
轿塔,僅進行
和
操作特愿,輸出輸出的低頻僅有X生成,即
和
- 輸出部分勾缭,取
揍障,
。即僅進行
和
的操作俩由,最終輸出為
性能分析
以下計算均取原Tensor尺寸為毒嫡,卷積尺寸為
,輸出Tensor尺寸為
(stride=1幻梯,padding設置使feature map尺寸不變)兜畸。
計算量分析
Octave卷積的最大優(yōu)勢在于減小計算量,取參數(shù)為低頻通道占總通道的比例碘梢。首先考慮直接卷積的計算量咬摇,對于輸出feature map中的每個數(shù)據(jù),需要進行
次乘加計算痘系,因此總的計算量為:
現(xiàn)考慮Octave卷積菲嘴,有四個卷積操作:
-
卷積:
-
卷積:
-
卷積:
-
卷積:
總上,可以得出計算量有:
在中單調遞減汰翠,當取
時,有
昭雌。
參數(shù)量分析
原卷積的參數(shù)量為:
Octave卷積將該部分分為四個复唤,對于每個卷積有:
-
卷積:
-
卷積:
-
卷積:
-
卷積:
因此共有參數(shù)量:
由此,參數(shù)量沒有發(fā)生變化烛卧,該方法無法減少參數(shù)量佛纫。
Octave卷積實現(xiàn)
Octave卷積模塊
以下實現(xiàn)了一個兼容普通卷積的Octave卷積模塊妓局,針對不同的高頻低頻feature map的通道數(shù),分為以下幾種情況:
-
Lout_channel != 0 and Lin_channel != 0
:通用Octave卷積呈宇,需要四個卷積參數(shù) -
Lout_channel == 0 and Lin_channel != 0
:輸出Octave卷積好爬,輸入有低頻部分,輸出無低頻部分甥啄,僅需要兩個卷積參數(shù) -
Lout_channel != 0 and Lin_channel == 0
:輸入Octave卷積存炮,輸入無低頻部分,輸出有低頻部分蜈漓,僅需要兩個卷積參數(shù) -
Lout_channel == 0 and Lin_channel == 0
:退化為普通卷積穆桂,輸入輸出均無低頻部分,僅有一個卷積參數(shù)
class OctaveConv(pt.nn.Module):
def __init__(self,Lin_channel,Hin_channel,Lout_channel,Hout_channel,
kernel,stride,padding):
super(OctaveConv, self).__init__()
if Lout_channel != 0 and Lin_channel != 0:
self.convL2L = pt.nn.Conv2d(Lin_channel,Lout_channel, kernel,stride,padding)
self.convH2L = pt.nn.Conv2d(Hin_channel,Lout_channel, kernel,stride,padding)
self.convL2H = pt.nn.Conv2d(Lin_channel,Hout_channel, kernel,stride,padding)
self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
elif Lout_channel == 0 and Lin_channel != 0:
self.convL2L = None
self.convH2L = None
self.convL2H = pt.nn.Conv2d(Lin_channel,Hout_channel, kernel,stride,padding)
self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
elif Lout_channel != 0 and Lin_channel == 0:
self.convL2L = None
self.convH2L = pt.nn.Conv2d(Hin_channel,Lout_channel, kernel,stride,padding)
self.convL2H = None
self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
else:
self.convL2L = None
self.convH2L = None
self.convL2H = None
self.convH2H = pt.nn.Conv2d(Hin_channel,Hout_channel, kernel,stride,padding)
self.upsample = pt.nn.Upsample(scale_factor=2)
self.pool = pt.nn.AvgPool2d(2)
def forward(self,Lx,Hx):
if self.convL2L is not None:
L2Ly = self.convL2L(Lx)
else:
L2Ly = 0
if self.convL2H is not None:
L2Hy = self.upsample(self.convL2H(Lx))
else:
L2Hy = 0
if self.convH2L is not None:
H2Ly = self.convH2L(self.pool(Hx))
else:
H2Ly = 0
if self.convH2H is not None:
H2Hy = self.convH2H(Hx)
else:
H2Hy = 0
return L2Ly+H2Ly,L2Hy+H2Hy
在前項傳播的過程中融虽,根據(jù)是否有對應的卷積操作參數(shù)判斷是否進行卷積享完,若不進行卷積,將輸出置為0有额。前向傳播時般又,輸入為低頻和高頻兩個feature map,輸出為低頻和高頻兩個feature map巍佑,輸入情況和參數(shù)配置應與通道數(shù)的配置匹配茴迁。
其他部分
使用MNIST數(shù)據(jù)集,構建了一個三層卷積+兩層全連接層的神經(jīng)網(wǎng)絡句狼,使用Adam優(yōu)化器訓練笋熬,代價函數(shù)使用交叉熵函數(shù),訓練3輪腻菇,最后在測試集上進行測試胳螟。
import torch as pt
import torchvision as ptv
# download dataset
train_dataset = ptv.datasets.MNIST("./",train=True,download=True,transform=ptv.transforms.ToTensor())
test_dataset = ptv.datasets.MNIST("./",train=False,download=True,transform=ptv.transforms.ToTensor())
train_loader = pt.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader = pt.utils.data.DataLoader(test_dataset,batch_size=64,shuffle=True)
# build network
class mnist_model(pt.nn.Module):
def __init__(self):
super(mnist_model, self).__init__()
self.conv1 = OctaveConv(0,1,8,8,kernel=3,stride=1,padding=1)
self.conv2 = OctaveConv(8,8,16,16,kernel=3,stride=1,padding=1)
self.conv3 = OctaveConv(16,16,0,64,kernel=3,stride=1,padding=1)
self.pool = pt.nn.MaxPool2d(2)
self.relu = pt.nn.ReLU()
self.fc1 = pt.nn.Linear(64*7*7,256)
self.fc2 = pt.nn.Linear(256,10)
def forward(self,x):
out = [self.pool(self.relu(i)) for i in self.conv1(0,x)]
out = self.conv2(*out)
_,out = self.conv3(*out)
out = self.fc1(self.pool(self.relu(out)).view(-1,64*7*7))
return self.fc2(out)
net = mnist_model().cuda()
# print(net)
# prepare training
def acc(outputs,label):
_,data = pt.max(outputs,dim=1)
return pt.mean((data.float()==label.float()).float()).item()
lossfunc = pt.nn.CrossEntropyLoss().cuda()
optimizer = pt.optim.Adam(net.parameters())
# train
for _ in range(3):
for i,(data,label) in enumerate(train_loader) :
optimizer.zero_grad()
# print(i,data,label)
data,label = data.cuda(),label.cuda()
outputs = net(data)
loss = lossfunc(outputs,label)
loss.backward()
optimizer.step()
if i % 100 == 0:
print(i,loss.cpu().data.item(),acc(outputs,label))
# test
acc_list = []
for i,(data,label) in enumerate(test_loader):
data,label = data.cuda(),label.cuda()
outputs = net(data)
acc_list.append(acc(outputs,label))
print("Test:",sum(acc_list)/len(acc_list))
# save
pt.save(net,"./model.pth")
最終獲得模型的準確率為0.988