MindSpore入門:使用LSTM進(jìn)行文本情感分析

MindSpore是華為最近開源的深度學(xué)習(xí)框架臣缀,根據(jù)官方的說法魄梯,開發(fā)這款深度學(xué)習(xí)框架主要是為了充分利用華為自研的昇騰AI處理器(Ascend)的硬件能力晚胡,當(dāng)然這款框架除了運(yùn)行在Ascend平臺(tái)也可以運(yùn)行在CPU和GPU上面档叔。由于該框架只開發(fā)到了0.3版本程奠,目前網(wǎng)絡(luò)上相關(guān)的資料比較少逢艘,所以這篇博客想要通過一個(gè)簡(jiǎn)單的小項(xiàng)目旦袋,介紹一下如何使用MindSpore訓(xùn)練一個(gè)深度學(xué)習(xí)模型。想要更深入的學(xué)習(xí)MindSpore可以訪問他的官網(wǎng):https://www.mindspore.cn 和項(xiàng)目代碼倉(cāng)庫(kù):https://gitee.com/mindspore/mindspore

這篇Notebook介紹如何使用MindSpore對(duì)IMDB數(shù)據(jù)集中的電影評(píng)論進(jìn)行情感分析它改。主要思路就是對(duì)電影評(píng)論中的單詞進(jìn)行詞嵌入處理疤孕,然后將處理后的數(shù)據(jù)送入LSTM模型,模型對(duì)評(píng)論進(jìn)行打標(biāo)簽(正面或者負(fù)面)央拖。

整個(gè)處理過程分為三個(gè)部分:

  • 準(zhǔn)備數(shù)據(jù):該教程使用的數(shù)據(jù)采用IMDB影評(píng)數(shù)據(jù)集祭阀,下載地址:http://ai.stanford.edu/~amaas/data/sentiment/ 鹉戚, 如果需要運(yùn)行該notebook你需要把下載之后的數(shù)據(jù)解壓之后放到 ./data/imdb目錄下。由于我們需要對(duì)評(píng)論中的單詞進(jìn)行詞嵌入處理专控,所以我們還需要用到預(yù)訓(xùn)練好的詞向量抹凳,這里我們不再自己去訓(xùn)練詞向量,而是直接采用GloVe伦腐,下載地址為:http://nlp.stanford.edu/data/glove.6B.zip赢底。 該文件解壓之后包含多個(gè)txt文件,多個(gè)txt文件的數(shù)據(jù)都是常用詞匯的詞向量柏蘑,只不過向量的維度不同幸冻,分為50、100辩越、200、300四種信粮,向量維度越高黔攒,詞向量的表達(dá)能力越強(qiáng)。你可以根據(jù)需要選擇一個(gè)文件使用强缘。將文件放到 ./data/glove目錄下面督惰。imdb和glove的下載完之后,我們需要將原始的文本數(shù)據(jù)經(jīng)過切詞旅掂、詞嵌入赏胚、對(duì)齊之后,保存為mindrecord格式商虐。

  • 模型訓(xùn)練:MindSpore為我們定義好了很多常用模型觉阅,我們可以直接從model_zoo中選擇基于LSTM實(shí)現(xiàn)的SentimentNet使用。

  • 模型評(píng)估:使用MindSpore定義好的接口可方便的對(duì)訓(xùn)練好的模型進(jìn)行評(píng)估秘车,比如準(zhǔn)確率等等典勇。

詳細(xì)的處理流程,可以參考下面的代碼叮趴。

準(zhǔn)備數(shù)據(jù)

import os
import math
from itertools import chain
import gensim
import numpy as np
from mindspore.mindrecord import FileWriter
def read_imdb(path, seg='train'):
    labels = ['pos', 'neg']
    data = []
    for label in labels:
        files = os.listdir(os.path.join(path, seg, label))
        for file in files:
            with open(os.path.join(path, seg, label, file), 'r', encoding='utf8') as rf:
                review = rf.read().replace('\n', '')
                if label == 'pos':
                    data.append([review, 1])
                elif label == 'neg':
                    data.append([review, 0])
    return data
def tokenize_samples(raw_data):
    tokenized_data = []
    for review in raw_data:
        tokenized_data.append([tok.lower() for tok in review.split()])
    return tokenized_data
def encode_samples(tokenized_samples, word_to_idx):
    """
    tokenized_samples: [[word, word, ...]]
    word_to_idx: {word:idx, word:idx, ...}
    features: [[idx, idx, ...], [idx, idx, ...], ...]
    """
    features = []
    for sample in tokenized_samples:
        feature = []
        for token in sample:
            feature.append(word_to_idx.get(token, 0))
        features.append(feature)
    return features
