使用 bi-LSTM 對(duì)文本進(jìn)行特征提取

該部分內(nèi)容通過代碼注釋的形式說明荒典。

一遗座、TextCNN 核心部分代碼如下,這里主要關(guān)注 LSTM 類的內(nèi)容扼仲。

import torch.nn as nn
import torch

class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()

        self.linear = nn.Linear(in_features=in_features,
                                out_features=out_features)
        self.init_params()

    def init_params(self):
        nn.init.kaiming_normal_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x):
        x = self.linear(x)
        return x


class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, bidirectional, dropout):
        """
        Args:
            input_size: x 的特征維度
            hidden_size: 隱層的特征維度
            num_layers: LSTM 層數(shù)
        """
        super(LSTM, self).__init__()

        self.rnn = nn.LSTM(
            input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout
        )

        self.init_params()

    def init_params(self):
        for i in range(self.rnn.num_layers):
            nn.init.orthogonal_(getattr(self.rnn, 'weight_hh_l{}'.format(i)))
            nn.init.kaiming_normal_(getattr(self.rnn, 'weight_ih_l{}'.format(i)))
            nn.init.constant_(getattr(self.rnn, 'bias_hh_l{}'.format(i)), val=0)
            nn.init.constant_(getattr(self.rnn, 'bias_ih_l{}'.format(i)), val=0)
            getattr(self.rnn, 'bias_hh_l{}'.format(i)).chunk(4)[1].fill_(1)

            if self.rnn.bidirectional:
                nn.init.orthogonal_(
                    getattr(self.rnn, 'weight_hh_l{}_reverse'.format(i)))
                nn.init.kaiming_normal_(
                    getattr(self.rnn, 'weight_ih_l{}_reverse'.format(i)))
                nn.init.constant_(
                    getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)), val=0)
                nn.init.constant_(
                    getattr(self.rnn, 'bias_ih_l{}_reverse'.format(i)), val=0)
                getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)).chunk(4)[1].fill_(1)

    def forward(self, x, lengths):
        ''' 
        關(guān)于 pack_padded_sequence 和 pad_packed_sequence 函數(shù)的用法見本文最后
        '''
        # x: [seq_len, batch_size, input_size]
        # lengths: [batch_size]
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths)

        # packed_x远寸, packed_output: PackedSequence 對(duì)象
        # hidden: [num_layers * bidirectional, batch_size, hidden_size]
        # cell: [num_layers * bidirectional, batch_size, hidden_size]
        # Note: hidden 作為每個(gè)時(shí)間步的輸出,cell 作為細(xì)胞狀態(tài)屠凶。在相鄰的時(shí)間步之間驰后,cell 的值一般變化不大,但
        #              hidden 的差別一般會(huì)變化很大矗愧。
        packed_output, (hidden, cell) = self.rnn(packed_x)

        # output: [max_seq_len, batch_size, hidden_size * 2]
        # output_lengths: [batch_size]
        # 這里的 output 作為接下來全連接層的輸入
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)

        return hidden, output


class TextRNN(nn.Module):

    def __init__(self, embedding_dim, output_dim, hidden_size, num_layers, bidirectional, dropout,
                 pretrained_embeddings):
        super(TextRNN, self).__init__()

        self.embedding = nn.Embedding.from_pretrained(
            pretrained_embeddings, freeze=False)
        self.rnn = LSTM(embedding_dim, hidden_size, num_layers, bidirectional, dropout)

        self.fc = Linear(hidden_size * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        text, text_lengths = x
        # text: [sent len, batch size]
        embedded = self.dropout(self.embedding(text))
        # embedded: [sent len, batch size, emb dim]

        hidden, outputs = self.rnn(embedded, text_lengths)

        hidden = self.dropout(
            torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))  # 連接最后一層的雙向輸出

        return self.fc(hidden)

二灶芝、完整 demo

import torch.nn as nn
import torch

from torchtext import data
from torchtext import vocab
from tqdm import tqdm


class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()

        self.linear = nn.Linear(in_features=in_features,
                                out_features=out_features)
        self.init_params()

    def init_params(self):
        nn.init.kaiming_normal_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x):
        x = self.linear(x)
        return x


