預(yù)訓(xùn)練模型與遷移學(xué)習(xí)

一鞭缭、什么是遷移學(xué)習(xí)

聰明人都喜歡"偷懶",因?yàn)檫@樣的偷懶能幫我們節(jié)省大量的時(shí)間提高效率魏颓。有一種偷懶是 "站在巨人的肩膀上"岭辣,也就是表示要善于學(xué)習(xí)先輩的經(jīng)驗(yàn)。這句話放在機(jī)器學(xué)習(xí)中就是指的遷移學(xué)習(xí)甸饱。
遷移學(xué)習(xí)是一種機(jī)器學(xué)習(xí)技術(shù)沦童,顧名思義就是指將知識(shí)從一個(gè)領(lǐng)域遷移到另一個(gè)領(lǐng)域的能力。
我們知道叹话,神經(jīng)網(wǎng)絡(luò)需要用數(shù)據(jù)來訓(xùn)練偷遗,它從數(shù)據(jù)中獲得信息,進(jìn)而把它們轉(zhuǎn)換成相應(yīng)的權(quán)重驼壶。這些權(quán)重能夠被提取出來氏豌,遷移到其他的神經(jīng)網(wǎng)絡(luò)中,我們"遷移"了這些學(xué)來的特征热凹,就不需要從零開始訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)了 泵喘。

遷移學(xué)習(xí)的價(jià)值
復(fù)用現(xiàn)有知識(shí)域數(shù)據(jù),已有的大量工作不至于完全丟棄般妙;
不需要再去花費(fèi)巨大代價(jià)去重新采集和標(biāo)定龐大的新數(shù)據(jù)集涣旨,也有可能數(shù)據(jù)根本無法獲取股冗;
對(duì)于快速出現(xiàn)的新領(lǐng)域霹陡,能夠快速遷移和應(yīng)用,體現(xiàn)時(shí)效性優(yōu)勢(shì)止状。

二烹棉、遷移學(xué)習(xí)的載體:預(yù)訓(xùn)練模型

在計(jì)算機(jī)視覺領(lǐng)域中,遷移學(xué)習(xí)通常是通過使用預(yù)訓(xùn)練模型來體現(xiàn)的怯疤。預(yù)訓(xùn)練模型是在大型基準(zhǔn)數(shù)據(jù)集上訓(xùn)練的模型浆洗,用于解決相似的問題。由于訓(xùn)練這種模型的計(jì)算成本較高集峦,因此伏社,導(dǎo)入已發(fā)布的成果并使用相應(yīng)的模型是比較常見的做法。

1塔淤、keras.Application

Kera的應(yīng)用模塊Application提供了帶有預(yù)訓(xùn)練權(quán)重的Keras模型摘昌,這些模型可以用來進(jìn)行預(yù)測(cè)、特征提取和finetune高蜂。
目前聪黎,Keras 包含有 5 個(gè)預(yù)訓(xùn)練模型,分別為:VGG16备恤,VGG19稿饰,ResNet50锦秒,InceptionV3,Xception喉镰,MobileNet
(1)VGG16/ VGG19
Keras 導(dǎo)入 VGG16 和 VGG19 模型及默認(rèn)參數(shù)如下:

from keras.applications import vgg16
from keras.applications import vgg19
vgg16.VGG16(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)
vgg19.VGG19(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)

(2)ResNet50
Keras 導(dǎo)入 ResNet50 模型及默認(rèn)參數(shù)如下:

from keras.applications import resnet50
resnet50.ResNet50(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)

(3)InceptionV3
Keras 導(dǎo)入 InceptionV3 模型及默認(rèn)參數(shù)如下:

from keras.applications import inception_v3
inception_v3.InceptionV3(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)

(4)Xception
Keras 導(dǎo)入 Xception 模型及默認(rèn)參數(shù)如下

from keras.applications import xception
xception.Xception(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)

