TensorFlow 訓練多任務多標簽模型

????????在學習了使用 TensorFlow 的 CNN 進行圖像分類之后询枚,現(xiàn)在對這些方法做一個簡單的拓展违帆,即來處理多任務多標簽的情形浙巫。為了便于說明金蜀,我們假設現(xiàn)在要對 0-9 這 10 個數(shù)字 和 A-Z (排除 I、O) 這 24 個字母進行識別的畴,所有的數(shù)據(jù)都使用 captcha 生成(讀過 TensorFlow 訓練 CNN 分類器 這篇文章的讀者應該不陌生了)渊抄。以下的代碼(命名為 generate_train_data.py)使用 captcha 生成了 100000 萬張 28 x 28 的圖像,每張圖像都是帶有大量噪聲的一個字符(所有字符見下面代碼中的 alphabets 列表丧裁,所有的圖像保存在文件夾 ./datasets/images 中护桦,每張圖像命名為 image圖像序號_類標號.jpg,其中的類標號為該字符在列表 alphabets 中的下標)煎娇。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018

@author: shirhe-lyh
"""

import cv2
import numpy as np

from captcha.image import ImageCaptcha


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image
    
    
if __name__ == '__main__':
    output_dir = './datasets/images/'
    alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
                 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 
                 'U', 'V', 'W', 'X', 'Y', 'Z']
    for i in range(100000):
        label = np.random.randint(0, 34)
        image = generate_captcha(alphabets[label])
        image_name = 'image{}_{}.jpg'.format(i+1, label)
        output_path = output_dir + image_name
        cv2.imwrite(output_path, image)

我們的目的是訓練一個簡單的 CNN 模型將對這些圖像進行分類二庵,由于這個問題很簡單,直接訓練一個 34 類的分類器就達成目標了缓呛。但類別數(shù)越大催享,訓練就越困難,因此我們采取另一種分化的策略哟绊,將這個 34 類的問題分為兩個子問題因妙,分別是:1.只識別數(shù)字;2.只識別字母票髓。之所以可以這么分攀涵,是因為 數(shù)字字母 的差別很大,完全可以認為它們屬于兩種不同的范疇洽沟,從而可以看成獨立的分類任務來處理以故。這樣我們現(xiàn)在的問題是:怎樣同時識別 10 個數(shù)字和 24 個字母?這是一個多任務多標簽問題:我們要處理識別數(shù)字和識別字母這兩個任務裆操,其中每個任務都是涉及多個標簽(分別是 10 個標簽和 24 個標簽)怒详。

????????雖然這篇文章舉例的這個問題非常簡單,但這個方法(再加上預訓練模型技巧)可以用于更加復雜的問題跷车,比如 阿里的 FashionAI 服飾屬性識別全球挑戰(zhàn)賽棘利,感興趣的朋友可以用 ResNet-50 預訓練模型去微調一個 8 任務模型。

????????本文的所有代碼見 github:multi_task_test朽缴,歡迎訪問交流并反饋問題善玫!

一、多分支 CNN 模型定義

????????雖然我們要處理的是兩個獨立的任務,但我們希望這兩個任務共用大部分的神經(jīng)網(wǎng)絡層茅郎,這樣既可以節(jié)省計算量蜗元,一般來說,也可以提升準確率系冗。因此奕扣,我們將要定義的神經(jīng)網(wǎng)絡結構設計為(所有共用的層在文章 TensorFlow-slim 訓練 CNN 分類模型 中用來識別 0-9 這 10 個數(shù)字):

兩分支輸出的 CNN 用于識別數(shù)字和字母兩個任務

當獲取了一張圖像(數(shù)字或字母)之后,將它送入第一個卷積層(conv1)掌敬、第二個卷積層(conv2)惯豆、······,直到第二個全連接層(fc2)奔害,到此為止楷兽,這些層都是兩個任務共用的,它們的作用是用來提取圖像特征华临。然后芯杀,針對兩個不同的任務,將網(wǎng)絡分為兩個分支雅潭,一個用于輸出該圖像是各個數(shù)字的概率(digits_output)揭厚,另一個用于輸出該圖像是各個字母的概率(letters_output)。網(wǎng)絡的具體定義如下(網(wǎng)絡各層的名字可能和上圖不一致):

def predict(self, preprocessed_inputs):
    """Predict prediction tensors from inputs tensor.
        
    Outputs of this function can be passed to loss or postprocess functions.
        
   Args:
        preprocessed_inputs: A float32 tensor with shape [batch_size,
            height, width, num_channels] representing a batch of images.
            
    Returns:
        prediction_dict: A dictionary holding prediction tensors to be
            passed to the Loss or Postprocess functions.
    """
    net = preprocessed_inputs
    net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
    net = slim.max_pool2d(net, [2, 2], scope='pool1')
    net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
    net = slim.max_pool2d(net, [2, 2], scope='pool2')
    net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
    net = slim.flatten(net, scope='flatten')
    net = slim.dropout(net, keep_prob=0.5,
                       is_training=self._is_training)
    net = slim.fully_connected(net, 512, scope='fc1')
    net = slim.fully_connected(net, 512, scope='fc2')
    prediction_dict = {}
    for class_name, num_classes in self.num_classes_dict.items():
        logits = slim.fully_connected(net, num_outputs=num_classes, 
                                      activation_fn=None, 
                                      scope='Predict/' + class_name)
        prediction_dict[class_name] = logits
    return prediction_dict

????????從以上代碼可以看到扶供,多任務多標簽任務的 CNN 定義也非常簡單筛圆,只需要引入一個 for 循環(huán)即可。接下來诚欠,要定義損失函數(shù)和準確率函數(shù)顽染。

????????在生成圖像的時候,圖片名字命名的模式是 image圖像序號_類標號.jpg轰绵,比如粉寞,假設第 1 張圖像是字母 G,那么它的類標號是 16 = 10 + 7 - 1左腔,因此它的名字是 image1_16.jpg唧垦。但這個類標號 16 是基于所有 34 個類來說的,實際上液样,如果只限于字母來說振亮,它的類標號應該是 6。之所以對數(shù)字和字母使用統(tǒng)一的類標號鞭莽,其實是為了便于定義損失和準確率函數(shù)坊秸。原因在于:對字母 G,因為我們現(xiàn)在是獨立處理數(shù)字和字母兩個分支任務澎怒,因此 G 應該只對分類字母的分支貢獻損失褒搔,而不應當對分類數(shù)字的分支產(chǎn)生損失。如果統(tǒng)一對數(shù)字和字母分配類標號,那么 G 的類標號 16 的獨熱(one-hot)編碼是:

0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

其中的 - 是為了便于看清兩個任務的分界線星瘾,實際請忽略走孽。此時,在計算損失時琳状,將這個獨熱編碼一分為二:

0 0 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

前一部分對應于 G 在分類 0-9 這 10 個數(shù)字的任務內的(非嚴格獨熱)編碼磕瓷,因為全部為 0,因此在計算分類交叉熵的時候損失為 0念逞,這是我們期望的困食;后一部分恰好是 G 在分類 A-Z(排除 I、O)這 24 個字母的任務內的獨熱編碼肮柜,正好用于計算分類交叉熵陷舅,也是我們期望的倒彰,可見統(tǒng)一分配類標號在計算損失時是非常方便的审洞。了解了這一點之后,損失函數(shù)的定義如下:

 def loss(self, prediction_dict, groundtruth_lists):
    """Compute scalar loss tensors with respect to provided groundtruth.
        
    Args:
        prediction_dict: A dictionary holding prediction tensors.
        groundtruth_lists: A list of tensors holding groundtruth
            information, with one entry for each task.
                
    Returns:
        A dictionary mapping strings (loss names) to scalar tensors
            representing loss values.
    """
    onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
    for class_name in self.num_classes_dict:
        weights = tf.cast(tf.greater(
            tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
            dtype=tf.float32)
        slim.losses.softmax_cross_entropy(
            logits=prediction_dict[class_name], 
            onehot_labels=onehot_labels_dict[class_name],
            weights=weights,
            scope='Loss/' + class_name)
    loss = slim.losses.get_total_loss()
    loss_dict = {'loss': loss}
    return loss_dict
    
def _onehot_groundtruth_dict(self, groundtruth_lists):
    """Transform groundtruth lables to one-hot formats.
        
    Args:
        groundtruth_lists: A dict of tensors holding groundtruth
            information, with one entry for task.
                
    Returns:
        onehot_labels_dict: A dictionary mapping strings (class names) 
            to one-hot lable tensors.
    """
    one_hot = tf.one_hot(
        groundtruth_lists, depth=sum(self.num_classes_dict.values()))
    onehot_labels_dict = {}
    start_index = 0
    for class_name in self._class_order:
        onehot_labels_dict[class_name] = tf.slice(
            one_hot, [0, start_index], 
            [-1, self.num_classes_dict[class_name]])
        start_index += self.num_classes_dict[class_name]
    return onehot_labels_dict

其中待讳,函數(shù) _onehot_groundtruth_dict 用于將統(tǒng)一分配的類標號對應的獨熱編碼分為數(shù)字和字母這兩個任務對應的兩個獨熱編碼芒澜,之后的 loss 函數(shù)就可以用來計算正常的分類交叉熵損失。為了確保全 0 的獨熱編碼對應 0 的損失创淡,定義了 weights 這一個變量痴晦,它的作用是:當編碼為全 0 時,該樣本對應的損失權重為 0琳彩,因此貢獻的損失為 0誊酌,即不屬于這個分類任務的樣本對這個分類任務的損失貢獻為 0(雖然理論上全 0 的獨熱編碼對應的分類交叉熵為 0,但為了確保這點而不出現(xiàn)意外露乏,weights 是非常必要的)碧浊。

????????至于,準確率函數(shù)的定義則更簡單瘟仿,想法如下:當一張圖像經(jīng)過神經(jīng)網(wǎng)絡預測后箱锐,我們得到兩個分支任務的概率輸出,我們不關心它來源于哪個任務劳较,因為這不影響準確率的計算驹止;分別對兩個任務的概率輸出取 tf.argmax 得到在每個任務內的預測類標號,然后對這兩個預測的類標號再計算它在對應任務內的獨熱編碼观蜗,把這兩個獨熱編碼與上面計算損失時切割得到的兩個獨熱編碼分別按對應元素求和臊恋,如果求和結果中出現(xiàn) 2 說明預測結果正確,否則錯誤墓捻;對一個批量中的所有圖像累計處理之后抖仅,即可算出準確率。繼續(xù)上面的例子,前面已經(jīng)說過岸售,G 的類標號 16 對應的獨熱編碼一分為二的結果為:

0 0 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

假如現(xiàn)在神經(jīng)網(wǎng)絡的兩個分支預測的類標號分別為 1 和 6践樱,那么它們分別對應獨熱編碼:

0 1 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

以上獨熱編碼按位置對應相加,得到:

0 1 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

前一個結果(即 0 1 0 0 0 0 0 0 0 0)所有位置上都沒有出現(xiàn) 2凸丸,說明預測和實際的類標號沒有重合拷邢,對準確率沒有產(chǎn)生作用;后一個結果(即 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0) 中屎慢,第 6 個索引位置出現(xiàn) 2 說明預測和實際的類標號是一樣的瞭稼,因此預測正確,預測正確數(shù)加 1腻惠。顯然环肘,每一張圖像要么加 0 (兩個任務都預測錯誤)要么加 1(其中一個任務預測正確),因此這樣計算準確率是正確的(不可能加 2集灌,因為實際的兩個獨熱編碼中悔雹,其中的一個全是 0)。詳細的細節(jié)請參考如下完整代碼(將其命名為 model.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from abc import ABCMeta
from abc import abstractmethod

slim = tf.contrib.slim


class BaseModel(object):
    """Abstract base class for any model."""
    __metaclass__ = ABCMeta
    
    def __init__(self, num_classes_dict):
        """Constructor.
        
        Args:
            num_classes: Number of classes.
        """
        self._num_classes_dict = num_classes_dict
        
    @property
    def num_classes_dict(self):
        return self._num_classes_dict
    
    @abstractmethod
    def preprocess(self, inputs):
        """Input preprocessing. To be override by implementations.
        
        Args:
            inputs: A float32 tensor with shape [batch_size, height, width,
                num_channels] representing a batch of images.
            
        Returns:
            preprocessed_inputs: A float32 tensor with shape [batch_size, 
                height, widht, num_channels] representing a batch of images.
        """
        pass
    
    @abstractmethod
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        pass
    
    @abstractmethod
    def postprocess(self, prediction_dict, **params):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        pass
    
    @abstractmethod
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists: A list of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        pass
    
        
class Model(BaseModel):
    """xxx definition."""
    
    def __init__(self,
                 is_training,
                 num_classes_dict={'digits': 10, 'letters': 24}):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: Number of classes.
        """
        super(Model, self).__init__(num_classes_dict=num_classes_dict)
        
        self._is_training = is_training
        self._class_order = ['digits', 'letters']
        
    def preprocess(self, inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        preprocessed_inputs = tf.to_float(inputs)
        preprocessed_inputs = tf.subtract(preprocessed_inputs, 128.0)
        preprocessed_inputs = tf.div(preprocessed_inputs, 128.0)
        return preprocessed_inputs
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        net = preprocessed_inputs
        net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
        net = slim.max_pool2d(net, [2, 2], scope='pool1')
        net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
        net = slim.max_pool2d(net, [2, 2], scope='pool2')
        net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
        net = slim.flatten(net, scope='flatten')
        net = slim.dropout(net, keep_prob=0.5,
                           is_training=self._is_training)
        net = slim.fully_connected(net, 512, scope='fc1')
        net = slim.fully_connected(net, 512, scope='fc2')
        prediction_dict = {}
        for class_name, num_classes in self.num_classes_dict.items():
            logits = slim.fully_connected(net, num_outputs=num_classes, 
                                          activation_fn=None, 
                                          scope='Predict/' + class_name)
            prediction_dict[class_name] = logits
        return prediction_dict
    
    def postprocess(self, prediction_dict):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        postprecessed_dict = {}
        for class_name in self.num_classes_dict:
            logits = prediction_dict[class_name]
#            logits = tf.nn.softmax(logits, name=class_name)
            postprecessed_dict[class_name] = logits
        return postprecessed_dict
    
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists: A list of tensors holding groundtruth
                information, with one entry for each task.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
        for class_name in self.num_classes_dict:
            weights = tf.cast(tf.greater(
                tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
                dtype=tf.float32)
            slim.losses.softmax_cross_entropy(
                logits=prediction_dict[class_name], 
                onehot_labels=onehot_labels_dict[class_name],
                weights=weights,
                scope='Loss/' + class_name)
        loss = slim.losses.get_total_loss()
        loss_dict = {'loss': loss}
        return loss_dict
    
    def _onehot_groundtruth_dict(self, groundtruth_lists):
        """Transform groundtruth lables to one-hot formats.
        
        Args:
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for task.
                
        Returns:
            onehot_labels_dict: A dictionary mapping strings (class names) 
                to one-hot lable tensors.
        """
        one_hot = tf.one_hot(
            groundtruth_lists, depth=sum(self.num_classes_dict.values()))
        onehot_labels_dict = {}
        start_index = 0
        for class_name in self._class_order:
            onehot_labels_dict[class_name] = tf.slice(
                one_hot, [0, start_index], 
                [-1, self.num_classes_dict[class_name]])
            start_index += self.num_classes_dict[class_name]
        return onehot_labels_dict
    
    def accuracy(self, postprocessed_dict, groundtruth_lists):
        """Calculate accuracy.
        
        Args:
            postprocessed_dict: A dictionary containing the postprocessed 
                results
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            accuracy: The scalar accuracy.
        """
        onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
        num_corrections = 0.
        for class_name in self.num_classes_dict:
            predicted_argmax = tf.argmax(tf.nn.softmax(
                postprocessed_dict[class_name]), axis=1)
            onehot_predicted = tf.one_hot(
                predicted_argmax, depth=self.num_classes_dict[class_name])
            onehot_sum = tf.add(onehot_labels_dict[class_name],
                                onehot_predicted)
            correct = tf.greater(onehot_sum, 1)
            num = tf.reduce_sum(tf.cast(correct, tf.float32))
            num_corrections += num
        total_nums = tf.cast(tf.shape(groundtruth_lists)[0], dtype=tf.float32)
        accuracy = num_corrections / total_nums
        return accuracy 

????????在定義 postprocess 函數(shù)時欣喧,我把語句:

logits = tf.nn.softmax(logits, name=class_name)

注釋掉了(這顯得這個函數(shù)沒有任何用處)腌零,我的本意是為了觀察 predict 函數(shù)中兩個網(wǎng)絡分支的最本原輸出,主要考慮的是:當一張圖片送到網(wǎng)絡入口時唆阿,如果根本不知道它是數(shù)字還是字母益涧,那么經(jīng)過神經(jīng)網(wǎng)絡處理后,我們面臨著兩個任務的輸出驯鳖,要怎么判斷它屬于哪個任務中的哪個標簽呢闲询?如果我們已經(jīng)知道這張圖像來源于其中某一個任務,比如來源于數(shù)字任務浅辙,那么直接對數(shù)字任務分支的輸出取 tf.argmax 就知道它對應的預測標簽了扭弧。但現(xiàn)在的關鍵問題是,如果不知道它屬于其中哪個任務摔握,能否根據(jù)兩個分支的輸出直接判斷出來呢寄狼?答案是可以的,盡管這是基于經(jīng)驗觀察的氨淌。通過模型訓練并導出為 .pb 文件之后泊愧,運行 evaluate.py 文件(很多次),可以觀察兩個分支的直接輸出盛正,你會發(fā)現(xiàn)兩個任務中所有這些輸出的最大值對應的標簽就是網(wǎng)絡的預測輸出删咱,也就是說:可以通過比較兩個任務的所有輸出,來預測圖像來源于哪個任務(進而預測屬于哪個標簽)——所有輸出的值中豪筝,最大值所在的任務就可以認為是圖像來源的任務痰滋。

二摘能、模型訓練與保存

????????因為模型訓練的代碼和文章 TensorFlow-slim 訓練 CNN 分類模型(續(xù))train.py 的是一樣的,這里直接忽略(也可以訪問 github:multi_task_test 獲取本文所有代碼)敲街。

????????當你獲取到代碼后团搞,首先在項目當前目錄下新建文件夾 datasets/images,然后在當前目錄下的終端運行

python3 generate_train_data.py

生成 100000 張訓練圖像多艇。之后逻恐,繼續(xù)運行

python3 generate_tfrecord.py \
    --images_path ./datasets/images/ \
    --output_path ./datasets/train.record

得到訓練的 .record 文件。 此時峻黍,在項目目錄下再新建文件夾 training复隆,接著在終端執(zhí)行如下命令

python3 train.py --record_path ./datasets/train.record --logdir ./training/

便開始了訓練過程。如果你要可視化的觀看損失和準確率的變化情況姆涩,在當前目錄下的終端執(zhí)行

tensorboard --logdir ./training/

得到本地瀏覽器鏈接挽拂,打開這個鏈接即可監(jiān)控訓練的全過程。比如骨饿,我訓練 5000 多次之后亏栈,準確率和損失的圖像如下:

Tensorboard 顯示的準確率和損失曲線

????????當你覺得訓練的準確率已經(jīng)足夠高了,并且文件夾 training 中也保存好了當前訓練次數(shù)的模型文件之后样刷,使用 Ctrl + C 中斷訓練過程仑扑。接下來,就是將 training 中的訓練模型文件 .ckpt 轉化為 .pb 文件置鼻,然后測試訓練效果了。有關自定義的將 .ckpt 格式轉化為 .pb 格式的模型文件請訪問文章 TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式蜓竹。在那篇文章中箕母,已經(jīng)指出,需要針對不同的分類模型做出改變的地方主要是包含 model 參數(shù)的那些函數(shù)俱济,尤其是由輸入得到輸出的函數(shù) _add_output_tensor_nodes嘶是。比如,我們這篇文章有兩個分支任務的輸出蛛碌,對應的函數(shù) _add_output_tensor_nodes 修改為:

def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
    """Adds output nodes.
    
    Adjust according to specified implementations.
    
    Adds the following nodes for output tensors:
        * classes: A float32 tensor of shape [batch_size] containing class
            predictions.
    
    Args:
        postprocessed_tensors: A dictionary containing the following fields:
            'classes': [batch_size].
        output_collection_name: Name of collection to add output tensors to.
        
    Returns:
        A tensor dict containing the added output tensor nodes.
    """
    outputs = {}
    for class_name, logits in postprocessed_tensors.items():
        outputs[class_name] = tf.identity(logits, name=class_name)
    for output_key in outputs:
        tf.add_to_collection(output_collection_name, outputs[output_key])
    return outputs

其它函數(shù)不需要修改聂喇,完整文件請查看 github:multi_task_testexport.py 文件。然后蔚携,在項目的當前目錄終端執(zhí)行模型導出命令:

python3 export_inference_graph.py \
    --trained_checkpoint_prefix ./training/model.ckpt-5265 \
    --output_directory ./training/inference_graph_pb

你會在 training 文件夾中看到一個新的文件夾 inference_graph_pb希太,里面的文件 frozen_inference_graph.pb 就是我們用來做模型推斷的文件。上面一條命令中的 model.ckpt-5265 請根據(jù)你自己的訓練情況做修改酝蜒,這里我是只訓練了 5000 多次誊辉,然后使用訓練了 5265 次的模型用于圖像推斷。

????????當你一切都順利執(zhí)行之后亡脑,恭喜你來到最后一步堕澄,是時候驗證一下你訓練的模型的效果了邀跃。寫個簡單的模型驗證文件 evaluate.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  2 14:02:05 2018

@author: shirhe-lyh
"""

"""Evaluate the trained CNN model.

Example Usage:
---------------
python3 evaluate.py \

    --frozen_graph_path: Path to model frozen graph.
"""

import numpy as np
import tensorflow as tf

from captcha.image import ImageCaptcha

flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


def main(_):
    alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
                 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 
                 'U', 'V', 'W', 'X', 'Y', 'Z']
    
    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    
    with model_graph.as_default():
        with tf.Session(graph=model_graph) as sess:
            inputs = model_graph.get_tensor_by_name('image_tensor:0')
            digits = model_graph.get_tensor_by_name('digits:0')
            digit_classes = tf.argmax(tf.nn.softmax(digits), axis=1)
            letters = model_graph.get_tensor_by_name('letters:0')
            letter_classes = tf.argmax(tf.nn.softmax(letters), axis=1)
            for i in range(10):
                label = np.random.randint(0, 34)
                image = generate_captcha(alphabets[label])
                image_np = np.expand_dims(image, axis=0)
                predicted_ = sess.run([digits, digit_classes,
                                       letters, letter_classes], 
                                           feed_dict={inputs: image_np})
                predicted_digits = np.round(predicted_[0], 2)
                predicted_digit_classes = predicted_[1]
                predicted_letters = np.round(predicted_[2], 2)
                predicted_letter_classes = predicted_[3]
                print(predicted_digits, '----', predicted_digit_classes)
                print(predicted_letters, '----', predicted_letter_classes)
                predicted_label = predicted_letter_classes[0] + 10
                if label < 10:
                    predicted_label = predicted_digit_classes[0]
                print(alphabets[predicted_label], ' vs ', alphabets[label])
            
            
if __name__ == '__main__':
    tf.app.run()

在終端執(zhí)行如下命令,進行模型評估:

python3 evaluate.py \
    --frozen_graph_path ./training/inference_graph_pb/frozen_inference_graph.pb

你可以仔細的觀察最后兩個分支的直接輸出蛙紫,看看最大值對應的那個任務是否恰好是驗證圖像實際來源的任務拍屑。

預告:下一篇文章將要介紹如何用 TensorFlow 實現(xiàn) 生成對抗網(wǎng)絡,敬請期待坑傅!

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末丽涩,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子裁蚁,更是在濱河造成了極大的恐慌矢渊,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,546評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件枉证,死亡現(xiàn)場離奇詭異矮男,居然都是意外死亡,警方通過查閱死者的電腦和手機室谚,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,224評論 3 395
  • 文/潘曉璐 我一進店門毡鉴,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人秒赤,你說我怎么就攤上這事猪瞬。” “怎么了入篮?”我有些...
    開封第一講書人閱讀 164,911評論 0 354
  • 文/不壞的土叔 我叫張陵陈瘦,是天一觀的道長。 經(jīng)常有香客問我潮售,道長痊项,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,737評論 1 294
  • 正文 為了忘掉前任酥诽,我火速辦了婚禮鞍泉,結果婚禮上,老公的妹妹穿的比我還像新娘肮帐。我一直安慰自己咖驮,他們只是感情好,可當我...
    茶點故事閱讀 67,753評論 6 392
  • 文/花漫 我一把揭開白布训枢。 她就那樣靜靜地躺著托修,像睡著了一般。 火紅的嫁衣襯著肌膚如雪肮砾。 梳的紋絲不亂的頭發(fā)上诀黍,一...
    開封第一講書人閱讀 51,598評論 1 305
  • 那天,我揣著相機與錄音仗处,去河邊找鬼眯勾。 笑死枣宫,一個胖子當著我的面吹牛,可吹牛的內容都是我干的吃环。 我是一名探鬼主播也颤,決...
    沈念sama閱讀 40,338評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼郁轻!你這毒婦竟也來了翅娶?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 39,249評論 0 276
  • 序言:老撾萬榮一對情侶失蹤好唯,失蹤者是張志新(化名)和其女友劉穎竭沫,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體骑篙,經(jīng)...
    沈念sama閱讀 45,696評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡蜕提,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,888評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了靶端。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片谎势。...
    茶點故事閱讀 40,013評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖杨名,靈堂內的尸體忽然破棺而出脏榆,到底是詐尸還是另有隱情,我是刑警寧澤台谍,帶...
    沈念sama閱讀 35,731評論 5 346
  • 正文 年R本政府宣布须喂,位于F島的核電站,受9級特大地震影響典唇,放射性物質發(fā)生泄漏镊折。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,348評論 3 330
  • 文/蒙蒙 一介衔、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧骂因,春花似錦炎咖、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,929評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至俄烁,卻和暖如春绸栅,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背页屠。 一陣腳步聲響...
    開封第一講書人閱讀 33,048評論 1 270
  • 我被黑心中介騙來泰國打工粹胯, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留蓖柔,地道東北人。 一個月前我還...
    沈念sama閱讀 48,203評論 3 370
  • 正文 我出身青樓风纠,卻偏偏與公主長得像况鸣,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子竹观,可洞房花燭夜當晚...
    茶點故事閱讀 44,960評論 2 355

推薦閱讀更多精彩內容