[TorchText]使用

只是教程的搬運(yùn)工-.-

Field的使用

Torchtext采用聲明式方法加載數(shù)據(jù)薯鼠,需要先聲明一個(gè)Field對(duì)象械蹋,這個(gè)Field對(duì)象指定你想要怎么處理某個(gè)數(shù)據(jù),each Field has its own Vocab class。

  • tokenize傳入一個(gè)函數(shù)郊艘,表示如何將文本str變成token
  • sequential表示是否切分?jǐn)?shù)據(jù)暇仲,如果數(shù)據(jù)已經(jīng)是序列化的了而且是數(shù)字類型的副渴,則應(yīng)該傳遞參數(shù)use_vocab = Falsesequential = False
    除了上面提到的關(guān)鍵字參數(shù)之外全度,Field類還允許用戶指定特殊標(biāo)記(用于標(biāo)記詞典外詞語(yǔ)的unk_token,用于填充的pad_token佑颇,用于句子結(jié)尾的eos_token以及用于句子開(kāi)頭的可選的init_token)草娜。設(shè)置將第一維是batch還是sequence(第一維默認(rèn)是sequence),并選擇是否允許在運(yùn)行時(shí)決定序列長(zhǎng)度還是預(yù)先就決定好茬贵,Field類的文檔
from torchtext.data import Field
tokenize = lambda x: x.split()

TEXT = Field(sequential=True, tokenize=tokenize, lower=True)
LABEL = Field(sequential=False, use_vocab=False)

使用spacy進(jìn)行tokenizer解藻,

import spacy
spacy_en = spacy.load('en')

def tokenizer(text): # create a tokenizer function
    return [tok.text for tok in spacy_en.tokenizer(text)]

TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True)
LABEL = data.Field(sequential=False, use_vocab=False)

構(gòu)建Dataset

Fields知道怎么處理原始數(shù)據(jù)葡盗,現(xiàn)在我們需要告訴Fields去處理哪些數(shù)據(jù)觅够。這就是我們需要用到Dataset的地方。Torchtext中有各種內(nèi)置Dataset奄妨,用于處理常見(jiàn)的數(shù)據(jù)格式苹祟。 對(duì)于csv/tsv文件,TabularDataset類很方便直焙。 以下是我們?nèi)绾问褂肨abularDataset從csv文件讀取數(shù)據(jù)的示例:

from torchtext.data import TabularDataset

tv_datafields = [("id", None), # 我們不會(huì)需要id奔誓,所以我們傳入的filed是None
                 ("comment_text", TEXT), ("toxic", LABEL),
                 ("severe_toxic", LABEL), ("threat", LABEL),
                 ("obscene", LABEL), ("insult", LABEL),
                 ("identity_hate", LABEL)]
trn, vld = TabularDataset.splits(
               path="data", # 數(shù)據(jù)存放的根目錄
               train='train.csv', validation="valid.csv",
               format='csv',
               skip_header=True, # 如果你的csv有表頭, 確保這個(gè)表頭不會(huì)作為數(shù)據(jù)處理
               fields=tv_datafields)

tst_datafields = [("id", None), # 我們不會(huì)需要id搔涝,所以我們傳入的filed是None
                  ("comment_text", TEXT)]
tst = TabularDataset(
           path="data/test.csv", # 文件路徑
           format='csv',
           skip_header=True, # 如果你的csv有表頭, 確保這個(gè)表頭不會(huì)作為數(shù)據(jù)處理
           fields=tst_datafields)
  • 我們傳入(name庄呈,field)對(duì)的列表作為fields參數(shù)。我們傳入的fields必須與列的順序相同斜纪。對(duì)于我們不使用的列,我們?cè)趂ields的位置傳入一個(gè)None腺劣。
  • splits方法通過(guò)應(yīng)用相同的處理為訓(xùn)練數(shù)據(jù)和驗(yàn)證數(shù)據(jù)創(chuàng)建Dataset因块。 它也可以處理測(cè)試數(shù)據(jù)涡上,但由于測(cè)試數(shù)據(jù)與訓(xùn)練數(shù)據(jù)和驗(yàn)證數(shù)據(jù)有不同的格式,因此我們創(chuàng)建了不同的Dataset歼冰。
  • 數(shù)據(jù)集大多可以和list一樣去處理耻警。 為了理解這一點(diǎn)甘穿,我們看看Dataset內(nèi)部是怎么樣的。 數(shù)據(jù)集可以像list一樣進(jìn)行索引和迭代秸滴,所以讓我們看看第一個(gè)元素是什么樣的:
>>> trn[0]
<torchtext.data.example.Example at 0x10d3ed3c8>

>>> trn[0].__dict__.keys()
dict_keys(['comment_text', 'toxic', 'severe_toxic', 'threat', 'obscene', 'insult', 'identity_hate'])

>>> trn[0].comment_text[:3]
['explanation', 'why', 'the']

在一個(gè)TabularDataset里面直接指定train,test,validation荡含,好像更為方便一些届垫。

from torchtext import data
train, val, test = data.TabularDataset.splits(
        path='./data/', train='train.tsv',
        validation='val.tsv', test='test.tsv', format='tsv',
        fields=[('Text', TEXT), ('Label', LABEL)])

詞表

Torchtext將單詞映射為整數(shù),但必須告訴它應(yīng)該處理的全部單詞误债。 在我們的例子中寝蹈,我們可能只想在訓(xùn)練集上建立詞匯表登淘,所以我們運(yùn)行代碼:TEXT.build_vocab(trn)。這使得torchtext遍歷訓(xùn)練集中的所有元素槽惫,檢查T(mén)EXT字段的內(nèi)容界斜,并將其添加到其詞匯表中合冀。Torchtext有自己的Vocab類來(lái)處理詞匯。Vocab類在stoi屬性中包含從word到id的映射峭判,并在其itos屬性中包含反向映射棕叫。 除此之外俺泣,它可以為word2vec等預(yù)訓(xùn)練的embedding自動(dòng)構(gòu)建embedding矩陣。Vocab類還可以使用像max_size和min_freq這樣的選項(xiàng)來(lái)表示詞匯表中有多少單詞或單詞出現(xiàn)的次數(shù)横漏。未包含在詞匯表中的單詞將被轉(zhuǎn)換成<unk>缎浇。
TEXT.build_vocab(train, vectors="glove.6B.100d"),可以給vectors直接傳入一個(gè)字符串類型的赴肚,那么會(huì)自動(dòng)下載你需要的詞向量,存放的位置是./.vector_cache 的文件夾下亡笑,或者可以使用類vocab.Vectors指定你自己的詞向量仑乌。

Note you can directly pass in a string and it will download pre-trained word vectors and load them for you. You can also use your own vectors by using this class vocab.Vectors. The downloaded word embeddings will stay at ./.vector_cache folder. I have not yet discovered a way to specify a custom location to store the downloaded vectors (there should be a way right?).

將預(yù)訓(xùn)練的詞向量加載到模型中琴锭。

from torchtext import data
TEXT.build_vocab(train, vectors="glove.6B.100d")
vocab = TEXT.vocab
self.embed = nn.Embedding(len(vocab), emb_dim)
self.embed.weight.data.copy_(vocab.vectors)

從詞表變回單詞

先安裝工具,先從 GitHub上git clone代碼然后進(jìn)入目錄安裝决帖,因?yàn)閜ip安裝的不是最新版本

cd revtok/
python setup.py install

使用時(shí)候只需要使用可翻轉(zhuǎn)的Field代替原來(lái)的Field就可以創(chuàng)建雙向的轉(zhuǎn)換的Vocb了。

from torchtext import data
TEXT = data.ReversibleField(sequential=True, lower=True, include_lengths=True)
...
for data in valid_iter:
        (x, x_lengths), y = data.Text, data.Description
        orig_text = TEXT.reverse(x.data)

對(duì)于預(yù)訓(xùn)練單詞表中沒(méi)有的單詞我們可以隨機(jī)進(jìn)行初始化扁远。

  • 在構(gòu)建Field的時(shí)候設(shè)置include_lengths字段為T(mén)rue可以在返回minibatch的時(shí)候同時(shí)返回一個(gè)表示每個(gè)句子長(zhǎng)度的list畅买。

include_lengths: Whether to return a tuple of a padded minibatch and
a list containing the lengths of each examples, or just a padded
minibatch. Default: False.

隨機(jī)初始化未知單詞。

