Vision Transformer圖像分類(MindSpore實(shí)現(xiàn))

Vision Transformer(ViT)簡介

近些年千劈,隨著基于自注意(Self-Attention)結(jié)構(gòu)的模型的發(fā)展抬闷,特別是Transformer模型的提出魔市,極大的促進(jìn)了自然語言處理模型的發(fā)展督禽。由于Transformers的計算效率和可擴(kuò)展性颠黎,它已經(jīng)能夠訓(xùn)練具有超過100B參數(shù)的空前規(guī)模的模型另锋。

ViT則是自然語言處理和計算機(jī)視覺兩個領(lǐng)域的融合結(jié)晶。在不依賴卷積操作的情況下狭归,依然可以在圖像分類任務(wù)上達(dá)到很好的效果夭坪。

模型結(jié)構(gòu)

ViT模型的主體結(jié)構(gòu)是基于Transformer模型的Encoder部分(部分結(jié)構(gòu)順序有調(diào)整,如:normalization的位置與標(biāo)準(zhǔn)Transformer不同)过椎,其結(jié)構(gòu)圖如下:

[圖片上傳失敗...(image-54c387-1652419666222)]

模型特點(diǎn)

ViT模型是應(yīng)用于圖像分類領(lǐng)域室梅。因此,其模型結(jié)構(gòu)相較于傳統(tǒng)的Transformer有以下幾個特點(diǎn):

  1. 數(shù)據(jù)集的原圖像被劃分為多個patch后疚宇,將二維patch(不考慮channel)轉(zhuǎn)換為一維向量亡鼠,再加上類別向量與位置向量作為模型輸入。
  2. 模型主體的Block基于Transformer的Encoder部分敷待,但是調(diào)整了normaliztion的位置间涵,其中,最主要的結(jié)構(gòu)依然是Multi-head Attention結(jié)構(gòu)榜揖。
  3. 模型在Blocks堆疊后接全連接層接受類別向量輸出用于分類勾哩。通常情況下股耽,我們將最后的全連接層稱為Head,Transformer Encoder部分為backbone钳幅。

下面將通過代碼實(shí)例來詳細(xì)解釋基于ViT實(shí)現(xiàn)ImageNet分類任務(wù)物蝙。

環(huán)境準(zhǔn)備與數(shù)據(jù)讀取

本案例基于MindSpore-GPU版本,在單GPU卡上完成模型訓(xùn)練和驗(yàn)證敢艰。

首先導(dǎo)入相關(guān)模塊诬乞,配置相關(guān)超參數(shù)并讀取數(shù)據(jù)集,該部分代碼在Vision套件中都有API可直接調(diào)用钠导,詳情可以參考以下鏈接:<u style="text-decoration: none; border-bottom: 1px dashed rgb(128, 128, 128);">https://gitee.com/mindspore/vision</u> 震嫉。

可通過:<u style="text-decoration: none; border-bottom: 1px dashed rgb(128, 128, 128);">http://image-net.org/</u> 進(jìn)行數(shù)據(jù)集下載。

加載前先定義數(shù)據(jù)集路徑牡属,請確保你的數(shù)據(jù)集路徑如以下結(jié)構(gòu)票堵。

.ImageNet/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    └── val/
from mindspore import context
from mindvision.classification.dataset import ImageNet

context.set_context(mode=context.GRAPH_MODE, device_target='GPU')

data_url = './ImageNet/'
resize = 224
batch_size = 16

dataset_train = ImageNet(data_url,
                         split="train",
                         shuffle=True,
                         resize=resize,
                         batch_size=batch_size,
                         repeat_num=1,
                         num_parallel_workers=1).run()

模型解析

下面將通過代碼來細(xì)致剖析ViT模型的內(nèi)部結(jié)構(gòu)。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]逮栅。在這篇文章中提出的基于Attention機(jī)制的編碼器-解碼器型結(jié)構(gòu)在自然語言處理領(lǐng)域獲得了巨大的成功悴势。模型結(jié)構(gòu)如下圖所示:

[圖片上傳失敗...(image-251c0-1652419666222)]

其主要結(jié)構(gòu)為多個Encoder和Decoder模塊所組成,其中Encoder和Decoder的詳細(xì)結(jié)構(gòu)如下圖所示:

[圖片上傳失敗...(image-bf835b-1652419666222)]

Encoder與Decoder由許多結(jié)構(gòu)組成措伐,如:多頭注意力(Multi-Head Attention)層特纤,F(xiàn)eed Forward層,Normaliztion層侥加,甚至殘差連接(Residual Connection捧存,圖中的“add”)。不過担败,其中最重要的結(jié)構(gòu)是多頭注意力(Multi-Head Attention)結(jié)構(gòu)昔穴,該結(jié)構(gòu)基于自注意力(Self-Attention)機(jī)制,是多個Self-Attention的并行組成提前。

所以吗货,理解了Self-Attention就抓住了Transformer的核心。

Attention模塊

以下是Self-Attention的解釋岖研,其核心內(nèi)容是為輸入向量的每個單詞學(xué)習(xí)一個權(quán)重卿操。通過給定一個任務(wù)相關(guān)的查詢向量Query向量,計算Query和各個Key的相似性或者相關(guān)性得到注意力分布孙援,即得到每個Key對應(yīng)Value的權(quán)重系數(shù)害淤,然后對Value進(jìn)行加權(quán)求和得到最終的Attention數(shù)值。

在Self-Attention中:

  1. 最初的輸入向量首先會經(jīng)過Embedding層映射成Q(Query)拓售,K(Key)窥摄,V(Value)三個向量,由于是并行操作础淤,所以代碼中是映射成為dim x 3的向量然后進(jìn)行分割崭放,換言之哨苛,如果你的輸入向量為一個向量序列( 1x1 1x1, 2x2 2x2币砂, 3x3 3x3)建峭,其中的 1x1 1x1, 2x2 2x2决摧, 3x3 3x3都是一維向量亿蒸,那么每一個一維向量都會經(jīng)過Embedding層映射出Q,K掌桩,V三個向量边锁,只是Embedding矩陣不同,矩陣參數(shù)也是通過學(xué)習(xí)得到的波岛。這里大家可以認(rèn)為茅坛,Q,K则拷,V三個矩陣是發(fā)現(xiàn)向量之間關(guān)聯(lián)信息的一種手段贡蓖,需要經(jīng)過學(xué)習(xí)得到,至于為什么是Q隔躲,K摩梧,V三個物延,主要是因?yàn)樾枰獌蓚€向量點(diǎn)乘以獲得權(quán)重宣旱,又需要另一個向量來承載權(quán)重向加的結(jié)果,所以叛薯,最少需要3個矩陣浑吟,也是論文作者經(jīng)過實(shí)驗(yàn)得出的結(jié)論。

<picture><source media="(max-width: 320px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+q_i+%3D+W_q+%5Ccdot+x_i+%26+%5C%5C+k_i+%3D+W_k+%5Ccdot+x_i%2C%5Chspace%7B1em%7D+%26i+%3D+1%2C2%2C3+%5Cldots+%5C%5C+v_i+%3D+W_v+%5Ccdot+x_i+%26+%5Cend%7Bcases%7D+%5Ctag%7B1%7D&width=40"><source media="(max-width: 400px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+q_i+%3D+W_q+%5Ccdot+x_i+%26+%5C%5C+k_i+%3D+W_k+%5Ccdot+x_i%2C%5Chspace%7B1em%7D+%26i+%3D+1%2C2%2C3+%5Cldots+%5C%5C+v_i+%3D+W_v+%5Ccdot+x_i+%26+%5Cend%7Bcases%7D+%5Ctag%7B1%7D&width=50"><source media="(max-width: 480px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+q_i+%3D+W_q+%5Ccdot+x_i+%26+%5C%5C+k_i+%3D+W_k+%5Ccdot+x_i%2C%5Chspace%7B1em%7D+%26i+%3D+1%2C2%2C3+%5Cldots+%5C%5C+v_i+%3D+W_v+%5Ccdot+x_i+%26+%5Cend%7Bcases%7D+%5Ctag%7B1%7D&width=60">[圖片上傳失敗...(image-81b91f-1652419666221)]</picture>