def pad_samples(features, maxlen=500, pad=0):
    padded_features = []
    for feature in features:
        if len(feature) >= maxlen:
            padded_feature = feature[:maxlen]
        else:
            padded_feature = feature
            while len(padded_feature) < maxlen:
                padded_feature.append(pad)
        padded_features.append(padded_feature)
    return padded_features
def prepare_data(imdb_data_path='./data/imdb/aclImdb'):      
    raw_data_train = read_imdb(imdb_data_path, seg='train')
    raw_data_test = read_imdb(imdb_data_path, seg='test')
    y_train = np.array([label for _, label in raw_data_train]).astype(np.int32)
    y_test = np.array([label for _, label in raw_data_test]).astype(np.int32)
    tokenized_data_train = tokenize_samples([review for review, _ in raw_data_train])
    tokenized_data_test = tokenize_samples([review for review, _ in raw_data_test])
    vocab = set(chain(*tokenized_data_train))
    word_to_idx = {word: i+1 for i, word in enumerate(vocab)}
    word_to_idx['<unk>'] = 0
    X_train = np.array(pad_samples(encode_samples(tokenized_data_train, word_to_idx))).astype(np.int32)
    X_test = np.array(pad_samples(encode_samples(tokenized_data_test, word_to_idx))).astype(np.int32)
    return X_train, y_train, X_test, y_test, word_to_idx
X_train, y_train, X_test, y_test, word_to_idx = prepare_data()
#!sed -i '1i\400000 50' ./data/glove/glove.6B.50d.txt
def load_embeddings(glove_file_path, word_to_idx, embed_size=50):
    word2vector = gensim.models.KeyedVectors.load_word2vec_format(
        glove_file_path, binary=False, encoding='utf-8')
    assert embed_size == word2vector.vector_size
    embeddings = np.zeros((len(word_to_idx), embed_size)).astype(np.float32)
    for word, idx in word_to_idx.items():
        try:
            embeddings[idx, :] = word2vector.word_vec(word)
        except KeyError:
            continue
    return embeddings
embeddings = load_embeddings('./data/glove/glove.6B.50d.txt', word_to_idx)
def get_json_data_list(X, y):
    data_list = []
    for i, (feature, label) in enumerate(zip(X, y)):
        data_json = {"id": i, "feature": feature.reshape(-1), "label": int(label)}
        data_list.append(data_json)
    return data_list
def convert_np_to_mindrecord(X_train, y_train, X_test, y_test, mindrecord_save_path="./data/mindrecord"):
    schema_json = {"id": {"type": "int32"},
                  "label": {"type": "int32"},
                  "feature": {"type": "int32", "shape": [-1]}}
    writer = FileWriter(os.path.join(mindrecord_save_path, "aclImdb_train.mindrecord"), shard_num=4)
    data_train = get_json_data_list(X_train, y_train)
    writer.add_schema(schema_json, "nlp_schema")
    writer.add_index(["id", "label"])
    writer.write_raw_data(data_train)
    writer.commit()
    
    writer = FileWriter(os.path.join(mindrecord_save_path, "aclImdb_test.mindrecord"), shard_num=4)
    data_test = get_json_data_list(X_test, y_test)
    writer.add_schema(schema_json, "nlp_schema")
    writer.add_index(["id", "label"])
    writer.write_raw_data(data_test)
    writer.commit()
!ls ./data/mindrecord
aclImdb_test.mindrecord0     aclImdb_train.mindrecord0
aclImdb_test.mindrecord0.db  aclImdb_train.mindrecord0.db
aclImdb_test.mindrecord1     aclImdb_train.mindrecord1
aclImdb_test.mindrecord1.db  aclImdb_train.mindrecord1.db
aclImdb_test.mindrecord2     aclImdb_train.mindrecord2
aclImdb_test.mindrecord2.db  aclImdb_train.mindrecord2.db
aclImdb_test.mindrecord3     aclImdb_train.mindrecord3
aclImdb_test.mindrecord3.db  aclImdb_train.mindrecord3.db
np.savetxt("./data/mindrecord/weight.txt", embeddings)
convert_np_to_mindrecord(X_train, y_train, X_test, y_test)

創(chuàng)建數(shù)據(jù)集

