第42章 使用VGG模型實現(xiàn)MNIST數(shù)據(jù)集分類




import array
import gzip
import os
import ssl
import struct
import urllib.request
import jax.numpy

from os import path
from tqdm import tqdm

data_dir = "/tmp/JAX/Shares/Datasets/MNIST/"

def _download(url, name):


    Download an url to a file in JAX data temporary directory


    if not path.exists(data_dir):


    out_file = path.join(data_dir, name)

    if not path.isfile(out_file):

        ssl._create_default_https_context = ssl._create_unverified_context

        with tqdm(unit = "B", unit_scale = True, unit_divisor = 1024, miniters = 1, desc = name) as bar:

            urllib.request.urlretrieve(url, out_file, reporthook = report_hook(bar))

        print(f"Downloaded {url} to {data_dir}")

def report_hook(bar: tqdm):


    Progress Bar of tqdm for downloads


    def hook(block_counter = 0, block_size = 1, total_size = None):

        if total_size is not None:

            bar.total = total_size

        bar.update(block_counter * block_size - bar.n)

    return hook

def mnist_raw():


    Download and parse the raw MNIST dataset.


    # CVDF mirror of http://yann.lecun.com/exdb/mnist/
    base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

    def parse_labels(file):

        with gzip.open(file, "rb") as handler:

            _ = struct.unpack(">II", handler.read(8))

            return jax.numpy.array(array.array("B", handler.read()), dtype = jax.numpy.uint8)

    def parse_images(file):

        with gzip.open(file, "rb") as handler:

            _, number, rows, columns = struct.unpack(">IIII", handler.read(16))

            return jax.numpy.array(array.array("B", handler.read()), dtype = jax.numpy.uint8).reshape(number, rows, columns)

    for name in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:

        url = path.join(data_dir, name)

        if not path.exists(url):
            _download(base_url + name, name)

    train_images = parse_images(path.join(data_dir, "train-images-idx3-ubyte.gz"))
    train_labels = parse_labels(path.join(data_dir, "train-labels-idx1-ubyte.gz"))
    test_images = parse_images(path.join(data_dir, "t10k-images-idx3-ubyte.gz"))
    test_labels = parse_labels(path.join(data_dir, "t10k-labels-idx1-ubyte.gz"))

    return train_images, train_labels, test_images, test_labels

def mnist(permute_train = False):


    Download, parse and process the MNIST data to unit scale and one-hot labels


    train_images, train_labels, test_images, test_labels = mnist_raw()

    if permute_train:

        permutation = jax.random.permutation(train_images.shape[0])

        train_images = train_images[permutation]
        train_labels = train_labels[permutation]

    return train_images, train_labels, test_images, test_labels


import jax
import tensorflow_datasets as tfds
import time

import MnistDatasets

def setup():

   train_images, train_labels, test_images, test_labels = MnistDatasets.mnist()
   batch_size = 600
   inputs_channels = 1
   epochs = 9
   prng = jax.random.PRNGKey(15)
   kernel_shapes = [
        [3, 3, 1, 16],
        [3, 3, 16, 32],
        [3, 3, 32, 48],
        [3, 3, 48, 64],
        [50176, 128],
        [128, 10]
   return (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes)

def one_hot(inputs, length = 10, dtype = jax.numpy.float32):
    matches = jax.numpy.array(inputs[:, None] == jax.numpy.arange(length), dtype)
    return matches

def partial_flatten(inputs):


    Flatten all but the first dimension of an array

    jax.lax.expand_dims(inputs, [-1]): [60000, 28, 28] -> [60000, 28, 28, 1]
    jax.lax.expand_dims(inputs, [1, 2]): [60000, 28, 28] -> [60000, 1, 1, 28, 28]

    inputs = jax.lax.expand_dims(inputs, [-1])  # [60000, 28, 28] -> [60000, 28, 28, 1]

    return inputs / 255.
def test():

    (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes) = setup()
    print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape), (batch_size, epochs, prng, kernel_shapes))