[圖片上傳失敗...(image-8f818f-1652419666222)]

2. 自注意力機(jī)制的自注意主要體現(xiàn)在它的Q耗溜,K组力,V都來源于其自身,也就是該過程是在提取輸入的不同順序的向量的聯(lián)系與特征抖拴,最終通過不同順序向量之間的聯(lián)系緊密性(Q與K乘積經(jīng)過softmax的結(jié)果)來表現(xiàn)出來燎字。Q,K阿宅,V得到后就需要獲取向量間權(quán)重候衍,需要對Q和K進(jìn)行點(diǎn)乘并除以維度的平方根 ??√d ??√d,對所有向量的結(jié)果進(jìn)行Softmax處理洒放,通過公式(2)的操作蛉鹿,我們獲得了向量之間的關(guān)系權(quán)重。

<picture><source media="(max-width: 320px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+a_%7B1%2C1%7D+%3D+q_1+%5Ccdot+k_1+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C2%7D+%3D+q_1+%5Ccdot+k_2+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C3%7D+%3D+q_1+%5Ccdot+k_3+%2F+%5Csqrt+d+%5Cend%7Bcases%7D+%5Ctag%7B2%7D&width=40"><source media="(max-width: 400px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+a_%7B1%2C1%7D+%3D+q_1+%5Ccdot+k_1+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C2%7D+%3D+q_1+%5Ccdot+k_2+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C3%7D+%3D+q_1+%5Ccdot+k_3+%2F+%5Csqrt+d+%5Cend%7Bcases%7D+%5Ctag%7B2%7D&width=50"><source media="(max-width: 480px)" srcset="https://www.zhihu.com/equation?tex=%5Cbegin%7Bcases%7D+a_%7B1%2C1%7D+%3D+q_1+%5Ccdot+k_1+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C2%7D+%3D+q_1+%5Ccdot+k_2+%2F+%5Csqrt+d+%5C%5C+a_%7B1%2C3%7D+%3D+q_1+%5Ccdot+k_3+%2F+%5Csqrt+d+%5Cend%7Bcases%7D+%5Ctag%7B2%7D&width=60">[圖片上傳失敗...(image-22fb5f-1652419666220)]</picture>

[圖片上傳失敗...(image-8c029e-1652419666222)]

[圖片上傳失敗...(image-96d4ef-1652419666222)]

[圖片上傳失敗...(image-7101cc-1652419666222)]

3.其最終輸出則是通過V這個映射后的向量與QK經(jīng)過Softmax結(jié)果進(jìn)行weight sum獲得往湿,這個過程可以理解為在全局上進(jìn)行自注意表示妖异。每一組QKV最后都有一個V輸出惋戏,這是Self-Attention得到的最終結(jié)果,是當(dāng)前向量在結(jié)合了它與其他向量關(guān)聯(lián)權(quán)重后得到的結(jié)果他膳。

[圖片上傳失敗...(image-300ae1-1652419666222)]

通過下圖可以整體把握Self-Attention的全部過程响逢。

[圖片上傳失敗...(image-3bd17c-1652419666222)]

多頭注意力機(jī)制就是將原本self-Attention處理的向量分割為多個Head進(jìn)行處理,這一點(diǎn)也可以從代碼中體現(xiàn)棕孙,這也是attention結(jié)構(gòu)可以進(jìn)行并行加速的一個方面龄句。

總結(jié)來說,多頭注意力機(jī)制在保持參數(shù)總量不變的情況下散罕,將同樣的query, key和value映射到原來的高維空間(Q,K,V)的不同子空間(Q_0,K_0,V_0)中進(jìn)行自注意力的計算分歇,最后再合并不同子空間中的注意力信息。

