PyTorch自定義數(shù)據(jù)集示例(2019-12-19)

文章結(jié)構(gòu)

  • 自定義Dataset的基本結(jié)構(gòu)

  • 使用Torchvisiom進(jìn)行類型轉(zhuǎn)換

  • 使用Torchvision的另一種方法

  • Incorporating Pandas

  • Incorporating Pandas with More Logic

  • 使用Data Loader

自定義Dataset的基本結(jié)構(gòu)

  • 首先最重要的是要?jiǎng)?chuàng)建dataset類
from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # 填充
        
    def __getitem__(self, index):
        # 填充
        return (img, label)

    def __len__(self):
        return count # 你有多少張圖片
  • 這是必須填充用來獲得自定義數(shù)據(jù)集的框架季眷。數(shù)據(jù)集必須包含以下函數(shù)蛙奖,以便稍后由數(shù)據(jù)加載程序使用被环。
__init__() #函數(shù)是初始邏輯發(fā)生的地方,比如讀取csv局荚、分配轉(zhuǎn)換等
__getitem__()#函數(shù)返回?cái)?shù)據(jù)和標(biāo)簽直撤。這個(gè)函數(shù)是從dataloader中被調(diào)用的眼坏,如下所示:
img, label = MyCustomDataset.__getitem__(99)  # 有99個(gè)數(shù)據(jù)

  • 因此嚷掠,索引參數(shù)(index)是你要返回的第n個(gè)數(shù)據(jù)/圖像(tensor)。
__len__()#返回你的樣本數(shù)量
  • 注意__getitem__()返回一個(gè)特殊的數(shù)據(jù)類型首量,比如tensor壮吩,numpy array等,如果不是這些類型加缘,在data loader將會報(bào)錯(cuò)鸭叙。
    TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>

使用Torchvisiom進(jìn)行類型轉(zhuǎn)換

  • 一般在__init__()里面都會寫成transforms = None,這是為了方便在調(diào)用dataset類的時(shí)候傳入自定義的transforms
from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # 填充
        #...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # 填充
        #...
        data = # 從文件或者圖像中讀取的數(shù)據(jù)
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果轉(zhuǎn)換變量不是空
        # 按照傳入的轉(zhuǎn)換格式來轉(zhuǎn)換數(shù)據(jù)
        return (img, label)

    def __len__(self):
        return count
        
if __name__ == '__main__':
    # 自定義transforms
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 調(diào)用數(shù)據(jù)集
    custom_dataset = MyCustomDataset(..., transformations)

使用Torchvision的另一種方法

  • 如果不喜歡在外面自定義transforms拣宏,可以在dataset類里面定義好递雀,不過這樣降低了程序的可讀性。
