Vision Transformer階段性總結(jié) (2021.10)

0. 前言

近兩年學(xué)術(shù)界對(duì)Transformer在CV上的應(yīng)用可謂異常青睞,這里重點(diǎn)強(qiáng)調(diào)學(xué)術(shù)界的原因是目前工業(yè)界還是比較冷靜的(部分公司已經(jīng)開始考慮Vision Trransformer的落地了)酪穿,畢竟新方法從學(xué)術(shù)界到工業(yè)界落地一般都會(huì)晚幾年。在當(dāng)前背景下衰琐,個(gè)人對(duì)Vision Transformer進(jìn)行了較淺的調(diào)研宏娄,本文便是該工作的階段性總結(jié)毙石。本文由三部分組成:

  • 第一部分結(jié)合代碼介紹了Transformer的基本概念和原理,由于我們主要關(guān)注的是Transformer在CV上的應(yīng)用颓遏,所以不會(huì)涉及過多關(guān)于NLP相關(guān)的細(xì)節(jié)徐矩。
  • 第二部分會(huì)介紹幾個(gè)自己驗(yàn)證過效果較好的Vision Transformer方法。
  • 第三部分是一些個(gè)人對(duì)Vision Transformer的看法叁幢,部分觀點(diǎn)已有實(shí)驗(yàn)驗(yàn)證滤灯,但因?yàn)樯婕肮ぷ鲀?nèi)容,所以不便貼出來曼玩,另外一些觀點(diǎn)只是個(gè)人的假設(shè)鳞骤,并沒有經(jīng)過驗(yàn)證。

由于時(shí)間倉促以及本人能力有限黍判,無法完全保證文中沒有錯(cuò)誤之處豫尽,歡迎指正。

1. Transformer

本部分我們以翻譯任務(wù)為例介紹Transformer顷帖。假設(shè)我們有一個(gè)法語句子需要翻譯為英語美旧,即將"Je suis étudiant"翻譯為"I am a student"(為了配合使用別人的圖片:)),所以使用這個(gè)例子)贬墩。如果不關(guān)注細(xì)節(jié)的話榴嗅,我們可以將翻譯模型視為黑箱,翻譯過程就如下圖所示陶舞。

1.1 Transformer之前的Seq2Seq模型

現(xiàn)在我們來探究在Transformer出現(xiàn)之前上圖中黑箱的大致細(xì)節(jié)和原理嗽测,從而引入一些常用的概念(用黑體標(biāo)出)。

一個(gè)語言句子其實(shí)就是一些詞(word)的序列(Sequence)吊说,而翻譯的過程其實(shí)就是將源語言的詞序列( x_1, \dots ,x_n)翻譯為目標(biāo)語言的詞序列( y_1, \dots ,y_m)论咏,即Sequence to Sequence,簡(jiǎn)稱Seq2Seq颁井。Transformer之前主流Seq2Seq神經(jīng)網(wǎng)絡(luò)模型是RNN以及一些變種厅贪,主要有LSTM、GRU(GRU是LSTM的變種)等雅宾。下面簡(jiǎn)單了解下RNN和LSTM养涮。

我們?nèi)祟愒诜g一個(gè)句子的時(shí)候并不是看到一個(gè)詞就將其立即翻譯成目標(biāo)語言對(duì)應(yīng)的一個(gè)單詞,并且忘掉之前讀過的單詞眉抬,而是通讀一遍句子并理解句子的含義贯吓,然后將含義再翻譯為目標(biāo)語言。用計(jì)算機(jī)語言描述就是先將輸入的源詞序列編碼(Encoding)為一個(gè)表示含義的張量(tensor蜀变,一組數(shù)字)悄谐,其實(shí)主要是對(duì)每個(gè)詞進(jìn)行編碼,然后再將編碼的張量解碼(Decoding)為目標(biāo)語言的詞序列作為輸出库北。如下圖所示爬舰,整個(gè)模型由兩部分組成们陆,編碼器(Encoder)和解碼器(Decoder)。

上圖只是直觀的展示情屹,其實(shí)在訓(xùn)練和推理(翻譯)時(shí)是無法直接將原始的語言句子輸入給模型進(jìn)行計(jì)算的坪仇,因?yàn)橛?jì)算機(jī)本質(zhì)上只支持?jǐn)?shù)字運(yùn)算,所以需要將序列中的每個(gè)詞處理成數(shù)字才行垃你,常用的方法是word2vec椅文,即將每個(gè)單詞用唯一的一個(gè)張量進(jìn)行表示,輸入序列就是很多個(gè)張量的列表惜颇,模型的輸出也是一個(gè)元素為張量的列表皆刺,其中每個(gè)張量表示目標(biāo)語言的一個(gè)詞,下圖為源語言句子用word2vec處理后的張量(x_1, x_2, x_3)官还。

但是芹橡,早期的翻譯模型僅僅能夠做到簡(jiǎn)單地將源序列的一個(gè)或多個(gè)單詞映射為目標(biāo)語言的一個(gè)或多個(gè)單詞,句子如果簡(jiǎn)單的話望伦,這種方法還是比較有效的,但是對(duì)于一些大長句就比較吃力了煎殷,尤其是無法解決語言翻譯中的一個(gè)典型問題屯伞,指代消除。指代消除是指將句子中表示同一個(gè)對(duì)象的詞關(guān)聯(lián)起來豪直。下面是一個(gè)需要進(jìn)行指代消除的句子:

The animal didn’t cross the street because it was too tired.

句子中的it是一個(gè)代名詞劣摇,具體指的是什么?我們需要結(jié)合句子中其他的詞才能知道弓乙。將it替代為animal便是指代消除末融。為了更好地編碼序列中某個(gè)位置的詞而需要關(guān)注序列中其他位置的詞,即考慮上下文信息暇韧。RNN(Recurrent Neural Networks)主要就是針對(duì)像語言這種長度可變且序列中存在依賴關(guān)系的任務(wù)設(shè)計(jì)的勾习。下圖左側(cè)是簡(jiǎn)略的RNN網(wǎng)絡(luò)結(jié)構(gòu),結(jié)構(gòu)中具有的循環(huán)機(jī)制(圖中指向A自己的箭頭懈玻,這也是名稱中R的來源)使得模型在編碼后續(xù)詞的時(shí)候可以關(guān)注之前詞的編碼信息巧婶,好像具有了記憶功能一樣。訓(xùn)練和推理的時(shí)候都是將序列中的單詞逐個(gè)輸入到模型中進(jìn)行編碼和解碼涂乌,序列中所有詞都是共享同一模型參數(shù)艺栈。下圖右側(cè)是按時(shí)間序列展開的結(jié)構(gòu),x_t為序列中第t個(gè)時(shí)間步的輸入湾盒,也就是序列中第t個(gè)位置的詞張量, A表示模型(可學(xué)習(xí)參數(shù))湿右,h_t表示第t個(gè)輸出,即目標(biāo)語言序列中第t個(gè)位置的詞張量罚勾。

下圖是標(biāo)準(zhǔn)RNN模型的大致內(nèi)部結(jié)構(gòu)毅人,這里不進(jìn)行細(xì)致探討:

RNN存在一個(gè)嚴(yán)重問題是無法捕獲序列中相隔較遠(yuǎn)的詞之間的依賴關(guān)系吭狡,即所謂的長程依賴問題。另外堰塌,輸入序列中的不同詞一般具有不同的重要程度赵刑,而且輸出序列中不同位置的詞有時(shí)需要考慮輸入中的不同位置的詞才能更好地進(jìn)行解碼。例如场刑,在翻譯任務(wù)中般此,輸出的第一個(gè)單詞一般是基于輸入的前幾個(gè)詞確定的,輸出的最后幾個(gè)詞可能基于輸入的最后幾個(gè)詞牵现。注意力機(jī)制(下一節(jié)會(huì)詳細(xì)說明)的引入為解碼器提供了在每個(gè)解碼時(shí)間步上查看整個(gè)輸入序列的能力铐懊,而且解碼器可以在任何時(shí)間步?jīng)Q定哪些輸入單詞是重要的。LSTM(Long Short Term Memory)便是為了解決長程依賴問題而提出的瞎疼,其后續(xù)改進(jìn)版引入了注意力機(jī)制科乎,使得RNN模型的表現(xiàn)不斷提升,下圖是原始LSTM的大致結(jié)構(gòu)贼急,其中添加了記憶篩選機(jī)制茅茂,使得模型在后續(xù)編解碼中有選擇地記住重要詞的編碼信息,并忘掉不重要詞的編碼信息太抓,這也是其名稱的由來空闲。

下圖是增加了注意力機(jī)制的LSTM模型的改進(jìn)版本,可以看到預(yù)測(cè)\hat y_2(圖中的“hit”)的解碼時(shí)不僅依賴于解碼器前一個(gè)時(shí)間步的輸出(圖中的"<START>"和"he")和編碼器的編碼信息走敌,還使用到了由輸入序列中每個(gè)詞的編碼信息計(jì)算得到的注意力輸出(圖中的"Attention output")碴倾。"<START>"詞為序列開始解碼的指示符,因?yàn)镈ecoder的第一次預(yù)測(cè)時(shí)沒有來自前一個(gè)時(shí)間步的輸出詞作為輸入掉丽,便用特殊的詞代替跌榔,類似于占位符。

雖然LSTM已經(jīng)很優(yōu)秀了捶障,但是還是有明顯的缺陷:無法進(jìn)行并行訓(xùn)練僧须。由于語言數(shù)據(jù)本身是不定長的,RNN恰恰就是設(shè)計(jì)來處理不定長數(shù)據(jù)任務(wù)的残邀,訓(xùn)練時(shí)都是單個(gè)序列進(jìn)行訓(xùn)練皆辽。有些工作也嘗試解決這個(gè)問題,但是仍然無法從根本上解決芥挣。這個(gè)缺陷大大阻礙了將現(xiàn)代GPU用于加速訓(xùn)練大規(guī)模語言模型驱闷,而這便是Transformer要解決的問題。

