Swin Transformer

目前transformer從語言到視覺任務(wù)的挑戰(zhàn)主要是由于這兩個領(lǐng)域間的差異:

  • 1场勤、尺度變化大
  • 2、高分辨率的輸入

為了解決以上兩點萨螺,我們提出了層級Transformer决乎,通過滑動窗口提取特征的方式將使得self.attention的計算量降低為和圖像尺寸的線性相關(guān)。

簡介

我們觀察到將語言領(lǐng)域遷移到視覺領(lǐng)域的主要問題可以被總結(jié)為兩種:

  • 1佃迄、不同于word token泼差,它的尺度是固定的,但是視覺領(lǐng)域的尺度變化非常劇烈
  • 2呵俏、相對于上下文中的words堆缘,圖片有著更高分辨率的像素,計算量會隨著圖片的尺寸成平方倍的增長柴信。

結(jié)構(gòu)

image.png

以上是論文中結(jié)構(gòu)圖套啤,每一個stage feature map的尺寸都會減半。易知主要分為四個模塊:

  • Patch Partition
  • Linear Embedding
  • Swin Transformer Block(主要模塊)
    • W-MSA:regular window partitionmutil-head self attention
    • SW-MSA: shift window partitionmutil-head self attention
  • Patch Merging

1随常、Patch Partition 和 Linear Embedding

在源碼實現(xiàn)中兩個模塊合二為一潜沦,稱為PatchEmbedding。輸入圖片尺寸為H \times W \times 3 的RGB圖片绪氛,將4x4x3視為一個patch唆鸡,用一個linear embedding 層將patch轉(zhuǎn)換為任意dimension(通道)的feature。源碼中使用4x4的stride=4的conv實現(xiàn)枣察。-> \frac{H}{4} \times \frac{W}{4} \times C

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim
       
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

2歌殃、Swin Transformer Block

這是這篇論文的核心模塊。

  • 如何解決計算量隨著輸入尺寸的增大成平方倍的增長致盟? 拋棄傳統(tǒng)的transformer基于全局來計算注意力的方法岩梳,將輸入劃分為不同的窗口,分別對每個窗口(window)施加注意力猿涨。
  • 僅僅對窗口(window)單獨施加注意力握童,如何解決窗口(window)之間的信息流動?交替使用W-MSASW-MSA模塊叛赚,因此SwinTransformerBlock必須是偶數(shù)澡绩。如下圖所示:
    image.png

    整體流程如下:
    • 先對特征圖進(jìn)行LayerNorm
    • 通過self.shift_size決定是否需要對特征圖進(jìn)行shift
    • 然后將特征圖切成一個個窗口
    • 計算Attention,通過self.attn_mask來區(qū)分Window Attention還是Shift Window Attention
    • 將各個窗口合并回來
    • 如果之前有做shift操作俺附,此時進(jìn)行reverse shift肥卡,把之前的shift操作恢復(fù)
    • 做dropout和殘差連接
    • 再通過一層LayerNorm+全連接層,以及dropout和殘差連接

2.1事镣、window partition

