深入理解TensorFlow中的tf.metrics算子

前言

本文將深入介紹Tensorflow內(nèi)置的評估指標(biāo)算子塘娶。

  • tf.metrics.accuracy()

  • tf.metrics.precision()

  • tf.metrics.recall()

  • tf.metrics.mean_iou()

簡單起見葵姥,本文在示例中使用tf.metrics.accuracy()溯街,但它的模式以及它背后的原理將適用于所有評估指標(biāo)切蟋。如果您只想看到有關(guān)如何使用tf.metrics的示例代碼忱辅,請?zhí)D(zhuǎn)到5.1和5.2節(jié)焰宣,如果您想要了解為何使用這種方式霉囚,請繼續(xù)閱讀。

這篇文章將通過一個非常簡單的代碼示例來理解tf.metrics的原理匕积,這里使用Numpy創(chuàng)建自己的評估指標(biāo)盈罐。這將有助于對Tensorflow中的評估指標(biāo)如何工作有一個很好的直覺認(rèn)識。然后闪唆,我們將給出如何采用tf.metrics快速實現(xiàn)同樣的功能盅粪。但首先,我先講述一下寫下這篇博客的由來悄蕾。

背景

這篇文章的由來是來自于我嘗試使用tf.metrics.mean_iou評估指標(biāo)進行圖像分割票顾,但卻獲得完全奇怪和不正確的結(jié)果。我花了一天半的時間來弄清楚我哪里出錯了帆调。你會發(fā)現(xiàn)奠骄,自己可能會非常容易錯誤地使用tf的評估指標(biāo)。截至2017年9月11日番刊,tensorflow文檔并沒有非常清楚地介紹如何正確使用Tensorflow的評估指標(biāo)含鳞。

因此,這篇文章旨在幫助其他人避免同樣的錯誤芹务,并且深入理解其背后的原理蝉绷,以便了解如何正確地使用它們。

生成數(shù)據(jù)

在我們開始使用任何評估指標(biāo)之前枣抱,讓我們先從簡單的數(shù)據(jù)開始熔吗。我們將使用以下Numpy數(shù)組作為我們預(yù)測的標(biāo)簽和真實標(biāo)簽。數(shù)組的每一行視為一個batch佳晶,因此這個例子中共有4個batch桅狠。

import numpy as np
labels = np.array([[1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0]], dtype=np.uint8)
predictions = np.array([[1,0,0,0],
                        [1,1,0,0],
                        [1,1,1,0],
                        [0,1,1,1]], dtype=np.uint8)
n_batches = len(labels)

建立評價指標(biāo)

為了簡單起見,這里采用的評估指標(biāo)是準(zhǔn)確度(accuracy):

如果我們想計算整個數(shù)據(jù)集上的accuracy,可以這樣計算:

n_items = labels.size
accuracy = (labels ==  predictions).sum() / n_items
print("Accuracy :", accuracy)
[OUTPUT]
Accuracy : 0.6875

這種方法的問題在于它不能擴展到大型數(shù)據(jù)集垂攘,這些數(shù)據(jù)集太大而無法一次性加載到內(nèi)存维雇。為了使其可擴展,我們希望使評估指標(biāo)能夠逐步更新晒他,每次更新一個batch中預(yù)測值和標(biāo)簽吱型。為此,我們需要跟蹤兩個值:

  • 正確預(yù)測的例子總和

  • 目前所有例子的總數(shù)

在Python中陨仅,我們創(chuàng)建兩個全局變量:

# Initialize running variables
N_CORRECT = 0
N_ITEMS_SEEN = 0

每次新來一個batch津滞,我們將這個batch中的預(yù)測情況更新到這兩個變量中:

# Update running variables
N_CORRECT += (batch_labels == batch_predictions).sum()
N_ITEMS_SEEN += batch_labels.size

而且,我們可以實時地計算每個點處的accuracy:

# Calculate accuracy on updated values
acc = float(N_CORRECT) / N_ITEMS_SEEN

合并前面的功能灼伤,我們創(chuàng)建如下的代碼:

# Create running variables
N_CORRECT = 0
N_ITEMS_SEEN = 0
def reset_running_variables():
    """ Resets the previous values of running variables to zero     """
    global N_CORRECT, N_ITEMS_SEEN
    N_CORRECT = 0
    N_ITEMS_SEEN = 0
def update_running_variables(labs, preds):
    global N_CORRECT, N_ITEMS_SEEN
    N_CORRECT += (labs == preds).sum()
    N_ITEMS_SEEN += labs.size
def calculate_accuracy():
    global N_CORRECT, N_ITEMS_SEEN
    return float(N_CORRECT) / N_ITEMS_SEEN

4.1 整體accuracy

使用上面的函數(shù)触徐,當(dāng)我們便利完所有的batch之后,可以計算出整體accuracy:

reset_running_variables()
for i in range(n_batches):
    update_running_variables(labs=labels[i], preds=predictions[i])
accuracy = calculate_accuracy()
print("[NP] SCORE: ", accuracy)
[OUTPUT]
[NP] SCORE:  0.6875