def init_emb(vocab, init="randn", num_special_toks=2):
    emb_vectors = vocab.vectors
    sweep_range = len(vocab)
    running_norm = 0.
    num_non_zero = 0
    total_words = 0
    for i in range(num_special_toks, sweep_range):
        if len(emb_vectors[i, :].nonzero()) == 0:
            # std = 0.05 is based on the norm of average GloVE 100-dim word vectors
            if init == "randn":
                torch.nn.init.normal(emb_vectors[i], mean=0, std=0.05)
        else:
            num_non_zero += 1
            running_norm += torch.norm(emb_vectors[i])
        total_words += 1
    logger.info("average GloVE norm is {}, number of known words are {}, total number of words are {}".format(
        running_norm / num_non_zero, num_non_zero, total_words))

構(gòu)建迭代器

在torchvision和PyTorch中帝火,數(shù)據(jù)的處理和批處理由DataLoaders處理犀填。 出于某種原因嗓违,torchtext相同的東西又命名成了Iterators蹂季。 基本功能是一樣的,但我們將會(huì)看到佳窑,Iterators具有一些NLP特有的便捷功能父能。

  • 對(duì)于驗(yàn)證集和訓(xùn)練集合使用BucketIterator.splits(),目的是自動(dòng)進(jìn)行shuffle和padding何吝,并且為了訓(xùn)練效率期間,盡量把句子長(zhǎng)度相似的shuffle在一起瓣喊。
  • 對(duì)于測(cè)試集用Iterator黔酥,因?yàn)椴挥?code>sort跪者。
  • sort 是對(duì)全體數(shù)據(jù)按照升序順序進(jìn)行排序,而sort_within_batch僅僅對(duì)一個(gè)batch內(nèi)部的數(shù)據(jù)進(jìn)行排序逗概。
  • sort_within_batch參數(shù)設(shè)置為T(mén)rue時(shí)忘衍,按照sort_key按降序?qū)γ總€(gè)小批次內(nèi)的數(shù)據(jù)進(jìn)行降序排序。當(dāng)你想對(duì)padded序列使用pack_padded_sequence轉(zhuǎn)換為PackedSequence對(duì)象時(shí)瑟押,這是必需的狸吞。
  • 注意sortshuffle默認(rèn)只是對(duì)train=True字段進(jìn)行的蹋偏,但是train字段默認(rèn)是True至壤。所以測(cè)試集合可以這么寫(xiě)testIter = Iterator(tst, batch_size = 64, device =-1, train=False)寫(xiě)法等價(jià)于下面的一長(zhǎng)串寫(xiě)法像街。
  • repeat 是否連續(xù)的訓(xùn)練無(wú)數(shù)個(gè)batch ,默認(rèn)是False
  • device 可以是torch.device
from torchtext.data import Iterator, BucketIterator

train_iter, val_iter = BucketIterator.splits((trn, vld), 
                                             # 我們把Iterator希望抽取的Dataset傳遞進(jìn)去
                                             batch_sizes=(25, 25),
                                             device=-1, 
                                             # 如果要用GPU,這里指定GPU的編號(hào)
                                             sort_key=lambda x: len(x.comment_text), 
                                             # BucketIterator 依據(jù)什么對(duì)數(shù)據(jù)分組
                                             sort_within_batch=False,
                                             repeat=False)
                                             # repeat設(shè)置為False脓斩,因?yàn)槲覀兿胍b這個(gè)迭代器層随静。
test_iter = Iterator(tst, batch_size=64, 
                     device=-1, 
                     sort=False, 
                     sort_within_batch=False, 
                     repeat=False)

BucketIterator是torchtext最強(qiáng)大的功能之一吗讶。它會(huì)自動(dòng)將輸入序列進(jìn)行shuffle并做bucket照皆。這個(gè)功能強(qiáng)大的原因是——正如我前面提到的——我們需要填充輸入序列使得長(zhǎng)度相同才能批處理。 例如昭卓,序列

[ [3, 15, 2, 7], 
  [4, 1], 
  [5, 5, 6, 8, 1] ]

會(huì)需要pad成

[ [3, 15, 2, 7, 0],
  [4, 1, 0, 0, 0],
  [5, 5, 6, 8, 1] ]

填充量由batch中最長(zhǎng)的序列決定葬凳。因此室奏,當(dāng)序列長(zhǎng)度相似時(shí),填充效率最高昌简。BucketIterator會(huì)在在后臺(tái)執(zhí)行這些操作纯赎。需要注意的是,你需要告訴BucketIterator你想在哪個(gè)數(shù)據(jù)屬性上做bucket念恍。在我們的例子中晚顷,我們希望根據(jù)comment_text字段的長(zhǎng)度進(jìn)行bucket處理该默,因此我們將其作為關(guān)鍵字參數(shù)傳入sort_key = lambda x: len(x.comment_text)