(5)MobileNet
Keras 導(dǎo)入 MobileNet 模型及默認(rèn)參數(shù)如下:

from keras.applications import mobilenet
mobilenet.MobileNet(input_shape=None, alpha=1.0, depth_multiplier=1, dropout=1e-3, include_top=True, weights='imagenet', input_tensor=None, pooling=None, classes=1000)

舉例:使用預(yù)訓(xùn)練模型輸出圖像分類預(yù)測(cè)

import numpy as np
from keras.preprocessing import image
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.applications.inception_v3 import decode_predictions

# 新建模型旅择,此處實(shí)際上是導(dǎo)入預(yù)訓(xùn)練模型
model = InceptionV3()
model.summary()

# 按照 InceptionV3 模型的默認(rèn)輸入尺寸,載入 demo1 圖像
img = image.load_img('demo1.jpg', target_size=(299, 299))

# 提取特征
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# 預(yù)測(cè)并輸出概率最高的三個(gè)類別
preds = model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])

官方文檔

2侣姆、自己下載預(yù)訓(xùn)練權(quán)重

VGG16:
WEIGHTS_PATH = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5
WEIGHTS_PATH_NO_TOP = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
VGG19:
WEIGHTS_PATH = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5
WEIGHTS_PATH_NO_TOP = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
RESNET50:
WEIGHTS_PATH = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5
WEIGHTS_PATH_NO_TOP = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
INCEPTIONS_V3:
WEIGHTS_PATH = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5
WEIGHTS_PATH_NO_TOP = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
XCEPTION:
WEIGHTS_PATH = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5
WEIGHTS_PATH_NO_TOP = ‘https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels_notop.h5

三砌左、代碼實(shí)現(xiàn)

現(xiàn)在應(yīng)用遷移學(xué)習(xí)來實(shí)現(xiàn)一個(gè)特定的圖像分類任務(wù)

# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
import os, shutil
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
from tqdm import tqdm
from sklearn.datasets import load_files
# 載入畫圖所需要的庫 matplotlib
import matplotlib.pyplot as plt 
# 導(dǎo)入karas神經(jīng)網(wǎng)絡(luò)框架
import keras
from keras.optimizers import Adam
from keras.models import Model
from keras.utils import np_utils,plot_model
from keras.preprocessing import image
from keras.applications import inception_v3
from keras.layers import Conv2D, GlobalAveragePooling2D, Activation, Dropout, Dense
from keras.callbacks import ModelCheckpoint,EarlyStopping,ReduceLROnPlateau,TensorBoard

#========================================================
#  全局參數(shù)
#========================================================

# 訓(xùn)練參數(shù)
num_epochs = 10
batch_size = 32

# 模型參數(shù)存儲(chǔ)
weight_url = 'D://saved_models/V3.hdf5'
best_weight_url = 'D://saved_models/V3_best.hdf5'

#========================================================
#  文件準(zhǔn)備
#========================================================

