Pytorch通過torch.utils.data對(duì)一般常用數(shù)據(jù)加載進(jìn)行封裝苦始,可以容易的實(shí)現(xiàn)多線程數(shù)據(jù)預(yù)讀和批量加載恕出,并且torchvision已經(jīng)預(yù)先實(shí)現(xiàn)了常用圖像數(shù)據(jù)集合粒竖。
from torch.utils.data import Dataset
import pandas as pd
#定義一個(gè)數(shù)據(jù)集
class BulldozerDataset(Dataset):
def __init__(self, csv_file):
self.df=pd.read_csv(csv_file)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
return self.df.iloc[idx].SalePrice
DataLoader主要提供DataSet的讀取操作宏赘,常用參數(shù)有
batch_size, shuffle, num_workers(加載數(shù)據(jù)時(shí)使用幾個(gè)子進(jìn)程)
ds_demo= BulldozerDataset('median_benchmark.csv')
dl = torch.utils.data.DataLoader(ds_demo, batch_size=32, shuffle=True, num_workers=0)
#DataLoader返回的是一個(gè)可迭代對(duì)象蒋荚,我們可以使用迭代器分次獲取數(shù)據(jù)
idata = iter(dl)
print(next(idata))
#使用for循環(huán)對(duì)其遍歷
for i, data in enumerate(dl):
print(i, data)
torchvision包是pytorch專門用來處理圖像的庫疾棵,torchvision.dataset提供了很多圖片數(shù)據(jù)集戈钢,
torchvision.models提供了常用模型結(jié)構(gòu)
torchvision.transforms提供了一般的圖像轉(zhuǎn)換操作類,用作數(shù)據(jù)處理和數(shù)據(jù)增強(qiáng)
from torchvision import transforms as transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), #先四周填充0是尔,在把圖像隨機(jī)裁剪成32*32
transforms.RandomHorizontalFlip(), #圖像一半的概率翻轉(zhuǎn)殉了,一半的概率不翻轉(zhuǎn)
transforms.RandomRotation((-45,45)), #隨機(jī)旋轉(zhuǎn)
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每層的歸一化用到的均值和方差
])