提示學(xué)習(xí)系列:P-Tuning微調(diào)BERT/GPT2實(shí)現(xiàn)文本多分類

關(guān)鍵詞:提示學(xué)習(xí)杂穷,P-Tuning滓鸠,BERTGPT2

前言

P-Tuning是清華團(tuán)隊(duì)提出的一種使用提示學(xué)習(xí)微調(diào)大模型的方法吐葱,它提出自適應(yīng)學(xué)習(xí)的連續(xù)提示模板街望,來解決人工自然語言模板的不穩(wěn)定性,本文對(duì)該方法進(jìn)行簡(jiǎn)要介紹和實(shí)踐弟跑。


內(nèi)容摘要

  • P-Tuning理論方法簡(jiǎn)介
  • P-Tuning微調(diào)BERT實(shí)踐
  • P-Tuning微調(diào)GPT-2實(shí)踐
  • P-Tuning它匕、PET、Fine-Tuning效果對(duì)比

P-Tuning理論方法簡(jiǎn)介

前文所介紹的《提示學(xué)習(xí)系列:prompt自然語言模板微調(diào)BERT/GPT2實(shí)現(xiàn)文本分類》中窖认,指出用自然語言來誘導(dǎo)預(yù)訓(xùn)練模型完成NLU任務(wù)豫柬,例如在文本分類任務(wù)中,通過自然語言配合BERT的MLM完型填空過程來對(duì)要預(yù)測(cè)的分類做填空扑浸,而GPT-2也是構(gòu)造自然語言讓其進(jìn)行續(xù)寫得出分類類型烧给,提示學(xué)習(xí)不同于額外增加分類層的fine-tuning,做到了訓(xùn)練和預(yù)測(cè)表達(dá)形式的統(tǒng)一喝噪,自然語言模板的提示學(xué)習(xí)示意圖如下础嫡。

自然語言模板提示學(xué)習(xí)

上圖中要分類的文本是“歐聯(lián)杯雙馬對(duì)決”,“下面是一篇關(guān)于MM的新聞”是人工構(gòu)造的prompt,MM是完型填空需要預(yù)測(cè)的目標(biāo)榴鼎。在該類提示學(xué)習(xí)方法中伯诬,人工構(gòu)造prompt的內(nèi)容,以及拼接到原文的位置巫财,都會(huì)影響模型的訓(xùn)練效果盗似,往往改一個(gè)字都可能導(dǎo)致提示學(xué)習(xí)的性能大幅下降,難以求得最優(yōu)的提示文本平项,因此人工模板的提示學(xué)習(xí)效果不穩(wěn)定赫舒。

PET模型精度存在高方差

在論文中,作者舉例通過提示求得城市X所在國(guó)家Y闽瓢,不同的提示文本對(duì)模型影響巨大接癌,體現(xiàn)在預(yù)測(cè)結(jié)果精確度的高方差,從最低19.8最高51.1扣讼,而引入P-Tuning不僅能降低預(yù)測(cè)方差缺猛,還能提升整體準(zhǔn)確性。

P-Tuning的思想是椭符,與其絞盡腦汁構(gòu)造和搜索出最優(yōu)的prompt文本枯夜,不如引入一部分可訓(xùn)練的embedding和人工模板組合,一齊作為prompt的表征艰山,讓其具備一定的自適應(yīng)能力湖雹,從而來適配各種下游NLU任務(wù),增強(qiáng)模型訓(xùn)練的穩(wěn)定性曙搬,具體的摔吏,采用新的未知token來構(gòu)成prompt的可訓(xùn)練部分,和部分人工模板拼接纵装,P-Tuning的提示學(xué)習(xí)示意圖如下征讲。

p-tuning提示學(xué)習(xí)

