利用RNN和LSTM生成小說題記

一、選取素材

  • 語料格式
  • 題記:此情可待成追憶,只是當(dāng)時已惘然。

二烁试、開發(fā)環(huán)境

三、實戰(zhàn)代碼

#!/bash/bin
# -*-coding:utf-8-*-

import sys
import os
import numpy as np
import collections
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import tensorflow.contrib.legacy_seq2seq as seq2seq

BEGIN_CHAR = '^'
END_CHAR = '$'
UNKNOWN_CHAR = '*'
MAX_LENGTH = 100
MIN_LENGTH = 10
max_words = 3000
epochs = 50
# 語料
poetry_file = 'story.txt'
# 模型文件存放位置
save_dir = 'model'


class Data:
    def __init__(self):
        self.batch_size = 64
        self.poetry_file = poetry_file
        self.load()
        self.create_batches()

    def load(self):
        def handle(line):
            if len(line) > MAX_LENGTH:
                index_end = line.rfind('雾狈。', 0, MAX_LENGTH)
                index_end = index_end if index_end > 0 else MAX_LENGTH
                line = line[:index_end + 1]
            return BEGIN_CHAR + line + END_CHAR

        self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in
                        open(self.poetry_file, encoding='utf-8')]
        self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH]
        # 所有字
        words = []
        for poetry in self.poetrys:
            words += [word for word in poetry]
        counter = collections.Counter(words)
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])
        words, _ = zip(*count_pairs)

        # 取出現(xiàn)頻率最高的詞的數(shù)量組成字典廓潜,不在字典中的字用'*'代替
        words_size = min(max_words, len(words))
        self.words = words[:words_size] + (UNKNOWN_CHAR,)
        self.words_size = len(self.words)

        # 字映射成id
        self.char2id_dict = {w: i for i, w in enumerate(self.words)}
        self.id2char_dict = {i: w for i, w in enumerate(self.words)}
        self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR)
        self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char)
        self.id2char = lambda num: self.id2char_dict.get(num)
        self.poetrys = sorted(self.poetrys, key=lambda line: len(line))
        self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys]

    def create_batches(self):
        self.n_size = len(self.poetrys_vector) // self.batch_size
        self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size]
        self.x_batches = []
        self.y_batches = []
        for i in range(self.n_size):
            batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size]
            length = max(map(len, batches))
            for row in range(self.batch_size):
                if len(batches[row]) < length:
                    r = length - len(batches[row])
                    batches[row][len(batches[row]): length] = [self.unknow_char] * r
            xdata = np.array(batches)
            ydata = np.copy(xdata)
            ydata[:, :-1] = xdata[:, 1:]
            self.x_batches.append(xdata)
            self.y_batches.append(ydata)


class Model:
    def __init__(self, data, model='lstm', infer=False):
        self.rnn_size = 128
        self.n_layers = 2

        if infer:
            self.batch_size = 1
        else:
            self.batch_size = data.batch_size

        if model == 'rnn':
            cell_rnn = rnn.BasicRNNCell
        elif model == 'gru':
            cell_rnn = rnn.GRUCell
        elif model == 'lstm':
            cell_rnn = rnn.BasicLSTMCell

        cell = cell_rnn(self.rnn_size, state_is_tuple=False)
        self.cell = rnn.MultiRNNCell([cell] * self.n_layers, state_is_tuple=False)

        self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None])
        self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None])

        self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)

        with tf.variable_scope('rnnlm'):
            softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size])
            softmax_b = tf.get_variable("softmax_b", [data.words_size])
            with tf.device("/cpu:0"):
                embedding = tf.get_variable(
                    "embedding", [data.words_size, self.rnn_size])
                inputs = tf.nn.embedding_lookup(embedding, self.x_tf)

        outputs, final_state = tf.nn.dynamic_rnn(
            self.cell, inputs, initial_state=self.initial_state, scope='rnnlm')

        self.output = tf.reshape(outputs, [-1, self.rnn_size])
        self.logits = tf.matmul(self.output, softmax_w) + softmax_b
        self.probs = tf.nn.softmax(self.logits)
        self.final_state = final_state
        pred = tf.reshape(self.y_tf, [-1])
        # seq2seq
        loss = seq2seq.sequence_loss_by_example([self.logits],
                                                [pred],
                                                [tf.ones_like(pred, dtype=tf.float32)], )

        self.cost = tf.reduce_mean(loss)
        self.learning_rate = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5)

        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))


def train(data, model):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        n = 0
        for epoch in range(epochs):
            sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch)))
            pointer = 0
            for batche in range(data.n_size):
                n += 1
                feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]}
                pointer += 1
                train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict)
                sys.stdout.write('\r')
                info = "{}/{} (epoch {}) | train_loss {:.3f}" \
                    .format(epoch * data.n_size + batche,
                            epochs * data.n_size, epoch, train_loss)
                sys.stdout.write(info)
                sys.stdout.flush()
                # save
                if (epoch * data.n_size + batche) % 1000 == 0 \
                        or (epoch == epochs - 1 and batche == data.n_size - 1):
                    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=n)
                    sys.stdout.write('\n')
                    print("model saved to {}".format(checkpoint_path))
            sys.stdout.write('\n')


