Squeeze-and-Excitation Networks (SENet)獲得了2017年ImageNet的分類(lèi)冠軍暖庄。
論文地址:https://arxiv.org/abs/1709.01507
本文簡(jiǎn)單介紹了SENet這篇文章宵距,并附上了SE-ResNet基于MXNet(主要基于是gluon接口)的代碼實(shí)現(xiàn)魏身。
SENet中寥院,Squeeze和Excitation是兩個(gè)關(guān)鍵性操作肤无,示意圖如下:
第一步:Squeeze是在空間維度對(duì)特征進(jìn)行壓縮刁俭,即Global Average Pooling隅津。
第二步:Excitation是用Sigmoid Function為每個(gè)特征通道生成權(quán)重,權(quán)重表示特征通道間的相關(guān)性悯恍。
第三步:Reweight操作库糠,將Excitation生成的權(quán)重通過(guò)乘法逐通道加權(quán)到CNN提取的特征圖上,完成在通道維度上的對(duì)原始特征的重標(biāo)定涮毫。
SE模塊可以簡(jiǎn)單地嵌入到任何神經(jīng)網(wǎng)絡(luò)當(dāng)中,下面是SE-ResNet的網(wǎng)絡(luò)結(jié)構(gòu)圖:
直接上代碼:
這是原始的Residual Block贷屎,我們拿來(lái)做個(gè)參考
class Residual(nn.HybridBlock):
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(Residual, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
def forward(self, X):
Y = nd.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return nd.relu(Y + X)
重點(diǎn)在這里罢防,SE-Module,為了方便理解我們把Squeeze和Excitation單獨(dú)寫(xiě):
def Attention(num_channels):
net = nn.HybridSequential()
with net.name_scope():
net.add(
nn.GlobalAvgPool2D(),
nn.Dense(num_channels),
nn.Activation('relu'),
nn.Dense(num_channels),
nn.Activation('sigmoid')
)
return net
再將SE-Module嵌入到Residual Block里面去唉侄,做一個(gè)broadcast_multiply
class SEResidual(nn.HybridBlock):
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(SEResidual, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
self.weight = Attention(num_channels)
def forward(self, X):
Y = nd.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
W = Y
for layer in self.weight: #W就是Attention的權(quán)重
W = layer(W)
if self.conv3:
X = self.conv3(X)
Y = nd.broadcast_mul(Y,nd.reshape(W,shape=(-1,num_channels,1,1)))
return nd.relu(Y + X)
最后再用SE-Residual Block搭積木就好啦咒吐。
啾咪~