CNN遷移學習vgg16實踐

目錄

  • 應用場景
  • prerequisite知識
  • 代碼實例
  • 結論

應用場景

假如我們有一系列訴求是把圖片識別成一個特定分類、比如

  1. 把圖片分類成為貓、狗、狼等
  2. 把圖片分類成為奔馳晴氨、寶馬、奧迪
  3. ...

幾乎很少有人從頭訓練網絡碉输、復用只有訓練的網絡參數適應新的數據集籽前、參考transfer-learning

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.

prerequisite知識

. CNN卷積過程
. TensorFlow的接口

可視化下貼上caffemodel定義可以查看網絡結構、以下是vgg16前幾層的參考


層數越往上激活的圖片就約簡單敷钾、所以更容易被共享枝哄;拿用image Net訓練好1000分類的網絡參數可以認為前幾層幾乎都是訓練好的、替換最后面fc層阻荒、換成目標的分類的個數
假如我們識別的是貓狗挠锥、那么fc就兩個分類、最后一層需要重新訓練

代碼實例

基于TensorFlow vgg16 fine tuning
卷積矩陣大小變化變化可以參考過程侨赡、

其中涉及數據預處理可以參考neural-networks-2

Mean subtraction is the most common form of preprocessing. It involves subtracting the mean across every individual feature in the data, and has the geometric interpretation of centering the cloud of data around the origin along every dimension. In numpy, this operation would be implemented as: X -= np.mean(X, axis = 0). With images specifically, for convenience it can be common to subtract a single value from all pixels (e.g. X -= np.mean(X)), or to do so separately across the three color channels.
代碼如下:

"""
#訓練好的參數http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
目錄結構
  train/
    貓/
      COCO_train2014_000000005785.jpg
      COCO_train2014_000000015870.jpg
    ??/
  val/
    貓/
    狗/
"""
import argparse
import os

import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets


parser = argparse.ArgumentParser()
#訓練數據目錄
parser.add_argument('--train_dir', default='train')
#測試目錄
parser.add_argument('--val_dir', default='val')
#初始網絡參數
parser.add_argument('--model_path', default='vgg_16.ckpt', type=str)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--num_epochs1', default=10, type=int)
parser.add_argument('--num_epochs2', default=10, type=int)
parser.add_argument('--learning_rate1', default=1e-3, type=float)
parser.add_argument('--learning_rate2', default=1e-5, type=float)
parser.add_argument('--dropout_keep_prob', default=0.5, type=float)
parser.add_argument('--weight_decay', default=5e-4, type=float)

#平化圖像參數
VGG_MEAN = [123.68, 116.78, 103.94]


def list_images(directory):
    labels = os.listdir(directory)
    files_and_labels = []
    for label in labels:
        for f in os.listdir(os.path.join(directory, label)):
            files_and_labels.append((os.path.join(directory, label, f), label))

    filenames, labels = zip(*files_and_labels)
    filenames = list(filenames)
    labels = list(labels)
    unique_labels = list(set(labels))

    label_to_int = {}
    for i, label in enumerate(unique_labels):
        label_to_int[label] = i

    labels = [label_to_int[l] for l in labels]

    return filenames, labels


def check_accuracy(sess, correct_prediction, is_training, dataset_init_op):
    # Initialize the correct dataset
    sess.run(dataset_init_op)
    num_correct, num_samples = 0, 0
    while True:
        try:
            correct_pred = sess.run(correct_prediction, {is_training: False})
            num_correct += correct_pred.sum()
            num_samples += correct_pred.shape[0]
        except tf.errors.OutOfRangeError:
            break

    acc = float(num_correct) / num_samples
    return acc


