從源碼理解部分transformer

從源碼理解部分transformer

源碼來自

https://github.com/graykode/nlp-tutorial

位置編碼

  • 原理

    image.png

  • 源碼

    class PositionalEncoding(nn.Module):
      def __init__(self, d_model, dropout=0.1, max_len=5000):
          super(PositionalEncoding, self).__init__()
          self.dropout = nn.Dropout(p=dropout)
    
          pe = torch.zeros(max_len, d_model)
          position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
          div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # math.pow(10000, (2i/dmodel))济欢,先取對數(shù)眉厨,再指數(shù)慧起,對數(shù)和指數(shù)是反函數(shù)中和了;指數(shù)上加負號math.pow(2,-4)=1/math.pow(2,4)
          pe[:, 0::2] = torch.sin(position * div_term) # position * div_term  維度:max_len,d_model/2
          pe[:, 1::2] = torch.cos(position * div_term)
          pe = pe.unsqueeze(0).transpose(0, 1) # 轉(zhuǎn)置目的是批量上的字加相關(guān)位置(第二維度變成batsize般又,且pe的batsize是1)族吻,如[[1, 2, 3],[4, 5, 6]]猛铅,1和4,2和5,3和6加的位置相同鉴象,同一個字d_model向量維度加的不同滞欠;[seq_len,1,d_model]
          self.register_buffer('pe', pe)
    
      def forward(self, x):
          '''
          x: [seq_len, batch_size, d_model]
          '''
          x = x + self.pe[:x.size(0), :]
          return self.dropout(x)
    
  • 理解

    • div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

      的公式變換如下

DINGTALK_IM_2417386734.JPG.JPG
  • pe = pe.unsqueeze(0).transpose(0, 1)

    與輸入

    x: [seq_len, batch_size, d_model]

    理解為不同batch句子加位置相同

  • 同一詞的不同向量(d_model)維度加的值不同

    pe[:, 0::2] = torch.sin(position * div_term) # position * div_term 維度:max_len,d_model/2

    pe[:, 1::2] = torch.cos(position * div_term)

    • 打印d_model維度對應的位置值
      import numpy as np
      import matplotlib.pyplot as plt
      
      plt.rcParams['font.sans-serif']=['SimHei']###解決中文亂碼
      plt.rcParams['axes.unicode_minus']=False
      
      pe = np.load('./pe.npy')
      x = np.array(range(0, 30)) # 代表0到30的詞
      y0 = pe[0:30, 0]  # 代表0到30的詞的d_model0
      y1 = pe[0:30, 1]  # 代表0到30的詞的d_model1
      y2 = pe[0:30, 2]  # 代表0到30的詞的d_model2
      y3 = pe[0:30, 3]  # 代表0到30的詞的d_model3
      plt.plot(x, y0, label='d_model0')
      plt.plot(x, y1, label='d_model1')
      plt.plot(x, y2, label='d_model2')
      plt.plot(x, y3, label='d_model3')
      plt.legend()
      plt.show()
      
d_model.png

get_attn_pad_mask(seq_q, seq_k)

  • 個人感覺是因為輸入固定長度的序列字符串古胆,屏蔽沒用的是pad(0)的字符
  • 代碼流程
    • 編碼階段(enc_input=seq_q=seq_k;對seq_k的句子筛璧,是pad(0) 的詞設(shè)置為Ture逸绎,其它詞設(shè)置為False,且拉伸到len(seq_q))

      (1)
      # enc_input           dec_input         dec_output
      ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
      ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
      
      src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
      tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
      
      (2)
      seq_k = seq_q = array([[1, 2, 3, 5, 0],
                             [1, 2, 3, 4, 0]])
      
      (3)
      # eq(zero) is PAD token
      pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
      
      pad_attn_mask的值
      array([[[False, False, False, False,  True]],
            [[False, False, False, False,  True]]])  # [batch_size, 1, len_k], False is masked
      
      (4)拉伸到len(seq_q)長度
      pad_attn_mask.expand(batch_size, len_q, len_k)的值
      array([[[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]],
      
         [[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]]])  # [batch_size, len_q, len_k]
      
    • 解碼階段1(dec_input=seq_q=seq_k隧哮; 對seq_k的句子,是pad(0) 的詞設(shè)置為Ture座舍,其它詞設(shè)置為False沮翔,且拉伸到len(seq_q))

      (1)
      # enc_input           dec_input         dec_output
      ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
      ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
      
      src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
      tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
      
      (2)
      seq_k = seq_q = array([[6, 1, 2, 3, 4, 8],
                            [6, 1, 2, 3, 5, 8]])
      
      (3)
      # eq(zero) is PAD token
      pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
      
      pad_attn_mask的值
      array([[[False, False, False, False, False, False]],
            [[False, False, False, False, False, False]]])  # [batch_size, 1, len_k], False is masked
      
      (4)拉伸到len(seq_q)長度
      pad_attn_mask.expand(batch_size, len_q, len_k)的值
      array([[[False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False]],
      
        [[False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False]]])  # [batch_size, len_q, len_k]
      
    • 解碼階段2(dec_input=seq_q,enc_inputs=seq_k曲秉; 對seq_k的句子采蚀,是pad(0) 的詞設(shè)置為Ture,其它詞設(shè)置為False承二,且拉伸到len(seq_q))

      (1)
      # enc_input           dec_input         dec_output
      ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
      ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
      
      src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
      tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
      
      (2)
      seq_q = array([[6, 1, 2, 3, 4, 8],
                     [6, 1, 2, 3, 5, 8]])
      seq_k = array([[1, 2, 3, 4, 0],
                     [1, 2, 3, 5, 0]])
      
      (3)
      # eq(zero) is PAD token
      pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
      
      pad_attn_mask的值
      array([[[False, False, False, False,  True]],
             [[False, False, False, False,  True]]])  # [batch_size, 1, len_k], False is masked
      
      (4)拉伸到len(seq_q)長度
      pad_attn_mask.expand(batch_size, len_q, len_k)的值
      array([[[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]],
      
         [[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]]])  # [batch_size, len_q, len_k]
      

