深度學(xué)習(xí)框架PyTorch入門與實(shí)踐:第六章 PyTorch實(shí)戰(zhàn)指南

通過前面幾章的學(xué)習(xí)轻局,我們已經(jīng)掌握了PyTorch中大部分的基礎(chǔ)知識,本章將結(jié)合之前講的內(nèi)容,帶領(lǐng)讀者從頭實(shí)現(xiàn)一個(gè)完整的深度學(xué)習(xí)項(xiàng)目。本章的重點(diǎn)不在于如何使用PyTorch的接口,而在于合理地設(shè)計(jì)程序的結(jié)構(gòu)腋粥,使得程序更具可讀性晦雨、更易用。

6.1 編程實(shí)戰(zhàn):貓和狗二分類

在學(xué)習(xí)某個(gè)深度學(xué)習(xí)框架時(shí)隘冲,掌握其基本知識和接口固然重要闹瞧,但如何合理組織代碼,使得代碼具有良好的可讀性和可擴(kuò)展性也必不可少展辞。本文不會深入講解過多知識性的東西奥邮,更多的則是傳授一些經(jīng)驗(yàn),這些內(nèi)容可能有些爭議,因其受我個(gè)人喜好和coding風(fēng)格影響較大洽腺,讀者可以將這部分當(dāng)成是一種參考或提議脚粟,而不是作為必須遵循的準(zhǔn)則。歸根到底蘸朋,都是希望你能以一種更為合理的方式組織自己的程序核无。

在做深度學(xué)習(xí)實(shí)驗(yàn)或項(xiàng)目時(shí),為了得到最優(yōu)的模型結(jié)果藕坯,中間往往需要很多次的嘗試和修改团南。而合理的文件組織結(jié)構(gòu),以及一些小技巧可以極大地提高代碼的易讀易用性炼彪。根據(jù)筆者的個(gè)人經(jīng)驗(yàn)吐根,在從事大多數(shù)深度學(xué)習(xí)研究時(shí),程序都需要實(shí)現(xiàn)以下幾個(gè)功能:

  • 模型定義
  • 數(shù)據(jù)處理和加載
  • 訓(xùn)練模型(Train&Validate)
  • 訓(xùn)練過程的可視化
  • 測試(Test/Inference)

另外程序還應(yīng)該滿足以下幾個(gè)要求:

  • 模型需具有高度可配置性辐马,便于修改參數(shù)拷橘、修改模型,反復(fù)實(shí)驗(yàn)齐疙。
  • 代碼應(yīng)具有良好的組織結(jié)構(gòu)膜楷,使人一目了然。
  • 代碼應(yīng)具有良好的說明贞奋,使其他人能夠理解赌厅。

在之前的章節(jié)中,我們已經(jīng)講解了PyTorch中的絕大部分內(nèi)容轿塔。本章我們將應(yīng)用這些內(nèi)容特愿,并結(jié)合實(shí)際的例子,來講解如何用PyTorch完成Kaggle上的經(jīng)典比賽:Dogs vs. Cats勾缭。本文所有示例程序均在github上開源 揍障。

6.1.1 比賽介紹

Dogs vs. Cats是一個(gè)傳統(tǒng)的二分類問題,其訓(xùn)練集包含25000張圖片俩由,均放置在同一文件夾下毒嫡,命名格式為<category>.<num>.jpg, 如cat.10000.jpgdog.100.jpg幻梯,測試集包含12500張圖片兜畸,命名為<num>.jpg,如1000.jpg碘梢。參賽者需根據(jù)訓(xùn)練集的圖片訓(xùn)練模型咬摇,并在測試集上進(jìn)行預(yù)測,輸出它是狗的概率煞躬。最后提交的csv文件如下肛鹏,第一列是圖片的<num>逸邦,第二列是圖片為狗的概率。

id,label
10001,0.889
10002,0.01
...

image.png
6.1.2 文件組織架構(gòu)

前面提到過在扰,程序主要包含以下功能:

  • 模型定義
  • 數(shù)據(jù)加載
  • 訓(xùn)練和測試

首先來看程序文件的組織結(jié)構(gòu):

├── checkpoints/
├── data/
│   ├── __init__.py
│   ├── dataset.py
│   └── get_data.sh
├── models/
│   ├── __init__.py
│   ├── AlexNet.py
│   ├── BasicModule.py
│   └── ResNet34.py
└── utils/
│   ├── __init__.py
│   └── visualize.py
├── config.py
├── main.py
├── requirements.txt
├── README.md

其中:

  • checkpoints/: 用于保存訓(xùn)練好的模型缕减,可使程序在異常退出后仍能重新載入模型,恢復(fù)訓(xùn)練健田。
  • data/:數(shù)據(jù)相關(guān)操作烛卧,包括數(shù)據(jù)預(yù)處理、dataset實(shí)現(xiàn)等妓局。
  • models/:模型定義总放,可以有多個(gè)模型,例如上面的AlexNet和ResNet34好爬,一個(gè)模型對應(yīng)一個(gè)文件局雄。
  • utils/:可能用到的工具函數(shù),在本次實(shí)驗(yàn)中主要是封裝了可視化工具存炮。
  • config.py:配置文件炬搭,所有可配置的變量都集中在此,并提供默認(rèn)值穆桂。
  • main.py:主文件宫盔,訓(xùn)練和測試程序的入口,可通過不同的命令來指定不同的操作和參數(shù)享完。
  • requirements.txt:程序依賴的第三方庫灼芭。
  • README.md:提供程序的必要說明。
6.1.3 關(guān)于init.py

可以看到般又,幾乎每個(gè)文件夾下都有__init__.py彼绷,一個(gè)目錄如果包含了__init__.py 文件,那么它就變成了一個(gè)包(package)茴迁。__init__.py可以為空寄悯,也可以定義包的屬性和方法,但其必須存在堕义,其它程序才能從這個(gè)目錄中導(dǎo)入相應(yīng)的模塊或函數(shù)猜旬。例如在data/文件夾下有__init__.py,則在main.py 中就可以from data.dataset import DogCat倦卖。而如果在__init__.py中寫入from .dataset import DogCat洒擦,則在main.py中就可以直接寫為:from data import DogCat,或者import data; dataset = data.DogCat糖耸,相比于from data.dataset import DogCat更加便捷。

6.1.4 數(shù)據(jù)加載

數(shù)據(jù)的相關(guān)處理主要保存在data/dataset.py中丘薛。關(guān)于數(shù)據(jù)加載的相關(guān)操作嘉竟,在上一章中我們已經(jīng)提到過,其基本原理就是使用Dataset提供數(shù)據(jù)集的封裝,再使用Dataloader實(shí)現(xiàn)數(shù)據(jù)并行加載舍扰。Kaggle提供的數(shù)據(jù)包括訓(xùn)練集和測試集倦蚪,而我們在實(shí)際使用中,還需專門從訓(xùn)練集中取出一部分作為驗(yàn)證集边苹。對于這三類數(shù)據(jù)集陵且,其相應(yīng)操作也不太一樣,而如果專門寫三個(gè)Dataset个束,則稍顯復(fù)雜和冗余慕购,因此這里通過加一些判斷來區(qū)分。對于訓(xùn)練集茬底,我們希望做一些數(shù)據(jù)增強(qiáng)處理沪悲,如隨機(jī)裁剪、隨機(jī)翻轉(zhuǎn)阱表、加噪聲等殿如,而驗(yàn)證集和測試集則不需要。下面看dataset.py的代碼:

# coding:utf8
import os

from PIL import Image
from torch.utils import data
from torchvision import transforms as T


