從LSTM到GRU基于門控的循環(huán)神經(jīng)網(wǎng)絡(luò)總結(jié)

技術(shù)交流QQ群:1027579432揍移,歡迎你的加入绎秒!

1.概述

  • 為了改善基本RNN的長期依賴問題搔涝,一種方法是引入門控機制來控制信息的累積速度,包括有選擇性地加入新的信息曾雕,并有選擇性遺忘之前累積的信息奴烙。下面主要介紹兩種基于門控的循環(huán)神經(jīng)網(wǎng)絡(luò):長短時記憶網(wǎng)絡(luò)和門控循環(huán)單元網(wǎng)絡(luò)。因為基本的RNN即\mathbf{h}_{t}=f\left(U \mathbf{h}_{t-1}+W \mathbf{x}_{t}+\mathbf剖张\right)切诀,每層的隱狀態(tài)都是由前一層的隱狀態(tài)經(jīng)變換和激活函數(shù)得到的,反向傳播求導(dǎo)時搔弄,最終得到的導(dǎo)數(shù)會包含每步梯度的連乘幅虑,會導(dǎo)致梯度爆炸或消失。所以顾犹,基本的RNN很難處理長期依賴問題倒庵,即無法學(xué)習(xí)到序列中蘊含的間隔時間較長的規(guī)律褒墨。

