caffe實現多標簽分類模型

常規(guī)的基于CNN的圖像分類網絡如Lenet扣孟、Alexnet勉耀、VGGnet等都是單分類模型瘫想,本文記錄在ubuntu16.04下如何對傳統(tǒng)的單分類模型進行調整醋粟,實現多標簽分類的效果粪躬,這里主要指的是對固定長度字符串的識別,相同原理可用于驗證碼識別車牌識別昔穴。

聲明:本文代碼主要來自于以下兩篇博文:

下面整理了使用caffe完成多標簽分類(multi-label classification)模型訓練測試的整個流程镰官,主要分為4個部分:

  1. 如何制作多標簽分類數據集;
  2. 修改caffe源代碼吗货,實現多標簽數據集的轉換和讀扔具搿;
  3. 修改分類模型Alexnet宙搬,實現多標簽分類笨腥;
  4. 模型的訓練和測試拓哺。

1.如何制作多標簽分類數據集


制作的數據集圖片類似于:

這里的每張圖片中包含4個字符(0-9或者A-Z),通過對代碼的簡單修改脖母,可以擴展成任意長度士鸥。

為了簡單,將車牌識別中的不分割字符的端到端(End-to-End)識別中的源代碼修改簡化谆级。

首先建立一個名為multi-label-classification的文件夾烤礁,下面的子目錄/子文件如下:

其中藍色的是文件夾,其他顏色的是文件肥照。

生成多標簽字符圖片的思路大概是:

  • 首先確定字符串的長度脚仔,即想要生成包含幾個標簽的圖片;

  • 根據字符串的長度舆绎,確定圖像的尺寸;比如我生成4個字符的圖片吕朵,再考慮單個字符和長寬比,字符間的間隙努溃,以及字體的大小,確定4字符圖像的長和寬是90x30;

  • 需要找到一種.ttf格式的字體则拷,這根據實際情況選擇合適的字體;

  • 接下來斥铺,需要確定圖像要用什么樣的背景;比如我隨便找了十幾種顏色的背景圖片(放在../background/文件夾下)坛善,部分顯示如下晾蜘,每張都是90x30大小。


  • 最后考慮需要對字符串圖像做什么處理眠屎,比如隨機旋轉剔交,畸變處理,加入噪聲改衩,模糊等等岖常,用于增強模型的泛化能力。

下面是gen_character.py的代碼:

#coding=utf-8
import PIL
from PIL import ImageFont
from PIL import Image
from PIL import ImageDraw
import cv2
import numpy as np
import os
from math import *

chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A","B", "C", "D", "E", "F", "G", "H", "I",
         "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X","Y", "Z"];

# 引入畸變葫督,將使字符隨機地向左或者向右傾斜一個隨機的角度(4-10個像素值)
def distortionRandom(img):
    w = img.shape[1]
    h = img.shape[0]
    pts1 = np.float32([[0, 0], [0, h], [w, 0], [w, h]])
    pos_or_neg = np.random.random_integers(0,1)
    distortion_value = np.random.random_integers(4,10)
    if(pos_or_neg==0):
        pts2 = np.float32([[0, 0], [distortion_value, h], [w-distortion_value, 0], [w, h]])
    else:
        pts2 = np.float32([[distortion_value, 0], [0, h], [w, 0], [w-distortion_value, h]])
    M = cv2.getPerspectiveTransform(pts1, pts2)
    dst = cv2.warpPerspective(img, M, (w,h))
    return dst

# 在背景圖像塊中寫入一個字符
def GenCh(f,val):
    img=Image.new("RGB", (16,28),(255,255,255))
    draw = ImageDraw.Draw(img)
    draw.text((2, 0),val.decode('utf-8'),(0,0,0),font=f, align="center")
    A = np.array(img)
    A = cv2.resize(A, (22,28))
    return A

