Graph attention network

Authors: Hao Zhang, Mufei Li, Minjie Wang Zheng Zhang

In this tutorial, you learn about a graph attention network (GAT) and how it can be implemented in PyTorch. You can also learn to visualize and understand what the attention mechanism has learned.

The research described in the paper Graph Convolutional Network (GCN), indicates that combining local graph structure and node-level features yields good performance on node classification tasks. However, the way GCN aggregates is structure-dependent, which can hurt its generalizability.

One workaround is to simply average over all neighbor node features as described in the research paper GraphSAGE. However, Graph Attention Network proposes a different type of aggregation. GAN uses weighting neighbor features with feature dependent and structure-free normalization, in the style of attention.

Introducing attention to GCN

The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.

For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors.

h_i^{l+1}=\sigma(\sum_{j\in N(i)}\frac{1}{c_{ij}}W^{(l)}h_j^{(l)})

where N(i) is the set of its one-hop neighbors (to include v_i in the set, simply add a self-loop to each node)愈案,c_{ij}=\sqrt{|N(i)|}\sqrt{|N(j)|} is a normalization constant based on graph structure, sigma is an activation function (GCN uses ReLU), and W(l) is a shared weight matrix for node-wise feature transformation. Another model proposed in GraphSAGE employs the same update rule except that they set c_{ij}=|N(i)|.

GAT introduces the attention mechanism as a substitute for the statically normalized convolution operation. Below are the equations to compute the node embedding h^{(l+1)}_i of layer l+1 from the embeddings of layer l.

z_i^{(l)}=W^{(l)}h_i^{(l)}\tag{1}
e_{ij}^{(l)}=LeakyReLU(\vec{a}^{(l)^T}(z_i^{(l)}||z_j^{(l)}))\tag{2}
\alpha_{ij}^{(l)}=\frac{exp(e_{ij}^{(l)})}{\sum_{k\in N(i)}exp(e_{ik}^{(l)})}\tag{3}
h_i^{(l+1)}=\sigma(\sum_{j\in N(i)}\alpha_{ij}^{(l)}z_j^{(l)})\tag{4}

Explanations:

  • Equation (1) is a linear transformation of the lower layer embedding h^{(l)}_i and W^{(l)} is its learnable weight matrix.
  • Equation (2) computes a pair-wise un-normalized attention score between two neighbors. Here, it first concatenates the z embeddings of the two nodes, where || denotes concatenation, then takes a dot product of it and a learnable weight vector \vec{a}^{(l)}, and applies a LeakyReLU in the end. This form of attention is usually called additive attention, contrast with the dot-product attention in the Transformer model.
  • Equation (3) applies a softmax to normalize the attention scores on each node’s incoming edges.
  • Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together, scaled by the attention scores.

There are other details from the paper, such as dropout and skip connections. For the purpose of simplicity, those details are left out of this tutorial. To see more details, download the full example. In its essence, GAT is just a different aggregation function with attention over features of neighbors, instead of a simple mean aggregation.

GAT in DGL

DGL provides an off-the-shelf implementation of the GAT layer under the dgl.nn.<backend> subpackage. Simply import the GATConv as the follows.

from dgl.nn.pytorch import GATConv

Readers can skip the following step-by-step explanation of the implementation and jump to the Put everything together for training and visualization results.

To begin, you can get an overall impression about how a GATLayer module is implemented in DGL. In this section, the four equations above are broken down one at a time.

import torch
import torch.nn as nn
import torch.nn.functional as F


class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

Equation (1)

z_i^{(l)}=W^{(l)}h_i^{(l)}\tag{1}

The first one shows linear transformation. It’s common and can be easily implemented in Pytorch using torch.nn.Linear.

Equation (2)

e_{ij}^{(l)}=LeakyReLU(\vec{a}^{(l)^T}(z_i^{(l)}||z_j^{(l)}))\tag{2}

The un-normalized attention score eij is calculated using the embeddings of adjacent nodes i and j. This suggests that the attention scores can be viewed as edge data, which can be calculated by the apply_edges API. The argument to the apply_edges is an Edge UDF, which is defined as below:

def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e' : F.leaky_relu(a)}

Here, the dot product with the learnable weight vector \vec{a^{(l)}} is implemented again using PyTorch’s linear transformation attn_fc. Note that apply_edges will batch all the edge data in one tensor, so the cat, attn_fc here are applied on all the edges in parallel.

Equation (3) & (4)

\alpha_{ij}^{(l)}=\frac{exp(e_{ij}^{(l)})}{\sum_{k\in N(i)}exp(e_{ik}^{(l)})}\tag{3}
h_i^{(l+1)}=\sigma(\sum_{j\in N(i)}\alpha_{ij}^{(l)}z_j^{(l)})\tag{4}

Similar to GCN, update_all API is used to trigger message passing on all the nodes. The message function sends out two tensors: the transformed z embedding of the source node and the un-normalized attention score e on each edge. The reduce function then performs two tasks:

  • Normalize the attention scores using softmax (equation (3)).
  • Aggregate neighbor embeddings weighted by the attention scores (equation(4)).

Both tasks first fetch data from the mailbox and then manipulate it on the second dimension (dim=1), on which the messages are batched.

def reduce_func(self, nodes):
    # reduce UDF for equation (3) & (4)
    # equation (3)
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    # equation (4)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h' : h}

Multi-head attention

Analogous to multiple channels in ConvNet, GAT introduces multi-head attention to enrich the model capacity and to stabilize the learning process. Each attention head has its own parameters and their outputs can be merged in two ways:

concatenation:h_i^{(l+1)}=||_{k=1}^K\sigma(\sum_{j\in N(i)}\alpha_{ij}^kW^kh_j^{(l)})
or
average:h_i^{(l+1)}=\sigma(\frac{1}{K}\sum_{k=1}^K\sum_{j\in N(i)}\alpha_{ij}^kW^kh_j^{(l)})

where K is the number of heads. You can use concatenation for intermediary layers and average for the final layer.

Use the above defined single-head GATLayer as the building block for theMultiHeadGATLayer below:

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

Put everything together

Now, you can define a two-layer GAT model.

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

We then load the Cora dataset using DGL’s built-in data module.

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx

def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.BoolTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

The training loop is exactly the same as in the GCN tutorial.

import time
import numpy as np

g, features, labels, mask = load_cora_data()

# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
        epoch, loss.item(), np.mean(dur)))

# Out:
Epoch 00000 | Loss 1.9495 | Time(s) nan
Epoch 00001 | Loss 1.9473 | Time(s) nan
Epoch 00002 | Loss 1.9451 | Time(s) nan
Epoch 00003 | Loss 1.9429 | Time(s) 0.1162
Epoch 00004 | Loss 1.9408 | Time(s) 0.1123
Epoch 00005 | Loss 1.9386 | Time(s) 0.1437
Epoch 00006 | Loss 1.9364 | Time(s) 0.1400
Epoch 00007 | Loss 1.9342 | Time(s) 0.1393
Epoch 00008 | Loss 1.9320 | Time(s) 0.1366
Epoch 00009 | Loss 1.9298 | Time(s) 0.1406
Epoch 00010 | Loss 1.9276 | Time(s) 0.1410
Epoch 00011 | Loss 1.9254 | Time(s) 0.1401
Epoch 00012 | Loss 1.9232 | Time(s) 0.1385
Epoch 00013 | Loss 1.9210 | Time(s) 0.1360
Epoch 00014 | Loss 1.9187 | Time(s) 0.1352
Epoch 00015 | Loss 1.9165 | Time(s) 0.1332
Epoch 00016 | Loss 1.9142 | Time(s) 0.1311
Epoch 00017 | Loss 1.9120 | Time(s) 0.1291
Epoch 00018 | Loss 1.9097 | Time(s) 0.1273
Epoch 00019 | Loss 1.9074 | Time(s) 0.1256
Epoch 00020 | Loss 1.9051 | Time(s) 0.1242
Epoch 00021 | Loss 1.9028 | Time(s) 0.1229
Epoch 00022 | Loss 1.9005 | Time(s) 0.1219
Epoch 00023 | Loss 1.8982 | Time(s) 0.1209
Epoch 00024 | Loss 1.8958 | Time(s) 0.1207
Epoch 00025 | Loss 1.8935 | Time(s) 0.1204
Epoch 00026 | Loss 1.8911 | Time(s) 0.1198
Epoch 00027 | Loss 1.8887 | Time(s) 0.1192
Epoch 00028 | Loss 1.8863 | Time(s) 0.1191
Epoch 00029 | Loss 1.8839 | Time(s) 0.1187

