Task07 圖預測任務實踐

參考鏈接:https://github.com/datawhalechina/team-learning-nlp/blob/master/GNN

第一部分 超大規(guī)模數(shù)據(jù)集類的創(chuàng)建

當數(shù)據(jù)集規(guī)模超級大時,很難有足夠大的內(nèi)存完全存下所有數(shù)據(jù)。因此需要一個按需加載樣本到內(nèi)存的數(shù)據(jù)集類。

一、Dataset基類

1.1Dataset基類簡介

在PyG中缩功,通過繼承torch_geometric.data.Dataset基類來自定義一個按需加載樣本到內(nèi)存的數(shù)據(jù)集類。
繼承此基類比繼承torch_geometric.data.InMemoryDataset基類要多實現(xiàn)以下方法:

  • len():返回數(shù)據(jù)集中的樣本的數(shù)量。
  • get():實現(xiàn)加載單個圖的操作嗤堰。注意:在內(nèi)部,getitem()返回通過調(diào)用get()來獲取Data對象度宦,并根據(jù)transform參數(shù)對它們進行選擇性轉(zhuǎn)換踢匣。

1.2繼承torch_geometric.data.Dataset基類的代碼實現(xiàn):

import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url

class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

1.3其他注意事項

1.download/process步驟可以跳過

  • 對于無需下載數(shù)據(jù)集原文件的情況,不重寫(override)download方法即可跳過下載戈抄。
  • 對于無需對數(shù)據(jù)集做預處理的情況离唬,不重寫process方法即可跳過預處理。

2.有些Dataset類無需定義
如下划鸽,可以不用定義一個Dataset類输莺,而直接生成一個Dataloader對象戚哎,直接用于訓練。

from torch_geometric.data import Data, DataLoader

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

二嫂用、圖樣本封裝成批(BATCHING)與DataLoader類

2.1合并小圖組成大圖

PyTorch Geometric中采用的是將多個圖封裝成批的方式型凳,將小圖作為連通組件(connected component)的形式合并,構(gòu)建一個大圖嘱函。于是小圖的鄰接矩陣存儲在大圖鄰接矩陣的對角線上甘畅。

此方法有以下關鍵的優(yōu)勢

  • 依靠消息傳遞方案的GNN運算不需要被修改。
  • 沒有額外的計算或內(nèi)存的開銷往弓。

通過torch_geometric.data.DataLoader類疏唾,多個小圖被封裝成一個大圖。torch_geometric.data.DataLoader是PyTorch的DataLoader的子類亮航,覆蓋了collate()函數(shù)荸实,該函數(shù)定義了一列表的樣本是如何封裝成批的。因此缴淋,所有可以傳遞給PyTorch DataLoader的參數(shù)也可以傳遞給PyTorch Geometric的 DataLoader准给。

2.2小圖的屬性增值與拼接

將小圖存儲到大圖中時需要對小圖的屬性做一些修改,一個最顯著的例子就是要對節(jié)點序號增值重抖。在最一般的形式中露氮,PyTorch Geometric的DataLoader類會自動對edge_index張量增值,增加的值為當前被處理圖的前面的圖的累積節(jié)點數(shù)量钟沛。增值后畔规,對所有圖的edge_index張量(其形狀為[2, num_edges])在第二維中連接起來。

2.2.1圖的匹配(Pairs of Graphs)

不同類型的節(jié)點數(shù)量不一致恨统,edge_index邊的源節(jié)點與目標節(jié)點進行增值操作不同叁扫。

2.2.2二部圖(Bipartite Graphs)

二部圖是圖論中的一種特殊模型。設G=(V,E)是一個無向圖畜埋,如果頂點V可分割為兩個互不相交的子集(A,B)莫绣,并且圖中的每條邊(i,j)所關聯(lián)的兩個頂點i和j分別屬于這兩個不同的頂點集(i in A,j in B)悠鞍,則稱圖G為一個二部圖对室。它的鄰接矩陣定義兩種類型的節(jié)點之間的連接關系。一般來說咖祭,不同類型的節(jié)點數(shù)量不需要一致掩宜,于是二部圖的鄰接矩陣A \in {0,1}^{N \times M}可能為平方矩陣,即可能有N \neq M么翰。