# 定義一個類GenCharacter竭鞍,用于生成固定長度多標簽圖片
class GenCharacter:
    def __init__(self,font):
        # 初始化所用的字符字體
        self.fontE =  ImageFont.truetype(font,28,0)
        # 初始化多標簽圖片的大小為90x30
        self.img=np.array(Image.new("RGB", (90,30),(255,255,255)))
        # 初始化標簽圖片所用的背景板惑,這里在./background/文件夾中準備了十幾張90x30的不同背景
        # 全部讀取到一個list中,生成多標簽圖片時隨機選擇某一個背景
        self.bgs = []
        for file in os.listdir("./background/"):
            bg  = cv2.resize(cv2.imread("./background/"+file),(90,30))
            self.bgs.append(bg)

    # 將長度為4的字符串寫入90x30的圖片中
    def draw(self,val):
        offset = 2
        for i in range(4):
            base = offset + i*22
            self.img[0:28, base:base+22]= GenCh(self.fontE,val[i])
        return self.img

    # 生成一張帶隨機背景的隨機字符串
    def generate(self,text):
        if len(text) == 4:
            fg = self.draw(text.decode(encoding="utf-8"))
            fg = cv2.bitwise_not(fg)
            k = np.random.random_integers(0,len(self.bgs)-1)
            com = cv2.bitwise_or(fg,self.bgs[k])
            com = distortionRandom(com)
            com = cv2.bitwise_or(com,self.bgs[k])
            return com
    # 隨機生成長度為4的字符串
    def genCharacterString(self):
        CharacterStr = ""
        box = [0,0,0,0]
        for unit,cpos in zip(box,xrange(len(box))):
                CharacterStr += chars[np.random.random_integers(0,35)]
        return CharacterStr
    # 生成指定批次大小的多標簽圖片偎快,病保存到指定文件夾
    def genBatch(self, batchSize,outputPath):
        if (not os.path.exists(outputPath)):
            os.mkdir(outputPath)
        for i in xrange(batchSize):
            CharacterStr = G.genCharacterString()
            img =  G.generate(CharacterStr)
            filename = os.path.join(outputPath, str(i).zfill(6) + '.' + CharacterStr + ".jpg")
            cv2.imwrite(filename, img)

G = GenCharacter('./font/platechar.ttf')
G.genBatch(30000,"./data/train")
G.genBatch(10000,"./data/val")

直接在/multi-label-classification/文件夾下打開bash,執(zhí)行

python ./gen_character.py

生成30000張訓練集圖片和10000張驗證集圖片冯乘。

如何生成train.txt和val.txt文本文件?

使用過caffe分類模型的同學應該清楚晒夹,除了圖片文件之外裆馒,還需要保存有圖片名和對應gt-label的train.txt和val.txt文本文件,寫了一個簡單的python腳本實現:

create_train_txt.py:

#coding=utf-8
#根據圖像名的特點如000001.5GSB.jpg惋戏,生成gt-label文件
import os
train_src_path = "data/train/"
train_dst_file = "data/train.txt"

if __name__ == '__main__':
    train_file = open(train_dst_file, 'w')
    k=0
    for file in os.listdir(train_src_path):
        lines = file
        strs = file.split('.')
        for i in range(4):
            cha = strs[1][i]
            # '0'-'9'對應的ASCII碼值是48-57,'A'-'Z'對應的ASCII碼值是65-90,
            # 這里為了方便领追,將'0'-'9'減去48映射到0-9;將'A'-'Z'減去55映射到10-35,
            if ord(cha)>=65:
                num = ord(cha)-55
            else:
                num = ord(cha)-48
            lines+=' '+str(num)
        lines+='\n'
        train_file.writelines(lines)
        k+=1
    train_file.close()
    print('there are %d images in total' % int(k))
    print('done')

create_train_txt.py文件放在/multi-label-classification/文件夾下,在/multi-label-classification/文件夾下打開bash,執(zhí)行

python ./create_train_txt.py

將在/multi-label-classification/data/下面生成train.txt文件响逢。

將上面代碼中路徑名的train改成val,相同的方法些膨,生成val.txt文件订雾。

比如train.txt文件的部分內容如下:

接下來,需要將多標簽的訓練集和驗證集轉換成LMDB格式噩峦,這一步需要對/caffe/tools/convert_imageset.cpp文件做修改识补,所以這一步留到后面進行凭涂。

2.修改caffe源代碼切油,實現多標簽數據集的轉換和讀取


