【GiantPandaCV導(dǎo)語(yǔ)】CoAt=Convolution + Attention,paperwithcode榜單第一名把沼,通過(guò)結(jié)合卷積與Transformer實(shí)現(xiàn)性能上的突破都哭,方法部分設(shè)計(jì)非常規(guī)整秩伞,層層深入考慮模型的架構(gòu)設(shè)計(jì)。
引言
Transformer模型的容量大欺矫,由于缺乏正確的歸納偏置纱新,泛化能力要比卷積網(wǎng)絡(luò)差。
提出了CoAtNets模型族:
- 深度可分離卷積與self-attention能夠通過(guò)簡(jiǎn)單的相對(duì)注意力來(lái)統(tǒng)一化穆趴。
- 疊加卷積層和注意層在提高泛化能力和效率方面具有驚人的效果
方法
這部分主要關(guān)注如何將conv與transformer以一種最優(yōu)的方式結(jié)合:
- 在基礎(chǔ)的計(jì)算塊中脸爱,如果合并卷積與自注意力操作。
- 如何組織不同的計(jì)算模塊來(lái)構(gòu)建整個(gè)網(wǎng)絡(luò)未妹。
合并卷積與自注意力
卷積方面谷歌使用的是經(jīng)典的MBConv簿废, 使用深度可分離卷積來(lái)捕獲空間之間的交互空入。
卷積操作的表示:代表i周邊的位置,也即卷積處理的感受野族檬。
自注意力表示:表示全局空間感受野歪赢。
融合方法一:先求和,再softmax
融合方法二:先softmax单料,再求和
出于參數(shù)量埋凯、計(jì)算兩方面的考慮,論文打算采用第二種融合方法扫尖。
垂直布局設(shè)計(jì)
決定好合并卷積與注意力的方式后應(yīng)該考慮如何構(gòu)建網(wǎng)絡(luò)整體架構(gòu)递鹉,主要有三個(gè)方面的考量:
- 使用降采樣降低空間維度大小,然后使用global relative attention藏斩。
- 使用局部注意力,強(qiáng)制全局感受野限制在一定范圍內(nèi)却盘。典型代表有:
- Scaling local self-attention for parameter efficient visual backbone
- Swin Transformer
- 使用某種線性注意力方法來(lái)取代二次的softmax attention狰域。典型代表有:
- Efficient Attention
- Transformers are rnns
- Rethinking attention with performers
第二種方法實(shí)現(xiàn)效率不夠高,第三種方法性能不夠好黄橘,因此采用第一種方法兆览,如何設(shè)計(jì)降采樣的方式也有幾種方案:
- 使用卷積配合stride進(jìn)行降采樣。
- 使用pooling操作完成降采樣塞关,構(gòu)建multi-stage網(wǎng)絡(luò)范式抬探。
- 根據(jù)第一種方案提出
, 即使用ViT Stem,直接堆疊L層Transformer block使用relative attention帆赢。
- 根據(jù)第二種方案小压,采用multi-stage方案提出模型組:
,如下圖所示:
采用卷積以及MBConv,從
的幾個(gè)模塊采用Transformer 結(jié)構(gòu)椰于。具體Transformer內(nèi)部有以下幾個(gè)變體:C代表卷積怠益,T代表Transformer
- C-C-C-C
- C-C-C-T
- C-C-T-T
- C-T-T-T
初步測(cè)試模型泛化能力
泛化能力排序?yàn)椋海ㄗC明架構(gòu)中還是需要存在想當(dāng)比例的卷積操作)
初步測(cè)試模型容量
主要是從JFT以及ImageNet-1k上不同的表現(xiàn)來(lái)判定的,排序結(jié)果為:
測(cè)試模型遷移能力
為了進(jìn)一步比較CCTT與CTTT瘾婿,進(jìn)行了遷移能力測(cè)試蜻牢,發(fā)現(xiàn)CCTT能夠超越CTTT。
最終CCTT勝出偏陪!
實(shí)驗(yàn)
與SOTA模型比較結(jié)果:
實(shí)驗(yàn)結(jié)果:
消融實(shí)驗(yàn):
代碼
淺層使用的MBConv模塊如下:
class MBConv(nn.Module):
def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
super().__init__()
self.downsample = downsample
stride = 1 if self.downsample == False else 2
hidden_dim = int(inp * expansion)
if self.downsample:
self.pool = nn.MaxPool2d(3, 2, 1)
self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
# down-sample in the first conv
nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SE(inp, hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)
def forward(self, x):
if self.downsample:
return self.proj(self.pool(x)) + self.conv(x)
else:
return x + self.conv(x)
主要關(guān)注Attention Block設(shè)計(jì)抢呆,引入Relative Position:
class Attention(nn.Module):
def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == inp)
self.ih, self.iw = image_size
self.heads = heads
self.scale = dim_head ** -0.5
# parameter table of relative position bias
self.relative_bias_table = nn.Parameter(
torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))
coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
coords = torch.flatten(torch.stack(coords), 1)
relative_coords = coords[:, :, None] - coords[:, None, :]
relative_coords[0] += self.ih - 1
relative_coords[1] += self.iw - 1
relative_coords[0] *= 2 * self.iw - 1
relative_coords = rearrange(relative_coords, 'c h w -> h w c')
relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
self.register_buffer("relative_index", relative_index)
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, oup),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# Use "gather" for more efficiency on GPUs
relative_bias = self.relative_bias_table.gather(
0, self.relative_index.repeat(1, self.heads))
relative_bias = rearrange(
relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
dots = dots + relative_bias
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out