基于detectron2實(shí)現(xiàn)的PVT開(kāi)源了,歡迎star:https://github.com/xiaohu2015/pvt_detectron2
自從ViT之后衍锚,關(guān)于vision transformer的研究呈井噴式爆發(fā)友题,從思路上分主要沿著兩大個(gè)方向,一是提升ViT在圖像分類(lèi)的效果戴质;二就是將ViT應(yīng)用在其它圖像任務(wù)中度宦,比如分割和檢測(cè)任務(wù)上,這里介紹的PVT(Pyramid Vision Transformer) 就屬于后者告匠。PVT相比ViT引入了和CNN類(lèi)似的金字塔結(jié)構(gòu)戈抄,使得PVT像CNN那樣作為backbone應(yīng)用在dense prediction任務(wù)(分割和檢測(cè)等)。
CNN結(jié)構(gòu)常用的是一種金字塔架構(gòu)后专,如上圖所示划鸽,CNN網(wǎng)絡(luò)一般可以劃分為不同的stage,在每個(gè)stage開(kāi)始時(shí)戚哎,特征圖的長(zhǎng)和寬均減半漾稀,而特征維度(channel)擴(kuò)寬2倍。這主要有兩個(gè)方面的考慮建瘫,一是采用stride=2的卷積或者池化層對(duì)特征降維可以增大感受野崭捍,另外也可以減少計(jì)算量,但同時(shí)空間上的損失用channel維度的增加來(lái)彌補(bǔ)啰脚。但是ViT本身就是全局感受野殷蛇,所以ViT就比較簡(jiǎn)單直接了,直接將輸入圖像tokens化后就不斷堆積相同的transformer encoders橄浓,這應(yīng)用在圖像分類(lèi)上是沒(méi)有太大的問(wèn)題粒梦。但是如果應(yīng)用在密集任務(wù)上,會(huì)遇到問(wèn)題:一是分割和檢測(cè)往往需要較大的分辨率輸入荸实,當(dāng)輸入圖像增大時(shí)匀们,ViT的計(jì)算量會(huì)急劇上升;二是ViT直接采用較大patchs進(jìn)行token化准给,如采用16x16大小那么得到的粗粒度特征泄朴,對(duì)密集任務(wù)來(lái)說(shuō)損失較大。這正是PVT想要解決的問(wèn)題露氮,PVT采用和CNN類(lèi)似的架構(gòu)祖灰,將網(wǎng)絡(luò)分成不同的stages,每個(gè)stage相比之前的stage特征圖的維度是減半的畔规,這意味著tokens數(shù)量減少4倍局扶,具體結(jié)構(gòu)如下:
每個(gè)stage的輸入都是一個(gè)維度的3-D特征圖,對(duì)于第1個(gè)stage,輸入就是RGB圖像三妈,對(duì)于其它stage可以將tokens重新reshape成3-D特征圖畜埋。在每個(gè)stage開(kāi)始,首先像ViT一樣對(duì)輸入圖像進(jìn)行token化畴蒲,即進(jìn)行patch embedding由捎,patch大小均采用2x2大小(第1個(gè)stage的patch大小是4x4)饿凛,這意味著該stage最終得到的特征圖維度是減半的狞玛,tokens數(shù)量對(duì)應(yīng)減少4倍。PVT共4個(gè)stage涧窒,這和ResNet類(lèi)似心肪,4個(gè)stage得到的特征圖相比原圖大小分別是1/4,1/8纠吴,1/16和1/32硬鞍。由于不同的stage的tokens數(shù)量不一樣,所以每個(gè)stage采用不同的position embeddings戴已,在patch embed之后加上各自的position embedding固该,當(dāng)輸入圖像大小變化時(shí),position embeddings也可以通過(guò)插值來(lái)自適應(yīng)糖儡。
不同的stage的tokens數(shù)量不同伐坏,越靠前的stage的patchs數(shù)量越多,我們知道self-attention的計(jì)算量與sequence的長(zhǎng)度的平方成正比握联,如果PVT和ViT一樣桦沉,所有的transformer encoders均采用相同的參數(shù),那么計(jì)算量肯定是無(wú)法承受的金闽。PVT為了減少計(jì)算量纯露,不同的stages采用的網(wǎng)絡(luò)參數(shù)是不同的。PVT不同系列的網(wǎng)絡(luò)參數(shù)設(shè)置如下所示代芜,這里為patch的size埠褪,為特征維度大小,為MHA(multi-head attention)的heads數(shù)量挤庇,為FFN的擴(kuò)展系數(shù)钞速,transformer中默認(rèn)為4。
可以見(jiàn)到隨著stage罚随,特征的維度是逐漸增加的玉工,比如stage1的特征維度只有64羽资,而stage4的特征維度為512淘菩,這種設(shè)置和常規(guī)的CNN網(wǎng)絡(luò)設(shè)置是類(lèi)似的,所以前面stage的patchs數(shù)量雖然大,但是特征維度小潮改,所以計(jì)算量也不是太大狭郑。不同體量的PVT其差異主要體現(xiàn)在各個(gè)stage的transformer encoder的數(shù)量差異。
PVT為了進(jìn)一步減少計(jì)算量汇在,將常規(guī)的multi-head attention (MHA)用spatial-reduction attention (SRA)來(lái)替換翰萨。SRA的核心是減少attention層的key和value對(duì)的數(shù)量,常規(guī)的MHA在attention層計(jì)算時(shí)key和value對(duì)的數(shù)量為sequence的長(zhǎng)度糕殉,但是SRA將其降低為原來(lái)的亩鬼。SRA的具體結(jié)構(gòu)如下所示:
在實(shí)現(xiàn)上,首先將維度為的patch embeddings通過(guò)reshape變換到維度為的3-D特征圖阿蝶,然后均分大小為的patchs雳锋,每個(gè)patchs通過(guò)線性變換將得到維度為的patch embeddings(這里實(shí)現(xiàn)上其實(shí)和patch emb操作類(lèi)似,等價(jià)于一個(gè)卷積操作)羡洁,最后應(yīng)用一個(gè)layer norm層玷过,這樣就可以大大降低K和V的數(shù)量。具體實(shí)現(xiàn)代碼如下:
class?Attention(nn.Module):
????def?__init__(self,?dim,?num_heads=8,?qkv_bias=False,?qk_scale=None,?attn_drop=0.,?proj_drop=0.,?sr_ratio=1):
????????super().__init__()
????????assert?dim?%?num_heads?==?0,?f"dim?{dim}?should?be?divided?by?num_heads?{num_heads}."
????????self.dim?=?dim
????????self.num_heads?=?num_heads
????????head_dim?=?dim?//?num_heads
????????self.scale?=?qk_scale?or?head_dim?**?-0.5
????????self.q?=?nn.Linear(dim,?dim,?bias=qkv_bias)
????????self.kv?=?nn.Linear(dim,?dim?*?2,?bias=qkv_bias)
????????self.attn_drop?=?nn.Dropout(attn_drop)
????????self.proj?=?nn.Linear(dim,?dim)
????????self.proj_drop?=?nn.Dropout(proj_drop)
????????self.sr_ratio?=?sr_ratio
????????#?實(shí)現(xiàn)上這里等價(jià)于一個(gè)卷積層
????????if?sr_ratio?>?1:
????????????self.sr?=?nn.Conv2d(dim,?dim,?kernel_size=sr_ratio,?stride=sr_ratio)
????????????self.norm?=?nn.LayerNorm(dim)
????def?forward(self,?x,?H,?W):
????????B,?N,?C?=?x.shape
????????q?=?self.q(x).reshape(B,?N,?self.num_heads,?C?//?self.num_heads).permute(0,?2,?1,?3)
????????if?self.sr_ratio?>?1:
????????????x_?=?x.permute(0,?2,?1).reshape(B,?C,?H,?W)
????????????x_?=?self.sr(x_).reshape(B,?C,?-1).permute(0,?2,?1)?#?這里x_.shape?=?(B,?N/R^2,?C)
????????????x_?=?self.norm(x_)
????????????kv?=?self.kv(x_).reshape(B,?-1,?2,?self.num_heads,?C?//?self.num_heads).permute(2,?0,?3,?1,?4)
????????else:
????????????kv?=?self.kv(x).reshape(B,?-1,?2,?self.num_heads,?C?//?self.num_heads).permute(2,?0,?3,?1,?4)
????????k,?v?=?kv[0],?kv[1]
????????attn?=?(q?@?k.transpose(-2,?-1))?*?self.scale
????????attn?=?attn.softmax(dim=-1)
????????attn?=?self.attn_drop(attn)
????????x?=?(attn?@?v).transpose(1,?2).reshape(B,?N,?C)
????????x?=?self.proj(x)
????????x?=?self.proj_drop(x)
????????return?x
從PVT的網(wǎng)絡(luò)設(shè)置上筑煮,前面的stage的取較大的值辛蚊,比如stage1的,說(shuō)明這里直接將Q和V的數(shù)量直接減為原來(lái)的1/64真仲,這個(gè)就大大降低計(jì)算量了袋马。
PVT具體到圖像分類(lèi)任務(wù)上,和ViT一樣也通過(guò)引入一個(gè)class token來(lái)實(shí)現(xiàn)最后的分類(lèi)秸应,不過(guò)PVT是在最后的一個(gè)stage才引入:
????def?forward_features(self,?x):
????????B?=?x.shape[0]
????????#?stage?1
????????x,?(H,?W)?=?self.patch_embed1(x)
????????x?=?x?+?self.pos_embed1
????????x?=?self.pos_drop1(x)
????????for?blk?in?self.block1:
????????????x?=?blk(x,?H,?W)
????????x?=?x.reshape(B,?H,?W,?-1).permute(0,?3,?1,?2).contiguous()
????????#?stage?2
????????x,?(H,?W)?=?self.patch_embed2(x)
????????x?=?x?+?self.pos_embed2
????????x?=?self.pos_drop2(x)
????????for?blk?in?self.block2:
????????????x?=?blk(x,?H,?W)
????????x?=?x.reshape(B,?H,?W,?-1).permute(0,?3,?1,?2).contiguous()
????????#?stage?3
????????x,?(H,?W)?=?self.patch_embed3(x)
????????x?=?x?+?self.pos_embed3
????????x?=?self.pos_drop3(x)
????????for?blk?in?self.block3:
????????????x?=?blk(x,?H,?W)
????????x?=?x.reshape(B,?H,?W,?-1).permute(0,?3,?1,?2).contiguous()
????????#?stage?4
????????x,?(H,?W)?=?self.patch_embed4(x)
????????cls_tokens?=?self.cls_token.expand(B,?-1,?-1)?#?引入class?token
????????x?=?torch.cat((cls_tokens,?x),?dim=1)
????????x?=?x?+?self.pos_embed4
????????x?=?self.pos_drop4(x)
????????for?blk?in?self.block4:
????????????x?=?blk(x,?H,?W)
????????x?=?self.norm(x)
????????return?x[:,?0]
具體到分類(lèi)任務(wù)上飞蛹,PVT在ImageNet上的Top-1 Acc其實(shí)是和ViT差不多的。其實(shí)PVT最重要的應(yīng)用是作為dense任務(wù)如分割和檢測(cè)的backbone灸眼,一方面PVT通過(guò)一些巧妙的設(shè)計(jì)使得對(duì)于分辨率較大的輸入圖像卧檐,其模型計(jì)算量不像ViT那么大,論文中比較了ViT-Small/16 焰宣,ViT-Small霉囚,PVT-Small和ResNet50四種網(wǎng)絡(luò)在不同的輸入scale下的GFLOPs,可以看到PVT相比ViT要好不少匕积,當(dāng)輸入scale=640時(shí)盈罐,PVT-Small和ResNet50的計(jì)算量是類(lèi)似的,但是如果到更大的scale闪唆,PVT的增長(zhǎng)速度就遠(yuǎn)超過(guò)ResNet50了盅粪。
PVT的另外一個(gè)相比ViT的優(yōu)勢(shì)就是其可以輸出不同scale的特征圖,這對(duì)于分割和檢測(cè)都是非常重要的悄蕾。因?yàn)槟壳按蟛糠值姆指詈蜋z測(cè)模型都是采用FPN結(jié)構(gòu)票顾,而PVT這個(gè)特性可以使其作為替代CNN的backbone而無(wú)縫對(duì)接分割和檢測(cè)的heads础浮。論文中做了大量的關(guān)于檢測(cè),語(yǔ)義分割以及實(shí)例分割的實(shí)驗(yàn)奠骄,可以看到PVT在dense任務(wù)的優(yōu)勢(shì)豆同。比如,在更少的推理時(shí)間內(nèi)含鳞,基于PVT-Small的RetinaNet比基于R50的RetinaNet在COCO上的AP值更高(38.7 vs. 36.3)影锈,雖然繼續(xù)增加scale可以提升效果,但是就需要額外的推理時(shí)間:
所以雖然PVT可以解決一部分問(wèn)題蝉绷,但是如果輸入圖像分辨率特別大鸭廷,可能基于CNN的方案還是最優(yōu)的。另外曠視最新的一篇論文YOLOF指出其實(shí)ResNet一個(gè)C5特征加上一些增大感受野的模塊就可以在檢測(cè)上實(shí)現(xiàn)類(lèi)似的效果熔吗,這不得不讓人思考多尺度特征是不是必須的靴姿,而且transformer encoder本身就是全局感受野的。近期Intel提出的DPT直接在ViT模型的基礎(chǔ)上通過(guò)Reassembles operation來(lái)得到不同scale的特征圖以用于dense任務(wù)磁滚,并在ADE20K語(yǔ)義分割數(shù)據(jù)集上達(dá)到新的SOTA(mIoU 49.02)佛吓。而在近日,微軟提出的Swin Transformer和PVT的網(wǎng)絡(luò)架構(gòu)和很類(lèi)似垂攘,但其性能在各個(gè)檢測(cè)和分割數(shù)據(jù)集上效果達(dá)到SOTA(在ADE20K語(yǔ)義分割數(shù)據(jù)集mIoU 53.5)维雇,其核心提出了一種shifted window方法來(lái)減少self-attention的計(jì)算量。
相信未來(lái)會(huì)有更好的work晒他!期待吱型!
參考
- Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
- whai362/PVT
- 大白話(huà)Pyramid Vision Transformer
- You Only Look One-level Feature
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
- Vision Transformers for Dense Prediction
推薦閱讀
谷歌提出Meta Pseudo Labels,刷新ImageNet上的SOTA陨仅!
大道至簡(jiǎn)!深度解讀CVPR2021論文RepVGG灼伤!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
SWA:讓你的目標(biāo)檢測(cè)模型無(wú)痛漲點(diǎn)1% AP
CondInst:性能和速度均超越Mask RCNN的實(shí)例分割模型
centerX: 用新的視角的方式打開(kāi)CenterNet
mmdetection最小復(fù)刻版(十一):概率Anchor分配機(jī)制PAA深入分析
MMDetection新版本V2.7發(fā)布,支持DETR狐赡,還有YOLOV4在路上!
TF Object Detection 終于支持TF2了撞鹉!
無(wú)需tricks,知識(shí)蒸餾提升ResNet50在ImageNet上準(zhǔn)確度至80%+
不妨試試MoCo颖侄,來(lái)替換ImageNet上pretrain模型鸟雏!