pyTorch版OpenNMT的學習筆記

前言

2017年1月18日Touch7的開發(fā)團隊發(fā)布了pyTorch会烙,pyTorch是一個python優(yōu)先的深度學習框架倦卖,能夠在GPU加速的基礎上實現(xiàn)Tensor計算和動態(tài)神經(jīng)網(wǎng)絡帚桩。
是的姨俩,相較于G家以靜態(tài)圖為基礎的tensorFlow球散,pyTorch的動態(tài)神經(jīng)網(wǎng)絡結(jié)構(gòu)更加靈活强缘,其通過一種稱之為「Reverse-mode auto-differentiation(反向模式自動微分)」的技術欣鳖,使你可以零延遲或零成本地任意改變你的網(wǎng)絡的行為察皇。(然而我暫時并沒有領略到這項技術的精髓... -.-!)
關于pyTorch細節(jié)的問題另做討論,這里說一說正題--基于pyTorch實現(xiàn)的OpenNMT泽台。

prepocess.py

preprocess.py相對來說比較好理解什荣,但對于OpenNMT-py環(huán)環(huán)相扣的編程方法感到很新奇,函數(shù)封裝的很細致怀酷,便于后續(xù)的debug或修改稻爬,對自己以后的編程是一個很好的啟發(fā)。此外其代碼很優(yōu)雅(beam search部分除外蜕依,稍后會有介紹)桅锄。
關于這部分代碼中makedata函數(shù)中:

if opt.shuffle == 1:
    print('... shuffling sentences')
    perm = torch.randperm(len(src))
    src = [src[idx] for idx in perm]
    tgt = [tgt[idx] for idx in perm]
    sizes = [sizes[idx] for idx in perm]

print('... sorting sentences by size')
_, perm = torch.sort(torch.Tensor(sizes))
src = [src[idx] for idx in perm]
tgt = [tgt[idx] for idx in perm]

預先shuffle一下琉雳,再根據(jù)句子長度排序,這樣在每一種長度的句子的內(nèi)部友瘤,句子是順序是隨機的咐吼,按照句長排序,使每一個batch中的句長基本相等商佑,以加快訓練速度锯茄。
而以下這部分代碼:

        src += [srcDicts.convertToIdx(srcWords,
                                      onmt.Constants.UNK_WORD)]
        tgt += [tgtDicts.convertToIdx(tgtWords,
                                      onmt.Constants.UNK_WORD,
                                      onmt.Constants.BOS_WORD,
                                      onmt.Constants.EOS_WORD)]

tgt語句中,在句前加了BOS符號茶没,在句末加了EOS符號肌幽。

prepocess.py最后保存了一個.pt文件,其中:

  • dict:字典格式抓半,保存有'src'和'tgt'的兩個字典
  • train:字典格式喂急,保存有'src'和'tgt'兩個Dict類
  • valid:字典格式,保存有'src'和'tgt'兩個Dict類

此外笛求,還對dict字典進行了存儲廊移。


train.py

