01 Vision Transformer

使用 TensorFlow 構建 ViT B-16 模型朽们。

1. 引言

在計算機視覺任務中通常使用注意力機制對特征進行增強或者使用注意力機制替換某些卷積層的方式來實現(xiàn)對網(wǎng)絡結構的優(yōu)化菜枷,這些方法都在原有卷積網(wǎng)絡的結構中運用注意力機制進行特征增強

而 ViT 依賴于原有的編碼器結構進行搭建蚊锹,并將其用于圖像分類任務牡昆,在減少模型參數(shù)量的同時提高了檢測準確度。

將 Transformer 用于圖像分類任務主要有以下 5 個過程:(1)將輸入圖像或特征進行序列化铅协;(2)添加位置編碼;(3)添加可學習的嵌入向量说墨;(4)輸入到編碼器中進行編碼姜贡;(5)將輸出的可學習嵌入向量用于分類楼咳。結構圖如下:


2. Patch Embedding

以 b×224×224×3 的輸入圖片為例。首先進行圖像分塊苹熏,將原圖片切分為 14×14 個圖像塊(Patch),每個 Patch 的大小為 16×16干发,通過提取輸入圖片中的平坦像素向量枉长,將每個輸入 Patch 送入線性投影層搀暑,得到 Patch Embeddings。
在代碼中桂敛,先經(jīng)過一個 kernel=(16,16),strides=16 的卷積層劃分圖像塊粗仓,再將 h和w 維度整合為 num_patches 維度塘淑,代表一共有 196 個 patch,每個 patch 為 16×16捌治。

3. 添加類別標簽和位置編碼

為了輸出融合了全局語義信息的向量表示,在第一個輸入向量前添加可學習分類變量构韵。經(jīng)過編碼器編碼后,在最后一層輸出中显拳,該位置對應的輸出向量就可以用于分類任務杂数。與其他位置對應的輸出向量相比,該向量可以更好的融合圖像中各個圖像塊之間的依賴關系那伐。
在 Transformer 更新的過程中,輸入序列的順序信息會丟失诉探。Transformer 本身并沒有辦法學習這個信息肾胯,所以需要一種方法將位置表示聚合到模型的輸入嵌入中怕敬。我們對每個 Patch 進行位置編碼畸陡,該位置編碼采用隨機初始化,之后參與模型訓練丁恭。與傳統(tǒng)三角函數(shù)的位置編碼方法不同曹动,該方法是可學習的。
最后牲览,將 Patch-Embeddings 和 class-token 進行堆疊墓陈,和 Position-Embeddings 進行疊加,得到最終嵌入向量第献,該向量輸入給 Transformer 層進行后續(xù)處理贡必。


4. 多頭自注意力模塊

Transformer 層中飒赃,主要包含多頭注意力機制和多層感知機模塊挠乳,下面先介紹多頭自注意力模塊。
單個的注意力機制,其每個輸入包含三個不同的向量娩贷,分別為 Query向量(Q),Key向量(K)羽历,Value向量(V)。他們的結果分別由輸入特征圖和三個權重做矩陣乘法得到颈抚。



接著為每一個輸入計算一個得分Score=q*k赐稽。
為了使梯度穩(wěn)定伶选,對 Score 的值進行歸一化處理掸鹅,并將結果通過 softmax 函數(shù)進行映射矮嫉。之后再和 v 做矩陣相乘援岩,得到加權后每個輸入向量的得分 v限寞。計算完后再乘以一個權重張量 W 提取特征妈橄。
計算公式如下,其中 \sqrt{d_{k}代表 K 向量維度的平方根洞焙。


5. MLP 多層感知器

這個部分簡單來看就是兩個全連接層提取特征兵迅,流程圖如下。第一個全連接層通道上升4倍交洗,第二個全連接層通道下降為原來沪伙。


6. 特征提取模塊

Transformer 的單個特征提取模塊是由 多頭注意力機制 和 多層感知機模塊 組合而成谍倦,encoder_block 模塊的流程圖如下塞赂。
輸入圖像像經(jīng)過 LayerNormalization 標準化后,再經(jīng)過我們上面定義的多頭注意力模塊昼蛀,將輸出結果和輸入特征圖殘差連接宴猾,圖像在特征提取過程中shape保持不變圆存。

將輸出結果再經(jīng)過標準化,然后送入多層感知器提取特征仇哆,再使用殘差連接輸入和輸出沦辙。



而 transformer 的特征提取模塊是由多個 encoder_block 疊加而成,這里連續(xù)使用12個 encoder_block 模塊來提取特征税产。

7. 主干網(wǎng)絡

接下來就搭建網(wǎng)絡了怕轿,將上面所有的模塊組合到一起,如下圖所示辟拷。

在下面代碼中要注意的是 cls_ticks = x[:,0] 取出所有的類別標簽撞羽。 因為在 cls_pos_embed 模塊中,我們將 cls_token 和輸入圖像在 patch 維度上堆疊 layers.concate衫冻,用于學習每張?zhí)卣鲌D的類別信息诀紊,取出的類別標簽 cls_ticks 的 shape 為 [b, 768]。最后經(jīng)過一個全連接層得出每張圖片屬于每個類別的得分隅俘。


