利用RNN(lstm)生成文本

致謝以及參考

最近在做序列化標(biāo)注項(xiàng)目凰棉,試著理解rnn的設(shè)計(jì)結(jié)構(gòu)以及tensorflow中的具體實(shí)現(xiàn)方法届惋。在知乎中找到這篇文章,具有很大的幫助作用贩据,感謝作者為分享知識(shí)做出的努力荣病。

學(xué)習(xí)目標(biāo)定位

我主要重點(diǎn)在于理解文中連接所提供的在github上的project代碼码撰,一句句理解數(shù)據(jù)的預(yù)處理過程以及rnn網(wǎng)絡(luò)搭建過程(重點(diǎn)在于代碼注釋,代碼改動(dòng)很小个盆,實(shí)用python3)脖岛。(進(jìn)入下面環(huán)節(jié)之前,假設(shè)你已經(jīng)閱讀了知乎上的關(guān)于rnn知識(shí)講解篇幅颊亮,project的readme文檔)

數(shù)據(jù)預(yù)處理

  1. 理解模型大概需要的重要參數(shù):/Char-RNN-TensorFlow-master/train.py
# encoding: utf-8

import tensorflow as tf
from model import CharRNN
import os
import codecs  # 相比自帶的open函數(shù) 讀取寫入進(jìn)行自我轉(zhuǎn)碼
from read_utils import TextConverter, batch_generator

FLAGS = tf.flags.FLAGS
#  變量定義 以及 默認(rèn)值
tf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')    # 一個(gè) batch 可以組成num_seqs個(gè)輸入信號(hào)序列
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')             # 一個(gè)輸入信號(hào)序列的長度柴梆, rnn網(wǎng)絡(luò)會(huì)更具輸入進(jìn)行自動(dòng)調(diào)整
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')  # 隱藏層節(jié)點(diǎn)數(shù)量,即lstm 的 cell中state數(shù)量
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')          # rnn的深度
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')  # 如果中文字符則需要一個(gè)word2vec终惑, 字母字符直接采用onehot編碼
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')          # 使用word2vec的 中文字符的嵌入維度選取
tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')           # --input_file data/shakespeare.txt
tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')  
tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
# 不同于英文字符比較短幾十個(gè)就能解決绍在,中文字符比較多,word2vec層之前輸入需要進(jìn)行onehot編碼,根據(jù)字符頻數(shù)降序排列取前面的3500個(gè)編碼
tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')                
  1. 理解main函數(shù)中數(shù)據(jù)預(yù)處理部分, 數(shù)據(jù)預(yù)處理主要采用TextConverter類
def main(_):
    model_path = os.path.join('model', FLAGS.name)
    print("模型保存位置(根據(jù)模型命名)", model_path)
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)
    with codecs.open(FLAGS.input_file, encoding='utf-8') as f:
        print("建模訓(xùn)練數(shù)據(jù)來源:", FLAGS.input_file)
        text = f.read()
    converter = TextConverter(text,  # string     # 返回一個(gè)整理文本詞典的類
                              FLAGS.max_vocab)
    print("構(gòu)建該文本的字符集合數(shù)量(包含未登錄詞:):", converter.vocab_size)
    print("建模所用字符保存地址位置(list): ", os.path.join(model_path, 'converter.pkl'))  # 用來建模詞匯的 前max_vocab個(gè)
    converter.save_to_file(os.path.join(model_path, 'converter.pkl'))

    arr = converter.text_to_arr(text)
    # batch生成函數(shù):返回一個(gè)生成器
  • TextConverter類:\Char-RNN-TensorFlow-master\read_utils.py
    比如 莎士比亞訓(xùn)練數(shù)據(jù)用vocab組成:{v} {'} {[} {t} {u} {R} {W} {x} {?} { } {F} {I} {G} {O} {E} {$} {y} {e} {:} {L} {s} {c} {g} {Y} {]} {h} {w} {-} {a} {S} {J} {q} {V} {3} {X} {p} {T} {!} {C} {n} {;} {r} {M} {j} {f} {U} zphh1zv {Q} {K} 揣苏 {m} {H} {Z} {o} {i} {P} {D} {.} {l} {&} {N} {z} {A} {,} {
    } {B} {k}
