[0.3] 續(xù)--Tensorflow踩坑記之tf.metrics

【續(xù)】--Tensorflow踩坑記之tf.metrics

欠下的帳總歸還是要還的醋界,之前一直拖著服鹅,總是懶得寫tf.metrics這個(gè)API的一些用法混蔼,今天總算是克服了懶癌涧尿,總結(jié)一下tf.metrics遇到的一些坑系奉。

  • 博客中涉及到的所有代碼都已經(jīng)傳到瀾子的Github上啦~歡迎互粉哇圈盔。
  • 本篇博客也傳到了瀾子的個(gè)人博客匪蝙,歡迎大噶多多關(guān)注哇炊汹。

插一句閑話头谜,這一次的博客基本上用的都是 Jupyter,感覺一級(jí)好用啊脂新⊙崞可以一邊寫代碼槽地,一邊記markdown号阿,忍不住上一張效果圖并鸵,再次歡迎大噶去我的Github上看一看,而且Github支持 jupyter notebook 顯示扔涧,真得效果很好园担。

jupyter

在這篇偽Tensorflow-tf-metrics中届谈,瀾子介紹了tf.metrics中涉及的一些指標(biāo)和概念,包括:精確率(precision)弯汰,召回率(recall)艰山,準(zhǔn)確率(accuracy),AUC咏闪,混淆矩陣(confusion matrix)曙搬。下面先給出官方的API文檔,看看這個(gè)模塊中都有哪些隱藏秘笈鸽嫂。

看了官方文檔之后纵装,大噶可能會(huì)發(fā)現(xiàn)其中有好多可以調(diào)用的函數(shù),不僅有precision / accuracy/ auc/ recall溪胶,還有precision_at_k / recall_at_k搂擦,更有precision_at_thresholds/ precision_at_top_k/ sparse_precision_at_k...天啦嚕稳诚,這都是什么呀哗脖,瀾子已經(jīng)徹底暈了,到底要怎么用鞍饣埂(眼冒金星中)才避。別急,讓我一個(gè)坑一個(gè)坑地告訴你氨距。

劃重點(diǎn)

首先桑逝,這篇文章是受到Ronny Restrepo的啟發(fā),
這是一篇很好的文章俏让,將tf.metrics.accuracy()講解滴很清楚楞遏,本文就模仿他的思路,驗(yàn)證一下precision的計(jì)算首昔。

精確率的計(jì)算公式

Precision = \frac{truePositive}{truePositive + falsePositive}

讓我們先造點(diǎn)數(shù)據(jù)寡喝,傳統(tǒng)算算看

import tensorflow as tf
import numpy as np

labels = np.array([[1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0]], dtype=np.uint8)

predictions = np.array([[1,0,0,0],
                        [1,1,0,0],
                        [1,1,1,0],
                        [0,1,1,1]], dtype=np.uint8)

n_batches = len(labels)
# First,calculate precision over entire set of batches 
# using formula mentioned above
pred_p = (predictions > 0).sum()
# print(pred_p)
true_p = (labels*predictions > 0).sum()
# print(true_p)
precision = true_p / pred_p
print("Precision :%1.4f" %(precision))

上述方法的問題

由于硬件方面的一些限制,導(dǎo)致此方法不能擴(kuò)展到大型數(shù)據(jù)集勒奇,比如當(dāng)數(shù)據(jù)集很大時(shí)预鬓,就無法一次性適應(yīng)內(nèi)存。
因而赊颠,為了使其可擴(kuò)展格二,我們希望使評(píng)估指標(biāo)能夠逐步更新,每批新的預(yù)測(cè)和標(biāo)簽竣蹦。 為此顶猜,我們需要跟蹤兩個(gè)值。

  • 正確預(yù)測(cè)的正樣本數(shù)量
  • 預(yù)測(cè)樣本中所有正樣本的數(shù)量

所以我們要這么做

# Initialize running variables
N_TRUE_P = 0
N_PRED_P = 0