下載的caffe源碼中有一個/caffe/tools/convert_imageset.cpp文件,使用它可以將圖像圖像格式的數據集轉換成LMDB格式,但它只能處理單標簽的數據集岛琼,為了處理多標簽數據集,需要修改convert_imageset.cpp文件阁苞;而convert_imageset.cpp的實現涉及到io.hpp和io.cpp中的函數悼沿,于是要修改io.hpp和io.cpp糟趾。

同樣义郑,caffe的Data層也只能讀取單標簽的數據集非驮,為了處理多標簽數據集劫笙,需要修改data_layer.cpp文件。

另外栋盹,需要在caffe.proto中添加一個參數敷矫。

總的來說,需要修改以下幾個文件:

  • /caffe/tools/convert_imageset.cpp
  • /caffe/include/caffe/util/io.hpp
  • /caffe/src/caffe/util/io.cpp
  • /caffe/src/caffe/proto/caffe.proto
  • /caffe/src/caffe/layers/data_layer.cpp

原來的代碼用/* ... */注釋掉怎茫,新增的代碼用////////////// ...... //////////////////包圍起來

修改/caffe/tools/convert_imageset.cpp轨蛤,在約74行處:

/*
  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::string, int> > lines;
  std::string line;
  size_t pos;
  int label;
  while (std::getline(infile, line)) {
    pos = line.find_last_of(' ');
    label = atoi(line.substr(pos + 1).c_str());
    lines.push_back(std::make_pair(line.substr(0, pos), label));
  }
  */
  ////////////////////////////
  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::string, vector<float> > > lines;
  std::string filename;
  vector<float> labels(4);
  while (infile >> filename >> labels[0] >> labels[1] >> labels[2] >> labels[3]){
      lines.push_back(std::make_pair(filename, labels));
  }
  ///////////////////////////

修改/caffe/include/caffe/util/io.hpp圃验。

在其中新加入/////// ..... ///////內的兩個成員函數聲明澳窑,不刪除原來的任何代碼,下面的前兩個函數聲明是原來文件中就有的栈暇,可以看到悲立,原來代碼中的label參數是int類型薪夕,只能處理單標簽字符原献;新增的兩個成員函數就是參考上面兩個函數姑隅,將const int label參數改成了std::vector<float> labels,以接受多標簽字符鄙陡。

bool ReadImageToDatum(const string& filename, const int label,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum);
bool ReadFileToDatum(const string& filename, const int label, Datum* datum);
//////////////////////////////////////////
bool ReadImageToDatum(const string& filename, std::vector<float> labels,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum);
bool ReadFileLabelsToDatum(const string& filename, std::vector<float> labels,
    Datum* datum);
///////////////////////////////////

修改/caffe/src/caffe/util/io.cpp趁矾。

在ReadImageToDatum()函數實現下面添加下面函數實現毫捣,約143行處:

//////////////////////////////////////////////////////////////////////////
bool ReadImageToDatum(const string& filename, std::vector<float> labels,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum)
{
    std::cout << filename << " " << labels[0] << " " << labels[1] << " " << labels[2] << " " << labels[3] << std::endl;
    cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
    if (cv_img.data) {
        if (encoding.size()) {
            if ((cv_img.channels() == 3) == is_color && !height && !width &&
                matchExt(filename, encoding))
                //return ReadFileToDatum(filename, label, datum);
                return ReadFileLabelsToDatum(filename, labels, datum);//ReadFileToDatum -> ReadFileLabelsToDatum
            std::vector<uchar> buf;
            cv::imencode("." + encoding, cv_img, buf);
            datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
                buf.size()));
            //datum->set_label(label);
            datum->clear_labels();
            datum->add_labels(labels[0]);
            datum->add_labels(labels[1]);
            datum->add_labels(labels[2]);
            datum->add_labels(labels[3]);
            //////////////////
            datum->set_encoded(true);
            return true;
        }
        CVMatToDatum(cv_img, datum);
        //datum->set_label(label);
        datum->clear_labels();
        datum->add_labels(labels[0]);
        datum->add_labels(labels[1]);
        datum->add_labels(labels[2]);
        datum->add_labels(labels[3]);
        //////////////////
        return true;
    }
    else {
        return false;
    }
}
/////////////////////////////////////////////////////////////////////

