深度學習框架PyTorch入門與實踐:第五章 PyTorch中常用的工具

在訓練神經(jīng)網(wǎng)絡的過程中需要用到很多工具,其中最重要的三部分是數(shù)據(jù)處理裸违、可視化和GPU加速厌秒。本章主要介紹PyTorch在這幾方面常用的工具村缸,合理使用這些工具能極大地提高編程效率。

5.1 數(shù)據(jù)處理

在解決深度學習問題的過程中痢虹,往往需要花費大量的精力去處理數(shù)據(jù)被去,包括圖像、文本奖唯、語音或其他二進制數(shù)據(jù)等惨缆。數(shù)據(jù)的處理對訓練神經(jīng)網(wǎng)絡來說十分重要,良好的數(shù)據(jù)處理不僅會加速模型訓練丰捷,也會提高模型效果坯墨。考慮到這一點病往,PyTorch提供了幾個高效便捷的工具捣染,以便使用者進行數(shù)據(jù)處理或者增強等操作,同時可通過并行化加速數(shù)據(jù)加載停巷。

(1)數(shù)據(jù)加載

在PyTorch中耍攘,數(shù)據(jù)加載可通過自定義的數(shù)據(jù)集對象實現(xiàn)榕栏。數(shù)據(jù)集對象被抽象為Dataset,實現(xiàn)自定義的數(shù)據(jù)集需要繼承Dataset蕾各,并實現(xiàn)兩個Python魔法方法扒磁。

  • getitem:返回一條數(shù)據(jù)或一個樣本。obj[index]等價于obj.getitem(index)式曲。
  • len:返回樣本的數(shù)量妨托。len(obj)等價于obj.len()。

這里我們以Kaggle經(jīng)典挑戰(zhàn)賽“Dogs vs. Cats”的數(shù)據(jù)為例吝羞,詳細講解如何處理數(shù)據(jù)兰伤。“Dogs vs. Cats”是一個分類問題脆贵,判斷一張圖片是狗還是貓医清,其所有圖片都存放在一個文件夾下,根據(jù)文件名的前綴判斷是狗還是貓卖氨。

import torch as t
from torch.utils import data
import os
from PIL import  Image
import numpy as np

class DogCat(data.Dataset):
    def __init__(self, root):
        imgs = os.listdir(root)
        # 所有圖片的絕對路徑
        # 這里不實際加載圖片会烙,只是指定路徑,當調(diào)用__getitem__時才會真正讀圖片
        self.imgs = [os.path.join(root, img) for img in imgs]
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        # dog->1筒捺, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        array = np.asarray(pil_img)
        data = t.from_numpy(array)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

dataset = DogCat('./data/dogcat/')
img, label = dataset[0] # 相當于調(diào)用dataset.__getitem__(0)
for img, label in dataset:
    print(img.size(), img.float().mean(), label)

輸出:

torch.Size([500, 497, 3]) tensor(106.4915) 0
torch.Size([499, 379, 3]) tensor(171.8085) 0
torch.Size([236, 289, 3]) tensor(130.3004) 0
torch.Size([374, 499, 3]) tensor(115.5177) 0
torch.Size([375, 499, 3]) tensor(116.8139) 1
torch.Size([375, 499, 3]) tensor(150.5080) 1
torch.Size([377, 499, 3]) tensor(151.7174) 1
torch.Size([400, 300, 3]) tensor(128.1550) 1

通過上面的代碼柏腻,我們學習了如何自定義自己的數(shù)據(jù)集,并可以依次獲取系吭。但這里返回的數(shù)據(jù)不適合實際使用五嫂,因其具有如下兩方面問題:

  • 返回樣本的形狀不一,因每張圖片的大小不一樣肯尺,這對于需要取batch訓練的神經(jīng)網(wǎng)絡來說很不友好沃缘。
  • 返回樣本的數(shù)值較大,未歸一化至[-1, 1]则吟。

針對上述問題槐臀,PyTorch提供了torchvision。它是一個視覺工具包氓仲,提供了很多視覺圖像處理的工具水慨,其中transforms模塊提供了對PIL Image對象和Tensor對象的常用操作。

