mxnet RNN簡(jiǎn)單剖析

import mxnet as mx

官方github教程部分代碼

網(wǎng)絡(luò)生成

num_layers = 2
num_hidden = 256
stack = mx.rnn.SequentialRNNCell()
for i in range(num_layers):
    stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
  • mx.rnn.SequentialRNNCell():RNN容器嫉鲸,用于組合多個(gè)RNN層
  • mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i):LSTM單元
num_embed = 256
def sym_gen(seq_len):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('softmax_label')
    embed = mx.sym.Embedding(data=data, input_dim=1000,output_dim=num_embed, name='embed')
#   數(shù)據(jù)生成世吨,定義Variable并進(jìn)行詞向量化

    stack.reset()
    outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)
#   按時(shí)間展開輸出和狀態(tài)
    
    pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
    pred = mx.sym.FullyConnected(data=pred, num_hidden=1000, name='pred')
#   變換輸出形式炼彪,將輸出變?yōu)?-1,num_hidden)尺寸

    label = mx.sym.Reshape(label, shape=(-1,))
    pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
#   展平label典挑,并計(jì)算代價(jià)函數(shù)
    
    return pred, ('data',), ('softmax_label',)
sym_gen(1)
(<Symbol softmax>, ('data',), ('softmax_label',))
  • unroll()函數(shù)按時(shí)間展開RNN單元,輸出最終的運(yùn)算結(jié)果
  • 輸出接全連接層,再轉(zhuǎn)換為詞向量

官方API文檔代碼

數(shù)據(jù)轉(zhuǎn)換

step_input = mx.symbol.Variable('step_data')

# First we embed our raw input data to be used as LSTM's input.
embedded_step = mx.symbol.Embedding(data=step_input, \
                                    input_dim=50, \
                                    output_dim=50)
# print(embedded_step.shape)
mx.viz.plot_network(symbol=embedded_step)
# Then we create an LSTM cell.
output_7_0.png

Embedding是一種詞向量化技術(shù),這種技術(shù)可以保持語義(例如相近語義的詞的向量距離會(huì)較近),將尺寸為(d0,d1...dn)的輸入向量進(jìn)行詞向量化技術(shù)后轉(zhuǎn)換為尺寸為(d0,d1,...,dn,out_dim)的向量获诈,多出的一維為詞向量,即使用一個(gè)向量代替原來一個(gè)詞的位置心褐。

  • 參數(shù)input_dim為輸入向量的范圍舔涎,即輸入data的范圍在[0,input_dim)之間
  • 參數(shù)output_dim為詞向量大小
  • 可選參數(shù)weight,可傳入指定的詞向量字典
  • 可選參數(shù)name逗爹,可傳入名稱
vocabulary_size = 26
embed_dim = 16
seq_len, batch_size = (10, 64)
input = mx.sym.Variable('letters')
op = mx.sym.Embedding(data=input, input_dim=vocabulary_size, output_dim=embed_dim,name='embed')
op.infer_shape(letters=(seq_len, batch_size))
([(10, 64), (26, 16)], [(10, 64, 16)], [])

上文的例子可以看出輸入向量尺寸為(10,64),輸出向量尺寸變?yōu)榱耍?0,64,16)

網(wǎng)絡(luò)構(gòu)建

使用了隱層為50的LSTM單元亡嫌,并帶入轉(zhuǎn)換好的數(shù)據(jù),該圖繪制出的lstm圖較經(jīng)典LSTM有一些出入

lstm_cell = mx.rnn.LSTMCell(num_hidden=50)
begin_state = lstm_cell.begin_state()
output, states = lstm_cell(embedded_step, begin_state)
mx.viz.plot_network(symbol=output)
output_11_0.png

LSTM的源碼的構(gòu)造函數(shù)如下:

def __init__(self, num_hidden, prefix='lstm_', params=None, forget_bias=1.0):
        super(LSTMCell, self).__init__(prefix=prefix, params=params)

        self._num_hidden = num_hidden
        self._iW = self.params.get('i2h_weight')
        self._hW = self.params.get('h2h_weight')
        # we add the forget_bias to i2h_bias, this adds the bias to the forget gate activation
        self._iB = self.params.get('i2h_bias', init=init.LSTMBias(forget_bias=forget_bias))
        self._hB = self.params.get('h2h_bias')

其中:self.params.get()方法為嘗試找到傳入名稱對(duì)應(yīng)的Variable掘而,若找不到則新建挟冠,因此該LSTM單元一共僅有兩對(duì)參數(shù):iW和iB,hW和hB

前向傳播函數(shù)如下:

    def __call__(self, inputs, states):
        self._counter += 1
        name = '%st%d_'%(self._prefix, self._counter)
        i2h = symbol.FullyConnected(data=inputs, weight=self._iW, bias=self._iB,
                                    num_hidden=self._num_hidden*4,
                                    name='%si2h'%name)
        h2h = symbol.FullyConnected(data=states[0], weight=self._hW, bias=self._hB,
                                    num_hidden=self._num_hidden*4,
                                    name='%sh2h'%name)
        gates = i2h + h2h
        slice_gates = symbol.SliceChannel(gates, num_outputs=4,name="%sslice"%name)
        in_gate = symbol.Activation(slice_gates[0], act_type="sigmoid",name='%si'%name)
        forget_gate = symbol.Activation(slice_gates[1], act_type="sigmoid",name='%sf'%name)
        in_transform = symbol.Activation(slice_gates[2], act_type="tanh",name='%sc'%name)
        out_gate = symbol.Activation(slice_gates[3], act_type="sigmoid",name='%so'%name)
        next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform,name='%sstate'%name)
        next_h = symbol._internal._mul(out_gate, symbol.Activation(next_c, act_type="tanh"),name='%sout'%name)
        return next_h, [next_h, next_c]

