基于tensorflow_slim模型調(diào)參的flower102鮮花分類過程
實驗軟件環(huán)境如下
windows10
tensorflow-gpu 1.11
python3.5
1.數(shù)據(jù)分析工作
1.1數(shù)據(jù)介紹
該數(shù)據(jù)集由102類產(chǎn)自英國的花卉組成。每類由40-258張圖片組成生百。具體示例如下圖所示:
![png] (https://my-picture-bed-1256685253.cos.ap-shanghai.myqcloud.com/201812/%E8%8A%B1%E7%A4%BA%E4%BE%8B.PNG)
下載地址為:http://www.robots.ox.ac.uk/~vgg/data/flowers/102/
其中有兩個mat文件標(biāo)記了整個數(shù)據(jù)集的label,具體結(jié)構(gòu)如下:
-imagelabels.mat
總共有8189列繁扎,每列上的數(shù)字代表類別號籍滴。
-setid.mat
-trnid字段:總共有1020列受楼,每10列為一類花卉的圖片转唉,每列上的數(shù)字代表圖片號命迈。
-valid字段:總共有1020列贩绕,每10列為一類花卉的圖片火的,每列上的數(shù)字代表圖片號。
-tstid字段:總共有6149列淑倾,每一類花卉的列數(shù)不定馏鹤,每列上的數(shù)字代表圖片號。
2.數(shù)據(jù)預(yù)處理
tensorflow-slim 程序包是由谷歌公司提供的圖像分類工具包,其中預(yù)訓(xùn)練的比較流行的圖像分類的神經(jīng)網(wǎng)絡(luò),比如VGG16,VGG19,InceptionV1~V4,殘差網(wǎng)絡(luò)等等,實驗中我們使用了比較新的InceptionV3模型進行訓(xùn)練.
2.1數(shù)據(jù)集圖像格式處理
對于InceptionV3網(wǎng)絡(luò),要求輸入的圖片分辨率保持一致,由于數(shù)據(jù)集中的圖片大小不一,所以需要修改分辨率后保存,這里將圖片統(tǒng)一保存為256*256的jpg格式,具體代碼如下:
#flower_dir[tid]為原圖片的絕對地址
img=Image.open(flower_dir[tid])
img = img.resize((256, 256),Image.ANTIALIAS)
#despath為生成標(biāo)準(zhǔn)圖片的保存地址
img.save(despath)
2.2數(shù)據(jù)集存儲路徑處理
在slim框架中,對于數(shù)據(jù)集的存儲路徑以及存儲格式是由要求的,具體示例如下:
data_prepare/
pic/
train/
class1/
img1
img2
...
class2
img1
img2
...
validation/
class1/
img1
img2
...
class2
img1
img2
...
所以需要根據(jù)數(shù)據(jù)集提供的標(biāo)簽規(guī)整圖片的路徑.總體代碼如下:
import scipy.io
import numpy as np
import os
from PIL import Image
import shutil
########取出 imagelabels 文件的值############
imagelabels_path='I:\\dataSet\\imagelabels.mat'
labels = scipy.io.loadmat(imagelabels_path)
labels = np.array(labels['labels'][0])-1
######## 取出 flower dataset: train test valid 數(shù)據(jù)id標(biāo)識 ########
setid_path='I:\\dataSet\\setid.mat'
setid = scipy.io.loadmat(setid_path)
validation = np.array(setid['valid'][0]) - 1
np.random.shuffle(validation)
train = np.array(setid['trnid'][0]) - 1
np.random.shuffle(train)
test=np.array(setid['tstid'][0]) -1
np.random.shuffle(test)
######## flower data path 數(shù)據(jù)保存路徑 ########
flower_dir = list()
######## flower data dirs 生成保存數(shù)據(jù)的絕對路徑和名稱 ########
for img in os.listdir("I:\\dataSet\\102flowers"):
######## flower data ########
flower_dir.append(os.path.join("I:\\dataSet\\102flowers", img))
######## flower data dirs sort 數(shù)據(jù)的絕對路徑和名稱排序 從小到大 ########
flower_dir.sort()
#print(flower_dir)
#####生成flower data train的分類數(shù)據(jù) #######
des_folder_train="I:\\dataSet\\prepare_pic\\train"
for tid in train:
######## open image and get label ########
img=Image.open(flower_dir[tid])
#print(flower_dir[tid])
######## resize img #######
img = img.resize((256, 256),Image.ANTIALIAS)
lable=labels[tid]
#print(lable)
path=flower_dir[tid]
#print("path:",path)
base_path=os.path.basename(path)
#print("base_path:",base_path)
######類別目錄路徑
classes="c"+str(lable)
class_path=os.path.join(des_folder_train,classes)
if not os.path.exists(class_path):
os.makedirs(class_path)
#print("class_path:",class_path)
despath=os.path.join(class_path,base_path)
#print("despath:",despath)
img.save(despath)
#####生成flower data validation的分類數(shù)據(jù) #######
des_folder_validation="I:\\dataSet\\prepare_pic\\validation"
for tid in validation:
######## open image and get label ########
img=Image.open(flower_dir[tid])
#print(flower_dir[tid])
img = img.resize((256, 256),Image.ANTIALIAS)
lable=labels[tid]
#print(lable)
path=flower_dir[tid]
print("path:",path)
base_path=os.path.basename(path)
print("base_path:",base_path)
classes="c"+str(lable)
class_path=os.path.join(des_folder_validation,classes)
# 判斷結(jié)果
if not os.path.exists(class_path):
os.makedirs(class_path)
print("class_path:",class_path)
despath=os.path.join(class_path,base_path)
print("despath:",despath)
img.save(despath)
#####生成flower data test的分類數(shù)據(jù) #######
des_folder_test="I:\\dataSet\\prepare_pic\\test"
for tid in test:
######## open image and get label ########
img=Image.open(flower_dir[tid])
#print(flower_dir[tid])
img = img.resize((256, 256),Image.ANTIALIAS)
lable=labels[tid]
#print(lable)
path=flower_dir[tid]
print("path:",path)
base_path=os.path.basename(path)
print("base_path:",base_path)
classes="c"+str(lable)
class_path=os.path.join(des_folder_test,classes)
# 判斷結(jié)果
if not os.path.exists(class_path):
os.makedirs(class_path)
print("class_path:",class_path)
despath=os.path.join(class_path,base_path)
print("despath:",despath)
img.save(despath)
數(shù)據(jù)生成之后,共生成三個目錄,分別為train,test,validation如下目錄格式:
文件數(shù)量如下所示:
train:
102類:1020個圖片
validation:
102類:1020幅圖片
test:
102類:6149幅圖片
標(biāo)準(zhǔn)圖片已經(jīng)路徑的處理工作完成之后,需要使用slim提供的腳本將圖片轉(zhuǎn)換為tfrecord格式,該格式作為tensorflow高速讀取的二進制文件,數(shù)據(jù)的高速傳輸提供了接口,具體使用的教程可以參考該博主.
在實驗過程中,我們使用預(yù)先編譯好的腳本文件data_convert.py對圖片進行轉(zhuǎn)換,進入到該文件所在目錄,使用如下命令:
python data_convert.py -t I:\\prepare_data\\prepare_pic #生成圖片根目錄路徑
--train-shards 5\ #切成5兩個tfrecord train文件
--validation-shards 5\ #切成5兩個tfrecord train文件
--num-threads 5\ #啟動五個線程運算
--dataset-name flower102 #文件名頭
運行完成后生成以下文件:
3.模型選擇
4.模型微調(diào)
4.1 拷貝文件到數(shù)據(jù)集目錄
- 首先將生成的tfrecord文件以及l(fā)abel.txt拷貝到slim模型中,具體路徑為slim/flower102/data
4.2定義新的datasets文件
對模型有一定的了解之后,我們進入到模型微調(diào)階段,要將slim/datasets文件中的flowers.py做一些修改,并且另存flowers102.py具體修改以及解釋如下:
#將tfrecord文件的文件頭改為flower102,對應(yīng)生成tfrecord文件過程中的--dataset-name flower102命令
_FILE_PATTERN = 'flower102_%s_*.tfrecord'
# 設(shè)置訓(xùn)練集與驗證集的圖片個數(shù),都是1020
SPLITS_TO_SIZES = {'train': 1020, 'validation': 1020}
#設(shè)置類別個數(shù):102
_NUM_CLASSES = 102
#將圖片格式改為"jpg"
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
修改完flowers102.py后,還需要對同目錄下的dataset_factory.py進行修改,具體修改內(nèi)容如下:
from datasets import flower102
datasets_map = {
'flower102':flower102,
}
具體就是把剛才新建的flower102添加到包中.
5.訓(xùn)練模型
5.1 準(zhǔn)備訓(xùn)練文件夾:
在slim文件中建立以下目錄結(jié)構(gòu):
slim/
flower102/
data/
pretrained/
train_dir/
- data中存放tfrecord數(shù)據(jù),已經(jīng)在4.1步完成
- pretrained中放置已經(jīng)訓(xùn)練好的InceptionV3的模型,可以在網(wǎng)上下載,源文件中也已經(jīng)包含.
- train_dir是用來保存訓(xùn)練過程中存儲的模型的文件夾.
5.2 開始訓(xùn)練模型
在slim文件夾中,使用train_image_classifier.py文件對模型進行訓(xùn)練,具體命令行以及解釋如下:
python train_image_classifier.py \
#模型保存路徑
--train_dir=flower102/train_dir \
#數(shù)據(jù)集名稱
--dataset_name=flower102 \
#數(shù)據(jù)集切分后的第二名稱(train)
--dataset_split_name=train \
#數(shù)據(jù)集所在目錄
--dataset_dir=flower102/data \
#使用的模型名稱
--model_name=inception_v3 \
#使用的模型的地址
--checkpoint_path=flower102/pretrained/inception_v3.ckpt \
#微調(diào)層(在恢復(fù)訓(xùn)練模型時,不恢復(fù)這兩層,這兩層對V3模型的末端層,原模型對應(yīng)1000類,而新模型只對應(yīng)102類)
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
#最大迭代次數(shù)
--max_number_of_steps=100 \
#batch_size
--batch_size=32
#學(xué)習(xí)率
--learning_rate=0.001 \
#學(xué)習(xí)率是否自動下降 此處為固定值
--learning_rate_decay_type=fixed \
#間隔多久保存一次模型
--save_interval_secs=50 \
#間隔多久寫入日志以供tensorborad查看
--save_summaries_secs=2 \
#間隔迭代次數(shù)打印
--log_every_n_steps=10 \
#選定優(yōu)化器
--optimizer=rmsprop \
#選定模型中2次正則化超參數(shù)
--weight_decay=0.00004 \
使用該命令對模型進行訓(xùn)練,訓(xùn)練過程部分截圖如下:
使用tensorboard工具可以查看到損失函數(shù)下降的過程:
tensorboard --logdir flower102/train_dir
6.驗證模型
驗證過程與訓(xùn)練過程所使用的命令類似,如下:
python eval_image_classifier.py \
--checkpoint_path=/tmp/tfmodel/model.ckpt-10000 \
--eval_dir=flower102/eval_dir \
--dataset_dir=flower102/data \
--dataset_name=flower102 \
--dataset_split_name=validation \
--model_name=inception_v3
驗證結(jié)果如下:
可以看出,準(zhǔn)確率有83%,而top2的召回率有90%左右的成績.
也可以使用tensorboard查看驗證過程:
7.測試模型
7.1導(dǎo)出模型
tensorflow_slim提供了導(dǎo)出模型框架的腳本export_inference_graph.py,可以將模型框架導(dǎo)出,在通過使用freeze_graph.py將訓(xùn)練好的參數(shù)值導(dǎo)入到模型中去.
step 1
輸出框架
python export_inference_graph.py \
--alsologtosterr \
--model_name=inception_v3 \
--output_file=flower102/inception_v3_inf_graph.pb \
--dataset_name flower102
step 2
注入?yún)?shù)數(shù)據(jù)
進入freeze_graph.py所在文件目錄,輸入:
python freeze_graph.py \
--input_graph slim/flower102/inception_v3_inf_graph.pb \
--input_checkpoint flower102\train_dir/model.ckpt-100000 \
--input_binary true \
--output_node_names InceptionV3/Predictions/Reshape_1 \
--output_graph slim/flower102/frozen_graph.pb
經(jīng)過這兩步之后,帶有參數(shù)值的模型就構(gòu)造好了,接下來就可以使用這個模型進行測試工作:
運行根目錄下的classify_image_incepetion_v3.py,并對以下輸入?yún)?shù)進行修改,更正為自己所使用的測試圖片與模型名稱:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path',#模型的路徑,使用填充數(shù)據(jù)的模型框架
default='slim/flower102/frozen_graph.pb',
type=str,
)
parser.add_argument(
'--label_path',#label地址,在生成tfrecord文件過程中自動生成了label.txt,制定為其地址.
default='slim/flower102/data/label.txt',
type=str,
)
parser.add_argument(
'--image_file',#測試圖片的地址,這里使用了相對地址
type=str,
default='image_07111.jpg',
help='Absolute path to image file.'
)
parser.add_argument(
'--num_top_predictions',#給出top n的可能結(jié)果
type=int,
default=5,
help='Display this many predictions.'
)
以下為驗證的結(jié)果截圖:
測試image_07111.jpg這張圖片,結(jié)果如下:
可以看出C9的概率最高,對比該圖片與C9類,可見結(jié)果正確.