第50章 CIFAR100數(shù)據(jù)集與ResNet網(wǎng)絡(luò)實現(xiàn)

上一章講解了ResNet模型及組件蒋荚,也介紹了jax.example_libraries.stax下JAX內(nèi)置的模型組件期升。有了這些準(zhǔn)備工作,可以開始上手寫代碼了颂郎。所以行拢,本章著手使用ResNet實現(xiàn)CIFAR100數(shù)據(jù)集的分類舟奠。

何為CIFAR100數(shù)據(jù)集

CIFAR10和CIFAR100都是含有標(biāo)記小圖的數(shù)據(jù)集沼瘫,相比CIFAR10咙俩,CIFAR100含有100個分類阿趁,每類含600個圖。其中皂股,每個分類含有500張訓(xùn)練圖和100張測試圖呜呐。CIFAR100的100個分類又被分組到20個大類悍募。

Superclass Classes
aquatic mammal beaver, dolphin, otter, seal, whale
fish aquarium fish, flatfish, ray, shark, trout
flowers orchids, poppies, roses, sunflowers, tulips
food containers bottles, bowls, cans, cups, plates
fruit and vegetables apples, mushrooms, oranges, pears, sweet peppers
household electrical devices clock, computer keyboard, lamp, telephone, television
household furniture bed, chair, couch, table, wardrobe
insects bee, beetle, butterfly, caterpillar, cockroach
large carnivores bear, leopard, lion, tiger, wolf
large man-made outdoor things bridge, castle, house, road, skyscraper
large natural outdoor scenes cloud, forest, mountain, plain, sea
large omnivores and herbivores camel, cattle, chimpanzee, elephant, kangaroo
medium-sized mammals fox, porcupine, possum, raccoon, skunk
non-insect invertebrates crab, lobster, snail, spider, worm
people baby, boy, girl, man, woman
reptiles crocodile, dinosaur, lizard, snake, turtle
small mammals hamster, mouse, rabbit, shrew, squirrel
trees hamster, mouse, rabbit, shrew, squirrel
vehicles 1 bicycle, bus, motorcycle, pickup truck, train
vehicles 2 lawn-mower, rocket, streetcar, tank, tractor

每個圖含有一個“fine”標(biāo)簽(表示所屬分類)和一個“coarse”標(biāo)簽(所屬大類)洋魂,大小為32x32像素副砍。

圖1 CIFAR分類

可以通過兩種方式下載數(shù)據(jù)集址晕,

版本 大小 md5sum
CIFAR-100 python version 161 MB eb9058c3a382ffc7106e4002c42a8d85
CIFAR-100 Matlab version 175 MB 6a4bfa1dcd5c9453dda6bb54194911f4
CIFAR-100 binary version (suitable for C programs) 161 MB 03b5dce01913d631647c71ecec9e9cb8

選擇python版本。

  • 使用tensorflow_datasets下載胳赌。

下面分別介紹一下匙隔。

使用下載后的CIFAR100生成數(shù)據(jù)集

CIFAR-100 python version下載后,會有如下文件結(jié)構(gòu)捍掺,

train
test
meta
file.txt~

其中挺勿,meta是數(shù)據(jù)集信息喂柒,train是訓(xùn)練集灾杰,test是測試集。通過如下代碼可以讀取數(shù)據(jù)集麦备,


import pickle

def setup():
    
    def load(fileName: str):
        
        with open(file = fileName, mode = "rb") as handler:
            
            data = pickle.load(file = handler, encoding = "latin1")
            
        return data
    
    trains = load("../../Shares/cifar-100-python/train")
    tests = load("../../Shares/cifar-100-python/test")
    metas = load("../../Shares/cifar-100-python/meta")
    
    return trains, tests, metas
    
