一 Transformer overview
本文結(jié)合pytorch源碼以盡可能簡(jiǎn)潔的方式把Transformer的工作流程講解以及原理講解清楚。全文分為三個(gè)部分
- Transformer架構(gòu):這個(gè)模塊的詳細(xì)說明
- pytorch中Transformer的api解讀
- 實(shí)際運(yùn)用:雖然Transformer的api使用大大簡(jiǎn)化了打碼量,但是還有需要自已實(shí)現(xiàn)一些代碼的
Transformer架構(gòu)
Transformer結(jié)構(gòu)如下:
Transformer的經(jīng)典應(yīng)用場(chǎng)景就是機(jī)器翻譯。
整體分為Encoder揍诽、Decoder兩大部分,具體實(shí)現(xiàn)細(xì)分為六塊。
-
輸入編碼什湘、位置編碼
Encoder、Decoder都需要將輸入字符進(jìn)行編碼送入網(wǎng)絡(luò)訓(xùn)練晦攒。
Input Embeding:將一個(gè)字符(或者漢字)進(jìn)行編碼闽撤,比如“我愛中國(guó)”四個(gè)漢字編碼后會(huì)變成(4,d_model)的矩陣脯颜,Transformer中d_model等于512哟旗,那么輸入就變成(4,512)的矩陣,為了方便敘述闸餐,后面都用(4饱亮,512)來當(dāng)成模型的輸入。
positional encoding:在Q舍沙、K近上、V的計(jì)算過程中,輸入單詞的位置信息會(huì)丟失掉拂铡。所以需要額外用一個(gè)位置編碼來表示輸入單詞的順序壹无。編碼公式如下
其中,pos:表示第幾個(gè)單詞感帅,2i,2i+1表示Input Embeding編碼維度(512)的偶數(shù)位斗锭、奇數(shù)位。
論文中作者也試過將positional encoding變成可以學(xué)習(xí)的失球,但是發(fā)現(xiàn)效果差不多岖是;而且使用硬位置編碼就不用考慮在推斷環(huán)節(jié)中句子的實(shí)際長(zhǎng)度超過訓(xùn)練環(huán)節(jié)中使用的位置編碼長(zhǎng)度的問題;為什么使用sin她倘、cos呢璧微?可以有效的考慮句中單詞的相對(duì)位置信息 -
多頭注意力機(jī)制(Multi-Head Attention)
多頭注意力機(jī)制是Transformer的核心,屬于Self-Attention(自注意力)硬梁。注意只要是可以注意到自身特征的注意力機(jī)制就叫Self-Attention前硫,并不是只有Transformer有。 示意圖如下
image.pngMulti-Head Attention的輸入的Q荧止、K屹电、V就是輸入的(4,512)維矩陣,Q=K=V跃巡。然后用全連接層對(duì)Q危号、K、V做一個(gè)變換素邪,多頭就是指的這里將輸入映射到多個(gè)空間外莲。公式描述如下:
其中
其中
其中
, 論文中h=8,
稱為注意力矩陣(attention),表示兩個(gè)句子中的任意兩個(gè)單詞的相關(guān)性。所以attention mask不一定是方陣兔朦。
-
前向傳播模塊
Q偷线、K、V經(jīng)過Multi-Head Attention模塊再加上一個(gè)殘差跳鏈沽甥,維度不變声邦,輸入維度是(4,512)摆舟,輸出維度還是(4,512)亥曹,只不過輸出的矩陣的每一行已經(jīng)融合了其他行的信息(根據(jù)attention mask)邓了。
這里前向傳播模塊是一個(gè)兩層的全連接。公式如下:, 其中輸入輸出維度為
, 中間維度
-
帶Mask的多頭注意力機(jī)制
這里的Mask Multi-head Attention與步驟2中的稍有不同媳瞪∑“我愛中國(guó)”的英文翻譯為“I love china”。 在翻譯到“l(fā)ove”的時(shí)候材失,其實(shí)我們是不知道“china”的這個(gè)單詞的痕鳍,所以在訓(xùn)練的時(shí)候,就需要來模擬這個(gè)過程龙巨。即用一個(gè)mask來遮住后面信息笼呆。這個(gè)mask在實(shí)際實(shí)現(xiàn)中是一個(gè)三角矩陣(主對(duì)角線及下方為0,上方為-inf)旨别, 定義為
大概就長(zhǎng)下面這個(gè)樣子
attention mask, 加上
后數(shù)學(xué)表達(dá)為
诗赌。
-inf在經(jīng)過softmax后變成0,就相當(dāng)于忽略對(duì)應(yīng)的單詞信息
-
Decoder中的多頭注意力機(jī)制
這里的Multi-head Attention和步驟2中的注意力就差不多是一個(gè)意思了秸弛,但是
中K铭若,V是來自Encoder的輸出,Q是Mask Multi-head Attention后的輸出
-
預(yù)測(cè)
整個(gè)Transformer的輸入維度是(4,512)递览,輸出維度是(4,512)叼屠。那么如果變成最終具體的單詞呢(假如單詞表大小為10000)。那么最后的輸出output必須是(4绞铃,10000)镜雨。max(output, dim=-1)的下標(biāo)就是單詞的序號(hào)。
以上就是Transformer的核心流程以及解釋儿捧,那么下面接下來看一下Transformer在pytorch中的具體實(shí)現(xiàn)荚坞。
二、Transformer在pytorch中具體實(shí)現(xiàn)
pytorch version: 1.11
2.1 Transformer類
要自已實(shí)現(xiàn)一個(gè)完整的Transformer菲盾,還是有點(diǎn)難度的颓影,好在pytorch提供了官方實(shí)現(xiàn)。所有的核心細(xì)節(jié)都被封裝Transformer類懒鉴。特別說明诡挂,下面的代碼講解中會(huì)刪除非核心代碼,只保留核心細(xì)節(jié)
下面代碼講解會(huì)對(duì)矩陣的一些維度做一些說明临谱,這里統(tǒng)一下符號(hào)
- N: batch size
- S: Encoder輸入序列的長(zhǎng)度
- T: Decoder輸入序列的長(zhǎng)度
- E: embeding的維度咆畏,就是上文中
class Transformer(Module):
"""
Args:
d_model: 單詞維度(default=512).
nhead: 多頭注意力中的head數(shù)量 (default=8).
num_encoder_layers: Encoder中子Encoder堆疊的數(shù)量(default=6).
num_decoder_layers: Decoder中子Decoder堆疊的數(shù)量(default=6).
dim_feedforward: 前向傳播的中間維度 (default=2048).
dropout: the dropout value (default=0.1).
activation: Default: relu
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).
layer_norm_eps: (default=1e-5).
batch_first: Default: ``False`` (seq, batch, feature).
norm_first:
"""
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(Transformer, self).__init__()
# 步驟2和3的實(shí)現(xiàn)
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first,
**factory_kwargs)
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
# 將TransformerEncoderLayer執(zhí)行6次,簡(jiǎn)單的堆疊而已
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
# 步驟4吴裤、5的實(shí)現(xiàn)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first,
**factory_kwargs)
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self.d_model = d_model
self.nhead = nhead
self.batch_first = batch_first
def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
"""
src, tgt: 這里的src、tgt是已經(jīng)經(jīng)過input embeding和positional emdbeding后的輸入溺健。
memory_mask:實(shí)際使用的過程中都是None
memory_key_padding_mask:實(shí)際使用的過程中就是src_key_padding_mask
"""
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
# 這里的memory就是步驟5中K麦牺,V矩陣
# memory_key_padding_mask實(shí)際使用過程就是傳入的src_key_padding_mask钮蛛,在步驟5中需要遮住Encoder中padding的位置
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask剖膳,memory_key_padding_mask=memory_key_padding_mask)
return output
@staticmethod
def generate_square_subsequent_mask(sz: int) -> Tensor:
r"""用來生成步驟4中attention mask."""
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
關(guān)鍵參數(shù)說明
- batch_first:默認(rèn)False魏颓,在pytorch中,rnn吱晒、Transformer層的輸入維度一般是(seq, batch, feature)甸饱,第一個(gè)維度表示seq的長(zhǎng)度,batch放在第二個(gè)維度
- src仑濒、tgt:由于batch_first=False叹话, 所以Encoder的輸入src、Decoder的輸入tgt的shape都是(seq, batch, feature)
- src_mask: shape:(S, S), 含義就是上文講的attenion_mask, Encoder的輸入是不需要遮住后面的單詞的墩瞳,所以該參數(shù)一般是一個(gè)全為False的陣
- tgt_mask: shape:(T, T), 含義同src_mask驼壶,但是Decoder的輸入是需要遮住后面的單詞的,所以這里的mask是一個(gè)三角矩陣(下三角是0喉酌,上三角是-inf热凹。當(dāng)然也可以用True, False表示)
- src_key_padding_mask: shape:(N, S) 因?yàn)槟P偷妮斎胍话闶且粋€(gè)batch泪电,實(shí)際場(chǎng)景中輸入的句子或者序列是不等長(zhǎng)的般妙,那么就需要將不等長(zhǎng)的多個(gè)序列通過padding的方式補(bǔ)齊成等長(zhǎng)的。那么padding的位置無意義相速,不需要參與attention的計(jì)算碟渺,所以通src_key_padding_mask來標(biāo)記padding的位置。后面計(jì)算attention的時(shí)候不計(jì)算該位置的權(quán)重
- tgt_key_padding_mask: shape:(N, T) 含義同src_key_padding_mask
src_mask和蚪,tgt_mask止状,src_key_padding_mask,tgt_key_padding_mask雖然在這里是分開的攒霹,但是在計(jì)算attention時(shí)候怯疤,實(shí)際是合并到同一個(gè)attention矩陣中的;維度不同催束,通過廣播的方式合并
2.2 TransformerEncoder集峦、TransformerDecoder
TransformerEncoder、TransformerDecoder邏輯是類似的抠刺,就是執(zhí)行TransformerEncoderLayer多次塔淤,默認(rèn)是6次,以TransformerEncoder為例
class TransformerEncoder(Module):
r"""TransformerEncoderLayer 堆疊N次
Args:
encoder_layer: TransformerEncoderLayer(子模塊)
num_layers: 堆疊次數(shù)
norm:
"""
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
# 模塊復(fù)制N次
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
output = src
# 執(zhí)行N次TransformerEncoderLayer
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
# 正則
if self.norm is not None:
output = self.norm(output)
return output
2.3 TransformerEncoderLayer
class TransformerEncoderLayer(Module):
r"""步驟2和3的具體實(shí)現(xiàn).
Args:
基本上見名知意
"""
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(TransformerEncoderLayer, self).__init__()
# 步驟2的實(shí)現(xiàn)
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
# 步驟3的實(shí)現(xiàn)速妖,兩個(gè)全連接層
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
"""
x = src
if self.norm_first: # norm操作放在哪里執(zhí)行
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
# 步驟2
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
# 步驟3
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
上面除了MultiheadAttention的實(shí)現(xiàn)細(xì)節(jié)高蜂,其他邏輯是很清楚的
2.4 TransformerDecoderLayer
TransformerDecoderLayer的實(shí)現(xiàn)邏輯與TransformerEncoderLayer實(shí)現(xiàn)差不多,有兩點(diǎn)不同
- 第一個(gè)注意力(步驟4)的mask罕容,需要遮住后續(xù)單詞信息
- 第二個(gè)注意力(不足5)的K备恤,V來自Encoder的輸出(代碼中叫memory)
class TransformerDecoderLayer(Module):
r"""步驟4和5的具體實(shí)現(xiàn).
Args:
見名知意
"""
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
# 步驟4
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
# 步驟5
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
"""
x = tgt
# 無論是_sa_block()還是_mha_block()的具體實(shí)現(xiàn)都在MultiheadAttention中
if self.norm_first:
# _sa_block對(duì)應(yīng)步驟4稿饰,
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
# _mha_block對(duì)應(yīng)步驟5,只不過Q來自于Decoder本身露泊,K喉镰,V來自于Encoder,就是這里memory
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
x = self.norm3(x + self._ff_block(x))
return x
# # 步驟4
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
# Q=K=V=x,
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
# # 步驟5
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
# Q=x(來自于_sa_block()的輸出), K=V=mem(來自Encoder的輸出)惭笑,
# attn_mask=None, key_padding_mask等于src_key_padding_mask
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
整個(gè)流程以及參數(shù)已經(jīng)說明了侣姆,TransformerEncoderLayer、TransformerDecoderLayer的代碼實(shí)現(xiàn)中就剩MultiheadAttention實(shí)現(xiàn)沒有講了
2.5 MultiheadAttention
MultiheadAttention的核心實(shí)現(xiàn)在multi_head_attention_forward方法中沉噩,我們直接看multi_head_attention_forward方法捺宗。
以下代碼刪除了非核心參數(shù)和非核心代碼
def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
multi_head_attention 實(shí)現(xiàn)
Shape:
"""
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
match value shape {value.shape}"
# prep attention mask
# 將attn_mask變成3維的,方便后面與key_padding_mask合并
if attn_mask is not None:
if attn_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
attn_mask = attn_mask.to(torch.bool)
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# prep key padding mask
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask = key_padding_mask.to(torch.bool)
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
# update source sequence length after adjustments
src_len = k.size(1)
# merge key padding and attention masks
# key_padding_mask 和 attn_mask合并屁擅,
if key_padding_mask is not None:
# key_padding_mask 做一些維度的變換
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
elif attn_mask.dtype == torch.bool:
# 合并attention mask
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
# 將attn_mask變成float類型偿凭,True變成負(fù)無窮,F(xiàn)alse變成0
if attn_mask is not None and attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
#
# (deep breath) calculate attention and out projection
# QKV計(jì)算
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
return attn_output, None
上面的代碼實(shí)現(xiàn)的核心邏輯是Attention(Q派歌,K弯囊,V)的計(jì)算,還有一個(gè)就是Transformer的輸入?yún)?shù)*_mask, *_key_padding_mask是這么影響最終的注意力權(quán)重的胶果;就是將兩個(gè)mask合并為attn_mask,最后加到上
三匾嘱、實(shí)際應(yīng)用
實(shí)際應(yīng)用其實(shí)官方有一個(gè)翻譯的例子LANGUAGE TRANSLATION WITH NN.TRANSFORMER AND TORCHTEXT,說的還是很清楚的早抠■樱可以參考。
通過上面Transformer類的詳細(xì)說明蕊连,我們是否可以訓(xùn)練自已的seq2seq模型了呢悬垃?其實(shí)沒有,還有幾件事要做
將輸入字符變成一個(gè)個(gè)數(shù)字甘苍,需要自已按照使用場(chǎng)景實(shí)現(xiàn)一個(gè)類尝蠕。以中文為例,是將一個(gè)漢字mapping到一個(gè)索引载庭,還是將一個(gè)詞mapping到一個(gè)索引看彼。可以仿照pytorch的torchtext.vocab.build_vocab_from_iterator去實(shí)現(xiàn)
將上面轉(zhuǎn)化后的索引list轉(zhuǎn)化為囚聚,Transformer類的輸入靖榕。就需要實(shí)現(xiàn)一個(gè)映射詞向量表和位置編碼。比較“我愛中國(guó)”轉(zhuǎn)化為索引后可能是[300, 250, 10, 888],詞向量表需要將這個(gè)list變成(4顽铸, 512)的向量茁计,即用一個(gè)512位的向量來表示一個(gè)單詞
這里給出一個(gè)實(shí)現(xiàn), 其實(shí)就將pytorch的Transformer類加上輸入編碼和位置編碼部分。
class Seq2SeqTransformer(nn.Module):
def __init__(self,
num_encoder_layers: int,
num_decoder_layers: int,
emb_size: int,
nhead: int,
src_vocab_size: int,
tgt_vocab_size: int,
dim_feedforward: int = 512,
dropout: float = 0.1):
super(Seq2SeqTransformer, self).__init__()
self.transformer = Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout)
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
self.positional_encoding = PositionalEncoding(
emb_size, dropout=dropout)
def forward(self,
src: Tensor,
trg: Tensor,
src_mask: Tensor,
tgt_mask: Tensor,
src_padding_mask: Tensor,
tgt_padding_mask: Tensor,
memory_key_padding_mask: Tensor):
"""
:param src:
:param trg:
:param src_mask: 用于遮擋句子的下文谓松,shape(S, S)
:param tgt_mask:
:param src_padding_mask: 用于指定pad位置簸淀,shape(B, S)
:param tgt_padding_mask:
:param memory_key_padding_mask:
:return:
"""
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer.encoder(self.positional_encoding(
self.src_tok_emb(src)), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer.decoder(self.positional_encoding(
self.tgt_tok_emb(tgt)), memory,
tgt_mask)
# 單詞編碼
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens: Tensor):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
# 位置編碼
class PositionalEncoding(nn.Module):
def __init__(self,
emb_size: int,
dropout: float,
maxlen: int = 5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)
self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
- 實(shí)際的翻譯過程
這一部分等我有空且想完善的時(shí)候瓶蝴,會(huì)繼續(xù)完善,在這之前強(qiáng)烈建議大家直接參考官方的例子租幕,見參考文獻(xiàn)一