Transformer系列:Multi-Head Attention網(wǎng)絡(luò)結(jié)構(gòu)和代碼解析

關(guān)鍵詞:Transfomer惠况,self attention

Transformer Self Attention的作用

Transformer引入Self Attention解決NLP任務(wù),相比于傳統(tǒng)的TextCNN, LSTM等模型擁有以下優(yōu)勢

    1. 解決了傳統(tǒng)的RNN無法并行的問題合冀,RNN是自回歸模型怖喻,下一個RNN單元的計算依賴上一個RNN單元的計算結(jié)果,Transformer采用Self Attention每個句子的字/詞可以同時輸入獨立計算
    1. Transformer能夠觀察整個句子的每個元素進行語義理解宪哩,而TextCNN采用一定尺寸的卷積核只能觀察局部上下文,只能通過增加卷積層數(shù)來處理這種長距離的元素依賴
    1. Transformer每一個字/詞不僅包含了自身的embedding信息第晰,還自適應(yīng)地融合和整個句子的上下文的信息锁孟,可以實現(xiàn)相同的字/詞在不同上下文語境下不同表達,尤其擅長對有強語義關(guān)系的數(shù)據(jù)進行建模

Self Attention簡介

Self Attention就是自身和自身進行Attention茁瘦,具體為句子內(nèi)部的每個字/詞之間進行通信品抽,計算出句子中每個字/詞和其中一個目標(biāo)字/詞的注意力權(quán)重,從而得到目標(biāo)字/詞的embedding表征


self attention示意圖

在Transformer中Self Attention采用Scaled Dot-Product Attention(縮放點積注意力)甜熔,采用向量內(nèi)積計算兩兩字/詞的相似度桑包,相似度越大注意力權(quán)重越大,融合這個詞的信息越多


Multi-Head Attention網(wǎng)絡(luò)結(jié)構(gòu)解析

Transformer采用多頭注意力機制纺非,模型網(wǎng)絡(luò)結(jié)構(gòu)如下


Multi-Head Attention

其中h表示頭的個數(shù),每個頭都包含單獨的一個縮放點積注意力以及注意力前的線性映射層赘方,多個頭的結(jié)果concat烧颖,輸入到最后的全連接映射層,縮放點積注意力網(wǎng)絡(luò)結(jié)構(gòu)如下


Scaled Dot-Product Attention

(1) Multi-Head Attention流程

Transformer的Multi-Head Attention包含5個步驟:

  • 1.點乘: 計算Query矩陣Q窄陡、Key矩陣K的乘積炕淮,得到得分矩陣scores
  • 2.縮放: 對得分矩陣scores進行縮放,即將其除以向量維度的平方根(np.sqrt(d_k))
  • 3.mask: 若存在Attention Mask跳夭,則將Attention Mask的值為True的位置對應(yīng)的得分矩陣元素置為負無窮(-inf)
  • 4.softmax: 對得分矩陣scores進行softmax計算涂圆,得到Attention權(quán)重矩陣attn
  • 5.加權(quán)求和: 計算Value矩陣V和Attention權(quán)重矩陣attn的乘積们镜,得到加權(quán)后的Context矩陣

(2) 為啥Q,K,V線性變換

Q,K润歉,V是三個矩陣模狭,對原始的輸入句子的embedding做線性映射(wx+b,沒有激活函數(shù))踩衩,其中Q和K映射后的新矩陣負責(zé)計算相似度嚼鹉,V映射的矩陣負責(zé)和相似度進行加權(quán)求和。在Transformer的decoder層驱富,Q锚赤,K,V對同一個句子進行三次不同的映射褐鸥,目的是提升原始embedding表達的豐富度线脚,如果有多個頭,就有多少套Q叫榕,K浑侥,V矩陣,他們之間不共享翠霍。
如果不引入Q锭吨,K而選擇直接對原始的embedding做self attention,則計算的相似度是個上三角和下三角對稱

字和字的點積結(jié)果

另外如果不引入Q寒匙,K零如,則對角向上的值一定是最大的,因為同一個字相同的embedding是完全重合的锄弱,每個字/詞必定最關(guān)心自己考蕾,這是模型不想看到的,因此要引入Q会宪,K肖卧。而引入V矩陣主要是提升原始embedding的表達能力


(3) 為啥要帶有縮放的Scaled Dot-Product Attention

Scaled是縮放的意思,表現(xiàn)在在點乘之后除以一個分母根號下K向量的維度


Attention公式

