深入淺出視覺Transformer

**原生:**

**QKV的含義和注意力機制的三個計算步驟:Q和所有K計算相似性茅坛;對相似性采用softmax轉(zhuǎn)化為概率分布观挎;將概率分布和V進(jìn)行一一對應(yīng)相乘挪凑,最后相加得到新的和Q一樣長的向量輸出**

**cv:**

發(fā)現(xiàn)其**小物體檢測能力**遠(yuǎn)遠(yuǎn)低于faster rcnn荠瘪,這是一個比較大的弊端夯巷。

導(dǎo)讀

Transformer整個網(wǎng)絡(luò)結(jié)構(gòu)完全由Attention機制組成,其出色的性能在多個任務(wù)上都取得了非常好的效果哀墓。本文從Transformer的結(jié)構(gòu)出發(fā)趁餐,結(jié)合視覺中的成果進(jìn)行了分析,能夠幫助初學(xué)者們快速入門篮绰。

0 摘要

transformer結(jié)構(gòu)是google在17年的Attention Is All You Need論文中提出后雷,在NLP的多個任務(wù)上取得了非常好的效果,可以說目前NLP發(fā)展都離不開transformer吠各。最大特點是拋棄了傳統(tǒng)的CNN和RNN臀突,整個網(wǎng)絡(luò)結(jié)構(gòu)完全是由Attention機制組成。由于其出色性能以及對下游任務(wù)的友好性或者說下游任務(wù)僅僅微調(diào)即可得到不錯效果贾漏,在計算機視覺領(lǐng)域不斷有人嘗試將transformer引入惧辈,近期也出現(xiàn)了一些效果不錯的嘗試,典型的如目標(biāo)檢測領(lǐng)域的detr和可變形detr磕瓷,分類領(lǐng)域的vision transformer等等。本文從transformer結(jié)構(gòu)出發(fā),結(jié)合視覺中的transformer成果(具體是vision transformer和detr)進(jìn)行分析困食,希望能夠幫助cv領(lǐng)域想了解transformer的初學(xué)者快速入門边翁。由于本人接觸transformer時間也不長,也算初學(xué)者硕盹,故如果有描述或者理解錯誤的地方歡迎指正符匾。

本文的大部分圖來自論文、國外博客和國內(nèi)翻譯博客瘩例,在此一并感謝前人工作啊胶,具體鏈接見參考資料。本文特別長垛贤,大概有3w字焰坪,請先點贊收藏然后慢慢看....

1 transformer介紹

一般講解transformer都會以機器翻譯任務(wù)為例子講解,機器翻譯任務(wù)是指將一種語言轉(zhuǎn)換得到另一種語言聘惦,例如英語翻譯為中文任務(wù)某饰。從最上層來看,如下所示:

圖片

1.1 早期seq2seq

機器翻譯是一個歷史悠久的問題善绎,本質(zhì)可以理解為序列轉(zhuǎn)序列問題黔漂,也就是我們常說的seq2seq結(jié)構(gòu),也可以稱為encoder-decoder結(jié)構(gòu)弥搞,如下所示:

圖片

encoder和decoder在早期一般是RNN模塊(因為其可以捕獲時序信息)谨娜,后來引入了LSTM或者GRU模塊毁枯,不管內(nèi)部組件是啥,其核心思想都是通過Encoder編碼成一個表示向量减途,即上下文編碼向量,然后交給Decoder來進(jìn)行解碼浩聋,翻譯成目標(biāo)語言观蜗。一個采用典型RNN進(jìn)行編碼碼翻譯的可視化圖如下:

圖片

可以看出,其解碼過程是順序進(jìn)行衣洁,每次僅解碼出一個單詞墓捻。對于CV領(lǐng)域初學(xué)者來說,RNN模塊構(gòu)建的seq2seq算法坊夫,理解到這個程度就可以了砖第,不需要深入探討如何進(jìn)行訓(xùn)練。但是上述結(jié)構(gòu)其實有缺陷环凿,具體來說是:

  • 不論輸入和輸出的語句長度是什么梧兼,中間的上下文向量長度都是固定的,一旦長度過長智听,僅僅靠一個固定長度的上下文向量明顯不合理
  • 僅僅利用上下文向量解碼羽杰,會有信息瓶頸渡紫,長度過長時候信息可能會丟失

通俗理解是編碼器與解碼器的連接點僅僅是編碼單元輸出的隱含向量,其包含的信息有限考赛,對于一些復(fù)雜任務(wù)可能信息不夠惕澎,如要翻譯的句子較長時,一個上下文向量可能存不下那么多信息颜骤,就會造成翻譯精度的下降唧喉。

1.2 基于attention的seq2seq

基于上述缺陷進(jìn)而提出帶有注意力機制Attention的seq2seq,同樣可以應(yīng)用于RNN忍抽、LSTM或者GRU模塊中八孝。注意力機制Attention對人類來說非常好理解,假設(shè)給定一張圖片鸠项,我們會自動聚焦到一些關(guān)鍵信息位置干跛,而不需要逐行掃描全圖。此處的attention也是同一個意思锈锤,其本質(zhì)是對輸入的自適應(yīng)加權(quán)驯鳖,結(jié)合cv領(lǐng)域的senet中的se模塊就能夠理解了。

圖片

se模塊最終是學(xué)習(xí)出一個1x1xc的向量久免,然后逐通道乘以原始輸入浅辙,從而對特征圖的每個通道進(jìn)行加權(quán)即通道注意力,對attention進(jìn)行抽象阎姥,不管啥領(lǐng)域其機制都可以歸納為下圖:

圖片

將Query(通常是向量)和4個Key(和Q長度相同的向量)分別計算相似性记舆,然后經(jīng)過softmax得到q和4個key相似性的概率權(quán)重分布,然后對應(yīng)權(quán)重乘以Value(和Q長度相同的向量)呼巴,最后相加即可得到包含注意力的attention值輸出泽腮,理解上應(yīng)該不難。舉個簡單例子說明:

  • 假設(shè)世界上所有小吃都可以被標(biāo)簽化衣赶,例如微辣诊赊、特辣、變態(tài)辣府瞄、微甜碧磅、有嚼勁....,總共有1000個標(biāo)簽遵馆,現(xiàn)在我想要吃的小吃是[微辣鲸郊、微甜、有嚼勁]货邓,這三個單詞就是我的Query
  • 來到東門老街一共100家小吃點秆撮,每個店鋪賣的東西不一樣,但是肯定可以被標(biāo)簽化换况,例如第一家小吃被標(biāo)簽化后是[微辣职辨、微咸],第二家小吃被標(biāo)簽化后是[特辣盗蟆、微臭、特咸]拨匆,第二家小吃被標(biāo)簽化后是[特辣姆涩、微甜、特咸惭每、有嚼勁],其余店鋪都可以被標(biāo)簽化亏栈,每個店鋪的標(biāo)簽就是Keys,但是每家店鋪由于賣的東西不一樣台腥,單品種類也不一樣,所以被標(biāo)簽化后每一家的標(biāo)簽List不一樣長
  • Values就是每家店鋪對應(yīng)的單品绒北,例如第一家小吃的Values是[烤羊肉串黎侈、炒花生]
  • 將Query和所有的Keys進(jìn)行一一比對,相當(dāng)于計算相似性闷游,此時就可以知道我想買的小吃和每一家店鋪的匹配情況峻汉,最后有了匹配列表,就可以去店鋪里面買東西了(Values和相似性加權(quán)求和)脐往。最終的情況可能是休吠,我在第一家店鋪買了烤羊肉串,然后在第10家店鋪買了個玉米业簿,最后在第15家店鋪買了個烤面筋

以上就是完整的注意力機制瘤礁,采用我心中的標(biāo)準(zhǔn)Query去和被標(biāo)簽化的所有店鋪Keys一一比對,此時就可以得到我的Query在每個店鋪中的匹配情況梅尤,最終去不同店鋪買不同東西的過程就是權(quán)重和Values加權(quán)求和過程柜思。簡要代碼如下:

# 假設(shè)q是(1,N,512),N就是最大標(biāo)簽化后的list長度,k是(1,M,512),M可以等于N巷燥,也可以不相等
# (1,N,512) x (1,512,M)-->(1,N,M)
attn = torch.matmul(q, k.transpose(2, 3))
# softmax轉(zhuǎn)化為概率赡盘,輸出(1,N,M),表示q中每個n和每個m的相關(guān)性
attn=F.softmax(attn, dim=-1)
# (1,N,M) x (1,M,512)-->(1,N,512)缰揪,V和k的shape相同
output = torch.matmul(attn, v)

帶有attention的RNN模塊組成的ser2seq,解碼時候可視化如下:

圖片

在沒有attention時候陨享,不同解碼階段都僅僅利用了同一個編碼層的最后一個隱含輸出,加入attention后可以通過在每個解碼時間步輸入的都是不同的上下文向量邀跃,以上圖為例霉咨,解碼階段會將第一個開啟解碼標(biāo)志<START>(也就是Q)與編碼器的每一個時間步的隱含狀態(tài)(一系列Key和Value)進(jìn)行點乘計算相似性得到每一時間步的相似性分?jǐn)?shù),然后通過softmax轉(zhuǎn)化為概率分布拍屑,然后將概率分布和對應(yīng)位置向量進(jìn)行加權(quán)求和得到新的上下文向量途戒,最后輸入解碼器中進(jìn)行解碼輸出,其詳細(xì)解碼可視化如下:

圖片

