Transformer in pytorch

一 Transformer overview

本文結(jié)合pytorch源碼以盡可能簡(jiǎn)潔的方式把Transformer的工作流程講解以及原理講解清楚。全文分為三個(gè)部分

  1. Transformer架構(gòu):這個(gè)模塊的詳細(xì)說明
  2. pytorch中Transformer的api解讀
  3. 實(shí)際運(yùn)用:雖然Transformer的api使用大大簡(jiǎn)化了打碼量,但是還有需要自已實(shí)現(xiàn)一些代碼的

Transformer架構(gòu)

Transformer結(jié)構(gòu)如下:


image.png

Transformer的經(jīng)典應(yīng)用場(chǎng)景就是機(jī)器翻譯。
整體分為Encoder揍诽、Decoder兩大部分,具體實(shí)現(xiàn)細(xì)分為六塊。

  1. 輸入編碼什湘、位置編碼

    Encoder、Decoder都需要將輸入字符進(jìn)行編碼送入網(wǎng)絡(luò)訓(xùn)練晦攒。

    Input Embeding:將一個(gè)字符(或者漢字)進(jìn)行編碼闽撤,比如“我愛中國(guó)”四個(gè)漢字編碼后會(huì)變成(4,d_model)的矩陣脯颜,Transformer中d_model等于512哟旗,那么輸入就變成(4,512)的矩陣,為了方便敘述闸餐,后面都用(4饱亮,512)來當(dāng)成模型的輸入

    positional encoding:在Q舍沙、K近上、V的計(jì)算過程中,輸入單詞的位置信息會(huì)丟失掉拂铡。所以需要額外用一個(gè)位置編碼來表示輸入單詞的順序壹无。編碼公式如下

    PE_{pos,2i}=sin(pos/1000^{2i/d_{model}})

    PE_{pos,2i+1}=cos(pos/1000^{2i/d_{model}})

    其中,pos:表示第幾個(gè)單詞感帅,2i,2i+1表示Input Embeding編碼維度(512)的偶數(shù)位斗锭、奇數(shù)位。
    論文中作者也試過將positional encoding變成可以學(xué)習(xí)的失球,但是發(fā)現(xiàn)效果差不多岖是;而且使用硬位置編碼就不用考慮在推斷環(huán)節(jié)中句子的實(shí)際長(zhǎng)度超過訓(xùn)練環(huán)節(jié)中使用的位置編碼長(zhǎng)度的問題;為什么使用sin她倘、cos呢璧微?可以有效的考慮句中單詞的相對(duì)位置信息

  2. 多頭注意力機(jī)制(Multi-Head Attention)

    多頭注意力機(jī)制是Transformer的核心,屬于Self-Attention(自注意力)硬梁。注意只要是可以注意到自身特征的注意力機(jī)制就叫Self-Attention前硫,并不是只有Transformer有。 示意圖如下

    image.png

    Multi-Head Attention的輸入的Q荧止、K屹电、V就是輸入的(4,512)維矩陣,Q=K=V跃巡。然后用全連接層對(duì)Q危号、K、V做一個(gè)變換素邪,多頭就是指的這里將輸入映射到多個(gè)空間外莲。公式描述如下:

    MultiHead(Q,K,V)=Concat(head_1, head_2,..., head_n)W^o

    其中
    head_i=Attention(QW^Q_i, KW^K_i, VW^V_i)

    其中
    Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V

    其中W^Q_i\in R^{d_{model}*d_k}, W^K_i\in R^{d_{model}*d_k}, W^V_i\in R^{d_{model}*d_v}, W^o\in R^{hd_v*d_{model}}, 論文中h=8, d_k=d_v=d_{model}/h=512/8=64

    QK^T稱為注意力矩陣(attention),表示兩個(gè)句子中的任意兩個(gè)單詞的相關(guān)性。所以attention mask不一定是方陣兔朦。

  3. 前向傳播模塊

    Q偷线、K、V經(jīng)過Multi-Head Attention模塊再加上一個(gè)殘差跳鏈沽甥,維度不變声邦,輸入維度是(4,512)摆舟,輸出維度還是(4,512)亥曹,只不過輸出的矩陣的每一行已經(jīng)融合了其他行的信息(根據(jù)attention mask)邓了。
    這里前向傳播模塊是一個(gè)兩層的全連接。公式如下:

    FFN(x)=max(0, xW_1+b_1)W_2+b_2, 其中輸入輸出維度為d_model=512, 中間維度d_{ff}=2048

  4. 帶Mask的多頭注意力機(jī)制

    這里的Mask Multi-head Attention與步驟2中的稍有不同媳瞪∑“我愛中國(guó)”的英文翻譯為“I love china”。 在翻譯到“l(fā)ove”的時(shí)候材失,其實(shí)我們是不知道“china”的這個(gè)單詞的痕鳍,所以在訓(xùn)練的時(shí)候,就需要來模擬這個(gè)過程龙巨。即用一個(gè)mask來遮住后面信息笼呆。這個(gè)mask在實(shí)際實(shí)現(xiàn)中是一個(gè)三角矩陣(主對(duì)角線及下方為0,上方為-inf)旨别, 定義為attention\_mask大概就長(zhǎng)下面這個(gè)樣子

    attention mask

    Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V, 加上attention\_mask后數(shù)學(xué)表達(dá)為

    Masked\_Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}}+attention_{mask})V诗赌。

    -inf在經(jīng)過softmax后變成0,就相當(dāng)于忽略對(duì)應(yīng)的單詞信息

  5. Decoder中的多頭注意力機(jī)制

    這里的Multi-head Attention和步驟2中的注意力就差不多是一個(gè)意思了秸弛,但是Attention(Q,K,V)中K铭若,V是來自Encoder的輸出,Q是Mask Multi-head Attention后的輸出

  6. 預(yù)測(cè)

    整個(gè)Transformer的輸入維度是(4,512)递览,輸出維度是(4,512)叼屠。那么如果變成最終具體的單詞呢(假如單詞表大小為10000)。那么最后的輸出output必須是(4绞铃,10000)镜雨。max(output, dim=-1)的下標(biāo)就是單詞的序號(hào)。

