使用Keras訓(xùn)練自己的數(shù)據(jù)集——以圖像多分類為例(基于卷積神經(jīng)網(wǎng)絡(luò))

1.準(zhǔn)備數(shù)據(jù)集:

本次以圖像三分類為例即寒,準(zhǔn)備貓议蟆、狗锡宋、熊貓三種動(dòng)物的圖片數(shù)據(jù)(每種各1000張圖片),依次存放在'./dataset/cats'媳否、'./dataset/dogs'栅螟、'./dataset/pandas'文件夾中荆秦。

2.網(wǎng)絡(luò)結(jié)構(gòu):

# 導(dǎo)入所需模塊
from keras.models import Sequential
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.initializers import TruncatedNormal
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras import backend as K

class SimpleVGGNet:
    @staticmethod
    def build(width, height, depth, classes):
        model = Sequential()
        inputShape = (height, width, depth)
        chanDim = -1

        if K.image_data_format() == "channels_first":
            inputShape = (depth, height, width)
            chanDim = 1

        # CONV => RELU => POOL
        model.add(Conv2D(32, (3, 3), padding="same",
            input_shape=inputShape,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=chanDim))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        #model.add(Dropout(0.25))

        # (CONV => RELU) * 2 => POOL
        model.add(Conv2D(64, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=chanDim))
        model.add(Conv2D(64, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=chanDim))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        #model.add(Dropout(0.25))

        # (CONV => RELU) * 3 => POOL
        model.add(Conv2D(128, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=chanDim))
        model.add(Conv2D(128, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=chanDim))
        model.add(Conv2D(128, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=chanDim))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        #model.add(Dropout(0.25))

        # FC層
        model.add(Flatten())
        model.add(Dense(512,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("relu"))
        model.add(BatchNormalization())
        model.add(Dropout(0.6))

        # softmax 分類
        model.add(Dense(classes,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
        model.add(Activation("softmax"))

        return model

3.訓(xùn)練模型:

# 導(dǎo)入所需工具包
from CNN_net import SimpleVGGNet
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import utils_paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import random
import pickle
import cv2
import os


# 讀取數(shù)據(jù)和標(biāo)簽
print("------開始讀取數(shù)據(jù)------")
data = []
labels = []

# 拿到圖像數(shù)據(jù)路徑,方便后續(xù)讀取
imagePaths = sorted(list(utils_paths.list_images('./dataset')))
random.seed(42)
random.shuffle(imagePaths)

# 遍歷讀取數(shù)據(jù)
for imagePath in imagePaths:
    # 讀取圖像數(shù)據(jù)
    image = cv2.imread(imagePath)
    image = cv2.resize(image, (64, 64))
    data.append(image)
    # 讀取標(biāo)簽
    label = imagePath.split(os.path.sep)[-2]
    labels.append(label)

# 對圖像數(shù)據(jù)做scale操作
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)

# 數(shù)據(jù)集切分
(trainX, testX, trainY, testY) = train_test_split(data,labels, test_size=0.25, random_state=42)

# 轉(zhuǎn)換標(biāo)簽為one-hot encoding格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

# 數(shù)據(jù)增強(qiáng)處理
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
    height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
    horizontal_flip=True, fill_mode="nearest")

# 建立卷積神經(jīng)網(wǎng)絡(luò)
model = SimpleVGGNet.build(width=64, height=64, depth=3,classes=len(lb.classes_))

# 設(shè)置初始化超參數(shù)
INIT_LR = 0.01
EPOCHS = 30
BS = 32

# 損失函數(shù)力图,編譯模型
print("------準(zhǔn)備訓(xùn)練網(wǎng)絡(luò)------")
opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])

# 訓(xùn)練網(wǎng)絡(luò)模型
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
    validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
    epochs=EPOCHS)
"""
H = model.fit(trainX, trainY, validation_data=(testX, testY),
    epochs=EPOCHS, batch_size=32)
"""


# 測試
print("------測試網(wǎng)絡(luò)------")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
    predictions.argmax(axis=1), target_names=lb.classes_))

# 繪制結(jié)果曲線
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig('./output/cnn_plot.png')

# 保存模型
print("------正在保存模型------")
model.save('./output/cnn.model')
f = open('./output/cnn_lb.pickle', "wb")
f.write(pickle.dumps(lb))
f.close()

運(yùn)行得到如下文件數(shù)據(jù):

4.加載模型進(jìn)行預(yù)測:

# 導(dǎo)入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2


# 加載測試數(shù)據(jù)并進(jìn)行相同預(yù)處理操作
image = cv2.imread('./cs_image/panda.jpg')
output = image.copy()
image = cv2.resize(image, (64, 64))