1.2 Transformer

如第1.1部分所述空免,Transformer的提出主要是為了解決LSTM無法進(jìn)行并行訓(xùn)練的問題空另,但是仍然沿用Seq2Seq的Encoder-Decoder結(jié)構(gòu),主要的改進(jìn)有兩點(diǎn):

  1. Encoder和Decoder模塊都是由多個(gè)Attention(論文中稱為Multi-Head Attention)和MLP(論文中稱為Feed-Forward Networks)子模塊堆疊組成(暫時(shí)忽略LayerNorm層);
  2. 類似CNN網(wǎng)絡(luò)蹋砚,輸入可以一次接受包含多個(gè)序列的數(shù)據(jù)扼菠,大大加速了模型的訓(xùn)練摄杂。

下面是Transformer的整體網(wǎng)絡(luò)結(jié)構(gòu),我們將依次對(duì)各個(gè)組件內(nèi)部進(jìn)行詳細(xì)探究循榆。

Encoder的輸入(Input Embedding)

與LSTM類似析恢,我們需要將輸入(語言)序列進(jìn)行數(shù)值化,由于Transformer支持一次輸入多個(gè)序列秧饮,而每個(gè)序列的長度不一映挂,為了能夠?qū)⒉煌L度的序列組成一個(gè)batch(為了進(jìn)行并行計(jì)算),我們需要將所有序列的長度進(jìn)行對(duì)齊盗尸,簡(jiǎn)單的做法便是設(shè)置最大序列長度柑船,該最大長度為訓(xùn)練數(shù)據(jù)集中最長序列的長度加1(加1用于放置序列結(jié)束符),其他不足最大長度的序列泼各,多余的部分全部填充為序列結(jié)束符鞍时,例如符號(hào)“<EOS>”表示End of Sentence(實(shí)際中可能用另外的符號(hào),如"<blank>")扣蜻。假設(shè)一個(gè)數(shù)據(jù)集中只有兩個(gè)句子分別為:

Do we really need Attention?
Yes or no?

數(shù)據(jù)集中最大句子長度為6(包括標(biāo)點(diǎn)符號(hào))逆巍,則我們?cè)O(shè)置最大序列長度為7。所有句子長度對(duì)齊后如下莽使。具體實(shí)現(xiàn)上蒸苇,訓(xùn)練時(shí)只需要將一個(gè)batch內(nèi)的長度對(duì)齊便可,補(bǔ)充的多余部分可以在訓(xùn)練時(shí)用mask屏蔽掉,從而不參與訓(xùn)練吮旅,mask方式類似后面要介紹的Masked Multi-Head Attention部分。

下一步我們將序列中的每個(gè)詞進(jìn)行數(shù)值化(又稱詞嵌入, word embedding)味咳,假設(shè)整個(gè)數(shù)據(jù)集中最多有10個(gè)詞(包括標(biāo)點(diǎn)符號(hào))庇勃,則我們可以用one-hot方式將每個(gè)詞表示成唯一的張量,張量的維度為10槽驶,例如用“0000000001”表示“do”责嚷,用“0000000010”表示“we”,其他依次類推掂铐,則對(duì)齊的兩個(gè)句子數(shù)值化后如下(為了方便展示罕拂,對(duì)序列矩陣進(jìn)行了轉(zhuǎn)置),如果batch size=2全陨,則例子中處理后的輸入矩陣大小為[2, 7, 10]爆班,這便是上圖中的Input Embedding。當(dāng)然這個(gè)是最簡(jiǎn)單的策略辱姨,存在的主要問題是Embedding的維度隨著語料庫的詞匯數(shù)量增加柿菩,更常用的是 Word2Vec 方法,這里不再展開雨涛。

Encoder和Decoder

Encoder

我們先從整體上了解下Encoder的結(jié)構(gòu)枢舶,如下圖懦胞,從圖中可以看出Encoder由N個(gè)相同的層堆疊組成(論文中N=6),每個(gè)層又由兩個(gè)子層(sub-layer)組成凉泄,第一個(gè)子層稱為Multi-head Attention躏尉,另一個(gè)子層稱為Feed Forward。每個(gè)子層都引入了殘差連接后众,子層的輸出與其輸入逐元素相加(element-wise addition)胀糜,再進(jìn)行LayerNorm(類似于BatchNorm的層歸一化),作為本層的輸出吼具。Encoder中第一個(gè)層(Multi-head Attention+Feed Forward)的輸入來自模型的輸入僚纷,即經(jīng)過位置編碼(圖中的Position Encoding)的詞嵌入矩陣,其他后續(xù)層的輸入都來自前一層的輸出拗盒。

下面我們一一研究Encoder中的每個(gè)子層怖竭。在研究Multi-Head Attention子層之前,我們需要先弄明白什么是self-attention陡蝇。

自注意力(Self-Attention)

所謂“自”注意力痊臭,個(gè)人理解就是指在沒有人為先驗(yàn)經(jīng)驗(yàn)指導(dǎo)的情況下讓模型自己學(xué)習(xí)掌握一種能力,這種能力能夠建模所需的注意力登夫。至于注意力機(jī)制识椰,論文中給出了非常簡(jiǎn)潔的定義:為了更好地編碼一個(gè)序列的表征(representions)牲距,將序列中的不同位置關(guān)聯(lián)起來的一種機(jī)制,類似于我們?cè)诒鎰e照片中一個(gè)不清晰物體時(shí)會(huì)借助周圍其他物體。

常用的關(guān)聯(lián)不同位置信息的方式便是加權(quán)求和榜旦。具體地,就是先對(duì)序列中每個(gè)位置進(jìn)行獨(dú)立(個(gè)人使用的非正式說法)編碼环肘,然后將不同位置的編碼進(jìn)行加權(quán)求和勿璃,得到包含注意力能力的編碼信息,這里的權(quán)重便是模型需要學(xué)習(xí)的參數(shù)(self的來源)狮斗。這種機(jī)制其實(shí)在CV中也已經(jīng)有比較成熟的應(yīng)用绽乔,比如2018年提出SENet中使用Squeeze-and-Excitation便是一種通道注意力機(jī)制,可以認(rèn)為一個(gè)通道特征就是序列中的一個(gè)位置的編碼碳褒,如下圖所示折砸,模型通過學(xué)習(xí)自注意力函數(shù)F_{ex}(\cdot;W) ,該函數(shù)接受壓縮后的通道編碼信息沙峻,產(chǎn)出每個(gè)通道的權(quán)重值睦授,權(quán)重?cái)?shù)量與原始特征通道數(shù)量相同,這樣便可以將每個(gè)權(quán)重與原始通道進(jìn)行相乘专酗,這里相乘后并沒有多通道相加睹逃,但可以認(rèn)為是單通上的加權(quán)。

我們可以用簡(jiǎn)單的公式表示這種加權(quán)的注意力機(jī)制:

A_i = w_i V = \sum_{j=0}^{N} w_{ij} * v_j

其中,A_i表示得是第i個(gè)位置詞包含注意力的編碼信息沉填,{w_i} \in R^{N}是模型學(xué)習(xí)到的權(quán)重疗隶,每個(gè)位置對(duì)應(yīng)一個(gè),V \in R^{N * M}表示所有位置的獨(dú)立編碼信息翼闹,{v_j} \in R^{M}表示第j個(gè)位置的獨(dú)立編碼信息斑鼻,M為編碼信息的長度,N為序列的長度猎荠,即序列中詞的數(shù)量坚弱。如果要使用矩陣一次計(jì)算所有位置的注意力編碼信息,則可以使用下面的公式:

A = W^{T} V \text{(公式1)}

其中关摇,W \in R^{N*N}荒叶,A \in R^{N*M}

在Transformer中這種自注意力機(jī)制得到了進(jìn)一步的發(fā)展输虱,但主要的思想是不變的些楣,具體地,要進(jìn)行加權(quán)的各個(gè)位置的編碼信息矩陣V是不變的宪睹,變化的是得到權(quán)重W的方式愁茁。如前所述,我們將其他位置的編碼信息以加權(quán)的方式加入到當(dāng)前位置的編碼中的主要目的是序列中多個(gè)位置的詞之間存在一定的相關(guān)性亭病,這種關(guān)系有助于更好地對(duì)當(dāng)前位置的詞進(jìn)行編碼(比如解決長程依賴問題)鹅很。因此,可以讓模型嘗試學(xué)習(xí)對(duì)這種相關(guān)性進(jìn)行編碼(這是一種1-N的關(guān)系罪帖,包括當(dāng)前詞自己與自己的相關(guān)性)促煮,然后將這種相關(guān)性直接轉(zhuǎn)換為加權(quán)值,相關(guān)性越高權(quán)值越大整袁,對(duì)詞最終的編碼的貢獻(xiàn)越大污茵。

Transformer使用類似查字典的方式來顯式地對(duì)這種相關(guān)性進(jìn)行建模。在實(shí)際使用字典的時(shí)候葬项,我們首先得有個(gè)查詢的字(這里稱為query,簡(jiǎn)稱為q),然后使用一定的標(biāo)準(zhǔn)逐一與字典中的關(guān)鍵字(這里稱為key迹蛤,簡(jiǎn)稱為k)進(jìn)行比較民珍,符合比較標(biāo)準(zhǔn)(比如拼音或筆畫相同)的我們就認(rèn)為找到了要查找的內(nèi)容,或者說兩個(gè)字匹配上了盗飒。在Transformer中嚷量,也是顯示地為序列中的每個(gè)詞指定一個(gè)查詢的字{q_i},并且為序列中每個(gè)詞指定一個(gè)關(guān)鍵字{k_i}逆趣,它們都是與獨(dú)立編碼信息{v_i}(這里稱為value)類似的張量蝶溶。到這里,有兩個(gè)問題需要解決,一是如何得到具體的{q}抖所、{k}{v}張量梨州,另一個(gè)是如何計(jì)算{q}{k}的相似性。