所以欧漱,對于同一個輸入向量职抡,多個注意力機(jī)制可以同時對其進(jìn)行處理,即利用并行計算加速處理過程误甚,又在處理的時候更充分的分析和利用了向量特征缚甩。下圖展示了多頭注意力機(jī)制,其并行能力的主要體現(xiàn)在下圖中的a_1a_2是同一個向量進(jìn)行分割獲得的窑邦。

[圖片上傳失敗...(image-ae8751-1652419666222)]

以下是vision套件中的Multi-Head Attention代碼擅威,結(jié)合上文的解釋,代碼清晰的展現(xiàn)了這一過程冈钦。

import mindspore.nn as nn

class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(keep_prob)

        self.mul = P.Mul()
        self.reshape = P.Reshape()
        self.transpose = P.Transpose()
        self.unstack = P.Unstack(axis=0)
        self.attn_matmul_v = P.BatchMatMul()
        self.q_matmul_k = P.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape

        # 最初的輸入向量首先會經(jīng)過Embedding層映射成Q(Query)郊丛,K(Key),V(Value)三個向量
        # 由于是并行操作瞧筛,所以代碼中是映射成為dim*3的向量然后進(jìn)行分割
        qkv = self.qkv(x)

        #多頭注意力機(jī)制就是將原本self-Attention處理的向量分割為多個Head進(jìn)行處理
        qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = self.unstack(qkv)

        # 自注意力機(jī)制的自注意主要體現(xiàn)在它的Q,K较幌,V都來源于其自身
        # 也就是該過程是在提取輸入的不同順序的向量的聯(lián)系與特征
        # 最終通過不同順序向量之間的聯(lián)系緊密性(Q與K乘積經(jīng)過softmax的結(jié)果)來表現(xiàn)出來
        attn = self.q_matmul_k(q, k)
        attn = self.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        # 其最終輸出則是通過V這個映射后的向量與QK經(jīng)過Softmax結(jié)果進(jìn)行weight sum獲得
        # 這個過程可以理解為在全局上進(jìn)行自注意表示
        out = self.attn_matmul_v(attn, v)
        out = self.transpose(out, (0, 2, 1, 3))
        out = self.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

Transformer Encoder

在了解了Self-Attention結(jié)構(gòu)之后揍瑟,通過與Feed Forward,Residual Connection等結(jié)構(gòu)的拼接就可以形成Transformer的基礎(chǔ)結(jié)構(gòu)乍炉,接下來就利用Self-Attention來構(gòu)建ViT模型中的TransformerEncoder部分绢片,類似于構(gòu)建了一個Transformer的編碼器部分。

[圖片上傳失敗...(image-87e954-1652419666222)]

  1. ViT模型中的基礎(chǔ)結(jié)構(gòu)與標(biāo)準(zhǔn)Transformer有所不同岛琼,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他結(jié)構(gòu)如Residual Connection底循,F(xiàn)eed Forward,Normalization都如Transformer中所設(shè)計衷恭。
  2. 從transformer結(jié)構(gòu)的圖片可以發(fā)現(xiàn)此叠,多個子encoder的堆疊就完成了模型編碼器的構(gòu)建,在ViT模型中,依然沿用這個思路灭袁,通過配置超參數(shù)num_layers猬错,就可以確定堆疊層數(shù)。
  3. Residual Connection茸歧,Normalization的結(jié)構(gòu)可以保證模型有很強(qiáng)的擴(kuò)展性(保證信息經(jīng)過深層處理不會出現(xiàn)退化的現(xiàn)象倦炒,這是Residual Connection的作用),Normalization和dropout的應(yīng)用可以增強(qiáng)模型泛化能力软瞎。

