經(jīng)過前面兩章的介紹跋理,現(xiàn)在著手開始實現(xiàn)VGG模型進行MNIST分類任務的JAX代碼實戰(zhàn)绷跑。
數(shù)據(jù)準備
這里使用tensorflow_datasets庫來管理數(shù)據(jù)集渗钉。該庫自動下載航瞭,同時提供了類和方法來操作數(shù)據(jù)集诫硕,包括分割訓練集和測試集。MnistDatasets.py代碼如下刊侯,
import array
import gzip
import os
import ssl
import struct
import urllib.request
import jax.numpy
from os import path
from tqdm import tqdm
data_dir = "/tmp/JAX/Shares/Datasets/MNIST/"
def _download(url, name):
"""
Download an url to a file in JAX data temporary directory
"""
if not path.exists(data_dir):
os.makedirs(data_dir)
out_file = path.join(data_dir, name)
if not path.isfile(out_file):
ssl._create_default_https_context = ssl._create_unverified_context
with tqdm(unit = "B", unit_scale = True, unit_divisor = 1024, miniters = 1, desc = name) as bar:
urllib.request.urlretrieve(url, out_file, reporthook = report_hook(bar))
print(f"Downloaded {url} to {data_dir}")
def report_hook(bar: tqdm):
"""
Progress Bar of tqdm for downloads
"""
def hook(block_counter = 0, block_size = 1, total_size = None):
if total_size is not None:
bar.total = total_size
bar.update(block_counter * block_size - bar.n)
return hook
def mnist_raw():
"""
Download and parse the raw MNIST dataset.
"""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
def parse_labels(file):
with gzip.open(file, "rb") as handler:
_ = struct.unpack(">II", handler.read(8))
return jax.numpy.array(array.array("B", handler.read()), dtype = jax.numpy.uint8)
def parse_images(file):
with gzip.open(file, "rb") as handler:
_, number, rows, columns = struct.unpack(">IIII", handler.read(16))
return jax.numpy.array(array.array("B", handler.read()), dtype = jax.numpy.uint8).reshape(number, rows, columns)
for name in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
url = path.join(data_dir, name)
if not path.exists(url):
_download(base_url + name, name)
train_images = parse_images(path.join(data_dir, "train-images-idx3-ubyte.gz"))
train_labels = parse_labels(path.join(data_dir, "train-labels-idx1-ubyte.gz"))
test_images = parse_images(path.join(data_dir, "t10k-images-idx3-ubyte.gz"))
test_labels = parse_labels(path.join(data_dir, "t10k-labels-idx1-ubyte.gz"))
return train_images, train_labels, test_images, test_labels
def mnist(permute_train = False):
"""
Download, parse and process the MNIST data to unit scale and one-hot labels
"""
train_images, train_labels, test_images, test_labels = mnist_raw()
if permute_train:
permutation = jax.random.permutation(train_images.shape[0])
train_images = train_images[permutation]
train_labels = train_labels[permutation]
return train_images, train_labels, test_images, test_labels
VisualGeometryGroupMnist.py代碼如下章办,
import jax
import tensorflow_datasets as tfds
import time
import MnistDatasets
def setup():
train_images, train_labels, test_images, test_labels = MnistDatasets.mnist()
batch_size = 600
inputs_channels = 1
epochs = 9
prng = jax.random.PRNGKey(15)
kernel_shapes = [
[3, 3, 1, 16],
[3, 3, 16, 32],
[3, 3, 32, 48],
[3, 3, 48, 64],
[50176, 128],
[128, 10]
]
return (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes)
def one_hot(inputs, length = 10, dtype = jax.numpy.float32):
matches = jax.numpy.array(inputs[:, None] == jax.numpy.arange(length), dtype)
return matches
def partial_flatten(inputs):
"""
Flatten all but the first dimension of an array
jax.lax.expand_dims(inputs, [-1]): [60000, 28, 28] -> [60000, 28, 28, 1]
jax.lax.expand_dims(inputs, [1, 2]): [60000, 28, 28] -> [60000, 1, 1, 28, 28]
"""
inputs = jax.lax.expand_dims(inputs, [-1]) # [60000, 28, 28] -> [60000, 28, 28, 1]
return inputs / 255.
def test():
(train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes) = setup()
print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape), (batch_size, epochs, prng, kernel_shapes))
運行結果打印輸出如下,
((60000, 28, 28), (60000,)) ((10000, 28, 28), (10000,)) (600, 9, Array([ 0, 15], dtype=uint32), [[3, 3, 1, 16], [3, 3, 16, 32], [3, 3, 32, 48], [3, 3, 48, 64], [50176, 128], [128, 10]])
通過打印結果可以預覽數(shù)據(jù)集和訓練集的構成滨彻。
計算模型實現(xiàn)
上一章我們介紹并實現(xiàn)了VGG的部分組件藕届,這里可以直接使用,
卷積層
def conv(inputs, kernel, window_strides = 1):
shape = inputs.shape
dimension_numbers = jax.lax.conv_dimension_numbers(lhs_shape = shape, rhs_shape = kernel["weight"].shape, dimension_numbers = ("NHWC", "HWIO", "NHWC"))
inputs = jax.lax.conv_general_dilated(inputs, kernel["weight"], window_strides = [window_strides, window_strides], padding = "SAME", dimension_numbers = dimension_numbers)
inputs = jax.nn.selu(inputs)
return inputs
前向傳播函數(shù)
@jax.jit
def forward(parameters, inputs):
for i in range(len(parameters) - 2):
inputs = conv(inputs, kernel = parameters[i])
inputs = jax.numpy.reshape(inputs, newshape = (inputs.shape[0], 50176))
for i in range(len(parameters) - 2, len(parameters) - 1):
inputs = jax.numpy.matmul(inputs, parameters[i]["weight"]) + parameters[i]["bias"]
inputs = jax.nn.selu(inputs)
inputs = jax.numpy.matmul(inputs, parameters[-1]["weight"]) + parameters[-1]["bias"]
inputs = jax.nn.softmax(inputs, axis = -1)
return inputs
預測模型及訓練
對于模型參數(shù)的初始化亭饵、損失函數(shù)以及優(yōu)化函數(shù)休偶,早前的章節(jié)已經(jīng)介紹,不再贅述辜羊。完整訓練代碼如下所示踏兜。
import jax
import tensorflow_datasets as tfds
import time
import MnistDatasets
def setup():
train_images, train_labels, test_images, test_labels = MnistDatasets.mnist()
batch_size = 600
inputs_channels = 1
epochs = 9
prng = jax.random.PRNGKey(15)
kernel_shapes = [
[3, 3, 1, 16],
[3, 3, 16, 32],
[3, 3, 32, 48],
[3, 3, 48, 64],
[50176, 128],
[128, 10]
]
return (train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes)
def one_hot(inputs, length = 10, dtype = jax.numpy.float32):
matches = jax.numpy.array(inputs[:, None] == jax.numpy.arange(length), dtype)
return matches
def partial_flatten(inputs):
"""
Flatten all but the first dimension of an array
jax.lax.expand_dims(inputs, [-1]): [60000, 28, 28] -> [60000, 28, 28, 1]
jax.lax.expand_dims(inputs, [1, 2]): [60000, 28, 28] -> [60000, 1, 1, 28, 28]
"""
inputs = jax.lax.expand_dims(inputs, [-1]) # [60000, 28, 28] -> [60000, 28, 28, 1]
return inputs / 255.
def init_mlp_params(shapes, prng):
params = []
# Create 12 layers kernels for Convolutional Neural Networks
for i in range(len(shapes) - 2):
weights = jax.random.normal(key = prng, shape = shapes[i]) / jax.numpy.sqrt(28. * 28.)
_dict = dict(weight = weights)
params.append(_dict)
# Create 3 layers kernels for Dense
for i in range(len(shapes) - 2, len(shapes)):
weights = jax.random.normal(key = prng, shape = shapes[i]) / jax.numpy.sqrt(28. * 28.)
biases = jax.random.normal(key = prng, shape = (shapes[i][-1],)) / jax.numpy.sqrt(28. * 28.)
_dict = dict(weight = weights, bias = biases)
params.append(_dict)
return params
def conv(inputs, kernel, window_strides = 1):
shape = inputs.shape
dimension_numbers = jax.lax.conv_dimension_numbers(lhs_shape = shape, rhs_shape = kernel["weight"].shape, dimension_numbers = ("NHWC", "HWIO", "NHWC"))
inputs = jax.lax.conv_general_dilated(inputs, kernel["weight"], window_strides = [window_strides, window_strides], padding = "SAME", dimension_numbers = dimension_numbers)
inputs = jax.nn.selu(inputs)
return inputs
@jax.jit
def forward(parameters, inputs):
for i in range(len(parameters) - 2):
inputs = conv(inputs, kernel = parameters[i])
inputs = jax.numpy.reshape(inputs, newshape = (inputs.shape[0], 50176))
for i in range(len(parameters) - 2, len(parameters) - 1):
inputs = jax.numpy.matmul(inputs, parameters[i]["weight"]) + parameters[i]["bias"]
inputs = jax.nn.selu(inputs)
inputs = jax.numpy.matmul(inputs, parameters[-1]["weight"]) + parameters[-1]["bias"]
inputs = jax.nn.softmax(inputs, axis = -1)
return inputs
@jax.jit
def cross_entropy(genuines, predictions):
entropys = genuines * jax.numpy.log(jax.numpy.clip(predictions, 1e-9, 0.999)) + (1 - genuines) * jax.numpy.log(jax.numpy.clip(1 - predictions, 1e-9, 0.999))
entropys = jax.numpy.sum(entropys, axis = 1)
entropys = jax.numpy.mean(entropys)
return entropys
@jax.jit
def loss_function(parameters, inputs, genuines):
predictions = forward(parameters, inputs)
entropys = cross_entropy(genuines, predictions)
return entropys
@jax.jit
def optimizer_function(parameters, inputs, genuines, learning_rate = 1e-3):
grad_loss_function = jax.grad(loss_function)
gradients = grad_loss_function(parameters, inputs, genuines)
new_parameters = jax.tree_util.tree_map(lambda parameter, gradient: parameter - learning_rate * gradient, parameters, gradients)
return new_parameters
@jax.jit
def verify_accuracy(params, inputs, targets):
"""
Correct predictions over a mini batch
"""
predictions = forward(params, inputs)
_class = jax.numpy.argmax(predictions, axis = 1)
targets = jax.numpy.argmax(targets, axis = 1)
return jax.numpy.sum(_class == targets)
def train():
(train_images, train_labels), (test_images, test_labels), (batch_size, epochs, prng, kernel_shapes) = setup()
print(f"train_images.shape = {train_images.shape}, train_labels.shape = {train_labels.shape}), (test_images.shape = {test_images.shape}, test_labels.shape = {test_labels.shape}")
'''
train_images.shape = (60000, 28, 28), train_labels.shape = (60000,)), (test_images.shape = (10000, 28, 28), test_labels.shape = (10000,)
'''
train_images = partial_flatten(train_images)
train_labels = one_hot(train_labels)
test_images = partial_flatten(test_images)
test_labels = one_hot(test_labels)
params = init_mlp_params(kernel_shapes, prng)
begin = time.time();
for i in range(epochs):
batch_number = train_images.shape[0] // batch_size
for j in range(batch_number):
start = batch_size * j
stop = batch_size * (j + 1)
images_batch = train_images[start: stop]
labels_batch = train_labels[start: stop]
params = optimizer_function(params, images_batch, labels_batch)
print(f"Bacth number {j + 1}/{batch_number} within epoch {i + 1}/{epochs} is completed")
if (i + 1) % 2 == 0:
loss = loss_function(params, train_images, train_labels)
end = time.time()
accuracies = verify_accuracy(params, test_images, test_labels) / float(4096.)
print(f"Now the loss is {loss}, accuracy is {accuracies} after {1 + 1} iterations")
start = time.time()
if __name__ == "__main__":
train()
運行結果打印輸出如下词顾,
…
Bacth number 95/100 within epoch 1/9 is completed
Bacth number 96/100 within epoch 1/9 is completed
Bacth number 97/100 within epoch 1/9 is completed
Bacth number 98/100 within epoch 1/9 is completed
Bacth number 99/100 within epoch 1/9 is completed
Bacth number 100/100 within epoch 1/9 is completed
Bacth number 1/100 within epoch 2/9 is completed
Bacth number 2/100 within epoch 2/9 is completed
Bacth number 3/100 within epoch 2/9 is completed
Bacth number 4/100 within epoch 2/9 is completed
…
可以看到,經(jīng)過10個epochs碱妆,模型的準確性就達到了一個較好的水平肉盹,相較于前面幾章使用的全連接層完成分類任務,結果有了一個極大的提升疹尾。
結論
VGG是一個最為經(jīng)典的卷積神經(jīng)網(wǎng)絡分類模型上忍,至今在不少領域仍舊占據(jù)重要的地位。本站完成的VGG模型的設計和訓練纳本,請理解并掌握窍蓝。
另外,關于JAX的卷積計算繁成,對于深度學習來說吓笙,卷積是計算機視覺,部分自然語言處理朴艰,以及強化學習領域應用最為廣泛的數(shù)據(jù)處理和模型提取模型方法观蓄,也要理解掌握。