從源碼理解部分transformer
源碼來自
https://github.com/graykode/nlp-tutorial
位置編碼
-
原理
image.png -
源碼
class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # math.pow(10000, (2i/dmodel))济欢,先取對數(shù)眉厨,再指數(shù)慧起,對數(shù)和指數(shù)是反函數(shù)中和了;指數(shù)上加負號math.pow(2,-4)=1/math.pow(2,4) pe[:, 0::2] = torch.sin(position * div_term) # position * div_term 維度:max_len,d_model/2 pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) # 轉(zhuǎn)置目的是批量上的字加相關(guān)位置(第二維度變成batsize般又,且pe的batsize是1)族吻,如[[1, 2, 3],[4, 5, 6]]猛铅,1和4,2和5,3和6加的位置相同鉴象,同一個字d_model向量維度加的不同滞欠;[seq_len,1,d_model] self.register_buffer('pe', pe) def forward(self, x): ''' x: [seq_len, batch_size, d_model] ''' x = x + self.pe[:x.size(0), :] return self.dropout(x)
-
理解
-
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
的公式變換如下
-
DINGTALK_IM_2417386734.JPG.JPG
-
pe = pe.unsqueeze(0).transpose(0, 1)
與輸入
x: [seq_len, batch_size, d_model]
理解為不同batch句子加位置相同
-
同一詞的不同向量(d_model)維度加的值不同
pe[:, 0::2] = torch.sin(position * div_term) # position * div_term 維度:max_len,d_model/2
pe[:, 1::2] = torch.cos(position * div_term)
- 打印d_model維度對應的位置值
import numpy as np import matplotlib.pyplot as plt plt.rcParams['font.sans-serif']=['SimHei']###解決中文亂碼 plt.rcParams['axes.unicode_minus']=False pe = np.load('./pe.npy') x = np.array(range(0, 30)) # 代表0到30的詞 y0 = pe[0:30, 0] # 代表0到30的詞的d_model0 y1 = pe[0:30, 1] # 代表0到30的詞的d_model1 y2 = pe[0:30, 2] # 代表0到30的詞的d_model2 y3 = pe[0:30, 3] # 代表0到30的詞的d_model3 plt.plot(x, y0, label='d_model0') plt.plot(x, y1, label='d_model1') plt.plot(x, y2, label='d_model2') plt.plot(x, y3, label='d_model3') plt.legend() plt.show()
- 打印d_model維度對應的位置值
d_model.png
get_attn_pad_mask(seq_q, seq_k)
- 個人感覺是因為輸入固定長度的序列字符串古胆,屏蔽沒用的是pad(0)的字符
- 代碼流程
-
編碼階段(enc_input=seq_q=seq_k;對seq_k的句子筛璧,是pad(0) 的詞設(shè)置為Ture逸绎,其它詞設(shè)置為False,且拉伸到len(seq_q))
(1) # enc_input dec_input dec_output ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'], ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E'] src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5} tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8} (2) seq_k = seq_q = array([[1, 2, 3, 5, 0], [1, 2, 3, 4, 0]]) (3) # eq(zero) is PAD token pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], False is masked pad_attn_mask的值 array([[[False, False, False, False, True]], [[False, False, False, False, True]]]) # [batch_size, 1, len_k], False is masked (4)拉伸到len(seq_q)長度 pad_attn_mask.expand(batch_size, len_q, len_k)的值 array([[[False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True]], [[False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True]]]) # [batch_size, len_q, len_k]
-
解碼階段1(dec_input=seq_q=seq_k隧哮; 對seq_k的句子,是pad(0) 的詞設(shè)置為Ture座舍,其它詞設(shè)置為False沮翔,且拉伸到len(seq_q))
(1) # enc_input dec_input dec_output ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'], ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E'] src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5} tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8} (2) seq_k = seq_q = array([[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]) (3) # eq(zero) is PAD token pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], False is masked pad_attn_mask的值 array([[[False, False, False, False, False, False]], [[False, False, False, False, False, False]]]) # [batch_size, 1, len_k], False is masked (4)拉伸到len(seq_q)長度 pad_attn_mask.expand(batch_size, len_q, len_k)的值 array([[[False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False]], [[False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False]]]) # [batch_size, len_q, len_k]
-
解碼階段2(dec_input=seq_q,enc_inputs=seq_k曲秉; 對seq_k的句子采蚀,是pad(0) 的詞設(shè)置為Ture,其它詞設(shè)置為False承二,且拉伸到len(seq_q))
(1) # enc_input dec_input dec_output ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'], ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E'] src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5} tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8} (2) seq_q = array([[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]) seq_k = array([[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]) (3) # eq(zero) is PAD token pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], False is masked pad_attn_mask的值 array([[[False, False, False, False, True]], [[False, False, False, False, True]]]) # [batch_size, 1, len_k], False is masked (4)拉伸到len(seq_q)長度 pad_attn_mask.expand(batch_size, len_q, len_k)的值 array([[[False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True]], [[False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True], [False, False, False, False, True]]]) # [batch_size, len_q, len_k]
-
get_attn_subsequence_mask(seq)
屏蔽序列下文
返回三角矩陣 [batch_size, tgt_len, tgt_len]
-
返回內(nèi)容(只有解碼流程 seq=dec_inputs)
tensor([[[0, 1, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0]], [[0, 1, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
class ScaledDotProductAttention(nn.Module)
- 固定參數(shù)說明
- d_k == d_v 預先設(shè)置配置好
- n_heads 預先設(shè)置配置好
- d_model 預先設(shè)置配置好榆鼠,詞生成的Embedding Size
- 一般 d_k * n_heads == d_model
- Q、K亥鸠、V 參數(shù)由來
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False) Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k] K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k] V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v]
- attn_mask 參數(shù)
- 編碼來自 get_attn_pad_mask(seq_q, seq_k)妆够,
- 解碼來自 get_attn_pad_mask(seq_q, seq_k),get_attn_subsequence_mask(seq)
- 返回參數(shù) attn pad(0)的字符為0
tensor([[[[0.2763, 0.1676, 0.3137, 0.2424, 0.0000], [0.1488, 0.2446, 0.2804, 0.3262, 0.0000], [0.1457, 0.2651, 0.2659, 0.3232, 0.0000], [0.1611, 0.3513, 0.2042, 0.2834, 0.0000], [0.2473, 0.2574, 0.2483, 0.2470, 0.0000]], [[0.2334, 0.3264, 0.2237, 0.2165, 0.0000], [0.1732, 0.4632, 0.1530, 0.2107, 0.0000], [0.1774, 0.4091, 0.1821, 0.2315, 0.0000], [0.2029, 0.2564, 0.2512, 0.2895, 0.0000], [0.1930, 0.3559, 0.2416, 0.2095, 0.0000]]], [[[0.2988, 0.1964, 0.2787, 0.2261, 0.0000], [0.1587, 0.3660, 0.3240, 0.1514, 0.0000], [0.1618, 0.3244, 0.2693, 0.2445, 0.0000], [0.2168, 0.3195, 0.2736, 0.1901, 0.0000], [0.1520, 0.3054, 0.2551, 0.2874, 0.0000]], [[0.2571, 0.3385, 0.2586, 0.1458, 0.0000], [0.2371, 0.4731, 0.1560, 0.1339, 0.0000], [0.2473, 0.3912, 0.1984, 0.1632, 0.0000], [0.2586, 0.2994, 0.2078, 0.2342, 0.0000], [0.2261, 0.3706, 0.1942, 0.2091, 0.0000]]]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
- 返回參數(shù) attn 上下文解碼
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5283, 0.4717, 0.0000, 0.0000, 0.0000, 0.0000], [0.3347, 0.1790, 0.4864, 0.0000, 0.0000, 0.0000], [0.3931, 0.3598, 0.1147, 0.1324, 0.0000, 0.0000], [0.2269, 0.2316, 0.0421, 0.0928, 0.4066, 0.0000], [0.1697, 0.1361, 0.1696, 0.2551, 0.1201, 0.1493]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4688, 0.5312, 0.0000, 0.0000, 0.0000, 0.0000], [0.2173, 0.4401, 0.3426, 0.0000, 0.0000, 0.0000], [0.0493, 0.1505, 0.0469, 0.7534, 0.0000, 0.0000], [0.0564, 0.1002, 0.0642, 0.4745, 0.3047, 0.0000], [0.1956, 0.1259, 0.1935, 0.0706, 0.0597, 0.3547]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4690, 0.5310, 0.0000, 0.0000, 0.0000, 0.0000], [0.3483, 0.1474, 0.5043, 0.0000, 0.0000, 0.0000], [0.3429, 0.4056, 0.0692, 0.1823, 0.0000, 0.0000], [0.3588, 0.2282, 0.0887, 0.1288, 0.1956, 0.0000], [0.1614, 0.1630, 0.1179, 0.2421, 0.1809, 0.1348]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4932, 0.5068, 0.0000, 0.0000, 0.0000, 0.0000], [0.2127, 0.4520, 0.3353, 0.0000, 0.0000, 0.0000], [0.0308, 0.0730, 0.0270, 0.8691, 0.0000, 0.0000], [0.0171, 0.0461, 0.0254, 0.4332, 0.4783, 0.0000], [0.2427, 0.1489, 0.2457, 0.0840, 0.0753, 0.2035]]]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
- 固定參數(shù)說明