import torch
from torchvision import datasets, transforms
加載圖像的最簡(jiǎn)單方式是使用 torchvision
的 datasets.ImageFolder
(文檔)壤靶。使用 ImageFolder
丁逝,就是這樣:
dataset = datasets.ImageFolder('path/to/data', transform=transform)
其中 'path/to/data'
是通往數(shù)據(jù)目錄的文件路徑羽资,transform
是用 torchvision
中的 transforms
模塊構(gòu)建的處理步驟列表淘菩。ImageFolder 中的文件和目錄應(yīng)按以下格式構(gòu)建:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
每個(gè)圖像類別都有各自存儲(chǔ)圖像的目錄(cat
和 dog
)。然后使用從目錄名中提取的類別標(biāo)記圖像屠升。圖像 123.png
將采用類別標(biāo)簽 cat
潮改。你可以從此頁(yè)面下載已經(jīng)采用此結(jié)構(gòu)的數(shù)據(jù)集。在其中已被拆分成了訓(xùn)練集和測(cè)試集腹暖。
一进陡、 轉(zhuǎn)換
使用 ImageFolder
加載數(shù)據(jù)時(shí),你需要定義轉(zhuǎn)換微服。例如趾疚,圖像的尺寸不相同,但是我們需要將它們變成統(tǒng)一尺寸以蕴,才能用于訓(xùn)練模型柏锄。你可以使用 transforms.Resize()
調(diào)整尺寸或使用 transforms.CenterCrop()
施逾、transforms.RandomResizedCrop()
等裁剪圖像。我們還需要使用 transforms.ToTensor()
將圖像轉(zhuǎn)換為 PyTorch 張量。通常陌知,你將使用 transforms.Compose()
來(lái)將這些轉(zhuǎn)換結(jié)合到一條流水線中,這條流水線接收包含轉(zhuǎn)換的列表梗醇,并按順序運(yùn)行捉捅。流程大概為縮放、裁剪习劫,然后轉(zhuǎn)換為張量:
transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
二咆瘟、 數(shù)據(jù)加載器
加載 ImageFolder
后,你需要將其傳入 DataLoader
诽里。DataLoader
接受數(shù)據(jù)集(例如要從 ImageFolder
獲得的數(shù)據(jù)集)袒餐,并返回批次圖像和相應(yīng)的標(biāo)簽。你可以設(shè)置各種參數(shù),例如批次大小灸眼,或者在每個(gè)周期之后是否重排數(shù)據(jù)卧檐。
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
dataloader
是一個(gè)生成器。要從中獲取數(shù)據(jù)焰宣,你需要遍歷它霉囚,或?qū)⑺D(zhuǎn)換成迭代器并調(diào)用 next()
。
# Looping through it, get a batch on each loop
for images, labels in dataloader:
pass
# Get one batch
images, labels = next(iter(dataloader))
三匕积、 數(shù)據(jù)增強(qiáng)
訓(xùn)練神經(jīng)網(wǎng)絡(luò)的一個(gè)常見策略是在輸入數(shù)據(jù)本身里引入隨機(jī)性佛嬉。例如,你可以在訓(xùn)練過(guò)程中隨機(jī)地旋轉(zhuǎn)闸天、翻轉(zhuǎn)暖呕、縮放和/或裁剪圖像。這樣一來(lái)苞氮,你的神經(jīng)網(wǎng)絡(luò)在處理位置湾揽、大小、方向不同的相同圖像時(shí)笼吟,可以更好地進(jìn)行泛化库物。
要隨機(jī)旋轉(zhuǎn)、縮放贷帮、裁剪圖像戚揭,然后翻轉(zhuǎn)圖像,你需要如下所示地定義轉(zhuǎn)換:
train_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])])
另外撵枢,還需要使用 transforms.Normalize 標(biāo)準(zhǔn)化圖像民晒。傳入均值和標(biāo)準(zhǔn)偏差列表,然后標(biāo)準(zhǔn)化顏色通道锄禽。
減去 mean
使數(shù)據(jù)以 0 居中潜必,除以 std
使值位于 -1 到 1 之間。標(biāo)準(zhǔn)化有助于神經(jīng)網(wǎng)絡(luò)使權(quán)重接近 0沃但,這能使反向傳播更為穩(wěn)定磁滚。不標(biāo)準(zhǔn)化的話,網(wǎng)絡(luò)往往會(huì)學(xué)習(xí)失敗宵晚。
你可以在此處查看可用的轉(zhuǎn)換列表垂攘。測(cè)試時(shí),不能改變圖像(但是需要以同一方式標(biāo)準(zhǔn)化)淤刃。因此晒他,在驗(yàn)證/測(cè)試圖像時(shí),通常只能調(diào)整大小和裁剪圖像钝凶。
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
data_dir = 'Cat_Dog_data'
# TODO: Define transforms for the training data and testing data
train_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])])
# test_transforms = transforms.Compose([transforms.RandomRotation(30),
# transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# transforms.Normalize([0.5, 0.5, 0.5],
# [0.5, 0.5, 0.5])])
#測(cè)試時(shí)仪芒,不能改變圖像(但是需要以同一方式標(biāo)準(zhǔn)化)唁影。因此耕陷,在驗(yàn)證/測(cè)試圖像時(shí)掂名,通常只能調(diào)整大小和裁#剪圖像。
test_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)
trainloader = torch.utils.data.DataLoader(train_data, batch_size=32)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
# change this to the trainloader or testloader
data_iter = iter(testloader)
images, labels = next(data_iter)
fig, axes = plt.subplots(figsize=(10,4), ncols=4)
for ii in range(4):
ax = axes[ii]
helper.imshow(images[ii], ax=ax, normalize=False)