本文代碼參考:
https://github.com/percent4/pytorch_transformer_chinese_text_classification
1. 數(shù)據(jù)集預處理
- 主要通過對數(shù)據(jù)集的預處理得到字庫與標簽庫锡垄,并把字庫與標簽分表序列化存儲到文件咆爽。
- 標簽存儲文件:
labels.pk
- 字庫存儲文件:
chars.pk
- 標簽存儲文件:
import os
import pickle
import pandas as pd
from random import shuffle
from operator import itemgetter
from collections import Counter, defaultdict
1.1. 文件保存的封裝實現(xiàn)
# pickle文件操作
class PickleFileOperator:
def __init__(self, data=None, file_path=''):
self.data = data
self.file_path = file_path
def save(self):
with open(self.file_path, 'wb') as f:
pickle.dump(self.data, f)
def read(self):
with open(self.file_path, "rb") as f:
content = pickle.load(f)
return content
1.2. 預處理數(shù)據(jù)集
- 通過對數(shù)據(jù)集的統(tǒng)計得到標簽庫與字庫。
(1)數(shù)據(jù)集預處理參數(shù)
# 數(shù)據(jù)集文件
DATASETS_DIR = "./datasets"
TRAIN_FILE_PATH = os.path.join(DATASETS_DIR, 'train.csv') # 訓練數(shù)據(jù)集
TEST_FILE_PATH = os.path.join(DATASETS_DIR, 'test.csv') # 測試數(shù)據(jù)集
# 字庫的最大數(shù)量
NUM_WORDS = 5500 # 用來限制字庫的最大容量疑故,
(2)數(shù)據(jù)集預處理
- 打開數(shù)據(jù)集,讀取標簽與內(nèi)容。
- 對標簽唯一化處理昧识,得到標簽庫
- 對內(nèi)容,統(tǒng)計得到字庫盗扒,字庫是打亂后跪楞,隨機取
NUM_WORDS
大的數(shù)量。這個操作是認為產(chǎn)生一些不在統(tǒng)計范圍的字侣灶,降低模型訓練的擬合性甸祭。
class FilePreprossing(object):
def __init__(self, n):
# 保留前n個高頻字
self.__n = n
def _read_train_file(self):
train_pd = pd.read_csv(TRAIN_FILE_PATH)
label_list = train_pd['label'].unique().tolist()
# 統(tǒng)計文字頻數(shù)
character_dict = defaultdict(int)
for content in train_pd['content']:
for key, value in Counter(content).items():
character_dict[key] += value
# 不排序
sort_char_list = [(k, v) for k, v in character_dict.items()]
shuffle(sort_char_list)
# 排序
# sort_char_list = sorted(character_dict.items(), key=itemgetter(1), reverse=True)
print(f'數(shù)據(jù)集共計 {len(character_dict)} 漢字.')
print('隨機打亂后,前10個字的統(tǒng)計: ', sort_char_list[:10])
# 保留前n個文字
top_n_chars = [_[0] for _ in sort_char_list[:self.__n]] # 這里只保留了前n個字褥影。(注意:這里對漢字沒有采用分詞處理淋叶,而是直接處理字)
print("最終字庫的總數(shù):", len(top_n_chars))
return label_list, top_n_chars
def run(self):
label_list, top_n_chars = self._read_train_file()
PickleFileOperator(data=label_list, file_path='labels.pk').save()
PickleFileOperator(data=top_n_chars, file_path='chars.pk').save()
(3)執(zhí)行數(shù)據(jù)集預處理
processor = FilePreprossing(NUM_WORDS)
processor.run()
數(shù)據(jù)集共計 5259 漢字.
隨機打亂后,前10個字的統(tǒng)計: [('蠣', 3), ('娼', 9), ('Ⅱ', 43), ('識', 2534), ('釣', 36), ('座', 825), ('晁', 6), ('按', 1800), ('迢', 2), ('伎', 8)]
最終字庫的總數(shù): 5259
labels = PickleFileOperator(file_path='labels.pk').read()
print("標簽:", labels)
content = PickleFileOperator(file_path='chars.pk').read()
print("字庫的前10個字(隨機打亂的伪阶,沒有按照統(tǒng)計數(shù)量排序):", content[:10])
標簽: ['體育', '健康', '軍事', '教育', '汽車']
字庫的前10個字(隨機打亂的煞檩,沒有按照統(tǒng)計數(shù)量排序): ['蠣', '娼', 'Ⅱ', '識', '釣', '座', '晁', '按', '迢', '伎']
2. 數(shù)據(jù)集特征處理(數(shù)據(jù)集工程)
- 把數(shù)據(jù)集處理成在模型可以使用的格式:向量。
- 模型訓練
- 模型測試
- 模型評估
- 模型推理
- 在PyTorch中需要處理成Dataset與DataLoader栅贴。
- Dataset是數(shù)據(jù)集格式斟湃。
- DataLoader是批次數(shù)據(jù)集格式。
- 數(shù)據(jù)集特征處理主要是文本向量化檐薯,向量化技術(shù)很多凝赛,這里采用的向量化方式:
- 對字庫編號,使用編號代替字坛缕,實現(xiàn)數(shù)值化墓猎,
- 文本字符串就可以輕松轉(zhuǎn)化為向量。
- 為了保證向量的維度形狀一致赚楚,對每個句子都進行了對齊處理:
- 指定一個對齊長度SENT_LENGTH
- 大于SENT_LENGTH的句子截斷處理
- 小于SENT_LENGTH長度的進行補齊毙沾,補齊的字符統(tǒng)一采用PAD定義字符替代,補齊字符的編號采用PAD_NO定義的編號宠页,一般用0
- 如果碰見字庫中沒有的字左胞,則使用UNK定義的字符替代寇仓。UNK的字符編號采用UNK_NO定義的編號,一般使用1烤宙。
import pandas as pd
import numpy as np
import torch as T
from torch.utils.data import Dataset, random_split
2.1. 讀取標簽庫與字庫
- 把上面與處理的字庫與標簽數(shù)值化:
- 使用順序編號遍烦,把編號與字與標簽對應起來。
- 采用字典存儲標簽與編碼的對應關(guān)系躺枕,存儲字與編號對應關(guān)系服猪。
- 后面可以通過字典把預測的編號還原為字與文本標簽。
# 讀取pickle文件
def load_file_file():
labels = PickleFileOperator(file_path='labels.pk').read()
chars = PickleFileOperator(file_path='chars.pk').read()
label_dict = dict(zip(labels, range(len(labels))))
char_dict = dict(zip(chars, range(len(chars))))
return label_dict, char_dict
l_d, c_d = load_file_file()
print(l_d)
print(list(c_d.items())[:5])
{'體育': 0, '健康': 1, '軍事': 2, '教育': 3, '汽車': 4}
[('蠣', 0), ('娼', 1), ('Ⅱ', 2), ('識', 3), ('釣', 4)]
2.2. 讀取數(shù)據(jù)集樣本與標簽
- 讀取csv文件拐云,一行一個樣本蔓姚。
# load csv file
def load_csv_file(file_path):
df = pd.read_csv(file_path)
samples, y_true = [], []
for index, row in df.iterrows():
y_true.append(row['label'])
samples.append(row['content'])
return samples, y_true
s, l = load_csv_file(TRAIN_FILE_PATH)
print(s[:2])
print(l[:2])
['中國“鐵腰”與英超球隊埃弗頓分道揚鑣,閃電般轉(zhuǎn)投謝聯(lián)(本賽季成功升入英超)慨丐,此事運作速度之快令人驚詫坡脐。針對李鐵與埃弗頓“分手”的原因、與埃弗頓主帥莫耶斯矛盾以及鐵子為何選擇謝聯(lián)房揭,記者昨日采訪了李鐵的母親王桂芹备闲,李母道出了李鐵與埃弗頓分開的真實原因。龍菲堅決讓鐵子走人記者在采訪王桂芹時了解到捅暴,李鐵離開埃弗頓主要是妻子龍菲建議恬砂。龍菲平時不太過問李鐵的足球方面事宜,但是蓬痒,因為李鐵長時間不能在埃弗頓踢上球泻骤,龍菲也十分焦急。多次安慰李鐵后梧奢,龍菲想這樣下去也不是個辦法狱掂,于是索性做出決定,讓李鐵離開埃弗頓亲轨,只要能踢上球趋惨,去哪支球隊都行。但前提條件必須是英國的球隊惦蚊。王媽媽告訴記者:“媳婦龍菲一直在英國學習器虾,這孩子特別懂事,一邊學習蹦锋,一邊還要照顧女兒和李鐵的日常生活兆沙。對于李鐵與埃弗頓的前前后后,龍菲一直都了解內(nèi)情莉掂,因此龍菲最后告訴李鐵葛圃,就是埃弗頓再請我們,我們也不去了,只要能離開埃弗頓装悲,去哪支球隊踢球都行昏鹃∩蟹眨”據(jù)悉诀诊,龍菲2001年便在沈陽拿到了留學英國全額獎學金的錄取通知書,而后阅嘶,龍菲便一直在英國求學属瓣。紅牌讓李鐵失去位置“拼命三郎”、“跑不死”讯柔、“體能王”抡蛙,這些溢美之詞都是稱贊李鐵的,不過魂迄,正是因為李鐵防守時的動作過于兇狠粗截,在英超的賽場上屢次領(lǐng)到紅牌。過多的紅牌讓主帥莫耶斯逐漸對李鐵失去了興趣捣炬。對此熊昌,王媽媽向記者表示:“關(guān)于球隊的相關(guān)事宜,我不太清楚湿酸。因為埃弗頓主教練莫耶斯一直都很器重李鐵婿屹,因為莫耶斯的戰(zhàn)術(shù)比較偏重防守。而在防守過程中推溃,李鐵也確實吃到過紅牌昂利,但是,我覺得教練組和俱樂部不能因為紅牌的原因不讓李鐵上場吧铁坎》浼椋”許宏濤挽救李鐵李鐵成功轉(zhuǎn)會埃弗頓的整個過程,國內(nèi)足球著名的經(jīng)紀人許宏濤功不可沒硬萍。而李鐵在埃弗頓后期四處碰壁的危難時刻窝撵,正是許宏濤的左右逢源,令李鐵還能堅持在埃弗頓預備隊踢球襟铭。后來碌奉,在埃弗頓摒棄李鐵后,也正是許宏濤的人脈關(guān)系讓李鐵再次找到了位置寒砖。根據(jù)李鐵與埃弗頓簽訂的合同赐劣,今年6月30日,工作合同才到期哩都。但是魁兼,考慮到李鐵與俱樂部的關(guān)系日益緊張,許宏濤便提早聯(lián)系了英超其他球隊漠嵌,重點便是當時的英甲球隊謝聯(lián)咐汞。由于許宏濤是謝聯(lián)董事會的成員盖呼,多次與俱樂部溝通李鐵的事宜后,俱樂部最終同意了李鐵加盟謝聯(lián)化撕。對李鐵加盟謝聯(lián)的事情几晤,王桂芹不愿多談。她只是表示:“許宏濤一直在幫助李鐵植阴,特別是在英國蟹瘾,許宏濤非常熟悉那里的環(huán)境,李鐵也非常信任他掠手,這下也好憾朴,李鐵可以同郝海東一塊踢球了∨绺耄”首席記者賈瓊', '拉齊奧獲不利排序意甲本周末拉齊奧與帕爾馬之戰(zhàn)為收官階段表現(xiàn)較為突出的兩支球隊之間的較量众雷,兩隊在最近10場比賽中均取得了其中6戰(zhàn)的勝利,主隊因此提前鎖定了聯(lián)盟杯的參賽資格做祝,客隊更是借此早早就擺脫了賽季中段的降組威脅砾省。目前本場比賽的賠率為主隊優(yōu)勢的正向賠率,主勝賠率在全部42個賠率中排在了第4位剖淀,而此前的5輪競猜中纯蛾,該點位全部正路開出,已經(jīng)達到了本賽季的最大值纵隔,本周末該結(jié)果極易走冷衅谷,投注時建議一搏冷門賽果铸本。沙爾克得大莊信任近來狀態(tài)不佳的沙爾克04本周末將在主場對陣斯圖加特,目前本場比賽的主勝賠率處于1.70的賠率區(qū)間內(nèi),此類賠率在過去5個賽季的德甲聯(lián)賽中出現(xiàn)次數(shù)較多屯吊,其戰(zhàn)果統(tǒng)計也相對較為正路吟榴,屬于純實力對比賠率塘幅。不過在對本場比賽的賠率進行比較后可發(fā)現(xiàn)骤竹,奧地利著名博彩公司必贏為此戰(zhàn)開出的1.65-3.35-5.10賠率組合對主隊十分有利,主勝位明顯低于目前的平均賠率俄认,可見該公司極為看好沙爾克周末取勝个少,投注時可一搏主勝。藍黑軍團慎防大冷在聯(lián)賽后半程排名已經(jīng)敲定的情況下眯杏,國際米蘭果然如人們所預料那樣又一次出現(xiàn)了內(nèi)耗問題夜焦,內(nèi)部的爭斗也使得曼齊尼的球隊在近幾輪比賽中連嘗苦果,而且他們在周中又要接受意大利杯的考驗岂贩,這對他們周末客場對陣卡利亞里的比賽必然要造成一定的影響茫经。此外,我們在上期曾重點提示過,足彩競猜第1場位的賽果已連續(xù)5輪正路開出卸伞,而且也13輪沒有出現(xiàn)過賠率末選抹镊,兩項均達到了本賽季的極限值,周末卡利亞里很有可能爆大冷擊敗藍黑軍團荤傲。本人心水(256元)300311033010103311330']
['體育', '體育']
2.3. 樣本特征處理:句子向量化
- 使用上面標簽庫與字庫的編號字段垮耳,把每個樣本的句子數(shù)值化。
(1)文本向量化處理的參數(shù)
PAD = '<PAD>'
PAD_NO = 0
UNK = '<UNK>'
UNK_NO = 1
START_NO = UNK_NO + 1
SENT_LENGTH = 200
(2) 文本向量化
- 對句子進行長度對齊弃酌,并根據(jù)編號字典氨菇,對句子數(shù)值化儡炼。
# 文本預處理
def text_feature(labels, contents, label_dict, char_dict):
samples, y_true = [], []
for s_label, s_content in zip(labels, contents):
y_true.append(label_dict[s_label])
train_sample = []
for char in s_content:
if char in char_dict:
train_sample.append(START_NO + char_dict[char])
else:
train_sample.append(UNK_NO)
# 補充或截斷
if len(train_sample) < SENT_LENGTH:
samples.append(train_sample + ([PAD_NO] * (SENT_LENGTH - len(train_sample))))
else:
samples.append(train_sample[:SENT_LENGTH])
return samples, y_true
digit_s, digit_l = text_feature(l, s, l_d, c_d)
print("對齊后妓湘,并數(shù)值化的句子:\n", digit_s[:2])
print("數(shù)值化后的標簽:\n", digit_l[:2])
對齊后,并數(shù)值化的句子:
[[3436, 3896, 5108, 4825, 2824, 2987, 3767, 1156, 121, 226, 1544, 4917, 1991, 2640, 267, 2148, 4616, 25, 605, 36, 2592, 1219, 1025, 975, 4833, 3894, 4556, 3815, 2014, 713, 5133, 2235, 172, 2819, 1156, 121, 1746, 605, 2411, 2527, 2929, 757, 1733, 3054, 4573, 2902, 3486, 4517, 4907, 1868, 417, 963, 4219, 1116, 4825, 3767, 4917, 1991, 2640, 5108, 267, 2916, 2987, 3050, 3417, 3620, 4894, 3767, 4917, 1991, 2640, 2388, 2364, 1216, 93, 1603, 4283, 2074, 2368, 1885, 4825, 308, 5105, 969, 5070, 779, 4833, 3894, 605, 1037, 408, 2517, 4988, 1988, 489, 1142, 1116, 4825, 3050, 5101, 3624, 1541, 1173, 5212, 605, 1116, 5101, 2148, 4208, 1142, 1116, 4825, 3767, 4917, 1991, 2640, 267, 3626, 3050, 2236, 4003, 3417, 3620, 417, 4191, 3597, 4010, 4532, 4419, 4825, 308, 4398, 4517, 1037, 408, 625, 1988, 489, 1541, 1173, 5212, 3105, 1142, 227, 3748, 605, 1116, 4825, 2574, 3626, 4917, 1991, 2640, 2388, 1118, 1066, 1702, 308, 4191, 3597, 4553, 4950, 417, 4191, 3597, 159, 3105, 1386, 1605, 1106, 5085, 1116, 4825, 3050, 2415, 226, 2498, 149, 2527, 4718, 605, 4852, 1066, 605, 3620, 5105, 1116, 4825, 2522, 3105, 4471, 1386, 2356, 625, 4917, 1991, 2640, 4362, 4266, 226], [880, 4522, 3895, 4507, 1386, 2155, 2675, 2259, 2875, 4366, 3815, 568, 2924, 880, 4522, 3895, 3767, 4426, 3704, 491, 4573, 1513, 5105, 3076, 548, 4569, 2430, 4964, 186, 3318, 5105, 542, 4208, 3050, 2180, 5128, 226, 1544, 4573, 4471, 3050, 3318, 563, 605, 2180, 1544, 625, 125, 1826, 3401, 758, 23, 216, 2014, 3436, 361, 4445, 3985, 1142, 3407, 3436, 3121, 1513, 3050, 1133, 2155, 605, 2388, 1544, 3620, 2411, 187, 4511, 2903, 2269, 1142, 3894, 2379, 3180, 3050, 3760, 2014, 1726, 2648, 605, 4584, 1544, 3667, 1066, 4712, 2411, 3576, 3576, 3963, 1479, 1496, 1142, 2014, 713, 3436, 2430, 3050, 731, 4837, 2606, 3613, 417, 41, 4511, 3815, 23, 216, 2014, 3050, 1922, 1317, 5105, 2388, 1544, 2158, 4233, 3050, 1026, 3863, 1922, 1317, 605, 2388, 1133, 1922, 1317, 625, 1866, 4056, 4140, 194, 2725, 1922, 1317, 3436, 2675, 625, 1142, 20, 4140, 196, 605, 3657, 2411, 4511, 3050, 1788, 4187, 3850, 993, 3436, 605, 3084, 1228, 196, 1866, 4056, 1026, 5071, 3626, 4208, 605, 200, 2424, 1209, 3748, 1142, 3815, 2014, 713, 3050, 125, 3107, 1931, 605, 3815, 568, 2924, 3084, 3580, 3047, 327, 5162, 4398, 2087, 605, 975, 4531, 3105, 4553, 4950, 4804, 3096, 2087, 4535]]
數(shù)值化后的標簽:
[0, 0]
2.4 生成PyTorch的數(shù)據(jù)集格式
- 因為我們使用PyTorch乌询,所以采用PyTorch的Dataset實現(xiàn)數(shù)據(jù)集榜贴。方便后面訓練,驗證妹田,測試使用唬党。
# Dataset類實現(xiàn)
class CSVDataset(Dataset):
# load the dataset
def __init__(self, file_path):
label_dict, char_dict = load_file_file() # 讀取標簽庫與字庫
samples, y_true = load_csv_file(file_path) # 加載數(shù)據(jù)集樣本
x, y = text_feature(y_true, samples, label_dict, char_dict)
# 轉(zhuǎn)換為張量
self.X = T.from_numpy(np.array(x)).long()
self.y = T.from_numpy(np.array(y))
# 數(shù)據(jù)集樣本數(shù)
def __len__(self):
return len(self.X)
# 返回指定索引的數(shù)據(jù)樣本與標簽,這是下標運算符鬼佣。
def __getitem__(self, idx):
return [self.X[idx], self.y[idx]]
# 根據(jù)比例把數(shù)據(jù)集分成訓練集與測試集驶拱。
def get_splits(self, n_test=0.3):
# determine sizes
test_size = round(n_test * len(self.X))
train_size = len(self.X) - test_size
# calculate the split
return random_split(self, [train_size, test_size])
ds = CSVDataset(TRAIN_FILE_PATH)
print(ds[0])
[tensor([3436, 3896, 5108, 4825, 2824, 2987, 3767, 1156, 121, 226, 1544, 4917,
1991, 2640, 267, 2148, 4616, 25, 605, 36, 2592, 1219, 1025, 975,
4833, 3894, 4556, 3815, 2014, 713, 5133, 2235, 172, 2819, 1156, 121,
1746, 605, 2411, 2527, 2929, 757, 1733, 3054, 4573, 2902, 3486, 4517,
4907, 1868, 417, 963, 4219, 1116, 4825, 3767, 4917, 1991, 2640, 5108,
267, 2916, 2987, 3050, 3417, 3620, 4894, 3767, 4917, 1991, 2640, 2388,
2364, 1216, 93, 1603, 4283, 2074, 2368, 1885, 4825, 308, 5105, 969,
5070, 779, 4833, 3894, 605, 1037, 408, 2517, 4988, 1988, 489, 1142,
1116, 4825, 3050, 5101, 3624, 1541, 1173, 5212, 605, 1116, 5101, 2148,
4208, 1142, 1116, 4825, 3767, 4917, 1991, 2640, 267, 3626, 3050, 2236,
4003, 3417, 3620, 417, 4191, 3597, 4010, 4532, 4419, 4825, 308, 4398,
4517, 1037, 408, 625, 1988, 489, 1541, 1173, 5212, 3105, 1142, 227,
3748, 605, 1116, 4825, 2574, 3626, 4917, 1991, 2640, 2388, 1118, 1066,
1702, 308, 4191, 3597, 4553, 4950, 417, 4191, 3597, 159, 3105, 1386,
1605, 1106, 5085, 1116, 4825, 3050, 2415, 226, 2498, 149, 2527, 4718,
605, 4852, 1066, 605, 3620, 5105, 1116, 4825, 2522, 3105, 4471, 1386,
2356, 625, 4917, 1991, 2640, 4362, 4266, 226]), tensor(0, dtype=torch.int32)]
3. 詞嵌入處理
- 一般會直接使用詞嵌入,但是這里對漢字使用預訓練的詞嵌入方式晶衷,對句子進行特征向量化處理蓝纲,可以確保模型訓練效果更好。
- 這里的預訓練模型晌纫,采用維基百科中語料庫訓練的詞向量模型:
sgns.wiki.char.bz2
- 每個字都需要轉(zhuǎn)換為向量税迷。
- 這里的預訓練模型晌纫,采用維基百科中語料庫訓練的詞向量模型:
- 下面使用上面字庫中的字,查詢已經(jīng)預訓練的詞向量模型中訓練的向量锹漱,得到滿足我們這里使用的詞向量
import torch
from gensim.models import KeyedVectors
# 讀取標簽庫與字庫
label_dict, char_dict = load_file_file()
# 加載預訓練的詞向量模型
em_model = KeyedVectors.load_word2vec_format('./datasets/sgns.wiki.char.bz2',
binary=False,
encoding="utf-8",
unicode_errors="ignore")
# 使用gensim載入word2vec詞向量
"""
4是考慮未來加入四個特殊字符:<PAD>箭养,<UNK>,<START>哥牍,<END>
300是預訓練的時候就設(shè)置為300:具體可以參考:https://github.com/Embedding/Chinese-Word-Vectors?tab=readme-ov-file
實際這里數(shù)據(jù)集統(tǒng)計的字庫沒有5500毕泌,下面pretrained_vector的后面行都是0。
"""
pretrained_vector = torch.zeros(NUM_WORDS + 4, 300).float() # 存放字庫中每個字的詞向量
# print(model.index2word)
for char, index in char_dict.items():
if char in em_model.key_to_index:
# 把字轉(zhuǎn)換為向量
vector = em_model.get_vector(char)
# print(vector)
pretrained_vector[index, :] = torch.from_numpy(vector.copy()) # 使用copy是因為get_vector返回的numpy數(shù)組是不可寫的嗅辣。不加會有警告
print(vector.flags['WRITEABLE'])
print(vector.copy().flags['WRITEABLE'])
False
True
pretrained_vector[-1]
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
pretrained_vector[0]
tensor([-2.5166e-01, 2.1260e-03, -7.5505e-01, -5.7399e-02, -4.0988e-02,
-2.5291e-01, 5.3310e-02, -1.3894e-02, 4.3891e-01, -3.8147e-02,
-2.0128e-01, -5.9504e-02, 4.8097e-02, 1.0573e-01, -6.1304e-02,
-2.1859e-01, 4.8095e-01, -2.3189e-01, -4.5559e-01, 3.5048e-01,
2.8622e-01, 1.5197e-01, -4.5313e-02, -5.5626e-02, -8.5551e-02,
-6.2766e-02, 9.7919e-02, 3.8548e-01, 1.8273e-01, 5.4902e-02,
-4.4603e-02, -2.7428e-01, 1.7315e-02, 4.8173e-02, -1.0188e-02,
-1.1564e-01, 1.3562e-01, -8.5669e-02, -1.2031e-01, 3.4792e-01,
1.6377e-01, 1.7365e-01, 4.0493e-01, -2.6211e-01, -2.8300e-01,
-2.4447e-02, -1.7962e-01, -9.1980e-03, 2.4517e-01, 1.4564e-01,
1.8893e-01, 6.1344e-01, 1.4634e-01, -3.7221e-01, 1.3984e-01,
-1.6315e-01, 1.7710e-02, -2.2459e-01, 3.1234e-02, -2.7014e-01,
-2.1277e-01, -9.9185e-02, 1.1965e-01, -3.5157e-02, -3.1400e-04,
1.0341e-01, -4.5918e-01, -1.3590e-01, -1.9133e-01, -2.5318e-01,
1.6349e-01, -3.9125e-01, -7.9844e-02, -3.9014e-02, 4.2832e-01,
5.4695e-02, -3.2569e-01, -1.0863e-01, 1.0948e-01, 5.0902e-01,
6.1270e-01, -1.9650e-01, -4.1784e-02, 9.8486e-02, -2.8914e-01,
1.1830e-01, 2.1662e-01, -2.6285e-01, 2.2754e-01, -3.3230e-01,
-2.9382e-01, -2.1537e-01, -5.5550e-01, -3.0106e-02, 6.3398e-02,
8.9900e-03, 2.1025e-01, 1.2269e-01, -2.0311e-01, -3.8709e-01,
-3.1572e-01, 3.0690e-01, -1.9227e-01, -3.4366e-01, 7.7620e-02,
4.7994e-01, 2.4226e-01, -1.0725e-01, -1.0820e-01, 1.4680e-01,
-1.6433e-01, 1.7356e-01, 1.6682e-02, 2.3170e-01, -1.2936e-01,
1.8013e-01, -1.2464e-01, -2.2828e-01, -2.3223e-01, -5.9250e-03,
2.3588e-01, -3.7569e-01, 5.0721e-01, 7.0246e-01, -2.6877e-01,
2.8580e-03, -5.8815e-01, -2.3668e-01, 1.0971e-01, -8.2170e-03,
1.2551e-01, 1.8670e-02, 4.5151e-01, 6.8175e-02, 2.0498e-01,
4.5140e-01, 5.1324e-01, -4.8228e-02, -1.7520e-03, 6.9598e-02,
-4.7379e-02, -2.3501e-01, -4.1574e-01, 1.1202e-01, -4.1136e-01,
-2.2400e-01, -1.1157e-01, 3.9643e-01, 1.7197e-01, -7.1166e-02,
2.2666e-01, 4.9972e-01, -5.9917e-01, -5.2575e-01, -3.8444e-01,
2.9197e-01, -2.6319e-01, -2.6827e-01, -1.7151e-01, -3.2219e-01,
-1.5482e-01, 4.4596e-01, 1.1041e-01, 3.2358e-01, 1.1809e-01,
7.4830e-03, 3.9770e-01, 2.3340e-01, 6.3971e-01, -7.0496e-01,
-1.2747e-01, 1.5125e-01, 2.0257e-01, 2.9059e-01, 5.4421e-02,
-5.9573e-01, 1.8627e-02, -2.0663e-01, 2.4536e-01, -3.1686e-01,
1.5185e-01, 3.5283e-02, -2.4756e-01, 2.7790e-01, -1.1016e-01,
-1.4018e-01, 2.4151e-01, -7.5792e-02, -4.4470e-01, -3.0382e-01,
8.3656e-02, -1.0520e-01, -6.6970e-03, 2.0030e-01, -2.7011e-01,
1.0509e-01, 2.1204e-01, 1.9944e-01, -2.2444e-01, -1.9029e-01,
-3.3236e-01, -7.9911e-02, -3.7321e-01, 9.8192e-02, -1.9179e-01,
2.6793e-01, 4.5805e-01, -2.5262e-01, -1.1888e-01, -2.9169e-01,
2.9650e-01, 4.0774e-01, -1.3908e-01, 1.6033e-01, -4.0140e-02,
-3.6502e-01, 2.9890e-01, 6.8221e-01, -4.8779e-01, 2.5828e-01,
-2.7593e-01, -1.2254e-01, -3.9470e-02, -2.1260e-01, -2.3199e-02,
-2.7077e-01, 3.8680e-02, -1.8343e-01, -2.1692e-02, -2.4166e-01,
1.1560e-01, 8.0079e-02, 2.1750e-03, -1.9942e-02, -2.9017e-01,
-2.7840e-01, 2.2855e-01, -3.2480e-01, -2.2139e-01, 1.9187e-01,
-4.6475e-01, 5.2336e-01, 5.4522e-01, -1.0142e-01, -3.1336e-01,
1.4690e-01, 8.9748e-02, 2.2159e-01, -7.5918e-01, -2.7461e-01,
-8.2008e-02, -3.2914e-01, -2.8129e-01, 3.4548e-01, 2.3467e-01,
-5.6391e-02, 1.4375e-02, 3.8655e-01, -2.0344e-01, -2.4192e-01,
5.5580e-01, -3.3075e-01, 2.6455e-01, 3.5124e-01, 2.4330e-01,
1.5741e-01, 1.0453e-02, 8.6976e-02, 3.3163e-01, 2.9760e-01,
5.1001e-02, 7.9290e-02, 4.2176e-01, 3.4901e-02, 6.3282e-01,
-3.1701e-01, 3.7667e-01, -1.0663e-01, -2.6375e-01, -5.9062e-01,
2.2802e-01, 1.2913e-01, 5.9333e-01, -1.1817e-02, -1.9145e-02,
9.6389e-02, 2.0213e-01, 2.5641e-01, 5.0276e-01, 3.5181e-02,
-3.3445e-01, -5.2460e-03, -1.3024e-01, 3.0163e-01, -2.7992e-01,
-2.3243e-01, -1.5426e-01, -4.0426e-01, -1.8360e-02, 8.2140e-03])
- 下面簡單演示下撼泛,詞嵌入向量的使用
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# 詞嵌入對象
emb = nn.Embedding.from_pretrained(pretrained_vector, freeze=False, padding_idx=0)
# 對數(shù)據(jù)集進行批次處理
train_dl = DataLoader(ds, batch_size=5, shuffle=True) # 每個批次5個樣本
for x_batch, y_batch in train_dl:
v_emb = emb(x_batch)
print(v_emb.shape)
break
torch.Size([5, 200, 300])
-
詞向量化后,一個句子中的每個詞都轉(zhuǎn)化為一個向量辩诞。
- 5是一個批次中的樣本數(shù):5個句子
- 200是句子長度
- 300是預訓練的詞向量的特征維度坎弯。就是每個詞使用300長的向量表示其特征。
因為采用批次的方式,所以每個句子需要補齊抠忘。這樣才能滿足矩陣運算中對形狀的要求撩炊。
4. Transformer模型實現(xiàn)
- 這里的模型沒有使用PyTroch進行原生實現(xiàn),而是利用PyTorch的封裝實現(xiàn):
- TransformerEncoderLayer:編碼器單元
- TransformerEncoder:編碼器
- 因為位置編碼在PyTorch中沒有實現(xiàn)崎脉,需要自己實現(xiàn)拧咳。
4.1. 位置編碼
-
位置編碼的計算公式如下:
- 偶位置:
- 奇位置:
-
參數(shù)解釋:
- 表示單子在句子中位置
- 表示位置編碼的維度,這個維度必須與詞嵌入的維度一直囚灼。在上面采用的額是預訓練的維度:300骆膝。
-
表示偶數(shù)維度,表示奇數(shù)維度灶体。
- 阅签,,
-
下面的實現(xiàn)來自Pytorch官方文檔:
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
- 關(guān)于位置編碼實際有個發(fā)展過程蝎抽,Pytorch官方文檔的實現(xiàn)與上面原始論文中提出的計算公式不一樣政钟,有微小的變化。
- 這里不糾結(jié)位置編碼的具體計算公式樟结,后面會單獨說明养交。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, vocab_size=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(vocab_size, d_model)
# torch.arange(0, vocab_size, dtype=torch.float):生成0-vocab_size的張量,shape=(vocab_size,)
# unsqueeze(1):增加1維瓢宦,變成2維張量碎连。2維張量的shape=(vocab_size, 1)
position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
# exp:自然指數(shù)運算
div_term = torch.exp(
torch.arange(0, d_model, 2).float()
* (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # 增加1維,vocab_size所在維變成第二維
self.register_buffer("pe", pe)
def forward(self, x):
"""
X的第一維表示批次數(shù)驮履,每行是一個樣本鱼辙。
位置編碼對每個句子處理一樣。所以X的第一維是n=5疲吸,PE第一維是1
"""
# print("x的形狀", x.shape)
# print("PE的形狀", self.pe.shape)
x = x + self.pe[:, : x.size(1), :] # : x.size(1)限制與x的句子長度一致座每。
# print("截斷后的維數(shù):", self.pe[:, : x.size(1), :].shape)
return self.dropout(x)
4.2. Transformer分類器
- Pytorch已經(jīng)實現(xiàn)編碼器:
- 編碼單元:TransformerEncoderLayer
- 編碼器:編碼器
- 分類器使用Pytorch的邏輯回歸:
- 全連接層,加一個sigmoid運算摘悴,實際這里使用的是softmax函數(shù)峭梳。
EMBEDDING_SIZE = 300
class TextClassifier(nn.Module):
def __init__(
self,
nhead=8, # 多頭自注意力的多頭個數(shù)
dim_feedforward=2048, # 前饋網(wǎng)絡的大小
num_layers=6, # 編碼器中編碼單元的個數(shù)
dropout=0.1,
activation="relu", # 激活函數(shù)
classifier_dropout=0.1):
super().__init__()
vocab_size = NUM_WORDS + 2 # 這個大小不影響運算,實際不同的語料庫蹂喻,計算的vocab_size也不一樣葱椭。
d_model = EMBEDDING_SIZE
# vocab_size, d_model = embeddings.size()
assert d_model % nhead == 0, "nheads 必須整除 d_model"
# Embedding layer definition
# self.emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
# 詞嵌入對象:使用預訓練模型
self.emb = nn.Embedding.from_pretrained(pretrained_vector, freeze=False, padding_idx=0)
# 位置編碼器
self.pos_encoder = PositionalEncoding(
d_model=d_model,
dropout=dropout,
vocab_size=vocab_size
)
# 編碼單元
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True # 提高性能,否則會出現(xiàn)警告
)
# 編碼器
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# 分類器:5是最后分類的類別數(shù)口四,這里采用一層分類
self.classifier = nn.Linear(d_model, 5)
self.d_model = d_model
def forward(self, x):
# 詞嵌入運算
x = self.emb(x) * math.sqrt(self.d_model) # 對詞嵌入向量做了額外的scaled計算孵运,方式梯度消失
# 位置編碼運算
x = self.pos_encoder(x)
# 編碼器處理
x = self.transformer_encoder(x)
# 使用均值降維
x = x.mean(dim=1)
# 分類計算
x = self.classifier(x)
# 這里沒有直接轉(zhuǎn)換為概率softmax運算,這個對訓練沒有影響蔓彩,主要在分類方便治笨。
return x
- 下面是測試模型的運算:
- 沒有訓練過的模型驳概,只是分類效果差而已,實際已經(jīng)可以使用了旷赖。
model = TextClassifier(
nhead=10, # 多頭數(shù)量顺又,記得與d_model有整除關(guān)系
dim_feedforward=128, # 前饋全連接神經(jīng)網(wǎng)絡的維度
num_layers=1, # 編碼器層數(shù)
dropout=0.0,
classifier_dropout=0.0)
# 上面的批次是5,分類標簽個數(shù)是5等孵,輸出結(jié)果是沒有經(jīng)過概率化的稚照,就是sigmoid或者softmax運算
for x_batch, y_batch in train_dl:
y_ = model(x_batch)
print(y_.shape)
print(y_) # 概率化后,概率最大的下標就是分類的標簽編號俯萌。
break
torch.Size([5, 5])
tensor([[-0.2383, 0.3347, 0.2001, -0.1349, 0.1200],
[-0.1407, 0.3566, 0.3740, -0.2075, 0.0439],
[-0.1911, 0.2526, 0.2672, -0.3290, -0.0506],
[-0.2626, 0.1256, 0.3350, -0.3126, 0.2457],
[-0.1740, 0.2923, 0.2634, -0.2818, -0.0333]],
grad_fn=<AddmmBackward0>)
5. Transformer訓練實現(xiàn)
- Transfoemer的實現(xiàn)與一般深度學習神經(jīng)網(wǎng)絡的實現(xiàn)一樣:
- 對訓練樣本進行迭代開始訓練
- 調(diào)用模型果录,計算模型輸出
- 使用模型輸出與已知標簽計算誤差
- 對誤差求導,得到更新值
- 反向更新所有模型參數(shù)
- 可選:使用模型預測咐熙,并統(tǒng)計分類準確率(評估)
- 繼續(xù)下一次訓練弱恒。
import torch
from torch.optim import Adam # 優(yōu)化器
from torch.nn import CrossEntropyLoss, Softmax # 損失函數(shù),與概率轉(zhuǎn)化函數(shù)
from torch.utils.data import DataLoader # 批次數(shù)據(jù)集
from numpy import vstack, argmax # argmax是預測的常用函數(shù)糖声,得到概率最大下標(就是預測分類結(jié)果)
from sklearn.metrics import accuracy_score # 度量精確度
TRAIN_BATCH_SIZE = 32 # 批次大小斤彼,我們前面使用的是5
TEST_BATCH_SIZE = 16 # 測試批次大小分瘦,可以設(shè)置為1蘸泻,就是一個一個樣本測試
LEARNING_RATE = 0.001 # 學習率
EPOCHS = 10 # 訓練輪次
5.1. 訓練實現(xiàn)
- 深度學習的訓練模式基本上固化了
class ModelTrainer(object):
# 評估
@staticmethod
def evaluate_model(test_dl, model):
# 預測
predictions, actuals = [], []
# 迭代預測
for i, (inputs, targets) in enumerate(test_dl):
# 預測結(jié)果
yhat = model(inputs)
# 轉(zhuǎn)換為numpy數(shù)組
yhat = yhat.detach().numpy()
# 樣本標簽(真實標簽)
actual = targets.numpy()
# 轉(zhuǎn)換為分類標簽編號(不需要使用softxmax,因為這是遞增函數(shù))
yhat = argmax(yhat, axis=1) # 預測標簽
# 對預測結(jié)果進行形狀處理嘲玫,并放入一個列表悦施,并利用numpy的vstack合并成一個預測結(jié)果
actual = actual.reshape((len(actual), 1))
yhat = yhat.reshape((len(yhat), 1))
# store
predictions.append(yhat)
actuals.append(actual)
predictions, actuals = vstack(predictions), vstack(actuals)
# 計算精確度
acc = accuracy_score(actuals, predictions)
return acc
# 訓練,評估去团,訓練參數(shù)
def train(self, model):
# 加載訓練數(shù)據(jù)集與測試數(shù)據(jù)集
train, test = CSVDataset(TRAIN_FILE_PATH), CSVDataset(TEST_FILE_PATH)
# 轉(zhuǎn)換為批次數(shù)據(jù)集
train_dl = DataLoader(train, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test, batch_size=TEST_BATCH_SIZE)
# 定義優(yōu)化器
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
# 開始輪次訓練
for epoch in range(EPOCHS):
# 對訓練樣本進行批次訓練抡诞。
for x_batch, y_batch in train_dl:
y_batch = y_batch.long()
# 梯度置零
optimizer.zero_grad()
# 計算預測值
y_pred = model(x_batch)
# 使用預測值與真實標簽進行計算誤差
loss = CrossEntropyLoss()(y_pred, y_batch)
# 對誤差進行求導,得到梯度土陪。
loss.backward()
# 更新梯度
optimizer.step()
# 評估
test_accuracy = self.evaluate_model(test_dl, model)
print("輪次: %d, 損失值: %.5f, 測試集準確率: %.5f" % (epoch+1, loss.item(), test_accuracy))
5.2. 訓練執(zhí)行
model = TextClassifier(
nhead=10, # 多頭自注意力數(shù)量
dim_feedforward=128, # 解碼器單元的前饋全連接網(wǎng)絡維度
num_layers=4, # 編碼器的層數(shù)
dropout=0.0,
classifier_dropout=0.0)
# 統(tǒng)計參數(shù)量
num_params = sum(param.numel() for param in model.parameters())
print("參數(shù)量:", num_params)
# 訓練
ModelTrainer().train(model)
# 保存訓練模型
torch.save(model, 'model.pth')
參數(shù)量: 3411217
輪次: 1, 損失值: 0.52407, 測試集準確率: 0.87879
輪次: 2, 損失值: 0.63633, 測試集準確率: 0.88485
輪次: 3, 損失值: 0.30823, 測試集準確率: 0.89091
輪次: 4, 損失值: 0.22140, 測試集準確率: 0.87273
輪次: 5, 損失值: 0.18218, 測試集準確率: 0.87071
輪次: 6, 損失值: 0.23782, 測試集準確率: 0.88687
輪次: 7, 損失值: 0.28823, 測試集準確率: 0.90707
輪次: 8, 損失值: 0.20676, 測試集準確率: 0.88081
輪次: 9, 損失值: 0.14725, 測試集準確率: 0.85859
輪次: 10, 損失值: 0.20790, 測試集準確率: 0.88889
6. 模型評估
- 利用sklearn工具昼汗,對測試集預測結(jié)果計算分類報告與混淆矩陣。
import torch as T
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import rcParams # 顯示漢字
6.1. 計算預測結(jié)果
- 預測結(jié)果需要進行如下處理
- 加載測試數(shù)據(jù)集鬼雀,得到樣本與真實標簽
- 利用樣本計算預測結(jié)果
(1) 計算預測結(jié)果
# 加載模型
model = T.load('model.pth')
# 加載測試數(shù)據(jù)集
test_ds = CSVDataset(TEST_FILE_PATH)
test_dl = DataLoader(test_ds, batch_size=len(test_ds)) # 做成一個批次
# 循環(huán)預測
for x, y in test_dl:
y_ = model(x) # 預測
y_ = y_.detach().numpy()
y_ = argmax(y_, axis=1)
y = y.detach().numpy()
print(y_[:5], y[:5])
[0 1 0 1 0] [0 0 0 0 0]
(2)把預測標簽編號轉(zhuǎn)換為文字
# 記載標簽庫與字庫
label_dict, _ = load_file_file()
# 把key與value交換
label_dict_rev = {v: k for k, v in label_dict.items()}
true_labels = []
pred_labels = []
for true_no, pred_no in zip(y, y_):
true_label = label_dict_rev[true_no]
pred_label = label_dict_rev[pred_no]
true_labels.append(true_label)
pred_labels.append(pred_label)
# 打印5個看看效果
print(true_labels[:5], pred_labels[:5])
['體育', '體育', '體育', '體育', '體育'] ['體育', '健康', '體育', '健康', '體育']
6.2. 計算分類報告
- 調(diào)用
classification_report
輸出分類報告
report = classification_report(true_labels, pred_labels, digits=5) # digits指定輸出的有效小數(shù)位數(shù)
print(report)
precision recall f1-score support
體育 0.94505 0.86869 0.90526 99
健康 0.79464 0.89899 0.84360 99
軍事 0.94737 0.90909 0.92784 99
教育 0.83654 0.87879 0.85714 99
汽車 0.94624 0.88889 0.91667 99
accuracy 0.88889 495
macro avg 0.89397 0.88889 0.89010 495
weighted avg 0.89397 0.88889 0.89010 495
6.3. 計算混淆矩陣
- 調(diào)用
confusion_matrix
輸出混淆矩陣-
confusion_matrix
輸出的矩陣可以使用matplotlib可視化顷窒。
-
label_names = list(label_dict.keys())
C_M = confusion_matrix(true_labels, pred_labels, labels=label_names) # 最后是標簽名,需要類型是list
print(C_M)
[[86 7 2 4 0]
[ 0 89 1 6 3]
[ 1 2 90 6 0]
[ 1 9 0 87 2]
[ 3 5 2 1 88]]
- 使用matplotlib可視化
rcParams['font.family'] = 'SimHei'
plt.matshow(C_M, cmap=plt.cm.Reds) # cmap指定顏色系
# 顯示刻度與標簽
ticks = np.array(range(len(label_names)))
plt.xticks(ticks, label_names, rotation=90) # 將標簽印在x軸坐標上, 旋轉(zhuǎn)90度
plt.yticks(ticks, label_names) # 將標簽印在y軸坐標上
plt.show()
# plt.savefig("./image/confusion_matrix.png") # 直接保存為圖片
7. 推理
- 推理實現(xiàn)也是常見的流程:
- 加載模型
- 預處理需要分類的文本
- 預測計算
- 處理預測結(jié)果
import torch as T
import numpy as np
import torch.nn.functional as F
# 記載模型
model = T.load('model.pth')
# 加載標簽庫與字庫
label_dict, char_dict = load_file_file()
# 交換key與value
label_dict_rev = {v: k for k, v in label_dict.items()}
# 分類預測文本
text = '蓋世汽車訊源哩,特斯拉去年擊敗了寶馬鞋吉,奪得了美國豪華汽車市場的桂冠,并在今年實現(xiàn)了開門紅励烦。1月份谓着,得益于大幅降價和7500美元美國電動汽車稅收抵免,特斯拉再度擊敗寶馬坛掠,蟬聯(lián)了美國豪華車銷冠赊锚,并且注冊量超過了排名第三的梅賽德斯-奔馳和排名第四的雷克薩斯的總和治筒。根據(jù)Experian的數(shù)據(jù),在所有豪華品牌中舷蒲,1月份矢炼,特斯拉在美國的豪華車注冊量為49,917輛阿纤,同比增長34%句灌;寶馬的注冊量為31,070輛欠拾,同比增長2.5%胰锌;奔馳的注冊量為23,345輛藐窄,同比增長7.3%资昧;雷克薩斯的注冊量為23,082輛荆忍,同比下降6.6%格带。奧迪以19,113輛的注冊量排名第五刹枉,同比增長38%叽唱。凱迪拉克注冊量為13,220輛微宝,較去年同期增長36%棺亭,排名第六。排名第七的謳歌的注冊量為10蟋软,833輛镶摘,同比增長32%。沃爾沃汽車排名第八岳守,注冊量為8凄敢,864輛,同比增長1.8%湿痢。路虎以7涝缝,003輛的注冊量排名第九,林肯以6蒙袍,964輛的注冊量排名第十俊卤。'
# 文本向量化,因為text_feature實現(xiàn)的緣故害幅,其中需要一個labels參數(shù)消恍,但實際該參數(shù)在推理沒有意義,所以使用隨意一個標簽代替以现。
labels, contents = ['汽車'], [text]
samples, y_true = text_feature(labels, contents, label_dict, char_dict)
# 轉(zhuǎn)化為張量
x = T.from_numpy(np.array(samples)).long()
print(x.shape)
# 預測狠怨,注意x的形狀按照我們前面說的约啊,需要滿足特定的形狀
y_pred = model(x)
# 轉(zhuǎn)換為概率
y_numpy = F.softmax(y_pred, dim=1).detach().numpy()
# 去最大概率的下標作為預測標簽編號(因為可能存在多個文本預測結(jié)果)
predict_list = np.argmax(y_numpy, axis=1).tolist()
# 查詢輸出預測標簽
for i, predict in enumerate(predict_list):
print(f"第{i+1}個文本,預測標簽為: {label_dict_rev[predict]}")
torch.Size([1, 200])
第1個文本佣赖,預測標簽為: 汽車