class TextConverter(object):
    def __init__(self, text=None, max_vocab=5000, filename=None):
        """

        :param text: string
        :param max_vocab:
        :param filename:
        """
        if filename is not None:
            # 如果存在 字典文件悯嗓,即將字符集合進(jìn)行編號(hào)的字典
            with open(filename, 'rb') as f:
                self.vocab = pickle.load(f)
        else:
            vocab = set(text)  # 組成text的所有字符,比如卸察, i see you脯厨, 那么就是 i s e y o u
            logging.info('組成文本的字符集合:')
            logging.info("數(shù)量: %d" % len(vocab))
            s = ' '.join(["{%s}" % v for v in vocab])
            logging.info("vocab: %s" % s)

            # max_vocab_process
            vocab_count = defaultdict(int)     # 這里相對(duì)原始代碼做了小小優(yōu)化
            # 統(tǒng)計(jì)所有字符的頻數(shù)
            for word in text:
                vocab_count[word] += 1
            vocab_count_list = list(vocab_count.items())
            vocab_count_list.sort(key=lambda x: x[1], reverse=True)  # 根據(jù)頻數(shù)降序排序
            if len(vocab_count_list) > max_vocab:
                vocab_count_list = vocab_count_list[:max_vocab]  # 截取最大允許部分
            vocab = [x[0] for x in vocab_count_list]  # 截取 前max_vocab
            self.vocab = vocab
        # 對(duì)vocab進(jìn)行編序
        self.word_to_int_table = {c: i for i, c in enumerate(self.vocab)}  # 詞匯進(jìn)行編序號(hào)
        self.int_to_word_table = dict(enumerate(self.vocab))

    @property   # 這個(gè)實(shí)現(xiàn)直接,將vocab_size作為一個(gè)變量成員調(diào)用而不是方法
    def vocab_size(self):
        return len(self.vocab) + 1  # 加上一個(gè)未登錄詞

    def word_to_int(self, word):
        # 更具給定的字符返回index
        if word in self.word_to_int_table:
            return self.word_to_int_table[word]
        else:
            # 未登錄詞 就是最后一個(gè)序號(hào)
            return len(self.vocab)

    def int_to_word(self, index):
        # 根據(jù)給定indx返回字符
        if index == len(self.vocab):
            return '<unk>'  # 未登錄詞
        elif index < len(self.vocab):
            return self.int_to_word_table[index]
        else:
            raise Exception('Unknown index!')

    def text_to_arr(self, text):
        # 將文本序列化:字符轉(zhuǎn)化為index
        arr = []
        for word in text:
            arr.append(self.word_to_int(word))
        return np.array(arr)

    def arr_to_text(self, arr):
        # 反序列化
        words = []
        for index in arr:
            words.append(self.int_to_word(index))
        return "".join(words)

    def save_to_file(self, filename):
        # 存儲(chǔ)詞典
        with open(filename, 'wb') as f:
            pickle.dump(self.vocab, f)
  1. 準(zhǔn)備batch用于訓(xùn)練
    # batch生成函數(shù):返回一個(gè)生成器
    print("訓(xùn)練文本長度:", len(arr))
    print("num_seqs:", FLAGS.num_seqs)
    print("num_steps", FLAGS.num_steps)
    g = batch_generator(arr,   # 輸入信號(hào)文本序列
                        FLAGS.num_seqs,   # batch 信號(hào)序列數(shù)量
                        FLAGS.num_steps)  # 一個(gè)信號(hào)序列的長度
  • 重點(diǎn)在于理解batch_generator函數(shù)坑质, 這個(gè)過程的理解需要理解生成文本的rnn的輸出和輸入是什么樣的(N vs N合武, 輸出和輸入數(shù)目是一致的)

    • 一個(gè)單層的展開如下: 展開后h的節(jié)點(diǎn)個(gè)數(shù)取決于你的輸入序列向量的長度,即輸入文本的長度涡扼,圖片來源于簡書稼跳,這個(gè)鏈接可以幫助你很好從數(shù)學(xué)公式上理解。
      image.png
    • 一個(gè)文本序列輸入展示(這里為了直觀的展示沒有將文本數(shù)字化吃沪, 例如真正的"床"的輸入應(yīng)該為一個(gè)embeding的向量汤善, 而輸出“前”也是一個(gè)與輸入一致的長度向量)


      image.png
  • /read_utils.py 的batch_generator函數(shù)