在ReadFileToDatum()函數實現下面添加下面的函數實現,約209行處:

//////////////////////////////////////////////////////////////////////
bool ReadFileLabelsToDatum(const string& filename, std::vector<float> labels,
    Datum* datum)
{
    std::streampos size;

    fstream file(filename.c_str(), ios::in | ios::binary | ios::ate);
    if (file.is_open()) {
        size = file.tellg();
        std::string buffer(size, ' ');
        file.seekg(0, ios::beg);
        file.read(&buffer[0], size);
        file.close();
        datum->set_data(buffer);
        //datum->set_label(label);
        datum->clear_labels();
        datum->add_labels(labels[0]);
        datum->add_labels(labels[1]);
        datum->add_labels(labels[2]);
        datum->add_labels(labels[3]);
        //////////////////
        datum->set_encoded(true);
        return true;
    }
    else {
        return false;
    }
}
///////////////////////////////////////////////////////

修改/caffe/src/caffe/proto/caffe.proto。

在下面的源代碼中添加一行代碼蹋宦,即添加一個labels冷冗,是repeated類型的蒿辙,以便接受多標簽數據集思灌。

message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
  //////////////////////////////////
  repeated float labels = 8;
  //////////////////////////////////
}

修改/caffe/src/caffe/layers/data_layer.cpp。

約49行處:

// label
  /*
  if (this->output_labels_) {
    vector<int> label_shape(1, batch_size);
    top[1]->Reshape(label_shape);
    for (int i = 0; i < this->prefetch_.size(); ++i) {
      this->prefetch_[i]->label_.Reshape(label_shape);
    }
  }
  */
  /////////////////////////////////////////////////
  if (this->output_labels_){
      top[1]->Reshape(batch_size, 4, 1, 1);
      for (int i = 0; i < this->prefetch_.size(); ++i) {
          this->prefetch_[i]->label_.Reshape(batch_size, 4, 1, 1);
      }
  }
  //////////////////////////////////////////////////

約128行處:

// Copy label.
    /*
    if (this->output_labels_) {
      Dtype* top_label = batch->label_.mutable_cpu_data();
      top_label[item_id] = datum.label();
    }
    */
    ///////////////////////////////////////////////
    if (this->output_labels_) {
      Dtype* top_label = batch->label_.mutable_cpu_data();
      for (int i = 0; i < 4; i++)
                 top_label[item_id * 4 + i] = datum.labels(i);
      }
    ///////////////////////////////////////////////

修改完成耗跛,在caffe根目錄執(zhí)行:

make clean
make all -j8

將修改后的caffe重新編譯。

將原始數據集轉換成LMDB格式

修改編譯caffe后羔砾,就可以使用convert_imageset工具將原始數據集轉換成LMDB格式了姜凄。

執(zhí)行腳本create_train_val_lmdb.sh進行完成數據集轉換檀葛。

create_train_val_lmdb.sh內容:

echo "create train lmdb..."
/home/ys/caffe/build/tools/convert_imageset \
--resize_height=227 \
--resize_width=227 \
--backend="lmdb" \
--shuffle \
/home/ys/caffe/models/multi-label-classification/data/train/ \
/home/ys/caffe/models/multi-label-classification/data/train.txt \
/home/ys/caffe/models/multi-label-classification/data/train_lmdb
echo "done"
echo "create val lmdb..."
/home/ys/caffe/build/tools/convert_imageset \
--resize_height=227 \
--resize_width=227 \
--backend="lmdb" \
--shuffle \
/home/ys/caffe/models/multi-label-classification/data/val/ \
/home/ys/caffe/models/multi-label-classification/data/val.txt \
/home/ys/caffe/models/multi-label-classification/data/val_lmdb
echo "done"

文件路徑根據自己的實際情況更改。

3.修改分類模型Alexnet藏鹊,實現多標簽分類


在/caffe/models/bvlc_alexnet/下有經典的Alexnet模型楚殿,其train_val.prototxt模型結構如下:

alexnet.png

將其修改后的train_val.prototxt模型結構如下:

