Visual Transformer (ViT) 代碼實現(xiàn) PyTorch版本

簡介

本文的目的是通過實際代碼編寫來實現(xiàn)ViT模型,進一步加對ViT模型的理解恨锚,如果還不知道ViT模型的話宇驾,可以先看下博客了解一下ViT的整體結(jié)構(gòu)。
本文整體是對Implementing Vision Transformer (ViT) in PyTorch 的翻譯猴伶,但是也加上了一些自己的注解课舍。如果讀者更習慣看英文版,建議直接去看原文他挎。

ViT模型整體結(jié)構(gòu)

按照慣例筝尾,先放上模型的架構(gòu)圖,如下:

ViT模型
輸入圖片被劃分為一個個16x16的小塊办桨,也叫做patch筹淫。接著這些patch被送入一個全連接層得到embeddings,然后在embeddings前前加上一個特殊的cls token呢撞。然后給所有的embedding加上位置信息編碼positional encoding损姜。接著這個整體被送入Transformer Encoder饰剥,然后取cls token的輸出特征送入MLP Head去做分類,總體流程就是這樣摧阅。
代碼的整體結(jié)構(gòu)跟ViT模型的結(jié)構(gòu)類似汰蓉,大體可以分為以下幾個部分:

我們將以自底向上的方式來逐步實現(xiàn)ViT模型。

Data

首先需要導(dǎo)入相關(guān)的依賴庫棒卷,如下:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

首先我們需要打開一張圖片顾孽,如下:

img = Image.open('./cat.jpg')
fig = plt.figure()
plt.imshow(img)

結(jié)果:

Cat

接著我們需要對圖片進行預(yù)處理,主要是包含resize比规、向量化等操作若厚,代碼如下:

# resize to imagenet size 
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # 主要是為了添加batch這個維度
x.shape
輸出

Pathches Embeddings

根據(jù)ViT模型結(jié)構(gòu),第一步是需要將圖片劃分為多個Patches苞俘,并且將其鋪平盹沈。如下圖:

原文的描述如下:
看起來很復(fù)雜,但是我們可以使用einops庫來簡化代碼編寫吃谣,如下:

patch_size = 16 # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
pathes.shape

結(jié)果

關(guān)于einops庫的使用可以參考doc乞封。這里解釋一下這個結(jié)果[1,196,768]是怎么來的。我們知道原始圖片向量x的大小為[1,3,224,224]岗憋,當我們使用16x16大小的patch對其進行分割的時候肃晚,一共可以劃分為224x224/16/16 = 196個patches,其次每個patch大小為16x16x3=768仔戈,故大小為[1,196,768]关串。
接著我們需要將這些patches通過一個線性映射層。
這里可以定義一個名為PatchEmbedding的類來使代碼更加整潔:

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            # 注意這里的隱層大小設(shè)置的也是768监徘,可以配置
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape

實際查看原作者的代碼晋修,他并沒有使用線性映射層來做這件事,出于效率考慮凰盔,作者使用了Conv2d層來實現(xiàn)相同的功能墓卦。這是通過設(shè)置卷積核大小和步長均為patch_size來實現(xiàn)的。直觀上來看户敬,卷積操作是分別應(yīng)用在每個patch上的落剪。所以,我們可以先應(yīng)用一個卷積層尿庐,然后再對結(jié)果進行鋪平忠怖,改進如下:

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            # 將卷積操作后的patch鋪平
            Rearrange('b e h w -> b (h w) e'),
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape

CLS Token

下一步是對映射后的patches添加上cls token以及位置編碼信息。cls token是一個隨機初始化的torch Parameter對象抄瑟,在forward方法中它需要被拷貝b次(b是batch的數(shù)量)凡泣,然后使用torch.cat函數(shù)添加到patch前面。

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.proj = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        # 生成一個維度為emb_size的向量當做cls_token
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape # 單獨先將batch緩存起來
        x = self.proj(x) # 進行卷積操作
        # 將cls_token 擴展b次
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        print(cls_tokens.shape)
        print(x.shape)
        # prepend the cls token to the input on the dim 1
        x = torch.cat([cls_tokens, x], dim=1)
        return x
    
PatchEmbedding()(x).shape

Position Embedding

