GCN輸出的H'矩陣聚假,最后怎么令其作節(jié)點分類块蚌。即,GCN輸出H’如何讓節(jié)點分類的膘格?
以pytorch的GCN模型為例:GCN
GCN已經(jīng)將計算簡化為:
假設(shè)一個圖的頂點數(shù)目為:
import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj)) ###注: X = AXW1 A=[n,n] X[n,nfeat] W=[nfeat,nhid] ==> X=[n,nhid]
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj) ###注:X=AXW2 A=[n,n] X=[n,nhid] W=[nhid,nclass] ==> X=[n,nclass]
return F.log_softmax(x, dim=1)
代碼里的x就是與公式里的H對應(yīng)峭范,x是圖頂點的原始特征矩陣,x輸入gc1層時的維度是:[n,nfeat]瘪贱,n是圖節(jié)點數(shù)纱控,nfeat是圖節(jié)點原始特征的維度;
第一次計算,即
菜秦,A矩陣維度[n,n]甜害,X矩陣就是x維度[n,nfeat],
變量維度[nfeat,nhid]球昨,所以
是新特征矩陣尔店。
第二次計算,即
,A矩陣維度[n,n]嚣州,X矩陣就是x維度[n,nhid]鲫售,
變量維度[nhid,nclass],所以
是新矩陣该肴,它就對應(yīng)nclass分類情竹。
后面return F.log_softmax(x, dim=1),即對分類上分數(shù)進行softmax歸一化處理沙庐,即可以和真實的標簽向量進行對標鲤妥,計算損失值。