10、S2Attention模塊
論文《S2-MLPV2: IMPROVED SPATIAL-SHIFT MLP ARCHITECTURE FOR VISION》
1瘤载、作用
S2-MLPv2是一個(gè)改進(jìn)的空間位移多層感知器(MLP)視覺骨架網(wǎng)絡(luò)否灾,旨在通過利用通道維度的擴(kuò)展和分割以及采用分割注意力(split-attention)操作來增強(qiáng)圖像識(shí)別準(zhǔn)確性。與傳統(tǒng)的S2-MLP相比鸣奔,S2-MLPv2在不同的部分執(zhí)行不同的空間位移操作墨技,然后利用分割注意力操作來融合這些部分。此外挎狸,該方法采用了較小尺度的圖像塊和金字塔結(jié)構(gòu)扣汪,進(jìn)一步提升圖像識(shí)別精度询件。
2拿撩、機(jī)制
1、特征圖擴(kuò)展和分割:
首先沿著通道維度擴(kuò)展特征圖瞪慧,然后將擴(kuò)展后的特征圖分割成多個(gè)部分。
2茅主、空間位移操作:
對(duì)每個(gè)分割的部分執(zhí)行不同的空間位移操作舞痰,以增強(qiáng)特征表征。
3诀姚、分割注意力操作:
使用分割注意力操作融合經(jīng)過空間位移處理的各個(gè)部分响牛,生成融合后的特征圖。
4赫段、金字塔結(jié)構(gòu):
采用較小尺度的圖像塊和層次化的金字塔結(jié)構(gòu)呀打,以捕獲更精細(xì)的視覺細(xì)節(jié),提高模型的識(shí)別精度糯笙。
3贬丛、獨(dú)特優(yōu)勢(shì)
1、增強(qiáng)的特征表征能力:
通過對(duì)特征圖進(jìn)行擴(kuò)展炬丸、分割和不同方向的空間位移操作瘫寝,S2-MLPv2能夠捕獲更加豐富的特征信息蜒蕾,提升模型的表征能力稠炬。
2、分割注意力機(jī)制:
利用分割注意力操作有效地融合了不同空間位移處理的特征咪啡,進(jìn)一步增強(qiáng)了特征的表征力首启。
3、金字塔結(jié)構(gòu)的應(yīng)用:
通過采用較小尺度的圖像塊和層次化的金字塔結(jié)構(gòu)撤摸,S2-MLPv2模型能夠更好地捕捉圖像中的細(xì)粒度細(xì)節(jié)毅桃,從而在圖像識(shí)別任務(wù)上達(dá)到更高的準(zhǔn)確率。
4准夷、高效的性能:
即使在沒有自注意力機(jī)制和額外訓(xùn)練數(shù)據(jù)的情況下钥飞,S2-MLPv2也能在ImageNet-1K基準(zhǔn)上達(dá)到83.6%的頂級(jí)1準(zhǔn)確率,表現(xiàn)優(yōu)于其他MLP模型衫嵌,同時(shí)參數(shù)數(shù)量更少读宙,表明其在實(shí)際部署中具有競(jìng)爭(zhēng)力。
4楔绞、代碼
import numpy as np
import torch
from torch import nn
from torch.nn import init
def spatial_shift1(x):
# 實(shí)現(xiàn)第一種空間位移结闸,位移圖像的四分之一塊
b, w, h, c = x.size()
# 以下四行代碼分別向左、向右酒朵、向上桦锄、向下移動(dòng)圖像的四分之一塊
x[:, 1:, :, :c // 4] = x[:, :w - 1, :, :c // 4]
x[:, :w - 1, :, c // 4:c // 2] = x[:, 1:, :, c // 4:c // 2]
x[:, :, 1:, c // 2:c * 3 // 4] = x[:, :, :h - 1, c // 2:c * 3 // 4]
x[:, :, :h - 1, 3 * c // 4:] = x[:, :, 1:, 3 * c // 4:]
return x
def spatial_shift2(x):
# 實(shí)現(xiàn)第二種空間位移,邏輯與spatial_shift1相似蔫耽,但位移方向不同
b, w, h, c = x.size()
# 對(duì)圖像的四分之一塊進(jìn)行空間位移
x[:, :, 1:, :c // 4] = x[:, :, :h - 1, :c // 4]
x[:, :, :h - 1, c // 4:c // 2] = x[:, :, 1:, c // 4:c // 2]
x[:, 1:, :, c // 2:c * 3 // 4] = x[:, :w - 1, :, c // 2:c * 3 // 4]
x[:, :w - 1, :, 3 * c // 4:] = x[:, 1:, :, 3 * c // 4:]
return x
class SplitAttention(nn.Module):
# 定義分割注意力模塊结耀,使用MLP層進(jìn)行特征轉(zhuǎn)換和注意力權(quán)重計(jì)算
def __init__(self, channel=512, k=3):
super().__init__()
self.channel = channel
self.k = k # 分割的塊數(shù)
# 定義MLP層和激活函數(shù)
self.mlp1 = nn.Linear(channel, channel, bias=False)
self.gelu = nn.GELU()
self.mlp2 = nn.Linear(channel, channel * k, bias=False)
self.softmax = nn.Softmax(1)
def forward(self, x_all):
# 計(jì)算分割注意力,并應(yīng)用于輸入特征
b, k, h, w, c = x_all.shape
x_all = x_all.reshape(b, k, -1, c) # 重塑維度
a = torch.sum(torch.sum(x_all, 1), 1) # 聚合特征
hat_a = self.mlp2(self.gelu(self.mlp1(a))) # 通過MLP計(jì)算注意力權(quán)重
hat_a = hat_a.reshape(b, self.k, c) # 調(diào)整形狀
bar_a = self.softmax(hat_a) # 應(yīng)用softmax獲取注意力分布
attention = bar_a.unsqueeze(-2) # 增加維度
out = attention * x_all # 將注意力權(quán)重應(yīng)用于特征
out = torch.sum(out, 1).reshape(b, h, w, c) # 聚合并調(diào)整形狀
return out
class S2Attention(nn.Module):
# S2注意力模塊,整合空間位移和分割注意力
def __init__(self, channels=512):
super().__init__()
# 定義MLP層
self.mlp1 = nn.Linear(channels, channels * 3)
self.mlp2 = nn.Linear(channels, channels)
self.split_attention = SplitAttention()
def forward(self, x):
b, c, w, h = x.size()
x = x.permute(0, 2, 3, 1) # 調(diào)整維度順序
x = self.mlp1(x) # 通過MLP層擴(kuò)展特征
x1 = spatial_shift1(x[:, :, :, :c]) # 應(yīng)用第一種空間位移
x2 = spatial_shift2(x[:, :, :, c:c * 2]) # 應(yīng)用第二種空間位移
x3 = x[:, :, :, c * 2:] # 保留原始特征的一部分
x_all = torch.stack([x1, x2, x3], 1) # 堆疊特征
a = self.split_attention(x_all) # 應(yīng)用分割注意力
x = self.mlp2(a) # 通過另一個(gè)MLP層縮減特征維度
x = x.permute(0, 3, 1, 2) # 調(diào)整維度順序回原始
return x
# 示例代碼
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7) # 創(chuàng)建輸入張量
s2att = S2Attention(channels=512) # 實(shí)例化S2注意力模塊
output = s2att(input) # 通過S2注意力模塊處理輸入
print(output.shape) # 打印輸出張量的形狀