通過前面幾章的學(xué)習(xí)轻局,我們已經(jīng)掌握了PyTorch中大部分的基礎(chǔ)知識,本章將結(jié)合之前講的內(nèi)容,帶領(lǐng)讀者從頭實(shí)現(xiàn)一個(gè)完整的深度學(xué)習(xí)項(xiàng)目。本章的重點(diǎn)不在于如何使用PyTorch的接口,而在于合理地設(shè)計(jì)程序的結(jié)構(gòu)腋粥,使得程序更具可讀性晦雨、更易用。
6.1 編程實(shí)戰(zhàn):貓和狗二分類
在學(xué)習(xí)某個(gè)深度學(xué)習(xí)框架時(shí)隘冲,掌握其基本知識和接口固然重要闹瞧,但如何合理組織代碼,使得代碼具有良好的可讀性和可擴(kuò)展性也必不可少展辞。本文不會深入講解過多知識性的東西奥邮,更多的則是傳授一些經(jīng)驗(yàn),這些內(nèi)容可能有些爭議,因其受我個(gè)人喜好和coding風(fēng)格影響較大洽腺,讀者可以將這部分當(dāng)成是一種參考或提議脚粟,而不是作為必須遵循的準(zhǔn)則。歸根到底蘸朋,都是希望你能以一種更為合理的方式組織自己的程序核无。
在做深度學(xué)習(xí)實(shí)驗(yàn)或項(xiàng)目時(shí),為了得到最優(yōu)的模型結(jié)果藕坯,中間往往需要很多次的嘗試和修改团南。而合理的文件組織結(jié)構(gòu),以及一些小技巧可以極大地提高代碼的易讀易用性炼彪。根據(jù)筆者的個(gè)人經(jīng)驗(yàn)吐根,在從事大多數(shù)深度學(xué)習(xí)研究時(shí),程序都需要實(shí)現(xiàn)以下幾個(gè)功能:
- 模型定義
- 數(shù)據(jù)處理和加載
- 訓(xùn)練模型(Train&Validate)
- 訓(xùn)練過程的可視化
- 測試(Test/Inference)
另外程序還應(yīng)該滿足以下幾個(gè)要求:
- 模型需具有高度可配置性辐马,便于修改參數(shù)拷橘、修改模型,反復(fù)實(shí)驗(yàn)齐疙。
- 代碼應(yīng)具有良好的組織結(jié)構(gòu)膜楷,使人一目了然。
- 代碼應(yīng)具有良好的說明贞奋,使其他人能夠理解赌厅。
在之前的章節(jié)中,我們已經(jīng)講解了PyTorch中的絕大部分內(nèi)容轿塔。本章我們將應(yīng)用這些內(nèi)容特愿,并結(jié)合實(shí)際的例子,來講解如何用PyTorch完成Kaggle上的經(jīng)典比賽:Dogs vs. Cats勾缭。本文所有示例程序均在github上開源 揍障。
6.1.1 比賽介紹
Dogs vs. Cats是一個(gè)傳統(tǒng)的二分類問題,其訓(xùn)練集包含25000張圖片俩由,均放置在同一文件夾下毒嫡,命名格式為<category>.<num>.jpg
, 如cat.10000.jpg
、dog.100.jpg
幻梯,測試集包含12500張圖片兜畸,命名為<num>.jpg
,如1000.jpg
碘梢。參賽者需根據(jù)訓(xùn)練集的圖片訓(xùn)練模型咬摇,并在測試集上進(jìn)行預(yù)測,輸出它是狗的概率煞躬。最后提交的csv文件如下肛鹏,第一列是圖片的<num>
逸邦,第二列是圖片為狗的概率。
id,label
10001,0.889
10002,0.01
...
6.1.2 文件組織架構(gòu)
前面提到過在扰,程序主要包含以下功能:
- 模型定義
- 數(shù)據(jù)加載
- 訓(xùn)練和測試
首先來看程序文件的組織結(jié)構(gòu):
├── checkpoints/
├── data/
│ ├── __init__.py
│ ├── dataset.py
│ └── get_data.sh
├── models/
│ ├── __init__.py
│ ├── AlexNet.py
│ ├── BasicModule.py
│ └── ResNet34.py
└── utils/
│ ├── __init__.py
│ └── visualize.py
├── config.py
├── main.py
├── requirements.txt
├── README.md
其中:
-
checkpoints/
: 用于保存訓(xùn)練好的模型缕减,可使程序在異常退出后仍能重新載入模型,恢復(fù)訓(xùn)練健田。 -
data/
:數(shù)據(jù)相關(guān)操作烛卧,包括數(shù)據(jù)預(yù)處理、dataset實(shí)現(xiàn)等妓局。 -
models/
:模型定義总放,可以有多個(gè)模型,例如上面的AlexNet和ResNet34好爬,一個(gè)模型對應(yīng)一個(gè)文件局雄。 -
utils/
:可能用到的工具函數(shù),在本次實(shí)驗(yàn)中主要是封裝了可視化工具存炮。 -
config.py
:配置文件炬搭,所有可配置的變量都集中在此,并提供默認(rèn)值穆桂。 -
main.py
:主文件宫盔,訓(xùn)練和測試程序的入口,可通過不同的命令來指定不同的操作和參數(shù)享完。 -
requirements.txt
:程序依賴的第三方庫灼芭。 -
README.md
:提供程序的必要說明。
6.1.3 關(guān)于init.py
可以看到般又,幾乎每個(gè)文件夾下都有__init__.py
彼绷,一個(gè)目錄如果包含了__init__.py
文件,那么它就變成了一個(gè)包(package)茴迁。__init__.py
可以為空寄悯,也可以定義包的屬性和方法,但其必須存在堕义,其它程序才能從這個(gè)目錄中導(dǎo)入相應(yīng)的模塊或函數(shù)猜旬。例如在data/
文件夾下有__init__.py
,則在main.py
中就可以from data.dataset import DogCat
倦卖。而如果在__init__.py
中寫入from .dataset import DogCat
洒擦,則在main.py中就可以直接寫為:from data import DogCat
,或者import data; dataset = data.DogCat
糖耸,相比于from data.dataset import DogCat
更加便捷。
6.1.4 數(shù)據(jù)加載
數(shù)據(jù)的相關(guān)處理主要保存在data/dataset.py
中丘薛。關(guān)于數(shù)據(jù)加載的相關(guān)操作嘉竟,在上一章中我們已經(jīng)提到過,其基本原理就是使用Dataset
提供數(shù)據(jù)集的封裝,再使用Dataloader
實(shí)現(xiàn)數(shù)據(jù)并行加載舍扰。Kaggle提供的數(shù)據(jù)包括訓(xùn)練集和測試集倦蚪,而我們在實(shí)際使用中,還需專門從訓(xùn)練集中取出一部分作為驗(yàn)證集边苹。對于這三類數(shù)據(jù)集陵且,其相應(yīng)操作也不太一樣,而如果專門寫三個(gè)Dataset
个束,則稍顯復(fù)雜和冗余慕购,因此這里通過加一些判斷來區(qū)分。對于訓(xùn)練集茬底,我們希望做一些數(shù)據(jù)增強(qiáng)處理沪悲,如隨機(jī)裁剪、隨機(jī)翻轉(zhuǎn)阱表、加噪聲等殿如,而驗(yàn)證集和測試集則不需要。下面看dataset.py
的代碼:
# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目標(biāo): 獲取所有圖片的地址最爬,并根據(jù)訓(xùn)練涉馁,驗(yàn)證,測試劃分?jǐn)?shù)據(jù)
"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
else:
self.imgs = imgs[int(0.7 * imgs_num):]
if transforms is None:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
一次返回一張圖片的數(shù)據(jù)
"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
關(guān)于數(shù)據(jù)集使用的注意事項(xiàng)爱致,在上一章中已經(jīng)提到烤送,將文件讀取等費(fèi)時(shí)操作放在__getitem__
函數(shù)中,利用多進(jìn)程加速蒜鸡。避免一次性將所有圖片都讀進(jìn)內(nèi)存胯努,不僅費(fèi)時(shí)也會占用較大內(nèi)存,而且不易進(jìn)行數(shù)據(jù)增強(qiáng)等操作逢防。另外在這里叶沛,我們將訓(xùn)練集中的30%作為驗(yàn)證集,可用來檢查模型的訓(xùn)練效果忘朝,避免過擬合灰署。在使用時(shí),我們可通過dataloader加載數(shù)據(jù)局嘁。
train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers = opt.num_workers)
for ii, (data, label) in enumerate(trainloader):
train()
6.1.5 模型定義
模型的定義主要保存在models/
目錄下溉箕,其中BasicModule
是對nn.Module
的簡易封裝,提供快速加載和保存模型的接口悦昵。
# coding:utf8
import time
import torch as t
class BasicModule(t.nn.Module):
"""
封裝了nn.Module,主要是提供了save和load兩個(gè)方法
"""
def __init__(self):
super(BasicModule, self).__init__()
self.model_name = str(type(self)) # 默認(rèn)名字
def load(self, path):
"""
可加載指定路徑的模型
"""
self.load_state_dict(t.load(path))
def save(self, name=None):
"""
保存模型肴茄,默認(rèn)使用“模型名字+時(shí)間”作為文件名
"""
if name is None:
prefix = 'checkpoints/' + self.model_name + '_'
name = time.strftime(prefix + '%Y%m%d%H%M%S.pth')
t.save(self.state_dict(), name)
return name
def get_optimizer(self, lr, weight_decay):
return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
class Flat(t.nn.Module):
"""
把輸入reshape成(batch_size,dim_length)
"""
def __init__(self):
super(Flat, self).__init__()
# self.size = size
def forward(self, x):
return x.view(x.size(0), -1)
在實(shí)際使用中,直接調(diào)用model.save()
及model.load(opt.load_path)
即可但指。
其它自定義模型一般繼承BasicModule
寡痰,然后實(shí)現(xiàn)自己的模型抗楔。其中AlexNet.py
實(shí)現(xiàn)了AlexNet,ResNet34
實(shí)現(xiàn)了ResNet34拦坠。在models/__init__py
中连躏,代碼如下:
from .AlexNet import AlexNet
from .ResNet34 import ResNet34
這樣在主函數(shù)中就可以寫成:
from models import AlexNet
或
import models
model = models.AlexNet()
或
import models
model = getattr('models', 'AlexNet')()
其中最后一種寫法最為關(guān)鍵,這意味著我們可以通過字符串直接指定使用的模型贞滨,而不必使用判斷語句入热,也不必在每次新增加模型后都修改代碼。新增模型后只需要在models/__init__.py
中加上from .new_module import new_module
即可晓铆。
其它關(guān)于模型定義的注意事項(xiàng)勺良,在上一章中已詳細(xì)講解,這里就不再贅述尤蒿,總結(jié)起來就是:
- 盡量使用
nn.Sequential
(比如AlexNet)郑气。 - 將經(jīng)常使用的結(jié)構(gòu)封裝成子Module(比如GoogLeNet的Inception結(jié)構(gòu),ResNet的Residual Block結(jié)構(gòu))腰池。
- 將重復(fù)且有規(guī)律性的結(jié)構(gòu)尾组,用函數(shù)生成(比如VGG的多種變體,ResNet多種變體都是由多個(gè)重復(fù)卷積層組成)示弓。
6.1.6 工具函數(shù)
在項(xiàng)目中讳侨,我們可能會用到一些helper方法,這些方法可以統(tǒng)一放在utils/
文件夾下奏属,需要使用時(shí)再引入跨跨。在本例中主要是封裝了可視化工具visdom的一些操作,其代碼如下囱皿,在本次實(shí)驗(yàn)中只會用到plot
方法勇婴,用來統(tǒng)計(jì)損失信息。
# coding:utf8
import time
import numpy as np
import visdom
class Visualizer(object):
"""
封裝了visdom的基本操作嘱腥,但是你仍然可以通過`self.vis.function`
調(diào)用原生的visdom接口
"""
def __init__(self, env='default', **kwargs):
self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs)
# 畫的第幾個(gè)數(shù)耕渴,相當(dāng)于橫座標(biāo)
# 保存(’loss',23) 即loss的第23個(gè)點(diǎn)
self.index = {}
self.log_text = ''
def reinit(self, env='default', **kwargs):
"""
修改visdom的配置
"""
self.vis = visdom.Visdom(env=env, **kwargs)
return self
def plot_many(self, d):
"""
一次plot多個(gè)
@params d: dict (name,value) i.e. ('loss',0.11)
"""
for k, v in d.items():
self.plot(k, v)
def img_many(self, d):
for k, v in d.items():
self.img(k, v)
def plot(self, name, y, **kwargs):
"""
self.plot('loss',1.00)
"""
x = self.index.get(name, 0)
self.vis.line(Y=np.array([y]), X=np.array([x]),
win=name,
opts=dict(title=name),
update=None if x == 0 else 'append',
**kwargs
)
self.index[name] = x + 1
def img(self, name, img_, **kwargs):
"""
self.img('input_img',t.Tensor(64,64))
self.img('input_imgs',t.Tensor(3,64,64))
self.img('input_imgs',t.Tensor(100,1,64,64))
self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10)
!3萃谩橱脸!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!7治添诉!
"""
self.vis.images(img_.cpu().numpy(),
win=name,
opts=dict(title=name),
**kwargs
)
def log(self, info, win='log_text'):
"""
self.log({'loss':1,'lr':0.0001})
"""
self.log_text += ('[{time}] {info} <br>'.format(
time=time.strftime('%Y%m%d %H:%M:%S'),
info=info))
self.vis.text(self.log_text, win)
def __getattr__(self, name):
return getattr(self.vis, name)
6.1.7 配置文件
在模型定義、數(shù)據(jù)處理和訓(xùn)練等過程都有很多變量医寿,這些變量應(yīng)提供默認(rèn)值栏赴,并統(tǒng)一放置在配置文件中,這樣在后期調(diào)試靖秩、修改代碼或遷移程序時(shí)會比較方便须眷,在這里我們將所有可配置項(xiàng)放在config.py
中乌叶。
# coding:utf8
import warnings
import torch as t
class DefaultConfig(object):
env = 'default' # visdom 環(huán)境
vis_port = 8097 # visdom 端口
model = 'SqueezeNet' # 使用的模型,名字必須與models/__init__.py中的名字一致
train_data_root = './data/train/' # 訓(xùn)練集存放路徑
test_data_root = './data/test/' # 測試集存放路徑
load_model_path = None # 加載預(yù)訓(xùn)練的模型的路徑柒爸,為None代表不加載
batch_size = 32 # batch size
use_gpu = True # user GPU or not
num_workers = 0 # how many workers for loading data
print_freq = 20 # print info every N batch
debug_file = './debug/debug.txt' # if os.path.exists(debug_file): enter ipdb
result_file = 'result.csv'
max_epoch = 10
lr = 0.001 # initial learning rate
lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay
weight_decay = 0e-5 # 損失函數(shù)
def _parse(self, kwargs):
"""
根據(jù)字典kwargs 更新 config參數(shù)
"""
for k, v in kwargs.items():
if not hasattr(self, k):
warnings.warn("Warning: opt has not attribut %s" % k)
setattr(self, k, v)
opt.device = t.device('cuda') if opt.use_gpu else t.device('cpu')
print('user config:')
for k, v in self.__class__.__dict__.items():
if not k.startswith('_'):
print(k, getattr(self, k))
opt = DefaultConfig()
可配置的參數(shù)主要包括:
- 數(shù)據(jù)集參數(shù)(文件路徑、batch_size等)
- 訓(xùn)練參數(shù)(學(xué)習(xí)率事扭、訓(xùn)練epoch等)
- 模型參數(shù)
這樣我們在程序中就可以這樣使用:
import models
from config import DefaultConfig
opt = DefaultConfig()
lr = opt.lr
model = getattr(models, opt.model)
dataset = DogCat(opt.train_data_root)
這些都只是默認(rèn)參數(shù)捎稚,在這里還提供了更新函數(shù),根據(jù)字典更新配置參數(shù)求橄。
def _parse(self, kwargs):
"""
根據(jù)字典kwargs 更新 config參數(shù)
"""
for k, v in kwargs.items():
if not hasattr(self, k):
warnings.warn("Warning: opt has not attribut %s" % k)
setattr(self, k, v)
opt.device = t.device('cuda') if opt.use_gpu else t.device('cpu')
print('user config:')
for k, v in self.__class__.__dict__.items():
if not k.startswith('_'):
print(k, getattr(self, k))
這樣我們在實(shí)際使用時(shí)今野,并不需要每次都修改config.py
,只需要通過命令行傳入所需參數(shù)罐农,覆蓋默認(rèn)配置即可条霜。
例如:
opt = DefaultConfig()
new_config = {'lr':0.1,'use_gpu':False}
opt.parse(new_config)
opt.lr == 0.1
6.1.8 main.py
在講解主程序main.py
之前,我們先來看看2017年3月谷歌開源的一個(gè)命令行工具fire涵亏,通過pip install fire
即可安裝宰睡。下面來看看fire
的基礎(chǔ)用法,假設(shè)example.py
文件內(nèi)容如下:
import fire
def add(x, y):
return x + y
def mul(**kwargs):
a = kwargs['a']
b = kwargs['b']
return a * b
if __name__ == '__main__':
fire.Fire()
那么我們可以使用:
python example.py add 1 2 # 執(zhí)行add(1, 2)
python example.py mul --a=1 --b=2 # 執(zhí)行mul(a=1, b=2), kwargs={'a':1, 'b':2}
python example.py add --x=1 --y==2 # 執(zhí)行add(x=1, y=2)
可見气筋,只要在程序中運(yùn)行fire.Fire()
拆内,即可使用命令行參數(shù)python file <function> [args,] {--kwargs,}
。fire還支持更多的高級功能宠默,具體請參考官方指南《The Python Fire Guide》麸恍。
在主程序main.py
中,主要包含四個(gè)函數(shù)搀矫,其中三個(gè)需要命令行執(zhí)行抹沪,main.py
的代碼組織結(jié)構(gòu)如下:
def train(**kwargs):
"""
訓(xùn)練
"""
pass
def val(model, dataloader):
"""
計(jì)算模型在驗(yàn)證集上的準(zhǔn)確率等信息,用以輔助訓(xùn)練
"""
pass
def test(**kwargs):
"""
測試(inference)
"""
pass
def help():
"""
打印幫助的信息
"""
print('help')
if __name__=='__main__':
import fire
fire.Fire()
根據(jù)fire的使用方法瓤球,可通過python main.py <function> --args=xx
的方式來執(zhí)行訓(xùn)練或者測試融欧。
訓(xùn)練
訓(xùn)練的主要步驟如下:
- 定義網(wǎng)絡(luò)
- 定義數(shù)據(jù)
- 定義損失函數(shù)和優(yōu)化器
- 計(jì)算重要指標(biāo)
- 開始訓(xùn)練
- 訓(xùn)練網(wǎng)絡(luò)
- 可視化各種指標(biāo)
- 計(jì)算在驗(yàn)證集上的指標(biāo)
訓(xùn)練函數(shù)的代碼如下:
def train(**kwargs):
opt._parse(kwargs)
vis = Visualizer(opt.env, port=opt.vis_port)
# step1: configure model
model = getattr(models, opt.model)()
if opt.load_model_path:
model.load(opt.load_model_path)
model.to(opt.device)
# step2: data
train_data = DogCat(opt.train_data_root, train=True)
val_data = DogCat(opt.train_data_root, train=False)
train_dataloader = DataLoader(train_data, opt.batch_size,
shuffle=True, num_workers=opt.num_workers)
val_dataloader = DataLoader(val_data, opt.batch_size,
shuffle=False, num_workers=opt.num_workers)
# step3: criterion and optimizer
criterion = t.nn.CrossEntropyLoss()
lr = opt.lr
optimizer = model.get_optimizer(lr, opt.weight_decay)
# step4: meters
loss_meter = meter.AverageValueMeter()
confusion_matrix = meter.ConfusionMeter(2)
previous_loss = 1e10
# train
for epoch in range(opt.max_epoch):
loss_meter.reset()
confusion_matrix.reset()
for ii, (data, label) in tqdm(enumerate(train_dataloader)):
# train model
input = data.to(opt.device)
target = label.to(opt.device)
optimizer.zero_grad()
score = model(input)
loss = criterion(score, target)
loss.backward()
optimizer.step()
# meters update and visualize
loss_meter.add(loss.item())
# detach 一下更安全保險(xiǎn)
confusion_matrix.add(score.detach(), target.detach())
if (ii + 1) % opt.print_freq == 0:
vis.plot('loss', loss_meter.value()[0])
print("loss:", loss_meter.value()[0])
# 進(jìn)入debug模式
# if os.path.exists(opt.debug_file):
# import ipdb;
# ipdb.set_trace()
print("保存檢查點(diǎn)...")
model.save()
cm_value = confusion_matrix.value()
vis.plot('train_accuracy', 100. * (cm_value[0][0] + cm_value[1][1]) / cm_value.sum())
# validate and visualize
val_cm, val_accuracy = val(model, val_dataloader)
vis.plot('val_accuracy', val_accuracy)
vis.log("\tepoch:{epoch},\tlr:{lr},\tloss:{loss},\ttrain_cm:{train_cm},\tval_cm:{val_cm}\t".format(
epoch=epoch, lr=lr, loss=loss_meter.value()[0], train_cm=str(confusion_matrix.value()),
val_cm=str(val_cm.value())))
# update learning rate
if loss_meter.value()[0] > previous_loss:
lr = lr * opt.lr_decay
# 第二種降低學(xué)習(xí)率的方法:不會有moment等信息的丟失
for param_group in optimizer.param_groups:
param_group['lr'] = lr
previous_loss = loss_meter.value()[0]
這里用到了PyTorchNet里面的一個(gè)工具: meter。meter提供了一些輕量級的工具冰垄,用于幫助用戶快速統(tǒng)計(jì)訓(xùn)練過程中的一些指標(biāo)蹬癌。AverageValueMeter
能夠計(jì)算所有數(shù)的平均值和標(biāo)準(zhǔn)差,這里用來統(tǒng)計(jì)一個(gè)epoch中損失的平均值虹茶。confusionmeter
用來統(tǒng)計(jì)分類問題中的分類情況逝薪,是一個(gè)比準(zhǔn)確率更詳細(xì)的統(tǒng)計(jì)指標(biāo)。例如對于表格6-1蝴罪,共有50張狗的圖片董济,其中有35張被正確分類成了狗,還有15張被誤判成貓要门;共有100張貓的圖片虏肾,其中有91張被正確判為了貓廓啊,剩下9張被誤判成狗。相比于準(zhǔn)確率等統(tǒng)計(jì)信息封豪,混淆矩陣更能體現(xiàn)分類的結(jié)果谴轮,尤其是在樣本比例不均衡的情況下。
表6-1 混淆矩陣
樣本 | 判為狗 | 判為貓 |
---|---|---|
實(shí)際是狗 | 35 | 15 |
實(shí)際是貓 | 9 | 91 |
PyTorchNet從TorchNet遷移而來吹埠,提供了很多有用的工具第步,但其目前開發(fā)和文檔都還不是很完善,本書不做過多的講解缘琅。
驗(yàn)證
驗(yàn)證相對來說比較簡單粘都,但要注意需將模型置于驗(yàn)證模式(model.eval()
),驗(yàn)證完成后還需要將其置回為訓(xùn)練模式(model.train()
)刷袍,這兩句代碼會影響BatchNorm
和Dropout
等層的運(yùn)行模式翩隧。驗(yàn)證模型準(zhǔn)確率的代碼如下。
@t.no_grad()
def val(model, dataloader):
"""
計(jì)算模型在驗(yàn)證集上的準(zhǔn)確率等信息
"""
model.eval()
confusion_matrix = meter.ConfusionMeter(2)
for ii, (val_input, label) in tqdm(enumerate(dataloader)):
val_input = val_input.to(opt.device)
score = model(val_input)
confusion_matrix.add(score.detach().squeeze(), label.type(t.LongTensor))
model.train()
cm_value = confusion_matrix.value()
accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
return confusion_matrix, accuracy
測試
測試時(shí)呻纹,需要計(jì)算每個(gè)樣本屬于狗的概率堆生,并將結(jié)果保存成csv文件。測試的代碼與驗(yàn)證比較相似雷酪,但需要自己加載模型和數(shù)據(jù)顽频。
@t.no_grad() # pytorch>=0.5
def test(**kwargs):
opt._parse(kwargs)
# configure model
model = getattr(models, opt.model)().eval()
if opt.load_model_path:
model.load(opt.load_model_path)
model.to(opt.device)
# data
train_data = DogCat(opt.test_data_root, test=True)
test_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)
results = []
for ii, (data, path) in tqdm(enumerate(test_dataloader)):
input = data.to(opt.device)
score = model(input)
probability = t.nn.functional.softmax(score, dim=1)[:, 0].detach().tolist()
batch_results = [(path_.item(), probability_) for path_, probability_ in zip(path, probability)]
results += batch_results
write_csv(results, opt.result_file)
return results
def write_csv(results, file_name):
import csv
with open(file_name, 'w') as f:
writer = csv.writer(f)
writer.writerow(['id', 'label'])
writer.writerows(results)
幫助函數(shù)
為了方便他人使用, 程序中還應(yīng)當(dāng)提供一個(gè)幫助函數(shù),用于說明函數(shù)是如何使用太闺。程序的命令行接口中有眾多參數(shù)糯景,如果手動(dòng)用字符串表示不僅復(fù)雜,而且后期修改config文件時(shí)省骂,還需要修改對應(yīng)的幫助信息蟀淮,十分不便。這里使用了Python標(biāo)準(zhǔn)庫中的inspect方法钞澳,可以自動(dòng)獲取config的源代碼怠惶。help的代碼如下:
def help():
"""
打印幫助的信息: python file.py help
"""
print("""
usage : python file.py <function> [--args=value]
<function> := train | test | help
example:
python {0} train --env='env0701' --lr=0.01
python {0} test --dataset='path/to/dataset/root/'
python {0} help
avaiable args:""".format(__file__))
from inspect import getsource
source = (getsource(opt.__class__))
print(source)
當(dāng)用戶執(zhí)行python main.py help
的時(shí)候,會打印如下幫助信息:
usage : python main.py <function> [--args=value,]
<function> := train | test | help
example:
python main.py train --env='env0701' --lr=0.01
python main.py test --dataset='path/to/dataset/'
python main.py help
avaiable args:
class DefaultConfig(object):
env = 'default' # visdom 環(huán)境
model = 'AlexNet' # 使用的模型
train_data_root = './data/train/' # 訓(xùn)練集存放路徑
test_data_root = './data/test' # 測試集存放路徑
load_model_path = 'checkpoints/model.pth' # 加載預(yù)訓(xùn)練的模型
batch_size = 128 # batch size
use_gpu = True # user GPU or not
num_workers = 4 # how many workers for loading data
print_freq = 20 # print info every N batch
debug_file = './debug/debug.txt'
result_file = 'result.csv' # 結(jié)果文件
max_epoch = 10
lr = 0.1 # initial learning rate
lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay
weight_decay = 1e-4 # 損失函數(shù)
6.1.9 使用
正如help
函數(shù)的打印信息所述轧粟,可以通過命令行參數(shù)指定變量名.下面是三個(gè)使用例子策治,fire會將包含-
的命令行參數(shù)自動(dòng)轉(zhuǎn)層下劃線_
,也會將非數(shù)值的值轉(zhuǎn)成字符串兰吟。所以--train-data-root=data/train
和--train_data_root='data/train'
是等價(jià)的通惫。
# 訓(xùn)練模型
python main.py train
--train-data-root=data/train/
--lr=0.005
--batch-size=32
--model='ResNet34'
--max-epoch = 20
# 測試模型
python main.py test
--test-data-root=data/test
--load-model-path='checkpoints/resnet34_00:23:05.pth'
--batch-size=128
--model='ResNet34'
--num-workers=12
# 打印幫助信息
python main.py help
實(shí)驗(yàn)過程
本章程序及數(shù)據(jù)下載:百度網(wǎng)盤,提取碼:aw26混蔼。
首先履腋,在命令行cmd紅啟動(dòng)visdom服務(wù)器:
python -m visdom.server
然后,訓(xùn)練模型:
python main.py train
訓(xùn)練結(jié)果如下:
從上述結(jié)果可以看出,模型的精度可以達(dá)到97%以上遵湖。你也可以手動(dòng)更改模型悔政,通過調(diào)節(jié)參數(shù)來進(jìn)一步提升模型的準(zhǔn)確率。
最后延旧,測試模型:
python main.py test
第二列表示預(yù)測為狗的概率:
我們來看一下測試集圖片:
可以看到谋国,模型能夠正確識別出很多狗和貓了,但是還存在很大的改進(jìn)空間迁沫。
6.1.10 爭議
以上的程序設(shè)計(jì)規(guī)范帶有作者強(qiáng)烈的個(gè)人喜好烹卒,并不想作為一個(gè)標(biāo)準(zhǔn),而是作為一個(gè)提議和一種參考弯洗。上述設(shè)計(jì)在很多地方還有待商榷,例如對于訓(xùn)練過程是否應(yīng)該封裝成一個(gè)trainer
對象逢勾,或者直接封裝到BaiscModule
的train
方法之中牡整。對命令行參數(shù)的處理也有不少值得討論之處。因此不要將本文中的觀點(diǎn)作為一個(gè)必須遵守的規(guī)范溺拱,而應(yīng)該看作一個(gè)參考逃贝。
本章中的設(shè)計(jì)可能會引起不少爭議,其中比較值得商榷的部分主要有以下兩個(gè)方面:
- 命令行參數(shù)的設(shè)置做修。目前大多數(shù)程序都是使用Python標(biāo)準(zhǔn)庫中的
argparse
來處理命令行參數(shù)硝岗,也有些使用比較輕量級的click
涝涤。這種處理相對來說對命令行的支持更完備,但根據(jù)作者的經(jīng)驗(yàn)來看沪摄,這種做法不夠直觀,并且代碼量相對來說也較多纱烘。比如argparse
杨拐,每次增加一個(gè)命令行參數(shù),都必須寫如下代碼:
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
在讀者眼中擂啥,這種實(shí)現(xiàn)方式遠(yuǎn)不如一個(gè)專門的config.py
來的直觀和易用哄陶。尤其是對于使用Jupyter notebook或IPython等交互式調(diào)試的用戶來說,argparse
較難使用哺壶。
- 模型訓(xùn)練屋吨。有不少人喜歡將模型的訓(xùn)練過程集成于模型的定義之中,代碼結(jié)構(gòu)如下所示:
class MyModel(nn.Module):
def __init__(self,opt):
self.dataloader = Dataloader(opt)
self.optimizer = optim.Adam(self.parameters(),lr=0.001)
self.lr = opt.lr
self.model = make_model()
def forward(self,input):
pass
def train_(self):
# 訓(xùn)練模型
for epoch in range(opt.max_epoch)
for ii,data in enumerate(self.dataloader):
train_epoch()
model.save()
def train_epoch(self):
pass
抑或是專門設(shè)計(jì)一個(gè)Trainer
對象山宾,形如:
"""
code simplified from:
https://github.com/pytorch/pytorch/blob/master/torch/utils/trainer/trainer.py
"""
import heapq
from torch.autograd import Variable
class Trainer(object):
def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.dataset = dataset
self.iterations = 0
def run(self, epochs=1):
for i in range(1, epochs + 1):
self.train()
def train(self):
for i, data in enumerate(self.dataset, self.iterations + 1):
batch_input, batch_target = data
self.call_plugins('batch', i, batch_input, batch_target)
input_var = Variable(batch_input)
target_var = Variable(batch_target)
plugin_data = [None, None]
def closure():
batch_output = self.model(input_var)
loss = self.criterion(batch_output, target_var)
loss.backward()
if plugin_data[0] is None:
plugin_data[0] = batch_output.data
plugin_data[1] = loss.data
return loss
self.optimizer.zero_grad()
self.optimizer.step(closure)
self.iterations += i
還有一些人喜歡模仿keras和scikit-learn的設(shè)計(jì)至扰,設(shè)計(jì)一個(gè)fit
接口。對讀者來說资锰,這些處理方式很難說哪個(gè)更好或更差渊胸,找到最適合自己的方法才是最好的。
BasicModule
的封裝台妆,可多可少翎猛。訓(xùn)練過程中的很多操作都可以移到BasicModule
之中胖翰,比如get_optimizer
方法用來獲取優(yōu)化器,比如train_step
用來執(zhí)行單歩訓(xùn)練切厘。對于不同的模型萨咳,如果對應(yīng)的優(yōu)化器定義不一樣,或者是訓(xùn)練方法不一樣疫稿,可以復(fù)寫這些函數(shù)自定義相應(yīng)的方法培他,取決于自己的喜好和項(xiàng)目的實(shí)際需求。
6.2 PyTorch Debug指南
6.2.1 ipdb介紹
很多初學(xué)者用print或log調(diào)試程序遗座,這在小規(guī)模的程序下很方便舀凛。但是更好的調(diào)試方法是一邊運(yùn)行一邊檢查里面的變量和方法。pdb是一個(gè)交互式的調(diào)試工具途蒋,集成于Python的標(biāo)準(zhǔn)庫之中猛遍,由于其強(qiáng)大的功能,被廣泛應(yīng)用于Python環(huán)境中号坡。pdb能讓你根據(jù)需求跳轉(zhuǎn)到任意的Python代碼斷點(diǎn)懊烤、查看任意變量、單步執(zhí)行代碼宽堆,甚至還能修改代碼的值腌紧,而不必重啟程序。ipdb是一個(gè)增強(qiáng)版的pdb畜隶,可通過pip install ipdb
安裝壁肋。ipdb提供了調(diào)試模式下的代碼補(bǔ)全,還具有更好的語法高亮和代碼溯源籽慢,以及更好的內(nèi)省功能墩划,更關(guān)鍵的是,它與pdb接口完全兼容嗡综。
在本書第2章曾粗略地提到過ipdb的基本使用乙帮,本章將繼續(xù)介紹如何結(jié)合PyTorch和ipdb進(jìn)行調(diào)試。首先看一個(gè)例子极景,要是用ipdb察净,只需在想要進(jìn)行調(diào)試的地方插入ipdb.set_trace()
,當(dāng)代碼運(yùn)行到此處時(shí)盼樟,就會自動(dòng)進(jìn)入交互式調(diào)試模式氢卡。
假設(shè)有如下程序:
try:
import ipdb
except:
import pdb as ipdb
def sum(x):
r = 0
for ii in x:
r += ii
return r
def mul(x):
r = 1
for ii in x:
r *= ii
return r
ipdf.set_trace()
x = [1,2,3,4,5]
r = sum(x)
r = mul(x)
當(dāng)程序運(yùn)行至ipdb.set_trace(),會自動(dòng)進(jìn)入debug模式晨缴,在該模式中译秦,我們可使用調(diào)試命令,如next或縮寫n單步執(zhí)行,也可查看Python變量筑悴,或是運(yùn)行Python代碼们拙。如果Python變量名和調(diào)式命令沖突,需要在變量名前加"!"阁吝,這樣ipdb會執(zhí)行對應(yīng)的Python代碼砚婆,而不是調(diào)試命令。下面舉例說明ipdb的調(diào)試突勇,這里重點(diǎn)講解ipdb的兩大功能装盯。
- 查看:在函數(shù)調(diào)用堆棧中自由跳轉(zhuǎn),并查看函數(shù)的局部變量
- 修改:修改程序中的變量甲馋,并能以此影響程序的運(yùn)行結(jié)果埂奈。
> e:\debug.py(19)<module>()
18 ipdb.set_trace()
---> 19 x = [1,2,3,4,5]
20 r = sum(x)
ipdb> l 1,21 # list 1,21的縮寫,查看第1行到第21行的代碼定躏,光標(biāo)所指的這一行尚未運(yùn)行
1 try:
2 import ipdb
3 except:
4 import pdb as ipdb
5
6 def sum(x):
7 r = 0
8 for ii in x:
9 r += ii
10 return r
11
12 def mul(x):
13 r = 1
14 for ii in x:
15 r *= ii
16 return r
17
18 ipdb.set_trace()
---> 19 x = [1,2,3,4,5]
20 r = sum(x)
21 r = mul(x)
ipdb> n # next的縮寫账磺,執(zhí)行下一步
> e:\debug.py(20)<module>()
19 x = [1,2,3,4,5]
---> 20 r = sum(x)
21 r = mul(x)
ipdb> s # step的縮寫,進(jìn)入sum函數(shù)內(nèi)部
--Call--
> e:\debug.py(6)sum()
5
----> 6 def sum(x):
7 r = 0
ipdb> n # next單步執(zhí)行
> e:\debug.py(7)sum()
6 def sum(x):
----> 7 r = 0
8 for ii in x:
ipdb> n # next單步執(zhí)行
> e:\debug.py(8)sum()
7 r = 0
----> 8 for ii in x:
9 r += ii
ipdb> n # next單步執(zhí)行
> e:\debug.py(9)sum()
8 for ii in x:
----> 9 r += ii
10 return r
ipdb> u # up的縮寫共屈,跳回到上一層的調(diào)用
> e:\debug.py(20)<module>()
19 x = [1,2,3,4,5]
---> 20 r = sum(x)
21 r = mul(x)
ipdb> d # down的縮寫,跳到調(diào)用的下一層
> e:\debug.py(9)sum()
8 for ii in x:
----> 9 r += ii
10 return r
ipdb> !r # !r 查看變量r的值党窜,該變量名與調(diào)試命令`r(eturn)`沖突
0
ipdb> r # return的縮寫拗引,繼續(xù)運(yùn)行直到函數(shù)返回
--Return--
15
> e:\debug.py(10)sum()
9 r += ii
---> 10 return r
11
ipdb> n # 下一步
> e:\debug.py(21)<module>()
19 x = [1,2,3,4,5]
20 r = sum(x)
---> 21 r = mul(x)
ipdb> x # 查看變量x的值
[1, 2, 3, 4, 5]
ipdb> x[0] = 10000 # 修改變量x的值
ipdb> b 13 # break的縮寫,AI第13行設(shè)置斷點(diǎn)
Breakpoint 1 at e:\debug.py:13
ipdb> c # continue的縮寫幌衣,繼續(xù)運(yùn)行矾削,直到遇到斷點(diǎn)
> e:\debug.py(13)mul()
12 def mul(x):
1--> 13 r = 1
14 for ii in x:
ipdb> return # 返回的是修改后x的乘積
--Return--
1200000
> e:\debug.py(16)mul()
15 r *= ii
---> 16 return r
17
ipdb> q # quit的縮寫,退出debug模式
Exiting Debugger.
關(guān)于ipdb的使用還有一些技巧:
- <tab>鍵能夠自動(dòng)補(bǔ)齊豁护,補(bǔ)齊用法與IPython中的類似哼凯。
- j(ump) <lineno>能夠跳過中間某些行代碼的執(zhí)行
- 可以直接在ipdb中修改變量的值
- h(elp)能夠查看調(diào)試命令的用法,比如
h h
可以查看h(elp)命令的用法楚里,h jump
能夠查看j(ump)命令的用法断部。
6.2.2 在PyTorch中Debug
PyTorch作為一個(gè)動(dòng)態(tài)圖框架,與ipdb結(jié)合使用能為調(diào)試過程帶來便捷班缎。對TensorFlow等靜態(tài)圖框架來說蝴光,使用Python接口定義計(jì)算圖,然后使用C++代碼執(zhí)行底層運(yùn)算达址,在定義圖的時(shí)候不進(jìn)行任何計(jì)算蔑祟,而在計(jì)算的時(shí)候又無法使用pdb進(jìn)行調(diào)試,因?yàn)閜db調(diào)試只能調(diào)試Python代碼沉唠,故調(diào)試一直是此類靜態(tài)圖框架的一個(gè)痛點(diǎn)疆虚。與TensorFlow不同,PyTorch可以在執(zhí)行計(jì)算的同時(shí)定義計(jì)算圖,這些計(jì)算定義過程是使用Python完成的径簿。雖然底層的計(jì)算也是用C/C++完成的罢屈,但是我們能夠查看Python定義部分的變量值,這就已經(jīng)足夠了牍帚。下面我們將舉例說明儡遮。
- 如何AIPyTorch中查看神經(jīng)網(wǎng)絡(luò)各個(gè)層的輸出。
- 如何在PyTorch中分析各個(gè)參數(shù)的梯度暗赶。
- 如何動(dòng)態(tài)修改PyTorch的訓(xùn)練過程鄙币。
首先,運(yùn)行第一節(jié)給出的“貓狗大戰(zhàn)”程序:
python main.py train --debug-file='debug/debug.txt'
程序運(yùn)行一段時(shí)間后蹂随,在debug目錄下創(chuàng)建debug.txt標(biāo)識文件十嘿,當(dāng)程序檢測到這個(gè)文件存在時(shí),會自動(dòng)進(jìn)入debug模式岳锁。
99it [00:17, 6.07it/s]loss: 0.22854854568839075
119it [00:21, 5.79it/s]loss: 0.21267264398435753
139it [00:24, 5.99it/s]loss: 0.19839374726372108
> e:\workspace\python\pytorch\chapter6\main.py(80)train()
79 loss_meter.reset()
---> 80 confusion_matrix.reset()
81 for ii, (data, label) in tqdm(enumerate(train_dataloader)):
ipdb> break 88 # 在第88行設(shè)置斷點(diǎn)绩衷,當(dāng)程序運(yùn)行到此處進(jìn)入debug模式
Breakpoint 1 at e:\workspace\python\pytorch\chapter6\main.py:88
ipdb> # 打印所有參數(shù)及其梯度的標(biāo)準(zhǔn)差
for (name,p) in model.named_parameters(): \
print(name,p.data.std(),p.grad.data.std())
model.features.0.weight tensor(0.2615, device='cuda:0') tensor(0.3769, device='cuda:0')
model.features.0.bias tensor(0.4862, device='cuda:0') tensor(0.3368, device='cuda:0')
model.features.3.squeeze.weight tensor(0.2738, device='cuda:0') tensor(0.3023, device='cuda:0')
model.features.3.squeeze.bias tensor(0.5867, device='cuda:0') tensor(0.3753, device='cuda:0')
model.features.3.expand1x1.weight tensor(0.2168, device='cuda:0') tensor(0.2883, device='cuda:0')
model.features.3.expand1x1.bias tensor(0.2256, device='cuda:0') tensor(0.1147, device='cuda:0')
model.features.3.expand3x3.weight tensor(0.0935, device='cuda:0') tensor(0.1605, device='cuda:0')
model.features.3.expand3x3.bias tensor(0.1421, device='cuda:0') tensor(0.0583, device='cuda:0')
model.features.4.squeeze.weight tensor(0.1976, device='cuda:0') tensor(0.2137, device='cuda:0')
model.features.4.squeeze.bias tensor(0.4058, device='cuda:0') tensor(0.1798, device='cuda:0')
model.features.4.expand1x1.weight tensor(0.2144, device='cuda:0') tensor(0.4214, device='cuda:0')
model.features.4.expand1x1.bias tensor(0.4994, device='cuda:0') tensor(0.0958, device='cuda:0')
model.features.4.expand3x3.weight tensor(0.1063, device='cuda:0') tensor(0.2963, device='cuda:0')
model.features.4.expand3x3.bias tensor(0.0489, device='cuda:0') tensor(0.0719, device='cuda:0')
model.features.6.squeeze.weight tensor(0.1736, device='cuda:0') tensor(0.3544, device='cuda:0')
model.features.6.squeeze.bias tensor(0.2420, device='cuda:0') tensor(0.0896, device='cuda:0')
model.features.6.expand1x1.weight tensor(0.1211, device='cuda:0') tensor(0.2428, device='cuda:0')
model.features.6.expand1x1.bias tensor(0.0670, device='cuda:0') tensor(0.0162, device='cuda:0')
model.features.6.expand3x3.weight tensor(0.0593, device='cuda:0') tensor(0.1917, device='cuda:0')
model.features.6.expand3x3.bias tensor(0.0227, device='cuda:0') tensor(0.0160, device='cuda:0')
model.features.7.squeeze.weight tensor(0.1207, device='cuda:0') tensor(0.2179, device='cuda:0')
model.features.7.squeeze.bias tensor(0.1484, device='cuda:0') tensor(0.0381, device='cuda:0')
model.features.7.expand1x1.weight tensor(0.1235, device='cuda:0') tensor(0.2279, device='cuda:0')
model.features.7.expand1x1.bias tensor(0.0450, device='cuda:0') tensor(0.0100, device='cuda:0')
model.features.7.expand3x3.weight tensor(0.0609, device='cuda:0') tensor(0.1628, device='cuda:0')
model.features.7.expand3x3.bias tensor(0.0132, device='cuda:0') tensor(0.0079, device='cuda:0')
model.features.9.squeeze.weight tensor(0.1093, device='cuda:0') tensor(0.2459, device='cuda:0')
model.features.9.squeeze.bias tensor(0.0646, device='cuda:0') tensor(0.0135, device='cuda:0')
model.features.9.expand1x1.weight tensor(0.0840, device='cuda:0') tensor(0.1860, device='cuda:0')
model.features.9.expand1x1.bias tensor(0.0177, device='cuda:0') tensor(0.0033, device='cuda:0')
model.features.9.expand3x3.weight tensor(0.0476, device='cuda:0') tensor(0.1393, device='cuda:0')
model.features.9.expand3x3.bias tensor(0.0058, device='cuda:0') tensor(0.0030, device='cuda:0')
model.features.10.squeeze.weight tensor(0.0872, device='cuda:0') tensor(0.1676, device='cuda:0')
model.features.10.squeeze.bias tensor(0.0484, device='cuda:0') tensor(0.0088, device='cuda:0')
model.features.10.expand1x1.weight tensor(0.0859, device='cuda:0') tensor(0.2145, device='cuda:0')
model.features.10.expand1x1.bias tensor(0.0160, device='cuda:0') tensor(0.0025, device='cuda:0')
model.features.10.expand3x3.weight tensor(0.0456, device='cuda:0') tensor(0.1429, device='cuda:0')
model.features.10.expand3x3.bias tensor(0.0070, device='cuda:0') tensor(0.0021, device='cuda:0')
model.features.11.squeeze.weight tensor(0.0786, device='cuda:0') tensor(0.2003, device='cuda:0')
model.features.11.squeeze.bias tensor(0.0422, device='cuda:0') tensor(0.0069, device='cuda:0')
model.features.11.expand1x1.weight tensor(0.0690, device='cuda:0') tensor(0.1400, device='cuda:0')
model.features.11.expand1x1.bias tensor(0.0138, device='cuda:0') tensor(0.0022, device='cuda:0')
model.features.11.expand3x3.weight tensor(0.0366, device='cuda:0') tensor(0.1517, device='cuda:0')
model.features.11.expand3x3.bias tensor(0.0109, device='cuda:0') tensor(0.0023, device='cuda:0')
model.features.12.squeeze.weight tensor(0.0729, device='cuda:0') tensor(0.1736, device='cuda:0')
model.features.12.squeeze.bias tensor(0.0814, device='cuda:0') tensor(0.0084, device='cuda:0')
model.features.12.expand1x1.weight tensor(0.0977, device='cuda:0') tensor(0.1385, device='cuda:0')
model.features.12.expand1x1.bias tensor(0.0102, device='cuda:0') tensor(0.0032, device='cuda:0')
model.features.12.expand3x3.weight tensor(0.0365, device='cuda:0') tensor(0.1312, device='cuda:0')
model.features.12.expand3x3.bias tensor(0.0038, device='cuda:0') tensor(0.0026, device='cuda:0')
model.classifier.1.weight tensor(0.0285, device='cuda:0') tensor(0.0865, device='cuda:0')
model.classifier.1.bias tensor(0.0362, device='cuda:0') tensor(0.0192, device='cuda:0')
ipdb> opt.lr # 查看學(xué)習(xí)率
0.001
ipdb> opt.lr = 0.002 # 更改學(xué)習(xí)率
ipdb> for p in optimizer.param_groups: \
p['lr'] = opt.lr
ipdb> model.save() # 保存模型
'checkpoints/squeezenet_20191004212249.pth'
ipdb> c # 繼續(xù)運(yùn)行,直到第88行暫停
222it [16:38, 35.62s/it]> e:\workspace\python\pytorch\chapter6\main.py(88)train()
87 optimizer.zero_grad()
1--> 88 score = model(input)
89 loss = criterion(score, target)
ipdb> s # 進(jìn)入model(input)內(nèi)部激率,即model.__call__(input)
--Call--
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(537)__call__()
536
--> 537 def __call__(self, *input, **kwargs):
538 for hook in self._forward_pre_hooks.values():
ipdb> n # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(538)__call__()
537 def __call__(self, *input, **kwargs):
--> 538 for hook in self._forward_pre_hooks.values():
539 result = hook(self, input)
ipdb> n # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(544)__call__()
543 input = result
--> 544 if torch._C._get_tracing_state():
545 result = self._slow_forward(*input, **kwargs)
ipdb> n # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(547)__call__()
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
ipdb> s # 進(jìn)入forward函數(shù)內(nèi)容
--Call--
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\loss.py(914)forward()
913
--> 914 def forward(self, input, target):
915 return F.cross_entropy(input, target, weight=self.weight,
ipdb> input # 查看input變量值
tensor([[4.5005, 2.0725],
[3.5933, 7.8643],
[2.9086, 3.4209],
[2.7740, 4.4332],
[6.0164, 2.3033],
[5.2261, 3.2189],
[2.6529, 2.0749],
[6.3259, 2.2383],
[3.0629, 3.4832],
[2.7008, 8.2818],
[5.5684, 2.1567],
[3.0689, 6.1022],
[3.4848, 5.3831],
[1.7920, 5.7709],
[6.5032, 2.8080],
[2.3071, 5.2417],
[3.7474, 5.0263],
[4.3682, 3.6707],
[2.2196, 6.9298],
[5.2201, 2.3034],
[6.4315, 1.4970],
[3.4684, 4.0371],
[3.9620, 1.7629],
[1.7069, 7.8898],
[3.0462, 1.6505],
[2.4081, 6.4456],
[2.1932, 7.4614],
[2.3405, 2.7603],
[1.9478, 8.4156],
[2.7935, 7.8331],
[1.8898, 3.8836],
[3.3008, 1.6832]], device='cuda:0', grad_fn=<AsStridedBackward>)
ipdb> input.data.mean() # 查看input的均值和標(biāo)準(zhǔn)差
tensor(3.9630, device='cuda:0')
ipdb> input.data.std()
tensor(1.9513, device='cuda:0')
ipdb> u # 跳回上一層
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(547)__call__()
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
ipdb> u # 跳回上一層
> e:\workspace\python\pytorch\chapter6\main.py(88)train()
87 optimizer.zero_grad()
1--> 88 score = model(input)
89 loss = criterion(score, target)
ipdb> clear # 清除所有斷點(diǎn)
Clear all breaks? y
Deleted breakpoint 1 at e:\workspace\python\pytorch\chapter6\main.py:88
ipdb> c # 繼續(xù)運(yùn)行咳燕,記得先刪除"debug/debug.txt",否則很快又會進(jìn)入調(diào)試模式
59it [06:21, 5.75it/s]loss: 0.24856307208538073
76it [06:24, 5.91it/s]
當(dāng)我們想要進(jìn)入debug模式乒躺,修改程序中某些參數(shù)值或者想分析程序時(shí)招盲,就可以通過創(chuàng)建debug標(biāo)識文件,此時(shí)程序會進(jìn)入調(diào)試模式嘉冒,調(diào)試完成之后刪除這個(gè)文件并在ipdb調(diào)試接口輸入c繼續(xù)運(yùn)行程序曹货。如果想退出程序,也可以使用這種方式讳推,先創(chuàng)建debug標(biāo)識文件顶籽,然后輸入quit在退出debug的同時(shí)退出程序。這種退出程序的方式银觅,與使用Ctrl+C的方式相比更安全礼饱,因?yàn)檫@能保證數(shù)據(jù)加載的多進(jìn)程程序也能正確地退出,并釋放內(nèi)存究驴、顯存等資源慨仿。
PyTorch和ipdb集合能完成很多其他框架所不能完成或很難完成的功能。根據(jù)筆者日常使用的總結(jié)纳胧,主要有以下幾個(gè)部分:
(1)通過debug暫停程序镰吆。當(dāng)程序進(jìn)入debug模式后,將不再執(zhí)行PCU和GPU運(yùn)算跑慕,但是內(nèi)存和顯存及相應(yīng)的堆椡蛎螅空間不會釋放摧找。
(2)通過debug分析程序,查看每個(gè)層的輸出牢硅,查看網(wǎng)絡(luò)的參數(shù)情況蹬耘。通過u(p)、d(own)减余、s(tep)等命令综苔,能夠進(jìn)入指定的代碼,通過n(ext)可以單步執(zhí)行位岔,從而看到每一層的運(yùn)算結(jié)果如筛,便于分析網(wǎng)絡(luò)的數(shù)值分布等信息。
(3)作為動(dòng)態(tài)圖框架抒抬,PyTorch擁有Python動(dòng)態(tài)語言解釋執(zhí)行的優(yōu)點(diǎn)杨刨,我們能夠在運(yùn)行程序時(shí),用過ipdb修改某些變量的值或?qū)傩圆两#@些修改能夠立即生效妖胀。例如可以在訓(xùn)練開始不久根據(jù)損失函數(shù)調(diào)整學(xué)習(xí)率,不必重啟程序惠勒。
(4)如果在IPython中通過%run魔法方法運(yùn)行程序赚抡,那么在程序異常退出時(shí),可以使用%debug命令纠屋,直接進(jìn)入debug模式涂臣,通過u(p)和d(own)跳到報(bào)錯(cuò)的地方,查看對應(yīng)的變量巾遭,找出原因后修改相應(yīng)的代碼即可肉康。有時(shí)我們的模式訓(xùn)練了好幾個(gè)小時(shí)闯估,卻在將要保存模式之前灼舍,因?yàn)橐粋€(gè)小小的拼寫錯(cuò)誤異常退出。此時(shí)涨薪,如果修改錯(cuò)誤再重新運(yùn)行程序又要花費(fèi)好幾個(gè)小時(shí)骑素,太浪費(fèi)時(shí)間。因此最好的方法就是看利用%debug進(jìn)入調(diào)試模式刚夺,在調(diào)試模式中直接運(yùn)行model.save()保存模型献丑。在IPython中,%pdb魔術(shù)方法能夠使得程序出現(xiàn)問題后侠姑,不用手動(dòng)輸入%debug而自動(dòng)進(jìn)入debug模式创橄,建議使用。
PyTorch調(diào)用CuDNN報(bào)錯(cuò)時(shí)莽红,報(bào)錯(cuò)信息諸如CUDNN_STATUS_BAD_PARAM妥畏,從這些報(bào)錯(cuò)內(nèi)容很難得到有用的幫助信息,最后先利用PCU運(yùn)行代碼醉蚁,此時(shí)一般會得到相對友好的報(bào)錯(cuò)信息燃辖,例如在ipdb中執(zhí)行model.cpu()(input.cpu()),PyTorch底層的TH庫會給出相對比較詳細(xì)的信息网棍。
常見的錯(cuò)誤主要有以下幾種:
- 類型不匹配問題黔龟。例如CrossEntropyLoss的輸入target應(yīng)該是一個(gè)LongTensor,而很多人輸入FloatTensor滥玷。
- 部分?jǐn)?shù)據(jù)忘記從CPU轉(zhuǎn)移到GPU氏身。例如,當(dāng)model存放于GPU時(shí)罗捎,輸入input也需要轉(zhuǎn)移到GPU才能輸入到model中观谦。還有可能就是把多個(gè)model存放于一個(gè)list對象,而在執(zhí)行model.cuda()時(shí)桨菜,這個(gè)list中的對象是不會被轉(zhuǎn)移到CUDA上的豁状,正確的用法是用ModuleList代替。
- Tensor形狀不匹配倒得。此類問題一般是輸入數(shù)據(jù)形狀不對泻红,或是網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計(jì)有問題,一般通過u(p)跳到指定代碼霞掺,查看輸入和模型參數(shù)的形狀即可得知谊路。
此外,可能還會經(jīng)常遇到程序正常運(yùn)行菩彬、沒有報(bào)錯(cuò)缠劝,但是模型無法收斂的問題。例如對于二分類問題骗灶,交叉熵?fù)p失一直徘徊在0.69附近(ln2)惨恭,或者是數(shù)值出現(xiàn)溢出等問題,此時(shí)可以進(jìn)入debug模式耙旦,用單步執(zhí)行查看脱羡,每一層輸出的均值和方差,觀察從哪一層的輸出開始出現(xiàn)數(shù)值異常免都。還要查看每個(gè)參數(shù)梯度的均值和方差锉罐,查看是否出現(xiàn)梯度消失或者梯度爆炸等問題。一般來說绕娘,通過再激活函數(shù)之前增加BatchNorm層脓规、合理的參數(shù)初始化、使用Adam優(yōu)化器险领、學(xué)習(xí)率設(shè)為0.001侨舆,基本就能確保模型在一定程度收斂升酣。
本章帶領(lǐng)讀者從頭實(shí)現(xiàn)了一個(gè)Kaggle上的經(jīng)典競賽,重點(diǎn)講解了如何合理地組合安排程序态罪,同時(shí)介紹了一些在PyTorch中調(diào)試的技巧噩茄。