訓練結(jié)果
image.png
image.png
image.png
image.png
image.png
image.png
image.png
完整工程
-
工程目錄結(jié)構(gòu)
image.png 代碼
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy
# ---------------------------------------------------------
# 載入預訓練的AlexNet模型
model = models.alexnet(pretrained=True)
# 修改輸出層颓鲜,2分類
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)
# -------------------------數(shù)據(jù)集----------------------------------------------------
transform = transforms.Compose([transforms.Resize((227,227)),
transforms.ToTensor()])
train_dataset = ImageFolder(root='./data/train', transform=transform)
val_dataset = ImageFolder(root='./data/val', transform=transform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# ------------------優(yōu)化方法典予,損失函數(shù)--------------------------------------------------
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)
# --------------------判斷是否支持GPU--------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
# -------------------訓練-------------------------------------------------------------
epoch_nums = 50
best_model_wts = model.state_dict()
best_acc = 0
for epoch in range(epoch_nums):
scheduler.step()
running_loss = 0.0
epoch_loss = 0.0
correct = 0
total = 0
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs.to(device)
labels.to(device)
model.train()
optimizer.zero_grad()
# forward
outputs = model(inputs)
# loss
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
#
running_loss += loss.item()
if i % 10 == 9:
correct = 0
total = 0
for images_test, labels_test in val_dataloader:
model.eval()
images_test = images_test.to(device)
labels_test = labels_test.to(device)
outputs_test = model(images_test)
_, prediction = torch.max(outputs_test, 1)
correct += ((prediction == labels_test).sum()).item()
total += labels_test.size(0)
accuracy = correct/total
print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
running_loss = 0.0
if accuracy > best_acc:
best_acc = accuracy
best_model_wts = copy.deepcopy(model.state_dict())
print('Train finish')
torch.save(best_model_wts, './models/model_50.pth')