tensorflow 2.0 下 bilstm + attention 實(shí)現(xiàn)文本分類 demo

代碼如下:
需要注意一下幾點(diǎn):
1)利用 keras 里面的 layer 或者 variable, 盡量取一個(gè)名字粤铭,不然多個(gè)相同的 layer 出來(lái), 跑的時(shí)候會(huì)報(bào)錯(cuò)
2)Bidirectional 必須一個(gè)正向 一個(gè) 反向
3)CategoricalCrossentropy loss fun 的輸入?yún)?shù)不能寫反了,事實(shí)上,寫反了蝴悉,這個(gè)函數(shù)不會(huì)報(bào)錯(cuò),只會(huì)訓(xùn)練不出來(lái)瘾敢,因?yàn)楹瘮?shù)內(nèi)部有個(gè)類型轉(zhuǎn)換的拍冠,這個(gè)筆誤很難發(fā)現(xiàn)
4)@tf.function 負(fù)責(zé)將函數(shù)轉(zhuǎn)化為圖模型,里面的部分會(huì)在 gpu 上跑簇抵,會(huì)加快速度

解釋下這里 attention 的用法:
一般的解釋見:


image.png

實(shí)際使用的時(shí)候:
query 假定是待求的參數(shù)庆杜,可以理解為 一個(gè)判斷文章是什么分類的問(wèn)題,
然后 key = value 為 bi-lstm 每一步的隱狀態(tài)碟摆, 最終綜合的結(jié)果是各個(gè)隱狀態(tài)的對(duì)于
該 query 的加權(quán)求和欣福,并最后加一個(gè) dense 層,變?yōu)樽罱K的分類結(jié)果

query => atten_u
attention layer 使用以下函數(shù)計(jì)算 key 和 query 的相似度

image.png

最終使用的加權(quán)結(jié)果為:


image.png
#!/usr/bin/env python 
#-*- coding:utf-8 -*-

import os
import sys
import warnings
import pickle
import datetime
import tensorflow as tf 
import pandas as pd
import traceback
import time 
import json
import numpy as np 
from tensorflow import keras 
from tensorflow.keras import layers
from tensorflow.keras import Input 
from tensorflow.keras.layers import Dense 
from tensorflow.keras.layers import LSTM 
from tensorflow.keras.layers import Bidirectional 
from tensorflow.keras.layers import Dropout 
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import BatchNormalization
warnings.filterwarnings("ignore")


####################  helper function #########################
def one_hot_encode(raw_y, num_classes):
  index = np.array(raw_y)
  class_cnt = num_classes #np.max(index) + 1 
  out = np.zeros((index.shape[0], class_cnt))
  out[np.arange(index.shape[0]), index] = 1
  return out 

def load_sample(fn, max_seq_len, word_dict, num_classes):
  text_df = pd.read_csv(fn)
  raw_y = []
  raw_x = []
  for i in range(len(text_df)):
    label = text_df['label'][i]
    raw_y.append(int(label))

    text = text_df['text'][i]
    text_len = len(text)
    x = np.zeros(max_seq_len, dtype = np.int32)
    if text_len <= max_seq_len:
      for i in range(text_len):
        x[i] = word_dict[text[i]]
    else:
      for i in range(text_len - max_seq_len, text_len):
        x[i - text_len + max_seq_len] = word_dict[text[i]]
    raw_x.append(x)

  all_x = np.array(raw_x)
  all_y = one_hot_encode(raw_y, num_classes)
  return all_x, all_y 

def batch_iter(x, y, batch_size = 16):
  data_len = len(x)
  num_batch = (data_len + batch_size - 1) // batch_size
  indices = np.random.permutation(np.arange(data_len))
  x_shuff = x[indices]
  y_shuff = y[indices]
  for i in range(num_batch):
    start_offset = i*batch_size 
    end_offset = min(start_offset + batch_size, data_len)
    yield i, num_batch, x_shuff[start_offset:end_offset], y_shuff[start_offset:end_offset]


