bert多標簽分類實驗

好久沒來更新呛讲,好慚愧凯亮,現(xiàn)在也沒了當(dāng)初做這個的心情,就大概記錄一下吧磷斧。
首先BERT模型是一個像word2vec這種的預(yù)訓(xùn)練模型,word2vec結(jié)構(gòu)比較簡單就是一個最簡單的神經(jīng)網(wǎng)絡(luò)并且只取中間那個隱藏的weights作為詞向量诗芜,而BERT復(fù)雜一點瞳抓,用的是很多層(BASE是12層,也是我實驗用到的)的transformer網(wǎng)絡(luò)結(jié)構(gòu)伏恐,transfomer細節(jié)https://jalammar.github.io/illustrated-transformer/這里講的比較形象好理解孩哑,或者直接去tensorflow看源碼。
BERT的開源代碼里是這樣寫的(在modeling.py里):

class BertModel(object):
  """BERT model ("Bidirectional Encoder Representations from Transformers").
  Example usage:
  # Already been converted into WordPiece token ids
  input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
  input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
  token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
  config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
    num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
  model = modeling.BertModel(config=config, is_training=True,
    input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
  label_embeddings = tf.get_variable(...)
  pooled_output = model.get_pooled_output()
  logits = tf.matmul(pooled_output, label_embeddings)
  ...

  """

  def __init__(self,
               config,
               is_training,
               input_ids,
               input_mask=None,
               token_type_ids=None,
               use_one_hot_embeddings=False,
               scope=None):
    """Constructor for BertModel.
    Args:
      config: `BertConfig` instance.
      is_training: bool. true for training model, false for eval model. Controls
        whether dropout will be applied.
      input_ids: int32 Tensor of shape [batch_size, seq_length].
      input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
        embeddings or tf.embedding_lookup() for the word embeddings.
      scope: (optional) variable scope. Defaults to "bert".
    Raises:
      ValueError: The config is invalid or one of the input tensor shapes
        is invalid.
    """
    config = copy.deepcopy(config)
    if not is_training:
      config.hidden_dropout_prob = 0.0
      config.attention_probs_dropout_prob = 0.0

    input_shape = get_shape_list(input_ids, expected_rank=2)
    batch_size = input_shape[0]
    seq_length = input_shape[1]

    if input_mask is None:
      input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)

    if token_type_ids is None:
      token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

    with tf.variable_scope(scope, default_name="bert"):
      with tf.variable_scope("embeddings"):
        # Perform embedding lookup on the word ids.
        (self.embedding_output, self.embedding_table) = embedding_lookup(
            input_ids=input_ids,
            vocab_size=config.vocab_size,
            embedding_size=config.hidden_size,
            initializer_range=config.initializer_range,
            word_embedding_name="word_embeddings",
            use_one_hot_embeddings=use_one_hot_embeddings)

        # Add positional embeddings and token type embeddings, then layer
        # normalize and perform dropout.
        self.embedding_output = embedding_postprocessor(
            input_tensor=self.embedding_output,
            use_token_type=True,
            token_type_ids=token_type_ids,
            token_type_vocab_size=config.type_vocab_size,
            token_type_embedding_name="token_type_embeddings",
            use_position_embeddings=True,
            position_embedding_name="position_embeddings",
            initializer_range=config.initializer_range,
            max_position_embeddings=config.max_position_embeddings,
            dropout_prob=config.hidden_dropout_prob)

      with tf.variable_scope("encoder"):
        # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
        # mask of shape [batch_size, seq_length, seq_length] which is used
        # for the attention scores.
        attention_mask = create_attention_mask_from_input_mask(
            input_ids, input_mask)

        # Run the stacked transformer.
        # `sequence_output` shape = [batch_size, seq_length, hidden_size].
        self.all_encoder_layers = transformer_model(
            input_tensor=self.embedding_output,
            attention_mask=attention_mask,
            hidden_size=config.hidden_size,
            num_hidden_layers=config.num_hidden_layers,
            num_attention_heads=config.num_attention_heads,
            intermediate_size=config.intermediate_size,
            intermediate_act_fn=get_activation(config.hidden_act),
            hidden_dropout_prob=config.hidden_dropout_prob,
            attention_probs_dropout_prob=config.attention_probs_dropout_prob,
            initializer_range=config.initializer_range,
            do_return_all_layers=True)

      self.sequence_output = self.all_encoder_layers[-1]
      # The "pooler" converts the encoded sequence tensor of shape
      # [batch_size, seq_length, hidden_size] to a tensor of shape
      # [batch_size, hidden_size]. This is necessary for segment-level
      # (or segment-pair-level) classification tasks where we need a fixed
      # dimensional representation of the segment.
      with tf.variable_scope("pooler"):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. We assume that this has been pre-trained
        first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
        self.pooled_output = tf.layers.dense(
            first_token_tensor,
            config.hidden_size,
            activation=tf.tanh,
            kernel_initializer=create_initializer(config.initializer_range))

