為什么要用筛谚?
習慣于自己實現(xiàn)業(yè)務邏輯的每一步碧磅,以至于沒有意識去尋找框架本身自有的數(shù)據(jù)預處理方法俊犯,Pytorch的Dataset 和 DataLoader便于加載和迭代處理數(shù)據(jù)闪彼,并且可以傻瓜式實現(xiàn)各種常見的數(shù)據(jù)預處理敬拓,以供訓練使用啼辣。
調(diào)包俠
from torch.utils.data.dataset import Dataset, DataLoader
from torchvision import transforms ##可方便指定各種transformer啊研,直接傳入DataLoader
Dataset 和 DataLoader是什么?
Dataset是一個包裝類,可對數(shù)據(jù)進行張量(tensor)的封裝党远,其可作為DataLoader的參數(shù)傳入削解,進一步實現(xiàn)基于tensor的數(shù)據(jù)預處理。
如何處理自己的數(shù)據(jù)集沟娱?
很多教程里分兩種情況:數(shù)據(jù)同在一個文件夾氛驮;數(shù)據(jù)按類別分布在不同文件夾。其實剛開始我是一頭霧水济似,后來總結(jié)后發(fā)現(xiàn)矫废,兩種情況均可用一種方法來處理,即:只要有一份文件砰蠢,記錄圖像數(shù)據(jù)路徑及對應的標簽即可磷脯,如下所示:
record.txt 示例:
pic_path label
./pic_01/aaa.bmp 1
./pic_22/bbb.bmp 0
./pic_03/ccc.bmp 3
./pic_01/ddd.bmp 1
...
其實有了上面的一份數(shù)據(jù)對照表文件,即可不用管是否在同一文件夾或是不同文件夾的情況娩脾,我自己感覺是要方便一些赵誓。下面就按照這種方法來介紹如何使用。
第一步:實現(xiàn)MyDataset類
既然是要處理自己的數(shù)據(jù)集柿赊,那么一般情況下還是寫一個自己的Dataset類俩功,該類要繼承Dataset,并重寫 __ init __() 和 __ getitem __() 兩個方法碰声。
例如:
class MyDataset(Dataset):
def __init__(self, record_path, is_train=True):
## record_path:記錄圖片路徑及對應label的文件
self.data = []
self.is_train = is_train
with open(record_path) as fp:
for line in fp.readlines():
if line == '\n':
break
else:
tmp = line.split("\t")
## tmp[0]:某圖片的路徑诡蜓,tmp[1]:該圖片對應的label
self.data.append([tmp[0], tmp[1]])
# 定義transform,將數(shù)據(jù)封裝為Tensor
self.transformations = transforms.Compose([transforms.ToTensor()])
# 獲取單條數(shù)據(jù)
def __getitem__(self, index):
img = self.transformations (Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
label = int(self.data[index][1])
return img, label
# 數(shù)據(jù)集長度
def __len__(self):
return len(self.data)
上面是一個簡單的MyDataset類胰挑,僅依賴記錄了圖像位置以及相應label的record文件蔓罚,實現(xiàn)對數(shù)據(jù)集的讀取和Tensor的轉(zhuǎn)換
當然,根據(jù)個人對數(shù)據(jù)預處理的需求不同瞻颂,該類的實現(xiàn)可進一步完善豺谈,例如:
class MyDataset(Dataset):
def __init__(self, base_path, is_train=True):
self.data = []
self.is_train = is_train
with open(base_path) as fp:
for line in fp.readlines():
if line == '\n':
break
else:
tmp = line.split("\t")
self.data.append([tmp[0], tmp[1]])
## transforms.Normalize:對R G B三通道數(shù)據(jù)做均值方差歸一化,因此給出下方三個均值和方差
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
## 可由 transforms.Compose([transformer_01, transformer_02, ...])實現(xiàn)一些數(shù)據(jù)的處理和增強
self.trainTransform = transforms.Compose([ ## train訓練集處理
transforms.RandomCrop(32, padding=4), ## 圖像裁剪的transforms
transforms.RandomHorizontalFlip(p=0.5), ## 以50%概率水平翻轉(zhuǎn)
transforms.ToTensor(), ## 轉(zhuǎn)為Tensor形式
normTransform ## 進行 R G B數(shù)據(jù)歸一化
])
## 測試集的transforms數(shù)據(jù)處理
self.testTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
# 獲取單條數(shù)據(jù)
def __getitem__(self, index):
img = self.trainTransform(Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
if not self.is_train:
img = self.testTransform(Image.open(self.data[index][0]).resize((256, 256)).convert('RGB'))
label = int(self.data[index][1])
return img, label
# 數(shù)據(jù)集長度
def __len__(self):
return len(self.data)
或許已經(jīng)看出來了贡这,所有可能的數(shù)據(jù)處理或數(shù)據(jù)增強操作茬末,都可通過transforms來進行調(diào)用與封裝,是不是一下變得很方便呢盖矫!
第二步:將MyDataset裝入DataLoader中
MyDataset類中的init方法要求傳入記錄數(shù)據(jù)路徑及l(fā)abel的文件丽惭,因此可如下所示進行操作:
import MyDataset
train_data = MyDataset.MyDataset("./train_record.txt")
test_data = myDataset.MyDataset("./test_record.txt")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dataset=train_data,batch_size=64,shuffle=True,**kwargs)
testLoader = DataLoader(dataset=test_data,batch_size=64,shuffle=False, **kwargs)
這樣,便生成了trainLoader 和testLoader
第三步:在訓練中使用DataLoader
for epoch in range(1, args.nEpochs + 1):
## 定義好的train方法
train(args, epoch, model, trainLoader, optimizer)
## 定義好的val方法辈双,用于測試或驗證
val(args, epoch, model, testLoader, optimizer)
最后
以上便是使用 Dataset和DataLoader處理自己數(shù)據(jù)集的通用方法责掏,當然本次僅記錄了圖片數(shù)據(jù)的使用方法,后續(xù)記錄文本數(shù)據(jù)處理方法湃望。
彩蛋
ooh~~ 那么對于Pytorch自帶數(shù)據(jù)集如果處理呢换衬?
若直接使用 CIFAR10
數(shù)據(jù)集局义,可以如下處理:
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normTransform
])
testTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dset.CIFAR10(root='cifar', train=True, download=True,
transform=trainTransform),batch_size=64, shuffle=True, **kwargs)
testLoader = DataLoader(dset.CIFAR10(root='cifar', train=False, download=True,
transform=testTransform),batch_size=64, shuffle=False, **kwargs)
其實也就是 torchvision.datasets
將這些共用數(shù)據(jù)集本身就做了 Dataset類的封裝,因此直接調(diào)用冗疮,傳入你想要的transforms萄唇,再丟給DataLoader即可。
轉(zhuǎn)載注明出處:http://www.reibang.com/p/b558c538eac2