自定義數(shù)據(jù)集
在訓(xùn)練深度學(xué)習(xí)模型之前凤瘦,樣本集的制作非常重要。在pytorch中拂封,提供了一些接口和類茬射,方便我們定義自己的數(shù)據(jù)集合,下面完整的試驗(yàn)自定義樣本集的整個(gè)流程烘苹。
開發(fā)環(huán)境
- Ubuntu 18.04
- pytorch 1.0
- pycharm
實(shí)驗(yàn)?zāi)康?/h1>
- 掌握pytorch中數(shù)據(jù)集相關(guān)的API接口和類
- 熟悉數(shù)據(jù)集制作的整個(gè)流程
實(shí)驗(yàn)過(guò)程
1.收集圖像樣本
以簡(jiǎn)單的貓狗二分類為例躲株,可以在網(wǎng)上下載一些貓狗圖片。創(chuàng)建以下目錄:
- data-------------根目錄
- data/test-------測(cè)試集
- data/train------訓(xùn)練集
- data/val--------驗(yàn)證集
在test/train/val之下在校分別創(chuàng)建2個(gè)文件夾镣衡,dog, cat
cat, dog文件夾下分別存放2類圖像:
標(biāo)簽
種類 | 標(biāo)簽 |
---|---|
cat | 0 |
dog | 1 |
之后寫一個(gè)簡(jiǎn)單的python腳本霜定,生成txt文件档悠,用于指明每個(gè)圖像和標(biāo)簽的對(duì)應(yīng)關(guān)系。
格式: /cat/1.jpg 0 \n dog/1.jpg 1 \n .....
如圖:
至此望浩,樣本集的收集以及簡(jiǎn)單歸類完成辖所,下面將開始采用pytorch的數(shù)據(jù)集相關(guān)API和類。
2. 使用pytorch相關(guān)類磨德,API對(duì)數(shù)據(jù)集進(jìn)行封裝
2.1 pytorch中數(shù)據(jù)集相關(guān)的類缘回,接口
pytorch中數(shù)據(jù)集相關(guān)的類位于torch.utils.data
package中。
https://pytorch.org/docs/stable/data.html
本次實(shí)驗(yàn)典挑,主要使用以下類:
torch.utils.data.Dataset
torch.utils.data.DataLoader
Dataset
類的使用: 所有的類都應(yīng)該是此類的子類(也就是說(shuō)應(yīng)該繼承該類)酥宴。 所有的子類都要重寫(override) __len()__
, __getitem()__
這兩個(gè)方法。
方法 | 作用 |
---|---|
__len()__ |
此方法應(yīng)該提供數(shù)據(jù)集的大小(容量) |
__getitem()__ |
此方法應(yīng)該提供支持下標(biāo)索方式引訪問(wèn)數(shù)據(jù)集 |
這里和Java抽象類很相似您觉,在抽象類abstract class
中拙寡,一般會(huì)定義一些抽象方法abstract method
,抽象方法:只有方法名沒(méi)有方法的具體實(shí)現(xiàn)。如果一個(gè)子類繼承于該抽象類琳水,要重寫(overrode)父類的抽象方法肆糕。
DataLoader
類的使用:
2.2 實(shí)現(xiàn)
- 使用到的python package
python package | 目的 |
---|---|
numpy |
矩陣操作,對(duì)圖像進(jìn)行轉(zhuǎn)置 |
skimage |
圖像處理在孝,圖像I/O,圖像變換 |
matplotlib |
圖像的顯示诚啃,可視化 |
os |
一些文件查找操作 |
torch |
pytorch |
torvision |
pytorch |
- 源碼
導(dǎo)入python包
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
第一步:
定義一個(gè)子類,繼承Dataset
類私沮, 重寫 __len()__
, __getitem()__
方法始赎。
細(xì)節(jié):
1.數(shù)據(jù)集中一個(gè)一樣的表示:采用字典的形式sample = {'image': image, 'label': label}
。
圖像的讀茸醒唷:采用
skimage.io
進(jìn)行讀取极阅,讀取之后的結(jié)果為numpy.ndarray
形式。圖像變換:transform參數(shù)
# step1: 定義MyDataset類涨享, 繼承Dataset, 重寫抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):
def __init__(self, root_dir, names_file, transform=None):
self.root_dir = root_dir
self.names_file = names_file
self.transform = transform
self.size = 0
self.names_list = []
if not os.path.isfile(self.names_file):
print(self.names_file + 'does not exist!')
file = open(self.names_file)
for f in file:
self.names_list.append(f)
self.size += 1
def __len__(self):
return self.size
def __getitem__(self, idx):
image_path = self.root_dir + self.names_list[idx].split(' ')[0]
if not os.path.isfile(image_path):
print(image_path + 'does not exist!')
return None
image = io.imread(image_path) # use skitimage
label = int(self.names_list[idx].split(' ')[1])
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
第二步
實(shí)例化一個(gè)對(duì)象,并讀取和顯示數(shù)據(jù)集
train_dataset = MyDataset(root_dir='./data/train',
names_file='./data/train/train.txt',
transform=None)
plt.figure()
for (cnt,i) in enumerate(train_dataset):
image = i['image']
label = i['label']
ax = plt.subplot(4, 4, cnt+1)
ax.axis('off')
ax.imshow(image)
ax.set_title('label {}'.format(label))
plt.pause(0.001)
if cnt == 15:
break
只顯示了部分?jǐn)?shù)據(jù)仆百,前部分全是cat
第三步(可選 optional)
對(duì)數(shù)據(jù)集進(jìn)行變換:一般收集到的圖像大小尺寸厕隧,亮度等存在差異,變換的目的就是使得數(shù)據(jù)歸一化俄周。另一方面吁讨,可以通過(guò)變換進(jìn)行數(shù)據(jù)增加data argument
關(guān)于pytorch中的變換transforms,請(qǐng)參考該系列之前的文章
由于數(shù)據(jù)集中樣本采用字典dicts
形式表示。 因此不能直接調(diào)用torchvision.transofrms中的方法峦朗。
本實(shí)驗(yàn)只進(jìn)行尺寸歸一化Resize, 數(shù)據(jù)類型變換ToTensor操作建丧。
Resize
# # 變換Resize
class Resize(object):
def __init__(self, output_size: tuple):
self.output_size = output_size
def __call__(self, sample):
# 圖像
image = sample['image']
# 使用skitimage.transform對(duì)圖像進(jìn)行縮放
image_new = transform.resize(image, self.output_size)
return {'image': image_new, 'label': sample['label']}
ToTensor
# # 變換ToTensor
class ToTensor(object):
def __call__(self, sample):
image = sample['image']
image_new = np.transpose(image, (2, 0, 1))
return {'image': torch.from_numpy(image_new),
'label': sample['label']}
第四步: 對(duì)整個(gè)數(shù)據(jù)集應(yīng)用變換
細(xì)節(jié): transformers.Compose()
將不同的幾個(gè)組合起來(lái)。先進(jìn)行Resize, 再進(jìn)行ToTensor
# 對(duì)原始的訓(xùn)練數(shù)據(jù)集進(jìn)行變換
transformed_trainset = MyDataset(root_dir='./data/train',
names_file='./data/train/train.txt',
transform=transforms.Compose(
[Resize((224,224)),
ToTensor()]
))
第五步: 使用DataLoader進(jìn)行包裝
為何要使用DataLoader?
① 深度學(xué)習(xí)的輸入是mini_batch形式
② 樣本加載時(shí)候可能需要隨機(jī)打亂順序波势,shuffle操作
③ 樣本加載需要采用多線程
pytorch提供的DataLoader
封裝了上述的功能翎朱,這樣使用起來(lái)更方便橄维。
# 使用DataLoader可以利用多線程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
batch_size=4,
shuffle=True,
num_workers=4)
可視化:
def show_images_batch(sample_batched):
images_batch, labels_batch = \
sample_batched['image'], sample_batched['label']
grid = make_grid(images_batch)
plt.imshow(grid.numpy().transpose(1, 2, 0))
# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
show_images_batch(sample_batch)
plt.axis('off')
plt.ioff()
plt.show()
plt.show()
通過(guò)DataLoader包裝之后拴曲,樣本以min_batch形式輸出争舞,而且進(jìn)行了隨機(jī)打亂順序。
至此澈灼,自定義數(shù)據(jù)集的完整流程已實(shí)現(xiàn)竞川,test, val集只需要改路徑即可。
補(bǔ)充
更簡(jiǎn)單的方法
上述繼承Dataset
, 重寫 __len()__
, __getitem()
是通用的方法叁熔,過(guò)程相對(duì)繁瑣委乌。對(duì)于簡(jiǎn)單的分類數(shù)據(jù)集,pytorch中提供了更簡(jiǎn)便的方式——ImageFolder
荣回。
如果每種類別的樣本放在各自的文件夾中遭贸,則可以直接使用ImageFolder
。
仍然以cat, dog 二分類數(shù)據(jù)集為例:
文件結(jié)構(gòu):
Code
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
# data_transform = transforms.Compose([
# transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# ])
data_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=4,
shuffle=True,
num_workers=4)
def show_batch_images(sample_batch):
labels_batch = sample_batch[1]
images_batch = sample_batch[0]
for i in range(4):
label_ = labels_batch[i].item()
image_ = np.transpose(images_batch[i], (1, 2, 0))
ax = plt.subplot(1, 4, i + 1)
ax.imshow(image_)
ax.set_title(str(label_))
ax.axis('off')
plt.pause(0.01)
plt.figure()
for i_batch, sample_batch in enumerate(train_dataloader):
show_batch_images(sample_batch)
plt.show()
由于 train 目錄下只有2個(gè)文件夾驹马,分別為cat, dog, 因此ImageFolder
安裝順序?qū)at使用標(biāo)簽0, dog使用標(biāo)簽1革砸。
End
參考:
https://pytorch.org/docs/stable/data.html
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html