8. 查看模型結構

這里有個注意點邻奠,keras.Input() 的參數(shù)問題,創(chuàng)建輸入層時为居,參數(shù) shape 不需要指定batch維度碌宴,batch_shape 需要指定batch維度。

keras.Input(shape=None, batch_shape=None, name=None, dtype=K.floatx(), sparse=False, tensor=None)
'''
shape: 形狀元組(整型)蒙畴,不包括batch size贰镣。for instance, shape=(32,) 表示了預期的輸入將是一批32維的向量。
batch_shape: 形狀元組(整型)膳凝,包括了batch size碑隆。for instance, batch_shape=(10,32)表示了預期的輸入將是10個32維向量的批次。
'''

接收模型后蹬音,通過 model.summary() 查看模型結構和參數(shù)量上煤,通過 get_flops() 參看浮點計算量。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


# --------------------------------------------- #
# (1)Embedding 層
# inputs代表輸入圖像著淆,shape為224*224*3
# out_channel代表該模塊的輸出通道數(shù)劫狠,即第一個卷積層輸出通道數(shù)=768
# patch_size代表卷積核在圖像上每16*16個區(qū)域卷積得出一個值
# --------------------------------------------- #
def patch_embed(inputs, out_channel, patch_size=16):
 
    # 獲得輸入圖像的shape=[b,224,224,3]
    b, h, w, c = inputs.shape
 
    # 獲得劃分后每張圖像的size=(14,14)
    grid_h, grid_w = h//patch_size, w//patch_size
 
    # 計算圖像寬高共有多少個像素點 n = h*w
    num_patches = grid_h * grid_w
 
    # 卷積 [b,224,224,3]==>[b,14,14,768]
    x = layers.Conv2D(filters=out_channel, kernel_size=(patch_size,patch_size), strides=patch_size, padding='same')(inputs)
 
    # 維度調(diào)整 [b,h,w,c]==>[b,n,c]
    # [b,14,14,768]==>[b,196,768]
    x = tf.reshape(x, shape=[b, num_patches, out_channel])
 
    return x


# --------------------------------------------- #
# (2)類別標簽和位置編碼
# --------------------------------------------- #
def class_pos_add(inputs):
 
    # 獲得輸入特征圖的shape=[b,196,768]
    b, num_patches, channel = inputs.shape
 
    # 類別信息 [1,1,768]
    # 直接通過classtoken來判斷類別,classtoken能夠?qū)W到其他token中的分類相關的信息
    cls_token = layers.Layer().add_weight(name='classtoken', shape=[1,1,channel], dtype=tf.float32,
                                          initializer=keras.initializers.Zeros(), trainable=True)  
 
    # 可學習的位置變量 [1,197,768], 初始化為0牧抽,trainable=True代表可以通過反向傳播更新權重
    pos_embed = layers.Layer().add_weight(name='posembed', shape=[1,num_patches+1,channel], dtype=tf.float32,
                                          initializer=keras.initializers.RandomNormal(stddev=0.02), trainable=True)
 
    # 將類別信息在維度上廣播 [1,1,768]==>[b,1,768]
    cls_token = tf.broadcast_to(cls_token, shape=[b, 1, channel])
 
    # 在num_patches維度上堆疊嘉熊,注意要把cls_token放前面
    # [b,1,768]+[b,196,768]==>[b,197,768]
    x = layers.concatenate([cls_token, inputs], axis=1)
 
    # 將位置信息疊加上去
    x = tf.add(x, pos_embed)
 
    return x  # [b,197,768]