通過上述簡單的attention引入僵驰,可以將機器翻譯性能大幅提升喷斋,引入attention有以下幾個好處:

  • 注意力顯著提高了機器翻譯性能
  • 注意力允許解碼器以不同程度的權(quán)重利用到編碼器的所有信息唁毒,可以繞過瓶頸
  • 通過檢查注意力分布,可以看到解碼器在關(guān)注什么星爪,可解釋性強

1.3 基于transformer的seq2seq

基于attention的seq2seq的結(jié)構(gòu)雖然說解決了很多問題浆西,但是其依然存在不足:

  • 不管是采用RNN、LSTM還是GRU都不利于并行訓(xùn)練和推理顽腾,因為相關(guān)算法只能從左向右依次計算或者從右向左依次計算
  • 長依賴信息丟失問題近零,順序計算過程中信息會丟失,雖然LSTM號稱有緩解抄肖,但是無法徹底解決

最大問題應(yīng)該是無法并行訓(xùn)練久信,不利于大規(guī)模快速訓(xùn)練和部署漓摩,也不利于整個算法領(lǐng)域發(fā)展裙士,故在Attention Is All You Need論文中拋棄了傳統(tǒng)的CNN和RNN,將attention機制發(fā)揮到底管毙,整個網(wǎng)絡(luò)結(jié)構(gòu)完全是由Attention機制組成腿椎,這是一個比較大的進(jìn)步。

google所提基于transformer的seq2seq整體結(jié)構(gòu)如下所示:

圖片

其包括6個結(jié)構(gòu)完全相同的編碼器夭咬,和6個結(jié)構(gòu)完全相同的解碼器啃炸,其中每個編碼器和解碼器設(shè)計思想完全相同,只不過由于任務(wù)不同而有些許區(qū)別皱埠,整體詳細(xì)結(jié)構(gòu)如下所示:

圖片

第一眼看有點復(fù)雜肮帐,其中N=6,由于基于transformer的翻譯任務(wù)已經(jīng)轉(zhuǎn)化為分類任務(wù)(目標(biāo)翻譯句子有多長边器,那么就有多少個分類樣本)训枢,故在解碼器最后會引入fc+softmax層進(jìn)行概率輸出,訓(xùn)練也比較簡單忘巧,直接采用ce loss即可恒界,對于采用大量數(shù)據(jù)訓(xùn)練好的預(yù)訓(xùn)練模型,下游任務(wù)僅僅需要訓(xùn)練fc層即可砚嘴。上述結(jié)構(gòu)看起來有點復(fù)雜十酣,一個稍微抽象點的圖示如下:

圖片

看起來比基于RNN或者其余結(jié)構(gòu)構(gòu)建的seq2seq簡單很多。下面結(jié)合代碼和原理進(jìn)行深入分析际长。

1.4 transformer深入分析

前面寫了一大堆耸采,沒有理解沒有關(guān)系,對于cv初學(xué)者來說其實只需要理解QKV的含義和注意力機制的三個計算步驟:Q和所有K計算相似性工育;對相似性采用softmax轉(zhuǎn)化為概率分布虾宇;將概率分布和V進(jìn)行一一對應(yīng)相乘,最后相加得到新的和Q一樣長的向量輸出即可如绸,重點是下面要講的transformer結(jié)構(gòu)嘱朽。

下面按照 編碼器輸入數(shù)據(jù)處理->編碼器運行->解碼器輸入數(shù)據(jù)處理->解碼器運行->分類head 的實際運行流程進(jìn)行講解旭贬。

1.4.1 編碼器輸入數(shù)據(jù)處理

(1) 源單詞嵌入

以上面翻譯任務(wù)為例,原始待翻譯輸入是三個單詞:

圖片

輸入是三個單詞搪泳,為了能夠?qū)⑽谋緝?nèi)容輸入到網(wǎng)絡(luò)中肯定需要進(jìn)行向量化(不然單詞如何計算稀轨?),具體是采用nlp領(lǐng)域的embedding算法進(jìn)行詞嵌入岸军,也就是常說的Word2Vec奋刽。對于cv來說知道是干嘛的就行,不必了解細(xì)節(jié)凛膏。假設(shè)每個單詞都可以嵌入成512個長度的向量杨名,故此時輸入即為3x512,注意Word2Vec操作只會輸入到第一個編碼器中猖毫,后面的編碼器接受的輸入是前一個編碼器輸出。

為了便于組成batch(不同訓(xùn)練句子單詞個數(shù)肯定不一樣)進(jìn)行訓(xùn)練须喂,可以簡單統(tǒng)計所有訓(xùn)練句子的單詞個數(shù)吁断,取最大即可,假設(shè)統(tǒng)計后發(fā)現(xiàn)待翻譯句子最長是10個單詞坞生,那么編碼器輸入是10x512仔役,額外填充的512維向量可以采用固定的標(biāo)志編碼得到,例如$$是己。

(2) 位置編碼positional encoding

采用經(jīng)過單詞嵌入后的向量輸入到編碼器中還不夠又兵,因為transformer內(nèi)部沒有類似RNN的循環(huán)結(jié)構(gòu),沒有捕捉順序序列的能力卒废,或者說無論句子結(jié)構(gòu)怎么打亂沛厨,transformer都會得到類似的結(jié)果。為了解決這個問題摔认,在編碼詞向量時會額外引入了位置編碼position encoding向量表示兩個單詞i和j之間的距離逆皮,簡單來說就是在詞向量中加入了單詞的位置信息

加入位置信息的方式非常多参袱,最簡單的可以是直接將絕對坐標(biāo)0,1,2編碼成512個長度向量即可电谣。作者實際上提出了兩種方式:

  • 網(wǎng)絡(luò)自動學(xué)習(xí)
  • 自己定義規(guī)則

提前假設(shè)單詞嵌入并且組成batch后,shape為(b,N,512)抹蚀,N是序列最大長度剿牺,512是每個單詞的嵌入向量長度,b是batch

(a) 網(wǎng)絡(luò)自動學(xué)習(xí)

self.pos_embedding = nn.Parameter(torch.randn(1, N, 512))

比較簡單,因為位置編碼向量需要和輸入嵌入(b,N,512)相加环壤,所以其shape為(1,N,512)表示N個位置晒来,每個位置采用512長度向量進(jìn)行編碼

(b) 自己定義規(guī)則

自定義規(guī)則做法非常多,論文中采用的是sin-cos規(guī)則镐捧,具體做法是:

  • 將向量(N,512)采用如下函數(shù)進(jìn)行處理
圖片

pos即0~N,i是0-511

  • 將向量的512維度切分為奇數(shù)行和偶數(shù)行
  • 偶數(shù)行采用sin函數(shù)編碼潜索,奇數(shù)行采用cos函數(shù)編碼
  • 然后按照原始行號拼接
def get_position_angle_vec(position):
    # d_hid是0-511,position表示單詞位置0~N-1
    return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

# 每個單詞位置0~N-1都可以編碼得到512長度的向量
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
# 偶數(shù)列進(jìn)行sin
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
# 奇數(shù)列進(jìn)行cos
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

上面例子的可視化如下:

圖片

如此編碼的優(yōu)點是能夠擴展到未知的序列長度臭增,例如前向時候有特別長的句子,其可視化如下:

圖片

作者為啥要設(shè)計如此復(fù)雜的編碼規(guī)則竹习?原因是sin和cos的如下特性:

圖片

可以將用進(jìn)行線性表出:

圖片

假設(shè)k=1誊抛,那么下一個位置的編碼向量可以由前面的編碼向量線性表示,等價于以一種非常容易學(xué)會的方式告訴了網(wǎng)絡(luò)單詞之間的絕對位置整陌,讓模型能夠輕松學(xué)習(xí)到相對位置信息拗窃。注意編碼方式不是唯一的,將單詞嵌入向量和位置編碼向量相加就可以得到編碼器的真正輸入了泌辫,其輸出shape是(b,N,512)随夸。

1.4.2 編碼器前向過程

編碼器由兩部分組成:自注意力層和前饋神經(jīng)網(wǎng)絡(luò)層。

圖片

其前向可視化如下:

圖片

注意上圖沒有繪制出單詞嵌入向量和位置編碼向量相加過程震放,但是是存在的宾毒。

(1) 自注意力層

通過前面分析我們知道自注意力層其實就是attention操作,并且由于其QKV來自同一個輸入殿遂,故稱為自注意力層诈铛。我想大家應(yīng)該能想到這里attention層作用,在參考資料1博客里面舉了個簡單例子來說明attention的作用:假設(shè)我們想要翻譯的輸入句子為The animal didn't cross the street because it was too tired墨礁,這個“it”在這個句子是指什么呢幢竹?它指的是street還是這個animal呢?這對于人類來說是一個簡單的問題恩静,但是對于算法則不是焕毫。當(dāng)模型處理這個單詞“it”的時候,自注意力機制會允許“it”與“animal”建立聯(lián)系即隨著模型處理輸入序列的每個單詞驶乾,自注意力會關(guān)注整個輸入序列的所有單詞邑飒,幫助模型對本單詞更好地進(jìn)行編碼。實際上訓(xùn)練完成后確實如此轻掩,google提供了可視化工具幸乒,如下所示:

圖片

上述是從宏觀角度思考,如果從輸入輸出流角度思考唇牧,也比較容易:

圖片

