一文詳解圖神經(jīng)網(wǎng)絡(luò)(sa)

5.2 《Semi-Supervised Classification with Graph Convolutional Networks》

這篇論文受到譜圖卷積的局部一階近似可以用于對局部圖結(jié)構(gòu)與節(jié)點的特征進行編碼從而確定卷積網(wǎng)絡(luò)結(jié)構(gòu)的啟發(fā)压怠,提出了一種可擴展的圖卷積的實現(xiàn)方法搪花,可用于具有圖結(jié)構(gòu)數(shù)據(jù)的半監(jiān)督學(xué)習(xí)

5.2.1 GCN定義

按照圖傅里葉變換的性質(zhì)摘投,可以得到如下圖卷積的定義:
(\boldsymbol{f} * \boldsymbol{h})_{\mathcal{G}}=\boldsymbol{\Phi} \operatorname{diag}\left[\hat{h}\left(\lambda_{1}\right), \ldots, \hat{h}\left(\lambda_{n}\right)\right] \mathbf{\Phi}^{T} \boldsymbol{f}\tag{46}
其中:

  • 對于圖\boldsymbol{f}的傅里葉變換為的傅里葉變換為的傅里葉變換為\boldsymbol{\hat{f}}=\mathbf{\Phi}^{T} \boldsymbol{f}
  • 對于卷積核的圖傅里葉變換:\hat{h}=\left(\hat{h}_{1}, \ldots, \hat{h}_{n}\right)其中\hat{h}_{k}=\left\langle h, \phi_{k}\right\rangle, k=1,2 \ldots,按照矩陣形式就是\hat{\boldsymbol{h}}=\mathbf{\Phi}^{T} \boldsymbol{h}
  • 對兩者的傅里葉變換向量\hat{f} \in \mathbb{R}^{N \times 1}\hat{h} \in \mathbb{R}^{N \times 1}element-wise乘積等價于將{h}組織成對角矩陣此改,即\operatorname{diag}\left[\hat{h}\left(\lambda_{k}\right)\right] \in \mathbb{R}^{N \times N}串塑,然后再求\operatorname{diag}\left[\hat{h}\left(\lambda_{k}\right)\right]\boldsymbol{f}矩陣乘法
  • 求上述結(jié)果的傅里葉逆變換战转,即左乘\mathbf{\Phi}

深度學(xué)習(xí)中的卷積就是要設(shè)計trainable的卷積核躬审,從上式可以看出枉证,就是要設(shè)計\operatorname{diag}\left[\hat{h}\left(\lambda_{1}\right), \ldots, \hat{h}\left(\lambda_{n}\right)\right]矮男,由此,可以直接將其變?yōu)榫矸e核\operatorname{diag}\left[\theta_{1}, \ldots, \theta_{n}\right]室谚,而不需要再將卷積核進行傅里葉變換毡鉴,由此,相當(dāng)于直接將變換后的參量進行學(xué)習(xí)

第一代GCN

\boldsymbol{y}_{\text {output}}=\sigma\left(\mathbf{\Phi} \boldsymbol{g}_{\theta} \mathbf{\Phi}^{T} \boldsymbol{x}\right)=\sigma\left(\boldsymbol{\Phi} \operatorname{diag}\left[\theta_{1}, \ldots, \theta_{n}\right] \mathbf{\Phi}^{T} \boldsymbol{x}\right)\tag{47}

其中秒赤,\boldsymbol{x}就是graph上對應(yīng)每個節(jié)點的feature構(gòu)成的向量,x=\left(x_{1}, x_{2}, \ldots, x_{n}\right)入篮,這里暫時對每個節(jié)點都使用標(biāo)量陈瘦,然后經(jīng)過激活之后,得到輸出\boldsymbol{y}_{\text {output}}潮售,之后傳入下一層

一些缺點:

  • 需要對拉普拉斯矩陣進行譜分解來求\mathbf{\Phi}痊项,在graph很大的時候復(fù)雜度很高。另外酥诽,還需要計算矩陣乘積鞍泉,復(fù)雜度為O(n^2)
  • 卷積核參數(shù)為n,當(dāng)graph很大的時候肮帐,n會很大
  • 卷積核的spatial localization不好
第二代GCN

