關鍵詞:提示學習
东跪,P-Tuning
止喷,BERT
呛凶,GPT2
前言
P-tuning v2是清華團隊在P-tuning基礎上提出的一種提示微調(diào)大模型方法彪杉,它旨在解決提示學習在小尺寸模型上效果不佳谣光,以及無法對NLU下游任務通用的問題檩淋,本文對該方法進行簡要介紹和實踐。
內(nèi)容摘要
- P-tuning v2理論方法簡介
- P-tuning v2微調(diào)BERT實踐
- P-tuning v2萄金、PET蟀悦、Fine-Tuning效果對比
P-tuning v2理論方法簡介
相比于現(xiàn)有的Prompt tuning方式,P-tuning v2的調(diào)整主要體現(xiàn)在:
- 1.為了增強對下游任務的通用性氧敢,使用類似Fine-tuning的[CLS]為作為任務的預測表征
- 2.引入Deep Prompt Tuning日戈,在Transformer的每一層Block中的輸入層,對輸入添加一定長度的前綴Prompt Embedding孙乖,讓模型自適應學習Prompt的表征
模型的結(jié)構(gòu)圖如下
以330M的BERT預訓練模型為例浙炼,Transformer一共12層,token的維度表征為768唯袄,設置提示長度為20弯屈,則要學習的連續(xù)提示Embedding表征為12 * [20, 768],相比于P-tuning v1可學習的參數(shù)數(shù)量明顯增多恋拷,同時這些參數(shù)嵌入在模型網(wǎng)絡的每一層资厉,相比于P-tuning v1不改變模型僅改變輸入而言,參數(shù)對模型結(jié)果的影響更直接蔬顾。
P-tuning v2微調(diào)BERT實踐
論文團隊只提供了P-tuning v2在BERT結(jié)構(gòu)上的方案和源碼宴偿,在源碼中作者并沒有改造Bert的代碼結(jié)構(gòu)來給每一層創(chuàng)建隨機Embedidng再做自注意力湘捎,而是采用了類似交叉注意力的方式,對每一層的Key和Value額外拼接了一定長度的Embedding窄刘,讓Key和拼接后的Key消痛、Value做交叉注意力,采用HuggingFace的模型類源碼能夠很容易的實現(xiàn)都哭,代碼如下
batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
)
其中past_key_values是額外給Key秩伞,Value添加的Prompt Embedding,attention_mask也同步增加前綴欺矫。
在Bert內(nèi)部纱新,會past_key_values把拼接在Key、Value原始向量的前面穆趴,代碼如下脸爱,原始輸入分別經(jīng)過Key、Value線性映射后直接在頭部拼接可學習的參數(shù)Embedidng未妹,來達到P-tuning v2的效果簿废。
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
除此之外P-tuning v2和Fine-tuing的實現(xiàn)無明顯區(qū)別,取[CLS]的池化輸出計算損失络它。本文采用和前文提示學習系列:P-Tuning微調(diào)BERT/GPT2實現(xiàn)文本多分類 同樣的數(shù)據(jù)集族檬,在新聞數(shù)據(jù)上通過P-tuning v2提示微調(diào)來實現(xiàn)文本多分類,模型網(wǎng)絡代碼如下
class Model(nn.Module):
def __init__(self, num_labels, pre_seq_len=40, hidden_size=PRE_TRAIN_CONFIG.hidden_size, hidden_dropout_prob=0.1):
super(Model, self).__init__()
self.num_labels = num_labels
self.pre_seq_len = pre_seq_len
self.n_layer = PRE_TRAIN_CONFIG.num_hidden_layers
self.n_head = PRE_TRAIN_CONFIG.num_attention_heads
self.n_embd = PRE_TRAIN_CONFIG.hidden_size // PRE_TRAIN_CONFIG.num_attention_heads
self.bert = PRE_TRAIN
self.dropout = torch.nn.Dropout(hidden_dropout_prob)
self.classifier = torch.nn.Linear(hidden_size, num_labels)
for param in self.bert.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(self.pre_seq_len, PRE_TRAIN_CONFIG.num_hidden_layers,
PRE_TRAIN_CONFIG.hidden_size)
requires_grad_param = 0
total_param = 0
for name, param in self.named_parameters():
total_param += param.numel()
if param.requires_grad:
requires_grad_param += param.numel()
print('total param: {}, trainable param: {}, trainable/total: {}'.format(total_param, requires_grad_param,
requires_grad_param / total_param))
def get_prompt(self, batch_size):
# TODO 統(tǒng)一構(gòu)造embedding并且改造為對應的維度
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
past_key_values = self.prefix_encoder(prefix_tokens)
past_key_values = past_key_values.view(
batch_size, # 128
self.pre_seq_len, # 40
self.n_layer * 2, # 24
self.n_head, #
self.n_embd
)
past_key_values = self.dropout(past_key_values)
# TODO 根據(jù)n_layer * 2分為2個
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
其中PrefixEncoder為創(chuàng)建的隨機初始化的Prompt Embedding化戳,實現(xiàn)如下
class PrefixEncoder(nn.Module):
def __init__(self, pre_seq_len, num_hidden_layers, hidden_size, prefix_projection=False):
super().__init__()
self.prefix_projection = prefix_projection # false
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(pre_seq_len, hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(hidden_size, prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(prefix_hidden_size, num_hidden_layers * 2 * hidden_size)
)
else:
self.embedding = torch.nn.Embedding(pre_seq_len, num_hidden_layers * 2 * hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
根據(jù)Bert的層數(shù)单料、隱藏層維度,Prompt長度來確定需要訓練的參數(shù)量点楼,初始化之后為了和Bert注意力源碼中的Key和Value拼接扫尖,需要額外做注意力頭維度分割和轉(zhuǎn)置。
P-tuning v2掠廓、PET换怖、Fine-Tuning效果對比
筆者在不同樣本數(shù)量下對Bert采用P-tuning v2,PET和Fine-Tuning微調(diào)蟀瞧,其中P-tuning v2凍結(jié)大模型沉颂,僅微調(diào)Prompt Embedding,PET和Fine-Tuning采用全參微調(diào)黄橘,以20000條樣本為例兆览,F(xiàn)1和模型訓練參量對比如下
其中P-tuning v2的預測精度略高于Fine-Tuning抬探,明顯高于PET,同時訓練參數(shù)量為74萬,而其他兩種全參微調(diào)參數(shù)量達到1億小压,P-tuning v2僅需要約0.1%的參數(shù)微調(diào)量就能達到全參微調(diào)的效果线梗。筆者在不同樣本量的多次訓練測試下,P-tuning v2的F1值接近Fine-Tuning怠益,仍普遍低于Fine-Tuning仪搔,但是明顯優(yōu)秀于PET和P-tuning v1。