當(dāng)需要從一個(gè)圖中抽取某一個(gè)或一系列節(jié)點(diǎn)(目標(biāo)節(jié)點(diǎn))周?chē)膋級(jí)連續(xù)節(jié)點(diǎn)時(shí),可以用PyG的 k_hop_subgraph() 方法荧恍。
k=1:與目標(biāo)節(jié)點(diǎn)直接相連的節(jié)點(diǎn)瓷叫;
k=2:與目標(biāo)節(jié)點(diǎn)隔1個(gè)節(jié)點(diǎn)相連的節(jié)點(diǎn);
……
一块饺、構(gòu)建原始大圖
第一部分內(nèi)容純粹是為了準(zhǔn)備原始大圖(相關(guān)內(nèi)容見(jiàn):PyG構(gòu)建圖對(duì)象并轉(zhuǎn)換成networkx圖對(duì)象)赞辩,與本文的核心內(nèi)容無(wú)關(guān)。
1.1 原始數(shù)據(jù)準(zhǔn)備
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx, k_hop_subgraph
import matplotlib.pyplot as plt
# 節(jié)點(diǎn)特征矩陣(一行對(duì)應(yīng)一個(gè)節(jié)點(diǎn)的特征,共7個(gè)節(jié)點(diǎn)(=節(jié)點(diǎn)特征矩陣的行數(shù))授艰,每個(gè)節(jié)點(diǎn)有3個(gè)特征)
my_node_features = torch.tensor([[0, 0, 0],
[-1, -1, -1],
[-2, -2, -2],
[-3, -3, -3],
[-4, -4, -4],
[-5, -5, -5],
[-6, -6, -6]],dtype=torch.float)
# 邊的節(jié)點(diǎn)對(duì)辨嗽,共有6條邊(7個(gè)節(jié)點(diǎn):0、1淮腾、2糟需、3屉佳、4、5洲押、6武花,與節(jié)點(diǎn)特征矩陣的行標(biāo)一一對(duì)應(yīng))
my_edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
[2, 2, 4, 4, 6, 6]])
# 邊特征矩陣(一行對(duì)應(yīng)一條邊的特征,每條邊有4個(gè)特征)
my_edge_attr = torch.tensor([[11, 11, 11, 11],
[22, 22, 22, 22],
[33, 33, 33, 33],
[44, 44, 44, 44],
[55, 55, 55, 55],
[66, 66, 66, 66]], dtype=torch.float)
# 邊權(quán)重杈帐,共有6個(gè)邊權(quán)重体箕,一條邊一個(gè)
my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)
1.2 根據(jù)原始數(shù)據(jù)構(gòu)建PyG對(duì)象
# 構(gòu)建Pyg對(duì)象
pyg_G = Data(x=my_node_features,
edge_index=my_edge_index,
edge_attr=my_edge_attr,
edge_weight=my_edge_weight)
print(pyg_G)
輸出:Data(x=[7, 3], edge_index=[2, 6], edge_attr=[6, 4], edge_weight=[6])
# 輸出信息,為to_networkx()的參數(shù)提供參考
print(pyg_G.node_attrs())
print(pyg_G.edge_attrs())
輸出:
['x']
['edge_attr', 'edge_index', 'edge_weight']
1.3 將PyG對(duì)象轉(zhuǎn)化成networkx對(duì)象挑童,用于成圖
這里需要注意的是如果原始數(shù)據(jù)準(zhǔn)備不恰當(dāng)累铅,可能會(huì)導(dǎo)致to_networkx()將PyG對(duì)象轉(zhuǎn)化成networkx對(duì)象后,多出來(lái)一些節(jié)點(diǎn)站叼,詳見(jiàn):PyG構(gòu)建圖對(duì)象并轉(zhuǎn)換成networkx圖對(duì)象
# PyG對(duì)象轉(zhuǎn)化成networkx對(duì)象
nx_G = to_networkx(data=pyg_G,
node_attrs=['x'],
edge_attrs=['edge_weight', 'edge_attr'],
to_undirected=False) # 將PyG的Data對(duì)象轉(zhuǎn)化成networkx的數(shù)據(jù)對(duì)象
print(f'節(jié)點(diǎn)名:{nx_G.nodes}')
print(f'邊的節(jié)點(diǎn)對(duì):{nx_G.edges}')
print('每個(gè)節(jié)點(diǎn)的屬性:')
# print(nx_G.nodes(data=True))
for node in nx_G.nodes(data=True):
print(node)
print('每條邊的屬性:')
# print(nx_G.edges(data=True))
for edge in nx_G.edges(data=True):
print(edge)
# 畫(huà)圖
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(nx_G) # 迭代計(jì)算‘可視化圖片’上每個(gè)節(jié)點(diǎn)的坐標(biāo)
nx.draw(nx_G, pos, node_size=400, with_labels=True) # 繪圖
plt.show()
二娃兽、k_hop_subgraph()抽取子圖
2.1 方法解釋
result = k_hop_subgraph(
node_idx=, 目標(biāo)節(jié)點(diǎn)(int,或list int)尽楔;
num_hops=, 待獲取的目標(biāo)節(jié)點(diǎn)的幾級(jí)周?chē)?jié)點(diǎn)(int)投储;
edge_index=, 原圖的邊節(jié)點(diǎn)對(duì)矩陣(tensor),其shape=[2, 邊數(shù)]阔馋;
relabel_nodes=, 是否對(duì)獲取的節(jié)點(diǎn)從0開(kāi)始重新順序編號(hào)(True/False)(要注意★★)玛荞;
flow=, 根據(jù)邊的方向選擇節(jié)點(diǎn)(str = target_to_source(目標(biāo)節(jié)點(diǎn)到其他節(jié)點(diǎn))或source_to_target(其他節(jié)點(diǎn)到目標(biāo)節(jié)點(diǎn))(要注意★★);
directed=, 如果=False垦缅,將包括所有所有采樣節(jié)點(diǎn)之間的邊冲泥。(默認(rèn)=True)
)
k_hop_subgraph()的返回值result是一個(gè)包含四個(gè)元素的tuple:
result[0]:抽取出來(lái)的節(jié)點(diǎn)(包括目標(biāo)節(jié)點(diǎn))list,已經(jīng)按照從小到大順序排列好了壁涎;
result[1]:抽取的節(jié)點(diǎn)的邊對(duì)凡恍,是個(gè)shape=[2, 抽取的邊的條數(shù)]的tensor;
result[2]:每個(gè)目標(biāo)節(jié)點(diǎn)在result[0]中的位置怔球,是一個(gè)長(zhǎng)度與目標(biāo)節(jié)點(diǎn)個(gè)數(shù)相同的一維tensor嚼酝;
result[3]:抽取的每條邊在原圖邊對(duì)矩陣中的位置,一個(gè)由True和False組成的list竟坛,長(zhǎng)度等原圖邊對(duì)矩陣的列數(shù)闽巩。
2.2 抽取子圖并繪圖
下面看具體例子(緊接著前面代碼):
2.2.1 抽取子圖信息
我們希望找到6號(hào)節(jié)點(diǎn)周?chē)膋=2的節(jié)點(diǎn)(即找到6號(hào)節(jié)點(diǎn)的1、2級(jí)節(jié)點(diǎn))担汤。
設(shè)置:relabel_nodes=False涎跨,不對(duì)找到的節(jié)點(diǎn)重新命名;
設(shè)置:flow='source_to_target')崭歧,只要求邊是指向目標(biāo)節(jié)點(diǎn)的節(jié)點(diǎn)隅很;
(上面這兩個(gè)參數(shù)的設(shè)置對(duì)結(jié)果影響很大,特別注意率碾。)
target_node_idx = [6] # 確定目標(biāo)節(jié)點(diǎn)序列
k = 2 # 目標(biāo)節(jié)點(diǎn)往周?chē)S的次數(shù)(即幾級(jí)節(jié)點(diǎn))
# 設(shè)置重要參數(shù)
relabel_nodes=False
flow='source_to_target'
# 抽取節(jié)點(diǎn)
result = k_hop_subgraph(node_idx=target_node_idx,
num_hops=k,
edge_index=pyg_G.edge_index,
relabel_nodes=relabel_nodes,
flow=flow,
directed=False)
sub_nodes_names = result[0]
sub_edge_index = result[1]
target_node_map = result[2]
sub_edge_mask = result[3]
print(f'抽取的節(jié)點(diǎn)序列:{sub_nodes_names}')
print(f'抽取的邊節(jié)點(diǎn)對(duì):{sub_edge_index}')
print(f'目標(biāo)節(jié)點(diǎn)在抽取節(jié)點(diǎn)序列中的位置:{target_node_map}')
print(f'選中的邊在原圖的邊序列中的位置:{sub_edge_mask}')
print(f'抽取目標(biāo)節(jié)點(diǎn):{sub_nodes_names[target_node_map]}')
從上述‘輸出結(jié)果’看relabel_nodes=False時(shí)叔营,和‘我們的目標(biāo)’是完全一致的屋彪。
relabel_nodes=True時(shí),只有邊的節(jié)點(diǎn)對(duì)的序號(hào)被從0開(kāi)始重新命名了绒尊,其他沒(méi)變畜挥。
2.2.2 計(jì)算抽取邊在原圖邊中的序號(hào)
這一步是為了從原圖的一些其他數(shù)據(jù)中抽取跟子圖相關(guān)的數(shù)據(jù)。
# 計(jì)算抽取的邊對(duì)矩陣sub_edge_index的每一條邊對(duì)在原邊對(duì)矩陣my_edge_index中的位置序號(hào)
match_indices_list = []
for row in sub_edge_index.t():
res = torch.where(torch.all(torch.isin(my_edge_index.t(), row), dim=1))
if res[0].numel()!=0:
print(res)
match_indices_list.append(res[0].item())
#match_indices_list
2.2.3 構(gòu)建子圖PyG對(duì)象
# 創(chuàng)建子圖的 Data 對(duì)象
sub_pyg_G = Data(x=my_node_features[sub_nodes_names,:], # 在原節(jié)點(diǎn)特征矩陣中抽取被選中的節(jié)點(diǎn)的特征
edge_index=sub_edge_index, # 在原邊節(jié)點(diǎn)對(duì)矩陣中抽取被選中的邊節(jié)點(diǎn)對(duì)
edge_attr=my_edge_attr[match_indices_list,:], # 在原邊特征矩陣中抽取被選中的邊特征
edge_weight=my_edge_weight[match_indices_list]) # 在原邊權(quán)重矩陣中抽取被選中的邊權(quán)重
sub_pyg_G
輸出:
Data(x=[3, 3], edge_index=[2, 2], edge_attr=[2, 4], edge_weight=[2])
# 輸出信息作為to_networkx()設(shè)置參數(shù)時(shí)的參考婴谱。
print(sub_pyg_G.node_attrs())
print(sub_pyg_G.edge_attrs())
輸出:
['x']
['edge_attr', 'edge_index', 'edge_weight']
2.2.4 將子圖的PyG對(duì)象轉(zhuǎn)換為networkx對(duì)象
# 將PyG對(duì)象轉(zhuǎn)換為networkx對(duì)象
sub_nx_G = to_networkx(data=sub_pyg_G,
node_attrs=['x'],
edge_attrs=['edge_attr', 'edge_weight'],
to_undirected=False)
2.2.5 將冗余節(jié)點(diǎn)從子圖的networkx圖對(duì)象中刪除
這一步是需要的蟹但,這里我們抽取的子圖節(jié)點(diǎn)為‘抽取的節(jié)點(diǎn)序列:tensor([2, 3, 4, 5, 6])’,顯然是不包括0和1號(hào)節(jié)點(diǎn)勘究,但因?yàn)閠o_networkx()方法本身的一些原因矮湘,會(huì)在轉(zhuǎn)化時(shí)把0和1號(hào)節(jié)點(diǎn)也加上(注意轉(zhuǎn)化時(shí)加上的0和1號(hào)節(jié)點(diǎn)與原圖中的0和1號(hào)節(jié)點(diǎn)是完全不同的斟冕,只是名稱(chēng)相同)口糕,稱(chēng)其為冗余節(jié)點(diǎn)。這樣一來(lái)磕蛇,就會(huì)導(dǎo)致networkx對(duì)象的節(jié)點(diǎn)數(shù)比PyG對(duì)象的節(jié)點(diǎn)數(shù)多景描。因此冗余節(jié)點(diǎn)需要?jiǎng)h除。
如果k_hop_subgraph()函數(shù)的 relabel_nodes=Ture秀撇,則無(wú)論被選取的節(jié)點(diǎn)是什么超棺,都會(huì)被從0開(kāi)始重新命名,這時(shí)就不會(huì)出現(xiàn)冗余節(jié)點(diǎn)呵燕,如果執(zhí)行下述代碼棠绘,反而會(huì)刪除有效節(jié)點(diǎn)。
(注意:這一步不是必須的再扭,只有存在冗余節(jié)點(diǎn)時(shí)需要執(zhí)行氧苍。當(dāng)。)
# 將沒(méi)有被選中的節(jié)點(diǎn)從networkx圖對(duì)象中刪除
if not relabel_nodes:
nodes_to_remove = list(set(list(sub_nx_G.nodes)) - set(sub_nodes_names.tolist()))
sub_nx_G.remove_nodes_from(nodes_to_remove)
2.2.6 子圖繪圖
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(sub_nx_G) # 定義節(jié)點(diǎn)的布局
nx.draw(sub_nx_G, pos, with_labels=True, node_color='red', edge_color="green", node_size=100, font_size=10)