行為序列建模:MIMN系列1——原理初探和源碼解析

關(guān)鍵詞行為序列建模MIMNRNN偶摔,神經(jīng)圖靈機(jī)金闽,Attention

內(nèi)容摘要

  • MIMN原理整體提要解析
  • MIMN源碼速覽
  • MIMN中參數(shù)維護(hù)方式總結(jié)
  • 在風(fēng)控場景下纯露,MIMN的訓(xùn)練,部署代碼實戰(zhàn)

本文主要是MIMN原理迅速掃描代芜,實戰(zhàn)部分見行為序列建模:MIMN系列2——消費Kafka實時預(yù)測代碼實戰(zhàn)


研究背景

本文受到字節(jié)跳動技術(shù)團(tuán)隊的一片博客《行為序列模型在抖音風(fēng)控中的應(yīng)用》的啟發(fā)埠褪,在長序列建模中引入MIMN算法(Multi-channel user Interest Memory Network),進(jìn)一步研究了阿里媽媽MIMN的論文和源碼挤庇,將該算法成功部署到了風(fēng)控業(yè)務(wù)系統(tǒng)钞速,使得模型可以接受任意長度的歷史序列對實體進(jìn)行風(fēng)險預(yù)測,同時引入外部存儲記錄在此之前所有的記憶狀態(tài)罚随,當(dāng)有新的序列元素進(jìn)入時玉工,讀寫記錄實時預(yù)測,簡單而言相比于原始的通過滑窗限制序列長度的LSTM算法淘菩,MIMN具有兩大優(yōu)勢:

  • 歷史長序列建模:輸入給模型的用戶行為序列越長遵班,理論上模型的效果越好,然而傳統(tǒng)的RNN對歷史長序列表征能力有限潮改,而MIMN將歷史信息的表征和Y值解耦狭郑,可以根據(jù)序列本身記錄純粹的歷史所有記憶信息。
  • 實時增量預(yù)測:改變了部署方式汇在,傳統(tǒng)的RNN在實時預(yù)測時面臨推理延遲和存儲占用大的問題翰萨,MIMN采用外部存儲記錄最新的存量記憶狀態(tài),增量部分新來一個行為對接一個MIMN單元糕殉,讀寫修改維護(hù)狀態(tài)亩鬼,大大降低了在線部署實時推理的延遲和記憶存儲的空間占用殖告。

MIMN原理迅速概括

MIMN論文涉及好幾個獨立的知識點,作者的創(chuàng)新是將這些技術(shù)串起來解決了一個實際的問題雳锋,其中設(shè)計的子模塊包括NTM神經(jīng)圖靈機(jī)黄绩,MIU記憶感知單元DIN注意力網(wǎng)絡(luò)三個知識點玷过,本文對于這三塊不做展開爽丹,只在整體層面介紹下幾大模塊的最用,以及內(nèi)部參數(shù)的更新維護(hù)方式辛蚊,原論文地址粤蝎。

(1)模型輸入輸出介紹

下面先從模型的輸入開始了解MIMN,左側(cè)橙色是增存量序列構(gòu)建記憶的過程袋马,右側(cè)是在線部署時的預(yù)測部分初澎。


模型架構(gòu)

對于增存量記憶構(gòu)建部分,輸入是歷史所有序列元素飞蛹,每個元素包括物品id和物品的其他上下文信息拼接的結(jié)果谤狡,序列元素輸入的目的是維護(hù)了一個M矩陣S矩陣,對于每一個用戶都有它對應(yīng)的M和S矩陣卧檐,每來一個新的序列元素墓懂,都會對M和S進(jìn)行更新

  • M矩陣:負(fù)責(zé)對用戶原始?xì)v史行為序列信息的表征,它通過NTM神經(jīng)圖靈機(jī)實現(xiàn)霉囚,通過讀頭和寫頭對NTM的結(jié)果進(jìn)行更新
  • S矩陣:負(fù)責(zé)從M矩陣中提取高階信息捕仔,配合目標(biāo)物品進(jìn)行DIN Attention從記憶中提取對目標(biāo)有益的信息,彌補(bǔ)M矩陣的不足盈罐,它通過MIU模塊實現(xiàn)