def train():

        trains, tests, metas = setup2()
    
    for key in trains.keys():
        
        print(f"key = {key}, len(trains[key]) = {len(trains[key])}")
    
    print("--------------------------------------------------")
    
    for key in tests.keys():
        
        print(f"key = {key}, len(tests[key]) = {len(tests[key])}")
    
    print("--------------------------------------------------")
    
    for key in metas.keys():
        
        print(f"key = {key}, len(metas[key]) = {len(metas[key])}”)
    
def main():
    
    train()

運行結(jié)果打印輸出如下泥兰,


key = filenames, len(trains[key]) = 50000
key = batch_label, len(trains[key]) = 21
key = fine_labels, len(trains[key]) = 50000
key = coarse_labels, len(trains[key]) = 50000
key = data, len(trains[key]) = 50000
--------------------------------------------------
key = filenames, len(tests[key]) = 10000
key = batch_label, len(tests[key]) = 20
key = fine_labels, len(tests[key]) = 10000
key = coarse_labels, len(tests[key]) = 10000
key = data, len(tests[key]) = 10000
--------------------------------------------------
key = fine_label_names, len(metas[key]) = 100
key = coarse_label_names, len(metas[key]) = 20

具體說明如下鞋诗,

  • filenames削彬,長度為50000的列表融痛,每一項代表對應(yīng)一個圖片文件名神僵。
  • batch_label,批的信息沛励。
  • fine_labels,所屬分類坤候。
  • coarse_labels企蹭,所屬大類谅摄。
  • data,長度為50000 x 3072的的二位數(shù)據(jù)虚青,每一行代表一幅圖片的像素值。
使用tensorflow_datasets

import tensorflow as tf
import tensorflow_datasets as tfds
import jax

def setup():
    
    (trains, tests), meta = tfds.load("cifar100", data_dir = "/tmp/", split = [tfds.Split.TRAIN, tfds.Split.TEST], with_info = True, batch_size = -1)
    
    #tensorflow_datasets.show_examples(trains, metas)
        
    trains = tfds.as_numpy(trains)
    tests = tfds.as_numpy(tests)
    
    train_images, train_labels = trains["image"], trains["label"]
    test_images, test_labels = tests["image"], tests["label"]
    
    return (train_images, train_labels), (test_images, test_labels)
    
def train():
    
    (train_images, train_labels), (test_images, test_labels) = setup()
    
    print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape))
    
def main():
    
    train()
    
if __name__ == "__main__":
    
    main()

運行結(jié)果打印輸出如下下隧,


((50000, 32, 32, 3), (50000,)) ((10000, 32, 32, 3), (10000,))

keras.datasets數(shù)據(jù)集

def setup():
    
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()

    return (train_images, train_labels), (test_images, test_labels)

運行結(jié)果打印輸出如下谓媒,


((50000, 32, 32, 3), (50000, 1)) ((10000, 32, 32, 3), (10000, 1))

ResNet殘差模型實現(xiàn)

ResNet神經(jīng)網(wǎng)絡(luò)架構(gòu)在上一章已經(jīng)介紹,該網(wǎng)絡(luò)創(chuàng)造性地使用“模塊化‘的思維去對網(wǎng)絡(luò)進(jìn)行疊加句惯,從而實現(xiàn)了數(shù)據(jù)在模塊內(nèi)部特征的傳遞不會丟失。

從下圖可以看到拷淘,模塊內(nèi)部司機(jī)上是3個卷積通道互相疊加启涯,形成一個瓶頸設(shè)計恃轩。對于每一個殘差模塊叉跛,使用3層卷積。這3層分別是1 x 1鸣峭、3 x 3和1 x 1的卷積層,其中1 x 1層負(fù)責(zé)先減少后增加(恢復(fù))尺寸楣铁,使3 x 3層具有較小的輸入和輸出尺寸瓶頸更扁。

實現(xiàn)3層卷積結(jié)構(gòu)的代碼如下,


import jax.example_libraries.stax

def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

