第42章 使用VGG模型實現(xiàn)MNIST數(shù)據(jù)集分類

經(jīng)過前面兩章的介紹跋理,現(xiàn)在著手開始實現(xiàn)VGG模型進行MNIST分類任務的JAX代碼實戰(zhàn)绷跑。

數(shù)據(jù)準備

這里使用tensorflow_datasets庫來管理數(shù)據(jù)集渗钉。該庫自動下載航瞭,同時提供了類和方法來操作數(shù)據(jù)集诫硕,包括分割訓練集和測試集。MnistDatasets.py代碼如下刊侯,


import array
import gzip
import os
import ssl
import struct
import urllib.request
import jax.numpy

from os import path
from tqdm import tqdm

data_dir = "/tmp/JAX/Shares/Datasets/MNIST/"

def _download(url, name):

    """

    Download an url to a file in JAX data temporary directory

    """

    if not path.exists(data_dir):

        os.makedirs(data_dir)

    out_file = path.join(data_dir, name)

    if not path.isfile(out_file):

        ssl._create_default_https_context = ssl._create_unverified_context

        with tqdm(unit = "B", unit_scale = True, unit_divisor = 1024, miniters = 1, desc = name) as bar:

            urllib.request.urlretrieve(url, out_file, reporthook = report_hook(bar))

        print(f"Downloaded {url} to {data_dir}")

def report_hook(bar: tqdm):

    """

    Progress Bar of tqdm for downloads

    """

    def hook(block_counter = 0, block_size = 1, total_size = None):

        if total_size is not None:

            bar.total = total_size

        bar.update(block_counter * block_size - bar.n)

    return hook

def mnist_raw():

    """

    Download and parse the raw MNIST dataset.

    """

    # CVDF mirror of http://yann.lecun.com/exdb/mnist/
    base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

    def parse_labels(file):

        with gzip.open(file, "rb") as handler:

            _ = struct.unpack(">II", handler.read(8))

            return jax.numpy.array(array.array("B", handler.read()), dtype = jax.numpy.uint8)

    def parse_images(file):

        with gzip.open(file, "rb") as handler:

            _, number, rows, columns = struct.unpack(">IIII", handler.read(16))

            return jax.numpy.array(array.array("B", handler.read()), dtype = jax.numpy.uint8).reshape(number, rows, columns)

    for name in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:

        url = path.join(data_dir, name)

        if not path.exists(url):
            _download(base_url + name, name)

    train_images = parse_images(path.join(data_dir, "train-images-idx3-ubyte.gz"))
    train_labels = parse_labels(path.join(data_dir, "train-labels-idx1-ubyte.gz"))
    test_images = parse_images(path.join(data_dir, "t10k-images-idx3-ubyte.gz"))
    test_labels = parse_labels(path.join(data_dir, "t10k-labels-idx1-ubyte.gz"))

    return train_images, train_labels, test_images, test_labels

def mnist(permute_train = False):

    """

    Download, parse and process the MNIST data to unit scale and one-hot labels

    """

    train_images, train_labels, test_images, test_labels = mnist_raw()

    if permute_train:

        permutation = jax.random.permutation(train_images.shape[0])

        train_images = train_images[permutation]
        train_labels = train_labels[permutation]

    return train_images, train_labels, test_images, test_labels


VisualGeometryGroupMnist.py代碼如下章办,

import jax
import tensorflow_datasets as tfds
import time

import MnistDatasets

def setup():

   train_images, train_labels, test_images, test_labels = MnistDatasets.mnist()
   
   batch_size = 600
   
   inputs_channels = 1
   epochs = 9
   
   prng = jax.random.PRNGKey(15)
   
   kernel_shapes = [
        [3, 3, 1, 16],
        [3, 3, 16, 32],
        [3, 3, 32, 48],
        [3, 3, 48, 64],
        [50176, 128],
        [128, 10]
    ]
   
   return (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes)

def one_hot(inputs, length = 10, dtype = jax.numpy.float32):
    
    matches = jax.numpy.array(inputs[:, None] == jax.numpy.arange(length), dtype)
    
    return matches

def partial_flatten(inputs):

    """

    Flatten all but the first dimension of an array

    jax.lax.expand_dims(inputs, [-1]): [60000, 28, 28] -> [60000, 28, 28, 1]
    jax.lax.expand_dims(inputs, [1, 2]): [60000, 28, 28] -> [60000, 1, 1, 28, 28]

    """
    inputs = jax.lax.expand_dims(inputs, [-1])  # [60000, 28, 28] -> [60000, 28, 28, 1]

    return inputs / 255.
   