第一個(gè)問題田轧,由于我們希望模型自己學(xué)習(xí)到序列中不同{v}的相似性暴匠,最直接的方法就是通過不斷地迭代訓(xùn)練,用監(jiān)督的方式讓模型自己學(xué)習(xí)每個(gè)詞(或稱為token傻粘,對(duì)于更深的層來說每窖,不能再稱為詞)對(duì)應(yīng)的{q_i}{k_i}{v_i}弦悉,這其實(shí)類似于CNN中讓模型自己學(xué)習(xí)特征窒典,而不是人工設(shè)計(jì)特征。簡(jiǎn)單的做法就是各使用一個(gè)可學(xué)習(xí)的矩陣將序列的每個(gè)詞線性映射為{q_i}稽莉,{k_i}{v_i}瀑志,假設(shè)這三個(gè)矩陣分別為W^QW^KW^V肩祥,輸入為{x_i}(行向量后室,表示一個(gè)詞嵌入),則每個(gè)詞對(duì)應(yīng)的{q_i}混狠,{k_i}{v_i}可通過如下計(jì)算得到:

{q_i} = {x_i} W^Q

{k_i} = {x_i} W^K

{v_i} = {x_i} W^V

過程如下圖示例岸霹,圖中假設(shè)輸入{x_i}的維度是4,W^Q \in R^{4*3},W^K \in R^{4*3}将饺,W^V \in R^{4*3}贡避,:

如果使用矩陣一次計(jì)算序列中所有位置的{q}{k}{v}予弧,則可以使用下面的公式:
Q = X W^Q

K = X W^K

V = X W^V

其中假設(shè)輸入序列矩陣X \in R^{N \times d_{model}}刮吧,每個(gè)詞張量為d_{model}維,一共N個(gè)詞掖蛤,W^Q \in R^{d_{model} \times d_{k}},W^K \in R^{d_{model} \times d_{k}}杀捻,W^V \in R^{d_{model} \times d_{v}}d_{model}蚓庭、d_{k}d_{v}是模型超參致讥,后續(xù)我們還會(huì)提到他們,計(jì)算的Q \in R^{N \times d_{k}},K \in R^{N \times d_{k}}, V \in R^{N \times d_{v}}器赞。

過程如下圖示例:


Tranformer論文中作者將分別計(jì)算Q垢袱、K、V的輸入直接用Q港柜、K请契、V表示,如下公式,但其實(shí)Q爽锥、K涌韩、V是同一輸入X的副本。
Q = Q W^Q

K = K W^K

V = V W^V

對(duì)于第二個(gè)問題救恨,相對(duì)比較簡(jiǎn)單了贸辈,兩個(gè)張量相似的話,可以用它們的歐式距離衡量肠槽,距離越小越相關(guān)擎淤,或者先計(jì)算兩個(gè)張量的點(diǎn)積,然后使用sigmoid函數(shù)映射到[0,1]區(qū)間秸仙,越相關(guān)則值越接近1嘴拢。如前所述,每個(gè){q_i}都要與序列中所有的{k_j}計(jì)算相似性寂纪,而且序列中的這種相似性可能不是獨(dú)立的(互斥性)席吴,那么可以讓{q_i}與每個(gè){k_j}先計(jì)算點(diǎn)積,然后在產(chǎn)生的點(diǎn)積張量(一個(gè)點(diǎn)積一個(gè)標(biāo)量捞蛋,多次點(diǎn)積就是一個(gè)多維張量)上執(zhí)行softmax函數(shù)孝冒,點(diǎn)積張量的每個(gè)值都會(huì)映射到[0,1]區(qū)間,即可得到第i個(gè)詞與序列中其他詞之間的相關(guān)性權(quán)重拟杉。計(jì)算的方式如下:

每個(gè)張量獨(dú)立計(jì)算點(diǎn)積:
r_{ij} = {q_i} {k_j}^T

{q_i}與所有的key的點(diǎn)積計(jì)算如下:
{r_i} = {q_i} K^T

使用矩陣一次計(jì)算序列中所有詞之間的點(diǎn)積:
R = Q K^T

r_{ij}表示第i個(gè)詞的query {q_i}與第j個(gè)詞的key {k_j}的點(diǎn)積庄涡。

權(quán)重計(jì)算如下:
{w_i} = softmax({r_i})
其中
{w_{ij}} = \frac{e^{r_{ij}}}{\sum_{k=1}^{k=N} e^{r_{ik}}}

過程如下圖示例,該過程展示了計(jì)算x_1x_1x_2的權(quán)重搬设,結(jié)果分別為0.88和0.12:

使用矩陣一次計(jì)算序列中所有詞的權(quán)重張量:
W = softmax(R) = softmax(Q K^T)

該權(quán)重矩陣便可以用于前面公式1中計(jì)算注意力編碼了:
A = W^{T} V = softmax(Q K^T) V
可以看到穴店,這個(gè)公式與Transformer論文中的注意力計(jì)算公式已經(jīng)很接近了:
A = W^{T} V = softmax(\frac{Q K^T}{\sqrt{d_k}}) V \text{ (公式2)}

過程如下圖示例:


作者增加縮放因子\frac{1}{\sqrt{d_k}}的原因,是當(dāng)QK元素的維度d_k太大時(shí)會(huì)導(dǎo)致softmax函數(shù)處于梯度極小的區(qū)域拿穴,這樣會(huì)導(dǎo)致模型的最終效果較差泣洞,所以增加一個(gè)縮放因子。作者也將這樣的注意力機(jī)制稱為帶縮放的點(diǎn)積注意力(Scaled Dot-Product Attention)默色,模型的示意圖如下(后面會(huì)介紹Mask的作用):

Pytorch代碼實(shí)現(xiàn)如下:

import torch

class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = torch.nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        # self.temperature就是縮放因子
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        # dropout是正則化方法
        attn = self.dropout(torch.nn.functional.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn
Multi-Head Attention

Transformer作者發(fā)現(xiàn)球凰,相比使用單一的注意力函數(shù)(即公式2),對(duì)輸入序列X并行執(zhí)行多次線性映射(即計(jì)算不同的Q腿宰、K弟蚀、V)并使用多個(gè)注意力函數(shù),模型效果會(huì)更好酗失。個(gè)人猜測(cè)原因是一個(gè)序列中一個(gè)詞可能與多個(gè)詞的相關(guān)性同樣重要,但是使用了softmax函數(shù)就是顯式地約束相關(guān)性是互斥的昧绣。通過并行計(jì)算多個(gè)注意力函數(shù)规肴,可以大大緩解這個(gè)問題。但是,每增加一個(gè)注意力函數(shù)拖刃,就是顯著增加模型的參數(shù)量和計(jì)算量删壮。作者做法是將原來的Q、K兑牡、V張量維度分為h份央碟,每一部分在一個(gè)獨(dú)立的注意力函數(shù)中進(jìn)行計(jì)算,這樣的話均函,參數(shù)量和計(jì)算量并未增加亿虽,但將單注意力(Single—Head Attention)變成了多注意力(Multi-Head Attention)。例如苞也,在使用Multi-Head Attention之前的Q洛勉、K、V維度為d_k=d_v=d_{model}如迟,使用后每個(gè)Attention中的Q收毫、K、V維度為d_k=d_v=d_{model}/h殷勘,總體上的張量維度是不變的此再,還是d_{model}維。每個(gè)獨(dú)立的Attention分別計(jì)算出一部分注意力編碼后再將它們拼接起來(Concat)玲销,然后使用一個(gè)矩陣進(jìn)行線性變換(融合)输拇,具體公式如下:

其中,W_i^Q \in R^{d_{model} \times d_{k}},W_i^K \in R^{d_{model} \times d_{k}}痒玩,W_i^V \in R^{d_{model} \times d_{v}}淳附,計(jì)算的Q \in R^{N \times d_{model}},K \in R^{N \times d_{model}}, V \in R^{N \times d_{model}}。論文中蠢古,h=8, d_k=d_v=d_{model}/h=64奴曙,整體的示意圖如下。

代碼實(shí)現(xiàn)如下:

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformer.Modules import ScaledDotProductAttention

class MultiHeadAttention(nn.Module):
    # n_head就是上述的h超參
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        # 為了實(shí)現(xiàn)高效草讶,不同頭的Q洽糟、K和V的計(jì)算矩陣可以是同一個(gè),這其實(shí)應(yīng)用了矩陣分塊計(jì)算
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        # len_q堕战、len_k坤溃、len_v分別是序列的長度,即詞的數(shù)量
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = 
        # 同時(shí)計(jì)算所有head中的q,k,v嘱丢,然后再分拆薪介,便于在多個(gè)head中分別計(jì)算注意力
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1) 

        # 實(shí)現(xiàn)上看起來是Single-Head, 但是q,k,v的shape比Single-Head的多一維
        # 使用了矩陣運(yùn)算的高效實(shí)現(xiàn)
        q, attn = self.attention(q, k, v, mask=mask)
        
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        # 對(duì)多個(gè)head的注意力編碼進(jìn)行融合
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn
Position-wise Feed-Forward Networks