# Specific steps
# Create running variables
N_TRUE_P = 0
N_PRED_P = 0

def reset_running_variables():
    """ Resets the previous values of running variables to zero """
    global N_TRUE_P, N_PRED_P
    N_TRUE_P = 0
    c = 0

def update_running_variables(labs, preds):
    global N_TRUE_P, N_PRED_P
    N_TRUE_P += ((labs * preds) > 0).sum()
    N_PRED_P += (preds > 0).sum()

def calculate_precision():
    global N_TRUE_P, N_PRED_P
    return float (N_TRUE_P) / N_PRED_P

怎么用上面的函數(shù)呢痘括?

接下來的兩個(gè)例子长窄,給出了運(yùn)用的具體代碼,并且可以更好滴幫助我們理解tf.metrics.precision()的計(jì)算邏輯以及對(duì)應(yīng)輸出所代表的含義

樣本整體準(zhǔn)確率(直接計(jì)算)

# Overall precision
reset_running_variables()

for i in range(n_batches):
    update_running_variables(labs=labels[i], preds=predictions[i])

precision = calculate_precision()
print("[NP] SCORE: %1.4f" %precision)

批次準(zhǔn)確率(直接計(jì)算)

# Batch precision
for i in range(n_batches):
    reset_running_variables()
    update_running_variables(labs=labels[i], preds=predictions[i])
    prec = calculate_precision()
    print("- [NP] batch %d score: %1.4f" %(i, prec))
[NP] batch 0 score: 1.0000
[NP] batch 1 score: 1.0000
[NP] batch 2 score: 1.0000
[NP] batch 3 score: 0.6667

不要小瞧這兩個(gè)變量和三個(gè)函數(shù)

上面說了這么多,感覺沒有tensorflow的什么事哇抄淑,別急屠凶,先看一個(gè)tensorflow的官方文檔

放一個(gè)官方的解釋

The precision function creates two local variables,
true_positives and false_positives, that are used to compute the precision. This value is ultimately returned as precision, an idempotent operation that simply divides true_positives by the sum of true_positives and false_positives.
For estimation of the metric over a stream of data, the function creates an update_op operation that updates these variables and returns the precision.

兩個(gè)變量和 tf.metrics.precision()的關(guān)系

官方文檔提及的two local variablestrue_postivesfalse_positives分別對(duì)應(yīng)上文定義的兩個(gè)變量。

  • true_postives -- N_TRUE_P
  • false_postives -- N_PRED_P - N_TRUE_P

三個(gè)函數(shù)和頭大的update_op

官方文檔提及的update_opprecision分別對(duì)應(yīng)上文定義的兩個(gè)函數(shù)

  • precision--calculate_precision()
  • update_op--update_running_variables()

大家不要被這個(gè)update_op搞暈肆资,其實(shí)從字面來理解就是一個(gè)變量更新的操作矗愧,上文的代碼中,就是通過reset_running_variables()的位置來決定何時(shí)對(duì)變量進(jìn)行更新郑原,其實(shí)就是對(duì)應(yīng)于tf.variables_initializer()唉韭。我之所以一直用錯(cuò)這個(gè)API,是因?yàn)槲覍?code>tf.variables_initializer()放在了錯(cuò)誤的位置犯犁,導(dǎo)致變量沒有按照我的預(yù)期正常更新属愤,進(jìn)而結(jié)果一直不正確。具體看看tensorflow是怎么實(shí)現(xiàn)的吧酸役。

Overall precision using tensorflow

# Overall precision using tensorflow
import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    # Placeholders to take in batches onf data
    tf_label = tf.placeholder(dtype=tf.int32, shape=[None])
    tf_prediction = tf.placeholder(dtype=tf.int32, shape=[None])

    # Define the metric and update operations
    tf_metric, tf_metric_update = tf.metrics.precision(tf_label,
                                                      tf_prediction,
                                                      name="my_metric")

    # Isolate the variables stored behind the scenes by the metric operation
    running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")

    # Define initializer to initialize/reset running variables
    running_vars_initializer = tf.variables_initializer(var_list=running_vars)


