在分類問題中,通常需要使用max()
函數(shù)對softmax
函數(shù)的輸出值進行操作,求出預測值索引,然后與標簽進行比對崩掘,計算準確率。下面講解一下torch.max()
函數(shù)的輸入及輸出值都是什么少办,便于我們理解該函數(shù)苞慢。
1. torch.max(input, dim) 函數(shù)
output = torch.max(input, dim)
輸入
input
是softmax函數(shù)輸出的一個tensor
dim
是max函數(shù)索引的維度0/1
,0
是每列的最大值英妓,1
是每行的最大值
輸出
- 函數(shù)會返回兩個
tensor
挽放,第一個tensor
是每行的最大值;第二個tensor
是每行最大值的索引鞋拟。
在多分類任務中我們并不需要知道各類別的預測概率骂维,所以返回值的第一個tensor
對分類任務沒有幫助,而第二個tensor
包含了預測最大概率的索引贺纲,所以在實際使用中我們僅獲取第二個tensor
即可航闺。
下面通過一個實例可以更容易理解這個函數(shù)的用法。
import torch
a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
print(a)
輸出:
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
索引每行的最大值:
torch.max(a, 1)
輸出:
torch.return_types.max(
values=tensor([62, 6, 65]),
indices=tensor([2, 3, 1]))
在計算準確率時第一個tensor values
是不需要的猴誊,所以我們只需提取第二個tensor潦刃,并將tensor格式的數(shù)據(jù)轉(zhuǎn)換成array格式。
torch.max(a, 1)[1].numpy()
輸出:
array([2, 3, 1], dtype=int64)
這樣懈叹,我們就可以與標簽值進行比對乖杠,計算模型預測準確率。
*注:在有的地方我們會看到torch.max(a, 1).data.numpy()
的寫法澄成,這是因為在早期的pytorch的版本中胧洒,variable變量和tenosr是不一樣的數(shù)據(jù)格式畏吓,variable可以進行反向傳播,tensor不可以卫漫,需要將variable轉(zhuǎn)變成tensor再轉(zhuǎn)變成numpy》票現(xiàn)在的版本已經(jīng)將variable和tenosr合并,所以只用torch.max(a,1).numpy()
就可以了列赎。
2.準確率的計算
pred_y = torch.max(predict, 1)[1].numpy()
label_y = torch.max(label, 1)[1].data.numpy()
accuracy = (pred_y == label_y).sum() / len(label_y)
predict
- softmax函數(shù)輸出
label
- 樣本標簽宏悦,這里假設它是one-hot編碼