get_attn_subsequence_mask(seq)

  • 屏蔽序列下文

  • 返回三角矩陣 [batch_size, tgt_len, tgt_len]

  • 返回內(nèi)容(只有解碼流程 seq=dec_inputs)

    tensor([[[0, 1, 1, 1, 1, 1],
           [0, 0, 1, 1, 1, 1],
           [0, 0, 0, 1, 1, 1],
           [0, 0, 0, 0, 1, 1],
           [0, 0, 0, 0, 0, 1],
           [0, 0, 0, 0, 0, 0]],
    
          [[0, 1, 1, 1, 1, 1],
           [0, 0, 1, 1, 1, 1],
           [0, 0, 0, 1, 1, 1],
           [0, 0, 0, 0, 1, 1],
           [0, 0, 0, 0, 0, 1],
           [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
    

    class ScaledDotProductAttention(nn.Module)

    • 固定參數(shù)說明
      • d_k == d_v 預先設(shè)置配置好
      • n_heads 預先設(shè)置配置好
      • d_model 預先設(shè)置配置好榆鼠,詞生成的Embedding Size
      • 一般 d_k * n_heads == d_model
    • Q、K亥鸠、V 參數(shù)由來
      self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
      self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
      self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
      
      Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
      K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
      V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]
      
    • attn_mask 參數(shù)
      • 編碼來自 get_attn_pad_mask(seq_q, seq_k)妆够,
      • 解碼來自 get_attn_pad_mask(seq_q, seq_k),get_attn_subsequence_mask(seq)
    • 返回參數(shù) attn pad(0)的字符為0
    tensor([[[[0.2763, 0.1676, 0.3137, 0.2424, 0.0000],
            [0.1488, 0.2446, 0.2804, 0.3262, 0.0000],
            [0.1457, 0.2651, 0.2659, 0.3232, 0.0000],
            [0.1611, 0.3513, 0.2042, 0.2834, 0.0000],
            [0.2473, 0.2574, 0.2483, 0.2470, 0.0000]],
    
           [[0.2334, 0.3264, 0.2237, 0.2165, 0.0000],
            [0.1732, 0.4632, 0.1530, 0.2107, 0.0000],
            [0.1774, 0.4091, 0.1821, 0.2315, 0.0000],
            [0.2029, 0.2564, 0.2512, 0.2895, 0.0000],
            [0.1930, 0.3559, 0.2416, 0.2095, 0.0000]]],
    
    
          [[[0.2988, 0.1964, 0.2787, 0.2261, 0.0000],
            [0.1587, 0.3660, 0.3240, 0.1514, 0.0000],
            [0.1618, 0.3244, 0.2693, 0.2445, 0.0000],
            [0.2168, 0.3195, 0.2736, 0.1901, 0.0000],
            [0.1520, 0.3054, 0.2551, 0.2874, 0.0000]],
    
           [[0.2571, 0.3385, 0.2586, 0.1458, 0.0000],
            [0.2371, 0.4731, 0.1560, 0.1339, 0.0000],
            [0.2473, 0.3912, 0.1984, 0.1632, 0.0000],
            [0.2586, 0.2994, 0.2078, 0.2342, 0.0000],
            [0.2261, 0.3706, 0.1942, 0.2091, 0.0000]]]], device='cuda:0',
         grad_fn=<SoftmaxBackward0>)
    
    • 返回參數(shù) attn 上下文解碼
    tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.5283, 0.4717, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.3347, 0.1790, 0.4864, 0.0000, 0.0000, 0.0000],
            [0.3931, 0.3598, 0.1147, 0.1324, 0.0000, 0.0000],
            [0.2269, 0.2316, 0.0421, 0.0928, 0.4066, 0.0000],
            [0.1697, 0.1361, 0.1696, 0.2551, 0.1201, 0.1493]],
    
           [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.4688, 0.5312, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.2173, 0.4401, 0.3426, 0.0000, 0.0000, 0.0000],
            [0.0493, 0.1505, 0.0469, 0.7534, 0.0000, 0.0000],
            [0.0564, 0.1002, 0.0642, 0.4745, 0.3047, 0.0000],
            [0.1956, 0.1259, 0.1935, 0.0706, 0.0597, 0.3547]]],
    
    
          [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.4690, 0.5310, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.3483, 0.1474, 0.5043, 0.0000, 0.0000, 0.0000],
            [0.3429, 0.4056, 0.0692, 0.1823, 0.0000, 0.0000],
            [0.3588, 0.2282, 0.0887, 0.1288, 0.1956, 0.0000],
            [0.1614, 0.1630, 0.1179, 0.2421, 0.1809, 0.1348]],
    
           [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.4932, 0.5068, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.2127, 0.4520, 0.3353, 0.0000, 0.0000, 0.0000],
            [0.0308, 0.0730, 0.0270, 0.8691, 0.0000, 0.0000],
            [0.0171, 0.0461, 0.0254, 0.4332, 0.4783, 0.0000],
            [0.2427, 0.1489, 0.2457, 0.0840, 0.0753, 0.2035]]]], device='cuda:0',
         grad_fn=<SoftmaxBackward0>)
    
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末负蚊,一起剝皮案震驚了整個濱河市神妹,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌家妆,老刑警劉巖鸵荠,帶你破解...
    沈念sama閱讀 219,188評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異伤极,居然都是意外死亡蛹找,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,464評論 3 395
  • 文/潘曉璐 我一進店門哨坪,熙熙樓的掌柜王于貴愁眉苦臉地迎上來庸疾,“玉大人,你說我怎么就攤上這事当编”肆颍” “怎么了?”我有些...
    開封第一講書人閱讀 165,562評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長拧篮。 經(jīng)常有香客問我词渤,道長,這世上最難降的妖魔是什么串绩? 我笑而不...
    開封第一講書人閱讀 58,893評論 1 295
  • 正文 為了忘掉前任缺虐,我火速辦了婚禮,結(jié)果婚禮上礁凡,老公的妹妹穿的比我還像新娘高氮。我一直安慰自己,他們只是感情好顷牌,可當我...
    茶點故事閱讀 67,917評論 6 392
  • 文/花漫 我一把揭開白布剪芍。 她就那樣靜靜地躺著,像睡著了一般窟蓝。 火紅的嫁衣襯著肌膚如雪罪裹。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,708評論 1 305
  • 那天运挫,我揣著相機與錄音状共,去河邊找鬼。 笑死谁帕,一個胖子當著我的面吹牛峡继,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播匈挖,決...
    沈念sama閱讀 40,430評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼碾牌,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了儡循?” 一聲冷哼從身側(cè)響起小染,我...
    開封第一講書人閱讀 39,342評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎贮折,沒想到半個月后裤翩,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,801評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡调榄,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,976評論 3 337
  • 正文 我和宋清朗相戀三年踊赠,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片每庆。...
    茶點故事閱讀 40,115評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡筐带,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出缤灵,到底是詐尸還是另有隱情伦籍,我是刑警寧澤蓝晒,帶...
    沈念sama閱讀 35,804評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站帖鸦,受9級特大地震影響芝薇,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜作儿,卻給世界環(huán)境...
    茶點故事閱讀 41,458評論 3 331
  • 文/蒙蒙 一洛二、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧攻锰,春花似錦晾嘶、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,008評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至妒蛇,卻和暖如春机断,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背材部。 一陣腳步聲響...
    開封第一講書人閱讀 33,135評論 1 272
  • 我被黑心中介騙來泰國打工毫缆, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留唯竹,地道東北人乐导。 一個月前我還...
    沈念sama閱讀 48,365評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像浸颓,于是被迫代替她去往敵國和親物臂。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,055評論 2 355

推薦閱讀更多精彩內(nèi)容