Visualizing and understanding attention learned Cora

The following table summarizes the model performance on Cora that is reported in the GAT paper and obtained with DGL implementations.

Model Accuracy
GCN (paper) 81.4±0.5
GCN (dgl) 82.05±0.33
GAT (paper) 83.0±0.7
GAT (dgl) 83.69±0.529

What kind of attention distribution has our model learned?
Because the attention weight a_{ij} is associated with edges, you can visualize it by coloring edges. Below you can pick a subgraph of Cora and plot the attention weights of the last GATLayer. The nodes are colored according to their labels, whereas the edges are colored according to the magnitude of the attention weights, which can be referred with the colorbar on the right.

You can see that the model seems to learn different attention weights. To understand the distribution more thoroughly, measure the entropy of the attention distribution. For any node i腕铸,{\alpha_{ij}_{j\in N(i)}} forms a discrete probability distribution over all its neighbors with the entropy given by

H(\alpha_{ij_{j\in N(i)}}=-\sum_{j\in N(i)}\alpha_{ij}log \alpha_{ij})

A low entropy means a high degree of concentration, and vice versa. An entropy of 0 means all attention is on one source node. The uniform distribution has the highest entropy of log(N(i)). Ideally, you want to see the model learns a distribution of lower entropy (i.e, one or two neighbors are much more important than the others).

Note that since nodes can have different degrees, the maximum entropy will also be different. Therefore, you plot the aggregated histogram of entropy values of all nodes in the entire graph. Below are the attention histogram of learned by each attention head.

As a reference, here is the histogram if all the nodes have uniform attention weight distribution.

One can see that the attention values learned is quite similar to uniform distribution (i.e, all neighbors are equally important). This partially explains why the performance of GAT is close to that of GCN on Cora (according to author’s reported result, the accuracy difference averaged over 100 runs is less than 2 percent). Attention does not matter since it does not differentiate much.

Does that mean the attention mechanism is not useful? No! A different dataset exhibits an entirely different pattern, as you can see next.

Protein-protein interaction (PPI) networks

The PPI dataset used here consists of 24 graphs corresponding to different human tissues. Nodes can have up to 121 kinds of labels, so the label of node is represented as a binary tensor of size 121. The task is to predict node label.

Use 20 graphs for training, 2 for validation and 2 for test. The average number of nodes per graph is 2372. Each node has 50 features that are composed of positional gene sets, motif gene sets, and immunological signatures. Critically, test graphs remain completely unobserved during training, a setting called “inductive learning”.

Compare the performance of GAT and GCN for 10 random runs on this task and use hyperparameter search on the validation set to find the best model.

Model F1 Score(micro)
GAT 0.975±0.006
GCN 0.509±0.025
Paper 0.973±0.002

The table above is the result of this experiment, where you use micro F1 score to evaluate the model performance.

During training, use BCEWithLogitsLoss as the loss function. The learning curves of GAT and GCN are presented below; what is evident is the dramatic performance adavantage of GAT over GCN.

As before, you can have a statistical understanding of the attentions learned by showing the histogram plot for the node-wise attention entropy. Below are the attention histograms learned by different attention layers.

Attention learned in layer 1:


layer 1

Attention learned in layer 2:


layer 2

Attention learned in final layer:


Again, comparing with uniform distribution:


Clearly, GAT does learn sharp attention weights! There is a clear pattern over the layers as well: the attention gets sharper with a higher layer.

