這篇文章需要大家對深度學(xué)習(xí)里的神經(jīng)網(wǎng)絡(luò)訓(xùn)練有一定的基礎(chǔ)厢蒜,我以前訓(xùn)練網(wǎng)絡(luò)一直都是用的TensorFlow叁征,后面需要把模型和數(shù)據(jù)遷移到Pytorch平臺上去,發(fā)現(xiàn)很多里面有很多知識點需要注意,寫這篇文章一方面是給自己做個筆記梯找,總結(jié)下自己的經(jīng)驗情组,另一方面是為了方便想要快速上手Pytorch的同學(xué)燥筷。這篇文章主要內(nèi)容有:
- Tensorflow的PlayGround
- Pytorch介紹和安裝
- Torch和Torchvision里的常用包
- Variable、Tensor院崇、Numpy之間的關(guān)系
- CPU與GPU
- 示例--GAN生成MINIST數(shù)據(jù)
Tensorflow的PlayGround
PlayGround是一個在線演示肆氓、實驗的神經(jīng)網(wǎng)絡(luò)平臺,是一個入門神經(jīng)網(wǎng)絡(luò)非常直觀的網(wǎng)站底瓣。這個圖形化平臺非常強大谢揪,將神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程直接可視化。假若有的同學(xué)剛剛想入門深度學(xué)習(xí)這一領(lǐng)域捐凭,可以去看看:
PlayGround地址:http://playground.tensorflow.org
這里也有一篇PlayGround介紹寫的非常詳細的文章:
參考地址:https://finthon.com/tensorflow-playground-nn/
Pytorch介紹和安裝
2017年1月拨扶,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。Pytorch和Torch底層實現(xiàn)都用的是C語言茁肠,但是Torch的調(diào)用需要掌握Lua語言患民,相比而言使用Python的人更多,根本不是一個數(shù)量級垦梆,所以Pytorch基于Torch做了些底層修改匹颤、優(yōu)化并且支持Python語言調(diào)用。
它是一個基于Python的可續(xù)計算包托猩,目標(biāo)用戶有兩類:
- 使用GPU來運算numpy
- 一個深度學(xué)習(xí)平臺印蓖,提供最大的靈活型和速度
如何安裝Pytorch呢?
- 基礎(chǔ)環(huán)境
一臺PC設(shè)備京腥、一張高性能NVIDIA顯卡(可選)赦肃、Ubuntu系統(tǒng) - 安裝步驟
- Anaconda(可選)和Python
- 顯卡驅(qū)動和CUDA
- 運行Pytorch的安裝命令
- 相關(guān)資料
詳細的安裝教程https://blog.csdn.net/zzlyw/article/details/78674543
Pytorch中文網(wǎng)https://www.pytorchtutorial.com
Torch和Torchvision里的常用包
Torch
torch
:張量相關(guān)的運算,例如創(chuàng)建
公浪、索引
他宛、切片
、連接
因悲、轉(zhuǎn)置
堕汞、加減乘除
等torch.nn
:包含搭建網(wǎng)絡(luò)層的模塊(Modules)和一系列的loss函數(shù),例如全連接
晃琳、卷積
讯检、池化
琐鲁、BN批處理
、dropout
人灼、CrossEntropyLoss
围段、MSELoss
等torch.nn.functional
:常用的激活函數(shù)relu
、leaky_relu
投放、sigmoid
等torch.autograd
:提供Tensor所有操作的自動求導(dǎo)方法torch.optim
:各種參數(shù)優(yōu)化方法奈泪,例如SGD
、AdaGrad
灸芳、RMSProp
涝桅、Adam
等torch.nn.init
:可以用它更改nn.Module
的默認參數(shù)初始化方式torch.utils.data
:用于加載數(shù)據(jù)
Torchvision
torchvision.datasets
:常用數(shù)據(jù)集,MNIST
烙样、COCO
冯遂、CIFAR10
、Imagenet
等torchvision.models
:常用模型谒获,AlextNet
蛤肌、VGG
、ResNet
批狱、DenseNet
等torchvision.transforms
:圖片相關(guān)處理裸准,裁剪
、尺寸縮放
赔硫、歸一化
等torchvision.utils
:將給定的Tensor保存成image文件
Variable炒俱、Tensor、Numpy之間的關(guān)系
- Numpy
NumPy是Python語言的一個擴充程序庫卦停。支持高級大量的維度數(shù)組與矩陣運算,此外也針對數(shù)組運算提供大量的數(shù)學(xué)函數(shù)庫向胡。
例子:
>>> import numpy as np
>>> x=np.array([[1,2,3],[9,8,7],[6,5,4]])
-
Tensor
PyTorch 提供一種類似 NumPy 的抽象方法來表征張量(或多維數(shù)組)恼蓬,它可以利用 GPU 來加速訓(xùn)練惊完。
-
Variable
- PyTorch 張量的簡單封裝
- 幫助建立計算圖
- Autograd(自動微分庫)的必要部分
- 將關(guān)于這些變量的梯度保存在 .grad 中
- Tensor、Variable处硬、Numpy之間相互轉(zhuǎn)化
- 將Numpy矩陣轉(zhuǎn)換為Tensor張量
sub_ts = torch.from_numpy(sub_img)
- 將Tensor張量轉(zhuǎn)化為Numpy矩陣
sub_np1 = sub_ts.numpy()
- 將Tensor轉(zhuǎn)換為Variable
sub_va = Variable(sub_ts)
- 將Variable轉(zhuǎn)換為Tensor
sub_np2 = sub_va.data
CPU與GPU
Pytorch支持CPU運行小槐,但是速度非常慢,一張好的NVIDIA顯卡能夠大大減少網(wǎng)絡(luò)訓(xùn)練時間荷辕,以我自己經(jīng)驗來看凿跳,15年MacBook Pro 與戴爾工作站附加一張顯存11GB的1080ti顯卡相比,后者速度是前者速度的224倍疮方,尤其訓(xùn)練復(fù)雜網(wǎng)絡(luò)一定要在GPU上跑控嗜。Pytorch中把數(shù)據(jù)和模型從CPU遷移到GPU非常簡單:
直接對變量、張量骡显、模型使用.cuda()
即可把他們遷移到GPU上疆栏,反過來遷移到CPU上曾掂,使用.cpu()
。
當(dāng)有多行顯卡時壁顶,想充分利用它們珠洗,則可使用model = nn.DataParallel(model)
命令:
常見問題
- 這里的不同位置包含GPU與CPU,還包含不同GPU之間
- 不同位置的
Variable
之間不能直接相互運算 - 不同位置的
Tensor
直接不能直接相互運算 - 不同位置的
Variable
和模型
不能直接訓(xùn)練 - 使用指定顯卡:
.cuda(<顯卡號數(shù)>)
示例--GAN生成MINIST數(shù)據(jù)
最后看個實例若专,如何使用GAN網(wǎng)絡(luò)生成MINIST 數(shù)據(jù)许蓖,主要內(nèi)容有:
MNIST數(shù)據(jù)集
MNIST數(shù)據(jù)集是一個手寫體數(shù)據(jù)集,圖片大小都是28x28调衰,包含0-9共10個數(shù)字膊爪,各種風(fēng)格:
下載好的數(shù)據(jù)集:
測試集t10k
開頭,訓(xùn)練集train
開頭嚎莉,images
是圖片蚁飒,labels
是標(biāo)簽
GAN網(wǎng)絡(luò)模型
輸入100長度的噪聲向量,經(jīng)過一個全連接萝喘,兩個卷積層淮逻,一個下采樣之后生成成28x28大小的圖片,這一部分是生成器
生成的假圖片和MNIST里的真圖片經(jīng)過兩個卷積層下采樣之后阁簸,再次經(jīng)歷兩個全連接層后輸出一個1長度的單位向量爬早,
1
代表輸入圖片為真,0
代表輸入圖片為假
GAN訓(xùn)練和Loss
訓(xùn)練判別器D時启妹,要使得V整體變大筛严,訓(xùn)練生成器G時,要使得V整體變小饶米。
這是一個博弈的過程桨啃,就像制造假錢的犯罪團伙和驗鈔機的關(guān)系,犯罪團伙需要努力提高技術(shù)檬输,讓驗鈔機無法識別出來其制造的假幣照瘾,而驗鈔機要能夠正確的分辨出真正的紙幣還有假幣。
理論上當(dāng)判別器D只有一半的概率
0.5
能識別出假圖片時丧慈,就已經(jīng)收斂了析命,實際上達不到一半的概率,沒關(guān)系逃默,使得假圖片概率盡量高就行了鹃愤,最終看上去效果不錯。這是一張由生成器生成的假圖片完域,你能區(qū)分出來嗎软吐?
可視化
可視化方式有兩種,一種是利用torchvision
里面的包 torchvision.utils
吟税,另外一種是利用visdom
插件凹耙,下面是二者的對比:
上面那張生成的假圖片就是利用
torchvision.utils
里的save_image函數(shù)來存儲在本地的鸟蟹。而以下這張圖是利用
visdom
,在瀏覽器中查看到的效果:visdom
不光可以查看圖片使兔,還可以查看loss變化曲線圖等各種功能建钥。
具體的代碼實現(xiàn)去工程里查看,這里給出分享地址:
https://github.com/gcfrun/GAN_MNIST_Pytorch
mnist_data.py
:數(shù)據(jù)輸入模塊
mnist_net.py
:網(wǎng)絡(luò)模型模塊
mnist_loss.py
:Loss計算模塊
mnist_train.py
:迭代訓(xùn)練模塊
mnist_visual.py
:可視化模塊