機(jī)器學(xué)習(xí)RRN訓(xùn)練聊天機(jī)器人

前言

上篇寫過(guò)一個(gè)機(jī)器學(xué)習(xí)寫唐詩(shī)的實(shí)驗(yàn),這次我們搞個(gè)稍微復(fù)雜些的,實(shí)現(xiàn)一個(gè)聊天機(jī)器人则酝,也是基于騰訊云實(shí)驗(yàn)室的一篇教程,有些部分做了改動(dòng)闰集,大部分時(shí)間都用在了環(huán)境的適配上面沽讹。開始本地是在Mac環(huán)境般卑,單獨(dú)依靠CPU訓(xùn)練,比較慢爽雄。后來(lái)找了個(gè)配置比較好的機(jī)器蝠检, 6核心12線程,效果好一些挚瘟√舅總結(jié)來(lái)說(shuō),機(jī)器學(xué)習(xí)相關(guān)有兩個(gè)重點(diǎn)刽沾,一個(gè)是基礎(chǔ)的訓(xùn)練資源本慕,包括對(duì)原始數(shù)據(jù)的清洗處理和規(guī)范化,訓(xùn)練中其實(shí)模型是沒(méi)有很大區(qū)別的侧漓。其次锅尘,是好的機(jī)器配置,資源有限布蔗,沒(méi)有上GPU藤违。這次實(shí)驗(yàn),本地訓(xùn)練大概半天到4000步的時(shí)候纵揍,還只是個(gè)復(fù)讀機(jī)顿乒,換了高配機(jī)器1天左右就可以到30萬(wàn)左右,兩天到70萬(wàn)泽谨,基本達(dá)到損失率穩(wěn)定(30萬(wàn)就可以)璧榄。
以下是本地機(jī)器的配置,奈何效果不行

MacBook Pro (13-inch, 2017, Four Thunderbolt 3 Ports)
10.13.6 (17G65)16 GB 2133 MHz LPDDR3
3.1 GHz Intel Core i5

注意事項(xiàng)
強(qiáng)烈建議使用virtualenv配置python吧雹,簡(jiǎn)單而且不會(huì)對(duì)本地運(yùn)行環(huán)境造成影響骨杂。
同時(shí)需要安裝好TensorFlow環(huán)境

過(guò)程步驟

實(shí)驗(yàn)內(nèi)容

  1. 首先進(jìn)行數(shù)據(jù)的清洗,處理雄卷。提取ask和answer數(shù)據(jù)搓蚪,并得到字典,以及做向量化處理丁鹉。訓(xùn)練數(shù)據(jù)可以使用本次實(shí)驗(yàn)鏈接里的妒潭,也可以使用網(wǎng)上的小黃雞等等語(yǔ)料。注意這里的字典之前查的資料是滿足3000左右的常用漢字就可以揣钦,是在語(yǔ)料中找到常用字雳灾。

  2. 模型學(xué)習(xí)部分。
    這里引用了seq2seq的部分冯凹,單獨(dú)有一些修改佑女。之前下載實(shí)驗(yàn)中提供的訓(xùn)練了30萬(wàn)次左右的模型直接進(jìn)行對(duì)話,但是本地一直提示錯(cuò)誤谈竿。最終選擇了自己訓(xùn)練团驱,保存了完整的checkpoint文件,可以啟動(dòng)程序空凸。如圖最終訓(xùn)練在71萬(wàn)次左右嚎花,其實(shí)30萬(wàn)左右損失率基本就已經(jīng)不變了,如果能提供更優(yōu)化的語(yǔ)料應(yīng)該效果會(huì)更好呀洲。后續(xù)有鏈接提供所有資料紊选,可以直接下載。


    訓(xùn)練完畢的模型
  3. 模擬對(duì)話道逗,這部分是最終的成果兵罢,啟動(dòng)本地依賴,加載訓(xùn)練模型之后就可以對(duì)話了滓窍,效果看圖卖词,可以看到有些句子還是可以對(duì)上的,一問(wèn)一答吏夯,有些幼稚此蜈。


    模擬對(duì)話

代碼部分

  1. 數(shù)據(jù)整理和向量化 generate.py
# -*- coding:utf-8 -*-
from io import open
import random
import tensorflow as tf

# version tf 1.12 2018-12-08 22:22:08
PAD = "PAD"
GO = "GO"
EOS = "EOS"
UNK = "UNK"
START_VOCAB = [PAD, GO, EOS, UNK]

PAD_ID = 0  # 填充
GO_ID = 1  # 開始標(biāo)志
EOS_ID = 2  # 結(jié)束標(biāo)志
UNK_ID = 3  # 未知字符
_buckets = [(10, 15), (20, 25), (40, 50), (80, 100)]
units_num = 256
num_layers = 3
max_gradient_norm = 5.0
batch_size = 50
learning_rate = 0.5
learning_rate_decay_factor = 0.97