對于在線部署部分榜跌,輸入是目標(biāo)物品(Target Ad)歷史記憶的讀輸出(Read Head)盅粪,記憶感知模塊和目標(biāo)物品的Attention輸出钓葫,以及其他上下文信息(Context Feas),四大輸入拼接之后兩層全連接在softmax得到0-1的輸出票顾,預(yù)測用戶是否對目標(biāo)物品有行為交互础浮。

對于序列元素,是由歷史到現(xiàn)在所有商品/廣告形成的序列奠骄,細(xì)分的話有三種豆同,一種是歷史商品,一種是最后一個商品(或者是當(dāng)前最新的一次行為商品)含鳞,一種是目標(biāo)商品(通過召回得到的候選商品)影锈,三個作用如下

  • 歷史商品:用于刷存量構(gòu)成S和M矩陣
  • 最后一個商品:用于在線部署階段,觸發(fā)UIC更新用戶的S和M矩陣
  • 目標(biāo)商品:用于在線部署階段,調(diào)用UIC的M矩陣拿到讀頭輸出鸭廷,以及調(diào)用S矩陣進(jìn)行Attention枣抱,從而輸入全連接進(jìn)行ctr預(yù)測

搞清楚三種元素的區(qū)別基本MIMN大體上吃透一半了。

(2)模型部署介紹

模型部署也分為增存量記憶維護(hù)靴姿,和線上預(yù)測兩個部分

模型部署

虛線下面是增存量記憶維護(hù)沃但,增量和存量的行為序列產(chǎn)出UIC Server的M矩陣和S矩陣以及其他記憶信息,沒來一個新的序列元素就更新UIC的內(nèi)容佛吓,不需要全部從頭開始重新計算記憶信息。虛線上的在線預(yù)測部分垂攘,簡單而言就是根據(jù)目標(biāo)物品信息维雇,用戶靜態(tài)信息,再去UIC中拿到無延遲的記憶信息晒他,預(yù)測得到用戶對目標(biāo)響應(yīng)概率吱型。這兩個流程是完全解耦的,相當(dāng)于UIC對實時預(yù)測部分是無延遲的陨仅,不再像傳統(tǒng)RNN那樣維護(hù)歷史序列id津滞,而是維護(hù)一個歷史到現(xiàn)在為止的記憶矩陣在外部存儲即可。


MIMN源碼速覽

下面進(jìn)一步了解MIMN都從源碼開始灼伤,源碼地址触徐,源碼比較復(fù)雜涉及一些其他算法,挑一些重點記錄一下狐赡。

(1)主模型框架類

模型的主類是Model_MIMN

class Model_MIMN(Model):
    def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, MEMORY_SIZE, SEQ_LEN=400, Mem_Induction=0,
                 Util_Reg=0, use_negsample=False, mask_flag=False):
        super(Model_MIMN, self).__init__(n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE,
                                         BATCH_SIZE, SEQ_LEN, use_negsample, Flag="MIMN")
        self.reg = Util_Reg
...

該類繼承Model類撞鹉,Model類主要包含輸入序列id的embedding映射過程和最后的全連接過程,NTM颖侄,MIU鸟雏,DIN Attention全部在子類Model_MIMN中。

class Model(object):
    def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=False, Flag="DNN"):
        self.model_flag = Flag
        self.reg = False
        self.use_negsample = use_negsample
        with tf.name_scope('Inputs'):
        ...
        # Embedding layer
        with tf.name_scope('Embedding_layer'):
        ...
    # 基于之前網(wǎng)絡(luò)的輸出構(gòu)造最后的全連接層
    def build_fcn_net(self, inp, use_dice=False):
        bn1 = tf.layers.batch_normalization(inputs=inp, name='bn1')
        ...

從功能上來說Model_MIMN的目的就是構(gòu)造出最后一層全連接的輸入inp览祖,inp輸入到全連接層孝鹊,全連接包含batchNorm和兩層全連接,和上圖灰色的在線預(yù)測部分內(nèi)容一致展蒂。

(2)MIMN單元

這是整個代碼的核心又活,先看MIMN單元的實例化

cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE, memory_size=MEMORY_SIZE, memory_vector_dim=2 * EMBEDDING_DIM,
                             read_head_num=1, write_head_num=1,
                             reuse=False, output_dim=HIDDEN_SIZE, clip_value=20, batch_size=BATCH_SIZE,
                             mem_induction=Mem_Induction, util_reg=Util_Reg)

