第一次用的損失函數(shù)是均方誤差MSELoss
程序正常運行沒有遇到問題谴餐,但當換成CrossEntropyLoss
后會報如下錯誤:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
搜了很多博客都沒有找到答案茄猫,這篇博客中說到:
交叉熵需要傳入一個output和一個target排监。nn.CrossEntropyLoss(output, target)
锥债。
其中:
output.dtype : torch.FloatTorch
target.dtype : torch.LongTorch
我的預測數(shù)據(jù)output和標簽數(shù)據(jù)target都是torch.float32數(shù)據(jù)類型,所以我在將array數(shù)據(jù)類型轉換成tensor數(shù)據(jù)類型時做了如下操作:
x = torch.from_numpy(x).float()
target = torch.from_numpy(target).long()
其中float是float32類型践樱,long是int64類型,但是問題依然存在凸丸。
在pytorch的官方論壇里有一個人也遇到了同樣的問題拷邢,他把nn.CrossEntropyLoss()
換成了nn.MultiLabelSoftMarginLoss()
就不再報錯了。但是經(jīng)過實驗發(fā)現(xiàn)屎慢,這個損失函數(shù)的效果非常差瞭稼,遠不如MSELoss
。
最終抛人,我找到了一篇運用交叉熵損失函數(shù)的多分類代碼一步步檢查發(fā)現(xiàn)了報錯的原因:
在多分類問題中弛姜,當損失函數(shù)為nn.CrossEntropyLoss()
時脐瑰,它會自動把標簽轉換成onehot形式妖枚。例如,MNIST數(shù)據(jù)集的標簽為0到9的數(shù)字苍在,有100個標簽绝页,則標簽的形狀為[100],而我們的模型的輸出則為onehot形式寂恬,其形狀為[100, 10]续誉。所以,我們在運用交叉熵損失函數(shù)時不必將標簽也轉換成onehot形式初肉。問題成功解決酷鸦。(target仍然需要為int64類型)