Multi-Head Attention子層輸出的編碼張量與殘差連接的原始張量相加并進(jìn)行LayerNorm后,作為Position-wise Feed-Forward Networks的輸入越驻,一個(gè)由二層全連接網(wǎng)絡(luò)和位于兩者之間的ReLU激活函數(shù)組成的子層汁政,加“Position-wise”的原因是每個(gè)全連接線性變換是應(yīng)用到序列中的每個(gè)詞的道偷,即所有詞共用一組權(quán)重參數(shù),公式表達(dá)如下:

x表示序列中某個(gè)位置“詞”的注意力編碼记劈,x \in R^{d_{model}}勺鸦,W_1 \in R^{d_{model} \times d_{ff}}W_2 \in R^{d_{ff} \times d_{model}}目木。論文中d_{model}=512换途,d_{ff}=2048

和Multi-Head Attention子層一樣刽射,兩個(gè)全連接網(wǎng)絡(luò)產(chǎn)生的編碼張量與殘差連接的原始張量相加军拟,并進(jìn)行LayerNorm,代碼實(shí)現(xiàn)如下:

import torch.nn as nn

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid) 
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual
        x = self.layer_norm(x)

        return x
LayerNorm

CV中不同Norm的計(jì)算方式比較:

為什么Transformer使用LayerNorm柄冲,而不是BatchNorm?

Layer normalization is used in the transformer because the statistics of language data exhibit large fluctuations across the batch dimension, and this leads to instability in batch normalization.

image.png

具體實(shí)現(xiàn)上吻谋,NLP和CV中LayerNorm的計(jì)算方式不一樣,NLP中的計(jì)算方式如下圖左側(cè)所示现横,其只在在每個(gè)token的embedding維度上計(jì)算均值和方式漓拾,然后進(jìn)行歸一化:

image.png

Pytorch中的示例代碼:

>>> # NLP Example
>>> batch, sentence_length, embedding_dim = 20, 5, 10
>>> embedding = torch.randn(batch, sentence_length, embedding_dim)
>>> layer_norm = nn.LayerNorm(embedding_dim)
>>> # Activate module
>>> layer_norm(embedding)
>>>
>>> # Image Example
>>> N, C, H, W = 20, 5, 10, 10
>>> input = torch.randn(N, C, H, W)
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
>>> # as shown in the image below
>>> layer_norm = nn.LayerNorm([C, H, W])
>>> output = layer_norm(input)
位置編碼(Positional Encoding)

從前面的注意力機(jī)制我們可以看到,Transformer最大的好處就是讓序列中每個(gè)詞的編碼是同時(shí)進(jìn)行的戒祠,且對(duì)序列中每個(gè)詞進(jìn)行編碼時(shí)都會(huì)注意到序列中的其他詞骇两,有利于對(duì)每個(gè)詞進(jìn)行更好地編碼。但是姜盈,序列中詞的相對(duì)位置關(guān)系也是特別重要的低千,在LSTM模型中序列的每個(gè)詞是以先后順序輸入到模型中的,所以天然地具有利用詞之間的位置信息馏颂。Transformer的解決方法是顯式地將位置信息直接附加到詞的embedding張量中示血。所謂直接附加,就是將每個(gè)詞的位置信息直接編碼成與每個(gè)詞相同維度(即d_{model})的張量救拉,與詞的embedding相加难审,如下圖所示。

這里還有個(gè)問題亿絮,那就是每個(gè)詞的位置編碼怎么得到告喊。Transformer的作者指出直接編碼位置信息或者讓模型學(xué)習(xí)每個(gè)詞的位置信息都是可行的,且效果相當(dāng)派昧,但是作者用了sine和cosine函數(shù)來對(duì)每個(gè)詞的位置信息進(jìn)行編碼黔姜,公式如下:

其中pos表示詞在序列中的位置,i表示詞張量中的位置蒂萎。代碼的實(shí)現(xiàn)如下:

import torch
import torch.nn as nn
import numpy as np

class PositionalEncoding(nn.Module):

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()

這樣的位置編碼方式秆吵,使模型很容易學(xué)習(xí)到序列中每個(gè)詞的相對(duì)位置,因?yàn)閷?duì)于任何一個(gè)固定的位置偏移k五慈,PE_{pos+k}可以用PE_{pos}位置的編碼信息進(jìn)行線性表示纳寂,如下公式所示实苞,而且這樣的編碼信息可以使模型的編碼位置擴(kuò)展到訓(xùn)練時(shí)未見過的長度。

以上是Encoder模塊中一個(gè)層的組成烈疚,較詳細(xì)的組成如下圖,為了便于理解聪轿,圖中將一個(gè)序列的x_1x_2分開輸入給Encoder爷肝,串聯(lián)重復(fù)N次就組成了Encoder,論文中N=6陆错〉婆祝總體上,Encoder的編碼輸出維度為 outputs \in R^{B * N * d_{model}}音瓷,B為Batch size对嚼。

其實(shí),如果只是想要了解Transformer的Attention機(jī)制绳慎,并應(yīng)用到CV中纵竖,那么了解整個(gè)Encoder就可以了,因?yàn)榻^大部分的CV任務(wù)(除了類似Image Captrain這樣的任務(wù))從本質(zhì)上來說并不需要用Encoder-Decoder網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行建模杏愤,只需要對(duì)輸入進(jìn)行encoding就行靡砌。所以,目前幾乎所有的(個(gè)人的片面經(jīng)驗(yàn))ViT模型都只利用了Transformer的Encoder部分的思想珊楼。

Decoder

Decoder也是由N個(gè)(論文中是6)相同的層串聯(lián)組成通殃,每個(gè)層除了由分別帶有殘差連接的Multi-Head Attention和Feed Forward子層組成外(與Encoder一樣),在兩個(gè)子層之間還加入了一個(gè)Multi-Head Attention子層厕宗,該子層也帶有殘差連接和LayerNorm画舌。其主要作用是融合Encoder的輸出和Docoder的中間編碼信息,便于更好地進(jìn)行解碼已慢。其中的Feed Forward子層與Encoder中的完全一樣曲聂。整體的結(jié)構(gòu)如下圖所示。

Masked Multi-Head Attention

細(xì)心的讀者肯定已經(jīng)發(fā)現(xiàn)蛇受,在前面的Multi-Head Attention源碼的forward函數(shù)有個(gè)mask參數(shù)句葵,這其實(shí)就是為Masked Multi-Head Attention實(shí)現(xiàn)的。所以兢仰,Masked Multi-Head Attention與Multi-Head Attention幾乎是一樣的乍丈,只是多了一個(gè)mask操作。

總體上來說把将,這mask是為了屏蔽掉訓(xùn)練時(shí)位于序列前面的詞與序列中位于該詞后面的其他詞計(jì)算注意力轻专,并將它們的編碼信息用于計(jì)算該詞的編碼。屏蔽的原因主要是Decoder在訓(xùn)練時(shí)與推理時(shí)的使用方式不同造成的察蹲,在推理的時(shí)候请垛,Decoder與RNN模型一樣催训,每個(gè)時(shí)間步只能預(yù)測(cè)一個(gè)單詞,不斷循環(huán)直到模型預(yù)測(cè)序列結(jié)束符或者達(dá)到預(yù)定義的最大長度宗收,即只能根據(jù)前面所有已經(jīng)預(yù)測(cè)的單詞預(yù)測(cè)當(dāng)前時(shí)間步的單詞漫拭。例如,將"Je suis étudiant"翻譯為"I am a student"混稽,在推理時(shí)采驻,要先給解碼器輸入開始解碼標(biāo)志符挂洛,這是一個(gè)特殊的單詞句喷,例如“<START>”,如果沒有這個(gè)預(yù)先輸入的單詞俐东,Decoder就無法對(duì)句子的第一個(gè)單詞“I”進(jìn)行解碼洽洁,因?yàn)樵谛蛄猩纤堑谝粋€(gè)痘系,沒有可以依賴的之前時(shí)間步的輸入。當(dāng)輸入“<START>”的embedding張量(包含位置編碼信息)饿自,Decoder結(jié)合embedding和Encoder的編碼信息預(yù)測(cè)輸出“I”汰翠,然后再將“<START>”和“I”的embedding作為Decoder的輸入,結(jié)合Encoder的編碼信息預(yù)測(cè)輸出“am”,以此類推璃俗,不斷循環(huán)奴璃,直到模型預(yù)測(cè)出句子結(jié)束符“<EOS>”或者達(dá)到提前設(shè)定的最大預(yù)測(cè)序列長度。下面的動(dòng)圖展示了推理的過程城豁,但是其中沒有從“<START>”作為輸出開始解碼苟穆。

