準(zhǔn)備工作
本篇文章參考自pytorch官方教程拴测,即末尾參考的第一篇垫蛆,略去了不必要的亂七八糟的matlab顯示功能壹将,保留最實(shí)用的數(shù)據(jù)加載功能嗤攻。
先從這里下載并解壓示例數(shù)據(jù)集。這里介紹如何創(chuàng)建一個(gè)dataloader去加載該文件夾內(nèi)的數(shù)據(jù)集诽俯。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
Dataset Class
torch.utils.data.Dataset
是一個(gè)抽象類妇菱,代表了一個(gè)數(shù)據(jù)集。自定義數(shù)據(jù)集的時(shí)候需要重寫兩個(gè)方法暴区。
__len__
使得len(dataset)
可以返回dataset的大小
__getitem__
支持dataset[i]
可以取出第i個(gè)數(shù)據(jù)闯团。
下面為我們的數(shù)據(jù)集創(chuàng)建一個(gè)dataset類,首先會(huì)在__init__
方法中讀取csv文件仙粱,在__getitem__
方法中讀取圖片房交,這樣可以節(jié)約內(nèi)存,根據(jù)需要讀取圖片伐割,而不是一次性加載圖片到內(nèi)存中候味。
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')
至此數(shù)據(jù)就可以從face_dataset中讀取了。
變換
可以看到文件夾內(nèi)的圖片大小都不一致隔心,大多數(shù)的網(wǎng)絡(luò)都需要接受統(tǒng)一大小白群,所以需要對(duì)數(shù)據(jù)進(jìn)行一些預(yù)處理,例如縮放济炎,隨機(jī)裁剪川抡,轉(zhuǎn)化成張量。
我們會(huì)將這些方法寫道一個(gè)可調(diào)用的類中,而不是簡(jiǎn)單的函數(shù)中崖堤,如此一來變換的參數(shù)就不用每次調(diào)用都傳遞一次侍咱。所以我們需要在類中實(shí)現(xiàn)__call__
方法,有必要的話還要實(shí)現(xiàn)__init__
方法密幔。
我們可以像下面這樣調(diào)用楔脯。
tsfm = Transform(params)
transformed_sample = tsfm(sample)
像下面這樣定義
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size.
If tuple, output is matched to output_size.
If int, smaller of image edges is matched to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}
組合變換
如果我們需要做很最多變換,就需要把這些類組合到一起胯甩。像下面這樣
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
plt.show()
迭代器
下面把這些變換都結(jié)合到一起創(chuàng)建一個(gè)dataset昧廷。所有圖片都從文件名中,變換在讀取圖片是生效偎箫,每一個(gè)變換都是隨機(jī)的木柬。
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
然而我們丟失了一些特征,比如數(shù)據(jù)的批大小淹办,數(shù)據(jù)隨機(jī)眉枕,多gpu并行處理×可以用dataloader來玩速挑。
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
然后可以遍歷dataloader,讀取里面的數(shù)據(jù)副硅。
還有一點(diǎn)沒看完
## Afterword: torchvision
[Afterword: torchvision](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#afterword-torchvision)
In this tutorial, we have seen how to write and use datasets, transforms and dataloader. `torchvision` package provides some common datasets and transforms. You might not even have to write custom classes. One of the more generic datasets available in torchvision is `ImageFolder`. It assumes that images are organized in the following way:
<pre style="box-sizing: border-box; font-family: IBMPlexMono, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; font-size: 14px; margin-top: 0px; margin-bottom: 2.5rem; overflow: auto; display: block; color: rgb(33, 37, 41); padding: 1.375rem; background-color: rgb(243, 244, 247); white-space: pre-wrap; overflow-wrap: break-word;">root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
</pre>
where ‘a(chǎn)nts’, ‘bees’ etc. are class labels. Similarly generic transforms which operate on `PIL.Image` like `RandomHorizontalFlip`, `Scale`, are also available. You can use these to write a dataloader like this:
<pre style="box-sizing: border-box; font-family: IBMPlexMono, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; font-size: 14px; margin-top: 0px; margin-bottom: 2.5rem; overflow: auto; display: block; color: rgb(33, 37, 41); padding: 1.375rem; background-color: rgb(243, 244, 247); white-space: pre-wrap; overflow-wrap: break-word;">import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)</pre>