在Model_MIMN中實例化了一個MIMN單元,而每一個序列的輸入都會進(jìn)這個MIMN單元玄货,全局共享這個MIMN單元的模型參數(shù)皇钞,比如控制器和MIU中的GRU部分。在實例化MIMN單元的時候松捉,這一段代碼初始化了S矩陣

        if self.mem_induction > 0:
            self.channel_rnn = single_cell(self.memory_vector_dim)
            # TODO channel_rnn_state是S矩陣 [[256, 32], [256, 32], [256, 32], [256, 32]]
            self.channel_rnn_state = [self.channel_rnn.zero_state(batch_size, tf.float32) for i in range(memory_size)]
            self.channel_rnn_output = [tf.zeros(((batch_size, self.memory_vector_dim))) for i in range(memory_size)]

S矩陣為全0初始化夹界,維度是[memory_size, batch_size, memory_dim],memory_size是記憶矩陣的高,memory_dim是記憶矩陣的寬可柿,每個輸入進(jìn)來的樣本都會有有一個自己的S矩陣鸠踪。

下面初始化M矩陣的狀態(tài),當(dāng)模型才開始訓(xùn)練和用戶處于冷啟動的時候复斥,狀態(tài)需要初始化营密,M矩陣比S矩陣復(fù)雜,會多一些相關(guān)的變量

state = cell.zero_state(BATCH_SIZE, tf.float32)

注意zero_state將BATCH_SIZE傳進(jìn)去目锭,說明初始化和輸入訓(xùn)練的用戶數(shù)量有關(guān)评汰,實際是每個用戶都分配了一個初始化狀態(tài)。舉個例子看M矩陣的初始化

M = expand(
                tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5),
                                        trainable=False)),
                dim=0, N=batch_size)
def expand(x, dim, N):
    return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim)

對于每一個輸入的用戶痢虹,給他一個均值是0標(biāo)準(zhǔn)差是1e-5的隨機(jī)(4,32)的初始化被去,然后復(fù)制batch_size(比如256)的份數(shù),拼接成(256,4,32)的該batch下的init_M矩陣奖唯。由此可見雖然每個用戶都給到一個單獨的初始化M惨缆,但是他們初始化的結(jié)果是一模一樣的,注意該變量trainable=False丰捷,不隨著損失函數(shù)優(yōu)化迭代坯墨。同理創(chuàng)建controller_state
,read_vector病往,w_list捣染,M,key_M荣恐,w_aggre其他NTM需維護(hù)的變量液斜,其中w_list包含了讀頭和寫頭。

(3)歷史序列刷存量構(gòu)建M和S矩陣

在MIMN單元實例化和MIMN的state初始化后叠穆,作者開始將歷史200長度的序列灌入MIMN單元少漆,代碼如下

        for t in range(SEQ_LEN):
            output, state, temp_output_list = cell(self.item_his_eb[:, t, :], state)
            if mask_flag:
                # TODO mask的作用是修正狀態(tài),排除prepare階段由于padding導(dǎo)致的state變動
                state = clear_mask_state(state, begin_state, begin_channel_rnn_output, self.mask, cell, t)
            # 記錄下每個序列元素輸出的output和status
            self.mimn_o.append(output)
            self.state_list.append(state)

代碼里面通過item_his_eb[:, t, :]切片拿到了對應(yīng)步長的序列元素硼被,和當(dāng)前的state一起輸入MIMN單元示损,第一個元素對應(yīng)的state是cell.zero_state得到的狀態(tài),后面的都是在循環(huán)中更新最新的state給下一個序列元素使用嚷硫。注意這個for循環(huán)構(gòu)造了一張tensorflow長圖检访,及從第一個MIMN走到最后一個MIMN的路徑,每一個樣本仔掸,每一個批次進(jìn)來的時候脆贵,都要經(jīng)過這條路徑,互不干擾起暮,代碼里面的self.state_list可以打印出來看一下卖氨,每一個樣本的第一次state都是0初始化,不會存在參數(shù)繼承的情況。
clear_mask_state函數(shù)是避免左邊padding為0給state帶來影響筒捺,代碼如下

        def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask, cell, t):
            # TODO mask[:, t] = [256, 1] => [256, 1]
            # TODO 如果mask是0相當(dāng)于將controller_state重新置為begin_state,全0初始化,否則保持原樣不變
            state["controller_state"] = (1 - tf.reshape(mask[:, t], (batch_size, 1))) * begin_state[
                "controller_state"] + tf.reshape(mask[:, t], (batch_size, 1)) * state["controller_state"]
            ...