以上就是Transformer的核心流程以及解釋儿捧,那么下面接下來看一下Transformer在pytorch中的具體實(shí)現(xiàn)荚坞。

二、Transformer在pytorch中具體實(shí)現(xiàn)

pytorch version: 1.11

2.1 Transformer類

要自已實(shí)現(xiàn)一個(gè)完整的Transformer菲盾,還是有點(diǎn)難度的颓影,好在pytorch提供了官方實(shí)現(xiàn)。所有的核心細(xì)節(jié)都被封裝Transformer類懒鉴。特別說明诡挂,下面的代碼講解中會(huì)刪除非核心代碼,只保留核心細(xì)節(jié)

下面代碼講解會(huì)對(duì)矩陣的一些維度做一些說明临谱,這里統(tǒng)一下符號(hào)

  • N: batch size
  • S: Encoder輸入序列的長(zhǎng)度
  • T: Decoder輸入序列的長(zhǎng)度
  • E: embeding的維度咆畏,就是上文中d_{model}
class Transformer(Module):
"""   
Args:
        d_model: 單詞維度(default=512).
        nhead: 多頭注意力中的head數(shù)量 (default=8).
        num_encoder_layers: Encoder中子Encoder堆疊的數(shù)量(default=6).
        num_decoder_layers: Decoder中子Decoder堆疊的數(shù)量(default=6).
        dim_feedforward: 前向傳播的中間維度 (default=2048).
        dropout: the dropout value (default=0.1).
        activation:  Default: relu
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).
        layer_norm_eps:  (default=1e-5).
        batch_first: Default: ``False`` (seq, batch, feature).
        norm_first: 
"""
    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Transformer, self).__init__()

        # 步驟2和3的實(shí)現(xiàn)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                activation, layer_norm_eps, batch_first, norm_first,
                                                **factory_kwargs)
        encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

        # 將TransformerEncoderLayer執(zhí)行6次,簡(jiǎn)單的堆疊而已
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # 步驟4吴裤、5的實(shí)現(xiàn)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                activation, layer_norm_eps, batch_first, norm_first,
                                                **factory_kwargs)
        decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)


        self.d_model = d_model
        self.nhead = nhead

        self.batch_first = batch_first

    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        
        """
        src, tgt: 這里的src、tgt是已經(jīng)經(jīng)過input embeding和positional emdbeding后的輸入溺健。
        memory_mask:實(shí)際使用的過程中都是None
        memory_key_padding_mask:實(shí)際使用的過程中就是src_key_padding_mask
        """
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

        # 這里的memory就是步驟5中K麦牺,V矩陣
        # memory_key_padding_mask實(shí)際使用過程就是傳入的src_key_padding_mask钮蛛,在步驟5中需要遮住Encoder中padding的位置
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask剖膳,memory_key_padding_mask=memory_key_padding_mask)
        return output

    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> Tensor:
        r"""用來生成步驟4中attention mask."""
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