2.2.3在新的維度上做拼接

有時牺汤,Data對象的屬性需要在一個新的維度上做拼接(如經(jīng)典的封裝成批),例如浩嫌,圖級別屬性或預測目標慧瘤。具體來說戴已,形狀為[num_features]的屬性列表應該被返回為[num_examples, num_features],而不是[num_examples * num_features]锅减。PyTorch Geometric通過在__cat_dim__()中返回一個None的連接維度來實現(xiàn)這一點。

 class MyData(Data):
     def __cat_dim__(self, key, item):
         if key == 'foo':
             return None
         else:
             return super().__cat_dim__(key, item)

edge_index = torch.tensor([
   [0, 1, 1, 2],
   [1, 0, 2, 1],
])
foo = torch.randn(16)

data = MyData(edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# Batch(edge_index=[2, 8], foo=[2, 16])

正如期望的伐坏,batch.foo現(xiàn)在由兩個維度來表示怔匣,一個批維度,一個特征維度桦沉。

三每瞒、創(chuàng)建超大規(guī)模數(shù)據(jù)集類實踐

PCQM4M-LSC是一個分子圖的量子特性回歸數(shù)據(jù)集,它包含了3,803,453個圖纯露。
定義的數(shù)據(jù)集類如下:

import os
import os.path as osp

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil

RDLogger.DisableLog('rdApp.*')

class MyPCQM4MDataset(Dataset):

    def __init__(self, root):
        self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
        super(MyPCQM4MDataset, self).__init__(root)

        filepath = osp.join(root, 'raw/data.csv.gz')
        data_df = pd.read_csv(filepath)
        self.smiles_list = data_df['smiles']
        self.homolumogap_list = data_df['homolumogap']

    @property
    def raw_file_names(self):
        return 'data.csv.gz'

    def download(self):
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))

    def len(self):
        return len(self.smiles_list)

    def get(self, idx):
        smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
        graph = smiles2graph(smiles)
        assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
        assert(len(graph['node_feat']) == graph['num_nodes'])

        x = torch.from_numpy(graph['node_feat']).to(torch.int64)
        edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
        edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
        y = torch.Tensor([homolumogap])
        num_nodes = int(graph['num_nodes'])
        data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
        return data

    # 獲取數(shù)據(jù)集劃分
    def get_idx_split(self):
        split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
        return split_dict

if __name__ == "__main__":
    dataset = MyPCQM4MDataset('dataset2')
    from torch_geometric.data import DataLoader
    from tqdm import tqdm
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
    for batch in tqdm(dataloader):
        pass

在生成一個該數(shù)據(jù)集類的對象時剿骨,程序

  • 首先會檢查指定的文件夾下是否存在data.csv.gz文件,如果不在埠褪,則會執(zhí)行download方法浓利,這一過程是在運行super類的__init__方法中發(fā)生的。
  • 然后程序繼續(xù)執(zhí)行__init__方法的剩余部分钞速,讀取data.csv.gz文件贷掖,獲取存儲圖信息的smiles格式的字符串,以及回歸預測的目標homolumogap渴语。由smiles格式的字符串轉(zhuǎn)成圖的過程在get()方法中實現(xiàn)苹威,這樣在生成一個DataLoader變量時,通過指定num_workers可以實現(xiàn)并行執(zhí)行生成多個圖驾凶。

第二部分 圖預測任務實踐

1.通過試驗尋找最佳超參數(shù)

通過運行以下的命令即可運行一次試驗:

#!/bin/sh

python main.py  --task_name GINGraphPooling\    # 為當前試驗取名
                --device 0\                     
                --num_layers 5\                 # 使用GINConv層數(shù)
                --graph_pooling sum\            # 圖讀出方法
                --emb_dim 256\                  # 節(jié)點嵌入維度
                --drop_ratio 0.\
                --save_test\                    # 是否對測試集做預測并保留預測結(jié)果
                --batch_size 512\
                --epochs 100\
                --weight_decay 0.00001\
                --early_stop 10\                # 當有`early_stop`個epoches驗證集結(jié)果沒有提升牙甫,則停止訓練
                --num_workers 4\
                --dataset_root dataset          # 存放數(shù)據(jù)集的根目錄