class DogCat(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目標(biāo): 獲取所有圖片的地址最爬,并根據(jù)訓(xùn)練涉馁,驗(yàn)證,測試劃分?jǐn)?shù)據(jù)
        """
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)]

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg 
        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

        imgs_num = len(imgs)

        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]

        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

            if self.test or not train:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])

    def __getitem__(self, index):
        """
        一次返回一張圖片的數(shù)據(jù)
        """
        img_path = self.imgs[index]
        if self.test:
            label = int(self.imgs[index].split('.')[-2].split('/')[-1])
        else:
            label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

關(guān)于數(shù)據(jù)集使用的注意事項(xiàng)爱致,在上一章中已經(jīng)提到烤送,將文件讀取等費(fèi)時(shí)操作放在__getitem__函數(shù)中,利用多進(jìn)程加速蒜鸡。避免一次性將所有圖片都讀進(jìn)內(nèi)存胯努,不僅費(fèi)時(shí)也會占用較大內(nèi)存,而且不易進(jìn)行數(shù)據(jù)增強(qiáng)等操作逢防。另外在這里叶沛,我們將訓(xùn)練集中的30%作為驗(yàn)證集,可用來檢查模型的訓(xùn)練效果忘朝,避免過擬合灰署。在使用時(shí),我們可通過dataloader加載數(shù)據(jù)局嘁。

train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
                        batch_size = opt.batch_size,
                        shuffle = True,
                        num_workers = opt.num_workers)
                  
for ii, (data, label) in enumerate(trainloader):
    train()
6.1.5 模型定義

模型的定義主要保存在models/目錄下溉箕,其中BasicModule是對nn.Module的簡易封裝,提供快速加載和保存模型的接口悦昵。

# coding:utf8
import time

import torch as t


class BasicModule(t.nn.Module):
    """
    封裝了nn.Module,主要是提供了save和load兩個(gè)方法
    """

    def __init__(self):
        super(BasicModule, self).__init__()
        self.model_name = str(type(self))  # 默認(rèn)名字

    def load(self, path):
        """
        可加載指定路徑的模型
        """
        self.load_state_dict(t.load(path))

    def save(self, name=None):
        """
        保存模型肴茄,默認(rèn)使用“模型名字+時(shí)間”作為文件名
        """
        if name is None:
            prefix = 'checkpoints/' + self.model_name + '_'
            name = time.strftime(prefix + '%Y%m%d%H%M%S.pth')
        t.save(self.state_dict(), name)
        return name

    def get_optimizer(self, lr, weight_decay):
        return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)


class Flat(t.nn.Module):
    """
    把輸入reshape成(batch_size,dim_length)
    """

    def __init__(self):
        super(Flat, self).__init__()
        # self.size = size

    def forward(self, x):
        return x.view(x.size(0), -1)

在實(shí)際使用中,直接調(diào)用model.save()model.load(opt.load_path)即可但指。

其它自定義模型一般繼承BasicModule寡痰,然后實(shí)現(xiàn)自己的模型抗楔。其中AlexNet.py實(shí)現(xiàn)了AlexNet,ResNet34實(shí)現(xiàn)了ResNet34拦坠。在models/__init__py中连躏,代碼如下:

from .AlexNet import AlexNet
from .ResNet34 import ResNet34

這樣在主函數(shù)中就可以寫成:

from models import AlexNet
或
import models
model = models.AlexNet()
或
import models
model = getattr('models', 'AlexNet')()

其中最后一種寫法最為關(guān)鍵,這意味著我們可以通過字符串直接指定使用的模型贞滨,而不必使用判斷語句入热,也不必在每次新增加模型后都修改代碼。新增模型后只需要在models/__init__.py中加上from .new_module import new_module即可晓铆。

其它關(guān)于模型定義的注意事項(xiàng)勺良,在上一章中已詳細(xì)講解,這里就不再贅述尤蒿,總結(jié)起來就是:

  • 盡量使用nn.Sequential(比如AlexNet)郑气。
  • 將經(jīng)常使用的結(jié)構(gòu)封裝成子Module(比如GoogLeNet的Inception結(jié)構(gòu),ResNet的Residual Block結(jié)構(gòu))腰池。
  • 將重復(fù)且有規(guī)律性的結(jié)構(gòu)尾组,用函數(shù)生成(比如VGG的多種變體,ResNet多種變體都是由多個(gè)重復(fù)卷積層組成)示弓。
6.1.6 工具函數(shù)

在項(xiàng)目中讳侨,我們可能會用到一些helper方法,這些方法可以統(tǒng)一放在utils/文件夾下奏属,需要使用時(shí)再引入跨跨。在本例中主要是封裝了可視化工具visdom的一些操作,其代碼如下囱皿,在本次實(shí)驗(yàn)中只會用到plot方法勇婴,用來統(tǒng)計(jì)損失信息。

# coding:utf8
import time

import numpy as np
import visdom


class Visualizer(object):
    """
    封裝了visdom的基本操作嘱腥,但是你仍然可以通過`self.vis.function`
    調(diào)用原生的visdom接口
    """

    def __init__(self, env='default', **kwargs):
        self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs)

        # 畫的第幾個(gè)數(shù)耕渴,相當(dāng)于橫座標(biāo)
        # 保存(’loss',23) 即loss的第23個(gè)點(diǎn)
        self.index = {}
        self.log_text = ''

    def reinit(self, env='default', **kwargs):
        """
        修改visdom的配置
        """
        self.vis = visdom.Visdom(env=env, **kwargs)
        return self

    def plot_many(self, d):
        """
        一次plot多個(gè)
        @params d: dict (name,value) i.e. ('loss',0.11)
        """
        for k, v in d.items():
            self.plot(k, v)

    def img_many(self, d):
        for k, v in d.items():
            self.img(k, v)

    def plot(self, name, y, **kwargs):
        """
        self.plot('loss',1.00)
        """
        x = self.index.get(name, 0)
        self.vis.line(Y=np.array([y]), X=np.array([x]),
                      win=name,
                      opts=dict(title=name),
                      update=None if x == 0 else 'append',
                      **kwargs
                      )
        self.index[name] = x + 1

    def img(self, name, img_, **kwargs):
        """
        self.img('input_img',t.Tensor(64,64))
        self.img('input_imgs',t.Tensor(3,64,64))
        self.img('input_imgs',t.Tensor(100,1,64,64))
        self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10)

        !3萃谩橱脸!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!7治添诉!
        """
        self.vis.images(img_.cpu().numpy(),
                        win=name,
                        opts=dict(title=name),
                        **kwargs
                        )

    def log(self, info, win='log_text'):
        """
        self.log({'loss':1,'lr':0.0001})
        """

        self.log_text += ('[{time}] {info} <br>'.format(
            time=time.strftime('%Y%m%d %H:%M:%S'),
            info=info))
        self.vis.text(self.log_text, win)

    def __getattr__(self, name):
        return getattr(self.vis, name)

6.1.7 配置文件

在模型定義、數(shù)據(jù)處理和訓(xùn)練等過程都有很多變量医寿,這些變量應(yīng)提供默認(rèn)值栏赴,并統(tǒng)一放置在配置文件中,這樣在后期調(diào)試靖秩、修改代碼或遷移程序時(shí)會比較方便须眷,在這里我們將所有可配置項(xiàng)放在config.py中乌叶。

# coding:utf8
import warnings

import torch as t


class DefaultConfig(object):
    env = 'default'  # visdom 環(huán)境
    vis_port = 8097  # visdom 端口
    model = 'SqueezeNet'  # 使用的模型,名字必須與models/__init__.py中的名字一致

    train_data_root = './data/train/'  # 訓(xùn)練集存放路徑
    test_data_root = './data/test/'  # 測試集存放路徑
    load_model_path = None  # 加載預(yù)訓(xùn)練的模型的路徑柒爸,為None代表不加載

    batch_size = 32  # batch size
    use_gpu = True  # user GPU or not
    num_workers = 0  # how many workers for loading data
    print_freq = 20  # print info every N batch

    debug_file = './debug/debug.txt'  # if os.path.exists(debug_file): enter ipdb
    result_file = 'result.csv'

    max_epoch = 10
    lr = 0.001  # initial learning rate
    lr_decay = 0.5  # when val_loss increase, lr = lr*lr_decay
    weight_decay = 0e-5  # 損失函數(shù)

    def _parse(self, kwargs):
        """
        根據(jù)字典kwargs 更新 config參數(shù)
        """
        for k, v in kwargs.items():
            if not hasattr(self, k):
                warnings.warn("Warning: opt has not attribut %s" % k)
            setattr(self, k, v)

        opt.device = t.device('cuda') if opt.use_gpu else t.device('cpu')

        print('user config:')
        for k, v in self.__class__.__dict__.items():
            if not k.startswith('_'):
                print(k, getattr(self, k))


opt = DefaultConfig()

可配置的參數(shù)主要包括:

  • 數(shù)據(jù)集參數(shù)(文件路徑、batch_size等)
  • 訓(xùn)練參數(shù)(學(xué)習(xí)率事扭、訓(xùn)練epoch等)
  • 模型參數(shù)

這樣我們在程序中就可以這樣使用:

import models
from config import DefaultConfig

opt = DefaultConfig()
lr = opt.lr
model = getattr(models, opt.model)
dataset = DogCat(opt.train_data_root)

這些都只是默認(rèn)參數(shù)捎稚,在這里還提供了更新函數(shù),根據(jù)字典更新配置參數(shù)求橄。

def _parse(self, kwargs):
        """
        根據(jù)字典kwargs 更新 config參數(shù)
        """
        for k, v in kwargs.items():
            if not hasattr(self, k):
                warnings.warn("Warning: opt has not attribut %s" % k)
            setattr(self, k, v)

        opt.device = t.device('cuda') if opt.use_gpu else t.device('cpu')

        print('user config:')
        for k, v in self.__class__.__dict__.items():
            if not k.startswith('_'):
                print(k, getattr(self, k))

這樣我們在實(shí)際使用時(shí)今野,并不需要每次都修改config.py,只需要通過命令行傳入所需參數(shù)罐农,覆蓋默認(rèn)配置即可条霜。

例如:

opt = DefaultConfig()
new_config = {'lr':0.1,'use_gpu':False}
opt.parse(new_config)
opt.lr == 0.1
6.1.8 main.py

在講解主程序main.py之前,我們先來看看2017年3月谷歌開源的一個(gè)命令行工具fire涵亏,通過pip install fire即可安裝宰睡。下面來看看fire的基礎(chǔ)用法,假設(shè)example.py文件內(nèi)容如下:

import fire

def add(x, y):
  return x + y
  
def mul(**kwargs):
    a = kwargs['a']
    b = kwargs['b']
    return a * b

if __name__ == '__main__':
  fire.Fire()

那么我們可以使用:

python example.py add 1 2 # 執(zhí)行add(1, 2)
python example.py mul --a=1 --b=2 # 執(zhí)行mul(a=1, b=2), kwargs={'a':1, 'b':2}
python example.py add --x=1 --y==2 # 執(zhí)行add(x=1, y=2)

可見气筋,只要在程序中運(yùn)行fire.Fire()拆内,即可使用命令行參數(shù)python file <function> [args,] {--kwargs,}。fire還支持更多的高級功能宠默,具體請參考官方指南《The Python Fire Guide》麸恍。

在主程序main.py中,主要包含四個(gè)函數(shù)搀矫,其中三個(gè)需要命令行執(zhí)行抹沪,main.py的代碼組織結(jié)構(gòu)如下:

def train(**kwargs):
    """
    訓(xùn)練
    """
    pass
     
def val(model, dataloader):
    """
    計(jì)算模型在驗(yàn)證集上的準(zhǔn)確率等信息,用以輔助訓(xùn)練
    """
    pass

def test(**kwargs):
    """
    測試(inference)
    """
    pass

def help():
    """
    打印幫助的信息 
    """
    print('help')

if __name__=='__main__':
    import fire
    fire.Fire()

根據(jù)fire的使用方法瓤球,可通過python main.py <function> --args=xx的方式來執(zhí)行訓(xùn)練或者測試融欧。

訓(xùn)練

訓(xùn)練的主要步驟如下:

  • 定義網(wǎng)絡(luò)
  • 定義數(shù)據(jù)
  • 定義損失函數(shù)和優(yōu)化器
  • 計(jì)算重要指標(biāo)
  • 開始訓(xùn)練
    • 訓(xùn)練網(wǎng)絡(luò)
    • 可視化各種指標(biāo)
    • 計(jì)算在驗(yàn)證集上的指標(biāo)

訓(xùn)練函數(shù)的代碼如下:

def train(**kwargs):
    opt._parse(kwargs)
    vis = Visualizer(opt.env, port=opt.vis_port)

    # step1: configure model
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    model.to(opt.device)

    # step2: data
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train=False)
    train_dataloader = DataLoader(train_data, opt.batch_size,
                                  shuffle=True, num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data, opt.batch_size,
                                shuffle=False, num_workers=opt.num_workers)

    # step3: criterion and optimizer
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = model.get_optimizer(lr, opt.weight_decay)

    # step4: meters
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e10

    # train
    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        confusion_matrix.reset()
        for ii, (data, label) in tqdm(enumerate(train_dataloader)):

            # train model 
            input = data.to(opt.device)
            target = label.to(opt.device)

            optimizer.zero_grad()
            score = model(input)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            # meters update and visualize
            loss_meter.add(loss.item())
            # detach 一下更安全保險(xiǎn)
            confusion_matrix.add(score.detach(), target.detach())

            if (ii + 1) % opt.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])
                print("loss:", loss_meter.value()[0])

                # 進(jìn)入debug模式
                # if os.path.exists(opt.debug_file):
                #     import ipdb;
                #     ipdb.set_trace()
        print("保存檢查點(diǎn)...")
        model.save()

        cm_value = confusion_matrix.value()
        vis.plot('train_accuracy', 100. * (cm_value[0][0] + cm_value[1][1]) / cm_value.sum())

        # validate and visualize
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)
        vis.log("\tepoch:{epoch},\tlr:{lr},\tloss:{loss},\ttrain_cm:{train_cm},\tval_cm:{val_cm}\t".format(
            epoch=epoch, lr=lr, loss=loss_meter.value()[0], train_cm=str(confusion_matrix.value()),
            val_cm=str(val_cm.value())))

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            # 第二種降低學(xué)習(xí)率的方法:不會有moment等信息的丟失
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]

這里用到了PyTorchNet里面的一個(gè)工具: meter。meter提供了一些輕量級的工具冰垄,用于幫助用戶快速統(tǒng)計(jì)訓(xùn)練過程中的一些指標(biāo)蹬癌。AverageValueMeter能夠計(jì)算所有數(shù)的平均值和標(biāo)準(zhǔn)差,這里用來統(tǒng)計(jì)一個(gè)epoch中損失的平均值虹茶。confusionmeter用來統(tǒng)計(jì)分類問題中的分類情況逝薪,是一個(gè)比準(zhǔn)確率更詳細(xì)的統(tǒng)計(jì)指標(biāo)。例如對于表格6-1蝴罪,共有50張狗的圖片董济,其中有35張被正確分類成了狗,還有15張被誤判成貓要门;共有100張貓的圖片虏肾,其中有91張被正確判為了貓廓啊,剩下9張被誤判成狗。相比于準(zhǔn)確率等統(tǒng)計(jì)信息封豪,混淆矩陣更能體現(xiàn)分類的結(jié)果谴轮,尤其是在樣本比例不均衡的情況下。

表6-1 混淆矩陣

樣本 判為狗 判為貓
實(shí)際是狗 35 15
實(shí)際是貓 9 91

PyTorchNet從TorchNet遷移而來吹埠,提供了很多有用的工具第步,但其目前開發(fā)和文檔都還不是很完善,本書不做過多的講解缘琅。

驗(yàn)證

驗(yàn)證相對來說比較簡單粘都,但要注意需將模型置于驗(yàn)證模式(model.eval()),驗(yàn)證完成后還需要將其置回為訓(xùn)練模式(model.train())刷袍,這兩句代碼會影響BatchNormDropout等層的運(yùn)行模式翩隧。驗(yàn)證模型準(zhǔn)確率的代碼如下。

@t.no_grad()
def val(model, dataloader):
    """
    計(jì)算模型在驗(yàn)證集上的準(zhǔn)確率等信息
    """
    model.eval()
    confusion_matrix = meter.ConfusionMeter(2)
    for ii, (val_input, label) in tqdm(enumerate(dataloader)):
        val_input = val_input.to(opt.device)
        score = model(val_input)
        confusion_matrix.add(score.detach().squeeze(), label.type(t.LongTensor))

    model.train()
    cm_value = confusion_matrix.value()
    accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
    return confusion_matrix, accuracy

測試

測試時(shí)呻纹,需要計(jì)算每個(gè)樣本屬于狗的概率堆生,并將結(jié)果保存成csv文件。測試的代碼與驗(yàn)證比較相似雷酪,但需要自己加載模型和數(shù)據(jù)顽频。

@t.no_grad()  # pytorch>=0.5
def test(**kwargs):
    opt._parse(kwargs)

    # configure model
    model = getattr(models, opt.model)().eval()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    model.to(opt.device)

    # data
    train_data = DogCat(opt.test_data_root, test=True)
    test_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)
    results = []
    for ii, (data, path) in tqdm(enumerate(test_dataloader)):
        input = data.to(opt.device)
        score = model(input)
        probability = t.nn.functional.softmax(score, dim=1)[:, 0].detach().tolist()

        batch_results = [(path_.item(), probability_) for path_, probability_ in zip(path, probability)]

        results += batch_results
    write_csv(results, opt.result_file)

    return results


def write_csv(results, file_name):
    import csv
    with open(file_name, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['id', 'label'])
        writer.writerows(results)

幫助函數(shù)

為了方便他人使用, 程序中還應(yīng)當(dāng)提供一個(gè)幫助函數(shù),用于說明函數(shù)是如何使用太闺。程序的命令行接口中有眾多參數(shù)糯景,如果手動(dòng)用字符串表示不僅復(fù)雜,而且后期修改config文件時(shí)省骂,還需要修改對應(yīng)的幫助信息蟀淮,十分不便。這里使用了Python標(biāo)準(zhǔn)庫中的inspect方法钞澳,可以自動(dòng)獲取config的源代碼怠惶。help的代碼如下:

def help():
    """
    打印幫助的信息: python file.py help
    """

    print("""
    usage : python file.py <function> [--args=value]
    <function> := train | test | help
    example: 
            python {0} train --env='env0701' --lr=0.01
            python {0} test --dataset='path/to/dataset/root/'
            python {0} help
    avaiable args:""".format(__file__))

    from inspect import getsource
    source = (getsource(opt.__class__))
    print(source)

當(dāng)用戶執(zhí)行python main.py help的時(shí)候,會打印如下幫助信息:

    usage : python main.py <function> [--args=value,]
    <function> := train | test | help
    example: 
            python main.py train --env='env0701' --lr=0.01
            python main.py test --dataset='path/to/dataset/'
            python main.py help
    avaiable args:
class DefaultConfig(object):
    env = 'default' # visdom 環(huán)境
    model = 'AlexNet' # 使用的模型
    
    train_data_root = './data/train/' # 訓(xùn)練集存放路徑
    test_data_root = './data/test' # 測試集存放路徑
    load_model_path = 'checkpoints/model.pth' # 加載預(yù)訓(xùn)練的模型

    batch_size = 128 # batch size
    use_gpu = True # user GPU or not
    num_workers = 4 # how many workers for loading data
    print_freq = 20 # print info every N batch

    debug_file = './debug/debug.txt' 
    result_file = 'result.csv' # 結(jié)果文件
      
    max_epoch = 10
    lr = 0.1 # initial learning rate
    lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay
    weight_decay = 1e-4 # 損失函數(shù)
6.1.9 使用

正如help函數(shù)的打印信息所述轧粟,可以通過命令行參數(shù)指定變量名.下面是三個(gè)使用例子策治,fire會將包含-的命令行參數(shù)自動(dòng)轉(zhuǎn)層下劃線_,也會將非數(shù)值的值轉(zhuǎn)成字符串兰吟。所以--train-data-root=data/train--train_data_root='data/train'是等價(jià)的通惫。

# 訓(xùn)練模型
python main.py train 
        --train-data-root=data/train/ 
        --lr=0.005 
        --batch-size=32 
        --model='ResNet34'  
        --max-epoch = 20

# 測試模型
python main.py test
       --test-data-root=data/test 
       --load-model-path='checkpoints/resnet34_00:23:05.pth' 
       --batch-size=128 
       --model='ResNet34' 
       --num-workers=12

# 打印幫助信息
python main.py help

實(shí)驗(yàn)過程

本章程序及數(shù)據(jù)下載:百度網(wǎng)盤,提取碼:aw26混蔼。

首先履腋,在命令行cmd紅啟動(dòng)visdom服務(wù)器:

python -m visdom.server

然后,訓(xùn)練模型:

python main.py train

訓(xùn)練結(jié)果如下:

image.png

從上述結(jié)果可以看出,模型的精度可以達(dá)到97%以上遵湖。你也可以手動(dòng)更改模型悔政,通過調(diào)節(jié)參數(shù)來進(jìn)一步提升模型的準(zhǔn)確率。

最后延旧,測試模型:

python main.py test

第二列表示預(yù)測為狗的概率:

image.png

我們來看一下測試集圖片:

image.png

可以看到谋国,模型能夠正確識別出很多狗和貓了,但是還存在很大的改進(jìn)空間迁沫。

6.1.10 爭議

以上的程序設(shè)計(jì)規(guī)范帶有作者強(qiáng)烈的個(gè)人喜好烹卒,并不想作為一個(gè)標(biāo)準(zhǔn),而是作為一個(gè)提議和一種參考弯洗。上述設(shè)計(jì)在很多地方還有待商榷,例如對于訓(xùn)練過程是否應(yīng)該封裝成一個(gè)trainer對象逢勾,或者直接封裝到BaiscModuletrain方法之中牡整。對命令行參數(shù)的處理也有不少值得討論之處。因此不要將本文中的觀點(diǎn)作為一個(gè)必須遵守的規(guī)范溺拱,而應(yīng)該看作一個(gè)參考逃贝。

本章中的設(shè)計(jì)可能會引起不少爭議,其中比較值得商榷的部分主要有以下兩個(gè)方面:

  • 命令行參數(shù)的設(shè)置做修。目前大多數(shù)程序都是使用Python標(biāo)準(zhǔn)庫中的argparse來處理命令行參數(shù)硝岗,也有些使用比較輕量級的click涝涤。這種處理相對來說對命令行的支持更完備,但根據(jù)作者的經(jīng)驗(yàn)來看沪摄,這種做法不夠直觀,并且代碼量相對來說也較多纱烘。比如argparse杨拐,每次增加一個(gè)命令行參數(shù),都必須寫如下代碼:
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')

在讀者眼中擂啥,這種實(shí)現(xiàn)方式遠(yuǎn)不如一個(gè)專門的config.py來的直觀和易用哄陶。尤其是對于使用Jupyter notebook或IPython等交互式調(diào)試的用戶來說,argparse較難使用哺壶。

  • 模型訓(xùn)練屋吨。有不少人喜歡將模型的訓(xùn)練過程集成于模型的定義之中,代碼結(jié)構(gòu)如下所示:
  class MyModel(nn.Module):
    
      def __init__(self,opt):
          self.dataloader = Dataloader(opt)
          self.optimizer  = optim.Adam(self.parameters(),lr=0.001)
          self.lr = opt.lr
          self.model = make_model()
      
      def forward(self,input):
          pass
      
      def train_(self):
          # 訓(xùn)練模型
          for epoch in range(opt.max_epoch)
            for ii,data in enumerate(self.dataloader):
                train_epoch()
              
            model.save()
    
      def train_epoch(self):
          pass

抑或是專門設(shè)計(jì)一個(gè)Trainer對象山宾,形如:

    """
  code simplified from:
  https://github.com/pytorch/pytorch/blob/master/torch/utils/trainer/trainer.py
  """
  import heapq
  from torch.autograd import Variable

  class Trainer(object):

      def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
          self.model = model
          self.criterion = criterion
          self.optimizer = optimizer
          self.dataset = dataset
          self.iterations = 0

      def run(self, epochs=1):
          for i in range(1, epochs + 1):
              self.train()

      def train(self):
          for i, data in enumerate(self.dataset, self.iterations + 1):
              batch_input, batch_target = data
              self.call_plugins('batch', i, batch_input, batch_target)
              input_var = Variable(batch_input)
              target_var = Variable(batch_target)
    
              plugin_data = [None, None]
    
              def closure():
                  batch_output = self.model(input_var)
                  loss = self.criterion(batch_output, target_var)
                  loss.backward()
                  if plugin_data[0] is None:
                      plugin_data[0] = batch_output.data
                      plugin_data[1] = loss.data
                  return loss
    
              self.optimizer.zero_grad()
              self.optimizer.step(closure)
    
          self.iterations += i

還有一些人喜歡模仿keras和scikit-learn的設(shè)計(jì)至扰,設(shè)計(jì)一個(gè)fit接口。對讀者來說资锰,這些處理方式很難說哪個(gè)更好或更差渊胸,找到最適合自己的方法才是最好的。

BasicModule 的封裝台妆,可多可少翎猛。訓(xùn)練過程中的很多操作都可以移到BasicModule之中胖翰,比如get_optimizer方法用來獲取優(yōu)化器,比如train_step用來執(zhí)行單歩訓(xùn)練切厘。對于不同的模型萨咳,如果對應(yīng)的優(yōu)化器定義不一樣,或者是訓(xùn)練方法不一樣疫稿,可以復(fù)寫這些函數(shù)自定義相應(yīng)的方法培他,取決于自己的喜好和項(xiàng)目的實(shí)際需求。

6.2 PyTorch Debug指南

6.2.1 ipdb介紹

很多初學(xué)者用print或log調(diào)試程序遗座,這在小規(guī)模的程序下很方便舀凛。但是更好的調(diào)試方法是一邊運(yùn)行一邊檢查里面的變量和方法。pdb是一個(gè)交互式的調(diào)試工具途蒋,集成于Python的標(biāo)準(zhǔn)庫之中猛遍,由于其強(qiáng)大的功能,被廣泛應(yīng)用于Python環(huán)境中号坡。pdb能讓你根據(jù)需求跳轉(zhuǎn)到任意的Python代碼斷點(diǎn)懊烤、查看任意變量、單步執(zhí)行代碼宽堆,甚至還能修改代碼的值腌紧,而不必重啟程序。ipdb是一個(gè)增強(qiáng)版的pdb畜隶,可通過pip install ipdb安裝壁肋。ipdb提供了調(diào)試模式下的代碼補(bǔ)全,還具有更好的語法高亮和代碼溯源籽慢,以及更好的內(nèi)省功能墩划,更關(guān)鍵的是,它與pdb接口完全兼容嗡综。

在本書第2章曾粗略地提到過ipdb的基本使用乙帮,本章將繼續(xù)介紹如何結(jié)合PyTorch和ipdb進(jìn)行調(diào)試。首先看一個(gè)例子极景,要是用ipdb察净,只需在想要進(jìn)行調(diào)試的地方插入ipdb.set_trace(),當(dāng)代碼運(yùn)行到此處時(shí)盼樟,就會自動(dòng)進(jìn)入交互式調(diào)試模式氢卡。

假設(shè)有如下程序:

try:
    import ipdb
except:
    import pdb as ipdb

def sum(x):
    r = 0
    for ii in x:
        r += ii
    return r
    
def mul(x):
    r = 1
    for ii in x:
        r *= ii
    return r
    
ipdf.set_trace()
x = [1,2,3,4,5]
r = sum(x)
r = mul(x)

當(dāng)程序運(yùn)行至ipdb.set_trace(),會自動(dòng)進(jìn)入debug模式晨缴,在該模式中译秦,我們可使用調(diào)試命令,如next或縮寫n單步執(zhí)行,也可查看Python變量筑悴,或是運(yùn)行Python代碼们拙。如果Python變量名和調(diào)式命令沖突,需要在變量名前加"!"阁吝,這樣ipdb會執(zhí)行對應(yīng)的Python代碼砚婆,而不是調(diào)試命令。下面舉例說明ipdb的調(diào)試突勇,這里重點(diǎn)講解ipdb的兩大功能装盯。

  • 查看:在函數(shù)調(diào)用堆棧中自由跳轉(zhuǎn),并查看函數(shù)的局部變量
  • 修改:修改程序中的變量甲馋,并能以此影響程序的運(yùn)行結(jié)果埂奈。
> e:\debug.py(19)<module>()
     18 ipdb.set_trace()
---> 19 x = [1,2,3,4,5]
     20 r = sum(x)

ipdb> l 1,21    # list 1,21的縮寫,查看第1行到第21行的代碼定躏,光標(biāo)所指的這一行尚未運(yùn)行
      1 try:
      2         import ipdb
      3 except:
      4         import pdb as ipdb
      5
      6 def sum(x):
      7         r = 0
      8         for ii in x:
      9                 r += ii
     10         return r
     11
     12 def mul(x):
     13         r = 1
     14         for ii in x:
     15                 r *= ii
     16         return r
     17
     18 ipdb.set_trace()
---> 19 x = [1,2,3,4,5]
     20 r = sum(x)
     21 r = mul(x)

ipdb> n    # next的縮寫账磺,執(zhí)行下一步
> e:\debug.py(20)<module>()
     19 x = [1,2,3,4,5]
---> 20 r = sum(x)
     21 r = mul(x)

ipdb> s    # step的縮寫,進(jìn)入sum函數(shù)內(nèi)部
--Call--
> e:\debug.py(6)sum()
      5
----> 6 def sum(x):
      7         r = 0

ipdb> n    # next單步執(zhí)行
> e:\debug.py(7)sum()
      6 def sum(x):
----> 7         r = 0
      8         for ii in x:

ipdb> n    # next單步執(zhí)行
> e:\debug.py(8)sum()
      7         r = 0
----> 8         for ii in x:
      9                 r += ii

ipdb> n    # next單步執(zhí)行
> e:\debug.py(9)sum()
      8         for ii in x:
----> 9                 r += ii
     10         return r

ipdb> u    # up的縮寫共屈,跳回到上一層的調(diào)用
> e:\debug.py(20)<module>()
     19 x = [1,2,3,4,5]
---> 20 r = sum(x)
     21 r = mul(x)

ipdb> d    # down的縮寫,跳到調(diào)用的下一層
> e:\debug.py(9)sum()
      8         for ii in x:
----> 9                 r += ii
     10         return r

ipdb> !r    # !r 查看變量r的值党窜,該變量名與調(diào)試命令`r(eturn)`沖突 
0

ipdb> r    # return的縮寫拗引,繼續(xù)運(yùn)行直到函數(shù)返回
--Return--
15
> e:\debug.py(10)sum()
      9                 r += ii
---> 10         return r
     11

ipdb> n    # 下一步
> e:\debug.py(21)<module>()
     19 x = [1,2,3,4,5]
     20 r = sum(x)
---> 21 r = mul(x)

ipdb> x    # 查看變量x的值
[1, 2, 3, 4, 5]

ipdb> x[0] = 10000    # 修改變量x的值

ipdb> b 13    # break的縮寫,AI第13行設(shè)置斷點(diǎn)
Breakpoint 1 at e:\debug.py:13

ipdb> c    # continue的縮寫幌衣,繼續(xù)運(yùn)行矾削,直到遇到斷點(diǎn)
> e:\debug.py(13)mul()
     12 def mul(x):
1--> 13         r = 1
     14         for ii in x:

ipdb> return    # 返回的是修改后x的乘積
--Return--
1200000
> e:\debug.py(16)mul()
     15                 r *= ii
---> 16         return r
     17

ipdb> q    # quit的縮寫,退出debug模式
Exiting Debugger.

關(guān)于ipdb的使用還有一些技巧:

  • <tab>鍵能夠自動(dòng)補(bǔ)齊豁护,補(bǔ)齊用法與IPython中的類似哼凯。
  • j(ump) <lineno>能夠跳過中間某些行代碼的執(zhí)行
  • 可以直接在ipdb中修改變量的值
  • h(elp)能夠查看調(diào)試命令的用法,比如h h可以查看h(elp)命令的用法楚里,h jump能夠查看j(ump)命令的用法断部。
6.2.2 在PyTorch中Debug

PyTorch作為一個(gè)動(dòng)態(tài)圖框架,與ipdb結(jié)合使用能為調(diào)試過程帶來便捷班缎。對TensorFlow等靜態(tài)圖框架來說蝴光,使用Python接口定義計(jì)算圖,然后使用C++代碼執(zhí)行底層運(yùn)算达址,在定義圖的時(shí)候不進(jìn)行任何計(jì)算蔑祟,而在計(jì)算的時(shí)候又無法使用pdb進(jìn)行調(diào)試,因?yàn)閜db調(diào)試只能調(diào)試Python代碼沉唠,故調(diào)試一直是此類靜態(tài)圖框架的一個(gè)痛點(diǎn)疆虚。與TensorFlow不同,PyTorch可以在執(zhí)行計(jì)算的同時(shí)定義計(jì)算圖,這些計(jì)算定義過程是使用Python完成的径簿。雖然底層的計(jì)算也是用C/C++完成的罢屈,但是我們能夠查看Python定義部分的變量值,這就已經(jīng)足夠了牍帚。下面我們將舉例說明儡遮。

  • 如何AIPyTorch中查看神經(jīng)網(wǎng)絡(luò)各個(gè)層的輸出。
  • 如何在PyTorch中分析各個(gè)參數(shù)的梯度暗赶。
  • 如何動(dòng)態(tài)修改PyTorch的訓(xùn)練過程鄙币。

首先,運(yùn)行第一節(jié)給出的“貓狗大戰(zhàn)”程序:

python main.py train --debug-file='debug/debug.txt'

程序運(yùn)行一段時(shí)間后蹂随,在debug目錄下創(chuàng)建debug.txt標(biāo)識文件十嘿,當(dāng)程序檢測到這個(gè)文件存在時(shí),會自動(dòng)進(jìn)入debug模式岳锁。

99it [00:17,  6.07it/s]loss: 0.22854854568839075
119it [00:21,  5.79it/s]loss: 0.21267264398435753
139it [00:24,  5.99it/s]loss: 0.19839374726372108
> e:\workspace\python\pytorch\chapter6\main.py(80)train()
     79         loss_meter.reset()
---> 80         confusion_matrix.reset()
     81         for ii, (data, label) in tqdm(enumerate(train_dataloader)):

ipdb> break 88    # 在第88行設(shè)置斷點(diǎn)绩衷,當(dāng)程序運(yùn)行到此處進(jìn)入debug模式
Breakpoint 1 at e:\workspace\python\pytorch\chapter6\main.py:88

ipdb> # 打印所有參數(shù)及其梯度的標(biāo)準(zhǔn)差
for (name,p) in model.named_parameters(): \
    print(name,p.data.std(),p.grad.data.std())
model.features.0.weight tensor(0.2615, device='cuda:0') tensor(0.3769, device='cuda:0')
model.features.0.bias tensor(0.4862, device='cuda:0') tensor(0.3368, device='cuda:0')
model.features.3.squeeze.weight tensor(0.2738, device='cuda:0') tensor(0.3023, device='cuda:0')
model.features.3.squeeze.bias tensor(0.5867, device='cuda:0') tensor(0.3753, device='cuda:0')
model.features.3.expand1x1.weight tensor(0.2168, device='cuda:0') tensor(0.2883, device='cuda:0')
model.features.3.expand1x1.bias tensor(0.2256, device='cuda:0') tensor(0.1147, device='cuda:0')
model.features.3.expand3x3.weight tensor(0.0935, device='cuda:0') tensor(0.1605, device='cuda:0')
model.features.3.expand3x3.bias tensor(0.1421, device='cuda:0') tensor(0.0583, device='cuda:0')
model.features.4.squeeze.weight tensor(0.1976, device='cuda:0') tensor(0.2137, device='cuda:0')
model.features.4.squeeze.bias tensor(0.4058, device='cuda:0') tensor(0.1798, device='cuda:0')
model.features.4.expand1x1.weight tensor(0.2144, device='cuda:0') tensor(0.4214, device='cuda:0')
model.features.4.expand1x1.bias tensor(0.4994, device='cuda:0') tensor(0.0958, device='cuda:0')
model.features.4.expand3x3.weight tensor(0.1063, device='cuda:0') tensor(0.2963, device='cuda:0')
model.features.4.expand3x3.bias tensor(0.0489, device='cuda:0') tensor(0.0719, device='cuda:0')
model.features.6.squeeze.weight tensor(0.1736, device='cuda:0') tensor(0.3544, device='cuda:0')
model.features.6.squeeze.bias tensor(0.2420, device='cuda:0') tensor(0.0896, device='cuda:0')
model.features.6.expand1x1.weight tensor(0.1211, device='cuda:0') tensor(0.2428, device='cuda:0')
model.features.6.expand1x1.bias tensor(0.0670, device='cuda:0') tensor(0.0162, device='cuda:0')
model.features.6.expand3x3.weight tensor(0.0593, device='cuda:0') tensor(0.1917, device='cuda:0')
model.features.6.expand3x3.bias tensor(0.0227, device='cuda:0') tensor(0.0160, device='cuda:0')
model.features.7.squeeze.weight tensor(0.1207, device='cuda:0') tensor(0.2179, device='cuda:0')
model.features.7.squeeze.bias tensor(0.1484, device='cuda:0') tensor(0.0381, device='cuda:0')
model.features.7.expand1x1.weight tensor(0.1235, device='cuda:0') tensor(0.2279, device='cuda:0')
model.features.7.expand1x1.bias tensor(0.0450, device='cuda:0') tensor(0.0100, device='cuda:0')
model.features.7.expand3x3.weight tensor(0.0609, device='cuda:0') tensor(0.1628, device='cuda:0')
model.features.7.expand3x3.bias tensor(0.0132, device='cuda:0') tensor(0.0079, device='cuda:0')
model.features.9.squeeze.weight tensor(0.1093, device='cuda:0') tensor(0.2459, device='cuda:0')
model.features.9.squeeze.bias tensor(0.0646, device='cuda:0') tensor(0.0135, device='cuda:0')
model.features.9.expand1x1.weight tensor(0.0840, device='cuda:0') tensor(0.1860, device='cuda:0')
model.features.9.expand1x1.bias tensor(0.0177, device='cuda:0') tensor(0.0033, device='cuda:0')
model.features.9.expand3x3.weight tensor(0.0476, device='cuda:0') tensor(0.1393, device='cuda:0')
model.features.9.expand3x3.bias tensor(0.0058, device='cuda:0') tensor(0.0030, device='cuda:0')
model.features.10.squeeze.weight tensor(0.0872, device='cuda:0') tensor(0.1676, device='cuda:0')
model.features.10.squeeze.bias tensor(0.0484, device='cuda:0') tensor(0.0088, device='cuda:0')
model.features.10.expand1x1.weight tensor(0.0859, device='cuda:0') tensor(0.2145, device='cuda:0')
model.features.10.expand1x1.bias tensor(0.0160, device='cuda:0') tensor(0.0025, device='cuda:0')
model.features.10.expand3x3.weight tensor(0.0456, device='cuda:0') tensor(0.1429, device='cuda:0')
model.features.10.expand3x3.bias tensor(0.0070, device='cuda:0') tensor(0.0021, device='cuda:0')
model.features.11.squeeze.weight tensor(0.0786, device='cuda:0') tensor(0.2003, device='cuda:0')
model.features.11.squeeze.bias tensor(0.0422, device='cuda:0') tensor(0.0069, device='cuda:0')
model.features.11.expand1x1.weight tensor(0.0690, device='cuda:0') tensor(0.1400, device='cuda:0')
model.features.11.expand1x1.bias tensor(0.0138, device='cuda:0') tensor(0.0022, device='cuda:0')
model.features.11.expand3x3.weight tensor(0.0366, device='cuda:0') tensor(0.1517, device='cuda:0')
model.features.11.expand3x3.bias tensor(0.0109, device='cuda:0') tensor(0.0023, device='cuda:0')
model.features.12.squeeze.weight tensor(0.0729, device='cuda:0') tensor(0.1736, device='cuda:0')
model.features.12.squeeze.bias tensor(0.0814, device='cuda:0') tensor(0.0084, device='cuda:0')
model.features.12.expand1x1.weight tensor(0.0977, device='cuda:0') tensor(0.1385, device='cuda:0')
model.features.12.expand1x1.bias tensor(0.0102, device='cuda:0') tensor(0.0032, device='cuda:0')
model.features.12.expand3x3.weight tensor(0.0365, device='cuda:0') tensor(0.1312, device='cuda:0')
model.features.12.expand3x3.bias tensor(0.0038, device='cuda:0') tensor(0.0026, device='cuda:0')
model.classifier.1.weight tensor(0.0285, device='cuda:0') tensor(0.0865, device='cuda:0')
model.classifier.1.bias tensor(0.0362, device='cuda:0') tensor(0.0192, device='cuda:0')

ipdb> opt.lr    # 查看學(xué)習(xí)率
0.001

ipdb> opt.lr = 0.002    # 更改學(xué)習(xí)率

ipdb> for p in optimizer.param_groups: \
    p['lr'] = opt.lr

ipdb> model.save()    # 保存模型
'checkpoints/squeezenet_20191004212249.pth'

ipdb> c    # 繼續(xù)運(yùn)行,直到第88行暫停
222it [16:38, 35.62s/it]> e:\workspace\python\pytorch\chapter6\main.py(88)train()
     87             optimizer.zero_grad()
1--> 88             score = model(input)
     89             loss = criterion(score, target)

ipdb> s    # 進(jìn)入model(input)內(nèi)部激率,即model.__call__(input)
--Call--
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(537)__call__()
    536 
--> 537     def __call__(self, *input, **kwargs):
    538         for hook in self._forward_pre_hooks.values():

ipdb> n    # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(538)__call__()
    537     def __call__(self, *input, **kwargs):
--> 538         for hook in self._forward_pre_hooks.values():
    539             result = hook(self, input)

ipdb> n    # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(544)__call__()
    543                 input = result
--> 544         if torch._C._get_tracing_state():
    545             result = self._slow_forward(*input, **kwargs)

ipdb> n    # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(547)__call__()
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():

ipdb> s    # 進(jìn)入forward函數(shù)內(nèi)容
--Call--
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\loss.py(914)forward()
    913 
--> 914     def forward(self, input, target):
    915         return F.cross_entropy(input, target, weight=self.weight,

ipdb> input    # 查看input變量值
tensor([[4.5005, 2.0725],
        [3.5933, 7.8643],
        [2.9086, 3.4209],
        [2.7740, 4.4332],
        [6.0164, 2.3033],
        [5.2261, 3.2189],
        [2.6529, 2.0749],
        [6.3259, 2.2383],
        [3.0629, 3.4832],
        [2.7008, 8.2818],
        [5.5684, 2.1567],
        [3.0689, 6.1022],
        [3.4848, 5.3831],
        [1.7920, 5.7709],
        [6.5032, 2.8080],
        [2.3071, 5.2417],
        [3.7474, 5.0263],
        [4.3682, 3.6707],
        [2.2196, 6.9298],
        [5.2201, 2.3034],
        [6.4315, 1.4970],
        [3.4684, 4.0371],
        [3.9620, 1.7629],
        [1.7069, 7.8898],
        [3.0462, 1.6505],
        [2.4081, 6.4456],
        [2.1932, 7.4614],
        [2.3405, 2.7603],
        [1.9478, 8.4156],
        [2.7935, 7.8331],
        [1.8898, 3.8836],
        [3.3008, 1.6832]], device='cuda:0', grad_fn=<AsStridedBackward>)

ipdb> input.data.mean()    # 查看input的均值和標(biāo)準(zhǔn)差
tensor(3.9630, device='cuda:0')
ipdb> input.data.std()
tensor(1.9513, device='cuda:0')

ipdb> u    # 跳回上一層
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(547)__call__()
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():

ipdb> u    # 跳回上一層
> e:\workspace\python\pytorch\chapter6\main.py(88)train()
     87             optimizer.zero_grad()
1--> 88             score = model(input)
     89             loss = criterion(score, target)

ipdb> clear    # 清除所有斷點(diǎn)
Clear all breaks? y
Deleted breakpoint 1 at e:\workspace\python\pytorch\chapter6\main.py:88

ipdb> c    # 繼續(xù)運(yùn)行咳燕,記得先刪除"debug/debug.txt",否則很快又會進(jìn)入調(diào)試模式
59it [06:21,  5.75it/s]loss: 0.24856307208538073
76it [06:24,  5.91it/s]

當(dāng)我們想要進(jìn)入debug模式乒躺,修改程序中某些參數(shù)值或者想分析程序時(shí)招盲,就可以通過創(chuàng)建debug標(biāo)識文件,此時(shí)程序會進(jìn)入調(diào)試模式嘉冒,調(diào)試完成之后刪除這個(gè)文件并在ipdb調(diào)試接口輸入c繼續(xù)運(yùn)行程序曹货。如果想退出程序,也可以使用這種方式讳推,先創(chuàng)建debug標(biāo)識文件顶籽,然后輸入quit在退出debug的同時(shí)退出程序。這種退出程序的方式银觅,與使用Ctrl+C的方式相比更安全礼饱,因?yàn)檫@能保證數(shù)據(jù)加載的多進(jìn)程程序也能正確地退出,并釋放內(nèi)存究驴、顯存等資源慨仿。

PyTorch和ipdb集合能完成很多其他框架所不能完成或很難完成的功能。根據(jù)筆者日常使用的總結(jié)纳胧,主要有以下幾個(gè)部分:
(1)通過debug暫停程序镰吆。當(dāng)程序進(jìn)入debug模式后,將不再執(zhí)行PCU和GPU運(yùn)算跑慕,但是內(nèi)存和顯存及相應(yīng)的堆椡蛎螅空間不會釋放摧找。
(2)通過debug分析程序,查看每個(gè)層的輸出牢硅,查看網(wǎng)絡(luò)的參數(shù)情況蹬耘。通過u(p)、d(own)减余、s(tep)等命令综苔,能夠進(jìn)入指定的代碼,通過n(ext)可以單步執(zhí)行位岔,從而看到每一層的運(yùn)算結(jié)果如筛,便于分析網(wǎng)絡(luò)的數(shù)值分布等信息。
(3)作為動(dòng)態(tài)圖框架抒抬,PyTorch擁有Python動(dòng)態(tài)語言解釋執(zhí)行的優(yōu)點(diǎn)杨刨,我們能夠在運(yùn)行程序時(shí),用過ipdb修改某些變量的值或?qū)傩圆两#@些修改能夠立即生效妖胀。例如可以在訓(xùn)練開始不久根據(jù)損失函數(shù)調(diào)整學(xué)習(xí)率,不必重啟程序惠勒。
(4)如果在IPython中通過%run魔法方法運(yùn)行程序赚抡,那么在程序異常退出時(shí),可以使用%debug命令纠屋,直接進(jìn)入debug模式涂臣,通過u(p)和d(own)跳到報(bào)錯(cuò)的地方,查看對應(yīng)的變量巾遭,找出原因后修改相應(yīng)的代碼即可肉康。有時(shí)我們的模式訓(xùn)練了好幾個(gè)小時(shí)闯估,卻在將要保存模式之前灼舍,因?yàn)橐粋€(gè)小小的拼寫錯(cuò)誤異常退出。此時(shí)涨薪,如果修改錯(cuò)誤再重新運(yùn)行程序又要花費(fèi)好幾個(gè)小時(shí)骑素,太浪費(fèi)時(shí)間。因此最好的方法就是看利用%debug進(jìn)入調(diào)試模式刚夺,在調(diào)試模式中直接運(yùn)行model.save()保存模型献丑。在IPython中,%pdb魔術(shù)方法能夠使得程序出現(xiàn)問題后侠姑,不用手動(dòng)輸入%debug而自動(dòng)進(jìn)入debug模式创橄,建議使用。

PyTorch調(diào)用CuDNN報(bào)錯(cuò)時(shí)莽红,報(bào)錯(cuò)信息諸如CUDNN_STATUS_BAD_PARAM妥畏,從這些報(bào)錯(cuò)內(nèi)容很難得到有用的幫助信息,最后先利用PCU運(yùn)行代碼醉蚁,此時(shí)一般會得到相對友好的報(bào)錯(cuò)信息燃辖,例如在ipdb中執(zhí)行model.cpu()(input.cpu()),PyTorch底層的TH庫會給出相對比較詳細(xì)的信息网棍。

常見的錯(cuò)誤主要有以下幾種:

  • 類型不匹配問題黔龟。例如CrossEntropyLoss的輸入target應(yīng)該是一個(gè)LongTensor,而很多人輸入FloatTensor滥玷。
  • 部分?jǐn)?shù)據(jù)忘記從CPU轉(zhuǎn)移到GPU氏身。例如,當(dāng)model存放于GPU時(shí)罗捎,輸入input也需要轉(zhuǎn)移到GPU才能輸入到model中观谦。還有可能就是把多個(gè)model存放于一個(gè)list對象,而在執(zhí)行model.cuda()時(shí)桨菜,這個(gè)list中的對象是不會被轉(zhuǎn)移到CUDA上的豁状,正確的用法是用ModuleList代替。
  • Tensor形狀不匹配倒得。此類問題一般是輸入數(shù)據(jù)形狀不對泻红,或是網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計(jì)有問題,一般通過u(p)跳到指定代碼霞掺,查看輸入和模型參數(shù)的形狀即可得知谊路。

此外,可能還會經(jīng)常遇到程序正常運(yùn)行菩彬、沒有報(bào)錯(cuò)缠劝,但是模型無法收斂的問題。例如對于二分類問題骗灶,交叉熵?fù)p失一直徘徊在0.69附近(ln2)惨恭,或者是數(shù)值出現(xiàn)溢出等問題,此時(shí)可以進(jìn)入debug模式耙旦,用單步執(zhí)行查看脱羡,每一層輸出的均值和方差,觀察從哪一層的輸出開始出現(xiàn)數(shù)值異常免都。還要查看每個(gè)參數(shù)梯度的均值和方差锉罐,查看是否出現(xiàn)梯度消失或者梯度爆炸等問題。一般來說绕娘,通過再激活函數(shù)之前增加BatchNorm層脓规、合理的參數(shù)初始化、使用Adam優(yōu)化器险领、學(xué)習(xí)率設(shè)為0.001侨舆,基本就能確保模型在一定程度收斂升酣。

本章帶領(lǐng)讀者從頭實(shí)現(xiàn)了一個(gè)Kaggle上的經(jīng)典競賽,重點(diǎn)講解了如何合理地組合安排程序态罪,同時(shí)介紹了一些在PyTorch中調(diào)試的技巧噩茄。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市复颈,隨后出現(xiàn)的幾起案子绩聘,更是在濱河造成了極大的恐慌,老刑警劉巖耗啦,帶你破解...
    沈念sama閱讀 217,509評論 6 504
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件凿菩,死亡現(xiàn)場離奇詭異,居然都是意外死亡帜讲,警方通過查閱死者的電腦和手機(jī)衅谷,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,806評論 3 394
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來似将,“玉大人获黔,你說我怎么就攤上這事≡谘椋” “怎么了玷氏?”我有些...
    開封第一講書人閱讀 163,875評論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長腋舌。 經(jīng)常有香客問我盏触,道長,這世上最難降的妖魔是什么块饺? 我笑而不...
    開封第一講書人閱讀 58,441評論 1 293
  • 正文 為了忘掉前任赞辩,我火速辦了婚禮,結(jié)果婚禮上授艰,老公的妹妹穿的比我還像新娘辨嗽。我一直安慰自己,他們只是感情好想诅,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,488評論 6 392
  • 文/花漫 我一把揭開白布召庞。 她就那樣靜靜地躺著岛心,像睡著了一般来破。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上忘古,一...
    開封第一講書人閱讀 51,365評論 1 302
  • 那天徘禁,我揣著相機(jī)與錄音,去河邊找鬼髓堪。 笑死送朱,一個(gè)胖子當(dāng)著我的面吹牛娘荡,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播驶沼,決...
    沈念sama閱讀 40,190評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼炮沐,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了回怜?” 一聲冷哼從身側(cè)響起大年,我...
    開封第一講書人閱讀 39,062評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎玉雾,沒想到半個(gè)月后翔试,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,500評論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡复旬,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,706評論 3 335
  • 正文 我和宋清朗相戀三年垦缅,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片驹碍。...
    茶點(diǎn)故事閱讀 39,834評論 1 347
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡壁涎,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出志秃,到底是詐尸還是另有隱情粹庞,我是刑警寧澤,帶...
    沈念sama閱讀 35,559評論 5 345
  • 正文 年R本政府宣布洽损,位于F島的核電站庞溜,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏碑定。R本人自食惡果不足惜流码,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,167評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望延刘。 院中可真熱鬧漫试,春花似錦、人聲如沸碘赖。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,779評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽普泡。三九已至播掷,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間撼班,已是汗流浹背歧匈。 一陣腳步聲響...
    開封第一講書人閱讀 32,912評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留砰嘁,地道東北人件炉。 一個(gè)月前我還...
    沈念sama閱讀 47,958評論 2 370
  • 正文 我出身青樓勘究,卻偏偏與公主長得像,于是被迫代替她去往敵國和親斟冕。 傳聞我的和親對象是個(gè)殘疾皇子口糕,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,779評論 2 354

推薦閱讀更多精彩內(nèi)容