with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())

    # initialize/reset the running variables
    session.run(running_vars_initializer)

    for i in range(n_batches):
        # Update the running variables on new batch of samples
        feed_dict={tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(tf_metric_update, feed_dict=feed_dict)

    # Calculate the score
    score = session.run(tf_metric)
    print("[TF] SCORE: %1.4f" %score)

[TF] SCORE: 0.8889

Batch precision using tensorflow

# Batch precision using tensorflow
with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())

    for i in range(n_batches):
        # Reset the running variables
        session.run(running_vars_initializer)

        # Update the running variables on new batch of samples
        feed_dict={tf_label: labels[i], tf_prediction: predictions[i]}
        session.run(tf_metric_update, feed_dict=feed_dict)

        # Calculate the score on this batch
        score = session.run(tf_metric)
        print("[TF] batch %d score: %1.4f" %(i, score))

[TF] batch 0 score: 1.0000
[TF] batch 1 score: 1.0000
[TF] batch 2 score: 1.0000
[TF] batch 3 score: 0.6667

再次劃重點(diǎn)

大噶一定要注意

session.run(running_vars_initializer)
score = session.run(tf_metric)

這兩行代碼在計(jì)算整體樣本精確度以及批次精確度所在位置的不同住诸。
瀾子第一次的時(shí)候由于粗心,并沒有注意兩段代碼的不同涣澡,才會(huì)導(dǎo)致tf計(jì)算結(jié)果普通計(jì)算結(jié)果不一致

還需要注意的點(diǎn)

不要在一個(gè)sess.run()里面同時(shí)調(diào)用tf_metrictf_metric_update贱呐,下面的代碼是錯(cuò)誤的示范

_ , score = session.run([tf_metric_update,tf_metric],\
                        feed_dict=feed_dict)

update_op究竟返回了什么捏

此處參考了
stackoverflow的一個(gè)回答

具體代碼如下

rel = tf.placeholder(tf.int64, [1,3])
rec = tf.constant([[7, 5, 10, 6, 3, 1, 8, 12, 31, 88]], tf.int64) 
precision, update_op = tf.metrics.precision_at_k(rel, rec, 10)

sess = tf.Session()
sess.run(tf.local_variables_initializer())

stream_vars = [i for i in tf.local_variables()]
#Get the local variables true_positive and false_positive

print("[PRECSION_1]: ",sess.run(precision, {rel:[[1,5,10]]})) # nan
#tf.metrics.precision maintains two variables true_positives 
#and  false_positives, each starts at zero.
#so the output at this step is 'nan'

print("[UPDATE_OP_1]:",sess.run(update_op, {rel:[[1,5,10]]})) #0.2
#when the update_op is called, it updates true_positives 
#and false_positives using labels and predictions.

print("[STREAM_VARS_1]:",sess.run(stream_vars)) #[2.0, 8.0]
# Get true positive rate and false positive rate

print("[PRECISION_1]:",sess.run(precision,{rel:[[1,10,15]]})) # 0.2
#So calling precision will use true_positives and false_positives and outputs 0.2

print("[UPDATE_OP_2]:",sess.run(update_op,{rel:[[1,10,15]]})) #0.15
#the update_op updates the values to the new calculated value 0.15.

print("[STREAM_VARS_2]:",sess.run(stream_vars)) #[3.0, 17.0]

[STREAM_VARS_1]: [0.0, 0.0, 0.0, 0.0, 2.0, 8.0]
[PRECISION_1]: 0.2
[UPDATE_OP_2]: 0.15
[STREAM_VARS_2]: [0.0, 0.0, 0.0, 0.0, 3.0, 17.0]

tf.metrics.precision_at_k