其中u1,u2橡娄,u3代表未知的token诗箍,在BERT詞表中對(duì)應(yīng)[unused1]~[unused3],“新聞主題分類”是人工模板挽唉,讓模型往有利于預(yù)測(cè)出M位置的方向上迭代token的表征滤祖,其中引入token的數(shù)量和拼接位置可自行調(diào)整設(shè)置。
P-Tuning同樣適合凍結(jié)預(yù)訓(xùn)練模型參數(shù)和放開全參微調(diào)兩種方式瓶籽,當(dāng)標(biāo)注樣本較少時(shí)采用凍結(jié)模型參數(shù)匠童,只優(yōu)化prompt token embeddng的方式,當(dāng)標(biāo)注樣本充足時(shí)建議全參微調(diào)以達(dá)到最優(yōu)的模型效果塑顺。
在原論文中汤求,作者為了增強(qiáng)prompt部分的表征能力俏险,引入LSTM+MLP來刻畫token之間的前后依賴關(guān)系,使得其更加貼近自然語言扬绪。


P-Tuning微調(diào)BERT實(shí)踐

本篇采用和提示學(xué)習(xí)系列:prompt自然語言模板微調(diào)BERT/GPT2實(shí)現(xiàn)文本分類中一樣的數(shù)據(jù)樣本竖独,通過PyTorch快速實(shí)現(xiàn)p-tuning文本多分類,樣本對(duì)新聞文本做15分類預(yù)測(cè)挤牛,對(duì)于每個(gè)預(yù)測(cè)類別都當(dāng)一個(gè)完整的新token加入詞表進(jìn)行預(yù)測(cè)和損失計(jì)算莹痢,詞表拓充如下。

MODEL_PATH = "./model_hub/chinese-roberta-wwm-ext"
PRE_TRAIN = BertForMaskedLM.from_pretrained(MODEL_PATH).to(DEVICE)
PRE_TRAIN_CONFIG = BertConfig.from_pretrained(MODEL_PATH)
TOKENIZER = BertTokenizer.from_pretrained(MODEL_PATH)
# TODO 加入新詞赊颠,用于標(biāo)記prompt占位符
TOKENIZER.add_special_tokens({'additional_special_tokens': ["[PROMPT]"]})
PROMPT_TOKEN_ID = TOKENIZER.get_vocab()["[PROMPT]"]
CONFIG = BertConfig.from_pretrained(MODEL_PATH)
LABELS = ['文化', '娛樂', '體育', '財(cái)經(jīng)', '房產(chǎn)', '汽車', '教育', '科技', '軍事', '旅游', '國(guó)際', '證券', '農(nóng)業(yè)', '電競(jìng)', '民生']
TOKENIZER.add_tokens(LABELS)
PRE_TRAIN.resize_token_embeddings(len(TOKENIZER))
PRE_TRAIN.tie_weights()
TOKENIZER.save_pretrained("./test_add_word_p_tuning")

樣本構(gòu)造格二,將可學(xué)習(xí)的token和人工合并劈彪,拼接到原文的前面竣蹦,完整的樣本樣式為[CLS] + [used1] + ... + [used3] + [MASK] + [used4] + ... + [used_n] + [token1] + .. + [token_n] + sample + [SEP]的形式,其中used為可學(xué)習(xí)token embedding沧奴,token為人工自然語言模板痘括,MASK為預(yù)測(cè)目標(biāo),將非MASK位置的token預(yù)測(cè)label改為-100不記入損失滔吠,sample為原文纲菌。

PROMPT_LEN = (4, 4)