((60000, 28, 28), (60000,)) ((10000, 28, 28), (10000,)) (600, 9, Array([ 0, 15], dtype=uint32), [[3, 3, 1, 16], [3, 3, 16, 32], [3, 3, 32, 48], [3, 3, 48, 64], [50176, 128], [128, 10]])





def conv(inputs, kernel, window_strides = 1):
    shape = inputs.shape
    dimension_numbers = jax.lax.conv_dimension_numbers(lhs_shape = shape, rhs_shape = kernel["weight"].shape, dimension_numbers = ("NHWC", "HWIO", "NHWC"))
    inputs = jax.lax.conv_general_dilated(inputs, kernel["weight"], window_strides = [window_strides, window_strides], padding = "SAME", dimension_numbers = dimension_numbers)
    inputs = jax.nn.selu(inputs)
    return inputs


def forward(parameters, inputs):

    for i in range(len(parameters) - 2):

        inputs = conv(inputs, kernel = parameters[i])

    inputs = jax.numpy.reshape(inputs, newshape = (inputs.shape[0], 50176))

    for i in range(len(parameters) - 2, len(parameters) - 1):

        inputs = jax.numpy.matmul(inputs, parameters[i]["weight"]) + parameters[i]["bias"]
        inputs = jax.nn.selu(inputs)

    inputs = jax.numpy.matmul(inputs, parameters[-1]["weight"]) + parameters[-1]["bias"]
    inputs = jax.nn.softmax(inputs, axis = -1)

    return inputs



import jax
import tensorflow_datasets as tfds
import time

import MnistDatasets

