GAT(GRAPH ATTENTION NETWORKS)是一種使用了self attention機(jī)制圖神經(jīng)網(wǎng)絡(luò)膝但,該網(wǎng)絡(luò)使用類似transformer里面self attention的方式計(jì)算圖里面某個(gè)節(jié)點(diǎn)相對(duì)于每個(gè)鄰接節(jié)點(diǎn)的注意力洒忧,將節(jié)點(diǎn)本身的特征和注意力特征concate起來(lái)作為該節(jié)點(diǎn)的特征碟绑,在此基礎(chǔ)上進(jìn)行節(jié)點(diǎn)的分類等任務(wù)。
下面是transformer self attention原理圖:
GAT使用了類似的流程計(jì)算節(jié)點(diǎn)的self attention愿卒,首先計(jì)算當(dāng)前節(jié)點(diǎn)和每個(gè)鄰接節(jié)點(diǎn)的注意力score,然后使用該score乘以每個(gè)節(jié)點(diǎn)的特征潮秘,累加起來(lái)并經(jīng)過(guò)一個(gè)非線性映射琼开,作為當(dāng)前節(jié)點(diǎn)的特征。
Attention score公式表示如下:
這里使用W矩陣將原始的特征映射到一個(gè)新的空間枕荞,a代表self attention的計(jì)算柜候,如前面圖2所示,這樣計(jì)算出兩個(gè)鄰接節(jié)點(diǎn)的attention score躏精,也就是Eij渣刷,然后對(duì)所有鄰接節(jié)點(diǎn)的score進(jìn)行softmax處理,得到歸一化的attention score矗烛。
代碼可以參考這個(gè)實(shí)現(xiàn):https://github.com/gordicaleksa/pytorch-GAT
核心代碼:
def forward(self, data):
in_nodes_features, connectivity_mask = data
num_of_nodes = in_nodes_features.shape[0]
in_nodes_features = self.dropout(in_nodes_features)
# V
nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)
nodes_features_proj = self.dropout(nodes_features_proj)
# Q辅柴、K
scores_source = torch.sum((nodes_features_proj * self.scoring_fn_source), dim=-1, keepdim=True)
scores_target = torch.sum((nodes_features_proj * self.scoring_fn_target), dim=-1, keepdim=True)
scores_source = scores_source.transpose(0, 1)
scores_target = scores_target.permute(1, 2, 0)
# Q * K
all_scores = self.leakyReLU(scores_source + scores_target)
all_attention_coefficients = self.softmax(all_scores + connectivity_mask)
# Q * K * V
out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj.transpose(0, 1))
out_nodes_features = out_nodes_features.permute(1, 0, 2)
# in_nodes_features + out_nodes_features(attention)
out_nodes_features = self.skip_concat_bias(all_attention_coefficients, in_nodes_features, out_nodes_features)
return (out_nodes_features, connectivity_mask)
該GAT的實(shí)現(xiàn)也包含在了PYG庫(kù)中,這個(gè)庫(kù)涵蓋了各種常見(jiàn)的圖神經(jīng)網(wǎng)絡(luò)方面的論文算法實(shí)現(xiàn)瞭吃。