def main(args):
    # 拿訓練&測試文件和label
    train_filenames, train_labels = list_images(args.train_dir)
    val_filenames, val_labels = list_images(args.val_dir)

    num_classes = len(set(train_labels))


    graph = tf.Graph()
    with graph.as_default():
        #讀圖
        def _parse_function(filename, label):
            image_string = tf.read_file(filename)
            image_decoded = tf.image.decode_jpeg(image_string, channels=3)          
            image = tf.cast(image_decoded, tf.float32)

            smallest_side = 256.0
            height, width = tf.shape(image)[0], tf.shape(image)[1]
            height = tf.to_float(height)
            width = tf.to_float(width)
            #縮放
            scale = tf.cond(tf.greater(height, width),
                            lambda: smallest_side / width,
                            lambda: smallest_side / height)
            new_height = tf.to_int32(height * scale)
            new_width = tf.to_int32(width * scale)

            resized_image = tf.image.resize_images(image, [new_height, new_width])  # (2)
            return resized_image, label

        #均值數據處理
        def training_preprocess(image, label):
            crop_image = tf.random_crop(image, [224, 224, 3])                       # (3)
            flip_image = tf.image.random_flip_left_right(crop_image)                # (4)

            means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
            centered_image = flip_image - means                                     # (5)

            return centered_image, label

        # 預處理蓖租、取224*224中間區(qū)域、減平均值
        def val_preprocess(image, label):
            crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)    # (3)

            means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
            centered_image = crop_image - means                                     # (4)

            return centered_image, label

        train_filenames = tf.constant(train_filenames)
        train_labels = tf.constant(train_labels)
        train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
        train_dataset = train_dataset.map(_parse_function,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        train_dataset = train_dataset.map(training_preprocess,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        train_dataset = train_dataset.shuffle(buffer_size=10000) 
        batched_train_dataset = train_dataset.batch(args.batch_size)


        val_filenames = tf.constant(val_filenames)
        val_labels = tf.constant(val_labels)
        val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels))
        val_dataset = val_dataset.map(_parse_function,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        val_dataset = val_dataset.map(val_preprocess,
            num_threads=args.num_workers, output_buffer_size=args.batch_size)
        batched_val_dataset = val_dataset.batch(args.batch_size)


        #迭代器讀圖&label
        iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types,
                                                           batched_train_dataset.output_shapes)
        images, labels = iterator.get_next()
        
        #初始化迭代器函數
        train_init_op = iterator.make_initializer(batched_train_dataset)
        val_init_op = iterator.make_initializer(batched_val_dataset)

        #傳給vgg16網絡羊壹、標識正向分類或者是訓練網絡參數
        is_training = tf.placeholder(tf.bool)
      
        vgg = tf.contrib.slim.nets.vgg
        with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=args.weight_decay)):
            #使用TensorFlow封裝好的網絡蓖宦、設置輸出分類個數
            logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training,
                                   dropout_keep_prob=args.dropout_keep_prob)

        model_path = args.model_path
        assert(os.path.isfile(model_path))

        # 加載fc8之前網絡參數
        variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8'])
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)

        # 獲取fc8初始化函數
        fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
        fc8_init = tf.variables_initializer(fc8_variables)

        # loss疊加到tf.GraphKeys.LOSSES 結合上
        tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        loss = tf.losses.get_total_loss()

        #先訓練fc8這一層的參數
        fc8_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate1)
        fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables)

        # 然后再去整體訓練
        full_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate2)
        full_train_op = full_optimizer.minimize(loss)

        # 評估模型
        prediction = tf.to_int32(tf.argmax(logits, 1))
        correct_prediction = tf.equal(prediction, labels)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        tf.get_default_graph().finalize()

    with tf.Session(graph=graph) as sess:
        #加載conv1-fc7的參數
        init_fn(sess)
        #初始化fc的參數
        sess.run(fc8_init)

       #迭代
        for epoch in range(args.num_epochs1):
            sess.run(train_init_op)
            while True:
                try:
                    #文件和label已經在迭代器中
                    _ = sess.run(fc8_train_op, {is_training: True})
                except tf.errors.OutOfRangeError:
                    break
            
            train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
            val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
            


        # 整體訓練
        for epoch in range(args.num_epochs2):
            print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2))
            sess.run(train_init_op)
            while True:
                try:
                    _ = sess.run(full_train_op, {is_training: True})
                except tf.errors.OutOfRangeError:
                    break

            train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
            val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
            print('Train accuracy: %f' % train_acc)
            print('Val accuracy: %f\n' % val_acc)


if __name__ == '__main__':
    args = parser.parse_args()
    main(args)

vgg16在TensorFlow封裝slim庫中,函數原型

