深度學(xué)習(xí)--Lstm+CNN 文本分類

本文從實(shí)踐的角度钮莲,來講一下如何構(gòu)建LSTM+CNN的模型對文本進(jìn)行分類夹抗。

本文Github

RNN網(wǎng)絡(luò)與CNN網(wǎng)絡(luò)可以分別用來進(jìn)行文本分類蒲祈。RNN網(wǎng)絡(luò)在文本分類中呆瞻,作用是用來提取句子的關(guān)鍵語義信息,根據(jù)提取的語義對文本進(jìn)行區(qū)分慨蛙;CNN的作用是用來提取文本的特征辽聊,根據(jù)特征進(jìn)行分類。LSTM+CNN的作用股淡,就是兩者的結(jié)合,首先抽取文本關(guān)鍵語義廷区,然后對語義提取關(guān)鍵特征唯灵。
需要了解CNN基本原理:https://zhuanlan.zhihu.com/p/28173972
需要了解RNN基本原理:http://www.reibang.com/p/32d3048da5ba隙轻。
個(gè)人認(rèn)為基礎(chǔ)知識講解的還不錯的博客埠帕。

數(shù)據(jù)來源

本實(shí)驗(yàn)是使用THUCNews的一個(gè)子集進(jìn)行訓(xùn)練與測試叁巨,數(shù)據(jù)集請自行到THUCTC:一個(gè)高效的中文文本分類工具包下載呐籽,請遵循數(shù)據(jù)提供方的開源協(xié)議;
文本類別涉及10個(gè)類別:categories = ['體育', '財(cái)經(jīng)', '房產(chǎn)', '家居', '教育', '科技', '時(shí)尚', '時(shí)政', '游戲', '娛樂']锋勺,每個(gè)分類6500條數(shù)據(jù);
cnews.train.txt: 訓(xùn)練集(500010)
cnews.val.txt: 驗(yàn)證集(500
10)
cnews.test.txt: 測試集(1000*10)

文本預(yù)處理

本文的預(yù)處理過程與文本分類--CNN大部分相同,其中有兩處不同。
1.在CNN分類中显蝌,文本的長度padding到了600;本次padding到了300。
2.針對動態(tài)RNN的特點(diǎn)神郊,增加計(jì)算每個(gè)batch中句子的真實(shí)長度夕晓。
代碼如下:

def seq_length(x_batch):
    real_seq_len = []
    for line in x_batch:
        real_seq_len.append(np.sum(np.sign(line)))
return real_seq_len

LSTM模型中的處理

定義占位符

        self.input_x = tf.placeholder(tf.int32, shape=[None, pm.seq_length], name='input_x')
        self.input_y = tf.placeholder(tf.float32, shape=[None, pm.num_classes], name='input_y')
        self.length = tf.placeholder(tf.int32, shape=[None], name='rnn_length')
        self.keep_pro = tf.placeholder(tf.float32, name='dropout')
        self.global_step = tf.Variable(0, trainable=False, name='global_step')

embedding層

使用預(yù)訓(xùn)練詞向量析既。

        with tf.device('/cpu:0'), tf.name_scope('embedding'):
            self.embedding = tf.get_variable("embeddings", shape=[pm.vocab_size, pm.embedding_dim],
                                             initializer=tf.constant_initializer(pm.pre_trianing))
            embedding_input = tf.nn.embedding_lookup(self.embedding, self.input_x)

LSTM層

        with tf.name_scope('LSTM'):
            cell = tf.nn.rnn_cell.LSTMCell(pm.hidden_dim, state_is_tuple=True)
            Cell = tf.contrib.rnn.DropoutWrapper(cell, self.keep_pro)
            output, _ = tf.nn.dynamic_rnn(cell=Cell, inputs=embedding_input, sequence_length=self.length, dtype=tf.float32)

