PyTorch-17 使用TorchText進(jìn)行文本分類

要查看圖文并茂版教程宾尚,請(qǐng)移步: http://studyai.com/pytorch-1.4/beginner/text_sentiment_ngrams_tutorial.html

本教程演示如何在 torchtext 中使用文本分類數(shù)據(jù)集,包括

- AG_NEWS,
- SogouNews,
- DBpedia,
- YelpReviewPolarity,
- YelpReviewFull,
- YahooAnswers,
- AmazonReviewPolarity,
- AmazonReviewFull

此示例演示如何使用 TextClassification 數(shù)據(jù)集中的一個(gè)訓(xùn)練用于分類文本數(shù)據(jù)的監(jiān)督學(xué)習(xí)算法串绩。

使用ngrams加載數(shù)據(jù)

一個(gè)ngrams特征包(A bag of ngrams feature)被用來(lái)捕獲一些關(guān)于本地詞序的部分信息帆喇。 在實(shí)際應(yīng)用中乳附,雙字元(bi-gram)或三字元(tri-gram)作為詞組比只使用一個(gè)單詞(word)更有益處辛蚊。例如:

"load data with ngrams"
Bi-grams results: "load data", "data with", "with ngrams"
Tri-grams results: "load data with", "data with ngrams"

TextClassification Dataset支持 ngrams 方法臂外。通過將 ngrams 設(shè)置為2窟扑, 數(shù)據(jù)集中的示例文本將是一個(gè)單字加上bi-grams字符串的列表。

import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./.data'):
    os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

定義模型

模型由 EmbeddingBag 層和線性層組成(見下圖)漏健。 nn.EmbeddingBag 計(jì)算 embeddings 的 “bag” 的平均值嚎货。這里的文本條目有不同的長(zhǎng)度。 nn.EmbeddingBag 此處不需要填充(padding)蔫浆,因?yàn)槲谋鹃L(zhǎng)度以偏移量形式保存殖属。

此外,由于 nn.EmbeddingBag 在線動(dòng)態(tài)地累積了embeddings的平均值瓦盛,因此 nn.EmbeddingBag 可以提高處理張量序列的性能和內(nèi)存效率洗显。
../_images/text_sentiment_ngrams_model.png

import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

初始化模型

AG_NEWS 數(shù)據(jù)集有四個(gè)標(biāo)簽外潜,因此類的數(shù)量是四個(gè)。

1 : World
2 : Sports
3 : Business
4 : Sci/Tec

The vocab size is equal to the length of vocab (including single word and ngrams). The number of classes is equal to the number of labels, which is four in AG_NEWS case.

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

用于產(chǎn)生批量數(shù)據(jù)的函數(shù)

由于文本條目的長(zhǎng)度不同挠唆,因此使用自定義函數(shù) generate_batch() 生成數(shù)據(jù)batch和偏移量处窥。 此函數(shù)傳遞給 torch.utils.data.DataLoader.中的 collate_fn 。 collate_fn 的輸入是一個(gè)具有batch_size大小的張量列表玄组, collate_fn 函數(shù)將它們打包成一個(gè) mini-batch 滔驾。注意這里必須確保 collate_fn 被聲明為頂級(jí)定義的函數(shù), 這樣可以確保每個(gè)線程(worker)都可以使用該功能俄讹。

原始數(shù)據(jù)batch輸入中的文本條目被打包成一個(gè)列表哆致,并作為 nn.EmbeddingBag 的輸入連接為單個(gè)張量。 偏移量(offsets)是分隔符的張量患膛,表示文本張量中單個(gè)序列的起始索引摊阀。Label 是保存單個(gè)文本條目標(biāo)簽的張量。

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    # torch.Tensor.cumsum returns the cumulative sum
    # of elements in the dimension dim.
    # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)

    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text, offsets, label

定義訓(xùn)練和評(píng)估模型的函數(shù)

建議PyTorch用戶使用 torch.utils.data.DataLoader 剩瓶, 它可以輕松地并行加載數(shù)據(jù)(這里有一個(gè)教程: 數(shù)據(jù)加載 )驹溃。 我們?cè)谶@里使用 DataLoader 加載AG_NEWS數(shù)據(jù)集并將其發(fā)送到模型進(jìn)行訓(xùn)練/驗(yàn)證。

from torch.utils.data import DataLoader

def train_func(sub_train_):

    # 訓(xùn)練模型
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # 調(diào)整學(xué)習(xí)率
    scheduler.step()

    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

劃分?jǐn)?shù)據(jù)集并運(yùn)行模型

由于原始的 AG_NEWS 沒有有效的數(shù)據(jù)集延曙,我們將訓(xùn)練數(shù)據(jù)集分割為具有0.95(train)和0.05(valid)分割比的train/valid集豌鹤。 這里我們使用PyTorch核心庫(kù)中的 torch.utils.data.dataset.random_split 函數(shù)。