def test():

    (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes) = setup()
    
    print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape), (batch_size, epochs, prng, kernel_shapes))

運行結果打印輸出如下,


((60000, 28, 28), (60000,)) ((10000, 28, 28), (10000,)) (600, 9, Array([ 0, 15], dtype=uint32), [[3, 3, 1, 16], [3, 3, 16, 32], [3, 3, 32, 48], [3, 3, 48, 64], [50176, 128], [128, 10]])

通過打印結果可以預覽數(shù)據(jù)集和訓練集的構成滨彻。

計算模型實現(xiàn)

上一章我們介紹并實現(xiàn)了VGG的部分組件藕届,這里可以直接使用,

卷積層


def conv(inputs, kernel, window_strides = 1):
    
    shape = inputs.shape
    dimension_numbers = jax.lax.conv_dimension_numbers(lhs_shape = shape, rhs_shape = kernel["weight"].shape, dimension_numbers = ("NHWC", "HWIO", "NHWC"))
    
    inputs = jax.lax.conv_general_dilated(inputs, kernel["weight"], window_strides = [window_strides, window_strides], padding = "SAME", dimension_numbers = dimension_numbers)
    inputs = jax.nn.selu(inputs)
    
    return inputs

前向傳播函數(shù)


@jax.jit
def forward(parameters, inputs):

    for i in range(len(parameters) - 2):

        inputs = conv(inputs, kernel = parameters[i])

    inputs = jax.numpy.reshape(inputs, newshape = (inputs.shape[0], 50176))

    for i in range(len(parameters) - 2, len(parameters) - 1):

        inputs = jax.numpy.matmul(inputs, parameters[i]["weight"]) + parameters[i]["bias"]
        inputs = jax.nn.selu(inputs)

    inputs = jax.numpy.matmul(inputs, parameters[-1]["weight"]) + parameters[-1]["bias"]
    inputs = jax.nn.softmax(inputs, axis = -1)

    return inputs

預測模型及訓練

對于模型參數(shù)的初始化亭饵、損失函數(shù)以及優(yōu)化函數(shù)休偶,早前的章節(jié)已經(jīng)介紹,不再贅述辜羊。完整訓練代碼如下所示踏兜。


import jax
import tensorflow_datasets as tfds
import time

import MnistDatasets

def setup():

   train_images, train_labels, test_images, test_labels = MnistDatasets.mnist()
   
   batch_size = 600
   
   inputs_channels = 1
   epochs = 9
   
   prng = jax.random.PRNGKey(15)
   
   kernel_shapes = [
        [3, 3, 1, 16],
        [3, 3, 16, 32],
        [3, 3, 32, 48],
        [3, 3, 48, 64],
        [50176, 128],
        [128, 10]
    ]
   
   return (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes)

def one_hot(inputs, length = 10, dtype = jax.numpy.float32):
    
    matches = jax.numpy.array(inputs[:, None] == jax.numpy.arange(length), dtype)
    
    return matches

def partial_flatten(inputs):

    """

    Flatten all but the first dimension of an array

    jax.lax.expand_dims(inputs, [-1]): [60000, 28, 28] -> [60000, 28, 28, 1]
    jax.lax.expand_dims(inputs, [1, 2]): [60000, 28, 28] -> [60000, 1, 1, 28, 28]

    """
    inputs = jax.lax.expand_dims(inputs, [-1])  # [60000, 28, 28] -> [60000, 28, 28, 1]

    return inputs / 255.

def init_mlp_params(shapes, prng):
    
    params = []
    
    # Create 12 layers kernels for Convolutional Neural Networks
    for i in range(len(shapes) - 2):
        
        weights = jax.random.normal(key = prng, shape = shapes[i]) / jax.numpy.sqrt(28. * 28.)
        
        _dict = dict(weight = weights)
        
        params.append(_dict)
         
    # Create 3 layers kernels for Dense
    for i in range(len(shapes) - 2, len(shapes)):
        
        weights = jax.random.normal(key = prng, shape = shapes[i]) / jax.numpy.sqrt(28. * 28.)
        biases = jax.random.normal(key = prng, shape = (shapes[i][-1],)) / jax.numpy.sqrt(28. * 28.)
        
        _dict = dict(weight = weights, bias = biases)
        
        params.append(_dict)
        
    return params