在訓(xùn)練的時(shí)候,我們是可以一次看到要預(yù)測(cè)序列中的所有詞的唱星,如果還像推理時(shí)那樣雳旅,那么整個(gè)模型并沒有完全實(shí)現(xiàn)真正的多個(gè)序列并行訓(xùn)練,Decoder將是整個(gè)模型的訓(xùn)練效率瓶頸间聊。所以攒盈,當(dāng)然是希望訓(xùn)練時(shí)像Encoder一樣,一次將多個(gè)序列同時(shí)作為Decoder的輸入哎榴,直接預(yù)測(cè)后續(xù)的單詞型豁。但是,如果使用不帶mask的Multi-Head Attention尚蝌,那么模型就可以很開心地作弊了迎变,因?yàn)橐呀?jīng)全部知道要預(yù)測(cè)的答案了,為什么還要費(fèi)力氣學(xué)習(xí)解碼呢飘言,直接將輸入作為輸出不要太簡(jiǎn)單衣形,而且訓(xùn)練Loss直接就是為0。解決這個(gè)問題的方式就是在訓(xùn)練的時(shí)姿鸿,將序列中后續(xù)的詞的編碼信息遮蓋住谆吴,不讓其參與到前面詞的編碼中倒源,這便有效阻止了模型作弊,而且可以提升并行訓(xùn)練句狼。這里說的遮蓋其實(shí)很容易在Multi-Head Attention基礎(chǔ)上實(shí)現(xiàn)的笋熬,在Multi-Head Attention中會(huì)計(jì)算一個(gè)詞與序列中每個(gè)詞的權(quán)重,這樣會(huì)得到一個(gè)權(quán)重矩陣W腻菇,矩陣的每一行表示改行對(duì)應(yīng)的單詞與序列其他詞的相關(guān)性突诬,那么在計(jì)算這個(gè)權(quán)重前,我們只需要將該詞位置之后的點(diǎn)積值重置為一個(gè)特別小的值(例如e^{-9})芜繁,在進(jìn)softmax計(jì)算后,那些位于該詞位置之后的詞的權(quán)重都特別小绒极,即與重置的權(quán)重對(duì)應(yīng)的詞的編碼信息就幾乎不會(huì)參與到當(dāng)前詞的解碼信息中骏令。例如下面是對(duì)“"I am a student”解碼時(shí),在mask之前計(jì)算的點(diǎn)積矩陣:

經(jīng)過mask處理后如下所示垄提,可以看到權(quán)重矩陣變成了一個(gè)下三角矩陣榔袋。

具體的代碼實(shí)現(xiàn)可以參考Multi-Head Attention小結(jié)。

Decoder的輸入

從前面的Decoder整體結(jié)構(gòu)圖中我們可以看到Decoder的輸入包含兩部分铡俐,一部分是Encoder對(duì)原始序列(源語言序列)的編碼信息凰兑,一部分是目標(biāo)序列(目標(biāo)語言句子)。其中审丘,訓(xùn)練時(shí)的目標(biāo)序列是ground truth序列吏够,而推理時(shí)每次的輸入是前面所有時(shí)間步預(yù)測(cè)的詞組成的序列。與Encoder一樣滩报,也需要對(duì)輸入的序列進(jìn)行位置編碼锅知,編碼方式與Encoder相同,這里不再贅述脓钾。

從圖中可以看到售睹,Masked Multi-Head Attetion模塊的輸入都是來自輸入的目標(biāo)序列,為了利用原序列中的編碼信息可训,每個(gè)Encoder層(一共N個(gè))的第二個(gè)Multi-Head Attetion子層會(huì)接受Encoder中的編碼信息作為注意力計(jì)算中的K和V昌妹,將前一層的輸出作為Q,目的是用找到與目標(biāo)序列詞相關(guān)性較大的源序列詞的編碼信息握截,從而更好地預(yù)測(cè)當(dāng)前時(shí)間步的詞飞崖。

這里有個(gè)細(xì)節(jié)需要注意,那就是Encoder本身的輸出只是一個(gè)outputs=R^{B * N * d_{model}}的矩陣川蒙,那又是如何轉(zhuǎn)換為兩個(gè)K和V矩陣的蚜厉?Transformer的做法是直接將outputs既作為K,也作為V直接使用畜眨。

具體可以參考下面這個(gè)較詳細(xì)的結(jié)構(gòu)圖昼牛。

整個(gè)Decoder的輸出與Encoder一樣术瓮,都是outputs=R^{B * N * d_{model}}

Linear贰健、Softmax和Loss

Transformer將翻譯視為分類任務(wù)進(jìn)行目標(biāo)語言輸出胞四,具體地,首先獲取數(shù)據(jù)集中目標(biāo)語言的詞數(shù)量伶椿,即詞庫的大小辜伟,假設(shè)為M,然后每個(gè)時(shí)間步就是進(jìn)行M分類脊另,將對(duì)應(yīng)預(yù)測(cè)概率最高的詞作為本次時(shí)間步的輸出导狡。

所以,就像大部分分類模型一樣偎痛,Transformer會(huì)在Decoder后面跟一個(gè)線性分類層和Softmax層旱捧。這里詳細(xì)說明下,Linear是如何將Decoder的outputs=R^{B * N * d_{model}}輸出映射為分類的logits張量的踩麦。類似于Feed Forward子層一樣枚赡,這個(gè)Linear層其實(shí)也是Position-wise的,即序列中所有位置的詞都是共用同一個(gè)線性映射權(quán)重的谓谦。如前贫橙,假設(shè)詞庫的大小為M,則Linear的權(quán)重大小為W_l \in R^{d_{model} * M}反粥,則映射后的logits張量的大小為logits \in R^{B * N * M}卢肃,然后在logits的最后一個(gè)維度(M)上執(zhí)行Softmax操作,選擇概率最大的作為預(yù)測(cè)輸出才顿,就可以得到B * N個(gè)詞的預(yù)測(cè)践剂,即序列中每個(gè)位置詞的預(yù)測(cè),大致流程如下圖所示娜膘。訓(xùn)練時(shí)逊脯,使用分類任務(wù)常用的CrossEntropy函數(shù)就可以了。

到此竣贪,Transformer相關(guān)的內(nèi)容基本上就介紹完了军洼,下面是Vision Transformer相關(guān)的內(nèi)容。

2. Vision Transformer

2.1 ViT

在ViT之前的很多工作都或多或少嘗試在CNN網(wǎng)絡(luò)中添加各種各樣的attention機(jī)制演怎。但是ViT的提出完全顛覆了之前的觀點(diǎn)匕争,它讓我們看到,對(duì)于圖像識(shí)別任務(wù)爷耀,也可以完全使用attention的深度網(wǎng)絡(luò)來解決甘桑。總體上,理解了Transformer的原理跑杭,就會(huì)發(fā)現(xiàn)ViT的實(shí)現(xiàn)還是很簡(jiǎn)單的铆帽。

下面是整個(gè)模型的結(jié)構(gòu)圖,可以看出德谅,核心就使用了Transformer的Encoder模塊(下圖右側(cè))爹橱,從圖中看出該模塊并非是原始的Transformer論文的Encoder,而是后續(xù)的改進(jìn)版窄做,主要是將LayerNorm放在每個(gè)子層的前面執(zhí)行愧驱。另外,MLP子層中的全連接隱藏層后非線性激活函數(shù)使用的是GELU椭盏。

為了使用原生的Encoder模塊组砚,作者對(duì)輸入的2D圖像做了以下處理,從而適應(yīng)Encoder的一維序列形式的輸入要求掏颊。

  1. 將圖片按固定分辨率大小切成N個(gè)切片(patches)惫确,然后將切片內(nèi)的所有像素值扁平化,即延通道方向拼接起來蚯舱。假設(shè)切片大小為P \times P,圖片的大小為x \in R^{H \times W \times C}掩蛤,則經(jīng)過切片和扁平化處理后的張量為x_p \in R^{N \times (P^{2} \dot C)}枉昏, 其中N為序列的長度且N=\frac{HW}{P^{2}}
  2. 對(duì)每個(gè)切片張量進(jìn)行線性映射揍鸟,整個(gè)Encoder的每個(gè)子層中使用的序列詞的embedding維度為D兄裂,所以使用一個(gè)全連接層對(duì)其進(jìn)行線性映射,則維度變成x_p \in R^{N \times D}阳藻;
  3. 類似于BERT晰奖,在序列的頭部添加一個(gè)可學(xué)習(xí)的類別embedding,維度與切片的維度相同(至于為什么這么做腥泥,筆者沒有深究匾南,感興趣的讀者可以進(jìn)一步閱讀BERT文章),這樣序列的長度就變成N+1蛔外,輸入變成x_p \in R^{(N+1) \times D}蛆楞;
  4. 與Transformer一樣,需要對(duì)切片的嵌入添加位置編碼夹厌,ViT使用的一維可學(xué)習(xí)位置編碼(維度與切片embedding的維度相同)豹爹,并未使用二維位置編碼,因?yàn)樽髡甙l(fā)現(xiàn)二者在效果上并沒有差異矛纹。

經(jīng)過以上處理后臂聋,就可以作為Encoder的輸入,如前所述,Encoder最后一個(gè)層的輸出仍然是z_l \in R^{(N+1) \times D}孩等。ViT并沒有使用整個(gè)嵌入z_l進(jìn)行類別計(jì)算艾君,而是只使用序列的第0個(gè)位置,即可學(xué)習(xí)類別embeddingz^0_l作為類別推理時(shí)使用的圖片表征y瞎访,這也是BERT的做法腻贰。最后,使用一個(gè)MLP模塊進(jìn)行分類預(yù)測(cè)扒秸。該MLP模塊在大規(guī)模數(shù)據(jù)集上預(yù)訓(xùn)練時(shí)使用的是兩層結(jié)構(gòu)播演,在小數(shù)據(jù)集上微調(diào)時(shí)只使用單層,輸出層參數(shù)維度為R^{D \times K}伴奥,K為類別概率写烤。

整體的推理計(jì)算過程如下面公式所示:

以上就是ViT所有的相關(guān)細(xì)節(jié),下面是github上ViT的Pytorch實(shí)現(xiàn)主頁提供的動(dòng)態(tài)推理效果圖拾徙。具體的代碼實(shí)現(xiàn)也很簡(jiǎn)單洲炊,可參考Pytorch實(shí)現(xiàn)版本。

評(píng)價(jià)

