CoAtNet: 90.88% Paperwithcode榜單第一异赫,層層深入考慮模型設(shè)計(jì)

【GiantPandaCV導(dǎo)語(yǔ)】CoAt=Convolution + Attention,paperwithcode榜單第一名把沼,通過(guò)結(jié)合卷積與Transformer實(shí)現(xiàn)性能上的突破都哭,方法部分設(shè)計(jì)非常規(guī)整秩伞,層層深入考慮模型的架構(gòu)設(shè)計(jì)。

image

引言

Transformer模型的容量大欺矫,由于缺乏正確的歸納偏置纱新,泛化能力要比卷積網(wǎng)絡(luò)差。

提出了CoAtNets模型族:

  • 深度可分離卷積與self-attention能夠通過(guò)簡(jiǎn)單的相對(duì)注意力來(lái)統(tǒng)一化穆趴。
  • 疊加卷積層和注意層在提高泛化能力和效率方面具有驚人的效果

方法

這部分主要關(guān)注如何將conv與transformer以一種最優(yōu)的方式結(jié)合:

  • 在基礎(chǔ)的計(jì)算塊中脸爱,如果合并卷積與自注意力操作。
  • 如何組織不同的計(jì)算模塊來(lái)構(gòu)建整個(gè)網(wǎng)絡(luò)未妹。

合并卷積與自注意力

卷積方面谷歌使用的是經(jīng)典的MBConv簿废, 使用深度可分離卷積來(lái)捕獲空間之間的交互空入。

卷積操作的表示:\mathcal{L}(i)代表i周邊的位置,也即卷積處理的感受野族檬。

y_{i}=\sum_{j \in \mathcal{L}(i)} w_{i-j} \odot x_{j} \quad \text { (depthwise convolution) }

自注意力表示:\mathcal{G}表示全局空間感受野歪赢。

y_{i}=\sum_{j \in \mathcal{G}} \underbrace{\frac{\exp \left(x_{i}^{\top} x_{j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top} x_{k}\right)}}_{A_{i, j}} x_{j} \quad \text { (self-attention) }

融合方法一:先求和,再softmax

y_{i}^{\text {post }}=\sum_{j \in \mathcal{G}}\left(\frac{\exp \left(x_{i}^{\top} x_{j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top} x_{k}\right)}+w_{i-j}\right) x_{j}

融合方法二:先softmax单料,再求和

y_{i}^{\text {pre }}=\sum_{j \in \mathcal{G}} \frac{\exp \left(x_{i}^{\top} x_{j}+w_{i-j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top} x_{k}+w_{i-k}\right)} x_{j}

出于參數(shù)量埋凯、計(jì)算兩方面的考慮,論文打算采用第二種融合方法扫尖。

垂直布局設(shè)計(jì)

決定好合并卷積與注意力的方式后應(yīng)該考慮如何構(gòu)建網(wǎng)絡(luò)整體架構(gòu)递鹉,主要有三個(gè)方面的考量:

  • 使用降采樣降低空間維度大小,然后使用global relative attention藏斩。
  • 使用局部注意力,強(qiáng)制全局感受野限制在一定范圍內(nèi)却盘。典型代表有:
    • Scaling local self-attention for parameter efficient visual backbone
    • Swin Transformer
  • 使用某種線性注意力方法來(lái)取代二次的softmax attention狰域。典型代表有:
    • Efficient Attention
    • Transformers are rnns
    • Rethinking attention with performers

第二種方法實(shí)現(xiàn)效率不夠高,第三種方法性能不夠好黄橘,因此采用第一種方法兆览,如何設(shè)計(jì)降采樣的方式也有幾種方案:

  • 使用卷積配合stride進(jìn)行降采樣。
  • 使用pooling操作完成降采樣塞关,構(gòu)建multi-stage網(wǎng)絡(luò)范式抬探。
  • 根據(jù)第一種方案提出ViT_{REL}, 即使用ViT Stem,直接堆疊L層Transformer block使用relative attention帆赢。
  • 根據(jù)第二種方案小压,采用multi-stage方案提出模型組:S_0,...,S_4,如下圖所示:
image