直接從main()函數(shù)的'Building model'開始說起吧,中間串聯(lián)對各個函數(shù)的理解探入。
這里的encoder直接調(diào)用了pyTorch封裝好的nn.LSTM()類狡孔,其初始化參數(shù)包括:

  • input_size : input的Embedding_size
  • hidden_size : 隱狀態(tài)的數(shù)量
  • num_layers : 層數(shù)
  • bias : 默認為True,如果設置為False,網(wǎng)絡將不使用 b_ih,b_hh蜂嗽。(詳見鏈接中LSTM中的計算公式)
  • batch_fisrt : 如果設置為True,輸入和輸出的形狀將變?yōu)椋╞atch x seq_length x embedding_size)
  • dropout : 如果非0苗膝,除了最后一層,縱向?qū)又g植旧,丟棄(1-dropout)比例的隱藏神經(jīng)元
  • bidirectional : 默認為False辱揭,如果為True,成為雙向的RNN
    LTSM的輸入為:input,(h_0,c_0)
  • input : seq_len x batch x enbedding_size
  • h_0 : num_layers * num_directions x batch x hidden_size
  • c_0 : num_layers * num_directions x batch x hidden_size
    輸出為:
  • output :seq_len x batch x hidden_size * num_directions
  • h_n : num_layers * num_directions x batch x hidden_size
  • c_n : num_layers * num_directions x batch x hidden_size
    decoder中self.rnn卻是用LSTMCell()堆疊出來的病附,然而為什么要這么做呢问窃?-.-!
    LSTMCell()的輸入輸出維度為:
    輸入:
  • input : batch x embedding_size
  • h_0 : batch x hidden_size
  • c_0 : batch x hidden_size
    輸出:
  • h_1 : batch x hidden_size
  • c_1 : batch x hidden_size
    在decoder中引入了attention機制,類似于于pytorch tutorials中seq2seq模型中的attention機制完沪,
    圖片截自pytorch tutorials

    但又略有不同域庇,如圖在bmm的到attn_applied之后,OpenNMT-py代碼沒有選擇將attn_applied與embedd相結(jié)合丽焊,而是經(jīng)過一次softmax后變形為batch x 1 x src_sent_length(attn3) 较剃,再和context 矩陣相乘(weightedContext)后與input連接(contextCombined),最后經(jīng)過線性變化再取tanh后返回技健。
    (其實對attention機制這樣的處理方式并沒有一個直觀理解写穴,求大神講解)
    模型部分說明完畢接下來看看trainModel函數(shù),這里首先需要注意的一點是,在Dataset.py中重寫了getitem方法雌贱,每次給trainData一個一個batchIdx去的是一個batch的數(shù)據(jù)啊送,也重寫了len方法偿短,用len(trainData)返回的是numBatchs。
    然后將batch輸入進model馋没,batch輸入進model之后將tgt切掉最后一維EOS符號的昔逗,然后默認是以Teacher forcing的方式進行訓練。Teacher forcing 就是將tgt的值作為decoder每次的輸入篷朵,而不是使用其產(chǎn)生的預測值勾怒,這樣做的好處就是可以使模型更快的收斂,但是對沒有見到過的句子效果可能欠佳声旺。

translata.py

這里面的重點是:Translator.py文件中的translateBatch()函數(shù)笔链。

    #  (2) if a target is specified, compute the 'goldScore'
    #  (i.e. log likelihood) of the target under the model
    goldScores = context.data.new(batchSize).zero_()
    if tgtBatch is not None:
        decStates = encStates
        decOut = self.model.make_init_decoder_output(context)
        self.model.decoder.apply(applyContextMask)
        initOutput = self.model.make_init_decoder_output(context)

        decOut, decStates, attn = self.model.decoder(
            tgtBatch[:-1], decStates, context, initOutput)
        for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
            gen_t = self.model.generator.forward(dec_t)
            tgt_t = tgt_t.unsqueeze(1)
            scores = gen_t.data.gather(1, tgt_t)
            scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
            goldScores += scores

其中這部分代碼,是計算model翻譯的結(jié)果與標準答案對比后獲得分數(shù)腮猖,分數(shù)由翻譯正確的詞的概率取和得到鉴扫。
接下來重點說明一下,OpenNMT-py優(yōu)雅的代碼中的一個槽點澈缺,beam-search部分坪创,實在寫的略難理解。
首先:

    context = Variable(context.data.repeat(1, beamSize, 1))
    decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                 Variable(encStates[1].data.repeat(1, beamSize, 1)))

    beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

將encoder 輸出的context,decStates各沿第二維方向重復beamsize遍姐赡,其中context維度由seq_len x batch x hidden_size * num_directions變?yōu)閟eq_len x batch*beamsize x hidden_size * num_directions莱预,并將beam初始化為一個含有batch個Beam類的列表。

        input = torch.stack([b.getCurrentState() for b in beam
                           if not b.done]).t().contiguous().view(1, -1)