# --------------------------------------------- #
# (3)多頭自注意力模塊
# inputs: 代表編碼后的特征圖
# num_heads: 代表多頭注意力中heads個數(shù)
# qkv_bias: 計算qkv是否使用偏置
# atten_drop_rate, proj_drop_rate:代表兩個全連接層后面的dropout層
# --------------------------------------------- #
def attention(inputs, num_heads, qkv_bias=False, atten_drop_rate=0., proj_drop_rate=0.):
 
    # 獲取輸入特征圖的shape=[b,197,768]
    b, num_patches, channel = inputs.shape
    # 計算每個head的通道數(shù)
    head_channel = channel // num_heads
    # 公式的分母,根號d
    scale = head_channel ** 0.5
 
    # 經(jīng)過一個全連接層計算qkv [b,197,768]==>[b,197,768*3]
    qkv = layers.Dense(channel*3, use_bias=qkv_bias)(inputs)
    # 調(diào)整維度 [b,197,768*3]==>[b,197,3,num_heads,c//num_heads]
    qkv = tf.reshape(qkv, shape=[b, num_patches, 3, num_heads, channel//num_heads])
    # 維度重排 [b,197,3,num_heads,c//num_heads]==>[3,b,num_heads,197,c//num_heads]
    qkv = tf.transpose(qkv, perm=[2, 0, 3, 1, 4])
    # 獲取q扬舒、k阐肤、v的值==>[b,num_heads,197,c//num_heads]
    q, k, v = qkv[0], qkv[1], qkv[2]
 
    # 矩陣乘法, q 乘 k 的轉(zhuǎn)置,除以縮放因子。矩陣相乘計算最后兩個維度
    # [b,num_heads,197,c//num_heads] * [b,num_heads,c//num_heads,197] ==> [b,num_heads,197,197]
    atten = tf.matmul(a=q, b=k, transpose_b=True) / scale
    # 對每張?zhí)卣鲌D進行softmax函數(shù)
    atten = tf.nn.softmax(atten, axis=-1)
    # 經(jīng)過dropout層
    atten = layers.Dropout(rate=atten_drop_rate)(atten)
    # 再進行矩陣相乘==>[b,num_heads,197,c//num_heads]
    atten = tf.matmul(a=atten, b=v)
 
    # 維度重排==>[b,197,num_heads,c//num_heads]
    x = tf.transpose(atten, perm=[0, 2, 1, 3])
    # 維度調(diào)整==>[b,197,c]==[b,197,768]
    x = tf.reshape(x, shape=[b, num_patches, channel])
 
    # 調(diào)整之后再經(jīng)過一個全連接層提取特征==>[b,197,768]
    x = layers.Dense(channel)(x)
    # 經(jīng)過dropout
    x = layers.Dropout(rate=proj_drop_rate)(x)
 
    return x


# ------------------------------------------------------ #
# (4)MLP block
# inputs代表輸入特征圖孕惜;mlp_ratio代表第一個全連接層上升通道倍數(shù)愧薛;
# drop_rate代表殺死神經(jīng)元概率
# ------------------------------------------------------ #
def mlp_block(inputs, mlp_ratio=4.0, drop_rate=0.):
 
    # 獲取輸入圖像的shape=[b,197,768]
    b, num_patches, channel = inputs.shape
 
    # 第一個全連接上升通道數(shù)==>[b,197,768*4]
    x = layers.Dense(int(channel*mlp_ratio))(inputs)
    # GeLU激活函數(shù)
    x = layers.Activation('gelu')(x)
    # dropout層
    x = layers.Dropout(rate=drop_rate)(x)
 
    # 第二個全連接層恢復通道數(shù)==>[b,197,768]
    x = layers.Dense(channel)(x)
    # dropout層
    x = layers.Dropout(rate=drop_rate)(x)
 
    return x


# ------------------------------------------------------ #
# (5)單個特征提取模塊
# num_heads:代表自注意力的heads個數(shù)
# epsilon:小浮點數(shù)添加到方差中以避免除以零
# drop_rate:自注意力模塊之后的dropout概率
# ------------------------------------------------------ #
def encoder_block(inputs, num_heads, epsilon=1e-6, atten_drop_rate=0., proj_drop_rate=0., drop_rate=0.):
 
    # LayerNormalization
    x = layers.LayerNormalization(epsilon=epsilon)(inputs)
    # 自注意力模塊
    x = attention(x, num_heads=num_heads, atten_drop_rate=atten_drop_rate, proj_drop_rate=proj_drop_rate)
    # 殘差連接輸入和輸出
    # x1 = x + inputs
    x1 = layers.add([x, inputs])
    
    # LayerNormalization
    x = layers.LayerNormalization(epsilon=epsilon)(x1)
    # MLP模塊
    x = mlp_block(x, drop_rate=drop_rate)
    # 殘差連接
    # x2 = x + x1
    x2 = layers.add([x, x1])
 
    return x2  # [b,197,768]
 
# ------------------------------------------------------ #
# (6)連續(xù)12個特征提取模塊
# ------------------------------------------------------ #
def transformer_block(x, num_heads):
 
    # 重復堆疊12次
    for _ in range(12):
        # 本次的特征提取塊的輸出是下一次的輸入
        x = encoder_block(x, num_heads=num_heads)
 
    return x  # 返回特征提取12次后的特征圖


# ---------------------------------------------------------- # 
# (7)主干網(wǎng)絡
# batch_shape:代表輸入圖像的shape=[8,224,224,3]
# classes:代表最終的分類數(shù)
# drop_rate:代表位置編碼后的dropout層的drop率
# num_heads:代表自注意力機制的heads個數(shù)
# epsilon:小浮點數(shù)添加到方差中以避免除以零
# ---------------------------------------------------------- # 
def VIT(batch_shape, classes, drop_rate=0., num_heads=12, epsilon=1e-6):
 
    # 構造輸入層 [b,224,224,3]
    inputs = keras.Input(batch_shape=batch_shape)
 
    # PatchEmbedding層==>[b,196,768]
    x = patch_embed(inputs, out_channel=768)
 
    # 類別和位置編碼==>[b,197,768]
    x = class_pos_add(x)
 
    # dropout層
    x = layers.Dropout(rate=drop_rate)(x)
 
    # 經(jīng)過12次特征提取==>[b,197,768]
    x = transformer_block(x, num_heads=num_heads)
 
    # LayerNormalization
    x = layers.LayerNormalization(epsilon=epsilon)(x)
 
    # 取出特征圖的類別標簽,在第(2)步中我們把類別標簽放在了最前面
    cls_ticks = x[:,0]
    # 全連接層分類
    outputs = layers.Dense(classes)(cls_ticks)
 
    # 構建模型
    model = keras.Model(inputs, outputs)
 
    return model


# ---------------------------------------------------------- # 
# (8)接收模型
# ---------------------------------------------------------- # 
if __name__ == '__main__':
 
    batch_shape = [8,224,224,3]  # 輸入圖像的尺寸
    classes = 1000  # 分類數(shù)
 
    # 接收模型
    model = VIT(batch_shape, classes)
 
    # 查看模型結構
    model.summary()
    
    # 查看浮點計算量 flops = 51955425272
    from keras_flops import get_flops
    print('flops:', get_flops(model, batch_size=8))

上述代碼在tensorflow2.14.0中能運行衫画,但在2.16.0中出現(xiàn)問題毫炉,結合官網(wǎng)給出程序修改,如下

# 1 Setup
# 環(huán)境準備
# 導入 TensorFlow削罩、NumPy等相關庫
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import matplotlib.pyplot as plt

# 2 Prepare the data
# 數(shù)據(jù)導入和查看
num_classes = 10  # 分類數(shù)目為10個
input_shape = (32, 32, 3)  # 輸入的圖片形狀為32x32像素瞄勾,3個通道(RGB)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# 加載 CIFAR10 數(shù)據(jù)集,并將其拆分為訓練和測試集弥激,x_train和x_test為圖像數(shù)據(jù)进陡,y_train和y_test為標簽數(shù)據(jù)

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
# 打印訓練集的圖像數(shù)據(jù)形狀和標簽數(shù)據(jù)形狀
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
# 打印測試集的圖像數(shù)據(jù)形狀和標簽數(shù)據(jù)形狀

#3 Configure the hyperparameters
# 配置超參數(shù)
learning_rate = 0.001  # 學習率,控制模型更新的步長大小
weight_decay = 0.0001  # 權重衰減微服,控制模型復雜度趾疚,防止過擬合
batch_size = 256  # 每次訓練使用的樣本數(shù)量
num_epochs = 10  # 訓練的總輪數(shù),真正訓練以蕴,可設置為100
image_size = 72  # 將輸入圖像的大小調(diào)整為此大小
patch_size = 6  # 從輸入圖像中提取的圖像塊patch的大小
num_patches = (image_size // patch_size)**2  # 輸入圖像中的圖像塊數(shù)
projection_dim = 64  # Transformer模型中的投影維度糙麦,用于計算每個圖像塊的嵌入向量
num_heads = 4  # Transformer模型中的注意力頭數(shù),用于計算每個圖像塊的特征向量
transformer_units = [  # Transformer層的大小丛肮,每一層都有兩個子層赡磅,一個是多頭自注意力子層,一個是全連接子層
    projection_dim * 2,  # 第一子層的大小是投影維度的兩倍
    projection_dim,  # 第二子層的大小是投影維度
]
transformer_layers = 8  # Transformer模型中Transformer層的數(shù)量
mlp_head_units = [2048, 1024]  # 最終分類器中的兩個全連接層的大小

# 4 Use data augmentation
# 數(shù)據(jù)增強
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),  # 歸一化圖像數(shù)據(jù)
        layers.Resizing(image_size, image_size),  # 調(diào)整圖像大小為指定大小
        layers.RandomFlip("horizontal"),  # 水平隨機翻轉(zhuǎn)圖像
        layers.RandomRotation(factor=0.02),  # 隨機旋轉(zhuǎn)圖像
        layers.RandomZoom(
            height_factor=0.2,
            width_factor=0.2  # 隨機縮放圖像
        ),
    ],
    name="data_augmentation",  # 給數(shù)據(jù)增強模型起個名字
)

