Transformer在圖像中的運(yùn)用(三)SwinTransformer原理及代碼解讀

說之前先提一個視頻這個視頻還是很好的將transformer機(jī)制的變遷及未來的趨勢很詳細(xì)的說明了一下我覺得蠻有感觸的,建議可以看看這里首先提一下代碼及其對應(yīng)的論文視頻地址蚌卤。
paper:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
code: microsoft/Swin-Transformer
可以理解SwinTransformer是新一代的特征提取神器奶镶,很多榜單都有它的影子羹奉,這里我們可以理解為是一種新的`backbone鲫尊,如下所示支持多種下游任務(wù)笔横。相對比之前說的Transformer 在圖像中的運(yùn)用(一)VIT(Transformers for Image Recognition at Scale)論文及代碼解讀 之前需要每個像素

一竞滓、 原理

在Transformer種,如果圖像像素太多則我們需要構(gòu)建出更多的特征序列吹缔,這樣就會導(dǎo)致我們的效率降低商佑,所以我們采用了窗口以及分層的形式來替代長序列。

1.1 整體網(wǎng)絡(luò)架構(gòu)

  • 得到各Patch特征構(gòu)建的序列(注意這里先卷積得到特征圖厢塘,再對特征圖進(jìn)行切分成Patch
  • 分成計算attention(逐步下采樣過程)
  • 其中Block是最核心的茶没, 對attention的計算方法進(jìn)行了改進(jìn)

由下面的圖我們可以看出特征圖大小不斷減小, 但是特征圖的通道數(shù)不斷增加晚碾。


Swin整體網(wǎng)絡(luò)結(jié)構(gòu)
1.1.1 Patch Embedding

下面舉一個例子比如輸入的圖像數(shù)據(jù)為(224, 224, 3)抓半, 輸出(3136, 96)相當(dāng)于序列長度為3136, 每個向量是96維特征格嘁。這里的卷積核我們使用Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))笛求。所以3136就是卷積(224 / 4) * (224 / 4)得到的。

這時候我們得到的輸入特征圖為(56, 56, 96)糕簿, 如果默認(rèn)窗口大小為7涣易,所以總共可以分為8 * 8個窗口。則輸出的特征圖為(64(8*8), 7, 7, 96) 之前單位是序列冶伞, 現(xiàn)在單位是窗口(工64個窗口)新症。

1.1.2 Swin Transformer Block

下面我們來看下上面圖中對應(yīng)的Transformer Blocks是什么樣子, 如下圖所示响禽。

Swin Transformer Block

上圖的兩個組合是串聯(lián)而成的Block徒爹,對于左邊為基于窗口的注意力計算W-MSA(multi-head self attention modules with regular),對于右邊為窗口滑動后重新計算注意力SW-MSA(multi-head self attention modules with shifted windowing)

1. W-MSA(計算每個不同窗口自身的注意力機(jī)制(下面不同顏色的矩形代表不同的窗口))


對得到的窗口芋类,計算各個窗口自己的自注意力得分隆嗅,qkv三個矩陣放在一起得到(3, 64, 3, 49, 32)

  • 3個矩陣
  • 64個窗口
  • 3個heads
  • 7*7的窗口大小(每個窗口有49個token即49個像素)
  • 96/3=32個單head特征

所以attention結(jié)果為(64, 3, 49, 49) 每個頭都會得出每個窗口內(nèi)的自注意力(3為頭侯繁,這里可以理解為不同窗口不同頭對應(yīng)窗口的不同token之間的注意力)胖喳。
通過上面的計算我們可以得到新的特征(64, 49, 96), 之后再進(jìn)行reshape操作將其還原到(56, 56, 96)大小特征圖目的就是為了還原輸入特征圖大小(但是其已經(jīng)計算過了attentation), 因為再transformer要經(jīng)過多層輸入大小與輸出大小一般都是相同的。
\color{red}{這里順便提下主要就是有這篇論文windows機(jī)制相對于VIT來說需要對特征圖上的每個像素相互進(jìn)行QKV運(yùn)算贮竟,進(jìn)行信息溝通丽焊。 而這里采用了windows機(jī)制较剃,} \color{red}{將特征圖分成一個個windows,我們只在每個windows內(nèi)部進(jìn)行MSA技健,可以大大減少計算量, 但是有一個缺點(diǎn)就是窗口之間是無法進(jìn)行信息交互的写穴, 從而導(dǎo)致我們的感受野變小}下面給出了省出來的計算量。

h,w,c分別代表特征圖高度寬度和深度雌贱, M代表窗口大小


矩陣計算評估計算量

這里計算量公式可以參考這篇文章Swin-Transformer網(wǎng)絡(luò)結(jié)構(gòu)詳解啊送。

2. SW-MSA(計算不同窗口之間的注意力機(jī)制)
上面W-MSA是只是知道窗口內(nèi)部的特征,但是我們不知道窗口之間的特征我們可以用SW-MSA機(jī)制來彌補(bǔ)欣孤。這里的主要區(qū)別就是S(shift滑動)馋没,我們?nèi)绾稳プ龌瑒幽?

transformer偏移

上圖中我們可以看出網(wǎng)格由紅色網(wǎng)格(b)移動到了藍(lán)色網(wǎng)格(c),我們需要通過將上方藍(lán)色區(qū)域移動到下方降传,左邊紅色區(qū)域移動到右邊披泪。這么做的目的如下:
https://www.zhihu.com/question/492057377/answer/2213112296

記住這里是半個窗口, 還有一點(diǎn)記住是向下取整(如窗口大小3, 則移動為1)
說白了就是換一換所有不同窗口的匹配對搬瑰,使得模型更加健壯款票,這就是滑動操作。

由于不同Windows之間互不重疊泽论,每次進(jìn)行自注意力計算時很顯然就丟失了Windows之間的信息艾少,那么如何在降低計算量的同時保留全局信息呢?Shifted Window應(yīng)運(yùn)而生翼悴。


上面這張圖可以用如下的示意圖理解:




但是還有一個問題原來是4個windows缚够,但是移動之后變成了9個windows,為了能夠做到并行計算應(yīng)該如何解決呢鹦赎?我們可以做如下偏移方法谍椅。



則得到如下效果:

Attention Mask 機(jī)制
因為我們區(qū)域(5,3) (7,1) (8,6,2,0)本來是之間不想連接的,所以我們要單獨(dú)計算各自的區(qū)域的MSA古话。我們借用區(qū)域(5,3)舉例雏吭,這篇博客對于這個解釋非常棒Swin-Transformer網(wǎng)絡(luò)結(jié)構(gòu)詳解, 如下所示:


這里我們僅僅計算區(qū)域5的信息而不想引入區(qū)域3的信息陪踩,我們通過掩碼mask的方式即可計算杖们。因為本來公式中\alpha是一個很小的數(shù)字如果我們減去100, 再經(jīng)過softmax可以理解為就是為0了。
示例1

示例2

注意肩狂,全部計算完后需要將數(shù)據(jù)挪回到原來的位置上摘完。下面演示一下整體流程
流程1

流程2

流程3

因為要經(jīng)過多層transformer通過W-MSA以及SW-MSA輸出的大小保持不變(56*56*96)

1.1.2 Relative Position Bias


下面我們看下加相對偏置與不加相對偏置的效果
Table 4. Ablation study on the shifted windows approach and different position embedding methods on three benchmarks, using the Swin-T architecture. w/o shifting: all self-attention modules adopt regular window partitioning, without shifting; abs. pos.: absolute position embedding term of ViT; rel. pos.: the default settings with an additional relative position bias term (see Eq. (4)); app.: the first scaled dot-product term in Eq. (4).

發(fā)現(xiàn)使用rel.pos相對位置偏置更加合理。
上述相對位置小矩陣攤平再拼接就得到下面的大矩陣

如何將一元坐標(biāo)轉(zhuǎn)成二元坐標(biāo)呢傻谁?我們看作者如何去做的孝治。
偏移從0開始行、列標(biāo)加上M(窗口大小->2*2)-1

行標(biāo)乘上2M-1

行列相加

上述就可以得出我們下面的公式B
1.1.3 PatchMerging

network structure

這里我們就要說到這里Patch Merging操作。它的作用可以縮小特征圖大小谈飒,提升特征圖的通道數(shù)(這里也可以理解為就是下采樣操作)岂座。

二、 代碼邏輯解讀

# file: models/swin_transformer.py
# class: SwinTransformer
class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        # 這里的drop rate是會隨著模型不同stage不斷提升到我們設(shè)定的rate
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), # 我們的深度不斷乘上2
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, # 這里transoformer和patchMerge是連在一起的最后一個沒有transformer只有patchMerge
                               use_checkpoint=use_checkpoint) 
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

2.1 input embedding

# file: models/swin_transformer.py
# class: SwinTransformer
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

這里我們的輸入大小為4(batch), 3(channel), 224(width), 224(height)步绸, 接著進(jìn)入到self.patch_embed操作掺逼。

# file: swin_transformer.py
# class: PatchEmbed
    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

和以往vit一樣吃媒,這里做self.proj就是進(jìn)行卷積操作

# 卷積核大小為4瓤介, stride也是為4, 這樣會導(dǎo)致輸出特征圖為原來的額1/4 -> (56 * 56)
# 輸入輸出channel分別為3和96
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
# 這部分flatten操作是將我們的寬度高度展平赘那,輸出shape為(4, 3136(56*56), 96)
x = self.proj(x).flatten(2).transpose(1, 2)

在經(jīng)過self.norm對應(yīng)的操作為nn.LayerNorm刑桑。
接著我們會經(jīng)過我們的self.pos_drop(x), 這里的self.pos_dropnn.Dropout(p=drop_rate)操作募舟。
接著進(jìn)行下面各個層的操作(別忘記此時我們的輸入shape為
(4, 3136(56*56), 96))

      for layer in self.layers:
            x = layer(x)

2.2 Basiclayer

接著上面我們看一下self.layers是如何構(gòu)建的

# file: models/swin_transformer.py
# class: SwinTransformer
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

# file: models/swin_transformer.py
# class: BasicLayer
class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

2.2 SwinTransformerBlock

# file: models/swin_transformer.py
# class: SwinTransformerBlock
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x
2.2.1 W-MSA及SW-MSA輸入

我們知道輸入是先經(jīng)過W-MSA再經(jīng)過SW-MSA
經(jīng)過W-MSA是沒有做任何處理的即代碼中shifted_x = x祠斧, 但是對于W-MSA是通過torch.roll的操作進(jìn)行的,代碼如下所示:

shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

這里有1和2拱礁,分別表示要左右上下移動琢锋,還有就是這里的self.shift_size為負(fù)數(shù),說明移動完處理之后這里還是要復(fù)原的
如下代碼所示:

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C # 第一個block得到(4呢灶, 56吴超, 56, 96)

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

最終得到的shape依然是我們原來的輸入(4, 3136, 96)
接著下進(jìn)入如下操作

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

比如一開始第一個block我們的得到的第一個輸出shape為(256, 7, 7, 96) 然后我們得到第二個windows為(256, 49, 96)鸯乃。相當(dāng)于256windows鲸阻, 每個windows49像素, 每個像素96個維度缨睡。
對于上面代碼中的window_partition代碼如下:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

這里的xshape為(4, 8, 7, 8, 7, 96)鸟悴, 我們可以得到windows的數(shù)量為
(H/windows_size) * (W/windows_size) * batch, 這里W,H一開始都為56, windows_size7, 這里設(shè)置的batch4, 因此這里我們最終得到的windows shape為(256 7 7 96)

2.2.2 Attention機(jī)制

上面的輸出之后我們要經(jīng)過我們的Attention機(jī)制奖年。

attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

如果x_windowsW-MSAself.atten_maskNone细诸, 否則會加入atten_mask, 具體代碼如下(詳細(xì)理解可以參考bilibili, 在31分鐘左右 說的非常好):

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

對應(yīng)上述代碼簡單點(diǎn)就是再不需要做內(nèi)積的地方填入-100, 這樣經(jīng)過softmax的時候就被自動設(shè)置為0了。
下面先給出我們進(jìn)入attention的代碼陋守。


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

可以看出首先會經(jīng)過self.qkv生成我們的q, k, v矩陣揍堰,內(nèi)部代碼就是很簡單的nn.Linear,

# 這里的`dim`, 我們設(shè)置為96
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# 這里我們得到的self.qkv shape 為[3, 256, 3, 49, 32] 這里的3分別對應(yīng)qkv, 
# 256個窗口分別做attention嗅义, 
# 剛開始head為3屏歹, 
# 每個窗口有49個元素, 
# 32 代表每個頭有32個維度
# q, k, v shape分別為[256, 3, 49, 32]
q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) 

接著用attn = softmax(qk^T/ \sqrt{V})得到我們的注意力機(jī)制之碗,如下所示蝙眶,這里的self.scale可以理解為我們的v

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

最終讓我們attentionposition bias相加, 如下所示獲得我們最終的atten

attn = attn + relative_position_bias.unsqueeze(0)

這里的position bias 下面解釋幽纷。

2.2.3 Relative Position Bias Table

我們上面說了相對位置偏置矩陣的大小為(2M-1) * (2M-1), 這里的Mwindows-size大惺剿(詳細(xì)理解可以參考bilibili, 在56分鐘左右 說的非常好)。

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

下面就是之前說的經(jīng)softmax友浸, 如果mask不相同索引的我們設(shè)置為-100, 經(jīng)過softmax計算就變成了0.

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

在經(jīng)過

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

操作之后我們得到的attention之后的向量為(256, 49. 96)峰尝, self.proj_drop為drop_out

2.2.4 FFN(殘差操作)
圖中紅色藍(lán)色框的部分做了兩種殘差

最后要做一次殘差連接

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

上述說完就完成了我們SwinTransformerBlock的部分了收恢。

3. Patch Merging

Swin整體網(wǎng)絡(luò)結(jié)構(gòu)

通過結(jié)構(gòu)圖我們可以看出經(jīng)過Swin Transformer Block之后會經(jīng)過Patch Merging層武学,原理如下圖所示。


對應(yīng)的代碼如下:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

4. 輸出層

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
x = self.norm(x)  # B L C
x = self.avgpool(x.transpose(1, 2))  # B C 1
x = torch.flatten(x, 1)

經(jīng)過平均池化將原來shape由(4, 49, 768)轉(zhuǎn)成(4, 768, 1)后面再接一下全連接層
nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()即可伦意。

參考
[1] Swin Transformer
[2] 如何看待swin transformer成為ICCV2021的 best paper火窒?
[3] Swin-Transformer網(wǎng)絡(luò)結(jié)構(gòu)詳解

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市驮肉,隨后出現(xiàn)的幾起案子熏矿,更是在濱河造成了極大的恐慌,老刑警劉巖离钝,帶你破解...
    沈念sama閱讀 221,635評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件票编,死亡現(xiàn)場離奇詭異,居然都是意外死亡卵渴,警方通過查閱死者的電腦和手機(jī)慧域,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,543評論 3 399
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來奖恰,“玉大人吊趾,你說我怎么就攤上這事∩校” “怎么了论泛?”我有些...
    開封第一講書人閱讀 168,083評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長蛹屿。 經(jīng)常有香客問我屁奏,道長,這世上最難降的妖魔是什么错负? 我笑而不...
    開封第一講書人閱讀 59,640評論 1 296
  • 正文 為了忘掉前任坟瓢,我火速辦了婚禮,結(jié)果婚禮上犹撒,老公的妹妹穿的比我還像新娘折联。我一直安慰自己,他們只是感情好识颊,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,640評論 6 397
  • 文/花漫 我一把揭開白布诚镰。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪清笨。 梳的紋絲不亂的頭發(fā)上月杉,一...
    開封第一講書人閱讀 52,262評論 1 308
  • 那天,我揣著相機(jī)與錄音抠艾,去河邊找鬼苛萎。 笑死,一個胖子當(dāng)著我的面吹牛检号,可吹牛的內(nèi)容都是我干的腌歉。 我是一名探鬼主播,決...
    沈念sama閱讀 40,833評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼谨敛,長吁一口氣:“原來是場噩夢啊……” “哼究履!你這毒婦竟也來了滤否?” 一聲冷哼從身側(cè)響起脸狸,我...
    開封第一講書人閱讀 39,736評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎藐俺,沒想到半個月后炊甲,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,280評論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡锯厢,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,369評論 3 340
  • 正文 我和宋清朗相戀三年柴罐,在試婚紗的時候發(fā)現(xiàn)自己被綠了决帖。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,503評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡颈娜,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出浙宜,到底是詐尸還是另有隱情官辽,我是刑警寧澤,帶...
    沈念sama閱讀 36,185評論 5 350
  • 正文 年R本政府宣布粟瞬,位于F島的核電站同仆,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏裙品。R本人自食惡果不足惜俗批,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,870評論 3 333
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望市怎。 院中可真熱鬧岁忘,春花似錦、人聲如沸区匠。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,340評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至蝠筑,卻和暖如春狞膘,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背什乙。 一陣腳步聲響...
    開封第一講書人閱讀 33,460評論 1 272
  • 我被黑心中介騙來泰國打工挽封, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人臣镣。 一個月前我還...
    沈念sama閱讀 48,909評論 3 376
  • 正文 我出身青樓辅愿,卻偏偏與公主長得像,于是被迫代替她去往敵國和親忆某。 傳聞我的和親對象是個殘疾皇子点待,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,512評論 2 359