原文地址:Keras 實(shí)現(xiàn) LSTM
本文在原文的基礎(chǔ)上添加了一些注釋、運(yùn)行結(jié)果和修改了少量的代碼垛耳。
1. 介紹
LSTM(Long Short Term Memory)是一種特殊的循環(huán)神經(jīng)網(wǎng)絡(luò)溯乒,在許多任務(wù)中俐填,LSTM表現(xiàn)得比標(biāo)準(zhǔn)的RNN要出色得多吭历。
關(guān)于LSTM的介紹可以看參考文獻(xiàn)1和2抡医。本文重點(diǎn)在使用LSTM實(shí)現(xiàn)一個(gè)分類器。
2. 如何在 keras 中使用LSTM
本文主要測(cè)試 keras
使用Word Embeddings
并進(jìn)行分類的測(cè)試蔑祟。代碼是在keras
官方文檔的示例中修改而來趁耗。IPython代碼鏈接
2.1 Word Embeddings 數(shù)據(jù)集
使用了stanford的GloVe作為詞向量集,這個(gè)直接下載訓(xùn)練好的詞向量文件疆虚。直接字典搜索苛败,得到文本詞向量。Glove數(shù)據(jù)集下載文本測(cè)試數(shù)據(jù)是20_newsgroup
This data set is a collection of 20,000 messages, collected from 20 different netnews newsgroups. One thousand messages from each of the twenty newsgroups were chosen at random and partitioned by newsgroup name. The list of newsgroups from which the messages were chose is as follows:
alt.atheism
talk.politics.guns
talk.politics.mideast
talk.politics.misc
talk.religion.misc
soc.religion.christian
comp.sys.ibm.pc.hardware
comp.graphics
comp.os.ms-windows.misc
comp.sys.mac.hardware
comp.windows.x
rec.autos
rec.motorcycles
rec.sport.baseball
rec.sport.hockey
sci.crypt
sci.electronics
sci.space
sci.med
misc.forsale
我們通過label標(biāo)注把message分成不同的20個(gè)類別径簿。每個(gè)newsgroup被map到一個(gè)數(shù)值label上罢屈。
需要用到的模塊
import numpy as np
import os
import sys
import random
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils.np_utils import to_categorical
from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dense, Activation
2.2 數(shù)據(jù)預(yù)處理
這部分是設(shè)定訓(xùn)練相關(guān)參數(shù),并且讀入訓(xùn)練好的GloVe詞向量文件篇亭。把文本讀入進(jìn)list里缠捌,一個(gè)文本存成一個(gè)str,變成一個(gè)[str]
BASE_DIR = '/home/lich/Workspace/Learning'
GLOVE_DIR = BASE_DIR + '/glove.6B/'
TEXT_DATA_DIR = BASE_DIR + '/20_newsgroup/'
MAX_SEQUENCE_LENGTH = 1000
MAX_NB_WORDS = 20000
EMBEDDING_DIM = 100
VALIDATION_SPLIT = 0.2
batch_size = 32
# first, build index mapping words in the embeddings set
# to their embedding vector
embeddings_index = {}
f = open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt'))
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
f.close()
print('Found %s word vectors.' % len(embeddings_index))
#Found 400000 word vectors.
# second, prepare text samples and their labels
print('Processing text dataset')
texts = [] # list of text samples
labels_index = {} # dictionary mapping label name to numeric id
labels = [] # list of label ids
for name in sorted(os.listdir(TEXT_DATA_DIR)):
path = os.path.join(TEXT_DATA_DIR, name)
if os.path.isdir(path):
label_id = len(labels_index)
labels_index[name] = label_id
for fname in sorted(os.listdir(path)):
if fname.isdigit():
fpath = os.path.join(path, fname)
if sys.version_info < (3,):
f = open(fpath)
else:
f = open(fpath, encoding='latin-1')
texts.append(f.read())
f.close()
labels.append(label_id)
print('Found %s texts.' % len(texts))
#Found 19997 texts.
embeddings_index 里面是這樣:
embeddings_index['hi']
"""
array([ 0.1444 , 0.23978999, 0.96692997, 0.31628999, -0.36063999,
-0.87673998, 0.098512 , 0.31077999, 0.47929001, 0.27175 ,
0.30004999, -0.23732001, -0.31516999, 0.17925 , 0.61773002,
0.59820998, 0.49489 , 0.3423 , -0.078034 , 0.60211998,
0.18683 , 0.52069998, -0.12331 , 0.48313001, -0.24117 ,
0.59696001, 0.61078 , -0.84413999, 0.27660999, 0.068767 ,
-1.13880002, 0.089544 , 0.89841998, 0.53788 , 0.10841 ,
-0.10038 , 0.12921 , 0.11476 , -0.47400001, -0.80489999,
0.95999998, -0.36601999, -0.43019 , -0.39807999, -0.096782 ,
-0.71183997, -0.31494001, 0.82345998, 0.42179 , -0.69204998,
-1.48640001, 0.29497999, -0.30875 , -0.49994999, -0.46489999,
-0.44523999, 0.81059998, 1.47570002, 0.53781998, -0.28270999,
-0.045796 , 0.14454 , -0.74484998, 0.35495001, -0.40961 ,
0.35778999, 0.40061 , 0.37338999, 0.72162998, 0.40812999,
0.26155001, -0.14239 , -0.020514 , -1.11059999, -0.47670001,
0.37832001, 0.89612001, -0.17323001, -0.50137001, 0.22991 ,
1.53240001, -0.82032001, -0.10096 , 0.45201999, -0.88638997,
0.089056 , -0.19347 , -0.42253 , 0.022429 , 0.29444 ,
0.020747 , 0.48934999, 0.35991001, 0.092758 , -0.22428 ,
0.60038 , -0.31850001, -0.72424001, -0.22632 , -0.030972 ], dtype=float32)
"""
embeddings_index['hi'].shape
# (100,)
labels_index 與 20_newsgroup 的20個(gè)分類一一對(duì)應(yīng)
labels_index['alt.atheism']
#0
labels_index['comp.sys.ibm.pc.hardware']
#3
labels[:10]
#[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
labels[1000:1010]
#[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels[2000:2010]
#[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
打開其中一個(gè) texts 看看
len(texts[2])
#4550
texts[2]
"""
Organization: Technical University Braunschweig, Germany
References: <16BA1E197.I3150101@dbstu1.rz.tu-bs.de> <65974@mimsy.umd.edu>
Date: Mon, 5 Apr 1993 19:08:25 GMT
Lines: 93
In article <65974@mimsy.umd.edu>
mangoe@cs.umd.edu (Charley Wingate) writes:
Well, John has a quite different, not necessarily more elaborated theology.
There is some evidence that he must have known Luke, and that the content
of Q was known to him, but not in a 'canonized' form.
This is a new argument to me. Could you elaborate a little?
The argument goes as follows: Q-oid quotes appear in John, but not in
the almost codified way they were in Matthew or Luke. However, they are
considered to be similar enough to point to knowledge of Q as such, and
not an entirely different source.
Assuming that he knew Luke would obviously put him after Luke, and would
give evidence for the latter assumption.
I don't think this follows. If you take the most traditional attributions,
then Luke might have known John, but John is an elder figure in either case.
We're talking spans of time here which are well within the range of
lifetimes.
We are talking date of texts here, not the age of the authors. The usual
explanation for the time order of Mark, Matthew and Luke does not consider
their respective ages. It says Matthew has read the text of Mark, and Luke
that of Matthew (and probably that of Mark).
As it is assumed that John knew the content of Luke's text. The evidence
for that is not overwhelming, admittedly.
(1) Earlier manuscripts of John have been discovered.
Interesting, where and which? How are they dated? How old are they?
Unfortunately, I haven't got the info at hand. It was (I think) in the late
'70s or early '80s, and it was possibly as old as CE 200.
When they are from about 200, why do they shed doubt on the order on
putting John after the rest of the three?
I don't see your point, it is exactly what James Felder said. They had no
first hand knowledge of the events, and it obvious that at least two of them
used older texts as the base of their account. And even the association of
Luke to Paul or Mark to Peter are not generally accepted.
Well, a genuine letter of Peter would be close enough, wouldn't it?
Sure, an original together with Id card of sender and receiver would be
fine. So what's that supposed to say? Am I missing something?
And I don't think a "one step removed" source is that bad. If Luke and Mark
and Matthew learned their stories directly from diciples, then I really
cannot believe in the sort of "big transformation from Jesus to gospel" that
some people posit. In news reports, one generally gets no better
information than this.
And if John IS a diciple, then there's nothing more to be said.
That John was a disciple is not generally accepted. The style and language
together with the theology are usually used as counterargument.
The argument that John was a disciple relies on the claim in the gospel
of John itself. Is there any other evidence for it?
One step and one generation removed is bad even in our times. Compare that
to reports of similar events in our century in almost illiterate societies.
Not even to speak off that believers are not necessarily the best sources.
It is also obvious that Mark has been edited. How old are the oldest
manuscripts? To my knowledge (which can be antiquated) the oldest is
quite after any of these estimates, and it is not even complete.
The only clear "editing" is problem of the ending, and it's basically a
hopeless mess. The oldest versions give a strong sense of incompleteness,
to the point where the shortest versions seem to break off in midsentence.
The most obvious solution is that at some point part of the text was lost.
The material from verse 9 on is pretty clearly later and seems to represent
a synopsys of the end of Luke.
In other words, one does not know what the original of Mark did look like
and arguments based on Mark are pretty weak.
But how is that connected to a redating of John?
Benedikt
"""
2.3Tokenize
Tokenizer 所有文本暗赶,并且把texts里面的str值先tokenizer然后映射到相應(yīng)index。下面是舉出的一個(gè)例子(只是形式一樣):
“he is a professor”
變成:
[143, 12, 1, 23]
# finally, vectorize the text samples into a 2D integer tensor
tokenizer = Tokenizer(nb_words=MAX_NB_WORDS)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)
word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))
#Found 214909 unique tokens.
上面的代碼吧所有的單詞都轉(zhuǎn)換成了數(shù)字
word_index['newsgroups']
# 43
sequences[2][:20]
"""
[43,
127,
357,
44,
29,
24,
16,
12,
2,
160,
24,
16,
12,
2,
195,
185,
12,
2,
182,
144]
"""
2.4 生成Train和Validate數(shù)據(jù)集
使用random.shuffle進(jìn)行隨機(jī)分割數(shù)據(jù)集肃叶,并聲稱相關(guān)訓(xùn)練驗(yàn)證集蹂随。
data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
labels = to_categorical(np.asarray(labels))
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)
# ('Shape of data tensor:', (19997, 1000))
# ('Shape of label tensor:', (19997, 20))
# split the data into a training set and a validation set
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]
nb_validation_samples = int(VALIDATION_SPLIT * data.shape[0])
x_train = data[:-nb_validation_samples]
y_train = labels[:-nb_validation_samples]
x_train.shape
#(15998, 1000)
y_train.shape
#(15998, 20)
x_val = data[-nb_validation_samples:]
y_val = labels[-nb_validation_samples:]
print('Preparing embedding matrix.')
data 是一個(gè)長(zhǎng)度為 1000 的 array,sequences 中不夠長(zhǎng)的部分被補(bǔ)0了因惭。
labels 被轉(zhuǎn)換成了 one-hot 編碼的形式岳锁。
len(data)
#1000
data[2]
"""
array([ 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
...
...
93, 6, 1818, 480, 19, 471, 25, 668, 2797,
35, 111, 9, 10, 2425, 3, 5, 4, 370, 5271], dtype=int32)
"""
labels[0]
"""
array([ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.])
"""
labels[1000]
"""
array([ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.])
"""
labels[2000]
"""
array([ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.])
"""
2.5 生成Embedding Matrix
把有效出現(xiàn)次數(shù)在前面的通過GloVe生成的字典,以及本身所有的Token串進(jìn)行比對(duì)蹦魔,得到出現(xiàn)在訓(xùn)練集中每個(gè)詞的詞向量激率。
nb_words = min(MAX_NB_WORDS, len(word_index))
#20000
embedding_matrix = np.zeros((nb_words + 1, EMBEDDING_DIM))
for word, i in word_index.items():
if i > MAX_NB_WORDS:
continue
embedding_vector = embeddings_index.get(word)
if embedding_vector is not None:
# words not found in embedding index will be all-zeros.
embedding_matrix[i] = embedding_vector
print(embedding_matrix.shape)
#(20001, 100)
embedding_matrix 和 embeddings_index 是這樣的:
embedding_matrix[76]
'''
array([ 0.1225 , -0.058833 , 0.23658 , -0.28876999, -0.028181 ,
0.31524 , 0.070229 , 0.16447 , -0.027623 , 0.25213999,
0.21174 , -0.059674 , 0.36133 , 0.13607 , 0.18754999,
-0.1487 , 0.31314999, 0.13368 , -0.59702998, -0.030161 ,
0.080656 , 0.26161999, -0.055924 , -0.35350999, 0.34722 ,
-0.0055801 , -0.57934999, -0.88006997, 0.42930999, -0.15695 ,
-0.51256001, 1.26839995, -0.25228 , 0.35264999, -0.46419001,
0.55647999, -0.57555997, 0.32574001, -0.21893001, -0.13178 ,
-1.1027 , -0.039591 , 0.89643002, -0.98449999, -0.47393 ,
-0.12854999, 0.63506001, -0.94888002, 0.40088001, -0.77542001,
-0.35152999, -0.27788001, 0.68747002, 1.45799994, -0.38474 ,
-2.89369988, -0.29523 , -0.38835999, 0.94880998, 1.38909996,
0.054591 , 0.70485997, -0.65698999, 0.075648 , 0.76550001,
-0.63365 , 0.86556 , 0.42440999, 0.14796001, 0.4156 ,
0.29354 , -0.51295 , 0.19634999, -0.45568001, 0.0080246 ,
0.14528 , -0.15395001, 0.11406 , -1.21669996, -0.1111 ,
0.82639998, 0.21738 , -0.63775998, -0.074874 , -1.71300006,
-0.88270003, -0.0073058 , -0.37623 , -0.50208998, -0.58844 ,
-0.24943 , -1.04250002, 0.27678001, 0.64142001, -0.64604998,
0.43559 , -0.37276 , -0.0032068 , 0.18743999, 0.30702001])
'''
embeddings_index.get('he')
'''
array([ 0.1225 , -0.058833 , 0.23658 , -0.28876999, -0.028181 ,
0.31524 , 0.070229 , 0.16447 , -0.027623 , 0.25213999,
0.21174 , -0.059674 , 0.36133 , 0.13607 , 0.18754999,
-0.1487 , 0.31314999, 0.13368 , -0.59702998, -0.030161 ,
0.080656 , 0.26161999, -0.055924 , -0.35350999, 0.34722 ,
-0.0055801 , -0.57934999, -0.88006997, 0.42930999, -0.15695 ,
-0.51256001, 1.26839995, -0.25228 , 0.35264999, -0.46419001,
0.55647999, -0.57555997, 0.32574001, -0.21893001, -0.13178 ,
-1.1027 , -0.039591 , 0.89643002, -0.98449999, -0.47393 ,
-0.12854999, 0.63506001, -0.94888002, 0.40088001, -0.77542001,
-0.35152999, -0.27788001, 0.68747002, 1.45799994, -0.38474 ,
-2.89369988, -0.29523 , -0.38835999, 0.94880998, 1.38909996,
0.054591 , 0.70485997, -0.65698999, 0.075648 , 0.76550001,
-0.63365 , 0.86556 , 0.42440999, 0.14796001, 0.4156 ,
0.29354 , -0.51295 , 0.19634999, -0.45568001, 0.0080246 ,
0.14528 , -0.15395001, 0.11406 , -1.21669996, -0.1111 ,
0.82639998, 0.21738 , -0.63775998, -0.074874 , -1.71300006,
-0.88270003, -0.0073058 , -0.37623 , -0.50208998, -0.58844 ,
-0.24943 , -1.04250002, 0.27678001, 0.64142001, -0.64604998,
0.43559 , -0.37276 , -0.0032068 , 0.18743999, 0.30702001], dtype=float32)
'''
embeddings_index.get('he') == embedding_matrix[76]
'''
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True], dtype=bool)
'''
2.6 LSTM訓(xùn)練
注意訓(xùn)練集data的shape是(N_SAMPLES, MAX_SEQUENCE_LENGT),100是詞向量長(zhǎng)度勿决,然后根據(jù)Embedding層會(huì)變成3D的Matrix
如果不清楚 Word Embedding 可以參考在Keras模型中使用預(yù)訓(xùn)練的詞向量
因?yàn)?keras 版本的問題乒躺,運(yùn)行原文的代碼會(huì)出了一個(gè)錯(cuò)誤,本文根據(jù)這里進(jìn)行了更改低缩。將:
embedding_layer = Embedding(nb_words + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False,
dropout=0.2)
中的 trainable=False 去掉嘉冒,在后面加上 model.layers[1].trainable=False
embedding_layer = Embedding(nb_words + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
dropout=0.2)
print('Build model...')
# sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
# embedded_sequences = embedding_layer()
model = Sequential()
model.add(embedding_layer)
model.add(LSTM(100, dropout_W=0.2, dropout_U=0.2)) # try using a GRU instead, for fun
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.add(Dense(len(labels_index), activation='softmax'))
model.layers[1].trainable=False
網(wǎng)絡(luò)的模型是個(gè)樣子的:
model.summary()
"""
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
embedding_1 (Embedding) (None, 1000, 100) 2000100 embedding_input_1[0][0]
____________________________________________________________________________________________________
lstm_1 (LSTM) (None, 100) 80400 embedding_1[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 1) 101 lstm_1[0][0]
____________________________________________________________________________________________________
activation_1 (Activation) (None, 1) 0 dense_1[0][0]
____________________________________________________________________________________________________
dense_2 (Dense) (None, 20) 40 activation_1[0][0]
====================================================================================================
Total params: 2,080,641
Trainable params: 2,000,241
Non-trainable params: 80,400
____________________________________________________________________________________________________
"""
2.6 LSTM訓(xùn)練
注意訓(xùn)練集data的shape是(N_SAMPLES, MAX_SEQUENCE_LENGT)曹货,100是詞向量長(zhǎng)度,然后根據(jù)Embedding層會(huì)變成3D的Matrix讳推。
# try using different optimizers and different optimizer configs
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print('Train...')
model.fit(x_train, y_train, batch_size=batch_size, nb_epoch=5,
validation_data=(x_val, y_val))
score, acc = model.evaluate(x_val, y_val,
batch_size=batch_size)
print('Test score:', score)
print('Test accuracy:', acc)
"""
Train on 15998 samples, validate on 3999 samples
Epoch 1/5
608/15998 [>.............................] - ETA: 833s - loss: 0.1992 - acc: 0.9500
"""
后面我就懶得訓(xùn)練了,你們也看到了顶籽,渣渣電腦太慢了。
參考
[1] Understanding LSTM:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
[2] 理解 LSTM 網(wǎng)絡(luò):https://www.yunaitong.cn/understanding-lstm-networks.html
[2] GloVe: Global Vectors for Word Representation:http://nlp.stanford.edu/projects/glove