data_augmentation.layers[0].adapt(x_train)  # 對訓練集進行數(shù)據(jù)歸一化宝与,計算均值和方差用于后續(xù)的歸一化處理


# 5 Implement multilayer perceptron (MLP)
# MLP 實現(xiàn)
def mlp(x, hidden_units, dropout_rate):
    # 定義一個MLP函數(shù)仆邓,其中參數(shù)x表示輸入數(shù)據(jù),hidden_units表示每一層MLP的神經(jīng)元數(shù)伴鳖,dropout_rate表示Dropout比率
    for units in hidden_units:
        # 循環(huán)遍歷所有的隱藏層神經(jīng)元數(shù)
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        # 全連接層,其中units為該層神經(jīng)元個數(shù)徙硅,激活函數(shù)為gelu
        x = layers.Dropout(dropout_rate)(x)  # Dropout層榜聂,使部分神經(jīng)元隨機失活,防止過擬合
    return x  # 返回處理后的數(shù)據(jù)x


# 6 Implement patch creation as a layer
# 將patch創(chuàng)建實現(xiàn)為層
class Patches(layers.Layer):

    def __init__(self, patch_size):
        super().__init__()  # 繼承父類的初始化方法
        self.patch_size = patch_size  # 圖像塊的大小

    def call(self, images):
        input_shape = tf.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = tf.image.extract_patches(  # 使用 TensorFlow 的圖像處理 API 獲取圖像塊
            images=images,  # 輸入的圖像
            sizes=[1, self.patch_size, self.patch_size, 1],  # 圖像塊的大小
            strides=[1, self.patch_size, self.patch_size, 1],  # 滑動步長
            rates=[1, 1, 1, 1],  # 對輸入數(shù)據(jù)進行擴展的因素
            padding="VALID",  # 填充方式
        )
        patches = tf.reshape(patches, [
            batch_size, num_patches_h * num_patches_w,
            self.patch_size * self.patch_size * channels
        ])  # 對圖像塊進行形狀變換
        return patches  # 返回處理后的圖像塊

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config


