cv中的dataloader使用
加載頭文件
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets,transforms
transforms
表示對(duì)圖片的預(yù)處理方式
data_transform={'train':transforms.Compose([
# transforms.RandomResizedCrop(image_size),
# transforms.Resize(224),
transforms.RandomResizedCrop(int(image_size*1.2)),
# transforms.ToPILImage(),
transforms.RandomAffine(15),
# transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.RandomGrayscale(),
transforms.TenCrop(image_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) for crop in crops])),
# transforms.FiveCop(image_size),
# Lambda(lambda crops: torch.stack([transfoms.ToTensor()(crop) for crop in crops])),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"val":transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"test":transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])}
使用datasets.ImageFolder
加載圖片數(shù)據(jù)
image_datasets={name:datasets.ImageFolder(os.path.join(rootpath,name),data_transform[name]) for name in ['train','val','test']}
生成dataloader
dataloaders={name : torch.utils.data.DataLoader(image_datasets[name],batch_size=batch_size,shuffle=True) for name in ['train','val']}
testDataloader=torch.utils.data.DataLoader(image_datasets['test'],batch_size=1,shuffle=False)
使用方法簸淀,每次會(huì)讀出一個(gè)batch_size
的數(shù)據(jù)绕娘。
for index,item in enumerate(dataloaders['train'])
nlp中的dataloader的使用
torch.utils.data.DataLoader
中的參數(shù):
- dataset (Dataset) – dataset from which to load the data.
- batch_size (int, optional) – how many samples per batch to load (default: 1).
-
shuffle (bool, optional) – set to
True
to have the data reshuffled at every epoch (default: False). -
sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified,
shuffle
must be False. - batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
- num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
- collate_fn (callable*, *optional) – merges a list of samples to form a mini-batch.
-
pin_memory (bool, optional) – If
True
, the data loader will copy tensors into CUDA pinned memory before returning them. -
drop_last (bool, optional) – set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) - timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
-
worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in
[0, num_workers - 1]
) as input, after seeding and before data loading. (default: None)
需要自己構(gòu)造的兩個(gè)東西
Dataloader的處理邏輯是先通過Dataset類里面的 __getitem__
函數(shù)獲取單個(gè)的數(shù)據(jù)锯七,然后組合成batch,再使用collate_fn所指定的函數(shù)對(duì)這個(gè)batch做一些操作翩伪,比如padding啊之類的微猖。
在NLP
中的使用主要是要重構(gòu)兩個(gè)兩個(gè)東西,一個(gè)是dataset
,必須繼承自torch.utils.data.Dataset
,內(nèi)部要實(shí)現(xiàn)兩個(gè)函數(shù)一個(gè)是__lent__
用來(lái)獲取整個(gè)數(shù)據(jù)集的大小缘屹,一個(gè)是__getitem__
用來(lái)從數(shù)據(jù)集中得到一個(gè)數(shù)據(jù)片段item
凛剥。
class Dataset(torch.utils.data.Dataset):
def __init__(self, filepath=None,dataLen=None):
self.file = filepath
self.dataLen = dataLen
def __getitem__(self, index):
A,B,path,hop= linecache.getline(self.file, index+1).split('\t')
return A,B,path.split(' '),int(hop)
def __len__(self):
return self.dataLen
因?yàn)?code>dataloader是有batch_size
參數(shù)的,我們可以通過自定義collate_fn=myfunction
來(lái)設(shè)計(jì)數(shù)據(jù)收集的方式轻姿,意思是已經(jīng)通過上面的Dataset
類中的__getitem__
函數(shù)采樣了batch_size
數(shù)據(jù)犁珠,以一個(gè)包的形式傳遞給collate_fn
所指定的函數(shù)逻炊。
def myfunction(data):
A,B,path,hop=zip(*data)
print('A:',A," B:",B," path:",path," hop:",hop)
raise Exception('utils collate_fun 147')
return A,B,path,hop
for index,item in enumerate(dataloaders['train'])
A,B,path.hop=item
nlp
任務(wù)中,經(jīng)常在collate_fn
指定的函數(shù)里面做padding犁享,就是將在同一個(gè)batch中不一樣長(zhǎng)的句子padding成一樣長(zhǎng)余素。
def myfunction(data):
src, tgt, original_src, original_tgt = zip(*data)
src_len = [len(s) for s in src]
src_pad = torch.zeros(len(src), max(src_len)).long()
for i, s in enumerate(src):
end = src_len[i]
src_pad[i, :end] = torch.LongTensor(s[end-1::-1])
tgt_len = [len(s) for s in tgt]
tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()
for i, s in enumerate(tgt):
end = tgt_len[i]
tgt_pad[i, :end] = torch.LongTensor(s)[:end]
return src_pad, tgt_pad, \
torch.LongTensor(src_len), torch.LongTensor(tgt_len), \
original_src, original_tgt