PyTorch 知識

PyTorch使用總覽

原文鏈接:https://blog.csdn.net/u014380165/article/details/79222243

參考:PyTorch學(xué)習(xí)之路(level1)——訓(xùn)練一個圖像分類模型徙邻、PyTorch學(xué)習(xí)之路(level2)——自定義數(shù)據(jù)讀取奶赠、PyTorch源碼解讀之torchvision.transforms灭抑、PyTorch源碼解讀之torch.utils.data.DataLoader羡滑、PyTorch源碼解讀之torchvision.models

一甸祭、數(shù)據(jù)讀取

官方代碼庫中有一個接口例子:torchvision.ImageFolder -- 針對的數(shù)據(jù)存放方式是每個文件夾包含一個類的圖像呢蔫,但往往實(shí)際應(yīng)用中可能你的數(shù)據(jù)不是這樣維護(hù)的恕曲,此時需要自定義一個數(shù)據(jù)讀取接口(使用PyTorch中數(shù)據(jù)讀取基類:torch.utils.data.Dataset)

數(shù)據(jù)讀取接口
class customData(data.Dataset):

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        """
        提供數(shù)據(jù)地址(data path)鹏氧、每一文件所屬的類別(label),and other Info wanted(transform\loader\...) --> self.(attributes)
        
        :param root(string): Root directory path.
        :param transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
        :param target_transform(callable, optional): A function/transform that takes in the target and transforms it. 
        :param loader (callable, optional): A function to load an image given its path. 
        
        the data loader where the images are arranged in this way: ::
                root/class_1_xxx.png    
                root/class_2_xxx.png
        ...
                root/class_n_xxx.png    # 此例中佩谣,文件名包含label信息把还,__init__中可不需要額外提供
            
        """
        self.dataset = [os.path.join(root, npy_data) for npy_data in os.listdir(root)]  # 整個數(shù)據(jù)集(圖像)文件的路徑
                
                self.transform = transform  # (optional)
        self.target_transform = target_transform    # (optional)
        self.loader = loader    # (optional)
        
    def __getitem__(self, index):
        """
        :return 相應(yīng)index的data && label

                """
        data = np.load(self.dataset[index])
        
        if self.transform is not None:  # (optional)
            img = self.transform(img)
        if self.target_transform is not None:  # (optional)
            target = self.target_transform(target)
        
        label_txt = self.dataset[index].split('/')[-1][:2]  # (class_n)_xxxx.npy → (class_n)

        if label_txt == 'class_1':
            label = 0
        elif label_txt == 'class_2':
            label = 1
        else:
            raise RuntimeError('Now only support class_1 vs class_2.')

        return data, label

    def __len__(self):
        """
                :return 數(shù)據(jù)集數(shù)量
                
        """
        return len(self.dataset)

上述提到的transforms數(shù)據(jù)預(yù)處理,可以通過torchvision.transforms接口來實(shí)現(xiàn)茸俭。具體請看博客:PyTorch源碼解讀之torchvision.transforms

接口調(diào)用
root_dir = r'xxxxxxxx'  
image_datasets = {x: customData(root=root_dir+x) for x in ['train', 'val', 'test']}

返回的image_datasets(自定義數(shù)據(jù)讀取接口)就和用torchvision.datasets.ImageFolder類(官方提供的數(shù)據(jù)讀取接口)返回的數(shù)據(jù)類型一樣

數(shù)據(jù)迭代器封裝
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
                   for x in ['train', 'valid', 'test']}

torch.utils.data.DataLoader接口將每個batch的圖像數(shù)據(jù)和標(biāo)簽都分別封裝成Tensor吊履,方便以batch進(jìn)行模型批訓(xùn)練,具體可以參考博客: PyTorch源碼解讀之torch.utils.data.DataLoader