S_o-S_2采用卷積以及MBConv,從S_2-S_4的幾個(gè)模塊采用Transformer 結(jié)構(gòu)椰于。具體Transformer內(nèi)部有以下幾個(gè)變體:C代表卷積怠益,T代表Transformer

  • C-C-C-C
  • C-C-C-T
  • C-C-T-T
  • C-T-T-T

初步測(cè)試模型泛化能力

image

泛化能力排序?yàn)椋海ㄗC明架構(gòu)中還是需要存在想當(dāng)比例的卷積操作)

image

初步測(cè)試模型容量

主要是從JFT以及ImageNet-1k上不同的表現(xiàn)來(lái)判定的,排序結(jié)果為:

image

測(cè)試模型遷移能力

image

為了進(jìn)一步比較CCTT與CTTT瘾婿,進(jìn)行了遷移能力測(cè)試蜻牢,發(fā)現(xiàn)CCTT能夠超越CTTT。

最終CCTT勝出偏陪!

實(shí)驗(yàn)

與SOTA模型比較結(jié)果:

image

實(shí)驗(yàn)結(jié)果:

image

消融實(shí)驗(yàn):

image
image
image

代碼

淺層使用的MBConv模塊如下:

class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)

主要關(guān)注Attention Block設(shè)計(jì)抢呆,引入Relative Position:

class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

參考

https://arxiv.org/pdf/2106.04803.pdf

https://github.com/chinhsuanwu/coatnet-pytorch

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市笛谦,隨后出現(xiàn)的幾起案子抱虐,更是在濱河造成了極大的恐慌,老刑警劉巖揪罕,帶你破解...
    沈念sama閱讀 217,734評(píng)論 6 505
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件梯码,死亡現(xiàn)場(chǎng)離奇詭異宝泵,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)轩娶,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,931評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門(mén)儿奶,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人鳄抒,你說(shuō)我怎么就攤上這事闯捎。” “怎么了许溅?”我有些...
    開(kāi)封第一講書(shū)人閱讀 164,133評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵瓤鼻,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我贤重,道長(zhǎng)茬祷,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,532評(píng)論 1 293
  • 正文 為了忘掉前任并蝗,我火速辦了婚禮祭犯,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘滚停。我一直安慰自己沃粗,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,585評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布键畴。 她就那樣靜靜地躺著最盅,像睡著了一般。 火紅的嫁衣襯著肌膚如雪起惕。 梳的紋絲不亂的頭發(fā)上涡贱,一...
    開(kāi)封第一講書(shū)人閱讀 51,462評(píng)論 1 302
  • 那天,我揣著相機(jī)與錄音疤祭,去河邊找鬼盼产。 笑死,一個(gè)胖子當(dāng)著我的面吹牛勺馆,可吹牛的內(nèi)容都是我干的戏售。 我是一名探鬼主播,決...
    沈念sama閱讀 40,262評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼草穆,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼灌灾!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起悲柱,我...
    開(kāi)封第一講書(shū)人閱讀 39,153評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤锋喜,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體嘿般,經(jīng)...
    沈念sama閱讀 45,587評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡段标,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,792評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了炉奴。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片逼庞。...
    茶點(diǎn)故事閱讀 39,919評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖瞻赶,靈堂內(nèi)的尸體忽然破棺而出赛糟,到底是詐尸還是另有隱情,我是刑警寧澤砸逊,帶...
    沈念sama閱讀 35,635評(píng)論 5 345
  • 正文 年R本政府宣布璧南,位于F島的核電站,受9級(jí)特大地震影響师逸,放射性物質(zhì)發(fā)生泄漏司倚。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,237評(píng)論 3 329
  • 文/蒙蒙 一篓像、第九天 我趴在偏房一處隱蔽的房頂上張望对湃。 院中可真熱鬧,春花似錦遗淳、人聲如沸。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,855評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至脂男,卻和暖如春养叛,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背宰翅。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,983評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工弃甥, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人汁讼。 一個(gè)月前我還...
    沈念sama閱讀 48,048評(píng)論 3 370
  • 正文 我出身青樓淆攻,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親嘿架。 傳聞我的和親對(duì)象是個(gè)殘疾皇子瓶珊,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,864評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容