引入這個分母的作用的防止在Softmax計算中值和值存在過大的差異掸鹅,導(dǎo)致計算結(jié)果為OneHot導(dǎo)致梯度消失塞帐。
容易理解除以分母之后整個點乘的結(jié)果會變小,可以緩解值和值之間的差異大小巍沙,而為什么是除以根號下K向量維度(K,V,Q三個向量維度一樣)葵姥,原因是除以根號下K維度后數(shù)據(jù)的分布期望和原來一致。舉例假設(shè)key和query服從均值為0句携,方差為1的均勻分布, 即D(query)=D(key)=1, 維度大小為64榔幸,那么點積后的,我們可以計算他的方差變化

方差變化

因此所有計算出的點積值都除以根號下64似的最終的結(jié)果還是符合均值0方差1的分布。
在計算Attention的時候多種策略比如第一種以全連接計算相似性比如GAT中所使用,和第二種類似Transformer的向量內(nèi)積

兩種注意力權(quán)重計算方式

其中由于第一種有全連接參數(shù)進行學(xué)習(xí)削咆,還有tanh激活函數(shù)壓縮牍疏,到Softmax的輸入是可控的,而第二種隨著向量維度的增大拨齐,點乘結(jié)果的上限越來越高鳞陨,點乘結(jié)果的差異越來越大,因此采用第二種計算Attention權(quán)重需要加入scaled


(4) 為啥要多頭

多個頭的結(jié)果拼接融合奏黑,提升特征表征和泛化能力


tensorflow代碼實現(xiàn)

代碼參考attention-is-all-you-need-keras
作者基于tensorflow2和tf.keras炊邦,關(guān)于Self Attention的代碼在MultiHeadAttention類