對PIL Image的操作包括:

  • Scale:調(diào)整圖片尺寸敬扛,長寬比保持不變
  • CenterCrop晰洒、`RandomCrop、RandomResizedCrop: 裁剪圖片
  • Pad:填充
  • ToTensor:將PIL Image對象轉(zhuǎn)成Tensor啥箭,會自動將[0, 255]歸一化至[0, 1]

對Tensor的操作包括:

  • Normalize:標準化谍珊,即減均值,除以標準差
  • ToPILImage:將Tensor轉(zhuǎn)為PIL Image對象

如果要對圖片進行多個操作急侥,可通過Compose函數(shù)將這些操作拼接起來抬驴,類似于nn.Sequential炼七。注意,這些操作定義后是以函數(shù)的形式存在布持,真正使用時需調(diào)用它的call方法,這點類似于nn.Module陕悬。例如要將圖片調(diào)整為224×224题暖,首先應構(gòu)建這個操作trans = Resize((224, 224)),然后調(diào)用trans(img)捉超。下面我們就用transforms的這些操作來優(yōu)化上面實現(xiàn)的dataset胧卤。

import os
from PIL import  Image
import numpy as np
from torchvision import transforms as T

transform = T.Compose([
    T.Resize(224), # 縮放圖片(Image),保持長寬比不變拼岳,最短邊為224像素
    T.CenterCrop(224), # 從圖片中間切出224*224的圖片
    T.ToTensor(), # 將圖片(Image)轉(zhuǎn)成Tensor枝誊,歸一化至[0, 1]
    T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 標準化至[-1, 1],規(guī)定均值和標準差
])

class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms=transforms
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        label = 0 if 'dog' in img_path.split('/')[-1] else 1
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
    print(img.size(), label)

輸出:

torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0

除了上述操作之外惜纸,transforms還可通過Lambda封裝自定義的轉(zhuǎn)換策略叶撒。例如想對PIL Image進行隨機旋轉(zhuǎn),則可寫成這樣trans=T.Lambda(lambda img: img.rotate(random()*360))耐版。

torchvision已經(jīng)預先實現(xiàn)了常用的Dataset祠够,包括前面使用過的CIFAR-10,以及ImageNet粪牲、COCO古瓤、MNIST、LSUN等數(shù)據(jù)集腺阳,可通過諸如torchvision.datasets.CIFAR10來調(diào)用落君,具體使用方法請參看官方文檔。在這里介紹一個會經(jīng)常使用到的Dataset——ImageFolder亭引,它的實現(xiàn)和上述的DogCat很相似绎速。ImageFolder假設所有的文件按文件夾保存,每個文件夾下存儲同一個類別的圖片痛侍,文件夾名為類名朝氓,其構(gòu)造函數(shù)如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四個參數(shù):

  • root:在root指定的路徑下尋找圖片。
  • transform:對PIL Image進行的轉(zhuǎn)換操作主届,transform的輸入是使用loader讀取圖片的返回對象赵哲。
  • target_transform:對label的轉(zhuǎn)換。
  • loader:給定路徑后如何讀取圖片君丁,默認讀取為RGB格式的PIL Image對象枫夺。

label是按照文件夾名順序排序后存成字典,即{類名:類序號(從0開始)}绘闷,一般來說最好直接將文件夾命名為從0開始的數(shù)字橡庞,這樣會和ImageFolder實際的label一致较坛,如果不是這種命名規(guī)范,建議看看self.class_to_idx屬性以了解label和文件夾名的映射關系扒最。

from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')

# cat文件夾的圖片對應label 0丑勤,dog對應1
dataset.class_to_idx

輸出:

{'cat': 0, 'dog': 1}
# 所有圖片的路徑和對應的label
dataset.imgs

輸出:

[('data/dogcat_2/cat\\cat.12484.jpg', 0),
 ('data/dogcat_2/cat\\cat.12485.jpg', 0),
 ('data/dogcat_2/cat\\cat.12486.jpg', 0),
 ('data/dogcat_2/cat\\cat.12487.jpg', 0),
 ('data/dogcat_2/dog\\dog.12496.jpg', 1),
 ('data/dogcat_2/dog\\dog.12497.jpg', 1),
 ('data/dogcat_2/dog\\dog.12498.jpg', 1),
 ('data/dogcat_2/dog\\dog.12499.jpg', 1)]
# 沒有任何的transform,所以返回的還是PIL Image對象
dataset[0][1] # 第一維是第幾張圖吧趣,第二維為1返回label
dataset[0][0] # 第一維是第幾張圖法竞,第二維為0返回圖片數(shù)據(jù)

輸出:

image.png
# 加上transform
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform  = T.Compose([
         T.RandomResizedCrop(224),
         T.RandomHorizontalFlip(),
         T.ToTensor(),
         normalize,
])

dataset = ImageFolder('data/dogcat_2/', transform=transform)

# 深度學習中圖片數(shù)據(jù)一般保存成CxHxW,即通道數(shù)x圖片高x圖片寬
dataset[0][0].size()

輸出:

torch.Size([3, 224, 224])
to_img = T.ToPILImage()
# 0.2和0.4是標準差和均值的近似
to_img(dataset[0][0]*0.2+0.4)

輸出:

image.png

Dataset只負責數(shù)據(jù)的抽象强挫,一次調(diào)用getitem只返回一個樣本岔霸。前面提到過,在訓練神經(jīng)網(wǎng)絡時俯渤,最好是對一個batch的數(shù)據(jù)進行操作呆细,同時還需要對數(shù)據(jù)進行shuffle和并行加速等。對此八匠,PyTorch提供了DataLoader幫助我們實現(xiàn)這些功能絮爷。

DataLoader的函數(shù)定義如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
  • dataset:加載的數(shù)據(jù)集(Dataset對象)
  • batch_size:batch size
  • shuffle::是否將數(shù)據(jù)打亂
  • sampler: 樣本抽樣,后續(xù)會詳細介紹
  • num_workers:使用多進程加載的進程數(shù)臀叙,0代表不使用多進程
  • collate_fn: 如何將多個樣本數(shù)據(jù)拼接成一個batch略水,一般使用默認的拼接方式即可
  • pin_memory:是否將數(shù)據(jù)保存在pin memory區(qū),pin memory中的數(shù)據(jù)轉(zhuǎn)到GPU會快一些
  • drop_last:dataset中的數(shù)據(jù)個數(shù)可能不是batch_size的整數(shù)倍劝萤,drop_last為True會將多出來不足一個batch的數(shù)據(jù)丟棄
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)

dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight

輸出:

torch.Size([3, 3, 224, 224])

dataloader是一個可迭代的對象渊涝,意味著我們可以像使用迭代器一樣使用它,例如:

for batch_datas, batch_labels in dataloader:
    train()

dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)

在數(shù)據(jù)處理中床嫌,有時會出現(xiàn)某個樣本無法讀取等問題跨释,比如某張圖片損壞。這時在getitem函數(shù)中將出現(xiàn)異常厌处,此時最好的解決方案即是將出錯的樣本剔除鳖谈。如果實在是遇到這種情況無法處理,則可以返回None對象阔涉,然后在Dataloader中實現(xiàn)自定義的collate_fn缆娃,將空對象過濾掉。但要注意瑰排,在這種情況下dataloader返回的batch數(shù)目會少于batch_size贯要。

class NewDogCat(DogCat): # 繼承前面實現(xiàn)的DogCat數(shù)據(jù)集
    def __getitem__(self, index):
        try:
            # 調(diào)用父類的獲取函數(shù),即 DogCat.__getitem__(self, index)
            return super(NewDogCat,self).__getitem__(index)
        except:
            return None, None

from torch.utils.data.dataloader import default_collate # 導入默認的拼接方式
def my_collate_fn(batch):
    '''
    batch中每個元素形如(data, label)
    '''
    # 過濾為None的數(shù)據(jù)
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return t.Tensor()
    return default_collate(batch) # 用默認方式拼接過濾后的batch數(shù)據(jù)

dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)

dataset[5]

輸出:

(tensor([[[ 1.2549,  1.2549,  1.2549,  ..., -0.0980, -0.0980, -0.1569],
          [ 1.2941,  1.3137,  1.3137,  ..., -0.0784, -0.0980, -0.0980],
          [ 1.3137,  1.3137,  1.3137,  ..., -0.0588, -0.0784, -0.0588],
          ...,
          [ 1.3725,  1.3529,  1.3529,  ..., -1.4314, -1.4314, -1.4118],
          [ 1.3725,  1.3529,  1.3529,  ..., -1.4314, -1.4314, -1.4314],
          [ 1.3922,  1.3725,  1.3725,  ..., -1.3922, -1.3922, -1.3922]],
 
         [[ 1.0588,  0.9216,  0.8824,  ...,  0.4314,  0.4510,  0.3922],
          [ 1.0980,  0.9804,  0.9412,  ...,  0.4314,  0.4118,  0.4118],
          [ 1.1176,  0.9804,  0.9412,  ...,  0.4314,  0.4118,  0.4314],
          ...,
          [ 1.3725,  1.3137,  1.3137,  ..., -1.1569, -1.1373, -1.1176],
          [ 1.3725,  1.3137,  1.3137,  ..., -1.1569, -1.1373, -1.1176],
          [ 1.3922,  1.3137,  1.3137,  ..., -1.0784, -1.0784, -1.0784]],
 
         [[ 0.0392, -0.1765, -0.2549,  ...,  1.4706,  1.4510,  1.3725],
          [ 0.0784, -0.1176, -0.1961,  ...,  1.5098,  1.4706,  1.4510],
          [ 0.0980, -0.1176, -0.1765,  ...,  1.5294,  1.5098,  1.5294],
          ...,
          [ 0.3922,  0.3922,  0.3922,  ..., -1.1373, -1.1569, -1.1569],
          [ 0.3922,  0.3922,  0.4118,  ..., -1.2353, -1.2745, -1.2745],
          [ 0.4118,  0.4118,  0.4118,  ..., -1.2157, -1.2745, -1.2745]]]), 0)
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=0,shuffle=True)
for batch_datas, batch_labels in dataloader:
    print(batch_datas.size(),batch_labels.size())

輸出:

torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])

來看一下上述batch_size的大小椭住。其中第1個的batch_size為1崇渗,這是因為有一張圖片損壞,導致其無法正常返回。而最后1個的batch_size也為1宅广,這是因為共有9張(包括損壞的文件)圖片葫掉,無法整除2(batch_size),因此最后一個batch的數(shù)據(jù)會少于batch_szie跟狱,可通過指定drop_last=True來丟棄最后一個不足batch_size的batch俭厚。

對于諸如樣本損壞或數(shù)據(jù)集加載異常等情況,還可以通過其它方式解決驶臊。例如但凡遇到異常情況套腹,就隨機取一張圖片代替:

class NewDogCat(DogCat):
    def __getitem__(self, index):
        try:
            return super(NewDogCat, self).__getitem__(index)
        except:
            new_index = random.randint(0, len(self)-1)
            return self[new_index]

相比較丟棄異常圖片而言,這種做法會更好一些资铡,因為它能保證每個batch的數(shù)目仍是batch_size。但在大多數(shù)情況下幢码,最好的方式還是對數(shù)據(jù)進行徹底清洗笤休。

DataLoader里面并沒有太多的魔法方法,它封裝了Python的標準庫multiprocessing症副,使其能夠?qū)崿F(xiàn)多進程加速店雅。在此提幾點關于Dataset和DataLoader使用方面的建議:

  • 高負載的操作放在getitem中,如加載圖片等贞铣。
  • dataset中應盡量只包含只讀對象闹啦,避免修改任何可變對象,利用多線程進行操作辕坝。

第一點是因為多進程會并行的調(diào)用getitem函數(shù)窍奋,將負載高的放在getitem函數(shù)中能夠?qū)崿F(xiàn)并行加速。 第二點是因為dataloader使用多進程加載酱畅,如果在Dataset實現(xiàn)中使用了可變對象琳袄,可能會有意想不到的沖突。在多線程/多進程中纺酸,修改一個可變對象窖逗,需要加鎖,但是dataloader的設計使得其很難加鎖(在實際使用中也應盡量避免鎖的存在)餐蔬,因此最好避免在dataset中修改可變對象碎紊。例如下面就是一個不好的例子,在多進程處理中self.num可能與預期不符樊诺,這種問題不會報錯仗考,因此難以發(fā)現(xiàn)。如果一定要修改可變對象啄骇,建議使用Python標準庫Queue中的相關數(shù)據(jù)結(jié)構(gòu)痴鳄。

class BadDataset(Dataset):
    def __init__(self):
        self.datas = range(100)
        self.num = 0 # 取數(shù)據(jù)的次數(shù)
    def __getitem__(self, index):
        self.num += 1
        return self.datas[index]

使用Python multiprocessing庫的另一個問題是,在使用多進程時缸夹,如果主程序異常終止(比如用Ctrl+C強行退出)痪寻,相應的數(shù)據(jù)加載進程可能無法正常退出螺句。這時你可能會發(fā)現(xiàn)程序已經(jīng)退出了,但GPU顯存和內(nèi)存依舊被占用著橡类,或通過top蛇尚、ps aux依舊能夠看到已經(jīng)退出的程序,這時就需要手動強行殺掉進程顾画。建議使用如下命令:

ps x | grep <cmdline> | awk '{print $1}' | xargs kill
  • ps x:獲取當前用戶的所有進程
  • grep <cmdline>:找到已經(jīng)停止的PyTorch程序的進程取劫,例如你是通過python train.py啟動的,那你就需要寫grep 'python train.py'
  • awk '{print $1}':獲取進程的pid
  • xargs kill:殺掉進程研侣,根據(jù)需要可能要寫成xargs kill -9強制殺掉進程

在執(zhí)行這句命令之前谱邪,建議先打印確認一下是否會誤殺其它進程

ps x | grep <cmdline> | ps x

PyTorch中還單獨提供了一個sampler模塊,用來對數(shù)據(jù)進行采樣庶诡。常用的有隨機采樣器:RandomSampler惦银,當dataloader的shuffle參數(shù)為True時,系統(tǒng)會自動調(diào)用這個采樣器末誓,實現(xiàn)打亂數(shù)據(jù)扯俱。默認的是采用SequentialSampler,它會按順序一個一個進行采樣喇澡。這里介紹另外一個很有用的采樣方法: WeightedRandomSampler迅栅,它會根據(jù)每個樣本的權(quán)重選取數(shù)據(jù),在樣本比例不均衡的問題中晴玖,可用它來進行重采樣读存。

構(gòu)建WeightedRandomSampler時需提供兩個參數(shù):每個樣本的權(quán)重weights、共選取的樣本總數(shù)num_samples窜醉,以及一個可選參數(shù)replacement宪萄。權(quán)重越大的樣本被選中的概率越大,待選取的樣本數(shù)目一般小于全部的樣本數(shù)目榨惰。replacement用于指定是否可以重復選取某一個樣本拜英,默認為True,即允許在一個epoch中重復采樣某一個數(shù)據(jù)琅催。如果設為False居凶,則當某一類的樣本被全部選取完,但其樣本數(shù)目仍未達到num_samples時藤抡,sampler將不會再從該類中選擇數(shù)據(jù)侠碧,此時可能導致weights參數(shù)失效。下面舉例說明缠黍。

dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的圖片被取出的概率是貓的概率的兩倍
# 兩類圖片被取出的概率與weights的絕對大小無關弄兜,只和比值有關
weights = [2 if label == 1 else 1 for data, label in dataset]
weights

輸出:

[2, 2, 2, 2, 1, 1, 1, 1]
from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                                num_samples=9,\
                                replacement=True)
dataloader = DataLoader(dataset,
                        batch_size=3,
                        sampler=sampler)
for datas, labels in dataloader:
    print(labels.tolist())

輸出:

[1, 1, 1]
[0, 0, 1]
[1, 1, 1]

可見貓狗樣本比例約為1:2,另外一共只有8個樣本,但是卻返回了9個替饿,說明肯定有被重復返回的语泽,這就是replacement參數(shù)的作用,下面將replacement設為False試試视卢。

sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
    print(labels.tolist())

輸出:

[1, 1, 1, 1]
[0, 0, 0, 0]

在這種情況下踱卵,num_samples等于dataset的樣本總數(shù),為了不重復選取据过,sampler會將每個樣本都返回惋砂,這樣就失去weight參數(shù)的意義了。

從上面的例子可見sampler在樣本采樣中的作用:如果指定了sampler绳锅,shuffle將不再生效西饵,并且sampler.num_samples會覆蓋dataset的實際大小,即一個epoch返回的圖片總數(shù)取決于sampler.num_samples鳞芙。

5.2 計算機視覺工具包:torchvision

計算機視覺是深度學習中最重要的一類應用罗标,為了方便研究者使用,PyTorch團隊專門開發(fā)了一個視覺工具包torchvision积蜻,這個包獨立于PyTorch挖函,需通過pip instal torchvision安裝迈着。在之前的例子中我們已經(jīng)見識到了它的部分功能,這里再做一個系統(tǒng)性的介紹笨触。torchvision主要包含三部分:

  • models:提供深度學習中各種經(jīng)典網(wǎng)絡的網(wǎng)絡結(jié)構(gòu)以及預訓練好的模型宾尚,包括AlexNet丙笋、VGG系列、ResNet系列煌贴、Inception系列等御板。
  • datasets: 提供常用的數(shù)據(jù)集加載,設計上都是繼承torhc.utils.data.Dataset牛郑,主要包括MNIST怠肋、CIFAR10/100、ImageNet淹朋、COCO等笙各。
  • transforms:提供常用的數(shù)據(jù)預處理操作,主要包括對Tensor以及PIL Image對象的操作础芍。
from torchvision import models
from torch import nn
from torchvision import datasets

# 加載預訓練好的模型杈抢,如果不存在會進行下載
# 預訓練好的模型保存在 ~/.torch/models/下面
resnet34 = models.squeezenet1_1(pretrained=True, num_classes=1000)

# 修改最后的全連接層為10分類問題(默認是ImageNet上的1000分類)
resnet34.fc=nn.Linear(512, 10)

# 加上transform
transform  = T.Compose([
         T.ToTensor(),
         T.Normalize(mean=[0.4,], std=[0.2,]),
])

# 指定數(shù)據(jù)集路徑為data,如果數(shù)據(jù)集不存在則進行下載
# 通過train=False獲取測試集
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)

Transforms中涵蓋了大部分對Tensor和PIL Image的常用處理仑性,這些已在上文提到惶楼,這里就不再詳細介紹。需要注意的是轉(zhuǎn)換分為兩步,第一步:構(gòu)建轉(zhuǎn)換操作歼捐,例如transf = transforms.Normalize(mean=x, std=y)何陆,第二步:執(zhí)行轉(zhuǎn)換操作,例如output = transf(input)窥岩。另外還可將多個處理操作用Compose拼接起來甲献,形成一個處理轉(zhuǎn)換流程。

from torchvision import transforms 
to_pil = transforms.ToPILImage()
to_pil(t.randn(3, 64, 64))

輸出:隨機噪聲

image.png

torchvision還提供了兩個常用的函數(shù)颂翼。一個是make_grid晃洒,它能將多張圖片拼接成一個網(wǎng)格中;另一個是save_img朦乏,它能將Tensor保存成圖片球及。

len(dataset)

輸出:

10000
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成4*4網(wǎng)格圖片,且會轉(zhuǎn)成3通道
to_img(img)

輸出:

image.png
save_image(img, 'a.png')
Image.open('a.png')
image.png

5.3 可視化工具

在訓練神經(jīng)網(wǎng)絡時呻疹,我們希望能更直觀地了解訓練情況吃引,包括損失曲線、輸入圖片刽锤、輸出圖片镊尺、卷積核的參數(shù)分布等信息。這些信息能幫助我們更好地監(jiān)督網(wǎng)絡的訓練過程并思,并為參數(shù)優(yōu)化提供方向和依據(jù)庐氮。最簡單的辦法就是打印輸出,但其只能打印數(shù)值信息宋彼,不夠直觀弄砍,同時無法查看分布、圖片输涕、聲音等音婶。在本節(jié),我們將介紹兩個深度學習中常用的可視化工具:Tensorboard和Visdom莱坎。

5.3.1 Tensorboard

Tensorboard最初是作為TensorFlow的可視化工具迅速流行開來衣式。作為和TensorFlow深度集成的工具,Tensorboard能夠展現(xiàn)你的TensorFlow網(wǎng)絡計算圖檐什,繪制圖像生成的定量指標圖以及附加數(shù)據(jù)瞳收。但同時Tensorboard也是一個相對獨立的工具,只要用戶保存的數(shù)據(jù)遵循相應的格式厢汹,tensorboard就能讀取這些數(shù)據(jù)并進行可視化螟深。這里我們將主要介紹如何在PyTorch中使用tensorboardX進行訓練損失的可視化。 TensorboardX是將Tensorboard的功能抽取出來烫葬,使得非TensorFlow用戶也能使用它進行可視化界弧,幾乎支持原生TensorBoard的全部功能凡蜻。

image.png

tensorboard的安裝主要分為以下兩步:
(1)安裝TensorFlow:如果電腦中已經(jīng)安裝完TensorFlow可以跳過這一步,如果電腦中尚未安裝垢箕,建議安裝CPU-Only的版本划栓,具體安裝教程參見TensorFlow官網(wǎng),或使用pip命令直接安裝条获。

  • 安裝tensorboard: pip install tensorboard
  • 安裝tensorboardX:可通過pip install tensorboardX命令直接安裝忠荞。
    (2)安裝tensorboard_logger:可通過pip install tensorboard_logger命令直接安裝。

tensorboardX的使用非常簡單帅掘。首先用如下命令啟動tensorboard:

tensorboard --logdir <your/running/dir> --port <your_bind_port>

下面舉例說明tensorboardX的使用委煤。
打開一個新的終端,進入到你指定的tensorboard日志的上一層目錄修档,如:這里把生成的tensorboard日志放在E:/log目錄下碧绞,就需要進入E:/開啟tensorboard服務:

C:\Users\Mloong>E:
E:\>tensorboard --logdir log
TensorBoard 1.9.0 at http://DESKTOP-1LO98I2:6006 (Press CTRL+C to quit)

下面模擬訓練過程生成loss和accuracy,寫入E:/log目錄:

from tensorboardX import SummaryWriter

writer = SummaryWriter(log_dir='E:/log', flush_secs=2)

for ii in range(100):
    writer.add_scalar('data/loss', 100-ii**0.5, ii)
    writer.add_scalar('data/accuracy', ii/100, ii)
    
writer.close()

注意:tensorboard是谷歌的產(chǎn)品讥邻,最好使用谷歌瀏覽器(像360瀏覽器、IE瀏覽器都有問題)院峡。
打開谷歌瀏覽器輸入啟動日志中的網(wǎng)址http://DESKTOP-1LO98I2:6006(6006是默認端口號兴使,可以使用--port選項指定端口號),即可看到如圖所示的結(jié)果照激。

image.png

左側(cè)的Horizontal Axis下有三個選項鲫惶,分別是:

  • Step:根據(jù)步長來記錄,如果有步長实抡,則將其作為x軸坐標描點畫線。
  • Relative:用前后相對順序描點畫線欢策,每調(diào)用一次就自動加1吆寨。
  • Wall:按時間排序描點畫線。

左側(cè)的Smoothing條可以左右拖動踩寇,用來調(diào)節(jié)平滑的幅度啄清。點擊右上角的刷新按鈕可立即刷新結(jié)果,默認是每30s自動刷新數(shù)據(jù)俺孙±弊洌可見tensorboard_logger的使用十分簡單,但它只能統(tǒng)計簡單的數(shù)值信息睛榄,不支持其它功能荣茫。

感興趣的讀者可以從github項目主頁獲取更多信息,本節(jié)將把更多的內(nèi)容留給另一個可視化工具:Visdom场靴。

5.3.2 visdom

Visdom是Facebook專門為PyTorch開發(fā)的一款可視化工具啡莉,其開源于2017年3月港准。Visdom十分輕量級,但卻支持非常豐富的功能咧欣,能勝任大多數(shù)的科學運算可視化任務浅缸。其可視化界面如圖所示。

image.png

Visdom可以創(chuàng)造魄咕、組織和共享多種數(shù)據(jù)的可視化衩椒,包括數(shù)值、圖像哮兰、文本毛萌,甚至是視頻,其支持PyTorch奠蹬、Torch及Numpy朝聋。用戶可通過編程組織可視化空間,或通過用戶接口為生動數(shù)據(jù)打造儀表板囤躁,檢查實驗結(jié)果或調(diào)試代碼冀痕。

Visdom中有兩個重要概念:

  • env:環(huán)境。不同環(huán)境的可視化結(jié)果相互隔離狸演,互不影響言蛇,在使用時如果不指定env,默認使用main宵距。不同用戶腊尚、不同程序一般使用不同的env。
  • pane:窗格满哪。窗格可用于可視化圖像婿斥、數(shù)值或打印文本等,其可以拖動哨鸭、縮放民宿、保存和關閉。一個程序中可使用同一個env中的不同pane像鸡,每個pane可視化或記錄某一信息活鹰。

如圖所示,當前env='test'共有6個pane只估,分別展示不同的結(jié)果志群。點擊clear按鈕可以清空當前env的所有pane,點擊save按鈕可將當前env保存成json文件蛔钙,保存路徑位于~/.visdom/目錄下锌云。也可修改env的名字后點擊fork,保存當前env的狀態(tài)至更名后的env吁脱。

image.png

Visdom的安裝可通過命令pip install visdom宾抓。安裝完成后子漩,需通過python -m visdom.server命令啟動visdom服務,或通過nohup python -m visdom.server &命令將服務放至后臺運行石洗。Visdom服務是一個web server服務幢泼,默認綁定8097端口,客戶端與服務器間通過tornado進行非阻塞交互讲衫。

python -m visdom.server

輸出下面日志則表示啟動成功缕棵,拷貝網(wǎng)址到瀏覽器進行可視化。

It's Alive!
INFO:root:Application Started
You can navigate to http://localhost:8097

Visdom的使用有兩點需要注意的地方:

  • 需手動指定保存env涉兽,可在web界面點擊save按鈕或在程序中調(diào)用save方法招驴,否則visdom服務重啟后,env等信息會丟失枷畏。
  • 客戶端與服務器之間的交互采用tornado異步框架别厘,可視化操作不會阻塞當前程序,網(wǎng)絡異常也不會導致程序退出拥诡。

Visdom以Plotly為基礎触趴,支持豐富的可視化操作,下面舉例說明一些最常用的操作渴肉。

import visdom

# 新建一個連接客戶端
# 指定server(默認為‘localhost')指定端口(默認為'8097')冗懦,指定環(huán)境(默認為'main')
vis = visdom.Visdom(server='localhost',port='8097',env=u'test1')
x = t.arange(1, 30, 0.01)
y = t.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})
image.png

下面逐一分析這幾行代碼:

vis = visdom.Visdom(server='localhost',port='8097',env=u'test1'),用于構(gòu)建一個連接客戶端仇祭,可以指定server披蕉、port、env等參數(shù)乌奇。

vis作為一個客戶端對象没讲,可以使用常見的畫圖函數(shù),包括:

  • line:類似Matlab中的plot操作礁苗,用于記錄某些標量的變化爬凑,如損失、準確率等
  • image:可視化圖片寂屏,可以是輸入的圖片,也可以是GAN生成的圖片娜搂,還可以是卷積核的信息
  • text:用于記錄日志等文字信息迁霎,支持html格式
  • histgram:可視化分布,主要是查看數(shù)據(jù)百宇、參數(shù)的分布
  • scatter:繪制散點圖
  • bar:繪制柱狀圖
  • pie:繪制餅狀圖

更多操作可參考visdom的github主頁考廉。這里主要介紹深度學習中常見的line、image和text操作携御。

Visdom同時支持PyTorch的tensor和Numpy的ndarray兩種數(shù)據(jù)結(jié)構(gòu)昌粤,但不支持Python的int既绕、float等類型,因此每次傳入時都需先將數(shù)據(jù)轉(zhuǎn)成ndarray或tensor涮坐。上述操作的參數(shù)一般不同凄贩,但有兩個參數(shù)是絕大多數(shù)操作都具備的:

  • win:用于指定pane的名字,如果不指定袱讹,visdom將自動分配一個新的pane疲扎。如果兩次操作指定的win名字一樣,新的操作將覆蓋當前pane的內(nèi)容捷雕,因此建議每次操作都重新指定win椒丧。
  • opts:選項,接收一個字典救巷,常見的option包括title壶熏、xlabel、ylabel浦译、width等棒假,主要用于設置pane的顯示格式。

之前提到過管怠,每次操作都會覆蓋之前的數(shù)值淆衷,但往往我們在訓練網(wǎng)絡的過程中需不斷更新數(shù)值,如損失值等渤弛,這時就需要指定參數(shù)update='append'來避免覆蓋之前的數(shù)值祝拯。而除了使用update參數(shù)以外,還可以使用vis.updateTrace方法來更新圖她肯,但updateTrace不僅能在指定pane上新增一個和已有數(shù)據(jù)相互獨立的Trace佳头,還能像update='append'那樣在同一條trace上追加數(shù)據(jù)。

# append 追加數(shù)據(jù)
for ii in range(0, 10):
    x = t.Tensor([ii])
    y = x
    vis.line(X=x, Y=y, win='polynomial', update='append' if ii>0 else None)
    
# updateTrace 新增一條線
x = t.arange(0, 9, 0.1)
y = (x ** 2) / 9
vis.line(X=x, Y=y, win='polynomial', name='this is a new Trace',update='new')

結(jié)果如下圖所示晴氨。

image.png

image的畫圖功能可分為如下兩類:

  • image接收一個二維或三維向量康嘉,H×W 或 3×H×W,前者是黑白圖像籽前,后者是彩色圖像亭珍。
  • images接收一個四維向量 N×C×H×W,C 可以是1或3枝哄,分別代表黑白和彩色圖像肄梨。可實現(xiàn)類似torchvision中make_grid的功能挠锥,將多張圖片拼接在一起众羡。images也可以接收一個二維或三維的向量,此時它所實現(xiàn)的功能與image一致蓖租。
# 可視化一個隨機的黑白圖片
vis.image(t.randn(64, 64).numpy())

# 隨機可視化一張彩色圖片
vis.image(t.randn(3, 64, 64).numpy(), win='random2')

# 可視化36張隨機的彩色圖片粱侣,每一行6張
vis.images(t.randn(36, 3, 64, 64).numpy(), nrow=6, win='random3', opts={'title':'random_imgs'})

結(jié)果如下:

image.png

vis.text用于可視化文本羊壹,支持所有的html標簽,同時也遵循著html的語法標準齐婴。例如油猫,換行需使用
標簽,\r\n無法實現(xiàn)換行尔店。下面舉例說明眨攘。

vis.text(u'''<h1>Hello Visdom</h1><br>Visdom是Facebook專門為<b>PyTorch</b>開發(fā)的一個可視化工具,
         在內(nèi)部使用了很久嚣州,在2017年3月份開源了它鲫售。
         
         Visdom十分輕量級,但是卻有十分強大的功能该肴,支持幾乎所有的科學運算可視化任務''',
         win='visdom',
         opts={'title': u'visdom簡介' }
        )

結(jié)果如下:

image.png

5.4 使用GPU加速:cuda

這部分內(nèi)容在前面介紹Tensor情竹、Module時大都提到過,這里將做一個總結(jié)匀哄,并深入介紹相關應用秦效。

在PyTorch中以下數(shù)據(jù)結(jié)構(gòu)分為CPU和GPU兩個版本:

  • Tensor
  • nn.Module(包括常用的layer、loss function涎嚼,以及容器Sequential等)

它們都帶有一個.cuda方法阱州,調(diào)用此方法即可將其轉(zhuǎn)為對應的GPU對象。注意法梯,tensor.cuda會返回一個新對象苔货,這個新對象的數(shù)據(jù)已轉(zhuǎn)移至GPU,而之前的tensor還在原來的設備上(CPU)立哑。而module.cuda則會將所有的數(shù)據(jù)都遷移至GPU夜惭,并返回自己。所以module = module.cuda()和module.cuda()所起的作用一致铛绰。

nn.Module在GPU與CPU之間的轉(zhuǎn)換诈茧,本質(zhì)上還是利用了Tensor在GPU和CPU之間的轉(zhuǎn)換。nn.Module的cuda方法是將nn.Module下的所有parameter(包括子module的parameter)都轉(zhuǎn)移至GPU捂掰,而Parameter本質(zhì)上也是tensor(Tensor的子類)敢会。

下面將舉例說明,這部分代碼需要你具有兩塊GPU設備这嚣。>>>呵呵鸥昏,反正我沒有<<<

P.S. 為什么將數(shù)據(jù)轉(zhuǎn)移至GPU的方法叫做.cuda而不是.gpu,就像將數(shù)據(jù)轉(zhuǎn)移至CPU調(diào)用的方法是.cpu疤苹?這是因為GPU的編程接口采用CUDA互广,而目前并不是所有的GPU都支持CUDA敛腌,只有部分Nvidia的GPU才支持卧土。PyTorch未來可能會支持AMD的GPU惫皱,而AMD GPU的編程接口采用OpenCL,因此PyTorch還預留著.cl方法尤莺,用于以后支持AMD等的GPU旅敷。

tensor = t.Tensor(3, 4)
# 返回一個新的tensor,但原來的tensor并沒有改變
tensor.cuda(0)
tensor.is_cuda

輸出:

False
# 重新賦給自己颤霎,tensor指向cuda上的數(shù)據(jù)媳谁,不再執(zhí)行原數(shù)據(jù)
tensor = tensor.cuda()
tensor.is_cuda

輸出:

True
# 將nn.Module模型放到cuda上,其子模型也都自動放到cuda上
from torch import nn
module = nn.Linear(3, 4)
module.cuda(device = 0)
module.weight.is_cuda

輸出:

True
class VeryBigModule(nn.Module):
    def __init__(self):
        super(VeryBigModule, self).__init__()
        self.GiantParameter1 = t.nn.Parameter(t.randn(100000, 20000)).cuda(0)
        self.GiantParameter2 = t.nn.Parameter(t.randn(20000, 100000)).cuda(1)
    
    def forward(self, x):
        x = self.GiantParameter1.mm(x.cuda(0))
        x = self.GiantParameter2.mm(x.cuda(1))
        return x

上面最后一部分中友酱,兩個Parameter所占用的內(nèi)存空間都非常大晴音,大概是8個G,如果將這兩個都同時放在一塊GPU上幾乎會將顯存占滿缔杉,無法再進行任何其它運算锤躁。此時可通過這種方式將不同的計算分布到不同的GPU中。

關于使用GPU的一些建議:

  • GPU運算很快或详,但對于很小的運算量來說系羞,并不能體現(xiàn)出它的優(yōu)勢,因此對于一些簡單的操作可直接利用CPU完成
  • 數(shù)據(jù)在CPU和GPU之間霸琴,以及GPU與GPU之間的傳遞會比較耗時椒振,應當盡量避免
  • 在進行低精度的計算時,可以考慮HalfTensor梧乘,它相比于FloatTensor能節(jié)省一半的顯存澎迎,但需千萬注意數(shù)值溢出的情況。

另外這里需要專門提一下宋下,大部分的損失函數(shù)也都屬于nn.Moudle嗡善,但在使用GPU時,很多時候我們都忘記使用它的.cuda方法学歧,這在大多數(shù)情況下不會報錯罩引,因為損失函數(shù)本身沒有可學習的參數(shù)(learnable parameters)。但在某些情況下會出現(xiàn)問題枝笨,為了保險起見同時也為了代碼更規(guī)范袁铐,應記得調(diào)用criterion.cuda。下面舉例說明横浑。

# 交叉熵損失函數(shù)剔桨,帶權(quán)重
criterion = t.nn.CrossEntropyLoss(weight=t.Tensor([1, 3]))
input = t.randn(4, 2).cuda()
target = t.Tensor([1, 0, 0, 1]).long().cuda()

# 下面這行會報錯,因weight未被轉(zhuǎn)移至GPU
# loss = criterion(input, target)

# 這行則不會報錯
criterion.cuda()
loss = criterion(input, target)

criterion._buffers

輸出:

OrderedDict([('weight', tensor([1., 3.], device='cuda:0'))])

而除了調(diào)用對象的.cuda方法之外徙融,還可以使用torch.cuda.device洒缀,來指定默認使用哪一塊GPU,或使用torch.set_default_tensor_type使程序默認使用GPU,不需要手動調(diào)用cuda树绩。

# 如果未指定使用哪塊GPU萨脑,默認使用GPU 0
x = t.cuda.FloatTensor(2, 3)
# x.get_device() == 0
y = t.FloatTensor(2, 3).cuda()
# y.get_device() == 0

# 指定默認使用GPU 0
with t.cuda.device(0):    
    # 在GPU 0上構(gòu)建tensor
    a = t.cuda.FloatTensor(2, 3)

    # 將tensor轉(zhuǎn)移至GPU 0
    b = t.FloatTensor(2, 3).cuda()
    print(a.get_device() == b.get_device() == 0 )

    c = a + b
    print(c.get_device() == 0)

    z = x + y
    print(z.get_device() == 0)

    # 手動指定使用GPU 0
    d = t.randn(2, 3).cuda(0)
    print(d.get_device() == 2)

輸出:

True
True
True
False
t.set_default_tensor_type('torch.cuda.FloatTensor') # 指定默認tensor的類型為GPU上的FloatTensor
a = t.ones(2, 3)
a.is_cuda

輸出:

True

如果服務器具有多個GPU,tensor.cuda()方法會將tensor保存到第一塊GPU上饺饭,等價于tensor.cuda(0)渤早。此時如果想使用第二塊GPU,需手動指定tensor.cuda(1)瘫俊,而這需要修改大量代碼鹊杖,很是繁瑣。這里有兩種替代方法:

  • 一種是先調(diào)用t.cuda.set_device(1)指定使用第二塊GPU扛芽,后續(xù)的.cuda()都無需更改骂蓖,切換GPU只需修改這一行代碼。
  • 更推薦的方法是設置環(huán)境變量CUDA_VISIBLE_DEVICES川尖,例如當export CUDA_VISIBLE_DEVICE=1(下標是從0開始涯竟,1代表第二塊GPU),只使用第二塊物理GPU空厌,但在程序中這塊GPU會被看成是第一塊邏輯GPU庐船,因此此時調(diào)用tensor.cuda()會將Tensor轉(zhuǎn)移至第二塊物理GPU。CUDA_VISIBLE_DEVICES還可以指定多個GPU嘲更,如export CUDA_VISIBLE_DEVICES=0,2,3筐钟,那么第一、三赋朦、四塊物理GPU會被映射成第一篓冲、二、三塊邏輯GPU宠哄,tensor.cuda(1)會將Tensor轉(zhuǎn)移到第三塊物理GPU上壹将。

設置CUDA_VISIBLE_DEVICES有兩種方法,一種是在命令行中CUDA_VISIBLE_DEVICES=0,1 python main.py毛嫉,一種是在程序中import os;os.environ["CUDA_VISIBLE_DEVICES"] = "2"诽俯。如果使用IPython或者Jupyter notebook,還可以使用%env CUDA_VISIBLE_DEVICES=1,2來設置環(huán)境變量承粤。

從 0.4 版本開始暴区,pytorch新增了tensor.to(device)方法,能夠?qū)崿F(xiàn)設備透明辛臊,便于實現(xiàn)CPU/GPU兼容仙粱。這部份內(nèi)容已經(jīng)在第三章講解過了。

從PyTorch 0.2版本中彻舰,PyTorch新增分布式GPU支持伐割。分布式是指有多個GPU在多臺服務器上候味,而并行一般指的是一臺服務器上的多個GPU。分布式涉及到了服務器之間的通信隔心,因此比較復雜负溪,PyTorch封裝了相應的接口,可以用幾句簡單的代碼實現(xiàn)分布式訓練济炎。分布式對普通用戶來說比較遙遠,因為搭建一個分布式集群的代價十分大辐真,使用也比較復雜须尚。相比之下一機多卡更加現(xiàn)實。對于分布式訓練侍咱,這里不做太多的介紹耐床,感興趣的讀者可參考文檔distributed

5.5 持久化

在PyTorch中楔脯,以下對象可以持久化到硬盤撩轰,并能通過相應的方法加載到內(nèi)存中:

  • Tensor
  • Variable
  • nn.Module
  • Optimizer

本質(zhì)上上述這些信息最終都是保存成Tensor。Tensor的保存和加載十分的簡單昧廷,使用t.save和t.load即可完成相應的功能堪嫂。在save/load時可指定使用的pickle模塊,在load時還可將GPU tensor映射到CPU或其它GPU上木柬。

我們可以通過t.save(obj, file_name)等方法保存任意可序列化的對象皆串,然后通過obj = t.load(file_name)方法加載保存的數(shù)據(jù)。對于Module和Optimizer對象眉枕,這里建議保存對應的state_dict恶复,而不是直接保存整個Module/Optimizer對象。Optimizer對象保存的主要是參數(shù)速挑,以及動量信息谤牡,通過加載之前的動量信息,能夠有效地減少模型震蕩姥宝,下面舉例說明翅萤。

a = t.Tensor(3, 4)
if t.cuda.is_available():
        a = a.cuda(0) # 把a轉(zhuǎn)為GPU0上的tensor,
        t.save(a,'a.pth')
        
        # 加載為b, 存儲于GPU0上(因為保存時tensor就在GPU0上)
        b = t.load('a.pth')
        
        # 加載為c, 存儲于CPU
        c = t.load('a.pth', map_location=lambda storage, loc: storage)
        
        # 加載為d, 存儲于GPU0上
        d = t.load('a.pth', map_location={'cuda:1':'cuda:0'})

t.set_default_tensor_type('torch.FloatTensor')
from torchvision.models import SqueezeNet
model = SqueezeNet()
# module的state_dict是一個字典
model.state_dict().keys()

輸出:

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.squeeze.weight', 'features.3.squeeze.bias', 'features.3.expand1x1.weight', 'features.3.expand1x1.bias', 'features.3.expand3x3.weight', 'features.3.expand3x3.bias', 'features.4.squeeze.weight', 'features.4.squeeze.bias', 'features.4.expand1x1.weight', 'features.4.expand1x1.bias', 'features.4.expand3x3.weight', 'features.4.expand3x3.bias', 'features.5.squeeze.weight', 'features.5.squeeze.bias', 'features.5.expand1x1.weight', 'features.5.expand1x1.bias', 'features.5.expand3x3.weight', 'features.5.expand3x3.bias', 'features.7.squeeze.weight', 'features.7.squeeze.bias', 'features.7.expand1x1.weight', 'features.7.expand1x1.bias', 'features.7.expand3x3.weight', 'features.7.expand3x3.bias', 'features.8.squeeze.weight', 'features.8.squeeze.bias', 'features.8.expand1x1.weight', 'features.8.expand1x1.bias', 'features.8.expand3x3.weight', 'features.8.expand3x3.bias', 'features.9.squeeze.weight', 'features.9.squeeze.bias', 'features.9.expand1x1.weight', 'features.9.expand1x1.bias', 'features.9.expand3x3.weight', 'features.9.expand3x3.bias', 'features.10.squeeze.weight', 'features.10.squeeze.bias', 'features.10.expand1x1.weight', 'features.10.expand1x1.bias', 'features.10.expand3x3.weight', 'features.10.expand3x3.bias', 'features.12.squeeze.weight', 'features.12.squeeze.bias', 'features.12.expand1x1.weight', 'features.12.expand1x1.bias', 'features.12.expand3x3.weight', 'features.12.expand3x3.bias', 'classifier.1.weight', 'classifier.1.bias'])
# Module對象的保存與加載
t.save(model.state_dict(), 'squeezenet.pth')
model.load_state_dict(t.load('squeezenet.pth'))

輸出:

<All keys matched successfully>
optimizer = t.optim.Adam(model.parameters(), lr=0.1)
t.save(optimizer.state_dict(), 'optimizer.pth')
optimizer.load_state_dict(t.load('optimizer.pth'))

all_data = dict(
    optimizer = optimizer.state_dict(),
    model = model.state_dict(),
    info = u'模型和優(yōu)化器的所有參數(shù)'
)
t.save(all_data, 'all.pth')

all_data = t.load('all.pth')
all_data.keys()

輸出:

dict_keys(['optimizer', 'model', 'info'])

本章介紹了一些工具模塊,這些工具有些位于PyTorch中腊满,有些是獨立于PyTorch的第三方模塊断序。這些模塊主要設計數(shù)據(jù)加載、可視化和GPU加速相關的內(nèi)容糜烹,合理地使用這些模塊能極大地提升我們的編程效率违诗。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市疮蹦,隨后出現(xiàn)的幾起案子诸迟,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 217,277評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件阵苇,死亡現(xiàn)場離奇詭異壁公,居然都是意外死亡,警方通過查閱死者的電腦和手機绅项,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評論 3 393
  • 文/潘曉璐 我一進店門紊册,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人快耿,你說我怎么就攤上這事囊陡。” “怎么了掀亥?”我有些...
    開封第一講書人閱讀 163,624評論 0 353
  • 文/不壞的土叔 我叫張陵撞反,是天一觀的道長。 經(jīng)常有香客問我搪花,道長遏片,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,356評論 1 293
  • 正文 為了忘掉前任撮竿,我火速辦了婚禮吮便,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘幢踏。我一直安慰自己线衫,他們只是感情好,可當我...
    茶點故事閱讀 67,402評論 6 392
  • 文/花漫 我一把揭開白布惑折。 她就那樣靜靜地躺著授账,像睡著了一般。 火紅的嫁衣襯著肌膚如雪惨驶。 梳的紋絲不亂的頭發(fā)上白热,一...
    開封第一講書人閱讀 51,292評論 1 301
  • 那天,我揣著相機與錄音粗卜,去河邊找鬼令野。 笑死踱启,一個胖子當著我的面吹牛捆蜀,可吹牛的內(nèi)容都是我干的钠四。 我是一名探鬼主播,決...
    沈念sama閱讀 40,135評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼纱昧,長吁一口氣:“原來是場噩夢啊……” “哼刨啸!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起识脆,我...
    開封第一講書人閱讀 38,992評論 0 275
  • 序言:老撾萬榮一對情侶失蹤设联,失蹤者是張志新(化名)和其女友劉穎善已,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體离例,經(jīng)...
    沈念sama閱讀 45,429評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡换团,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,636評論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了宫蛆。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片艘包。...
    茶點故事閱讀 39,785評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖耀盗,靈堂內(nèi)的尸體忽然破棺而出想虎,到底是詐尸還是另有隱情,我是刑警寧澤袍冷,帶...
    沈念sama閱讀 35,492評論 5 345
  • 正文 年R本政府宣布,位于F島的核電站猫牡,受9級特大地震影響胡诗,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜淌友,卻給世界環(huán)境...
    茶點故事閱讀 41,092評論 3 328
  • 文/蒙蒙 一煌恢、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧震庭,春花似錦瑰抵、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至拨拓,卻和暖如春肴颊,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背渣磷。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評論 1 269
  • 我被黑心中介騙來泰國打工婿着, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人醋界。 一個月前我還...
    沈念sama閱讀 47,891評論 2 370
  • 正文 我出身青樓竟宋,卻偏偏與公主長得像,于是被迫代替她去往敵國和親形纺。 傳聞我的和親對象是個殘疾皇子丘侠,可洞房花燭夜當晚...
    茶點故事閱讀 44,713評論 2 354

推薦閱讀更多精彩內(nèi)容