代碼中輸入的數(shù)據(jù)首先經(jīng)過jax.example_libraries.stax.Conv()卷積運算,輸出的為四分之一的輸出維度膛薛,這是為了降低輸入數(shù)據(jù)的整個數(shù)據(jù)量哄啄,為進(jìn)行下一層[3, 3]的計算做準(zhǔn)備。 jax.example_libraries.stax.BatchNorm()是批標(biāo)準(zhǔn)化層沪么,jax.example_libraries.stax.Relu是激活層锌半。

另外刊殉,這里使用了3個之前沒有見過的類,首先需要知道逸月,這些類的目的是將不同的計算通路進(jìn)行一個組合亚亲。jax.example_libraries.stax.FanOut(2)是對數(shù)據(jù)進(jìn)行復(fù)制捌归,jax.example_libraries.stax.paralle(Main, Identity)是將主通計算結(jié)果與Identity通路計算結(jié)果進(jìn)行同時并聯(lián)處理,jax.example_libraries.stax.FanInSum()對并聯(lián)處理的數(shù)據(jù)進(jìn)行合并特笋。

在數(shù)據(jù)傳遞過程中猎物,ResNet模塊使用了名為“shortcut”的“新石高速公路”,即集捷通道淘讥。shortcut連接相當(dāng)于簡單執(zhí)行了同等映射蒲列,不會產(chǎn)生額外的參數(shù)搀罢,也不會增加計算復(fù)雜度,如下圖所示抵赢,

而且铅鲤,整個網(wǎng)絡(luò)依舊可以通過端到端的反向傳播訓(xùn)練枫弟。代碼如下,


def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

有的時候,除了判定是否對輸入數(shù)據(jù)進(jìn)行處理外袜漩,由于ResNet在實現(xiàn)過程中對數(shù)據(jù)的維度做了改變宙攻,因此介褥,當(dāng)輸入的維度和要求模型輸出的維度不同(input_channel不等于out_dim)時,需要對輸入的維度進(jìn)行padding操作溢陪。所謂padding操作就是補全數(shù)據(jù)形真,通過設(shè)置padding參數(shù)對數(shù)據(jù)進(jìn)行補全超全。

ResNet網(wǎng)絡(luò)實現(xiàn)

ResNet網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示,

圖中一共提到5種深度的ResNet蛾坯,分別是18、34救军、50下翎、101和152,其中所有的網(wǎng)絡(luò)都分為5個部分胆萧,分貝是conv1跌穗、conv2_x、conv3_x蚌吸、conv4_x和conv5_x羹唠。

下面將對其進(jìn)行實現(xiàn)娄昆。需要說明的是,ResNet完整的實現(xiàn)需要較高性能的顯卡哺眯。為了便于演示奶卓,下面代碼里做了修改夺姑,去掉了pooling層,并降低了filters的數(shù)目和每層的層數(shù)瑟幕,這一點請務(wù)必注意只盹。

完整實現(xiàn)的ResNet50代碼如下,


import jax.example_libraries.stax

def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

def ConvolutionalBlock(kernel_size, filters, strides = (1, 1)):
    
    kernel_size_ = kernel_size
    filters1, filters2, filters3 = filters
    
    Main = jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(filters1, (1, 1), strides = strides, padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(filters3, (1, 1), strides = strides, padding = "SAME"),
        jax.example_libraries.stax.BatchNorm()
    )
    
    Shortcut = jax.example_libraries.stax.serial(
        jax.example_libraries.stax.Conv(filters3, (1, 1), strides, padding = "SAME")
    )
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(
            Main,
            Shortcut
        ),
        
        jax.example_libraries.stax.FanInSum,
        jax.example_libraries.stax.Relu)

