理工男的文科夢 —— LSTM深度學(xué)習(xí)寫春聯(lián)

筆者作為一名根正苗紅的理工男吸奴,內(nèi)心卻常常有很多文藝青年才會有的想法,例如寫首詩缠局、做首詞则奥,甚至包括春節(jié)寫副對聯(lián),空有一番愿望卻胸?zé)o點墨狭园。隨著對機器學(xué)習(xí)和深度學(xué)習(xí)的了解读处,逐漸萌生了使用機器幫助筆者完成文藝青年的轉(zhuǎn)型。:)
本文借助遞歸神經(jīng)網(wǎng)絡(luò)RDD的變種之一LSTM算法唱矛,對收集到的6900多條對聯(lián)進行學(xué)習(xí)罚舱,訓(xùn)練好模型后可以由機器寫出對聯(lián)。

遞歸神經(jīng)網(wǎng)絡(luò)與LSTM

故事從人工神經(jīng)網(wǎng)絡(luò)開始绎谦,人工神經(jīng)網(wǎng)絡(luò)誕生已久管闷。如下圖所示,神經(jīng)網(wǎng)絡(luò)的基本結(jié)構(gòu)由輸入層窃肠、輸出層和一個或多個隱含層組成包个。

多層神經(jīng)網(wǎng)絡(luò)

全連接的神經(jīng)網(wǎng)絡(luò)下一層神經(jīng)元的輸入由上一層所有神經(jīng)元的輸出決定,因此帶來了一個嚴(yán)重的問題即參數(shù)數(shù)量過大導(dǎo)致無法訓(xùn)練冤留。因此碧囊,隨時神經(jīng)網(wǎng)絡(luò)的發(fā)展,衍生了一系列的變化纤怒。比較流行的有應(yīng)用于圖像識別領(lǐng)域的卷積神經(jīng)網(wǎng)絡(luò)CNN糯而、應(yīng)用于自然語言處理的遞歸神經(jīng)網(wǎng)絡(luò)RNN。本文應(yīng)用到的LSTM算法即為RNN的一種形態(tài)泊窘。RNN解決了這樣的問題:即樣本出現(xiàn)的時間順序?qū)τ谧匀徽Z言處理熄驼、語音識別、手寫體識別等應(yīng)用非常重要州既,神經(jīng)元的輸出可以在下一個時間戳直接作用到自身谜洽。因此RNN很適合處理時序?qū)Y(jié)果影響較深的領(lǐng)域萝映。
關(guān)于RNN和LSTM原理的說明可以移步 http://www.reibang.com/p/9dc9f41f0b29 吴叶,本文不多加贅言。

RNN
由LSTM作詩引發(fā)

由于LSTM算法非常適用自然語言處理領(lǐng)域序臂,因此網(wǎng)上出現(xiàn)了很多應(yīng)用LSTM做文字領(lǐng)域的嘗試蚌卤,例如: LSTM寫詩 中使用LSTM寫詩实束,LSTM創(chuàng)作歌詞中使用LSTM模仿歌手風(fēng)格寫歌詞,以及使用LSTM算法給小孩起名(是多么不靠譜的粑粑麻麻)逊彭。
因此咸灿,筆者突發(fā)想法,如果給一個足夠的春聯(lián)訓(xùn)練樣本侮叮,一樣可以照貓畫老虎避矢,訓(xùn)練一個可以寫對聯(lián)的文藝“機器模型”。因此囊榜,問題就分解為:找樣本审胸、寫算法、訓(xùn)練卸勺、應(yīng)用模型砂沛。

春聯(lián)樣本搜集和規(guī)整

借助于強大的度娘,費勁九牛之力曙求,從網(wǎng)上搜集了各式春聯(lián)共6900對碍庵,其中上聯(lián)下聯(lián)之間是用","分割區(qū)分上下聯(lián)悟狱,對聯(lián)之間是用"静浴。"區(qū)分一聯(lián)的結(jié)束。樣式如下:
訓(xùn)練樣本

