引言
本文是datawhale開源社區(qū)GNN組隊學(xué)習(xí)的筆記,絕大部分內(nèi)容出自其中困檩。
torch_geometric.data.InMemoryDataset類的使用是讓數(shù)據(jù)可全部儲存于內(nèi)存的數(shù)據(jù)集购啄,這些數(shù)據(jù)集對應(yīng)的數(shù)據(jù)集類在創(chuàng)建對象時就將所有數(shù)據(jù)都加載到內(nèi)存。然而在一些應(yīng)用場景中扇苞,數(shù)據(jù)集規(guī)模超級大,我們很難有足夠大的內(nèi)存完全存下所有數(shù)據(jù)不从。因此需要一個按需加載樣本到內(nèi)存的數(shù)據(jù)集類牙甫。
Dataset基類
在我們將學(xué)習(xí)為一個包含上千萬個圖樣本的數(shù)據(jù)集構(gòu)建一個數(shù)據(jù)集類。
在PyG中虽界,我們通過繼承torch_geometric.data.Dataset
基類來自定義一個按需加載樣本到內(nèi)存的數(shù)據(jù)集類。
繼承torch_geometric.data.InMemoryDataset基類要實現(xiàn)的方法(raw_file_names(), processed_file_names(),download(),process())涛菠,繼承此基類同樣要實現(xiàn)莉御,此外還需要實現(xiàn)以下方法:
- len():返回數(shù)據(jù)集中的樣本的數(shù)量。
- get():實現(xiàn)加載單個圖的操作俗冻。注意:在內(nèi)部礁叔,getitem()返回通過調(diào)用
get()
來獲取Data
對象,并根據(jù)transform
參數(shù)對它們進(jìn)行選擇性轉(zhuǎn)換言疗。
對于無需下載數(shù)據(jù)集原文件的情況晴圾,我們不重寫download方法即可跳過下載。對于無需對數(shù)據(jù)集做預(yù)處理的情況噪奄,我們不重寫process方法即可跳過預(yù)處理死姚。
合并小圖組成大圖
圖可以有任意數(shù)量的節(jié)點和邊,它不是規(guī)整的數(shù)據(jù)結(jié)構(gòu)勤篮,因此對圖數(shù)據(jù)封裝成批的操作與對圖像和序列等數(shù)據(jù)封裝成批的操作不同都毒。PyTorch Geometric中采用的將多個圖封裝成批的方式是,將小圖作為連通組件(connected component)的形式合并碰缔,構(gòu)建一個大圖账劲。于是小圖的鄰接矩陣存儲在大圖鄰接矩陣的對角線上。大圖的鄰接矩陣、屬性矩陣瀑焦、預(yù)測目標(biāo)矩陣分別為:
此方法有以下關(guān)鍵的優(yōu)勢:
- 依靠消息傳遞方案的GNN運(yùn)算不需要被修改腌且,因為消息仍然不能在屬于不同圖的兩個節(jié)點之間交換。
- 沒有額外的計算或內(nèi)存的開銷榛瓮。例如铺董,這個批處理程序的工作完全不需要對節(jié)點或邊緣特征進(jìn)行任何填充。請注意禀晓,鄰接矩陣沒有額外的內(nèi)存開銷精续,因為它們是以稀疏的方式保存的,只保留非零項粹懒,即邊重付。
小圖中的屬性拼接:
將小圖存儲到大圖中時需要對小圖的屬性做一些修改,一個最顯著的例子就是要對節(jié)點序號增值凫乖。在最一般的形式中确垫,PyTorch Geometric的DataLoader
類會自動對edge_index
張量增值,增加的值為當(dāng)前被處理圖的前面的圖的累積節(jié)點數(shù)量拣凹。比方說森爽,現(xiàn)在對第個圖的edge_index
張量做增值恨豁,前面?zhèn)€圖的累積節(jié)點數(shù)量為嚣镜,那么對第個圖的edge_index
張量的增值。增值后橘蜜,對所有圖的edge_index
張量(其形狀為[2, num_edges]
)在第二維中連接起來菊匿。
然而,有一些特殊的場景中(如下所述)计福,基于需求我們希望能修改這一行為跌捆。PyTorch Geometric允許我們通過覆蓋torch_geometric.data.__inc__()
和torch_geometric.data.__cat_dim__()
函數(shù)來實現(xiàn)我們希望的行為。
from torch_geometric.data import Data, DataLoader
import torch
class PairData(Data):
#將兩個圖象颖,一個源圖G_s和一個目標(biāo)圖G_t佩厚,存儲在一個Data類中
def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
super(PairData, self).__init__()
self.edge_index_s = edge_index_s
self.x_s = x_s
self.edge_index_t = edge_index_t
self.x_t = x_t
#c重寫__inc__()兩個連續(xù)的圖的屬性之間的增量大小
def __inc__(self, key, value):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
else:
return super().__inc__(key, value)
# 定義邊索引矩陣
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
#5個節(jié)點,16個特征
x_s = torch.randn(5, 16)
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
#4個節(jié)點 16個特征
x_t = torch.randn(4, 16)
data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
#對應(yīng)小圖 未成功將bacth映射成小圖
print(batch.edge_index_s)
print(batch.edge_index_t)
'''
Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_t=[8, 16])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
[1, 2, 3, 4, 6, 7, 8, 9]])
tensor([[0, 0, 0, 4, 4, 4],
[1, 2, 3, 5, 6, 7]])
'''
# 利用follow_batch屬性
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))
print(batch)
# Batch(edge_index_s=[2, 8], x_s=[10, 16], x_s_batch=[10], edge_index_t=[2, 6], x_t=[8, 16], x_t_batch=[8])
print(batch.x_s_batch)
# tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
print(batch.x_t_batch)
# tensor([0, 0, 0, 0, 1, 1, 1, 1])
二部圖的增值:
一般來說说订,不同類型的節(jié)點數(shù)量不需要一致抄瓦,于是二部圖的鄰接矩陣可能不是平方矩陣,即可能有
陶冷。對二部圖的封裝成批過程中钙姊,
edge_index
中邊的源節(jié)點與目標(biāo)節(jié)點做的增值操作應(yīng)是不同的。
class BipartiteData(Data):
def __init__(self, edge_index, x_s, x_t):
super(BipartiteData, self).__init__()
self.edge_index = edge_index
self.x_s = x_s
self.x_t = x_t
def __inc__(self, key, value):
if key == 'edge_index':
#源埂伦、目標(biāo)節(jié)點增值
return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
else:
return super().__inc__(key, value)
edge_index = torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 2],
])
x_s = torch.randn(2, 16) # 2 nodes.
x_t = torch.randn(3, 16) # 3 nodes.
data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
# Batch(edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16])
#邊的源節(jié)點增值為2煞额,目標(biāo)節(jié)點增值為3
print(batch.edge_index)
# tensor([[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 3, 4, 4, 5]])
超大圖數(shù)據(jù)實踐
import os
import os.path as osp
import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil
from torch_geometric.data import DataLoader
from tqdm import tqdm
RDLogger.DisableLog('rdApp.*')
class MyPCQM4MDataset(Dataset):
def __init__(self, root):
self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
super(MyPCQM4MDataset, self).__init__(root)
filepath = osp.join(root, 'raw/data.csv.gz')
data_df = pd.read_csv(filepath)
self.smiles_list = data_df['smiles']
self.homolumogap_list = data_df['homolumogap']
@property
def raw_file_names(self):
return 'data.csv.gz'
def download(self):
path = download_url(self.url, self.root)
extract_zip(path, self.root)
os.unlink(path)
shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))
def len(self):
return len(self.smiles_list)
def get(self, idx):
smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
graph = smiles2graph(smiles)
assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
assert(len(graph['node_feat']) == graph['num_nodes'])
x = torch.from_numpy(graph['node_feat']).to(torch.int64)
edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
y = torch.Tensor([homolumogap])
num_nodes = int(graph['num_nodes'])
data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
return data
# 獲取數(shù)據(jù)集劃分
def get_idx_split(self):
split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
return split_dict
dataset = MyPCQM4MDataset('dataset2')
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
for batch in tqdm(dataloader):
pass