一碘勉、寫在前面
PyG 是一款基于PyTorch 的圖神經(jīng)網(wǎng)絡(luò)庫,它提供了很多經(jīng)典的圖神經(jīng)網(wǎng)絡(luò)模型和圖數(shù)據(jù)集栋艳。
在使用 PyG 框架來構(gòu)建和訓(xùn)練圖網(wǎng)絡(luò)模型時(shí)恰聘,需要事先將圖數(shù)據(jù)換成PyG定義的“圖對(duì)象”句各。
PyG 提供多種類型的圖對(duì)象(在torch_geometric.data下)吸占,常用的包括:Data(同構(gòu)圖)和HeteroData(異構(gòu)圖)。
二凿宾、基本用法(以Data對(duì)象為例)
2.1) 構(gòu)建圖對(duì)象
構(gòu)建一張圖的Data對(duì)象時(shí)矾屯,通常需要提供以下基本數(shù)據(jù):
from torch_geometric.data import Data
Data ( x: Optional[torch.Tensor] = None,
?? edge_index: Optional[torch.Tensor] = None,
?? edge_attr: Optional[torch.Tensor] = None,
?? y: Optional[torch.Tensor] = None,
?? pos: Optional[torch.Tensor] = None,
?? **kwargs)
節(jié)點(diǎn):節(jié)點(diǎn)名稱用數(shù)字序號(hào)表示:0,1初厚,2件蚕,... ,num _ node-1(共num_nodes個(gè)節(jié)點(diǎn))产禾,這是默認(rèn)且固定的排作,不需要指定。
x:節(jié)點(diǎn)特征矩陣:shape為[num_nodes, num_node_features]
一張圖的所有節(jié)點(diǎn)的特征存儲(chǔ)于該二維矩陣中亚情,即一行表示一個(gè)節(jié)點(diǎn)妄痪,一列表示一個(gè)特征(一個(gè)節(jié)點(diǎn)可以有多個(gè)特征),行序號(hào)對(duì)應(yīng)節(jié)點(diǎn)序號(hào)(0楞件,1衫生,2,... 土浸,num _ node-1)罪针。edge_index:邊矩陣:shape為[2, num_edges]
一張圖的所有邊存儲(chǔ)于該二維矩陣中,其中第一行表示所有邊的起始節(jié)點(diǎn)編號(hào)黄伊,第二行表示所有邊的目標(biāo)節(jié)點(diǎn)編號(hào)泪酱,類型為 torch.long。
(注意:edge _ index 中的元素必須在{0,1墓阀,2愈腾,... ,num _ node-1})edge_attr:邊特征矩陣:shape為[num_edges, num_edge_features]
一張圖的所有邊的特征存儲(chǔ)于該二維矩陣中岂津,即一行表示一條邊虱黄,一列表示一個(gè)特征(一條邊可以有多個(gè)特征)。y:訓(xùn)練標(biāo)簽:(可能具有任意形狀)吮成。例如橱乱,如果是節(jié)點(diǎn)級(jí)別的標(biāo)簽,其形狀為 [num_nodes, *]粱甫;如果是圖級(jí)別的標(biāo)簽泳叠, 其形狀為為 [1,*]。
節(jié)點(diǎn)位置(pos):記錄每個(gè)節(jié)點(diǎn)的具體位置茶宵,存儲(chǔ)于shape為[num_nodes, num_dimensions]的二維矩陣中危纫。
上述信息通常需要用戶提前準(zhǔn)備好,才能構(gòu)建一個(gè)Data對(duì)象乌庶,但都不是必須要提供的种蝶。一般對(duì)于一張圖而言,最重要的是節(jié)點(diǎn)特征矩陣瞒大、邊矩陣螃征、邊特征矩陣。
Data 對(duì)象有點(diǎn)類似 Python 中的字典透敌,屬性和數(shù)據(jù)用鍵值對(duì)表示盯滚,因此可以用點(diǎn)“.”或方括號(hào)“[]”來訪問、修改酗电、增加其內(nèi)部的數(shù)據(jù)魄藕,就跟字典的操作方式一樣。
2.2) 圖對(duì)象的方法
見‘舉例1’一節(jié)的3.3)
三撵术、舉例1(簡單例子)
目標(biāo):為下圖創(chuàng)建Data對(duì)象:
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
3.1)圖的原始圖數(shù)據(jù)準(zhǔn)備:
首先將上圖中的圖(Graph)轉(zhuǎn)化成對(duì)應(yīng)的tensor(非常重要背率,決定了后面的圖對(duì)象是否能正確構(gòu)建)。
# 節(jié)點(diǎn)特征矩陣(一行對(duì)應(yīng)一個(gè)節(jié)點(diǎn)的特征荷荤,每個(gè)節(jié)點(diǎn)有3個(gè)特征)
>>my_node_features = torch.tensor([[-1, -1, -1],
[-2, -2, -2],
[-3, -3, -3],
[-4, -4, -4]],dtype=torch.float)
# 邊的節(jié)點(diǎn)對(duì)退渗,共有7條邊(四個(gè)節(jié)點(diǎn):0、1蕴纳、2会油、3),必須用7組節(jié)點(diǎn)對(duì)來表示
>>my_edge_index = torch.tensor([[0, 1, 2, 1, 3, 2, 3],
[1, 2, 1, 3, 1, 3, 2]], dtype=torch.long)
# 邊特征矩陣(一行對(duì)應(yīng)一條邊的特征古毛,每條邊有4個(gè)特征)
>>my_edge_attr = torch.tensor([[11, 11, 11, 11],
[22, 22, 22, 22],
[33, 33, 33, 33],
[44, 44, 44, 44],
[55, 55, 55, 55],
[66, 66, 66, 66],
[77, 77, 77, 77]], dtype=torch.float)
# 邊權(quán)重翻翩,共有7個(gè)邊權(quán)重都许,一條邊一個(gè)
>>my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float)
3.2)根據(jù)圖的原始數(shù)據(jù)構(gòu)建PyG圖對(duì)象(Data對(duì)象):
>>pyg_G = Data(x=my_node_features,
edge_index=my_edge_index,
edge_attr=my_edge_attr,
edge_weight=my_edge_weight)
>>print(pyg_G)
輸出:
Data(x=[4, 3], edge_index=[2, 7], edge_attr=[7, 4], edge_weight=[7])
★對(duì)PyG對(duì)象輸出信息的解讀很重要(特別是對(duì)于無法圖像化的大圖)!
3.3)圖對(duì)象(Data對(duì)象)提供的幾種常用方法(其他方法使用‘dir(圖對(duì)象)’獲壬┒场):
- .num_nodes:返回節(jié)點(diǎn)個(gè)數(shù)(int)
- .num_node_types:返回節(jié)點(diǎn)種類數(shù)(int)
- .num_node_features:返回節(jié)點(diǎn)特征數(shù)(int)
- .node_attrs():返回與節(jié)點(diǎn)相關(guān)的屬性名列表(str list)
>>pyg_G.node_attrs()
['x']
- .x:★返回節(jié)點(diǎn)特征矩陣(tensor array)
>>pyg_G.x
tensor([[-1., -1., -1.],
[-2., -2., -2.],
[-3., -3., -3.],
[-4., -4., -4.]])
.num_edges:返回邊條數(shù)(int)
.num_edge_types:返回邊種類數(shù)(int)
.num_edge_features:返回邊特征數(shù)(int)
.edge_index:★返回邊的節(jié)點(diǎn)對(duì)(tensor array)
>>pyg_G.edge_index
tensor([[0, 1, 2, 1, 3, 2, 3],
[1, 2, 1, 3, 1, 3, 2]])
- edge_attrs():返回與邊相關(guān)的屬性名列表(str list)
pyg_G.edge_attrs()
['edge_weight', 'edge_attr', 'edge_index']
- pyg_G.edge_weight:返回邊權(quán)重(tensor array)
>>pyg_G.edge_weight
tensor([1., 2., 3., 4., 5., 6., 7.])
- .edge_attr:返回邊的特征矩陣(tensor array)
>>pyg_G.edge_attr
tensor([[11., 11., 11., 11.],
[22., 22., 22., 22.],
[33., 33., 33., 33.],
[44., 44., 44., 44.],
[55., 55., 55., 55.],
[66., 66., 66., 66.],
[77., 77., 77., 77.]])
- .edge_stores 和.node_stores:返回存儲(chǔ)了整個(gè)圖的信息(dict list)
>>pyg_G.node_stores
[{'x': tensor([[-1., -1., -1.],
[-2., -2., -2.],
[-3., -3., -3.],
[-4., -4., -4.]]), 'edge_index': tensor([[0, 1, 2, 1, 3, 2, 3],
[1, 2, 1, 3, 1, 3, 2]]), 'edge_attr': tensor([[11., 11., 11., 11.],
[22., 22., 22., 22.],
[33., 33., 33., 33.],
[44., 44., 44., 44.],
[55., 55., 55., 55.],
[66., 66., 66., 66.],
[77., 77., 77., 77.]]), 'edge_weight': tensor([1., 2., 3., 4., 5., 6., 7.])}]
3.4)PyG圖對(duì)象與networkx圖對(duì)象的轉(zhuǎn)換(檢查我們創(chuàng)建的PyG對(duì)象是否與原圖一致)
(https://blog.csdn.net/zzy_NIC/article/details/127996911)
(https://zhuanlan.zhihu.com/p/92482339)
PyG主要用于圖網(wǎng)絡(luò)計(jì)算胶征,本身沒有可視化功能〗胺拢可利用PyG的to_networkx()方法將PyG同構(gòu)圖對(duì)象轉(zhuǎn)化成networkx對(duì)象睛低,然后可視化。
to_networkx(
?? data: PyG的Data或HeteroData對(duì)象,
?? node_attrs: 節(jié)點(diǎn)屬性名(可迭代str對(duì)象服傍,默認(rèn)None),
?? edge_attrs: 邊屬性名(可迭代str對(duì)象钱雷,默認(rèn)None),
?? graph_attrs: 圖屬性名(可迭代str對(duì)象,默認(rèn)None),
?? to_undirected: 轉(zhuǎn)換成無向圖還是有向圖(True/False吹零,默認(rèn)False),
?? remove_self_loops: 是否將圖中的loop移除(True/False罩抗,默認(rèn)False),
)
■■Case1:轉(zhuǎn)換時(shí),不指定 node_attrs灿椅、edge_attrs套蒂、graph_attrs參數(shù)。
從輸出結(jié)果來看茫蛹,這種情況to_networkx()只會(huì)把PyG對(duì)象的節(jié)點(diǎn)(nodes)和邊(edges)轉(zhuǎn)換到networkx對(duì)象中操刀,其他屬性信息不會(huì)包含(下圖中全是空{(diào) })。其次麻惶,從輸出的節(jié)點(diǎn)名馍刮、邊的節(jié)點(diǎn)對(duì)以及圖像來看,與最前面的‘原圖’是相同的窃蹋,說明我們構(gòu)建的PyG是對(duì)的。
# Case1
>>nx_G = to_networkx(data=pyg_G, to_undirected=False) # 將PyG的Data對(duì)象轉(zhuǎn)化成networkx的數(shù)據(jù)對(duì)象
>>print(f'節(jié)點(diǎn)名:{nx_G.nodes}')
>>print(f'邊的節(jié)點(diǎn)對(duì):{nx_G.edges}')
>>print('每個(gè)節(jié)點(diǎn)的屬性:')
# print(nx_G.nodes(data=True))
>>for node in nx_G.nodes(data=True):
print(node)
>>print('每條邊的屬性:')
# print(nx_G.edges(data=True))
>>for edge in nx_G.edges(data=True):
print(edge)
# 畫圖
>>pos = nx.spring_layout(nx_G) # 迭代計(jì)算‘可視化圖片’上每個(gè)節(jié)點(diǎn)的坐標(biāo)
>>nx.draw(nx_G, pos, node_size=800, with_labels=True, font_size=20) # 繪圖
>>plt.show()
輸出:如下圖所示
■■Case2:轉(zhuǎn)換時(shí)静稻,指定 node_attrs警没、edge_attrs、graph_attrs參數(shù)振湾。
這種情況杀迹,首先得查看原PyG對(duì)象有哪些屬性:
>>print(pyg_G.node_attrs())
>>print(pyg_G.edge_attrs())
輸出:
['x']
['edge_weight', 'edge_attr', 'edge_index']
可見,該P(yáng)yG對(duì)象有節(jié)點(diǎn)屬性有['x']押搪,邊屬性有['edge_weight', 'edge_attr', 'edge_index']树酪,
于是可以在to_networkx()轉(zhuǎn)換時(shí)進(jìn)行指定(特別注意:'edge_index'這個(gè)屬性不能寫在to_networkx()的edge_attrs變量中,否則出錯(cuò))大州,見下面代碼:
# Case2
>>nx_G = to_networkx(data=pyg_G,
node_attrs=['x'],
edge_attrs=['edge_weight', 'edge_attr'],
to_undirected=True) # 將PyG的Data對(duì)象轉(zhuǎn)化成networkx的數(shù)據(jù)對(duì)象
>>print(f'節(jié)點(diǎn)名:{nx_G.nodes}')
>>print(f'邊的節(jié)點(diǎn)對(duì):{nx_G.edges}')
>>print('每個(gè)節(jié)點(diǎn)的屬性:')
# print(nx_G.nodes(data=True))
>>for node in nx_G.nodes(data=True):
print(node)
>>print('每條邊的屬性:')
# print(nx_G.edges(data=True))
>>for edge in nx_G.edges(data=True):
print(edge)
# 畫圖
>>pos = nx.spring_layout(nx_G) # 迭代計(jì)算‘可視化圖片’上每個(gè)節(jié)點(diǎn)的坐標(biāo)
>>nx.draw(nx_G, pos, node_size=400, with_labels=True) # 繪圖
>>plt.show()
從上圖的輸出結(jié)果看续语,已經(jīng)把PyG對(duì)象的節(jié)點(diǎn)和邊的各種屬性同時(shí)轉(zhuǎn)化成networkx對(duì)象的屬性了。
四厦画、舉例2(PyG對(duì)象節(jié)點(diǎn)疮茄、邊滥朱、節(jié)點(diǎn)特征、邊特征之間的對(duì)應(yīng)關(guān)系剖析)
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
4.1)原始圖數(shù)據(jù)準(zhǔn)備
與前面不同的是力试,此例中事先并不知道圖的結(jié)構(gòu)徙邻,只有數(shù)據(jù)。
而且注意:
my_node_features的shape=[5,3]畸裳,即節(jié)點(diǎn)序號(hào)為:0缰犁、1、2怖糊、3民鼓、4;
但邊的節(jié)點(diǎn)對(duì)my_edge_index 指定的節(jié)點(diǎn)為:10蓬抄、11丰嘉、12、13嚷缭。
# 節(jié)點(diǎn)特征矩陣(一行對(duì)應(yīng)一個(gè)節(jié)點(diǎn)的特征饮亏,每個(gè)節(jié)點(diǎn)有3個(gè)特征)
>>my_node_features = torch.tensor([[-1, -1, -1],
[-2, -2, -2],
[-3, -3, -3],
[-4, -4, -4],
[-5, -5, -5]],
dtype=torch.float)
# 邊矩陣(這里共有7條邊,必須用7組節(jié)點(diǎn)對(duì)來表示阅爽,節(jié)點(diǎn)對(duì)的前后位置可以任意調(diào)換路幸,對(duì)結(jié)果沒有影響)
>>my_edge_index = torch.tensor([[10, 11, 12, 11, 13, 13, 12],
[11, 12, 11, 13, 11, 12, 13]], dtype=torch.long)
# 邊特征矩陣(一行對(duì)應(yīng)一條邊的特征,每條邊有4個(gè)特征)
>>my_edge_attr = torch.tensor([[11, 11, 11, 11],
[22, 22, 22, 22],
[33, 33, 33, 33],
[44, 44, 44, 44],
[55, 55, 55, 55],
[66, 66, 66, 66],
[77, 77, 77, 77]], dtype=torch.float)
# 邊權(quán)重付翁,共設(shè)置了7個(gè)邊權(quán)重
>>my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float)
4.2)根據(jù)原始數(shù)據(jù)構(gòu)建PyG圖對(duì)象
>>pyg_G = Data(x=my_node_features,
edge_index=my_edge_index,
edge_attr=my_edge_attr,
edge_weight=my_edge_weight)
>>print(pyg_G)
輸出:
Data(x=[5, 3], edge_index=[2, 7], edge_attr=[7, 4], edge_weight=[7])
從PyG對(duì)象的輸出結(jié)果看简肴,該圖有5個(gè)節(jié)點(diǎn),每個(gè)節(jié)點(diǎn)3個(gè)特征百侧;共有7條邊砰识,每條邊4個(gè)特征,1個(gè)權(quán)重佣渴。
輸出節(jié)點(diǎn)和邊的屬性名列表:
>>print(pyg_G.node_attrs())
>>print(pyg_G.edge_attrs())
輸出:
['x']
['edge_index', 'edge_weight', 'edge_attr']
4.3)將PyG對(duì)象轉(zhuǎn)換成networkx對(duì)象辫狼,并成圖
>>nx_G = to_networkx(data=pyg_G,
node_attrs=['x'],
edge_attrs=['edge_weight', 'edge_attr'],
to_undirected=False) # 將PyG的Data對(duì)象轉(zhuǎn)化成networkx的數(shù)據(jù)對(duì)象
>>print(f'節(jié)點(diǎn)名:{nx_G.nodes}')
>>print(f'邊的節(jié)點(diǎn)對(duì):{nx_G.edges}')
>>print('每個(gè)節(jié)點(diǎn)的屬性:')
# print(nx_G.nodes(data=True))
>>for node in nx_G.nodes(data=True):
print(node)
>>print('每條邊的屬性:')
# print(nx_G.edges(data=True))
>>for edge in nx_G.edges(data=True):
print(edge)
# 畫圖
>>pos = nx.circular_layout(nx_G) # 迭代計(jì)算‘可視化圖片’上每個(gè)節(jié)點(diǎn)的坐標(biāo)
>>nx.draw(nx_G, pos, node_size=800, with_labels=True, font_size=20) # 繪圖
>>plt.show()
Case1:參數(shù)to_undirected=False,即有向圖
從輸出結(jié)果的節(jié)點(diǎn)名來看辛润,該圖共有9個(gè)節(jié)點(diǎn)膨处,前面的[0,1,2,3,4]五個(gè)節(jié)點(diǎn)(注意,代碼中我們并沒有指定這些節(jié)點(diǎn)名)是to_networkx()根據(jù)節(jié)點(diǎn)特征矩陣my_node_features的行數(shù)按0,1,2……順序自動(dòng)分配的(這是PyG固定的)砂竖;后面四個(gè)節(jié)點(diǎn)[10,11,12,13]是to_networkx()根據(jù)用戶給的邊的節(jié)點(diǎn)對(duì)矩陣my_edge_index中自動(dòng)抽取并生成的真椿。
【★★可見,在利用to_networkx()將PyG對(duì)象轉(zhuǎn)換成networkx對(duì)象時(shí)乎澄,to_networkx會(huì)自動(dòng)補(bǔ)充一些節(jié)點(diǎn)突硝,比如這里的[0,1,2,3,4],我們將其稱為冗余節(jié)點(diǎn)三圆!可以寫額外的代碼來將這些冗余節(jié)點(diǎn)刪除狞换,見子圖抽取的‘2.2.5 將冗余節(jié)點(diǎn)從子圖的networkx圖對(duì)象中刪除’】
關(guān)于邊的特征和權(quán)重避咆,PyG會(huì)自動(dòng)將邊特征矩陣my_edge_attr的
第1行作為第1條邊【這里是(10,11)】的特征;
第2行作為第2條邊【這里是(11,12)】的特征修噪;
第3行作為第3條邊【這里是(12,11)】的特征查库;
……
同理,PyG會(huì)自動(dòng)將邊權(quán)重向量my_edge_weight的
第1個(gè)值作為第1條邊【這里是(10,11)】的權(quán)重黄琼;
第2個(gè)值作為第2條邊【這里是(11,12)】的權(quán)重樊销;
第3個(gè)值作為第3條邊【這里是(12,11)】的權(quán)重;
……
特別注意:邊特征矩陣(my_edge_attr)的行數(shù)脏款、邊權(quán)重向量(my_edge_weight)的元素個(gè)數(shù)都必須和邊節(jié)點(diǎn)對(duì)矩陣(my_edge_index )的列數(shù)相同围苫,否則結(jié)果會(huì)出錯(cuò)。
Case2:參數(shù)to_undirected=True撤师,即無向圖
Case2除了邊有所變化以外剂府,其他都與Cas1一樣。
Case2主要為了說明to_networkx()這個(gè)函數(shù)的參數(shù)to_undirected=False/True(有向圖和無向圖)的區(qū)別剃盾。
Cas1是有向圖腺占,根據(jù)給定的節(jié)點(diǎn)對(duì)矩陣my_edge_index從起點(diǎn)到終點(diǎn)畫圖即可,這個(gè)沒啥疑問痒谴。
Cas2是無向圖:
- 如果兩個(gè)節(jié)點(diǎn)之間只有1條邊衰伯,則有向圖和無向圖都用這條邊,比如這里的(10,11)积蔚;
- 如果兩個(gè)節(jié)點(diǎn)之間有2條邊意鲸,則使用小節(jié)點(diǎn)序號(hào)到大節(jié)點(diǎn)序號(hào)的邊作為無向邊,比如這里的(11,12)和(12,11)尽爆,選擇(11,12)作為無向邊怎顾,(11,13)和(13,11)戏罢,選擇(11,13)作為無向邊,(13,12)和(12,13)溶诞,選擇(12,13)作為無向邊驳遵。
參考:
https://zhuanlan.zhihu.com/p/599104296
https://blog.csdn.net/ARPOSPF/article/details/128398393