前言
上篇寫過(guò)一個(gè)機(jī)器學(xué)習(xí)寫唐詩(shī)的實(shí)驗(yàn),這次我們搞個(gè)稍微復(fù)雜些的,實(shí)現(xiàn)一個(gè)聊天機(jī)器人则酝,也是基于騰訊云實(shí)驗(yàn)室的一篇教程,有些部分做了改動(dòng)闰集,大部分時(shí)間都用在了環(huán)境的適配上面沽讹。開始本地是在Mac環(huán)境般卑,單獨(dú)依靠CPU訓(xùn)練,比較慢爽雄。后來(lái)找了個(gè)配置比較好的機(jī)器蝠检, 6核心12線程,效果好一些挚瘟√舅總結(jié)來(lái)說(shuō),機(jī)器學(xué)習(xí)相關(guān)有兩個(gè)重點(diǎn)刽沾,一個(gè)是基礎(chǔ)的訓(xùn)練資源本慕,包括對(duì)原始數(shù)據(jù)的清洗處理和規(guī)范化,訓(xùn)練中其實(shí)模型是沒(méi)有很大區(qū)別的侧漓。其次锅尘,是好的機(jī)器配置,資源有限布蔗,沒(méi)有上GPU藤违。這次實(shí)驗(yàn),本地訓(xùn)練大概半天到4000步的時(shí)候纵揍,還只是個(gè)復(fù)讀機(jī)顿乒,換了高配機(jī)器1天左右就可以到30萬(wàn)左右,兩天到70萬(wàn)泽谨,基本達(dá)到損失率穩(wěn)定(30萬(wàn)就可以)璧榄。
以下是本地機(jī)器的配置,奈何效果不行
MacBook Pro (13-inch, 2017, Four Thunderbolt 3 Ports)
10.13.6 (17G65)16 GB 2133 MHz LPDDR3
3.1 GHz Intel Core i5
注意事項(xiàng)
:
強(qiáng)烈建議使用virtualenv配置python吧雹,簡(jiǎn)單而且不會(huì)對(duì)本地運(yùn)行環(huán)境造成影響骨杂。
同時(shí)需要安裝好TensorFlow環(huán)境
過(guò)程步驟
實(shí)驗(yàn)內(nèi)容
首先進(jìn)行數(shù)據(jù)的清洗,處理雄卷。提取ask和answer數(shù)據(jù)搓蚪,并得到字典,以及做向量化處理丁鹉。訓(xùn)練數(shù)據(jù)可以使用本次實(shí)驗(yàn)鏈接里的妒潭,也可以使用網(wǎng)上的小黃雞等等語(yǔ)料。注意這里的字典之前查的資料是滿足3000左右的常用漢字就可以揣钦,是在語(yǔ)料中找到常用字雳灾。
-
模型學(xué)習(xí)部分。
這里引用了seq2seq的部分冯凹,單獨(dú)有一些修改佑女。之前下載實(shí)驗(yàn)中提供的訓(xùn)練了30萬(wàn)次左右的模型直接進(jìn)行對(duì)話,但是本地一直提示錯(cuò)誤谈竿。最終選擇了自己訓(xùn)練团驱,保存了完整的checkpoint文件,可以啟動(dòng)程序空凸。如圖最終訓(xùn)練在71萬(wàn)次左右嚎花,其實(shí)30萬(wàn)左右損失率基本就已經(jīng)不變了,如果能提供更優(yōu)化的語(yǔ)料應(yīng)該效果會(huì)更好呀洲。后續(xù)有鏈接提供所有資料紊选,可以直接下載。
訓(xùn)練完畢的模型 -
模擬對(duì)話道逗,這部分是最終的成果兵罢,啟動(dòng)本地依賴,加載訓(xùn)練模型之后就可以對(duì)話了滓窍,效果看圖卖词,可以看到有些句子還是可以對(duì)上的,一問(wèn)一答吏夯,有些幼稚此蜈。
模擬對(duì)話
代碼部分
- 數(shù)據(jù)整理和向量化 generate.py
# -*- coding:utf-8 -*-
from io import open
import random
import tensorflow as tf
# version tf 1.12 2018-12-08 22:22:08
PAD = "PAD"
GO = "GO"
EOS = "EOS"
UNK = "UNK"
START_VOCAB = [PAD, GO, EOS, UNK]
PAD_ID = 0 # 填充
GO_ID = 1 # 開始標(biāo)志
EOS_ID = 2 # 結(jié)束標(biāo)志
UNK_ID = 3 # 未知字符
_buckets = [(10, 15), (20, 25), (40, 50), (80, 100)]
units_num = 256
num_layers = 3
max_gradient_norm = 5.0
batch_size = 50
learning_rate = 0.5
learning_rate_decay_factor = 0.97
train_encode_file = "data/train_encode"
train_decode_file = "data/train_decode"
test_encode_file = "data/test_encode"
test_decode_file = "data/test_decode"
vocab_encode_file = "data/vocab_encode"
vocab_decode_file = "data/vocab_decode"
train_encode_vec_file = "data/train_encode_vec"
train_decode_vec_file = "data/train_decode_vec"
test_encode_vec_file = "data/test_encode_vec"
test_decode_vec_file = "data/test_decode_vec"
def is_chinese(sentence):
flag = True
if len(sentence) < 2:
flag = False
return flag
for uchar in sentence:
if (uchar == ',' or uchar == '噪生。' or
uchar == '~' or uchar == '?' or
uchar == '裆赵!'):
flag = True
elif '一' <= uchar <= '?':
flag = True
else:
flag = False
break
return flag
def get_chatbot():
f = open("data/chat.conv", "r", encoding="utf-8")
train_encode = open(train_encode_file, "w", encoding="utf-8")
train_decode = open(train_decode_file, "w", encoding="utf-8")
test_encode = open(test_encode_file, "w", encoding="utf-8")
test_decode = open(test_decode_file, "w", encoding="utf-8")
vocab_encode = open(vocab_encode_file, "w", encoding="utf-8")
vocab_decode = open(vocab_decode_file, "w", encoding="utf-8")
encode = list()
decode = list()
chat = list()
print("start load source data...")
step = 0
for line in f.readlines():
line = line.strip('\n').strip()
if not line:
continue
if line[0] == "E":
if step % 1000 == 0:
print("step:%d" % step)
step += 1
if (len(chat) == 2 and is_chinese(chat[0]) and is_chinese(chat[1]) and
not chat[0] in encode and not chat[1] in decode):
encode.append(chat[0])
decode.append(chat[1])
chat = list()
elif line[0] == "M":
L = line.split(' ')
if len(L) > 1:
chat.append(L[1])
encode_size = len(encode)
if encode_size != len(decode):
raise ValueError("encode size not equal to decode size")
test_index = random.sample([i for i in range(encode_size)], int(encode_size * 0.2))
print("divide source into two...")
step = 0
for i in range(encode_size):
if step % 1000 == 0:
print("%d" % step)
step += 1
if i in test_index:
test_encode.write(encode[i] + "\n")
test_decode.write(decode[i] + "\n")
else:
train_encode.write(encode[i] + "\n")
train_decode.write(decode[i] + "\n")
vocab_encode_set = set(''.join(encode))
vocab_decode_set = set(''.join(decode))
print("get vocab_encode...")
step = 0
for word in vocab_encode_set:
if step % 1000 == 0:
print("%d" % step)
step += 1
vocab_encode.write(word + "\n")
print("get vocab_decode...")
step = 0
for word in vocab_decode_set:
print("%d" % step)
step += 1
vocab_decode.write(word + "\n")
def gen_chatbot_vectors(input_file, vocab_file, output_file):
vocab_f = open(vocab_file, "r", encoding="utf-8")
output_f = open(output_file, "w")
input_f = open(input_file, "r", encoding="utf-8")
words = list()
for word in vocab_f.readlines():
word = word.strip('\n').strip()
words.append(word)
word_to_id = {word: i for i, word in enumerate(words)}
to_id = lambda word: word_to_id.get(word, UNK_ID)
print("get %s vectors" % input_file)
step = 0
for line in input_f.readlines():
if step % 1000 == 0:
print("step:%d" % step)
step += 1
line = line.strip('\n').strip()
vec = map(to_id, line)
output_f.write(' '.join([str(n) for n in vec]) + "\n")
def get_vectors():
gen_chatbot_vectors(train_encode_file, vocab_encode_file, train_encode_vec_file)
gen_chatbot_vectors(train_decode_file, vocab_decode_file, train_decode_vec_file)
gen_chatbot_vectors(test_encode_file, vocab_encode_file, test_encode_vec_file)
gen_chatbot_vectors(test_decode_file, vocab_decode_file, test_decode_vec_file)
def get_vocabs(vocab_file):
words = list()
with open(vocab_file, "r", encoding="utf-8") as vocab_f:
for word in vocab_f:
words.append(word.strip('\n').strip())
id_to_word = {i: word for i, word in enumerate(words)}
word_to_id = {v: k for k, v in id_to_word.items()}
vocab_size = len(id_to_word)
return id_to_word, word_to_id, vocab_size
def read_data(source_path, target_path, max_size=None):
data_set = [[] for _ in _buckets]
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
counter = 0
while source and target and (not max_size or counter < max_size):
counter += 1
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
target_ids.append(EOS_ID)
for bucket_id, (source_size, target_size) in enumerate(_buckets):
if len(source_ids) < source_size and len(target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
# run
#獲取 ask、answer 數(shù)據(jù)并生成字典
# get_chatbot()
#訓(xùn)練數(shù)據(jù)轉(zhuǎn)化為數(shù)字表示
# get_vectors()
- 學(xué)習(xí)模型
簡(jiǎn)書限制太長(zhǎng)無(wú)法發(fā)布跺嗽,只能在最后的鏈接獲取了
seq2seq.py
seq2seq_model.py
- 訓(xùn)練模塊
可以改小配置中的step部分战授,簡(jiǎn)單驗(yàn)證下效果。這里有些改動(dòng)桨嫁,加了間隔一定步驟之后植兰,保存checkpoint到本地的功能,防止中間如果有異常瞧甩,比如斷電或者不小心關(guān)閉程序或者其他原因造成程序崩潰钉跷,導(dǎo)致前功盡棄。
train_chat.py
# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import logging
import logging.handlers
if __name__ == '__main__':
_, _, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
_, _, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
train_set = generate_chat.read_data(generate_chat.train_encode_vec_file, generate_chat.train_decode_vec_file)
test_set = generate_chat.read_data(generate_chat.test_encode_vec_file, generate_chat.test_decode_vec_file)
train_bucket_sizes = [len(train_set[i]) for i in range(len(generate_chat._buckets))]
train_total_size = float(sum(train_bucket_sizes))
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
with tf.Session(config=cpu_config) as sess:
model = seq2seq_model.Seq2SeqModel(source_vocab_size,
target_vocab_size,
generate_chat._buckets,
generate_chat.units_num,
generate_chat.num_layers,
generate_chat.max_gradient_norm,
generate_chat.batch_size,
generate_chat.learning_rate,
generate_chat.learning_rate_decay_factor,
use_lstm=True)
ckpt = tf.train.get_checkpoint_state('./mytrain')
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
sess.run(tf.global_variables_initializer())
loss = 0.0
step = 0
previous_losses = []
while True:
random_number_01 = np.random.random_sample()
bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
print("step:%d,loss:%f" % (step, step_loss))
loss += step_loss / 2000
step += 1
if step % 1000 == 0:
print("step:%d,per_loss:%f" % (step, loss))
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
model.saver.save(sess, "mytrain/chatbot.ckpt", global_step=model.global_step)
loss = 0.0
if step % 5000 == 0:
for bucket_id in range(len(generate_chat._buckets)):
if len(test_set[bucket_id]) == 0:
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
True)
print("bucket_id:%d,eval_loss:%f" % (bucket_id, eval_loss))
- 對(duì)話模塊
chat.py
# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import sys
if __name__ == '__main__':
source_id_to_word, source_word_to_id, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
target_id_to_word, target_word_to_id, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
to_id = lambda word: source_word_to_id.get(word, generate_chat.UNK_ID)
cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
with tf.Session(config=cpu_config) as sess:
model = seq2seq_model.Seq2SeqModel(source_vocab_size,
target_vocab_size,
generate_chat._buckets,
generate_chat.units_num,
generate_chat.num_layers,
generate_chat.max_gradient_norm,
1,
generate_chat.learning_rate,
generate_chat.learning_rate_decay_factor,
forward_only=True,
use_lstm=True)
#model.saver.restore(sess, "model/chatbot.ckpt-317000")
model.saver.restore(sess, "mytrain/chatbot.ckpt-717000")
while True:
sys.stdout.write("ask > ")
sys.stdout.flush()
sentence = sys.stdin.readline().strip('\n')
flag = generate_chat.is_chinese(sentence)
if not sentence or not flag:
print("請(qǐng)輸入純中文")
continue
sentence_vec = list(map(to_id, sentence))
bucket_id = len(generate_chat._buckets) - 1
if len(sentence_vec) > generate_chat._buckets[bucket_id][0]:
print("sentence too long max:%d" % generate_chat._buckets[bucket_id][0])
exit(0)
for i, bucket in enumerate(generate_chat._buckets):
if bucket[0] >= len(sentence_vec):
bucket_id = i
break
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(sentence_vec, [])]},
bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if generate_chat.EOS_ID in outputs:
outputs = outputs[:outputs.index(generate_chat.EOS_ID)]
answer = "".join([tf.compat.as_str(target_id_to_word[output]) for output in outputs])
print("answer > " + answer)
注意
這里在train_chat.py 和 chat.py中肚逸,tf.session
有個(gè)配置改動(dòng)爷辙,限制了使用的CPU數(shù),在Ubuntu下如果沒(méi)有限制朦促,會(huì)造成TF占用所有的CPU資源膝晾,導(dǎo)致系統(tǒng)卡死,具體數(shù)值根據(jù)CPU核心數(shù)設(shè)置务冕。
代碼如下:
cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
with tf.Session(config=cpu_config) as sess:
結(jié)語(yǔ)
感謝閱讀血当,最后放上實(shí)驗(yàn)的實(shí)際地址和我自己訓(xùn)練的所有資源,本地實(shí)驗(yàn)在mac tf 1.12.0 和 python3.6.7,以及Ubuntu tf.1.12.0 和 python3.5環(huán)境下都正常臊旭,再次建議在virtualenv環(huán)境下落恼。
實(shí)驗(yàn)鏈接(時(shí)間過(guò)久可能失效):https://cloud.tencent.com/developer/labs/lab/10406
本地實(shí)驗(yàn)資源:https://iss.igosh.com/share/201903/tencent-me.tar.gz