def batch_generator(arr, n_seqs, n_steps):
    """
    生成訓(xùn)練用的batch
    :param arr:
    :param n_seqs:
    :param n_steps:
    :return:
    """
    arr = copy.copy(arr)                    # 序列
    batch_size = n_seqs * n_steps           # 一個(gè)batch需要的字符數(shù)量
    n_batches = int(len(arr) / batch_size)  # 整個(gè)文本可以生成的batch總數(shù)
    arr = arr[:batch_size * n_batches]      # 截取下 以便reshape成array
    arr = arr.reshape((n_seqs, -1))         # 將batch, reshape成n_seqs行票彪, 每行為一輸入信號(hào)序列(序列長度為n_steps)
    while True:
        np.random.shuffle(arr)  # 打亂文本序列順序
        print(arr)
        for n in range(0, arr.shape[1], n_steps):
            x = arr[:, n:(n + n_steps)]
            y = np.zeros_like(x)
            y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]  # 
            yield x, y
    • 測(cè)試下原來的代碼的結(jié)果
    arr = np.arange(27)
    for x, y in batch_generator(arr, 4, 3):
        print(x)
        print(y)
        break
    • out-put: 以 6 7 8為例红淡, 給出一個(gè)6, 生成文本的長度為3降铸。將6對(duì)應(yīng)的輸出7作為下一個(gè)state的輸入在旱,輸出8, 然后依次這么進(jìn)行下去推掸,y應(yīng)該為7桶蝎,8, 9谅畅。說明一下的是最后一個(gè)輸出為啥為6 登渣,前面一個(gè)鏈接存在解釋。
0-26序列進(jìn)行生成序列操作毡泻,每批訓(xùn)練batch序列總數(shù)為4绍豁, 每個(gè)寫的長度為3
打亂排序的結(jié)果
[[ 6  7  8  9 10 11]
 [ 0  1  2  3  4  5]
 [18 19 20 21 22 23]
 [12 13 14 15 16 17]]
x
[[ 6  7  8]
 [ 0  1  2]
 [18 19 20]
 [12 13 14]]
y
[[ 7  8  6]
 [ 1  2  0]
 [19 20 18]
 [13 14 12]]

rnn 模型搭建

為了更好的理解這個(gè)過程下面是實(shí)際整個(gè)rnn的結(jié)構(gòu)展開展示,如有錯(cuò)誤請(qǐng)指出:
代碼中構(gòu)建2層的rnn牙捉,每個(gè)state(方塊)的有兩個(gè)一樣的輸出h竹揍,得到輸出前有個(gè)softmax處理。