假設(shè)我們現(xiàn)在要翻譯上述兩個單詞罕扎,首先將單詞進(jìn)行編碼,和位置編碼向量相加丐重,得到自注意力層輸入X,其shape為(b,N,512)腔召;然后定義三個可學(xué)習(xí)矩陣 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-ig5jEqK2-1675912069030)(null)] (通過nn.Linear實現(xiàn)),其shape為(512,M)扮惦,一般M等于前面維度512臀蛛,從而計算后維度不變;將X和矩陣[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-uKYMuXXa-1675912067818)(null)] 相乘,得到QKV輸出浊仆,shape為(b,N,M)客峭;然后將Q和K進(jìn)行點乘計算向量相似性;采用softmax轉(zhuǎn)換為概率分布抡柿;將概率分布和V進(jìn)行加權(quán)求和即可舔琅。其可視化如下:

圖片

上述繪制的不是矩陣形式,更好理解而已洲劣。對于第一個單詞的編碼過程是:將q1和所有的k進(jìn)行相似性計算备蚓,然后除以維度的平方根(論文中是64,本文可以認(rèn)為是512)使得梯度更加穩(wěn)定囱稽,然后通過softmax傳遞結(jié)果郊尝,這個softmax分?jǐn)?shù)決定了每個單詞對編碼當(dāng)下位置(“Thinking”)的貢獻(xiàn),最后對加權(quán)值向量求和得到z1战惊。

這個計算很明顯就是前面說的注意力機制計算過程流昏,每個輸入單詞的編碼輸出都會通過注意力機制引入其余單詞的編碼信息

上述為了方便理解才拆分這么細(xì)致吞获,實際上代碼層面采用矩陣實現(xiàn)非常簡單:

圖片

上面的操作很不錯横缔,但是還有改進(jìn)空間,論文中又增加一種叫做“多頭”注意力(“multi-headed” attention)的機制進(jìn)一步完善了自注意力層衫哥,并在兩方面提高了注意力層的性能:

  • 它擴展了模型專注于不同位置的能力。在上面的例子中襟锐,雖然每個編碼都在z1中有或多或少的體現(xiàn)撤逢,但是它可能被實際的單詞本身所支配。如果我們翻譯一個句子粮坞,比如“The animal didn’t cross the street because it was too tired”蚊荣,我們會想知道“it”指的是哪個詞,這時模型的“多頭”注意機制會起到作用莫杈。
  • 它給出了注意力層的多個“表示子空間",對于“多頭”注意機制互例,有多個查詢/鍵/值權(quán)重矩陣集(Transformer使用8個注意力頭,因此我們對于每個編碼器/解碼器有8個矩陣集合)筝闹。
圖片

簡單來說就是類似于分組操作媳叨,將輸入X分別輸入到8個attention層中,得到8個Z矩陣輸出关顷,最后對結(jié)果concat即可糊秆。論文圖示如下:

圖片

先忽略Mask的作用,左邊是單頭attention操作议双,右邊是n個單頭attention構(gòu)成的多頭自注意力層痘番。

代碼層面非常簡單,單頭attention操作如下:

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        # self.temperature是論文中的d_k ** 0.5,防止梯度過大
        # QxK/sqrt(dk)
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            # 屏蔽不想要的輸出
            attn = attn.masked_fill(mask == 0, -1e9)
        # softmax+dropout
        attn = self.dropout(F.softmax(attn, dim=-1))
        # 概率分布xV
        output = torch.matmul(attn, v)

        return output, attn

再次復(fù)習(xí)下Multi-Head Attention層的圖示汞舱,可以發(fā)現(xiàn)在前面講的內(nèi)容基礎(chǔ)上還加入了殘差設(shè)計和層歸一化操作伍纫,目的是為了防止梯度消失,加快收斂昂芜。

圖片

Multi-Head Attention實現(xiàn)在ScaledDotProductAttention基礎(chǔ)上構(gòu)建:

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    # n_head頭的個數(shù)莹规,默認(rèn)是8
    # d_model編碼向量長度,例如本文說的512
    # d_k, d_v的值一般會設(shè)置為 n_head * d_k=d_model说铃,
    # 此時concat后正好和原始輸入一樣访惜,當(dāng)然不相同也可以,因為后面有fc層
    # 相當(dāng)于將可學(xué)習(xí)矩陣分成獨立的n_head份
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        # 假設(shè)n_head=8腻扇,d_k=64
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        # d_model輸入向量债热,n_head * d_k輸出向量
        # 可學(xué)習(xí)W^Q,W^K,W^V矩陣參數(shù)初始化
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        # 最后的輸出維度變換操作
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
        # 單頭自注意力
        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
        self.dropout = nn.Dropout(dropout)
        # 層歸一化
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        # 假設(shè)qkv輸入是(b,100,512),100是訓(xùn)練每個樣本最大單詞個數(shù)
        # 一般qkv相等幼苛,即自注意力
        residual = q
        # 將輸入x和可學(xué)習(xí)矩陣相乘窒篱,得到(b,100,512)輸出
        # 其中512的含義其實是8x64,8個head舶沿,每個head的可學(xué)習(xí)矩陣為64維度
        # q的輸出是(b,100,8,64),kv也是一樣
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # 變成(b,8,100,64)墙杯,方便后面計算,也就是8個頭單獨計算
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.
        # 輸出q是(b,8,100,64),維持不變,內(nèi)部計算流程是:
        # q*k轉(zhuǎn)置括荡,除以d_k ** 0.5高镐,輸出維度是b,8,100,100即單詞和單詞直接的相似性
        # 對最后一個維度進(jìn)行softmax操作得到b,8,100,100
        # 最后乘上V,得到b,8,100,64輸出
        q, attn = self.attention(q, k, v, mask=mask)

        # b,100,8,64-->b,100,512
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        # 殘差計算
        q += residual
        # 層歸一化畸冲,在512維度計算均值和方差嫉髓,進(jìn)行層歸一化
        q = self.layer_norm(q)

        return q, attn

現(xiàn)在pytorch新版本已經(jīng)把MultiHeadAttention當(dāng)做nn中的一個類了,可以直接調(diào)用邑闲。

(2) 前饋神經(jīng)網(wǎng)絡(luò)層

這個層就沒啥說的了算行,非常簡單:

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        # 兩個fc層,對最后的512維度進(jìn)行變換
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

(3) 編碼層操作整體流程

可視化如下所示:

圖片

單個編碼層代碼如下所示:

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        # Q K V是同一個苫耸,自注意力
        # enc_input來自源單詞嵌入向量或者前一個編碼器輸出
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn

將上述編碼過程重復(fù)n遍即可州邢,除了第一個模塊輸入是單詞嵌入向量與位置編碼的和外,其余編碼層輸入是上一個編碼器輸出即后面的編碼器輸入不需要位置編碼向量褪子。如果考慮n個編碼器的運行過程量淌,如下所示:

class Encoder(nn.Module):
    def __init__(
            self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, dropout=0.1, n_position=200):
        # nlp領(lǐng)域的詞嵌入向量生成過程(單詞在詞表里面的索引idx-->d_word_vec長度的向量)
        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
        # 位置編碼
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)
        # n個編碼器層
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])
        # 層歸一化
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, src_seq, src_mask, return_attns=False):
        # 對輸入序列進(jìn)行詞嵌入,加上位置編碼
        enc_output = self.dropout(self.position_enc(self.src_word_emb(src_seq)))
        enc_output = self.layer_norm(enc_output)
        # 作為編碼器層輸入
        for enc_layer in self.layer_stack:
            enc_output, _ = enc_layer(enc_output, slf_attn_mask=src_mask)
        return enc_output

到目前為止我們就講完了編碼部分的全部流程和代碼細(xì)節(jié)『稚福現(xiàn)在再來看整個transformer算法就會感覺親切很多了:

圖片

1.4.3 解碼器輸入數(shù)據(jù)處理

在分析解碼器結(jié)構(gòu)前先看下解碼器整體結(jié)構(gòu)类少,方便理解:

圖片

其輸入數(shù)據(jù)處理也要區(qū)分第一個解碼器和后續(xù)解碼器,和編碼器類似渔扎,第一個解碼器輸入不僅包括最后一個編碼器輸出硫狞,還需要額外的輸出嵌入向量,而后續(xù)解碼器輸入是來自最后一個編碼器輸出和前面解碼器輸出。

(1) 目標(biāo)單詞嵌入

這個操作和源單詞嵌入過程完全相同残吩,維度也是512财忽,假設(shè)輸出是i am a student,那么需要對這4個單詞也利用word2vec算法轉(zhuǎn)化為4x512的矩陣泣侮,作為第一個解碼器的單詞嵌入輸入即彪。

(2) 位置編碼

同樣的也需要對解碼器輸入引入位置編碼,做法和編碼器部分完全相同活尊,且將目標(biāo)單詞嵌入向量和位置編碼向量相加即可作為第一個解碼器輸入劲件。

和編碼器單詞嵌入不同的地方是在進(jìn)行目標(biāo)單詞嵌入前章蚣,還需要將目標(biāo)單詞即是i am a student右移動一位,新增加的一個位置采用提前定義好的標(biāo)志位BOS_WORD代替,現(xiàn)在就變成[BOS_WORD,i,am,a,student]择浊,為啥要右移艳丛?因為解碼過程和seq2seq一樣是順序解碼的荔茬,需要提供一個開始解碼標(biāo)志霞揉,。不然第一個時間步的解碼單詞i是如何輸出的呢癣猾?具體解碼過程其實是:輸入BOS_WORD敛劝,解碼器輸出i;輸入前面已經(jīng)解碼的BOS_WORD和i纷宇,解碼器輸出am...夸盟,輸入已經(jīng)解碼的BOS_WORD、i像捶、am满俗、a和student,解碼器輸出解碼結(jié)束標(biāo)志位EOS_WORD,每次解碼都會利用前面已經(jīng)解碼輸出的所有單詞嵌入信息

下面有個非常清晰的gif圖作岖,一目了然:

圖片

上圖沒有繪制BOS_WORD嵌入向量輸入,然后解碼出i單詞的過程五芝。

1.4.4 解碼器前向過程

仔細(xì)觀察解碼器結(jié)構(gòu)痘儡,其包括:帶有mask的MultiHeadAttention、MultiHeadAttention和前饋神經(jīng)網(wǎng)絡(luò)層三個組件枢步,帶有mask的MultiHeadAttention和MultiHeadAttention結(jié)構(gòu)和代碼寫法是完全相同沉删,唯一區(qū)別是是否輸入了mask。

為啥要mask醉途?原因依然是順序解碼導(dǎo)致的矾瑰。試想模型訓(xùn)練好了,開始進(jìn)行翻譯(測試)隘擎,其流程就是上面寫的:輸入BOS_WORD殴穴,解碼器輸出i;輸入前面已經(jīng)解碼的BOS_WORD和i,解碼器輸出am...采幌,輸入已經(jīng)解碼的BOS_WORD劲够、i、am休傍、a和student征绎,解碼器輸出解碼結(jié)束標(biāo)志位EOS_WORD,每次解碼都會利用前面已經(jīng)解碼輸出的所有單詞嵌入信息,這個測試過程是沒有問題磨取,但是訓(xùn)練時候我肯定不想采用上述順序解碼類似rnn即一個一個目標(biāo)單詞嵌入向量順序輸入訓(xùn)練人柿,肯定想采用類似編碼器中的矩陣并行算法,一步就把所有目標(biāo)單詞預(yù)測出來忙厌。要實現(xiàn)這個功能就可以參考編碼器的操作凫岖,把目標(biāo)單詞嵌入向量組成矩陣一次輸入即可,但是在解碼am時候慰毅,不能利用到后面單詞a和student的目標(biāo)單詞嵌入向量信息隘截,否則這就是作弊(測試時候不可能能未卜先知)。為此引入mask汹胃,目的是構(gòu)成下三角矩陣婶芭,右上角全部設(shè)置為負(fù)無窮(相當(dāng)于忽略),從而實現(xiàn)當(dāng)解碼第一個字的時候着饥,第一個字只能與第一個字計算相關(guān)性犀农,當(dāng)解出第二個字的時候,只能計算出第二個字與第一個字和第二個字的相關(guān)性宰掉。具體是:在解碼器中呵哨,自注意力層只被允許處理輸出序列中更靠前的那些位置,在softmax步驟前轨奄,它會把后面的位置給隱去(把它們設(shè)為-inf)孟害。

還有個非常重要點需要知道(看圖示可以發(fā)現(xiàn)):解碼器內(nèi)部的帶有mask的MultiHeadAttention的qkv向量輸入來自目標(biāo)單詞嵌入或者前一個解碼器輸出,三者是相同的挪拟,但是后面的MultiHeadAttention的qkv向量中的kv來自最后一層編碼器的輸入挨务,而q來自帶有mask的MultiHeadAttention模塊的輸出。

關(guān)于帶mask的注意力層寫法其實就是前面提到的代碼:

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        # 假設(shè)q是b,8,10,64(b是batch玉组,8是head個數(shù)谎柄,10是樣本最大單詞長度,
        # 64是每個單詞的編碼向量)
        # attn輸出維度是b,8,10,10
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
        # 故mask維度也是b,8,10,10
        # 忽略b,8惯雳,只關(guān)注10x10的矩陣朝巫,其是下三角矩陣,下三角位置全1石景,其余位置全0
        if mask is not None:
            # 提前算出mask劈猿,將為0的地方變成極小值-1e9拙吉,把這些位置的值設(shè)置為忽略
            # 目的是避免解碼過程中利用到未來信息
            attn = attn.masked_fill(mask == 0, -1e9)
        # softmax+dropout
        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

可視化如下:圖片來源https://zhuanlan.zhihu.com/p/44731789

圖片

整個解碼器代碼和編碼器非常類似:

