Vision Transformer(ViT)簡介
近些年千劈,隨著基于自注意(Self-Attention)結(jié)構(gòu)的模型的發(fā)展抬闷,特別是Transformer模型的提出魔市,極大的促進(jìn)了自然語言處理模型的發(fā)展督禽。由于Transformers的計算效率和可擴(kuò)展性颠黎,它已經(jīng)能夠訓(xùn)練具有超過100B參數(shù)的空前規(guī)模的模型另锋。
ViT則是自然語言處理和計算機(jī)視覺兩個領(lǐng)域的融合結(jié)晶。在不依賴卷積操作的情況下狭归,依然可以在圖像分類任務(wù)上達(dá)到很好的效果夭坪。
模型結(jié)構(gòu)
ViT模型的主體結(jié)構(gòu)是基于Transformer模型的Encoder部分(部分結(jié)構(gòu)順序有調(diào)整,如:normalization的位置與標(biāo)準(zhǔn)Transformer不同)过椎,其結(jié)構(gòu)圖如下:
[圖片上傳失敗...(image-54c387-1652419666222)]
模型特點(diǎn)
ViT模型是應(yīng)用于圖像分類領(lǐng)域室梅。因此,其模型結(jié)構(gòu)相較于傳統(tǒng)的Transformer有以下幾個特點(diǎn):
- 數(shù)據(jù)集的原圖像被劃分為多個patch后疚宇,將二維patch(不考慮channel)轉(zhuǎn)換為一維向量亡鼠,再加上類別向量與位置向量作為模型輸入。
- 模型主體的Block基于Transformer的Encoder部分敷待,但是調(diào)整了normaliztion的位置间涵,其中,最主要的結(jié)構(gòu)依然是Multi-head Attention結(jié)構(gòu)榜揖。
- 模型在Blocks堆疊后接全連接層接受類別向量輸出用于分類勾哩。通常情況下股耽,我們將最后的全連接層稱為Head,Transformer Encoder部分為backbone钳幅。
下面將通過代碼實(shí)例來詳細(xì)解釋基于ViT實(shí)現(xiàn)ImageNet分類任務(wù)物蝙。
環(huán)境準(zhǔn)備與數(shù)據(jù)讀取
本案例基于MindSpore-GPU版本,在單GPU卡上完成模型訓(xùn)練和驗(yàn)證敢艰。
首先導(dǎo)入相關(guān)模塊诬乞,配置相關(guān)超參數(shù)并讀取數(shù)據(jù)集,該部分代碼在Vision套件中都有API可直接調(diào)用钠导,詳情可以參考以下鏈接:<u style="text-decoration: none; border-bottom: 1px dashed rgb(128, 128, 128);">https://gitee.com/mindspore/vision</u> 震嫉。
可通過:<u style="text-decoration: none; border-bottom: 1px dashed rgb(128, 128, 128);">http://image-net.org/</u> 進(jìn)行數(shù)據(jù)集下載。
加載前先定義數(shù)據(jù)集路徑牡属,請確保你的數(shù)據(jù)集路徑如以下結(jié)構(gòu)票堵。
.ImageNet/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
└── val/
from mindspore import context
from mindvision.classification.dataset import ImageNet
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
data_url = './ImageNet/'
resize = 224
batch_size = 16
dataset_train = ImageNet(data_url,
split="train",
shuffle=True,
resize=resize,
batch_size=batch_size,
repeat_num=1,
num_parallel_workers=1).run()
模型解析
下面將通過代碼來細(xì)致剖析ViT模型的內(nèi)部結(jié)構(gòu)。
Transformer基本原理
Transformer模型源于2017年的一篇文章[2]逮栅。在這篇文章中提出的基于Attention機(jī)制的編碼器-解碼器型結(jié)構(gòu)在自然語言處理領(lǐng)域獲得了巨大的成功悴势。模型結(jié)構(gòu)如下圖所示:
[圖片上傳失敗...(image-251c0-1652419666222)]
其主要結(jié)構(gòu)為多個Encoder和Decoder模塊所組成,其中Encoder和Decoder的詳細(xì)結(jié)構(gòu)如下圖所示:
[圖片上傳失敗...(image-bf835b-1652419666222)]
Encoder與Decoder由許多結(jié)構(gòu)組成措伐,如:多頭注意力(Multi-Head Attention)層特纤,F(xiàn)eed Forward層,Normaliztion層侥加,甚至殘差連接(Residual Connection捧存,圖中的“add”)。不過担败,其中最重要的結(jié)構(gòu)是多頭注意力(Multi-Head Attention)結(jié)構(gòu)昔穴,該結(jié)構(gòu)基于自注意力(Self-Attention)機(jī)制,是多個Self-Attention的并行組成提前。
所以吗货,理解了Self-Attention就抓住了Transformer的核心。
Attention模塊
以下是Self-Attention的解釋岖研,其核心內(nèi)容是為輸入向量的每個單詞學(xué)習(xí)一個權(quán)重卿操。通過給定一個任務(wù)相關(guān)的查詢向量Query向量,計算Query和各個Key的相似性或者相關(guān)性得到注意力分布孙援,即得到每個Key對應(yīng)Value的權(quán)重系數(shù)害淤,然后對Value進(jìn)行加權(quán)求和得到最終的Attention數(shù)值。
在Self-Attention中:
- 最初的輸入向量首先會經(jīng)過Embedding層映射成Q(Query)拓售,K(Key)窥摄,V(Value)三個向量,由于是并行操作础淤,所以代碼中是映射成為dim x 3的向量然后進(jìn)行分割崭放,換言之哨苛,如果你的輸入向量為一個向量序列( 1x1 1x1, 2x2 2x2币砂, 3x3 3x3)建峭,其中的 1x1 1x1, 2x2 2x2决摧, 3x3 3x3都是一維向量亿蒸,那么每一個一維向量都會經(jīng)過Embedding層映射出Q,K掌桩,V三個向量边锁,只是Embedding矩陣不同,矩陣參數(shù)也是通過學(xué)習(xí)得到的波岛。這里大家可以認(rèn)為茅坛,Q,K则拷,V三個矩陣是發(fā)現(xiàn)向量之間關(guān)聯(lián)信息的一種手段贡蓖,需要經(jīng)過學(xué)習(xí)得到,至于為什么是Q隔躲,K摩梧,V三個物延,主要是因?yàn)樾枰獌蓚€向量點(diǎn)乘以獲得權(quán)重宣旱,又需要另一個向量來承載權(quán)重向加的結(jié)果,所以叛薯,最少需要3個矩陣浑吟,也是論文作者經(jīng)過實(shí)驗(yàn)得出的結(jié)論。
<picture><source media="(max-width: 320px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+q_i+%3D+W_q+%5Ccdot+x_i+%26+%5C%5C+k_i+%3D+W_k+%5Ccdot+x_i%2C%5Chspace%7B1em%7D+%26i+%3D+1%2C2%2C3+%5Cldots+%5C%5C+v_i+%3D+W_v+%5Ccdot+x_i+%26+%5Cend%7Bcases%7D+%5Ctag%7B1%7D&width=40"><source media="(max-width: 400px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+q_i+%3D+W_q+%5Ccdot+x_i+%26+%5C%5C+k_i+%3D+W_k+%5Ccdot+x_i%2C%5Chspace%7B1em%7D+%26i+%3D+1%2C2%2C3+%5Cldots+%5C%5C+v_i+%3D+W_v+%5Ccdot+x_i+%26+%5Cend%7Bcases%7D+%5Ctag%7B1%7D&width=50"><source media="(max-width: 480px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+q_i+%3D+W_q+%5Ccdot+x_i+%26+%5C%5C+k_i+%3D+W_k+%5Ccdot+x_i%2C%5Chspace%7B1em%7D+%26i+%3D+1%2C2%2C3+%5Cldots+%5C%5C+v_i+%3D+W_v+%5Ccdot+x_i+%26+%5Cend%7Bcases%7D+%5Ctag%7B1%7D&width=60">[圖片上傳失敗...(image-81b91f-1652419666221)]</picture>
[圖片上傳失敗...(image-8f818f-1652419666222)]
2. 自注意力機(jī)制的自注意主要體現(xiàn)在它的Q耗溜,K组力,V都來源于其自身,也就是該過程是在提取輸入的不同順序的向量的聯(lián)系與特征抖拴,最終通過不同順序向量之間的聯(lián)系緊密性(Q與K乘積經(jīng)過softmax的結(jié)果)來表現(xiàn)出來燎字。Q,K阿宅,V得到后就需要獲取向量間權(quán)重候衍,需要對Q和K進(jìn)行點(diǎn)乘并除以維度的平方根 ??√d ??√d,對所有向量的結(jié)果進(jìn)行Softmax處理洒放,通過公式(2)的操作蛉鹿,我們獲得了向量之間的關(guān)系權(quán)重。
<picture><source media="(max-width: 320px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+a_%7B1%2C1%7D+%3D+q_1+%5Ccdot+k_1+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C2%7D+%3D+q_1+%5Ccdot+k_2+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C3%7D+%3D+q_1+%5Ccdot+k_3+%2F+%5Csqrt+d+%5Cend%7Bcases%7D+%5Ctag%7B2%7D&width=40"><source media="(max-width: 400px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+a_%7B1%2C1%7D+%3D+q_1+%5Ccdot+k_1+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C2%7D+%3D+q_1+%5Ccdot+k_2+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C3%7D+%3D+q_1+%5Ccdot+k_3+%2F+%5Csqrt+d+%5Cend%7Bcases%7D+%5Ctag%7B2%7D&width=50"><source media="(max-width: 480px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+a_%7B1%2C1%7D+%3D+q_1+%5Ccdot+k_1+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C2%7D+%3D+q_1+%5Ccdot+k_2+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C3%7D+%3D+q_1+%5Ccdot+k_3+%2F+%5Csqrt+d+%5Cend%7Bcases%7D+%5Ctag%7B2%7D&width=60">[圖片上傳失敗...(image-22fb5f-1652419666220)]</picture>
[圖片上傳失敗...(image-8c029e-1652419666222)]
[圖片上傳失敗...(image-96d4ef-1652419666222)]
[圖片上傳失敗...(image-7101cc-1652419666222)]
3.其最終輸出則是通過V這個映射后的向量與QK經(jīng)過Softmax結(jié)果進(jìn)行weight sum獲得往湿,這個過程可以理解為在全局上進(jìn)行自注意表示妖异。每一組QKV最后都有一個V輸出惋戏,這是Self-Attention得到的最終結(jié)果,是當(dāng)前向量在結(jié)合了它與其他向量關(guān)聯(lián)權(quán)重后得到的結(jié)果他膳。
[圖片上傳失敗...(image-300ae1-1652419666222)]
通過下圖可以整體把握Self-Attention的全部過程响逢。
[圖片上傳失敗...(image-3bd17c-1652419666222)]
多頭注意力機(jī)制就是將原本self-Attention處理的向量分割為多個Head進(jìn)行處理,這一點(diǎn)也可以從代碼中體現(xiàn)棕孙,這也是attention結(jié)構(gòu)可以進(jìn)行并行加速的一個方面龄句。
總結(jié)來說,多頭注意力機(jī)制在保持參數(shù)總量不變的情況下散罕,將同樣的query, key和value映射到原來的高維空間(Q,K,V)的不同子空間(Q_0,K_0,V_0)中進(jìn)行自注意力的計算分歇,最后再合并不同子空間中的注意力信息。
所以欧漱,對于同一個輸入向量职抡,多個注意力機(jī)制可以同時對其進(jìn)行處理,即利用并行計算加速處理過程误甚,又在處理的時候更充分的分析和利用了向量特征缚甩。下圖展示了多頭注意力機(jī)制,其并行能力的主要體現(xiàn)在下圖中的和
是同一個向量進(jìn)行分割獲得的窑邦。
[圖片上傳失敗...(image-ae8751-1652419666222)]
以下是vision套件中的Multi-Head Attention代碼擅威,結(jié)合上文的解釋,代碼清晰的展現(xiàn)了這一過程冈钦。
import mindspore.nn as nn
class Attention(nn.Cell):
def __init__(self,
dim: int,
num_heads: int = 8,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = Tensor(head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3)
self.attn_drop = nn.Dropout(attention_keep_prob)
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(keep_prob)
self.mul = P.Mul()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.unstack = P.Unstack(axis=0)
self.attn_matmul_v = P.BatchMatMul()
self.q_matmul_k = P.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
def construct(self, x):
"""Attention construct."""
b, n, c = x.shape
# 最初的輸入向量首先會經(jīng)過Embedding層映射成Q(Query)郊丛,K(Key),V(Value)三個向量
# 由于是并行操作瞧筛,所以代碼中是映射成為dim*3的向量然后進(jìn)行分割
qkv = self.qkv(x)
#多頭注意力機(jī)制就是將原本self-Attention處理的向量分割為多個Head進(jìn)行處理
qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = self.unstack(qkv)
# 自注意力機(jī)制的自注意主要體現(xiàn)在它的Q,K较幌,V都來源于其自身
# 也就是該過程是在提取輸入的不同順序的向量的聯(lián)系與特征
# 最終通過不同順序向量之間的聯(lián)系緊密性(Q與K乘積經(jīng)過softmax的結(jié)果)來表現(xiàn)出來
attn = self.q_matmul_k(q, k)
attn = self.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# 其最終輸出則是通過V這個映射后的向量與QK經(jīng)過Softmax結(jié)果進(jìn)行weight sum獲得
# 這個過程可以理解為在全局上進(jìn)行自注意表示
out = self.attn_matmul_v(attn, v)
out = self.transpose(out, (0, 2, 1, 3))
out = self.reshape(out, (b, n, c))
out = self.out(out)
out = self.out_drop(out)
return out
Transformer Encoder
在了解了Self-Attention結(jié)構(gòu)之后揍瑟,通過與Feed Forward,Residual Connection等結(jié)構(gòu)的拼接就可以形成Transformer的基礎(chǔ)結(jié)構(gòu)乍炉,接下來就利用Self-Attention來構(gòu)建ViT模型中的TransformerEncoder部分绢片,類似于構(gòu)建了一個Transformer的編碼器部分。
[圖片上傳失敗...(image-87e954-1652419666222)]
- ViT模型中的基礎(chǔ)結(jié)構(gòu)與標(biāo)準(zhǔn)Transformer有所不同岛琼,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他結(jié)構(gòu)如Residual Connection底循,F(xiàn)eed Forward,Normalization都如Transformer中所設(shè)計衷恭。
- 從transformer結(jié)構(gòu)的圖片可以發(fā)現(xiàn)此叠,多個子encoder的堆疊就完成了模型編碼器的構(gòu)建,在ViT模型中,依然沿用這個思路灭袁,通過配置超參數(shù)num_layers猬错,就可以確定堆疊層數(shù)。
- Residual Connection茸歧,Normalization的結(jié)構(gòu)可以保證模型有很強(qiáng)的擴(kuò)展性(保證信息經(jīng)過深層處理不會出現(xiàn)退化的現(xiàn)象倦炒,這是Residual Connection的作用),Normalization和dropout的應(yīng)用可以增強(qiáng)模型泛化能力软瞎。
從以下源碼中就可以清晰看到Transformer的結(jié)構(gòu)逢唤。將TransformerEncoder結(jié)構(gòu)和一個多層感知器(MLP)結(jié)合,就構(gòu)成了ViT模型的backbone部分涤浇。
class TransformerEncoder(nn.Cell):
def __init__(self,
dim: int,
num_layers: int,
num_heads: int,
mlp_dim: int,
keep_prob: float = 1.,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
# 從vit_architecture圖可以發(fā)現(xiàn)鳖藕,多個子encoder的堆疊就完成了模型編碼器的構(gòu)建
# 在ViT模型中,依然沿用這個思路只锭,通過配置超參數(shù)num_layers著恩,就可以確定堆疊層數(shù)
for _ in range(num_layers):
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim=dim,
num_heads=num_heads,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation=activation,
keep_prob=keep_prob)
# ViT模型中的基礎(chǔ)結(jié)構(gòu)與標(biāo)準(zhǔn)Transformer有所不同
# 主要在于Normalization的位置是放在Self-Attention和Feed Forward之前
# 其他結(jié)構(gòu)如Residual Connection,F(xiàn)eed Forward蜻展,Normalization都如Transformer中所設(shè)計
layers.append(
nn.SequentialCell([
# Residual Connection喉誊,Normalization的結(jié)構(gòu)可以保證模型有很強(qiáng)的擴(kuò)展性
# 保證信息經(jīng)過深層處理不會出現(xiàn)退化的現(xiàn)象,這是Residual Connection的作用
# Normalization和dropout的應(yīng)用可以增強(qiáng)模型泛化能力
ResidualCell(nn.SequentialCell([normalization1,
attention])),
ResidualCell(nn.SequentialCell([normalization2,
feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
"""Transformer construct."""
return self.layers(x)
ViT模型的輸入
傳統(tǒng)的Transformer結(jié)構(gòu)主要用于處理自然語言領(lǐng)域的詞向量(Word Embedding or Word Vector)纵顾,詞向量與傳統(tǒng)圖像數(shù)據(jù)的主要區(qū)別在于伍茄,詞向量通常是1維向量進(jìn)行堆疊,而圖片則是二維矩陣的堆疊施逾,多頭注意力機(jī)制在處理1維詞向量的堆疊時會提取詞向量之間的聯(lián)系也就是上下文語義敷矫,這使得Transformer在自然語言處理領(lǐng)域非常好用,而2維圖片矩陣如何與1維詞向量進(jìn)行轉(zhuǎn)化就成為了Transformer進(jìn)軍圖像處理領(lǐng)域的一個小門檻音念。
在ViT模型中:
- 通過將輸入圖像在每個channel上劃分為1616個patch沪饺,這一步是通過卷積操作來完成的,當(dāng)然也可以人工進(jìn)行劃分闷愤,但卷積操作也可以達(dá)到目的同時還可以進(jìn)行一次而外的數(shù)據(jù)處理;例如一幅輸入224 x 224的圖像件余,首先經(jīng)過卷積處理得到16 x 16個patch讥脐,那么每一個patch的大小就是14 x 14。*
- 再將每一個patch的矩陣?yán)斐蔀橐粋€1維向量啼器,從而獲得了近似詞向量堆疊的效果旬渠。上一步得道的14 x 14的patch就轉(zhuǎn)換為長度為196的向量。
這是圖像輸入網(wǎng)絡(luò)經(jīng)過的第一步處理端壳。具體Patch Embedding的代碼如下所示:
class PatchEmbedding(nn.Cell):
MIN_NUM_PATCHES = 4
def __init__(self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
input_channels: int = 3):
super(PatchEmbedding, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# 通過將輸入圖像在每個channel上劃分為16*16個patch
self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
"""Path Embedding construct."""
x = self.conv(x)
b, c, h, w = x.shape
# 再將每一個patch的矩陣?yán)斐蔀橐粋€1維向量告丢,從而獲得了近似詞向量堆疊的效果;
x = self.reshape(x, (b, c, h * w))
x = self.transpose(x, (0, 2, 1))
return x
由論文中的模型結(jié)構(gòu)可以得知损谦,輸入圖像在劃分為patch之后岖免,會經(jīng)過pos_embedding 和 class_embedding兩個過程岳颇。
- class_embedding主要借鑒了BERT模型的用于文本分類時的思想,在每一個word vector之前增加一個類別值颅湘,通常是加在向量的第一位,上一步得到的196維的向量加上class_embedding后變?yōu)?97維话侧。
- 增加的class_embedding是一個可以學(xué)習(xí)的參數(shù),經(jīng)過網(wǎng)絡(luò)的不斷訓(xùn)練闯参,最終以輸出向量的第一個維度的輸出來決定最后的輸出類別瞻鹏;由于輸入是16 x 16個patch,所以輸出進(jìn)行分類時是取 16 x 16個class_embedding進(jìn)行分類鹿寨。
- pos_embedding也是一組可以學(xué)習(xí)的參數(shù)新博,會被加入到經(jīng)過處理的patch矩陣中。
- 由于pos_embedding也是可以學(xué)習(xí)的參數(shù)脚草,所以它的加入類似于全鏈接網(wǎng)絡(luò)和卷積的bias叭披。這一步就是創(chuàng)造一個長度維197的可訓(xùn)練向量加入到經(jīng)過class_embedding的向量中。
從論文中可以得到玩讳,pos_embedding總共有4中方案涩蜘。但是經(jīng)過作者的論證,只有加上pos_embedding和不加pos_embedding有明顯影響熏纯,至于pos_embedding是1維還是2維對分類結(jié)果影響不大同诫,所以,在我們的代碼中樟澜,也是采用了1維的pos_embedding误窖,由于class_embedding是加在pos_embedding之前,所以pos_embedding的維度會比patch拉伸后的維度加1秩贰。
總的而言霹俺,ViT模型還是利用了Transformer模型在處理上下文語義時的優(yōu)勢,將圖像轉(zhuǎn)換為一種“變種詞向量”然后進(jìn)行處理毒费,而這樣轉(zhuǎn)換的意義在于丙唧,多個patch之間本身具有空間聯(lián)系,這類似于一種“空間語義”觅玻,從而獲得了比較好的處理效果想际。
整體構(gòu)建ViT
以下代碼構(gòu)建了一個完整的ViT模型。
from typing import Optional
class ViT(nn.Cell):
def __init__(self,
image_size: int = 224,
input_channels: int = 3,
patch_size: int = 16,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: Optional[nn.Cell] = nn.LayerNorm,
pool: str = 'cls') -> None:
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim,
input_channels=input_channels)
num_patches = self.patch_embedding.num_patches
# 此處增加class_embedding和pos_embedding溪厘,如果不是進(jìn)行分類任務(wù)
# 可以只增加pos_embedding胡本,通過pool參數(shù)進(jìn)行控制
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
# pos_embedding也是一組可以學(xué)習(xí)的參數(shù),會被加入到經(jīng)過處理的patch矩陣中
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
# axis=1定義了會在向量的開頭加入class_embedding
self.concat = P.Concat(axis=1)
self.pool = pool
self.pos_dropout = nn.Dropout(keep_prob)
self.norm = norm((embed_dim,))
self.tile = P.Tile()
self.transformer = TransformerEncoder(dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
mlp_dim=mlp_dim,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob,
drop_path_keep_prob=drop_path_keep_prob,
activation=activation,
norm=norm)
def construct(self, x):
"""ViT construct."""
x = self.patch_embedding(x)
# class_embedding主要借鑒了BERT模型的用于文本分類時的思想
# 在每一個word vector之前增加一個類別值畸悬,通常是加在向量的第一位
cls_tokens = self.tile(self.cls_token, (x.shape[0], 1, 1))
x = self.concat((cls_tokens, x))
x += self.pos_embedding
x = self.pos_dropout(x)
x = self.transformer(x)
x = self.norm(x)
# 增加的class_embedding是一個可以學(xué)習(xí)的參數(shù)侧甫,經(jīng)過網(wǎng)絡(luò)的不斷訓(xùn)練
# 最終以輸出向量的第一個維度的輸出來決定最后的輸出類別;
x = x[:, 0]
return x
[圖片上傳失敗...(image-a33976-1652419666221)]
模型訓(xùn)練與推理
模型訓(xùn)練
模型開始訓(xùn)練前,需要設(shè)定損失函數(shù)披粟,優(yōu)化器咒锻,回調(diào)函數(shù)等,直接調(diào)用mindvision提供的接口可以方便完成實(shí)例化僻爽。
import mindspore.nn as nn
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindvision.classification.models import vit_b_16
from mindvision.engine.callback import LossMonitor
from mindvision.engine.loss import CrossEntropySmooth
# 定義超參數(shù)
epoch_size = 10
momentum = 0.9
step_size = dataset_train.get_dataset_size()
num_classes = 1000
# 構(gòu)建模型
network = vit_b_16(num_classes=num_classes, image_size=resize, pretrained=True)
# 定義遞減的學(xué)習(xí)率
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.003,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=90)
# 定義優(yōu)化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# 定義損失函數(shù)
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
classes_num=num_classes)
# 設(shè)定checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# 初始化模型
model = Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"})
# 訓(xùn)練
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(lr)],
dataset_sink_mode=False)
結(jié)果:
Epoch:[ 0/ 10], step:[ 1/80072], loss:[1.963/1.963], time:8171.241 ms, lr:0.00300
Epoch:[ 0/ 10], step:[ 2/80072], loss:[7.809/4.886], time:769.321 ms, lr:0.00300
Epoch:[ 0/ 10], step:[ 3/80072], loss:[8.851/6.208], time:779.355 ms, lr:0.00300
....
Epoch:[ 9/ 10], step:[80070/80072], loss:[1.112/6.657], time:780.714 ms, lr:0.00240
Epoch:[ 9/ 10], step:[80071/80072], loss:[1.111/6.708], time:781.860 ms, lr:0.00240
Epoch:[ 9/ 10], step:[80072/80072], loss:[1.102/6.777], time:782.859 ms, lr:0.00240
模型驗(yàn)證
模型驗(yàn)證過程主要應(yīng)用了nn虫碉,Model,context胸梆,ImageNet敦捧,CrossEntropySmooth和vit_b_16等接口。
通過改變ImageNet接口的split參數(shù)即可調(diào)用驗(yàn)證集碰镜。
與訓(xùn)練過程相似兢卵,首先調(diào)用vit_b_16接口定義網(wǎng)絡(luò)結(jié)構(gòu),加載預(yù)訓(xùn)練模型參數(shù)绪颖。隨后設(shè)置損失函數(shù)秽荤,評價指標(biāo)等,編譯模型后進(jìn)行驗(yàn)證柠横。
dataset_analyse = ImageNet(data_url,
split="val",
num_parallel_workers=1,
resize=resize,
batch_size=batch_size)
dataset_eval = dataset_analyse.run()
network = vit_b_16(num_classes=num_classes, image_size=resize, pretrained=True)
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
classes_num=num_classes)
# 定義評價指標(biāo)
eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy(),
'Top_5_Accuracy': nn.Top5CategoricalAccuracy()}
model = Model(network, network_loss, metrics=eval_metrics)
# 評估模型
result = model.eval(dataset_eval)
print(result)
結(jié)果:
{'Top_1_Accuracy': 0.73524, 'Top_5_Accuracy': 0.91756}
模型推理
在進(jìn)行模型推理之前窃款,首先要定義一個對推理圖片進(jìn)行數(shù)據(jù)預(yù)處理的方法。該方法可以對我們的推理圖片進(jìn)行resize和normalize處理牍氛,這樣才能與我們訓(xùn)練時的輸入數(shù)據(jù)匹配晨继。
import mindspore.dataset.vision.c_transforms as transforms
# 數(shù)據(jù)預(yù)處理操作
def infer_transform(dataset, columns_list, resize):
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
trans = [transforms.Decode(),
transforms.Resize([resize, resize]),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()]
dataset = dataset.map(operations=trans,
input_columns=columns_list[0],
num_parallel_workers=1)
dataset = dataset.batch(1)
return dataset
接下來,我們將調(diào)用模型的predict方法進(jìn)行模型推理搬俊,需要注意的是紊扬,推理圖片需要自備,同時給予準(zhǔn)確的路徑利用read_dataset接口讀推理圖片路徑唉擂,利用GeneratorDataset來生成測試集餐屎。
在推理過程中,ImageNet接口主要負(fù)責(zé)對原數(shù)據(jù)集標(biāo)簽和模型輸出進(jìn)行配對玩祟。通過index2label就可以獲取對應(yīng)標(biāo)簽倘核,再通過show_result接口將結(jié)果寫在對應(yīng)圖片上井氢。
import numpy as np
import mindspore.dataset as ds
from mindspore import Tensor
from mindvision.dataset.generator import DatasetGenerator
from mindvision.dataset.download import read_dataset
from mindvision.classification.utils.image import show_result
# 讀取推理圖片
image_list, image_label = read_dataset('./infer')
columns_list = ('image', 'label')
dataset_infer = ds.GeneratorDataset(DatasetGenerator(image_list, image_label),
column_names=list(columns_list),
num_parallel_workers=1)
dataset_infer = infer_transform(dataset_infer, columns_list, resize)
# 讀取數(shù)據(jù)進(jìn)行推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
predict = dataset_analyse.index2label[int(label)]
output = {int(label): predict}
print(output)
show_result(img=image_list[i], result=output, out_file=image_list[i])
結(jié)果:
{236: 'Doberman'}
推理過程完成后躏仇,在推理文件夾下可以找到圖片的推理結(jié)果箫踩,如下圖所示:
[圖片上傳失敗...(image-756be5-1652419666221)]
總結(jié)
本案例完成了一個ViT模型在ImageNet數(shù)據(jù)上進(jìn)行訓(xùn)練昼激,驗(yàn)證和推理的過程鞋诗,其中疗杉,對關(guān)鍵的ViT模型結(jié)構(gòu)和原理作了講解尽狠。通過學(xué)習(xí)本案例象对,理解源碼可以幫助學(xué)員掌握Multi-Head Attention黑忱,TransformerEncoder,pos_embedding等關(guān)鍵概念,如果要詳細(xì)理解ViT的模型原理甫煞,建議基于源碼更深層次的詳細(xì)閱讀菇曲,可以參考vision套件:
<u style="text-decoration: none; border-bottom: 1px dashed rgb(128, 128, 128);">https://gitee.com/mindspore/vision/tree/master/examples/classification/vit</u> 。
引用
[1] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).
[2] Vaswani, Ashish, et al. "Attention is all you need."Advances in Neural Information Processing Systems. (2017).