網(wǎng)上找了一個(gè)代碼灾螃,閱讀代碼题翻,加上了相應(yīng)的注釋
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
BATCH_SIZE=512 #大概需要2G的顯存
EPOCHS=20 # 總共訓(xùn)練批次
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 讓torch判斷是否使用GPU
#獲取數(shù)據(jù)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, #有數(shù)據(jù)集后改為download=False
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
#定義模型
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
# 1,28x28
self.conv1=nn.Conv2d(1,10,5) # 10, 24x24
self.conv2=nn.Conv2d(10,20,3) # 128, 10x10
self.fc1 = nn.Linear(20*10*10,500)
self.fc2 = nn.Linear(500,10)
def forward(self,x):
in_size = x.size(0)
out = self.conv1(x) #24
out = F.relu(out)
out = F.max_pool2d(out, 2, 2) #12
out = self.conv2(out) #10
out = F.relu(out)
out = out.view(in_size,-1)#展開成一維,方便進(jìn)行FC
out = self.fc1(out)
out = F.relu(out)
out = self.fc2(out)
out = F.log_softmax(out,dim=1)
return out
model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())
#訓(xùn)練過程
'''
1 獲取loss:輸入圖像和標(biāo)簽腰鬼,通過infer計(jì)算得到預(yù)測值嵌赠,計(jì)算損失函數(shù);
2 optimizer.zero_grad() 清空過往梯度熄赡;
3 loss.backward() 反向傳播姜挺,計(jì)算當(dāng)前梯度;
4 optimizer.step() 根據(jù)梯度更新網(wǎng)絡(luò)參數(shù)
鏈接:https://www.zhihu.com/question/303070254/answer/573037166
'''
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()#梯度置零彼硫,清空過往梯度炊豪,這種操作模式的好處可參考https://www.zhihu.com/question/303070254
output = model(data)
loss = F.nll_loss(output, target)#調(diào)用內(nèi)置函數(shù)
loss.backward()#反向傳播凌箕,計(jì)算當(dāng)前梯度
optimizer.step()#根據(jù)梯度更新網(wǎng)絡(luò)參數(shù)
if(batch_idx+1)%30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))#Use torch.Tensor.item() to get a Python number from a tensor containing a single value:
#測試過程
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)#正向計(jì)算預(yù)測值
test_loss += F.nll_loss(output, target, reduction='sum').item() # 將一批的損失相加
pred = output.max(1, keepdim=True)[1] # 找到概率最大的下標(biāo)
correct += pred.eq(target.view_as(pred)).sum().item()#找到正確的預(yù)測值
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
#運(yùn)行,這里也可以改成main的形式
for epoch in range(1, EPOCHS + 1):
train(model, DEVICE, train_loader, optimizer, epoch)
test(model, DEVICE, test_loader)
代碼部分比較簡單
1词渤、數(shù)據(jù)讀取
2牵舱、構(gòu)建網(wǎng)絡(luò)模型 ConvNet(nn.Module)
3、構(gòu)建訓(xùn)練函數(shù)train
4缺虐、構(gòu)建測試函數(shù)test
5芜壁、關(guān)于代碼中的model.train(),model.eval()的說明:
參考 PyTorch進(jìn)行訓(xùn)練和測試時(shí)指定實(shí)例化的model模式為:train/eval:
eval即evaluation模式高氮,train即訓(xùn)練模式慧妄。僅僅當(dāng)模型中有Dropout和BatchNorm是才會(huì)有影響。因?yàn)橛?xùn)練時(shí)dropout和BN都開啟剪芍,而一般而言測試時(shí)dropout被關(guān)閉塞淹,BN中的參數(shù)也是利用訓(xùn)練時(shí)保留的參數(shù),所以測試時(shí)應(yīng)進(jìn)入評(píng)估模式紊浩。
(在訓(xùn)練時(shí)窖铡,??和??2是在整個(gè)mini-batch 上計(jì)算出來的包含了像是64 或28 或其它一定數(shù)量的樣本,但在測試時(shí)坊谁,你可能需要逐一處理樣本费彼,方法是根據(jù)你的訓(xùn)練集估算??和??2,估算的方式有很多種口芍,理論上你可以在最終的網(wǎng)絡(luò)中運(yùn)行整個(gè)訓(xùn)練集來得到??和??2箍铲,但在實(shí)際操作中,我們通常運(yùn)用指數(shù)加權(quán)平均來追蹤在訓(xùn)練過程中你看到的??和??2的值鬓椭。還可以用指數(shù)加權(quán)平均颠猴,有時(shí)也叫做流動(dòng)平均來粗略估算??和??2,然后在測試中使用??和??2的值來進(jìn)行你所需要的隱藏單元??值的調(diào)整小染。在實(shí)踐中翘瓮,不管你用什么方式估算??和??2,這套過程都是比較穩(wěn)健的裤翩,因此我不太會(huì)擔(dān)心你具體的操作方式资盅,而且如果你使用的是某種深度學(xué)習(xí)框架,通常會(huì)有默認(rèn)的估算??和??2的方式踊赠,應(yīng)該一樣會(huì)起到比較好的效果)
6呵扛、損失函數(shù),這是torch的loss function筐带,這里用的是負(fù)對數(shù)似然今穿,推導(dǎo)可以參考負(fù)對數(shù)似然