小白的自我救贖
Dataset類(lèi)
pytorch數(shù)據(jù)加載可通過(guò)自定義數(shù)據(jù)集對(duì)象實(shí)現(xiàn),數(shù)據(jù)集對(duì)象被抽象為Dataset類(lèi)条辟,實(shí)現(xiàn)自定義數(shù)據(jù)集需要繼承Dataset類(lèi)冀偶,并實(shí)現(xiàn)兩個(gè)python魔法方法。
_len _:使得len(dataset)返回?cái)?shù)據(jù)集的大小
_getitem _:返回一個(gè)樣本/一條數(shù)據(jù)吞获。完成圖片的讀取工作。
在python類(lèi)定義中出現(xiàn) init和self,其作用是什么谚鄙?
Python中init和self的意義和作用
由于類(lèi)可以起到模板的作用各拷,因此,可以在創(chuàng)建實(shí)例的時(shí)候闷营,把一些我們認(rèn)為必須綁定的屬性強(qiáng)制填寫(xiě)進(jìn)去烤黍。以學(xué)生類(lèi)為例,通過(guò)定義一個(gè)特殊的_init _方法傻盟,在創(chuàng)建實(shí)例的時(shí)候速蕊,就把name,score等屬性綁上去:
Class Student(object):
def __init__(self,name,score):
self.name = name
self.score = score
注意_init _ 方法的第一個(gè)參數(shù)永遠(yuǎn)是self娘赴,表示創(chuàng)建的實(shí)例本身,因此规哲,在_init _方法內(nèi)部,就可以把各種屬性綁定到self诽表,因?yàn)閟elf就指向創(chuàng)建的實(shí)例本身唉锌。有了_init _方法,在創(chuàng)建實(shí)例的時(shí)候竿奏,就不能傳入空的參數(shù)了袄简,必須傳入與_init _方法匹配的參數(shù),但self不需要傳泛啸,Python解釋器自己會(huì)把實(shí)例變量傳進(jìn)去:
和普通的函數(shù)相比痘番,在類(lèi)中定義的函數(shù)只有一點(diǎn)不同,就是第一個(gè)參數(shù)永遠(yuǎn)是實(shí)例變量self平痰,并且,調(diào)用時(shí)伍纫,不用傳遞該參數(shù)宗雇。除此之外,類(lèi)的方法和普通函數(shù)沒(méi)有什么區(qū)別莹规,所以赔蒲,你仍然可以用默認(rèn)參數(shù)、可變參數(shù)良漱、關(guān)鍵字參數(shù)和命名關(guān)鍵字參數(shù)舞虱。
數(shù)據(jù)類(lèi)代碼:
class FaceLandmarksDataset(Dataset):#定義類(lèi)并繼承Dataset類(lèi)
def __init__(self, csv_file , root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.標(biāo)簽路徑
root_dir (string): Directory with all the images.數(shù)據(jù)路徑
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform#將各種屬性通過(guò)self就指向創(chuàng)建的實(shí)例本身,調(diào)用創(chuàng)建類(lèi)就需要傳入這些參數(shù)
#即csv_file, root_dir, transform=None等母市,其他參數(shù)可參見(jiàn)Dataset類(lèi)
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
#當(dāng)調(diào)用此函數(shù)時(shí)才會(huì)真正讀圖片矾兜,root只是指定數(shù)據(jù)的絕對(duì)路徑。
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
自定義好數(shù)據(jù)集對(duì)象串塑,便可以加載圖像
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/')
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]#相當(dāng)于調(diào)用face_dataset的__getitem__(i)
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
Transforms
由于實(shí)際數(shù)據(jù)圖片的大小都不一樣侄柔,但在處理神經(jīng)網(wǎng)絡(luò)的輸入圖像的時(shí)候都希望它們有一個(gè)相對(duì)固定的大小。因此领舰,需要一些對(duì)圖像進(jìn)行預(yù)處理的工作返帕。
Rescale: 縮放圖像
RandomCrop: 從圖像中隨機(jī)裁剪桐玻,數(shù)據(jù)增強(qiáng)
ToTensor:將numpy表示的圖像轉(zhuǎn)換為torch類(lèi)表示的圖像
用類(lèi)而不是函數(shù)來(lái)實(shí)現(xiàn)以上這三個(gè)功能,主要是考慮到如果用函數(shù)的話荆萤,每次都需要傳入?yún)?shù)镊靴,但是用類(lèi)就可以省掉很多麻煩。我們只需要實(shí)現(xiàn)每個(gè)類(lèi)的_call _函數(shù)和_init _函數(shù)
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
pytorch提供了torchvision.transforms.Compose將Rescale和RandomCrop兩個(gè)變換組合起來(lái)链韭,太感人了偏竟。。梧油。
例如苫耸,我們需要將圖像的較短邊調(diào)整到256,然后從中隨機(jī)截取224的正方形圖像
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
綜上儡陨,可以把數(shù)據(jù)加載和Transform集成起來(lái)褪子,以前可能是我們先讀入數(shù)據(jù),然后對(duì)數(shù)據(jù)進(jìn)行預(yù)處理骗村,如裁剪嫌褪。或者先預(yù)處理在讀入圖像胚股,費(fèi)時(shí)費(fèi)力笼痛。有了torchvision.transforms.Compose便可以動(dòng)態(tài)的讀取處理圖像。
每次抽取一個(gè)樣本琅拌,都會(huì)有以下步驟:
從文件中讀取圖片缨伊;
將轉(zhuǎn)換應(yīng)用于讀入的圖片;
由于做了隨機(jī)選取的操作进宝,還起到了數(shù)據(jù)增強(qiáng)的效果刻坊。
其實(shí)只要把Transform的部分作為形參傳入dataset就可以了,其他的都不變党晋,利用for循環(huán)來(lái)依次獲得數(shù)據(jù)集樣本
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
DataLoader類(lèi)
以上已經(jīng)實(shí)現(xiàn)了dataset與transform的合并谭胚,也實(shí)現(xiàn)了用for循環(huán)來(lái)獲取每一個(gè)樣本數(shù)據(jù),還需要確定:
1.按照batch_size獲得批量數(shù)據(jù)未玻;
2.打亂數(shù)據(jù)順序灾而;
3.用多線程multiprocessing來(lái)加載數(shù)據(jù);
torch.utils.data.DataLoader這個(gè)類(lèi)可以解決了以上所有的問(wèn)題
按照要求設(shè)置DataLoader的參數(shù)即可:
第一個(gè)參數(shù)傳入transformed_dataset扳剿,即已經(jīng)用了transform的Dataset實(shí)例
第二個(gè)參數(shù)傳入batch_size旁趟,表示每個(gè)batch包含多少個(gè)數(shù)據(jù)
第三個(gè)參數(shù)傳入shuffle,布爾型變量庇绽,表示是否打亂
第四個(gè)參數(shù)傳入num_workers表示使用幾個(gè)線程來(lái)加載數(shù)據(jù)
如下即實(shí)現(xiàn)了DataLoader函數(shù)的使用轻庆,及批樣本數(shù)據(jù)的展示癣猾。
ataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
landmarks_batch[i, :, 1].numpy(),
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
總結(jié),對(duì)于普通數(shù)據(jù)余爆,如何加載:
1.首先纷宇,定義數(shù)據(jù)集的類(lèi)(myDataset),這個(gè)類(lèi)要繼承dataset這個(gè)抽象類(lèi)蛾方,并實(shí)現(xiàn)_len _以及_getitem _這兩個(gè)函數(shù)像捶,通常情況還包括初始函數(shù)_init _.
2.實(shí)現(xiàn)用于特定圖像預(yù)處理的功能,并封裝成類(lèi)桩砰。當(dāng)然常用的一些變換可以在torchvision中找到拓春。用torchvision.transforms.Compose將它們進(jìn)行組合成(transform)
3.transform作為上面myDataset類(lèi)的參數(shù)傳入,并得到實(shí)例化myDataset得到(transformed_dataset)對(duì)象亚隅。
4.最后硼莽,將transformed_dataset作為torch.utils.data.DataLoader類(lèi)的形參,并根據(jù)需求設(shè)置自己是否需要打亂順序煮纵,批大小...