2.長短時記憶網(wǎng)絡(luò)LSTM

  • 2.1長短時記憶網(wǎng)絡(luò)是基本的循環(huán)神經(jīng)網(wǎng)絡(luò)的一種變體,可以有效的解決簡單RNN的梯度爆炸或消失問題。LSTM網(wǎng)絡(luò)主要改進在下面兩個方面

    • 1.新的內(nèi)部狀態(tài)\mathbf{c}_{t}:LSTM網(wǎng)絡(luò)引入一個新的內(nèi)部狀態(tài)\mathbf{c}_{t},專門進行線性的循環(huán)信息傳遞屏歹,同時輸出信息給隱藏層的外部狀態(tài)\mathbf{h}_{t}.
      \begin{aligned} \mathbf{c}_{t} &=\mathbf{f}_{t} \odot \mathbf{c}_{t-1}+\mathbf{i}_{t} \odot \tilde{\mathbf{c}}_{t} \\ \mathbf{h}_{t} &=\mathbf{o}_{t} \odot \tanh \left(\mathbf{c}_{t}\right) \end{aligned}
      符號說明:\mathbf{f}_{t}寨辩、\mathbf{i}_{t}\mathbf{o}_{t}分別代表遺忘門锄奢、輸入門失晴、輸出門用來控制信息傳遞的路徑;⊙表示向量元素的點乘拘央;\mathbf{c}_{t-1}表示上一時刻的記憶單元涂屁;\tilde{\mathbf{c}}_{t}表示通過非線性函數(shù)得到的候選狀態(tài)。
      \tilde{\mathbf{c}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+U_{c} \mathbf{h}_{t-1}+\mathbf灰伟_{c}\right)
      在每個時刻t拆又,LSTM網(wǎng)絡(luò)的內(nèi)部狀態(tài)\mathbf{c}_{t}記錄了到當(dāng)前時刻為止的歷史信息。
    • 2.門控機制:LSTM網(wǎng)絡(luò)引入了門控機制栏账,用來控制信息傳遞的路徑帖族,\mathbf{f}_{t}\mathbf{i}_{t}挡爵、\mathbf{o}_{t}分別代表遺忘門竖般、輸入門、輸出門茶鹃。這里的門概念類似于電路中的邏輯門概念涣雕,1表示開放狀態(tài),允許信息通過闭翩;0表示關(guān)閉狀態(tài)挣郭,阻止信息通過。LSTM網(wǎng)絡(luò)中的門是一個抽象的概念疗韵,借助sigmiod函數(shù)丈屹,使得輸出值在(0,1)之間,表示以一定的比例運行信息通過伶棒。三個門的作用如下:
      • 遺忘門\mathbf{f}_{t}控制上一時刻的內(nèi)部狀態(tài)\mathbf{c}_{t-1}需要遺忘多少信息
      • 輸入門\mathbf{i}_{t}控制當(dāng)前時刻的候選狀態(tài)\tilde{\mathbf{c}}_{t}有多少信息需要保存
      • 輸出門\mathbf{o}_{t}控制當(dāng)前時刻的內(nèi)部狀態(tài)\mathbf{c}_{t}有多少信息需要輸出給外部狀態(tài)\mathbf{h}_{t}
        當(dāng)\mathbf{f}_{t}=0, \mathbf{i}_{t}=1時旺垒,記憶單元\mathbf{c}_{t}將歷史信息清空,并將候選狀態(tài)向量\tilde{\mathbf{c}}_{t}寫入肤无。但此時記憶單元\mathbf{c}_{t}依然和上一時刻的歷史信息相關(guān)先蒋。當(dāng)\mathbf{f}_{t}=1, \mathbf{i}_{t}=0時,記憶單元將復(fù)制上一時刻的內(nèi)容宛渐,不寫入新的信息竞漾。三個門的計算公式如下:
        \begin{aligned} \mathbf{i}_{t} &=\sigma\left(W_{i} \mathbf{x}_{t}+U_{i} \mathbf{h}_{t-1}+\mathbf眯搭_{i}\right) \\ \mathbf{f}_{t} &=\sigma\left(W_{f} \mathbf{x}_{t}+U_{f} \mathbf{h}_{t-1}+\mathbf_{f}\right) \\ \mathbf{o}_{t} &=\sigma\left(W_{o} \mathbf{x}_{t}+U_{o} \mathbf{h}_{t-1}+\mathbf业岁_{o}\right) \end{aligned}
        其中鳞仙,激活函數(shù)使用sigmoid函數(shù),其輸出區(qū)間是(0,1)笔时,\mathbf{x}_{t}表示當(dāng)前時刻的輸入棍好,\mathbf{h}_{t-1}表示上一時刻的外部狀態(tài)。
  • 2.2 LSTM網(wǎng)絡(luò)的循環(huán)單元結(jié)構(gòu)如下圖所示允耿,計算過程如下:

    • a.利用上一時刻的外部狀態(tài)\mathbf{h}_{t-1}和當(dāng)前時刻的輸入\mathbf{x}_{t}借笙,計算出三個門,已經(jīng)候選狀態(tài)\tilde{\mathbf{c}}_{t}
      \begin{aligned} \mathbf{i}_{t} &=\sigma\left(W_{i} \mathbf{x}_{t}+U_{i} \mathbf{h}_{t-1}+\mathbf较锡_{i}\right) \\ \mathbf{f}_{t} &=\sigma\left(W_{f} \mathbf{x}_{t}+U_{f} \mathbf{h}_{t-1}+\mathbf业稼_{f}\right) \\ \mathbf{o}_{t} &=\sigma\left(W_{o} \mathbf{x}_{t}+U_{o} \mathbf{h}_{t-1}+\mathbf_{o}\right) \end{aligned}

    \tilde{\mathbf{c}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+U_{c} \mathbf{h}_{t-1}+\mathbf蚂蕴_{c}\right)

    • b.結(jié)合遺忘門\mathbf{f}_{t}和輸入門\mathbf{i}_{t}來更新記憶單元\mathbf{c}_{t}
      \mathbf{c}_{t}=\mathbf{f}_{t} \odot \mathbf{c}_{t-1}+\mathbf{i}_{t} \odot \tilde{\mathbf{c}}_{t}
    • c.結(jié)合輸出門\mathbf{o}_{t}低散,將內(nèi)部狀態(tài)的信息傳遞給外部狀態(tài)\mathbf{h}_{t}
      \mathbf{h}_{t}=\mathbf{o}_{t} \odot \tanh \left(\mathbf{c}_{t}\right)
LSTM Cell

