Tensorflow 基本文本分類

導入工具庫

import tensorflow as tf
from tensorflow import keras

import numpy as np

print(tf.__version__)
1.10.0

導入數(shù)據(jù)

導入數(shù)據(jù)集赵刑,仍然是采用國內(nèi)特色的導入方式玲躯,先自己下載嚼锄,然后再導入减拭。

imdb = keras.datasets.imdb

(train_data, train_labels), (test_data, test_labels) = imdb.load_data(path = 'H:/tf_project/imdb.npz',num_words=10000)

npz格式的數(shù)據(jù)也可以直接使用np.load()導入,導入格式為類似字典的格式区丑,可以使用dict()將之轉(zhuǎn)化為字典拧粪。

npy格式的數(shù)據(jù)文件使用np.load()導入之后直接就是array()格式。

print("Training entries: {}, labels: {}".format(len(train_data), len(train_labels)))
Training entries: 25000, labels: 25000
print(train_data[0])
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
len(train_data[0]), len(train_data[1])
(218, 189)

將整型數(shù)組重新轉(zhuǎn)化為單詞

# A dictionary mapping words to an integer index
# 直接下載
# 地址:  https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json
word_index = imdb.get_word_index(r"H:\tf_project\imdb_word_index.json")
word_index
{'fawn': 34701,
 'tsukino': 52006,
 'nunnery': 52007,
 'sonja': 16816,
 'vani': 63951,
 'woods': 1408,
 ...}
# 所有word的編碼往后移三位
word_index = {k:(v+3) for k,v in word_index.items()} 
# 添加其他標記符
word_index["<PAD>"] = 0
word_index["<START>"] = 1
word_index["<UNK>"] = 2  # unknown
word_index["<UNUSED>"] = 3

# 翻轉(zhuǎn)key和value
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

def decode_review(text):
    return ' '.join([reverse_word_index.get(i, '?') for i in text])

將訓練數(shù)據(jù)轉(zhuǎn)化為文字

decode_review(train_data[0])
"<START> this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert <UNK> is an amazing actor and now the same being director <UNK> father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for <UNK> and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all"

array格式的電影評論需要轉(zhuǎn)化為張量才能傳入網(wǎng)絡沧侥,可以通過如下兩種方式實現(xiàn):

  1. 獨熱編碼可霎,即轉(zhuǎn)化為只包含0和1的向量,如列表[3,5]可以被轉(zhuǎn)化為一個10000維向量正什,該向量中除了下標為3和5的位置為1啥纸,其他位置均為0。這種方式對內(nèi)存要求比較高婴氮。

  2. 我們可以填充數(shù)組斯棒,使所有的數(shù)組具備相同的長度,然后傳入到網(wǎng)絡中主经。

train_data = keras.preprocessing.sequence.pad_sequences(train_data,
                                                        value=word_index["<PAD>"],
                                                        padding='post',
                                                        maxlen=256)
test_data = keras.preprocessing.sequence.pad_sequences(test_data,
                                                       value=word_index["<PAD>"],
                                                       padding='post',
                                                       maxlen=256)
len(train_data[0]), len(train_data[1])
(256, 256)
print(train_data[0])
[   1   14   22   16   43  530  973 1622 1385   65  458 4468   66 3941    4
  173   36  256    5   25  100   43  838  112   50  670    2    9   35  480
  284    5  150    4  172  112  167    2  336  385   39    4  172 4536 1111
   17  546   38   13  447    4  192   50   16    6  147 2025   19   14   22
    4 1920 4613  469    4   22   71   87   12   16   43  530   38   76   15
   13 1247    4   22   17  515   17   12   16  626   18    2    5   62  386
   12    8  316    8  106    5    4 2223 5244   16  480   66 3785   33    4
  130   12   16   38  619    5   25  124   51   36  135   48   25 1415   33
    6   22   12  215   28   77   52    5   14  407   16   82    2    8    4
  107  117 5952   15  256    4    2    7 3766    5  723   36   71   43  530
  476   26  400  317   46    7    4    2 1029   13  104   88    4  381   15
  297   98   32 2071   56   26  141    6  194 7486   18    4  226   22   21
  134  476   26  480    5  144   30 5535   18   51   36   28  224   92   25
  104    4  226   65   16   38 1334   88   12   16  283    5   16 4472  113
  103   32   15   16 5345   19  178   32    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0]

搭建模型

vocab_size = 10000

model = keras.Sequential()
model.add(keras.layers.Embedding(vocab_size, 16))
model.add(keras.layers.GlobalAveragePooling1D())
model.add(keras.layers.Dense(16, activation=tf.nn.relu))
model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid))