# 7 Implement the patch encoding layer
# 實現(xiàn) patch 的編碼
# 添加類別編碼嗓蘑,并將可學習的位置嵌入到投影向量中
class PatchEncoder(layers.Layer):
    # 定義一個類须肆,繼承自Layer類
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        self.num_patches = input_shape[-2]
        self.projection_dim = input_shape[-1]
        # 類別信息 [1,1,768]
        #直接通過classtoken來判斷類別,classtoken能夠?qū)W到其他token中的分類相關的信息
        self.cls_token = self.add_weight(
            shape=(1, 1, self.projection_dim),
            dtype=tf.float32,
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name='cls',
        )

        self.pe = self.add_weight(
            shape=[1, self.num_patches + 1, self.projection_dim],
            dtype=tf.float32,
            initializer=keras.initializers.RandomNormal(stddev=0.02),
            trainable=True,
            name='pos_embedding',
        )
        super(PatchEncoder, self).build(input_shape)

    def call(self, patch):
        batch_size = tf.shape(patch)[0]  # 獲取圖片的批次大小
        # 將類別信息在維度上廣播 [1,1,768]==>[b,1,768]
        cls_broadcasted = tf.cast(tf.broadcast_to(
            self.cls_token, shape=[batch_size, 1, self.projection_dim]),
                                  dtype=patch.dtype)
        # 定義call方法桩皿,用于前向傳播
        x = tf.concat([cls_broadcasted, patch], 1)
        # 在num_patches維度上堆疊豌汇,注意要把cls_token放前面
        # [b,1,768]+[b,196,768]==>[b,197,768]
        encoded = x + tf.cast(self.pe, dtype=patch.dtype)
        # 再加上嵌入的位置信息
        return encoded
        # 返回編碼結果