def sample(data, model, head=u''):
    def to_word(weights):
        t = np.cumsum(weights)
        s = np.sum(weights)
        sa = int(np.searchsorted(t, np.random.rand(1) * s))
        return data.id2char(sa)

    for word in head:
        if word not in data.words:
            return u'{} 不在字典中'.format(word)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables())
        model_file = tf.train.latest_checkpoint(save_dir)
        saver.restore(sess, model_file)

        if head:
            print('生成題記 ---> ', head)
            poem = BEGIN_CHAR
            for head_word in head:
                poem += head_word
                x = np.array([list(map(data.char2id, poem))])
                state = sess.run(model.cell.zero_state(1, tf.float32))
                feed_dict = {model.x_tf: x, model.initial_state: state}
                [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
                word = to_word(probs[-1])
                while word != u'抵皱,' and word != u'善榛。':
                    poem += word
                    x = np.zeros((1, 1))
                    x[0, 0] = data.char2id(word)
                    [probs, state] = sess.run([model.probs, model.final_state],
                                              {model.x_tf: x, model.initial_state: state})
                    word = to_word(probs[-1])
                poem += word
            return poem[1:]
        else:
            poem = ''
            head = BEGIN_CHAR
            x = np.array([list(map(data.char2id, head))])
            state = sess.run(model.cell.zero_state(1, tf.float32))
            feed_dict = {model.x_tf: x, model.initial_state: state}
            [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
            word = to_word(probs[-1])
            while word != END_CHAR:
                poem += word
                x = np.zeros((1, 1))
                x[0, 0] = data.char2id(word)
                [probs, state] = sess.run([model.probs, model.final_state],
                                          {model.x_tf: x, model.initial_state: state})
                word = to_word(probs[-1])
            return poem


if __name__ == '__main__':

    # 訓(xùn)練模型
    data = Data()
    model = Model(data=data, infer=False)
    print(train(data, model))

    # 生成題記
    # data = Data()
    # model = Model(data=data, infer=True)
    # print(sample(data, model, head='我為秋香'))

輸出
生成題記 --->  我為秋香
我罷性不行,為德勸仙興呻畸。秋風(fēng)暝冰始移盆,香巢深器酒。
輸出
關(guān)注我的技術(shù)公眾號《漫談人工智能》伤为,每天推送優(yōu)質(zhì)文章
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末咒循,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子绞愚,更是在濱河造成了極大的恐慌叙甸,老刑警劉巖,帶你破解...
    沈念sama閱讀 221,888評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件位衩,死亡現(xiàn)場離奇詭異裆蒸,居然都是意外死亡,警方通過查閱死者的電腦和手機糖驴,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,677評論 3 399
  • 文/潘曉璐 我一進店門僚祷,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人贮缕,你說我怎么就攤上這事辙谜。” “怎么了感昼?”我有些...
    開封第一講書人閱讀 168,386評論 0 360
  • 文/不壞的土叔 我叫張陵装哆,是天一觀的道長。 經(jīng)常有香客問我,道長烂琴,這世上最難降的妖魔是什么爹殊? 我笑而不...
    開封第一講書人閱讀 59,726評論 1 297
  • 正文 為了忘掉前任,我火速辦了婚禮奸绷,結(jié)果婚禮上梗夸,老公的妹妹穿的比我還像新娘。我一直安慰自己号醉,他們只是感情好反症,可當(dāng)我...
    茶點故事閱讀 68,729評論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著畔派,像睡著了一般铅碍。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上线椰,一...
    開封第一講書人閱讀 52,337評論 1 310
  • 那天胞谈,我揣著相機與錄音,去河邊找鬼憨愉。 笑死烦绳,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的配紫。 我是一名探鬼主播径密,決...
    沈念sama閱讀 40,902評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼躺孝!你這毒婦竟也來了享扔?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,807評論 0 276
  • 序言:老撾萬榮一對情侶失蹤植袍,失蹤者是張志新(化名)和其女友劉穎惧眠,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體于个,經(jīng)...
    沈念sama閱讀 46,349評論 1 318
  • 正文 獨居荒郊野嶺守林人離奇死亡氛魁,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,439評論 3 340
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了览濒。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片呆盖。...
    茶點故事閱讀 40,567評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖贷笛,靈堂內(nèi)的尸體忽然破棺而出应又,到底是詐尸還是另有隱情,我是刑警寧澤乏苦,帶...
    沈念sama閱讀 36,242評論 5 350
  • 正文 年R本政府宣布株扛,位于F島的核電站尤筐,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏洞就。R本人自食惡果不足惜盆繁,卻給世界環(huán)境...
    茶點故事閱讀 41,933評論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望旬蟋。 院中可真熱鬧油昂,春花似錦、人聲如沸倾贰。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,420評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽安寺。三九已至,卻和暖如春首尼,著一層夾襖步出監(jiān)牢的瞬間挑庶,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,531評論 1 272
  • 我被黑心中介騙來泰國打工软能, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留迎捺,地道東北人。 一個月前我還...
    沈念sama閱讀 48,995評論 3 377
  • 正文 我出身青樓埋嵌,卻偏偏與公主長得像破加,于是被迫代替她去往敵國和親俱恶。 傳聞我的和親對象是個殘疾皇子雹嗦,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,585評論 2 359

推薦閱讀更多精彩內(nèi)容

  • 1. 11月的蘇州了罪,街道上仍舊綠意盎然,身體上的感覺卻已然有了絲絲涼意聪全。尤其是連日來的陰雨天泊藕,讓人莫名的煩躁不開心...
    香草紫蘇閱讀 1,555評論 18 23
  • 蘋果印度官網(wǎng)僅列有 iPhone 6S 和 iPhone SE,但在其他本地零售商那里难礼,還能買到 2012 年發(fā)布...
    筆記本俠閱讀 268評論 0 0
  • 培養(yǎng)故事思維~ 你要記住三個關(guān)鍵詞: 收集者娃圆、開放心態(tài)、多維視角 蛾茉。 “薪盡火傳” ?了搜集故事讼呢,把心態(tài)打開,讓...
    索班班閱讀 153評論 0 0