從以下源碼中就可以清晰看到Transformer的結(jié)構(gòu)逢唤。將TransformerEncoder結(jié)構(gòu)和一個多層感知器(MLP)結(jié)合,就構(gòu)成了ViT模型的backbone部分涤浇。

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        # 從vit_architecture圖可以發(fā)現(xiàn)鳖藕,多個子encoder的堆疊就完成了模型編碼器的構(gòu)建
        # 在ViT模型中,依然沿用這個思路只锭,通過配置超參數(shù)num_layers著恩,就可以確定堆疊層數(shù)
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            # ViT模型中的基礎(chǔ)結(jié)構(gòu)與標(biāo)準(zhǔn)Transformer有所不同
            # 主要在于Normalization的位置是放在Self-Attention和Feed Forward之前
            # 其他結(jié)構(gòu)如Residual Connection,F(xiàn)eed Forward蜻展,Normalization都如Transformer中所設(shè)計
            layers.append(
                nn.SequentialCell([
                    # Residual Connection喉誊,Normalization的結(jié)構(gòu)可以保證模型有很強(qiáng)的擴(kuò)展性
                    # 保證信息經(jīng)過深層處理不會出現(xiàn)退化的現(xiàn)象,這是Residual Connection的作用
                    # Normalization和dropout的應(yīng)用可以增強(qiáng)模型泛化能力
                    ResidualCell(nn.SequentialCell([normalization1,
                                                    attention])),

                    ResidualCell(nn.SequentialCell([normalization2,
                                                    feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

ViT模型的輸入

傳統(tǒng)的Transformer結(jié)構(gòu)主要用于處理自然語言領(lǐng)域的詞向量(Word Embedding or Word Vector)纵顾,詞向量與傳統(tǒng)圖像數(shù)據(jù)的主要區(qū)別在于伍茄,詞向量通常是1維向量進(jìn)行堆疊,而圖片則是二維矩陣的堆疊施逾,多頭注意力機(jī)制在處理1維詞向量的堆疊時會提取詞向量之間的聯(lián)系也就是上下文語義敷矫,這使得Transformer在自然語言處理領(lǐng)域非常好用,而2維圖片矩陣如何與1維詞向量進(jìn)行轉(zhuǎn)化就成為了Transformer進(jìn)軍圖像處理領(lǐng)域的一個小門檻音念。

在ViT模型中:

  1. 通過將輸入圖像在每個channel上劃分為1616個patch沪饺,這一步是通過卷積操作來完成的,當(dāng)然也可以人工進(jìn)行劃分闷愤,但卷積操作也可以達(dá)到目的同時還可以進(jìn)行一次而外的數(shù)據(jù)處理;例如一幅輸入224 x 224的圖像件余,首先經(jīng)過卷積處理得到16 x 16個patch讥脐,那么每一個patch的大小就是14 x 14。*
  2. 再將每一個patch的矩陣?yán)斐蔀橐粋€1維向量啼器,從而獲得了近似詞向量堆疊的效果旬渠。上一步得道的14 x 14的patch就轉(zhuǎn)換為長度為196的向量。

這是圖像輸入網(wǎng)絡(luò)經(jīng)過的第一步處理端壳。具體Patch Embedding的代碼如下所示:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2

        # 通過將輸入圖像在每個channel上劃分為16*16個patch
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
        self.reshape = P.Reshape()
        self.transpose = P.Transpose()

    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape

        # 再將每一個patch的矩陣?yán)斐蔀橐粋€1維向量告丢,從而獲得了近似詞向量堆疊的效果;
        x = self.reshape(x, (b, c, h * w))
        x = self.transpose(x, (0, 2, 1))

        return x

由論文中的模型結(jié)構(gòu)可以得知损谦,輸入圖像在劃分為patch之后岖免,會經(jīng)過pos_embedding 和 class_embedding兩個過程岳颇。

  1. class_embedding主要借鑒了BERT模型的用于文本分類時的思想,在每一個word vector之前增加一個類別值颅湘,通常是加在向量的第一位,上一步得到的196維的向量加上class_embedding后變?yōu)?97維话侧。
  2. 增加的class_embedding是一個可以學(xué)習(xí)的參數(shù),經(jīng)過網(wǎng)絡(luò)的不斷訓(xùn)練闯参,最終以輸出向量的第一個維度的輸出來決定最后的輸出類別瞻鹏;由于輸入是16 x 16個patch,所以輸出進(jìn)行分類時是取 16 x 16個class_embedding進(jìn)行分類鹿寨。
  3. pos_embedding也是一組可以學(xué)習(xí)的參數(shù)新博,會被加入到經(jīng)過處理的patch矩陣中。
  4. 由于pos_embedding也是可以學(xué)習(xí)的參數(shù)脚草,所以它的加入類似于全鏈接網(wǎng)絡(luò)和卷積的bias叭披。這一步就是創(chuàng)造一個長度維197的可訓(xùn)練向量加入到經(jīng)過class_embedding的向量中。

從論文中可以得到玩讳,pos_embedding總共有4中方案涩蜘。但是經(jīng)過作者的論證,只有加上pos_embedding和不加pos_embedding有明顯影響熏纯,至于pos_embedding是1維還是2維對分類結(jié)果影響不大同诫,所以,在我們的代碼中樟澜,也是采用了1維的pos_embedding误窖,由于class_embedding是加在pos_embedding之前,所以pos_embedding的維度會比patch拉伸后的維度加1秩贰。

總的而言霹俺,ViT模型還是利用了Transformer模型在處理上下文語義時的優(yōu)勢,將圖像轉(zhuǎn)換為一種“變種詞向量”然后進(jìn)行處理毒费,而這樣轉(zhuǎn)換的意義在于丙唧,多個patch之間本身具有空間聯(lián)系,這類似于一種“空間語義”觅玻,從而獲得了比較好的處理效果想际。

整體構(gòu)建ViT

以下代碼構(gòu)建了一個完整的ViT模型。

from typing import Optional

class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        # 此處增加class_embedding和pos_embedding溪厘,如果不是進(jìn)行分類任務(wù)
        # 可以只增加pos_embedding胡本,通過pool參數(shù)進(jìn)行控制
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        # pos_embedding也是一組可以學(xué)習(xí)的參數(shù),會被加入到經(jīng)過處理的patch矩陣中
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        # axis=1定義了會在向量的開頭加入class_embedding
        self.concat = P.Concat(axis=1)

        self.pool = pool
        self.pos_dropout = nn.Dropout(keep_prob)
        self.norm = norm((embed_dim,))
        self.tile = P.Tile()
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)

    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)

        # class_embedding主要借鑒了BERT模型的用于文本分類時的思想
        # 在每一個word vector之前增加一個類別值畸悬,通常是加在向量的第一位
        cls_tokens = self.tile(self.cls_token, (x.shape[0], 1, 1))
        x = self.concat((cls_tokens, x))
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)

        # 增加的class_embedding是一個可以學(xué)習(xí)的參數(shù)侧甫,經(jīng)過網(wǎng)絡(luò)的不斷訓(xùn)練
        # 最終以輸出向量的第一個維度的輸出來決定最后的輸出類別;
        x = x[:, 0]

        return x

[圖片上傳失敗...(image-a33976-1652419666221)]

模型訓(xùn)練與推理
模型訓(xùn)練
模型開始訓(xùn)練前,需要設(shè)定損失函數(shù)披粟,優(yōu)化器咒锻,回調(diào)函數(shù)等,直接調(diào)用mindvision提供的接口可以方便完成實(shí)例化僻爽。

import mindspore.nn as nn
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

from mindvision.classification.models import vit_b_16
from mindvision.engine.callback import LossMonitor
from mindvision.engine.loss import CrossEntropySmooth

# 定義超參數(shù)
epoch_size = 10
momentum = 0.9
step_size = dataset_train.get_dataset_size()
num_classes = 1000

# 構(gòu)建模型
network = vit_b_16(num_classes=num_classes, image_size=resize, pretrained=True)

# 定義遞減的學(xué)習(xí)率
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.003,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=90)

