常規(guī)的基于CNN的圖像分類網絡如Lenet扣孟、Alexnet勉耀、VGGnet等都是單分類模型瘫想,本文記錄在ubuntu16.04下如何對傳統(tǒng)的單分類模型進行調整醋粟,實現多標簽分類的效果粪躬,這里主要指的是對固定長度字符串的識別,相同原理可用于驗證碼識別和車牌識別昔穴。
聲明:本文代碼主要來自于以下兩篇博文:
下面整理了使用caffe完成多標簽分類(multi-label classification)模型訓練測試的整個流程镰官,主要分為4個部分:
- 如何制作多標簽分類數據集;
- 修改caffe源代碼吗货,實現多標簽數據集的轉換和讀扔具搿;
- 修改分類模型Alexnet宙搬,實現多標簽分類笨腥;
- 模型的訓練和測試拓哺。
1.如何制作多標簽分類數據集
制作的數據集圖片類似于:
這里的每張圖片中包含4個字符(0-9或者A-Z),通過對代碼的簡單修改脖母,可以擴展成任意長度士鸥。
為了簡單,將車牌識別中的不分割字符的端到端(End-to-End)識別中的源代碼修改簡化谆级。
首先建立一個名為multi-label-classification
的文件夾烤礁,下面的子目錄/子文件如下:
其中藍色的是文件夾,其他顏色的是文件肥照。
生成多標簽字符圖片的思路大概是:
首先確定字符串的長度脚仔,即想要生成包含幾個標簽的圖片;
根據字符串的長度舆绎,確定圖像的尺寸;比如我生成4個字符的圖片吕朵,再考慮單個字符和長寬比,字符間的間隙努溃,以及字體的大小,確定4字符圖像的長和寬是90x30;
需要找到一種.ttf格式的字體则拷,這根據實際情況選擇合適的字體;
-
接下來斥铺,需要確定圖像要用什么樣的背景;比如我隨便找了十幾種顏色的背景圖片(放在../background/文件夾下)坛善,部分顯示如下晾蜘,每張都是90x30大小。
- 最后考慮需要對字符串圖像做什么處理眠屎,比如隨機旋轉剔交,畸變處理,加入噪聲改衩,模糊等等岖常,用于增強模型的泛化能力。
下面是gen_character.py的代碼:
#coding=utf-8
import PIL
from PIL import ImageFont
from PIL import Image
from PIL import ImageDraw
import cv2
import numpy as np
import os
from math import *
chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A","B", "C", "D", "E", "F", "G", "H", "I",
"J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X","Y", "Z"];
# 引入畸變葫督,將使字符隨機地向左或者向右傾斜一個隨機的角度(4-10個像素值)
def distortionRandom(img):
w = img.shape[1]
h = img.shape[0]
pts1 = np.float32([[0, 0], [0, h], [w, 0], [w, h]])
pos_or_neg = np.random.random_integers(0,1)
distortion_value = np.random.random_integers(4,10)
if(pos_or_neg==0):
pts2 = np.float32([[0, 0], [distortion_value, h], [w-distortion_value, 0], [w, h]])
else:
pts2 = np.float32([[distortion_value, 0], [0, h], [w, 0], [w-distortion_value, h]])
M = cv2.getPerspectiveTransform(pts1, pts2)
dst = cv2.warpPerspective(img, M, (w,h))
return dst
# 在背景圖像塊中寫入一個字符
def GenCh(f,val):
img=Image.new("RGB", (16,28),(255,255,255))
draw = ImageDraw.Draw(img)
draw.text((2, 0),val.decode('utf-8'),(0,0,0),font=f, align="center")
A = np.array(img)
A = cv2.resize(A, (22,28))
return A
# 定義一個類GenCharacter竭鞍,用于生成固定長度多標簽圖片
class GenCharacter:
def __init__(self,font):
# 初始化所用的字符字體
self.fontE = ImageFont.truetype(font,28,0)
# 初始化多標簽圖片的大小為90x30
self.img=np.array(Image.new("RGB", (90,30),(255,255,255)))
# 初始化標簽圖片所用的背景板惑,這里在./background/文件夾中準備了十幾張90x30的不同背景
# 全部讀取到一個list中,生成多標簽圖片時隨機選擇某一個背景
self.bgs = []
for file in os.listdir("./background/"):
bg = cv2.resize(cv2.imread("./background/"+file),(90,30))
self.bgs.append(bg)
# 將長度為4的字符串寫入90x30的圖片中
def draw(self,val):
offset = 2
for i in range(4):
base = offset + i*22
self.img[0:28, base:base+22]= GenCh(self.fontE,val[i])
return self.img
# 生成一張帶隨機背景的隨機字符串
def generate(self,text):
if len(text) == 4:
fg = self.draw(text.decode(encoding="utf-8"))
fg = cv2.bitwise_not(fg)
k = np.random.random_integers(0,len(self.bgs)-1)
com = cv2.bitwise_or(fg,self.bgs[k])
com = distortionRandom(com)
com = cv2.bitwise_or(com,self.bgs[k])
return com
# 隨機生成長度為4的字符串
def genCharacterString(self):
CharacterStr = ""
box = [0,0,0,0]
for unit,cpos in zip(box,xrange(len(box))):
CharacterStr += chars[np.random.random_integers(0,35)]
return CharacterStr
# 生成指定批次大小的多標簽圖片偎快,病保存到指定文件夾
def genBatch(self, batchSize,outputPath):
if (not os.path.exists(outputPath)):
os.mkdir(outputPath)
for i in xrange(batchSize):
CharacterStr = G.genCharacterString()
img = G.generate(CharacterStr)
filename = os.path.join(outputPath, str(i).zfill(6) + '.' + CharacterStr + ".jpg")
cv2.imwrite(filename, img)
G = GenCharacter('./font/platechar.ttf')
G.genBatch(30000,"./data/train")
G.genBatch(10000,"./data/val")
直接在/multi-label-classification/文件夾下打開bash,執(zhí)行
python ./gen_character.py
生成30000張訓練集圖片和10000張驗證集圖片冯乘。
如何生成train.txt和val.txt文本文件?
使用過caffe分類模型的同學應該清楚晒夹,除了圖片文件之外裆馒,還需要保存有圖片名和對應gt-label的train.txt和val.txt文本文件,寫了一個簡單的python腳本實現:
create_train_txt.py:
#coding=utf-8
#根據圖像名的特點如000001.5GSB.jpg惋戏,生成gt-label文件
import os
train_src_path = "data/train/"
train_dst_file = "data/train.txt"
if __name__ == '__main__':
train_file = open(train_dst_file, 'w')
k=0
for file in os.listdir(train_src_path):
lines = file
strs = file.split('.')
for i in range(4):
cha = strs[1][i]
# '0'-'9'對應的ASCII碼值是48-57,'A'-'Z'對應的ASCII碼值是65-90,
# 這里為了方便领追,將'0'-'9'減去48映射到0-9;將'A'-'Z'減去55映射到10-35,
if ord(cha)>=65:
num = ord(cha)-55
else:
num = ord(cha)-48
lines+=' '+str(num)
lines+='\n'
train_file.writelines(lines)
k+=1
train_file.close()
print('there are %d images in total' % int(k))
print('done')
create_train_txt.py文件放在/multi-label-classification/文件夾下,在/multi-label-classification/文件夾下打開bash,執(zhí)行
python ./create_train_txt.py
將在/multi-label-classification/data/下面生成train.txt文件响逢。
將上面代碼中路徑名的train改成val,相同的方法些膨,生成val.txt文件订雾。
比如train.txt文件的部分內容如下:
接下來,需要將多標簽的訓練集和驗證集轉換成LMDB格式噩峦,這一步需要對/caffe/tools/convert_imageset.cpp
文件做修改识补,所以這一步留到后面進行凭涂。
2.修改caffe源代碼切油,實現多標簽數據集的轉換和讀取
下載的caffe源碼中有一個/caffe/tools/convert_imageset.cpp文件,使用它可以將圖像圖像格式的數據集轉換成LMDB格式,但它只能處理單標簽的數據集岛琼,為了處理多標簽數據集,需要修改convert_imageset.cpp文件阁苞;而convert_imageset.cpp的實現涉及到io.hpp和io.cpp中的函數悼沿,于是要修改io.hpp和io.cpp糟趾。
同樣义郑,caffe的Data層也只能讀取單標簽的數據集非驮,為了處理多標簽數據集劫笙,需要修改data_layer.cpp文件。
另外栋盹,需要在caffe.proto中添加一個參數敷矫。
總的來說,需要修改以下幾個文件:
- /caffe/tools/convert_imageset.cpp
- /caffe/include/caffe/util/io.hpp
- /caffe/src/caffe/util/io.cpp
- /caffe/src/caffe/proto/caffe.proto
- /caffe/src/caffe/layers/data_layer.cpp
原來的代碼用/* ... */注釋掉怎茫,新增的代碼用////////////// ...... //////////////////包圍起來
修改/caffe/tools/convert_imageset.cpp轨蛤,在約74行處:
/*
std::ifstream infile(argv[2]);
std::vector<std::pair<std::string, int> > lines;
std::string line;
size_t pos;
int label;
while (std::getline(infile, line)) {
pos = line.find_last_of(' ');
label = atoi(line.substr(pos + 1).c_str());
lines.push_back(std::make_pair(line.substr(0, pos), label));
}
*/
////////////////////////////
std::ifstream infile(argv[2]);
std::vector<std::pair<std::string, vector<float> > > lines;
std::string filename;
vector<float> labels(4);
while (infile >> filename >> labels[0] >> labels[1] >> labels[2] >> labels[3]){
lines.push_back(std::make_pair(filename, labels));
}
///////////////////////////
修改/caffe/include/caffe/util/io.hpp圃验。
在其中新加入/////// ..... ///////內的兩個成員函數聲明澳窑,不刪除原來的任何代碼,下面的前兩個函數聲明是原來文件中就有的栈暇,可以看到悲立,原來代碼中的label參數是int類型薪夕,只能處理單標簽字符原献;新增的兩個成員函數就是參考上面兩個函數姑隅,將const int label
參數改成了std::vector<float> labels
,以接受多標簽字符鄙陡。
bool ReadImageToDatum(const string& filename, const int label,
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum);
bool ReadFileToDatum(const string& filename, const int label, Datum* datum);
//////////////////////////////////////////
bool ReadImageToDatum(const string& filename, std::vector<float> labels,
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum);
bool ReadFileLabelsToDatum(const string& filename, std::vector<float> labels,
Datum* datum);
///////////////////////////////////
修改/caffe/src/caffe/util/io.cpp趁矾。
在ReadImageToDatum()函數實現下面添加下面函數實現毫捣,約143行處:
//////////////////////////////////////////////////////////////////////////
bool ReadImageToDatum(const string& filename, std::vector<float> labels,
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum)
{
std::cout << filename << " " << labels[0] << " " << labels[1] << " " << labels[2] << " " << labels[3] << std::endl;
cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
if (cv_img.data) {
if (encoding.size()) {
if ((cv_img.channels() == 3) == is_color && !height && !width &&
matchExt(filename, encoding))
//return ReadFileToDatum(filename, label, datum);
return ReadFileLabelsToDatum(filename, labels, datum);//ReadFileToDatum -> ReadFileLabelsToDatum
std::vector<uchar> buf;
cv::imencode("." + encoding, cv_img, buf);
datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
buf.size()));
//datum->set_label(label);
datum->clear_labels();
datum->add_labels(labels[0]);
datum->add_labels(labels[1]);
datum->add_labels(labels[2]);
datum->add_labels(labels[3]);
//////////////////
datum->set_encoded(true);
return true;
}
CVMatToDatum(cv_img, datum);
//datum->set_label(label);
datum->clear_labels();
datum->add_labels(labels[0]);
datum->add_labels(labels[1]);
datum->add_labels(labels[2]);
datum->add_labels(labels[3]);
//////////////////
return true;
}
else {
return false;
}
}
/////////////////////////////////////////////////////////////////////
在ReadFileToDatum()函數實現下面添加下面的函數實現,約209行處:
//////////////////////////////////////////////////////////////////////
bool ReadFileLabelsToDatum(const string& filename, std::vector<float> labels,
Datum* datum)
{
std::streampos size;
fstream file(filename.c_str(), ios::in | ios::binary | ios::ate);
if (file.is_open()) {
size = file.tellg();
std::string buffer(size, ' ');
file.seekg(0, ios::beg);
file.read(&buffer[0], size);
file.close();
datum->set_data(buffer);
//datum->set_label(label);
datum->clear_labels();
datum->add_labels(labels[0]);
datum->add_labels(labels[1]);
datum->add_labels(labels[2]);
datum->add_labels(labels[3]);
//////////////////
datum->set_encoded(true);
return true;
}
else {
return false;
}
}
///////////////////////////////////////////////////////
修改/caffe/src/caffe/proto/caffe.proto。
在下面的源代碼中添加一行代碼蹋宦,即添加一個labels冷冗,是repeated類型的蒿辙,以便接受多標簽數據集思灌。
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
// Optionally, the datum could also hold float data.
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false];
//////////////////////////////////
repeated float labels = 8;
//////////////////////////////////
}
修改/caffe/src/caffe/layers/data_layer.cpp。
約49行處:
// label
/*
if (this->output_labels_) {
vector<int> label_shape(1, batch_size);
top[1]->Reshape(label_shape);
for (int i = 0; i < this->prefetch_.size(); ++i) {
this->prefetch_[i]->label_.Reshape(label_shape);
}
}
*/
/////////////////////////////////////////////////
if (this->output_labels_){
top[1]->Reshape(batch_size, 4, 1, 1);
for (int i = 0; i < this->prefetch_.size(); ++i) {
this->prefetch_[i]->label_.Reshape(batch_size, 4, 1, 1);
}
}
//////////////////////////////////////////////////
約128行處:
// Copy label.
/*
if (this->output_labels_) {
Dtype* top_label = batch->label_.mutable_cpu_data();
top_label[item_id] = datum.label();
}
*/
///////////////////////////////////////////////
if (this->output_labels_) {
Dtype* top_label = batch->label_.mutable_cpu_data();
for (int i = 0; i < 4; i++)
top_label[item_id * 4 + i] = datum.labels(i);
}
///////////////////////////////////////////////
修改完成耗跛,在caffe根目錄執(zhí)行:
make clean
make all -j8
將修改后的caffe重新編譯。
將原始數據集轉換成LMDB格式
修改編譯caffe后羔砾,就可以使用convert_imageset工具將原始數據集轉換成LMDB格式了姜凄。
執(zhí)行腳本create_train_val_lmdb.sh進行完成數據集轉換檀葛。
create_train_val_lmdb.sh內容:
echo "create train lmdb..."
/home/ys/caffe/build/tools/convert_imageset \
--resize_height=227 \
--resize_width=227 \
--backend="lmdb" \
--shuffle \
/home/ys/caffe/models/multi-label-classification/data/train/ \
/home/ys/caffe/models/multi-label-classification/data/train.txt \
/home/ys/caffe/models/multi-label-classification/data/train_lmdb
echo "done"
echo "create val lmdb..."
/home/ys/caffe/build/tools/convert_imageset \
--resize_height=227 \
--resize_width=227 \
--backend="lmdb" \
--shuffle \
/home/ys/caffe/models/multi-label-classification/data/val/ \
/home/ys/caffe/models/multi-label-classification/data/val.txt \
/home/ys/caffe/models/multi-label-classification/data/val_lmdb
echo "done"
文件路徑根據自己的實際情況更改。
3.修改分類模型Alexnet藏鹊,實現多標簽分類
在/caffe/models/bvlc_alexnet/下有經典的Alexnet模型楚殿,其train_val.prototxt模型結構如下:
將其修改后的train_val.prototxt模型結構如下:
Data層不改動,在Data層后面新增了一個Slice層影涉,將Data層讀取的多標簽分解:
layer {
name: "slicers"
type: "Slice"
bottom: "label"
top: "label_1"
top: "label_2"
top: "label_3"
top: "label_4"
slice_param {
axis: 1
slice_point: 1
slice_point: 2
slice_point: 3
}
}
之后的Conv1層一直到fc6層的Dropout層都不變匣缘,然后將后面的fc7層以后的內容改成如下:
layer {
name: "fc7_1"
type: "InnerProduct"
bottom: "fc6"
top: "fc7_1"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 4096
weight_filler {
type: "gaussian"
std: 0.005
}
bias_filler {
type: "constant"
value: 0.1
}
}
}
layer {
name: "relu7_1"
type: "ReLU"
bottom: "fc7_1"
top: "fc7_1"
}
layer {
name: "drop7_1"
type: "Dropout"
bottom: "fc7_1"
top: "fc7_1"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc7_2"
type: "InnerProduct"
bottom: "fc6"
top: "fc7_2"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 4096
weight_filler {
type: "gaussian"
std: 0.005
}
bias_filler {
type: "constant"
value: 0.1
}
}
}
layer {
name: "relu7_2"
type: "ReLU"
bottom: "fc7_2"
top: "fc7_2"
}
layer {
name: "drop7_2"
type: "Dropout"
bottom: "fc7_2"
top: "fc7_2"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc7_3"
type: "InnerProduct"
bottom: "fc6"
top: "fc7_3"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 4096
weight_filler {
type: "gaussian"
std: 0.005
}
bias_filler {
type: "constant"
value: 0.1
}
}
}
layer {
name: "relu7_3"
type: "ReLU"
bottom: "fc7_3"
top: "fc7_3"
}
layer {
name: "drop7_3"
type: "Dropout"
bottom: "fc7_3"
top: "fc7_3"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc7_4"
type: "InnerProduct"
bottom: "fc6"
top: "fc7_4"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 4096
weight_filler {
type: "gaussian"
std: 0.005
}
bias_filler {
type: "constant"
value: 0.1
}
}
}
layer {
name: "relu7_4"
type: "ReLU"
bottom: "fc7_4"
top: "fc7_4"
}
layer {
name: "drop7_4"
type: "Dropout"
bottom: "fc7_4"
top: "fc7_4"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc8_1"
type: "InnerProduct"
bottom: "fc7_1"
top: "fc8_1"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 36 #1000->36
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "fc8_2"
type: "InnerProduct"
bottom: "fc7_2"
top: "fc8_2"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 36 #1000->36
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "fc8_3"
type: "InnerProduct"
bottom: "fc7_3"
top: "fc8_3"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 36 #1000->36
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "fc8_4"
type: "InnerProduct"
bottom: "fc7_4"
top: "fc8_4"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 36 #1000->36
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "accuracy_1"
type: "Accuracy"
bottom: "fc8_1"
bottom: "label_1"
top: "accuracy_1"
include {
phase: TEST
}
}
layer {
name: "accuracy_2"
type: "Accuracy"
bottom: "fc8_2"
bottom: "label_2"
top: "accuracy_2"
include {
phase: TEST
}
}
layer {
name: "accuracy_3"
type: "Accuracy"
bottom: "fc8_3"
bottom: "label_3"
top: "accuracy_3"
include {
phase: TEST
}
}
layer {
name: "accuracy_4"
type: "Accuracy"
bottom: "fc8_4"
bottom: "label_4"
top: "accuracy_4"
include {
phase: TEST
}
}
layer {
name: "loss_1"
type: "SoftmaxWithLoss"
bottom: "fc8_1"
bottom: "label_1"
top: "loss_1"
loss_weight: 0.25
}
layer {
name: "loss_2"
type: "SoftmaxWithLoss"
bottom: "fc8_2"
bottom: "label_2"
top: "loss_2"
loss_weight: 0.25
}
layer {
name: "loss_3"
type: "SoftmaxWithLoss"
bottom: "fc8_3"
bottom: "label_3"
top: "loss_3"
loss_weight: 0.25
}
layer {
name: "loss_4"
type: "SoftmaxWithLoss"
bottom: "fc8_4"
bottom: "label_4"
top: "loss_4"
loss_weight: 0.25
}
也就是層之前的單個分支改成了4個分支柑爸。后面分別計算loss和accuracy竖配。
solver.protxt代碼:
net: "train_val.prototxt"
test_iter: 100
test_interval: 500
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 6000
display: 10
max_iter: 10000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "multi-label-classification"
solver_mode: GPU
模型修改完成。
4.模型的訓練和測試
訓練模型
現在胁镐,/multi-label-classification/
文件夾下有如下內容:
在/multi-label-classification/
文件夾下打開bash,執(zhí)行
/home/ys/caffe/build/tools/caffe train --solver solver.prototxt --gpu 0
開始模型訓練盯漂,訓練好的模型文件保存在了/multi-label-classification/
文件夾下就缆。
測試模型
在/multi-label-classification/
文件夾下打開bash,執(zhí)行
/home/ys/caffe/build/tools/caffe test \
-model train_val.prototxt \
-weights multi-label-classification_iter_10000.caffemodel \
-iterations 100
即可查看訓練好的模型的測試效果竭宰。
使用pycaffe可視化測試結果
參考這篇文章狞甚,使用caffe的python接口測試單張圖片哼审。
現在在/multi-label-classification/data/test_images/
文件夾下有一張測試圖片:
使用python腳本pycaffe_test.py加載訓練好的caffe模型對這張圖片進行預測涩盾。
pycaffe_test.py:
#encoding:utf-8
import numpy as np
import sys,os
import caffe
import time
caffe.set_device(0)
caffe.set_mode_gpu()
time_begin = time.time()
# 設置當前的工作環(huán)境在caffe下, 根據自己實際情況更改
caffe_root = '/home/ys/caffe/'
# 我們也把caffe/python也添加到當前環(huán)境
sys.path.insert(0, caffe_root + 'python')
os.chdir(caffe_root)#更換工作目錄
# 設置網絡結構
net_file=caffe_root + 'models/multi-label-classification/deploy.prototxt'
# 添加訓練之后的參數
caffe_model=caffe_root + 'models/multi-label-classification/multi-label-classification_iter_10000.caffemodel'
# 均值文件
mean_file=caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'
# 這里對任何一個程序都是通用的,就是處理圖片
# 把上面添加的兩個變量都作為參數構造一個Net
net = caffe.Net(net_file,caffe_model,caffe.TEST)
# 得到data的形狀终畅,這里的圖片是默認matplotlib底層加載的
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
# matplotlib加載的image是像素[0-1],圖片的數據格式[h,w,c]离福,RGB
# caffe加載的圖片需要的是[0-255]像素妖爷,數據格式[c,h,w],BGR絮识,那么就需要轉換
# channel 放到前面
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
# 圖片像素放大到[0-255]
transformer.set_raw_scale('data', 255)
# RGB-->BGR 轉換
transformer.set_channel_swap('data', (2,1,0))
# 加載一張測試圖片
image_file = caffe_root+'models/multi-label-classification/data/test_images/000001.A86I.jpg'
im=caffe.io.load_image(image_file)
# 用上面的transformer.preprocess來處理剛剛加載圖片
net.blobs['data'].data[...] = transformer.preprocess('data',im)
#注意,網絡開始向前傳播啦
output = net.forward()
# 最終的結果: 當前這個圖片的屬于哪個物體的概率(列表表示)
output_prob1 = output['prob_1'][0]
output_prob2 = output['prob_2'][0]
output_prob3 = output['prob_3'][0]
output_prob4 = output['prob_4'][0]
# 找出最大的那個概率
chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A","B", "C", "D", "E", "F", "G", "H", "I",
"J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X","Y", "Z"];
print 'test image: ', image_file
print 'the predicted result is:', chars[output_prob1.argmax()],' ',chars[output_prob2.argmax()],' ',chars[output_prob3.argmax()],' ',chars[output_prob4.argmax()]
print 'time used: ', round(time.time()-time_begin, 4), 's'
在/multi-label-classification/
文件夾下打開bash,執(zhí)行
python ./pycaffe_test.py
運行結果:
本文用到的代碼在這里。