文章結(jié)構(gòu)
-
自定義Dataset的基本結(jié)構(gòu)
-
使用Torchvisiom進(jìn)行類型轉(zhuǎn)換
-
使用Torchvision的另一種方法
-
Incorporating Pandas
-
Incorporating Pandas with More Logic
-
使用Data Loader
自定義Dataset的基本結(jié)構(gòu)
- 首先最重要的是要?jiǎng)?chuàng)建dataset類
from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):
def __init__(self, ...):
# 填充
def __getitem__(self, index):
# 填充
return (img, label)
def __len__(self):
return count # 你有多少張圖片
- 這是必須填充用來獲得自定義數(shù)據(jù)集的框架季眷。數(shù)據(jù)集必須包含以下函數(shù)蛙奖,以便稍后由數(shù)據(jù)加載程序使用被环。
__init__() #函數(shù)是初始邏輯發(fā)生的地方,比如讀取csv局荚、分配轉(zhuǎn)換等
__getitem__()#函數(shù)返回?cái)?shù)據(jù)和標(biāo)簽直撤。這個(gè)函數(shù)是從dataloader中被調(diào)用的眼坏,如下所示:
img, label = MyCustomDataset.__getitem__(99) # 有99個(gè)數(shù)據(jù)
- 因此嚷掠,索引參數(shù)(index)是你要返回的第n個(gè)數(shù)據(jù)/圖像(tensor)。
__len__()#返回你的樣本數(shù)量
- 注意
__getitem__()
返回一個(gè)特殊的數(shù)據(jù)類型首量,比如tensor壮吩,numpy array等,如果不是這些類型加缘,在data loader將會報(bào)錯(cuò)鸭叙。
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>
使用Torchvisiom進(jìn)行類型轉(zhuǎn)換
- 一般在
__init__()
里面都會寫成transforms = None
,這是為了方便在調(diào)用dataset類的時(shí)候傳入自定義的transforms
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ..., transforms=None):
# 填充
#...
self.transforms = transforms
def __getitem__(self, index):
# 填充
#...
data = # 從文件或者圖像中讀取的數(shù)據(jù)
if self.transforms is not None:
data = self.transforms(data)
# 如果轉(zhuǎn)換變量不是空
# 按照傳入的轉(zhuǎn)換格式來轉(zhuǎn)換數(shù)據(jù)
return (img, label)
def __len__(self):
return count
if __name__ == '__main__':
# 自定義transforms
transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
# 調(diào)用數(shù)據(jù)集
custom_dataset = MyCustomDataset(..., transformations)
使用Torchvision的另一種方法
- 如果不喜歡在外面自定義transforms拣宏,可以在dataset類里面定義好递雀,不過這樣降低了程序的可讀性。
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ...):
# 填充
#...
# 單獨(dú)定義轉(zhuǎn)換
self.center_crop = transforms.CenterCrop(100)
self.to_tensor = transforms.ToTensor()
# 也可以組合定義
self.transformations = transforms.Compose([
transforms.CenterCrop(100),
transforms.ToTensor()])
def __getitem__(self, index):
# 填充
#...
data = # 從文件或者圖像中讀取的數(shù)據(jù)
#對應(yīng)了在__init__()中定義的三個(gè)transforms
data = self.center_crop(data)
data = self.to_tensor(data)
data = self.trasnformations(data)
return (img, label)
def __len__(self):
return count
if __name__ == '__main__':
# 調(diào)用dataset
custom_dataset = MyCustomDataset(...)
Incorporating Pandas
- 假設(shè)蚀浆,我們想通過pandas從csv文件中讀取數(shù)據(jù)。第一個(gè)例子如下的csv文件搜吧,包含文件名和標(biāo)簽市俊,和一個(gè)額外的操作指示器根據(jù)這個(gè)額外的操作標(biāo)志我們對圖像做一些操作。
File Name |
Label |
Extra Operation |
tr_0.png |
5 |
TRUE |
tr_1.png |
0 |
FALSE |
tr_2.png |
4 |
FALSE |
- 如果我們想建立一個(gè)自定義數(shù)據(jù)集滤奈,讀取圖像位置從這個(gè)csv文件摆昧,然后我們可以做如下操作
class CustomDatasetFromImages(Dataset):
def __init__(self, csv_path):
'''
Args:
csv_path (string): csv文件路徑
img_path (string): 圖片文件路徑
transform: pytorch變換用于變換和張量轉(zhuǎn)換
'''
# Transforms
self.to_tensor = transforms.ToTensor()
# 讀取csv文件
self.data_info = pd.read_csv(csv_path, header=None)
# 第一列包含圖像路徑
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# 第二列是標(biāo)簽
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# 第三列是操作指示符
self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
# 計(jì)算整個(gè)數(shù)據(jù)集的長度
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
# 從pandas df獲取圖片文件名
single_image_name = self.image_arr[index]
# 打開圖片
img_as_img = Image.open(single_image_name)
# 檢查是否有操作
some_operation = self.operation_arr[index]
# 如果有操作的話
if some_operation:
# 對圖像做一些操作
# ...
# ...
pass
# 把圖像變換成張量
img_as_tensor = self.to_tensor(img_as_img)
# 根據(jù)裁剪的panda列獲取圖像的標(biāo)簽
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
if __name__ == "__main__":
# 調(diào)用 dataset
custom_mnist_from_images = CustomDatasetFromImages('../data/mnist_labels.csv')
Incorporating Pandas with More Logic
- 另一個(gè)從csv中讀取圖像的例子,其中每個(gè)像素的值都在一個(gè)列中蜒程。這時(shí)绅你,只需要返回張量以及其標(biāo)簽。數(shù)據(jù)被分成像素昭躺。
Lbel |
pixel_1 |
pixel_2 |
... |
1 |
50 |
99 |
... |
0 |
21 |
223 |
... |
9 |
44 |
112 |
... |
... |
... |
... |
... |
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path, height, width, transforms=None):
'''
Args:
csv_path (string): csv文件路徑
height (int): 圖片高度
width (int): 圖片寬度
transform: pytorch transforms for transforms and tensor conversion
'''
self.data = pd.read_csv(csv_path)
self.labels = np.asarray(self.data.iloc[:, 0])
self.height = height
self.width = width
self.transforms = transform
def __getitem__(self, index):
single_image_label = self.labels[index]
# Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28])
img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
# 將圖像從numpy數(shù)組轉(zhuǎn)換為PIL圖像忌锯,模式“L”為灰度
img_as_img = Image.fromarray(img_as_np)
img_as_img = img_as_img.convert('L')
# 把圖像變換成tensor
if self.transforms is not None:
img_as_tensor = self.transforms(img_as_img)
# 返回圖片和標(biāo)簽
return (img_as_tensor, single_image_label)
def __len__(self):
return len(self.data.index)
if __name__ == "__main__":
transformations = transforms.Compose([transforms.ToTensor()])
custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)
使用Data Loader
- 在pytorch中DataLoader只需要調(diào)用
__getitem__()
然后把他們打包成一個(gè)批次。我們也可以不使用Dataloader每調(diào)用__getitem()__
一次就把數(shù)據(jù)傳入到模型(遠(yuǎn)沒有使用DataLoader方便)领炫。從上面的示例繼續(xù)偶垮,如果我們假設(shè)有一個(gè)名為CustomDatasetFromCSV的自定義數(shù)據(jù)集,那么我們可以像這樣調(diào)用DataLoader
if __name__ == "__main__":
# 定義 transforms
transformations = transforms.Compose([transforms.ToTensor()])
# 定義dataset
custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv',28, 28,transformations)
# 定義data loader
mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
batch_size=10,
shuffle=False)
for images, labels in mn_dataset_loader:
# 將數(shù)據(jù)送入模型
- DataLoader的第一個(gè)參數(shù)是數(shù)據(jù)集,從那里它調(diào)用該數(shù)據(jù)集的
__getitem__()
.batch_size確定一個(gè)批次傳入的數(shù)據(jù)量似舵,如果我們假設(shè)一張圖片的tensor是[1*28*28] ---> [D:1,H:28,W:28]
那么用這個(gè)DataLoader返回的tensor是[10*1*28*28]