caffe的mnist訓(xùn)練起來挺方便酌媒,但是怎么進(jìn)行預(yù)測呢?
參考:
http://blog.csdn.net/l691899397/article/details/52233454
http://www.reibang.com/p/9644f7ec0a03
http://www.reibang.com/p/9e30328a0a71
理論可以參考一下第一個(gè)博客迄靠,也可以看看論文秒咨。
我認(rèn)為進(jìn)行預(yù)測兩個(gè)關(guān)鍵的步驟是:1.加載訓(xùn)練好的caffemodel和模型描述文件deploy 2.是把要預(yù)測的圖片正確的導(dǎo)入,這需要理解caffe的Blob掌挚。下面進(jìn)入正題:
一.建立模型描述文件
讀了caffe的整體框架(本新只看了整體框架雨席,對(duì)于源代碼看的不深入)了解caffe極度模塊化,便于閱讀和理解吠式,對(duì)于整個(gè)神經(jīng)網(wǎng)絡(luò)的編寫都快接近可視化了陡厘。我的代碼在我的github上下載
https://github.com/zefan7564/caffe
下面是我進(jìn)行預(yù)測的網(wǎng)絡(luò)層。
name: "LeNet"
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 1 dim: 1 dim: 28 dim: 28 } }
}
layer {
name: "conv1"
type: "Convolution"
bottom: "data"
top: "conv1"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv2"
type: "Convolution"
bottom: "pool1"
top: "conv2"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "ip1"
type: "InnerProduct"
bottom: "pool2"
top: "ip1"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param {
num_output: 500
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "relu1"
type: "ReLU"
bottom: "ip1"
top: "ip1"
}
layer {
name: "ip2"
type: "InnerProduct"
bottom: "ip1"
top: "ip2"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param {
num_output: 10
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "prob"
type: "Softmax"
bottom: "ip2"
top: "prob"
}
值得注意的是
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 1 dim: 1 dim: 28 dim: 28 } }
}
因?yàn)閏affe的數(shù)據(jù)格式是四維數(shù)組特占,[num,channel,weight,high],因此這樣設(shè)計(jì)數(shù)據(jù)層糙置。接著是最后一個(gè)層prob,與訓(xùn)練時(shí)不同是目。它只需要前向傳播罢低,輸出數(shù)據(jù)。因此不需要loss層和accuracy層。
然后是圖像預(yù)處理減去均值网持,這樣在訓(xùn)練和預(yù)測都能提高速度宜岛。
然后是訓(xùn)練好的模型,這里就使用lenet_iter_10000.caffemodel
二.使用訓(xùn)練好的模型進(jìn)行預(yù)測功舀,這里我直接貼出兩個(gè)代碼
#coding=utf-8
#caffe and opencv test mnist
#test by yuzefan
import os
import caffe
import numpy as np
import cv2
import sys
caffe_root='/home/ubuntu/caffe-master/'
sys.path.insert(0,caffe_root+'python') #add this python path
os.chdir(caffe_root)
MODEL_FILE=caffe_root+'mytest/my-mnist/classificat_net.prototxt'
WEIGTHS=caffe_root+'mytest/my-mnist/lenet_iter_10000.caffemodel'
net=caffe.Classifier(MODEL_FILE,WEIGTHS)
caffe.set_mode_gpu()
IMAGE_PATH=caffe_root+'mytest/smy-mnist/'
font = cv2.FONT_HERSHEY_SIMPLEX #normal size sans-serif font
for i in range(0,9):
# astype() is a method provided by numpy to convert numpy dtype.
input_image=cv2.imread(IMAGE_PATH+'{}.png'.format(i),cv2.IMREAD_GRAYSCALE).astype(np.float32)
#resize Image to improve vision effect.
resized=cv2.resize(input_image,(280,280),None,0,0,cv2.INTER_AREA)
input_image = input_image[:, :, np.newaxis] # input_image.shape is (28, 28, 1), with dtype float32
prediction = net.predict([input_image], oversample=False)
cv2.putText(resized, str(prediction[0].argmax()), (200, 280), font, 4, (255,), 2)
cv2.imshow("Prediction", resized)
print 'predicted class:', prediction[0].argmax()
keycode = cv2.waitKey(0) & 0xFF
if keycode == 27:
break
#coding=utf-8
#caffe and opencv test mnist
#test by yuzefan
import os
import caffe
import numpy as np
import cv2
import sys
caffe_root='/home/ubuntu/caffe-master/'
sys.path.insert(0,caffe_root+'python') #add this python path
os.chdir(caffe_root)
MODEL_FILE=caffe_root+'mytest/my-mnist/classificat_net.prototxt'
WEIGTHS=caffe_root+'mytest/my-mnist/lenet_iter_10000.caffemodel'
MEAN_FILE=caffe_root+'mytest/my-mnist/mean.binaryproto'
print('Params loaded!')
cv2.waitKey(1000)
caffe.set_mode_gpu()
net=caffe.Net(MODEL_FILE,WEIGTHS,caffe.TEST)
mean_blob=caffe.proto.caffe_pb2.BlobProto()
mean_blob.ParseFromString(open(MEAN_FILE, 'rb').read())
mean_npy = caffe.io.blobproto_to_array(mean_blob)
a=mean_npy[0, :, 0, 0]
print(net.blobs['data'].data.shape)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
#transformer.set_transpose('data', (2, 0, 1))
##transformer.set_raw_scale('data', 255)
#transformer.set_channel_swap('data', (2, 1, 0))
for i in range(0,10):
IMAGE_PATH=caffe_root+'mytest/my-mnist/{}.png'.format(i)
#im = caffe.io.load_image(IMAGE_PATH)
input_image=cv2.imread(IMAGE_PATH,cv2.IMREAD_GRAYSCALE).astype(np.float32)
resized=cv2.resize(input_image,(280,280),None,0,0,cv2.INTER_AREA)
net.blobs['data'].data[...] = transformer.preprocess('data', input_image)
predict = net.forward()
names = []
with open('/home/ubuntu/caffe-master/mytest/my-mnist/words.txt', 'r+') as f:
for l in f.readlines():
names.append(l.split(' ')[1].strip())
print(names)
prob = net.blobs['prob'].data[0].flatten()
print('prob: ', prob)
print('class: ', names[np.argmax(prob)])
cv2.imshow("Prediction", resized)
keycode = cv2.waitKey(0) & 0xFF
if keycode == 27:
break
下面以代碼一和代碼二簡稱萍倡。
代碼一關(guān)鍵的地方是
net=caffe.Classifier(MODEL_FILE,WEIGTHS)
prediction = net.predict([input_image], oversample=False)
可以看出是利用opencv讀入灰度圖加入到net中進(jìn)行預(yù)測
代碼二關(guān)鍵的地方是
net=caffe.Net(MODEL_FILE,WEIGTHS,caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
#transformer.set_transpose('data', (2, 0, 1))
##transformer.set_raw_scale('data', 255)
#transformer.set_channel_swap('data', (2, 1, 0))
#img = caffe.io.load_image(IMAGE_PATH)
input_image=cv2.imread(IMAGE_PATH,cv2.IMREAD_GRAYSCALE).astype(np.float32)
net.blobs['data'].data[...] = transformer.preprocess('data', input_image)
predict = net.forward()
可以看到我注釋掉許多代碼,但是又很關(guān)鍵因此我沒有刪掉辟汰。對(duì)于transformer列敲,是對(duì)data進(jìn)行變換。但是利用caffe.io.load_image(IMAGE_PATH)讀入圖片帖汞,每次都會(huì)讀到[28,28,3]的矩陣戴而,也就是每次都會(huì)讀三個(gè)通道,于是出現(xiàn)如下錯(cuò)誤
could not broadcast input array from shape (28,28,3) into shape (1,1,28,28)
無奈之下又使用opencv,讀到的是(W,H,C)翩蘸。所意。。望高手替我解答一下催首。
但是對(duì)于彩色圖像應(yīng)該選用caffe.io.load_image(IMAGE_PATH)
下面是運(yùn)行效果
![Upload Screenshot from 2017-08-22 15:50:07.png failed. Please try again.]
對(duì)于疑問已經(jīng)解決:
cv2.imread()接口讀圖像扶踊,讀進(jìn)來直接是gray 格式and 0~255,所以不需要再縮放到[0,255]和通道變換[2,1,0]不需要
transformer.set_raw_scale('data',255)和transformer.set_channel_swap('data',(2,1,0)
是caffe.io.load_image()讀進(jìn)來是RGB格式和0~1(float)所以在進(jìn)行特征提取之前要在transformer中設(shè)置transformer.set_raw_scale('data',255)(縮放至0~255)
以及transformer.set_channel_swap('data',(2,1,0)(將RGB變換到BGR)
完畢!