關(guān)鍵詞:提示學(xué)習(xí)
杂穷,P-Tuning
滓鸠,BERT
,GPT2
前言
P-Tuning是清華團(tuán)隊(duì)提出的一種使用提示學(xué)習(xí)微調(diào)大模型的方法吐葱,它提出自適應(yīng)學(xué)習(xí)的連續(xù)提示模板街望,來解決人工自然語言模板的不穩(wěn)定性,本文對(duì)該方法進(jìn)行簡(jiǎn)要介紹和實(shí)踐弟跑。
內(nèi)容摘要
- P-Tuning理論方法簡(jiǎn)介
- P-Tuning微調(diào)BERT實(shí)踐
- P-Tuning微調(diào)GPT-2實(shí)踐
- P-Tuning它匕、PET、Fine-Tuning效果對(duì)比
P-Tuning理論方法簡(jiǎn)介
前文所介紹的《提示學(xué)習(xí)系列:prompt自然語言模板微調(diào)BERT/GPT2實(shí)現(xiàn)文本分類》中窖认,指出用自然語言來誘導(dǎo)預(yù)訓(xùn)練模型完成NLU任務(wù)豫柬,例如在文本分類任務(wù)中,通過自然語言配合BERT的MLM完型填空過程來對(duì)要預(yù)測(cè)的分類做填空扑浸,而GPT-2也是構(gòu)造自然語言讓其進(jìn)行續(xù)寫得出分類類型烧给,提示學(xué)習(xí)不同于額外增加分類層的fine-tuning,做到了訓(xùn)練和預(yù)測(cè)表達(dá)形式的統(tǒng)一喝噪,自然語言模板的提示學(xué)習(xí)示意圖如下础嫡。
上圖中要分類的文本是“歐聯(lián)杯雙馬對(duì)決”,“下面是一篇關(guān)于MM的新聞”是人工構(gòu)造的prompt,MM是完型填空需要預(yù)測(cè)的目標(biāo)榴鼎。在該類提示學(xué)習(xí)方法中伯诬,人工構(gòu)造prompt的內(nèi)容,以及拼接到原文的位置巫财,都會(huì)影響模型的訓(xùn)練效果盗似,往往改一個(gè)字都可能導(dǎo)致提示學(xué)習(xí)的性能大幅下降,難以求得最優(yōu)的提示文本平项,因此人工模板的提示學(xué)習(xí)效果不穩(wěn)定赫舒。
在論文中,作者舉例通過提示求得城市X所在國(guó)家Y闽瓢,不同的提示文本對(duì)模型影響巨大接癌,體現(xiàn)在預(yù)測(cè)結(jié)果精確度的高方差,從最低19.8最高51.1扣讼,而引入P-Tuning不僅能降低預(yù)測(cè)方差缺猛,還能提升整體準(zhǔn)確性。
P-Tuning的思想是椭符,與其絞盡腦汁構(gòu)造和搜索出最優(yōu)的prompt文本枯夜,不如引入一部分可訓(xùn)練的embedding和人工模板組合,一齊作為prompt的表征艰山,讓其具備一定的自適應(yīng)能力湖雹,從而來適配各種下游NLU任務(wù),增強(qiáng)模型訓(xùn)練的穩(wěn)定性曙搬,具體的摔吏,采用新的未知token來構(gòu)成prompt的可訓(xùn)練部分,和部分人工模板拼接纵装,P-Tuning的提示學(xué)習(xí)示意圖如下征讲。
其中u1,u2橡娄,u3代表未知的token诗箍,在BERT詞表中對(duì)應(yīng)[unused1]~[unused3],“新聞主題分類”是人工模板挽唉,讓模型往有利于預(yù)測(cè)出M位置的方向上迭代token的表征滤祖,其中引入token的數(shù)量和拼接位置可自行調(diào)整設(shè)置。
P-Tuning同樣適合凍結(jié)預(yù)訓(xùn)練模型參數(shù)和放開全參微調(diào)兩種方式瓶籽,當(dāng)標(biāo)注樣本較少時(shí)采用凍結(jié)模型參數(shù)匠童,只優(yōu)化prompt token embeddng的方式,當(dāng)標(biāo)注樣本充足時(shí)建議全參微調(diào)以達(dá)到最優(yōu)的模型效果塑顺。
在原論文中汤求,作者為了增強(qiáng)prompt部分的表征能力俏险,引入LSTM+MLP來刻畫token之間的前后依賴關(guān)系,使得其更加貼近自然語言扬绪。
P-Tuning微調(diào)BERT實(shí)踐
本篇采用和提示學(xué)習(xí)系列:prompt自然語言模板微調(diào)BERT/GPT2實(shí)現(xiàn)文本分類中一樣的數(shù)據(jù)樣本竖独,通過PyTorch快速實(shí)現(xiàn)p-tuning文本多分類,樣本對(duì)新聞文本做15分類預(yù)測(cè)挤牛,對(duì)于每個(gè)預(yù)測(cè)類別都當(dāng)一個(gè)完整的新token加入詞表進(jìn)行預(yù)測(cè)和損失計(jì)算莹痢,詞表拓充如下。
MODEL_PATH = "./model_hub/chinese-roberta-wwm-ext"
PRE_TRAIN = BertForMaskedLM.from_pretrained(MODEL_PATH).to(DEVICE)
PRE_TRAIN_CONFIG = BertConfig.from_pretrained(MODEL_PATH)
TOKENIZER = BertTokenizer.from_pretrained(MODEL_PATH)
# TODO 加入新詞赊颠,用于標(biāo)記prompt占位符
TOKENIZER.add_special_tokens({'additional_special_tokens': ["[PROMPT]"]})
PROMPT_TOKEN_ID = TOKENIZER.get_vocab()["[PROMPT]"]
CONFIG = BertConfig.from_pretrained(MODEL_PATH)
LABELS = ['文化', '娛樂', '體育', '財(cái)經(jīng)', '房產(chǎn)', '汽車', '教育', '科技', '軍事', '旅游', '國(guó)際', '證券', '農(nóng)業(yè)', '電競(jìng)', '民生']
TOKENIZER.add_tokens(LABELS)
PRE_TRAIN.resize_token_embeddings(len(TOKENIZER))
PRE_TRAIN.tie_weights()
TOKENIZER.save_pretrained("./test_add_word_p_tuning")
樣本構(gòu)造格二,將可學(xué)習(xí)的token和人工合并劈彪,拼接到原文的前面竣蹦,完整的樣本樣式為[CLS] + [used1] + ... + [used3] + [MASK] + [used4] + ... + [used_n] + [token1] + .. + [token_n] + sample + [SEP]的形式,其中used為可學(xué)習(xí)token embedding沧奴,token為人工自然語言模板痘括,MASK為預(yù)測(cè)目標(biāo),將非MASK位置的token預(yù)測(cè)label改為-100不記入損失滔吠,sample為原文纲菌。
PROMPT_LEN = (4, 4)
def collate_fn(data):
prompts, attention_mask, labels, label_no = [], [], [], []
for d in data:
token = TOKENIZER.convert_tokens_to_ids(list(d["text"]))
discrete_token = TOKENIZER.convert_tokens_to_ids(["新", "聞", "主", "題", "分", "類", "。"])
cls = [TOKENIZER.cls_token_id]
sep = [TOKENIZER.sep_token_id]
first_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[0]
mask = [TOKENIZER.mask_token_id] * 1
second_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[1]
# third_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[2]
# TODO [CLS] + [PROMPT] + [MASK] + [PROMPT] + token + [PROMPT] + [SEP]
# prompt = cls + first_prompt + mask + second_prompt + token + third_prompt + sep
prompt = cls + first_prompt + mask + second_prompt + discrete_token + token + sep
prompts.append(prompt)
attention_mask.append([1] * len(prompt))
labels.append(TOKENIZER.convert_tokens_to_ids(d["label_name"]))
label_no.append(d["label"])
# TODO 對(duì)輸入進(jìn)行padding
batch_max_length = max([len(x) for x in prompts])
for i in range(len(prompts)):
one_length = len(prompts[i])
if one_length < batch_max_length:
prompts[i] = prompts[i] + [0] * (batch_max_length - one_length)
attention_mask[i] = attention_mask[i] + [0] * (batch_max_length - one_length)
prompts = torch.LongTensor(prompts).to(DEVICE)
attention_mask = torch.LongTensor(attention_mask).to(DEVICE)
labels = torch.LongTensor(labels).to(DEVICE)
# TODO 對(duì)labels進(jìn)行進(jìn)行處理疮绷,[MASK]位置為label翰舌,其他位置為-100
label_ids = torch.empty_like(prompts).fill_(-100).long()
label_mask = (prompts == TOKENIZER.mask_token_id).nonzero()[:, 1].reshape(prompts.shape[0], 1)
# TODO 將MASK位置打上真實(shí)的詞,其他置為-100
label_ids = label_ids.scatter_(1, label_mask, labels.unsqueeze(1))
return prompts, attention_mask, label_ids, label_no
train = Data("./short_news/train.json")
train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=128, shuffle=True, drop_last=False)
單獨(dú)對(duì)prompt設(shè)置網(wǎng)絡(luò)模塊冬骚,設(shè)置可學(xué)習(xí)的token embedding椅贱,根據(jù)預(yù)設(shè)的位置隨機(jī)初始化embedding,經(jīng)過LSTM和MLP得到最終token表征只冻。
class PromptEncoder(nn.Module):
def __init__(self, prompt_num, embedding_size):
super(PromptEncoder, self).__init__()
self.input_ids = torch.arange(0, prompt_num).long().to(DEVICE)
self.embedding = nn.Embedding(prompt_num, embedding_size)
self.lstm = nn.LSTM(input_size=embedding_size, hidden_size=embedding_size // 2, batch_first=True,
bidirectional=True, num_layers=2)
self.mlp = nn.Sequential(nn.Linear(embedding_size, embedding_size), nn.ReLU(),
nn.Linear(embedding_size, embedding_size))
self.init_weight()
def init_weight(self):
nn.init.xavier_normal_(self.embedding.weight.data)
for name, weight in self.lstm.named_parameters():
if name.startswith("weight"):
nn.init.xavier_normal_(weight.data)
for layer in self.mlp:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight.data)
def forward(self):
embedding = self.embedding(self.input_ids).unsqueeze(0)
out = self.mlp(self.lstm(embedding)[0]).squeeze()
return out
在BERT網(wǎng)絡(luò)中僅需要對(duì)輸入層做改造庇麦,將used位置的token替換為PromptEncoder層輸出即可,LM_FINE_TUNING參數(shù)決定是否凍結(jié)BERT預(yù)訓(xùn)練參數(shù)喜德。
LM_FINE_TUNING = True
class PTuningBert(nn.Module):
def __init__(self):
super(PTuningBert, self).__init__()
self.pre_train = PRE_TRAIN
# TODO 如果僅微調(diào)prompt則凍結(jié)預(yù)訓(xùn)練模型
for param in self.pre_train.parameters():
param.requires_grad = LM_FINE_TUNING
self.embedding = self.pre_train.bert.get_input_embeddings() # TODO 單獨(dú)拿到embedding層
self.prompt_encoder = PromptEncoder(sum(PROMPT_LEN), PRE_TRAIN_CONFIG.hidden_size)
def replace_embedding(self, prompt_embedding, raw_embedding, block_indices):
# TODO 矩陣每一行進(jìn)行替換
for ids in range(block_indices.size()[0]):
for i in range(sum(PROMPT_LEN)):
# TODO 將PROMPT位置的embedding 一條樣本一條樣本山橄,一個(gè)位置一個(gè)位置的 替換為隨機(jī)初始化+LSTM+MLP的
# TODO block_indices[ids, i]:text中實(shí)際的PROMPT位置, i: PROMPT emb表中每個(gè)位置的id
raw_embedding[ids, block_indices[ids, i], :] = prompt_embedding[i, :]
return raw_embedding
def forward(self, input_ids, attention_mask, label_ids=None):
queries_for_embedding = input_ids.clone()
queries_for_embedding[(input_ids == PROMPT_TOKEN_ID)] = TOKENIZER.unk_token_id
raw_embeds = self.embedding(queries_for_embedding)
# TODO 拿到每個(gè)text中PROMPT的位置索引 [[p1,p2,p3], [], []...]
blocked_indices = (input_ids == PROMPT_TOKEN_ID).nonzero().reshape((input_ids.size()[0], sum(PROMPT_LEN), 2))[:, :, 1]
prompt_embeds = self.prompt_encoder()
# TODO 將原始raw_emb中的PROMPT(unk)位置替換掉
input_embedding = self.replace_embedding(prompt_embeds, raw_embeds, blocked_indices)
output = self.pre_train(inputs_embeds=input_embedding.to(attention_mask.device),
attention_mask=attention_mask,
labels=label_ids)
loss, logits = output.loss, output.logits
return loss, logits
P-Tuning微調(diào)GPT-2實(shí)踐
同理,GPT-2的預(yù)測(cè)目標(biāo)是最后一個(gè)token舍悯,將之前的所有token labe設(shè)置為-100不參與loss計(jì)算航棱,網(wǎng)絡(luò)部分和BERT實(shí)現(xiàn)一致,僅需要替換預(yù)訓(xùn)練模型即可萌衬。
def collate_fn(data):
prompts, attention_mask, label_ids, label_no = [], [], [], []
for d in data:
token = TOKENIZER.encode(d["text"])[1:-1]
discrete_token = TOKENIZER.convert_tokens_to_ids(["新", "聞", "主", "題", "分", "類", "丧诺。"])
first_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[0]
target_token = [TOKENIZER.convert_tokens_to_ids(d["label_name"])]
second_prompt = [PROMPT_TOKEN_ID] * PROMPT_LEN[1]
# TODO [PROMPT] + token + [PROMPT] + target
prompt = first_prompt + discrete_token + token + second_prompt + target_token
label_id = [-100] * (len(prompt) - 1) + target_token
prompts.append(prompt)
attention_mask.append([1] * len(prompt))
label_ids.append(label_id)
label_no.append(d["label"])
....
P-Tuning、PET奄薇、Fine-Tuning效果對(duì)比
訓(xùn)練樣本數(shù)量分別取1000驳阎,5000,20000,設(shè)置最大驗(yàn)證集10次早停呵晚,采用chinese-bert-wwm-ext和gpt2-chinese-cluecorpussmall作為預(yù)訓(xùn)練模型蜘腌,全部采用全參微調(diào)的方式,對(duì)比Fine-Tuning饵隙,人工模板PET撮珠,P-Tuning的測(cè)試集F1效果。
模型策略 | 1000 | 5000 | 20000 |
---|---|---|---|
BERT + fine_tuning | 0.8324 | 0.852 | 0.8623 |
BERT + pet | 0.8283 | 0.8511 | 0.8565 |
GPT-2 + pet | 0.7796 | 0.8329 | 0.8383 |
BERT + p_tuning | 0.7858 | 0.82675 | 0.84995 |
GPT-2 + p_tuning | 0.8134 | 0.83485 | 0.85545 |
統(tǒng)計(jì)結(jié)果可視化如下
結(jié)果顯示金矛,在小尺寸的BERT和GPT-2上芯急,不論P(yáng)ET還是P-Tuning這些提示學(xué)習(xí)微調(diào)的方法,微調(diào)結(jié)果都不如Fine-Tuning驶俊,至少有1個(gè)百分點(diǎn)的差距娶耍。在BERT上,P-Tuning的似乎明顯不如PET饼酿,該結(jié)論和作者論文的結(jié)論相悖滴肿。在GPT-2上绎巨,P-Tuning明顯提升了PET的效果,并且接近了BERT Fine-Tuning的效果,這點(diǎn)和作者的論文題目《GPT Understands, Too》這一結(jié)論一致录平。