ResNet-V1-50卷積神經(jīng)網(wǎng)絡(luò)遷移學習進行不同品種的花的分類識別

運行環(huán)境

python3.6.3奴璃、tensorflow1.10.0
Intel@AIDevCloud:Intel Xeon Gold 6128 processors集群

數(shù)據(jù)和模型來源

數(shù)據(jù)集:http://download.tensorflow.org/example_images/flower_photos.tgz
模型:http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

思路

數(shù)據(jù)集分析及處理

數(shù)據(jù)集文件解壓后共有五個文件夾讲竿,每個文件夾都包含一定數(shù)量的花的圖片妖混,一個文件夾對應(yīng)一個品種插龄,圖片各種尺寸都有昙衅,均為jpg格式腺办,均為彩色圖片。這里利用tensorflow提供的圖片處理工具將所有圖片轉(zhuǎn)為300×300×3的格式述呐,然后將所有圖片的80%當作訓練集,10%當作驗證集蕉毯,10%當作測試集乓搬,并且將訓練集進行隨機打亂,將得到的數(shù)據(jù)存在一個numpy文件中代虾,以待后續(xù)訓練使用进肯。

模型構(gòu)建

這里采用了ResNet-V1-50卷積神經(jīng)網(wǎng)絡(luò)來進行訓練,模型結(jié)構(gòu)在slim中都提供好了棉磨,另外采用官方已經(jīng)訓練好的參數(shù)進行遷移學習江掩,只是在模型的最后根據(jù)問題的實際需要再定義一層輸出層,只訓練最后的自定義的全連接輸出層的參數(shù)乘瓤,訓練500次环形,每次batch樣本數(shù)取32,學習率取0.0001衙傀。

源代碼

load_data.py
# -*- coding: UTF-8 -*-
#Author:Yinli

import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

#定義輸入文件夾和數(shù)據(jù)存儲文件名
INPUT_DATA = 'flower_photos'
OUTPUT_FILE = 'flower_processed_data.npy'

#設(shè)定驗證集和測試集的百分比
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

def create_image_list(sess, testing_percentage, validation_percentage):

    #列出輸入文件夾下的所有子文件夾抬吟,此時sub_dirs里面除了有子文件夾還有它自身,在第一個
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    #設(shè)置一個bool值统抬,指定第一次循環(huán)的時候跳過母文件夾
    is_root_dir = True
    #print(sub_dirs)

    #初始化數(shù)據(jù)矩陣
    training_images = []
    training_labels = []
    testing_images = []
    testing_labels = []
    validation_images = []
    validation_labels= []
    current_label = 0

    #分別處理每個子文件夾
    for sub_dir in sub_dirs:
        #跳過第一個值火本,即跳過母文件夾
        if is_root_dir:
            is_root_dir = False
            continue

        #獲取子目錄中的所有圖片文件
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        #用列表記錄所有圖片文件
        file_list = []
        #獲取此子目錄的名字比如daisy
        dir_name = os.path.basename(sub_dir)
        #對此子目錄中所有圖片后綴的文件
        for extension in extensions:
            #獲取每種圖片的所有正則表達式
            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
            print(file_glob)
            #將所有符合正則表達式的文件名加入文件列表
            file_list.extend(glob.glob(file_glob))
        print(file_list)
        #如果沒有文件跳出循環(huán)
        if not file_list:
            continue
        print("processing ", dir_name)

        i = 0
        #對于每張圖片
        for file_name in file_list:
            i+=1
            #打開圖片文件
            #print("process num : ",i,"   processing", file_name, file=f)
            image_raw_data = gfile.FastGFile(file_name,'rb').read()
            #解碼
            image = tf.image.decode_jpeg(image_raw_data)
            #如果圖片格式不是float32則轉(zhuǎn)為float32
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            #將圖片源數(shù)據(jù)轉(zhuǎn)為299*299
            image = tf.image.resize_images(image, [300,300])
            #得到此圖片的數(shù)據(jù)
            image_value = sess.run(image)
            print(np.shape(image_value))

            #生成一個100以內(nèi)的數(shù)
            chance = np.random.randint(100)
            #按概率隨機分到三個數(shù)據(jù)集中
            if chance < validation_percentage:
                validation_images.append(image_value)
                validation_labels.append(current_label)
            elif chance < (testing_percentage + validation_percentage):
                testing_images.append(image_value)
                testing_labels.append(current_label)
            else:
                training_images.append(image_value)
                training_labels.append(current_label)
            if i%200 == 0:
                print("processing...")
        #處理完此種品種就將標簽+1
        current_label += 1

    #將訓練數(shù)據(jù)和標簽以同樣的方式打亂
    state = np.random.get_state()
    np.random.shuffle(training_images)
    np.random.set_state(state)
    np.random.shuffle(training_labels)

    #返回所有數(shù)據(jù)
    return np.asarray([training_images, training_labels,
                       validation_images, validation_labels, testing_images, testing_labels])