總體上尼啡,ViT還是屬于挖坑的方法暂衡,存在很多的問題需要解決,比如很弱的歸納偏置崖瞭,這需要相比CNN更大的訓(xùn)練數(shù)據(jù)集狂巢,當(dāng)然,可以認(rèn)為這也是它的優(yōu)點(diǎn)书聚。還有就是特別耗GPU內(nèi)存唧领,難以直接遷移到檢測(cè)、分割等下游任務(wù)雌续,計(jì)算復(fù)雜度與輸入尺寸是平方關(guān)系斩个,缺乏像CNN這樣的局部注意力機(jī)制,推理速度慢等問題驯杜。后續(xù)的工作受啥,基本都是針對(duì)這些問題展開的。下面我們介紹的Swin Transformer和PvtV2主要解決ViT的高計(jì)算復(fù)雜度和難以直接遷移到下游任務(wù)的問題鸽心。

2.2 Swin Transformer

概述

如前所述腔呜,Swin Transformer(后面簡(jiǎn)稱Swin)主要解決難以直接遷移到下游任務(wù)和ViT的高計(jì)算復(fù)雜度的問題。下面是Swin網(wǎng)絡(luò)的結(jié)構(gòu)圖再悼。

為了解決ViT難以直接遷移到下游任務(wù)的問題核畴,Swin實(shí)現(xiàn)了非常類似ResNet系列的網(wǎng)絡(luò)結(jié)構(gòu),整個(gè)網(wǎng)絡(luò)仍然由4個(gè)獨(dú)立的stage串聯(lián)而成冲九,第2~4個(gè)stage輸出的feature map的高和寬都是前一個(gè)stage的1/2谤草,通道數(shù)都是前一個(gè)stage的2倍跟束。這樣的網(wǎng)絡(luò)結(jié)構(gòu)非常適合與后續(xù)的FPN結(jié)合并應(yīng)用于檢測(cè)和分割任務(wù)上。對(duì)feature map空間進(jìn)行降采樣并對(duì)通道進(jìn)行升維的操作主要由網(wǎng)絡(luò)中的Patch Merging模塊實(shí)現(xiàn)丑孩。

為了解決計(jì)算復(fù)雜度問題冀宴,作者將ViT中(global)計(jì)算全局注意力的方式(feature map中每個(gè)patch都會(huì)與其他所有的patch計(jì)算注意力)改為局部窗口(local window)計(jì)算,每個(gè)窗口由固定數(shù)量的相鄰patch劃分得到温学,窗口之間沒有重疊略贮,如下圖左側(cè)部分所示(紅色框?yàn)榇翱冢疑驗(yàn)閜atch)仗岖,這樣便可以將計(jì)算復(fù)雜度與輸入分辨率的二次方關(guān)系(O((wh)^2))降低為線性關(guān)系(O(wh))逃延。該功能通過將ViT中的MSA(Multi-head Self-Attention)子層替換為基于局部非重疊窗口計(jì)算注意力的W-MSA(Window MSA)來實(shí)現(xiàn)。然而轧拄,局部窗口注意力的引入限制了feature map中patch之間的信息交互揽祥,這會(huì)影響模型的表達(dá)能力¢莸纾基于此拄丰,作者將W-MSA改進(jìn)為帶偏移的(shifted)SW-MSA模塊,簡(jiǎn)單理解就是將前一層W-MSA的所有窗口平移一定的位置俐末,這樣之前不在一個(gè)窗口內(nèi)的patch經(jīng)過窗口重新劃分后處于同一窗口(窗口大小不變)內(nèi)料按,這樣便能達(dá)到更大范圍的patch進(jìn)行信息交互的目的。如下圖右側(cè)所示卓箫,圖中標(biāo)示藍(lán)點(diǎn)的patch處于不同層的不同窗口內(nèi)载矿,便能夠與更大范圍的patch計(jì)算注意力。為了實(shí)現(xiàn)以上效果并兼容較低的計(jì)算復(fù)雜度丽柿,作者將配置有W-MSA的注意力子模塊與配置有SW-MS的注意力子模塊串聯(lián)起來,作為一個(gè)整體應(yīng)用到網(wǎng)絡(luò)結(jié)構(gòu)的設(shè)計(jì)中魂挂,文中稱為Swin Transformer Block甫题,如上圖b所示。根據(jù)網(wǎng)絡(luò)規(guī)模的設(shè)計(jì)涂召,不同stage會(huì)配置不同數(shù)量的Swin Transformer Block坠非,非常類似ResNet的Residule Block和BottleNeck Block。

以上是Swin整體的概述果正,下面我們?cè)俜謩e探究下網(wǎng)絡(luò)中幾個(gè)關(guān)鍵組件的具體細(xì)節(jié)炎码。

Patch Partition & Linear Embedding

類似于ViT,Swin先將原圖片按照固定數(shù)量的像素值(patch size)切分成不重疊的patch(或稱為token)秋泳,作為注意力計(jì)算的最小單元潦闲,Swin的patch size為4,假設(shè)原圖為H * W * 3迫皱,則經(jīng)過Patch Partition后的尺寸為\frac{H}{4} * \frac{W}{4} * 48歉闰。Swin中第一個(gè)stage的通道輸入假設(shè)為C(論文中C=96或128或192三種),則需要將輸入線性映射為\frac{H}{4} * \frac{W}{4} * C『途矗總體上凹炸,這個(gè)功能是通過一個(gè)kernel_size=4,輸入通道數(shù)為3昼弟,輸出通道數(shù)為C啤它,步長=4的卷積操作一步完成。

Patch Merging

如前所述舱痘,第2~4個(gè)stage輸出的feature map的高和寬都是前一個(gè)stage的1/2变骡,通道數(shù)都是前一個(gè)stage的2倍。這由網(wǎng)絡(luò)中的Patch Merging模塊實(shí)現(xiàn)衰粹。這個(gè)過程其實(shí)也可以用一個(gè)kernel_size=2锣光,輸入通道數(shù)為C(假設(shè)是第2個(gè)stage),輸出通道數(shù)為2 * C铝耻,步長=2的卷積操作完成誊爹。但是Swin的實(shí)現(xiàn)中并未這么做,而是像ViT中劃分patch的方式一樣瓢捉,先對(duì)輸入矩陣進(jìn)行切分频丘,然后對(duì)屬于同一patch的像素進(jìn)行通道維度的合并(假設(shè)為4 * C),最后使用線性映射泡态,將每個(gè)patch的通道數(shù)映射為指定大新(如2 * C)。個(gè)人覺得這兩種實(shí)現(xiàn)基本等價(jià)某弦,可能是為了降低模型的參數(shù)量桐汤,所以作者選擇了后者。

Swin Transformer Block

如前所述靶壮,具體的Swin Transformer Block分為兩種怔毛,分別是使用W-MSA的Block和SW-MSA block,二者的區(qū)別僅僅就是計(jì)算注意力時(shí)窗口的劃分方式有些區(qū)別腾降,下面分別介紹這兩種注意力拣度。

W-MSA

W-MSA的原理相對(duì)比較簡(jiǎn)單,ViT的注意力計(jì)算是在feature map空間內(nèi)的全部patch上進(jìn)行的螃壤,為了在不重疊的局部窗口內(nèi)計(jì)算注意力抗果,Swin將原特征圖矩陣重新排列成維度更高尺寸更小的特征圖,然后計(jì)算注意力就和在原特征圖上沒什么區(qū)別了奸晴。例如冤馏,原特征圖尺寸為(H, W, C),假設(shè)窗口大小為M寄啼,即一個(gè)窗口由相鄰的M * M個(gè)patch組成宿接,那么重新排列后的特征圖維度為(\frac{H}{M}, \frac{W}{M}, M, M, C)赘淮,這樣注意力計(jì)算就在每個(gè)M * M * C的子特征圖上進(jìn)行。代碼實(shí)現(xiàn)如下:

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

論文中也給出了W-MSA將注意力的計(jì)算復(fù)雜度由與輸出尺寸hw的平方關(guān)系睦霎,降到了線性關(guān)系梢卸,如下圖公式(2)所示,公式中忽略了softmax計(jì)算副女。

現(xiàn)給出以上公式具體的推理過程蛤高。假設(shè)原圖已經(jīng)通過Patch Partition劃分為(hw) \times C的特征圖X。為了便于閱讀碑幅,我們先給出Transformer的注意力計(jì)算公式:

Q = X W^Q

K = X W^K

V = X W^V

Z = softmax(\frac{Q K^T}{\sqrt{d_k}}) V

Z = Z W^O

假設(shè)Q戴陡、K、V的維度都是R^{(hw) \times C }沟涨,則W^Q恤批、W^KW^VW^O的維度都是R^{C \times C }裹赴。在MSA中喜庞,計(jì)算Q、K和V棋返,需要3hwC^2的計(jì)算量延都,然后計(jì)算QK^T(維度為R^{(hw) \times (hw)})需要(hw)^2 C的計(jì)算量,最后進(jìn)行加權(quán)平均睛竣,也需要(hw)^2 C晰房,計(jì)算過程中使用的是多頭注意力,多頭特征融合(Z W^O)還需要hwC^2射沟,所以總體上就是公式(1)的復(fù)雜度殊者。但是由于W-MSA中的注意力是在M \times M的窗口內(nèi)進(jìn)行的,所以計(jì)算QK^T僅需要\frac{h}{M} \frac{w}{W}個(gè)(MM)^2 C的計(jì)算量验夯,即(hw) M^2 C猖吴。同樣的,計(jì)算加權(quán)平均也需要(hw) M^2 C計(jì)算量簿姨,其他部分計(jì)算量不變距误,則最終計(jì)算量如公式(2)所示簸搞。推理過程參考了這篇博客扁位。

SW-MSA