3.門控循環(huán)單元網(wǎng)絡(luò)GRU

  • GRU與LSTM的不同之處在于:GRU不引入額外的記憶單元\mathbf{c}_{t},GRU網(wǎng)絡(luò)引入一個更新門來控制當(dāng)前狀態(tài)需要從歷史狀態(tài)中保留多少信息(不經(jīng)過非線性變換)骡楼,以及需要從候選狀態(tài)中接收多少新的信息熔号。
    \mathbf{h}_{t}=\mathbf{z}_{t} \odot \mathbf{h}_{t-1}+\left(1-\mathbf{z}_{t}\right) \odot g\left(\mathbf{x}_{t}, \mathbf{h}_{t-1} ; \theta\right)
    其中,\mathbf{z}_{t} \in[0,1]為更新門
    \mathbf{z}_{t}=\sigma\left(\mathbf{W}_{z} \mathbf{x}_{t}+\mathbf{U}_{z} \mathbf{h}_{t-1}+\mathbf君编_{z}\right)
    在GRU網(wǎng)絡(luò)中跨嘉,函數(shù)g\left(\mathbf{x}_{t}, \mathbf{h}_{t-1} ; \theta\right)定義為:
    \tilde{\mathbf{h}}_{t}=\tanh \left(W_{h} \mathbf{x}_{t}+U_{h}\left(\mathbf{r}_{t} \odot \mathbf{h}_{t-1}\right)+\mathbf_{h}\right)
    上式中的符號說明:\tilde{\mathbf{h}}_{t}表示當(dāng)前時刻的候選狀態(tài)吃嘿,\mathbf{r}_{t} \in[0,1]為重置門祠乃,用來控制候選狀態(tài)\tilde{\mathbf{h}}_{t}的計算是否依賴上一時刻的狀態(tài)\mathbf{h}_{t-1}
    \mathbf{r}_{t}=\sigma\left(W_{r} \mathbf{x}_{t}+U_{r} \mathbf{h}_{t-1}+\mathbf兑燥_{r}\right)
    當(dāng)\mathbf{r}_{t}=0時亮瓷,候選狀態(tài)\tilde{\mathbf{h}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+\mathbf\right)只和當(dāng)前輸入\mathbf{x}_{t}相關(guān)而與歷史狀態(tài)無關(guān)降瞳。當(dāng)\mathbf{r}_{t}=1時嘱支,候選狀態(tài)\tilde{\mathbf{h}}_{t}=\tanh \left(W_{h} \mathbf{x}_{t}+U_{h} \mathbf{h}_{t-1}+\mathbf_{h}\right)和當(dāng)前輸入\mathbf{x}_{t}相關(guān)挣饥,也和歷史狀態(tài)\mathbf{h}_{t-1}相關(guān)除师,此時和簡單的RNN是一樣的。
    綜合上述各式扔枫,GRU網(wǎng)絡(luò)的狀態(tài)更新方式為:
    \mathbf{h}_{t}=\mathbf{z}_{t} \odot \mathbf{h}_{t-1}+\left(1-\mathbf{z}_{t}\right) \odot \tilde{\mathbf{h}}_{t}
  • 總結(jié):當(dāng)\mathbf{z}_{t}=0, \mathbf{r}=1時汛聚,GRU網(wǎng)絡(luò)退化為簡單的RNN;若\mathbf{z}_{t}=0, \mathbf{r}=0時短荐,當(dāng)前狀態(tài)\mathbf{h}_{t}只和當(dāng)前輸入\mathbf{x}_{t}相關(guān)倚舀,和歷史狀態(tài)\mathbf{h}_{t-1}無關(guān)叹哭。當(dāng)\mathbf{z}_{t}=1時,當(dāng)前狀態(tài)\mathbf{h}_{t}等于上一時刻狀態(tài)\mathbf{h}_{t-1}和當(dāng)前輸入\mathbf{x}_{t}無關(guān)痕貌。
    GRU Cell