圖傅里葉變換是關(guān)于特征值(相當(dāng)于普通傅里葉變換的頻率)的函數(shù)咖驮,也就是F\left(\lambda_{1}\right), \ldots, F\left(\lambda_{n}\right),即F(\mathbf{\Lambda})训枢,因此游沿,將卷積核\boldsymbol{g}_{\theta}寫成\boldsymbol{g}_{\theta}(\Lambda),然后肮砾,將\boldsymbol{g}_{\theta}(\Lambda)定義為如下k階多項式:
g_{\theta^{\prime}}(\mathbf{\Lambda}) \approx \sum_{k=0}^{K} \theta_{k}^{\prime} \mathbf{\Lambda}^{k}\tag{48}
將卷積公式帶入,可以得到:
g_{\theta^{\prime}}*x≈\Phi\sum_{k=0}^K\theta^{\prime}_k\mathbf{\Lambda}^{k}\Phi^Tx\\ =\sum_{k=0}^K\theta^{\prime}_k(\Phi\mathbf{\Lambda}^{k}\Phi^T)x\\ =\sum_{k=0}^K\theta^{\prime}_k(\Phi\mathbf{\Lambda}\Phi^T)^{k}x\\ =\sum_{k=0}^K\theta^{\prime}_kL^{k}x\tag{49}
這一代的GCN不需要做特征分解袋坑,可以直接對Laplacian矩陣做變換仗处,通過事先將Laplacian矩陣求出來,以及\boldsymbol{L}^{k}求出來枣宫,前向傳播的時候婆誓,可以直接使用,復(fù)雜度為O(Kn^2)

對于每一次Laplacian矩陣\boldsymbol{L}\mathbf{x}相乘也颤,對于節(jié)點n洋幻,相當(dāng)于從鄰居節(jié)點ne[n]傳遞一次信息給節(jié)點n,由于連續(xù)乘以了k次Laplacian矩陣翅娶,那么相當(dāng)于n節(jié)點的k-hop之內(nèi)的節(jié)點能夠傳遞信息給n文留,因此好唯,實際上只利用了節(jié)點的K-Localized信息

另外,可以使用切比雪夫展開式來近似\boldsymbol{L}^{k}燥翅,任何k次多項式都可以使用切比雪夫展開式來近似骑篙,由此,引入切比雪夫多項式的K階截斷獲得\boldsymbol{L}^{k}近似森书,從而獲得對g_{\theta}(\mathbf{\Lambda})的近似
g_{\theta^{\prime}}(\mathbf{\Lambda}) \approx \sum_{k=0}^{K} \theta_{k}^{\prime} T_{k}(\tilde{\mathbf{\Lambda}})\tag{50}
其中靶端,\tilde{\mathbf{\Lambda}}=\frac{2}{\lambda_{\max }} \mathbf{\Lambda}-\boldsymbol{I}_{n}\boldsymbol{\theta}^{\prime} \in \mathbb{R}^{K}為切比雪夫向量凛膏,\theta_{k}^{\prime}為第k個分量杨名,切比雪夫多項式T_{k}(x)使用遞歸的方式進行定義:T_{k}(x)=2 x T_{k-1}(x)-T_{k-2}(x),其中猖毫,T_{0}(x)=1, T_{1}(x)=x台谍,此時,帶入到卷積公式:
g_{\theta^{\prime}}*x≈\Phi\sum_{k=0}^K\theta^{\prime}_kT_k(\tilde{\mathbf{\Lambda}})\Phi^Tx\\ ≈\sum_{k=0}^K\theta^{\prime}_k\Big(\Phi T_k(\tilde{\mathbf{\Lambda}})\Phi^T\Big)x\\ =\sum_{k=0}^K\theta^{\prime}_k T_k(\tilde{\boldsymbol{L}}){x}\tag{51}
其中鄙麦,\tilde{\boldsymbol{L}}=\frac{2}{\lambda_{\max }} \boldsymbol{L}-\boldsymbol{I}_{n}典唇,因此,可以得到輸出為:
\boldsymbol{y}_{\text {output}}=\sigma\left(\sum_{k=0}^{K} \theta_{k}^{\prime} T_{k}(\tilde{\boldsymbol{L}}) \boldsymbol{x}\right)\tag{52}

第三代GCN

