上一章講解了ResNet模型及組件蒋荚,也介紹了jax.example_libraries.stax下JAX內(nèi)置的模型組件期升。有了這些準(zhǔn)備工作,可以開始上手寫代碼了颂郎。所以行拢,本章著手使用ResNet實現(xiàn)CIFAR100數(shù)據(jù)集的分類舟奠。
何為CIFAR100數(shù)據(jù)集
CIFAR10和CIFAR100都是含有標(biāo)記小圖的數(shù)據(jù)集沼瘫,相比CIFAR10咙俩,CIFAR100含有100個分類阿趁,每類含600個圖。其中皂股,每個分類含有500張訓(xùn)練圖和100張測試圖呜呐。CIFAR100的100個分類又被分組到20個大類悍募。
Superclass | Classes |
---|---|
aquatic mammal | beaver, dolphin, otter, seal, whale |
fish | aquarium fish, flatfish, ray, shark, trout |
flowers | orchids, poppies, roses, sunflowers, tulips |
food containers | bottles, bowls, cans, cups, plates |
fruit and vegetables | apples, mushrooms, oranges, pears, sweet peppers |
household electrical devices | clock, computer keyboard, lamp, telephone, television |
household furniture | bed, chair, couch, table, wardrobe |
insects | bee, beetle, butterfly, caterpillar, cockroach |
large carnivores | bear, leopard, lion, tiger, wolf |
large man-made outdoor things | bridge, castle, house, road, skyscraper |
large natural outdoor scenes | cloud, forest, mountain, plain, sea |
large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo |
medium-sized mammals | fox, porcupine, possum, raccoon, skunk |
non-insect invertebrates | crab, lobster, snail, spider, worm |
people | baby, boy, girl, man, woman |
reptiles | crocodile, dinosaur, lizard, snake, turtle |
small mammals | hamster, mouse, rabbit, shrew, squirrel |
trees | hamster, mouse, rabbit, shrew, squirrel |
vehicles 1 | bicycle, bus, motorcycle, pickup truck, train |
vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor |
每個圖含有一個“fine”標(biāo)簽(表示所屬分類)和一個“coarse”標(biāo)簽(所屬大類)洋魂,大小為32x32像素副砍。
可以通過兩種方式下載數(shù)據(jù)集址晕,
- 直接下載顿锰,通過訪問 https://www.cs.toronto.edu/~kriz/cifar.html
版本 | 大小 | md5sum |
---|---|---|
CIFAR-100 python version | 161 MB | eb9058c3a382ffc7106e4002c42a8d85 |
CIFAR-100 Matlab version | 175 MB | 6a4bfa1dcd5c9453dda6bb54194911f4 |
CIFAR-100 binary version (suitable for C programs) | 161 MB | 03b5dce01913d631647c71ecec9e9cb8 |
選擇python版本。
- 使用tensorflow_datasets下載胳赌。
下面分別介紹一下匙隔。
使用下載后的CIFAR100生成數(shù)據(jù)集
CIFAR-100 python version下載后,會有如下文件結(jié)構(gòu)捍掺,
train
test
meta
file.txt~
其中挺勿,meta是數(shù)據(jù)集信息喂柒,train是訓(xùn)練集灾杰,test是測試集。通過如下代碼可以讀取數(shù)據(jù)集麦备,
import pickle
def setup():
def load(fileName: str):
with open(file = fileName, mode = "rb") as handler:
data = pickle.load(file = handler, encoding = "latin1")
return data
trains = load("../../Shares/cifar-100-python/train")
tests = load("../../Shares/cifar-100-python/test")
metas = load("../../Shares/cifar-100-python/meta")
return trains, tests, metas
def train():
trains, tests, metas = setup2()
for key in trains.keys():
print(f"key = {key}, len(trains[key]) = {len(trains[key])}")
print("--------------------------------------------------")
for key in tests.keys():
print(f"key = {key}, len(tests[key]) = {len(tests[key])}")
print("--------------------------------------------------")
for key in metas.keys():
print(f"key = {key}, len(metas[key]) = {len(metas[key])}”)
def main():
train()
運行結(jié)果打印輸出如下泥兰,
key = filenames, len(trains[key]) = 50000
key = batch_label, len(trains[key]) = 21
key = fine_labels, len(trains[key]) = 50000
key = coarse_labels, len(trains[key]) = 50000
key = data, len(trains[key]) = 50000
--------------------------------------------------
key = filenames, len(tests[key]) = 10000
key = batch_label, len(tests[key]) = 20
key = fine_labels, len(tests[key]) = 10000
key = coarse_labels, len(tests[key]) = 10000
key = data, len(tests[key]) = 10000
--------------------------------------------------
key = fine_label_names, len(metas[key]) = 100
key = coarse_label_names, len(metas[key]) = 20
具體說明如下鞋诗,
- filenames削彬,長度為50000的列表融痛,每一項代表對應(yīng)一個圖片文件名神僵。
- batch_label,批的信息沛励。
- fine_labels,所屬分類坤候。
- coarse_labels企蹭,所屬大類谅摄。
- data,長度為50000 x 3072的的二位數(shù)據(jù)虚青,每一行代表一幅圖片的像素值。
使用tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
def setup():
(trains, tests), meta = tfds.load("cifar100", data_dir = "/tmp/", split = [tfds.Split.TRAIN, tfds.Split.TEST], with_info = True, batch_size = -1)
#tensorflow_datasets.show_examples(trains, metas)
trains = tfds.as_numpy(trains)
tests = tfds.as_numpy(tests)
train_images, train_labels = trains["image"], trains["label"]
test_images, test_labels = tests["image"], tests["label"]
return (train_images, train_labels), (test_images, test_labels)
def train():
(train_images, train_labels), (test_images, test_labels) = setup()
print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape))
def main():
train()
if __name__ == "__main__":
main()
運行結(jié)果打印輸出如下下隧,
((50000, 32, 32, 3), (50000,)) ((10000, 32, 32, 3), (10000,))
keras.datasets數(shù)據(jù)集
def setup():
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()
return (train_images, train_labels), (test_images, test_labels)
運行結(jié)果打印輸出如下谓媒,
((50000, 32, 32, 3), (50000, 1)) ((10000, 32, 32, 3), (10000, 1))
ResNet殘差模型實現(xiàn)
ResNet神經(jīng)網(wǎng)絡(luò)架構(gòu)在上一章已經(jīng)介紹,該網(wǎng)絡(luò)創(chuàng)造性地使用“模塊化‘的思維去對網(wǎng)絡(luò)進(jìn)行疊加句惯,從而實現(xiàn)了數(shù)據(jù)在模塊內(nèi)部特征的傳遞不會丟失。
從下圖可以看到拷淘,模塊內(nèi)部司機(jī)上是3個卷積通道互相疊加启涯,形成一個瓶頸設(shè)計恃轩。對于每一個殘差模塊叉跛,使用3層卷積。這3層分別是1 x 1鸣峭、3 x 3和1 x 1的卷積層,其中1 x 1層負(fù)責(zé)先減少后增加(恢復(fù))尺寸楣铁,使3 x 3層具有較小的輸入和輸出尺寸瓶頸更扁。
實現(xiàn)3層卷積結(jié)構(gòu)的代碼如下,
import jax.example_libraries.stax
def IdentityBlock(kernel_size, filters):
kernel_size_ = kernel_size
filters1, filters2 = filters
# Generate a main path
def make_main(inputs_shape):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
# Adjust according to the inputs automatically
jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(Main,
jax.example_libraries.stax.Identity),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu
)
代碼中輸入的數(shù)據(jù)首先經(jīng)過jax.example_libraries.stax.Conv()卷積運算,輸出的為四分之一的輸出維度膛薛,這是為了降低輸入數(shù)據(jù)的整個數(shù)據(jù)量哄啄,為進(jìn)行下一層[3, 3]的計算做準(zhǔn)備。 jax.example_libraries.stax.BatchNorm()是批標(biāo)準(zhǔn)化層沪么,jax.example_libraries.stax.Relu是激活層锌半。
另外刊殉,這里使用了3個之前沒有見過的類,首先需要知道逸月,這些類的目的是將不同的計算通路進(jìn)行一個組合亚亲。jax.example_libraries.stax.FanOut(2)是對數(shù)據(jù)進(jìn)行復(fù)制捌归,jax.example_libraries.stax.paralle(Main, Identity)是將主通計算結(jié)果與Identity通路計算結(jié)果進(jìn)行同時并聯(lián)處理,jax.example_libraries.stax.FanInSum()對并聯(lián)處理的數(shù)據(jù)進(jìn)行合并特笋。
在數(shù)據(jù)傳遞過程中猎物,ResNet模塊使用了名為“shortcut”的“新石高速公路”,即集捷通道淘讥。shortcut連接相當(dāng)于簡單執(zhí)行了同等映射蒲列,不會產(chǎn)生額外的參數(shù)搀罢,也不會增加計算復(fù)雜度,如下圖所示抵赢,
而且铅鲤,整個網(wǎng)絡(luò)依舊可以通過端到端的反向傳播訓(xùn)練枫弟。代碼如下,
def IdentityBlock(kernel_size, filters):
kernel_size_ = kernel_size
filters1, filters2 = filters
# Generate a main path
def make_main(inputs_shape):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
# Adjust according to the inputs automatically
jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(Main,
jax.example_libraries.stax.Identity),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu
)
有的時候,除了判定是否對輸入數(shù)據(jù)進(jìn)行處理外袜漩,由于ResNet在實現(xiàn)過程中對數(shù)據(jù)的維度做了改變宙攻,因此介褥,當(dāng)輸入的維度和要求模型輸出的維度不同(input_channel不等于out_dim)時,需要對輸入的維度進(jìn)行padding操作溢陪。所謂padding操作就是補全數(shù)據(jù)形真,通過設(shè)置padding參數(shù)對數(shù)據(jù)進(jìn)行補全超全。
ResNet網(wǎng)絡(luò)實現(xiàn)
ResNet網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示,
圖中一共提到5種深度的ResNet蛾坯,分別是18、34救军、50下翎、101和152,其中所有的網(wǎng)絡(luò)都分為5個部分胆萧,分貝是conv1跌穗、conv2_x、conv3_x蚌吸、conv4_x和conv5_x羹唠。
下面將對其進(jìn)行實現(xiàn)娄昆。需要說明的是,ResNet完整的實現(xiàn)需要較高性能的顯卡哺眯。為了便于演示奶卓,下面代碼里做了修改夺姑,去掉了pooling層,并降低了filters的數(shù)目和每層的層數(shù)瑟幕,這一點請務(wù)必注意只盹。
完整實現(xiàn)的ResNet50代碼如下,
import jax.example_libraries.stax
def IdentityBlock(kernel_size, filters):
kernel_size_ = kernel_size
filters1, filters2 = filters
# Generate a main path
def make_main(inputs_shape):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
# Adjust according to the inputs automatically
jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(Main,
jax.example_libraries.stax.Identity),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu
)
def ConvolutionalBlock(kernel_size, filters, strides = (1, 1)):
kernel_size_ = kernel_size
filters1, filters2, filters3 = filters
Main = jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), strides = strides, padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters3, (1, 1), strides = strides, padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Shortcut = jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters3, (1, 1), strides, padding = "SAME")
)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(
Main,
Shortcut
),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu)
def ResNet50(number_classes):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(64, (3, 3), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.MaxPool((3, 3), strides = (2, 2)),
ConvolutionalBlock(3, [64, 64, 256]),
IdentityBlock(3, [64, 64]),
IdentityBlock(3, [64, 64]),
ConvolutionalBlock(3, [128, 128, 512]),
IdentityBlock(3, [128, 128]),
IdentityBlock(3, [128, 128]),
IdentityBlock(3, [128, 128,]),
ConvolutionalBlock(3, [256, 256, 1024]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
ConvolutionalBlock(3, [512, 512, 2048]),
IdentityBlock(3, [512, 512]),
IdentityBlock(3, [512, 512]),
jax.example_libraries.stax.AvgPool((7, 7)),
jax.example_libraries.stax.Flatten,
jax.example_libraries.stax.Dense(number_classes),
jax.example_libraries.stax.LogSoftmax
)
結(jié)論
本章介紹了CIFAR100的數(shù)據(jù)集的結(jié)構(gòu)孵稽,也介紹了ResNet殘差模塊及網(wǎng)絡(luò)實現(xiàn)菩鲜,還是為了實戰(zhàn)做準(zhǔn)備。