目前為止,模型還對patches的在圖像中的原始位置一無所知问麸。我們需要傳遞給模型這些空間上的信息往衷〕瑁可以有很多種方法來實現(xiàn)這個功能严卖,在ViT中,我們讓模型自己去學(xué)習這個布轿。位置編碼信息只是一個形狀為[N_PATCHES+1(token)m EMBED_SIZE]的張量哮笆,它直接加到映射后的patches上。

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        # 位置編碼信息汰扭,一共有(img_size // patch_size)**2 + 1(cls token)個位置向量
        self.positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        print(x.shape, self.positions.shape)
        x += self.positions
        return x
    
PatchEmbedding()(x).shape

我們首先定義位置embedding向量稠肘,然后在forward函數(shù)中將其加到線性映射后的patches向量上去。

Transformer

現(xiàn)在我們來實現(xiàn)Transformer模塊萝毛。ViT模型中只使用了Transformer的Encoder部分项阴,其整體架構(gòu)如下:
Encoder架構(gòu)

接下來依次實現(xiàn)。

Attention

attention部分有三個輸入笆包,分別是queries,keys,values矩陣环揽,首先使用queries,keys矩陣去計算注意力矩陣,然后與values矩陣相乘庵佣,得到對應(yīng)的輸出歉胶。在下圖中,multi-head注意力機制表示將輸入劃分成n份巴粪,然后將計算分到n個head上去通今。

Attentionn
我們可以使用pytorch的nn.MultiAttention模塊或者自己實現(xiàn)一個,為了完整起見肛根,我將完整的MultiAttention代碼貼出來:

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

接下來辫塌,一步一步分析下上述代碼。我們定義了4個全連接層派哲,分別用于queries,keys,values,以及最后的線性映射層臼氨。關(guān)于這塊更加詳細的內(nèi)容可以閱讀The Illustrated Transformer 。主要的思想是使用querieskeys之間的乘積來計算輸入序列中的每一個patch與剩余patch之間的匹配程度狮辽。然后使用這個匹配程度(數(shù)值)去對對應(yīng)的values做縮放一也,再累加起來作為Encoder的輸出。
forward方法將上一層的輸出作為輸入喉脖,使用三個線性映射層分別得到queries,keys,values椰苟。因為我們要實現(xiàn)multi-head注意力機制,我們需要將輸出重排成多個head的形式树叽。這一步是使用einops庫的rearrange函數(shù)來完成的舆蝴。
Queries,keys,values的形狀是一樣的,為了簡便起見,它們都是基于同一個輸入x洁仗。

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)

經(jīng)過rearrange操作之后层皱,Queries,keys,values的形狀大小為[BATCH, HEADS, SEQUENCE_LEN, EMBEDDING_SIZE]. 然后我們將多個head拼接在一起就得到了最終的輸出。

注意: 為了加快計算赠潦,我們可以使用單個矩陣一次性計算出Queries,keys,values叫胖。

改進后的代碼如下:

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        print("1qkv's shape: ", self.qkv(x).shape)
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        print("2qkv's shape: ", qkv.shape)
        
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        print("queries's shape: ", queries.shape)
        print("keys's shape: ", keys.shape)
        print("values's shape: ", values.shape)
        
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        print("energy's shape: ", energy.shape)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        
        scaling = self.emb_size ** (1/2)
        print("scaling: ", scaling)
        att = F.softmax(energy, dim=-1) / scaling
        print("att1' shape: ", att.shape)
        att = self.att_drop(att)
        print("att2' shape: ", att.shape)
        
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        print("out1's shape: ", out.shape)
        out = rearrange(out, "b h n d -> b n (h d)")
        print("out2's shape: ", out.shape)
        out = self.projection(out)
        print("out3's shape: ", out.shape)
        return out
    
patches_embedded = PatchEmbedding()(x)
print("patches_embedding's shape: ", patches_embedded.shape)
MultiHeadAttention()(patches_embedded).shape
output

Residuals

Transformer模塊也包含了殘差連接,如下圖:
Skip Connection

我們可以單獨封裝一個處理殘差連接的類如下:

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

接著attention層的輸出首先通過BN層她奥,緊跟其后一個全連接層瓮增,全連接層中采用了一個expansion因子來對輸入進行上采樣。同樣這里也采用了類似resnet的殘差連接的方式哩俭。如下圖:

FFN
代碼如下:

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

