14Seq2Seq原理(1)

先看看seq2seq原理:

encoder通過學(xué)習(xí)將輸入embedding后傳入rnn網(wǎng)絡(luò)形成一個固定大小的狀態(tài)向量S承桥,并將S傳給Decoder, Decoder一樣通過學(xué)習(xí)embedding后傳入RNN網(wǎng)絡(luò),并輸出預(yù)測結(jié)果盏筐。

優(yōu)缺點:

這樣可以解決輸入和輸出不等長的問題庙洼,如文本翻譯泛释。但是因為encoder到decoder都依賴一個固定大小的狀態(tài)向量S淘讥,所以咱們可以想象一下,信息越大悲关,轉(zhuǎn)化成為的S損失越大,隨著序列長度增加娄柳,S損失的信息會越來越大寓辱。這個是Seq2seq的缺陷,所以要引入attention及Bi-directional encoder layer等赤拒。

1.Encoder分兩步:

1.1 首先把輸入進行embedding完成對輸入序列數(shù)據(jù)嵌入工作秫筏,這里用到tf.contrib.layers.embed_sequence。
假如我們有一個batch=2挎挖,sequence_length=5的樣本跳昼,features = [[1,2,3,4,5],[6,7,8,9,10]],使用tf.contrib.layers.embed_sequence(features,vocab_size=n_words, embed_dim=10)
那么我們會得到一個2 x 5 x 10的輸出肋乍,其中features中的每個數(shù)字都被embed成了一個10維向量。

encoder_embed_input = tf.contrib.layers.embed_sequence(input_data, source_vocab_size, encoding_embedding_size)

1.2 然后embedding完的向量傳入RNN進行處理敷存,返回encoder_output, encoder_state

 def get_lstm_cell(rnn_size):
        lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
        return lstm_cell
    cell = tf.contrib.rnn.MultiRNNCell([get_lstm_cell(rnn_size) for _ in range(num_layers)])
    encoder_output, encoder_state = tf.nn.dynamic_rnn(cell, encoder_embed_input,
                                                     sequence_length=source_sequence_length, dtype=tf.float32)

2.Decoder分三步

2.1 對target數(shù)據(jù)進行預(yù)處理
為什么這一步要做預(yù)處理墓造?

image.png

  • 左邊encoder紅框很簡單,A,B,C融匯成一個輸出
  • 右邊decoder紅框接受一個輸出后锚烦,傳給每個RNN進行解碼
  • <GO>為解碼開始符 <EOS>為解碼結(jié)束符

我們預(yù)處理就要對encoder傳過來的輸出(添加<GO>觅闽,去掉<EOS>),用tf.strided_slice()

def process_decoder_input(data, vocab_to_int, batch_size):
    '''
    補充<GO>涮俄,并移除最后一個字符
    '''
    # cut掉最后一個字符
    ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1])
    decoder_input = tf.concat([tf.fill([batch_size, 1], vocab_to_int['<GO>']), ending], 1)

    return decoder_input

2.2 對target數(shù)據(jù)進行embedding

 target_vocab_size = len(target_letter_to_int)
    decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
    decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)

2.3 處理完的數(shù)據(jù)傳入RNN蛉拙,返回訓(xùn)練和預(yù)測的output

def get_decoder_cell(rnn_size):
        decoder_cell = tf.contrib.rnn.LSTMCell(rnn_size,
                                           initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
        return decoder_cell
    cell = tf.contrib.rnn.MultiRNNCell([get_decoder_cell(rnn_size) for _ in range(num_layers)])

output_layer = Dense(target_vocab_size,
                         kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))

Training Decoder

with tf.variable_scope("decode"):
    # 得到help對象
    training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,
                                                        sequence_length=target_sequence_length,
                                                        time_major=False)
    # 構(gòu)造decoder
    training_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
                                                       training_helper,
                                                       encoder_state,
                                                       output_layer) 
    training_decoder_output, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                   impute_finished=True,
                                                                   maximum_iterations=max_target_sequence_length)

Prediction decoder

# 與training共享參數(shù)
    with tf.variable_scope("decode", reuse=True):
        # 創(chuàng)建一個常量tensor并復(fù)制為batch_size的大小
        start_tokens = tf.tile(tf.constant([target_letter_to_int['<GO>']], dtype=tf.int32), [batch_size], 
                               name='start_tokens')
        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,
                                                                start_tokens,
                                                                target_letter_to_int['<EOS>'])
        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
                                                        predicting_helper,
                                                        encoder_state,
                                                        output_layer)
        predicting_decoder_output, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,
                                                            impute_finished=True,
                                                            maximum_iterations=max_target_sequence_length)
    
    return training_decoder_output, predicting_decoder_output