train_iter, val_iter, test_iter = data.BucketIterator.splits(
        (train, val, test), sort_key=lambda x: len(x.Text),
        batch_sizes=(32, 256, 1), device=-1)

BucketIteratorIterator的區(qū)別是,BucketIterator盡可能的把長(zhǎng)度相似的句子放在一個(gè)batch里面匣摘。

Defines an iterator that batches examples of similar lengths together

封裝迭代器

目前音榜,迭代器返回一個(gè)名為torchtext.data.Batch的自定義數(shù)據(jù)類型捧弃。Batch類具有與Example類相似的API塔橡,將來(lái)自每個(gè)字段的一批數(shù)據(jù)作為屬性。

>>> train_iter
[torchtext.data.batch.Batch of size 25]
    [.comment_text]:[torch.LongTensor of size 494x25]
    [.toxic]:[torch.LongTensor of size 25]
    [.severe_toxic]:[torch.LongTensor of size 25]
    [.threat]:[torch.LongTensor of size 25]
    [.obscene]:[torch.LongTensor of size 25]
    [.insult]:[torch.LongTensor of size 25]
    [.identity_hate]:[torch.LongTensor of size 25]
>>> train_iter.__dict__.keys()
dict_keys(['batch_size', 'dataset', 'fields', 'comment_text', 'toxic', 'severe_toxic', 'threat', 'obscene', 'insult', 'identity_hate'])
>>> train_iter.comment_text
tensor([[  15,  606,  280,  ...,   15,   63,   15],
        [ 360,  693,   18,  ...,   29,    4,    2],
        [  45,  584,   14,  ...,   21,  664,  645],
        ...,
        [   1,    1,    1,  ...,   84,    1,    1],
        [   1,    1,    1,  ...,  118,    1,    1],
        [   1,    1,    1,  ...,   15,    1,    1]])
>>> train_iter.toxic
tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  1,  0,  1,  0,  1,  0,  0,  0,  0])

不幸的是,這種自定義數(shù)據(jù)類型使得代碼重用變得困難(因?yàn)槊看瘟忻l(fā)生變化時(shí)癞谒,我們都需要修改代碼)弹砚,并且使torchtext在某些情況(如torchsample和fastai)下很難與其他庫(kù)一起使用。
我希望這可以在未來(lái)得到優(yōu)化(我正在考慮提交PR朱沃,如果我可以決定API應(yīng)該是什么樣的話),但同時(shí)搬卒,我們使用簡(jiǎn)單的封裝來(lái)使batch易于使用契邀。

具體來(lái)說(shuō)失暴,我們將把batch轉(zhuǎn)換為形式為(x逗扒,y)的元組,其中x是自變量(模型的輸入)允瞧,y是因變量(標(biāo)簽數(shù)據(jù))蛮拔。 代碼如下:

class BatchWrapper:
    def __init__(self, dl, x_var, y_vars):
        self.dl, self.x_var, self.y_vars = dl, x_var, y_vars # 傳入自變量x列表和因變量y列表

    def __iter__(self):
        for batch in self.dl:
            x = getattr(batch, self.x_var) # 在這個(gè)封裝中只有一個(gè)自變量

            if self.y_vars is not None: # 把所有因變量cat成一個(gè)向量
                temp = [getattr(batch, feat).unsqueeze(1) for feat in self.y_vars]
                y = torch.cat(temp, dim=1).float()
            else:
                y = torch.zeros((1))

            yield (x, y)

    def __len__(self):
        return len(self.dl)

train_dl = BatchWrapper(train_iter, "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"])
valid_dl = BatchWrapper(val_iter, "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"])
test_dl = BatchWrapper(test_iter, "comment_text", None)

我們?cè)谶@里所做的是將Batch對(duì)象轉(zhuǎn)換為輸入和輸出的元組建炫。

>>> next(train_dl.__iter__())
(tensor([[  15,   15,   15,  ...,  375,  354,   44],
         [ 601,  657,  360,  ...,   27,   63,  739],
         [ 242,   22,   45,  ...,  526,    4,    3],
         ...,
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1]]),
 tensor([[ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  1.,  0.,  1.,  1.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.]]))

訓(xùn)練模型