3.實戰(zhàn):基于Keras的LSTM和GRU的文本分類

    import random
    import jieba
    import pandas as pd
    import numpy as np
    
    stopwords = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\stopwords.txt", index_col=False, quoting=3, sep="\t", names=["stopword"], encoding="utf-8")
    stopwords = stopwords["stopword"].values
    
    # 加載語料
    laogong_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beilaogongda.csv", encoding="utf-8", sep=",")
    laopo_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beilaopoda.csv", encoding="utf-8", sep=",")
    erzi_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beierzida.csv", encoding="utf-8", sep=",")
    nver_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beinverda.csv", encoding="utf-8", sep=",")
    
    # 刪除語料的nan行
    laogong_df.dropna(inplace=True)
    laopo_df.dropna(inplace=True)
    erzi_df.dropna(inplace=True)
    nver_df.dropna(inplace=True)
    
    # 轉(zhuǎn)換
    laogong = laogong_df.segment.values.tolist()
    laopo = laopo_df.segment.values.tolist()
    erzi = erzi_df.segment.values.tolist()
    nver = nver_df.segment.values.tolist()
    
    # 分詞和去掉停用詞
    
    ## 定義分詞和打標(biāo)簽函數(shù)preprocess_text
    def preprocess_text(content_lines, sentences, category):
        # content_lines是上面轉(zhuǎn)換得到的list
        # sentences是空的list风罩,用來存儲打上標(biāo)簽后的數(shù)據(jù)
        # category是類型標(biāo)簽
        for line in content_lines:
            try:
                segs = jieba.lcut(line)
                segs = [v for v in segs if not str(v).isdigit()]  # 除去數(shù)字
                segs = list(filter(lambda x: x.strip(), segs))  # 除去左右空格
                segs = list(filter(lambda x: len(x) > 1, segs))  # 除去長度為1的字符
                segs = list(filter(lambda x: x not in stopwords, segs))  # 除去停用詞
                sentences.append((" ".join(segs), category))  # 打標(biāo)簽
            except Exception:
                print(line)
                continue
    
    # 調(diào)用上面函數(shù),生成訓(xùn)練數(shù)據(jù)
    sentences = []
    preprocess_text(laogong, sentences, 0)
    preprocess_text(laopo, sentences, 1)
    preprocess_text(erzi, sentences, 2)
    preprocess_text(nver, sentences, 3)
    
    # 先打亂數(shù)據(jù)舵稠,使得數(shù)據(jù)分布均勻超升,然后獲取特征和標(biāo)簽列表
    random.shuffle(sentences)  # 打亂數(shù)據(jù),生成更可靠的訓(xùn)練集
    for sentence in sentences[:10]:    # 輸出前10條數(shù)據(jù)柱查,觀察一下
        print(sentence[0], sentence[1])
    
    # 所有特征和對應(yīng)標(biāo)簽
    all_texts = [sentence[0] for sentence in sentences]
    all_labels = [sentence[1] for sentence in sentences]
    
    
    # 使用LSTM對數(shù)據(jù)進行分類
    from keras.preprocessing.text import Tokenizer
    from keras.preprocessing.sequence import pad_sequences
    from keras.utils import to_categorical
    from keras.layers import Dense, Input, Flatten, Dropout
    from keras.layers import LSTM, Embedding, GRU
    from keras.models import Sequential
    
    
    # 預(yù)定義變量
    MAX_SEQENCE_LENGTH = 100   # 最大序列長度
    EMBEDDING_DIM = 200   # 詞嵌入維度
    VALIDATION_SPLIT = 0.16   # 驗證集比例
    TEST_SPLIT = 0.2  # 測試集比例
    
    # 使用keras的sequence模塊文本序列填充
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(all_texts)
    sequences = tokenizer.texts_to_sequences(all_texts)
    word_index = tokenizer.word_index
    print("Found %s unique tokens." % len(word_index))
    
    
    data = pad_sequences(sequences, maxlen=MAX_SEQENCE_LENGTH)
    labels = to_categorical(np.asarray(all_labels))
    print("data shape:", data.shape)
    print("labels shape:", labels.shape)
    
    # 數(shù)據(jù)切分
    p1 = int(len(data) * (1 - VALIDATION_SPLIT - TEST_SPLIT))
    p2 = int(len(data) * (1 - TEST_SPLIT))
    
    # 訓(xùn)練集
    x_train = data[:p1]
    y_train = labels[:p1]
    
    # 驗證集
    x_val = data[p1:p2]
    y_val = labels[p1:p2]
    
    # 測試集
    x_test = data[p2:]
    y_test = labels[p2:]
    
    # LSTM訓(xùn)練模型
    model = Sequential()
    model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQENCE_LENGTH))
    model.add(LSTM(200, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dropout(0.2))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(labels.shape[1], activation="softmax"))
    model.summary()
    
    # 模型編譯
    model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])
    print(model.metrics_names)
    
    model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=128)
    model.save("lstm.h5")
    # 模型評估
    print(model.evaluate(x_test, y_test))
    
    
    
    # 使用GRU模型
    model = Sequential()
    model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQENCE_LENGTH))
    model.add(GRU(200, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dropout(0.2))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(labels.shape[1], activation="softmax"))
    model.summary()
    
    model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])
    print(model.metrics_names)
    
    model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=128)
    model.save("gru.h5")
    
    print(model.evaluate(x_test, y_test))