train_encode_file = "data/train_encode"
train_decode_file = "data/train_decode"
test_encode_file = "data/test_encode"
test_decode_file = "data/test_decode"
vocab_encode_file = "data/vocab_encode"
vocab_decode_file = "data/vocab_decode"
train_encode_vec_file = "data/train_encode_vec"
train_decode_vec_file = "data/train_decode_vec"
test_encode_vec_file = "data/test_encode_vec"
test_decode_vec_file = "data/test_decode_vec"


def is_chinese(sentence):
    flag = True
    if len(sentence) < 2:
        flag = False
        return flag
    for uchar in sentence:
        if (uchar == ',' or uchar == '噪生。' or
                uchar == '~' or uchar == '?' or
                uchar == '裆赵!'):
            flag = True
        elif '一' <= uchar <= '?':
            flag = True
        else:
            flag = False
            break
    return flag


def get_chatbot():
    f = open("data/chat.conv", "r", encoding="utf-8")
    train_encode = open(train_encode_file, "w", encoding="utf-8")
    train_decode = open(train_decode_file, "w", encoding="utf-8")
    test_encode = open(test_encode_file, "w", encoding="utf-8")
    test_decode = open(test_decode_file, "w", encoding="utf-8")
    vocab_encode = open(vocab_encode_file, "w", encoding="utf-8")
    vocab_decode = open(vocab_decode_file, "w", encoding="utf-8")
    encode = list()
    decode = list()

    chat = list()
    print("start load source data...")
    step = 0
    for line in f.readlines():
        line = line.strip('\n').strip()
        if not line:
            continue
        if line[0] == "E":
            if step % 1000 == 0:
                print("step:%d" % step)
            step += 1
            if (len(chat) == 2 and is_chinese(chat[0]) and is_chinese(chat[1]) and
                    not chat[0] in encode and not chat[1] in decode):
                encode.append(chat[0])
                decode.append(chat[1])
            chat = list()
        elif line[0] == "M":
            L = line.split(' ')
            if len(L) > 1:
                chat.append(L[1])
    encode_size = len(encode)
    if encode_size != len(decode):
        raise ValueError("encode size not equal to decode size")
    test_index = random.sample([i for i in range(encode_size)], int(encode_size * 0.2))
    print("divide source into two...")
    step = 0
    for i in range(encode_size):
        if step % 1000 == 0:
            print("%d" % step)
        step += 1
        if i in test_index:
            test_encode.write(encode[i] + "\n")
            test_decode.write(decode[i] + "\n")
        else:
            train_encode.write(encode[i] + "\n")
            train_decode.write(decode[i] + "\n")

    vocab_encode_set = set(''.join(encode))
    vocab_decode_set = set(''.join(decode))
    print("get vocab_encode...")
    step = 0
    for word in vocab_encode_set:
        if step % 1000 == 0:
            print("%d" % step)
        step += 1
        vocab_encode.write(word + "\n")
    print("get vocab_decode...")
    step = 0
    for word in vocab_decode_set:
        print("%d" % step)
        step += 1
        vocab_decode.write(word + "\n")


def gen_chatbot_vectors(input_file, vocab_file, output_file):
    vocab_f = open(vocab_file, "r", encoding="utf-8")
    output_f = open(output_file, "w")
    input_f = open(input_file, "r", encoding="utf-8")
    words = list()
    for word in vocab_f.readlines():
        word = word.strip('\n').strip()
        words.append(word)
    word_to_id = {word: i for i, word in enumerate(words)}
    to_id = lambda word: word_to_id.get(word, UNK_ID)
    print("get %s vectors" % input_file)
    step = 0
    for line in input_f.readlines():
        if step % 1000 == 0:
            print("step:%d" % step)
        step += 1
        line = line.strip('\n').strip()
        vec = map(to_id, line)
        output_f.write(' '.join([str(n) for n in vec]) + "\n")


def get_vectors():
    gen_chatbot_vectors(train_encode_file, vocab_encode_file, train_encode_vec_file)
    gen_chatbot_vectors(train_decode_file, vocab_decode_file, train_decode_vec_file)
    gen_chatbot_vectors(test_encode_file, vocab_encode_file, test_encode_vec_file)
    gen_chatbot_vectors(test_decode_file, vocab_decode_file, test_decode_vec_file)


def get_vocabs(vocab_file):
    words = list()
    with open(vocab_file, "r", encoding="utf-8") as vocab_f:
        for word in vocab_f:
            words.append(word.strip('\n').strip())
    id_to_word = {i: word for i, word in enumerate(words)}
    word_to_id = {v: k for k, v in id_to_word.items()}
    vocab_size = len(id_to_word)
    return id_to_word, word_to_id, vocab_size