######################### model start #####################
class RnnAttentionLayer(layers.Layer):
  def __init__(self, attention_size, drop_rate):
    super().__init__()
    self.attention_size = attention_size
    self.dropout = Dropout(drop_rate, name = "rnn_attention_dropout")

  def build(self, input_shape):
    self.attention_w = self.add_weight(name = "atten_w", shape = (input_shape[-1], self.attention_size), initializer = tf.random_uniform_initializer(), dtype = "float32", trainable = True)
    self.attention_u = self.add_weight(name = "atten_u", shape = (self.attention_size,), initializer = tf.random_uniform_initializer(), dtype = "float32", trainable = True)
    self.attention_b = self.add_weight(name = "atten_b", shape = (self.attention_size,), initializer = tf.constant_initializer(0.1), dtype = "float32", trainable = True)    
    super().build(input_shape)

  def call(self, inputs, training):
    x = tf.tanh(tf.add(tf.tensordot(inputs, self.attention_w, axes = 1), self.attention_b))
    x = tf.tensordot(x, self.attention_u, axes = 1)
    x = tf.nn.softmax(x)
    weight_out = tf.multiply(tf.expand_dims(x, -1), inputs)
    final_out = tf.reduce_sum(weight_out, axis = 1) 
    drop_out = self.dropout(final_out, training = training)
    return drop_out

class RnnLayer(layers.Layer):
  def __init__(self, rnn_size, drop_rate):
    super().__init__()
    fwd_lstm = LSTM(rnn_size, return_sequences = True, go_backwards= False, dropout = drop_rate, name = "fwd_lstm")
    bwd_lstm = LSTM(rnn_size, return_sequences = True, go_backwards = True, dropout = drop_rate, name = "bwd_lstm")
    self.bilstm = Bidirectional(merge_mode = "concat", layer = fwd_lstm, backward_layer = bwd_lstm, name = "bilstm")
    #self.bilstm = Bidirectional(LSTM(rnn_size, activation= "relu", return_sequences = True, dropout = drop_rate))

  def call(self, inputs, training):
    outputs = self.bilstm(inputs, training = training)
    return outputs
 
class Model(tf.keras.Model):
  def __init__(self, num_classes, drop_rate, vocab_size, embedding_size, rnn_size, attention_size):
    super().__init__()
    self.embedding_layer = Embedding(vocab_size, embedding_size, embeddings_initializer = "uniform", name = "embeding_0")
    self.rnn_layer = RnnLayer(rnn_size, drop_rate)
    self.attention_layer = RnnAttentionLayer(attention_size, drop_rate)
    self.dense_layer = Dense(num_classes, activation = "softmax", kernel_regularizer=keras.regularizers.l2(0.001), name = "dense_1")

  def call(self, input_x, training):
    x = self.embedding_layer(input_x)
    x = self.rnn_layer(x, training = training)
    x = self.attention_layer(x, training = training)
    x = self.dense_layer(x)
    return x

def train(xy_train, xy_val, num_classes, vocab_size, nbr_epoches, batch_size):
  uniq_cfg_name = datetime.datetime.now().strftime("%Y")
  model_prefix = os.path.join(os.getcwd(), "model")
  if not os.path.exists(model_prefix):
    print("create model dir: %s" % model_prefix)
    os.mkdir(model_prefix)

  model_path = os.path.join(model_prefix, uniq_cfg_name)
  model = Model(num_classes, drop_rate = 0.05, vocab_size = vocab_size, embedding_size = 256, rnn_size = 128, attention_size = 128)
  if os.path.exists(model_path):
    model.load_weights(model_path)
    print("load weight from: %s" % model_path)
  
  optimizer = tf.keras.optimizers.Adam(0.01)
  loss_fn = tf.keras.losses.CategoricalCrossentropy()

  loss_metric = tf.keras.metrics.Mean(name='train_loss')
  accuracy_metric = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

  @tf.function 
  def train_step(input_x, input_y, training = True):
    with tf.GradientTape() as tape:
      raw_prob = model(input_x, training)
      #tf.print("raw_prob", raw_prob)
      pred_loss = loss_fn(input_y, raw_prob)
    gradients = tape.gradient(pred_loss, model.trainable_variables)
    if training:
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    # Update the metrics
    loss_metric.update_state(pred_loss)
    accuracy_metric.update_state(input_y, raw_prob)
    return raw_prob 

  for i in range(nbr_epoches):
    t0 = time.time()
    batch_train = batch_iter(xy_train[0], xy_train[1], batch_size = batch_size)
    loss_metric.reset_states()
    accuracy_metric.reset_states()

    for batch_no, batch_tot, data_x, data_y in batch_train:
      predict_prob = train_step(data_x, data_y, True)  
      #if batch_no % 10 == 0:
      #  print("[%d of %d]: loss: %0.3f acc %0.3f" % (batch_no, batch_tot, loss_metric.result(), accuracy_metric.result()))

    print("[train ep %d] [%s]: %0.3f  [%s]: %0.3f" %  (i, "loss", loss_metric.result() , "acc", accuracy_metric.result()))
    model.save_weights(model_path, overwrite=True)

    if (i + 1) % 5 == 0:
      loss_metric.reset_states()
      accuracy_metric.reset_states()
      batch_test = batch_iter(xy_val[0], xy_val[1], batch_size = batch_size)
      for _, _, data_x, data_y in batch_test:
        train_step(data_x, data_y, False)
      print("[***** ep %d] [%s]: %0.3f  [%s]: %0.3f" %  (i, "loss", loss_metric.result() , "acc", accuracy_metric.result()))

