????????在學習了使用 TensorFlow 的 CNN 進行圖像分類之后询枚,現(xiàn)在對這些方法做一個簡單的拓展违帆,即來處理多任務多標簽的情形浙巫。為了便于說明金蜀,我們假設現(xiàn)在要對 0-9 這 10 個數(shù)字 和 A-Z (排除 I、O) 這 24 個字母進行識別的畴,所有的數(shù)據(jù)都使用 captcha 生成(讀過 TensorFlow 訓練 CNN 分類器 這篇文章的讀者應該不陌生了)渊抄。以下的代碼(命名為 generate_train_data.py)使用 captcha 生成了 100000 萬張 28 x 28 的圖像,每張圖像都是帶有大量噪聲的一個字符(所有字符見下面代碼中的 alphabets 列表丧裁,所有的圖像保存在文件夾 ./datasets/images 中护桦,每張圖像命名為 image圖像序號_類標號.jpg,其中的類標號為該字符在列表 alphabets 中的下標)煎娇。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018
@author: shirhe-lyh
"""
import cv2
import numpy as np
from captcha.image import ImageCaptcha
def generate_captcha(text='1'):
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
if __name__ == '__main__':
output_dir = './datasets/images/'
alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z']
for i in range(100000):
label = np.random.randint(0, 34)
image = generate_captcha(alphabets[label])
image_name = 'image{}_{}.jpg'.format(i+1, label)
output_path = output_dir + image_name
cv2.imwrite(output_path, image)
我們的目的是訓練一個簡單的 CNN 模型將對這些圖像進行分類二庵,由于這個問題很簡單,直接訓練一個 34 類的分類器就達成目標了缓呛。但類別數(shù)越大催享,訓練就越困難,因此我們采取另一種分化的策略哟绊,將這個 34 類的問題分為兩個子問題因妙,分別是:1.只識別數(shù)字;2.只識別字母票髓。之所以可以這么分攀涵,是因為 數(shù)字 和 字母 的差別很大,完全可以認為它們屬于兩種不同的范疇洽沟,從而可以看成獨立的分類任務來處理以故。這樣我們現(xiàn)在的問題是:怎樣同時識別 10 個數(shù)字和 24 個字母?這是一個多任務多標簽問題:我們要處理識別數(shù)字和識別字母這兩個任務裆操,其中每個任務都是涉及多個標簽(分別是 10 個標簽和 24 個標簽)怒详。
????????雖然這篇文章舉例的這個問題非常簡單,但這個方法(再加上預訓練模型技巧)可以用于更加復雜的問題跷车,比如 阿里的 FashionAI 服飾屬性識別全球挑戰(zhàn)賽棘利,感興趣的朋友可以用 ResNet-50 預訓練模型去微調一個 8 任務模型。
????????本文的所有代碼見 github:multi_task_test朽缴,歡迎訪問交流并反饋問題善玫!
一、多分支 CNN 模型定義
????????雖然我們要處理的是兩個獨立的任務,但我們希望這兩個任務共用大部分的神經(jīng)網(wǎng)絡層茅郎,這樣既可以節(jié)省計算量蜗元,一般來說,也可以提升準確率系冗。因此奕扣,我們將要定義的神經(jīng)網(wǎng)絡結構設計為(所有共用的層在文章 TensorFlow-slim 訓練 CNN 分類模型 中用來識別 0-9 這 10 個數(shù)字):
當獲取了一張圖像(數(shù)字或字母)之后,將它送入第一個卷積層(conv1)掌敬、第二個卷積層(conv2)惯豆、······,直到第二個全連接層(fc2)奔害,到此為止楷兽,這些層都是兩個任務共用的,它們的作用是用來提取圖像特征华临。然后芯杀,針對兩個不同的任務,將網(wǎng)絡分為兩個分支雅潭,一個用于輸出該圖像是各個數(shù)字的概率(digits_output)揭厚,另一個用于輸出該圖像是各個字母的概率(letters_output)。網(wǎng)絡的具體定義如下(網(wǎng)絡各層的名字可能和上圖不一致):
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
net = preprocessed_inputs
net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.5,
is_training=self._is_training)
net = slim.fully_connected(net, 512, scope='fc1')
net = slim.fully_connected(net, 512, scope='fc2')
prediction_dict = {}
for class_name, num_classes in self.num_classes_dict.items():
logits = slim.fully_connected(net, num_outputs=num_classes,
activation_fn=None,
scope='Predict/' + class_name)
prediction_dict[class_name] = logits
return prediction_dict
????????從以上代碼可以看到扶供,多任務多標簽任務的 CNN 定義也非常簡單筛圆,只需要引入一個 for 循環(huán)即可。接下來诚欠,要定義損失函數(shù)和準確率函數(shù)顽染。
????????在生成圖像的時候,圖片名字命名的模式是 image圖像序號_類標號.jpg轰绵,比如粉寞,假設第 1 張圖像是字母 G,那么它的類標號是 16 = 10 + 7 - 1左腔,因此它的名字是 image1_16.jpg唧垦。但這個類標號 16 是基于所有 34 個類來說的,實際上液样,如果只限于字母來說振亮,它的類標號應該是 6。之所以對數(shù)字和字母使用統(tǒng)一的類標號鞭莽,其實是為了便于定義損失和準確率函數(shù)坊秸。原因在于:對字母 G,因為我們現(xiàn)在是獨立處理數(shù)字和字母兩個分支任務澎怒,因此 G 應該只對分類字母的分支貢獻損失褒搔,而不應當對分類數(shù)字的分支產(chǎn)生損失。如果統(tǒng)一對數(shù)字和字母分配類標號,那么 G 的類標號 16 的獨熱(one-hot)編碼是:
0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
其中的 - 是為了便于看清兩個任務的分界線星瘾,實際請忽略走孽。此時,在計算損失時琳状,將這個獨熱編碼一分為二:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
前一部分對應于 G 在分類 0-9 這 10 個數(shù)字的任務內的(非嚴格獨熱)編碼磕瓷,因為全部為 0,因此在計算分類交叉熵的時候損失為 0念逞,這是我們期望的困食;后一部分恰好是 G 在分類 A-Z(排除 I、O)這 24 個字母的任務內的獨熱編碼肮柜,正好用于計算分類交叉熵陷舅,也是我們期望的倒彰,可見統(tǒng)一分配類標號在計算損失時是非常方便的审洞。了解了這一點之后,損失函數(shù)的定義如下:
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each task.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
for class_name in self.num_classes_dict:
weights = tf.cast(tf.greater(
tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
dtype=tf.float32)
slim.losses.softmax_cross_entropy(
logits=prediction_dict[class_name],
onehot_labels=onehot_labels_dict[class_name],
weights=weights,
scope='Loss/' + class_name)
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def _onehot_groundtruth_dict(self, groundtruth_lists):
"""Transform groundtruth lables to one-hot formats.
Args:
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for task.
Returns:
onehot_labels_dict: A dictionary mapping strings (class names)
to one-hot lable tensors.
"""
one_hot = tf.one_hot(
groundtruth_lists, depth=sum(self.num_classes_dict.values()))
onehot_labels_dict = {}
start_index = 0
for class_name in self._class_order:
onehot_labels_dict[class_name] = tf.slice(
one_hot, [0, start_index],
[-1, self.num_classes_dict[class_name]])
start_index += self.num_classes_dict[class_name]
return onehot_labels_dict
其中待讳,函數(shù) _onehot_groundtruth_dict 用于將統(tǒng)一分配的類標號對應的獨熱編碼分為數(shù)字和字母這兩個任務對應的兩個獨熱編碼芒澜,之后的 loss 函數(shù)就可以用來計算正常的分類交叉熵損失。為了確保全 0 的獨熱編碼對應 0 的損失创淡,定義了 weights 這一個變量痴晦,它的作用是:當編碼為全 0 時,該樣本對應的損失權重為 0琳彩,因此貢獻的損失為 0誊酌,即不屬于這個分類任務的樣本對這個分類任務的損失貢獻為 0(雖然理論上全 0 的獨熱編碼對應的分類交叉熵為 0,但為了確保這點而不出現(xiàn)意外露乏,weights 是非常必要的)碧浊。
????????至于,準確率函數(shù)的定義則更簡單瘟仿,想法如下:當一張圖像經(jīng)過神經(jīng)網(wǎng)絡預測后箱锐,我們得到兩個分支任務的概率輸出,我們不關心它來源于哪個任務劳较,因為這不影響準確率的計算驹止;分別對兩個任務的概率輸出取 tf.argmax
得到在每個任務內的預測類標號,然后對這兩個預測的類標號再計算它在對應任務內的獨熱編碼观蜗,把這兩個獨熱編碼與上面計算損失時切割得到的兩個獨熱編碼分別按對應元素求和臊恋,如果求和結果中出現(xiàn) 2 說明預測結果正確,否則錯誤墓捻;對一個批量中的所有圖像累計處理之后抖仅,即可算出準確率。繼續(xù)上面的例子,前面已經(jīng)說過岸售,G 的類標號 16 對應的獨熱編碼一分為二的結果為:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
假如現(xiàn)在神經(jīng)網(wǎng)絡的兩個分支預測的類標號分別為 1 和 6践樱,那么它們分別對應獨熱編碼:
0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
以上獨熱編碼按位置對應相加,得到:
0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
前一個結果(即 0 1 0 0 0 0 0 0 0 0
)所有位置上都沒有出現(xiàn) 2凸丸,說明預測和實際的類標號沒有重合拷邢,對準確率沒有產(chǎn)生作用;后一個結果(即 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
) 中屎慢,第 6 個索引位置出現(xiàn) 2 說明預測和實際的類標號是一樣的瞭稼,因此預測正確,預測正確數(shù)加 1腻惠。顯然环肘,每一張圖像要么加 0 (兩個任務都預測錯誤)要么加 1(其中一個任務預測正確),因此這樣計算準確率是正確的(不可能加 2集灌,因為實際的兩個獨熱編碼中悔雹,其中的一個全是 0)。詳細的細節(jié)請參考如下完整代碼(將其命名為 model.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from abc import ABCMeta
from abc import abstractmethod
slim = tf.contrib.slim
class BaseModel(object):
"""Abstract base class for any model."""
__metaclass__ = ABCMeta
def __init__(self, num_classes_dict):
"""Constructor.
Args:
num_classes: Number of classes.
"""
self._num_classes_dict = num_classes_dict
@property
def num_classes_dict(self):
return self._num_classes_dict
@abstractmethod
def preprocess(self, inputs):
"""Input preprocessing. To be override by implementations.
Args:
inputs: A float32 tensor with shape [batch_size, height, width,
num_channels] representing a batch of images.
Returns:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, widht, num_channels] representing a batch of images.
"""
pass
@abstractmethod
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
pass
@abstractmethod
def postprocess(self, prediction_dict, **params):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
pass
@abstractmethod
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
pass
class Model(BaseModel):
"""xxx definition."""
def __init__(self,
is_training,
num_classes_dict={'digits': 10, 'letters': 24}):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: Number of classes.
"""
super(Model, self).__init__(num_classes_dict=num_classes_dict)
self._is_training = is_training
self._class_order = ['digits', 'letters']
def preprocess(self, inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
preprocessed_inputs = tf.to_float(inputs)
preprocessed_inputs = tf.subtract(preprocessed_inputs, 128.0)
preprocessed_inputs = tf.div(preprocessed_inputs, 128.0)
return preprocessed_inputs
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
net = preprocessed_inputs
net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.5,
is_training=self._is_training)
net = slim.fully_connected(net, 512, scope='fc1')
net = slim.fully_connected(net, 512, scope='fc2')
prediction_dict = {}
for class_name, num_classes in self.num_classes_dict.items():
logits = slim.fully_connected(net, num_outputs=num_classes,
activation_fn=None,
scope='Predict/' + class_name)
prediction_dict[class_name] = logits
return prediction_dict
def postprocess(self, prediction_dict):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
postprecessed_dict = {}
for class_name in self.num_classes_dict:
logits = prediction_dict[class_name]
# logits = tf.nn.softmax(logits, name=class_name)
postprecessed_dict[class_name] = logits
return postprecessed_dict
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each task.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
for class_name in self.num_classes_dict:
weights = tf.cast(tf.greater(
tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
dtype=tf.float32)
slim.losses.softmax_cross_entropy(
logits=prediction_dict[class_name],
onehot_labels=onehot_labels_dict[class_name],
weights=weights,
scope='Loss/' + class_name)
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def _onehot_groundtruth_dict(self, groundtruth_lists):
"""Transform groundtruth lables to one-hot formats.
Args:
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for task.
Returns:
onehot_labels_dict: A dictionary mapping strings (class names)
to one-hot lable tensors.
"""
one_hot = tf.one_hot(
groundtruth_lists, depth=sum(self.num_classes_dict.values()))
onehot_labels_dict = {}
start_index = 0
for class_name in self._class_order:
onehot_labels_dict[class_name] = tf.slice(
one_hot, [0, start_index],
[-1, self.num_classes_dict[class_name]])
start_index += self.num_classes_dict[class_name]
return onehot_labels_dict
def accuracy(self, postprocessed_dict, groundtruth_lists):
"""Calculate accuracy.
Args:
postprocessed_dict: A dictionary containing the postprocessed
results
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
accuracy: The scalar accuracy.
"""
onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
num_corrections = 0.
for class_name in self.num_classes_dict:
predicted_argmax = tf.argmax(tf.nn.softmax(
postprocessed_dict[class_name]), axis=1)
onehot_predicted = tf.one_hot(
predicted_argmax, depth=self.num_classes_dict[class_name])
onehot_sum = tf.add(onehot_labels_dict[class_name],
onehot_predicted)
correct = tf.greater(onehot_sum, 1)
num = tf.reduce_sum(tf.cast(correct, tf.float32))
num_corrections += num
total_nums = tf.cast(tf.shape(groundtruth_lists)[0], dtype=tf.float32)
accuracy = num_corrections / total_nums
return accuracy
????????在定義 postprocess 函數(shù)時欣喧,我把語句:
logits = tf.nn.softmax(logits, name=class_name)
注釋掉了(這顯得這個函數(shù)沒有任何用處)腌零,我的本意是為了觀察 predict 函數(shù)中兩個網(wǎng)絡分支的最本原輸出,主要考慮的是:當一張圖片送到網(wǎng)絡入口時唆阿,如果根本不知道它是數(shù)字還是字母益涧,那么經(jīng)過神經(jīng)網(wǎng)絡處理后,我們面臨著兩個任務的輸出驯鳖,要怎么判斷它屬于哪個任務中的哪個標簽呢闲询?如果我們已經(jīng)知道這張圖像來源于其中某一個任務,比如來源于數(shù)字任務浅辙,那么直接對數(shù)字任務分支的輸出取 tf.argmax 就知道它對應的預測標簽了扭弧。但現(xiàn)在的關鍵問題是,如果不知道它屬于其中哪個任務摔握,能否根據(jù)兩個分支的輸出直接判斷出來呢寄狼?答案是可以的,盡管這是基于經(jīng)驗觀察的氨淌。通過模型訓練并導出為 .pb 文件之后泊愧,運行 evaluate.py 文件(很多次),可以觀察兩個分支的直接輸出盛正,你會發(fā)現(xiàn)兩個任務中所有這些輸出的最大值對應的標簽就是網(wǎng)絡的預測輸出删咱,也就是說:可以通過比較兩個任務的所有輸出,來預測圖像來源于哪個任務(進而預測屬于哪個標簽)——所有輸出的值中豪筝,最大值所在的任務就可以認為是圖像來源的任務痰滋。
二摘能、模型訓練與保存
????????因為模型訓練的代碼和文章 TensorFlow-slim 訓練 CNN 分類模型(續(xù)) 中 train.py 的是一樣的,這里直接忽略(也可以訪問 github:multi_task_test 獲取本文所有代碼)敲街。
????????當你獲取到代碼后团搞,首先在項目當前目錄下新建文件夾 datasets/images,然后在當前目錄下的終端運行
python3 generate_train_data.py
生成 100000 張訓練圖像多艇。之后逻恐,繼續(xù)運行
python3 generate_tfrecord.py \
--images_path ./datasets/images/ \
--output_path ./datasets/train.record
得到訓練的 .record 文件。 此時峻黍,在項目目錄下再新建文件夾 training复隆,接著在終端執(zhí)行如下命令
python3 train.py --record_path ./datasets/train.record --logdir ./training/
便開始了訓練過程。如果你要可視化的觀看損失和準確率的變化情況姆涩,在當前目錄下的終端執(zhí)行
tensorboard --logdir ./training/
得到本地瀏覽器鏈接挽拂,打開這個鏈接即可監(jiān)控訓練的全過程。比如骨饿,我訓練 5000 多次之后亏栈,準確率和損失的圖像如下:
????????當你覺得訓練的準確率已經(jīng)足夠高了,并且文件夾 training 中也保存好了當前訓練次數(shù)的模型文件之后样刷,使用 Ctrl + C 中斷訓練過程仑扑。接下來,就是將 training 中的訓練模型文件 .ckpt 轉化為 .pb 文件置鼻,然后測試訓練效果了。有關自定義的將 .ckpt 格式轉化為 .pb 格式的模型文件請訪問文章 TensorFlow 自定義模型導出:將 .ckpt 格式轉化為 .pb 格式蜓竹。在那篇文章中箕母,已經(jīng)指出,需要針對不同的分類模型做出改變的地方主要是包含 model 參數(shù)的那些函數(shù)俱济,尤其是由輸入得到輸出的函數(shù) _add_output_tensor_nodes嘶是。比如,我們這篇文章有兩個分支任務的輸出蛛碌,對應的函數(shù) _add_output_tensor_nodes 修改為:
def _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name='inference_op'):
"""Adds output nodes.
Adjust according to specified implementations.
Adds the following nodes for output tensors:
* classes: A float32 tensor of shape [batch_size] containing class
predictions.
Args:
postprocessed_tensors: A dictionary containing the following fields:
'classes': [batch_size].
output_collection_name: Name of collection to add output tensors to.
Returns:
A tensor dict containing the added output tensor nodes.
"""
outputs = {}
for class_name, logits in postprocessed_tensors.items():
outputs[class_name] = tf.identity(logits, name=class_name)
for output_key in outputs:
tf.add_to_collection(output_collection_name, outputs[output_key])
return outputs
其它函數(shù)不需要修改聂喇,完整文件請查看 github:multi_task_test 的 export.py 文件。然后蔚携,在項目的當前目錄終端執(zhí)行模型導出命令:
python3 export_inference_graph.py \
--trained_checkpoint_prefix ./training/model.ckpt-5265 \
--output_directory ./training/inference_graph_pb
你會在 training 文件夾中看到一個新的文件夾 inference_graph_pb希太,里面的文件 frozen_inference_graph.pb 就是我們用來做模型推斷的文件。上面一條命令中的 model.ckpt-5265 請根據(jù)你自己的訓練情況做修改酝蜒,這里我是只訓練了 5000 多次誊辉,然后使用訓練了 5265 次的模型用于圖像推斷。
????????當你一切都順利執(zhí)行之后亡脑,恭喜你來到最后一步堕澄,是時候驗證一下你訓練的模型的效果了邀跃。寫個簡單的模型驗證文件 evaluate.py:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 2 14:02:05 2018
@author: shirhe-lyh
"""
"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 evaluate.py \
--frozen_graph_path: Path to model frozen graph.
"""
import numpy as np
import tensorflow as tf
from captcha.image import ImageCaptcha
flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS
def generate_captcha(text='1'):
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
def main(_):
alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z']
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
inputs = model_graph.get_tensor_by_name('image_tensor:0')
digits = model_graph.get_tensor_by_name('digits:0')
digit_classes = tf.argmax(tf.nn.softmax(digits), axis=1)
letters = model_graph.get_tensor_by_name('letters:0')
letter_classes = tf.argmax(tf.nn.softmax(letters), axis=1)
for i in range(10):
label = np.random.randint(0, 34)
image = generate_captcha(alphabets[label])
image_np = np.expand_dims(image, axis=0)
predicted_ = sess.run([digits, digit_classes,
letters, letter_classes],
feed_dict={inputs: image_np})
predicted_digits = np.round(predicted_[0], 2)
predicted_digit_classes = predicted_[1]
predicted_letters = np.round(predicted_[2], 2)
predicted_letter_classes = predicted_[3]
print(predicted_digits, '----', predicted_digit_classes)
print(predicted_letters, '----', predicted_letter_classes)
predicted_label = predicted_letter_classes[0] + 10
if label < 10:
predicted_label = predicted_digit_classes[0]
print(alphabets[predicted_label], ' vs ', alphabets[label])
if __name__ == '__main__':
tf.app.run()
在終端執(zhí)行如下命令,進行模型評估:
python3 evaluate.py \
--frozen_graph_path ./training/inference_graph_pb/frozen_inference_graph.pb
你可以仔細的觀察最后兩個分支的直接輸出蛙紫,看看最大值對應的那個任務是否恰好是驗證圖像實際來源的任務拍屑。
預告:下一篇文章將要介紹如何用 TensorFlow 實現(xiàn) 生成對抗網(wǎng)絡,敬請期待坑傅!