def read_data(source_path, target_path, max_size=None):
    data_set = [[] for _ in _buckets]
    with tf.gfile.GFile(source_path, mode="r") as source_file:
        with tf.gfile.GFile(target_path, mode="r") as target_file:
            source, target = source_file.readline(), target_file.readline()
            counter = 0
            while source and target and (not max_size or counter < max_size):
                counter += 1
                source_ids = [int(x) for x in source.split()]
                target_ids = [int(x) for x in target.split()]
                target_ids.append(EOS_ID)
                for bucket_id, (source_size, target_size) in enumerate(_buckets):
                    if len(source_ids) < source_size and len(target_ids) < target_size:
                        data_set[bucket_id].append([source_ids, target_ids])
                        break
                source, target = source_file.readline(), target_file.readline()
    return data_set


# run
#獲取 ask、answer 數(shù)據(jù)并生成字典
# get_chatbot()
#訓(xùn)練數(shù)據(jù)轉(zhuǎn)化為數(shù)字表示
# get_vectors()
  1. 學(xué)習(xí)模型

簡(jiǎn)書限制太長(zhǎng)無(wú)法發(fā)布跺嗽,只能在最后的鏈接獲取了
seq2seq.py
seq2seq_model.py

  1. 訓(xùn)練模塊

可以改小配置中的step部分战授,簡(jiǎn)單驗(yàn)證下效果。這里有些改動(dòng)桨嫁,加了間隔一定步驟之后植兰,保存checkpoint到本地的功能,防止中間如果有異常瞧甩,比如斷電或者不小心關(guān)閉程序或者其他原因造成程序崩潰钉跷,導(dǎo)致前功盡棄。

train_chat.py

# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import logging
import logging.handlers

if __name__ == '__main__':

    _, _, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
    _, _, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
    train_set = generate_chat.read_data(generate_chat.train_encode_vec_file, generate_chat.train_decode_vec_file)
    test_set = generate_chat.read_data(generate_chat.test_encode_vec_file, generate_chat.test_decode_vec_file)
    train_bucket_sizes = [len(train_set[i]) for i in range(len(generate_chat._buckets))]
    train_total_size = float(sum(train_bucket_sizes))
    train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
    cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
    with tf.Session(config=cpu_config) as sess:
        model = seq2seq_model.Seq2SeqModel(source_vocab_size,
                                           target_vocab_size,
                                           generate_chat._buckets,
                                           generate_chat.units_num,
                                           generate_chat.num_layers,
                                           generate_chat.max_gradient_norm,
                                           generate_chat.batch_size,
                                           generate_chat.learning_rate,
                                           generate_chat.learning_rate_decay_factor,
                                           use_lstm=True)

        ckpt = tf.train.get_checkpoint_state('./mytrain')

        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print("Created model with fresh parameters.")
            sess.run(tf.global_variables_initializer())
        loss = 0.0
        step = 0
        previous_losses = []
        while True:
            random_number_01 = np.random.random_sample()
            bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
            _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
            print("step:%d,loss:%f" % (step, step_loss))
            loss += step_loss / 2000
            step += 1
            if step % 1000 == 0:
                print("step:%d,per_loss:%f" % (step, loss))
                if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                model.saver.save(sess, "mytrain/chatbot.ckpt", global_step=model.global_step)
                loss = 0.0
            if step % 5000 == 0:
                for bucket_id in range(len(generate_chat._buckets)):
                    if len(test_set[bucket_id]) == 0:
                        continue
                        encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
                        _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
                                                     True)
                        print("bucket_id:%d,eval_loss:%f" % (bucket_id, eval_loss))

  1. 對(duì)話模塊
    chat.py
# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import sys

if __name__ == '__main__':
    source_id_to_word, source_word_to_id, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
    target_id_to_word, target_word_to_id, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
    to_id = lambda word: source_word_to_id.get(word, generate_chat.UNK_ID)
    cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
    with tf.Session(config=cpu_config) as sess:
        model = seq2seq_model.Seq2SeqModel(source_vocab_size,
                                           target_vocab_size,
                                           generate_chat._buckets,
                                           generate_chat.units_num,
                                           generate_chat.num_layers,
                                           generate_chat.max_gradient_norm,
                                           1,
                                           generate_chat.learning_rate,
                                           generate_chat.learning_rate_decay_factor,
                                           forward_only=True,
                                           use_lstm=True)
        #model.saver.restore(sess, "model/chatbot.ckpt-317000")
        model.saver.restore(sess, "mytrain/chatbot.ckpt-717000")
        while True:
            sys.stdout.write("ask > ")
            sys.stdout.flush()
            sentence = sys.stdin.readline().strip('\n')
            flag = generate_chat.is_chinese(sentence)
            if not sentence or not flag:
                print("請(qǐng)輸入純中文")
                continue
            sentence_vec = list(map(to_id, sentence))
            bucket_id = len(generate_chat._buckets) - 1
            if len(sentence_vec) > generate_chat._buckets[bucket_id][0]:
                print("sentence too long max:%d" % generate_chat._buckets[bucket_id][0])
                exit(0)
            for i, bucket in enumerate(generate_chat._buckets):
                if bucket[0] >= len(sentence_vec):
                    bucket_id = i
                    break
            encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(sentence_vec, [])]},
                                                                             bucket_id)
            _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            if generate_chat.EOS_ID in outputs:
                outputs = outputs[:outputs.index(generate_chat.EOS_ID)]
            answer = "".join([tf.compat.as_str(target_id_to_word[output]) for output in outputs])
            print("answer > " + answer)