直接取切比雪夫多項式中K=1胯府,此時模型是1階近似介衔,將K=1\lambda_{\max }=2帶入可以得到:
g_{\theta^{\prime}} * x≈\theta_{0}^{\prime}x+\theta_{1}^{\prime}(\boldsymbol{L}-\boldsymbol{I}_{n})x\\ =\theta_{0}^{\prime}x+\theta_{1}^{\prime}(\boldsymbol{L}-\boldsymbol{I}_{n})x\\ =\theta_{0}^{\prime}x+\theta_{1}^{\prime})(\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2})x\tag{53}
其中骂因,歸一化拉普拉斯矩陣\boldsymbol{L}=\boldsymbol{D}^{-1 / 2}(\boldsymbol{D}-\boldsymbol{W}) \boldsymbol{D}^{-1 / 2}=\boldsymbol{I}_{n}-\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2}炎咖。為了進一步簡化,令\theta_{0}^{\prime}=-\theta_{1}^{\prime}寒波,此時只含有一個參數(shù)\theta
g_{\theta^{\prime}} * x=\theta\left(I_{n}+D^{-1 / 2} W D^{-1 / 2}\right)\tag{54}
由于\boldsymbol{I}_{n}+\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2}的譜半徑[ 0 , 2 ]太大乘盼,使用歸一化的trick:
\boldsymbol{I}_{n}+\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2} \rightarrow \tilde{\boldsymbol{D}}^{-1 / 2} \tilde{\boldsymbol{W}} \tilde{\boldsymbol{D}}^{-1 / 2}\tag{55}
其中,\tilde{\boldsymbol{W}}=\boldsymbol{W}+\boldsymbol{I}_{n}俄烁,\tilde{D}_{i j}=\Sigma_{j} \tilde{W}_{i j}

由此绸栅,帶入卷積公式
\underbrace{g_{\theta^{\prime}} * x}_{\mathbb{R}^{n \times 1}}=\theta\left(\underbrace{\tilde{D}^{-1 / 2} \tilde{W} \tilde{D}^{-1 / 2}}_{\mathbb{R}^{n \times n}}\right) \underbrace{x}_{\mathbb{R}^{n \times 1}}\tag{56}
如果推廣到多通道,相當(dāng)于每一個節(jié)點的信息是向量
x \in \mathbb{R}^{N \times 1} \rightarrow X \in \mathbb{R}^{N \times C}
其中页屠,N是節(jié)點數(shù)量粹胯,C是通道數(shù),或者稱作表示節(jié)點的信息維度數(shù)辰企。\mathbf{X}是節(jié)點的特征矩陣风纠。相應(yīng)的卷積核參數(shù)變化:
\theta \in \mathbb{R} \rightarrow \Theta \in \mathbb{R}^{C \times F}\tag{57}
其中,F為卷積核數(shù)量牢贸,那么卷積結(jié)果寫成矩陣形式為:
\underbrace{Z}_{\mathbb{R}^{N \times F}}=\underbrace{\tilde{D}^{-1 / 2} \tilde{W} \tilde{D}^{-1 / 2}}_{\mathbb{R}^{N \times N}} \underbrace{X}_{\mathbb{R}^{N \times C}} \underbrace{\mathbf{\Theta}}_{\mathbb{R}^{C \times F}}\tag{58}
上述操作可以疊加多層竹观,對上述輸出激活一下,就可以作為下一層節(jié)點的特征矩陣

一些特點:

  • K=1,相當(dāng)于直接取鄰域信息臭增,類似于3\times{3}的卷積核
  • 由于卷積核寬度減小懂酱,可以通過增加卷積層數(shù)來擴大感受野,從而增強網(wǎng)絡(luò)的表達(dá)能力
  • 增加了參數(shù)約束速址,比如\lambda_{\max } \approx 2玩焰,引入歸一化操作
5.2.2 論文模型

論文采用兩層的GCN,用來在graph上進行半監(jiān)督的節(jié)點分類任務(wù)芍锚,鄰接矩陣為A昔园,首先計算出\hat{A}=\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}},由此并炮,前向網(wǎng)絡(luò)模型形式如下:
Z=f(X, A)=\operatorname{softmax}\left(\hat{A} \operatorname{ReLU}\left(\hat{A} X W^{(0)}\right) W^{(1)}\right)\tag{59}
其中默刚,W^{(0)} \in \mathbb{R}^{C \times H}為輸入層到隱藏層的權(quán)重矩陣,隱藏層的特征維度為H逃魄,W^{(1)} \in \mathbb{R}^{H \times F}為隱藏層到輸出層的權(quán)重矩陣荤西,softmax激活函數(shù)定義為\operatorname{softmax}\left(x_{i}\right)=\frac{1}{\mathcal{Z}} \exp \left(x_{i}\right)\mathcal{Z}=\sum_{i} \exp \left(x_{i}\right)伍俘,相當(dāng)于對每一列做softmax邪锌,由此得到交叉熵?fù)p失函數(shù)為:
\mathcal{L}=-\sum_{l \in \mathcal{Y}_{L}} \sum_{f=1}^{F} Y_{l f} \ln Z_{l f}\tag{60}
其中,\mathcal{Y}_{L}為帶有標(biāo)簽的節(jié)點集合