# 定義優(yōu)化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)

# 定義損失函數(shù)
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  classes_num=num_classes)

# 設(shè)定checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# 初始化模型
model = Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"})

# 訓(xùn)練
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(lr)],
            dataset_sink_mode=False)

結(jié)果:

Epoch:[  0/ 10], step:[    1/80072], loss:[1.963/1.963], time:8171.241 ms, lr:0.00300
Epoch:[  0/ 10], step:[    2/80072], loss:[7.809/4.886], time:769.321 ms, lr:0.00300
Epoch:[  0/ 10], step:[    3/80072], loss:[8.851/6.208], time:779.355 ms, lr:0.00300
....
Epoch:[  9/ 10], step:[80070/80072], loss:[1.112/6.657], time:780.714 ms, lr:0.00240
Epoch:[  9/ 10], step:[80071/80072], loss:[1.111/6.708], time:781.860 ms, lr:0.00240
Epoch:[  9/ 10], step:[80072/80072], loss:[1.102/6.777], time:782.859 ms, lr:0.00240

模型驗(yàn)證

模型驗(yàn)證過程主要應(yīng)用了nn虫碉,Model,context胸梆,ImageNet敦捧,CrossEntropySmooth和vit_b_16等接口。

通過改變ImageNet接口的split參數(shù)即可調(diào)用驗(yàn)證集碰镜。