class DecoderLayer(nn.Module):
    ''' Compose with three layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(
            self, dec_input, enc_output,
            slf_attn_mask=None, dec_enc_attn_mask=None):
        # 標(biāo)準(zhǔn)的自注意力,QKV=dec_input來自目標(biāo)單詞嵌入或者前一個解碼器輸出
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        # KV來自最后一個編碼層輸出enc_output糙臼,Q來自帶有mask的self.slf_attn輸出
        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_slf_attn, dec_enc_attn

考慮n個解碼器模塊庐镐,其整體流程為:

class Decoder(nn.Module):
    def __init__(
            self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, n_position=200, dropout=0.1):
        # 目標(biāo)單詞嵌入
        self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
        # 位置嵌入向量
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)
        # n個解碼器
        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])
        # 層歸一化
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False):
        # 目標(biāo)單詞嵌入+位置編碼
        dec_output = self.dropout(self.position_enc(self.trg_word_emb(trg_seq)))
        dec_output = self.layer_norm(dec_output)
        # 遍歷每個解碼器
        for dec_layer in self.layer_stack:  
            # 需要輸入3個信息:目標(biāo)單詞嵌入+位置編碼、最后一個編碼器輸出enc_output
            # 和dec_enc_attn_mask变逃,解碼時候不能看到未來單詞信息
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)
        return dec_output

1.4.5 分類層

在進(jìn)行編碼器-解碼器后輸出依然是向量必逆,需要在后面接fc+softmax層進(jìn)行分類訓(xùn)練。假設(shè)當(dāng)前訓(xùn)練過程是翻譯任務(wù)需要輸出i am a student EOS_WORD這5個單詞揽乱。假設(shè)我們的模型是從訓(xùn)練集中學(xué)習(xí)一萬個不同的英語單詞(我們模型的“輸出詞表”)名眉。因此softmax后輸出為一萬個單元格長度的向量,每個單元格對應(yīng)某一個單詞的分?jǐn)?shù)凰棉,這其實就是普通多分類問題损拢,只不過維度比較大而已。

依然以前面例子為例撒犀,假設(shè)編碼器輸出shape是(b,100,512)福压,經(jīng)過fc后變成(b,100,10000),然后對最后一個維度進(jìn)行softmax操作或舞,得到bx100個單詞的概率分布荆姆,在訓(xùn)練過程中bx100個單詞是知道label的,故可以直接采用ce loss進(jìn)行訓(xùn)練映凳。

self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)
dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask)
return F.softmax(self.model.trg_word_prj(dec_output), dim=-1)

1.4.6 前向流程

以翻譯任務(wù)為例:

  • 將源單詞進(jìn)行嵌入胆筒,組成矩陣(加上位置編碼矩陣)輸入到n個編碼器中,輸出編碼向量KV
  • 第一個解碼器先輸入一個BOS_WORD單詞嵌入向量诈豌,后續(xù)解碼器接受該解碼器輸出仆救,結(jié)合KV進(jìn)行第一次解碼
  • 將第一次解碼單詞進(jìn)行嵌入,聯(lián)合BOS_WORD單詞嵌入向量構(gòu)成矩陣再次輸入到解碼器中進(jìn)行第二次解碼矫渔,得到解碼單詞
  • 不斷循環(huán)彤蔽,每次的第一個解碼器輸入都不同,其包含了前面時間步長解碼出的所有單詞
  • 直到輸出EOS_WORD表示解碼結(jié)束或者強制設(shè)置最大時間步長即可

這個解碼過程其實就是標(biāo)準(zhǔn)的seq2seq流程庙洼。到目前為止就描述完了整個標(biāo)準(zhǔn)transformer訓(xùn)練和測試流程铆惑。

2 視覺領(lǐng)域的transformer

在理解了標(biāo)準(zhǔn)的transformer后,再來看視覺領(lǐng)域transformer就會非常簡單送膳,因為在cv領(lǐng)域應(yīng)用transformer時候大家都有一個共識:盡量不改動transformer結(jié)構(gòu),這樣才能和NLP領(lǐng)域發(fā)展對齊丑蛤,所以大家理解cv里面的transformer操作是非常簡單的叠聋。

2.1 分類vision transformer

論文題目:An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale

論文地址:https://arxiv.org/abs/2010.11929

github: https://github.com/lucidrains/vit-pytorch

其做法超級簡單,只含有編碼器模塊:

圖片

本文出發(fā)點是徹底拋棄CNN受裹,以前的cv領(lǐng)域雖然引入transformer碌补,但是或多或少都用到了cnn或者rnn虏束,本文就比較純粹了,整個算法幾句話就說清楚了厦章,下面直接分析镇匀。

2.1.1 圖片分塊和降維

因為transformer的輸入需要序列,所以最簡單做法就是把圖片切分為patch袜啃,然后拉成序列即可汗侵。假設(shè)輸入圖片大小是256x256,打算分成64個patch群发,每個patch是32x32像素

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

這個寫法是采用了愛因斯坦表達(dá)式晰韵,具體是采用了einops庫實現(xiàn),內(nèi)部集成了各種算子熟妓,rearrange就是其中一個雪猪,非常高效。不懂這種語法的請自行百度起愈。p就是patch大小只恨,假設(shè)輸入是b,3,256,256,則rearrange操作是先變成(b,3,8x32,8x32)抬虽,最后變成(b,8x8,32x32x3)即(b,64,3072)官觅,將每張圖片切分成64個小塊,每個小塊長度是32x32x3=3072斥赋,也就是說輸入長度為64的圖像序列缰猴,每個元素采用3072長度進(jìn)行編碼。

考慮到3072有點大疤剑,故作者先進(jìn)行降維:

# 將3072變成dim滑绒,假設(shè)是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)

仔細(xì)看論文上圖,可以發(fā)現(xiàn)假設(shè)切成9個塊隘膘,但是最終到transfomer輸入是10個向量疑故,額外追加了一個0和。為啥要追加弯菊?原因是我們現(xiàn)在沒有解碼器了纵势,而是編碼后直接就進(jìn)行分類預(yù)測,那么該解碼器就要負(fù)責(zé)一點點解碼器功能管钳,那就是:需要一個類似開啟解碼標(biāo)志钦铁,非常類似于標(biāo)準(zhǔn)transformer解碼器中輸入的目標(biāo)嵌入向量右移一位操作。試下如果沒有額外輸入才漆,9個塊輸入9個編碼向量輸出牛曹,那么對于分類任務(wù)而言,我應(yīng)該取哪個輸出向量進(jìn)行后續(xù)分類呢醇滥?選擇任何一個都說不通黎比,所以作者追加了一個可學(xué)習(xí)嵌入向量輸入超营。那么額外的可學(xué)習(xí)嵌入向量為啥要設(shè)計為可學(xué)習(xí),而不是類似nlp中采用固定的token代替阅虫?個人不負(fù)責(zé)任的猜測這應(yīng)該就是圖片領(lǐng)域和nlp領(lǐng)域的差別演闭,nlp里面每個詞其實都有具體含義,是離散的颓帝,但是圖像領(lǐng)域沒有這種真正意義上的離散token米碰,有的只是一堆連續(xù)特征或者圖像像素,如果不設(shè)置為可學(xué)習(xí)躲履,那還真不知道應(yīng)該設(shè)置為啥內(nèi)容比較合適见间,全0和全1也說不通。自此現(xiàn)在就是變成10個向量輸出工猜,輸出也是10個編碼向量米诉,然后取第0個編碼輸出進(jìn)行分類預(yù)測即可。從這個角度看可以認(rèn)為編碼器多了一點點解碼器功能篷帅。具體做法超級簡單史侣,0就是位置編碼向量,是可學(xué)習(xí)的patch嵌入向量魏身。

# dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 變成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 額外追加token惊橱,變成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)

2.1.2 位置編碼

位置編碼也是必不可少的,長度應(yīng)該是1024箭昵,這里做的比較簡單税朴,沒有采用sincos編碼,而是直接設(shè)置為可學(xué)習(xí)家制,效果差不多

# num_patches=64正林,dim=1024,+1是因為多了一個cls開啟解碼標(biāo)志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

對訓(xùn)練好的pos_embedding進(jìn)行可視化,如下所示:

[圖片上傳失敗...(image-a736e7-1675912126439)]

相鄰位置有相近的位置編碼向量颤殴,整體呈現(xiàn)2d空間位置排布一樣觅廓。

將patch嵌入向量和位置編碼向量相加即可作為編碼器輸入

x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

2.1.3 編碼器前向過程

作者采用的是沒有任何改動的transformer,故沒有啥說的涵但。

self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)

假設(shè)輸入是(b,65,1024)杈绸,那么transformer輸出也是(b,65,1024)

2.1.4 分類head

在編碼器后接fc分類器head即可

self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, num_classes)
        )

# 65個輸出里面只需要第0個輸出進(jìn)行后續(xù)分類即可
self.mlp_head(x[:, 0])

到目前為止就全部寫完了,是不是非常簡單矮瘟,外層整體流程為:

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0.,emb_dropout=0.):
        super().__init__()
        # image_size輸入圖片大小 256
        # patch_size 每個patch的大小 32
        num_patches = (image_size // patch_size) ** 2  # 一共有多少個patch 8x8=64
        patch_dim = channels * patch_size ** 2  # 3x32x32=3072
        self.patch_size = patch_size  # 32
        # 1,64+1,1024,+1是因為token瞳脓,可學(xué)習(xí)變量,不是固定編碼
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 圖片維度太大了澈侠,需要先降維
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        # 分類輸出位置標(biāo)志劫侧,否則分類輸出不知道應(yīng)該取哪個位置
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        # 編碼器
        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
        # 輸出頭
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, img, mask=None):
        p = self.patch_size

        # 先把圖片變成64個patch,輸出shape=b,64,3072
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        # 輸出 b,64,1024
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape
        # 輸出 b,1,1024
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        # 額外追加token,變成b,65,1024
        x = torch.cat((cls_tokens, x), dim=1)
        # 加上位置編碼1,64+1,1024
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x, mask)
        # 分類head,只需要x[0]即可
        # x = self.to_cls_token(x[:, 0])
        x = x[:, 0]
        return self.mlp_head(x)

2.1.5 實驗分析

作者得出的結(jié)論是:cv領(lǐng)域應(yīng)用transformer需要大量數(shù)據(jù)進(jìn)行預(yù)訓(xùn)練埋涧,在同等數(shù)據(jù)量的情況下性能不然cnn板辽。一旦數(shù)據(jù)量上來了,對應(yīng)的訓(xùn)練時間也會加長很多棘催,那么就可以輕松超越cnn劲弦。

[圖片上傳失敗...(image-c461a9-1675912126439)]

[圖片上傳失敗...(image-3cf3f4-1675912126439)]

同時應(yīng)用transformer,一個突出優(yōu)點是可解釋性比較強:

[圖片上傳失敗...(image-1f259b-1675912126439)]

2.2 目標(biāo)檢測detr

論文名稱:End-to-End Object Detection with Transformers

論文地址:https://arxiv.org/abs/2005.12872

github:https://github.com/facebookresearch/detr

detr是facebook提出的引入transformer到目標(biāo)檢測領(lǐng)域的算法醇坝,效果很好邑跪,做法也很簡單,符合其一貫的簡潔優(yōu)雅設(shè)計做法呼猪。

[圖片上傳失敗...(image-b7d99b-1675912126439)]

對于目標(biāo)檢測任務(wù)画畅,其要求輸出給定圖片中所有前景物體的類別和bbox坐標(biāo),該任務(wù)實際上是無序集合預(yù)測問題宋距。針對該問題轴踱,detr做法非常簡單:給定一張圖片,經(jīng)過CNN進(jìn)行特征提取谚赎,然后變成特征序列輸入到transformer的編解碼器中淫僻,直接輸出指定長度為N的無序集合,集合中每個元素包含物體類別和坐標(biāo)壶唤。其中N表示整個數(shù)據(jù)集中圖片上最多物體的數(shù)目雳灵,因為整個訓(xùn)練和測試都Batch進(jìn)行,如果不設(shè)置最大輸出集合數(shù)闸盔,無法進(jìn)行batch訓(xùn)練悯辙,如果圖片中物體不夠N個,那么就采用no object填充迎吵,表示該元素是背景躲撰。

整個思想看起來非常簡單,相比faster rcnn或者yolo算法那就簡單太多了钓觉,因為其不需要設(shè)置先驗anchor茴肥,超參幾乎沒有,也不需要nms(因為輸出的無序集合沒有重復(fù)情況)荡灾,并且在代碼程度相比faster rcnn那就不知道簡單多少倍了瓤狐,通過簡單修改就可以應(yīng)用于全景分割任務(wù)∨希可以推測础锐,如果transformer真正大規(guī)模應(yīng)用于CV領(lǐng)域,那么對初學(xué)者來說就是福音了荧缘,理解transformer就幾乎等于理解了整個cv領(lǐng)域了(當(dāng)然也可能是壞事)皆警。

2.2.1 detr核心思想分析

相比faster rcnn等做法,detr最大特點是將目標(biāo)檢測問題轉(zhuǎn)化為無序集合預(yù)測問題截粗。論文中特意指出faster rcnn這種設(shè)置一大堆anchor信姓,然后基于anchor進(jìn)行分類和回歸其實屬于代理做法即不是最直接做法鸵隧,目標(biāo)檢測任務(wù)就是輸出無序集合,而faster rcnn等算法通過各種操作意推,并結(jié)合復(fù)雜后處理最終才得到無序集合屬于繞路了豆瘫,而detr就比較純粹了。

盡管將transformer引入目標(biāo)檢測領(lǐng)域可以避免上述各種問題菊值,但是其依然存在兩個核心操作:

  • 無序集合輸出的loss計算
  • 針對目標(biāo)檢測的transformer改進(jìn)

2.2.2 detr算法實現(xiàn)細(xì)節(jié)

下面結(jié)合代碼和原理對其核心環(huán)節(jié)進(jìn)行深入分析外驱。

2.2.2.1 無序集合輸出的loss計算

在分析loss計算前,需要先明確N個無序集合的target構(gòu)建方式腻窒。作者在coco數(shù)據(jù)集上統(tǒng)計昵宇,一張圖片最多標(biāo)注了63個物體,所以N應(yīng)該要不小于63儿子,作者設(shè)置的是100瓦哎。為啥要設(shè)置為100?有人猜測是和coco評估指標(biāo)只取前100個預(yù)測結(jié)果算法指標(biāo)有關(guān)系典徊。

detr輸出是包括batchx100個無序集合杭煎,每個集合包括類別和坐標(biāo)信息。對于coco數(shù)據(jù)而言卒落,作者設(shè)置類別為91(coco類別標(biāo)注索引是1-91,但是實際就標(biāo)注了80個類別)羡铲,加上背景一共92個類別,對于坐標(biāo)分支采用4個歸一化值表征即cxcywh中心點儡毕、wh坐標(biāo)也切,然后除以圖片寬高進(jìn)行歸一化(沒有采用復(fù)雜變換策略),故每個集合是 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Ah59wMHA-1675912067642)(null)] 腰湾,c是長度為92的分類向量雷恃,b是長度為4的bbox坐標(biāo)向量》逊唬總之detr輸出集合包括兩個分支:分類分支shape=(b,100,92)倒槐,bbox坐標(biāo)分支shape=(b,100,4),對應(yīng)的target也是包括分類target和bbox坐標(biāo)target附井,如果不夠100讨越,則采用背景填充,計算loss時候bbox分支僅僅計算有物體位置永毅,背景集合忽略把跨。

現(xiàn)在核心問題來了:輸出的bx100個檢測結(jié)果是無序的,如何和gt bbox計算loss沼死?這就需要用到經(jīng)典的雙邊匹配算法了着逐,也就是常說的匈牙利算法,該算法廣泛應(yīng)用于最優(yōu)分配問題,在bottom-up人體姿態(tài)估計算法中進(jìn)行分組操作時候也經(jīng)常使用耸别。detr中利用匈牙利算法先進(jìn)行最優(yōu)一對一匹配得到匹配索引健芭,然后對bx100個結(jié)果進(jìn)行重排就和gt bbox對應(yīng)上了(對gt bbox進(jìn)行重排也可以,沒啥區(qū)別)秀姐,就可以算loss了吟榴。

匈牙利算法是一個標(biāo)準(zhǔn)優(yōu)化算法,具體是組合優(yōu)化算法囊扳,在scipy.optimize.linear_sum_assignmen函數(shù)中有實現(xiàn),一行代碼就可以得到最優(yōu)匹配兜看,網(wǎng)上解讀也非常多锥咸,這里就不寫細(xì)節(jié)了,該函數(shù)核心是需要輸入A集合和B集合兩兩元素之間的連接權(quán)重细移,基于該重要性進(jìn)行內(nèi)部最優(yōu)匹配搏予,連接權(quán)重大的優(yōu)先匹配

上述描述優(yōu)化過程可以采用如下公式表達(dá):

[圖片上傳失敗...(image-688515-1675912126439)]

優(yōu)化對象是 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-A4lkvnJa-1675912068640)(null)] 弧轧,其是長度為N的list雪侥, [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-kruQ4Upr-1675912069307)(null)] , [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-uzZtmRDR-1675912070767)(null)] 表示無序gt bbox集合的哪個元素和輸出預(yù)測集合中的第i個匹配精绎。其實簡單來說就是找到最優(yōu)匹配速缨,因為在最佳匹配情況下l_match和最小即loss最小。

前面說過匈牙利算法核心是需要提供輸入A集合和B集合兩兩元素之間的連接權(quán)重代乃,這里就是要輸入N個輸出集合和M個gt bbox之間的關(guān)聯(lián)程度旬牲,如下所示

[圖片上傳失敗...(image-ded520-1675912126439)]

而Lbox具體是:

[圖片上傳失敗...(image-61c8ad-1675912126439)]

Hungarian意思就是匈牙利,也就是前面的L_match搁吓,上述意思是需要計算M個gt bbox和N個輸出集合兩兩之間的廣義距離原茅,距離越近表示越可能是最優(yōu)匹配關(guān)系,也就是兩者最密切堕仔。廣義距離的計算考慮了分類分支和bbox分支擂橘,下面結(jié)合代碼直接說明,比較簡單摩骨。

# detr分類輸出通贞,num_queries=100,shape是(b,100,92)
bs, num_queries = outputs["pred_logits"].shape[:2]
# 得到概率輸出(bx100,92)
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) 
# 得到bbox分支輸出(bx100,4)
out_bbox = outputs["pred_boxes"].flatten(0, 1)

# 準(zhǔn)備分類target shape=(m,)里面存儲的是類別索引仿吞,m包括了整個batch內(nèi)部的所有g(shù)t bbox
tgt_ids = torch.cat([v["labels"] for v in targets]) 
# 準(zhǔn)備bbox target shape=(m,4)滑频,已經(jīng)歸一化了
tgt_bbox = torch.cat([v["boxes"] for v in targets])  

#核心
#bx100,92->bx100,m,對于每個預(yù)測結(jié)果唤冈,把目前gt里面有的所有類別值提取出來峡迷,其余值不需要參與匹配
#對應(yīng)上述公式,類似于nll loss,但是更加簡單
cost_class = -out_prob[:, tgt_ids]  
#計算out_bbox和tgt_bbox兩兩之間的l1距離 bx100,m
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
#額外多計算一個giou loss bx100,m
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

#得到最終的廣義距離bx100,m绘搞,距離越小越可能是最優(yōu)匹配
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
# bx100,m--> batch,100,m
C = C.view(bs, num_queries, -1).cpu() 

#計算每個batch內(nèi)部有多少物體彤避,后續(xù)計算時候按照單張圖片進(jìn)行匹配,沒必要batch級別匹配,徒增計算
sizes = [len(v["boxes"]) for v in targets]
#匈牙利最優(yōu)匹配夯辖,返回匹配索引
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

在得到匹配關(guān)系后算loss就水到渠成了琉预。分類分支計算ce loss,bbox分支計算l1 loss+giou loss

def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
    #shape是(b,100,92)
    src_logits = outputs['pred_logits']
  #得到匹配后索引蒿褂,作用在label上
    idx = self._get_src_permutation_idx(indices) 
    #得到匹配后的分類target
    target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
    #加入背景(self.num_classes)圆米,補齊bx100個
    target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                dtype=torch.int64, device=src_logits.device)
    #shape是(b,100,),存儲的是索引,不是one-hot
    target_classes[idx] = target_classes_o
    #計算ce loss,self.empty_weight前景和背景權(quán)重是1和0.1,克服類別不平衡
    loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
    losses = {'loss_ce': loss_ce}
    return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
    idx = self._get_src_permutation_idx(indices)
    src_boxes = outputs['pred_boxes'][idx]
    target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
    #l1 loss
    loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

    losses = {}
    losses['loss_bbox'] = loss_bbox.sum() / num_boxes
    #giou loss
    loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
        box_ops.box_cxcywh_to_xyxy(src_boxes),
        box_ops.box_cxcywh_to_xyxy(target_boxes)))
    losses['loss_giou'] = loss_giou.sum() / num_boxes
    return losses

2.2.2.2 針對目標(biāo)檢測的transformer改進(jìn)

分析完訓(xùn)練最關(guān)鍵的:雙邊匹配+loss計算部分啄栓,現(xiàn)在需要考慮在目標(biāo)檢測算法中transformer如何設(shè)計娄帖?下面按照算法的4個步驟講解。

[圖片上傳失敗...(image-39eacb-1675912126439)]

transformer細(xì)節(jié)如下:

[圖片上傳失敗...(image-43566-1675912126439)]

(1) cnn骨架特征提取

骨架網(wǎng)絡(luò)可以是任何一種昙楚,作者選擇resnet50近速,將最后一個stage即stride=32的特征圖作為編碼器輸入。由于resnet僅僅作為一個小部分且已經(jīng)經(jīng)過了imagenet預(yù)訓(xùn)練堪旧,故和常規(guī)操作一樣削葱,會進(jìn)行如下操作:

  • resnet中所有BN都固定,即采用全局均值和方差
  • resnet的stem和第一個stage不進(jìn)行參數(shù)更新淳梦,即parameter.requires_grad_(False)
  • backbone的學(xué)習(xí)率小于transformer,lr_backbone=1e-05,其余為0.0001

假設(shè)輸入是(b,c,h,w)析砸,則resnet50輸出是(b,1024,h//32,w//32),1024比較大爆袍,為了節(jié)省計算量干厚,先采用1x1卷積降維為256,最后轉(zhuǎn)化為序列格式輸入到transformer中,輸入shape=(h'xw',b,256)螃宙,h'=h//32

self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
# 輸出是(b,256,h//32,w//32)
src=self.input_proj(src)
# 變成序列模式蛮瞄,(h'xw',b,256),256是每個詞的編碼長度
src = src.flatten(2).permute(2, 0, 1)

(2) 編碼器設(shè)計和輸入

編碼器結(jié)構(gòu)設(shè)計沒有任何改變,但是輸入改變了谆扎。

a) 位置編碼需要考慮2d空間

由于圖像特征是2d特征挂捅,故位置嵌入向量也需要考慮xy方向。前面說過編碼方式可以采用sincos堂湖,也可以設(shè)置為可學(xué)習(xí)闲先,本文采用的依然是sincos模式,和前面說的一樣无蜂,但是需要考慮xy兩個方向(前面說的序列只有x方向)伺糠。

#輸入是b,c,h,w
#tensor_list的類型是NestedTensor,內(nèi)部自動附加了mask斥季,
#用于表示動態(tài)shape训桶,是pytorch中tensor新特性https://github.com/pytorch/nestedtensor
x = tensor_list.tensors # 原始tensor數(shù)據(jù)
# 附加的mask累驮,shape是b,h,w 全是false
mask = tensor_list.mask  
not_mask = ~mask
# 因為圖像是2d的,所以位置編碼也分為x,y方向
# 1 1 1 1 ..  2 2 2 2... 3 3 3...
y_embed = not_mask.cumsum(1, dtype=torch.float32) 
# 1 2 3 4 ... 1 2 3 4...
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
    eps = 1e-6
    y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
    x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

# 0~127 self.num_pos_feats=128,因為前面輸入向量是256舵揭,編碼是一半sin谤专,一半cos
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
# 歸一化
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# 輸出shape=b,h,w,128
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# 每個特征圖的xy位置都編碼成256的向量,其中前128是y方向編碼午绳,而128是x方向編碼
return pos  # b,n=256,h,w

可以看出對于h//32,w//32的2d圖像特征置侍,不是類似vision transoformer做法簡單的將其拉伸為h//32 x w//32,然后從0-n進(jìn)行長度為256的位置編碼拦焚,而是考慮了xy方向同時編碼蜡坊,每個方向各編碼128維向量,這種編碼方式更符合圖像特定赎败。

還有一個細(xì)節(jié)需要注意:原始transformer的n個編碼器輸入中算色,只有第一個編碼器需要輸入位置編碼向量,但是detr里面對每個編碼器都輸入了同一個位置編碼向量螟够,論文中沒有寫為啥要如此修改。

b) QKV處理邏輯不同

作者設(shè)置編碼器一共6個峡钓,并且位置編碼向量僅僅加到QK中妓笙,V中沒有加入位置信息,這個和原始做法不一樣能岩,原始做法是QKV都加上了位置編碼寞宫,論文中也沒有寫為啥要如此修改。

其余地方就完全相同了腐芍,故代碼就沒必要貼了稚照§排遥總結(jié)下和原始transformer編碼器不同的地方:

  • 輸入編碼器的位置編碼需要考慮2d空間位置
  • 位置編碼向量需要加入到每個編碼器中
  • 在編碼器內(nèi)部位置編碼僅僅和QK相加,V不做任何處理

經(jīng)過6個編碼器forward后钥屈,輸出shape為(h//32xw//32,b,256)。

c) 編碼器部分整體運行流程

6個編碼器整體forward流程如下:

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        # 編碼器copy6份
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        # 內(nèi)部包括6個編碼器坝辫,順序運行
        # src是圖像特征輸入篷就,shape=hxw,b,256
        output = src
        for layer in self.layers:
            # 每個編碼器都需要加入pos位置編碼
            # 第一個編碼器輸入來自圖像特征,后面的編碼器輸入來自前一個編碼器輸出
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        return output

每個編碼器內(nèi)部運行流程如下:

def forward_post(self,
                src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
               pos: Optional[Tensor] = None):
    # 和標(biāo)準(zhǔn)做法有點不一樣近忙,src加上位置編碼得到q和k竭业,但是v依然還是src,
    # 也就是v和qk不一樣
    q = k = src+pos
    src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
    src = src + self.dropout1(src2)
    src = self.norm1(src)
    src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
    src = src + self.dropout2(src2)
    src = self.norm2(src)
    return src

(3) 解碼器設(shè)計和輸入

解碼器結(jié)構(gòu)設(shè)計沒有任何改變及舍,但是輸入也改變了未辆。

a) 新引入Object queries

object queries(shape是(100,256))可以簡單認(rèn)為是輸出位置編碼,其作用主要是在學(xué)習(xí)過程中提供目標(biāo)對象和全局圖像之間的關(guān)系,相當(dāng)于全局注意力,必不可少非常關(guān)鍵锯玛。代碼形式上是可學(xué)習(xí)位置編碼矩陣咐柜。和編碼器一樣,該可學(xué)習(xí)位置編碼向量也會輸入到每一個解碼器中。我們可以嘗試通俗理解:object queries矩陣內(nèi)部通過學(xué)習(xí)建模了100個物體之間的全局關(guān)系炕桨,例如房間里面的桌子旁邊(A類)一般是放椅子(B類)饭尝,而不會是放一頭大象(C類),那么在推理時候就可以利用該全局注意力更好的進(jìn)行解碼預(yù)測輸出献宫。

# num_queries=100,hidden_dim=256
self.query_embed = nn.Embedding(num_queries, hidden_dim)

論文中指出object queries作用非常類似faster rcnn中的anchor钥平,只不過這里是可學(xué)習(xí)的,不是提前設(shè)置好的姊途。

b) 位置編碼也需要

編碼器環(huán)節(jié)采用的sincos位置編碼向量也可以考慮引入涉瘾,且該位置編碼向量輸入到每個解碼器的第二個Multi-Head Attention中,后面有是否需要該位置編碼的對比實驗捷兰。

c) QKV處理邏輯不同

解碼器一共包括6個立叛,和編碼器中QKV一樣,V不會加入位置編碼贡茅。上述說的三個操作秘蛇,只要看下網(wǎng)絡(luò)結(jié)構(gòu)圖就一目了然了。

d) 一次解碼輸出全部無序集合

和原始transformer順序解碼操作不同的是顶考,detr一次就把N個無序框并行輸出了(因為任務(wù)是無序集合赁还,做成順序推理有序輸出沒有很大必要)。為了說明如何實現(xiàn)該功能驹沿,我們需要先回憶下原始transformer的順序解碼過程:輸入BOS_WORD艘策,解碼器輸出i;輸入前面已經(jīng)解碼的BOS_WORD和i渊季,解碼器輸出am...朋蔫,輸入已經(jīng)解碼的BOS_WORD、i却汉、am驯妄、a和student,解碼器輸出解碼結(jié)束標(biāo)志位EOS_WORD,每次解碼都會利用前面已經(jīng)解碼輸出的所有單詞嵌入信息『仙埃現(xiàn)在就是一次解碼富玷,故只需要初始化時候輸入一個全0的查詢向量A,類似于BOS_WORD作用既穆,然后第一個解碼器接受該輸入A赎懦,解碼輸出向量作為下一個解碼器輸入,不斷推理即可幻工,最后一層解碼輸出即為我們需要的輸出励两,不需要在第二個解碼器輸入時候考慮BOS_WORD和第一個解碼器輸出。

總結(jié)下和原始transformer解碼器不同的地方:

  • 額外引入可學(xué)習(xí)的Object queries囊颅,相當(dāng)于可學(xué)習(xí)anchor当悔,提供全局注意力
  • 編碼器采用的sincos位置編碼向量也需要輸入解碼器中傅瞻,并且每個解碼器都輸入
  • QKV處理邏輯不同
  • 不需要順序解碼,一次即可輸出N個無序集合

e) 解碼器整體運行流程

n個解碼器整體流程如下:

class TransformerDecoder(nn.Module):
    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        # 首先query_pos是query_embed盲憎,可學(xué)習(xí)輸出位置向量shape=100,b,256
        # tgt = torch.zeros_like(query_embed),用于進(jìn)行一次性解碼輸出
        output = tgt
        # 存儲每個解碼器輸出嗅骄,后面中繼監(jiān)督需要
        intermediate = []
        # 編碼每個解碼器
        for layer in self.layers:
            # 每個解碼器都需要輸入query_pos和pos
            # memory是最后一個編碼器輸出
            # 每個解碼器都接受output作為輸入,然后輸出新的output
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))
        if self.return_intermediate:
            return torch.stack(intermediate)  # 6個輸出都返回
        return output.unsqueeze(0)

內(nèi)部每個解碼器運行流程為:

def forward_post(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
    # query_pos首先是可學(xué)習(xí)的饼疙,其作用主要是在學(xué)習(xí)過程中提供目標(biāo)對象和全局圖像之間的關(guān)系
    # 這個相當(dāng)于全局注意力輸入溺森,是非常關(guān)鍵的
    # query_pos是解碼器特有
    q = k = tgt+query_pos
    # 第一個自注意力模塊
    tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                        key_padding_mask=tgt_key_padding_mask)[0]
    tgt = tgt + self.dropout1(tgt2)
    tgt = self.norm1(tgt)
    # memory是最后一個編碼器輸出,pos是和編碼器輸入中完全相同的sincos位置嵌入向量
    # 輸入?yún)?shù)是最核心細(xì)節(jié)窑眯,query是tgt+query_pos屏积,而key是memory+pos
    # v直接用memory
    tgt2 = self.multihead_attn(query=tgt+query_pos,
                            key=memory+pos,
                            value=memory, attn_mask=memory_mask,
                            key_padding_mask=memory_key_padding_mask)[0]
    tgt = tgt + self.dropout2(tgt2)
    tgt = self.norm2(tgt)
    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
    tgt = tgt + self.dropout3(tgt2)
    tgt = self.norm3(tgt)
    return tgt

解碼器最終輸出shape是(6,b,100,256),6是指6個解碼器的輸出磅甩。

(4) 分類和回歸head

在解碼器輸出基礎(chǔ)上構(gòu)建分類和bbox回歸head即可輸出檢測結(jié)果炊林,比較簡單:

self.class_embed = nn.Linear(256, 92)
self.bbox_embed = MLP(256, 256, 4, 3)

# hs是(6,b,100,256),outputs_class輸出(6,b,100,92)卷要,表示6個分類分支
outputs_class = self.class_embed(hs)
# 輸出(6,b,100,4)渣聚,表示6個bbox坐標(biāo)回歸分支
outputs_coord = self.bbox_embed(hs).sigmoid() 
# 取最后一個解碼器輸出即可,分類輸出(b,100,92)僧叉,bbox回歸輸出(b,100,4)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
    # 除了最后一個輸出外奕枝,其余編碼器輸出都算輔助loss
    out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)

作者實驗發(fā)現(xiàn),如果對解碼器的每個輸出都加入輔助的分類和回歸loss彪标,可以提升性能,故作者除了對最后一個編碼層的輸出進(jìn)行Loss監(jiān)督外掷豺,還對其余5個編碼器采用了同樣的loss監(jiān)督捞烟,只不過權(quán)重設(shè)置低一點而已。

(5) 整體推理流程

基于transformer的detr算法当船,作者特意強調(diào)其突出優(yōu)點是部署代碼不超過50行题画,簡單至極。

[圖片上傳失敗...(image-387c-1675912126439)]

當(dāng)然上面是簡化代碼德频,和實際代碼不一樣苍息。具體流程是:

  • 將(b,3,800,1200)圖片輸入到resnet50中進(jìn)行特征提取,輸出shape=(b,1024,25,38)

  • 通過1x1卷積降維,變成(b,256,25,38)

  • 利用sincos函數(shù)計算位置編碼

  • 將圖像特征和位置編碼向量相加壹置,作為編碼器輸入竞思,輸出編碼后的向量,shape不變

  • 初始化全0的(100,b,256)的輸出嵌入向量钞护,結(jié)合位置編碼向量和query_embed盖喷,進(jìn)行解碼輸出,解碼器輸出shape為(6,b,100,256)

  • 將最后一個解碼器輸出輸入到分類和回歸head中难咕,得到100個無序集合

  • 對100個無序集合進(jìn)行后處理课梳,主要是提取前景類別和對應(yīng)的bbox坐標(biāo)距辆,乘上(800,1200)即可得到最終坐標(biāo),后處理代碼如下:

  • prob = F.softmax(out_logits, -1)
    scores, labels = prob[..., :-1].max(-1)
    # convert to [x0, y0, x1, y1] format
    boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
    # and from relative [0, 1] to absolute [0, height] coordinates
    img_h, img_w = target_sizes.unbind(1)
    scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
    boxes = boxes * scale_fct[:, None, :]
    results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
    

既然訓(xùn)練時候?qū)?個解碼器輸出都進(jìn)行了loss監(jiān)督,那么在測試時候也可以考慮將6個解碼器的分類和回歸分支輸出結(jié)果進(jìn)行nms合并暮刃,稍微有點性能提升跨算。

2.2.3 實驗分析

(1) 性能對比

[圖片上傳失敗...(image-9fcdc5-1675912126439)]

Faster RCNN-DC5是指的resnet的最后一個stage采用空洞率=stride設(shè)置代替stride,目的是在不進(jìn)行下采樣基礎(chǔ)上擴大感受野椭懊,輸出特征圖分辨率保持不變诸蚕。+號代表采用了額外的技巧提升性能例如giou、多尺度訓(xùn)練和9xepoch訓(xùn)練策略灾搏〈焱可以發(fā)現(xiàn)detr效果稍微好于faster rcnn各種版本,證明了視覺transformer的潛力狂窑。但是可以發(fā)現(xiàn)其小物體檢測能力遠(yuǎn)遠(yuǎn)低于faster rcnn媳板,這是一個比較大的弊端。

(2) 各個模塊分析

[圖片上傳失敗...(image-c5f1f8-1675912126439)]

編碼器數(shù)目越多效果越好泉哈,但是計算量也會增加很多蛉幸,作者最終選擇的是6。

[圖片上傳失敗...(image-438388-1675912126439)]

可以發(fā)現(xiàn)解碼器也是越多越好丛晦,還可以觀察到第一個解碼器輸出預(yù)測效果比較差奕纫,增加第二個解碼器后性能提升非常多。上圖中的NMS操作是指既然我們每個解碼層都可以輸入無序集合烫沙,那么將所有解碼器無序集合全部保留匹层,然后進(jìn)行nms得到最終輸出,可以發(fā)現(xiàn)性能稍微有提升锌蓄,特別是AP50升筏。

[圖片上傳失敗...(image-d44eec-1675912126439)]

作者對比了不同類型的位置編碼效果,因為query_embed(output pos)是必不可少的瘸爽,所以該列沒有進(jìn)行對比實驗您访,始終都有,最后一行效果最好剪决,所以作者采用的就是該方案灵汪,sine at attn表示每個注意力層都加入了sine位置編碼,相比僅僅在input增加位置編碼效果更好柑潦。

(3) 注意力可視化

前面說過transformer具有很好的可解釋性享言,故在訓(xùn)練完成后最終提出了幾種可視化形式

a) bbox輸出可視化

[圖片上傳失敗...(image-3e2195-1675912126439)]

這個就比較簡單了,直接對預(yù)測進(jìn)行后處理即可

probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
# 只保留概率大于0.9的bbox
keep = probas.max(-1).values > 0.9
# 還原到原圖渗鬼,然后繪制即可
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
plot_results(im, probas[keep], bboxes_scaled)

b) 解碼器自注意力層權(quán)重可視化

[圖片上傳失敗...(image-78456f-1675912126439)]

這里指的是最后一個解碼器內(nèi)部的第一個MultiheadAttention的自注意力權(quán)重担锤,其實就是QK相似性計算后然后softmax后的輸出可視化,具體是:

# multihead_attn注冊前向hook乍钻,output[1]指的就是softmax后輸出
model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
    lambda self, input, output: dec_attn_weights.append(output[1])
)
# 假設(shè)輸入是(1,3,800,1066)
outputs = model(img)
# 那么dec_attn_weights是(1,100,850=800//32x1066//32)
# 這個就是QK相似性計算后然后softmax后的輸出肛循,即自注意力權(quán)重
dec_attn_weights = dec_attn_weights[0]

# 如果想看哪個bbox的權(quán)重铭腕,則輸入idx即可
dec_attn_weights[0, idx].view(800//32, 1066//32)

c) 編碼器自注意力層權(quán)重可視化

[圖片上傳失敗...(image-a3b93b-1675912126439)]

這個和解碼器操作完全相同。

model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
    lambda self, input, output: enc_attn_weights.append(output[1])
)
outputs = model(img)
# 最后一個編碼器中的自注意力模塊權(quán)重輸出(b,h//32xw//32,h//32xw//32)多糠,其實就是qk計算然后softmax后的值即(1,25x34=850,850)
enc_attn_weights = enc_attn_weights[0]

# 變成(25, 34, 25, 34)
sattn = enc_attn_weights[0].reshape(shape + shape)

# 想看哪個特征點位置的注意力
idxs = [(200, 200), (280, 400), (200, 600), (440, 800), ]

for idx_o, ax in zip(idxs, axs):
    # 轉(zhuǎn)化到特征圖尺度
    idx = (idx_o[0] // fact, idx_o[1] // fact)
    # 直接sattn[..., idx[0], idx[1]]即可
    ax.imshow(sattn[..., idx[0], idx[1]], cmap='cividis', interpolation='nearest')

2.2.4 小結(jié)

detr整體做法非常簡單累舷,基本上沒有改動原始transformer結(jié)構(gòu),其顯著優(yōu)點是:不需要設(shè)置啥先驗夹孔,超參也比較少被盈,訓(xùn)練和部署代碼相比faster rcnn算法簡單很多,理解上也比較簡單搭伤。但是其缺點是:改了編解碼器的輸入只怎,在論文中也沒有解釋為啥要如此設(shè)計,而且很多操作都是實驗對比才確定的怜俐,比較迷身堡。算法層面訓(xùn)練epoch次數(shù)遠(yuǎn)遠(yuǎn)大于faster rcnn(300epoch),在同等epoch下明顯性能不如faster rcnn拍鲤,而且訓(xùn)練占用內(nèi)存也大于faster rcnn贴谎。

整體而言,雖然效果不錯季稳,但是整個做法還是顯得比較原始擅这,很多地方感覺是嘗試后得到的做法,沒有很好的解釋性景鼠,而且最大問題是訓(xùn)練epoch非常大和內(nèi)存占用比較多仲翎,對應(yīng)的就是收斂慢,期待后續(xù)作品铛漓。

3 總結(jié)

本文從transformer發(fā)展歷程入手溯香,并且深入介紹了transformer思想和實現(xiàn)細(xì)節(jié);最后結(jié)合計算機視覺領(lǐng)域的幾篇有典型代表文章進(jìn)行深入分析票渠,希望能夠給cv領(lǐng)域想快速理解transformer的初學(xué)者一點點幫助逐哈。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末芬迄,一起剝皮案震驚了整個濱河市问顷,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌禀梳,老刑警劉巖杜窄,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異算途,居然都是意外死亡塞耕,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進(jìn)店門嘴瓤,熙熙樓的掌柜王于貴愁眉苦臉地迎上來扫外,“玉大人莉钙,你說我怎么就攤上這事∩秆瑁” “怎么了磁玉?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長驾讲。 經(jīng)常有香客問我蚊伞,道長,這世上最難降的妖魔是什么吮铭? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任时迫,我火速辦了婚禮,結(jié)果婚禮上谓晌,老公的妹妹穿的比我還像新娘掠拳。我一直安慰自己,他們只是感情好扎谎,可當(dāng)我...
    茶點故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布碳想。 她就那樣靜靜地躺著,像睡著了一般毁靶。 火紅的嫁衣襯著肌膚如雪胧奔。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天预吆,我揣著相機與錄音龙填,去河邊找鬼。 笑死拐叉,一個胖子當(dāng)著我的面吹牛岩遗,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播凤瘦,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼宿礁,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了蔬芥?” 一聲冷哼從身側(cè)響起梆靖,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎笔诵,沒想到半個月后返吻,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨居荒郊野嶺守林人離奇死亡乎婿,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年测僵,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片谢翎。...
    茶點故事閱讀 39,690評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡捍靠,死狀恐怖沐旨,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情榨婆,我是刑警寧澤希俩,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布,位于F島的核電站纲辽,受9級特大地震影響颜武,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜拖吼,卻給世界環(huán)境...
    茶點故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一鳞上、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧吊档,春花似錦篙议、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至香璃,卻和暖如春这难,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背葡秒。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工姻乓, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人眯牧。 一個月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓蹋岩,卻偏偏與公主長得像,于是被迫代替她去往敵國和親学少。 傳聞我的和親對象是個殘疾皇子剪个,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,577評論 2 353

推薦閱讀更多精彩內(nèi)容