d2lzh_pytorch

import torch

from IPython import display

from matplotlib import pyplot as plt

import random

import sys

import torchvision

import torchvision.transforms as transforms

from torch import nn

def use_svg_display():

? ? # 用矢量圖顯示

? ? display.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):

? ? use_svg_display()

? ? # 設(shè)置圖的尺寸

? ? plt.rcParams['figure.figsize'] = figsize

# 每次返回batch_size(批量大写浇浮)個隨機(jī)樣本的特征和標(biāo)簽。

def data_iter(batch_size, features, labels):

? ? num_examples = len(features)

? ? indices = list(range(num_examples))

? ? random.shuffle(indices)? # 樣本的讀取順序是隨機(jī)的

? ? for i in range(0, num_examples, batch_size):

? ? ? ? j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # 最后一次可能不足一個batch

? ? ? ? yield? features.index_select(0, j), labels.index_select(0, j)

def linreg(X, w, b):? # 定義線性模型

? ? return torch.mm(X, w) + b

def squared_loss(y_hat, y):? # 定義損失函數(shù)

? ? # 注意這里返回的是向量, 另外, pytorch里的MSELoss并沒有除以 2

? ? return (y_hat - y.view(y_hat.size())) ** 2 / 2

def sgd(params, lr, batch_size):? # 定義優(yōu)化算法玩荠,它通過不斷迭代模型參數(shù)來優(yōu)化損失函數(shù)存捺。這里自動求梯度模塊計算得來的梯度是一個批量樣本的梯度和。我們將它除以批量大小來得到平均值。

? ? for param in params:

? ? ? ? param.data -= lr * param.grad / batch_size # 注意這里更改param時用的param.data

# 讀取數(shù)據(jù)集的label名稱

def get_fashion_mnist_labels(labels):

? ? text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',

? ? ? ? ? ? ? ? ? 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']

? ? return [text_labels[int(i)] for i in labels]

# 在一行里畫出多張圖像和對應(yīng)標(biāo)簽

def show_fashion_mnist(images, labels):

? ? use_svg_display()

? ? # 這里的_表示我們忽略(不使用)的變量

? ? _, figs = plt.subplots(1, len(images), figsize=(12, 12))

? ? for f, img, lbl in zip(figs, images, labels):

? ? ? ? f.imshow(img.view((28, 28)).numpy())

? ? ? ? f.set_title(lbl)

? ? ? ? f.axes.get_xaxis().set_visible(False)

? ? ? ? f.axes.get_yaxis().set_visible(False)

? ? plt.show()

def load_data_fashion_mnist(batch_size=256):

mnist_train = torchvision.datasets.FashionMNIST(root='D:/workspace/pytorch-test', train=True, download=False, transform=transforms.ToTensor())

mnist_test = torchvision.datasets.FashionMNIST(root='D:/workspace/pytorch-test', train=False, download=False, transform=transforms.ToTensor())

print(len(mnist_train), len(mnist_test))

if sys.platform.startswith('win'):

num_workers = 0? # 0表示不用額外的進(jìn)程來加速讀取數(shù)據(jù)

else:

num_workers = 4

train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

return train_iter, test_iter

def evaluate_accuracy(data_iter, net):

? ? acc_sum, n = 0.0, 0

? ? for X, y in data_iter:

? ? ? ? acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()

? ? ? ? n += y.shape[0]

? ? return acc_sum / n

def train_softmax(net, train_iter, test_iter, loss, num_epochs, batch_size,

? ? ? ? ? ? ? params=None, lr=None, optimizer=None):

? ? for epoch in range(num_epochs):

? ? ? ? train_l_sum, train_acc_sum, n = 0.0, 0.0, 0

? ? ? ? for X, y in train_iter:

? ? ? ? ? ? y_hat = net(X)

? ? ? ? ? ? l = loss(y_hat, y).sum()

? ? ? ? ? ? # 梯度清零

? ? ? ? ? ? if optimizer is not None:

? ? ? ? ? ? ? ? optimizer.zero_grad()

? ? ? ? ? ? elif params is not None and params[0].grad is not None:

? ? ? ? ? ? ? ? for param in params:

? ? ? ? ? ? ? ? ? ? param.grad.data.zero_()

? ? ? ? ? ? l.backward()

? ? ? ? ? ? if optimizer is None:

? ? ? ? ? ? ? ? sgd(params, lr, batch_size)

? ? ? ? ? ? else:

? ? ? ? ? ? ? ? optimizer.step()? # “softmax回歸的簡潔實(shí)現(xiàn)”一節(jié)將用到

? ? ? ? ? ? train_l_sum += l.item()

? ? ? ? ? ? train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()

? ? ? ? ? ? n += y.shape[0]

? ? ? ? test_acc = evaluate_accuracy(test_iter, net)

? ? ? ? print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'

? ? ? ? ? ? ? % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

class FlattenLayer(nn.Module):

? ? def __init__(self):

? ? ? ? super(FlattenLayer, self).__init__()

? ? def forward(self, x): # x shape: (batch, *, *, ...)

? ? ? ? return x.view(x.shape[0], -1)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市住册,隨后出現(xiàn)的幾起案子鱼响,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 222,807評論 6 518
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異,居然都是意外死亡卿捎,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 95,284評論 3 399
  • 文/潘曉璐 我一進(jìn)店門径密,熙熙樓的掌柜王于貴愁眉苦臉地迎上來午阵,“玉大人,你說我怎么就攤上這事睹晒√俗” “怎么了?”我有些...
    開封第一講書人閱讀 169,589評論 0 363
  • 文/不壞的土叔 我叫張陵伪很,是天一觀的道長戚啥。 經(jīng)常有香客問我,道長锉试,這世上最難降的妖魔是什么猫十? 我笑而不...
    開封第一講書人閱讀 60,188評論 1 300
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上拖云,老公的妹妹穿的比我還像新娘贷笛。我一直安慰自己,他們只是感情好宙项,可當(dāng)我...
    茶點(diǎn)故事閱讀 69,185評論 6 398
  • 文/花漫 我一把揭開白布乏苦。 她就那樣靜靜地躺著,像睡著了一般尤筐。 火紅的嫁衣襯著肌膚如雪汇荐。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,785評論 1 314
  • 那天盆繁,我揣著相機(jī)與錄音掀淘,去河邊找鬼。 笑死油昂,一個胖子當(dāng)著我的面吹牛革娄,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播冕碟,決...
    沈念sama閱讀 41,220評論 3 423
  • 文/蒼蘭香墨 我猛地睜開眼拦惋,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了鸣哀?” 一聲冷哼從身側(cè)響起架忌,我...
    開封第一講書人閱讀 40,167評論 0 277
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎我衬,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體饰恕,經(jīng)...
    沈念sama閱讀 46,698評論 1 320
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡挠羔,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,767評論 3 343
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了埋嵌。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片破加。...
    茶點(diǎn)故事閱讀 40,912評論 1 353
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖雹嗦,靈堂內(nèi)的尸體忽然破棺而出范舀,到底是詐尸還是另有隱情,我是刑警寧澤了罪,帶...
    沈念sama閱讀 36,572評論 5 351
  • 正文 年R本政府宣布锭环,位于F島的核電站,受9級特大地震影響泊藕,放射性物質(zhì)發(fā)生泄漏辅辩。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,254評論 3 336
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望玫锋。 院中可真熱鬧蛾茉,春花似錦、人聲如沸撩鹿。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,746評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽节沦。三九已至键思,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間散劫,已是汗流浹背稚机。 一陣腳步聲響...
    開封第一講書人閱讀 33,859評論 1 274
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留获搏,地道東北人赖条。 一個月前我還...
    沈念sama閱讀 49,359評論 3 379
  • 正文 我出身青樓,卻偏偏與公主長得像常熙,于是被迫代替她去往敵國和親纬乍。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,922評論 2 361