關(guān)鍵參數(shù)說明

  • batch_first:默認(rèn)False魏颓,在pytorch中,rnn吱晒、Transformer層的輸入維度一般是(seq, batch, feature)甸饱,第一個(gè)維度表示seq的長(zhǎng)度,batch放在第二個(gè)維度
  • src仑濒、tgt:由于batch_first=False叹话, 所以Encoder的輸入src、Decoder的輸入tgt的shape都是(seq, batch, feature)
  • src_mask: shape:(S, S), 含義就是上文講的attenion_mask, Encoder的輸入是不需要遮住后面的單詞的墩瞳,所以該參數(shù)一般是一個(gè)全為False的陣
  • tgt_mask: shape:(T, T), 含義同src_mask驼壶,但是Decoder的輸入是需要遮住后面的單詞的,所以這里的mask是一個(gè)三角矩陣(下三角是0喉酌,上三角是-inf热凹。當(dāng)然也可以用True, False表示)
  • src_key_padding_mask: shape:(N, S) 因?yàn)槟P偷妮斎胍话闶且粋€(gè)batch泪电,實(shí)際場(chǎng)景中輸入的句子或者序列是不等長(zhǎng)的般妙,那么就需要將不等長(zhǎng)的多個(gè)序列通過padding的方式補(bǔ)齊成等長(zhǎng)的。那么padding的位置無意義相速,不需要參與attention的計(jì)算碟渺,所以通src_key_padding_mask來標(biāo)記padding的位置。后面計(jì)算attention的時(shí)候不計(jì)算該位置的權(quán)重
  • tgt_key_padding_mask: shape:(N, T) 含義同src_key_padding_mask

src_mask和蚪,tgt_mask止状,src_key_padding_mask,tgt_key_padding_mask雖然在這里是分開的攒霹,但是在計(jì)算attention時(shí)候怯疤,實(shí)際是合并到同一個(gè)attention矩陣中的;維度不同催束,通過廣播的方式合并

2.2 TransformerEncoder集峦、TransformerDecoder

TransformerEncoder、TransformerDecoder邏輯是類似的抠刺,就是執(zhí)行TransformerEncoderLayer多次塔淤,默認(rèn)是6次,以TransformerEncoder為例

class TransformerEncoder(Module):
    r"""TransformerEncoderLayer 堆疊N次

    Args:
        encoder_layer: TransformerEncoderLayer(子模塊)
        num_layers: 堆疊次數(shù)
        norm: 

    """

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        # 模塊復(fù)制N次
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
     
        output = src
        # 執(zhí)行N次TransformerEncoderLayer
        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        # 正則
        if self.norm is not None:
            output = self.norm(output)

        return output

2.3 TransformerEncoderLayer