與訓(xùn)練過程相似兢卵,首先調(diào)用vit_b_16接口定義網(wǎng)絡(luò)結(jié)構(gòu),加載預(yù)訓(xùn)練模型參數(shù)绪颖。隨后設(shè)置損失函數(shù)秽荤,評價指標(biāo)等,編譯模型后進(jìn)行驗(yàn)證柠横。

dataset_analyse = ImageNet(data_url,
                           split="val",
                           num_parallel_workers=1,
                           resize=resize,
                           batch_size=batch_size)
dataset_eval = dataset_analyse.run()

network = vit_b_16(num_classes=num_classes, image_size=resize, pretrained=True)

network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  classes_num=num_classes)

# 定義評價指標(biāo)
eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': nn.Top5CategoricalAccuracy()}

model = Model(network, network_loss, metrics=eval_metrics)

# 評估模型
result = model.eval(dataset_eval)
print(result)

結(jié)果:

{'Top_1_Accuracy': 0.73524, 'Top_5_Accuracy': 0.91756}

模型推理

在進(jìn)行模型推理之前窃款,首先要定義一個對推理圖片進(jìn)行數(shù)據(jù)預(yù)處理的方法。該方法可以對我們的推理圖片進(jìn)行resize和normalize處理牍氛,這樣才能與我們訓(xùn)練時的輸入數(shù)據(jù)匹配晨继。

import mindspore.dataset.vision.c_transforms as transforms

# 數(shù)據(jù)預(yù)處理操作
def infer_transform(dataset, columns_list, resize):

    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    trans = [transforms.Decode(),
             transforms.Resize([resize, resize]),
             transforms.Normalize(mean=mean, std=std),
             transforms.HWC2CHW()]

    dataset = dataset.map(operations=trans,
                          input_columns=columns_list[0],
                          num_parallel_workers=1)
    dataset = dataset.batch(1)

    return dataset

接下來,我們將調(diào)用模型的predict方法進(jìn)行模型推理搬俊,需要注意的是紊扬,推理圖片需要自備,同時給予準(zhǔn)確的路徑利用read_dataset接口讀推理圖片路徑唉擂,利用GeneratorDataset來生成測試集餐屎。

在推理過程中,ImageNet接口主要負(fù)責(zé)對原數(shù)據(jù)集標(biāo)簽和模型輸出進(jìn)行配對玩祟。通過index2label就可以獲取對應(yīng)標(biāo)簽倘核,再通過show_result接口將結(jié)果寫在對應(yīng)圖片上井氢。

import numpy as np

import mindspore.dataset as ds
from mindspore import Tensor

from mindvision.dataset.generator import DatasetGenerator
from mindvision.dataset.download import read_dataset
from mindvision.classification.utils.image import show_result

# 讀取推理圖片
image_list, image_label = read_dataset('./infer')
columns_list = ('image', 'label')

dataset_infer = ds.GeneratorDataset(DatasetGenerator(image_list, image_label),
                                    column_names=list(columns_list),
                                    num_parallel_workers=1)

dataset_infer = infer_transform(dataset_infer, columns_list, resize)