image.png
  • train.py中main函數(shù)調(diào)用rnn部分代碼
    model = CharRNN(converter.vocab_size,           # 分類的數(shù)量
                    num_seqs=FLAGS.num_seqs,        # 一個(gè)batch可以組成num_seq個(gè)信號(hào)
                    num_steps=FLAGS.num_steps,      # 一次信號(hào)輸入RNN的字符長度邪铲,與一層的cell 的數(shù)量掛鉤
                    lstm_size=FLAGS.lstm_size,      # 每個(gè)cell的節(jié)點(diǎn)數(shù)量:
                    num_layers=FLAGS.num_layers,    # RNN 的層數(shù)
                    learning_rate=FLAGS.learning_rate,  # 學(xué)習(xí)速率
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.train(g,
                FLAGS.max_steps,
                model_path,
                FLAGS.save_every_n,
                FLAGS.log_every_n,)
  • 重點(diǎn)在于model.py中的CharRNN類的調(diào)用

    • 搭建rnn隱藏層
      整個(gè)過程的理解在備注帶代碼里面芬位,暫時(shí)不用關(guān)注類里面,sample函數(shù)
# coding: utf-8


from __future__ import print_function
import tensorflow as tf
import numpy as np
import time
import os


def pick_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    # 將除了top_n個(gè)預(yù)測(cè)值的位置都置為0
    p[np.argsort(p)[:-top_n]] = 0
    # 歸一化概率
    p = p / np.sum(p)
    # 隨機(jī)選取一個(gè)字符
    c = np.random.choice(vocab_size, 1, p=p)[0]
    return c


class CharRNN:
    def __init__(self, num_classes, num_seqs=64, num_steps=50,
                 lstm_size=128, num_layers=2, learning_rate=0.001,
                 grad_clip=5, sampling=False, train_keep_prob=0.5,
                 use_embedding=False, embedding_size=128):
        if sampling is True:
            # 用于 預(yù)測(cè)
            num_seqs, num_steps = 1, 1   # 僅僅根據(jù)前面一個(gè)字符預(yù)測(cè)后面一個(gè)字符
        else:
            num_seqs, num_steps = num_seqs, num_steps

        self.num_classes = num_classes   # 分類結(jié)果數(shù)量带到,與字典容量一致包含未登錄字
        self.num_seqs = num_seqs
        self.num_steps = num_steps
        self.lstm_size = lstm_size
        self.num_layers = num_layers
        self.learning_rate = learning_rate
        self.grad_clip = grad_clip
        self.train_keep_prob = train_keep_prob
        self.use_embedding = use_embedding
        self.embedding_size = embedding_size

        tf.reset_default_graph()
        self.build_inputs()
        self.build_lstm()
        self.build_loss()
        self.build_optimizer()
        self.saver = tf.train.Saver()

    def build_inputs(self):
        # 定義下輸入昧碉,輸出等,占位
        with tf.name_scope('inputs'):
            # 輸入是一個(gè)3維度矩陣,但是這里并不要過多關(guān)注每個(gè)節(jié)點(diǎn)輸入特征的維度被饿,中文字符額embeding或者因?yàn)樽址膐nehot編碼四康。
            # 模型會(huì)自動(dòng)識(shí)別和調(diào)整,暫時(shí)考慮每一個(gè)batch被reshape成 num_seqs * num_steps, 每一行為一個(gè)序列輸入信號(hào)
            self.inputs = tf.placeholder(tf.int32, shape=(
                self.num_seqs, self.num_steps), name='inputs')

            # N vs N: 輸出與輸入一致
            self.targets = tf.placeholder(tf.int32, shape=(
                self.num_seqs, self.num_steps), name='targets')  # N vs N

            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

            # 對(duì)于中文狭握,需要使用embedding層: ???
            # 英文字母沒有必要用embedding層: ???
            if self.use_embedding is False:
                # 對(duì)字字符進(jìn)行onehot編號(hào)
                self.lstm_inputs = tf.one_hot(self.inputs, self.num_classes)
            else:
                with tf.device("/cpu:0"):
                    #  嵌入維度層word2vec和RNN連接器闪金;起來同時(shí)訓(xùn)練 作為模型的第一層
                    # 先進(jìn)行onehot編碼然后, word2vec 所以額輸入信號(hào)維度為num_classes
                    embedding = tf.get_variable('embedding', [self.num_classes, self.embedding_size])
                    self.lstm_inputs = tf.nn.embedding_lookup(embedding, self.inputs)

    def build_lstm(self):
        # 創(chuàng)建單個(gè)cell并堆疊多層
        def get_a_cell(lstm_size, keep_prob):
            """
            返回一個(gè)cell
            :param lstm_size: cell的states數(shù)量
            :param keep_prob: 節(jié)點(diǎn)保留率
            :return:
            """
            lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)    # state并不是采用普通rnn 而是lstm
            drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)   # 對(duì)每個(gè)state的節(jié)點(diǎn)數(shù)量進(jìn)行dropout
            return drop

        with tf.name_scope('lstm'):
            # 構(gòu)建多層
            cell = tf.nn.rnn_cell.MultiRNNCell(
                [get_a_cell(self.lstm_size, self.keep_prob) for _ in range(self.num_layers)]
            )
            # 定義h_0
            self.initial_state = cell.zero_state(self.num_seqs, tf.float32)

            # 通過dynamic_rnn對(duì)cell展開時(shí)間維度
            self.lstm_outputs, self.final_state = tf.nn.dynamic_rnn(cell, self.lstm_inputs,
                                                                    initial_state=self.initial_state)

            # 通過lstm_outputs得到概率
            # 每個(gè)batch的輸出為lstm_outputs: num_seqs * num_steps * state_node_size(中文字符嵌入維度或英文的onehot編碼維度)
            # 將輸出進(jìn)行拼接 dim=1  # seq out應(yīng)該為 num_steps * (num_seqs * state_node_size), 即沒每個(gè)輸入信號(hào)對(duì)應(yīng)state輸出進(jìn)行拼接论颅。
            # 但是在train里面查看發(fā)現(xiàn)哎垦, dim沒有任何改變
            seq_output = tf.concat(self.lstm_outputs, 1)

            self.seq_output = seq_output  # just for  output in train method

            # 將每個(gè)batch的每個(gè)state拼接成 一個(gè)二維的batch_size * state_node_size(lstm_size) 列矩陣
            x = tf.reshape(seq_output, [-1, self.lstm_size])

            # 構(gòu)建一個(gè)輸出層:softmax
            with tf.variable_scope('softmax'):
                # 初始化 輸出的權(quán)重, 共享
                softmax_w = tf.Variable(tf.truncated_normal([self.lstm_size, self.num_classes], stddev=0.1))
                softmax_b = tf.Variable(tf.zeros(self.num_classes))

            # 定義輸出:softmax 歸一化
            self.logits = tf.matmul(x, softmax_w) + softmax_b
            self.proba_prediction = tf.nn.softmax(self.logits, name='predictions')

    def build_loss(self):
        with tf.name_scope('loss'):
            # 統(tǒng)一第輸出進(jìn)行non hot編碼
            y_one_hot = tf.one_hot(self.targets, self.num_classes)
            y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())
            # 計(jì)算交叉信息熵
            loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=y_reshaped)
            # 計(jì)算平均損失
            self.loss = tf.reduce_mean(loss)

    def build_optimizer(self):
        # 使用clipping gradients:避免梯度計(jì)算迭代過程變化過大導(dǎo)致梯度爆炸現(xiàn)象
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), self.grad_clip)  # 恃疯?漏设??
        train_op = tf.train.AdamOptimizer(self.learning_rate)
        self.optimizer = train_op.apply_gradients(zip(grads, tvars))

    def train(self, batch_generator, max_steps, save_path, save_every_n, log_every_n):
        self.session = tf.Session()
        with self.session as sess:
            sess.run(tf.global_variables_initializer())
            # Train network
            step = 0
            new_state = sess.run(self.initial_state)
            for x, y in batch_generator:
                step += 1
                start = time.time()
                feed = {self.inputs: x,
                        self.targets: y,
                        self.keep_prob: self.train_keep_prob,
                        self.initial_state: new_state}   # 下一輪batch的初始h狀態(tài)采用上一輪的final_state

                batch_loss, new_state, _ , lstm_outputs, seq_output, prp = sess.run([self.loss,
                                                     self.final_state,
                                                     self.optimizer,
                                                     self.lstm_outputs,
                                                     self.seq_output,
                                                     self.proba_prediction
                                                     ],
                                                    feed_dict=feed)
                print('lstm outpts: ', lstm_outputs.shape, self.num_seqs)
                print('lstm outpts: ', seq_output.shape)   # ??? 為啥一直沒有改變
                print(prp.shape)
                end = time.time()
                # control the print lines
                if step % log_every_n == 0:
                    print('step: {}/{}... '.format(step, max_steps),
                          'loss: {:.4f}... '.format(batch_loss),
                          '{:.4f} sec/batch'.format((end - start)))
                if (step % save_every_n == 0):
                    self.saver.save(sess, os.path.join(save_path, 'model'), global_step=step)
                if step >= max_steps:
                    break
            self.saver.save(sess, os.path.join(save_path, 'model'), global_step=step)

    def sample(self, n_samples, prime, vocab_size):
        samples = [c for c in prime]
        sess = self.session
        new_state = sess.run(self.initial_state)
        preds = np.ones((vocab_size,))  # for prime=[]
        for c in prime:
            x = np.zeros((1, 1))
            # 輸入單個(gè)字符
            x[0, 0] = c
            feed = {self.inputs: x,
                    self.keep_prob: 1.,
                    self.initial_state: new_state}
            preds, new_state = sess.run([self.proba_prediction, self.final_state],
                                        feed_dict=feed)

        c = pick_top_n(preds, vocab_size)
        # 添加字符到samples中
        samples.append(c)

        # 不斷生成字符今妄,直到達(dá)到指定數(shù)目
        for i in range(n_samples):
            x = np.zeros((1, 1))
            x[0, 0] = c
            feed = {self.inputs: x,
                    self.keep_prob: 1.,
                    self.initial_state: new_state}
            preds, new_state = sess.run([self.proba_prediction, self.final_state],
                                        feed_dict=feed)

            c = pick_top_n(preds, vocab_size)
            samples.append(c)

        return np.array(samples)

    def load(self, checkpoint):
        """
        :param checkpoint: 命名
        :return:
        """
        # 存儲(chǔ) 訓(xùn)練好的神經(jīng)網(wǎng)絡(luò)模型
        self.session = tf.Session()
        self.saver.restore(self.session, checkpoint)
        print('Restored from: {}'.format(checkpoint))