class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, bidirectional, dropout):
        """
        Args:
            input_size: x 的特征維度
            hidden_size: 隱層的特征維度
            num_layers: LSTM 層數(shù)
        """
        super(LSTM, self).__init__()

        self.rnn = nn.LSTM(
            input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout
        )

        self.init_params()

    def init_params(self):
        for i in range(self.rnn.num_layers):
            nn.init.orthogonal_(getattr(self.rnn, 'weight_hh_l{}'.format(i)))
            nn.init.kaiming_normal_(getattr(self.rnn, 'weight_ih_l{}'.format(i)))
            nn.init.constant_(getattr(self.rnn, 'bias_hh_l{}'.format(i)), val=0)
            nn.init.constant_(getattr(self.rnn, 'bias_ih_l{}'.format(i)), val=0)
            getattr(self.rnn, 'bias_hh_l{}'.format(i)).chunk(4)[1].fill_(1)

            if self.rnn.bidirectional:
                nn.init.orthogonal_(
                    getattr(self.rnn, 'weight_hh_l{}_reverse'.format(i)))
                nn.init.kaiming_normal_(
                    getattr(self.rnn, 'weight_ih_l{}_reverse'.format(i)))
                nn.init.constant_(
                    getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)), val=0)
                nn.init.constant_(
                    getattr(self.rnn, 'bias_ih_l{}_reverse'.format(i)), val=0)
                getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)).chunk(4)[1].fill_(1)

    def forward(self, x, lengths):
        # x: [seq_len, batch_size, input_size]
        # lengths: [batch_size]
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths)

        # packed_x, packed_output: PackedSequence 對(duì)象
        # hidden: [num_layers * bidirectional, batch_size, hidden_size]
        # cell: [num_layers * bidirectional, batch_size, hidden_size]
        # Note: hidden 作為每個(gè)時(shí)間步的輸出唉韭,cell 作為細(xì)胞狀態(tài)夜涕。在相鄰的時(shí)間步之間,cell 的值一般變化不大属愤,但
        #              hidden 的差別一般會(huì)變化很大女器。
        packed_output, (hidden, cell) = self.rnn(packed_x)

        # output: [max_seq_len, batch_size, hidden_size * 2]
        # output_lengths: [batch_size]
        # 這里的 output 作為接下來全連接層的輸入
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)

        return hidden, output


class TextRNN(nn.Module):

    def __init__(self, embedding_dim, output_dim, hidden_size, num_layers, bidirectional, dropout,
                 pretrained_embeddings):
        super(TextRNN, self).__init__()

        self.embedding = nn.Embedding.from_pretrained(
            pretrained_embeddings, freeze=False)
        self.rnn = LSTM(embedding_dim, hidden_size, num_layers, bidirectional, dropout)

        self.fc = Linear(hidden_size * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        text, text_lengths = x
        # text: [sent len, batch size]
        embedded = self.dropout(self.embedding(text))
        # embedded: [sent len, batch size, emb dim]

        hidden, outputs = self.rnn(embedded, text_lengths)

        hidden = self.dropout(
            torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))  # 連接最后一層的雙向輸出

        return self.fc(hidden)


if __name__ == '__main__':
    embedding_file = '/home/jason/Desktop/data/embeddings/glove/glove.840B.300d.txt'
    path = '/home/jason/Desktop/data/SST-2/'

    cache_dir = '.cache/'
    batch_size = 6
    vectors = vocab.Vectors(embedding_file, cache_dir)

    text_field = data.Field(tokenize='spacy',
                            lower=True,
                            include_lengths=True,
                            fix_length=5)
    label_field = data.LabelField(dtype=torch.long)

    train, dev, test = data.TabularDataset.splits(path=path,
                                                  train='train.tsv',
                                                  validation='dev.tsv',
                                                  test='test.tsv',
                                                  format='tsv',
                                                  skip_header=True,
                                                  fields=[('text', text_field), ('label', label_field)])

    text_field.build_vocab(train,
                           dev,
                           test,
                           max_size=25000,
                           vectors=vectors,
                           unk_init=torch.Tensor.normal_)
    label_field.build_vocab(train, dev, test)

    pretrained_embeddings = text_field.vocab.vectors
    labels = label_field.vocab.vectors

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    train_iter, dev_iter, test_iter = data.BucketIterator.splits((train, dev, test),
                                                                 batch_sizes=(batch_size, len(dev), len(test)),
                                                                 sort_key=lambda x: len(x.text),
                                                                 sort_within_batch=True,
                                                                 repeat=False,
                                                                 shuffle=True,
                                                                 device=device
                                                                )

    model = TextRNN(300, 2, 200, 2, True, 0.4, pretrained_embeddings)

    for step, batch in enumerate(tqdm(train_iter, desc="Iteration")):
        logits = model(batch.text)
        break