multi-label-alexnet.png

Data層不改動,在Data層后面新增了一個Slice層影涉,將Data層讀取的多標簽分解:

layer {
  name: "slicers"
  type: "Slice"
  bottom: "label"
  top: "label_1"
  top: "label_2"
  top: "label_3"
  top: "label_4"
  slice_param {
    axis: 1
    slice_point: 1
    slice_point: 2
    slice_point: 3
  }
}

之后的Conv1層一直到fc6層的Dropout層都不變匣缘,然后將后面的fc7層以后的內容改成如下:

layer {
  name: "fc7_1"
  type: "InnerProduct"
  bottom: "fc6"
  top: "fc7_1"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 0.1
    }
  }
}
layer {
  name: "relu7_1"
  type: "ReLU"
  bottom: "fc7_1"
  top: "fc7_1"
}
layer {
  name: "drop7_1"
  type: "Dropout"
  bottom: "fc7_1"
  top: "fc7_1"
  dropout_param {
    dropout_ratio: 0.5
  }
}
layer {
  name: "fc7_2"
  type: "InnerProduct"
  bottom: "fc6"
  top: "fc7_2"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 0.1
    }
  }
}
layer {
  name: "relu7_2"
  type: "ReLU"
  bottom: "fc7_2"
  top: "fc7_2"
}
layer {
  name: "drop7_2"
  type: "Dropout"
  bottom: "fc7_2"
  top: "fc7_2"
  dropout_param {
    dropout_ratio: 0.5
  }
}
layer {
  name: "fc7_3"
  type: "InnerProduct"
  bottom: "fc6"
  top: "fc7_3"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 0.1
    }
  }
}
layer {
  name: "relu7_3"
  type: "ReLU"
  bottom: "fc7_3"
  top: "fc7_3"
}
layer {
  name: "drop7_3"
  type: "Dropout"
  bottom: "fc7_3"
  top: "fc7_3"
  dropout_param {
    dropout_ratio: 0.5
  }
}
layer {
  name: "fc7_4"
  type: "InnerProduct"
  bottom: "fc6"
  top: "fc7_4"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 0.1
    }
  }
}
layer {
  name: "relu7_4"
  type: "ReLU"
  bottom: "fc7_4"
  top: "fc7_4"
}
layer {
  name: "drop7_4"
  type: "Dropout"
  bottom: "fc7_4"
  top: "fc7_4"
  dropout_param {
    dropout_ratio: 0.5
  }
}