if __name__ == "__main__":
  try:
    cur_dir=os.getcwd()
    corps_meta_path = os.path.join(cur_dir, "corps_meta")
    corps_meta = pickle.load(open(corps_meta_path, "rb"))
    max_seq_len = min(64, corps_meta["max_seq_len"])
    num_classes = corps_meta["num_classes"] 
    word_dict = corps_meta["word_dict"] 
    index_dict = corps_meta["index_dict"]
    train_sample_path = os.path.join(cur_dir, "train.csv")
    test_sample_path = os.path.join(cur_dir, "test.csv")

    ### gen samples ###
    train_x, train_y = load_sample(train_sample_path, max_seq_len, word_dict, num_classes)
    test_x, test_y = load_sample(test_sample_path, max_seq_len, word_dict, num_classes)
    key, freq = np.unique(np.argmax(train_y, axis = 1), return_counts = True)
    train([train_x, train_y], [test_x, test_y], num_classes, len(word_dict), nbr_epoches = 100, batch_size = 256)
  except:
    traceback.print_exc()
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末焦履,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子雏逾,更是在濱河造成了極大的恐慌嘉裤,老刑警劉巖,帶你破解...
    沈念sama閱讀 217,542評(píng)論 6 504
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件栖博,死亡現(xiàn)場(chǎng)離奇詭異屑宠,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)仇让,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,822評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門典奉,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人丧叽,你說(shuō)我怎么就攤上這事卫玖。” “怎么了踊淳?”我有些...
    開封第一講書人閱讀 163,912評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵假瞬,是天一觀的道長(zhǎng)陕靠。 經(jīng)常有香客問(wèn)我,道長(zhǎng)脱茉,這世上最難降的妖魔是什么剪芥? 我笑而不...
    開封第一講書人閱讀 58,449評(píng)論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮琴许,結(jié)果婚禮上税肪,老公的妹妹穿的比我還像新娘。我一直安慰自己榜田,他們只是感情好益兄,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,500評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著串慰,像睡著了一般偏塞。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上邦鲫,一...
    開封第一講書人閱讀 51,370評(píng)論 1 302
  • 那天灸叼,我揣著相機(jī)與錄音,去河邊找鬼庆捺。 笑死古今,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的滔以。 我是一名探鬼主播捉腥,決...
    沈念sama閱讀 40,193評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼你画!你這毒婦竟也來(lái)了抵碟?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,074評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤坏匪,失蹤者是張志新(化名)和其女友劉穎拟逮,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體适滓,經(jīng)...
    沈念sama閱讀 45,505評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡敦迄,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,722評(píng)論 3 335
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了凭迹。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片罚屋。...
    茶點(diǎn)故事閱讀 39,841評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖嗅绸,靈堂內(nèi)的尸體忽然破棺而出脾猛,到底是詐尸還是另有隱情,我是刑警寧澤鱼鸠,帶...
    沈念sama閱讀 35,569評(píng)論 5 345
  • 正文 年R本政府宣布尖滚,位于F島的核電站喉刘,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏漆弄。R本人自食惡果不足惜睦裳,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,168評(píng)論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望撼唾。 院中可真熱鬧廉邑,春花似錦、人聲如沸倒谷。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,783評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)渤愁。三九已至牵祟,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間抖格,已是汗流浹背茬末。 一陣腳步聲響...
    開封第一講書人閱讀 32,918評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工蛛碌, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,962評(píng)論 2 370
  • 正文 我出身青樓痰驱,卻偏偏與公主長(zhǎng)得像皿桑,于是被迫代替她去往敵國(guó)和親详拙。 傳聞我的和親對(duì)象是個(gè)殘疾皇子叹阔,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,781評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容