class TransformerEncoderLayer(Module):
    r"""步驟2和3的具體實(shí)現(xiàn).

    Args:
        基本上見名知意

    """

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()

        # 步驟2的實(shí)現(xiàn)
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)

        # 步驟3的實(shí)現(xiàn)速妖,兩個(gè)全連接層
        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)


    

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        """
        x = src
        if self.norm_first:  # norm操作放在哪里執(zhí)行
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))

        return x

    # 步驟2
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # 步驟3
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)

上面除了MultiheadAttention的實(shí)現(xiàn)細(xì)節(jié)高蜂,其他邏輯是很清楚的

2.4 TransformerDecoderLayer

TransformerDecoderLayer的實(shí)現(xiàn)邏輯與TransformerEncoderLayer實(shí)現(xiàn)差不多,有兩點(diǎn)不同

  1. 第一個(gè)注意力(步驟4)的mask罕容,需要遮住后續(xù)單詞信息
  2. 第二個(gè)注意力(不足5)的K备恤,V來自Encoder的輸出(代碼中叫memory)
class TransformerDecoderLayer(Module):
    r"""步驟4和5的具體實(shí)現(xiàn).
    Args:
        見名知意

    """

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        
        # 步驟4
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # 步驟5                                    
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                                 **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)


    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.

       
        """
        
        x = tgt
        # 無論是_sa_block()還是_mha_block()的具體實(shí)現(xiàn)都在MultiheadAttention中
        if self.norm_first:
        # _sa_block對(duì)應(yīng)步驟4稿饰,
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
        # _mha_block對(duì)應(yīng)步驟5,只不過Q來自于Decoder本身露泊,K喉镰,V來自于Encoder,就是這里memory
            x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
            x = x + self._ff_block(self.norm3(x))
        else:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
            x = self.norm3(x + self._ff_block(x))

        return x

    # # 步驟4
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:

        # Q=K=V=x, 
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # # 步驟5
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        # Q=x(來自于_sa_block()的輸出), K=V=mem(來自Encoder的輸出)惭笑,
        # attn_mask=None, key_padding_mask等于src_key_padding_mask
        x = self.multihead_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                need_weights=False)[0]
        return self.dropout2(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)

整個(gè)流程以及參數(shù)已經(jīng)說明了侣姆,TransformerEncoderLayer、TransformerDecoderLayer的代碼實(shí)現(xiàn)中就剩MultiheadAttention實(shí)現(xiàn)沒有講了

2.5 MultiheadAttention

MultiheadAttention的核心實(shí)現(xiàn)在multi_head_attention_forward方法中沉噩,我們直接看multi_head_attention_forward方法捺宗。
以下代碼刪除了非核心參數(shù)和非核心代碼

