實(shí)現(xiàn)了一個小demo演示SubsetRandomSampler的用法
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
dataset = TensorDataset(torch.tensor(list(range(20)))) # 構(gòu)造一個數(shù)據(jù)集(0到19)
idx = list(range(len(dataset))) # 創(chuàng)建索引,SubsetRandomSampler會自動亂序
# idx = torch.zeros(len(dataset)).long() # 傳入相同的索引姓迅,SubsetRandomSampler只會采樣相同結(jié)果
n = len(dataset)
split = n//5
train_sampler = SubsetRandomSampler(idx[split::]) # 隨機(jī)取80%的數(shù)據(jù)做訓(xùn)練集
test_sampler = SubsetRandomSampler(idx[::split]) # 隨機(jī)取20%的數(shù)據(jù)做測試集
train_loader = DataLoader(dataset, sampler=train_sampler)
test_loader = DataLoader(dataset, sampler=test_sampler)
print('data for training:')
for i in train_loader:
print(i)
print('data for testing:')
for i in test_loader:
print(i)
注意train_loader和test_loader的dataset都是一樣的毅否,比如要獲取loader的樣本總數(shù)蛾狗,應(yīng)該len(sampler)而不是len(dataset)
len(train_loader.sampler)
len(test_loader.sampler)
關(guān)于pytorch的其他sampler的文檔:
https://blog.csdn.net/aiwanghuan5017/article/details/102147825