這些樣本將會在訓(xùn)練階段進行類型轉(zhuǎn)換并輸入給LSTM模型中挤渐。如果您也想試下本文案例马绝,請私信我這些樣本(畢竟搜集訓(xùn)練樣本是個苦差事(: )

LSTM算法

本文使用TensorFlow進行建模,TensorFlow就無需多言挣菲,是這個領(lǐng)域目前最活躍的框架富稻。寫對聯(lián)的算法主要工作包括:根據(jù)樣本數(shù)據(jù)產(chǎn)生LSTM輸入數(shù)據(jù)和結(jié)果值;定義LSTM的模型以及損失函數(shù)白胀;將訓(xùn)練數(shù)據(jù)喂給TensorFlow用來訓(xùn)練模型椭赋。接下來會逐步列舉本例中使用的方法。

  • 訓(xùn)練數(shù)據(jù)轉(zhuǎn)換
    由于樣本數(shù)據(jù)是一條條漢字組成的對聯(lián)或杠,這樣的數(shù)據(jù)是無法交給模型訓(xùn)練的哪怔,因此需要對樣本數(shù)據(jù)進行轉(zhuǎn)換∠蚯溃基本思想是:
    • 將樣本的所有對聯(lián)加載錄入认境,統(tǒng)計出所有出現(xiàn)的漢字,并將漢字進行編碼挟鸠,例如:一共有10000個漢字出現(xiàn)在樣本中叉信,那么對出現(xiàn)的漢字按 0 - 999 進行編碼,每個漢字對應(yīng)一個編碼艘希。
    • 對原始樣本進行編碼轉(zhuǎn)換硼身,生成用數(shù)字編碼表示的對聯(lián)集硅急。
    • 每條對聯(lián)作為一個輸入序列,每批次訓(xùn)練batch_size條佳遂,生成輸入數(shù)據(jù)xdata营袜,輸出y值為xdata+1。因為文本分析的特點是有時序性丑罪。
couplet_file ="couplet.txt"
#對聯(lián)
couplets = []
with open(couplet_file,'r') as f:
    for line in f:
        try:
            content = line.replace(' ','')
            if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
                continue
            if len(content) < 5*3 or len(content) > 79*3:
                continue
            content = '[' + content + ']'
           # print chardet.detect(content)
            content = content.decode('utf-8')
            couplets.append(content)

        except Exception as e:
            pass

# 按字?jǐn)?shù)排序
couplets = sorted(couplets,key=lambda line: len(line))
print('對聯(lián)總數(shù): %d'%(len(couplets)))
# 統(tǒng)計每個字出現(xiàn)次數(shù)
all_words = []
for couplet in couplets:
    all_words += [word for word in couplet]

counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
words = words[:len(words)] + (' ',)
# 每個字映射為一個數(shù)字ID
word_num_map = dict(zip(words, range(len(words))))

to_num = lambda word: word_num_map.get(word, len(words))
couplets_vector = [ list(map(to_num, couplet)) for couplet in couplets]

# 每次取64首對聯(lián)進行訓(xùn)練, 此參數(shù)可以調(diào)整
batch_size = 64
n_chunk = len(couplets_vector) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
    start_index = i * batch_size#起始位置
    end_index = start_index + batch_size#結(jié)束位置

    batches = couplets_vector[start_index:end_index]
    length = max(map(len,batches))#每個batches中句子的最大長度
    xdata = np.full((batch_size,length), word_num_map[' '], np.int32)
    for row in range(batch_size):
        xdata[row,:len(batches[row])] = batches[row]
    ydata = np.copy(xdata)
    ydata[:,:-1] = xdata[:,1:]
    x_batches.append(xdata)
    y_batches.append(ydata)
  • 定義LSTM模型

    • 使用TF api tf.nn.rnn_cell.BasicLSTMCell定義cell為一個128維的ht的cell荚板。并使用MultiRNNCell 定義為兩層的LSTM。
    • 對訓(xùn)練樣本輸入進行embedding化吩屹。
    • 使用tf.nn.dynamic_rnn計算輸出值啸驯。(也可以通過循環(huán)step的方法,依次計算)
    • 加入softmax層祟峦。
def neural_network(rnn_size=128, num_layers=2):
    cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

    initial_state = cell.zero_state(batch_size, tf.float32)

    with tf.variable_scope('rnnlm'):
        softmax_w = tf.get_variable("softmax_w", [rnn_size, len(words)+1])
        softmax_b = tf.get_variable("softmax_b", [len(words)+1])
        with tf.device("/cpu:0"):
            embedding = tf.get_variable("embedding", [len(words)+1, rnn_size])
            inputs = tf.nn.embedding_lookup(embedding, input_data)

    outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, scope='rnnlm')
    output = tf.reshape(outputs,[-1, rnn_size])

    logits = tf.matmul(output, softmax_w) + softmax_b
    probs = tf.nn.softmax(logits)
    return logits, last_state, probs, cell, initial_state
  • 訓(xùn)練階段
    • 使用TF sequence_loss_by_example計算所有examples(假設(shè)一句話有n個單詞罚斗,一個單詞及單詞所對應(yīng)的label就是一個example,所有example就是一句話中所有單詞)的加權(quán)交叉熵?fù)p失。
    • tf.gradients 計算梯度宅楞,并使用clip_by_global_norm控制梯度爆炸的問題针姿。梯度爆炸和梯度彌散的原因一樣,都是因為鏈?zhǔn)椒▌t求導(dǎo)的關(guān)系厌衙,導(dǎo)致梯度的指數(shù)級衰減距淫。為了避免梯度爆炸,需要對梯度進行修剪婶希。(來自網(wǎng)上的解釋榕暇,不明覺厲(: )
    • 定義步長,步長過大喻杈,會很可能越過最優(yōu)值彤枢,步長過小則使優(yōu)化的效率過低,長時間無法收斂筒饰。因此learning rate是一個需要適當(dāng)調(diào)整的參數(shù)缴啡。一個小技巧是,隨時訓(xùn)練的進行瓷们,即沿著梯度方向收斂的過程中业栅,適當(dāng)減小步長,不至于錯過最優(yōu)解谬晕。在代碼中 0.01 * (0.97 ** epoch)碘裕,learing rate基數(shù)值為0.01, 系數(shù)為0.97的epoch方,可以看出epoch越大攒钳,learing rate越小帮孔。
    • 分批次將樣本數(shù)據(jù)x_batches和y_batches喂給TF進行訓(xùn)練。
def train_neural_network():
    logits, last_state, _, _, _ = neural_network()
    targets = tf.reshape(output_targets, [-1])
    loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)], len(words))
    cost = tf.reduce_mean(loss)
    learning_rate = tf.Variable(0.0, trainable=False)
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(zip(grads, tvars))

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver = tf.train.Saver(tf.all_variables())

        for epoch in range(100):
            sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch)))
            n = 0
            for batche in range(n_chunk):
                train_loss, _ , _ = sess.run([cost, last_state, train_op], feed_dict={input_data: x_batches[n], output_targets: y_batches[n]})
                n += 1
                print(epoch, batche, train_loss)
            if epoch % 7 == 0:
                saver.save(sess, './couplet.module', global_step=epoch)
  • 訓(xùn)練結(jié)束 , 詩性大發(fā)