def collate_fn(data):
    prompts, attention_mask, labels, label_no = [], [], [], []
    for d in data:
        token = TOKENIZER.convert_tokens_to_ids(list(d["text"]))
        discrete_token = TOKENIZER.convert_tokens_to_ids(["新", "聞", "主", "題", "分", "類", "。"])
        cls = [TOKENIZER.cls_token_id]
        sep = [TOKENIZER.sep_token_id]
        first_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[0]
        mask = [TOKENIZER.mask_token_id] * 1
        second_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[1]
        # third_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[2]
        # TODO [CLS] + [PROMPT] + [MASK] + [PROMPT]  + token + [PROMPT] + [SEP]
        # prompt = cls + first_prompt + mask + second_prompt + token + third_prompt + sep
        prompt = cls + first_prompt + mask + second_prompt + discrete_token + token + sep
        prompts.append(prompt)
        attention_mask.append([1] * len(prompt))
        labels.append(TOKENIZER.convert_tokens_to_ids(d["label_name"]))
        label_no.append(d["label"])
    # TODO 對(duì)輸入進(jìn)行padding
    batch_max_length = max([len(x) for x in prompts])
    for i in range(len(prompts)):
        one_length = len(prompts[i])
        if one_length < batch_max_length:
            prompts[i] = prompts[i] + [0] * (batch_max_length - one_length)
            attention_mask[i] = attention_mask[i] + [0] * (batch_max_length - one_length)
    prompts = torch.LongTensor(prompts).to(DEVICE)
    attention_mask = torch.LongTensor(attention_mask).to(DEVICE)
    labels = torch.LongTensor(labels).to(DEVICE)
    # TODO 對(duì)labels進(jìn)行進(jìn)行處理疮绷,[MASK]位置為label翰舌,其他位置為-100
    label_ids = torch.empty_like(prompts).fill_(-100).long()
    label_mask = (prompts == TOKENIZER.mask_token_id).nonzero()[:, 1].reshape(prompts.shape[0], 1)
    # TODO 將MASK位置打上真實(shí)的詞,其他置為-100
    label_ids = label_ids.scatter_(1, label_mask, labels.unsqueeze(1))
    return prompts, attention_mask, label_ids, label_no

train = Data("./short_news/train.json")
train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=128, shuffle=True, drop_last=False)

單獨(dú)對(duì)prompt設(shè)置網(wǎng)絡(luò)模塊冬骚,設(shè)置可學(xué)習(xí)的token embedding椅贱,根據(jù)預(yù)設(shè)的位置隨機(jī)初始化embedding,經(jīng)過LSTM和MLP得到最終token表征只冻。

class PromptEncoder(nn.Module):
    def __init__(self, prompt_num, embedding_size):
        super(PromptEncoder, self).__init__()
        self.input_ids = torch.arange(0, prompt_num).long().to(DEVICE)
        self.embedding = nn.Embedding(prompt_num, embedding_size)
        self.lstm = nn.LSTM(input_size=embedding_size, hidden_size=embedding_size // 2, batch_first=True,
                            bidirectional=True, num_layers=2)
        self.mlp = nn.Sequential(nn.Linear(embedding_size, embedding_size), nn.ReLU(),
                                 nn.Linear(embedding_size, embedding_size))
        self.init_weight()

    def init_weight(self):
        nn.init.xavier_normal_(self.embedding.weight.data)
        for name, weight in self.lstm.named_parameters():
            if name.startswith("weight"):
                nn.init.xavier_normal_(weight.data)
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight.data)

    def forward(self):
        embedding = self.embedding(self.input_ids).unsqueeze(0)
        out = self.mlp(self.lstm(embedding)[0]).squeeze()
        return out

在BERT網(wǎng)絡(luò)中僅需要對(duì)輸入層做改造庇麦,將used位置的token替換為PromptEncoder層輸出即可,LM_FINE_TUNING參數(shù)決定是否凍結(jié)BERT預(yù)訓(xùn)練參數(shù)喜德。

LM_FINE_TUNING = True