4.2 每個batch的accuracy

但是狐赡,如果我們想要計算每個batch的accuracy撞鹉,那就要重新組織我們的代碼了。每次更新全局變量之前颖侄,你需要先重置它們(歸為0):

for i in range(n_batches):
    reset_running_variables()
    update_running_variables(labs=labels[i], preds=predictions[i])
    acc = calculate_accuracy()
    print("- [NP] batch {} score: {}".format(i, acc))
[OUTPUT]
- [NP] batch 0 score: 0.5
- [NP] batch 1 score: 0.75
- [NP] batch 2 score: 1.0
- [NP] batch 3 score: 0.5

Tensorflow中的metrics

在第4節(jié)中我們將計算評估指標(biāo)的操作拆分為不同函數(shù)鸟雏,這其實與Tensorflow中tf.metrics背后原理是一樣的。當(dāng)我們調(diào)用tf.metrics.accuracy函數(shù)時览祖,類似的事情會發(fā)生:

  • 會同樣地創(chuàng)建兩個變量(變量會加入tf.GraphKeys.LOCAL_VARIABLES集合中)孝鹊,并將其放入幕后的計算圖中:

    total(相當(dāng)于N_CORRECT)

    count(相當(dāng)于N_ITEMS_SEEN)

  • 返回兩個tensorflow操作。

    accuracy(相當(dāng)于calculate_accuracy())

    update_op(相當(dāng)于update_running_variables())

為了初始化和重置變量展蒂,比如第4節(jié)中的reset_running_variables函數(shù)又活,我們首先需要獲得這些變量(total和count)。你可以在第一次調(diào)用時為tf.metrics.accuracy函數(shù)顯式指定一個名稱锰悼,比如:

tf.metrics.accuracy(label, prediction, name="my_metric")

然后就可以根據(jù)作用范圍找到隱式創(chuàng)建的2個變量:

# Isolate the variables stored behind the scenes by the metric operation
running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")
<tf.Variable 'my_metric/total:0' shape=() dtype=float32_ref>,
<tf.Variable 'my_metric/count:0' shape=() dtype=float32_ref>

接下了我們可以創(chuàng)建一個初始化操作柳骄,以可以初始化或者重置兩個變量:

running_vars_initializer = tf.variables_initializer(var_list=running_vars)

當(dāng)你需要初始化或者重置變量時,只需要在session中運行一下即可:

session.run(running_vars_initializer)

注意:除了手動分離變量松捉,然后創(chuàng)建初始化op夹界,在TF中更常用的是下面的操作:

session.run(tf.local_variables_initializer())

所以馆里,有時候你看到上面的操作不要大驚小怪隘世,其實只是初始化了在tf.GraphKeys.LOCAL_VARIABLES集合中的變量,但是這樣做把所以變量都初始化了鸠踪,使用時要特別注意丙者。

知道上面的東西,我們很容易計算整體accuracy和batch中的accuracy营密。

5.1 計算整體accuracy

在TF中要計算整體accuracy械媒,只需要如此:

import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
    # Placeholders to take in batches onf data
    tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
    tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])
    # Define the metric and update operations
    tf_metric, tf_metric_update = tf.metrics.accuracy(tf_label,
                                                      tf_prediction,
                                                      name="my_metric")
    # Isolate the variables stored behind the scenes by the metric operation
    running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")
    # Define initializer to initialize/reset running variables
    running_vars_initializer = tf.variables_initializer(var_list=running_vars)