以controller_state的計算為例柏腻,如果mask是0(代表padding了0),則左式保留controller_state打回原樣成為begin_state系吭,否則mask是1(代表不padding五嫂,是實際的序列元素),則左式刪除肯尺,右式和state["controller_state"]沒有差異保留模型對controller_state的更改沃缘。

(4)看看MIMN在做什么

下面深入這個cell(self.item_his_eb[:, t, :], state),看看MIMN在做什么则吟,代碼較長孩灯,挑提綱挈領(lǐng)的說。先看看這東西輸入輸出啥

def __call__(self, x, prev_state):
    return read_output, {
                "controller_state": controller_state,
                "read_vector_list": read_vector_list,
                "w_list": w_list,
                "M": M,
                "key_M": key_M,  # TODO key_M用完了之后沒有修改
                "w_aggre": w_aggre,
                "sum_aggre": sum_aggre
             }, output_list

輸入是當(dāng)前步長的元素embedding和當(dāng)前最新的state逾滥,輸出是讀M矩陣的輸出,最新的狀態(tài)败匹,以及讀S矩陣的輸出寨昙,簡單說一下三個輸出的代碼鏈路

  • 讀M矩陣的輸出:基于當(dāng)前輸入的序列元素,和上一個狀態(tài)的讀輸出掀亩,經(jīng)過NTM的控制器GRU舔哪,得到控制器輸出,進(jìn)一步計算得到讀寫之前記憶矩陣的w權(quán)重槽棍,通過該權(quán)重得到最新的讀輸出捉蚤,和控制器輸出拼接得到最終的read_output
  • 最新的狀態(tài):在讀M矩陣的輸出的計算過程中,同步記錄下變動的state
  • 讀S矩陣的輸出:通過將當(dāng)前步長的元素和上一個記憶矩陣輸入多通道GRU炼七,得到當(dāng)前步長的讀S矩陣的輸出缆巧,同時更新S矩陣狀態(tài)。

總結(jié)數(shù)據(jù)輸入MIMN單元之后豌拙,輸出讀M和S矩陣的輸出陕悬,以及更新M和S矩陣的參數(shù)狀態(tài),其中讀M和S矩陣的輸出要輸入最后的全連接模型進(jìn)行ctr預(yù)測按傅,更新M和S矩陣的參數(shù)狀態(tài)需要輸入給下一個序列元素進(jìn)行記憶更新來表征用戶的行為捉超。

(5)MIMN單元的后處理,構(gòu)造主模型輸入

MIMN的輸出需要準(zhǔn)備構(gòu)造為最終主模型的輸入的唯绍,首先用擁有最新的state的MIMN單元將目標(biāo)商品灌進(jìn)來走一邊拼岳,拿到讀輸出,來表征原始記憶信息况芒,第二第三全部不要惜纸,只要read_out

read_out, _, _ = cell(self.item_eb, state)

然后拿到現(xiàn)在最新的讀S矩陣的輸出,和目標(biāo)商品一起輸入給DIN Attention,提取高階特征

        if Mem_Induction == 1:
            channel_memory_tensor = tf.concat(temp_output_list, 1)
            multi_channel_hist = din_attention(self.item_eb, channel_memory_tensor, HIDDEN_SIZE, None, stag='pal')
            # TODO read_out是讀取M矩陣輸出的結(jié)果堪簿,multi_channel_hist是讀取S矩陣輸出的結(jié)果痊乾,其他都是目標(biāo)商品自身特征和上下文特征
            inp = tf.concat([self.item_eb, self.item_his_eb_sum, read_out, tf.squeeze(multi_channel_hist),
                             mean_memory * self.item_eb], 1)

最終的inp包含read_out, tf.squeeze(multi_channel_hist)這兩大主要特征椭更,以及其他上下文特征哪审。


最終輸入構(gòu)造

在回過頭來看圖示,很清楚了呀虑瀑,Target Ad拿到M的Read Head湿滓,同時和最新的S一起輸入Attention。inp最終輸入全連接進(jìn)行ctr預(yù)測舌狗。整個代碼的概覽結(jié)束叽奥,里面復(fù)雜的NTM和DIN Attention先不展開研究。