# --------------------------------------------- #
# 自行編寫:多頭自注意力模塊
# inputs: 代表編碼后的特征圖
# num_heads: 代表多頭注意力中heads個數(shù)
# qkv_bias: 計算qkv是否使用偏置
# atten_drop_rate, proj_drop_rate:代表兩個全連接層后面的dropout層
# --------------------------------------------- #
class attention(layers.Layer):

    def __init__(self,
                 num_heads,
                 projection_dim=64,
                 qkv_bias=False,
                 atten_drop_rate=0.,
                 proj_drop_rate=0.):
        super().__init__()
        self.num_heads = num_heads
        self.projection_dim = projection_dim
        self.qkv_bias = qkv_bias
        # 計算每個head的通道數(shù)
        self.head_channel = self.projection_dim // self.num_heads
        # 公式的分母,根號d
        self.scale = self.head_channel**0.5
        self.drop1 = layers.Dropout(rate=atten_drop_rate)
        self.dense1 = layers.Dense(self.projection_dim)
        self.drop2 = layers.Dropout(rate=proj_drop_rate)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        # 獲取輸入特征圖的shape=[b,197,768]
        # 調(diào)整維度 [b,197,768*3]==>[b,197,3,num_heads,c//num_heads]
        inputs = tf.reshape(
            inputs, [batch_size, -1, 3, self.num_heads, self.head_channel])
        # 公式的分母泄隔,根號d
        scale = self.head_channel**0.5

        # 維度重排 [b,197,3,num_heads,c//num_heads]==>[3,b,num_heads,197,c//num_heads]
        inputs = tf.transpose(inputs, perm=[2, 0, 3, 1, 4])
        # 獲取q拒贱、k、v的值==>[b,num_heads,197,c//num_heads]
        q, k, v = inputs[0], inputs[1], inputs[2]

        # 矩陣乘法, q 乘 k 的轉(zhuǎn)置,除以縮放因子逻澳。矩陣相乘計算最后兩個維度
        # [b,num_heads,197,c//num_heads] * [b,num_heads,c//num_heads,197] ==> [b,num_heads,197,197]
        atten = tf.matmul(a=q, b=k, transpose_b=True) / scale
        # 對每張?zhí)卣鲌D進行softmax函數(shù)
        atten = tf.nn.softmax(atten, axis=-1)
        # 經(jīng)過dropout層
        atten = self.drop1(atten)
        # 再進行矩陣相乘==>[b,num_heads,197,c//num_heads]
        atten = tf.matmul(a=atten, b=v)

        # 維度重排==>[b,197,num_heads,c//num_heads]
        x = tf.transpose(atten, perm=[0, 2, 1, 3])
        # 維度調(diào)整==>[b,197,c]==[b,197,768]
        x = tf.reshape(x, [batch_size, -1, self.projection_dim])

        # 調(diào)整之后再經(jīng)過一個全連接層提取特征==>[b,197,768]
        x = self.dense1(x)
        # 經(jīng)過dropout
        x = self.drop2(x)
        return x


