筆者作為一名根正苗紅的理工男吸奴,內(nèi)心卻常常有很多文藝青年才會有的想法,例如寫首詩缠局、做首詞则奥,甚至包括春節(jié)寫副對聯(lián),空有一番愿望卻胸?zé)o點墨狭园。隨著對機器學(xué)習(xí)和深度學(xué)習(xí)的了解读处,逐漸萌生了使用機器幫助筆者完成文藝青年的轉(zhuǎn)型。:)
本文借助遞歸神經(jīng)網(wǎng)絡(luò)RDD的變種之一LSTM算法唱矛,對收集到的6900多條對聯(lián)進行學(xué)習(xí)罚舱,訓(xùn)練好模型后可以由機器寫出對聯(lián)。
遞歸神經(jīng)網(wǎng)絡(luò)與LSTM
故事從人工神經(jīng)網(wǎng)絡(luò)開始绎谦,人工神經(jīng)網(wǎng)絡(luò)誕生已久管闷。如下圖所示,神經(jīng)網(wǎng)絡(luò)的基本結(jié)構(gòu)由輸入層窃肠、輸出層和一個或多個隱含層組成包个。
全連接的神經(jīng)網(wǎng)絡(luò)下一層神經(jīng)元的輸入由上一層所有神經(jīng)元的輸出決定,因此帶來了一個嚴(yán)重的問題即參數(shù)數(shù)量過大導(dǎo)致無法訓(xùn)練冤留。因此碧囊,隨時神經(jīng)網(wǎng)絡(luò)的發(fā)展,衍生了一系列的變化纤怒。比較流行的有應(yīng)用于圖像識別領(lǐng)域的卷積神經(jīng)網(wǎng)絡(luò)CNN糯而、應(yīng)用于自然語言處理的遞歸神經(jīng)網(wǎng)絡(luò)RNN。本文應(yīng)用到的LSTM算法即為RNN的一種形態(tài)泊窘。RNN解決了這樣的問題:即樣本出現(xiàn)的時間順序?qū)τ谧匀徽Z言處理熄驼、語音識別、手寫體識別等應(yīng)用非常重要州既,神經(jīng)元的輸出可以在下一個時間戳直接作用到自身谜洽。因此RNN很適合處理時序?qū)Y(jié)果影響較深的領(lǐng)域萝映。
關(guān)于RNN和LSTM原理的說明可以移步 http://www.reibang.com/p/9dc9f41f0b29 吴叶,本文不多加贅言。
由LSTM作詩引發(fā)
由于LSTM算法非常適用自然語言處理領(lǐng)域序臂,因此網(wǎng)上出現(xiàn)了很多應(yīng)用LSTM做文字領(lǐng)域的嘗試蚌卤,例如: LSTM寫詩 中使用LSTM寫詩实束,LSTM創(chuàng)作歌詞中使用LSTM模仿歌手風(fēng)格寫歌詞,以及使用LSTM算法給小孩起名(是多么不靠譜的粑粑麻麻)逊彭。
因此咸灿,筆者突發(fā)想法,如果給一個足夠的春聯(lián)訓(xùn)練樣本侮叮,一樣可以照貓畫老虎避矢,訓(xùn)練一個可以寫對聯(lián)的文藝“機器模型”。因此囊榜,問題就分解為:找樣本审胸、寫算法、訓(xùn)練卸勺、應(yīng)用模型砂沛。
春聯(lián)樣本搜集和規(guī)整
借助于強大的度娘,費勁九牛之力曙求,從網(wǎng)上搜集了各式春聯(lián)共6900對碍庵,其中上聯(lián)下聯(lián)之間是用","分割區(qū)分上下聯(lián)悟狱,對聯(lián)之間是用"静浴。"區(qū)分一聯(lián)的結(jié)束。樣式如下:這些樣本將會在訓(xùn)練階段進行類型轉(zhuǎn)換并輸入給LSTM模型中挤渐。如果您也想試下本文案例马绝,請私信我這些樣本(畢竟搜集訓(xùn)練樣本是個苦差事(: )
LSTM算法
本文使用TensorFlow進行建模,TensorFlow就無需多言挣菲,是這個領(lǐng)域目前最活躍的框架富稻。寫對聯(lián)的算法主要工作包括:根據(jù)樣本數(shù)據(jù)產(chǎn)生LSTM輸入數(shù)據(jù)和結(jié)果值;定義LSTM的模型以及損失函數(shù)白胀;將訓(xùn)練數(shù)據(jù)喂給TensorFlow用來訓(xùn)練模型椭赋。接下來會逐步列舉本例中使用的方法。
- 訓(xùn)練數(shù)據(jù)轉(zhuǎn)換
由于樣本數(shù)據(jù)是一條條漢字組成的對聯(lián)或杠,這樣的數(shù)據(jù)是無法交給模型訓(xùn)練的哪怔,因此需要對樣本數(shù)據(jù)進行轉(zhuǎn)換∠蚯溃基本思想是:- 將樣本的所有對聯(lián)加載錄入认境,統(tǒng)計出所有出現(xiàn)的漢字,并將漢字進行編碼挟鸠,例如:一共有10000個漢字出現(xiàn)在樣本中叉信,那么對出現(xiàn)的漢字按 0 - 999 進行編碼,每個漢字對應(yīng)一個編碼艘希。
- 對原始樣本進行編碼轉(zhuǎn)換硼身,生成用數(shù)字編碼表示的對聯(lián)集硅急。
- 每條對聯(lián)作為一個輸入序列,每批次訓(xùn)練batch_size條佳遂,生成輸入數(shù)據(jù)xdata营袜,輸出y值為xdata+1。因為文本分析的特點是有時序性丑罪。
couplet_file ="couplet.txt"
#對聯(lián)
couplets = []
with open(couplet_file,'r') as f:
for line in f:
try:
content = line.replace(' ','')
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
continue
if len(content) < 5*3 or len(content) > 79*3:
continue
content = '[' + content + ']'
# print chardet.detect(content)
content = content.decode('utf-8')
couplets.append(content)
except Exception as e:
pass
# 按字?jǐn)?shù)排序
couplets = sorted(couplets,key=lambda line: len(line))
print('對聯(lián)總數(shù): %d'%(len(couplets)))
# 統(tǒng)計每個字出現(xiàn)次數(shù)
all_words = []
for couplet in couplets:
all_words += [word for word in couplet]
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
words = words[:len(words)] + (' ',)
# 每個字映射為一個數(shù)字ID
word_num_map = dict(zip(words, range(len(words))))
to_num = lambda word: word_num_map.get(word, len(words))
couplets_vector = [ list(map(to_num, couplet)) for couplet in couplets]
# 每次取64首對聯(lián)進行訓(xùn)練, 此參數(shù)可以調(diào)整
batch_size = 64
n_chunk = len(couplets_vector) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size#起始位置
end_index = start_index + batch_size#結(jié)束位置
batches = couplets_vector[start_index:end_index]
length = max(map(len,batches))#每個batches中句子的最大長度
xdata = np.full((batch_size,length), word_num_map[' '], np.int32)
for row in range(batch_size):
xdata[row,:len(batches[row])] = batches[row]
ydata = np.copy(xdata)
ydata[:,:-1] = xdata[:,1:]
x_batches.append(xdata)
y_batches.append(ydata)
-
定義LSTM模型
- 使用TF api tf.nn.rnn_cell.BasicLSTMCell定義cell為一個128維的ht的cell荚板。并使用MultiRNNCell 定義為兩層的LSTM。
- 對訓(xùn)練樣本輸入進行embedding化吩屹。
- 使用tf.nn.dynamic_rnn計算輸出值啸驯。(也可以通過循環(huán)step的方法,依次計算)
- 加入softmax層祟峦。
def neural_network(rnn_size=128, num_layers=2):
cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
initial_state = cell.zero_state(batch_size, tf.float32)
with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [rnn_size, len(words)+1])
softmax_b = tf.get_variable("softmax_b", [len(words)+1])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [len(words)+1, rnn_size])
inputs = tf.nn.embedding_lookup(embedding, input_data)
outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, scope='rnnlm')
output = tf.reshape(outputs,[-1, rnn_size])
logits = tf.matmul(output, softmax_w) + softmax_b
probs = tf.nn.softmax(logits)
return logits, last_state, probs, cell, initial_state
- 訓(xùn)練階段
- 使用TF sequence_loss_by_example計算所有examples(假設(shè)一句話有n個單詞罚斗,一個單詞及單詞所對應(yīng)的label就是一個example,所有example就是一句話中所有單詞)的加權(quán)交叉熵?fù)p失。
- tf.gradients 計算梯度宅楞,并使用clip_by_global_norm控制梯度爆炸的問題针姿。梯度爆炸和梯度彌散的原因一樣,都是因為鏈?zhǔn)椒▌t求導(dǎo)的關(guān)系厌衙,導(dǎo)致梯度的指數(shù)級衰減距淫。為了避免梯度爆炸,需要對梯度進行修剪婶希。(來自網(wǎng)上的解釋榕暇,不明覺厲(: )
- 定義步長,步長過大喻杈,會很可能越過最優(yōu)值彤枢,步長過小則使優(yōu)化的效率過低,長時間無法收斂筒饰。因此learning rate是一個需要適當(dāng)調(diào)整的參數(shù)缴啡。一個小技巧是,隨時訓(xùn)練的進行瓷们,即沿著梯度方向收斂的過程中业栅,適當(dāng)減小步長,不至于錯過最優(yōu)解谬晕。在代碼中 0.01 * (0.97 ** epoch)碘裕,learing rate基數(shù)值為0.01, 系數(shù)為0.97的epoch方,可以看出epoch越大攒钳,learing rate越小帮孔。
- 分批次將樣本數(shù)據(jù)x_batches和y_batches喂給TF進行訓(xùn)練。
def train_neural_network():
logits, last_state, _, _, _ = neural_network()
targets = tf.reshape(output_targets, [-1])
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)], len(words))
cost = tf.reduce_mean(loss)
learning_rate = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.apply_gradients(zip(grads, tvars))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver(tf.all_variables())
for epoch in range(100):
sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch)))
n = 0
for batche in range(n_chunk):
train_loss, _ , _ = sess.run([cost, last_state, train_op], feed_dict={input_data: x_batches[n], output_targets: y_batches[n]})
n += 1
print(epoch, batche, train_loss)
if epoch % 7 == 0:
saver.save(sess, './couplet.module', global_step=epoch)
- 訓(xùn)練結(jié)束 , 詩性大發(fā)
經(jīng)過漫長的訓(xùn)練(取決于樣本數(shù)和迭代次數(shù))夕玩, loss控制在1.5左右你弦。
可以看到,經(jīng)過100次的迭代訓(xùn)練燎孟,每7次保存一次(saver.save(sess, './couplet.module', global_step=epoch)), 最后的模型保存在couplet.module-98里禽作。
在eval階段,使用saver.restore(sess, 'couplet.module-98') 將訓(xùn)練好的模型加載, 因為機器算出來的依舊是上文提到的數(shù)字編碼揩页,因此需要再將數(shù)字轉(zhuǎn)為漢字旷偿。
好啦,來看看機器創(chuàng)作的對聯(lián)吧爆侣, 是不是有點意思呢萍程?