目前transformer從語言到視覺任務(wù)的挑戰(zhàn)主要是由于這兩個領(lǐng)域間的差異:
- 1场勤、尺度變化大
- 2、高分辨率的輸入
為了解決以上兩點萨螺,我們提出了層級Transformer决乎,通過滑動窗口提取特征的方式將使得self.attention
的計算量降低為和圖像尺寸的線性相關(guān)。
簡介
我們觀察到將語言領(lǐng)域遷移到視覺領(lǐng)域的主要問題可以被總結(jié)為兩種:
- 1佃迄、不同于word token泼差,它的尺度是固定的,但是視覺領(lǐng)域的尺度變化非常劇烈
- 2呵俏、相對于上下文中的words堆缘,圖片有著更高分辨率的像素,計算量會隨著圖片的尺寸成平方倍的增長柴信。
結(jié)構(gòu)
以上是論文中結(jié)構(gòu)圖套啤,每一個
stage feature map
的尺寸都會減半。易知主要分為四個模塊:
- Patch Partition
- Linear Embedding
-
Swin Transformer Block(主要模塊):
- W-MSA:
regular window partition
和mutil-head self attention
- SW-MSA:
shift window partition
和mutil-head self attention
- W-MSA:
- Patch Merging
1随常、Patch Partition 和 Linear Embedding
在源碼實現(xiàn)中兩個模塊合二為一潜沦,稱為PatchEmbedding
。輸入圖片尺寸為 的RGB圖片绪氛,將4x4x3
視為一個patch唆鸡,用一個linear embedding 層將patch轉(zhuǎn)換為任意dimension(通道)的feature。源碼中使用4x4的stride=4的conv實現(xiàn)枣察。->
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
2歌殃、Swin Transformer Block
這是這篇論文的核心模塊。
- 如何解決計算量隨著輸入尺寸的增大成平方倍的增長致盟? 拋棄傳統(tǒng)的transformer基于全局來計算注意力的方法岩梳,將輸入劃分為不同的窗口,分別對每個窗口(window)施加注意力猿涨。
- 僅僅對窗口(window)單獨施加注意力握童,如何解決窗口(window)之間的信息流動?交替使用
W-MSA
和SW-MSA
模塊叛赚,因此SwinTransformerBlock
必須是偶數(shù)澡绩。如下圖所示:
整體流程如下:- 先對特征圖進(jìn)行LayerNorm
- 通過self.shift_size決定是否需要對特征圖進(jìn)行shift
- 然后將特征圖切成一個個窗口
- 計算Attention,通過self.attn_mask來區(qū)分Window Attention還是Shift Window Attention
- 將各個窗口合并回來
- 如果之前有做shift操作俺附,此時進(jìn)行reverse shift肥卡,把之前的shift操作恢復(fù)
- 做dropout和殘差連接
- 再通過一層LayerNorm+全連接層,以及dropout和殘差連接
2.1事镣、window partition
window partition
分為regular window partition
和shift window partition
步鉴,對應(yīng)于W-MSA
和SW-MSA
。通過窗口劃分,將輸入的feature map
轉(zhuǎn)換為num_windows*B, window_size, window_size, C
氛琢,其中 num_windows = H*W / window_size / window_size
只嚣。然后resize 到 num_windows*B, window_size*window_size, C
進(jìn)行attention。源碼如下:
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
-
Layer1
是regular window partition
艺沼,窗口的大小是4x4册舞,將圖片分成了4個窗口。 -
Layer2
是shift window partition
障般,為了保證不同窗口的信息流動调鲸,起始點從(windows_size//2, windows_size//2)開始進(jìn)行劃分,將圖片分成了9個窗口挽荡∶晔可以看到移位后的窗口包含了原本相鄰窗口的元素。但是同時也引入了新的問題定拟,窗口大小不一致的問題于微,有2x2、2x4青自、4x2株依、4x4,最簡單的方法就是統(tǒng)一padding到4x4延窜,但是窗口數(shù)量由4增加至9恋腕,計算量變大了2.25倍。因此作者提出了cycle shift
去解決這個問題逆瑞。
以下的示例圖片來自于:https://mp.weixin.qq.com/s/8x1pgRLWaMkFSjT7zjhTgQ
首先對窗口進(jìn)行shift window partition
荠藤,得到左圖部分。不進(jìn)行padding
获高,而是采用滾動的方式調(diào)整窗口哈肖,源碼中用torch.roll()
函數(shù)實現(xiàn),得到了右圖念秧。這時候得到了和regular window partition
一樣的4個2x2大小的window
淤井,不同的是,在一個2x2的windows
區(qū)域內(nèi)是不連續(xù)的(index
不一樣)出爹。
我們希望在計算Attention的時候庄吼,讓具有相同index
進(jìn)行計算缎除,而忽略不同index QK計算結(jié)果严就。因此我們?yōu)槠涮砑由蟤ask。源碼計算mask
實現(xiàn)如下:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
2.2器罐、W-MSA
由regular window partition
模塊 和 mutil-head self attention
模塊組成梢为。
W-MSA相比于直接使用MSA主要是為了降低計算量。傳統(tǒng)的transformer都是基于全局來計算注意力,因此計算復(fù)雜度非常高铸董。但是swin transformer通過對每個窗口施加注意力祟印,從而減少了計算量。attention的主要計算過程如下:
假設(shè)每一個window
的區(qū)塊大小為粟害,輸入的尺寸為蕴忆,以下為原始的和的計算復(fù)雜度:
- 對于:對輸入的
feature map
做全局attention,悲幅、套鹅、的計算量分別是,和的計算量分別是汰具,的計算量是卓鹿。 - 對于:在
windows
內(nèi)的大小的區(qū)域內(nèi)做attention,feature map
會被劃分為個windows
留荔,每個windows
的尺寸為吟孙。、聚蝶、的計算量分別是杰妓,和的計算量的分別是,的計算量是碘勉。因此和輸入尺寸成線性關(guān)系稚失。
2.3、SW-MSA
雖然降低了計算量恰聘,但是由于將attention限制在window
內(nèi)句各,因此不重合的window
缺乏聯(lián)系,限制了模型的性能晴叨。因此提出了模塊凿宾。在MSA
前面加上一個cycle shift window partition
3、Patch Merging
swin transformer中沒有使用pooling
進(jìn)行下采樣兼蕊,而是使用了和yolov5中的focus
層進(jìn)行feature map
的下采樣初厚。 -> ,在使用一個全連接層->孙技,在一個stage中將feature map的高寬減半产禾,通道數(shù)翻倍。
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
不同尺寸的網(wǎng)絡(luò)結(jié)構(gòu)
基準(zhǔn)模型結(jié)構(gòu)命名為Swin-B
牵啦,模型大小和計算復(fù)雜度和ViT-B
/DeiT-B
相近亚情。同時我們也提出了Swin-T
,Swin-S
和 Swin-L
哈雏,分別對應(yīng)0.25×
, 0.5×
和 2×
倍的模型尺寸和計算復(fù)雜度楞件。Swin-T
和 Swin-S
的計算復(fù)雜度分別和ResNet-50
衫生、ResNet-101
相近。默認(rèn)設(shè)置為7土浸。代表第一層隱藏層的數(shù)量罪针。
- Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
- Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
- Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
- Swin-L: C = 192, layer numbers ={2, 2, 18, 2}
不同數(shù)據(jù)集的實驗結(jié)果
1、ImageNet
2黄伊、COCO Object Detection
- 在不同的模型上使用swin transformer 作為特征提取網(wǎng)絡(luò)
- 在cascade mask rcnn上使用swin transformer 作為backbone
-
直接對比其他目標(biāo)檢測網(wǎng)絡(luò)