可以看到先進行embedding對應(yīng)各個id翠桦,然后encoder就是用的transformer_model横蜒,sequence_output是最后一個隱層(這個在閱讀理解任務(wù)會直接拿出來用到),pooled_output就是我們要做的分類任務(wù)拿出來用到的销凑,也就是在[CLS]這里輸出的結(jié)果丛晌。
實驗細節(jié)其實記不太清了,分類用到的是run_classifier.py斗幼。說兩點要改的地方澎蛛。一個是用自己的數(shù)據(jù)跑的話需要把讀文件那部分處理一下,一個是原代碼是處理分類的蜕窿,如果做多標簽分類谋逻,模型的最后一步輸出要從softmax改為多個sigmoid,相應(yīng)代碼是從create_model函數(shù)的

probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)

one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)

改為

probabilities = tf.nn.sigmoid(logits)
per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
loss_batch = tf.reduce_mean(per_example_loss, axis=1)
loss = tf.reduce_mean(loss_batch)

另外說一下桐经,他的源代碼回調(diào)函數(shù)比較多毁兆,也不好拆解模型加載和預(yù)測部分∫跽酰可以改寫一下气堕。

gpu_config = tf.ConfigProto()
gpu_config.gpu_options.allow_growth = True
sess = tf.Session(config=gpu_config)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if os.path.exists(MODEL_PATH):
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
feed_dict = {input_ids:batch_input_ids, input_mask:batch_input_mask, segment_ids:batch_segment_ids}
sess.run([probabilities], feed_dict)

大概這樣寫- -,很傳統(tǒng)易讀的tensorflow加載和預(yù)測方式畔咧,MODEL_PATH是我們自己訓(xùn)練好的模型茎芭,probabilities這里是create_model那里的定義的probabilities,入?yún)nput_ids那些是我們用tf.placeholder定義好的參數(shù)盒卸,batch_input_ids那些是我們feature那里拿到一批批對應(yīng)名字的數(shù)據(jù)骗爆。這樣就可以脫離他的各種fn回調(diào)函數(shù),分離加載和預(yù)測部分了蔽介≌叮或者用將ckpt轉(zhuǎn)換為pb和variables煮寡,用Tensorflow modeling serving的方式去部署也行。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末犀呼,一起剝皮案震驚了整個濱河市幸撕,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌外臂,老刑警劉巖坐儿,帶你破解...
    沈念sama閱讀 206,839評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異宋光,居然都是意外死亡貌矿,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,543評論 2 382
  • 文/潘曉璐 我一進店門罪佳,熙熙樓的掌柜王于貴愁眉苦臉地迎上來逛漫,“玉大人,你說我怎么就攤上這事赘艳∽谜保” “怎么了?”我有些...
    開封第一講書人閱讀 153,116評論 0 344
  • 文/不壞的土叔 我叫張陵蕾管,是天一觀的道長枷踏。 經(jīng)常有香客問我,道長掰曾,這世上最難降的妖魔是什么旭蠕? 我笑而不...
    開封第一講書人閱讀 55,371評論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮旷坦,結(jié)果婚禮上下梢,老公的妹妹穿的比我還像新娘。我一直安慰自己塞蹭,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 64,384評論 5 374
  • 文/花漫 我一把揭開白布讶坯。 她就那樣靜靜地躺著番电,像睡著了一般。 火紅的嫁衣襯著肌膚如雪辆琅。 梳的紋絲不亂的頭發(fā)上漱办,一...
    開封第一講書人閱讀 49,111評論 1 285
  • 那天,我揣著相機與錄音婉烟,去河邊找鬼娩井。 笑死,一個胖子當(dāng)著我的面吹牛似袁,可吹牛的內(nèi)容都是我干的洞辣。 我是一名探鬼主播咐刨,決...
    沈念sama閱讀 38,416評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼扬霜!你這毒婦竟也來了定鸟?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,053評論 0 259
  • 序言:老撾萬榮一對情侶失蹤著瓶,失蹤者是張志新(化名)和其女友劉穎联予,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體材原,經(jīng)...
    沈念sama閱讀 43,558評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡沸久,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,007評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了余蟹。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片卷胯。...
    茶點故事閱讀 38,117評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖客叉,靈堂內(nèi)的尸體忽然破棺而出诵竭,到底是詐尸還是另有隱情,我是刑警寧澤兼搏,帶...
    沈念sama閱讀 33,756評論 4 324
  • 正文 年R本政府宣布卵慰,位于F島的核電站,受9級特大地震影響佛呻,放射性物質(zhì)發(fā)生泄漏裳朋。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,324評論 3 307
  • 文/蒙蒙 一吓著、第九天 我趴在偏房一處隱蔽的房頂上張望鲤嫡。 院中可真熱鬧,春花似錦绑莺、人聲如沸暖眼。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽诫肠。三九已至,卻和暖如春欺缘,著一層夾襖步出監(jiān)牢的瞬間栋豫,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評論 1 262
  • 我被黑心中介騙來泰國打工谚殊, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留丧鸯,地道東北人。 一個月前我還...
    沈念sama閱讀 45,578評論 2 355
  • 正文 我出身青樓嫩絮,卻偏偏與公主長得像丛肢,于是被迫代替她去往敵國和親围肥。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 42,877評論 2 345

推薦閱讀更多精彩內(nèi)容