5.2.3 代碼實現(xiàn)
import torch_geometric.nn as gnn
class GCN(nn.Module):
    def __init__(self, config, in_channels, out_channels):
        '''
            in_channels : num of node features
            out_channels: num of class
        '''
        super().__init__()
        self.config = config
        self.hidden_dim = config.hidden_dim
        self.dropout_rate = config.dropout_rate
        self.conv1 = gnn.GCNConv(in_channels, self.hidden_dim, improved = False, cached=True, bias=True, normalize=True)
        self.conv2 = gnn.GCNConv(self.hidden_dim, out_channels, improved = False, cached=True, bias=True, normalize=True)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        #x = F.dropout(x, p=self.dropout_rate) # If no drop out, accuracy 0.75 --> 0.80
        x = self.conv2(x, edge_index)
        #x = F.dropout(x, p=self.dropout_rate) # there are two dropout.. But performance bad.
        return x  

5.3 《Diffusion-Convolutional Neural Networks》

該模型對每一個節(jié)點(或邊癌瘾、或圖)采用H個hop的矩陣進行表示觅丰,每一個hop都表示該鄰近范圍的鄰近信息,由此妨退,對于局部信息的獲取效果比較好妇萄,得到的節(jié)點的representation的表示能力很強

5.3.1 任務(wù)定義
  • 一個graph數(shù)據(jù)集\mathcal{G}=\left\{G_{t} | t \in 1 \ldots T\right\}
  • graph定義為G_{t}=\left(V_{t}, E_{t}\right),其中咬荷,V_t為節(jié)點集合冠句,E_t為邊集合
  • 所有節(jié)點的特征矩陣定義為X_t,大小為N_t\times{F}幸乒,其中懦底,N_t為圖G_t的節(jié)點個數(shù),F為節(jié)點特征維度
  • 邊信息E_t定義為N_t\times{}N_t的鄰接矩陣A_t罕扎,由此可以計算出節(jié)點度(degree)歸一化的轉(zhuǎn)移概率矩陣P_t基茵,表示從i節(jié)點轉(zhuǎn)移到j節(jié)點的概率

模型的目標(biāo)為預(yù)測Y,也就是預(yù)測每一個圖的節(jié)點標(biāo)簽壳影,或者邊的標(biāo)簽,或者每一個圖的標(biāo)簽弥臼,在每一種情況中宴咧,模型輸入部分帶有標(biāo)簽的數(shù)據(jù)集合,然后預(yù)測剩下的數(shù)據(jù)的標(biāo)簽径缅。DCNN模型輸入圖\mathcal{G}掺栅,返回硬分類預(yù)測值Y或者條件分布概率\mathbb{P}(Y|X)烙肺。該模型將每一個預(yù)測的目標(biāo)對象(節(jié)點、邊或圖)轉(zhuǎn)化為一個diffusion-convolutional representation氧卧,大小為H\times{}F桃笙,H表示擴散的hops,表示為Z_t

  • 對于節(jié)點分類任務(wù)沙绝,表示為Z_t為大小為N_t\times{H}\times{F}的矩陣
  • 對于圖分類任務(wù)搏明,張量Z_t為大小為H\times{F}的矩陣
  • 對于邊分類任務(wù),張量Z_t為大小為M_t\times{H}\times{F}的矩陣