# 讀取數(shù)據(jù)進(jìn)行推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)

    predict = dataset_analyse.index2label[int(label)]
    output = {int(label): predict}
    print(output)
    show_result(img=image_list[i], result=output, out_file=image_list[i])

結(jié)果:

{236: 'Doberman'}

推理過程完成后躏仇,在推理文件夾下可以找到圖片的推理結(jié)果箫踩,如下圖所示:

[圖片上傳失敗...(image-756be5-1652419666221)]

總結(jié)

本案例完成了一個ViT模型在ImageNet數(shù)據(jù)上進(jìn)行訓(xùn)練昼激,驗(yàn)證和推理的過程鞋诗,其中疗杉,對關(guān)鍵的ViT模型結(jié)構(gòu)和原理作了講解尽狠。通過學(xué)習(xí)本案例象对,理解源碼可以幫助學(xué)員掌握Multi-Head Attention黑忱,TransformerEncoder,pos_embedding等關(guān)鍵概念,如果要詳細(xì)理解ViT的模型原理甫煞,建議基于源碼更深層次的詳細(xì)閱讀菇曲,可以參考vision套件:

<u style="text-decoration: none; border-bottom: 1px dashed rgb(128, 128, 128);">https://gitee.com/mindspore/vision/tree/master/examples/classification/vit</u> 。

引用

[1] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

[2] Vaswani, Ashish, et al. "Attention is all you need."Advances in Neural Information Processing Systems. (2017).

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末抚吠,一起剝皮案震驚了整個濱河市常潮,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌楷力,老刑警劉巖喊式,帶你破解...
    沈念sama閱讀 218,204評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異萧朝,居然都是意外死亡岔留,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,091評論 3 395
  • 文/潘曉璐 我一進(jìn)店門检柬,熙熙樓的掌柜王于貴愁眉苦臉地迎上來献联,“玉大人,你說我怎么就攤上這事何址±锬妫” “怎么了?”我有些...
    開封第一講書人閱讀 164,548評論 0 354
  • 文/不壞的土叔 我叫張陵用爪,是天一觀的道長原押。 經(jīng)常有香客問我,道長项钮,這世上最難降的妖魔是什么班眯? 我笑而不...
    開封第一講書人閱讀 58,657評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮烁巫,結(jié)果婚禮上署隘,老公的妹妹穿的比我還像新娘。我一直安慰自己亚隙,他們只是感情好磁餐,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,689評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著阿弃,像睡著了一般诊霹。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上渣淳,一...
    開封第一講書人閱讀 51,554評論 1 305
  • 那天脾还,我揣著相機(jī)與錄音,去河邊找鬼入愧。 笑死鄙漏,一個胖子當(dāng)著我的面吹牛嗤谚,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播怔蚌,決...
    沈念sama閱讀 40,302評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼巩步,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了桦踊?” 一聲冷哼從身側(cè)響起椅野,我...
    開封第一講書人閱讀 39,216評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎籍胯,沒想到半個月后竟闪,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,661評論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡芒炼,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,851評論 3 336
  • 正文 我和宋清朗相戀三年瘫怜,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片本刽。...
    茶點(diǎn)故事閱讀 39,977評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡鲸湃,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出子寓,到底是詐尸還是另有隱情暗挑,我是刑警寧澤,帶...
    沈念sama閱讀 35,697評論 5 347
  • 正文 年R本政府宣布斜友,位于F島的核電站炸裆,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏鲜屏。R本人自食惡果不足惜烹看,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,306評論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望洛史。 院中可真熱鬧惯殊,春花似錦、人聲如沸也殖。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,898評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽忆嗜。三九已至己儒,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間捆毫,已是汗流浹背闪湾。 一陣腳步聲響...
    開封第一講書人閱讀 33,019評論 1 270
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留绩卤,地道東北人响谓。 一個月前我還...
    沈念sama閱讀 48,138評論 3 370
  • 正文 我出身青樓损合,卻偏偏與公主長得像省艳,于是被迫代替她去往敵國和親娘纷。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,927評論 2 355

推薦閱讀更多精彩內(nèi)容