我們將使用一個(gè)簡(jiǎn)單的LSTM來(lái)演示如何根據(jù)我們構(gòu)建的數(shù)據(jù)來(lái)訓(xùn)練文本分類器:

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class SimpleLSTMBaseline(nn.Module):
    def __init__(self, hidden_dim, emb_dim=300, num_linear=1):
        super().__init__() 
        # 詞匯量是 len(TEXT.vocab)
        self.embedding = nn.Embedding(len(TEXT.vocab), emb_dim)
        self.encoder = nn.LSTM(emb_dim, hidden_dim, num_layers=1)
        self.linear_layers = []
        # 中間fc層
        for _ in range(num_linear - 1):
            self.linear_layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.linear_layers = nn.ModuleList(self.linear_layers)
        # 輸出層
        self.predictor = nn.Linear(hidden_dim, 6)

    def forward(self, seq):
        hdn, _ = self.encoder(self.embedding(seq))
        feature = hdn[-1, :, :]  # 選擇最后一個(gè)output
        for layer in self.linear_layers:
          feature = layer(feature)
        preds = self.predictor(feature)
        return preds

em_sz = 100
nh = 500
model = SimpleBiLSTMBaseline(nh, emb_dim=em_sz) 

現(xiàn)在,我們將編寫(xiě)訓(xùn)練循環(huán)衍慎。 多虧我們所有的預(yù)處理皮钠,讓這變得非常簡(jiǎn)單非常簡(jiǎn)單麦轰。我們可以使用我們包裝的Iterator進(jìn)行迭代,并且數(shù)據(jù)在移動(dòng)到GPU和適當(dāng)數(shù)字化后將自動(dòng)傳遞給我們末荐。

import tqdm

opt = optim.Adam(model.parameters(), lr=1e-2)
loss_func = nn.BCEWithLogitsLoss()

epochs = 2