class PTuningBert(nn.Module):
    def __init__(self):
        super(PTuningBert, self).__init__()
        self.pre_train = PRE_TRAIN
        # TODO 如果僅微調(diào)prompt則凍結(jié)預(yù)訓(xùn)練模型
        for param in self.pre_train.parameters():
            param.requires_grad = LM_FINE_TUNING
        self.embedding = self.pre_train.bert.get_input_embeddings()  # TODO 單獨(dú)拿到embedding層
        self.prompt_encoder = PromptEncoder(sum(PROMPT_LEN), PRE_TRAIN_CONFIG.hidden_size)

    def replace_embedding(self, prompt_embedding, raw_embedding, block_indices):
        # TODO 矩陣每一行進(jìn)行替換
        for ids in range(block_indices.size()[0]):
            for i in range(sum(PROMPT_LEN)):
                # TODO 將PROMPT位置的embedding 一條樣本一條樣本山橄,一個(gè)位置一個(gè)位置的 替換為隨機(jī)初始化+LSTM+MLP的
                # TODO block_indices[ids, i]:text中實(shí)際的PROMPT位置, i: PROMPT emb表中每個(gè)位置的id
                raw_embedding[ids, block_indices[ids, i], :] = prompt_embedding[i, :]
        return raw_embedding

    def forward(self, input_ids, attention_mask, label_ids=None):
        queries_for_embedding = input_ids.clone()
        queries_for_embedding[(input_ids == PROMPT_TOKEN_ID)] = TOKENIZER.unk_token_id
        raw_embeds = self.embedding(queries_for_embedding)
        # TODO 拿到每個(gè)text中PROMPT的位置索引 [[p1,p2,p3], [], []...]
        blocked_indices = (input_ids == PROMPT_TOKEN_ID).nonzero().reshape((input_ids.size()[0], sum(PROMPT_LEN), 2))[:, :, 1]
        prompt_embeds = self.prompt_encoder()
        # TODO 將原始raw_emb中的PROMPT(unk)位置替換掉
        input_embedding = self.replace_embedding(prompt_embeds, raw_embeds, blocked_indices)
        output = self.pre_train(inputs_embeds=input_embedding.to(attention_mask.device),
                                attention_mask=attention_mask,
                                labels=label_ids)
        loss, logits = output.loss, output.logits
        return loss, logits

P-Tuning微調(diào)GPT-2實(shí)踐

同理,GPT-2的預(yù)測(cè)目標(biāo)是最后一個(gè)token舍悯,將之前的所有token labe設(shè)置為-100不參與loss計(jì)算航棱,網(wǎng)絡(luò)部分和BERT實(shí)現(xiàn)一致,僅需要替換預(yù)訓(xùn)練模型即可萌衬。

def collate_fn(data):
    prompts, attention_mask, label_ids, label_no = [], [], [], []
    for d in data:
        token = TOKENIZER.encode(d["text"])[1:-1]
        discrete_token = TOKENIZER.convert_tokens_to_ids(["新", "聞", "主", "題", "分", "類", "丧诺。"])
        first_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[0]
        target_token = [TOKENIZER.convert_tokens_to_ids(d["label_name"])]
        second_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[1]
        # TODO [PROMPT] + token + [PROMPT] + target
        prompt = first_prompt + discrete_token + token + second_prompt + target_token
        label_id = [-100] * (len(prompt) - 1) + target_token
        prompts.append(prompt)
        attention_mask.append([1] * len(prompt))
        label_ids.append(label_id)
        label_no.append(d["label"])
        ....

P-Tuning、PET奄薇、Fine-Tuning效果對(duì)比

訓(xùn)練樣本數(shù)量分別取1000驳阎,5000,20000,設(shè)置最大驗(yàn)證集10次早停呵晚,采用chinese-bert-wwm-ext和gpt2-chinese-cluecorpussmall作為預(yù)訓(xùn)練模型蜘腌,全部采用全參微調(diào)的方式,對(duì)比Fine-Tuning饵隙,人工模板PET撮珠,P-Tuning的測(cè)試集F1效果。