def vgg_16(inputs,
           num_classes=1000,
           is_training=True,
           dropout_keep_prob=0.5,
           spatial_squeeze=True,
           scope='vgg_16'):
  """Oxford Net VGG 16-Layers version D Example.
  Note: All the fully_connected layers have been transformed to conv2d layers.
        To use in classification mode, resize input to 224x224.
  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    num_classes: number of predicted classes.
    is_training: whether or not the model is being trained.
    dropout_keep_prob: the probability that activations are kept in the dropout
      layers during training.
    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
      outputs. Useful to remove unnecessary dimensions for classification.
    scope: Optional scope for the variables.
  Returns:
    the last op containing the log predictions and end_points dict.

這個例子不需要GPU的支持、在osx就可以跑


結論

通常工程同學不會設計新的網絡結構油猫、甚至很少大改一個網絡機構稠茂、但是理解網絡結構、loss漸進方式有利于遷移學習情妖、用到特定的場景

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末睬关,一起剝皮案震驚了整個濱河市诱担,隨后出現的幾起案子,更是在濱河造成了極大的恐慌电爹,老刑警劉巖蔫仙,帶你破解...
    沈念sama閱讀 206,839評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現場離奇詭異藐不,居然都是意外死亡匀哄,警方通過查閱死者的電腦和手機秦效,發(fā)現死者居然都...
    沈念sama閱讀 88,543評論 2 382
  • 文/潘曉璐 我一進店門雏蛮,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人阱州,你說我怎么就攤上這事挑秉。” “怎么了苔货?”我有些...
    開封第一講書人閱讀 153,116評論 0 344
  • 文/不壞的土叔 我叫張陵犀概,是天一觀的道長。 經常有香客問我夜惭,道長姻灶,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,371評論 1 279
  • 正文 為了忘掉前任诈茧,我火速辦了婚禮产喉,結果婚禮上,老公的妹妹穿的比我還像新娘敢会。我一直安慰自己曾沈,他們只是感情好,可當我...
    茶點故事閱讀 64,384評論 5 374
  • 文/花漫 我一把揭開白布鸥昏。 她就那樣靜靜地躺著塞俱,像睡著了一般。 火紅的嫁衣襯著肌膚如雪吏垮。 梳的紋絲不亂的頭發(fā)上障涯,一...
    開封第一講書人閱讀 49,111評論 1 285
  • 那天,我揣著相機與錄音膳汪,去河邊找鬼像樊。 笑死,一個胖子當著我的面吹牛旅敷,可吹牛的內容都是我干的生棍。 我是一名探鬼主播,決...
    沈念sama閱讀 38,416評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼媳谁,長吁一口氣:“原來是場噩夢啊……” “哼涂滴!你這毒婦竟也來了友酱?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 37,053評論 0 259
  • 序言:老撾萬榮一對情侶失蹤柔纵,失蹤者是張志新(化名)和其女友劉穎缔杉,沒想到半個月后,有當地人在樹林里發(fā)現了一具尸體搁料,經...
    沈念sama閱讀 43,558評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡或详,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 36,007評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現自己被綠了郭计。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片霸琴。...
    茶點故事閱讀 38,117評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖昭伸,靈堂內的尸體忽然破棺而出梧乘,到底是詐尸還是另有隱情,我是刑警寧澤庐杨,帶...
    沈念sama閱讀 33,756評論 4 324
  • 正文 年R本政府宣布选调,位于F島的核電站,受9級特大地震影響灵份,放射性物質發(fā)生泄漏仁堪。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,324評論 3 307
  • 文/蒙蒙 一填渠、第九天 我趴在偏房一處隱蔽的房頂上張望弦聂。 院中可真熱鬧,春花似錦揭蜒、人聲如沸横浑。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽徙融。三九已至,卻和暖如春瑰谜,著一層夾襖步出監(jiān)牢的瞬間欺冀,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評論 1 262
  • 我被黑心中介騙來泰國打工萨脑, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留隐轩,地道東北人。 一個月前我還...
    沈念sama閱讀 45,578評論 2 355
  • 正文 我出身青樓渤早,卻偏偏與公主長得像职车,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 42,877評論 2 345

推薦閱讀更多精彩內容