提示學習系列:P-tuning v2微調(diào)BERT實現(xiàn)文本多分類

關鍵詞:提示學習东跪,P-Tuning止喷,BERT呛凶,GPT2

前言

P-tuning v2是清華團隊在P-tuning基礎上提出的一種提示微調(diào)大模型方法彪杉,它旨在解決提示學習在小尺寸模型上效果不佳谣光,以及無法對NLU下游任務通用的問題檩淋,本文對該方法進行簡要介紹和實踐。


內(nèi)容摘要

  • P-tuning v2理論方法簡介
  • P-tuning v2微調(diào)BERT實踐
  • P-tuning v2萄金、PET蟀悦、Fine-Tuning效果對比

P-tuning v2理論方法簡介

相比于現(xiàn)有的Prompt tuning方式,P-tuning v2的調(diào)整主要體現(xiàn)在:

  • 1.為了增強對下游任務的通用性氧敢,使用類似Fine-tuning的[CLS]為作為任務的預測表征
  • 2.引入Deep Prompt Tuning日戈,在Transformer的每一層Block中的輸入層,對輸入添加一定長度的前綴Prompt Embedding孙乖,讓模型自適應學習Prompt的表征

模型的結(jié)構(gòu)圖如下

P-tuning v2模型結(jié)構(gòu)

以330M的BERT預訓練模型為例浙炼,Transformer一共12層,token的維度表征為768唯袄,設置提示長度為20弯屈,則要學習的連續(xù)提示Embedding表征為12 * [20, 768],相比于P-tuning v1可學習的參數(shù)數(shù)量明顯增多恋拷,同時這些參數(shù)嵌入在模型網(wǎng)絡的每一層资厉,相比于P-tuning v1不改變模型僅改變輸入而言,參數(shù)對模型結(jié)果的影響更直接蔬顾。


P-tuning v2微調(diào)BERT實踐

論文團隊只提供了P-tuning v2在BERT結(jié)構(gòu)上的方案和源碼宴偿,在源碼中作者并沒有改造Bert的代碼結(jié)構(gòu)來給每一層創(chuàng)建隨機Embedidng再做自注意力湘捎,而是采用了類似交叉注意力的方式,對每一層的Key和Value額外拼接了一定長度的Embedding窄刘,讓Key和拼接后的Key消痛、Value做交叉注意力,采用HuggingFace的模型類源碼能夠很容易的實現(xiàn)都哭,代碼如下

batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            past_key_values=past_key_values,
        )

其中past_key_values是額外給Key秩伞,Value添加的Prompt Embedding,attention_mask也同步增加前綴欺矫。
在Bert內(nèi)部纱新,會past_key_values把拼接在Key、Value原始向量的前面穆趴,代碼如下脸爱,原始輸入分別經(jīng)過Key、Value線性映射后直接在頭部拼接可學習的參數(shù)Embedidng未妹,來達到P-tuning v2的效果簿废。

        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

除此之外P-tuning v2和Fine-tuing的實現(xiàn)無明顯區(qū)別,取[CLS]的池化輸出計算損失络它。本文采用和前文提示學習系列:P-Tuning微調(diào)BERT/GPT2實現(xiàn)文本多分類 同樣的數(shù)據(jù)集族檬,在新聞數(shù)據(jù)上通過P-tuning v2提示微調(diào)來實現(xiàn)文本多分類,模型網(wǎng)絡代碼如下

class Model(nn.Module):
    def __init__(self, num_labels, pre_seq_len=40, hidden_size=PRE_TRAIN_CONFIG.hidden_size, hidden_dropout_prob=0.1):
        super(Model, self).__init__()
        self.num_labels = num_labels
        self.pre_seq_len = pre_seq_len
        self.n_layer = PRE_TRAIN_CONFIG.num_hidden_layers
        self.n_head = PRE_TRAIN_CONFIG.num_attention_heads
        self.n_embd = PRE_TRAIN_CONFIG.hidden_size // PRE_TRAIN_CONFIG.num_attention_heads
        self.bert = PRE_TRAIN
        self.dropout = torch.nn.Dropout(hidden_dropout_prob)
        self.classifier = torch.nn.Linear(hidden_size, num_labels)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
        self.prefix_encoder = PrefixEncoder(self.pre_seq_len, PRE_TRAIN_CONFIG.num_hidden_layers,
                                            PRE_TRAIN_CONFIG.hidden_size)
        requires_grad_param = 0
        total_param = 0
        for name, param in self.named_parameters():
            total_param += param.numel()
            if param.requires_grad:
                requires_grad_param += param.numel()

        print('total param: {}, trainable param: {}, trainable/total: {}'.format(total_param, requires_grad_param,
                                                                                 requires_grad_param / total_param))

    def get_prompt(self, batch_size):
        # TODO 統(tǒng)一構(gòu)造embedding并且改造為對應的維度
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
        past_key_values = self.prefix_encoder(prefix_tokens)
        past_key_values = past_key_values.view(
            batch_size,  # 128
            self.pre_seq_len,  # 40
            self.n_layer * 2,  # 24
            self.n_head,  #
            self.n_embd
        )
        past_key_values = self.dropout(past_key_values)
        # TODO 根據(jù)n_layer * 2分為2個
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        return past_key_values

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size=batch_size)
        prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            past_key_values=past_key_values,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