def setup():

   train_images, train_labels, test_images, test_labels = MnistDatasets.mnist()
   batch_size = 600
   inputs_channels = 1
   epochs = 9
   prng = jax.random.PRNGKey(15)
   kernel_shapes = [
        [3, 3, 1, 16],
        [3, 3, 16, 32],
        [3, 3, 32, 48],
        [3, 3, 48, 64],
        [50176, 128],
        [128, 10]
   return (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes)

def one_hot(inputs, length = 10, dtype = jax.numpy.float32):
    matches = jax.numpy.array(inputs[:, None] == jax.numpy.arange(length), dtype)
    return matches

def partial_flatten(inputs):


    Flatten all but the first dimension of an array

    jax.lax.expand_dims(inputs, [-1]): [60000, 28, 28] -> [60000, 28, 28, 1]
    jax.lax.expand_dims(inputs, [1, 2]): [60000, 28, 28] -> [60000, 1, 1, 28, 28]

    inputs = jax.lax.expand_dims(inputs, [-1])  # [60000, 28, 28] -> [60000, 28, 28, 1]

    return inputs / 255.

def init_mlp_params(shapes, prng):
    params = []
    # Create 12 layers kernels for Convolutional Neural Networks
    for i in range(len(shapes) - 2):
        weights = jax.random.normal(key = prng, shape = shapes[i]) / jax.numpy.sqrt(28. * 28.)
        _dict = dict(weight = weights)
    # Create 3 layers kernels for Dense
    for i in range(len(shapes) - 2, len(shapes)):
        weights = jax.random.normal(key = prng, shape = shapes[i]) / jax.numpy.sqrt(28. * 28.)
        biases = jax.random.normal(key = prng, shape = (shapes[i][-1],)) / jax.numpy.sqrt(28. * 28.)
        _dict = dict(weight = weights, bias = biases)
    return params

def conv(inputs, kernel, window_strides = 1):
    shape = inputs.shape
    dimension_numbers = jax.lax.conv_dimension_numbers(lhs_shape = shape, rhs_shape = kernel["weight"].shape, dimension_numbers = ("NHWC", "HWIO", "NHWC"))
    inputs = jax.lax.conv_general_dilated(inputs, kernel["weight"], window_strides = [window_strides, window_strides], padding = "SAME", dimension_numbers = dimension_numbers)
    inputs = jax.nn.selu(inputs)
    return inputs

def forward(parameters, inputs):

    for i in range(len(parameters) - 2):

        inputs = conv(inputs, kernel = parameters[i])

    inputs = jax.numpy.reshape(inputs, newshape = (inputs.shape[0], 50176))

    for i in range(len(parameters) - 2, len(parameters) - 1):

        inputs = jax.numpy.matmul(inputs, parameters[i]["weight"]) + parameters[i]["bias"]
        inputs = jax.nn.selu(inputs)

    inputs = jax.numpy.matmul(inputs, parameters[-1]["weight"]) + parameters[-1]["bias"]
    inputs = jax.nn.softmax(inputs, axis = -1)

    return inputs

def cross_entropy(genuines, predictions):
    entropys = genuines * jax.numpy.log(jax.numpy.clip(predictions, 1e-9, 0.999)) + (1 - genuines) * jax.numpy.log(jax.numpy.clip(1 - predictions, 1e-9, 0.999))
    entropys = jax.numpy.sum(entropys, axis = 1)
    entropys = jax.numpy.mean(entropys)
    return entropys

def loss_function(parameters, inputs, genuines):

    predictions = forward(parameters, inputs)
    entropys = cross_entropy(genuines, predictions)

    return entropys

def optimizer_function(parameters, inputs, genuines, learning_rate = 1e-3):

    grad_loss_function = jax.grad(loss_function)
    gradients = grad_loss_function(parameters, inputs, genuines)

    new_parameters = jax.tree_util.tree_map(lambda parameter, gradient: parameter - learning_rate * gradient, parameters, gradients)

    return new_parameters

def verify_accuracy(params, inputs, targets):
    Correct predictions over a mini batch
    predictions = forward(params, inputs)
    _class = jax.numpy.argmax(predictions, axis = 1)
    targets = jax.numpy.argmax(targets, axis = 1)
    return jax.numpy.sum(_class == targets)

def train():
    (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes) = setup()
    print(f"train_images.shape = {train_images.shape}, train_labels.shape = {train_labels.shape}), (test_images.shape = {test_images.shape}, test_labels.shape = {test_labels.shape}")
    train_images.shape = (60000, 28, 28), train_labels.shape = (60000,)), (test_images.shape = (10000, 28, 28), test_labels.shape = (10000,)
    train_images = partial_flatten(train_images)
    train_labels = one_hot(train_labels)
    test_images = partial_flatten(test_images)
    test_labels = one_hot(test_labels)
    params = init_mlp_params(kernel_shapes, prng)
    begin = time.time();
    for i in range(epochs):
        batch_number = train_images.shape[0] // batch_size
        for j in range(batch_number):
            start = batch_size * j
            stop = batch_size * (j + 1)
            images_batch = train_images[start: stop]
            labels_batch = train_labels[start: stop]
            params = optimizer_function(params, images_batch, labels_batch)
            print(f"Bacth number {j + 1}/{batch_number} within epoch {i + 1}/{epochs} is completed")
        if (i + 1) % 2 == 0:
            loss = loss_function(params, train_images, train_labels)
            end = time.time()
            accuracies = verify_accuracy(params, test_images, test_labels) / float(4096.)
            print(f"Now the loss is {loss}, accuracy is {accuracies} after {1 + 1} iterations")
            start = time.time()
if __name__ == "__main__":


Bacth number 95/100 within epoch 1/9 is completed
Bacth number 96/100 within epoch 1/9 is completed
Bacth number 97/100 within epoch 1/9 is completed
Bacth number 98/100 within epoch 1/9 is completed
Bacth number 99/100 within epoch 1/9 is completed
Bacth number 100/100 within epoch 1/9 is completed
Bacth number 1/100 within epoch 2/9 is completed
Bacth number 2/100 within epoch 2/9 is completed
Bacth number 3/100 within epoch 2/9 is completed
Bacth number 4/100 within epoch 2/9 is completed





