關(guān)鍵詞
:行為序列建模
,MIMN
,RNN
偶摔,神經(jīng)圖靈機(jī)
金闽,Attention
內(nèi)容摘要
- MIMN原理整體提要解析
- MIMN源碼速覽
- MIMN中參數(shù)維護(hù)方式總結(jié)
- 在風(fēng)控場景下纯露,MIMN的訓(xùn)練,部署代碼實戰(zhàn)
本文主要是MIMN原理迅速掃描代芜,實戰(zhàn)部分見行為序列建模:MIMN系列2——消費Kafka實時預(yù)測代碼實戰(zhàn)
研究背景
本文受到字節(jié)跳動技術(shù)團(tuán)隊的一片博客《行為序列模型在抖音風(fēng)控中的應(yīng)用》的啟發(fā)埠褪,在長序列建模中引入MIMN
算法(Multi-channel user Interest Memory Network),進(jìn)一步研究了阿里媽媽MIMN的論文和源碼挤庇,將該算法成功部署到了風(fēng)控業(yè)務(wù)系統(tǒng)钞速,使得模型可以接受任意長度的歷史序列對實體進(jìn)行風(fēng)險預(yù)測,同時引入外部存儲記錄在此之前所有的記憶狀態(tài)罚随,當(dāng)有新的序列元素進(jìn)入時玉工,讀寫記錄實時預(yù)測,簡單而言相比于原始的通過滑窗限制序列長度的LSTM算法淘菩,MIMN具有兩大優(yōu)勢:
-
歷史長序列建模
:輸入給模型的用戶行為序列越長遵班,理論上模型的效果越好,然而傳統(tǒng)的RNN對歷史長序列表征能力有限潮改,而MIMN將歷史信息的表征和Y值解耦狭郑,可以根據(jù)序列本身記錄純粹的歷史所有記憶信息。 -
實時增量預(yù)測
:改變了部署方式汇在,傳統(tǒng)的RNN在實時預(yù)測時面臨推理延遲和存儲占用大的問題翰萨,MIMN采用外部存儲記錄最新的存量記憶狀態(tài),增量部分新來一個行為對接一個MIMN單元糕殉,讀寫修改維護(hù)狀態(tài)亩鬼,大大降低了在線部署實時推理的延遲和記憶存儲的空間占用殖告。
MIMN原理迅速概括
MIMN論文涉及好幾個獨立的知識點,作者的創(chuàng)新是將這些技術(shù)串起來解決了一個實際的問題雳锋,其中設(shè)計的子模塊包括NTM神經(jīng)圖靈機(jī)
黄绩,MIU記憶感知單元
,DIN注意力網(wǎng)絡(luò)
三個知識點玷过,本文對于這三塊不做展開爽丹,只在整體層面介紹下幾大模塊的最用,以及內(nèi)部參數(shù)的更新維護(hù)方式辛蚊,原論文地址粤蝎。
(1)模型輸入輸出介紹
下面先從模型的輸入開始了解MIMN,左側(cè)橙色是增存量序列構(gòu)建記憶的過程袋马,右側(cè)是在線部署時的預(yù)測部分初澎。
對于增存量記憶構(gòu)建部分,輸入是歷史所有序列元素飞蛹,每個元素包括物品id和物品的其他上下文信息拼接的結(jié)果谤狡,序列元素輸入的目的是維護(hù)了一個M矩陣和S矩陣,對于每一個用戶都有它對應(yīng)的M和S矩陣卧檐,每來一個新的序列元素墓懂,都會對M和S進(jìn)行更新
-
M矩陣
:負(fù)責(zé)對用戶原始?xì)v史行為序列信息的表征,它通過NTM神經(jīng)圖靈機(jī)實現(xiàn)霉囚,通過讀頭和寫頭對NTM的結(jié)果進(jìn)行更新 -
S矩陣
:負(fù)責(zé)從M矩陣中提取高階信息捕仔,配合目標(biāo)物品進(jìn)行DIN Attention從記憶中提取對目標(biāo)有益的信息,彌補(bǔ)M矩陣的不足盈罐,它通過MIU模塊實現(xiàn)
對于在線部署部分榜跌,輸入是目標(biāo)物品(Target Ad),歷史記憶的讀輸出(Read Head)盅粪,記憶感知模塊和目標(biāo)物品的Attention輸出钓葫,以及其他上下文信息(Context Feas),四大輸入拼接之后兩層全連接在softmax得到0-1的輸出票顾,預(yù)測用戶是否對目標(biāo)物品有行為交互础浮。
對于序列元素,是由歷史到現(xiàn)在所有商品/廣告形成的序列奠骄,細(xì)分的話有三種豆同,一種是歷史商品,一種是最后一個商品(或者是當(dāng)前最新的一次行為商品)含鳞,一種是目標(biāo)商品(通過召回得到的候選商品)影锈,三個作用如下
-
歷史商品
:用于刷存量構(gòu)成S和M矩陣 -
最后一個商品
:用于在線部署階段,觸發(fā)UIC更新用戶的S和M矩陣 -
目標(biāo)商品
:用于在線部署階段,調(diào)用UIC的M矩陣拿到讀頭輸出鸭廷,以及調(diào)用S矩陣進(jìn)行Attention枣抱,從而輸入全連接進(jìn)行ctr預(yù)測
搞清楚三種元素的區(qū)別基本MIMN大體上吃透一半了。
(2)模型部署介紹
模型部署也分為增存量記憶維護(hù)靴姿,和線上預(yù)測兩個部分
虛線下面是增存量記憶維護(hù)沃但,增量和存量的行為序列產(chǎn)出UIC Server的M矩陣和S矩陣以及其他記憶信息,沒來一個新的序列元素就更新UIC的內(nèi)容佛吓,不需要全部從頭開始重新計算記憶信息。虛線上的在線預(yù)測部分垂攘,簡單而言就是根據(jù)目標(biāo)物品信息维雇,用戶靜態(tài)信息,再去UIC中拿到無延遲的記憶信息晒他,預(yù)測得到用戶對目標(biāo)響應(yīng)概率吱型。這兩個流程是完全解耦的,相當(dāng)于UIC對實時預(yù)測部分是無延遲的陨仅,不再像傳統(tǒng)RNN那樣維護(hù)歷史序列id津滞,而是維護(hù)一個歷史到現(xiàn)在為止的記憶矩陣在外部存儲即可。
MIMN源碼速覽
下面進(jìn)一步了解MIMN都從源碼開始灼伤,源碼地址触徐,源碼比較復(fù)雜涉及一些其他算法,挑一些重點記錄一下狐赡。
(1)主模型框架類
模型的主類是Model_MIMN
class Model_MIMN(Model):
def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, MEMORY_SIZE, SEQ_LEN=400, Mem_Induction=0,
Util_Reg=0, use_negsample=False, mask_flag=False):
super(Model_MIMN, self).__init__(n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE,
BATCH_SIZE, SEQ_LEN, use_negsample, Flag="MIMN")
self.reg = Util_Reg
...
該類繼承Model類撞鹉,Model類主要包含輸入序列id的embedding映射過程和最后的全連接過程,NTM颖侄,MIU鸟雏,DIN Attention全部在子類Model_MIMN中。
class Model(object):
def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=False, Flag="DNN"):
self.model_flag = Flag
self.reg = False
self.use_negsample = use_negsample
with tf.name_scope('Inputs'):
...
# Embedding layer
with tf.name_scope('Embedding_layer'):
...
# 基于之前網(wǎng)絡(luò)的輸出構(gòu)造最后的全連接層
def build_fcn_net(self, inp, use_dice=False):
bn1 = tf.layers.batch_normalization(inputs=inp, name='bn1')
...
從功能上來說Model_MIMN的目的就是構(gòu)造出最后一層全連接的輸入inp览祖,inp輸入到全連接層孝鹊,全連接包含batchNorm和兩層全連接,和上圖灰色的在線預(yù)測部分內(nèi)容一致展蒂。
(2)MIMN單元
這是整個代碼的核心又活,先看MIMN單元的實例化
cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE, memory_size=MEMORY_SIZE, memory_vector_dim=2 * EMBEDDING_DIM,
read_head_num=1, write_head_num=1,
reuse=False, output_dim=HIDDEN_SIZE, clip_value=20, batch_size=BATCH_SIZE,
mem_induction=Mem_Induction, util_reg=Util_Reg)
在Model_MIMN中實例化了一個MIMN單元,而每一個序列的輸入都會進(jìn)這個MIMN單元玄货,全局共享這個MIMN單元的模型參數(shù)皇钞,比如控制器和MIU中的GRU部分。在實例化MIMN單元的時候松捉,這一段代碼初始化了S矩陣
if self.mem_induction > 0:
self.channel_rnn = single_cell(self.memory_vector_dim)
# TODO channel_rnn_state是S矩陣 [[256, 32], [256, 32], [256, 32], [256, 32]]
self.channel_rnn_state = [self.channel_rnn.zero_state(batch_size, tf.float32) for i in range(memory_size)]
self.channel_rnn_output = [tf.zeros(((batch_size, self.memory_vector_dim))) for i in range(memory_size)]
S矩陣為全0初始化夹界,維度是[memory_size, batch_size, memory_dim],memory_size是記憶矩陣的高,memory_dim是記憶矩陣的寬可柿,每個輸入進(jìn)來的樣本都會有有一個自己的S矩陣鸠踪。
下面初始化M矩陣的狀態(tài),當(dāng)模型才開始訓(xùn)練和用戶處于冷啟動的時候复斥,狀態(tài)需要初始化营密,M矩陣比S矩陣復(fù)雜,會多一些相關(guān)的變量
state = cell.zero_state(BATCH_SIZE, tf.float32)
注意zero_state將BATCH_SIZE傳進(jìn)去目锭,說明初始化和輸入訓(xùn)練的用戶數(shù)量有關(guān)评汰,實際是每個用戶都分配了一個初始化狀態(tài)。舉個例子看M矩陣的初始化
M = expand(
tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim],
initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5),
trainable=False)),
dim=0, N=batch_size)
def expand(x, dim, N):
return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim)
對于每一個輸入的用戶痢虹,給他一個均值是0標(biāo)準(zhǔn)差是1e-5的隨機(jī)(4,32)的初始化被去,然后復(fù)制batch_size(比如256)的份數(shù),拼接成(256,4,32)的該batch下的init_M矩陣奖唯。由此可見雖然每個用戶都給到一個單獨的初始化M惨缆,但是他們初始化的結(jié)果是一模一樣的,注意該變量trainable=False丰捷,不隨著損失函數(shù)優(yōu)化迭代坯墨。同理創(chuàng)建controller_state
,read_vector病往,w_list捣染,M,key_M荣恐,w_aggre其他NTM需維護(hù)的變量液斜,其中w_list包含了讀頭和寫頭。
(3)歷史序列刷存量構(gòu)建M和S矩陣
在MIMN單元實例化和MIMN的state初始化后叠穆,作者開始將歷史200長度的序列灌入MIMN單元少漆,代碼如下
for t in range(SEQ_LEN):
output, state, temp_output_list = cell(self.item_his_eb[:, t, :], state)
if mask_flag:
# TODO mask的作用是修正狀態(tài),排除prepare階段由于padding導(dǎo)致的state變動
state = clear_mask_state(state, begin_state, begin_channel_rnn_output, self.mask, cell, t)
# 記錄下每個序列元素輸出的output和status
self.mimn_o.append(output)
self.state_list.append(state)
代碼里面通過item_his_eb[:, t, :]切片拿到了對應(yīng)步長的序列元素硼被,和當(dāng)前的state一起輸入MIMN單元示损,第一個元素對應(yīng)的state是cell.zero_state得到的狀態(tài),后面的都是在循環(huán)中更新最新的state給下一個序列元素使用嚷硫。注意這個for循環(huán)構(gòu)造了一張tensorflow長圖检访,及從第一個MIMN走到最后一個MIMN的路徑,每一個樣本仔掸,每一個批次進(jìn)來的時候脆贵,都要經(jīng)過這條路徑,互不干擾起暮,代碼里面的self.state_list可以打印出來看一下卖氨,每一個樣本的第一次state都是0初始化,不會存在參數(shù)繼承的情況。
clear_mask_state函數(shù)是避免左邊padding為0給state帶來影響筒捺,代碼如下
def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask, cell, t):
# TODO mask[:, t] = [256, 1] => [256, 1]
# TODO 如果mask是0相當(dāng)于將controller_state重新置為begin_state,全0初始化,否則保持原樣不變
state["controller_state"] = (1 - tf.reshape(mask[:, t], (batch_size, 1))) * begin_state[
"controller_state"] + tf.reshape(mask[:, t], (batch_size, 1)) * state["controller_state"]
...
以controller_state的計算為例柏腻,如果mask是0(代表padding了0),則左式保留controller_state打回原樣成為begin_state系吭,否則mask是1(代表不padding五嫂,是實際的序列元素),則左式刪除肯尺,右式和state["controller_state"]沒有差異保留模型對controller_state的更改沃缘。
(4)看看MIMN在做什么
下面深入這個cell(self.item_his_eb[:, t, :], state),看看MIMN在做什么则吟,代碼較長孩灯,挑提綱挈領(lǐng)的說。先看看這東西輸入輸出啥
def __call__(self, x, prev_state):
return read_output, {
"controller_state": controller_state,
"read_vector_list": read_vector_list,
"w_list": w_list,
"M": M,
"key_M": key_M, # TODO key_M用完了之后沒有修改
"w_aggre": w_aggre,
"sum_aggre": sum_aggre
}, output_list
輸入是當(dāng)前步長的元素embedding和當(dāng)前最新的state逾滥,輸出是讀M矩陣的輸出,最新的狀態(tài)败匹,以及讀S矩陣的輸出寨昙,簡單說一下三個輸出的代碼鏈路
- 讀M矩陣的輸出:基于當(dāng)前輸入的序列元素,和上一個狀態(tài)的讀輸出掀亩,經(jīng)過NTM的控制器GRU舔哪,得到控制器輸出,進(jìn)一步計算得到讀寫之前記憶矩陣的w權(quán)重槽棍,通過該權(quán)重得到最新的讀輸出捉蚤,和控制器輸出拼接得到最終的read_output
- 最新的狀態(tài):在讀M矩陣的輸出的計算過程中,同步記錄下變動的state
- 讀S矩陣的輸出:通過將當(dāng)前步長的元素和上一個記憶矩陣輸入多通道GRU炼七,得到當(dāng)前步長的讀S矩陣的輸出缆巧,同時更新S矩陣狀態(tài)。
總結(jié)數(shù)據(jù)輸入MIMN單元之后豌拙,輸出讀M和S矩陣的輸出陕悬,以及更新M和S矩陣的參數(shù)狀態(tài),其中讀M和S矩陣的輸出要輸入最后的全連接模型進(jìn)行ctr預(yù)測按傅,更新M和S矩陣的參數(shù)狀態(tài)需要輸入給下一個序列元素進(jìn)行記憶更新來表征用戶的行為捉超。
(5)MIMN單元的后處理,構(gòu)造主模型輸入
MIMN的輸出需要準(zhǔn)備構(gòu)造為最終主模型的輸入的唯绍,首先用擁有最新的state的MIMN單元將目標(biāo)商品灌進(jìn)來走一邊拼岳,拿到讀輸出,來表征原始記憶信息况芒,第二第三全部不要惜纸,只要read_out
read_out, _, _ = cell(self.item_eb, state)
然后拿到現(xiàn)在最新的讀S矩陣的輸出,和目標(biāo)商品一起輸入給DIN Attention,提取高階特征
if Mem_Induction == 1:
channel_memory_tensor = tf.concat(temp_output_list, 1)
multi_channel_hist = din_attention(self.item_eb, channel_memory_tensor, HIDDEN_SIZE, None, stag='pal')
# TODO read_out是讀取M矩陣輸出的結(jié)果堪簿,multi_channel_hist是讀取S矩陣輸出的結(jié)果痊乾,其他都是目標(biāo)商品自身特征和上下文特征
inp = tf.concat([self.item_eb, self.item_his_eb_sum, read_out, tf.squeeze(multi_channel_hist),
mean_memory * self.item_eb], 1)
最終的inp包含read_out, tf.squeeze(multi_channel_hist)這兩大主要特征椭更,以及其他上下文特征哪审。
在回過頭來看圖示,很清楚了呀虑瀑,Target Ad拿到M的Read Head湿滓,同時和最新的S一起輸入Attention。inp最終輸入全連接進(jìn)行ctr預(yù)測舌狗。整個代碼的概覽結(jié)束叽奥,里面復(fù)雜的NTM和DIN Attention先不展開研究。
MIMN參數(shù)維護(hù)方式總結(jié)
作者的代碼是訓(xùn)練部分痛侍,該代碼的目的僅僅是訓(xùn)練出控制器GRU朝氓,MIU的GRU,DIN以及其他幾個全連接的參數(shù)主届,保存在tensorflow網(wǎng)絡(luò)中赵哲,而S和M矩陣雖然在里面也產(chǎn)出了,但是真正部署上線肯定是重新刷歷史存量所有序列得到的君丁,而不是采用padding和截取200的方式枫夺,示意圖如下
其中NTM的讀寫w權(quán)重直接基于cos相似度計算得到,得到后直接更新M矩陣绘闷,不需要保存橡庞,其他記憶部分都是保存到外部存儲自行維護(hù),而右側(cè)部分全部是tensorflow圖來維護(hù)印蔗,不需要手動維護(hù)扒最,在線上環(huán)節(jié),讀取外部存儲拿到記憶參數(shù)喻鳄,輸入給tensorflow圖即可完成預(yù)測扼倘。
另外看一下記憶參數(shù)是如何初始化,以及如何更新的
其中有的初始化是需要模型學(xué)習(xí)的除呵,在部署的時候需要在訓(xùn)練的網(wǎng)絡(luò)中將它恢復(fù)出來再菊,否則初始化不一樣,有些初始化是0初始化是寫死的颜曾,相對而言方便一點纠拔。