CeiT:訓練更快的多層特征抽取ViT

【GiantPandaCV導語】來自商湯和南洋理工的工作配喳,也是使用卷積來增強模型提出low-level特征的能力赂韵,增強模型獲取局部性的能力送淆,核心貢獻是LCA模塊章鲤,可以用于捕獲多層特征表示致板。

引言

針對先前Transformer架構(gòu)需要大量額外數(shù)據(jù)或者額外的監(jiān)督(Deit),才能獲得與卷積神經(jīng)網(wǎng)絡結(jié)構(gòu)相當?shù)男阅苡搅瑸榱丝朔@種缺陷斟或,提出結(jié)合CNN來彌補Transformer的缺陷,提出了CeiT:

(1)設計Image-to-Tokens模塊來從low-level特征中得到embedding集嵌。

(2)將Transformer中的Feed Forward模塊替換為Locally-enhanced Feed-Forward(LeFF)模塊萝挤,增加了相鄰token之間的相關(guān)性。

(3)使用Layer-wise Class Token Attention(LCA)捕獲多層的特征表示根欧。

經(jīng)過以上修改怜珍,可以發(fā)現(xiàn)模型效率方面以及泛化能力得到了提升,收斂性也有所改善凤粗,如下圖所示:

image

方法

1. Image-to-Tokens

image

使用卷積+池化來取代原先ViT中7x7的大型patch酥泛。

\mathbf{x}^{\prime}=\mathrm{I} 2 \mathrm{~T}(\mathbf{x})=\operatorname{MaxPool}(\operatorname{BN}(\operatorname{Conv}(\mathbf{x})))

2. LeFF

image

將tokens重新拼成feature map,然后使用深度可分離卷積添加局部性的處理,然后再使用一個Linear層映射至tokens柔袁。

\begin{aligned} \mathbf{x}_{c}^{h}, \mathbf{x}_{p}^{h} &=\operatorname{Split}\left(\mathbf{x}_{t}^{h}\right) \\ \mathbf{x}_{p}^{l_{1}} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{Linear}\left(\left(\mathbf{x}_{p}^{h}\right)\right)\right)\right.\\ \mathbf{x}_{p}^{s} &=\operatorname{SpatialRestore}\left(\mathbf{x}_{p}^{l_{1}}\right) \\ \mathbf{x}_{p}^gmkyk2s &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{DWConv}\left(\mathbf{x}_{p}^{s}\right)\right)\right) \\ \mathbf{x}_{p}^{f} &=\operatorname{Flatten}\left(\mathbf{x}_{p}^kamicos\right) \\ \mathbf{x}_{p}^{l_{2}} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{Linear} 2\left(\mathbf{x}_{p}^{f}\right)\right)\right) \\ \mathbf{x}_{t}^{h+1} &=\operatorname{Concat}\left(\mathbf{x}_{c}^{h}, \mathbf{x}_{p}^{l_{2}}\right) \end{aligned}

3. LCA

前兩個都比較常規(guī)呆躲,最后一個比較有特色,經(jīng)過所有Transformer層以后使用的Layer-wise Class-token Attention捶索,如下圖所示:

image

LCA模塊會將所有Transformer Block中得到的class token作為輸入插掂,然后再在其基礎上使用一個MSA+FFN得到最終的logits輸出。作者認為這樣可以獲取多尺度的表征腥例。

實驗

SOTA比較:

image

I2T消融實驗:

image

LeFF消融實驗:

image

LCA有效性比較:

image

收斂速度比較:

image

代碼

模塊1:I2T Image-to-Token

  # IoT
  self.conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4),
      nn.BatchNorm2d(out_channels),
      nn.MaxPool2d(pool_kernel, stride)    
  )
  
  feature_size = image_size // 4

  assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
  num_patches = (feature_size // patch_size) ** 2
  patch_dim = out_channels * patch_size ** 2
  self.to_patch_embedding = nn.Sequential(
      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
      nn.Linear(patch_dim, dim),
  )

模塊2:LeFF

class LeFF(nn.Module):
    
    def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
        super().__init__()
        
        scale_dim = dim*scale
        self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(scale_dim),
                                    nn.GELU(),
                                    Rearrange('b c (h w) -> b c h w', h=14, w=14)
                                    )
        
        self.depth_conv =  nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
                          nn.BatchNorm2d(scale_dim),
                          nn.GELU(),
                          Rearrange('b c h w -> b (h w) c', h=14, w=14)
                          )
        
        self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(dim),
                                    nn.GELU(),
                                    Rearrange('b c n -> b n c')
                                    )
        
    def forward(self, x):
        x = self.up_proj(x)
        x = self.depth_conv(x)
        x = self.down_proj(x)
        return x
        