# 8 Build the ViT model
# 建立 ViT 模型
# ViT模型由多個Transformer塊組成闸天,每個塊使用layers.MultiHeadAttention層作為自注意機制
# Transformer塊生成一個[batch_size,num_patches斜做,projection_dim]張量
# 通過softmax分類器頭處理以生成最終的類別概率輸出苞氮。
def create_vit_classifier():
    # 輸入數(shù)據(jù)形狀。
    inputs = layers.Input(shape=input_shape)

    # 數(shù)據(jù)增強瓤逼。
    augmented = data_augmentation(inputs)

    # 將圖片切分成patch并embedding笼吟,
    # 必須使用超參數(shù)patch_size,有兩種方式:
    # 1. 直接卷積霸旗,這種方便贷帮,但靈活性差點
    # projection_1= layers.Conv2D(projection_dim,
    #                        patch_size,
    #                        strides=patch_size,
    #                        padding="valid",
    #                        name="patch_embed.proj")(augmented)
    # projection = layers.Reshape(
    #     ((image_size // patch_size) * (image_size // patch_size),
    #      projection_dim))(projection_1)
    # 2. 實現(xiàn)Patches類,并全連接映射定硝,使用超參數(shù)projection_dim
    patches = Patches(patch_size)(augmented)
    projection = layers.Dense(units=projection_dim)(patches)  # [batch,196,768]

    # 編碼圖像拼接塊皿桑。
    encoded_patches = PatchEncoder()(projection)

    # 創(chuàng)建多個Transformer塊。
    for _ in range(transformer_layers):
        # 第一層歸一化蔬啡。
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)

        #創(chuàng)建多頭注意力層诲侮。
        qkv = layers.Dense(int(projection_dim * 3))(x1)
        attention_output = attention(
            num_heads=num_heads,
            projection_dim=projection_dim,
            # 注意:head_channel=projection_dim//num_heads
            qkv_bias=False,
            atten_drop_rate=0.,
            proj_drop_rate=0.)(qkv)

        # attention_output = layers.MultiHeadAttention(num_heads=num_heads,
        #                                              key_dim=projection_dim,
        #                                              dropout=0.1)(x1, x1)

        # 跳躍連接1
        x2 = layers.Add()([attention_output, encoded_patches])
        # 第二層歸一化
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # 跳躍連接2
        encoded_patches = layers.Add()([x3, x2])

    # 創(chuàng)建一個 [batch_size, projection_dim] 張量。
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)

    # 取出特征圖的類別標簽箱蟆,在第(2)步中我們把類別標簽放在了最前面
    cls_ticks = layers.Lambda(lambda v: v[:, 0],
                              name="ExtractToken")(representation)
    # 全連接層分類
    logits = layers.Dense(num_classes)(cls_ticks)
    soft = layers.Softmax()(logits)
    # 創(chuàng)建Keras模型沟绪。
    model = keras.Model(inputs=inputs, outputs=soft)
    return model


# 9 Compile, train, and evaluate the mode
# 編譯、訓練和評估模型
def run_experiment(model):
    # 定義優(yōu)化器
    optimizer = tf.optimizers.AdamW(learning_rate=learning_rate,
                                    weight_decay=weight_decay)

    # 編譯模型空猜,指定優(yōu)化器和損失函數(shù)绽慈,同時定義評價指標
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5,
                                                        name="top-5-accuracy"),
        ],
    )

    # 設定模型訓練過程中的回調(diào)函數(shù),用于保存模型參數(shù)
    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",  # 監(jiān)控的評價指標
        save_best_only=True,  # 僅保存最好的模型
        save_weights_only=True,  # 僅保存模型參數(shù)
    )

    # 訓練模型
    history = model.fit(
        x=x_train,  # 輸入特征
        y=y_train,  # 輸入標簽
        batch_size=batch_size,  # 批次大小
        epochs=num_epochs,  # 訓練輪數(shù)
        validation_split=0.1,  # 用于驗證的數(shù)據(jù)比例
        callbacks=[checkpoint_callback],  # 回調(diào)函數(shù)列表
    )

    # 加載保存的最優(yōu)模型參數(shù)辈毯,并在測試集上進行評估
    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test,
                                                 y_test)  # 返回損失函數(shù)值坝疼、評價指標值
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # 返回訓練歷史記錄
    return history