def conv(inputs, kernel, window_strides = 1):
    
    shape = inputs.shape
    dimension_numbers = jax.lax.conv_dimension_numbers(lhs_shape = shape, rhs_shape = kernel["weight"].shape, dimension_numbers = ("NHWC", "HWIO", "NHWC"))
    
    inputs = jax.lax.conv_general_dilated(inputs, kernel["weight"], window_strides = [window_strides, window_strides], padding = "SAME", dimension_numbers = dimension_numbers)
    inputs = jax.nn.selu(inputs)
    
    return inputs

@jax.jit
def forward(parameters, inputs):

    for i in range(len(parameters) - 2):

        inputs = conv(inputs, kernel = parameters[i])

    inputs = jax.numpy.reshape(inputs, newshape = (inputs.shape[0], 50176))

    for i in range(len(parameters) - 2, len(parameters) - 1):

        inputs = jax.numpy.matmul(inputs, parameters[i]["weight"]) + parameters[i]["bias"]
        inputs = jax.nn.selu(inputs)

    inputs = jax.numpy.matmul(inputs, parameters[-1]["weight"]) + parameters[-1]["bias"]
    inputs = jax.nn.softmax(inputs, axis = -1)

    return inputs

@jax.jit
def cross_entropy(genuines, predictions):
    
    entropys = genuines * jax.numpy.log(jax.numpy.clip(predictions, 1e-9, 0.999)) + (1 - genuines) * jax.numpy.log(jax.numpy.clip(1 - predictions, 1e-9, 0.999))
    entropys = jax.numpy.sum(entropys, axis = 1)
    entropys = jax.numpy.mean(entropys)
    
    return entropys

@jax.jit
def loss_function(parameters, inputs, genuines):

    predictions = forward(parameters, inputs)
    entropys = cross_entropy(genuines, predictions)

    return entropys

@jax.jit
def optimizer_function(parameters, inputs, genuines, learning_rate = 1e-3):

    grad_loss_function = jax.grad(loss_function)
    gradients = grad_loss_function(parameters, inputs, genuines)

    new_parameters = jax.tree_util.tree_map(lambda parameter, gradient: parameter - learning_rate * gradient, parameters, gradients)

    return new_parameters

@jax.jit
def verify_accuracy(params, inputs, targets):
    
    """
    Correct predictions over a mini batch
    """
    predictions = forward(params, inputs)
    _class = jax.numpy.argmax(predictions, axis = 1)
    targets = jax.numpy.argmax(targets, axis = 1)
    
    return jax.numpy.sum(_class == targets)

def train():
    
    (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes) = setup()
    
    print(f"train_images.shape = {train_images.shape}, train_labels.shape = {train_labels.shape}), (test_images.shape = {test_images.shape}, test_labels.shape = {test_labels.shape}")
    
    '''
    train_images.shape = (60000, 28, 28), train_labels.shape = (60000,)), (test_images.shape = (10000, 28, 28), test_labels.shape = (10000,)
    '''
    
    train_images = partial_flatten(train_images)
    train_labels = one_hot(train_labels)
    
    test_images = partial_flatten(test_images)
    test_labels = one_hot(test_labels)
    
    params = init_mlp_params(kernel_shapes, prng)
    
    begin = time.time();
    
    for i in range(epochs):
        
        batch_number = train_images.shape[0] // batch_size
                
        for j in range(batch_number):
            
            start = batch_size * j
            stop = batch_size * (j + 1)
            
            images_batch = train_images[start: stop]
            labels_batch = train_labels[start: stop]
            
            params = optimizer_function(params, images_batch, labels_batch)
            
            print(f"Bacth number {j + 1}/{batch_number} within epoch {i + 1}/{epochs} is completed")
            
        if (i + 1) % 2 == 0:
            
            loss = loss_function(params, train_images, train_labels)
            
            end = time.time()
            
            accuracies = verify_accuracy(params, test_images, test_labels) / float(4096.)
            
            print(f"Now the loss is {loss}, accuracy is {accuracies} after {1 + 1} iterations")
            
            start = time.time()
            
if __name__ == "__main__":
    
    train()

運行結果打印輸出如下词顾,