Unlike the Cora dataset where GAT’s gain is minimal at best, for PPI there is a significant performance gap between GAT and other GNN variants compared in the GAT paper (at least 20 percent), and the attention distributions between the two clearly differ. While this deserves further research, one immediate conclusion is that GAT’s advantage lies perhaps more in its ability to handle a graph with more complex neighborhood structure.

參考鏈接:https://docs.dgl.ai/tutorials/models/1_gnn/9_gat.html#sphx-glr-tutorials-models-1-gnn-9-gat-py

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末羹幸,一起剝皮案震驚了整個(gè)濱河市免猾,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌叫胖,老刑警劉巖脆侮,帶你破解...
    沈念sama閱讀 219,490評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件蟀给,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡叉跛,警方通過(guò)查閱死者的電腦和手機(jī)松忍,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,581評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)筷厘,“玉大人鸣峭,你說(shuō)我怎么就攤上這事∷盅蓿” “怎么了摊溶?”我有些...
    開封第一講書人閱讀 165,830評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)充石。 經(jīng)常有香客問(wèn)我莫换,道長(zhǎng),這世上最難降的妖魔是什么骤铃? 我笑而不...
    開封第一講書人閱讀 58,957評(píng)論 1 295
  • 正文 為了忘掉前任拉岁,我火速辦了婚禮,結(jié)果婚禮上惰爬,老公的妹妹穿的比我還像新娘喊暖。我一直安慰自己,他們只是感情好补鼻,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,974評(píng)論 6 393
  • 文/花漫 我一把揭開白布哄啄。 她就那樣靜靜地躺著,像睡著了一般风范。 火紅的嫁衣襯著肌膚如雪咨跌。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,754評(píng)論 1 307
  • 那天硼婿,我揣著相機(jī)與錄音锌半,去河邊找鬼。 笑死寇漫,一個(gè)胖子當(dāng)著我的面吹牛刊殉,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播州胳,決...
    沈念sama閱讀 40,464評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼记焊,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了栓撞?” 一聲冷哼從身側(cè)響起遍膜,我...
    開封第一講書人閱讀 39,357評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎瓤湘,沒想到半個(gè)月后瓢颅,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,847評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡弛说,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,995評(píng)論 3 338
  • 正文 我和宋清朗相戀三年挽懦,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片木人。...
    茶點(diǎn)故事閱讀 40,137評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡信柿,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出醒第,到底是詐尸還是另有隱情角塑,我是刑警寧澤,帶...
    沈念sama閱讀 35,819評(píng)論 5 346
  • 正文 年R本政府宣布淘讥,位于F島的核電站圃伶,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏蒲列。R本人自食惡果不足惜窒朋,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,482評(píng)論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望蝗岖。 院中可真熱鬧侥猩,春花似錦、人聲如沸抵赢。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,023評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至划提,卻和暖如春枫弟,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背鹏往。 一陣腳步聲響...
    開封第一講書人閱讀 33,149評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工淡诗, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人伊履。 一個(gè)月前我還...
    沈念sama閱讀 48,409評(píng)論 3 373
  • 正文 我出身青樓韩容,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親唐瀑。 傳聞我的和親對(duì)象是個(gè)殘疾皇子群凶,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,086評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容

  • 本文轉(zhuǎn)載自知乎 作者:季子烏 筆記版權(quán)歸筆記作者所有 其中英文語(yǔ)句取自:英語(yǔ)流利說(shuō)-懂你英語(yǔ) ——————————...
    Danny_Edward閱讀 43,876評(píng)論 4 38
  • 風(fēng)兒輕撫嫵媚長(zhǎng)發(fā) 悄悄滴落林里 婀娜的柳對(duì)著月梳妝 纖纖綠指將秀發(fā)挽成翠鈿云髻 插上月贈(zèng)的銀釵 嬈嬈多姿 月光溫柔...
    金指尖的花園閱讀 261評(píng)論 2 7
  • 樹青了又會(huì)黃 天亮了又要黑 路邊的樹卻永遠(yuǎn)都綠 就像它的花一直都紅 湖面的水沒有波瀾 春天的風(fēng)都只有一種味道 想要...
    小一yoyo閱讀 162評(píng)論 0 0