簡介
本文的目的是通過實際代碼編寫來實現(xiàn)ViT模型,進一步加對ViT模型的理解恨锚,如果還不知道ViT模型的話宇驾,可以先看下博客了解一下ViT的整體結(jié)構(gòu)。
本文整體是對Implementing Vision Transformer (ViT) in PyTorch 的翻譯猴伶,但是也加上了一些自己的注解课舍。如果讀者更習慣看英文版,建議直接去看原文他挎。
ViT模型整體結(jié)構(gòu)
按照慣例筝尾,先放上模型的架構(gòu)圖,如下:
patch
筹淫。接著這些patch
被送入一個全連接層得到embeddings
,然后在embeddings
前前加上一個特殊的cls token
呢撞。然后給所有的embedding
加上位置信息編碼positional encoding
损姜。接著這個整體被送入Transformer Encoder饰剥,然后取cls token
的輸出特征送入MLP Head去做分類,總體流程就是這樣摧阅。代碼的整體結(jié)構(gòu)跟ViT模型的結(jié)構(gòu)類似汰蓉,大體可以分為以下幾個部分:
我們將以自底向上的方式來逐步實現(xiàn)ViT模型。
Data
首先需要導(dǎo)入相關(guān)的依賴庫棒卷,如下:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
首先我們需要打開一張圖片顾孽,如下:
img = Image.open('./cat.jpg')
fig = plt.figure()
plt.imshow(img)
結(jié)果:
接著我們需要對圖片進行預(yù)處理,主要是包含
resize
比规、向量化
等操作若厚,代碼如下:
# resize to imagenet size
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # 主要是為了添加batch這個維度
x.shape
Pathches Embeddings
根據(jù)ViT模型結(jié)構(gòu),第一步是需要將圖片劃分為多個Patches苞俘,并且將其鋪平盹沈。如下圖:
einops
庫來簡化代碼編寫吃谣,如下:
patch_size = 16 # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
pathes.shape
關(guān)于
einops
庫的使用可以參考doc乞封。這里解釋一下這個結(jié)果[1,196,768]是怎么來的。我們知道原始圖片向量x的大小為[1,3,224,224]岗憋,當我們使用16x16大小的patch對其進行分割的時候肃晚,一共可以劃分為224x224/16/16 = 196個patches,其次每個patch大小為16x16x3=768仔戈,故大小為[1,196,768]关串。接著我們需要將這些patches通過一個線性映射層。
PatchEmbedding
的類來使代碼更加整潔:
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
# break-down the image in s1 x s2 patches and flat them
Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
# 注意這里的隱層大小設(shè)置的也是768监徘,可以配置
nn.Linear(patch_size * patch_size * in_channels, emb_size)
)
def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x
PatchEmbedding()(x).shape
Conv2d
層來實現(xiàn)相同的功能墓卦。這是通過設(shè)置卷積核大小和步長均為patch_size
來實現(xiàn)的。直觀上來看户敬,卷積操作是分別應(yīng)用在每個patch
上的落剪。所以,我們可以先應(yīng)用一個卷積層尿庐,然后再對結(jié)果進行鋪平忠怖,改進如下:
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
# using a conv layer instead of a linear one -> performance gains
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
# 將卷積操作后的patch鋪平
Rearrange('b e h w -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x
PatchEmbedding()(x).shape
CLS Token
下一步是對映射后的patches添加上cls token
以及位置編碼信息。cls token
是一個隨機初始化的torch Parameter對象抄瑟,在forward
方法中它需要被拷貝b
次(b
是batch的數(shù)量)凡泣,然后使用torch.cat函數(shù)添加到patch前面。
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.proj = nn.Sequential(
# using a conv layer instead of a linear one -> performance gains
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
# 生成一個維度為emb_size的向量當做cls_token
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape # 單獨先將batch緩存起來
x = self.proj(x) # 進行卷積操作
# 將cls_token 擴展b次
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
print(cls_tokens.shape)
print(x.shape)
# prepend the cls token to the input on the dim 1
x = torch.cat([cls_tokens, x], dim=1)
return x
PatchEmbedding()(x).shape
Position Embedding
目前為止,模型還對patches
的在圖像中的原始位置一無所知问麸。我們需要傳遞給模型這些空間上的信息往衷〕瑁可以有很多種方法來實現(xiàn)這個功能严卖,在ViT中,我們讓模型自己去學(xué)習這個布轿。位置編碼信息只是一個形狀為[N_PATCHES+1(token)m EMBED_SIZE]的張量哮笆,它直接加到映射后的patches
上。
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
# using a conv layer instead of a linear one -> performance gains
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
# 位置編碼信息汰扭,一共有(img_size // patch_size)**2 + 1(cls token)個位置向量
self.positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size))
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
# prepend the cls token to the input
x = torch.cat([cls_tokens, x], dim=1)
# add position embedding
print(x.shape, self.positions.shape)
x += self.positions
return x
PatchEmbedding()(x).shape
我們首先定義位置embedding向量稠肘,然后在forward函數(shù)中將其加到線性映射后的patches向量上去。
Transformer
現(xiàn)在我們來實現(xiàn)Transformer模塊萝毛。ViT模型中只使用了Transformer的Encoder部分项阴,其整體架構(gòu)如下:接下來依次實現(xiàn)。
Attention
attention部分有三個輸入笆包,分別是queries
,keys
,values
矩陣环揽,首先使用queries
,keys
矩陣去計算注意力矩陣,然后與values
矩陣相乘庵佣,得到對應(yīng)的輸出歉胶。在下圖中,multi-head
注意力機制表示將輸入劃分成n份巴粪,然后將計算分到n個head上去通今。
nn.MultiAttention
模塊或者自己實現(xiàn)一個,為了完整起見肛根,我將完整的MultiAttention代碼貼出來:
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
# split keys, queries and values in num_heads
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
# sum up over the last axis
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)
# sum up over the third axis
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
接下來辫塌,一步一步分析下上述代碼。我們定義了4個全連接層派哲,分別用于queries,keys,values,以及最后的線性映射層臼氨。關(guān)于這塊更加詳細的內(nèi)容可以閱讀The Illustrated Transformer 。主要的思想是使用queries
和keys
之間的乘積來計算輸入序列中的每一個patch與剩余patch之間的匹配程度狮辽。然后使用這個匹配程度(數(shù)值)去對對應(yīng)的values
做縮放一也,再累加起來作為Encoder的輸出。
forward
方法將上一層的輸出作為輸入喉脖,使用三個線性映射層分別得到queries
,keys
,values
椰苟。因為我們要實現(xiàn)multi-head
注意力機制,我們需要將輸出重排成多個head的形式树叽。這一步是使用einops
庫的rearrange
函數(shù)來完成的舆蝴。
Queries,keys,values的形狀是一樣的,為了簡便起見,它們都是基于同一個輸入x
洁仗。
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)
經(jīng)過rearrange
操作之后层皱,Queries,keys,values的形狀大小為[BATCH, HEADS, SEQUENCE_LEN, EMBEDDING_SIZE]
. 然后我們將多個head拼接在一起就得到了最終的輸出。
注意: 為了加快計算赠潦,我們可以使用單個矩陣一次性計算出Queries,keys,values叫胖。
改進后的代碼如下:
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
# fuse the queries, keys and values in one matrix
self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
# split keys, queries and values in num_heads
print("1qkv's shape: ", self.qkv(x).shape)
qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
print("2qkv's shape: ", qkv.shape)
queries, keys, values = qkv[0], qkv[1], qkv[2]
print("queries's shape: ", queries.shape)
print("keys's shape: ", keys.shape)
print("values's shape: ", values.shape)
# sum up over the last axis
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
print("energy's shape: ", energy.shape)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1/2)
print("scaling: ", scaling)
att = F.softmax(energy, dim=-1) / scaling
print("att1' shape: ", att.shape)
att = self.att_drop(att)
print("att2' shape: ", att.shape)
# sum up over the third axis
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
print("out1's shape: ", out.shape)
out = rearrange(out, "b h n d -> b n (h d)")
print("out2's shape: ", out.shape)
out = self.projection(out)
print("out3's shape: ", out.shape)
return out
patches_embedded = PatchEmbedding()(x)
print("patches_embedding's shape: ", patches_embedded.shape)
MultiHeadAttention()(patches_embedded).shape
Residuals
Transformer模塊也包含了殘差連接,如下圖:我們可以單獨封裝一個處理殘差連接的類如下:
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
接著attention層的輸出首先通過BN層她奥,緊跟其后一個全連接層瓮增,全連接層中采用了一個expansion
因子來對輸入進行上采樣。同樣這里也采用了類似resnet的殘差連接的方式哩俭。如下圖:
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
作者說绷跑,不知道為什么,很少看見人們直接繼承
nn.Sequential
類凡资,這樣就可以避免重寫forward
方法了砸捏。
譯者著,確實~又學(xué)到了一招隙赁。
最終垦藏,我們可以創(chuàng)建一個完整的Transformer Encoder 塊了。
ResidualAdd
類鸳谜,我們可以很優(yōu)雅地定義出Transformer Encoder Block膝藕,如下:
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size: int = 768,
drop_p: float = 0.,
forward_expansion: int = 4,
forward_drop_p: float = 0.,
** kwargs):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, **kwargs),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
我們來測試一下:
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
Transformer
在ViT中只使用了原始Transformer中的Encoder部分。Encoder一共包含L
個block咐扭,我們使用參數(shù)depth
來指定芭挽,代碼如下:
class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
很簡單不是嘛?
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))
我們將之前定義好的PatchEMbedding
, TransformerEncoder
,ClassificationHead
整合起來薛闪,搭建出最終的ViT代碼模型如下:
class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, n_classes)
)
我們可以使用torchsummary
函數(shù)來計算參數(shù)量辛馆,輸出如下:
與其他ViT實現(xiàn)代碼相比,這個參數(shù)量是差不多的豁延。
原文的代碼倉庫在https://github.com/FrancescoSaverioZuppichini/ViT昙篙。