2021-12-06 bert model

attention mask如何使用

  • attention_mask List[int] 0-mask,1-attention
    forward(,attention_mask,):
encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
  • extend_attention_mask
extended_attention_mask: torch.Tensor = \
self.get_extended_attention_mask(attention_mask, input_shape, device)

 def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.
        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )

                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

  • get_extended_attention_mask

attention_mask=extend_attention_mask

  • is_decoder中encoder_attention_mask: encoder_extend_attention_mask=self.invert_attention_mask()

形成一個(gè)下三角矩陣

最終mask在BertSelfAttention里起作用。

  • 在forward函數(shù)里求出attention score之后且轨,通過(guò)運(yùn)行
if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

在BertModel傳入attention_mask

這是attention已經(jīng)在BertModel的forward的get_extended_attention_mask處轉(zhuǎn)變
其中g(shù)et_extended_attention_mask

其中g(shù)et_extended_attention_mask來(lái)自modeling_utils.py文件

        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

目的帆赢,將attention_mask中為0的變?yōu)榇筘?fù)數(shù)码倦,1的為0

  • 此時(shí)傳給encoder的attention_mask已經(jīng)改變比然,(encoder_attention_mask根據(jù)是否decoder傳值)。
  • encoder來(lái)自 BertEncoder(config)
  • BertEncoder封裝了num_hidden_layer個(gè)BertLayer
  • BertLayer封裝了BertAttention和BertIntermediate和BertOutput
    *BertAttention封裝了BertSelfAttention忍疾,和BertSelfOutput

一個(gè)疑惑:BertModel的init具體初始化了那些東西

*Bert的init函數(shù)里有

super().__init__(config)
self.post_init()

在QA中沸毁,tokenizer之后的inputs的attention_mask仍然保持全1狀態(tài),需要手動(dòng)調(diào)整

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末亥鬓,一起剝皮案震驚了整個(gè)濱河市完沪,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌嵌戈,老刑警劉巖覆积,帶你破解...
    沈念sama閱讀 218,755評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異熟呛,居然都是意外死亡宽档,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,305評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門(mén)庵朝,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)吗冤,“玉大人,你說(shuō)我怎么就攤上這事九府∽滴粒” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 165,138評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵侄旬,是天一觀的道長(zhǎng)肺蔚。 經(jīng)常有香客問(wèn)我,道長(zhǎng)儡羔,這世上最難降的妖魔是什么宣羊? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,791評(píng)論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮笔链,結(jié)果婚禮上段只,老公的妹妹穿的比我還像新娘。我一直安慰自己鉴扫,他們只是感情好赞枕,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,794評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著,像睡著了一般炕婶。 火紅的嫁衣襯著肌膚如雪姐赡。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 51,631評(píng)論 1 305
  • 那天柠掂,我揣著相機(jī)與錄音项滑,去河邊找鬼。 笑死涯贞,一個(gè)胖子當(dāng)著我的面吹牛枪狂,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播宋渔,決...
    沈念sama閱讀 40,362評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼州疾,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了皇拣?” 一聲冷哼從身側(cè)響起严蓖,我...
    開(kāi)封第一講書(shū)人閱讀 39,264評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎氧急,沒(méi)想到半個(gè)月后颗胡,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,724評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡吩坝,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,900評(píng)論 3 336
  • 正文 我和宋清朗相戀三年毒姨,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片钾恢。...
    茶點(diǎn)故事閱讀 40,040評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡手素,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出瘩蚪,到底是詐尸還是另有隱情泉懦,我是刑警寧澤,帶...
    沈念sama閱讀 35,742評(píng)論 5 346
  • 正文 年R本政府宣布疹瘦,位于F島的核電站崩哩,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏言沐。R本人自食惡果不足惜邓嘹,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,364評(píng)論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望险胰。 院中可真熱鬧汹押,春花似錦、人聲如沸起便。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,944評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至妙痹,卻和暖如春铸史,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背怯伊。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,060評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工琳轿, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人耿芹。 一個(gè)月前我還...
    沈念sama閱讀 48,247評(píng)論 3 371
  • 正文 我出身青樓崭篡,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親猩系。 傳聞我的和親對(duì)象是個(gè)殘疾皇子媚送,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,979評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容

  • 1 BertTokenizer(Tokenization分詞) 組成結(jié)構(gòu):BasicTokenizer和WordP...
    def1037aab9e閱讀 1,223評(píng)論 0 0
  • 1 BertTokenizer(Tokenization分詞) 組成結(jié)構(gòu):BasicTokenizer和WordP...
    def1037aab9e閱讀 1,036評(píng)論 0 0
  • ![Flask](...
    極客學(xué)院Wiki閱讀 7,249評(píng)論 0 3
  • 不知不覺(jué)易趣客已經(jīng)在路上走了快一年了,感覺(jué)也該讓更多朋友認(rèn)識(shí)知道易趣客疗涉,所以就謝了這篇簡(jiǎn)介拿霉,已做創(chuàng)業(yè)記事。 易趣客...
    Physher閱讀 3,420評(píng)論 1 2
  • 雙胎妊娠有家族遺傳傾向咱扣,隨母系遺傳绽淘。有研究表明,如果孕婦本人是雙胎之一闹伪,她生雙胎的機(jī)率為1/58沪铭;若孕婦的父親或母...
    鄴水芙蓉hibiscus閱讀 3,702評(píng)論 0 2