import torch
from IPython import display
from matplotlib import pyplot as plt
import random
import sys
import torchvision
import torchvision.transforms as transforms
from torch import nn
def use_svg_display():
? ? # 用矢量圖顯示
? ? display.set_matplotlib_formats('svg')
def set_figsize(figsize=(3.5, 2.5)):
? ? use_svg_display()
? ? # 設(shè)置圖的尺寸
? ? plt.rcParams['figure.figsize'] = figsize
# 每次返回batch_size(批量大写浇浮)個隨機(jī)樣本的特征和標(biāo)簽。
def data_iter(batch_size, features, labels):
? ? num_examples = len(features)
? ? indices = list(range(num_examples))
? ? random.shuffle(indices)? # 樣本的讀取順序是隨機(jī)的
? ? for i in range(0, num_examples, batch_size):
? ? ? ? j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # 最后一次可能不足一個batch
? ? ? ? yield? features.index_select(0, j), labels.index_select(0, j)
def linreg(X, w, b):? # 定義線性模型
? ? return torch.mm(X, w) + b
def squared_loss(y_hat, y):? # 定義損失函數(shù)
? ? # 注意這里返回的是向量, 另外, pytorch里的MSELoss并沒有除以 2
? ? return (y_hat - y.view(y_hat.size())) ** 2 / 2
def sgd(params, lr, batch_size):? # 定義優(yōu)化算法玩荠,它通過不斷迭代模型參數(shù)來優(yōu)化損失函數(shù)存捺。這里自動求梯度模塊計算得來的梯度是一個批量樣本的梯度和。我們將它除以批量大小來得到平均值。
? ? for param in params:
? ? ? ? param.data -= lr * param.grad / batch_size # 注意這里更改param時用的param.data
# 讀取數(shù)據(jù)集的label名稱
def get_fashion_mnist_labels(labels):
? ? text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
? ? ? ? ? ? ? ? ? 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
? ? return [text_labels[int(i)] for i in labels]
# 在一行里畫出多張圖像和對應(yīng)標(biāo)簽
def show_fashion_mnist(images, labels):
? ? use_svg_display()
? ? # 這里的_表示我們忽略(不使用)的變量
? ? _, figs = plt.subplots(1, len(images), figsize=(12, 12))
? ? for f, img, lbl in zip(figs, images, labels):
? ? ? ? f.imshow(img.view((28, 28)).numpy())
? ? ? ? f.set_title(lbl)
? ? ? ? f.axes.get_xaxis().set_visible(False)
? ? ? ? f.axes.get_yaxis().set_visible(False)
? ? plt.show()
def load_data_fashion_mnist(batch_size=256):
mnist_train = torchvision.datasets.FashionMNIST(root='D:/workspace/pytorch-test', train=True, download=False, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='D:/workspace/pytorch-test', train=False, download=False, transform=transforms.ToTensor())
print(len(mnist_train), len(mnist_test))
if sys.platform.startswith('win'):
num_workers = 0? # 0表示不用額外的進(jìn)程來加速讀取數(shù)據(jù)
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_iter, test_iter
def evaluate_accuracy(data_iter, net):
? ? acc_sum, n = 0.0, 0
? ? for X, y in data_iter:
? ? ? ? acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
? ? ? ? n += y.shape[0]
? ? return acc_sum / n
def train_softmax(net, train_iter, test_iter, loss, num_epochs, batch_size,
? ? ? ? ? ? ? params=None, lr=None, optimizer=None):
? ? for epoch in range(num_epochs):
? ? ? ? train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
? ? ? ? for X, y in train_iter:
? ? ? ? ? ? y_hat = net(X)
? ? ? ? ? ? l = loss(y_hat, y).sum()
? ? ? ? ? ? # 梯度清零
? ? ? ? ? ? if optimizer is not None:
? ? ? ? ? ? ? ? optimizer.zero_grad()
? ? ? ? ? ? elif params is not None and params[0].grad is not None:
? ? ? ? ? ? ? ? for param in params:
? ? ? ? ? ? ? ? ? ? param.grad.data.zero_()
? ? ? ? ? ? l.backward()
? ? ? ? ? ? if optimizer is None:
? ? ? ? ? ? ? ? sgd(params, lr, batch_size)
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? optimizer.step()? # “softmax回歸的簡潔實(shí)現(xiàn)”一節(jié)將用到
? ? ? ? ? ? train_l_sum += l.item()
? ? ? ? ? ? train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
? ? ? ? ? ? n += y.shape[0]
? ? ? ? test_acc = evaluate_accuracy(test_iter, net)
? ? ? ? print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
? ? ? ? ? ? ? % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
class FlattenLayer(nn.Module):
? ? def __init__(self):
? ? ? ? super(FlattenLayer, self).__init__()
? ? def forward(self, x): # x shape: (batch, *, *, ...)
? ? ? ? return x.view(x.shape[0], -1)