def main():
    with tf.Session() as sess:
        processed_data = create_image_list(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
        #將數(shù)據(jù)存到文件中
        np.save(OUTPUT_FILE, processed_data)

if __name__ == "__main__":
    main()
resnet.py
# -*- coding: UTF-8 -*-
# Author:Yinli

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加載通過slim定義好的resnet_v1模型
import tensorflow.contrib.slim.python.slim.nets.resnet_v1 as resnet_v1

# 數(shù)據(jù)文件
INPUT_DATA = "./flower_processed_data.npy"
# 保存訓練好的模型
TRAIN_FILE = "./save_model/my_model"
# 提供的已經(jīng)訓練好的模型
CKPT_FILE = "./resnet_v1_50.ckpt"

# 定義訓練所用參數(shù)
LEARNING_RATE = 0.0001
STEPS = 500
BATCH = 32
N_CLASSES = 5

# 這里指出了不需要從訓練好的模型中加載的參數(shù)危队,就是最后的自定義的全連接層
CHECKPOINT_EXCLUDE_SCOPES = 'Logits'
# 指定最后的全連接層為可訓練的參數(shù)
TRAINABLE_SCOPES = 'Logits'


# 加載所有需要從訓練好的模型加載的參數(shù)
def get_tuned_variables():
    ##不需要加載的范圍
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
    # 初始化需要加載的參數(shù)
    variables_to_restore = []

    # 遍歷模型中的所有參數(shù)
    for var in slim.get_model_variables():
        # 先指定為不需要移除
        excluded = False
        # 遍歷exclusions,如果在exclusions中钙畔,就指定為需要移除
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        # 如果遍歷完后還是不需要移除茫陆,就把參數(shù)加到列表里
        if not excluded:
            variables_to_restore.append(var)
    return variables_to_restore


# 獲取所有需要訓練的參數(shù)
def get_trainable_variables():
    # 同上
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
    variables_to_train = []
    # 枚舉所有需要訓練的參數(shù)的前綴,并找到這些前綴的所有參數(shù)
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train


def main():
    # 加載數(shù)據(jù)
    processed_data = np.load(INPUT_DATA)
    training_images = processed_data[0]
    n_training_example = len(training_images)
    training_labels = processed_data[1]
    validation_images = processed_data[2]
    validation_labels = processed_data[3]
    testing_images = processed_data[4]
    testing_labels = processed_data[5]

    print("there is %d training examples, %d validation examples, %d testing examples" %
          (n_training_example, len(validation_labels), len(testing_labels)))

    # 定義數(shù)據(jù)格式
    images = tf.placeholder(tf.float32, [None, 300, 300, 3], name='input_images')
    labels = tf.placeholder(tf.int64, [None], name='labels')

    # 定義模型擎析,因為給出的只有參數(shù)簿盅,并沒有模型,這里需要指定模型的具體結(jié)構(gòu)
    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
        # logits就是最后預(yù)測值揍魂,images就是輸入數(shù)據(jù)挪鹏,指定num_classes=None是為了使resnet模型最后的輸出層禁用
        logits, _ = resnet_v1.resnet_v1_50(images, num_classes=None)

    #自定義的輸出層
    with tf.variable_scope("Logits"):
        #將原始模型的輸出數(shù)據(jù)去掉維度為2和3的維度,最后只剩維度1的batch數(shù)和維度4的300*300*3
        #也就是將原來的二三四維度全部壓縮到第四維度
        net = tf.squeeze(logits, axis=[1,2])
        #加入一層dropout層
        net = slim.dropout(net, keep_prob=0.5,scope='dropout_scope')
        #加入一層全連接層愉烙,指定最后輸出大小
        logits = slim.fully_connected(net, num_outputs=N_CLASSES, scope='fc')


    # 獲取需要訓練的變量
    trainable_variables = get_trainable_variables()

    # 定義損失讨盒,模型定義的時候已經(jīng)考慮了正則化了
    tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASSES), logits, weights=1.0)
    # 定義訓練過程
    train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())

    # 定義測試和驗證過程
    with tf.name_scope('evaluation'):
        correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
        evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # 定義加載模型的函數(shù),就是重新定義load_fn函數(shù)步责,從文件中獲取參數(shù)返顺,獲取指定的變量,忽略缺省值
    load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE, get_tuned_variables(), ignore_missing_vars=True)

    # 定義保存新的訓練好的模型的函數(shù)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 初始化沒有加載進來的變量蔓肯,一定要在模型加載之前遂鹊,否則會將訓練好的參數(shù)重新賦值
        init = tf.global_variables_initializer()
        sess.run(init)

        # 加載訓練好的模型
        print("加載谷歌訓練好的模型...")
        load_fn(sess)

        start = 0
        end = BATCH
        for i in range(STEPS):
            # 訓練...
            sess.run(train_step, feed_dict={images: training_images[start:end],
                                            labels: training_labels[start:end]})
            # 間斷地保存模型,并在驗證集上驗證
            if i % 50 == 0 or i + 1 == STEPS:
                saver.save(sess, TRAIN_FILE, global_step=i)
                validation_accuracy = sess.run(evaluation_step, feed_dict={images: validation_images,
                                                                           labels: validation_labels})
                print("經(jīng)過%d次訓練后蔗包,在驗證集上的正確率為%.3f" % (i, validation_accuracy))

            # 更新起始和末尾
            start = end
            if start == n_training_example:
                start = 0
            end = start + BATCH
            if end > n_training_example:
                end = n_training_example

        # 訓練完了在測試集上測試正確率
        testing_accuracy = sess.run(evaluation_step, feed_dict={images: testing_images,
                                                                labels: testing_labels})
        print("最后在測試集上的正確率為%.3f" % testing_accuracy)