利用模型生成文本

這個(gè)過程依靠調(diào)用sample.py腳本

import tensorflow as tf
from read_utils import TextConverter
from model import CharRNN
import os
from IPython import embed

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')    # 這里為什么還需要郑口??盾鳞?
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')            # 犬性??雁仲?
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_string('converter_path', '', 'model/name/converter.pkl')
tf.flags.DEFINE_string('checkpoint_path', '', 'checkpoint path')                   # 模型保存路徑
tf.flags.DEFINE_string('start_string', '', 'use this string to start generating')  # 給出一個(gè)字符開始生成
tf.flags.DEFINE_integer('max_length', 30, 'max length to generate')                # 最大字符


def main(_):
    FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    converter = TextConverter(filename=FLAGS.converter_path)  # 調(diào)用前面訓(xùn)練建立的voctab即可
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size, sampling=True,      # 調(diào)模型的保存不能保存節(jié)點(diǎn)等相關(guān)參數(shù)
                    lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)   # 載入訓(xùn)練好的模型

    start = converter.text_to_arr(FLAGS.start_string)  # 字符轉(zhuǎn)化為idnex
    arr = model.sample(FLAGS.max_length, start, converter.vocab_size)
    print(converter.arr_to_text(arr))  # 反序列化


if __name__ == '__main__':
    tf.app.run()
  • 主要調(diào)用rnn模型類的sample方法,很簡單琐脏,注釋即可看懂
    def sample(self, n_samples, prime, vocab_size):
        """
        用一個(gè)字符生成一段文本
        :param n_samples:
        :param prime:
        :param vocab_size:
        :return:
        """
        print("初始字符為:", prime)
        samples = [c for c in prime]
        sess = self.session
        new_state = sess.run(self.initial_state)
        preds = np.ones((vocab_size,))  # for prime=[] 
        # 對(duì)給定的初始字符串來攒砖,一次feed
        for c in prime:
            x = np.zeros((1, 1))
            # 輸入單個(gè)字符
            x[0, 0] = c
            feed = {self.inputs: x,
                    self.keep_prob: 1.,
                    self.initial_state: new_state}
            preds, new_state = sess.run([self.proba_prediction, self.final_state],
                                        feed_dict=feed)
            print(preds)  # 最后一個(gè)字符的輸出

        c = pick_top_n(preds, vocab_size)   # 
        # 添加字符到samples中
        samples.append(c)  # 根據(jù)概率所及選取

        # 不斷生成字符,直到達(dá)到指定數(shù)目
        for i in range(n_samples):
            x = np.zeros((1, 1))
            x[0, 0] = c
            feed = {self.inputs: x,
                    self.keep_prob: 1.,
                    self.initial_state: new_state}
            preds, new_state = sess.run([self.proba_prediction, self.final_state],
                                        feed_dict=feed)
            # 上一次的輸入作為下一次的輸出日裙, 直到達(dá)到指定長度
            c = pick_top_n(preds, vocab_size)   
            samples.append(c)

        return np.array(samples)