model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 16)          160000    
_________________________________________________________________
global_average_pooling1d (Gl (None, 16)                0         
_________________________________________________________________
dense (Dense)                (None, 16)                272       
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 160,289
Trainable params: 160,289
Non-trainable params: 0
_________________________________________________________________
  1. 第一層是Embedding層

  2. 第二層是GlobalAveragePooling1D層

  3. 第三四層是全連接層

  4. 輸出層只有一個節(jié)點钻洒,使用sigmoid激活函數(shù)將結(jié)果約束到0-1之間超全。

損失函數(shù)和優(yōu)化器

model.compile(optimizer=tf.train.AdamOptimizer(),
              loss='binary_crossentropy',
              metrics=['accuracy'])

創(chuàng)建驗證集

x_val = train_data[:10000]
partial_x_train = train_data[10000:]

y_val = train_labels[:10000]
partial_y_train = train_labels[10000:]

訓練模型

這里的history是fit的返回值,包括了訓練過程中指標變化信息

history = model.fit(partial_x_train,
                    partial_y_train,
                    epochs=40,
                    batch_size=512,
                    validation_data=(x_val, y_val),
                    verbose=1)
Train on 15000 samples, validate on 10000 samples
Epoch 1/40
15000/15000 [==============================] - 4s 249us/step - loss: 0.7391 - acc: 0.5035 - val_loss: 0.7010 - val_acc: 0.4947
Epoch 2/40
15000/15000 [==============================] - 1s 52us/step - loss: 0.6931 - acc: 0.5251 - val_loss: 0.6912 - val_acc: 0.5338
Epoch 3/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.6903 - acc: 0.5801 - val_loss: 0.6897 - val_acc: 0.5656
Epoch 4/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.6884 - acc: 0.6543 - val_loss: 0.6879 - val_acc: 0.6747
Epoch 5/40
15000/15000 [==============================] - 1s 48us/step - loss: 0.6864 - acc: 0.6421 - val_loss: 0.6860 - val_acc: 0.7004
Epoch 6/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.6841 - acc: 0.7283 - val_loss: 0.6837 - val_acc: 0.7259
Epoch 7/40
15000/15000 [==============================] - 1s 53us/step - loss: 0.6810 - acc: 0.7203 - val_loss: 0.6805 - val_acc: 0.6978
Epoch 8/40
15000/15000 [==============================] - 1s 53us/step - loss: 0.6769 - acc: 0.7057 - val_loss: 0.6759 - val_acc: 0.6885
Epoch 9/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.6707 - acc: 0.7150 - val_loss: 0.6695 - val_acc: 0.7142
Epoch 10/40
15000/15000 [==============================] - 1s 56us/step - loss: 0.6628 - acc: 0.7443 - val_loss: 0.6610 - val_acc: 0.7356
Epoch 11/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.6529 - acc: 0.7487 - val_loss: 0.6503 - val_acc: 0.7497
Epoch 12/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.6387 - acc: 0.7843 - val_loss: 0.6345 - val_acc: 0.7720
Epoch 13/40
15000/15000 [==============================] - 1s 56us/step - loss: 0.6182 - acc: 0.7861 - val_loss: 0.6157 - val_acc: 0.7727
Epoch 14/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.5933 - acc: 0.7986 - val_loss: 0.5889 - val_acc: 0.7900
Epoch 15/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.5614 - acc: 0.8103 - val_loss: 0.5584 - val_acc: 0.7956
Epoch 16/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.5295 - acc: 0.8157 - val_loss: 0.5293 - val_acc: 0.8052
Epoch 17/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.4963 - acc: 0.8327 - val_loss: 0.5008 - val_acc: 0.8192
Epoch 18/40
15000/15000 [==============================] - 1s 52us/step - loss: 0.4647 - acc: 0.8423 - val_loss: 0.4726 - val_acc: 0.8273
Epoch 19/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.4349 - acc: 0.8519 - val_loss: 0.4471 - val_acc: 0.8363
Epoch 20/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.4076 - acc: 0.8607 - val_loss: 0.4243 - val_acc: 0.8434
Epoch 21/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.3829 - acc: 0.8707 - val_loss: 0.4043 - val_acc: 0.8489
Epoch 22/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.3612 - acc: 0.8773 - val_loss: 0.3872 - val_acc: 0.8547
Epoch 23/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.3424 - acc: 0.8833 - val_loss: 0.3729 - val_acc: 0.8587
Epoch 24/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.3256 - acc: 0.8885 - val_loss: 0.3605 - val_acc: 0.8643
Epoch 25/40
15000/15000 [==============================] - 1s 48us/step - loss: 0.3111 - acc: 0.8935 - val_loss: 0.3500 - val_acc: 0.8673
Epoch 26/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.2980 - acc: 0.8960 - val_loss: 0.3415 - val_acc: 0.8698
Epoch 27/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.2868 - acc: 0.8989 - val_loss: 0.3338 - val_acc: 0.8711
Epoch 28/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.2758 - acc: 0.9039 - val_loss: 0.3268 - val_acc: 0.8746
Epoch 29/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.2666 - acc: 0.9058 - val_loss: 0.3218 - val_acc: 0.8751
Epoch 30/40
15000/15000 [==============================] - 1s 53us/step - loss: 0.2588 - acc: 0.9079 - val_loss: 0.3164 - val_acc: 0.8768
Epoch 31/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.2498 - acc: 0.9125 - val_loss: 0.3124 - val_acc: 0.8769
Epoch 32/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.2431 - acc: 0.9135 - val_loss: 0.3086 - val_acc: 0.8793
Epoch 33/40
15000/15000 [==============================] - 1s 50us/step - loss: 0.2352 - acc: 0.9170 - val_loss: 0.3052 - val_acc: 0.8805
Epoch 34/40
15000/15000 [==============================] - 1s 47us/step - loss: 0.2288 - acc: 0.9183 - val_loss: 0.3030 - val_acc: 0.8807
Epoch 35/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.2231 - acc: 0.9195 - val_loss: 0.2998 - val_acc: 0.8802
Epoch 36/40
15000/15000 [==============================] - 1s 51us/step - loss: 0.2166 - acc: 0.9220 - val_loss: 0.2975 - val_acc: 0.8825
Epoch 37/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.2111 - acc: 0.9247 - val_loss: 0.2956 - val_acc: 0.8831
Epoch 38/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.2058 - acc: 0.9259 - val_loss: 0.2940 - val_acc: 0.8834
Epoch 39/40
15000/15000 [==============================] - 1s 49us/step - loss: 0.2003 - acc: 0.9294 - val_loss: 0.2922 - val_acc: 0.8846
Epoch 40/40
15000/15000 [==============================] - 1s 52us/step - loss: 0.1953 - acc: 0.9307 - val_loss: 0.2908 - val_acc: 0.8848

驗證模型

results = model.evaluate(test_data, test_labels)

print(results)
25000/25000 [==============================] - 2s 86us/step
[0.3060342230606079, 0.87492000000000003]

結(jié)果可視化

history_dict = history.history
history_dict.keys()
dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])