5.3.2 論文模型
  1. 對于節(jié)點分類任務(wù)闪檬,假設(shè)P_t^*P_t的power series星著,大小為N_t\times{H}\times{N_t},那么對于圖t的節(jié)點i粗悯,第j個hop虚循,第k維特征值Z_{tijk}計算公式為:
    Z_{t i j k}=f\left(W_{j k}^{c} \cdot \sum_{l=1}^{N_{t}} P_{t i j l}^{*} X_{t l k}\right)\tag{61}
    使用矩陣表示為:
    Z_{t}=f\left(W^{c} \odot P_{t}^{*} X_{t}\right)\tag{62}
    其中\odot表示element-wise multiplication,由于模型只考慮H跳的參數(shù)样傍,即參數(shù)量為O(H\times{F})横缔,使得diffusion-convolutional representation不受輸入大小的限制

    在計算出Z之后,過一層全連接得到輸出Y衫哥,使用\hat{Y}表示硬分類預(yù)測結(jié)果茎刚,使用\mathbb{P}(Y|X)表示預(yù)測概率,計算方式如下:
    \hat{Y}=\arg \max \left(f\left(W^7jp1bfv \odot Z\right)\right)\tag{63}\\ \mathbb{P}(Y | X)=\operatorname{softmax}\left(f\left(W^ogj4sra \odot Z\right)\right)

  2. 對于圖分類任務(wù)炕檩,直接采用所有節(jié)點表示的均值作為graph的representation
    Z_{t}=f\left(W^{c} \odot 1_{N_{t}}^{T} P_{t}^{*} X_{t} / N_{t}\right)\tag{64}
    其中斗蒋,1_{N_t}是全為1的N_t\times{1}的向量

  3. 對于邊分類任務(wù),通過將每一條邊轉(zhuǎn)化為一個節(jié)點來進行訓(xùn)練和預(yù)測笛质,這個節(jié)點與原來的邊對應(yīng)的首尾節(jié)點相連泉沾,轉(zhuǎn)化后的圖的鄰接矩陣A_t'可以直接從原來的鄰接矩陣A_t增加一個incidence matrix得到:
    A_{t}^{\prime}= \begin{matrix} A_t & B_t^T\\ B_t & 0\\ \end{matrix} \tag{65}
    之后,使用A_t'來計算P_t'妇押,并用來替換P_t來進行分類跷究,對于模型訓(xùn)練,使用梯度下降法敲霍,并采用early-stop方式得到最終模型

5.3.3 代碼實現(xiàn)
import lasagne
import lasagne.layers
import theano
import theano.tensor as T
import numpy as np
class DCNNLayer(lasagne.layers.MergeLayer):
    """A node-level DCNN layer.
    This class contains the (symbolic) Lasagne internals for a node-level DCNN layer.  This class should
    be used in conjunction with a user-facing model class.
    """
    def __init__(self, incomings, parameters, layer_num,
                 W=lasagne.init.Normal(0.01),
                 num_features=None,
                 **kwargs):
        super(DCNNLayer, self).__init__(incomings, **kwargs)
        self.parameters = parameters
        if num_features is None:
            self.num_features = self.parameters.num_features
        else:
            self.num_features = num_features
        self.W = T.addbroadcast(
            self.add_param(W,
                           (1, parameters.num_hops + 1, self.num_features), name='DCNN_W_%d' % layer_num), 0)
        self.nonlinearity = params.nonlinearity_map[self.parameters.dcnn_nonlinearity]
    def get_output_for(self, inputs, **kwargs):
        """Compute diffusion convolutional activation of inputs."""
        Apow = inputs[0]
        X = inputs[1]
        Apow_dot_X = T.dot(Apow, X)
        Apow_dot_X_times_W = Apow_dot_X * self.W
        out = self.nonlinearity(Apow_dot_X_times_W)
        return out
    def get_output_shape_for(self, input_shapes):
        """Return the layer output shape."""
        shape = (None, self.parameters.num_hops + 1, self.num_features)
        return shape

5.4 《Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks》

將序列型的LSTM模型擴展到樹型的LSTM模型俊马,簡稱Tree-LSTM,并根據(jù)孩子節(jié)點是否有序肩杈,論文提出了兩個模型變體柴我,Child-Sum Tree-LSTM模型和N-ary Tree-LSTM模型。和序列型的LSTM模型的主要不同點在于扩然,序列型的LSTM從前一時刻獲取隱藏狀態(tài)h_t艘儒,而樹型的LSTM從其所有的孩子節(jié)點獲取隱藏狀態(tài)

5.4.1 論文模型