if __name__ == '__main__':
    main()

運行結(jié)果

result.png

結(jié)果分析

從結(jié)果中可以看到秉扑,利用已經(jīng)訓練好的復(fù)雜模型的參數(shù),再根據(jù)問題加上一層自定義的輸出層调限,可以在短時間內(nèi)利用較少的資源將模型遷移到不同的問題上舟陆,在200次訓練的時候就可以在這個問題上達到90%的正確率,經(jīng)過500次訓練后可以在測試集上達到接近95%的正確率耻矮,驗證了目前的主流卷積神經(jīng)網(wǎng)絡(luò)具有很好的普適性和遷移性秦躯。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市裆装,隨后出現(xiàn)的幾起案子踱承,更是在濱河造成了極大的恐慌,老刑警劉巖哨免,帶你破解...
    沈念sama閱讀 219,427評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件茎活,死亡現(xiàn)場離奇詭異,居然都是意外死亡琢唾,警方通過查閱死者的電腦和手機载荔,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,551評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來慧耍,“玉大人身辨,你說我怎么就攤上這事丐谋∩直蹋” “怎么了煌珊?”我有些...
    開封第一講書人閱讀 165,747評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長泌豆。 經(jīng)常有香客問我定庵,道長,這世上最難降的妖魔是什么踪危? 我笑而不...
    開封第一講書人閱讀 58,939評論 1 295
  • 正文 為了忘掉前任蔬浙,我火速辦了婚禮,結(jié)果婚禮上贞远,老公的妹妹穿的比我還像新娘畴博。我一直安慰自己,他們只是感情好蓝仲,可當我...
    茶點故事閱讀 67,955評論 6 392
  • 文/花漫 我一把揭開白布俱病。 她就那樣靜靜地躺著,像睡著了一般袱结。 火紅的嫁衣襯著肌膚如雪亮隙。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,737評論 1 305
  • 那天垢夹,我揣著相機與錄音溢吻,去河邊找鬼。 笑死果元,一個胖子當著我的面吹牛促王,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播而晒,決...
    沈念sama閱讀 40,448評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼硼砰,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了欣硼?” 一聲冷哼從身側(cè)響起题翰,我...
    開封第一講書人閱讀 39,352評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎诈胜,沒想到半個月后豹障,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,834評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡焦匈,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,992評論 3 338
  • 正文 我和宋清朗相戀三年血公,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片缓熟。...
    茶點故事閱讀 40,133評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡累魔,死狀恐怖摔笤,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情垦写,我是刑警寧澤吕世,帶...
    沈念sama閱讀 35,815評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站梯投,受9級特大地震影響命辖,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜分蓖,卻給世界環(huán)境...
    茶點故事閱讀 41,477評論 3 331
  • 文/蒙蒙 一尔艇、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧么鹤,春花似錦终娃、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,022評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至迅皇,卻和暖如春昧辽,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背登颓。 一陣腳步聲響...
    開封第一講書人閱讀 33,147評論 1 272
  • 我被黑心中介騙來泰國打工搅荞, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人框咙。 一個月前我還...
    沈念sama閱讀 48,398評論 3 373
  • 正文 我出身青樓咕痛,卻偏偏與公主長得像,于是被迫代替她去往敵國和親喇嘱。 傳聞我的和親對象是個殘疾皇子茉贡,可洞房花燭夜當晚...
    茶點故事閱讀 45,077評論 2 355

推薦閱讀更多精彩內(nèi)容