torch.utils.data
模塊提供了一些用于數(shù)據(jù)加載和處理的工具逸雹,其中最常用的類和函數(shù)包括 Dataset, DataLoader, Sampler 以及相關(guān)的輔助工具哀墓。這些工具使得處理大型數(shù)據(jù)集以及在批處理、并行化和數(shù)據(jù)預(yù)處理等方面變得更加簡(jiǎn)便。
Dataset
Dataset
是一個(gè)抽象類繁涂,用戶可以通過(guò)繼承它來(lái)定義自己的數(shù)據(jù)集衡招。需要實(shí)現(xiàn)__len__
和__getitem__
方法。抽象類是一種不能被實(shí)例化的類阔馋,它通常作為其他類的基類玛荞,提供抽象方法的定義,而這些方法需要在具體的子類中實(shí)現(xiàn)呕寝。抽象類的主要作用是定義接口或提供框架勋眯,確保子類實(shí)現(xiàn)特定的方法,從而保證子類具有一致的接口和行為下梢。
示例代碼
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 創(chuàng)建數(shù)據(jù)集
dataset = MyDataset(data, labels)
dataset.data
dataset.labels
random_split
random_split
用于將數(shù)據(jù)集按比例隨機(jī)劃分成多個(gè)子集客蹋。示例代碼
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
DataLoader
DataLoader
是用于將數(shù)據(jù)集分成小批量,并提供自動(dòng)化多線程數(shù)據(jù)加載的工具孽江。常用參數(shù)包括 batch_size, shuffle, num_workers 等讶坯。示例代碼
from torch.utils.data import DataLoader
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
總結(jié)
先定義Dataset類創(chuàng)建數(shù)據(jù)集,然后random_split劃分?jǐn)?shù)據(jù)集岗屏,最后DataLoader常見(jiàn)train_loader/valid_loader/test_loader