==用于衡量模型的最終效果==
一蛾默、背景
在學(xué)習(xí)tensorflow的初級階段哈打,會常常搞不懂,metrics的具體意義和實際用途掩缓,接下來的文章一方面是對接自己的解答,也是一種學(xué)習(xí)路徑的記錄遵岩。
二你辣、基礎(chǔ)
混淆矩陣是理解眾多評價指標(biāo)的基礎(chǔ)。下面是混淆矩陣的表格
- TP:True Positive尘执,預(yù)測為正例舍哄,實際也為正例。
- FP:False Positive誊锭,預(yù)測為正例表悬,實際卻為負(fù)例。
- TN:True Negative丧靡,預(yù)測為負(fù)例蟆沫,實際也為負(fù)例。
- FN:False Negative温治,預(yù)測為負(fù)例饭庞,實際卻為正例。
統(tǒng)計正確預(yù)測的次數(shù)在總的數(shù)據(jù)集中的占比熬荆,==最常用==舟山,tensorflow中是通過記錄每次測試的批次的正確數(shù)量和總體數(shù)量確定的。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Time : 2020/4/4 9:18 下午
# Author : Dale Chen
# Description:
# File : accuracy.py
# Copyright: (c) 2020 year, 4399 Network CO.ltd. All Rights Reserved.
import numpy as np
labels = np.array([[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0]], dtype=np.float)
predictions = np.array([[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[0, 1, 1, 1]], dtype=np.float)
N_CORRECT = 0
N_ITEM_SEEN = 0
def reset_running_variables():
"""reset variables"""
global N_CORRECT, N_ITEM_SEEN
N_CORRECT = 0
N_ITEM_SEEN = 0
def update_running_variables(labs, preds):
"""update variables"""
global N_CORRECT, N_ITEM_SEEN
N_CORRECT += (labs == preds).sum()
N_ITEM_SEEN += labs.size
def calculate_accuracy():
"""calculate accuracy"""
global N_CORRECT, N_ITEM_SEEN
return float(N_CORRECT) / N_ITEM_SEEN
reset_running_variables()
for i in range(len(labels)):
update_running_variables(labels[i], predictions[I])
accuracy = calculate_accuracy()
print("accuracy:", accuracy)
三卤恳、內(nèi)置評估指標(biāo)算子
- tf.metrics.accuracy()
- tf.metrics.precision()
- tf.metrics.recall()
- tf.metrics.mean_iou()
1. accuracy [??kj?r?si] (準(zhǔn)確率)
==所有預(yù)測正確的樣本(不論正例還是負(fù)例累盗,只看對錯)占總體數(shù)量的比例==
注意一下輸入的數(shù)據(jù)需要是true或者false, 實際上就是比對正確的數(shù)量占總體的百分比
import tensorflow as tf
import numpy as np
l = np.array([[0, 1, 0, 0],
[0, 0, 0, 0]], dtype=np.float)
p = np.array([[0, 1, 0, 0],
[0, 0, 0, 1]], dtype=np.float)
labels = tf.reshape(l, [2, 4])
predictions = tf.reshape(p, [2, 4])
op = tf.keras.metrics.BinaryAccuracy()
op.update_state(labels, predictions)
print("accuracy:", op.result().numpy())
//accuracy: 1.0
2.precision (精確度)
==精確率就是指 當(dāng)前劃分到正樣本類別中突琳,被正確分類的比例若债,即真正的正樣本所占所有預(yù)測為正樣本的比例。==
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Time : 2020/4/4 11:46 下午
# Author : Dale Chen
# Description:
# File : precision.py
# Copyright: (c) 2020 year, 4399 Network CO.ltd. All Rights Reserved.
import numpy as np
import tensorflow as tf
l = np.array([[0, 1, 1, 0],
[0, 0, 0, 0]])
p = np.array([[0, 1, 0, 0],
[0, 0, 0, 0]])
labels = tf.reshape(l, [1, 8])
predictions = tf.reshape(p, [1, 8])
op = tf.keras.metrics.Precision()
op.update_state(labels, predictions)
print("precision:", (op.result()).numpy())
//precision: 1
我們有一個樣本數(shù)量為50的數(shù)據(jù)集本今,其中正樣本的數(shù)量為20拆座。但是,在我們所有的預(yù)測結(jié)果中冠息,只預(yù)測出了一個正樣本挪凑,并且這個樣本也確實是正樣本,那么 TP=1逛艰,F(xiàn)P=0躏碳,Precision = TP/(TP+FP) = 1.0,那么我們的模型是不是就很好了呢散怖?當(dāng)然不是菇绵,我們還有19個正樣本都沒有預(yù)測成功. 這時候要使用回招率。
3.recall(回召率)
==召回率即指 當(dāng)前被分到正樣本類別中镇眷,真實的正樣本占所有正樣本的比例咬最,即召回了多少正樣本的比例。Recall = TP/(TP+FN)==
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Time : 2020/4/5 12:19 上午
# Author : Dale Chen
# Description:
# File : recall.py
# Copyright: (c) 2020 year, 4399 Network CO.ltd. All Rights Reserved.
import numpy as np
import tensorflow as tf
l = np.array([[0, 1, 1, 0],
[0, 0, 0, 0]])
p = np.array([[0, 1, 0, 0],
[0, 0, 0, 0]])
labels = tf.reshape(l, [1, 8])
predictions = tf.reshape(p, [1, 8])
op = tf.keras.metrics.Recall()
op.update_state(labels, predictions)
print("recall:", (op.result()).numpy())
//recall: 0.5
3.mean_iou
用于處理分類問題的欠动,需要輸入種類的數(shù)量永乌。標(biāo)簽的格式一定要是[0, 0 , 1, 0]的結(jié)構(gòu),只能有一種結(jié)果具伍, 但是預(yù)測的值可以是有多種結(jié)果翅雏。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Time : 2020/4/4 10:10 下午
# Author : Dale Chen
# Description:
# File : mean_iou.py
# Copyright: (c) 2020 year, 4399 Network CO.ltd. All Rights Reserved.
import numpy as np
import tensorflow as tf
l = np.array([[0, 1, 0, 0],
[0, 0, 0, 1]])
p = np.array([[0, 1, 0, 0],
[0, 0, 1, 0]])
labels = tf.reshape(l, [2, 4])
predictions = tf.reshape(p, [2, 4])
op = tf.keras.metrics.MeanIoU(num_classes=4)
op.update_state(labels, predictions)
print("iou_op", op.result().numpy())