上面的代碼中,我們看到運(yùn)用的是tf.metrics.precision_at_k()這個(gè)API入桂,這里的k是什么呢奄薇?
首先,我們要理解一個(gè)概念抗愁,究竟什么是Precision at k馁蒂,這里有兩份資料,應(yīng)該能很好地幫助你理解這個(gè)概念蜘腌。
瀾子就是看了這兩份資料之后沫屡,理解了Precision at k的概念的。

然后我們來看看這個(gè)函數(shù)是怎么用的逢捺,第一步當(dāng)然要先看看輸入啦谁鳍。

tf.metrics.precision_at_k(
    labels,
    predictions,
    k,
    class_id=None,
    weights=None,
    metrics_collections=None,
    updates_collections=None,
    name=None
)

我們重點(diǎn)關(guān)注labels,predictions,k這三個(gè)參數(shù),應(yīng)該可以滿足日常簡(jiǎn)單地使用了劫瞳。
labels,predictions,k的輸入形式是什么樣的呢倘潜?

閑話不說,直接看看上面的栗子志于。栗子中rel其實(shí)對(duì)應(yīng)為labels涮因,rec對(duì)應(yīng)為predictions,那k又是什么意思呢伺绽?
劃重點(diǎn):這里的k表明你需要對(duì)多少個(gè)預(yù)測(cè)樣本進(jìn)行排序养泡。這樣說可能有一點(diǎn)抽象嗜湃,給一個(gè)解釋。