CrossEntropyLoss 準(zhǔn)則把 nn.LogSoftmax() 和 nn.NLLLoss() 組合進(jìn)了一個(gè)類中枝缔。 它在訓(xùn)練C類分類問題時(shí)非常有用布疙。 SGD 作為優(yōu)化器實(shí)現(xiàn)了隨機(jī)梯度下降法。初始學(xué)習(xí)率設(shè)置為4.0愿卸。這里使用 StepLR 來(lái)調(diào)整各個(gè)回合(epoch)的學(xué)習(xí)率灵临。

import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

在GPU上運(yùn)行模型并得到以下信息:

Epoch: 1 | time in 0 minutes, 11 seconds

Loss: 0.0263(train)     |       Acc: 84.5%(train)
Loss: 0.0001(valid)     |       Acc: 89.0%(valid)

Epoch: 2 | time in 0 minutes, 10 seconds

Loss: 0.0119(train)     |       Acc: 93.6%(train)
Loss: 0.0000(valid)     |       Acc: 89.6%(valid)

Epoch: 3 | time in 0 minutes, 9 seconds

Loss: 0.0069(train)     |       Acc: 96.4%(train)
Loss: 0.0000(valid)     |       Acc: 90.5%(valid)

Epoch: 4 | time in 0 minutes, 11 seconds

Loss: 0.0038(train)     |       Acc: 98.2%(train)
Loss: 0.0000(valid)     |       Acc: 90.4%(valid)

Epoch: 5 | time in 0 minutes, 11 seconds

Loss: 0.0022(train)     |       Acc: 99.0%(train)
Loss: 0.0000(valid)     |       Acc: 91.0%(valid)

使用測(cè)試數(shù)據(jù)集評(píng)估模型

print('Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')

檢查測(cè)試數(shù)據(jù)集的結(jié)果

Loss: 0.0237(test)      |       Acc: 90.5%(test)

在一條隨機(jī)新聞上測(cè)試

使用目前為止最好的模型,測(cè)試一個(gè)高爾夫(golf)新聞趴荸。 標(biāo)簽信息在 此處 提供儒溉。

import re
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer

ag_news_label = {1 : "World",
                 2 : "Sports",
                 3 : "Business",
                 4 : "Sci/Tec"}

def predict(text, model, vocab, ngrams):
    tokenizer = get_tokenizer("basic_english")
    with torch.no_grad():
        text = torch.tensor([vocab[token]
                            for token in ngrams_iterator(tokenizer(text), ngrams)])
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

vocab = train_dataset.get_vocab()
model = model.to("cpu")

print("This is a %s news" %ag_news_label[predict(ex_text_str, model, vocab, 2)])

This is a Sports news

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市发钝,隨后出現(xiàn)的幾起案子顿涣,更是在濱河造成了極大的恐慌,老刑警劉巖酝豪,帶你破解...
    沈念sama閱讀 218,451評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件涛碑,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡孵淘,警方通過查閱死者的電腦和手機(jī)蒲障,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,172評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人揉阎,你說(shuō)我怎么就攤上這事庄撮。” “怎么了余黎?”我有些...
    開封第一講書人閱讀 164,782評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵重窟,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我惧财,道長(zhǎng)巡扇,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,709評(píng)論 1 294
  • 正文 為了忘掉前任垮衷,我火速辦了婚禮厅翔,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘搀突。我一直安慰自己刀闷,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,733評(píng)論 6 392
  • 文/花漫 我一把揭開白布仰迁。 她就那樣靜靜地躺著甸昏,像睡著了一般。 火紅的嫁衣襯著肌膚如雪徐许。 梳的紋絲不亂的頭發(fā)上施蜜,一...
    開封第一講書人閱讀 51,578評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音雌隅,去河邊找鬼翻默。 笑死,一個(gè)胖子當(dāng)著我的面吹牛恰起,可吹牛的內(nèi)容都是我干的修械。 我是一名探鬼主播,決...
    沈念sama閱讀 40,320評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼检盼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼肯污!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起吨枉,我...
    開封第一講書人閱讀 39,241評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤蹦渣,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后东羹,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體剂桥,經(jīng)...
    沈念sama閱讀 45,686評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡忠烛,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,878評(píng)論 3 336
  • 正文 我和宋清朗相戀三年属提,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,992評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡冤议,死狀恐怖斟薇,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情恕酸,我是刑警寧澤堪滨,帶...
    沈念sama閱讀 35,715評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站蕊温,受9級(jí)特大地震影響袱箱,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜义矛,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,336評(píng)論 3 330
  • 文/蒙蒙 一发笔、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧凉翻,春花似錦了讨、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,912評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至垃杖,卻和暖如春男杈,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背缩滨。 一陣腳步聲響...
    開封第一講書人閱讀 33,040評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工势就, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人脉漏。 一個(gè)月前我還...
    沈念sama閱讀 48,173評(píng)論 3 370
  • 正文 我出身青樓苞冯,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親侧巨。 傳聞我的和親對(duì)象是個(gè)殘疾皇子舅锄,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,947評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容