…
Bacth number 95/100 within epoch 1/9 is completed
Bacth number 96/100 within epoch 1/9 is completed
Bacth number 97/100 within epoch 1/9 is completed
Bacth number 98/100 within epoch 1/9 is completed
Bacth number 99/100 within epoch 1/9 is completed
Bacth number 100/100 within epoch 1/9 is completed
Bacth number 1/100 within epoch 2/9 is completed
Bacth number 2/100 within epoch 2/9 is completed
Bacth number 3/100 within epoch 2/9 is completed
Bacth number 4/100 within epoch 2/9 is completed
…

可以看到,經(jīng)過10個epochs碱妆,模型的準確性就達到了一個較好的水平肉盹,相較于前面幾章使用的全連接層完成分類任務,結果有了一個極大的提升疹尾。

結論

VGG是一個最為經(jīng)典的卷積神經(jīng)網(wǎng)絡分類模型上忍,至今在不少領域仍舊占據(jù)重要的地位。本站完成的VGG模型的設計和訓練纳本,請理解并掌握窍蓝。

另外,關于JAX的卷積計算繁成,對于深度學習來說吓笙,卷積是計算機視覺,部分自然語言處理朴艰,以及強化學習領域應用最為廣泛的數(shù)據(jù)處理和模型提取模型方法观蓄,也要理解掌握。

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末祠墅,一起剝皮案震驚了整個濱河市侮穿,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌毁嗦,老刑警劉巖亲茅,帶你破解...
    沈念sama閱讀 212,080評論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異狗准,居然都是意外死亡克锣,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,422評論 3 385
  • 文/潘曉璐 我一進店門腔长,熙熙樓的掌柜王于貴愁眉苦臉地迎上來袭祟,“玉大人,你說我怎么就攤上這事捞附〗砣椋” “怎么了?”我有些...
    開封第一講書人閱讀 157,630評論 0 348
  • 文/不壞的土叔 我叫張陵鸟召,是天一觀的道長胆绊。 經(jīng)常有香客問我,道長欧募,這世上最難降的妖魔是什么压状? 我笑而不...
    開封第一講書人閱讀 56,554評論 1 284
  • 正文 為了忘掉前任,我火速辦了婚禮跟继,結果婚禮上种冬,老公的妹妹穿的比我還像新娘镣丑。我一直安慰自己,他們只是感情好娱两,可當我...
    茶點故事閱讀 65,662評論 6 386
  • 文/花漫 我一把揭開白布传轰。 她就那樣靜靜地躺著,像睡著了一般谷婆。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上辽聊,一...
    開封第一講書人閱讀 49,856評論 1 290
  • 那天,我揣著相機與錄音,去河邊找鬼菜秦。 笑死滚朵,一個胖子當著我的面吹牛,可吹牛的內容都是我干的玛臂。 我是一名探鬼主播烤蜕,決...
    沈念sama閱讀 39,014評論 3 408
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼迹冤!你這毒婦竟也來了讽营?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 37,752評論 0 268
  • 序言:老撾萬榮一對情侶失蹤泡徙,失蹤者是張志新(化名)和其女友劉穎橱鹏,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體堪藐,經(jīng)...
    沈念sama閱讀 44,212評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡莉兰,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 36,541評論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了礁竞。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片糖荒。...
    茶點故事閱讀 38,687評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖模捂,靈堂內的尸體忽然破棺而出捶朵,到底是詐尸還是另有隱情,我是刑警寧澤枫绅,帶...
    沈念sama閱讀 34,347評論 4 331
  • 正文 年R本政府宣布泉孩,位于F島的核電站,受9級特大地震影響并淋,放射性物質發(fā)生泄漏寓搬。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,973評論 3 315
  • 文/蒙蒙 一县耽、第九天 我趴在偏房一處隱蔽的房頂上張望句喷。 院中可真熱鬧镣典,春花似錦、人聲如沸唾琼。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,777評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽锡溯。三九已至赶舆,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間祭饭,已是汗流浹背芜茵。 一陣腳步聲響...
    開封第一講書人閱讀 32,006評論 1 266
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留倡蝙,地道東北人九串。 一個月前我還...
    沈念sama閱讀 46,406評論 2 360
  • 正文 我出身青樓,卻偏偏與公主長得像寺鸥,于是被迫代替她去往敵國和親猪钮。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 43,576評論 2 349

推薦閱讀更多精彩內容