圖注意力網(wǎng)絡(luò)GAT最全要點梳理(附代碼和參考資料)

Graph Attention Networks[ICLR, 2018]

該paper提出了GAT聪建,利用masked self-attention layers進行圖卷積。該方法可以賦予不同鄰居節(jié)點不同的權(quán)重,可以處理transdutive問題和inductive問題曲秉。

1. introduction

  • GCN[Kipf et al, ICLR, 2017]
  • attention mechanism: 可以處理不同大小的input。self-attention提出:Attention is all you need!

2. Architecture

GAT主要將注意力機制(Attention mechanism)和圖卷積神經(jīng)網(wǎng)絡(luò)結(jié)合起來挺智,在聚合節(jié)點信息的時候航缀,對于每個鄰居節(jié)點賦予不同的權(quán)重(也稱為attention score)。同時醉锅,和transformer提出的self-attention一樣,GAT也可以實現(xiàn)多頭(multi-heads)注意力機制发绢,每個頭單獨更新參數(shù)硬耍,最終將幾個頭的結(jié)果進行串聯(lián)或者取平均得到最終過的節(jié)點表達。

下左為得到鄰居節(jié)點attention score過程边酒,下右為多頭注意力機制更新過程经柴。


gat.PNG

具體步驟描述如下:

  • 步驟一:計算未歸一化的attention acore e_{ij}=LeakyReLU({\vec {\bf{a}}}^T[W {\vec{h}}_i ||W {\vec{h}}_j])。沿著邊將斷點的節(jié)點表示的線性變換串聯(lián)墩朦,并過一個單層的MLP坯认;
  • 步驟二:得到歸一化后attention score \alpha_{ij}=softmax(e_{ij}, dim=1)。對于e_{ij}按行通過softmax函數(shù)進行歸一化氓涣;
  • 步驟三:將節(jié)點的信息沿著邊整合到一起牛哺。機制分為單頭和多頭,多頭又有兩種整合方式劳吠,第一種是將幾個head的hidden vector和attention score相乘之后直接concat起來引润,第二種是將幾個head的vector平均再過一個非線性層(在output-layer使用)。

\alpha_{ij}=\frac{\exp(LeakyReLU({\vec {\bf{a}}}^T[W {\vec{h}}_i ||W {\vec{h}}_j]))}{\Sigma_{k \in N_i}\exp(LeakyReLU({\vec {\bf{a}}}^T[W {\vec{h}}_i ||W {\vec{h}}_k]))} \\ single-head: \quad \vec{h}_i'=\sigma(\Sigma_{j\in N_i} \alpha_{ij} W \vec{h}_j)\\ multi-heads-1: \quad \vec{h}_i'={||}_{k=1}^K\sigma(\Sigma_{j\in N_i} \alpha_{ij}^k W^k \vec{h}_j) \\ multi-heads -2: \quad \vec{h}_i'=\sigma(\frac{1}{K}\Sigma_{k=1}^K\Sigma_{j\in N_i} \alpha_{ij}^k W^k \vec{h}_j)

3. Contributions

  • 計算高效(computation efficient): 可并行計算
  • 對于不同鄰居節(jié)點給予不同重要性痒玩,讓模型解釋性更好
  • 可用于directed graph和inductive learning場景
  • GraphSAGE對每個節(jié)點指定fixed-size的鄰居淳附,并且使用LSTM的聚合器需要random-ordering的操作;但是GAT可以獲得所有鄰居的信息蠢古,并且不需要ordering

4. Experiment

dataset.PNG

4.1 Transductive learnig

Node Classification dataset:

  • citation graph--Cora, Citeseer, Pubmed

set up details:

  • 2-layer GAT
  • 第一層:K=8, activation=ELU, F'=8
  • 第二層:a. Cora, Citeseer K=1, activation=softmax ; b. Pubmed K=8
  • L_2正則化:a. Cora, Citeseer \lambda=0.0005; b. Pubmed \lambda=0.001
  • 加入dropout層: p=0.6

4.2 Inductive learning

protein-protein interaction(PPI) dataset由包含24個graph奴曙。在訓(xùn)練集上訓(xùn)練得到每一層的參數(shù)W^{(l)}, a^{(l)},再利用這個參數(shù)的到val/test set的節(jié)點表示和進行節(jié)點multi-label分類任務(wù)草讶。

set up details:

  • 3-layer GAT
  • layer 1 and layer 2: K=4, F'=256, activation=ELU
  • layer 3: K=6, activation=sigmoid
  • skip connection
  • batch size=2