layer {
  name: "fc8_1"
  type: "InnerProduct"
  bottom: "fc7_1"
  top: "fc8_1"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 36 #1000->36
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "fc8_2"
  type: "InnerProduct"
  bottom: "fc7_2"
  top: "fc8_2"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 36 #1000->36
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "fc8_3"
  type: "InnerProduct"
  bottom: "fc7_3"
  top: "fc8_3"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 36 #1000->36
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "fc8_4"
  type: "InnerProduct"
  bottom: "fc7_4"
  top: "fc8_4"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 36 #1000->36
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

layer {
  name: "accuracy_1"
  type: "Accuracy"
  bottom: "fc8_1"
  bottom: "label_1"
  top: "accuracy_1"
  include {
    phase: TEST
  }
}
layer {
  name: "accuracy_2"
  type: "Accuracy"
  bottom: "fc8_2"
  bottom: "label_2"
  top: "accuracy_2"
  include {
    phase: TEST
  }
}
layer {
  name: "accuracy_3"
  type: "Accuracy"
  bottom: "fc8_3"
  bottom: "label_3"
  top: "accuracy_3"
  include {
    phase: TEST
  }
}
layer {
  name: "accuracy_4"
  type: "Accuracy"
  bottom: "fc8_4"
  bottom: "label_4"
  top: "accuracy_4"
  include {
    phase: TEST
  }
}


layer {
  name: "loss_1"
  type: "SoftmaxWithLoss"
  bottom: "fc8_1"
  bottom: "label_1"
  top: "loss_1"
  loss_weight: 0.25
}
layer {
  name: "loss_2"
  type: "SoftmaxWithLoss"
  bottom: "fc8_2"
  bottom: "label_2"
  top: "loss_2"
  loss_weight: 0.25
}
layer {
  name: "loss_3"
  type: "SoftmaxWithLoss"
  bottom: "fc8_3"
  bottom: "label_3"
  top: "loss_3"
  loss_weight: 0.25
}
layer {
  name: "loss_4"
  type: "SoftmaxWithLoss"
  bottom: "fc8_4"
  bottom: "label_4"
  top: "loss_4"
  loss_weight: 0.25
}

也就是層之前的單個分支改成了4個分支柑爸。后面分別計算loss和accuracy竖配。

solver.protxt代碼:

net: "train_val.prototxt"
test_iter: 100
test_interval: 500
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 6000
display: 10
max_iter: 10000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "multi-label-classification"
solver_mode: GPU

模型修改完成。

4.模型的訓練和測試


訓練模型

現在胁镐,/multi-label-classification/文件夾下有如下內容:

/multi-label-classification/文件夾下打開bash,執(zhí)行

/home/ys/caffe/build/tools/caffe train --solver solver.prototxt --gpu 0

開始模型訓練盯漂,訓練好的模型文件保存在了/multi-label-classification/文件夾下就缆。

測試模型

/multi-label-classification/文件夾下打開bash,執(zhí)行

/home/ys/caffe/build/tools/caffe test \
-model train_val.prototxt \
-weights multi-label-classification_iter_10000.caffemodel \
-iterations 100

即可查看訓練好的模型的測試效果竭宰。

使用pycaffe可視化測試結果

參考這篇文章狞甚,使用caffe的python接口測試單張圖片哼审。

現在在/multi-label-classification/data/test_images/文件夾下有一張測試圖片:

使用python腳本pycaffe_test.py加載訓練好的caffe模型對這張圖片進行預測涩盾。

pycaffe_test.py:

#encoding:utf-8
import numpy as np
import sys,os
import caffe
import time
caffe.set_device(0)
caffe.set_mode_gpu()
time_begin = time.time()
# 設置當前的工作環(huán)境在caffe下, 根據自己實際情況更改
caffe_root = '/home/ys/caffe/'
# 我們也把caffe/python也添加到當前環(huán)境
sys.path.insert(0, caffe_root + 'python')
os.chdir(caffe_root)#更換工作目錄

# 設置網絡結構
net_file=caffe_root + 'models/multi-label-classification/deploy.prototxt'
# 添加訓練之后的參數
caffe_model=caffe_root + 'models/multi-label-classification/multi-label-classification_iter_10000.caffemodel'
# 均值文件
mean_file=caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'

# 這里對任何一個程序都是通用的,就是處理圖片
# 把上面添加的兩個變量都作為參數構造一個Net
net = caffe.Net(net_file,caffe_model,caffe.TEST)
# 得到data的形狀终畅,這里的圖片是默認matplotlib底層加載的
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
# matplotlib加載的image是像素[0-1],圖片的數據格式[h,w,c]离福,RGB
# caffe加載的圖片需要的是[0-255]像素妖爷,數據格式[c,h,w],BGR絮识,那么就需要轉換

# channel 放到前面
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
# 圖片像素放大到[0-255]
transformer.set_raw_scale('data', 255)
# RGB-->BGR 轉換
transformer.set_channel_swap('data', (2,1,0))

# 加載一張測試圖片
image_file = caffe_root+'models/multi-label-classification/data/test_images/000001.A86I.jpg'
im=caffe.io.load_image(image_file)
# 用上面的transformer.preprocess來處理剛剛加載圖片
net.blobs['data'].data[...] = transformer.preprocess('data',im)
#注意,網絡開始向前傳播啦
output = net.forward()
# 最終的結果: 當前這個圖片的屬于哪個物體的概率(列表表示)
output_prob1 = output['prob_1'][0]
output_prob2 = output['prob_2'][0]
output_prob3 = output['prob_3'][0]
output_prob4 = output['prob_4'][0]
# 找出最大的那個概率
chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A","B", "C", "D", "E", "F", "G", "H", "I",
         "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X","Y", "Z"];