MIMN參數(shù)維護(hù)方式總結(jié)

作者的代碼是訓(xùn)練部分痛侍,該代碼的目的僅僅是訓(xùn)練出控制器GRU朝氓,MIU的GRU,DIN以及其他幾個全連接的參數(shù)主届,保存在tensorflow網(wǎng)絡(luò)中赵哲,而S和M矩陣雖然在里面也產(chǎn)出了,但是真正部署上線肯定是重新刷歷史存量所有序列得到的君丁,而不是采用padding和截取200的方式枫夺,示意圖如下

模型參數(shù)如何保存

其中NTM的讀寫w權(quán)重直接基于cos相似度計算得到,得到后直接更新M矩陣绘闷,不需要保存橡庞,其他記憶部分都是保存到外部存儲自行維護(hù),而右側(cè)部分全部是tensorflow圖來維護(hù)印蔗,不需要手動維護(hù)扒最,在線上環(huán)節(jié),讀取外部存儲拿到記憶參數(shù)喻鳄,輸入給tensorflow圖即可完成預(yù)測扼倘。
另外看一下記憶參數(shù)是如何初始化,以及如何更新的


參數(shù)的初始化和更新方式

其中有的初始化是需要模型學(xué)習(xí)的除呵,在部署的時候需要在訓(xùn)練的網(wǎng)絡(luò)中將它恢復(fù)出來再菊,否則初始化不一樣,有些初始化是0初始化是寫死的颜曾,相對而言方便一點纠拔。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市泛豪,隨后出現(xiàn)的幾起案子稠诲,更是在濱河造成了極大的恐慌侦鹏,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,039評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件臀叙,死亡現(xiàn)場離奇詭異略水,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)劝萤,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,426評論 3 395
  • 文/潘曉璐 我一進(jìn)店門渊涝,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人床嫌,你說我怎么就攤上這事跨释。” “怎么了厌处?”我有些...
    開封第一講書人閱讀 165,417評論 0 356
  • 文/不壞的土叔 我叫張陵鳖谈,是天一觀的道長。 經(jīng)常有香客問我阔涉,道長缆娃,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,868評論 1 295
  • 正文 為了忘掉前任瑰排,我火速辦了婚禮龄恋,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘凶伙。我一直安慰自己,他們只是感情好它碎,可當(dāng)我...
    茶點故事閱讀 67,892評論 6 392
  • 文/花漫 我一把揭開白布函荣。 她就那樣靜靜地躺著,像睡著了一般扳肛。 火紅的嫁衣襯著肌膚如雪傻挂。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,692評論 1 305
  • 那天挖息,我揣著相機(jī)與錄音金拒,去河邊找鬼。 笑死套腹,一個胖子當(dāng)著我的面吹牛绪抛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播电禀,決...
    沈念sama閱讀 40,416評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼幢码,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了尖飞?” 一聲冷哼從身側(cè)響起症副,我...
    開封第一講書人閱讀 39,326評論 0 276
  • 序言:老撾萬榮一對情侶失蹤店雅,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后贞铣,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體闹啦,經(jīng)...
    沈念sama閱讀 45,782評論 1 316
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,957評論 3 337
  • 正文 我和宋清朗相戀三年辕坝,在試婚紗的時候發(fā)現(xiàn)自己被綠了窍奋。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,102評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡圣勒,死狀恐怖费变,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情圣贸,我是刑警寧澤挚歧,帶...
    沈念sama閱讀 35,790評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站吁峻,受9級特大地震影響滑负,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜用含,卻給世界環(huán)境...
    茶點故事閱讀 41,442評論 3 331
  • 文/蒙蒙 一矮慕、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧啄骇,春花似錦痴鳄、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,996評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至虽惭,卻和暖如春橡类,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背芽唇。 一陣腳步聲響...
    開封第一講書人閱讀 33,113評論 1 272
  • 我被黑心中介騙來泰國打工顾画, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人匆笤。 一個月前我還...
    沈念sama閱讀 48,332評論 3 373
  • 正文 我出身青樓研侣,卻偏偏與公主長得像,于是被迫代替她去往敵國和親炮捧。 傳聞我的和親對象是個殘疾皇子义辕,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,044評論 2 355

推薦閱讀更多精彩內(nèi)容