以上為LSTM+CNN文本分類中囤屹,LSTM的環(huán)節(jié)智厌。針對動態(tài)RNN的情形,一般來說,只需將每個(gè)batch中的句子padding到等長即可睛约,但為了遷就CNN模型棍丐,所以須將所有句子padding到等長潦匈,計(jì)算batch中句子的真實(shí)長度掂为,是動態(tài)RNN部分需要的,告訴動態(tài)RNN真實(shí)句子是多長,這樣可以將填充的部分輸出為0,不會將額外的信息帶到CNN層中馋评。

CNN層

為了將LSTM輸出的結(jié)果是三維的tensor右核,而我們進(jìn)行conv2d的CNN操作,需要四維tensor镰烧,故第一步是擴(kuò)展維度拢军。CNN環(huán)節(jié)參考文本分類--CNN

        with tf.name_scope('CNN'):
            outputs = tf.expand_dims(outputs, -1) #[batch_size, seq_length, hidden_dim, 1]
            pooled_outputs = []
            for i, filter_size in enumerate(pm.filters_size):
                filter_shape = [filter_size, pm.hidden_dim, 1, pm.num_filters]
                w = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name='w')
                b = tf.Variable(tf.constant(0.1, shape=[pm.num_filters]), name='b')
                conv = tf.nn.conv2d(outputs, w, strides=[1, 1, 1, 1], padding='VALID', name='conv')
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')

                pooled = tf.nn.max_pool(h, ksize=[1, pm.seq_length-filter_size+1, 1, 1],
                                        strides=[1, 1, 1, 1], padding='VALID', name='pool')
                pooled_outputs.append(pooled)
            output_ = tf.concat(pooled_outputs, 3)
            self.output = tf.reshape(output_, shape=[-1, 3*pm.num_filters])

全連接層

將CNN輸出結(jié)果進(jìn)行dropout與全連接進(jìn)行相連怔鳖。

        with tf.name_scope('output'):
            out_final = tf.nn.dropout(self.output, keep_prob=self.keep_pro)
            o_w = tf.Variable(tf.truncated_normal([3*pm.num_filters, pm.num_classes], stddev=0.1), name='o_w')
            o_b = tf.Variable(tf.constant(0.1, shape=[pm.num_classes]), name='o_b')
            self.logits = tf.matmul(out_final, o_w) + o_b
            self.predict = tf.argmax(tf.nn.softmax(self.logits), 1, name='score')

Loss

這里使用softmax交叉熵求loss, logits=self.scores 這里一定用的是未經(jīng)過softmax處理的數(shù)值茉唉。

        with tf.name_scope('loss'):
            cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
            self.loss = tf.reduce_mean(cross_entropy)

optimizer

這里使用了梯度裁剪。首先計(jì)算梯度结执,這個(gè)計(jì)算是類似L2正則化計(jì)算w的值度陆,也就是求平方再平方根钞支。然后與設(shè)定的clip裁剪值進(jìn)行比較咱娶,如果小于等于clip,梯度不變蚯撩;如果大于clip,則梯度*(clip/梯度L2值)卒稳。

        with tf.name_scope('optimizer'):
            # 退化學(xué)習(xí)率 learning_rate = lr*(0.9**(global_step/10);staircase=True表示每decay_steps更新梯度
            # learning_rate = tf.train.exponential_decay(self.config.lr, global_step=self.global_step,
            # decay_steps=10, decay_rate=self.config.lr_decay, staircase=True)
            # optimizer = tf.train.AdamOptimizer(learning_rate)
            # self.optimizer = optimizer.minimize(self.loss, global_step=self.global_step) #global_step 自動+1
            # no.2
            optimizer = tf.train.AdamOptimizer(pm.learning_rate)
            gradients, variables = zip(*optimizer.compute_gradients(self.loss))  # 計(jì)算變量梯度,得到梯度值,變量
            gradients, _ = tf.clip_by_global_norm(gradients, pm.clip)
            # 對g進(jìn)行l(wèi)2正則化計(jì)算阳藻,比較其與clip的值打月,如果l2后的值更大外臂,讓梯度*(clip/l2_g),得到新梯度
            self.optimizer = optimizer.apply_gradients(zip(gradients, variables), global_step=self.global_step)
           # global_step 自動+1