class MultiHeadAttention():
    # mode 0 - big martixes, faster; mode 1 - more clear implementation
    def __init__(self, n_head, d_model, dropout, mode=0):
        self.mode = mode
        self.n_head = n_head  # 8
        # k的維度,v的維度, q的維度和k一致,因為k,q要計算內(nèi)積,256/8
        self.d_k = self.d_v = d_k = d_v = d_model // n_head  # 32, d_model為詞向量的emb維度
        self.dropout = dropout
        if mode == 0:
            # q,k,v => [None, seq_len, 256]
            # 這個是大矩陣的方案熟史,這個快馁害,這個256包含所有頭的線性變換參數(shù)w,沒有激活函數(shù),8個頭統(tǒng)一在一個大矩陣進行線性變換
            self.qs_layer = Dense(n_head * d_k, use_bias=False)
            self.ks_layer = Dense(n_head * d_k, use_bias=False)
            self.vs_layer = Dense(n_head * d_v, use_bias=False)
        elif mode == 1:
            self.qs_layers = []
            self.ks_layers = []
            self.vs_layers = []
            for _ in range(n_head):  # 8個頭
                # 保證每個頭dense之后的結(jié)果拼接和d_model一致
                self.qs_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
                self.ks_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
                self.vs_layers.append(TimeDistributed(Dense(d_v, use_bias=False)))
        # 縮放點積注意力
        self.attention = ScaledDotProductAttention()
        # TimeDistributed這個實際上就是一個全連接
        self.w_o = TimeDistributed(Dense(d_model))
        # self.w_o = Dense(d_model)

    def __call__(self, q, k, v, mask=None):
        # 在encoder,q=enc_input,k=enc_input,v=enc_input
        # 在decoder的第一層,q=dec_input, k=dec_last_state, v=dec_last_state
        # 在decoder的第二層,q=decoder第一層的輸出, k=enc_output, v=enc_output
        d_k, d_v = self.d_k, self.d_v
        n_head = self.n_head

        if self.mode == 0:
            # [None, seq_len, 256] => [None, seq_len, 256]
            qs = self.qs_layer(q)  # [batch_size, len_q, n_head*d_k]
            ks = self.ks_layer(k)
            vs = self.vs_layer(v)

            def reshape1(x):
                s = tf.shape(x)  # [batch_size, len_q, n_head * d_k]
                # [None, seq_len, 8, 32]
                x = tf.reshape(x, [s[0], s[1], n_head, s[2] // n_head])
                # [8, None, seq_len, 32]
                x = tf.transpose(x, [2, 0, 1, 3])
                # 連續(xù)的8個都是同一個原始語句的
                # [8 * batch_size, seq_len, 32]
                x = tf.reshape(x, [-1, s[1], s[2] // n_head])  # [n_head * batch_size, len_q, d_k]
                return x

            # 相當(dāng)于將for循環(huán)頭拼接,轉(zhuǎn)化為將for循環(huán)放到batch_size里面再整合最后的結(jié)果
            qs = Lambda(reshape1)(qs)  # [batch_size, seq_len, 256] => [8 * batch_size, seq_len, 32]
            ks = Lambda(reshape1)(ks)
            vs = Lambda(reshape1)(vs)

            if mask is not None:
                mask = Lambda(lambda x: K.repeat_elements(x, n_head, 0))(mask)
            # head是注意力的輸出蹂匹,attn是注意力權(quán)重
            # 如果是大矩陣 [8 * batch_size, seq_len, 32]
            head, attn = self.attention(qs, ks, vs, mask=mask)

            def reshape2(x):
                # 對結(jié)果再做整理
                s = tf.shape(x)  # [n_head * batch_size, len_v, d_v]
                # [8, batch_size, seq_len, 32]
                x = tf.reshape(x, [n_head, -1, s[1], s[2]])
                # [batch_size, seq_len, 8, 32]
                x = tf.transpose(x, [1, 2, 0, 3])
                # [batch_size, seq_len, 8 * 32]
                x = tf.reshape(x, [-1, s[1], n_head * d_v])  # [batch_size, len_v, n_head * d_v]
                return x

            head = Lambda(reshape2)(head)
        elif self.mode == 1:
            # 每個頭的結(jié)果
            heads = []
            # 每個頭的注意力權(quán)重
            attns = []
            for i in range(n_head):
                # 拿到對應(yīng)下標(biāo)的網(wǎng)絡(luò)
                qs = self.qs_layers[i](q)  # q線性變換  [None, None, 256] => [None, None, 32]
                ks = self.ks_layers[i](k)  # k線性變換  [None, None, 256] => [None, None, 32]
                vs = self.vs_layers[i](v)  # v線性變換  [None, None, 256] => [None, None, 32]
                head, attn = self.attention(qs, ks, vs, mask)
                heads.append(head)
                attns.append(attn)
            # concat [[None, seq_len, 32], [None, seq_len, 32 ...]], Concatenate默認axis=-1,最里面一維合并
            # [None, seq_len, 32 * 8] = [None, seq_len, 256], 最終子注意力產(chǎn)出每個詞維度emb是256,和原始的emb維度是一致的
            head = Concatenate()(heads) if n_head > 1 else heads[0]
            attn = Concatenate()(attns) if n_head > 1 else attns[0]

        # 加權(quán)求和的結(jié)果在做一層全連接,[None, None, 256] => [None. None, 256]
        outputs = self.w_o(head)
        outputs = Dropout(self.dropout)(outputs)
        return outputs, attn

以詞的embedding是256為例碘菜,其中調(diào)用該類的目的是使得輸入[batch_size, seq_len, 256]注意力映射為[batch_size, seq_len, 256]的新向量,其中第2位置上的256是8個頭的拼接的結(jié)果限寞,每個頭的embedding維度是32忍啸。
其中有兩種模式有mode參數(shù)控制,默認mode=0走大矩陣方式履植,該種方式將8個注意頭全部平鋪在三維輸入矩陣的第0維batch_size上计雌,一起進行點乘操作,結(jié)果在通過reshape和轉(zhuǎn)置整理為8個頭在第2維上的拼接玫霎,這種方式計算快凿滤。
第二種mode=1是傳統(tǒng)的for循環(huán)一個一個計算頭,再將結(jié)果列表進行concat庶近,代碼上更清晰一點翁脆。
其中點乘計算相似度的ScaledDotProductAttention如下

class ScaledDotProductAttention():
    def __init__(self, attn_dropout=0.1):
        self.dropout = Dropout(attn_dropout)

    def __call__(self, q, k, v, mask):  # mask_k or mask_qk
        # 根號32,向量維度平方根,np.sqrt(d_k)
        # 如果是大矩陣的話,還是32
        temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype='float32'))
        # 計算點乘 [None, None, 32] * [None, None, 32]
        # 這個K.batch_dot就是batch0位置不動鼻种,1和2位置點乘,相當(dāng)于tf.matmul(q, tf.transpose(k, [0, 2, 1]))
        # 每個句子內(nèi)部反番,每個字和其他字計算一個內(nèi)積[None, seq_len, 32] * [None, seq_len, 32] => [None, seq_len, seq_len]
        attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / x[2])([q, k, temper])  # shape=(batch, q, k)
        if mask is not None:
            # K.cast(K.greater(src_seq, 0), 'float32') pad=0,非pad=1
            # 將<pad>的置為一個極負的數(shù),使地softmax位置上為0,不把他的特征向量用于加權(quán)求和
            mmask = Lambda(lambda x: (-1e+9) * (1. - K.cast(x, 'float32')))(mask)
            attn = Add()([attn, mmask])
        attn = Activation('softmax')(attn)
        attn = self.dropout(attn)
        # 這個地方加權(quán)求和
        # [None, seq_len, seq_len] * [None, seq_len, 32] => [None, seq_len, 32] 每個詞/句子的最終表達
        # K.batch_dot可以直接改成tf.matmul
        output = Lambda(lambda x: K.batch_dot(x[0], x[1]))([attn, v])
        return output, attn

代碼中有一些用到了Keras算子叉钥,記錄一下

  • TimeDistributed

這個就是把一個網(wǎng)絡(luò)層應(yīng)用在一個有步長輸入矩陣的每一個步長上面罢缸,TimeDistributed(Dense(d_k, use_bias=False))相當(dāng)于原始三維([batch_size, seq_len, emb_size])的[seq_len, emb_size]去做一個Dense全連接,實際上三維可以直接和二維進行全連接投队,改行代表代表在構(gòu)建三個線性映射矩陣

  • K.batch_dot

代表一個帶有batch_size和另一個帶有batch_size的矩陣相乘枫疆,batch_size不參與計算,axes代表要進行矩陣運算需要匹配的對應(yīng)維度蛾洛,axes=[2, 2]表示前一個矩陣的第2維要和后一個矩陣的第2維匹配相等,然后進行相乘,實際上
attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / x[2])([q, k, temper])完全可以替換為一個普通的矩陣相乘轧膘,先把矩陣轉(zhuǎn)置一下再矩陣相乘即可
attn1 = tf.matmul(q, tf.transpose(k, [0, 2, 1])) / temper