from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # 填充
        #...
        # 單獨(dú)定義轉(zhuǎn)換
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # 也可以組合定義
        self.transformations = transforms.Compose([
                                transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # 填充
        #...
        data = # 從文件或者圖像中讀取的數(shù)據(jù)
        
        #對應(yīng)了在__init__()中定義的三個(gè)transforms
        data = self.center_crop(data)  
        data = self.to_tensor(data)  
        data = self.trasnformations(data) 
        
        return (img, label)

    def __len__(self):
        return count 
        
if __name__ == '__main__':
    # 調(diào)用dataset
    custom_dataset = MyCustomDataset(...)

Incorporating Pandas

  • 假設(shè)蚀浆,我們想通過pandas從csv文件中讀取數(shù)據(jù)。第一個(gè)例子如下的csv文件搜吧,包含文件名和標(biāo)簽市俊,和一個(gè)額外的操作指示器根據(jù)這個(gè)額外的操作標(biāo)志我們對圖像做一些操作。
File Name Label Extra Operation
tr_0.png 5 TRUE
tr_1.png 0 FALSE
tr_2.png 4 FALSE
  • 如果我們想建立一個(gè)自定義數(shù)據(jù)集滤奈,讀取圖像位置從這個(gè)csv文件摆昧,然后我們可以做如下操作
class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        '''
        Args:
            csv_path (string): csv文件路徑
            img_path (string): 圖片文件路徑
            transform: pytorch變換用于變換和張量轉(zhuǎn)換
        '''
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # 讀取csv文件
        self.data_info = pd.read_csv(csv_path, header=None)
        # 第一列包含圖像路徑
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # 第二列是標(biāo)簽
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # 第三列是操作指示符
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # 計(jì)算整個(gè)數(shù)據(jù)集的長度
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # 從pandas df獲取圖片文件名
        single_image_name = self.image_arr[index]
        # 打開圖片
        img_as_img = Image.open(single_image_name)

        # 檢查是否有操作
        some_operation = self.operation_arr[index]
        # 如果有操作的話
        if some_operation:
            # 對圖像做一些操作
            # ...
            # ...
            pass
        # 把圖像變換成張量
        img_as_tensor = self.to_tensor(img_as_img)

        # 根據(jù)裁剪的panda列獲取圖像的標(biāo)簽
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # 調(diào)用 dataset
    custom_mnist_from_images = CustomDatasetFromImages('../data/mnist_labels.csv')

Incorporating Pandas with More Logic

  • 另一個(gè)從csv中讀取圖像的例子,其中每個(gè)像素的值都在一個(gè)列中蜒程。這時(shí)绅你,只需要返回張量以及其標(biāo)簽。數(shù)據(jù)被分成像素昭躺。
Lbel pixel_1 pixel_2 ...
1 50 99 ...
0 21 223 ...
9 44 112 ...
... ... ... ...
class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        '''
        Args:
            csv_path (string): csv文件路徑
            height (int): 圖片高度
            width (int): 圖片寬度
            transform: pytorch transforms for transforms and tensor conversion
        '''
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform

    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
    # 將圖像從numpy數(shù)組轉(zhuǎn)換為PIL圖像忌锯,模式“L”為灰度
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert('L')
        # 把圖像變換成tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回圖片和標(biāo)簽
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return len(self.data.index)
        

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)

使用Data Loader

  • 在pytorch中DataLoader只需要調(diào)用__getitem__()然后把他們打包成一個(gè)批次。我們也可以不使用Dataloader每調(diào)用__getitem()__一次就把數(shù)據(jù)傳入到模型(遠(yuǎn)沒有使用DataLoader方便)领炫。從上面的示例繼續(xù)偶垮,如果我們假設(shè)有一個(gè)名為CustomDatasetFromCSV的自定義數(shù)據(jù)集,那么我們可以像這樣調(diào)用DataLoader
if __name__ == "__main__":
    # 定義 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 定義dataset
    custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv',28, 28,transformations)
    # 定義data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 將數(shù)據(jù)送入模型
  • DataLoader的第一個(gè)參數(shù)是數(shù)據(jù)集,從那里它調(diào)用該數(shù)據(jù)集的__getitem__().batch_size確定一個(gè)批次傳入的數(shù)據(jù)量似舵,如果我們假設(shè)一張圖片的tensor是[1*28*28] ---> [D:1,H:28,W:28]那么用這個(gè)DataLoader返回的tensor是[10*1*28*28]
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末脚猾,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子砚哗,更是在濱河造成了極大的恐慌龙助,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,451評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件蛛芥,死亡現(xiàn)場離奇詭異提鸟,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)常空,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,172評論 3 394
  • 文/潘曉璐 我一進(jìn)店門沽一,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人漓糙,你說我怎么就攤上這事铣缠。” “怎么了昆禽?”我有些...
    開封第一講書人閱讀 164,782評論 0 354
  • 文/不壞的土叔 我叫張陵蝗蛙,是天一觀的道長。 經(jīng)常有香客問我醉鳖,道長捡硅,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,709評論 1 294
  • 正文 為了忘掉前任盗棵,我火速辦了婚禮壮韭,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘纹因。我一直安慰自己喷屋,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,733評論 6 392
  • 文/花漫 我一把揭開白布瞭恰。 她就那樣靜靜地躺著屯曹,像睡著了一般。 火紅的嫁衣襯著肌膚如雪惊畏。 梳的紋絲不亂的頭發(fā)上恶耽,一...
    開封第一講書人閱讀 51,578評論 1 305
  • 那天,我揣著相機(jī)與錄音颜启,去河邊找鬼偷俭。 笑死,一個(gè)胖子當(dāng)著我的面吹牛农曲,可吹牛的內(nèi)容都是我干的社搅。 我是一名探鬼主播驻债,決...
    沈念sama閱讀 40,320評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼形葬!你這毒婦竟也來了合呐?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,241評論 0 276
  • 序言:老撾萬榮一對情侶失蹤笙以,失蹤者是張志新(化名)和其女友劉穎淌实,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體猖腕,經(jīng)...
    沈念sama閱讀 45,686評論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡拆祈,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,878評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了倘感。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片放坏。...
    茶點(diǎn)故事閱讀 39,992評論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖老玛,靈堂內(nèi)的尸體忽然破棺而出淤年,到底是詐尸還是另有隱情,我是刑警寧澤蜡豹,帶...
    沈念sama閱讀 35,715評論 5 346
  • 正文 年R本政府宣布麸粮,位于F島的核電站,受9級特大地震影響镜廉,放射性物質(zhì)發(fā)生泄漏弄诲。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,336評論 3 330
  • 文/蒙蒙 一娇唯、第九天 我趴在偏房一處隱蔽的房頂上張望齐遵。 院中可真熱鬧,春花似錦塔插、人聲如沸洛搀。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,912評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至彰檬,卻和暖如春伸刃,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背逢倍。 一陣腳步聲響...
    開封第一講書人閱讀 33,040評論 1 270
  • 我被黑心中介騙來泰國打工捧颅, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人较雕。 一個(gè)月前我還...
    沈念sama閱讀 48,173評論 3 370
  • 正文 我出身青樓碉哑,卻偏偏與公主長得像挚币,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子扣典,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,947評論 2 355