accuracy

最后,計(jì)算模型的準(zhǔn)確度犀斋。

        with tf.name_scope('accuracy'):
            correct = tf.equal(self.predict, tf.argmax(self.input_y, 1))
            self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy')

訓(xùn)練模型

global_step為100的倍數(shù)時(shí)贝乎,輸出當(dāng)前batch的訓(xùn)練loss,訓(xùn)練accuracy,在測試batch上的loss,accuracy;并每迭代完一次叽粹,保存一次模型糕非。

    x_train, y_train = process(pm.train_filename, wordid, cat_to_id, max_length=300)
    x_test, y_test = process(pm.test_filename, wordid, cat_to_id, max_length=300)
    for epoch in range(pm.num_epochs):
        print('Epoch:', epoch+1)
        num_batchs = int((len(x_train) - 1) / pm.batch_size) + 1
        batch_train = batch_iter(x_train, y_train, batch_size=pm.batch_size)
        for x_batch, y_batch in batch_train:
            real_seq_len = seq_length(x_batch)
            feed_dict = model.feed_data(x_batch, y_batch, real_seq_len, pm.keep_prob)
            _, global_step, _summary, train_loss, train_accuracy = session.run([model.optimizer, model.global_step, merged_summary,
                                                                                model.loss, model.accuracy], feed_dict=feed_dict)
            if global_step % 100 == 0:
                test_loss, test_accuracy = model.test(session, x_test, y_test)
                print('global_step:', global_step, 'train_loss:', train_loss, 'train_accuracy:', train_accuracy,
                      'test_loss:', test_loss, 'test_accuracy:', test_accuracy)

            if global_step % num_batchs == 0:
                print('Saving Model...')
                saver.save(session, save_path, global_step=global_step)
訓(xùn)練結(jié)果

由于小霸王運(yùn)行非常吃力,因此只進(jìn)行了3次迭代球榆。但從迭代的效果來看,結(jié)果很理想禁筏。在訓(xùn)練集的batch中最好達(dá)到100%持钉,同時(shí)測試集達(dá)到100%準(zhǔn)確。

驗(yàn)證模型

驗(yàn)證集有5000條語句篱昔,我用最后一次保存的模型每强,對5000條句子進(jìn)行預(yù)測,將預(yù)測的結(jié)果與原標(biāo)簽進(jìn)行對比州刽,得到驗(yàn)證集上的準(zhǔn)確率空执,結(jié)果表明在整個(gè)驗(yàn)證集上準(zhǔn)確達(dá)到97.7%,并輸出前10條語句穗椅,將預(yù)測結(jié)果與原結(jié)果進(jìn)行對比辨绊。

def val():

    pre_label = []
    label = []
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    save_path = tf.train.latest_checkpoint('./checkpoints/Lstm_CNN')
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)

    val_x, val_y = process(pm.val_filename, wordid, cat_to_id, max_length=pm.seq_length)
    batch_val = batch_iter(val_x, val_y, batch_size=64)
    for x_batch, y_batch in batch_val:
        real_seq_len = seq_length(x_batch)
        feed_dict = model.feed_data(x_batch, y_batch, real_seq_len, 1.0)
        pre_lab = session.run(model.predict, feed_dict=feed_dict)
        pre_label.extend(pre_lab)
        label.extend(y_batch)
    return pre_label, label
驗(yàn)證結(jié)果