4.本文代碼及數(shù)據(jù)集下載

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末廓俭,一起剝皮案震驚了整個濱河市云石,隨后出現(xiàn)的幾起案子唉工,更是在濱河造成了極大的恐慌,老刑警劉巖汹忠,帶你破解...
    沈念sama閱讀 216,651評論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件淋硝,死亡現(xiàn)場離奇詭異,居然都是意外死亡宽菜,警方通過查閱死者的電腦和手機谣膳,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,468評論 3 392
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來铅乡,“玉大人继谚,你說我怎么就攤上這事≌笮遥” “怎么了花履?”我有些...
    開封第一講書人閱讀 162,931評論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長挚赊。 經(jīng)常有香客問我诡壁,道長,這世上最難降的妖魔是什么荠割? 我笑而不...
    開封第一講書人閱讀 58,218評論 1 292
  • 正文 為了忘掉前任妹卿,我火速辦了婚禮,結(jié)果婚禮上蔑鹦,老公的妹妹穿的比我還像新娘夺克。我一直安慰自己,他們只是感情好嚎朽,可當(dāng)我...
    茶點故事閱讀 67,234評論 6 388
  • 文/花漫 我一把揭開白布铺纽。 她就那樣靜靜地躺著,像睡著了一般火鼻。 火紅的嫁衣襯著肌膚如雪室囊。 梳的紋絲不亂的頭發(fā)上雕崩,一...
    開封第一講書人閱讀 51,198評論 1 299
  • 那天,我揣著相機與錄音融撞,去河邊找鬼盼铁。 笑死,一個胖子當(dāng)著我的面吹牛尝偎,可吹牛的內(nèi)容都是我干的饶火。 我是一名探鬼主播,決...
    沈念sama閱讀 40,084評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼致扯,長吁一口氣:“原來是場噩夢啊……” “哼肤寝!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起抖僵,我...
    開封第一講書人閱讀 38,926評論 0 274
  • 序言:老撾萬榮一對情侶失蹤鲤看,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后耍群,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體义桂,經(jīng)...
    沈念sama閱讀 45,341評論 1 311
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,563評論 2 333
  • 正文 我和宋清朗相戀三年蹈垢,在試婚紗的時候發(fā)現(xiàn)自己被綠了慷吊。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,731評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡曹抬,死狀恐怖溉瓶,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情谤民,我是刑警寧澤堰酿,帶...
    沈念sama閱讀 35,430評論 5 343
  • 正文 年R本政府宣布,位于F島的核電站赖临,受9級特大地震影響胞锰,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜兢榨,卻給世界環(huán)境...
    茶點故事閱讀 41,036評論 3 326
  • 文/蒙蒙 一嗅榕、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧吵聪,春花似錦凌那、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,676評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至块攒,卻和暖如春励稳,著一層夾襖步出監(jiān)牢的瞬間佃乘,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,829評論 1 269
  • 我被黑心中介騙來泰國打工驹尼, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留趣避,地道東北人。 一個月前我還...
    沈念sama閱讀 47,743評論 2 368
  • 正文 我出身青樓新翎,卻偏偏與公主長得像程帕,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子地啰,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,629評論 2 354

推薦閱讀更多精彩內(nèi)容