3.完成了encoder和decoder之后,再把兩者連接起來形成seq2seq模型

def seq2seq_model(input_data, targets, lr, target_sequence_length, 
                  max_target_sequence_length, source_sequence_length,
                  source_vocab_size, target_vocab_size,
                  encoder_embedding_size, decoder_embedding_size, 
                  rnn_size, num_layers):
    
    # 獲取encoder的狀態(tài)輸出
    _, encoder_state = get_encoder_layer(input_data, 
                                  rnn_size, 
                                  num_layers, 
                                  source_sequence_length,
                                  source_vocab_size, 
                                  encoding_embedding_size)
    
    
    # 預(yù)處理后的decoder輸入
    decoder_input = process_decoder_input(targets, target_letter_to_int, batch_size)
    
    # 將狀態(tài)向量與輸入傳遞給decoder
    training_decoder_output, predicting_decoder_output = decoding_layer(target_letter_to_int, 
                                                                       decoding_embedding_size, 
                                                                       num_layers, 
                                                                       rnn_size,
                                                                       target_sequence_length,
                                                                       max_target_sequence_length,
                                                                       encoder_state, 
                                                                       decoder_input) 
    
    return training_decoder_output, predicting_decoder_output

這是個簡單的seq2seq模型彻亲,只是對單詞的字母進行簡單的排序孕锄,數(shù)據(jù)處理部分也比較簡單,符合本篇的宗旨:講清楚什么是seq2seq模型苞尝。下一篇將應(yīng)用seq2seq模型進行英法兩種語言的文本翻譯實戰(zhàn)畸肆。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市宙址,隨后出現(xiàn)的幾起案子轴脐,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,110評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件大咱,死亡現(xiàn)場離奇詭異恬涧,居然都是意外死亡,警方通過查閱死者的電腦和手機碴巾,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評論 3 395
  • 文/潘曉璐 我一進店門溯捆,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人餐抢,你說我怎么就攤上這事现使。” “怎么了旷痕?”我有些...
    開封第一講書人閱讀 165,474評論 0 356
  • 文/不壞的土叔 我叫張陵碳锈,是天一觀的道長。 經(jīng)常有香客問我欺抗,道長售碳,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,881評論 1 295
  • 正文 為了忘掉前任绞呈,我火速辦了婚禮贸人,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘佃声。我一直安慰自己艺智,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,902評論 6 392
  • 文/花漫 我一把揭開白布圾亏。 她就那樣靜靜地躺著十拣,像睡著了一般。 火紅的嫁衣襯著肌膚如雪志鹃。 梳的紋絲不亂的頭發(fā)上夭问,一...
    開封第一講書人閱讀 51,698評論 1 305
  • 那天,我揣著相機與錄音曹铃,去河邊找鬼缰趋。 笑死,一個胖子當(dāng)著我的面吹牛陕见,可吹牛的內(nèi)容都是我干的秘血。 我是一名探鬼主播,決...
    沈念sama閱讀 40,418評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼评甜,長吁一口氣:“原來是場噩夢啊……” “哼直撤!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起蜕着,我...
    開封第一講書人閱讀 39,332評論 0 276
  • 序言:老撾萬榮一對情侶失蹤谋竖,失蹤者是張志新(化名)和其女友劉穎红柱,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體蓖乘,經(jīng)...
    沈念sama閱讀 45,796評論 1 316
  • 正文 獨居荒郊野嶺守林人離奇死亡锤悄,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,968評論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了嘉抒。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片零聚。...
    茶點故事閱讀 40,110評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖些侍,靈堂內(nèi)的尸體忽然破棺而出隶症,到底是詐尸還是另有隱情,我是刑警寧澤岗宣,帶...
    沈念sama閱讀 35,792評論 5 346
  • 正文 年R本政府宣布蚂会,位于F島的核電站,受9級特大地震影響耗式,放射性物質(zhì)發(fā)生泄漏胁住。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,455評論 3 331
  • 文/蒙蒙 一刊咳、第九天 我趴在偏房一處隱蔽的房頂上張望彪见。 院中可真熱鬧,春花似錦娱挨、人聲如沸余指。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽浪规。三九已至,卻和暖如春探孝,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背誉裆。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評論 1 272
  • 我被黑心中介騙來泰國打工顿颅, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人足丢。 一個月前我還...
    沈念sama閱讀 48,348評論 3 373
  • 正文 我出身青樓粱腻,卻偏偏與公主長得像,于是被迫代替她去往敵國和親斩跌。 傳聞我的和親對象是個殘疾皇子绍些,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,047評論 2 355

推薦閱讀更多精彩內(nèi)容