整個(gè)模型的流程,分析完畢匹表。因?qū)W識有限门坷,文中難免有描述不對的地方,請各位批評指正袍镀。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末默蚌,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子苇羡,更是在濱河造成了極大的恐慌绸吸,老刑警劉巖,帶你破解...
    沈念sama閱讀 206,968評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件设江,死亡現(xiàn)場離奇詭異锦茁,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)叉存,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,601評論 2 382
  • 文/潘曉璐 我一進(jìn)店門蜻势,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人鹉胖,你說我怎么就攤上這事握玛」话” “怎么了?”我有些...
    開封第一講書人閱讀 153,220評論 0 344
  • 文/不壞的土叔 我叫張陵挠铲,是天一觀的道長冕屯。 經(jīng)常有香客問我,道長拂苹,這世上最難降的妖魔是什么安聘? 我笑而不...
    開封第一講書人閱讀 55,416評論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮瓢棒,結(jié)果婚禮上浴韭,老公的妹妹穿的比我還像新娘。我一直安慰自己脯宿,他們只是感情好念颈,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,425評論 5 374
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著连霉,像睡著了一般榴芳。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上跺撼,一...
    開封第一講書人閱讀 49,144評論 1 285
  • 那天窟感,我揣著相機(jī)與錄音,去河邊找鬼歉井。 笑死柿祈,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的哩至。 我是一名探鬼主播谍夭,決...
    沈念sama閱讀 38,432評論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼憨募!你這毒婦竟也來了紧索?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,088評論 0 261
  • 序言:老撾萬榮一對情侶失蹤菜谣,失蹤者是張志新(化名)和其女友劉穎珠漂,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體尾膊,經(jīng)...
    沈念sama閱讀 43,586評論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡媳危,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,028評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了冈敛。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片待笑。...
    茶點(diǎn)故事閱讀 38,137評論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖抓谴,靈堂內(nèi)的尸體忽然破棺而出暮蹂,到底是詐尸還是另有隱情寞缝,我是刑警寧澤,帶...
    沈念sama閱讀 33,783評論 4 324
  • 正文 年R本政府宣布仰泻,位于F島的核電站荆陆,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏集侯。R本人自食惡果不足惜被啼,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,343評論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望棠枉。 院中可真熱鬧浓体,春花似錦、人聲如沸辈讶。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,333評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽荞估。三九已至,卻和暖如春稚新,著一層夾襖步出監(jiān)牢的瞬間勘伺,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,559評論 1 262
  • 我被黑心中介騙來泰國打工褂删, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留飞醉,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,595評論 2 355
  • 正文 我出身青樓屯阀,卻偏偏與公主長得像缅帘,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子难衰,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,901評論 2 345

推薦閱讀更多精彩內(nèi)容

  • 激活函數(shù)(Activation Function) 為了讓神經(jīng)網(wǎng)絡(luò)能夠?qū)W習(xí)復(fù)雜的決策邊界(decision bou...
    御風(fēng)之星閱讀 5,110評論 0 8
  • 其實(shí)我和志超學(xué)長只有兩次謀面钦无,幾無交集。但恰恰是這兩次謀面盖袭,讓我見識到一種可能——對夢想的追逐失暂,讓自己所愛真真正正...
    俗人雜文閱讀 455評論 3 6
  • 丫丫進(jìn)入恒潤實(shí)驗(yàn)學(xué)校已經(jīng)半個(gè)學(xué)期了,馬上(11月6號鳄虱、7號)就是半期考試了弟塞。周五的下午,應(yīng)丫丫周四晚上通電...
    建妮閱讀 2,060評論 6 8
  • 1. 信息安全管理的重要性 2. 計(jì)算機(jī)犯罪與攻擊方法 3. 邏輯訪問控制 4. 網(wǎng)絡(luò)基礎(chǔ)設(shè)施安全 5. 加密與密碼學(xué)
    lianzhanshu閱讀 649評論 0 50
  • 概述 常用操作: 樣式:grades = {'Ana':'B', 'John':'A+', 'Denise':'A...
    愁容_騎士閱讀 242評論 0 0