其中PrefixEncoder為創(chuàng)建的隨機初始化的Prompt Embedding化戳,實現(xiàn)如下

class PrefixEncoder(nn.Module):
    def __init__(self, pre_seq_len, num_hidden_layers, hidden_size, prefix_projection=False):
        super().__init__()
        self.prefix_projection = prefix_projection  # false
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(pre_seq_len, hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(prefix_hidden_size, num_hidden_layers * 2 * hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(pre_seq_len, num_hidden_layers * 2 * hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

根據(jù)Bert的層數(shù)单料、隱藏層維度,Prompt長度來確定需要訓練的參數(shù)量点楼,初始化之后為了和Bert注意力源碼中的Key和Value拼接扫尖,需要額外做注意力頭維度分割和轉(zhuǎn)置。


P-tuning v2掠廓、PET换怖、Fine-Tuning效果對比

筆者在不同樣本數(shù)量下對Bert采用P-tuning v2,PET和Fine-Tuning微調(diào)蟀瞧,其中P-tuning v2凍結(jié)大模型沉颂,僅微調(diào)Prompt Embedding,PET和Fine-Tuning采用全參微調(diào)黄橘,以20000條樣本為例兆览,F(xiàn)1和模型訓練參量對比如下

P-tuning v2、PET塞关、Fine-Tuning效果對比

其中P-tuning v2的預測精度略高于Fine-Tuning抬探,明顯高于PET,同時訓練參數(shù)量為74萬,而其他兩種全參微調(diào)參數(shù)量達到1億小压,P-tuning v2僅需要約0.1%的參數(shù)微調(diào)量就能達到全參微調(diào)的效果线梗。筆者在不同樣本量的多次訓練測試下,P-tuning v2的F1值接近Fine-Tuning怠益,仍普遍低于Fine-Tuning仪搔,但是明顯優(yōu)秀于PET和P-tuning v1。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末蜻牢,一起剝皮案震驚了整個濱河市烤咧,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌抢呆,老刑警劉巖煮嫌,帶你破解...
    沈念sama閱讀 206,839評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異抱虐,居然都是意外死亡昌阿,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,543評論 2 382
  • 文/潘曉璐 我一進店門恳邀,熙熙樓的掌柜王于貴愁眉苦臉地迎上來懦冰,“玉大人,你說我怎么就攤上這事谣沸∷⒏郑” “怎么了?”我有些...
    開封第一講書人閱讀 153,116評論 0 344
  • 文/不壞的土叔 我叫張陵鳄抒,是天一觀的道長闯捎。 經(jīng)常有香客問我椰弊,道長许溅,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,371評論 1 279
  • 正文 為了忘掉前任秉版,我火速辦了婚禮贤重,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘清焕。我一直安慰自己并蝗,他們只是感情好,可當我...
    茶點故事閱讀 64,384評論 5 374
  • 文/花漫 我一把揭開白布秸妥。 她就那樣靜靜地躺著滚停,像睡著了一般。 火紅的嫁衣襯著肌膚如雪粥惧。 梳的紋絲不亂的頭發(fā)上键畴,一...
    開封第一講書人閱讀 49,111評論 1 285
  • 那天,我揣著相機與錄音,去河邊找鬼起惕。 笑死涡贱,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的惹想。 我是一名探鬼主播问词,決...
    沈念sama閱讀 38,416評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼嘀粱!你這毒婦竟也來了激挪?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,053評論 0 259
  • 序言:老撾萬榮一對情侶失蹤锋叨,失蹤者是張志新(化名)和其女友劉穎灌灾,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體悲柱,經(jīng)...
    沈念sama閱讀 43,558評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡锋喜,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,007評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了豌鸡。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片嘿般。...
    茶點故事閱讀 38,117評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖涯冠,靈堂內(nèi)的尸體忽然破棺而出炉奴,到底是詐尸還是另有隱情,我是刑警寧澤蛇更,帶...
    沈念sama閱讀 33,756評論 4 324
  • 正文 年R本政府宣布瞻赶,位于F島的核電站,受9級特大地震影響派任,放射性物質(zhì)發(fā)生泄漏砸逊。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,324評論 3 307
  • 文/蒙蒙 一掌逛、第九天 我趴在偏房一處隱蔽的房頂上張望师逸。 院中可真熱鬧,春花似錦豆混、人聲如沸篓像。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽员辩。三九已至,卻和暖如春鸵鸥,著一層夾襖步出監(jiān)牢的瞬間奠滑,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評論 1 262
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留养叛,地道東北人种呐。 一個月前我還...
    沈念sama閱讀 45,578評論 2 355
  • 正文 我出身青樓,卻偏偏與公主長得像弃甥,于是被迫代替她去往敵國和親爽室。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 42,877評論 2 345

推薦閱讀更多精彩內(nèi)容