補充:ELU激活函數(shù)
ELU(x)= \left\{ \begin{array} xx, & if \quad x\geq 0 \\ \alpha(e^x-1), & if \quad x\leq 0 \end{array} \right.

5. Code

本小節(jié)主要講GAT的實現(xiàn)代碼洽糟。第一部分講GATLayer如何實現(xiàn)的,主要通過dgl的框架看一下大致的整個代碼的實現(xiàn)思路,完整代碼可以看reference的源碼脊框;第二部分講基于GATLayer如何構(gòu)建GATmodel颁督。

參考DGL有關(guān)GAT的詳細說明以及DGL中GAT示例代碼

5.1 GATLayer

==Steps==:

a. 全連接層full connected layer:z_i^{(l)} = W^{(l)} h_i^{(l)}浇雹,將高維轉(zhuǎn)為較低維特征
b. message--計算沒經(jīng)過正則化(un-normalized)的attention score e_{ij}e_{ij}^{(l)}=LeakyReLU({{\bf{a}}^{l}}^T[z_i^{(l)}||z_j^{(l)}])沉御,這個score可以看做edge的特征
c. reduce

  • normalize: 計算attention score \alpha_{ij}: \alpha_{ij}^{(l)}=softmax{e_{ij}^{(l)}, dim=1}
  • aggregate: h_i^{(l+1)} = \sigma(\Sigma_{j \in N_i} \alpha_{ij}^{(l)} z_j^{(l)})

from dgl.nn.pytorch import GATConv : GATConv源碼

# GATConv Layer源碼關(guān)鍵部分(需要注意的地方)
# 主要展示了參數(shù)和residual connection部分
def __init__(self,
             in_feats,
             out_feats,
             num_heads,
             feat_drop=0.,        # dropout
             attn_drop=0.,
             negative_slope=0.2,  # leakyrelu
             residual=False,      # 是否連接residual
             activation=None,
             allow_zero_in_degree=False):
    #...
    if residual:
            if self._in_dst_feats != out_feats:
                self.res_fc = nn.Linear(
                    self._in_dst_feats, num_heads * out_feats, bias=False)
            else:
                self.res_fc = Identity()
    
    def forward(self, graph,...):
        # ...
        # residual
        if self.res_fc is not None:
            resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
            # h_(l+1)' = h_(l+1) + Wh_(l)
            rst = rst + resval

DGL有關(guān)GAT的詳細說明中有關(guān)于GATLayer的簡易實現(xiàn):

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')
    
    
# multi-heads通過疊加多個GATLayer實現(xiàn)
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

5.2 GAT model

tips:

  • 第一層hidden layer沒有residual connection
  • output layer 沒有activation,其多頭注意力機制采用均值的方法
class GAT(nn.Module):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
                 negative_slope,
                 residual):
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0],
            feat_drop, attn_drop, negative_slope, False, self.activation))
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(GATConv(
                num_hidden * heads[l-1], num_hidden, heads[l],
                feat_drop, attn_drop, negative_slope, residual, self.activation))
        # output projection
        self.gat_layers.append(GATConv(
            num_hidden * heads[-2], num_classes, heads[-1],
            feat_drop, attn_drop, negative_slope, residual, None))

    def forward(self, inputs):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](self.g, h).flatten(1)
        # output projection
        logits = self.gat_layers[-1](self.g, h).mean(1)  # mean aggregation
        return logits

--end--
如果有講得不清楚的地方昭灵,歡迎提問和提意見~

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末吠裆,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子烂完,更是在濱河造成了極大的恐慌试疙,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,366評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件抠蚣,死亡現(xiàn)場離奇詭異祝旷,居然都是意外死亡,警方通過查閱死者的電腦和手機嘶窄,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,521評論 3 395
  • 文/潘曉璐 我一進店門怀跛,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人柄冲,你說我怎么就攤上這事吻谋。” “怎么了现横?”我有些...
    開封第一講書人閱讀 165,689評論 0 356
  • 文/不壞的土叔 我叫張陵漓拾,是天一觀的道長。 經(jīng)常有香客問我戒祠,道長骇两,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,925評論 1 295
  • 正文 為了忘掉前任姜盈,我火速辦了婚禮脯颜,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘贩据。我一直安慰自己,他們只是感情好闸餐,可當我...
    茶點故事閱讀 67,942評論 6 392
  • 文/花漫 我一把揭開白布饱亮。 她就那樣靜靜地躺著,像睡著了一般舍沙。 火紅的嫁衣襯著肌膚如雪近上。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,727評論 1 305
  • 那天拂铡,我揣著相機與錄音壹无,去河邊找鬼葱绒。 笑死,一個胖子當著我的面吹牛斗锭,可吹牛的內(nèi)容都是我干的地淀。 我是一名探鬼主播,決...
    沈念sama閱讀 40,447評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼岖是,長吁一口氣:“原來是場噩夢啊……” “哼帮毁!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起豺撑,我...
    開封第一講書人閱讀 39,349評論 0 276
  • 序言:老撾萬榮一對情侶失蹤烈疚,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后聪轿,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體爷肝,經(jīng)...
    沈念sama閱讀 45,820評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,990評論 3 337
  • 正文 我和宋清朗相戀三年陆错,在試婚紗的時候發(fā)現(xiàn)自己被綠了灯抛。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,127評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡危号,死狀恐怖牧愁,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情外莲,我是刑警寧澤猪半,帶...
    沈念sama閱讀 35,812評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站偷线,受9級特大地震影響磨确,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜声邦,卻給世界環(huán)境...
    茶點故事閱讀 41,471評論 3 331
  • 文/蒙蒙 一乏奥、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧亥曹,春花似錦邓了、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,017評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至蛇受,卻和暖如春句葵,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,142評論 1 272
  • 我被黑心中介騙來泰國打工乍丈, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留剂碴,地道東北人。 一個月前我還...
    沈念sama閱讀 48,388評論 3 373
  • 正文 我出身青樓轻专,卻偏偏與公主長得像忆矛,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子铭若,可洞房花燭夜當晚...
    茶點故事閱讀 45,066評論 2 355

推薦閱讀更多精彩內(nèi)容