with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())
    # initialize/reset the running variables
    session.run(running_vars_initializer)
    for i in range(n_batches):
        # Update the running variables on new batch of samples
        feed_dict={tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(tf_metric_update, feed_dict=feed_dict)
    # Calculate the score
    score = session.run(tf_metric)
    print("[TF] SCORE: ", score)
[OUTPUT]
[TF] SCORE:  0.6875

5.2 計算每個batch的accuracy

為了分別計算各個batch的準(zhǔn)確度,在每批新數(shù)據(jù)之前將變量重置為零:

with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())
    for i in range(n_batches):
        # Reset the running variables
        session.run(running_vars_initializer)
        # Update the running variables on new batch of samples
        feed_dict={tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(tf_metric_update, feed_dict=feed_dict)
        # Calculate the score on this batch
        score = session.run(tf_metric)
        print("[TF] batch {} score: {}".format(i, score))
[OUTPUT]
[TF] batch 0 score: 0.5
[TF] batch 1 score: 0.75
[TF] batch 2 score: 1.0
[TF] batch 3 score: 0.5

注意:如果每個batch計算之前不重置變量的話,其實計算的累積accuracy纷捞,就是目前已經(jīng)運行數(shù)據(jù)的accuracy痢虹。

5.3 要避免的問題

不要在相同的session.run()中同時運行tf_metrics和tf_metric_update,比如這樣:

_ , score = session.run([tf_metric_update, tf_metric], feed_dict=feed_dict)
score, _ = session.run([tf_metric, tf_metric_update], feed_dict=feed_dict)

在Tensorflow 1.3 (或許其它版本)中主儡,這可能得到不一致的結(jié)果奖唯。這兩個op,update_op才是真正負(fù)責(zé)更新變量糜值,而第一個op只是簡單根據(jù)當(dāng)前變量計算評價指標(biāo)丰捷,所以你應(yīng)該先執(zhí)行update_op,然后再用第一個op計算指標(biāo)寂汇。需要注意的病往,update_op執(zhí)行后一個作用是更新變量,另外會同時返回一個結(jié)果骄瓣,對于tf.metric.accuracy停巷,就是更新變量后實時計算的accuracy。

其它metrics

tf.metrics中的其他評估指標(biāo)將以相同的方式工作榕栏。它們之間的唯一區(qū)別可能是調(diào)用tf.metrics函數(shù)時需要額外參數(shù)叠穆。例如,tf.metrics.mean_iou需要額外的參數(shù)num_classes來表示預(yù)測的類別數(shù)臼膏。另一個區(qū)別是背后所創(chuàng)建的變量硼被,如tf.metrics.mean_iou創(chuàng)建的是一個混淆矩陣,但仍然可以按照我在本文第5部分中描述的方式收集和初始化它們渗磅。

結(jié)語

對于TF中所有metric嚷硫,其都是返回兩個op,一個是計算評價指標(biāo)的op始鱼,另外一個是更新op仔掸,這個op才是真正其更新作用的。我想之所以TF會采用這種方式医清,是因為metric所服務(wù)的其實是評估模型的時候起暮,此時你需要收集整個數(shù)據(jù)集上的預(yù)測結(jié)果,然后計算整體指標(biāo)会烙,而TF的metric這種設(shè)計恰好滿足這種需求负懦。但是在訓(xùn)練模型時使用它們,就是理解它的原理柏腻,才可以得到正確的結(jié)果纸厉。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市五嫂,隨后出現(xiàn)的幾起案子颗品,更是在濱河造成了極大的恐慌肯尺,老刑警劉巖,帶你破解...
    沈念sama閱讀 206,378評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件躯枢,死亡現(xiàn)場離奇詭異则吟,居然都是意外死亡,警方通過查閱死者的電腦和手機锄蹂,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,356評論 2 382
  • 文/潘曉璐 我一進店門逾滥,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人败匹,你說我怎么就攤上這事寨昙。” “怎么了掀亩?”我有些...
    開封第一講書人閱讀 152,702評論 0 342
  • 文/不壞的土叔 我叫張陵舔哪,是天一觀的道長。 經(jīng)常有香客問我槽棍,道長捉蚤,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,259評論 1 279
  • 正文 為了忘掉前任炼七,我火速辦了婚禮缆巧,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘豌拙。我一直安慰自己陕悬,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 64,263評論 5 371
  • 文/花漫 我一把揭開白布按傅。 她就那樣靜靜地躺著捉超,像睡著了一般。 火紅的嫁衣襯著肌膚如雪唯绍。 梳的紋絲不亂的頭發(fā)上拼岳,一...
    開封第一講書人閱讀 49,036評論 1 285
  • 那天,我揣著相機與錄音况芒,去河邊找鬼惜纸。 笑死,一個胖子當(dāng)著我的面吹牛绝骚,可吹牛的內(nèi)容都是我干的耐版。 我是一名探鬼主播,決...
    沈念sama閱讀 38,349評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼皮壁,長吁一口氣:“原來是場噩夢啊……” “哼椭更!你這毒婦竟也來了哪审?” 一聲冷哼從身側(cè)響起蛾魄,我...
    開封第一講書人閱讀 36,979評論 0 259
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后滴须,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體舌狗,經(jīng)...
    沈念sama閱讀 43,469評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 35,938評論 2 323
  • 正文 我和宋清朗相戀三年扔水,在試婚紗的時候發(fā)現(xiàn)自己被綠了痛侍。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 38,059評論 1 333
  • 序言:一個原本活蹦亂跳的男人離奇死亡魔市,死狀恐怖主届,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情待德,我是刑警寧澤君丁,帶...
    沈念sama閱讀 33,703評論 4 323
  • 正文 年R本政府宣布,位于F島的核電站将宪,受9級特大地震影響绘闷,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜较坛,卻給世界環(huán)境...
    茶點故事閱讀 39,257評論 3 307
  • 文/蒙蒙 一印蔗、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧丑勤,春花似錦华嘹、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,262評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至爪喘,卻和暖如春颜曾,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背秉剑。 一陣腳步聲響...
    開封第一講書人閱讀 31,485評論 1 262
  • 我被黑心中介騙來泰國打工泛豪, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人侦鹏。 一個月前我還...
    沈念sama閱讀 45,501評論 2 354
  • 正文 我出身青樓诡曙,卻偏偏與公主長得像,于是被迫代替她去往敵國和親略水。 傳聞我的和親對象是個殘疾皇子价卤,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 42,792評論 2 345

推薦閱讀更多精彩內(nèi)容