Tree-LSTM模型對于每一個孩子節(jié)點都會產(chǎn)生一個遺忘門¥f_{jk}¥,這個使得模型能夠從所有的孩子節(jié)點選擇性地獲取信息和結(jié)合信息

  1. Child-Sum Tree-LSTMs
    該模型的更新方程如下:
    \widetilde{h_j}=\sum_{k\in C(j)}h_k\\ i_j=\sigma({W}^{(i)}x_j+{U}^{(i)}\widetilde{h_j}+b^{(i)})\\ f_{ik}=\sigma({W}^{(f)}x_j+{U}^{(f)}{h_k}+b^{(f)})\\ o_{j}=\sigma({W}^{(o)}x_j+{U}^{(o)}\widetilde{h_j}+b^{(o)})\\ u_{j}=tanh({W}^{(u)}x_j+{U}^{(u)}\widetilde{h_j}+b^{(u)})\\ c_j=i_j⊙u_j+\sum_{k\in C(j)}f_{ij}⊙c_k\\ h_j=o_j⊙tanh(c_j)\tag{66}
    其中,C(j)表示j節(jié)點的鄰居節(jié)點的個數(shù)界睁,h_k表示節(jié)點k的隱藏狀態(tài)觉增,i_j表示節(jié)點j輸入門f_{jk}表示節(jié)點j的鄰居節(jié)點k遺忘門翻斟,o_j表示節(jié)點j輸出門

    這里的關(guān)鍵點在于第三個公式的f_{jk}薯嗤,這個模型對節(jié)點j的每個鄰居節(jié)點k都計算了對應(yīng)的遺忘門向量哼凯,然后在第六行中計算c_j時對鄰居節(jié)點的信息進行遺忘組合

    由于該模型是對所有的孩子節(jié)點求和,所以這個模型對于節(jié)點順序不敏感的,適合于孩子節(jié)點無序的情況

  2. N-ary Tree-LSTMs
    假如一個樹的最大分支數(shù)為N(即孩子節(jié)點最多為N個)同规,而且孩子節(jié)點是有序的盏浇,對于節(jié)點j嬉愧,對于該節(jié)點的第k個孩子節(jié)點的隱藏狀態(tài)和記憶單元分別用h_{jk}c_{jk}表示曼氛,模型的方程如下:
    i_j=\sigma({W}^{(i)}x_j+\sum^N_{?=1}{U}_?^{(i)}{h_{j?}}+b^{(i)})\\ f_{ik}=\sigma({W}^{(f)}x_j+\sum^N_{?=1}{U}_{k?}^{(f)}{h_{j?}}+b^{(f)})\\ o_{j}=\sigma({W}^{(o)}x_j+\sum^N_{?=1}{U}_?^{(o)}{h_{j?}}+b^{(o)})\\ u_{j}=tanh({W}^{(u)}x_j+\sum^N_{?=1}{U}_?^{(a)}{h_{j?}}+b^{(a)})\\ c_j=i_j⊙u_j+\sum^N_{?=1}f_{j?}⊙c_{j?}\\ h_j=o_j⊙tanh(c_j)\tag{67}
    該模型為每個孩子節(jié)點都單獨地設(shè)置了參數(shù)U_{l}

5.4.2 訓(xùn)練策略
  • 分類任務(wù):

分類任務(wù)定義為在類別集\mathcal{Y}中預(yù)測出正確的標(biāo)簽\hat{y},對于每一個節(jié)點j阳柔,使用一個softmax分類器來預(yù)測節(jié)點標(biāo)簽\hat{y}_j焰枢,分類器取每個節(jié)點的隱藏狀態(tài)h_j作為輸入:
\hat p_\theta(y|\{x\}_j)=softmax(W^{(s)}h_j+b^{(s)})\\ \hat y_j=\underset{y}{\operatorname{argmax}} \hat p_\theta(y|\{x\}_j)\tag{68}
損失函數(shù)使用negative log-likelihood
J(\theta)=-\frac{1}{m} \sum_{k=1}^{m} \log \hat{p}_{\theta}\left(y^{(k)} |\{x\}^{(k)}\right)+\frac{\lambda}{2}\|\theta\|_{2}^{2}\tag{69}

其中,m是帶有標(biāo)簽的節(jié)點數(shù)量舌剂,\lambdaL2是正則化超參

  • 語義相關(guān)性任務(wù):

該任務(wù)給定一個句子對(sentence pair)济锄,模型需要預(yù)測出一個范圍在[ 1 , K ]之間的實數(shù)值,這個值越高霍转,表示相似度越高荐绝。首先對每一個句子產(chǎn)生一個representation,兩個句子的表示分別用h_Lh_R