window partition分為regular window partitionshift window partition步鉴,對應(yīng)于W-MSASW-MSA。通過窗口劃分,將輸入的feature map B \times H \times W \times C轉(zhuǎn)換為num_windows*B, window_size, window_size, C氛琢,其中 num_windows = H*W / window_size / window_size只嚣。然后resize 到 num_windows*B, window_size*window_size, C進(jìn)行attention。源碼如下:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
image.png
  • Layer1regular window partition艺沼,窗口的大小是4x4册舞,將圖片分成了4個窗口。
  • Layer2shift window partition障般,為了保證不同窗口的信息流動调鲸,起始點從(windows_size//2, windows_size//2)開始進(jìn)行劃分,將圖片分成了9個窗口挽荡∶晔可以看到移位后的窗口包含了原本相鄰窗口的元素。但是同時也引入了新的問題定拟,窗口大小不一致的問題于微,有2x2、2x4青自、4x2株依、4x4,最簡單的方法就是統(tǒng)一padding到4x4延窜,但是窗口數(shù)量由4增加至9恋腕,計算量變大了2.25倍。因此作者提出了cycle shift去解決這個問題逆瑞。
    image.png

    以下的示例圖片來自于:https://mp.weixin.qq.com/s/8x1pgRLWaMkFSjT7zjhTgQ
    image.png

    首先對窗口進(jìn)行shift window partition荠藤,得到左圖部分。不進(jìn)行padding获高,而是采用滾動的方式調(diào)整窗口哈肖,源碼中用torch.roll()函數(shù)實現(xiàn),得到了右圖念秧。這時候得到了和regular window partition一樣的4個2x2大小的window淤井,不同的是,在一個2x2的windows區(qū)域內(nèi)是不連續(xù)的(index不一樣)出爹。
    image.png

    我們希望在計算Attention的時候庄吼,讓具有相同index Q \times K^T進(jìn)行計算缎除,而忽略不同index QK計算結(jié)果严就。因此我們?yōu)槠涮砑由蟤ask。源碼計算mask實現(xiàn)如下:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

2.2器罐、W-MSA

regular window partition模塊 和 mutil-head self attention模塊組成梢为。
W-MSA相比于直接使用MSA主要是為了降低計算量。傳統(tǒng)的transformer都是基于全局來計算注意力,因此計算復(fù)雜度非常高铸董。但是swin transformer通過對每個窗口施加注意力祟印,從而減少了計算量。attention的主要計算過程如下:
Q=x \times W^q \\ K=x \times W^k \\ V=x \times W^v \\ attn=Q \times K^T \\ Z = attn \times V \\ output = Z \times W^z
假設(shè)每一個window的區(qū)塊大小為M\times M粟害,輸入的尺寸為h \times w蕴忆,以下為原始的MSAW-MSA的計算復(fù)雜度:
\Omega(\mathrm{MSA})=4 h w C^{2}+2(h w)^{2} C\\ \Omega(\mathrm{W}-\mathrm{MSA})=4 h w C^{2}+2 M^{2} h w C

  • 對于MSA:對輸入的feature map做全局attention,Q悲幅、K套鹅、V的計算量分別是hwC^2attnZ的計算量分別是(hw)^2C汰具,output的計算量是hwC^2卓鹿。
  • 對于W-MSA:在windows內(nèi)的M \times M大小的區(qū)域內(nèi)做attention,feature map會被劃分為\frac{h}{M} \times \frac{w}{M}windows留荔,每個windows的尺寸為M \times M吟孙。QK聚蝶、V的計算量分別是hwC^2杰妓,attnZ的計算量的分別是M^2 hwCoutput的計算量是hwC^2碘勉。因此和輸入尺寸成線性關(guān)系稚失。

2.3、SW-MSA

雖然W-MSA降低了計算量恰聘,但是由于將attention限制在window內(nèi)句各,因此不重合的window缺乏聯(lián)系,限制了模型的性能晴叨。因此提出了SW-MSA模塊凿宾。在MSA前面加上一個cycle shift window partition

3、Patch Merging

swin transformer中沒有使用pooling進(jìn)行下采樣兼蕊,而是使用了和yolov5中的focus層進(jìn)行feature map的下采樣初厚。H\times W\times C -> \frac{H}{2} \times \frac{W}{2} \times 4C,在使用一個全連接層->\frac{H}{2} \times \frac{W}{2} \times 2C孙技,在一個stage中將feature map的高寬減半产禾,通道數(shù)翻倍。

image.png

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

不同尺寸的網(wǎng)絡(luò)結(jié)構(gòu)