基于SW-MSA的Block是位于W-MSA Block之后,換句話說SW-MSA Block就是為了擴(kuò)展W-MSA Block的注意力視野的趁俊,所以都是在前一個(gè)W-MSA Block配置上將所有窗口的位置延H和W方向平移M/2的距離域仇,從而產(chǎn)生新的窗口,如上面圖2右側(cè)所示寺擂。然而暇务,這樣會(huì)帶來兩個(gè)問題泼掠,第一個(gè)是窗口的數(shù)量增加了,從原來的(\lfloor \frac{H}{M} \rfloor * \lfloor \frac{W}{M} \rfloor)變成(\lfloor \frac{H}{M} \rfloor + 1) * (\lfloor \frac{W}{M} \rfloor + 1)垦细,如果窗口的數(shù)量較小時(shí)择镇,如2 * 2喉祭,則會(huì)增加到3 * 3纵刘,增加了2.25倍。另外一個(gè)問題是這樣劃分后型宙,每個(gè)窗口內(nèi)的patch數(shù)量不相等嘱能,這樣就很難進(jìn)行矩陣計(jì)算吝梅,從而影響計(jì)算效率,簡(jiǎn)單的方式是使用padding的方式惹骂,將每個(gè)窗口的尺寸補(bǔ)全苏携,但是這樣會(huì)增加計(jì)算量。作者通過循環(huán)偏移的方式高效解決了這兩個(gè)問題对粪。整體的過程如下圖所示右冻。

光看上圖有些抽象,我將該過程拆解為如下圖5步進(jìn)行解釋衩侥,首先我們假設(shè)原特征圖的尺寸為4 * 4国旷,窗口大小為2 * 2,則窗口數(shù)量為4茫死,下圖步驟1為W-MSA的窗口劃分方式跪但,稱為正常窗口劃分(regular window partition), 步驟2為偏移過的窗口劃分,可以看到窗口數(shù)量增加為9個(gè)(9個(gè)不同顏色表示)峦萎,其中只有中間的黑色窗口的尺寸是4屡久,其他都不正常,為了保持黑色窗口內(nèi)的4個(gè)patch關(guān)系爱榔,作者將特征圖矩陣沿著(0,0)位置進(jìn)行滾動(dòng)被环,從而將黑色窗口patch張量組移到特征圖空間的左上角,具體的做法如下圖中的步驟2和步驟3详幽,首先將其延H方向滾動(dòng)1個(gè)像素筛欢,然后延W方向滾動(dòng)1個(gè)像素,就得到了步驟4中的特征圖唇聘,其中步驟2和步驟3可以用類似torch.roll的函數(shù)一次實(shí)現(xiàn)版姑。可以看到步驟4既保持了黑色窗口內(nèi)的patch組合關(guān)系沒有被打亂迟郎,而且將窗口數(shù)量減少到4剥险,且每個(gè)窗口的尺寸都是一樣的,如圖5所示宪肖。另外表制,計(jì)算完注意力后還要將特征矩陣還原回去健爬,即按照相反的方式將步驟5中每個(gè)位置的patch還原回到步驟1中的位置。

雖然目前的SW-MSA解決了窗口數(shù)量和尺寸的問題么介,但是又引入了一個(gè)新的問題娜遵,有些窗口內(nèi)的patch并不是相鄰的,以上圖中步驟5個(gè)紅色窗口為例壤短,里面的4個(gè)patch在原特征圖中剛好位于四個(gè)角上魔熏,互相之間的距離太遠(yuǎn),作者并不希望這些patch之間進(jìn)行信息交互(這樣會(huì)不會(huì)帶來什么負(fù)面影響鸽扁,論文中并沒有提及蒜绽,可能是因?yàn)樗鼤?huì)破壞作者引入的局部窗口的歸納偏置吧)。解決這個(gè)問題的方法桶现,可以使用Transformer中Decoder模塊用到的mask方法屏蔽掉這些patch之間計(jì)算注意力躲雅。具體的圖文介紹可以參考這篇博客

由于每次平移的尺寸都是M/2骡和,那么無論特征圖和窗口的尺寸是多少(窗口尺寸必須小于等于特征圖尺寸)相赁,總體上只有三種mask生成情況,這一點(diǎn)在這篇博客中并未詳細(xì)說明慰于,分別是:窗口內(nèi)的所有patch之前就位于同一窗口(稱為正常窗口)钮科,即上圖中的黑色窗口;第二種是正常窗口的右側(cè)和下方婆赠,窗口內(nèi)的patch由之前的兩個(gè)窗口的patch組成绵脯,即上圖中的綠色和黃色窗口,最后一種就是由之前至少四個(gè)窗口的patch組成的紅色窗口休里。下面是特征尺寸為8 * 8, 窗口尺寸為4 * 4蛆挫,平移尺寸為2的窗口劃分矩陣,數(shù)字相同的表示在roll之前位于同一窗口內(nèi)妙黍。