作者說绷跑,不知道為什么,很少看見人們直接繼承nn.Sequential類凡资,這樣就可以避免重寫forward方法了砸捏。
譯者著,確實~又學(xué)到了一招隙赁。

最終垦藏,我們可以創(chuàng)建一個完整的Transformer Encoder 塊了。

Encoder
利用我們之前定義好的ResidualAdd類鸳谜,我們可以很優(yōu)雅地定義出Transformer Encoder Block膝藕,如下:

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

我們來測試一下:

patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
output

Transformer

在ViT中只使用了原始Transformer中的Encoder部分。Encoder一共包含L個block咐扭,我們使用參數(shù)depth來指定芭挽,代碼如下:

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

很簡單不是嘛?

ViT的最后一層就是一個簡單的全連接層蝗肪,輸出分類的概率值袜爪。它對整個序列執(zhí)行一個mean操作。
MLP Head
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

我們將之前定義好的PatchEMbedding, TransformerEncoder,ClassificationHead整合起來薛闪,搭建出最終的ViT代碼模型如下:

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

我們可以使用torchsummary函數(shù)來計算參數(shù)量辛馆,輸出如下:

參數(shù)量

與其他ViT實現(xiàn)代碼相比,這個參數(shù)量是差不多的豁延。
原文的代碼倉庫在https://github.com/FrancescoSaverioZuppichini/ViT昙篙。

參考

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市诱咏,隨后出現(xiàn)的幾起案子苔可,更是在濱河造成了極大的恐慌,老刑警劉巖袋狞,帶你破解...
    沈念sama閱讀 218,682評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件焚辅,死亡現(xiàn)場離奇詭異映屋,居然都是意外死亡,警方通過查閱死者的電腦和手機同蜻,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,277評論 3 395
  • 文/潘曉璐 我一進店門棚点,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人湾蔓,你說我怎么就攤上這事瘫析。” “怎么了卵蛉?”我有些...
    開封第一講書人閱讀 165,083評論 0 355
  • 文/不壞的土叔 我叫張陵颁股,是天一觀的道長么库。 經(jīng)常有香客問我傻丝,道長,這世上最難降的妖魔是什么诉儒? 我笑而不...
    開封第一講書人閱讀 58,763評論 1 295
  • 正文 為了忘掉前任葡缰,我火速辦了婚禮,結(jié)果婚禮上忱反,老公的妹妹穿的比我還像新娘泛释。我一直安慰自己,他們只是感情好温算,可當我...
    茶點故事閱讀 67,785評論 6 392
  • 文/花漫 我一把揭開白布怜校。 她就那樣靜靜地躺著,像睡著了一般注竿。 火紅的嫁衣襯著肌膚如雪茄茁。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,624評論 1 305
  • 那天巩割,我揣著相機與錄音裙顽,去河邊找鬼。 笑死宣谈,一個胖子當著我的面吹牛愈犹,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播闻丑,決...
    沈念sama閱讀 40,358評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼漩怎,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了嗦嗡?” 一聲冷哼從身側(cè)響起勋锤,我...
    開封第一講書人閱讀 39,261評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎酸钦,沒想到半個月后怪得,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體咱枉,經(jīng)...
    沈念sama閱讀 45,722評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,900評論 3 336
  • 正文 我和宋清朗相戀三年徒恋,在試婚紗的時候發(fā)現(xiàn)自己被綠了蚕断。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,030評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡入挣,死狀恐怖亿乳,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情径筏,我是刑警寧澤葛假,帶...
    沈念sama閱讀 35,737評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站滋恬,受9級特大地震影響聊训,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜恢氯,卻給世界環(huán)境...
    茶點故事閱讀 41,360評論 3 330
  • 文/蒙蒙 一带斑、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧勋拟,春花似錦勋磕、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,941評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至啸胧,卻和暖如春赶站,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背吓揪。 一陣腳步聲響...
    開封第一講書人閱讀 33,057評論 1 270
  • 我被黑心中介騙來泰國打工亲怠, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人柠辞。 一個月前我還...
    沈念sama閱讀 48,237評論 3 371
  • 正文 我出身青樓团秽,卻偏偏與公主長得像,于是被迫代替她去往敵國和親叭首。 傳聞我的和親對象是個殘疾皇子习勤,可洞房花燭夜當晚...
    茶點故事閱讀 44,976評論 2 355

推薦閱讀更多精彩內(nèi)容