輸出展示:

同樣采用莎士比亞文集訓(xùn)練模型:

python sample.py --converter_path model/shakespeare/converter.pkl --checkpoint_path model/shakespeare/ --max_length 30 --start_string He

Heds since
I that what that when
以上為輸出結(jié)果吹艇,并不能成為一句語句, 個(gè)人覺得利用word并進(jìn)行word2vec來生成語句昂拂,可能會(huì)更好受神,利用字符 維度太低了。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末格侯,一起剝皮案震驚了整個(gè)濱河市鼻听,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌联四,老刑警劉巖撑碴,帶你破解...
    沈念sama閱讀 211,194評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異朝墩,居然都是意外死亡醉拓,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,058評(píng)論 2 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來亿卤,“玉大人愤兵,你說我怎么就攤上這事∨盼猓” “怎么了秆乳?”我有些...
    開封第一講書人閱讀 156,780評(píng)論 0 346
  • 文/不壞的土叔 我叫張陵,是天一觀的道長傍念。 經(jīng)常有香客問我矫夷,道長,這世上最難降的妖魔是什么憋槐? 我笑而不...
    開封第一講書人閱讀 56,388評(píng)論 1 283
  • 正文 為了忘掉前任双藕,我火速辦了婚禮,結(jié)果婚禮上阳仔,老公的妹妹穿的比我還像新娘忧陪。我一直安慰自己,他們只是感情好近范,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,430評(píng)論 5 384
  • 文/花漫 我一把揭開白布嘶摊。 她就那樣靜靜地躺著,像睡著了一般评矩。 火紅的嫁衣襯著肌膚如雪叶堆。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,764評(píng)論 1 290
  • 那天斥杜,我揣著相機(jī)與錄音虱颗,去河邊找鬼。 笑死蔗喂,一個(gè)胖子當(dāng)著我的面吹牛忘渔,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播缰儿,決...
    沈念sama閱讀 38,907評(píng)論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼畦粮,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了乖阵?” 一聲冷哼從身側(cè)響起宣赔,我...
    開封第一講書人閱讀 37,679評(píng)論 0 266
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎瞪浸,沒想到半個(gè)月后拉背,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,122評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡默终,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,459評(píng)論 2 325
  • 正文 我和宋清朗相戀三年椅棺,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了犁罩。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,605評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡两疚,死狀恐怖床估,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情诱渤,我是刑警寧澤丐巫,帶...
    沈念sama閱讀 34,270評(píng)論 4 329
  • 正文 年R本政府宣布,位于F島的核電站勺美,受9級(jí)特大地震影響递胧,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜赡茸,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,867評(píng)論 3 312
  • 文/蒙蒙 一缎脾、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧占卧,春花似錦遗菠、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,734評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至叭喜,卻和暖如春贺拣,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背捂蕴。 一陣腳步聲響...
    開封第一講書人閱讀 31,961評(píng)論 1 265
  • 我被黑心中介騙來泰國打工譬涡, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人启绰。 一個(gè)月前我還...
    沈念sama閱讀 46,297評(píng)論 2 360
  • 正文 我出身青樓昂儒,卻偏偏與公主長得像沟使,于是被迫代替她去往敵國和親委可。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,472評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容

  • 作者 | 武維AI前線出品| ID:ai-front 前言 自然語言處理(簡稱NLP)腊嗡,是研究計(jì)算機(jī)處理人類語言的...
    AI前線閱讀 2,564評(píng)論 0 8
  • 近日着倾,谷歌官方在 Github開放了一份神經(jīng)機(jī)器翻譯教程,該教程從基本概念實(shí)現(xiàn)開始燕少,首先搭建了一個(gè)簡單的NMT模型...
    MiracleJQ閱讀 6,352評(píng)論 1 11
  • 黑暗中無數(shù)雙慘白的手臂在向她揮舞著卡者。 過來呀,過來呀客们。 陰郁但輕柔的聲音一波一波的鉆入她的心房崇决,像深秋的風(fēng)拂過枯草...
    一平London閱讀 3,638評(píng)論 5 8
  • 早安,親愛的朋友盈厘!在人生路上睁枕,很多時(shí)候看不清未來,回不到過去沸手。那么外遇,就別讓時(shí)間在眺望未來中流逝,莫讓年華在回憶過去...
    js1314閱讀 300評(píng)論 0 0
  • 認(rèn)識(shí)柴靜快三年了契吉,但總是會(huì)隔一段時(shí)間看一下她的文章或者視頻跳仿。我覺得自己必須要寫一下柴靜。與其寫《看見》帶給我的感悟...
    邢彩燕閱讀 345評(píng)論 0 6