def image_preparation(original_dir, base_dir, labels):
    '''
    圖像分類文件準(zhǔn)備, 將文件復(fù)制到訓(xùn)練\驗(yàn)證\測(cè)試集目錄
    INPUT  -> 原始數(shù)據(jù)集地址, 數(shù)據(jù)集存放地址, 分類列表
    '''
    # 定義文件地址
    train_dir = os.path.join(base_dir, 'train')
    if not os.path.exists(train_dir):
        os.mkdir(train_dir)
    validation_dir = os.path.join(base_dir, 'validation')
    if not os.path.exists(validation_dir):
        os.mkdir(validation_dir)
    test_dir = os.path.join(base_dir, 'test')
    if not os.path.exists(test_dir):
        os.mkdir(test_dir)
    
    names = locals()
    # 圖片遷移
    for label in labels:
        names["train_"+str(label)+"dir"] =  os.path.join(train_dir, str(label))
        if not os.path.exists(names["train_"+str(label)+"dir"]):
            os.mkdir(names["train_"+str(label)+"dir"])
        names["validation_"+str(label)+"dir"] =  os.path.join(validation_dir, str(label))
        if not os.path.exists(names["validation_"+str(label)+"dir"]):
            os.mkdir(names["validation_"+str(label)+"dir"])
        names["test_"+str(label)+"dir"] =  os.path.join(test_dir, str(label))
        if not os.path.exists(names["test_"+str(label)+"dir"]):
            os.mkdir(names["test_"+str(label)+"dir"])

        fnames = [str(label)+'.{}.jpg'.format(i) for i in range(1000)]
        for fname in fnames:
            src = os.path.join(original_dir, fname)
            dst = os.path.join(names["train_"+str(label)+"dir"], fname)
            shutil.copyfile(src, dst)
        fnames = [str(label)+'.{}.jpg'.format(i) for i in range(1000, 1500)]
        for fname in fnames:
            src = os.path.join(original_dir, fname)
            dst = os.path.join(names["validation_"+str(label)+"dir"], fname)
            shutil.copyfile(src, dst)
        fnames = [str(label)+'.{}.jpg'.format(i) for i in range(1500, 2000)]
        for fname in fnames:
            src = os.path.join(original_dir, fname)
            dst = os.path.join(names["test_"+str(label)+"dir"], fname)
            shutil.copyfile(src, dst)
        print('total train '+str(label)+' images:', len(os.listdir(names["train_"+str(label)+"dir"])))
        print('total validation '+str(label)+' images:', len(os.listdir(names["validation_"+str(label)+"dir"])))
        print('total test '+str(label)+' images:', len(os.listdir(names["test_"+str(label)+"dir"])))

# 將數(shù)據(jù)分別存到各個(gè)文件夾
originial_dataset_dir = 'D:\download\kaggle_original_data'
base_dir = 'D:\cats_and_dogs'
if not os.path.exists(base_dir):
    os.mkdir(base_dir)
image_preparation(originial_dataset_dir, base_dir, ['cat','dog'])

# 分類數(shù)
n_classes = 0
for fn in os.listdir(os.path.join(base_dir, 'train')):
    n_classes += 1
#========================================================
#  圖像加載
#========================================================
def load_dataset(path):
    data = load_files(path)
    data_files = np.array(data['filenames'])
    data_targets = np_utils.to_categorical(np.array(data['target']), n_classes)
    return data_files, data_targets

train_files, train_targets = load_dataset(os.path.join(base_dir, 'train'))
valid_files, valid_targets = load_dataset(os.path.join(base_dir, 'validation'))

#========================================================
#  圖像預(yù)處理
#========================================================
def path_to_tensor(img_path):
    '''單個(gè)圖片格式處理'''
    img = image.load_img(img_path, target_size=(299, 299))
    x = image.img_to_array(img)
    # 將3維張量轉(zhuǎn)化為格式為(1, 299, 299, 3)的4維張量并進(jìn)行歸一化到0-1
    return np.expand_dims(x, axis=0).astype('float32')/255.0

def paths_to_tensor(img_paths):
    '''批量圖片格式處理'''
    list_of_tensors = [path_to_tensor(img_path) for img_path in tqdm(img_paths)]
    return np.vstack(list_of_tensors)

train_tensors = paths_to_tensor(train_files)
valid_tensors = paths_to_tensor(valid_files)