tensor([[0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [3., 3., 3., 3., 4., 4., 5., 5.],
        [3., 3., 3., 3., 4., 4., 5., 5.],
        [6., 6., 6., 6., 7., 7., 8., 8.],
        [6., 6., 6., 6., 7., 7., 8., 8.]])

具體的mask生成過程就不贅述了悴侵,這里截取了作者開源代碼生成mask的部分,可以進(jìn)一步驗(yàn)證:

def window_partition(x, window_size):
    H, W = x.shape
    x = x.view( H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 1, 3).contiguous().view(-1, window_size, window_size)
    return windows

def get_mask(res, window_size, shift_size):
    H, W = res
    img_mask = torch.zeros(H, W)
    h_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    w_slices = (slice(0, -window_size),
                slice(-window_size, -shift_size),
                slice(-shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[h, w] = cnt
            cnt += 1
            
    print(img_mask)

    mask_windows = window_partition(img_mask, window_size)
    print(mask_windows)
    mask_windows = mask_windows.view(-1, window_size * window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    
    return attn_mask

mask = get_mask((8, 8), 4, 2)

另外拭嫁,值得一提的是作者這里使用mask的方式與Transformer中的不太一樣可免,Transformer中的方式是直接用極小值填充Q、K計(jì)算的注意力矩陣:

attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
    attn = attn.masked_fill(mask == 0, -1e9)

而Swin的實(shí)現(xiàn)是將一個(gè)包含負(fù)值(-100)的mask(參考上面的mask生成代碼)與原attention進(jìn)行element-wise加和:

attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)

這樣使用的原因個(gè)人不太清楚(也許這樣是等價(jià)的)做粤,有了解的讀者浇借,還望不吝賜教。

相對(duì)位置編碼

與ViT中類似驮宴,Swin中也是用的是可學(xué)習(xí)的位置編碼逮刨,而且是2維相對(duì)位置編碼:

然而呕缭,可學(xué)習(xí)位置編碼參數(shù)的初始化大小為窗口尺寸M的大致2倍堵泽,稱為參數(shù)表\hat B \in R^{(2M-1) \times (2M-1)}修己,每次使用的位置參數(shù)B都是從\hat B中預(yù)定義的位置獲取,論文并未詳細(xì)說明這樣實(shí)現(xiàn)的原因迎罗,可能是引用了某篇論文中相對(duì)位置編碼方法睬愤,這里先記錄下,回頭研究了再補(bǔ)充纹安。

評(píng)價(jià)

Swin雖然在效果和性能上取得了顯著的提升尤辱,但是還有個(gè)明顯的缺點(diǎn),那就是在224分辨率上預(yù)訓(xùn)練的模型無法直接用于任意分辨率的finetune或下游任務(wù)厢岂,而且推理時(shí)并不支持像CNN一樣的任意分辨率輸入光督。

2.3 PVTv2

概述

由于PVT的v1和v2兩篇論文很接近,而且v2顯著優(yōu)于v1塔粒,所以這里就不分別介紹v1和v2了结借,直接總結(jié)v2版本。PVT要解決的問題和Swin一樣卒茬,構(gòu)建直接可以用于下游檢測(cè)和分割任務(wù)的backbone網(wǎng)絡(luò)船老,并且降低ViT的計(jì)算量。下圖是從v1論文中截取的網(wǎng)絡(luò)結(jié)構(gòu)圖圃酵,v2和v1在總體結(jié)構(gòu)上沒有差異柳畔。類似于ResNet系列模型結(jié)構(gòu),PVT和Swin都是由多個(gè)stage組成郭赐,每個(gè)stage內(nèi)部由若干個(gè)統(tǒng)一的注意力模塊組成薪韩,每個(gè)stage都會(huì)對(duì)輸入的feature map的空間尺度進(jìn)行降采樣并增加通道維數(shù),最后輸出呈金字塔形的多層feature map捌锭。而且二者在降低空間尺寸和增加通道維度的方法上基本類似躬存,最大的區(qū)別是feature map的注意力計(jì)算方式上。

具體地舀锨,PVT的每個(gè)stage由一個(gè)Patch Embedding子層和多個(gè)串聯(lián)Attention模塊(下圖中的Transformer Encoder)組成岭洲。Patch Embedding便是用來降低空間尺寸和增加通道維度的,類似于Swin中的Patch Merging子層坎匿。每個(gè)stage會(huì)將輸入feature map的寬和高減少為原來的1/2(第一個(gè)stage為1/4)盾剩,通道數(shù)也會(huì)增加,但不是增加2倍替蔬,每個(gè)stage的輸出通道數(shù)分別為64告私、128、320和512承桥。

下面分別探究PVT的部分細(xì)節(jié)驻粟。

Patch Embedding

PVTv2的Patch Embedding實(shí)現(xiàn)與Swin類似,使用一層卷積操作實(shí)現(xiàn)patch劃分和線性映射。不同的是蜀撑,PVTv2使用了重疊的大尺寸卷積操作挤巡,這樣使得相鄰的patch之間有一定比例的信息重疊,作者認(rèn)為這樣能夠保持圖像數(shù)據(jù)的局部連續(xù)性(并未解釋這樣的好處酷麦,文中也沒有消融實(shí)驗(yàn))矿卑,如下圖(a)所示。具體地沃饶,第i個(gè)stage中Patch Embedding的卷積核步長為S_i母廷,卷積核大小為2S_i-1,padding尺寸為S_i - 1糊肤,輸入通道數(shù)為C_{i-1}琴昆,輸出通道數(shù)為C_{i}(即卷積核數(shù)量),輸出為\frac{H}{S_i} \times \frac{W}{S_i} \times C_i(flatten之前的shape)馆揉,S_i即為每個(gè)stage的空間降采樣比例椎咧,然后接一個(gè)LayerNorm。

Transformer Encoder

SRA和Linear SRA

PVTv2對(duì)Transformer Encoder(已成為Transformer Block)的改進(jìn)也是為了減少Attention模塊的計(jì)算量把介,具體的做法是直接減少K和V的數(shù)量勤讽,這樣便能明顯較少矩陣計(jì)算。假設(shè)原始的Q拗踢、K脚牍、V的維度都是R^{(hw) \times C },長寬減小的比例為R_{i}巢墅,則K诸狭、V的維度都是R^{(\frac{hw}{R^2_{i}}) \times C },則計(jì)算的復(fù)雜度由原來的:

\Omega(A) = 4hwC^2 + 2(hw)^2C

減小到:

\Omega(SRA) = 4hwC^2 + 2(\frac{hw}{R_{i}})^2C
雖然不像Swin能夠直接將計(jì)算復(fù)雜度降低為hw的線性關(guān)系君纫。作者將這樣的Attention改進(jìn)版稱為SRA(Spatial Reduction Attention)驯遇。既然是空間的降采樣(K、V重排列為R^{h \times w \times C})蓄髓,同樣有多個(gè)實(shí)現(xiàn)方式叉庐,Patch劃分+線性映射的方式,或者使用步長大于1的卷積操作会喝,在PVTv2中使用的是卷積操作陡叠,PVTv1中使用的是前者。R_{i}是超參肢执,不同stage具體的值不同枉阵,文中使用的分別為8、4预茄、2和1兴溜。

另外作者還提出了一種稱為Linear SRA的注意力模塊,可以將計(jì)算復(fù)雜度降低為hw的線性關(guān)系。具體的做法是暴力地將K和V的Patch數(shù)量(或空間尺寸)降采樣為常數(shù)量P_i拙徽,具體地可以使用AdaptiveAvgPool2d實(shí)現(xiàn)刨沦,一般池化的窗口大小P_i設(shè)置地比較大,文中使用的是7斋攀。

\Omega(SRA) = 4hwC^2 + 2(hw * P^2_i)C

兩種SRA的比較如下圖。

Position Encoding

最后梧田,為了使得PVT像CNN網(wǎng)絡(luò)一樣可以接受任意尺寸的輸入淳蔼,作者將ViT中使用的與patch數(shù)量必須一致的可學(xué)習(xí)位置編碼改為帶0值padding的位置編碼方法(Zero Padding Postion Encoding),該方法首次是在”Conditional Positional Encodings for Vision Transformers"這篇論文中提出,具體的實(shí)現(xiàn)是在MLP子模塊的第一個(gè)Linear層與非線性激活函數(shù)(例如GELU)之間增加一個(gè)3 * 3的帶zero padding的深度可分離卷積裁眯,如上面的圖1(b)所示鹉梨。具體的原理這里就不展開了,關(guān)于Vision Transformer的位置編碼方式已經(jīng)可以另開一篇文章了穿稳。

評(píng)價(jià)

整體上PVT實(shí)現(xiàn)簡(jiǎn)單且有效存皂,尤其是用SRA的方式降低注意力計(jì)算量,應(yīng)該比較契合Kaiming He最新論文中的觀點(diǎn):

因?yàn)楸旧韴D像數(shù)據(jù)存在冗余性逢艘,均勻間隔地對(duì)Q旦袋、K的尺寸(flatten后指數(shù)量)做一定的降采樣,也不會(huì)影響模型的能力它改。

缺點(diǎn)是論文中缺乏消融實(shí)驗(yàn)疤孕,不清楚每個(gè)改進(jìn)點(diǎn)帶來多少的效果提升。

3. 個(gè)人的看法

  1. 理性看待paper里的效果央拖,尤其是在COCO和ImageNet上的效果祭阀。當(dāng)然,純粹發(fā)論文另當(dāng)別論鲜戒。
  2. Vision Transformer對(duì)細(xì)粒度分類可能有效专控。
  3. 背景單一的檢測(cè)任務(wù),可能并不需要注意力機(jī)制遏餐,例如工業(yè)數(shù)據(jù)集伦腐。
  4. 可遷移性還是較弱,可能配合自監(jiān)督預(yù)訓(xùn)練會(huì)有所改善失都。
  5. 非常吃顯存仍然是Vision Transformer的弊端蔗牡,這將嚴(yán)重導(dǎo)致Vision Transformer在檢測(cè)任務(wù)中的發(fā)揮,尤其對(duì)于小目標(biāo)檢測(cè)任務(wù)(大分辨率才能讓模型看清楚目標(biāo)特征)嗅剖。
  6. Vision Transformer依賴大規(guī)模數(shù)據(jù)集辩越,但是無監(jiān)督方法可能會(huì)緩解這個(gè)問題。
  7. 模型參數(shù)量容易達(dá)到飽和信粮,更多的參數(shù)黔攒,遷移到其他任務(wù)并沒有提升,也可能和數(shù)據(jù)量有關(guān)。
  8. Vision Transformer在通用檢測(cè)任務(wù)上提升很多督惰。

4. 參考

文中絕大部分圖片來自參考文獻(xiàn)不傅。

  1. Attention Is All You Need
  2. https://jalammar.github.io/illustrated-transformer
  3. https://zhuanlan.zhihu.com/p/54356280
  4. https://mp.weixin.qq.com/s/S89kak4El3hPZJc0EzRszA
  5. https://zhuanlan.zhihu.com/p/308301901
  6. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
  7. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  8. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows
  9. PVTv1: Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions.
  10. PVTv2: Improved Baselines with Pyramid Vision Transformer
  11. Conditional Positional Encodings for Vision Transformers.
  12. https://colah.github.io/posts/2015-08-Understanding-LSTMs/
  13. https://looperxx.github.io/CS224n-2019-08-Machine%20Translation,%20Sequence-to-sequence%20and%20Attention/
  14. https://medium.com/deeper-learning/glossary-of-deep-learning-word-embedding-f90c3cec34ca
  15. https://github.com/jadore801120/attention-is-all-you-need-pytorch
  16. https://github.com/lucidrains/vit-pytorch
  17. https://zhuanlan.zhihu.com/p/360513527
  18. Masked Autoencoders Are Scalable Vision Learners
  19. https://www.borealisai.com/en/blog/tutorial-17-transformers-iii-training/
  20. PowerNorm: Rethinking Batch Normalization in Transformers
    20.https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?highlight=layernorm#torch.nn.LayerNorm
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市赏胚,隨后出現(xiàn)的幾起案子访娶,更是在濱河造成了極大的恐慌,老刑警劉巖觉阅,帶你破解...
    沈念sama閱讀 216,372評(píng)論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件崖疤,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡典勇,警方通過查閱死者的電腦和手機(jī)劫哼,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來割笙,“玉大人权烧,你說我怎么就攤上這事∩烁龋” “怎么了般码?”我有些...
    開封第一講書人閱讀 162,415評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長乱顾。 經(jīng)常有香客問我侈询,道長,這世上最難降的妖魔是什么糯耍? 我笑而不...
    開封第一講書人閱讀 58,157評(píng)論 1 292
  • 正文 為了忘掉前任扔字,我火速辦了婚禮,結(jié)果婚禮上温技,老公的妹妹穿的比我還像新娘革为。我一直安慰自己,他們只是感情好舵鳞,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,171評(píng)論 6 388
  • 文/花漫 我一把揭開白布震檩。 她就那樣靜靜地躺著,像睡著了一般蜓堕。 火紅的嫁衣襯著肌膚如雪抛虏。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,125評(píng)論 1 297
  • 那天套才,我揣著相機(jī)與錄音迂猴,去河邊找鬼。 笑死背伴,一個(gè)胖子當(dāng)著我的面吹牛沸毁,可吹牛的內(nèi)容都是我干的峰髓。 我是一名探鬼主播,決...
    沈念sama閱讀 40,028評(píng)論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼息尺,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼携兵!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起搂誉,我...
    開封第一講書人閱讀 38,887評(píng)論 0 274
  • 序言:老撾萬榮一對(duì)情侶失蹤徐紧,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后炭懊,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體并级,經(jīng)...
    沈念sama閱讀 45,310評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,533評(píng)論 2 332
  • 正文 我和宋清朗相戀三年凛虽,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了死遭。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片广恢。...
    茶點(diǎn)故事閱讀 39,690評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡凯旋,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出钉迷,到底是詐尸還是另有隱情至非,我是刑警寧澤,帶...
    沈念sama閱讀 35,411評(píng)論 5 343
  • 正文 年R本政府宣布糠聪,位于F島的核電站荒椭,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏舰蟆。R本人自食惡果不足惜趣惠,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,004評(píng)論 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望身害。 院中可真熱鬧味悄,春花似錦、人聲如沸塌鸯。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽丙猬。三九已至涨颜,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間茧球,已是汗流浹背庭瑰。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評(píng)論 1 268
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留抢埋,地道東北人见擦。 一個(gè)月前我還...
    沈念sama閱讀 47,693評(píng)論 2 368
  • 正文 我出身青樓钉汗,卻偏偏與公主長得像,于是被迫代替她去往敵國和親鲤屡。 傳聞我的和親對(duì)象是個(gè)殘疾皇子损痰,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,577評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容