<span style="font-size:16px">
</span><span style="font-size:16px"><div class="image-package"><img src="https://upload-images.jianshu.io/upload_images/26011021-4ae807cb0dd8cf89.jpeg" img-data="{"format":"jpeg","size":39554,"height":449,"width":1440}" class="uploaded-img" style="min-height:200px;min-width:200px;" width="auto" height="auto"/>
</div><p>本文先簡(jiǎn)單概述GNN鏈接預(yù)測(cè)任務(wù)呀非,接下來使用Deep Graph Library實(shí)現(xiàn)GNN進(jìn)行鏈接預(yù)測(cè)湖苞,并對(duì)代碼進(jìn)行詳細(xì)介紹片部,若需獲取模型的完整代碼拱烁,可關(guān)注公眾號(hào)【AI機(jī)器學(xué)習(xí)與知識(shí)圖譜】后回復(fù):<strong>DGL第二講完整代碼</strong>
</p><p><strong>
</strong></p></span><strong><font size="4">一疹瘦、GNN鏈接預(yù)測(cè)概述</font></strong><div><font size="5"><b>
</b></font><span style="font-size:16px"><span>GNN鏈接預(yù)測(cè)任務(wù)崩哩,即預(yù)測(cè)圖中兩個(gè)節(jié)點(diǎn)之間的邊是否存在。在</span></span><span style="font-size:16px">Social Recommendation言沐,</span><span style="font-size:16px">Knowledge G</span><span style="font-size:16px"><span>raph Completion等應(yīng)用中都需要進(jìn)行鏈接預(yù)測(cè)琢锋。模型實(shí)現(xiàn)上是將鏈接預(yù)測(cè)任務(wù)看成一個(gè)二分類任務(wù):</span></span><span style="font-size:16px">
</span><span style="font-size:16px">1. 將圖中存在的邊作為正樣本辕漂;</span><span style="font-size:16px">2. 負(fù)采樣一些圖中不存在的邊作為負(fù)樣本;</span><span style="font-size:16px">3. 將正樣例和負(fù)樣例合并劃分為訓(xùn)練集和測(cè)試集吴超;</span><span style="font-size:16px">4. 可以采用二分類模型的評(píng)估指標(biāo)來評(píng)估模型的效果钉嘹,例如:AUC值</span><span style="font-size:16px">
</span><span style="font-size:16px"><span>在一些場(chǎng)景下例如大規(guī)模推薦系統(tǒng)或信息檢索,模型需要評(píng)估top</span><span>-k</span><span>預(yù)測(cè)結(jié)果的準(zhǔn)確性鲸阻,因此對(duì)于鏈接預(yù)測(cè)任務(wù)還需要一些其他的評(píng)估指標(biāo)來衡量模型最終效果:</span></span><span style="font-size:16px">1. MR(MeanRank)</span><span style="font-size:16px">2. MRR(Mean Reciprocal Rank)</span><span style="font-size:16px">3. Hit@n</span><span style="font-size:16px">
</span><span style="font-size:16px"><span>MR, MRR, Hit@n</span><span>指標(biāo)含義:假設(shè)整個(gè)圖譜中共n個(gè)實(shí)體跋涣,評(píng)估前先進(jìn)行如下操作:</span></span><span style="font-size:16px"><span>(1)</span><span>將一個(gè)正確的三元組</span><span><span>中的頭實(shí)體</span><span>h</span><span>或者尾實(shí)體t,依次替換成整個(gè)圖譜中的其他所有實(shí)體鸟悴,這樣會(huì)產(chǎn)生n個(gè)三元組陈辱;</span></span></span><span style="font-size:16px"><span>(2)</span><span>對(duì)(</span><span>1</span><span>)中產(chǎn)生的n個(gè)三元組分別計(jì)算其能量值,例如在TransE中計(jì)算</span><span><span>的值细诸,這樣n個(gè)三元組分別對(duì)應(yīng)自己的能量值沛贪;</span></span></span><span style="font-size:16px"><span>(3)</span><span>對(duì)上述n個(gè)三元組按照能量值進(jìn)行升序排序,記錄每個(gè)三元組排序后的序號(hào)震贵;</span></span><span style="font-size:16px"><span>(4)</span><span>對(duì)所有正確的三元組都進(jìn)行上述三步操作</span></span><span style="font-size:16px">
</span><p style="text-indent:0pt"><strong><span style="font-size:16px">MR指標(biāo):</span></strong><span style="font-size:16px">將整個(gè)圖譜中每個(gè)正確三元組的能量值排序后的序號(hào)取平均得到的值利赋;</span></p><p style="text-indent:0pt"><strong><span style="font-size:16px">MRR指標(biāo):</span></strong><span style="font-size:16px">將整個(gè)圖譜每個(gè)正確三元組的能量排序后的序號(hào)倒數(shù)取平均得到的值;</span></p><strong><span style="font-size:16px"><span>Hit@n</span><span>指標(biāo):</span></span></strong><span style="font-size:16px"><span>整個(gè)圖譜正確三元組的能量排序后序號(hào)小于n的三元組所占的比例猩系。</span></span><p style="text-indent:0pt"><span style="font-size:16px">
</span></p><p style="text-indent:0pt"><span style="font-size:16px"><span>因此對(duì)于鏈接預(yù)測(cè)任務(wù)來說媚送,MR指標(biāo)越小,模型效果越好寇甸;MRR和Hit</span><span>@n</span><span>指標(biāo)越大塘偎,模型效果越好。</span></span><span style="font-size:16px">接下來本文將在Cora引文數(shù)據(jù)集上拿霉,預(yù)測(cè)兩篇論文之間是否存在引用關(guān)系或被引用關(guān)系吟秩。</span></p>
<span style="font-size:16px"><span style="font-size:20px"><strong>二、GNN鏈接預(yù)測(cè)實(shí)現(xiàn)</strong></span></span><p style="text-indent:0pt"><span style="font-size:16px"><span><span style="font-size:16px">接下來使用DGL框架實(shí)現(xiàn)GNN模型進(jìn)行鏈接任務(wù)绽淘,對(duì)代碼給出詳細(xì)解釋峰尝。<span style="font-size:16px">首先如下所示,先加載需要使用的dgl庫(kù)和pytorch庫(kù)收恢;</span></span></span></span></p><span>import</span> dgl
<span>import</span> torch
<span>import</span> torch.nn <span>as</span> nn
<span>import</span> torch.nn.functional <span>as</span> F
<span>import</span> itertools
<span>import</span> numpy <span>as</span> np
<span>import</span> scipy.sparse <span>as</span> sp<p style="text-indent:0pt"><strong>數(shù)據(jù)加載</strong><span>:下面代碼加載dgl庫(kù)提供的Cora數(shù)據(jù)對(duì)象,dgl庫(kù)中Dataset數(shù)據(jù)集可能是包含多個(gè)圖的祭往,所以加載的dataset對(duì)象是一個(gè)list伦意,list中的每個(gè)元素對(duì)應(yīng)該數(shù)據(jù)的一個(gè)graph,但Cora數(shù)據(jù)集是由單個(gè)圖組成硼补,因此直接使用dataset[0]取出graph驮肉。</span>
</p><span>import</span> dgl.data
dataset = dgl.data.CoraGraphDataset()
g = dataset[<span>0</span>]<strong><span style="font-size:16px">
</span></strong><strong><span style="font-size:16px">正負(fù)數(shù)據(jù)劃分</span></strong><span style="font-size:16px">:隨機(jī)抽取數(shù)據(jù)集中10%的邊作為測(cè)試集中的正樣例,剩下的90%數(shù)據(jù)集中的邊作為訓(xùn)練集已骇,然后隨機(jī)為訓(xùn)練集和測(cè)試集离钝,負(fù)采樣生成相同數(shù)量的負(fù)樣例票编,使得訓(xùn)練集和測(cè)試集中的正負(fù)樣本比例為1:1,將數(shù)據(jù)集中邊的集合劃分到訓(xùn)練集和測(cè)試集中卵渴,訓(xùn)練集90%慧域,測(cè)試集10%</span>
u, v = g.edges()
eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * <span>0.1</span>)
train_size = g.number_of_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]
<span># 采樣所有負(fù)樣例并劃分為訓(xùn)練集和測(cè)試集中。</span>
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = <span>1</span> - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != <span>0</span>)
neg_eids = np.random.choice(len(neg_u), g.number_of_edges() // <span>2</span>)
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]<p style="text-indent:0pt"><strong><span style="font-size:16px">注意:</span></strong><span style="font-size:16px">在模型訓(xùn)練時(shí)浪读,需要將圖中在測(cè)試集中10%的邊移除掉昔榴,防止數(shù)據(jù)泄露,使用dgl.remove_edges</span></p>train_g = dgl.remove_edges(g, eids[:test_size])
<span>from</span> dgl.nn <span>import</span> SAGEConv
<span># 定義一個(gè)兩層的GraphSage模型</span>
<span><span>class</span> <span>GraphSAGE</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>init</span><span>(self, in_feats, h_feats)</span>:</span>
super(GraphSAGE, self).init()
self.conv1 = SAGEConv(in_feats, h_feats, <span>'mean'</span>)
self.conv2 = SAGEConv(h_feats, h_feats, <span>'mean'</span>)
<span><span>def</span> <span>forward</span><span>(self, g, in_feat)</span>:</span>
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
<span>return</span> h<p><span style="font-size:16px"><strong>
</strong></span></p><p><span style="font-size:16px"><strong>節(jié)點(diǎn)對(duì)得分函數(shù)定義方式</strong>:模型通過定義函數(shù)來預(yù)測(cè)兩個(gè)節(jié)點(diǎn)表示之間的得分碘橘,從而來判斷兩個(gè)節(jié)點(diǎn)之間存在邊的可能性互订,在GNN節(jié)點(diǎn)分類任務(wù)重,模型是訓(xùn)練得到單個(gè)節(jié)點(diǎn)的表征痘拆,但在鏈接計(jì)算任務(wù)中是預(yù)測(cè)節(jié)點(diǎn)對(duì)的表征</span>
</p><p style="text-indent:0pt">
</p><p style="text-indent:0pt"><strong><span style="font-size:16px">注意:</span></strong><span style="font-size:16px">在給出節(jié)點(diǎn)對(duì)的預(yù)測(cè)得分函數(shù)之前仰禽,先需要理解一下:DGL使用方式是先將節(jié)點(diǎn)對(duì)視為一個(gè)圖,同時(shí)一條邊可用來描述一對(duì)節(jié)點(diǎn)纺蛆。在鏈接預(yù)測(cè)中吐葵,會(huì)得到一個(gè)正圖,它包含所有的正例子作為邊犹撒,以及一個(gè)負(fù)圖折联,它包含所有的負(fù)例子。正圖和負(fù)圖將包含與原始圖相同的節(jié)點(diǎn)集识颊。這使得在多個(gè)圖中傳遞節(jié)點(diǎn)特征更容易進(jìn)行計(jì)算诚镰。可以直接將在整個(gè)圖上計(jì)算的節(jié)點(diǎn)表示形式提供給正圖和負(fù)圖祥款,用于計(jì)算節(jié)點(diǎn)對(duì)的兩兩得分清笨。</span></p><p style="text-indent:0pt">
</p><p style="text-indent:0pt"><span style="font-size:16px">將節(jié)點(diǎn)對(duì)視為圖的好處是可以使用DGLGraph.apply_edges方法,基于節(jié)點(diǎn)的特征表示和原始圖中邊的特征表示可以方便地計(jì)算新產(chǎn)生邊的特征表示刃跛。DGL提供了一組優(yōu)化的內(nèi)置函數(shù)抠艾,可以直接使用原始節(jié)點(diǎn)/邊特征表示計(jì)算新的邊特征表示。</span>
</p>train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())
test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())<p><span style="font-size:16px"><strong>
</strong></span></p><p><span style="font-size:16px"><strong>官方和自定義兩種得分函數(shù)實(shí)現(xiàn)</strong>:接下來定義兩個(gè)節(jié)點(diǎn)之間的得分函數(shù)預(yù)測(cè),可以直接使用DGL提供的港令,也可以自定義赴肚,下面DotPredictor是官方提供的預(yù)測(cè)函數(shù),MLPPredictor是自定義的預(yù)測(cè)函數(shù)齐苛。</span></p><span>import</span> dgl.function <span>as</span> fn
<span><span>class</span> <span>DotPredictor</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>forward</span><span>(self, g, h)</span>:</span>
<span>with</span> g.local_scope():
g.ndata[<span>'h'</span>] = h
<span># 通過源節(jié)點(diǎn)特征“h”和目標(biāo)節(jié)點(diǎn)特征“h”之間的點(diǎn)積計(jì)算兩點(diǎn)之間存在邊的Score</span>
g.apply_edges(fn.u_dot_v(<span>'h'</span>, <span>'h'</span>, <span>'score'</span>))
<span># u_dot_v為每條邊返回一個(gè)元素向量,因此需要squeeze操作</span>
<span>return</span> g.edata[<span>'score'</span>][:, <span>0</span>]<p>
</p><span><span>class</span> <span>MLPPredictor</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>init</span><span>(self, h_feats)</span>:</span>
super().init()
self.W1 = nn.Linear(h_feats * <span>2</span>, h_feats)
self.W2 = nn.Linear(h_feats, <span>1</span>)
<span><span>def</span> <span>apply_edges</span><span>(self, edges)</span>:</span>
<span>"""
Computes a scalar score for each edge of the given graph.</span><span>
Parameters
----------
edges :
Has three members src
, dst
and data
, each of
which is a dictionary representing the features of the
source nodes, the destination nodes, and the edges
themselves.
Returns
-------
dict
A dictionary of new edge features.
"""</span>
h = torch.cat([edges.src[<span>'h'</span>], edges.dst[<span>'h'</span>]], <span>1</span>)
<span>return</span> {<span>'score'</span>: self.W2(F.relu(self.W1(h))).squeeze(<span>1</span>)}
<span><span>def</span> <span>forward</span><span>(self, g, h)</span>:</span>
<span>with</span> g.local_scope():
g.ndata[<span>'h'</span>] = h
g.apply_edges(self.apply_edges)
<span>return</span> g.edata[<span>'score'</span>]<p><span style="font-size:16px">接下來直接進(jìn)行模型訓(xùn)練:</span><span style="font-size:16px"/></p>optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=<span>0.01</span>)
all_logits = []
<span>for</span> e <span>in</span> range(<span>100</span>):
<span># 前向傳播</span>
h = model(train_g, train_g.ndata[<span>'feat'</span>])
pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h)
loss = compute_loss(pos_score, neg_score)
<span># 后向傳播</span>
optimizer.zero_grad()
loss.backward()
optimizer.step()
<span>if</span> e % <span>5</span> == <span>0</span>:
print(<span>'In epoch {}, loss: {}'</span>.format(e, loss))
<span># 檢測(cè)結(jié)果準(zhǔn)確性</span>
<span>from</span> sklearn.metrics <span>import</span> roc_auc_score
<span>with</span> torch.no_grad():
pos_score = pred(test_pos_g, h)
neg_score = pred(test_neg_g, h)
print(<span>'AUC'</span>, compute_auc(pos_score, neg_score))<p><span style="font-size:16px">
</span></p><p><span style="font-size:16px">上面是模型的訓(xùn)練函數(shù)桂塞,和pytorch模型訓(xùn)練過程都是相似的凹蜂,訓(xùn)練過程如下圖所示:</span></p>In epoch <span>0</span>, loss: <span>0.6172636151313782</span>
In epoch <span>5</span>, loss: <span>0.6101921796798706</span>
In epoch <span>10</span>, loss: <span>0.5864554047584534</span>
In epoch <span>15</span>, loss: <span>0.5405876040458679</span>
In epoch <span>20</span>, loss: <span>0.4583510458469391</span>
In epoch <span>25</span>, loss: <span>0.39045605063438416</span>
In epoch <span>30</span>, loss: <span>0.34702828526496887</span>
In epoch <span>35</span>, loss: <span>0.3122958838939667</span>
In epoch <span>40</span>, loss: <span>0.2834944725036621</span>
In epoch <span>45</span>, loss: <span>0.25488677620887756</span>
In epoch <span>50</span>, loss: <span>0.22920763492584229</span>
In epoch <span>55</span>, loss: <span>0.20638766884803772</span>
In epoch <span>60</span>, loss: <span>0.18289318680763245</span>
In epoch <span>65</span>, loss: <span>0.16009262204170227</span>
In epoch <span>70</span>, loss: <span>0.1381770521402359</span>
In epoch <span>75</span>, loss: <span>0.11725720018148422</span>
In epoch <span>80</span>, loss: <span>0.09779688715934753</span>
In epoch <span>85</span>, loss: <span>0.07947927713394165</span>
In epoch <span>90</span>, loss: <span>0.06309689581394196</span>
In epoch <span>95</span>, loss: <span>0.048749890178442</span>
AUC <span>0.8526520069180836</span><p>
</p><p><span style="font-size:18px"><strong>往期精彩</strong></span></p><p>【知識(shí)圖譜系列】基于生成式的知識(shí)圖譜預(yù)訓(xùn)練模型</p><p>【知識(shí)圖譜系列】基于實(shí)數(shù)或復(fù)數(shù)空間的知識(shí)圖譜嵌入
</p><p><span style="font-size:14px"/></p><p>【知識(shí)圖譜系列】知識(shí)圖譜多跳推理之強(qiáng)化學(xué)習(xí)
</p><p><span style="font-size:14px"/></p><p>【知識(shí)圖譜系列】動(dòng)態(tài)時(shí)序知識(shí)圖譜EvolveGCN</p><p><span style="font-size:14px"/></p><p>【機(jī)器學(xué)習(xí)系列】機(jī)器學(xué)習(xí)中的兩大學(xué)派</p></div>
「GNN框架系列」DGL第二講:實(shí)現(xiàn)GNN鏈接預(yù)測(cè)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
- 文/潘曉璐 我一進(jìn)店門剑逃,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人官辽,你說我怎么就攤上這事蛹磺。” “怎么了同仆?”我有些...
- 文/不壞的土叔 我叫張陵萤捆,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我俗批,道長(zhǎng)俗或,這世上最難降的妖魔是什么? 我笑而不...
- 正文 為了忘掉前任岁忘,我火速辦了婚禮辛慰,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘干像。我一直安慰自己帅腌,他們只是感情好,可當(dāng)我...
- 文/花漫 我一把揭開白布麻汰。 她就那樣靜靜地躺著速客,像睡著了一般。 火紅的嫁衣襯著肌膚如雪五鲫。 梳的紋絲不亂的頭發(fā)上溺职,一...
- 文/蒼蘭香墨 我猛地睜開眼弃舒,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起聋呢,我...
- 序言:老撾萬(wàn)榮一對(duì)情侶失蹤苗踪,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后削锰,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體通铲,經(jīng)...
- 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
- 正文 我和宋清朗相戀三年器贩,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了颅夺。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
- 正文 年R本政府宣布,位于F島的核電站奉芦,受9級(jí)特大地震影響赵抢,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜声功,卻給世界環(huán)境...
- 文/蒙蒙 一烦却、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧先巴,春花似錦其爵、人聲如沸。這莊子的主人今日做“春日...
- 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至朝卒,卻和暖如春证逻,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背抗斤。 一陣腳步聲響...
- 正文 我出身青樓龙宏,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親伤疙。 傳聞我的和親對(duì)象是個(gè)殘疾皇子银酗,可洞房花燭夜當(dāng)晚...
推薦閱讀更多精彩內(nèi)容
- 本文先簡(jiǎn)單概述GNN節(jié)點(diǎn)分類任務(wù)辆影,然后詳細(xì)介紹如何使用Deep Graph Library + Pytorch...
- 數(shù)據(jù)完整存儲(chǔ)與內(nèi)存的數(shù)據(jù)集類 一、InMemoryDataset基類簡(jiǎn)介 在PyG中黍特,通過繼承InMemoryDa...
- 首先感謝datawhale 的GNN課程蛙讥,非常精彩。GNN/Markdown版本/6-1-數(shù)據(jù)完整存于內(nèi)存的數(shù)據(jù)集...
- 姓名:馮子豪 學(xué)號(hào):16020199001 轉(zhuǎn)載自https://zhuanlan.zhihu.com/p/248...
- 預(yù)備知識(shí): 會(huì)使用pytorch搭建簡(jiǎn)單的cnn 熟悉神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程與優(yōu)化方法 結(jié)合理論課的內(nèi)容灭衷,了解目標(biāo)檢測(cè)...