self-attention實現(xiàn)

cnn中實現(xiàn)attention主要是有Sparial Domain和Channel Domain
soft-attention是可微的晦炊,可以通過梯度來實現(xiàn)

import torch
import torch.nn as nn
import numpy as np
import math
class SelfAttention(nn.Module):
    
    def __init__(self, hidden_size, num_attention_heads, dropout_prob):   
        """
        假設 hidden_size = 128, num_attention_heads = 8, dropout_prob = 0.2
        即隱層維度為128逛绵,注意力頭設置為8個
        """
        super(SelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:   # 整除
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        # 參數(shù)定義
        self.num_attention_heads = num_attention_heads    # 8
        self.attention_head_size = int(hidden_size / num_attention_heads)  # 16  每個注意力頭的維度
        self.all_head_size = int(self.num_attention_heads * self.attention_head_size)   
        # all_head_size = 128 即等于hidden_size, 一般自注意力輸入輸出前后維度不變
        
        # query, key, value 的線性變換(上述公式2)
        self.query = nn.Linear(hidden_size, self.all_head_size)    # 128, 128
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        
        # dropout
        self.dropout = nn.Dropout(dropout_prob)

    def transpose_for_scores(self, x):
        # INPUT:  x'shape = [bs, seqlen, hid_size]  假設hid_size=128
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # [bs, seqlen, 8, 16]
        x = x.view(*new_x_shape)   # 
        return x.permute(0, 2, 1, 3)   # [bs, 8, seqlen, 16]

    def forward(self, hidden_states, attention_mask):
        # eg: attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])  shape=[bs, seqlen]
        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)   # [bs, 1, 1, seqlen] 增加維度
        attention_mask = (1.0 - attention_mask) * -10000.0   # padding的token置為-10000忱辅,exp(-1w)=0
        
        # 線性變換
        mixed_query_layer = self.query(hidden_states)   # [bs, seqlen, hid_size]
        mixed_key_layer = self.key(hidden_states)       # [bs, seqlen, hid_size]
        mixed_value_layer = self.value(hidden_states)   # [bs, seqlen, hid_size]

        query_layer = self.transpose_for_scores(mixed_query_layer)    # [bs, 8, seqlen, 16]
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)   # [bs, 8, seqlen, 16]

        # Take the dot product between "query" and "key" to get the raw attention scores.
        # 計算query與title之間的點積注意力分數(shù)途样,還不是權重(個人認為權重應該是和為1的概率分布)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        # [bs, 8, seqlen, 16]*[bs, 8, 16, seqlen]  ==> [bs, 8, seqlen, seqlen]
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)   # [bs, 8, seqlen, seqlen]
        # 除以根號注意力頭的數(shù)量蝎毡,可看原論文公式银还,防止分數(shù)過大,過大會導致softmax之后非0即1
        attention_scores = attention_scores + attention_mask
        # 加上mask,將padding所在的表示直接-10000

        # 將注意力轉(zhuǎn)化為概率分布趁冈,即注意力權重
        attention_probs = nn.Softmax(dim=-1)(attention_scores)    # [bs, 8, seqlen, seqlen]

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)
        
        # 矩陣相乘,[bs, 8, seqlen, seqlen]*[bs, 8, seqlen, 16] = [bs, 8, seqlen, 16]
        context_layer = torch.matmul(attention_probs, value_layer)   # [bs, 8, seqlen, 16]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()   # [bs, seqlen, 8, 16]
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)   # [bs, seqlen, 128]
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer    # [bs, seqlen, 128] 得到輸出
attention=SelfAttention(4,2,0.2)
x_in=torch.randn(3,5,4)
x_mask=torch.Tensor([[1,1,1,0,0],
                    [1,1,0,0,0],
                    [1,1,1,1,1],])
print(x_mask.shape)
x_out=attention(x_in,x_mask)
最后編輯于
?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末拜马,一起剝皮案震驚了整個濱河市渗勘,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌俩莽,老刑警劉巖旺坠,帶你破解...
    沈念sama閱讀 219,039評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異豹绪,居然都是意外死亡价淌,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,426評論 3 395
  • 文/潘曉璐 我一進店門瞒津,熙熙樓的掌柜王于貴愁眉苦臉地迎上來蝉衣,“玉大人,你說我怎么就攤上這事巷蚪〔≌保” “怎么了?”我有些...
    開封第一講書人閱讀 165,417評論 0 356
  • 文/不壞的土叔 我叫張陵屁柏,是天一觀的道長啦膜。 經(jīng)常有香客問我有送,道長,這世上最難降的妖魔是什么僧家? 我笑而不...
    開封第一講書人閱讀 58,868評論 1 295
  • 正文 為了忘掉前任雀摘,我火速辦了婚禮,結(jié)果婚禮上八拱,老公的妹妹穿的比我還像新娘阵赠。我一直安慰自己,他們只是感情好肌稻,可當我...
    茶點故事閱讀 67,892評論 6 392
  • 文/花漫 我一把揭開白布清蚀。 她就那樣靜靜地躺著,像睡著了一般爹谭。 火紅的嫁衣襯著肌膚如雪枷邪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,692評論 1 305
  • 那天诺凡,我揣著相機與錄音东揣,去河邊找鬼。 笑死绑洛,一個胖子當著我的面吹牛救斑,可吹牛的內(nèi)容都是我干的童本。 我是一名探鬼主播真屯,決...
    沈念sama閱讀 40,416評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼穷娱!你這毒婦竟也來了绑蔫?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,326評論 0 276
  • 序言:老撾萬榮一對情侶失蹤泵额,失蹤者是張志新(化名)和其女友劉穎配深,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體嫁盲,經(jīng)...
    沈念sama閱讀 45,782評論 1 316
  • 正文 獨居荒郊野嶺守林人離奇死亡篓叶,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,957評論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了羞秤。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片缸托。...
    茶點故事閱讀 40,102評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖瘾蛋,靈堂內(nèi)的尸體忽然破棺而出俐镐,到底是詐尸還是另有隱情,我是刑警寧澤哺哼,帶...
    沈念sama閱讀 35,790評論 5 346
  • 正文 年R本政府宣布佩抹,位于F島的核電站叼风,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏棍苹。R本人自食惡果不足惜无宿,卻給世界環(huán)境...
    茶點故事閱讀 41,442評論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望枢里。 院中可真熱鬧懈贺,春花似錦、人聲如沸坡垫。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,996評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽冰悠。三九已至堡妒,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間溉卓,已是汗流浹背皮迟。 一陣腳步聲響...
    開封第一講書人閱讀 33,113評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留桑寨,地道東北人伏尼。 一個月前我還...
    沈念sama閱讀 48,332評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像尉尾,于是被迫代替她去往敵國和親爆阶。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,044評論 2 355