模型策略 1000 5000 20000
BERT + fine_tuning 0.8324 0.852 0.8623
BERT + pet 0.8283 0.8511 0.8565
GPT-2 + pet 0.7796 0.8329 0.8383
BERT + p_tuning 0.7858 0.82675 0.84995
GPT-2 + p_tuning 0.8134 0.83485 0.85545

統(tǒng)計(jì)結(jié)果可視化如下

測(cè)試集F1值對(duì)比

結(jié)果顯示金矛,在小尺寸的BERT和GPT-2上芯急,不論P(yáng)ET還是P-Tuning這些提示學(xué)習(xí)微調(diào)的方法,微調(diào)結(jié)果都不如Fine-Tuning驶俊,至少有1個(gè)百分點(diǎn)的差距娶耍。在BERT上,P-Tuning的似乎明顯不如PET饼酿,該結(jié)論和作者論文的結(jié)論相悖滴肿。在GPT-2上绎巨,P-Tuning明顯提升了PET的效果,并且接近了BERT Fine-Tuning的效果,這點(diǎn)和作者的論文題目《GPT Understands, Too》這一結(jié)論一致录平。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末聪轿,一起剝皮案震驚了整個(gè)濱河市韭畸,隨后出現(xiàn)的幾起案子白魂,更是在濱河造成了極大的恐慌,老刑警劉巖槽片,帶你破解...
    沈念sama閱讀 219,427評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件何缓,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡筐乳,警方通過查閱死者的電腦和手機(jī)歌殃,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,551評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來蝙云,“玉大人氓皱,你說我怎么就攤上這事〔伲” “怎么了波材?”我有些...
    開封第一講書人閱讀 165,747評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)身隐。 經(jīng)常有香客問我廷区,道長(zhǎng),這世上最難降的妖魔是什么贾铝? 我笑而不...
    開封第一講書人閱讀 58,939評(píng)論 1 295
  • 正文 為了忘掉前任隙轻,我火速辦了婚禮埠帕,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘玖绿。我一直安慰自己敛瓷,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,955評(píng)論 6 392
  • 文/花漫 我一把揭開白布斑匪。 她就那樣靜靜地躺著呐籽,像睡著了一般。 火紅的嫁衣襯著肌膚如雪蚀瘸。 梳的紋絲不亂的頭發(fā)上狡蝶,一...
    開封第一講書人閱讀 51,737評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音贮勃,去河邊找鬼贪惹。 笑死,一個(gè)胖子當(dāng)著我的面吹牛衙猪,可吹牛的內(nèi)容都是我干的馍乙。 我是一名探鬼主播布近,決...
    沈念sama閱讀 40,448評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼垫释,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了撑瞧?” 一聲冷哼從身側(cè)響起棵譬,我...
    開封第一講書人閱讀 39,352評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎预伺,沒想到半個(gè)月后订咸,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,834評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡酬诀,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,992評(píng)論 3 338
  • 正文 我和宋清朗相戀三年脏嚷,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片瞒御。...
    茶點(diǎn)故事閱讀 40,133評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡父叙,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出肴裙,到底是詐尸還是另有隱情趾唱,我是刑警寧澤,帶...
    沈念sama閱讀 35,815評(píng)論 5 346
  • 正文 年R本政府宣布蜻懦,位于F島的核電站甜癞,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏宛乃。R本人自食惡果不足惜悠咱,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,477評(píng)論 3 331
  • 文/蒙蒙 一蒸辆、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧析既,春花似錦吁朦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,022評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至空骚,卻和暖如春纺讲,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背囤屹。 一陣腳步聲響...
    開封第一講書人閱讀 33,147評(píng)論 1 272
  • 我被黑心中介騙來泰國(guó)打工熬甚, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人肋坚。 一個(gè)月前我還...
    沈念sama閱讀 48,398評(píng)論 3 373
  • 正文 我出身青樓乡括,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親智厌。 傳聞我的和親對(duì)象是個(gè)殘疾皇子诲泌,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,077評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容