經(jīng)過漫長的訓(xùn)練(取決于樣本數(shù)和迭代次數(shù))夕玩, loss控制在1.5左右你弦。


loss

可以看到,經(jīng)過100次的迭代訓(xùn)練燎孟,每7次保存一次(saver.save(sess, './couplet.module', global_step=epoch)), 最后的模型保存在couplet.module-98里禽作。

modle

在eval階段,使用saver.restore(sess, 'couplet.module-98') 將訓(xùn)練好的模型加載, 因為機器算出來的依舊是上文提到的數(shù)字編碼揩页,因此需要再將數(shù)字轉(zhuǎn)為漢字旷偿。

好啦,來看看機器創(chuàng)作的對聯(lián)吧爆侣, 是不是有點意思呢萍程?

couplet
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市兔仰,隨后出現(xiàn)的幾起案子茫负,更是在濱河造成了極大的恐慌,老刑警劉巖乎赴,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件忍法,死亡現(xiàn)場離奇詭異,居然都是意外死亡榕吼,警方通過查閱死者的電腦和手機饿序,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來羹蚣,“玉大人原探,你說我怎么就攤上這事⊥缢兀” “怎么了咽弦?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長胁出。 經(jīng)常有香客問我离唬,道長,這世上最難降的妖魔是什么划鸽? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任输莺,我火速辦了婚禮,結(jié)果婚禮上裸诽,老公的妹妹穿的比我還像新娘嫂用。我一直安慰自己,他們只是感情好丈冬,可當(dāng)我...
    茶點故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布嘱函。 她就那樣靜靜地躺著,像睡著了一般埂蕊。 火紅的嫁衣襯著肌膚如雪往弓。 梳的紋絲不亂的頭發(fā)上疏唾,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天,我揣著相機與錄音函似,去河邊找鬼槐脏。 笑死,一個胖子當(dāng)著我的面吹牛撇寞,可吹牛的內(nèi)容都是我干的顿天。 我是一名探鬼主播,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼蔑担,長吁一口氣:“原來是場噩夢啊……” “哼牌废!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起啤握,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤鸟缕,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后排抬,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體叁扫,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年畜埋,在試婚紗的時候發(fā)現(xiàn)自己被綠了莫绣。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,690評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡悠鞍,死狀恐怖对室,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情咖祭,我是刑警寧澤掩宜,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布,位于F島的核電站么翰,受9級特大地震影響牺汤,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜浩嫌,卻給世界環(huán)境...
    茶點故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一檐迟、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧码耐,春花似錦追迟、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春廓块,著一層夾襖步出監(jiān)牢的瞬間厢绝,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工带猴, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留昔汉,地道東北人。 一個月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓浓利,卻偏偏與公主長得像挤庇,于是被迫代替她去往敵國和親钞速。 傳聞我的和親對象是個殘疾皇子贷掖,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,577評論 2 353