至此调鬓,從圖像和標(biāo)簽文件就生成了Tensor類型的數(shù)據(jù)迭代器艇炎,后續(xù)僅需將Tensor對象用torch.autograd.Variable接口封裝成Variable類型(比如train_data=torch.autograd.Variable(train_data),如果要在gpu上運(yùn)行則是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作為模型的輸入

二袖迎、網(wǎng)絡(luò)構(gòu)建

PyTorch框架中提供了一些方便使用的網(wǎng)絡(luò)結(jié)構(gòu)及預(yù)訓(xùn)練模型接口:torchvision.models冕臭,具體可以看博客:PyTorch源碼解讀之torchvision.models。該接口可以直接導(dǎo)入指定的網(wǎng)絡(luò)結(jié)構(gòu)燕锥,并且可以選擇是否用預(yù)訓(xùn)練模型初始化導(dǎo)入的網(wǎng)絡(luò)結(jié)構(gòu)辜贵。示例如下:

import torchvision
model = torchvision.models.resnet50(pretrained=True)  # 導(dǎo)入resnet50的預(yù)訓(xùn)練模型

那么如何自定義網(wǎng)絡(luò)結(jié)構(gòu)呢?在PyTorch中归形,構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)的類都是基于torch.nn.Module這個基類進(jìn)行的托慨,也就是說所有網(wǎng)絡(luò)結(jié)構(gòu)的構(gòu)建都可以通過繼承該類來實(shí)現(xiàn),包括torchvision.models接口中的模型實(shí)現(xiàn)類也是繼承這個基類進(jìn)行重寫的暇榴。自定義網(wǎng)絡(luò)結(jié)構(gòu)可以參考:1厚棵、https://github.com/miraclewkf/MobileNetV2-PyTorch蕉世。該項(xiàng)目中的MobileNetV2.py腳本自定義了網(wǎng)絡(luò)結(jié)構(gòu)。2婆硬、https://github.com/miraclewkf/SENet-PyTorch狠轻。該項(xiàng)目中的se_resnet.py和se_resnext.py腳本分別自定義了不同的網(wǎng)絡(luò)結(jié)構(gòu)。

如果要用某預(yù)訓(xùn)練模型為自定義的網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行參數(shù)初始化彬犯,可以用torch.load接口導(dǎo)入預(yù)訓(xùn)練模型忍啤,然后調(diào)用自定義的網(wǎng)絡(luò)結(jié)構(gòu)對象的load_state_dict方式進(jìn)行參數(shù)初始化膝但,具體可以看https://github.com/miraclewkf/MobileNetV2-PyTorch項(xiàng)目中的train.py腳本中if args.resume條件語句(如下所示)。

if args.resume:
  if os.path.isfile(args.resume):
    print(("=> loading checkpoint '{}'".format(args.resume)))
    checkpoint = torch.load(args.resume)
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.state_dict().items())}
    model.load_state_dict(base_dict)
    else:
      print(("=> no checkpoint found at '{}'".format(args.resume)))

三空入、其他設(shè)置

優(yōu)化函數(shù)通過torch.optim包實(shí)現(xiàn)孙蒙,比如torch.optim.SGD()接口表示隨機(jī)梯度下降履腋。更多優(yōu)化函數(shù)可以看官方文檔:http://pytorch.org/docs/0.3.0/optim.html薇缅。

學(xué)習(xí)率策略通過torch.optim.lr_scheduler接口實(shí)現(xiàn)瞬矩,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch數(shù)減少學(xué)習(xí)率。更多學(xué)習(xí)率變化策略可以看官方文檔:http://pytorch.org/docs/0.3.0/optim.html炼杖。

損失函數(shù)通過torch.nn包實(shí)現(xiàn)灭返,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。

多GPU訓(xùn)練通過torch.nn.DataParallel接口實(shí)現(xiàn)坤邪,比如:model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上訓(xùn)練模型婆殿。

模塊解讀

torch.utils.data.DataLoader

將數(shù)據(jù)讀取接口的輸入按照batch size封裝成Tensor,后續(xù)只需要再包裝成Variable即可作為模型的輸入罩扇,因此該接口有承上啟下的作用

源碼地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

示例:

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
                   for x in ['train', 'valid', 'test']}
  • dataset (Dataset): dataset from which to load the data.
  • batch_size (int, optional): how many samples per batch to load (default: 1).
  • shuffle (bool, optional): set to True to have the data reshuffled at every epoch (default: False).
  • num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • ... ...

從torch.utils.data.DataLoader類生成的對象中取數(shù)據(jù):

train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
    # ...
    pass

此時,調(diào)用DataLoader類的__iter__方法 ??:

    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

使用隊(duì)列queue對象怕磨,完成多線程調(diào)度喂饥;通過迭代器iter,完成batch更替(詳情讀源碼)

