1.準(zhǔn)備數(shù)據(jù)集:
本次以圖像三分類為例即寒,準(zhǔn)備貓议蟆、狗锡宋、熊貓三種動(dòng)物的圖片數(shù)據(jù)(每種各1000張圖片),依次存放在'./dataset/cats'媳否、'./dataset/dogs'栅螟、'./dataset/pandas'文件夾中荆秦。
2.網(wǎng)絡(luò)結(jié)構(gòu):
# 導(dǎo)入所需模塊
from keras.models import Sequential
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.initializers import TruncatedNormal
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras import backend as K
class SimpleVGGNet:
@staticmethod
def build(width, height, depth, classes):
model = Sequential()
inputShape = (height, width, depth)
chanDim = -1
if K.image_data_format() == "channels_first":
inputShape = (depth, height, width)
chanDim = 1
# CONV => RELU => POOL
model.add(Conv2D(32, (3, 3), padding="same",
input_shape=inputShape,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))
# (CONV => RELU) * 2 => POOL
model.add(Conv2D(64, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(64, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))
# (CONV => RELU) * 3 => POOL
model.add(Conv2D(128, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same",kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))
# FC層
model.add(Flatten())
model.add(Dense(512,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("relu"))
model.add(BatchNormalization())
model.add(Dropout(0.6))
# softmax 分類
model.add(Dense(classes,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)))
model.add(Activation("softmax"))
return model
3.訓(xùn)練模型:
# 導(dǎo)入所需工具包
from CNN_net import SimpleVGGNet
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import utils_paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import random
import pickle
import cv2
import os
# 讀取數(shù)據(jù)和標(biāo)簽
print("------開始讀取數(shù)據(jù)------")
data = []
labels = []
# 拿到圖像數(shù)據(jù)路徑,方便后續(xù)讀取
imagePaths = sorted(list(utils_paths.list_images('./dataset')))
random.seed(42)
random.shuffle(imagePaths)
# 遍歷讀取數(shù)據(jù)
for imagePath in imagePaths:
# 讀取圖像數(shù)據(jù)
image = cv2.imread(imagePath)
image = cv2.resize(image, (64, 64))
data.append(image)
# 讀取標(biāo)簽
label = imagePath.split(os.path.sep)[-2]
labels.append(label)
# 對圖像數(shù)據(jù)做scale操作
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# 數(shù)據(jù)集切分
(trainX, testX, trainY, testY) = train_test_split(data,labels, test_size=0.25, random_state=42)
# 轉(zhuǎn)換標(biāo)簽為one-hot encoding格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 數(shù)據(jù)增強(qiáng)處理
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode="nearest")
# 建立卷積神經(jīng)網(wǎng)絡(luò)
model = SimpleVGGNet.build(width=64, height=64, depth=3,classes=len(lb.classes_))
# 設(shè)置初始化超參數(shù)
INIT_LR = 0.01
EPOCHS = 30
BS = 32
# 損失函數(shù)力图,編譯模型
print("------準(zhǔn)備訓(xùn)練網(wǎng)絡(luò)------")
opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])
# 訓(xùn)練網(wǎng)絡(luò)模型
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS)
"""
H = model.fit(trainX, trainY, validation_data=(testX, testY),
epochs=EPOCHS, batch_size=32)
"""
# 測試
print("------測試網(wǎng)絡(luò)------")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1), target_names=lb.classes_))
# 繪制結(jié)果曲線
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig('./output/cnn_plot.png')
# 保存模型
print("------正在保存模型------")
model.save('./output/cnn.model')
f = open('./output/cnn_lb.pickle', "wb")
f.write(pickle.dumps(lb))
f.close()
運(yùn)行得到如下文件數(shù)據(jù):
4.加載模型進(jìn)行預(yù)測:
# 導(dǎo)入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2
# 加載測試數(shù)據(jù)并進(jìn)行相同預(yù)處理操作
image = cv2.imread('./cs_image/panda.jpg')
output = image.copy()
image = cv2.resize(image, (64, 64))
# scale圖像數(shù)據(jù)
image = image.astype("float") / 255.0
# 對圖像進(jìn)行拉平操作
image = image.reshape((1, image.shape[0], image.shape[1],image.shape[2]))
# 讀取模型和標(biāo)簽
print("------讀取模型和標(biāo)簽------")
model = load_model('./output/cnn.model')
lb = pickle.loads(open('./output/cnn_lb.pickle', "rb").read())
# 預(yù)測
preds = model.predict(image)
# 得到預(yù)測結(jié)果以及其對應(yīng)的標(biāo)簽
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]
# 在圖像中把結(jié)果畫出來
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
# 繪圖
cv2.imshow("Image", output)
cv2.waitKey(0)
最終得到預(yù)測結(jié)果:
5.附錄:
utils_paths.py
代碼如下:
import os
image_types = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
def list_images(basePath, contains=None):
# 返回有效的圖片路徑數(shù)據(jù)集
return list_files(basePath, validExts=image_types, contains=contains)
def list_files(basePath, validExts=None, contains=None):
# 遍歷圖片數(shù)據(jù)目錄步绸,生成每張圖片的路徑
for (rootDir, dirNames, filenames) in os.walk(basePath):
# 循環(huán)遍歷當(dāng)前目錄中的文件名
for filename in filenames:
# if the contains string is not none and the filename does not contain
# the supplied string, then ignore the file
if contains is not None and filename.find(contains) == -1:
continue
# 通過確定.的位置,從而確定當(dāng)前文件的文件擴(kuò)展名
ext = filename[filename.rfind("."):].lower()
# 檢查文件是否為圖像搪哪,是否應(yīng)進(jìn)行處理
if validExts is None or ext.endswith(validExts):
# 構(gòu)造圖像路徑
imagePath = os.path.join(rootDir, filename)
yield imagePath