# scale圖像數(shù)據(jù)
image = image.astype("float") / 255.0

# 對圖像進(jìn)行拉平操作
image = image.reshape((1, image.shape[0], image.shape[1],image.shape[2]))

# 讀取模型和標(biāo)簽
print("------讀取模型和標(biāo)簽------")
model = load_model('./output/cnn.model')
lb = pickle.loads(open('./output/cnn_lb.pickle', "rb").read())

# 預(yù)測
preds = model.predict(image)

# 得到預(yù)測結(jié)果以及其對應(yīng)的標(biāo)簽
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]

# 在圖像中把結(jié)果畫出來
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)

# 繪圖
cv2.imshow("Image", output)
cv2.waitKey(0)

最終得到預(yù)測結(jié)果:

5.附錄:

utils_paths.py代碼如下:

import os
 
 
image_types = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
 
 
def list_images(basePath, contains=None):
    # 返回有效的圖片路徑數(shù)據(jù)集
    return list_files(basePath, validExts=image_types, contains=contains)
 
 
def list_files(basePath, validExts=None, contains=None):
    # 遍歷圖片數(shù)據(jù)目錄步绸,生成每張圖片的路徑
    for (rootDir, dirNames, filenames) in os.walk(basePath):
        # 循環(huán)遍歷當(dāng)前目錄中的文件名
        for filename in filenames:
            # if the contains string is not none and the filename does not contain
            # the supplied string, then ignore the file
            if contains is not None and filename.find(contains) == -1:
                continue
 
            # 通過確定.的位置,從而確定當(dāng)前文件的文件擴(kuò)展名
            ext = filename[filename.rfind("."):].lower()
 
            # 檢查文件是否為圖像搪哪,是否應(yīng)進(jìn)行處理
            if validExts is None or ext.endswith(validExts):
                # 構(gòu)造圖像路徑
                imagePath = os.path.join(rootDir, filename)
                yield imagePath
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末靡努,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子晓折,更是在濱河造成了極大的恐慌惑朦,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,402評(píng)論 6 499
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件漓概,死亡現(xiàn)場離奇詭異漾月,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)胃珍,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,377評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門梁肿,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人觅彰,你說我怎么就攤上這事吩蔑。” “怎么了填抬?”我有些...
    開封第一講書人閱讀 162,483評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵烛芬,是天一觀的道長。 經(jīng)常有香客問我飒责,道長赘娄,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,165評(píng)論 1 292
  • 正文 為了忘掉前任宏蛉,我火速辦了婚禮遣臼,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘拾并。我一直安慰自己揍堰,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,176評(píng)論 6 388
  • 文/花漫 我一把揭開白布嗅义。 她就那樣靜靜地躺著屏歹,像睡著了一般。 火紅的嫁衣襯著肌膚如雪芥喇。 梳的紋絲不亂的頭發(fā)上西采,一...
    開封第一講書人閱讀 51,146評(píng)論 1 297
  • 那天,我揣著相機(jī)與錄音继控,去河邊找鬼械馆。 笑死胖眷,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的霹崎。 我是一名探鬼主播珊搀,決...
    沈念sama閱讀 40,032評(píng)論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼尾菇!你這毒婦竟也來了境析?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,896評(píng)論 0 274
  • 序言:老撾萬榮一對情侶失蹤派诬,失蹤者是張志新(化名)和其女友劉穎劳淆,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體默赂,經(jīng)...
    沈念sama閱讀 45,311評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡沛鸵,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,536評(píng)論 2 332
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了缆八。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片曲掰。...
    茶點(diǎn)故事閱讀 39,696評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖奈辰,靈堂內(nèi)的尸體忽然破棺而出栏妖,到底是詐尸還是另有隱情,我是刑警寧澤奖恰,帶...
    沈念sama閱讀 35,413評(píng)論 5 343
  • 正文 年R本政府宣布吊趾,位于F島的核電站,受9級(jí)特大地震影響房官,放射性物質(zhì)發(fā)生泄漏趾徽。R本人自食惡果不足惜续滋,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,008評(píng)論 3 325
  • 文/蒙蒙 一翰守、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧疲酌,春花似錦蜡峰、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至粥诫,卻和暖如春油航,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背怀浆。 一陣腳步聲響...
    開封第一講書人閱讀 32,815評(píng)論 1 269
  • 我被黑心中介騙來泰國打工谊囚, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留怕享,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,698評(píng)論 2 368
  • 正文 我出身青樓镰踏,卻偏偏與公主長得像函筋,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子奠伪,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,592評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容