損失函數(shù)(criterion)
- 通過實例化各種損失函數(shù)類進行定義赘淮,一般實例化名為criterion
import torch.nn as nn
# 均方誤差,適用于回歸任務(wù)
criterion = nn.MSELoss()
# 交叉熵?fù)p失耿眉,用于多分類任務(wù)
criterion = nn.CrossEntropyLoss()
# 二元交叉熵?fù)p失悲关,用于二分類任務(wù)胆绊,輸出層神經(jīng)元個數(shù)為1
criterion = nn.BCELoss()
優(yōu)化器(optimizer)
優(yōu)化器在PyTorch中是用來管理模型參數(shù)和梯度的间影。包括最基本的SGD、SGD with Momentum晴氨、AdaGrad康嘉、RMSprop、Adam籽前。
Adam是目前最常用的優(yōu)化算法之一凄鼻,結(jié)合了動量和RMSprop的優(yōu)點。通過一下代碼實例化一個基于Adam的optimizer
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=0.001) # model為實例化后的模型
在訓(xùn)練中使用損失函數(shù)和優(yōu)化器
- 向前傳播
# 輸入數(shù)據(jù)經(jīng)過模型聚假,得到outputs
outputs = model(inputs)
# outputs和labels進行損失計算
loss = criterion(outputs, labels)
- 后向傳播
# 梯度清零
# 每次反向傳播時,我們希望計算的是當(dāng)前批次數(shù)據(jù)所對應(yīng)的梯度
# 如果不清零梯度闰非,當(dāng)前批次的梯度會被之前批次的梯度污染膘格,導(dǎo)致梯度計算不準(zhǔn)確
optimizer.zero_grad()
# 計算損失相對于每個參數(shù)的梯度
loss.backward()
# 根據(jù)當(dāng)前的梯度更新模型的參數(shù)
optimizer.step()