在訓練神經(jīng)網(wǎng)絡的過程中需要用到很多工具,其中最重要的三部分是數(shù)據(jù)處理裸违、可視化和GPU加速厌秒。本章主要介紹PyTorch在這幾方面常用的工具村缸,合理使用這些工具能極大地提高編程效率。
5.1 數(shù)據(jù)處理
在解決深度學習問題的過程中痢虹,往往需要花費大量的精力去處理數(shù)據(jù)被去,包括圖像、文本奖唯、語音或其他二進制數(shù)據(jù)等惨缆。數(shù)據(jù)的處理對訓練神經(jīng)網(wǎng)絡來說十分重要,良好的數(shù)據(jù)處理不僅會加速模型訓練丰捷,也會提高模型效果坯墨。考慮到這一點病往,PyTorch提供了幾個高效便捷的工具捣染,以便使用者進行數(shù)據(jù)處理或者增強等操作,同時可通過并行化加速數(shù)據(jù)加載停巷。
(1)數(shù)據(jù)加載
在PyTorch中耍攘,數(shù)據(jù)加載可通過自定義的數(shù)據(jù)集對象實現(xiàn)榕栏。數(shù)據(jù)集對象被抽象為Dataset,實現(xiàn)自定義的數(shù)據(jù)集需要繼承Dataset蕾各,并實現(xiàn)兩個Python魔法方法扒磁。
- getitem:返回一條數(shù)據(jù)或一個樣本。obj[index]等價于obj.getitem(index)式曲。
- len:返回樣本的數(shù)量妨托。len(obj)等價于obj.len()。
這里我們以Kaggle經(jīng)典挑戰(zhàn)賽“Dogs vs. Cats”的數(shù)據(jù)為例吝羞,詳細講解如何處理數(shù)據(jù)兰伤。“Dogs vs. Cats”是一個分類問題脆贵,判斷一張圖片是狗還是貓医清,其所有圖片都存放在一個文件夾下,根據(jù)文件名的前綴判斷是狗還是貓卖氨。
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
class DogCat(data.Dataset):
def __init__(self, root):
imgs = os.listdir(root)
# 所有圖片的絕對路徑
# 這里不實際加載圖片会烙,只是指定路徑,當調(diào)用__getitem__時才會真正讀圖片
self.imgs = [os.path.join(root, img) for img in imgs]
def __getitem__(self, index):
img_path = self.imgs[index]
# dog->1筒捺, cat->0
label = 1 if 'dog' in img_path.split('/')[-1] else 0
pil_img = Image.open(img_path)
array = np.asarray(pil_img)
data = t.from_numpy(array)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/')
img, label = dataset[0] # 相當于調(diào)用dataset.__getitem__(0)
for img, label in dataset:
print(img.size(), img.float().mean(), label)
輸出:
torch.Size([500, 497, 3]) tensor(106.4915) 0
torch.Size([499, 379, 3]) tensor(171.8085) 0
torch.Size([236, 289, 3]) tensor(130.3004) 0
torch.Size([374, 499, 3]) tensor(115.5177) 0
torch.Size([375, 499, 3]) tensor(116.8139) 1
torch.Size([375, 499, 3]) tensor(150.5080) 1
torch.Size([377, 499, 3]) tensor(151.7174) 1
torch.Size([400, 300, 3]) tensor(128.1550) 1
通過上面的代碼柏腻,我們學習了如何自定義自己的數(shù)據(jù)集,并可以依次獲取系吭。但這里返回的數(shù)據(jù)不適合實際使用五嫂,因其具有如下兩方面問題:
- 返回樣本的形狀不一,因每張圖片的大小不一樣肯尺,這對于需要取batch訓練的神經(jīng)網(wǎng)絡來說很不友好沃缘。
- 返回樣本的數(shù)值較大,未歸一化至[-1, 1]则吟。
針對上述問題槐臀,PyTorch提供了torchvision。它是一個視覺工具包氓仲,提供了很多視覺圖像處理的工具水慨,其中transforms模塊提供了對PIL Image對象和Tensor對象的常用操作。
對PIL Image的操作包括:
- Scale:調(diào)整圖片尺寸敬扛,長寬比保持不變
- CenterCrop晰洒、`RandomCrop、RandomResizedCrop: 裁剪圖片
- Pad:填充
- ToTensor:將PIL Image對象轉(zhuǎn)成Tensor啥箭,會自動將[0, 255]歸一化至[0, 1]
對Tensor的操作包括:
- Normalize:標準化谍珊,即減均值,除以標準差
- ToPILImage:將Tensor轉(zhuǎn)為PIL Image對象
如果要對圖片進行多個操作急侥,可通過Compose函數(shù)將這些操作拼接起來抬驴,類似于nn.Sequential炼七。注意,這些操作定義后是以函數(shù)的形式存在布持,真正使用時需調(diào)用它的call方法,這點類似于nn.Module陕悬。例如要將圖片調(diào)整為224×224题暖,首先應構(gòu)建這個操作trans = Resize((224, 224)),然后調(diào)用trans(img)捉超。下面我們就用transforms的這些操作來優(yōu)化上面實現(xiàn)的dataset胧卤。
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transform = T.Compose([
T.Resize(224), # 縮放圖片(Image),保持長寬比不變拼岳,最短邊為224像素
T.CenterCrop(224), # 從圖片中間切出224*224的圖片
T.ToTensor(), # 將圖片(Image)轉(zhuǎn)成Tensor枝誊,歸一化至[0, 1]
T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 標準化至[-1, 1],規(guī)定均值和標準差
])
class DogCat(data.Dataset):
def __init__(self, root, transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms=transforms
def __getitem__(self, index):
img_path = self.imgs[index]
label = 0 if 'dog' in img_path.split('/')[-1] else 1
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
print(img.size(), label)
輸出:
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
除了上述操作之外惜纸,transforms還可通過Lambda封裝自定義的轉(zhuǎn)換策略叶撒。例如想對PIL Image進行隨機旋轉(zhuǎn),則可寫成這樣trans=T.Lambda(lambda img: img.rotate(random()*360))耐版。
torchvision已經(jīng)預先實現(xiàn)了常用的Dataset祠够,包括前面使用過的CIFAR-10,以及ImageNet粪牲、COCO古瓤、MNIST、LSUN等數(shù)據(jù)集腺阳,可通過諸如torchvision.datasets.CIFAR10來調(diào)用落君,具體使用方法請參看官方文檔。在這里介紹一個會經(jīng)常使用到的Dataset——ImageFolder亭引,它的實現(xiàn)和上述的DogCat很相似绎速。ImageFolder假設所有的文件按文件夾保存,每個文件夾下存儲同一個類別的圖片痛侍,文件夾名為類名朝氓,其構(gòu)造函數(shù)如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
它主要有四個參數(shù):
- root:在root指定的路徑下尋找圖片。
- transform:對PIL Image進行的轉(zhuǎn)換操作主届,transform的輸入是使用loader讀取圖片的返回對象赵哲。
- target_transform:對label的轉(zhuǎn)換。
- loader:給定路徑后如何讀取圖片君丁,默認讀取為RGB格式的PIL Image對象枫夺。
label是按照文件夾名順序排序后存成字典,即{類名:類序號(從0開始)}绘闷,一般來說最好直接將文件夾命名為從0開始的數(shù)字橡庞,這樣會和ImageFolder實際的label一致较坛,如果不是這種命名規(guī)范,建議看看self.class_to_idx屬性以了解label和文件夾名的映射關系扒最。
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')
# cat文件夾的圖片對應label 0丑勤,dog對應1
dataset.class_to_idx
輸出:
{'cat': 0, 'dog': 1}
# 所有圖片的路徑和對應的label
dataset.imgs
輸出:
[('data/dogcat_2/cat\\cat.12484.jpg', 0),
('data/dogcat_2/cat\\cat.12485.jpg', 0),
('data/dogcat_2/cat\\cat.12486.jpg', 0),
('data/dogcat_2/cat\\cat.12487.jpg', 0),
('data/dogcat_2/dog\\dog.12496.jpg', 1),
('data/dogcat_2/dog\\dog.12497.jpg', 1),
('data/dogcat_2/dog\\dog.12498.jpg', 1),
('data/dogcat_2/dog\\dog.12499.jpg', 1)]
# 沒有任何的transform,所以返回的還是PIL Image對象
dataset[0][1] # 第一維是第幾張圖吧趣,第二維為1返回label
dataset[0][0] # 第一維是第幾張圖法竞,第二維為0返回圖片數(shù)據(jù)
輸出:
# 加上transform
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
])
dataset = ImageFolder('data/dogcat_2/', transform=transform)
# 深度學習中圖片數(shù)據(jù)一般保存成CxHxW,即通道數(shù)x圖片高x圖片寬
dataset[0][0].size()
輸出:
torch.Size([3, 224, 224])
to_img = T.ToPILImage()
# 0.2和0.4是標準差和均值的近似
to_img(dataset[0][0]*0.2+0.4)
輸出:
Dataset只負責數(shù)據(jù)的抽象强挫,一次調(diào)用getitem只返回一個樣本岔霸。前面提到過,在訓練神經(jīng)網(wǎng)絡時俯渤,最好是對一個batch的數(shù)據(jù)進行操作呆细,同時還需要對數(shù)據(jù)進行shuffle和并行加速等。對此八匠,PyTorch提供了DataLoader幫助我們實現(xiàn)這些功能絮爷。
DataLoader的函數(shù)定義如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
- dataset:加載的數(shù)據(jù)集(Dataset對象)
- batch_size:batch size
- shuffle::是否將數(shù)據(jù)打亂
- sampler: 樣本抽樣,后續(xù)會詳細介紹
- num_workers:使用多進程加載的進程數(shù)臀叙,0代表不使用多進程
- collate_fn: 如何將多個樣本數(shù)據(jù)拼接成一個batch略水,一般使用默認的拼接方式即可
- pin_memory:是否將數(shù)據(jù)保存在pin memory區(qū),pin memory中的數(shù)據(jù)轉(zhuǎn)到GPU會快一些
- drop_last:dataset中的數(shù)據(jù)個數(shù)可能不是batch_size的整數(shù)倍劝萤,drop_last為True會將多出來不足一個batch的數(shù)據(jù)丟棄
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight
輸出:
torch.Size([3, 3, 224, 224])
dataloader是一個可迭代的對象渊涝,意味著我們可以像使用迭代器一樣使用它,例如:
for batch_datas, batch_labels in dataloader:
train()
或
dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)
在數(shù)據(jù)處理中床嫌,有時會出現(xiàn)某個樣本無法讀取等問題跨释,比如某張圖片損壞。這時在getitem函數(shù)中將出現(xiàn)異常厌处,此時最好的解決方案即是將出錯的樣本剔除鳖谈。如果實在是遇到這種情況無法處理,則可以返回None對象阔涉,然后在Dataloader中實現(xiàn)自定義的collate_fn缆娃,將空對象過濾掉。但要注意瑰排,在這種情況下dataloader返回的batch數(shù)目會少于batch_size贯要。
class NewDogCat(DogCat): # 繼承前面實現(xiàn)的DogCat數(shù)據(jù)集
def __getitem__(self, index):
try:
# 調(diào)用父類的獲取函數(shù),即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None
from torch.utils.data.dataloader import default_collate # 導入默認的拼接方式
def my_collate_fn(batch):
'''
batch中每個元素形如(data, label)
'''
# 過濾為None的數(shù)據(jù)
batch = list(filter(lambda x:x[0] is not None, batch))
if len(batch) == 0: return t.Tensor()
return default_collate(batch) # 用默認方式拼接過濾后的batch數(shù)據(jù)
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5]
輸出:
(tensor([[[ 1.2549, 1.2549, 1.2549, ..., -0.0980, -0.0980, -0.1569],
[ 1.2941, 1.3137, 1.3137, ..., -0.0784, -0.0980, -0.0980],
[ 1.3137, 1.3137, 1.3137, ..., -0.0588, -0.0784, -0.0588],
...,
[ 1.3725, 1.3529, 1.3529, ..., -1.4314, -1.4314, -1.4118],
[ 1.3725, 1.3529, 1.3529, ..., -1.4314, -1.4314, -1.4314],
[ 1.3922, 1.3725, 1.3725, ..., -1.3922, -1.3922, -1.3922]],
[[ 1.0588, 0.9216, 0.8824, ..., 0.4314, 0.4510, 0.3922],
[ 1.0980, 0.9804, 0.9412, ..., 0.4314, 0.4118, 0.4118],
[ 1.1176, 0.9804, 0.9412, ..., 0.4314, 0.4118, 0.4314],
...,
[ 1.3725, 1.3137, 1.3137, ..., -1.1569, -1.1373, -1.1176],
[ 1.3725, 1.3137, 1.3137, ..., -1.1569, -1.1373, -1.1176],
[ 1.3922, 1.3137, 1.3137, ..., -1.0784, -1.0784, -1.0784]],
[[ 0.0392, -0.1765, -0.2549, ..., 1.4706, 1.4510, 1.3725],
[ 0.0784, -0.1176, -0.1961, ..., 1.5098, 1.4706, 1.4510],
[ 0.0980, -0.1176, -0.1765, ..., 1.5294, 1.5098, 1.5294],
...,
[ 0.3922, 0.3922, 0.3922, ..., -1.1373, -1.1569, -1.1569],
[ 0.3922, 0.3922, 0.4118, ..., -1.2353, -1.2745, -1.2745],
[ 0.4118, 0.4118, 0.4118, ..., -1.2157, -1.2745, -1.2745]]]), 0)
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=0,shuffle=True)
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
輸出:
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
來看一下上述batch_size的大小椭住。其中第1個的batch_size為1崇渗,這是因為有一張圖片損壞,導致其無法正常返回。而最后1個的batch_size也為1宅广,這是因為共有9張(包括損壞的文件)圖片葫掉,無法整除2(batch_size),因此最后一個batch的數(shù)據(jù)會少于batch_szie跟狱,可通過指定drop_last=True來丟棄最后一個不足batch_size的batch俭厚。
對于諸如樣本損壞或數(shù)據(jù)集加載異常等情況,還可以通過其它方式解決驶臊。例如但凡遇到異常情況套腹,就隨機取一張圖片代替:
class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]
相比較丟棄異常圖片而言,這種做法會更好一些资铡,因為它能保證每個batch的數(shù)目仍是batch_size。但在大多數(shù)情況下幢码,最好的方式還是對數(shù)據(jù)進行徹底清洗笤休。
DataLoader里面并沒有太多的魔法方法,它封裝了Python的標準庫multiprocessing症副,使其能夠?qū)崿F(xiàn)多進程加速店雅。在此提幾點關于Dataset和DataLoader使用方面的建議:
- 高負載的操作放在getitem中,如加載圖片等贞铣。
- dataset中應盡量只包含只讀對象闹啦,避免修改任何可變對象,利用多線程進行操作辕坝。
第一點是因為多進程會并行的調(diào)用getitem函數(shù)窍奋,將負載高的放在getitem函數(shù)中能夠?qū)崿F(xiàn)并行加速。 第二點是因為dataloader使用多進程加載酱畅,如果在Dataset實現(xiàn)中使用了可變對象琳袄,可能會有意想不到的沖突。在多線程/多進程中纺酸,修改一個可變對象窖逗,需要加鎖,但是dataloader的設計使得其很難加鎖(在實際使用中也應盡量避免鎖的存在)餐蔬,因此最好避免在dataset中修改可變對象碎紊。例如下面就是一個不好的例子,在多進程處理中self.num可能與預期不符樊诺,這種問題不會報錯仗考,因此難以發(fā)現(xiàn)。如果一定要修改可變對象啄骇,建議使用Python標準庫Queue中的相關數(shù)據(jù)結(jié)構(gòu)痴鳄。
class BadDataset(Dataset):
def __init__(self):
self.datas = range(100)
self.num = 0 # 取數(shù)據(jù)的次數(shù)
def __getitem__(self, index):
self.num += 1
return self.datas[index]
使用Python multiprocessing庫的另一個問題是,在使用多進程時缸夹,如果主程序異常終止(比如用Ctrl+C強行退出)痪寻,相應的數(shù)據(jù)加載進程可能無法正常退出螺句。這時你可能會發(fā)現(xiàn)程序已經(jīng)退出了,但GPU顯存和內(nèi)存依舊被占用著橡类,或通過top蛇尚、ps aux依舊能夠看到已經(jīng)退出的程序,這時就需要手動強行殺掉進程顾画。建議使用如下命令:
ps x | grep <cmdline> | awk '{print $1}' | xargs kill
- ps x:獲取當前用戶的所有進程
- grep <cmdline>:找到已經(jīng)停止的PyTorch程序的進程取劫,例如你是通過python train.py啟動的,那你就需要寫grep 'python train.py'
- awk '{print $1}':獲取進程的pid
- xargs kill:殺掉進程研侣,根據(jù)需要可能要寫成xargs kill -9強制殺掉進程
在執(zhí)行這句命令之前谱邪,建議先打印確認一下是否會誤殺其它進程
ps x | grep <cmdline> | ps x
PyTorch中還單獨提供了一個sampler模塊,用來對數(shù)據(jù)進行采樣庶诡。常用的有隨機采樣器:RandomSampler惦银,當dataloader的shuffle參數(shù)為True時,系統(tǒng)會自動調(diào)用這個采樣器末誓,實現(xiàn)打亂數(shù)據(jù)扯俱。默認的是采用SequentialSampler,它會按順序一個一個進行采樣喇澡。這里介紹另外一個很有用的采樣方法: WeightedRandomSampler迅栅,它會根據(jù)每個樣本的權(quán)重選取數(shù)據(jù),在樣本比例不均衡的問題中晴玖,可用它來進行重采樣读存。
構(gòu)建WeightedRandomSampler時需提供兩個參數(shù):每個樣本的權(quán)重weights、共選取的樣本總數(shù)num_samples窜醉,以及一個可選參數(shù)replacement宪萄。權(quán)重越大的樣本被選中的概率越大,待選取的樣本數(shù)目一般小于全部的樣本數(shù)目榨惰。replacement用于指定是否可以重復選取某一個樣本拜英,默認為True,即允許在一個epoch中重復采樣某一個數(shù)據(jù)琅催。如果設為False居凶,則當某一類的樣本被全部選取完,但其樣本數(shù)目仍未達到num_samples時藤抡,sampler將不會再從該類中選擇數(shù)據(jù)侠碧,此時可能導致weights參數(shù)失效。下面舉例說明缠黍。
dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的圖片被取出的概率是貓的概率的兩倍
# 兩類圖片被取出的概率與weights的絕對大小無關弄兜,只和比值有關
weights = [2 if label == 1 else 1 for data, label in dataset]
weights
輸出:
[2, 2, 2, 2, 1, 1, 1, 1]
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
輸出:
[1, 1, 1]
[0, 0, 1]
[1, 1, 1]
可見貓狗樣本比例約為1:2,另外一共只有8個樣本,但是卻返回了9個替饿,說明肯定有被重復返回的语泽,這就是replacement參數(shù)的作用,下面將replacement設為False試試视卢。
sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
輸出:
[1, 1, 1, 1]
[0, 0, 0, 0]
在這種情況下踱卵,num_samples等于dataset的樣本總數(shù),為了不重復選取据过,sampler會將每個樣本都返回惋砂,這樣就失去weight參數(shù)的意義了。
從上面的例子可見sampler在樣本采樣中的作用:如果指定了sampler绳锅,shuffle將不再生效西饵,并且sampler.num_samples會覆蓋dataset的實際大小,即一個epoch返回的圖片總數(shù)取決于sampler.num_samples鳞芙。
5.2 計算機視覺工具包:torchvision
計算機視覺是深度學習中最重要的一類應用罗标,為了方便研究者使用,PyTorch團隊專門開發(fā)了一個視覺工具包torchvision积蜻,這個包獨立于PyTorch挖函,需通過pip instal torchvision安裝迈着。在之前的例子中我們已經(jīng)見識到了它的部分功能,這里再做一個系統(tǒng)性的介紹笨触。torchvision主要包含三部分:
- models:提供深度學習中各種經(jīng)典網(wǎng)絡的網(wǎng)絡結(jié)構(gòu)以及預訓練好的模型宾尚,包括AlexNet丙笋、VGG系列、ResNet系列煌贴、Inception系列等御板。
- datasets: 提供常用的數(shù)據(jù)集加載,設計上都是繼承torhc.utils.data.Dataset牛郑,主要包括MNIST怠肋、CIFAR10/100、ImageNet淹朋、COCO等笙各。
- transforms:提供常用的數(shù)據(jù)預處理操作,主要包括對Tensor以及PIL Image對象的操作础芍。
from torchvision import models
from torch import nn
from torchvision import datasets
# 加載預訓練好的模型杈抢,如果不存在會進行下載
# 預訓練好的模型保存在 ~/.torch/models/下面
resnet34 = models.squeezenet1_1(pretrained=True, num_classes=1000)
# 修改最后的全連接層為10分類問題(默認是ImageNet上的1000分類)
resnet34.fc=nn.Linear(512, 10)
# 加上transform
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.4,], std=[0.2,]),
])
# 指定數(shù)據(jù)集路徑為data,如果數(shù)據(jù)集不存在則進行下載
# 通過train=False獲取測試集
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)
Transforms中涵蓋了大部分對Tensor和PIL Image的常用處理仑性,這些已在上文提到惶楼,這里就不再詳細介紹。需要注意的是轉(zhuǎn)換分為兩步,第一步:構(gòu)建轉(zhuǎn)換操作歼捐,例如transf = transforms.Normalize(mean=x, std=y)何陆,第二步:執(zhí)行轉(zhuǎn)換操作,例如output = transf(input)窥岩。另外還可將多個處理操作用Compose拼接起來甲献,形成一個處理轉(zhuǎn)換流程。
from torchvision import transforms
to_pil = transforms.ToPILImage()
to_pil(t.randn(3, 64, 64))
輸出:隨機噪聲
torchvision還提供了兩個常用的函數(shù)颂翼。一個是make_grid晃洒,它能將多張圖片拼接成一個網(wǎng)格中;另一個是save_img朦乏,它能將Tensor保存成圖片球及。
len(dataset)
輸出:
10000
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成4*4網(wǎng)格圖片,且會轉(zhuǎn)成3通道
to_img(img)
輸出:
save_image(img, 'a.png')
Image.open('a.png')
5.3 可視化工具
在訓練神經(jīng)網(wǎng)絡時呻疹,我們希望能更直觀地了解訓練情況吃引,包括損失曲線、輸入圖片刽锤、輸出圖片镊尺、卷積核的參數(shù)分布等信息。這些信息能幫助我們更好地監(jiān)督網(wǎng)絡的訓練過程并思,并為參數(shù)優(yōu)化提供方向和依據(jù)庐氮。最簡單的辦法就是打印輸出,但其只能打印數(shù)值信息宋彼,不夠直觀弄砍,同時無法查看分布、圖片输涕、聲音等音婶。在本節(jié),我們將介紹兩個深度學習中常用的可視化工具:Tensorboard和Visdom莱坎。
5.3.1 Tensorboard
Tensorboard最初是作為TensorFlow的可視化工具迅速流行開來衣式。作為和TensorFlow深度集成的工具,Tensorboard能夠展現(xiàn)你的TensorFlow網(wǎng)絡計算圖檐什,繪制圖像生成的定量指標圖以及附加數(shù)據(jù)瞳收。但同時Tensorboard也是一個相對獨立的工具,只要用戶保存的數(shù)據(jù)遵循相應的格式厢汹,tensorboard就能讀取這些數(shù)據(jù)并進行可視化螟深。這里我們將主要介紹如何在PyTorch中使用tensorboardX進行訓練損失的可視化。 TensorboardX是將Tensorboard的功能抽取出來烫葬,使得非TensorFlow用戶也能使用它進行可視化界弧,幾乎支持原生TensorBoard的全部功能凡蜻。
tensorboard的安裝主要分為以下兩步:
(1)安裝TensorFlow:如果電腦中已經(jīng)安裝完TensorFlow可以跳過這一步,如果電腦中尚未安裝垢箕,建議安裝CPU-Only的版本划栓,具體安裝教程參見TensorFlow官網(wǎng),或使用pip命令直接安裝条获。
- 安裝tensorboard:
pip install tensorboard
- 安裝tensorboardX:可通過
pip install tensorboardX
命令直接安裝忠荞。
(2)安裝tensorboard_logger:可通過pip install tensorboard_logger命令直接安裝。
tensorboardX的使用非常簡單帅掘。首先用如下命令啟動tensorboard:
tensorboard --logdir <your/running/dir> --port <your_bind_port>
下面舉例說明tensorboardX的使用委煤。
打開一個新的終端,進入到你指定的tensorboard日志的上一層目錄修档,如:這里把生成的tensorboard日志放在E:/log目錄下碧绞,就需要進入E:/開啟tensorboard服務:
C:\Users\Mloong>E:
E:\>tensorboard --logdir log
TensorBoard 1.9.0 at http://DESKTOP-1LO98I2:6006 (Press CTRL+C to quit)
下面模擬訓練過程生成loss和accuracy,寫入E:/log目錄:
from tensorboardX import SummaryWriter
writer = SummaryWriter(log_dir='E:/log', flush_secs=2)
for ii in range(100):
writer.add_scalar('data/loss', 100-ii**0.5, ii)
writer.add_scalar('data/accuracy', ii/100, ii)
writer.close()
注意:tensorboard是谷歌的產(chǎn)品讥邻,最好使用谷歌瀏覽器(像360瀏覽器、IE瀏覽器都有問題)院峡。
打開谷歌瀏覽器輸入啟動日志中的網(wǎng)址http://DESKTOP-1LO98I2:6006
(6006是默認端口號兴使,可以使用--port選項指定端口號),即可看到如圖所示的結(jié)果照激。
左側(cè)的Horizontal Axis下有三個選項鲫惶,分別是:
- Step:根據(jù)步長來記錄,如果有步長实抡,則將其作為x軸坐標描點畫線。
- Relative:用前后相對順序描點畫線欢策,每調(diào)用一次就自動加1吆寨。
- Wall:按時間排序描點畫線。
左側(cè)的Smoothing條可以左右拖動踩寇,用來調(diào)節(jié)平滑的幅度啄清。點擊右上角的刷新按鈕可立即刷新結(jié)果,默認是每30s自動刷新數(shù)據(jù)俺孙±弊洌可見tensorboard_logger的使用十分簡單,但它只能統(tǒng)計簡單的數(shù)值信息睛榄,不支持其它功能荣茫。
感興趣的讀者可以從github項目主頁獲取更多信息,本節(jié)將把更多的內(nèi)容留給另一個可視化工具:Visdom场靴。
5.3.2 visdom
Visdom是Facebook專門為PyTorch開發(fā)的一款可視化工具啡莉,其開源于2017年3月港准。Visdom十分輕量級,但卻支持非常豐富的功能咧欣,能勝任大多數(shù)的科學運算可視化任務浅缸。其可視化界面如圖所示。
Visdom可以創(chuàng)造魄咕、組織和共享多種數(shù)據(jù)的可視化衩椒,包括數(shù)值、圖像哮兰、文本毛萌,甚至是視頻,其支持PyTorch奠蹬、Torch及Numpy朝聋。用戶可通過編程組織可視化空間,或通過用戶接口為生動數(shù)據(jù)打造儀表板囤躁,檢查實驗結(jié)果或調(diào)試代碼冀痕。
Visdom中有兩個重要概念:
- env:環(huán)境。不同環(huán)境的可視化結(jié)果相互隔離狸演,互不影響言蛇,在使用時如果不指定env,默認使用
main
宵距。不同用戶腊尚、不同程序一般使用不同的env。 - pane:窗格满哪。窗格可用于可視化圖像婿斥、數(shù)值或打印文本等,其可以拖動哨鸭、縮放民宿、保存和關閉。一個程序中可使用同一個env中的不同pane像鸡,每個pane可視化或記錄某一信息活鹰。
如圖所示,當前env='test'共有6個pane只估,分別展示不同的結(jié)果志群。點擊clear按鈕可以清空當前env的所有pane,點擊save按鈕可將當前env保存成json文件蛔钙,保存路徑位于~/.visdom/
目錄下锌云。也可修改env的名字后點擊fork,保存當前env的狀態(tài)至更名后的env吁脱。
Visdom的安裝可通過命令pip install visdom
宾抓。安裝完成后子漩,需通過python -m visdom.server
命令啟動visdom服務,或通過nohup python -m visdom.server &
命令將服務放至后臺運行石洗。Visdom服務是一個web server服務幢泼,默認綁定8097端口,客戶端與服務器間通過tornado進行非阻塞交互讲衫。
python -m visdom.server
輸出下面日志則表示啟動成功缕棵,拷貝網(wǎng)址到瀏覽器進行可視化。
It's Alive!
INFO:root:Application Started
You can navigate to http://localhost:8097
Visdom的使用有兩點需要注意的地方:
- 需手動指定保存env涉兽,可在web界面點擊save按鈕或在程序中調(diào)用save方法招驴,否則visdom服務重啟后,env等信息會丟失枷畏。
- 客戶端與服務器之間的交互采用tornado異步框架别厘,可視化操作不會阻塞當前程序,網(wǎng)絡異常也不會導致程序退出拥诡。
Visdom以Plotly為基礎触趴,支持豐富的可視化操作,下面舉例說明一些最常用的操作渴肉。
import visdom
# 新建一個連接客戶端
# 指定server(默認為‘localhost')指定端口(默認為'8097')冗懦,指定環(huán)境(默認為'main')
vis = visdom.Visdom(server='localhost',port='8097',env=u'test1')
x = t.arange(1, 30, 0.01)
y = t.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})
下面逐一分析這幾行代碼:
vis = visdom.Visdom(server='localhost',port='8097',env=u'test1'),用于構(gòu)建一個連接客戶端仇祭,可以指定server披蕉、port、env等參數(shù)乌奇。
vis作為一個客戶端對象没讲,可以使用常見的畫圖函數(shù),包括:
- line:類似Matlab中的plot操作礁苗,用于記錄某些標量的變化爬凑,如損失、準確率等
- image:可視化圖片寂屏,可以是輸入的圖片,也可以是GAN生成的圖片娜搂,還可以是卷積核的信息
- text:用于記錄日志等文字信息迁霎,支持html格式
- histgram:可視化分布,主要是查看數(shù)據(jù)百宇、參數(shù)的分布
- scatter:繪制散點圖
- bar:繪制柱狀圖
- pie:繪制餅狀圖
更多操作可參考visdom的github主頁考廉。這里主要介紹深度學習中常見的line、image和text操作携御。
Visdom同時支持PyTorch的tensor和Numpy的ndarray兩種數(shù)據(jù)結(jié)構(gòu)昌粤,但不支持Python的int既绕、float等類型,因此每次傳入時都需先將數(shù)據(jù)轉(zhuǎn)成ndarray或tensor涮坐。上述操作的參數(shù)一般不同凄贩,但有兩個參數(shù)是絕大多數(shù)操作都具備的:
- win:用于指定pane的名字,如果不指定袱讹,visdom將自動分配一個新的pane疲扎。如果兩次操作指定的win名字一樣,新的操作將覆蓋當前pane的內(nèi)容捷雕,因此建議每次操作都重新指定win椒丧。
- opts:選項,接收一個字典救巷,常見的option包括title壶熏、xlabel、ylabel浦译、width等棒假,主要用于設置pane的顯示格式。
之前提到過管怠,每次操作都會覆蓋之前的數(shù)值淆衷,但往往我們在訓練網(wǎng)絡的過程中需不斷更新數(shù)值,如損失值等渤弛,這時就需要指定參數(shù)update='append'來避免覆蓋之前的數(shù)值祝拯。而除了使用update參數(shù)以外,還可以使用vis.updateTrace方法來更新圖她肯,但updateTrace不僅能在指定pane上新增一個和已有數(shù)據(jù)相互獨立的Trace佳头,還能像update='append'那樣在同一條trace上追加數(shù)據(jù)。
# append 追加數(shù)據(jù)
for ii in range(0, 10):
x = t.Tensor([ii])
y = x
vis.line(X=x, Y=y, win='polynomial', update='append' if ii>0 else None)
# updateTrace 新增一條線
x = t.arange(0, 9, 0.1)
y = (x ** 2) / 9
vis.line(X=x, Y=y, win='polynomial', name='this is a new Trace',update='new')
結(jié)果如下圖所示晴氨。
image的畫圖功能可分為如下兩類:
- image接收一個二維或三維向量康嘉,H×W 或 3×H×W,前者是黑白圖像籽前,后者是彩色圖像亭珍。
- images接收一個四維向量 N×C×H×W,C 可以是1或3枝哄,分別代表黑白和彩色圖像肄梨。可實現(xiàn)類似torchvision中make_grid的功能挠锥,將多張圖片拼接在一起众羡。images也可以接收一個二維或三維的向量,此時它所實現(xiàn)的功能與image一致蓖租。
# 可視化一個隨機的黑白圖片
vis.image(t.randn(64, 64).numpy())
# 隨機可視化一張彩色圖片
vis.image(t.randn(3, 64, 64).numpy(), win='random2')
# 可視化36張隨機的彩色圖片粱侣,每一行6張
vis.images(t.randn(36, 3, 64, 64).numpy(), nrow=6, win='random3', opts={'title':'random_imgs'})
結(jié)果如下:
vis.text用于可視化文本羊壹,支持所有的html標簽,同時也遵循著html的語法標準齐婴。例如油猫,換行需使用
標簽,\r\n無法實現(xiàn)換行尔店。下面舉例說明眨攘。
vis.text(u'''<h1>Hello Visdom</h1><br>Visdom是Facebook專門為<b>PyTorch</b>開發(fā)的一個可視化工具,
在內(nèi)部使用了很久嚣州,在2017年3月份開源了它鲫售。
Visdom十分輕量級,但是卻有十分強大的功能该肴,支持幾乎所有的科學運算可視化任務''',
win='visdom',
opts={'title': u'visdom簡介' }
)
結(jié)果如下:
5.4 使用GPU加速:cuda
這部分內(nèi)容在前面介紹Tensor情竹、Module時大都提到過,這里將做一個總結(jié)匀哄,并深入介紹相關應用秦效。
在PyTorch中以下數(shù)據(jù)結(jié)構(gòu)分為CPU和GPU兩個版本:
- Tensor
- nn.Module(包括常用的layer、loss function涎嚼,以及容器Sequential等)
它們都帶有一個.cuda方法阱州,調(diào)用此方法即可將其轉(zhuǎn)為對應的GPU對象。注意法梯,tensor.cuda會返回一個新對象苔货,這個新對象的數(shù)據(jù)已轉(zhuǎn)移至GPU,而之前的tensor還在原來的設備上(CPU)立哑。而module.cuda則會將所有的數(shù)據(jù)都遷移至GPU夜惭,并返回自己。所以module = module.cuda()和module.cuda()所起的作用一致铛绰。
nn.Module在GPU與CPU之間的轉(zhuǎn)換诈茧,本質(zhì)上還是利用了Tensor在GPU和CPU之間的轉(zhuǎn)換。nn.Module的cuda方法是將nn.Module下的所有parameter(包括子module的parameter)都轉(zhuǎn)移至GPU捂掰,而Parameter本質(zhì)上也是tensor(Tensor的子類)敢会。
下面將舉例說明,這部分代碼需要你具有兩塊GPU設備这嚣。>>>呵呵鸥昏,反正我沒有<<<
P.S. 為什么將數(shù)據(jù)轉(zhuǎn)移至GPU的方法叫做.cuda而不是.gpu,就像將數(shù)據(jù)轉(zhuǎn)移至CPU調(diào)用的方法是.cpu疤苹?這是因為GPU的編程接口采用CUDA互广,而目前并不是所有的GPU都支持CUDA敛腌,只有部分Nvidia的GPU才支持卧土。PyTorch未來可能會支持AMD的GPU惫皱,而AMD GPU的編程接口采用OpenCL,因此PyTorch還預留著.cl方法尤莺,用于以后支持AMD等的GPU旅敷。
tensor = t.Tensor(3, 4)
# 返回一個新的tensor,但原來的tensor并沒有改變
tensor.cuda(0)
tensor.is_cuda
輸出:
False
# 重新賦給自己颤霎,tensor指向cuda上的數(shù)據(jù)媳谁,不再執(zhí)行原數(shù)據(jù)
tensor = tensor.cuda()
tensor.is_cuda
輸出:
True
# 將nn.Module模型放到cuda上,其子模型也都自動放到cuda上
from torch import nn
module = nn.Linear(3, 4)
module.cuda(device = 0)
module.weight.is_cuda
輸出:
True
class VeryBigModule(nn.Module):
def __init__(self):
super(VeryBigModule, self).__init__()
self.GiantParameter1 = t.nn.Parameter(t.randn(100000, 20000)).cuda(0)
self.GiantParameter2 = t.nn.Parameter(t.randn(20000, 100000)).cuda(1)
def forward(self, x):
x = self.GiantParameter1.mm(x.cuda(0))
x = self.GiantParameter2.mm(x.cuda(1))
return x
上面最后一部分中友酱,兩個Parameter所占用的內(nèi)存空間都非常大晴音,大概是8個G,如果將這兩個都同時放在一塊GPU上幾乎會將顯存占滿缔杉,無法再進行任何其它運算锤躁。此時可通過這種方式將不同的計算分布到不同的GPU中。
關于使用GPU的一些建議:
- GPU運算很快或详,但對于很小的運算量來說系羞,并不能體現(xiàn)出它的優(yōu)勢,因此對于一些簡單的操作可直接利用CPU完成
- 數(shù)據(jù)在CPU和GPU之間霸琴,以及GPU與GPU之間的傳遞會比較耗時椒振,應當盡量避免
- 在進行低精度的計算時,可以考慮HalfTensor梧乘,它相比于FloatTensor能節(jié)省一半的顯存澎迎,但需千萬注意數(shù)值溢出的情況。
另外這里需要專門提一下宋下,大部分的損失函數(shù)也都屬于nn.Moudle嗡善,但在使用GPU時,很多時候我們都忘記使用它的.cuda方法学歧,這在大多數(shù)情況下不會報錯罩引,因為損失函數(shù)本身沒有可學習的參數(shù)(learnable parameters)。但在某些情況下會出現(xiàn)問題枝笨,為了保險起見同時也為了代碼更規(guī)范袁铐,應記得調(diào)用criterion.cuda。下面舉例說明横浑。
# 交叉熵損失函數(shù)剔桨,帶權(quán)重
criterion = t.nn.CrossEntropyLoss(weight=t.Tensor([1, 3]))
input = t.randn(4, 2).cuda()
target = t.Tensor([1, 0, 0, 1]).long().cuda()
# 下面這行會報錯,因weight未被轉(zhuǎn)移至GPU
# loss = criterion(input, target)
# 這行則不會報錯
criterion.cuda()
loss = criterion(input, target)
criterion._buffers
輸出:
OrderedDict([('weight', tensor([1., 3.], device='cuda:0'))])
而除了調(diào)用對象的.cuda方法之外徙融,還可以使用torch.cuda.device洒缀,來指定默認使用哪一塊GPU,或使用torch.set_default_tensor_type使程序默認使用GPU,不需要手動調(diào)用cuda树绩。
# 如果未指定使用哪塊GPU萨脑,默認使用GPU 0
x = t.cuda.FloatTensor(2, 3)
# x.get_device() == 0
y = t.FloatTensor(2, 3).cuda()
# y.get_device() == 0
# 指定默認使用GPU 0
with t.cuda.device(0):
# 在GPU 0上構(gòu)建tensor
a = t.cuda.FloatTensor(2, 3)
# 將tensor轉(zhuǎn)移至GPU 0
b = t.FloatTensor(2, 3).cuda()
print(a.get_device() == b.get_device() == 0 )
c = a + b
print(c.get_device() == 0)
z = x + y
print(z.get_device() == 0)
# 手動指定使用GPU 0
d = t.randn(2, 3).cuda(0)
print(d.get_device() == 2)
輸出:
True
True
True
False
t.set_default_tensor_type('torch.cuda.FloatTensor') # 指定默認tensor的類型為GPU上的FloatTensor
a = t.ones(2, 3)
a.is_cuda
輸出:
True
如果服務器具有多個GPU,tensor.cuda()
方法會將tensor保存到第一塊GPU上饺饭,等價于tensor.cuda(0)
渤早。此時如果想使用第二塊GPU,需手動指定tensor.cuda(1)
瘫俊,而這需要修改大量代碼鹊杖,很是繁瑣。這里有兩種替代方法:
- 一種是先調(diào)用
t.cuda.set_device(1)
指定使用第二塊GPU扛芽,后續(xù)的.cuda()
都無需更改骂蓖,切換GPU只需修改這一行代碼。 - 更推薦的方法是設置環(huán)境變量
CUDA_VISIBLE_DEVICES
川尖,例如當export CUDA_VISIBLE_DEVICE=1
(下標是從0開始涯竟,1代表第二塊GPU),只使用第二塊物理GPU空厌,但在程序中這塊GPU會被看成是第一塊邏輯GPU庐船,因此此時調(diào)用tensor.cuda()
會將Tensor轉(zhuǎn)移至第二塊物理GPU。CUDA_VISIBLE_DEVICES
還可以指定多個GPU嘲更,如export CUDA_VISIBLE_DEVICES=0,2,3
筐钟,那么第一、三赋朦、四塊物理GPU會被映射成第一篓冲、二、三塊邏輯GPU宠哄,tensor.cuda(1)
會將Tensor轉(zhuǎn)移到第三塊物理GPU上壹将。
設置CUDA_VISIBLE_DEVICES
有兩種方法,一種是在命令行中CUDA_VISIBLE_DEVICES=0,1 python main.py
毛嫉,一種是在程序中import os;os.environ["CUDA_VISIBLE_DEVICES"] = "2"
诽俯。如果使用IPython或者Jupyter notebook,還可以使用%env CUDA_VISIBLE_DEVICES=1,2
來設置環(huán)境變量承粤。
從 0.4 版本開始暴区,pytorch新增了tensor.to(device)
方法,能夠?qū)崿F(xiàn)設備透明辛臊,便于實現(xiàn)CPU/GPU兼容仙粱。這部份內(nèi)容已經(jīng)在第三章講解過了。
從PyTorch 0.2版本中彻舰,PyTorch新增分布式GPU支持伐割。分布式是指有多個GPU在多臺服務器上候味,而并行一般指的是一臺服務器上的多個GPU。分布式涉及到了服務器之間的通信隔心,因此比較復雜负溪,PyTorch封裝了相應的接口,可以用幾句簡單的代碼實現(xiàn)分布式訓練济炎。分布式對普通用戶來說比較遙遠,因為搭建一個分布式集群的代價十分大辐真,使用也比較復雜须尚。相比之下一機多卡更加現(xiàn)實。對于分布式訓練侍咱,這里不做太多的介紹耐床,感興趣的讀者可參考文檔distributed。
5.5 持久化
在PyTorch中楔脯,以下對象可以持久化到硬盤撩轰,并能通過相應的方法加載到內(nèi)存中:
- Tensor
- Variable
- nn.Module
- Optimizer
本質(zhì)上上述這些信息最終都是保存成Tensor。Tensor的保存和加載十分的簡單昧廷,使用t.save和t.load即可完成相應的功能堪嫂。在save/load時可指定使用的pickle模塊,在load時還可將GPU tensor映射到CPU或其它GPU上木柬。
我們可以通過t.save(obj, file_name)等方法保存任意可序列化的對象皆串,然后通過obj = t.load(file_name)方法加載保存的數(shù)據(jù)。對于Module和Optimizer對象眉枕,這里建議保存對應的state_dict恶复,而不是直接保存整個Module/Optimizer對象。Optimizer對象保存的主要是參數(shù)速挑,以及動量信息谤牡,通過加載之前的動量信息,能夠有效地減少模型震蕩姥宝,下面舉例說明翅萤。
a = t.Tensor(3, 4)
if t.cuda.is_available():
a = a.cuda(0) # 把a轉(zhuǎn)為GPU0上的tensor,
t.save(a,'a.pth')
# 加載為b, 存儲于GPU0上(因為保存時tensor就在GPU0上)
b = t.load('a.pth')
# 加載為c, 存儲于CPU
c = t.load('a.pth', map_location=lambda storage, loc: storage)
# 加載為d, 存儲于GPU0上
d = t.load('a.pth', map_location={'cuda:1':'cuda:0'})
t.set_default_tensor_type('torch.FloatTensor')
from torchvision.models import SqueezeNet
model = SqueezeNet()
# module的state_dict是一個字典
model.state_dict().keys()
輸出:
odict_keys(['features.0.weight', 'features.0.bias', 'features.3.squeeze.weight', 'features.3.squeeze.bias', 'features.3.expand1x1.weight', 'features.3.expand1x1.bias', 'features.3.expand3x3.weight', 'features.3.expand3x3.bias', 'features.4.squeeze.weight', 'features.4.squeeze.bias', 'features.4.expand1x1.weight', 'features.4.expand1x1.bias', 'features.4.expand3x3.weight', 'features.4.expand3x3.bias', 'features.5.squeeze.weight', 'features.5.squeeze.bias', 'features.5.expand1x1.weight', 'features.5.expand1x1.bias', 'features.5.expand3x3.weight', 'features.5.expand3x3.bias', 'features.7.squeeze.weight', 'features.7.squeeze.bias', 'features.7.expand1x1.weight', 'features.7.expand1x1.bias', 'features.7.expand3x3.weight', 'features.7.expand3x3.bias', 'features.8.squeeze.weight', 'features.8.squeeze.bias', 'features.8.expand1x1.weight', 'features.8.expand1x1.bias', 'features.8.expand3x3.weight', 'features.8.expand3x3.bias', 'features.9.squeeze.weight', 'features.9.squeeze.bias', 'features.9.expand1x1.weight', 'features.9.expand1x1.bias', 'features.9.expand3x3.weight', 'features.9.expand3x3.bias', 'features.10.squeeze.weight', 'features.10.squeeze.bias', 'features.10.expand1x1.weight', 'features.10.expand1x1.bias', 'features.10.expand3x3.weight', 'features.10.expand3x3.bias', 'features.12.squeeze.weight', 'features.12.squeeze.bias', 'features.12.expand1x1.weight', 'features.12.expand1x1.bias', 'features.12.expand3x3.weight', 'features.12.expand3x3.bias', 'classifier.1.weight', 'classifier.1.bias'])
# Module對象的保存與加載
t.save(model.state_dict(), 'squeezenet.pth')
model.load_state_dict(t.load('squeezenet.pth'))
輸出:
<All keys matched successfully>
optimizer = t.optim.Adam(model.parameters(), lr=0.1)
t.save(optimizer.state_dict(), 'optimizer.pth')
optimizer.load_state_dict(t.load('optimizer.pth'))
all_data = dict(
optimizer = optimizer.state_dict(),
model = model.state_dict(),
info = u'模型和優(yōu)化器的所有參數(shù)'
)
t.save(all_data, 'all.pth')
all_data = t.load('all.pth')
all_data.keys()
輸出:
dict_keys(['optimizer', 'model', 'info'])
本章介紹了一些工具模塊,這些工具有些位于PyTorch中腊满,有些是獨立于PyTorch的第三方模塊断序。這些模塊主要設計數(shù)據(jù)加載、可視化和GPU加速相關的內(nèi)容糜烹,合理地使用這些模塊能極大地提升我們的編程效率违诗。