推薦閱讀更多精彩內(nèi)容

  • 神經(jīng)結(jié)構(gòu)進步、GPU深度學(xué)習(xí)訓(xùn)練效率突破渴语。RNN苹威,時間序列數(shù)據(jù)有效,每個神經(jīng)元通過內(nèi)部組件保存輸入信息驾凶。 卷積神經(jīng)...
    利炳根閱讀 4,740評論 0 7
  • 作者 | 武維AI前線出品| ID:ai-front 前言 自然語言處理(簡稱NLP)牙甫,是研究計算機處理人類語言的...
    AI前線閱讀 2,571評論 0 8
  • 第二個Topic講深度學(xué)習(xí),承接前面的《淺談機器學(xué)習(xí)基礎(chǔ)》调违。 深度學(xué)習(xí)簡介 前面也提到過窟哺,機器學(xué)習(xí)的本質(zhì)就是尋找最...
    我偏笑_NSNirvana閱讀 15,604評論 7 49
  • 當(dāng)一個人知道,另一個人其實是在乎的技肩, 可是溝通方式不對且轨, 距離也遠(yuǎn), 見面太少虚婿,觸摸不到旋奢。 漸漸,漸漸地然痊, 也就拖...
    我家門口的有條溪閱讀 398評論 0 0
  • 回顧自己走過的烏漆嘛嘿的路至朗,總有人為我點燈照我前行,雖然還是在摸黑前行剧浸,但至少有那么些瞬間讓我感動不已锹引,為我點燈的...
    云中君style閱讀 773評論 0 1