注意
這里在train_chat.py 和 chat.py中肚逸,tf.session有個(gè)配置改動(dòng)爷辙,限制了使用的CPU數(shù),在Ubuntu下如果沒(méi)有限制朦促,會(huì)造成TF占用所有的CPU資源膝晾,導(dǎo)致系統(tǒng)卡死,具體數(shù)值根據(jù)CPU核心數(shù)設(shè)置务冕。
代碼如下:

cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
    with tf.Session(config=cpu_config) as sess:

結(jié)語(yǔ)

感謝閱讀血当,最后放上實(shí)驗(yàn)的實(shí)際地址和我自己訓(xùn)練的所有資源,本地實(shí)驗(yàn)在mac tf 1.12.0 和 python3.6.7,以及Ubuntu tf.1.12.0 和 python3.5環(huán)境下都正常臊旭,再次建議在virtualenv環(huán)境下落恼。
實(shí)驗(yàn)鏈接(時(shí)間過(guò)久可能失效):https://cloud.tencent.com/developer/labs/lab/10406
本地實(shí)驗(yàn)資源:https://iss.igosh.com/share/201903/tencent-me.tar.gz

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市离熏,隨后出現(xiàn)的幾起案子佳谦,更是在濱河造成了極大的恐慌,老刑警劉巖滋戳,帶你破解...
    沈念sama閱讀 218,036評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件钻蔑,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡奸鸯,警方通過(guò)查閱死者的電腦和手機(jī)咪笑,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,046評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)娄涩,“玉大人窗怒,你說(shuō)我怎么就攤上這事《勐” “怎么了兜粘?”我有些...
    開封第一講書人閱讀 164,411評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)弯蚜。 經(jīng)常有香客問(wèn)我孔轴,道長(zhǎng),這世上最難降的妖魔是什么碎捺? 我笑而不...
    開封第一講書人閱讀 58,622評(píng)論 1 293
  • 正文 為了忘掉前任路鹰,我火速辦了婚禮,結(jié)果婚禮上收厨,老公的妹妹穿的比我還像新娘晋柱。我一直安慰自己,他們只是感情好诵叁,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,661評(píng)論 6 392
  • 文/花漫 我一把揭開白布雁竞。 她就那樣靜靜地躺著,像睡著了一般拧额。 火紅的嫁衣襯著肌膚如雪碑诉。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,521評(píng)論 1 304
  • 那天侥锦,我揣著相機(jī)與錄音进栽,去河邊找鬼。 笑死恭垦,一個(gè)胖子當(dāng)著我的面吹牛快毛,可吹牛的內(nèi)容都是我干的格嗅。 我是一名探鬼主播,決...
    沈念sama閱讀 40,288評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼唠帝,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼屯掖!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起没隘,我...
    開封第一講書人閱讀 39,200評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤懂扼,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后右蒲,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,644評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡赶熟,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,837評(píng)論 3 336
  • 正文 我和宋清朗相戀三年瑰妄,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片映砖。...
    茶點(diǎn)故事閱讀 39,953評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡间坐,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出邑退,到底是詐尸還是另有隱情竹宋,我是刑警寧澤,帶...
    沈念sama閱讀 35,673評(píng)論 5 346
  • 正文 年R本政府宣布地技,位于F島的核電站蜈七,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏莫矗。R本人自食惡果不足惜飒硅,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,281評(píng)論 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望作谚。 院中可真熱鬧三娩,春花似錦、人聲如沸妹懒。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,889評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)眨唬。三九已至会前,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間单绑,已是汗流浹背回官。 一陣腳步聲響...
    開封第一講書人閱讀 33,011評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留搂橙,地道東北人歉提。 一個(gè)月前我還...
    沈念sama閱讀 48,119評(píng)論 3 370
  • 正文 我出身青樓笛坦,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親苔巨。 傳聞我的和親對(duì)象是個(gè)殘疾皇子版扩,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,901評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容