其他代碼解讀詳情見注釋

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末钞螟,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子谎碍,更是在濱河造成了極大的恐慌鳞滨,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,427評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件蟆淀,死亡現(xiàn)場離奇詭異拯啦,居然都是意外死亡,警方通過查閱死者的電腦和手機熔任,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,551評論 3 395
  • 文/潘曉璐 我一進店門褒链,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人疑苔,你說我怎么就攤上這事甫匹。” “怎么了惦费?”我有些...
    開封第一講書人閱讀 165,747評論 0 356
  • 文/不壞的土叔 我叫張陵兵迅,是天一觀的道長。 經(jīng)常有香客問我薪贫,道長恍箭,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,939評論 1 295
  • 正文 為了忘掉前任瞧省,我火速辦了婚禮扯夭,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘臀突。我一直安慰自己勉抓,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,955評論 6 392
  • 文/花漫 我一把揭開白布候学。 她就那樣靜靜地躺著藕筋,像睡著了一般。 火紅的嫁衣襯著肌膚如雪梳码。 梳的紋絲不亂的頭發(fā)上隐圾,一...
    開封第一講書人閱讀 51,737評論 1 305
  • 那天,我揣著相機與錄音掰茶,去河邊找鬼暇藏。 笑死,一個胖子當(dāng)著我的面吹牛濒蒋,可吹牛的內(nèi)容都是我干的盐碱。 我是一名探鬼主播把兔,決...
    沈念sama閱讀 40,448評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼瓮顽!你這毒婦竟也來了县好?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,352評論 0 276
  • 序言:老撾萬榮一對情侶失蹤暖混,失蹤者是張志新(化名)和其女友劉穎缕贡,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體拣播,經(jīng)...
    沈念sama閱讀 45,834評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡晾咪,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,992評論 3 338
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了贮配。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片谍倦。...
    茶點故事閱讀 40,133評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖牧嫉,靈堂內(nèi)的尸體忽然破棺而出剂跟,到底是詐尸還是另有隱情,我是刑警寧澤酣藻,帶...
    沈念sama閱讀 35,815評論 5 346
  • 正文 年R本政府宣布曹洽,位于F島的核電站,受9級特大地震影響辽剧,放射性物質(zhì)發(fā)生泄漏送淆。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,477評論 3 331
  • 文/蒙蒙 一怕轿、第九天 我趴在偏房一處隱蔽的房頂上張望偷崩。 院中可真熱鬧,春花似錦撞羽、人聲如沸阐斜。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,022評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽谒出。三九已至,卻和暖如春邻奠,著一層夾襖步出監(jiān)牢的瞬間笤喳,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,147評論 1 272
  • 我被黑心中介騙來泰國打工碌宴, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留杀狡,地道東北人。 一個月前我還...
    沈念sama閱讀 48,398評論 3 373
  • 正文 我出身青樓贰镣,卻偏偏與公主長得像呜象,于是被迫代替她去往敵國和親膳凝。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,077評論 2 355

推薦閱讀更多精彩內(nèi)容