#========================================================
#  模型聲明
#========================================================
def InceptionV3_model(lr=0.005):
    '''構(gòu)造基于InceptionV3的遷移學(xué)習(xí)模型'''
    base_model = inception_v3.InceptionV3(weights='imagenet', include_top=False)

    # 凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
    for layer in base_model.layers:
        layer.trainable = False

    x = base_model.output
    # 重新配置全連接層,添加自己的全鏈接分類層
    x = GlobalAveragePooling2D(name='average_pooling2d_new')(x)
    x = Dense(1024, activation='relu', name='dense_new')(x)
    predictions = Dense(n_classes, activation='softmax', name='dense_output')(x)
    # 創(chuàng)建最終模型
    model = Model(inputs=base_model.input, outputs=predictions)

    # 模型編譯
    adam = Adam(lr=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
    model.compile(loss='categorical_crossentropy', 
                  optimizer=adam, 
                  metrics=['accuracy'])
    
    # model.summary()
    # plot_model(model, to_file='V3_model.png')
    return model

# 實(shí)例化模型
V3_model = InceptionV3_model()
#========================================================
#  模型訓(xùn)練
#========================================================
def train(X, Y, X_val, Y_val, model):
    # 載入已保存的權(quán)重, 繼續(xù)訓(xùn)練
    if os.path.exists(weight_url):
        model.load_weights(weight_url)

    # 訓(xùn)練過程中的回調(diào)函數(shù)(檢查點(diǎn)\早期停止\動(dòng)態(tài)學(xué)習(xí)率\訓(xùn)練日志)
    Checkpoint = ModelCheckpoint(filepath=best_weight_url,
                                 save_best_only=True,
                                 verbose=1)
    EarlyStop = EarlyStopping(monitor='val_loss',
                              patience=5,
                              mode='auto',
                              verbose=1)
    lrate = ReduceLROnPlateau(monitor='val_loss',  
                              factor=0.1,  # 每次減少學(xué)習(xí)率的因子铺敌,學(xué)習(xí)率將以lr = lr*factor的形式被減少 
                              patience=3,  # 當(dāng)patience個(gè)epoch過去而模型性能不提升時(shí)汇歹,學(xué)習(xí)率減少的動(dòng)作會(huì)被觸發(fā)
                              mode='auto', 
                              min_delta=0.0001, # 閾值,用來確定是否進(jìn)入檢測(cè)值的“平原區(qū)” 
                              cooldown=0, # 學(xué)習(xí)率減少后偿凭,會(huì)經(jīng)過cooldown個(gè)epoch才重新進(jìn)行正常操作
                              min_lr=0,  # 學(xué)習(xí)率的下限 
                              verbose=1)
    tb = TensorBoard(log_dir=FILE_DIR,  # log 目錄
                     histogram_freq=1,  # 按照何等頻率(epoch)來計(jì)算直方圖产弹,0為不計(jì)算
                     batch_size=batch_size,  # 用多大量的數(shù)據(jù)計(jì)算直方圖
                     write_graph=True,       # 是否存儲(chǔ)網(wǎng)絡(luò)結(jié)構(gòu)圖
                     write_grads=False,      # 是否可視化梯度直方圖
                     write_images=False,     # 是否可視化參數(shù)
                     embeddings_freq=0,
                     embeddings_layer_names=None,
                     embeddings_metadata=None)
                                
    history_ft = model.fit(X, Y,
                           validation_data = (X_val, Y_val),
                           # validation_split = 0.2,
                           epochs=num_epochs,
                           batch_size=batch_size,
                           # steps_per_epoch=None, # steps_per_epoch=10,則就是將一個(gè)epoch分為10份弯囊,不能和batch_size共同使用
                           # validation_steps=None, # 當(dāng)steps_per_epoch被啟用的時(shí)候才有用痰哨,驗(yàn)證集的batch_size
                           callbacks=[Checkpoint, EarlyStop, lrate, tb], 
                           verbose=1
                           )

    # 參數(shù)保存,留待下次繼續(xù)訓(xùn)練
    model.save_weights(weight_url, overwrite=True)
    return history_ft

def plot_training(data):
    '''繪制模型正確率曲線和損失曲線'''
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    # 正確率曲線
    plt.figure()
    plt.title('Train and valid accuracy')
    plt.plot(data.epoch,acc,label="train_acc")
    plt.plot(data.epoch,val_acc,label="val_acc")
    plt.scatter(data.epoch,data.history['acc'],marker='*')
    plt.scatter(data.epoch,data.history['val_acc'],marker='*')
    plt.legend()
    plt.show()
    # 損失曲線
    plt.figure()
    plt.title('Train and valid loss')
    plt.plot(data.epoch,loss,label="train_loss")
    plt.plot(data.epoch,val_loss,label="val_loss")
    plt.scatter(data.epoch,data.history['loss'],marker='*')
    plt.scatter(data.epoch,data.history['val_loss'],marker='*')
    plt.legend()
    plt.show()

history = train(X=train_tensors, Y=train_targets,  X_val=valid_tensors, Y_val=valid_targets, model=V3_model)
plot_training(history)
#========================================================
#  模型預(yù)測(cè)
#========================================================
def img_predict(model, img_path):
    '''判斷單張圖片'''
    prediction = model.predict(path_to_tensor(img_path))
    index = np.argmax(prediction)
    return index

# 加載最佳的模型參數(shù)
V3_model.load_weights(best_weight_url)
img_predict(V3_model, 'D://cats_and_dogs/test/cat/cat.8.jpg')
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市匾嘱,隨后出現(xiàn)的幾起案子斤斧,更是在濱河造成了極大的恐慌,老刑警劉巖霎烙,帶你破解...
    沈念sama閱讀 212,383評(píng)論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件撬讽,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡悬垃,警方通過查閱死者的電腦和手機(jī)游昼,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,522評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來尝蠕,“玉大人烘豌,你說我怎么就攤上這事】幢耍” “怎么了廊佩?”我有些...
    開封第一講書人閱讀 157,852評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)靖榕。 經(jīng)常有香客問我标锄,道長(zhǎng),這世上最難降的妖魔是什么序矩? 我笑而不...
    開封第一講書人閱讀 56,621評(píng)論 1 284
  • 正文 為了忘掉前任鸯绿,我火速辦了婚禮,結(jié)果婚禮上簸淀,老公的妹妹穿的比我還像新娘瓶蝴。我一直安慰自己,他們只是感情好租幕,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,741評(píng)論 6 386
  • 文/花漫 我一把揭開白布舷手。 她就那樣靜靜地躺著,像睡著了一般劲绪。 火紅的嫁衣襯著肌膚如雪男窟。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,929評(píng)論 1 290
  • 那天贾富,我揣著相機(jī)與錄音歉眷,去河邊找鬼。 笑死颤枪,一個(gè)胖子當(dāng)著我的面吹牛汗捡,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播畏纲,決...
    沈念sama閱讀 39,076評(píng)論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼扇住,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了盗胀?” 一聲冷哼從身側(cè)響起艘蹋,我...
    開封第一講書人閱讀 37,803評(píng)論 0 268
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎票灰,沒想到半個(gè)月后女阀,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,265評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡屑迂,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,582評(píng)論 2 327
  • 正文 我和宋清朗相戀三年强品,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片屈糊。...
    茶點(diǎn)故事閱讀 38,716評(píng)論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡的榛,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出逻锐,到底是詐尸還是另有隱情夫晌,我是刑警寧澤,帶...
    沈念sama閱讀 34,395評(píng)論 4 333
  • 正文 年R本政府宣布昧诱,位于F島的核電站晓淀,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏盏档。R本人自食惡果不足惜凶掰,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,039評(píng)論 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧懦窘,春花似錦前翎、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,798評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至午衰,卻和暖如春立宜,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背臊岸。 一陣腳步聲響...
    開封第一講書人閱讀 32,027評(píng)論 1 266
  • 我被黑心中介騙來泰國(guó)打工橙数, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人帅戒。 一個(gè)月前我還...
    沈念sama閱讀 46,488評(píng)論 2 361
  • 正文 我出身青樓灯帮,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親蜘澜。 傳聞我的和親對(duì)象是個(gè)殘疾皇子施流,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,612評(píng)論 2 350