print 'test image: ', image_file
print 'the predicted result is:', chars[output_prob1.argmax()],' ',chars[output_prob2.argmax()],' ',chars[output_prob3.argmax()],' ',chars[output_prob4.argmax()]
print 'time used: ', round(time.time()-time_begin, 4), 's'

/multi-label-classification/文件夾下打開bash,執(zhí)行

python ./pycaffe_test.py

運行結果:


本文用到的代碼在這里

?著作權歸作者所有,轉載或內容合作請聯系作者
  • 序言:七十年代末彼念,一起剝皮案震驚了整個濱河市逐沙,隨后出現的幾起案子吩案,更是在濱河造成了極大的恐慌徘郭,老刑警劉巖崎岂,帶你破解...
    沈念sama閱讀 206,839評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現場離奇詭異江醇,居然都是意外死亡陶夜,警方通過查閱死者的電腦和手機条辟,發(fā)現死者居然都...
    沈念sama閱讀 88,543評論 2 382
  • 文/潘曉璐 我一進店門羽嫡,熙熙樓的掌柜王于貴愁眉苦臉地迎上來杭棵,“玉大人魂爪,你說我怎么就攤上這事滓侍〈志” “怎么了街图?”我有些...
    開封第一講書人閱讀 153,116評論 0 344
  • 文/不壞的土叔 我叫張陵耘擂,是天一觀的道長絮姆。 經常有香客問我,道長蚁阳,這世上最難降的妖魔是什么颠悬? 我笑而不...
    開封第一講書人閱讀 55,371評論 1 279
  • 正文 為了忘掉前任赔癌,我火速辦了婚禮灾票,結果婚禮上茫虽,老公的妹妹穿的比我還像新娘濒析。我一直安慰自己悼枢,他們只是感情好馒索,可當我...
    茶點故事閱讀 64,384評論 5 374
  • 文/花漫 我一把揭開白布旨怠。 她就那樣靜靜地躺著鉴腻,像睡著了一般百揭。 火紅的嫁衣襯著肌膚如雪器一。 梳的紋絲不亂的頭發(fā)上祈秕,一...
    開封第一講書人閱讀 49,111評論 1 285
  • 那天,我揣著相機與錄音瞭亮,去河邊找鬼统翩。 笑死玻孟,一個胖子當著我的面吹牛黍翎,可吹牛的內容都是我干的匣掸。 我是一名探鬼主播氮双,決...
    沈念sama閱讀 38,416評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼送爸,長吁一口氣:“原來是場噩夢啊……” “哼暖释!你這毒婦竟也來了纹磺?” 一聲冷哼從身側響起亮曹,我...
    開封第一講書人閱讀 37,053評論 0 259
  • 序言:老撾萬榮一對情侶失蹤式矫,失蹤者是張志新(化名)和其女友劉穎衷佃,沒想到半個月后氏义,有當地人在樹林里發(fā)現了一具尸體惯悠,經...
    沈念sama閱讀 43,558評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡克婶,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 36,007評論 2 325
  • 正文 我和宋清朗相戀三年鸭蛙,在試婚紗的時候發(fā)現自己被綠了娶视。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片肪获。...
    茶點故事閱讀 38,117評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡孝赫,死狀恐怖,靈堂內的尸體忽然破棺而出致开,到底是詐尸還是另有隱情雌桑,我是刑警寧澤校坑,帶...
    沈念sama閱讀 33,756評論 4 324
  • 正文 年R本政府宣布膏斤,位于F島的核電站莫辨,受9級特大地震影響盘榨,放射性物質發(fā)生泄漏。R本人自食惡果不足惜蟆融,卻給世界環(huán)境...
    茶點故事閱讀 39,324評論 3 307
  • 文/蒙蒙 一草巡、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧型酥,春花似錦山憨、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至由境,卻和暖如春棚亩,著一層夾襖步出監(jiān)牢的瞬間拒担,已是汗流浹背州弟。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評論 1 262
  • 我被黑心中介騙來泰國打工啃奴, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留老厌,地道東北人。 一個月前我還...
    沈念sama閱讀 45,578評論 2 355
  • 正文 我出身青樓菌赖,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 42,877評論 2 345

推薦閱讀更多精彩內容