一 寫在前面
未經(jīng)允許,不得轉(zhuǎn)載鸽斟,謝謝~~~
- 之前寫了一篇沒(méi)有使用任何深度學(xué)習(xí)框架來(lái)處理視頻數(shù)據(jù)集的文章:視頻數(shù)據(jù)集UCF101的處理與加載(未使用深度學(xué)習(xí)框架)
- 上面的處理方法簡(jiǎn)單直接酗捌,但仍有很多可以優(yōu)化的空間呢诬,這兩天又學(xué)習(xí)了一下PyTorch對(duì)于數(shù)據(jù)集加載的支持:PyTorch入門學(xué)習(xí)(七):數(shù)據(jù)加載與處理
- 之前說(shuō)過(guò)要用PyTorch的方法重新實(shí)現(xiàn)一遍對(duì)于UCF101的處理,所以在這里做個(gè)記錄胖缤。這篇文章里僅僅記錄具體的實(shí)現(xiàn)方法尚镰,至于為什么這么做還是點(diǎn)這里哦~~
二 具體目標(biāo)
- 按照trainlist(testllist)中的列表去確定要用哪些數(shù)據(jù)集。
- 對(duì)于每一個(gè)視頻隨機(jī)取連續(xù)的16幀
- 每一幀都減去RGB平均值
- 對(duì)于每幀先將大小修改到(182,242)
- 然后對(duì)修改過(guò)大小的幀隨機(jī)截取(160,160)
- 每次返回視頻表示: x[batch_size,16,3,160,160], 標(biāo)簽值: y[batch_size]
三 基本實(shí)現(xiàn)思路
鑒于我們現(xiàn)在要處理的數(shù)據(jù)集既不是PyTorch直接提供的哪廓,又不符合最通用的ImageFolder
存儲(chǔ)格式狗唉,我們就一步步地實(shí)現(xiàn)具體的功能。
- 跟例程中最大的區(qū)別在于我們組要處理的視頻涡真,而不是單張圖像分俯,那么就把這一步工作放到
__getitem__
里面去完成。 - 剩下的變換功能放到
transform
里面去完成哆料。
具體的步驟如下所示:
-
首先缸剪,定義數(shù)據(jù)集的類UCF101,這個(gè)類要繼承
dataset
這個(gè)抽象類东亦,并實(shí)現(xiàn)__init__
,__len__
以及__getitem__
這幾個(gè)函數(shù)-
__init__
:完成infolist的讀入及處理還有其他的初始化工作杏节。 -
__len__
:返回?cái)?shù)據(jù)集大小 -
__getitem__
:返回單個(gè)視頻隨機(jī)連續(xù)16幀的讀取和返回 - 其他函數(shù)用于支持以上的功能。
-
-
然后,實(shí)現(xiàn)用于特定圖像預(yù)處理的功能拢锹,并封裝成類。
- 減去RGB的平均值
- 大小調(diào)整成(182,242)
- 隨機(jī)截取成(160,160)
- 轉(zhuǎn)換成Tensor
- 將它們進(jìn)行組合成
(transform)
transform
作為上面UCF101
類的參數(shù)傳入萄喳,并得到實(shí)例化UCF101
得到my_UCF101
對(duì)象卒稳。最后,將
my_UCF101
作為torch.utils.data.DataLoader
類的形參他巨,并根據(jù)需求設(shè)置自己是否需要打亂順序充坑,批大小...
四 完整代碼
原理的部分不懂的話還是建議回去看看這篇哇:PyTorch入門學(xué)習(xí)(七):數(shù)據(jù)加載與處理
這里就不再贅述,直接貼上源代碼了染突。
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import random
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
class ClipSubstractMean(object):
def __init__(self, b=104, g=117, r=123):
self.means = np.array((r, g, b))
def __call__(self, sample):
video_x,video_label=sample['video_x'],sample['video_label']
new_video_x=video_x - self.means
return {'video_x': new_video_x, 'video_label': video_label}
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size=(182,242)):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
video_x,video_label=sample['video_x'],sample['video_label']
h, w = video_x.shape[1],video_x[2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
new_video_x=np.zeros((16,new_h,new_w,3))
for i in range(16):
image=video_x[i,:,:,:]
img = transform.resize(image, (new_h, new_w))
new_video_x[i,:,:,:]=img
return {'video_x': new_video_x, 'video_label': video_label}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size=(160,160)):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
video_x, video_label = sample['video_x'], sample['video_label']
h, w = video_x.shape[1],video_x.shape[2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
new_video_x=np.zeros((16,new_h,new_w,3))
for i in range(16):
image=video_x[i,:,:,:]
image = image[top: top + new_h,left: left + new_w]
new_video_x[i,:,:,:]=image
return {'video_x': new_video_x, 'video_label': video_label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
video_x, video_label = sample['video_x'], sample['video_label']
# swap color axis because
# numpy image: batch_size x H x W x C
# torch image: batch_size x C X H X W
video_x = video_x.transpose((0, 3, 1, 2))
video_x=np.array(video_x)
video_label = [video_label]
return {'video_x':torch.from_numpy(video_x),'video_label':torch.FloatTensor(video_label)}
class UCF101(Dataset):
"""UCF101 Landmarks dataset."""
def __init__(self, info_list, root_dir, transform=None):
"""
Args:
info_list (string): Path to the info list file with annotations.
root_dir (string): Directory with all the video frames.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(info_list,delimiter=' ', header=None)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
# get (16,240,320,3)
def __getitem__(self, idx):
video_path = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])
video_label=self.landmarks_frame.iloc[idx,1]
video_x=self.get_single_video_x(video_path)
sample = {'video_x':video_x, 'video_label':video_label}
if self.transform:
sample = self.transform(sample)
return sample
def get_single_video_x(self,video_path):
slash_rows=video_path.split('.')
dir_name=slash_rows[0]
video_jpgs_path=os.path.join(self.root_dir,dir_name)
# get the random 16 frame
data=pd.read_csv(os.path.join(video_jpgs_path,'n_frames'),delimiter=' ',header=None)
frame_count=data[0][0]
video_x=np.zeros((16,240,320,3))
image_start=random.randint(1,frame_count-17)
image_id=image_start
for i in range(16):
s="%05d" % image_id
image_name='image_'+s+'.jpg'
image_path=os.path.join(video_jpgs_path,image_name)
tmp_image = io.imread(image_path)
video_x[i,:,:,:]=tmp_image
image_id+=1
return video_x
if __name__=='__main__':
#usage
root_list='/home/hl/Desktop/lovelyqian/CV_Learning/UCF101_jpg/'
info_list='/home/hl/Desktop/lovelyqian/CV_Learning/UCF101_TrainTestlist/trainlist01.txt'
myUCF101=UCF101(info_list,root_list,transform=transforms.Compose([ClipSubstractMean(),Rescale(),RandomCrop(),ToTensor()]))
dataloader=DataLoader(myUCF101,batch_size=8,shuffle=True,num_workers=8)
for i_batch,sample_batched in enumerate(dataloader):
print (i_batch,sample_batched['video_x'].size(),sample_batched['video_label'].size())
整個(gè)代碼不管是在邏輯清晰度還是代碼行數(shù)上都比之前的改進(jìn)了很多捻爷,所以還是要多多學(xué)習(xí)大佬的框架,當(dāng)然能自己實(shí)現(xiàn)一遍也是挺好的啦份企。
參考文獻(xiàn)
視頻數(shù)據(jù)集UCF101的處理與加載(未使用深度學(xué)習(xí)框架)
PyTorch入門學(xué)習(xí)(七):數(shù)據(jù)加載與處理