torchvision.transforms

基本上PyTorch中的data augmentation操作都可以通過該接口實(shí)現(xiàn)肠鲫,包含resize员帮、crop等常見的data augmentation操作

示例:

import torchvision
import torch
train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                     torchvision.transforms.RandomCrop(224),                                                                            
                                                     torchvision.transofrms.RandomHorizontalFlip(),
                                                     torchvision.transforms.ToTensor(),
                                                     torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
                                                     ])

class custom_dataread(torch.utils.data.Dataset):  # 數(shù)據(jù)讀取接口
    def __init__():
        ...
    def __getitem__():
        # use self.transform for input image
    def __len__():
        ...

train_loader = torch.utils.data.DataLoader(  # 數(shù)據(jù)迭代器
    custom_dataread(transform=train_augmentation),
    batch_size = batch_size, shuffle = True,
    num_workers = workers, pin_memory = True)

這里定義了resize、crop导饲、normalize等數(shù)據(jù)預(yù)處理操作捞高,并最終作為數(shù)據(jù)讀取類custom_dataread的一個參數(shù)傳入,可以在內(nèi)部方法__getitem__中實(shí)現(xiàn)數(shù)據(jù)增強(qiáng)操作渣锦。

源碼地址:transformas.py --- 定義各種data augmentation的類硝岗、functional.py --- 提供transformas.py中所需功能函數(shù)的實(shí)現(xiàn)

  • Compose類:Composes several transforms together. 對輸入圖像img逐次應(yīng)用輸入的[transform_1, transform_2, ...]操作

  • ToTensor類:Convert a PIL Image or numpy.ndarray to tensor. 要強(qiáng)調(diào)的是在做數(shù)據(jù)歸一化之前必須要把PIL Image轉(zhuǎn)成Tensor,而其他resize或crop操作則不需要.

  • ToPILImage類:Convert a tensor or an ndarray to PIL Image.

  • Normalize類:Normalize a tensor image with mean and standard deviation.一般都會對輸入數(shù)據(jù)做歸一化操作

  • Resize類:Resize the input PIL Image to the given size. 幾乎都要用到袋毙,這里輸入可以是int型檀,此時表示將輸入圖像的短邊resize到這個int數(shù),長邊則根據(jù)對應(yīng)比例調(diào)整听盖,圖像的長寬比不變胀溺。如果輸入是個(h,w)的序列裂七,h和w都是int,則直接將輸入圖像resize到這個(h,w)尺寸仓坞,相當(dāng)于force resize背零,所以一般最后圖像的長寬比會變化,也就是圖像內(nèi)容被拉長或縮短无埃。若輸入是PIL Image徙瓶,則將調(diào)用Image的各種方法;若輸入是Tensor录语,則對應(yīng)函數(shù)基本是在調(diào)用Tensor的各種方法倍啥。

  • CenterCrop類:Crops the given PIL Image at the center. 一般數(shù)據(jù)增強(qiáng)不會采用這個,因?yàn)楫?dāng)size固定的時候澎埠,在相同輸入圖像的情況下虽缕,N次CenterCrop的結(jié)果都是一樣的

  • RandomCrop類:Crop the given PIL Image at a random location. 相較CenterCrop,隨機(jī)裁剪更常用

  • RandomResizedCrop類:Crop the given PIL Image to random size and aspect ratio. 根據(jù)隨機(jī)生成的scale蒲稳、aspect ratio(縮放比例氮趋、長寬比)、中心點(diǎn)裁剪原圖江耀,(為可正常訓(xùn)練)再縮放為輸入的size大小

  • RandomHorizontalFlip類:Horizontally flip the given PIL Image randomly with a given probability. 隨機(jī)的圖像水平翻轉(zhuǎn)剩胁,通俗講就是圖像的左右對調(diào),較常用祥国。 probability of the image being flipped. Default value is 0.5 (水平翻轉(zhuǎn)的概率是0.5)

  • RandomVerticalFlip類:Vertically flip the given PIL Image randomly with a given probability. 隨機(jī)的圖像豎直翻轉(zhuǎn)昵观,通俗講就是圖像的上下對調(diào),較常用舌稀。probability of the image being flipped. Default value is 0.5(豎直翻轉(zhuǎn)的概率是0.5)

  • FiveCrop類:Crop the given PIL Image into four corners and the central crop. 曾在TSN算法的看到過這種用法啊犬。

  • TenCrop類:Crop the given PIL Image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default) 將輸入圖像進(jìn)行水平或豎直翻轉(zhuǎn),然后再進(jìn)行FiveCrop操作壁查;加上原始的FiveCrop操作觉至,這樣一張輸入圖像就能得到10張crop結(jié)果。

  • LinearTransformation類:Transform a tensor image with a square transformation matrix and a mean_vector computed offline. 用一個變換矩陣去乘輸入圖像得到輸出結(jié)果睡腿。

  • ColorJitter類:Randomly change the brightness, contrast, saturation and hue (即亮度语御,對比度,飽和度和色調(diào))of an image席怪,可以根據(jù)注釋來合理設(shè)置這4個參數(shù)应闯。(較常用)

  • RandomRotation類:隨機(jī)旋轉(zhuǎn)輸入圖像,具體參數(shù)可以看注釋何恶,在F.rotate()中主要是調(diào)用PIL Image的rotate方法孽锥。(較常用)

  • Grayscale類:用來將輸入圖像轉(zhuǎn)成灰度圖的,這里根據(jù)參數(shù)num_output_channels的不同有兩種轉(zhuǎn)換方式

  • RandomGrayscale類:Randomly convert image to grayscale with a probability of p (default 0.1).

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市惜辑,隨后出現(xiàn)的幾起案子唬涧,更是在濱河造成了極大的恐慌,老刑警劉巖盛撑,帶你破解...
    沈念sama閱讀 212,718評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件碎节,死亡現(xiàn)場離奇詭異,居然都是意外死亡抵卫,警方通過查閱死者的電腦和手機(jī)狮荔,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,683評論 3 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來介粘,“玉大人殖氏,你說我怎么就攤上這事∫霾桑” “怎么了雅采?”我有些...
    開封第一講書人閱讀 158,207評論 0 348
  • 文/不壞的土叔 我叫張陵,是天一觀的道長慨亲。 經(jīng)常有香客問我婚瓜,道長,這世上最難降的妖魔是什么刑棵? 我笑而不...
    開封第一講書人閱讀 56,755評論 1 284
  • 正文 為了忘掉前任巴刻,我火速辦了婚禮,結(jié)果婚禮上蛉签,老公的妹妹穿的比我還像新娘胡陪。我一直安慰自己,他們只是感情好碍舍,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,862評論 6 386
  • 文/花漫 我一把揭開白布督弓。 她就那樣靜靜地躺著,像睡著了一般乒验。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上蒂阱,一...
    開封第一講書人閱讀 50,050評論 1 291
  • 那天锻全,我揣著相機(jī)與錄音,去河邊找鬼录煤。 笑死鳄厌,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的妈踊。 我是一名探鬼主播了嚎,決...
    沈念sama閱讀 39,136評論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了歪泳?” 一聲冷哼從身側(cè)響起萝勤,我...
    開封第一講書人閱讀 37,882評論 0 268
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎呐伞,沒想到半個月后敌卓,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,330評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡伶氢,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,651評論 2 327
  • 正文 我和宋清朗相戀三年趟径,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片癣防。...
    茶點(diǎn)故事閱讀 38,789評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡蜗巧,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出蕾盯,到底是詐尸還是另有隱情幕屹,我是刑警寧澤,帶...
    沈念sama閱讀 34,477評論 4 333
  • 正文 年R本政府宣布刑枝,位于F島的核電站香嗓,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏装畅。R本人自食惡果不足惜靠娱,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,135評論 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望掠兄。 院中可真熱鬧像云,春花似錦、人聲如沸蚂夕。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,864評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽婿牍。三九已至侈贷,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間等脂,已是汗流浹背俏蛮。 一陣腳步聲響...
    開封第一講書人閱讀 32,099評論 1 267
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留上遥,地道東北人搏屑。 一個月前我還...
    沈念sama閱讀 46,598評論 2 362
  • 正文 我出身青樓,卻偏偏與公主長得像粉楚,于是被迫代替她去往敵國和親辣恋。 傳聞我的和親對象是個殘疾皇子亮垫,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,697評論 2 351