表示避消,得到這兩個representation之后低滩,從distance和angle兩個方面考慮,使用神經(jīng)網(wǎng)絡(luò)來得到(h_L,h_R)相似度:
h_× = h_L⊙h_R\\ h_+ = ∣h_L?h_R∣\\ h_s = σ(W^{(×)}h_×+W^{(+)}h_+ +b^{(h)})\\ \hat p_θ=softmax?(W^{(p)}h_s+b^{(p)})\\ \hat y=r^T\hat p_θ\tag{70}
其中岩喷,r^{T}=\left[1,2…K\right]恕沫,模型期望根據(jù)訓(xùn)練得到的參數(shù)\theta得到的結(jié)果:\hat{y}=r^{T} \hat{p}_{\theta} \approx y。由此纱意,定義一個目標(biāo)分布p
y= \begin{cases} y??y?, \ \ \ \ \ \ \ \quad i=?y?+1\\ ?y??y+1, \quad i=?y?\\ 0 \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \quad otherwise \end{cases} \tag{71}
其中婶溯,1\le{i}\le{K},損失函數(shù)為p\hat{p}_{\theta}之間的KL散度:
J(\theta)=\frac{1}{m} \sum_{k=1}^{m} \mathrm{KL}\left(p^{(k)} \| \hat{p}_{\theta}^{(k)}\right)+\frac{\lambda}{2}\|\theta\|_{2}^{2}\tag{72}