def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    key_padding_mask: Optional[Tensor] = None,
    attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        multi_head_attention 實(shí)現(xiàn)
    Shape:
       
    """
    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    match value shape {value.shape}"

    # prep attention mask
    # 將attn_mask變成3維的,方便后面與key_padding_mask合并
    if attn_mask is not None:
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # prep key padding mask
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        key_padding_mask = key_padding_mask.to(torch.bool)

    
    #
    # reshape q, k, v for multihead attention and make em batch first
    #
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
   

   
    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    # key_padding_mask 和 attn_mask合并屁擅,
    if key_padding_mask is not None:
        # key_padding_mask 做一些維度的變換
        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
        if attn_mask is None:
            attn_mask = key_padding_mask
        elif attn_mask.dtype == torch.bool:
            # 合并attention mask
            attn_mask = attn_mask.logical_or(key_padding_mask)
        else:
            attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

    # 將attn_mask變成float類型偿凭,True變成負(fù)無窮,F(xiàn)alse變成0
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    #
    # (deep breath) calculate attention and out projection
    # QKV計(jì)算
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)

    return attn_output, None

上面的代碼實(shí)現(xiàn)的核心邏輯是Attention(Q派歌,K弯囊,V)的計(jì)算,還有一個(gè)就是Transformer的輸入?yún)?shù)*_mask, *_key_padding_mask是這么影響最終的注意力權(quán)重的胶果;就是將兩個(gè)mask合并為attn_mask,最后加到QK^T

三匾嘱、實(shí)際應(yīng)用

實(shí)際應(yīng)用其實(shí)官方有一個(gè)翻譯的例子LANGUAGE TRANSLATION WITH NN.TRANSFORMER AND TORCHTEXT,說的還是很清楚的早抠■樱可以參考。

通過上面Transformer類的詳細(xì)說明蕊连,我們是否可以訓(xùn)練自已的seq2seq模型了呢悬垃?其實(shí)沒有,還有幾件事要做

  1. 將輸入字符變成一個(gè)個(gè)數(shù)字甘苍,需要自已按照使用場(chǎng)景實(shí)現(xiàn)一個(gè)類尝蠕。以中文為例,是將一個(gè)漢字mapping到一個(gè)索引载庭,還是將一個(gè)詞mapping到一個(gè)索引看彼。可以仿照pytorch的torchtext.vocab.build_vocab_from_iterator去實(shí)現(xiàn)

  2. 將上面轉(zhuǎn)化后的索引list轉(zhuǎn)化為囚聚,Transformer類的輸入靖榕。就需要實(shí)現(xiàn)一個(gè)映射詞向量表和位置編碼。比較“我愛中國(guó)”轉(zhuǎn)化為索引后可能是[300, 250, 10, 888],詞向量表需要將這個(gè)list變成(4顽铸, 512)的向量茁计,即用一個(gè)512位的向量來表示一個(gè)單詞
    這里給出一個(gè)實(shí)現(xiàn), 其實(shí)就將pytorch的Transformer類加上輸入編碼和位置編碼部分。

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        """

        :param src:
        :param trg:
        :param src_mask: 用于遮擋句子的下文谓松,shape(S, S)
        :param tgt_mask:
        :param src_padding_mask: 用于指定pad位置簸淀,shape(B, S)
        :param tgt_padding_mask:
        :param memory_key_padding_mask:
        :return:
        """
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

# 單詞編碼
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# 位置編碼
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
  1. 實(shí)際的翻譯過程

這一部分等我有空且想完善的時(shí)候瓶蝴,會(huì)繼續(xù)完善,在這之前強(qiáng)烈建議大家直接參考官方的例子租幕,見參考文獻(xiàn)一

參考文獻(xiàn)

  1. LANGUAGE TRANSLATION WITH NN.TRANSFORMER AND TORCHTEXT
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市拧簸,隨后出現(xiàn)的幾起案子劲绪,更是在濱河造成了極大的恐慌,老刑警劉巖盆赤,帶你破解...
    沈念sama閱讀 218,122評(píng)論 6 505
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件贾富,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡牺六,警方通過查閱死者的電腦和手機(jī)颤枪,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,070評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來淑际,“玉大人畏纲,你說我怎么就攤上這事〈郝疲” “怎么了盗胀?”我有些...
    開封第一講書人閱讀 164,491評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)锄贼。 經(jīng)常有香客問我票灰,道長(zhǎng),這世上最難降的妖魔是什么宅荤? 我笑而不...
    開封第一講書人閱讀 58,636評(píng)論 1 293
  • 正文 為了忘掉前任屑迂,我火速辦了婚禮,結(jié)果婚禮上冯键,老公的妹妹穿的比我還像新娘惹盼。我一直安慰自己,他們只是感情好琼了,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,676評(píng)論 6 392
  • 文/花漫 我一把揭開白布逻锐。 她就那樣靜靜地躺著,像睡著了一般雕薪。 火紅的嫁衣襯著肌膚如雪昧诱。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,541評(píng)論 1 305
  • 那天所袁,我揣著相機(jī)與錄音盏档,去河邊找鬼。 笑死燥爷,一個(gè)胖子當(dāng)著我的面吹牛蜈亩,可吹牛的內(nèi)容都是我干的懦窘。 我是一名探鬼主播,決...
    沈念sama閱讀 40,292評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼稚配,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼畅涂!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起道川,我...
    開封第一講書人閱讀 39,211評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤午衰,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后冒萄,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體臊岸,經(jīng)...
    沈念sama閱讀 45,655評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,846評(píng)論 3 336
  • 正文 我和宋清朗相戀三年尊流,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了帅戒。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,965評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡崖技,死狀恐怖逻住,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情响疚,我是刑警寧澤鄙信,帶...
    沈念sama閱讀 35,684評(píng)論 5 347
  • 正文 年R本政府宣布,位于F島的核電站忿晕,受9級(jí)特大地震影響装诡,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜践盼,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,295評(píng)論 3 329
  • 文/蒙蒙 一鸦采、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧咕幻,春花似錦渔伯、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,894評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至蓝厌,卻和暖如春玄叠,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背拓提。 一陣腳步聲響...
    開封第一講書人閱讀 33,012評(píng)論 1 269
  • 我被黑心中介騙來泰國(guó)打工读恃, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,126評(píng)論 3 370
  • 正文 我出身青樓寺惫,卻偏偏與公主長(zhǎng)得像疹吃,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子西雀,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,914評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容