可以看出袍睡,LSTM的實(shí)現(xiàn)過程如下所示

  1. 計(jì)算隱層輸入與狀態(tài)知染,隱層的channel數(shù)量是配置的hidden_num的四倍
  2. 將隱層輸入結(jié)果和隱層狀態(tài)相加,并按channel數(shù)量切分為4份
    • 第一份作為輸入門層斑胜,經(jīng)過sigmoid函數(shù)
    • 第二份作為忘記門層控淡,經(jīng)過sigmoid函數(shù)
    • 第三份作為輸入轉(zhuǎn)換層嫌吠,經(jīng)過tanh函數(shù)
    • 第四份作為輸出門層,經(jīng)過sigmoid函數(shù)
  3. 產(chǎn)生輸出
    • 輸出狀態(tài)為忘記門層乘狀態(tài)的一部分加輸入門層乘輸入轉(zhuǎn)換層
    • 輸出結(jié)果為輸出狀態(tài)經(jīng)過tanh乘輸出門層

結(jié)果生成

sequence_length = 10
input_dim = 10
seq_input = mx.symbol.Variable('seq_data')
embedded_seq = mx.symbol.Embedding(data=seq_input, \
                                   input_dim=input_dim, \
                                   output_dim=embed_dim)
outputs, states = lstm_cell.unroll(length=sequence_length, \
                                   inputs=embedded_seq, \
                                   layout='NTC', \
                                   merge_outputs=True)

使用unroll方法按時(shí)間展平運(yùn)算掺炭,輸入數(shù)據(jù)為(batch_size,lenght,...)(layout="NTC)或(lenght,batch,...)(layout="TNC)

該函數(shù)的源碼為:

def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()

        inputs, _ = _normalize_sequence(length, inputs, layout, False)
        if begin_state is None:
            begin_state = self.begin_state()
        states = begin_state
        outputs = []
        for i in range(length):
            output, states = self(inputs[i], states)
            outputs.append(output)
        outputs, _ = _normalize_sequence(length, outputs, layout, merge_outputs)
        return outputs, states

方法_normalize_sequence是對(duì)輸入做一些處理辫诅,由一個(gè)for循環(huán)可以看出該方法循環(huán)了網(wǎng)絡(luò)運(yùn)算

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市竹伸,隨后出現(xiàn)的幾起案子泥栖,更是在濱河造成了極大的恐慌簇宽,老刑警劉巖勋篓,帶你破解...
    沈念sama閱讀 221,548評(píng)論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異魏割,居然都是意外死亡譬嚣,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,497評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門钞它,熙熙樓的掌柜王于貴愁眉苦臉地迎上來拜银,“玉大人,你說我怎么就攤上這事遭垛∧嵬埃” “怎么了?”我有些...
    開封第一講書人閱讀 167,990評(píng)論 0 360
  • 文/不壞的土叔 我叫張陵锯仪,是天一觀的道長(zhǎng)泵督。 經(jīng)常有香客問我,道長(zhǎng)庶喜,這世上最難降的妖魔是什么小腊? 我笑而不...
    開封第一講書人閱讀 59,618評(píng)論 1 296
  • 正文 為了忘掉前任,我火速辦了婚禮久窟,結(jié)果婚禮上秩冈,老公的妹妹穿的比我還像新娘。我一直安慰自己斥扛,他們只是感情好入问,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,618評(píng)論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著稀颁,像睡著了一般芬失。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上峻村,一...
    開封第一講書人閱讀 52,246評(píng)論 1 308
  • 那天麸折,我揣著相機(jī)與錄音,去河邊找鬼粘昨。 笑死垢啼,一個(gè)胖子當(dāng)著我的面吹牛窜锯,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播芭析,決...
    沈念sama閱讀 40,819評(píng)論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼锚扎,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了馁启?” 一聲冷哼從身側(cè)響起驾孔,我...
    開封第一講書人閱讀 39,725評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎惯疙,沒想到半個(gè)月后翠勉,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,268評(píng)論 1 320
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡霉颠,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,356評(píng)論 3 340
  • 正文 我和宋清朗相戀三年对碌,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片蒿偎。...
    茶點(diǎn)故事閱讀 40,488評(píng)論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡朽们,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出诉位,到底是詐尸還是另有隱情骑脱,我是刑警寧澤,帶...
    沈念sama閱讀 36,181評(píng)論 5 350
  • 正文 年R本政府宣布苍糠,位于F島的核電站叁丧,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏椿息。R本人自食惡果不足惜歹袁,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,862評(píng)論 3 333
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望寝优。 院中可真熱鬧条舔,春花似錦、人聲如沸乏矾。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,331評(píng)論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽钻心。三九已至凄硼,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間捷沸,已是汗流浹背摊沉。 一陣腳步聲響...
    開封第一講書人閱讀 33,445評(píng)論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留痒给,地道東北人说墨。 一個(gè)月前我還...
    沈念sama閱讀 48,897評(píng)論 3 376
  • 正文 我出身青樓骏全,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國和親尼斧。 傳聞我的和親對(duì)象是個(gè)殘疾皇子姜贡,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,500評(píng)論 2 359

推薦閱讀更多精彩內(nèi)容