三、關(guān)于 pack_padded_sequence 和 pad_packed_sequence 函數(shù)的簡單示例

train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
           torch.tensor([3, 3, 3, 3, 3]),
           torch.tensor([6, 6])]
x = nn.utils.rnn.pad_sequence(train_x, batch_first=True)
print('>1: ', x)
pack_x = nn.utils.rnn.pack_padded_sequence(x, [7, 5, 2], batch_first=True)
print('>2: ', pack_x)
reverse_x = nn.utils.rnn.pad_packed_sequence(pack_x)
print('>3: ', reverse_x)

執(zhí)行結(jié)果如下:

>1:  tensor([[1, 1, 1, 1, 1, 1, 1],
        [3, 3, 3, 3, 3, 0, 0],
        [6, 6, 0, 0, 0, 0, 0]])
>2:  PackedSequence(data=tensor([1, 3, 6, 1, 3, 6, 1, 3, 1, 3, 1, 3, 1, 1]), batch_sizes=tensor([3, 3, 2, 2, 2, 1, 1]), sorted_indices=None, unsorted_indices=None)
>3:  (tensor([[1, 3, 6],
        [1, 3, 6],
        [1, 3, 0],
        [1, 3, 0],
        [1, 3, 0],
        [1, 0, 0],
        [1, 0, 0]]), tensor([7, 5, 2]))
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末住诸,一起剝皮案震驚了整個(gè)濱河市驾胆,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌贱呐,老刑警劉巖丧诺,帶你破解...
    沈念sama閱讀 217,406評(píng)論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異奄薇,居然都是意外死亡驳阎,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,732評(píng)論 3 393
  • 文/潘曉璐 我一進(jìn)店門馁蒂,熙熙樓的掌柜王于貴愁眉苦臉地迎上來呵晚,“玉大人,你說我怎么就攤上這事远搪×痈伲” “怎么了?”我有些...
    開封第一講書人閱讀 163,711評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵谁鳍,是天一觀的道長癞季。 經(jīng)常有香客問我,道長倘潜,這世上最難降的妖魔是什么绷柒? 我笑而不...
    開封第一講書人閱讀 58,380評(píng)論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮涮因,結(jié)果婚禮上废睦,老公的妹妹穿的比我還像新娘。我一直安慰自己养泡,他們只是感情好嗜湃,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,432評(píng)論 6 392
  • 文/花漫 我一把揭開白布奈应。 她就那樣靜靜地躺著,像睡著了一般购披。 火紅的嫁衣襯著肌膚如雪杖挣。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,301評(píng)論 1 301
  • 那天刚陡,我揣著相機(jī)與錄音惩妇,去河邊找鬼。 笑死筐乳,一個(gè)胖子當(dāng)著我的面吹牛歌殃,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播蝙云,決...
    沈念sama閱讀 40,145評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼氓皱,長吁一口氣:“原來是場噩夢(mèng)啊……” “哼!你這毒婦竟也來了贮懈?” 一聲冷哼從身側(cè)響起匀泊,我...
    開封第一講書人閱讀 39,008評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎朵你,沒想到半個(gè)月后各聘,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,443評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡抡医,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,649評(píng)論 3 334
  • 正文 我和宋清朗相戀三年躲因,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片忌傻。...
    茶點(diǎn)故事閱讀 39,795評(píng)論 1 347
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡大脉,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出水孩,到底是詐尸還是另有隱情镰矿,我是刑警寧澤,帶...
    沈念sama閱讀 35,501評(píng)論 5 345
  • 正文 年R本政府宣布俘种,位于F島的核電站秤标,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏宙刘。R本人自食惡果不足惜苍姜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,119評(píng)論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望悬包。 院中可真熱鬧衙猪,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,731評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至饶号,卻和暖如春铁追,著一層夾襖步出監(jiān)牢的瞬間茫船,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,865評(píng)論 1 269
  • 我被黑心中介騙來泰國打工扭屁, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留算谈,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,899評(píng)論 2 370
  • 正文 我出身青樓料滥,卻偏偏與公主長得像然眼,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子葵腹,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,724評(píng)論 2 354