試驗運行開始后,程序會在saves目錄下創(chuàng)建一個task_name參數(shù)指定名稱的文件夾用于記錄試驗過程调违,當saves目錄下已經(jīng)有一個同名的文件夾時窟哺,程序會在task_name參數(shù)末尾增加一個后綴作為文件夾名稱。試驗運行過程中翰萨,所有的print輸出都會寫入到試驗文件夾下的output文件脏答,tensorboard.SummaryWriter記錄的信息也存儲在試驗文件夾下的文件中。

修改上方的命令再執(zhí)行亩鬼,即可試驗不同的超參數(shù)殖告,所有試驗的過程與結(jié)果信息都存儲于saves文件夾下。啟動TensorBoard會話雳锋,選擇saves文件夾黄绩,即可查看所有試驗的過程與結(jié)果信息。

2.總結(jié)

在此圖預測任務實踐中:

  • 此次將前面所學的基于GIN的圖表示學習神經(jīng)網(wǎng)絡和超大規(guī)模數(shù)據(jù)集類的創(chuàng)建方法付諸于實際應用玷过;
  • 構(gòu)建了一種很方便的設置不同參數(shù)進行試驗的方法爽丹,不同試驗的過程與結(jié)果信息通過簡單的操作即可進行比較分析筑煮。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市粤蝎,隨后出現(xiàn)的幾起案子真仲,更是在濱河造成了極大的恐慌,老刑警劉巖初澎,帶你破解...
    沈念sama閱讀 218,858評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件秸应,死亡現(xiàn)場離奇詭異,居然都是意外死亡碑宴,警方通過查閱死者的電腦和手機软啼,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,372評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來延柠,“玉大人祸挪,你說我怎么就攤上這事≌昙洌” “怎么了贿条?”我有些...
    開封第一講書人閱讀 165,282評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長榜跌。 經(jīng)常有香客問我闪唆,道長,這世上最難降的妖魔是什么钓葫? 我笑而不...
    開封第一講書人閱讀 58,842評論 1 295
  • 正文 為了忘掉前任悄蕾,我火速辦了婚禮,結(jié)果婚禮上础浮,老公的妹妹穿的比我還像新娘帆调。我一直安慰自己,他們只是感情好豆同,可當我...
    茶點故事閱讀 67,857評論 6 392
  • 文/花漫 我一把揭開白布番刊。 她就那樣靜靜地躺著,像睡著了一般影锈。 火紅的嫁衣襯著肌膚如雪芹务。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,679評論 1 305
  • 那天鸭廷,我揣著相機與錄音枣抱,去河邊找鬼。 笑死辆床,一個胖子當著我的面吹牛佳晶,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播讼载,決...
    沈念sama閱讀 40,406評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼轿秧,長吁一口氣:“原來是場噩夢啊……” “哼中跌!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起菇篡,我...
    開封第一講書人閱讀 39,311評論 0 276
  • 序言:老撾萬榮一對情侶失蹤漩符,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后驱还,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體陨仅,經(jīng)...
    沈念sama閱讀 45,767評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,945評論 3 336
  • 正文 我和宋清朗相戀三年铝侵,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片触徐。...
    茶點故事閱讀 40,090評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡咪鲜,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出撞鹉,到底是詐尸還是另有隱情疟丙,我是刑警寧澤,帶...
    沈念sama閱讀 35,785評論 5 346
  • 正文 年R本政府宣布鸟雏,位于F島的核電站享郊,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏孝鹊。R本人自食惡果不足惜炊琉,卻給世界環(huán)境...
    茶點故事閱讀 41,420評論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望又活。 院中可真熱鬧苔咪,春花似錦、人聲如沸柳骄。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,988評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽耐薯。三九已至舔清,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間曲初,已是汗流浹背体谒。 一陣腳步聲響...
    開封第一講書人閱讀 33,101評論 1 271
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留复斥,地道東北人营密。 一個月前我還...
    沈念sama閱讀 48,298評論 3 372
  • 正文 我出身青樓,卻偏偏與公主長得像目锭,于是被迫代替她去往敵國和親评汰。 傳聞我的和親對象是個殘疾皇子纷捞,可洞房花燭夜當晚...
    茶點故事閱讀 45,033評論 2 355

推薦閱讀更多精彩內(nèi)容