誤差率

import matplotlib.pyplot as plt
%matplotlib inline
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

# "bo" is for "blue dot"
plt.plot(epochs, loss, 'bo', label='Training loss')
# b is for "solid blue line"
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()
output_31_0.png

上圖中,大約20個epochs之后系宜,validation loss的下降變緩,模型開始出現(xiàn)過擬合現(xiàn)象机打。

準確率

plt.clf()   # clear figure
acc_values = history_dict['acc']
val_acc_values = history_dict['val_acc']

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show()
output_34_0.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末拄养,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子节吮,更是在濱河造成了極大的恐慌抽高,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,386評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件透绩,死亡現(xiàn)場離奇詭異翘骂,居然都是意外死亡壁熄,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,142評論 3 394
  • 文/潘曉璐 我一進店門碳竟,熙熙樓的掌柜王于貴愁眉苦臉地迎上來草丧,“玉大人,你說我怎么就攤上這事莹桅〔矗” “怎么了?”我有些...
    開封第一講書人閱讀 164,704評論 0 353
  • 文/不壞的土叔 我叫張陵统翩,是天一觀的道長仙蚜。 經(jīng)常有香客問我,道長厂汗,這世上最難降的妖魔是什么委粉? 我笑而不...
    開封第一講書人閱讀 58,702評論 1 294
  • 正文 為了忘掉前任,我火速辦了婚禮娶桦,結(jié)果婚禮上贾节,老公的妹妹穿的比我還像新娘。我一直安慰自己衷畦,他們只是感情好栗涂,可當我...
    茶點故事閱讀 67,716評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著祈争,像睡著了一般斤程。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上菩混,一...
    開封第一講書人閱讀 51,573評論 1 305
  • 那天忿墅,我揣著相機與錄音,去河邊找鬼沮峡。 笑死疚脐,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的邢疙。 我是一名探鬼主播棍弄,決...
    沈念sama閱讀 40,314評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼疟游!你這毒婦竟也來了呼畸?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,230評論 0 276
  • 序言:老撾萬榮一對情侶失蹤颁虐,失蹤者是張志新(化名)和其女友劉穎役耕,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體聪廉,經(jīng)...
    沈念sama閱讀 45,680評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡瞬痘,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,873評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了板熊。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片框全。...
    茶點故事閱讀 39,991評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖干签,靈堂內(nèi)的尸體忽然破棺而出津辩,到底是詐尸還是另有隱情,我是刑警寧澤容劳,帶...
    沈念sama閱讀 35,706評論 5 346
  • 正文 年R本政府宣布喘沿,位于F島的核電站,受9級特大地震影響竭贩,放射性物質(zhì)發(fā)生泄漏蚜印。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,329評論 3 330
  • 文/蒙蒙 一留量、第九天 我趴在偏房一處隱蔽的房頂上張望窄赋。 院中可真熱鬧,春花似錦楼熄、人聲如沸忆绰。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,910評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽错敢。三九已至,卻和暖如春缕粹,著一層夾襖步出監(jiān)牢的瞬間稚茅,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,038評論 1 270
  • 我被黑心中介騙來泰國打工致开, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留峰锁,地道東北人。 一個月前我還...
    沈念sama閱讀 48,158評論 3 370
  • 正文 我出身青樓双戳,卻偏偏與公主長得像虹蒋,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子飒货,可洞房花燭夜當晚...
    茶點故事閱讀 44,941評論 2 355

推薦閱讀更多精彩內(nèi)容