import mindspore.dataset as mds
def create_dataset(base_path, batch_size, num_epochs, is_train):
    columns_list = ["feature", "label"]
    num_consumer = 4
    if is_train:
        path = os.path.join(base_path, "aclImdb_train.mindrecord0")
    else:
        path = os.path.join(base_path, "aclImdb_test.mindrecord0")
    dataset = mds.MindDataset(path, columns_list=["feature", "label"], num_parallel_workers=4)
    dataset = dataset.shuffle(buffer_size=dataset.get_dataset_size())
    dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
    dataset = dataset.repeat(count=num_epochs)
    return dataset
dataset_train = create_dataset("./data/mindrecord", batch_size=32, num_epochs=10, is_train=True)

定義模型并訓(xùn)練

from mindspore import Tensor, nn, Model, context, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.model_zoo.lstm import SentimentNet
embedding_tabel = np.loadtxt(os.path.join("./data/mindrecord", "weight.txt")).astype(np.float32)
network = SentimentNet(vocab_size=embedding_tabel.shape[0],
                embed_size=50,
                num_hiddens=100,
                num_layers=2,
                bidirectional=False,
                num_classes=2,
                weight=Tensor(embedding_tabel),
                batch_size=32)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(network.trainable_params(), 0.1, 0.9)
loss_callback = LossMonitor(per_print_times=60)
model = Model(network, loss, opt, {'acc': Accuracy()})
config_ck = CheckpointConfig(save_checkpoint_steps=390, keep_checkpoint_max=10)
checkpoint_cb = ModelCheckpoint(prefix="lstm", directory="./model", config=config_ck)
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
model.train(10, dataset_train, callbacks=[checkpoint_cb, loss_callback], dataset_sink_mode=False)

評(píng)估模型

dataset_test = create_dataset("./data/mindrecord", batch_size=32, num_epochs=10, is_train=False)
acc = model.eval(dataset_test)
print("accuracy:{}".format(acc))
accuracy:{'acc': 0.6604833546734955}

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末割笙,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子眯亦,更是在濱河造成了極大的恐慌伤溉,老刑警劉巖,帶你破解...
    沈念sama閱讀 221,888評(píng)論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件妻率,死亡現(xiàn)場(chǎng)離奇詭異乱顾,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)宫静,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,677評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門糯耍,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)扔字,“玉大人,你說我怎么就攤上這事温技「镂” “怎么了?”我有些...
    開封第一講書人閱讀 168,386評(píng)論 0 360
  • 文/不壞的土叔 我叫張陵舵鳞,是天一觀的道長(zhǎng)震檩。 經(jīng)常有香客問我,道長(zhǎng)蜓堕,這世上最難降的妖魔是什么抛虏? 我笑而不...
    開封第一講書人閱讀 59,726評(píng)論 1 297
  • 正文 為了忘掉前任,我火速辦了婚禮套才,結(jié)果婚禮上迂猴,老公的妹妹穿的比我還像新娘。我一直安慰自己背伴,他們只是感情好沸毁,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,729評(píng)論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著傻寂,像睡著了一般息尺。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上疾掰,一...
    開封第一講書人閱讀 52,337評(píng)論 1 310
  • 那天搂誉,我揣著相機(jī)與錄音,去河邊找鬼静檬。 笑死炭懊,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的拂檩。 我是一名探鬼主播凛虽,決...
    沈念sama閱讀 40,902評(píng)論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼广恢!你這毒婦竟也來(lái)了凯旋?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,807評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤钉迷,失蹤者是張志新(化名)和其女友劉穎至非,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體糠聪,經(jīng)...
    沈念sama閱讀 46,349評(píng)論 1 318
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡荒椭,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,439評(píng)論 3 340
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了舰蟆。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片趣惠。...
    茶點(diǎn)故事閱讀 40,567評(píng)論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡狸棍,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出味悄,到底是詐尸還是另有隱情草戈,我是刑警寧澤,帶...
    沈念sama閱讀 36,242評(píng)論 5 350
  • 正文 年R本政府宣布侍瑟,位于F島的核電站唐片,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏涨颜。R本人自食惡果不足惜费韭,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,933評(píng)論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望庭瑰。 院中可真熱鬧星持,春花似錦、人聲如沸弹灭。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,420評(píng)論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)鲤屡。三九已至损痰,卻和暖如春福侈,著一層夾襖步出監(jiān)牢的瞬間酒来,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,531評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工肪凛, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留堰汉,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,995評(píng)論 3 377
  • 正文 我出身青樓伟墙,卻偏偏與公主長(zhǎng)得像翘鸭,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子戳葵,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,585評(píng)論 2 359