前言
在pytorch中的數(shù)據(jù)加載到模型的操作順序是這樣的
1 創(chuàng)建一個Dataset
對象
2 創(chuàng)建一個Dataloader
對象
3 循環(huán)這個Dataloader
對象嚎于,將img,label加載到模型中進行訓(xùn)練
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img,label in dataloader :
...........
所以,作為直接對數(shù)據(jù)進入模型的關(guān)鍵一步后专,DataLoader
相當(dāng)重要。
首先介紹下DataLoader
他是pytorch中數(shù)據(jù)讀取的一個重要接口输莺,該接口定義在dataloader.py中戚哎,只要不是用戶進行重寫,一般都要用到該接口嫂用,該接口的目的:將自定義的Dataset根據(jù)batch size大小型凳,是否shuffle等封裝成一個Batch size 大小的Tensor,用于后面的訓(xùn)練嘱函。
官方對DataLoader
的說明是:
數(shù)據(jù)加載由數(shù)據(jù)集和采樣器組成甘畅,基于Python的單、多進程的iterators來處理數(shù)據(jù)
DataLoader
先介紹下DataLoader(object)
的參數(shù):
- dataset(Dataset):傳入的數(shù)據(jù)集
- batch_size(int, optional):每個batch有多少個樣本
- shuffle(bool, optional): 在每個epoch開始的時候,對數(shù)據(jù)進行重新排序
- sampler(Sampler, optional):自定義從數(shù)據(jù)集中取樣本的策略疏唾,如果指定這個參數(shù)蓄氧,那么shuffle必須為False
- batch_sampler(Sampler, optional):與sampler類似,但是一次只返回一個batch的indices(索引)槐脏,需要注意的是喉童,一旦指定了這個參數(shù),那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
- num_workers (int, optional):這個參數(shù)決定了有幾個進程來處理data loading顿天。0意味著所有的數(shù)據(jù)都會被load進主進程堂氯。(默認為0)
- collate_fn (callable, optional):將一個list的sample組成一個mini-batch的函數(shù)
- pin_memory (bool, optional):如果設(shè)置為True,那么data loader將會在返回它們之前牌废,將tensors拷貝到CUDA中的固定內(nèi)存(CUDA pinned memory)中.
- drop_last (bool, optional):如果設(shè)置為True:這個是對最后的未完成的batch來說的咽白,比如你的batch_size設(shè)置為64,而一個epoch只有100個樣本鸟缕,那么訓(xùn)練的時候后面的36個就被扔掉了…
如果為False(默認)晶框,那么會繼續(xù)正常執(zhí)行,只是最后的batch_size會小一點懂从。- timeout(numeric, optional): 如果是正數(shù)授段,表明等待從worker進程中收集一個batch等待的時間,若超出設(shè)定的時間還沒有收集到莫绣,那就不收集這個內(nèi)容了。這個numeric應(yīng)總是大于等于0悠鞍。默認為0
-worker_init_fn (callable, optional): 每個worker初始化函數(shù) If not None, this will be called on each