Precision@k = (Recommended items @k that are relevant) / (# Recommended items @k)

可以先去看一下Github澜掩,發(fā)現(xiàn)其實(shí)在tf.metrics.precision_at_k這個(gè)函數(shù)中购披,對(duì)于predictions會(huì)根據(jù)輸入的k值進(jìn)行top_k操作。
對(duì)應(yīng)上面的代碼中肩榕,當(dāng)k=10刚陡,即對(duì)rec = tf.constant([[7, 5, 10, 6, 3, 1, 8, 12, 31, 88]], tf.int64)
所有的樣本進(jìn)行排序,進(jìn)而在函數(shù)中實(shí)際運(yùn)用的是rec樣本數(shù)值從大到小排列的索引值株汉。這樣解釋應(yīng)該就能看懂上面代碼的意思了筐乳。

后來,瀾子又在

看到有人問怎么用tf.metrics.sparse_average_precision_at_k乔妈,就又去求是了一波蝙云,
還完成了知乎的技術(shù)首答以及stackoverflow上第一個(gè)贊
歡迎互粉知乎stackoverflow哇路召。下面給出栗子和簡(jiǎn)單解釋啦勃刨。

import tensorflow as tf
import numpy as np

y_true = np.array([[2], [1], [0], [3], [0]]).astype(np.int64)
y_true = tf.identity(y_true)

y_pred = np.array([[0.1, 0.2, 0.6, 0.1],
                   [0.8, 0.05, 0.1, 0.05],
                   [0.3, 0.4, 0.1, 0.2],
                   [0.6, 0.25, 0.1, 0.05],
                   [0.1, 0.2, 0.6, 0.1]
                   ]).astype(np.float32)
y_pred = tf.identity(y_pred)

_, m_ap = tf.metrics.sparse_average_precision_at_k(y_true, y_pred, 3)

sess = tf.Session()
sess.run(tf.local_variables_initializer())

stream_vars = [i for i in tf.local_variables()]

tf_map = sess.run(m_ap)
print("TF_MAP",tf_map)

print("STREAM_VARS",(sess.run(stream_vars)))

tmp_rank = tf.nn.top_k(y_pred,3)

print("TMP_RANK",sess.run(tmp_rank))

簡(jiǎn)單解釋一下

  • 首先y_true代表標(biāo)簽值(未經(jīng)過one-hot),shape:(batch_size, num_labels) ,y_pred代表預(yù)測(cè)值(logit值) 优训,shape:(batch_size, num_classes)

  • 其次朵你,要注意的是tf.metrics.sparse_average_precision_at_k中會(huì)采用top_k根據(jù)不同的k值對(duì)y_pred進(jìn)行排序操作 ,所以tmp_rank是為了幫助大噶理解究竟y_pred在函數(shù)中進(jìn)行了怎樣的轉(zhuǎn)換揣非。

  • 然后,stream_vars = [i for i in tf.local_variables()]這一行是為了幫助大噶理解 tf.metrics.sparse_average_precision_at_k創(chuàng)建的tf.local_varibles 實(shí)際輸出值躲因,進(jìn)而可以更好地理解這個(gè)函數(shù)的用法早敬。

  • 具體看這個(gè)例子,當(dāng)k=1時(shí)大脉,只有第一個(gè)batch的預(yù)測(cè)輸出是和標(biāo)簽匹配的 搞监,所以最終輸出為:1/6 = 0.166666 ;當(dāng)k=2時(shí)镰矿,除了第一個(gè)batch的預(yù)測(cè)輸出琐驴,第三個(gè)batch的預(yù)測(cè)輸出也是和標(biāo)簽匹配的,所以最終輸出為:(1+(1/2))/6 = 0.25秤标。

P.S:在以后的tf版本里绝淡,將tf.metrics.average_precision_at_k替代tf.metrics.sparse_average_precision_at_k

簡(jiǎn)直超累的苍姜,目測(cè)是最近的最后一篇博客啦牢酵,有什么錯(cuò)誤一定告訴我啦。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末衙猪,一起剝皮案震驚了整個(gè)濱河市馍乙,隨后出現(xiàn)的幾起案子布近,更是在濱河造成了極大的恐慌,老刑警劉巖丝格,帶你破解...
    沈念sama閱讀 218,858評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件撑瞧,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡显蝌,警方通過查閱死者的電腦和手機(jī)季蚂,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,372評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來琅束,“玉大人扭屁,你說我怎么就攤上這事∩鳎” “怎么了料滥?”我有些...
    開封第一講書人閱讀 165,282評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)艾船。 經(jīng)常有香客問我葵腹,道長(zhǎng),這世上最難降的妖魔是什么屿岂? 我笑而不...
    開封第一講書人閱讀 58,842評(píng)論 1 295
  • 正文 為了忘掉前任践宴,我火速辦了婚禮,結(jié)果婚禮上爷怀,老公的妹妹穿的比我還像新娘阻肩。我一直安慰自己,他們只是感情好运授,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,857評(píng)論 6 392
  • 文/花漫 我一把揭開白布烤惊。 她就那樣靜靜地躺著,像睡著了一般吁朦。 火紅的嫁衣襯著肌膚如雪柒室。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,679評(píng)論 1 305
  • 那天逗宜,我揣著相機(jī)與錄音雄右,去河邊找鬼。 笑死纺讲,一個(gè)胖子當(dāng)著我的面吹牛擂仍,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播刻诊,決...
    沈念sama閱讀 40,406評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼防楷,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了则涯?” 一聲冷哼從身側(cè)響起复局,我...
    開封第一講書人閱讀 39,311評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤冲簿,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后亿昏,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體峦剔,經(jīng)...
    沈念sama閱讀 45,767評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,945評(píng)論 3 336
  • 正文 我和宋清朗相戀三年角钩,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了吝沫。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,090評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡递礼,死狀恐怖惨险,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情脊髓,我是刑警寧澤辫愉,帶...
    沈念sama閱讀 35,785評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站将硝,受9級(jí)特大地震影響恭朗,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜依疼,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,420評(píng)論 3 331
  • 文/蒙蒙 一痰腮、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧律罢,春花似錦膀值、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,988評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至稀余,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間趋翻,已是汗流浹背睛琳。 一陣腳步聲響...
    開封第一講書人閱讀 33,101評(píng)論 1 271
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留踏烙,地道東北人师骗。 一個(gè)月前我還...
    沈念sama閱讀 48,298評(píng)論 3 372
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像讨惩,于是被迫代替她去往敵國(guó)和親辟癌。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,033評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容