word2vec/lstm on mxnet with NCE loss

Softmax是用來實現(xiàn)多類分類問題常見的損失函數(shù)摊求。但如果類別特別多,softmax的效率就是個問題了。比如在word2vec里访锻,每個詞都是一個類別褪尝,在這種情況下可能有100萬類。那么每次都得預(yù)測一個樣本在100萬類上屬于每個類的概率期犬,這個效率是非常低的河哑。

為了解決這個問題,在word2vec里面提出了基于Huffman編碼的層次Softmax(HS)龟虎。HS的結(jié)構(gòu)還是過于復(fù)雜璃谨,因此后來又有人提出了基于采樣的NCE(其實NCE和Negative Sampling是2個不同的paper提出的東西,形式上有所區(qū)別鲤妥,不過我覺得本質(zhì)是沒有區(qū)別的)佳吞。因此我們可以把HS或者NCE作為多類分類問題的Loss Layer。

所有的代碼目前在https://github.com/xlvector/learning-dl/tree/master/mxnet/nce-loss棉安。

為了體驗一下Softmax和NCE的速度差別底扳,我們實現(xiàn)了兩個例子 toy_softmax.py 和 toy_nce.py。我們虛構(gòu)了一個多類分類問題贡耽,他的構(gòu)造方法如下:

def mock_sample(self):
    ret = np.zeros(self.feature_size)
    rn = set()
    while len(rn) < 3:
        rn.add(random.randint(0, self.feature_size - 1))
    s = 0
    for k in rn:
        ret[k] = 1.0
        s *= self.feature_size
        s += k
    return ret, s % self.vocab_size

上面feature_size 是輸入特征的維度衷模,vocab_size是類別的數(shù)目。

toy_softmax.py 用普通的softmax來做多類分類問題蒲赂,網(wǎng)絡(luò)結(jié)構(gòu)如下:

def get_net(vocab_size):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    pred = mx.sym.FullyConnected(data = data, num_hidden = 100)
    pred = mx.sym.FullyConnected(data = pred, num_hidden = vocab_size)
    sm = mx.sym.SoftmaxOutput(data = pred, label = label)
    return sm

運行速度和類別個數(shù)的關(guān)系如下

類別數(shù) 每秒處理的樣本數(shù)
100 40000
1000 30000
10000 10000
100000 1000

可以看到阱冶,在類別數(shù)從10000提高到100000時,速度直接降為原來的1/10滥嘴。

在看看toy_nce.py木蹬,他的網(wǎng)絡(luò)結(jié)構(gòu)如下:

def get_net(vocab_size, num_label):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    pred = mx.sym.FullyConnected(data = data, num_hidden = 100)
    return nce_loss(data = pred,
                    label = label,
                    label_weight = label_weight,
                    embed_weight = embed_weight,
                    vocab_size = vocab_size,
                    num_hidden = 100,
                    num_label = num_label)

其中,nce_loss的結(jié)構(gòu)如下:

def nce_loss(data, label, label_weight, embed_weight, vocab_size, num_hidden, num_label):
    label_embed = mx.sym.Embedding(data = label, input_dim = vocab_size,
                                   weight = embed_weight,
                                   output_dim = num_hidden, name = 'label_embed')
    label_embed = mx.sym.SliceChannel(data = label_embed,
                                      num_outputs = num_label,
                                      squeeze_axis = 1, name = 'label_slice')
    label_weight = mx.sym.SliceChannel(data = label_weight,
                                       num_outputs = num_label,
                                       squeeze_axis = 1)
    probs = []
    for i in range(num_label):
        vec = label_embed[i]
        vec = vec * data
        vec = mx.sym.sum(vec, axis = 1)
        sm = mx.sym.LogisticRegressionOutput(data = vec,
                                             label = label_weight[i])
        probs.append(sm)
    return mx.sym.Group(probs)

NCE的主要思想是氏涩,對于每一個樣本届囚,除了他自己的label,同時采樣出N個其他的label是尖,從而我們只需要計算樣本在這N+1個label上的概率意系,而不用計算樣本在所有l(wèi)abel上的概率。而樣本在每個label上的概率最終用了Logistic的損失函數(shù)饺汹。再來看看NCE的速度和類別數(shù)之間的關(guān)系:

類別數(shù) 每秒處理的樣本數(shù)
100 30000
1000 30000
10000 30000
100000 20000

可以看到NCE的速度相對于類別數(shù)并不敏感蛔添。

有了NCE Loss后,就可以用mxnet來訓(xùn)練word2vec了兜辞。word2vec的其中一個CBOW模型是用一個詞周圍的N個詞去預(yù)測這個詞迎瞧,我們可以設(shè)計如下的網(wǎng)絡(luò)結(jié)構(gòu):

def get_net(vocab_size, num_input, num_label):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    data_embed = mx.sym.Embedding(data = data, input_dim = vocab_size,
                                  weight = embed_weight,
                                  output_dim = 100, name = 'data_embed')
    datavec = mx.sym.SliceChannel(data = data_embed,
                                     num_outputs = num_input,
                                     squeeze_axis = 1, name = 'data_slice')
    pred = datavec[0]
    for i in range(1, num_input):
        pred = pred + datavec[i]
    return nce_loss(data = pred,
                    label = label,
                    label_weight = label_weight,
                    embed_weight = embed_weight,
                    vocab_size = vocab_size,
                    num_hidden = 100,
                    num_label = num_label)