class TransformerLeFF(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel)))
            ]))
    def forward(self, x):
        c = list()
        for attn, leff in self.layers:
            x = attn(x)
            cls_tokens = x[:, 0]
            c.append(cls_tokens)
            x = leff(x[:, 1:])
            x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1) 
        return x, torch.stack(c).transpose(0, 1)

模塊3:LCA

class LCAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class LCA(nn.Module):
    # I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection, 
    # so I removed it, althougth with Residual connection also this code will work.
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layers.append(nn.ModuleList([
                PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x[:, -1].unsqueeze(1)
            x = x[:, -1].unsqueeze(1) + ff(x)
        return x

參考

https://arxiv.org/abs/2103.11816

https://github.com/rishikksh20/CeiT-pytorch/blob/master/ceit.py

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末辅甥,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子燎竖,更是在濱河造成了極大的恐慌璃弄,老刑警劉巖,帶你破解...
    沈念sama閱讀 217,734評論 6 505
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件构回,死亡現(xiàn)場離奇詭異谢揪,居然都是意外死亡,警方通過查閱死者的電腦和手機捐凭,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,931評論 3 394
  • 文/潘曉璐 我一進店門拨扶,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人茁肠,你說我怎么就攤上這事患民。” “怎么了垦梆?”我有些...
    開封第一講書人閱讀 164,133評論 0 354
  • 文/不壞的土叔 我叫張陵匹颤,是天一觀的道長。 經(jīng)常有香客問我托猩,道長印蓖,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,532評論 1 293
  • 正文 為了忘掉前任京腥,我火速辦了婚禮赦肃,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘公浪。我一直安慰自己他宛,他們只是感情好,可當我...
    茶點故事閱讀 67,585評論 6 392
  • 文/花漫 我一把揭開白布欠气。 她就那樣靜靜地躺著厅各,像睡著了一般。 火紅的嫁衣襯著肌膚如雪预柒。 梳的紋絲不亂的頭發(fā)上队塘,一...
    開封第一講書人閱讀 51,462評論 1 302
  • 那天袁梗,我揣著相機與錄音,去河邊找鬼憔古。 笑死遮怜,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的投放。 我是一名探鬼主播,決...
    沈念sama閱讀 40,262評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼适贸,長吁一口氣:“原來是場噩夢啊……” “哼灸芳!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起拜姿,我...
    開封第一講書人閱讀 39,153評論 0 276
  • 序言:老撾萬榮一對情侶失蹤烙样,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后蕊肥,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體谒获,經(jīng)...
    沈念sama閱讀 45,587評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,792評論 3 336
  • 正文 我和宋清朗相戀三年壁却,在試婚紗的時候發(fā)現(xiàn)自己被綠了批狱。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,919評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡展东,死狀恐怖赔硫,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情盐肃,我是刑警寧澤爪膊,帶...
    沈念sama閱讀 35,635評論 5 345
  • 正文 年R本政府宣布,位于F島的核電站砸王,受9級特大地震影響推盛,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜谦铃,卻給世界環(huán)境...
    茶點故事閱讀 41,237評論 3 329
  • 文/蒙蒙 一耘成、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧驹闰,春花似錦凿跳、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,855評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至骡显,卻和暖如春疆栏,著一層夾襖步出監(jiān)牢的瞬間曾掂,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,983評論 1 269
  • 我被黑心中介騙來泰國打工壁顶, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留珠洗,地道東北人。 一個月前我還...
    沈念sama閱讀 48,048評論 3 370
  • 正文 我出身青樓若专,卻偏偏與公主長得像许蓖,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子调衰,可洞房花燭夜當晚...
    茶點故事閱讀 44,864評論 2 354

推薦閱讀更多精彩內(nèi)容