for epoch in range(1, epochs + 1):
    running_loss = 0.0
    running_corrects = 0
    model.train() # 訓(xùn)練模式
    for x, y in tqdm.tqdm(train_dl): # 由于我們的封裝甲脏,我們可以直接對(duì)數(shù)據(jù)進(jìn)行迭代
        opt.zero_grad()
        preds = model(x)
        loss = loss_func(y, preds)
        loss.backward()
        opt.step()

        running_loss += loss.data[0] * x.size(0)

    epoch_loss = running_loss / len(trn)

    # 計(jì)算驗(yàn)證數(shù)據(jù)的誤差
    val_loss = 0.0
    model.eval() # 評(píng)估模式
    for x, y in valid_dl:
        preds = model(x)
        loss = loss_func(y, preds)
        val_loss += loss.data[0] * x.size(0)

    val_loss /= len(vld)
    print('Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, epoch_loss, val_loss))

這就只是一個(gè)標(biāo)準(zhǔn)的訓(xùn)練循環(huán)块请。 現(xiàn)在來(lái)產(chǎn)生我們的預(yù)測(cè)

test_preds = []
for x, y in tqdm.tqdm(test_dl):
    preds = model(x)
    preds = preds.data.numpy()
    # 模型的實(shí)際輸出是logit,所以再經(jīng)過(guò)一個(gè)sigmoid函數(shù)
    preds = 1 / (1 + np.exp(-preds))
    test_preds.append(preds)
    test_preds = np.hstack(test_preds)

最后牍白,我們可以將我們的預(yù)測(cè)寫(xiě)入一個(gè)csv文件抖棘。

import pandas as pd
df = pd.read_csv("data/test.csv")
for i, col in enumerate(["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]):
    df[col] = test_preds[:, i]

df.drop("comment_text", axis=1).to_csv("submission.csv", index=False)

對(duì)于mask的數(shù)據(jù)的支持

RNN需要將PyTorch 變量打包成一個(gè)padded 序列切省。

  • sort_within_batch=True ,如果想使用PyTorch里面的pack_padded sequence時(shí)候必須要設(shè)置為T(mén)RUE
  • sort_key 是對(duì)數(shù)據(jù)集合在一個(gè)batch的的排序函數(shù)般渡。
  • repeat 是否使用多個(gè)epoch驯用。

repeat – Whether to repeat the iterator for multiple epochs. Default: False.

TEXT = data.ReversibleField(sequential=True, lower=True, include_lengths=True)
train, val, test = data.TabularDataset.splits(
        path='./data/', train='train.tsv',
        validation='val.tsv', test='test.tsv', format='tsv',
        fields=[('Text', TEXT), ('Label', LABEL)])
        
train_iter, val_iter, test_iter = data.Iterator.splits(
        (train, val, test), sort_key=lambda x: len(x.Text), 
        batch_sizes=(32, 256, 256), device=args.gpu, 
        sort_within_batch=True, repeat=False)

在模型中這么使用

def forward(self, input, lengths=None):
        embed_input = self.embed(input)

        packed_emb = embed_input
        if lengths is not None:
            lengths = lengths.view(-1).tolist()
            packed_emb = nn.utils.rnn.pack_padded_sequence(embed_input, lengths)

        output, hidden = self.encoder(packed_emb)  # embed_input

        if lengths is not None:
            output,_=nn.utils.rnn.pad_packed_sequence(output)

在主循環(huán)中可以以這樣的方式傳遞數(shù)據(jù)蝴乔,未打包之前驮樊。

# Note this loop will go on FOREVER
for val_i, data in enumerate(train_iter):
  (x, x_lengths), y = data.Text, data.Description
  output = model(x, x_lengths)
  
  # terminate condition, when loss converges or it reaches 50000 iterations
  if loss converges or val_i == 50000:
    break

如果我們想訓(xùn)練一定的epoch囚衔,我們可以設(shè)置,repeat=False,這樣我們的數(shù)據(jù)會(huì)訓(xùn)練10個(gè)epoch然后停止下來(lái)猴仑。

# Note this loop will stop when training data is traversed once
epochs = 10

for epoch in range(epochs):
  for data in train_iter:
    (x, x_lengths), y = data.Text, data.Description
    # model running...

如果我們提前不知道應(yīng)該在多少個(gè)epoch的時(shí)候停止宁脊,比如當(dāng)精度達(dá)到某個(gè)值的時(shí)候停止贤姆,我們就可以設(shè)置repeat = True,然后在代碼里面break.

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末霞捡,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子赊琳,更是在濱河造成了極大的恐慌躏筏,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,496評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件埃碱,死亡現(xiàn)場(chǎng)離奇詭異砚殿,居然都是意外死亡芝囤,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,407評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門(mén),熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)挠轴,“玉大人岸晦,你說(shuō)我怎么就攤上這事睛藻〉暧。” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 162,632評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵包券,是天一觀的道長(zhǎng)溅固。 經(jīng)常有香客問(wèn)我兰珍,道長(zhǎng),這世上最難降的妖魔是什么亮元? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,180評(píng)論 1 292
  • 正文 為了忘掉前任爆捞,我火速辦了婚禮煮甥,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘肌访。我一直安慰自己艇劫,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,198評(píng)論 6 388
  • 文/花漫 我一把揭開(kāi)白布蟹演。 她就那樣靜靜地躺著,像睡著了一般顷蟀。 火紅的嫁衣襯著肌膚如雪酒请。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 51,165評(píng)論 1 299
  • 那天鸣个,我揣著相機(jī)與錄音羞反,去河邊找鬼。 笑死囤萤,一個(gè)胖子當(dāng)著我的面吹牛昼窗,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播澄惊,決...
    沈念sama閱讀 40,052評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼富雅!你這毒婦竟也來(lái)了掸驱?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 38,910評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤没佑,失蹤者是張志新(化名)和其女友劉穎毕贼,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體图筹,經(jīng)...
    沈念sama閱讀 45,324評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡帅刀,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,542評(píng)論 2 332
  • 正文 我和宋清朗相戀三年让腹,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片扣溺。...
    茶點(diǎn)故事閱讀 39,711評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡骇窍,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出锥余,到底是詐尸還是另有隱情腹纳,我是刑警寧澤,帶...
    沈念sama閱讀 35,424評(píng)論 5 343
  • 正文 年R本政府宣布驱犹,位于F島的核電站嘲恍,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏雄驹。R本人自食惡果不足惜佃牛,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,017評(píng)論 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望医舆。 院中可真熱鬧俘侠,春花似錦、人聲如沸蔬将。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,668評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)霞怀。三九已至惫东,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間毙石,已是汗流浹背廉沮。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,823評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留胁黑,地道東北人废封。 一個(gè)月前我還...
    沈念sama閱讀 47,722評(píng)論 2 368
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像丧蘸,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子遥皂,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,611評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容