【GiantPandaCV導語】來自商湯和南洋理工的工作配喳,也是使用卷積來增強模型提出low-level特征的能力赂韵,增強模型獲取局部性的能力送淆,核心貢獻是LCA模塊章鲤,可以用于捕獲多層特征表示致板。
引言
針對先前Transformer架構(gòu)需要大量額外數(shù)據(jù)或者額外的監(jiān)督(Deit),才能獲得與卷積神經(jīng)網(wǎng)絡結(jié)構(gòu)相當?shù)男阅苡搅瑸榱丝朔@種缺陷斟或,提出結(jié)合CNN來彌補Transformer的缺陷,提出了CeiT:
(1)設計Image-to-Tokens模塊來從low-level特征中得到embedding集嵌。
(2)將Transformer中的Feed Forward模塊替換為Locally-enhanced Feed-Forward(LeFF)模塊萝挤,增加了相鄰token之間的相關(guān)性。
(3)使用Layer-wise Class Token Attention(LCA)捕獲多層的特征表示根欧。
經(jīng)過以上修改怜珍,可以發(fā)現(xiàn)模型效率方面以及泛化能力得到了提升,收斂性也有所改善凤粗,如下圖所示:
方法
1. Image-to-Tokens
使用卷積+池化來取代原先ViT中7x7的大型patch酥泛。
2. LeFF
將tokens重新拼成feature map,然后使用深度可分離卷積添加局部性的處理,然后再使用一個Linear層映射至tokens柔袁。
3. LCA
前兩個都比較常規(guī)呆躲,最后一個比較有特色,經(jīng)過所有Transformer層以后使用的Layer-wise Class-token Attention捶索,如下圖所示:
LCA模塊會將所有Transformer Block中得到的class token作為輸入插掂,然后再在其基礎上使用一個MSA+FFN得到最終的logits輸出。作者認為這樣可以獲取多尺度的表征腥例。
實驗
SOTA比較:
I2T消融實驗:
LeFF消融實驗:
LCA有效性比較:
收斂速度比較:
代碼
模塊1:I2T Image-to-Token
# IoT
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4),
nn.BatchNorm2d(out_channels),
nn.MaxPool2d(pool_kernel, stride)
)
feature_size = image_size // 4
assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (feature_size // patch_size) ** 2
patch_dim = out_channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
模塊2:LeFF
class LeFF(nn.Module):
def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
super().__init__()
scale_dim = dim*scale
self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
Rearrange('b n c -> b c n'),
nn.BatchNorm1d(scale_dim),
nn.GELU(),
Rearrange('b c (h w) -> b c h w', h=14, w=14)
)
self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
nn.BatchNorm2d(scale_dim),
nn.GELU(),
Rearrange('b c h w -> b (h w) c', h=14, w=14)
)
self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
Rearrange('b n c -> b c n'),
nn.BatchNorm1d(dim),
nn.GELU(),
Rearrange('b c n -> b n c')
)
def forward(self, x):
x = self.up_proj(x)
x = self.depth_conv(x)
x = self.down_proj(x)
return x
class TransformerLeFF(nn.Module):
def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel)))
]))
def forward(self, x):
c = list()
for attn, leff in self.layers:
x = attn(x)
cls_tokens = x[:, 0]
c.append(cls_tokens)
x = leff(x[:, 1:])
x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1)
return x, torch.stack(c).transpose(0, 1)
模塊3:LCA
class LCAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class LCA(nn.Module):
# I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection,
# so I removed it, althougth with Residual connection also this code will work.
def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.layers.append(nn.ModuleList([
PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x[:, -1].unsqueeze(1)
x = x[:, -1].unsqueeze(1) + ff(x)
return x
參考
https://arxiv.org/abs/2103.11816
https://github.com/rishikksh20/CeiT-pytorch/blob/master/ceit.py