基準(zhǔn)模型結(jié)構(gòu)命名為Swin-B牵啦,模型大小和計算復(fù)雜度和ViT-B/DeiT-B相近亚情。同時我們也提出了Swin-TSwin-SSwin-L哈雏,分別對應(yīng)0.25×, 0.5×倍的模型尺寸和計算復(fù)雜度楞件。Swin-TSwin-S的計算復(fù)雜度分別和ResNet-50衫生、ResNet-101相近。M默認(rèn)設(shè)置為7土浸。C代表第一層隱藏層的數(shù)量罪针。

  • Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
  • Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
  • Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
  • Swin-L: C = 192, layer numbers ={2, 2, 18, 2}
image.png

不同數(shù)據(jù)集的實驗結(jié)果

1、ImageNet

image.png

2黄伊、COCO Object Detection

  • 在不同的模型上使用swin transformer 作為特征提取網(wǎng)絡(luò)
  • 在cascade mask rcnn上使用swin transformer 作為backbone
  • 直接對比其他目標(biāo)檢測網(wǎng)絡(luò)


    image.png

3泪酱、Semantic Segmentation on ADE20K

image.png

消融實驗

1、shifted windows 的有效性

image.png

2还最、position bias

image.png

3西篓、sliding window 和 shift window的速度和性能

image.png
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市憋活,隨后出現(xiàn)的幾起案子岂津,更是在濱河造成了極大的恐慌,老刑警劉巖悦即,帶你破解...
    沈念sama閱讀 217,277評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件吮成,死亡現(xiàn)場離奇詭異,居然都是意外死亡辜梳,警方通過查閱死者的電腦和手機(jī)粱甫,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評論 3 393
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來作瞄,“玉大人茶宵,你說我怎么就攤上這事∽诨樱” “怎么了乌庶?”我有些...
    開封第一講書人閱讀 163,624評論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長契耿。 經(jīng)常有香客問我瞒大,道長,這世上最難降的妖魔是什么搪桂? 我笑而不...
    開封第一講書人閱讀 58,356評論 1 293
  • 正文 為了忘掉前任透敌,我火速辦了婚禮,結(jié)果婚禮上踢械,老公的妹妹穿的比我還像新娘酗电。我一直安慰自己,他們只是感情好内列,可當(dāng)我...
    茶點故事閱讀 67,402評論 6 392
  • 文/花漫 我一把揭開白布撵术。 她就那樣靜靜地躺著,像睡著了一般德绿。 火紅的嫁衣襯著肌膚如雪荷荤。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,292評論 1 301
  • 那天移稳,我揣著相機(jī)與錄音蕴纳,去河邊找鬼。 笑死个粱,一個胖子當(dāng)著我的面吹牛古毛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播都许,決...
    沈念sama閱讀 40,135評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼稻薇,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了胶征?” 一聲冷哼從身側(cè)響起塞椎,我...
    開封第一講書人閱讀 38,992評論 0 275
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎睛低,沒想到半個月后案狠,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,429評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡钱雷,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,636評論 3 334
  • 正文 我和宋清朗相戀三年骂铁,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片罩抗。...
    茶點故事閱讀 39,785評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡拉庵,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出套蒂,到底是詐尸還是另有隱情钞支,我是刑警寧澤,帶...
    沈念sama閱讀 35,492評論 5 345
  • 正文 年R本政府宣布操刀,位于F島的核電站伸辟,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏馍刮。R本人自食惡果不足惜信夫,卻給世界環(huán)境...
    茶點故事閱讀 41,092評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望卡啰。 院中可真熱鬧静稻,春花似錦、人聲如沸匈辱。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽亡脸。三九已至押搪,卻和暖如春树酪,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背大州。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評論 1 269
  • 我被黑心中介騙來泰國打工续语, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人厦画。 一個月前我還...
    沈念sama閱讀 47,891評論 2 370
  • 正文 我出身青樓疮茄,卻偏偏與公主長得像,于是被迫代替她去往敵國和親根暑。 傳聞我的和親對象是個殘疾皇子力试,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,713評論 2 354

推薦閱讀更多精彩內(nèi)容