def ResNet50(number_classes):
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(64, (3, 3), padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.MaxPool((3, 3), strides = (2, 2)),
        
        ConvolutionalBlock(3, [64, 64, 256]),
        
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        
        ConvolutionalBlock(3, [128, 128, 512]),
        
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128,]),
        
        ConvolutionalBlock(3, [256, 256, 1024]),
        
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        
        ConvolutionalBlock(3, [512, 512, 2048]),
        
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        
        jax.example_libraries.stax.AvgPool((7, 7)),
        
        jax.example_libraries.stax.Flatten,
        
        jax.example_libraries.stax.Dense(number_classes),
        
        jax.example_libraries.stax.LogSoftmax
    )

結(jié)論

本章介紹了CIFAR100的數(shù)據(jù)集的結(jié)構(gòu)孵稽,也介紹了ResNet殘差模塊及網(wǎng)絡(luò)實現(xiàn)菩鲜,還是為了實戰(zhàn)做準(zhǔn)備。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末猛频,一起剝皮案震驚了整個濱河市鹿寻,隨后出現(xiàn)的幾起案子毡熏,更是在濱河造成了極大的恐慌侣诵,老刑警劉巖,帶你破解...
    沈念sama閱讀 212,029評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件疯暑,死亡現(xiàn)場離奇詭異,居然都是意外死亡幻馁,警方通過查閱死者的電腦和手機(jī)仗嗦,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,395評論 3 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來火邓,“玉大人铲咨,你說我怎么就攤上這事∨髌唬” “怎么了摇天?”我有些...
    開封第一講書人閱讀 157,570評論 0 348
  • 文/不壞的土叔 我叫張陵泉坐,是天一觀的道長。 經(jīng)常有香客問我孤钦,道長记某,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,535評論 1 284
  • 正文 為了忘掉前任壳猜,我火速辦了婚禮统扳,結(jié)果婚禮上咒钟,老公的妹妹穿的比我還像新娘。我一直安慰自己朱嘴,他們只是感情好萍嬉,可當(dāng)我...
    茶點故事閱讀 65,650評論 6 386
  • 文/花漫 我一把揭開白布壤追。 她就那樣靜靜地躺著供屉,像睡著了一般。 火紅的嫁衣襯著肌膚如雪悼做。 梳的紋絲不亂的頭發(fā)上贿堰,一...
    開封第一講書人閱讀 49,850評論 1 290
  • 那天,我揣著相機(jī)與錄音故硅,去河邊找鬼吃衅。 笑死徘层,一個胖子當(dāng)著我的面吹牛利职,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播跷敬,決...
    沈念sama閱讀 39,006評論 3 408
  • 文/蒼蘭香墨 我猛地睜開眼西傀,長吁一口氣:“原來是場噩夢啊……” “哼拥褂!你這毒婦竟也來了牙寞?” 一聲冷哼從身側(cè)響起饺鹃,我...
    開封第一講書人閱讀 37,747評論 0 268
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎间雀,沒想到半個月后悔详,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,207評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡雷蹂,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,536評論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了杯道。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片匪煌。...
    茶點故事閱讀 38,683評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡责蝠,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出萎庭,到底是詐尸還是另有隱情霜医,我是刑警寧澤,帶...
    沈念sama閱讀 34,342評論 4 330
  • 正文 年R本政府宣布,位于F島的核電站镀梭,受9級特大地震影響报账,放射性物質(zhì)發(fā)生泄漏透罢。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,964評論 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望愁憔。 院中可真熱鬧吨掌,春花似錦膜宋、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,772評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至轮傍,卻和暖如春金麸,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背棚瘟。 一陣腳步聲響...
    開封第一講書人閱讀 32,004評論 1 266
  • 我被黑心中介騙來泰國打工瞬内, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留虫蝶,地道東北人赁严。 一個月前我還...
    沈念sama閱讀 46,401評論 2 360
  • 正文 我出身青樓,卻偏偏與公主長得像程剥,于是被迫代替她去往敵國和親织鲸。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,566評論 2 349

推薦閱讀更多精彩內(nèi)容