5.4.3 代碼實現(xiàn)
class TreeLSTM(torch.nn.Module):
    '''PyTorch TreeLSTM model that implements efficient batching.
    '''
    def __init__(self, in_features, out_features):
        '''TreeLSTM class initializer
        Takes in int sizes of in_features and out_features and sets up model Linear network layers.
        '''
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        # bias terms are only on the W layers for efficiency
        self.W_iou = torch.nn.Linear(self.in_features, 3 * self.out_features)
        self.U_iou = torch.nn.Linear(self.out_features, 3 * self.out_features, bias=False)
        # f terms are maintained seperate from the iou terms because they involve sums over child nodes
        # while the iou terms do not
        self.W_f = torch.nn.Linear(self.in_features, self.out_features)
        self.U_f = torch.nn.Linear(self.out_features, self.out_features, bias=False)
    def forward(self, features, node_order, adjacency_list, edge_order):
        '''Run TreeLSTM model on a tree data structure with node features
        Takes Tensors encoding node features, a tree node adjacency_list, and the order in which 
        the tree processing should proceed in node_order and edge_order.
        '''
        # Total number of nodes in every tree in the batch
        batch_size = node_order.shape[0]
        # Retrive device the model is currently loaded on to generate h, c, and h_sum result buffers
        device = next(self.parameters()).device
        # h and c states for every node in the batch
        h = torch.zeros(batch_size, self.out_features, device=device)
        c = torch.zeros(batch_size, self.out_features, device=device)
        # populate the h and c states respecting computation order
        for n in range(node_order.max() + 1):
            self._run_lstm(n, h, c, features, node_order, adjacency_list, edge_order)
        return h, c
    def _run_lstm(self, iteration, h, c, features, node_order, adjacency_list, edge_order):
        '''Helper function to evaluate all tree nodes currently able to be evaluated.
        '''
        # N is the number of nodes in the tree
        # n is the number of nodes to be evaluated on in the current iteration
        # E is the number of edges in the tree
        # e is the number of edges to be evaluated on in the current iteration
        # F is the number of features in each node
        # M is the number of hidden neurons in the network
        # node_order is a tensor of size N x 1
        # edge_order is a tensor of size E x 1
        # features is a tensor of size N x F
        # adjacency_list is a tensor of size E x 2
        # node_mask is a tensor of size N x 1
        node_mask = node_order == iteration
        # edge_mask is a tensor of size E x 1
        edge_mask = edge_order == iteration
        # x is a tensor of size n x F
        x = features[node_mask, :]
        # At iteration 0 none of the nodes should have children
        # Otherwise, select the child nodes needed for current iteration
        # and sum over their hidden states
        if iteration == 0:
            iou = self.W_iou(x)
        else:
            # adjacency_list is a tensor of size e x 2
            adjacency_list = adjacency_list[edge_mask, :]
            # parent_indexes and child_indexes are tensors of size e x 1
            # parent_indexes and child_indexes contain the integer indexes needed to index into
            # the feature and hidden state arrays to retrieve the data for those parent/child nodes.
            parent_indexes = adjacency_list[:, 0]
            child_indexes = adjacency_list[:, 1]
            # child_h and child_c are tensors of size e x 1
            child_h = h[child_indexes, :]
            child_c = c[child_indexes, :]
            # Add child hidden states to parent offset locations
            _, child_counts = torch.unique_consecutive(parent_indexes, return_counts=True)
            child_counts = tuple(child_counts)
            parent_children = torch.split(child_h, child_counts)
            parent_list = [item.sum(0) for item in parent_children]
            h_sum = torch.stack(parent_list)
            iou = self.W_iou(x) + self.U_iou(h_sum)
        # i, o and u are tensors of size n x M
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i = torch.sigmoid(i)
        o = torch.sigmoid(o)
        u = torch.tanh(u)
        # At iteration 0 none of the nodes should have children
        # Otherwise, calculate the forget states for each parent node and child node
        # and sum over the child memory cell states
        if iteration == 0:
            c[node_mask, :] = i * u
        else:
            # f is a tensor of size e x M
            f = self.W_f(features[parent_indexes, :]) + self.U_f(child_h)
            f = torch.sigmoid(f)
            # fc is a tensor of size e x M
            fc = f * child_c
            # Add the calculated f values to the parent's memory cell state
            parent_children = torch.split(fc, child_counts)
            parent_list = [item.sum(0) for item in parent_children]
            c_sum = torch.stack(parent_list)
            c[node_mask, :] = i * u + c_sum
        h[node_mask, :] = o * torch.tanh(c[node_mask])
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市信轿,隨后出現(xiàn)的幾起案子定罢,更是在濱河造成了極大的恐慌酬凳,老刑警劉巖,帶你破解...
    沈念sama閱讀 221,548評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件满俗,死亡現(xiàn)場離奇詭異谤辜,居然都是意外死亡脯倚,警方通過查閱死者的電腦和手機渔彰,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,497評論 3 399
  • 文/潘曉璐 我一進店門嵌屎,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人恍涂,你說我怎么就攤上這事宝惰。” “怎么了再沧?”我有些...
    開封第一講書人閱讀 167,990評論 0 360
  • 文/不壞的土叔 我叫張陵尼夺,是天一觀的道長。 經(jīng)常有香客問我炒瘸,道長淤堵,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 59,618評論 1 296
  • 正文 為了忘掉前任顷扩,我火速辦了婚禮拐邪,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘隘截。我一直安慰自己扎阶,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 68,618評論 6 397
  • 文/花漫 我一把揭開白布技俐。 她就那樣靜靜地躺著乘陪,像睡著了一般。 火紅的嫁衣襯著肌膚如雪雕擂。 梳的紋絲不亂的頭發(fā)上啡邑,一...
    開封第一講書人閱讀 52,246評論 1 308
  • 那天,我揣著相機與錄音井赌,去河邊找鬼谤逼。 笑死,一個胖子當(dāng)著我的面吹牛仇穗,可吹牛的內(nèi)容都是我干的流部。 我是一名探鬼主播,決...
    沈念sama閱讀 40,819評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼纹坐,長吁一口氣:“原來是場噩夢啊……” “哼枝冀!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起耘子,我...
    開封第一講書人閱讀 39,725評論 0 276
  • 序言:老撾萬榮一對情侶失蹤果漾,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后谷誓,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體绒障,經(jīng)...
    沈念sama閱讀 46,268評論 1 320
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,356評論 3 340
  • 正文 我和宋清朗相戀三年捍歪,在試婚紗的時候發(fā)現(xiàn)自己被綠了户辱。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片鸵钝。...
    茶點故事閱讀 40,488評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖庐镐,靈堂內(nèi)的尸體忽然破棺而出恩商,到底是詐尸還是另有隱情,我是刑警寧澤必逆,帶...
    沈念sama閱讀 36,181評論 5 350
  • 正文 年R本政府宣布痕届,位于F島的核電站,受9級特大地震影響末患,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜锤窑,卻給世界環(huán)境...
    茶點故事閱讀 41,862評論 3 333
  • 文/蒙蒙 一璧针、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧渊啰,春花似錦探橱、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,331評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至嚷那,卻和暖如春胞枕,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背魏宽。 一陣腳步聲響...
    開封第一講書人閱讀 33,445評論 1 272
  • 我被黑心中介騙來泰國打工腐泻, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人队询。 一個月前我還...
    沈念sama閱讀 48,897評論 3 376
  • 正文 我出身青樓派桩,卻偏偏與公主長得像,于是被迫代替她去往敵國和親蚌斩。 傳聞我的和親對象是個殘疾皇子铆惑,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,500評論 2 359

推薦閱讀更多精彩內(nèi)容