BERT的那些類
1. BertConfig
這是一個(gè)配置類并村,繼承PretrainedConfig類,用于model的配置腾啥,構(gòu)造函數(shù)參數(shù)如下:
vocab_size?(int, optional, defaults to 30522) :BERT模型的字典大小东涡,默認(rèn)30522冯吓,每個(gè)token可以由input_ids表示。
hidden_size?(int, optional, defaults to 768) :是encoder和pooler層的維度软啼,其中encoder層就是bert的主體結(jié)構(gòu)桑谍,pooler層是將encoder層的輸出接一個(gè)全連接層,將整個(gè)句子的信息表示為第一個(gè)token對(duì)應(yīng)的隱含狀態(tài)祸挪。
num_hidden_layers?(int, optional, defaults to 12) :隱含層數(shù)锣披,默認(rèn)是12層。
num_attention_heads?(int, optional, defaults to 12) :每個(gè)attention層的attention頭數(shù)贿条,默認(rèn)是12個(gè)雹仿。
intermediate_size?(int, optional, defaults to 3072) :encoder中的中間層的維度,如前向傳播層整以,默認(rèn)是3072.
hidden_act?(str?or?function, optional, defaults to “gelu”):encoder和pooler部分中非線性層的激活函數(shù)胧辽,默認(rèn)是gelu
hidden_dropout_prob?(float, optional, defaults to 0.1) :embedding, encoder, pooler部分里全連接層的dropout概率,默認(rèn)為0.1.
attention_probs_dropout_prob?(float, optional, defaults to 0.1):attention過程中softmax后的概率計(jì)算時(shí)的dropout概率公黑,默認(rèn)0.1.
max_position_embeddings?(int, optional, defaults to 512) :模型允許的最大序列長(zhǎng)度邑商,默認(rèn)512。
函數(shù):
from_dict:由一個(gè)參數(shù)字典構(gòu)建Config凡蚜;
from_json_file:由一個(gè)參數(shù)json文件構(gòu)建Config人断;
from_pretrained:由一個(gè)預(yù)訓(xùn)練的模型配置實(shí)例化一個(gè)配置
2.?BertTokenizer
以字分割,繼承PreTrainedTokenizer朝蜘,前面介紹過恶迈,構(gòu)造函數(shù)參數(shù);
vocab_file?(string):字典文件,每行一個(gè)wordpiece
do_lower_case?(bool,?optional, defaults to?True) :是否將輸入轉(zhuǎn)換成小寫
do_basic_tokenize?(bool,?optional, defaults to?True):是否在字分割之前使用BasicTokenize
never_split?(Iterable,?optional, defaults to?None)?:可選谱醇。輸入一個(gè)列表暇仲,列表內(nèi)容為不進(jìn)行 tokenization 的單詞
unk_token?(string,?optional, defaults to “[UNK]”) :字典里沒有的字可以用這個(gè)token代替,默認(rèn)使用[UNK]
sep_token?(string,?optional, defaults to “[SEP]”):分隔句子的token符號(hào)副渴,默認(rèn)[SEP]
pad_token?(string,?optional, defaults to “[PAD]”)?
cls_token?(string,?optional, defaults to “[CLS]”)
mask_token?(string,?optional, defaults to “[MASK]”)
tokenize_chinese_chars?(bool,?optional, defaults to?True) :是否將中文字分割開
返回的就是input_ids奈附,token_type_ds,attention mask等煮剧。
3.?BertModel
Bert模型類桅狠,繼承torch.nn.Module,實(shí)例化對(duì)象時(shí)使用from_pretrained()函數(shù)初始化模型權(quán)重轿秧,參數(shù)config用于配置模型參數(shù)
模型輸入是:
input_ids,token_type_ids(可選)咨堤,attention_mask(可選)菇篡,position_ids(可選),
head_mask(可選):0表示head無效,1表示head有效一喘。
inputs_embeds?(可選)如果不使用input_ids驱还,可以直接輸入token的embedding表示嗜暴。
encoder_hidden_states(可選):encoder最后一層的隱含狀態(tài)序列,模型配置為decoder時(shí)议蟆,需要此輸入闷沥。
encoder_attention_mask(可選):encoder最后一層隱含狀態(tài)序列是否參與attention計(jì)算,防止padding部分參與咐容,模型配置為decoder時(shí)舆逃,需要此輸入.
返回類型tuple(torch.FloatTensor):
last_hidden_state:模型最后一層輸出的隱含層狀態(tài)序列
pooler_output :最后一層隱含層狀態(tài)序列經(jīng)過一層全連接和Tanh激活后,第一個(gè)toekn對(duì)應(yīng)位置的輸出戳粒。
hidden_states(可選路狮,當(dāng)output_hidden_states=True或者config.output_hidden_states=True):每一層和初始embedding層輸出的隱含狀態(tài)
attentions(可選,當(dāng)output_attentions=True或者config.output_attentions=True):attention softmax后的attention權(quán)重蔚约,用于在自注意力頭中計(jì)算權(quán)重平均值奄妨。
4.?BertForPreTraining
這個(gè)類是論文中做pre_train時(shí)的兩個(gè)任務(wù),a?masked language modeling and a?next sentence prediction 苹祟,模型主體與BertModel一樣砸抛,只是輸入輸出上稍有不同。
輸入中增加了labels?树枫,next_sentence_label直焙,分別用于兩個(gè)任務(wù)計(jì)算loss時(shí)用。
輸出主要是loss团赏,prediction_scores箕般,seq_relationship_scores分別表示兩個(gè)任務(wù)的總loss,MLM任務(wù)的loss和NSP任務(wù)的loss舔清。
5. BertForMaskedLM
6. BertForNextSentencePrediction
這兩個(gè)類就是把兩個(gè)任務(wù)分開了丝里,單獨(dú)進(jìn)行
7. BertForSequenceClassification
這個(gè)類用于句子分類或回歸任務(wù),繼承torch.nn.Module体谒,實(shí)例化依然使用from_pretrained+ config配置杯聚。
輸入相比BertModel多了一個(gè)label
輸出主要是loss,logits(softmax之前的分類分?jǐn)?shù))等tuple(torch.FloatTensor)
8. BertForMultipleChoice
這個(gè)是用于多項(xiàng)選擇分類抒痒,例如幌绍,RocStories/SWAG tasks,這個(gè)分支我不了解故响,簡(jiǎn)單搜了一下就是給出前面幾句話傀广,讓你從后面幾個(gè)選項(xiàng)中選出接下來的話是哪個(gè),感覺是知識(shí)推理彩届。
輸入輸出與BertForSequenceClassification一樣伪冰。
9. BertForTokenClassification
這個(gè)類用于對(duì)token分類,如命名實(shí)體識(shí)別任務(wù)樟蠕,從給定的輸入中識(shí)別出命名實(shí)體贮聂,所以是對(duì)最小單位toekn的分類靠柑。
輸入輸出同BertForSequenceClassification
10. BertForQuestionAnswering
這個(gè)類適用于問答系統(tǒng)。
輸入中將上面幾個(gè)模型中的label改成了start_position和end_position吓懈,即答案在原文中起始和結(jié)束位置歼冰。
輸出是將預(yù)測(cè)分?jǐn)?shù)改成了對(duì)答案起始位置和結(jié)束位置的預(yù)測(cè)分?jǐn)?shù)。