如上面的結(jié)構(gòu),輸入是num_input個詞語逸吵。輸出是num_label個詞語凶硅,其中有1個詞語是正樣本,剩下是負樣本扫皱。這里足绅,input的embeding和label的embeding都用了同一個embed矩陣embed_weight捷绑。

執(zhí)行wordvec.py (需要把text8放在./data/下面),就可以看到訓(xùn)練結(jié)果氢妈。

接著word2vec的思路粹污,可以繼續(xù)把lstm也用上NCE loss。網(wǎng)絡(luò)結(jié)構(gòu)如下:

def get_net(vocab_size, seq_len, num_label, num_lstm_layer, num_hidden):
    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                          h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)
        
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    label_embed_weight = mx.sym.Variable('label_embed_weight')
    data_embed = mx.sym.Embedding(data = data, input_dim = vocab_size,
                                  weight = embed_weight,
                                  output_dim = 100, name = 'data_embed')
    datavec = mx.sym.SliceChannel(data = data_embed,
                                  num_outputs = seq_len,
                                  squeeze_axis = True, name = 'data_slice')
    labelvec = mx.sym.SliceChannel(data = label,
                                   num_outputs = seq_len,
                                   squeeze_axis = True, name = 'label_slice')
    labelweightvec = mx.sym.SliceChannel(data = label_weight,
                                         num_outputs = seq_len,
                                         squeeze_axis = True, name = 'label_weight_slice')
    probs = []
    for seqidx in range(seq_len):
        hidden = datavec[seqidx]
        
        for i in range(num_lstm_layer):
            next_state = lstm(num_hidden, indata = hidden,
                              prev_state = last_states[i],
                              param = param_cells[i],
                              seqidx = seqidx, layeridx = i)
            hidden = next_state.h
            last_states[i] = next_state
            
        probs += nce_loss(data = hidden,
                          label = labelvec[seqidx],
                          label_weight = labelweightvec[seqidx],
                          embed_weight = label_embed_weight,
                          vocab_size = vocab_size,
                          num_hidden = 100,
                          num_label = num_label)
    return mx.sym.Group(probs)

參考

  1. Tensorflow 關(guān)于nce_loss的實現(xiàn)在 這里
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末首量,一起剝皮案震驚了整個濱河市壮吩,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌加缘,老刑警劉巖鸭叙,帶你破解...
    沈念sama閱讀 221,548評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異生百,居然都是意外死亡递雀,警方通過查閱死者的電腦和手機柄延,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,497評論 3 399
  • 文/潘曉璐 我一進店門蚀浆,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人搜吧,你說我怎么就攤上這事市俊。” “怎么了滤奈?”我有些...
    開封第一講書人閱讀 167,990評論 0 360
  • 文/不壞的土叔 我叫張陵摆昧,是天一觀的道長。 經(jīng)常有香客問我蜒程,道長绅你,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 59,618評論 1 296
  • 正文 為了忘掉前任昭躺,我火速辦了婚禮忌锯,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘领炫。我一直安慰自己偶垮,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 68,618評論 6 397
  • 文/花漫 我一把揭開白布帝洪。 她就那樣靜靜地躺著似舵,像睡著了一般。 火紅的嫁衣襯著肌膚如雪葱峡。 梳的紋絲不亂的頭發(fā)上砚哗,一...
    開封第一講書人閱讀 52,246評論 1 308
  • 那天,我揣著相機與錄音砰奕,去河邊找鬼蛛芥。 笑死泌参,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的常空。 我是一名探鬼主播沽一,決...
    沈念sama閱讀 40,819評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼漓糙!你這毒婦竟也來了铣缠?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,725評論 0 276
  • 序言:老撾萬榮一對情侶失蹤昆禽,失蹤者是張志新(化名)和其女友劉穎蝗蛙,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體醉鳖,經(jīng)...
    沈念sama閱讀 46,268評論 1 320
  • 正文 獨居荒郊野嶺守林人離奇死亡捡硅,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,356評論 3 340
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了盗棵。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片壮韭。...
    茶點故事閱讀 40,488評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖纹因,靈堂內(nèi)的尸體忽然破棺而出喷屋,到底是詐尸還是另有隱情,我是刑警寧澤瞭恰,帶...
    沈念sama閱讀 36,181評論 5 350
  • 正文 年R本政府宣布屯曹,位于F島的核電站,受9級特大地震影響惊畏,放射性物質(zhì)發(fā)生泄漏恶耽。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,862評論 3 333
  • 文/蒙蒙 一颜启、第九天 我趴在偏房一處隱蔽的房頂上張望偷俭。 院中可真熱鬧,春花似錦农曲、人聲如沸社搅。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,331評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽形葬。三九已至,卻和暖如春暮的,著一層夾襖步出監(jiān)牢的瞬間笙以,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,445評論 1 272
  • 我被黑心中介騙來泰國打工冻辩, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留猖腕,地道東北人拆祈。 一個月前我還...
    沈念sama閱讀 48,897評論 3 376
  • 正文 我出身青樓,卻偏偏與公主長得像倘感,于是被迫代替她去往敵國和親放坏。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,500評論 2 359

推薦閱讀更多精彩內(nèi)容