這行代碼將每個beam中上一時間步的預測值取出來雏吭,再將得到的batch x beam_size 轉(zhuǎn)置成beam_size x batch 后在view成一行锁施,沒隔batch個數(shù)據(jù)屬于同一個beam,形成beam_size個batch恰好與context和decStates的seq_len x batch*beam_size x rnn_size相對應。而model計算之后的out與input相對應杖们,故

wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()

此處對view的計算方法論存疑,理解上out應該是batch * beam x num_words ,

wordLk = out.view(remainingSents, beamsize, -1).contiguous()

就可以直接得到batch x beam x num_words 肩狂。
然后關注Beam.advance()方法摘完,
其中的prevKs是后指針,即記錄的是這一步結(jié)果對應來自上一步nextYs的第幾個值傻谁,nextYs記錄的是每一時間步產(chǎn)生的beam_size個最佳結(jié)果的idx孝治。
因為每次傳進beam_size x num_words個值,展成一個列表之后選取的最佳beam_size個值在整除num_words后得到的是這個最佳值來自那個beam审磁,而bestScoresId - prevK * numWords得到的是最佳結(jié)果的idx谈飒。

(另有細節(jié)問題,會不定時更新态蒂。)

最后編輯于
?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末杭措,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子钾恢,更是在濱河造成了極大的恐慌手素,老刑警劉巖鸳址,帶你破解...
    沈念sama閱讀 211,561評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異泉懦,居然都是意外死亡稿黍,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,218評論 3 385
  • 文/潘曉璐 我一進店門崩哩,熙熙樓的掌柜王于貴愁眉苦臉地迎上來巡球,“玉大人,你說我怎么就攤上這事邓嘹『ㄕ唬” “怎么了?”我有些...
    開封第一講書人閱讀 157,162評論 0 348
  • 文/不壞的土叔 我叫張陵吴超,是天一觀的道長钉嘹。 經(jīng)常有香客問我,道長鲸阻,這世上最難降的妖魔是什么跋涣? 我笑而不...
    開封第一講書人閱讀 56,470評論 1 283
  • 正文 為了忘掉前任,我火速辦了婚禮鸟悴,結(jié)果婚禮上陈辱,老公的妹妹穿的比我還像新娘。我一直安慰自己细诸,他們只是感情好沛贪,可當我...
    茶點故事閱讀 65,550評論 6 385
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著震贵,像睡著了一般利赋。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上猩系,一...
    開封第一講書人閱讀 49,806評論 1 290
  • 那天媚送,我揣著相機與錄音,去河邊找鬼寇甸。 笑死塘偎,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的拿霉。 我是一名探鬼主播吟秩,決...
    沈念sama閱讀 38,951評論 3 407
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼绽淘!你這毒婦竟也來了涵防?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,712評論 0 266
  • 序言:老撾萬榮一對情侶失蹤收恢,失蹤者是張志新(化名)和其女友劉穎武学,沒想到半個月后祭往,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,166評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡火窒,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,510評論 2 327
  • 正文 我和宋清朗相戀三年硼补,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片熏矿。...
    茶點故事閱讀 38,643評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡已骇,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出票编,到底是詐尸還是另有隱情褪储,我是刑警寧澤,帶...
    沈念sama閱讀 34,306評論 4 330
  • 正文 年R本政府宣布慧域,位于F島的核電站鲤竹,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏昔榴。R本人自食惡果不足惜辛藻,卻給世界環(huán)境...
    茶點故事閱讀 39,930評論 3 313
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望互订。 院中可真熱鬧吱肌,春花似錦、人聲如沸仰禽。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,745評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽吐葵。三九已至规揪,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間温峭,已是汗流浹背粒褒。 一陣腳步聲響...
    開封第一講書人閱讀 31,983評論 1 266
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留诚镰,地道東北人。 一個月前我還...
    沈念sama閱讀 46,351評論 2 360
  • 正文 我出身青樓祥款,卻偏偏與公主長得像清笨,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子刃跛,可洞房花燭夜當晚...
    茶點故事閱讀 43,509評論 2 348

推薦閱讀更多精彩內(nèi)容