# 創(chuàng)建一個 VIT 分類模型
vit_classifier = create_vit_classifier()
# 運行訓練實驗
history = run_experiment(vit_classifier)


def plot_history(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_history("loss")
plot_history("top-5-accuracy")

原文鏈接:https://blog.csdn.net/dgvv4/article/details/124792386
https://zhuanlan.zhihu.com/p/626375905
https://keras.io/examples/vision/image_classification_with_vision_transformer/

最后編輯于
?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市谆沃,隨后出現(xiàn)的幾起案子钝凶,更是在濱河造成了極大的恐慌,老刑警劉巖唁影,帶你破解...
    沈念sama閱讀 211,743評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件耕陷,死亡現(xiàn)場離奇詭異,居然都是意外死亡据沈,警方通過查閱死者的電腦和手機哟沫,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,296評論 3 385
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來锌介,“玉大人嗜诀,你說我怎么就攤上這事猾警。” “怎么了裹虫?”我有些...
    開封第一講書人閱讀 157,285評論 0 348
  • 文/不壞的土叔 我叫張陵肿嘲,是天一觀的道長。 經(jīng)常有香客問我筑公,道長雳窟,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,485評論 1 283
  • 正文 為了忘掉前任匣屡,我火速辦了婚禮封救,結果婚禮上,老公的妹妹穿的比我還像新娘捣作。我一直安慰自己誉结,他們只是感情好,可當我...
    茶點故事閱讀 65,581評論 6 386
  • 文/花漫 我一把揭開白布券躁。 她就那樣靜靜地躺著惩坑,像睡著了一般。 火紅的嫁衣襯著肌膚如雪也拜。 梳的紋絲不亂的頭發(fā)上以舒,一...
    開封第一講書人閱讀 49,821評論 1 290
  • 那天,我揣著相機與錄音慢哈,去河邊找鬼蔓钟。 笑死,一個胖子當著我的面吹牛卵贱,可吹牛的內(nèi)容都是我干的滥沫。 我是一名探鬼主播,決...
    沈念sama閱讀 38,960評論 3 408
  • 文/蒼蘭香墨 我猛地睜開眼键俱,長吁一口氣:“原來是場噩夢啊……” “哼兰绣!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起编振,我...
    開封第一講書人閱讀 37,719評論 0 266
  • 序言:老撾萬榮一對情侶失蹤狭魂,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后党觅,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,186評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡斋泄,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,516評論 2 327
  • 正文 我和宋清朗相戀三年杯瞻,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片炫掐。...
    茶點故事閱讀 38,650評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡魁莉,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情旗唁,我是刑警寧澤畦浓,帶...
    沈念sama閱讀 34,329評論 4 330
  • 正文 年R本政府宣布,位于F島的核電站检疫,受9級特大地震影響讶请,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜屎媳,卻給世界環(huán)境...
    茶點故事閱讀 39,936評論 3 313
  • 文/蒙蒙 一夺溢、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧烛谊,春花似錦风响、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,757評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至双泪,卻和暖如春持搜,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背攒读。 一陣腳步聲響...
    開封第一講書人閱讀 31,991評論 1 266
  • 我被黑心中介騙來泰國打工朵诫, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人薄扁。 一個月前我還...
    沈念sama閱讀 46,370評論 2 360
  • 正文 我出身青樓剪返,卻偏偏與公主長得像,于是被迫代替她去往敵國和親邓梅。 傳聞我的和親對象是個殘疾皇子脱盲,可洞房花燭夜當晚...
    茶點故事閱讀 43,527評論 2 349

推薦閱讀更多精彩內(nèi)容