時間序列預(yù)測方法之 Transformer

本文鏈接個人站 | 簡書 | CSDN
版權(quán)聲明:除特別聲明外,本博客文章均采用 BY-NC-SA 許可協(xié)議叶圃。轉(zhuǎn)載請注明出處掺冠。

最近打算分享一些基于深度學(xué)習(xí)的時間序列預(yù)測方法德崭。這是第三篇。

前面介紹的 DeepARDeepState 都是基于 RNN 的模型。RNN 是序列建模的經(jīng)典方法缺猛,它通過遞歸來獲得序列的全局信息荔燎,代價是無法并行有咨。CNN 也可以用來建模序列座享,但由于卷積捕捉的是局部信息渣叛,CNN 模型往往需要通過疊加很多層才能獲得較大的感受野淳衙。后續(xù)我可能會(意思就是未必會)介紹基于 CNN 的時間序列預(yù)測方法箫攀。Google 在 2017 年發(fā)表的大作 Attention Is All You Need 為序列建模提供了另一種思路靴跛,即單純依靠注意力機(jī)制(Attention Mechanism)渡嚣,一步到位獲得全局信息识椰。Google 將這個模型稱為 Transformer裤唠。Transformer 在自然語言處理种蘸、圖像處理等領(lǐng)域都取得了很好的效果。Transformer 的結(jié)構(gòu)如下圖所示(誤

Transformer: Attention Is All You Need

今次要介紹的是一篇 NIPS 2019 的文章 Enhancing the Locality and Breaking the Memory Bottleneck of Transformer on Time Series Forecasting坦辟,該文章將 Transformer 模型應(yīng)用到時間序列預(yù)測中[1]锉走,并提出了一些改進(jìn)方向挪蹭。

我們首先介紹注意力機(jī)制梁厉,然后簡單介紹一下模型词顾,最后給出一個 demo肉盹。

Attention

根據(jù) Google 的方案上忍,Attention 定義為
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
其中 Q\in\mathbb R^{n\times d_k}睡雇, K\in\mathbb R^{m\times d_k}它抱, V\in\mathbb R^{m\times d_v}观蓄。從矩陣的維度信息來看侮穿,可以認(rèn)為 Attention 把一個 n\times d_k 的序列 Q 編碼成一個 n\times d_v 的新序列亲茅。記 Q = [q_1, q_2, \cdots, q_n]^\top克锣,K = [k_1, k_2, \cdots, k_m]^\top袭祟,V = [v_1, v_2, \cdots, v_m]^\top巾乳,可以看到 kv 是一一對應(yīng)的胆绊。單看 Q 中的每一個向量辑舷,有
\mathrm{Attention}(q_t, K, V) = \sum\limits_{s=1}^m\frac 1Z\exp\left(\frac{q_t k_s^\top}{\sqrt{d_k}}\right)v_s\qquad t=0, 1, \cdots, n
其中 Z 是 softmax 函數(shù)的歸一化因子何缓。從上式可以看出碌廓,每一個 q_t 都被編碼成了 v_1, v_2, \cdots, v_m 的加權(quán)和谷婆,v_s 所占的權(quán)重取決于 q_tk_s 的內(nèi)積(點(diǎn)乘)∫彀溃縮放因子 \sqrt{d_k} 起到一定的調(diào)節(jié)作用烤蜕,避免內(nèi)積很大時 softmax 的梯度很小讽营。這種定義下的注意力機(jī)制被稱為縮放點(diǎn)乘注意力(Scaled Dot-Product Attention)。

在 Attention 的基礎(chǔ)上泡徙,Google 又提出了 Multi-Head Attention橱鹏,其定義如下
\begin{aligned} \mathrm{MultiHead}(Q, K, V) &= \mathrm{Concat}(head_1, head_2, \cdots, head_h)\\ head_i &= \mathrm{Attention} (Q_i,K_i,V_i )\\ Q_i &= QW_i^Q\\ K_i &= KW_i^K\\ V_i &= VW_i^V \end{aligned}
其中 W_i^Q,W_i^K\in\mathbb R^{d_k\times\tilde{d_k}}W_i^V\in\mathbb R^{d_v\times\tilde{d_v}}。簡單來說莉兰,就是把 Q狡蝶、KV 通過線性變換映射到不同的表示空間,然后計算 Attention贮勃;重復(fù) h 次贪惹,把得到的 h 個 Attention 的結(jié)果拼接起來寂嘉,最后輸出一個 n\times (h\tilde{d_v}) 的序列奏瞬。

注意力機(jī)制

在 Transformer 中,大部分的 Attention 都是 Self Attention(“自注意力”或“內(nèi)部注意力”)泉孩,就是在一個序列內(nèi)部做 Attention硼端,亦即 \mathrm{Attention}(X,X,X)。更準(zhǔn)確地說寓搬,是 Multi-Head Self Attention珍昨,即 \mathrm{MultiHead}(X,X,X)。Self Attention 可以理解為尋找序列 X 內(nèi)部不同位置之間的聯(lián)系句喷。

Model

前面講的基本上都是 Google 那篇 Transformer 文章的內(nèi)容镣典,現(xiàn)在我們回到時序預(yù)測這篇文章。為了避免混淆唾琼,我們用 Time Series Transformer 指代后者[2]兄春。

先回顧一下之前介紹的 DeepAR 模型:假設(shè)每個時間步的目標(biāo)值 z_t 服從概率分布 l(z_t|\theta_t);先使用 LSTM 單元計算當(dāng)前時間步的隱態(tài) h_t = \mathrm{LSTM}(h_{t-1}, z_{t-1}, x_t)锡溯,再計算概率分布的參數(shù) \theta_t = \theta(h_t)赶舆,最后通過最大化對數(shù)似然 \sum_t \log l(z_t|\theta_t) 來學(xué)習(xí)網(wǎng)絡(luò)參數(shù)。Time Series Transformer 的網(wǎng)絡(luò)結(jié)構(gòu)與 DeepAR 類似祭饭,只是將 LSTM 層替換為 Multi-Head Self Attention 層芜茵,從而不需要遞歸,而是一次性計算所有時間步的 \theta_t倡蝙。如下圖所示[3]

DeepAR 和 Time Series Transformer 的網(wǎng)絡(luò)結(jié)構(gòu)對比

需要注意的是九串,對當(dāng)前時間步做預(yù)測時,只能利用截止到當(dāng)前時間步的輸入悠咱。因此蒸辆,在計算 Attention 時需要增加一個 Mask,將矩陣 QK^\top 的上三角元素置為 -\infty析既。

在此基礎(chǔ)上躬贡,文章對模型又做了兩個改進(jìn)。

第一個改進(jìn)點(diǎn)叫做 Enhancing the locality of Transformer眼坏,字面意思就是增強(qiáng) Transformer 的局域性拂玻。時間序列中通常會有一些異常點(diǎn)酸些,一個觀測值是否應(yīng)該被視作異常相當(dāng)程度上取決于它所處的上下文環(huán)境。而 Multi-Head Self Attention
head_i = \mathrm{softmax}\left(\frac{Q_iK_i^\top}{\sqrt{\tilde{d_k}}}\cdot mask\right)V_i\\ Q_i = XW_i^Q,\quad K_i = XW_i^K,\quad V_i = XW_i^V\\
在計算序列內(nèi)部不同位置的關(guān)系時檐蚜,并沒有考慮各個位置所處的局域環(huán)境魄懂,這會使預(yù)測容易受異常值的干擾。在博客的開頭我們已經(jīng)提到卷積操作可以用來捕捉局部信息闯第。如果在計算 Q_iK_i 時使用卷積代替線性變換市栗,就可以在 Self Attention 中引入局部信息。注意咳短,由于當(dāng)前時間步不能使用未來的信息填帽,這里使用的是因果卷積(causal convolution)。后續(xù)介紹基于 CNN 的時序預(yù)測時咙好,因果卷積還會扮演重要角色篡腌。

經(jīng)典 Transformer (a, b) 和帶卷積的 Transformer (c, d)

第二個改進(jìn)點(diǎn)叫做 Breaking the memory bottleneck of Transformer. 假設(shè)序列長度為 n,Self Attention 的計算量為 O(n^2)勾效。在時序預(yù)測中嘹悼,往往要考慮長程依賴,這種情況下 memory usage 就會比較可觀了层宫。針對這一問題杨伙,文章提出了 LogSparse Self Attention 結(jié)構(gòu),使計算量減少到 O(n(\log n)^2)卒密,如下圖所示缀台。

幾種注意力機(jī)制

Code

按照慣例,這里給出一個基于 TensorFlow 的簡單 demo哮奇。需要說明的是,本 demo 沒有實(shí)現(xiàn) LogSparse Self Attention睛约。

以下定義了 Attention 層和 Transformer 模型:

import tensorflow as tf

class Attention(tf.keras.layers.Layer):
    """
    Multi-Head Convolutional Self Attention Layer
    """
    def __init__(self, dk, dv, num_heads, filter_size):
        super().__init__()
        self.dk = dk
        self.dv = dv
        self.num_heads = num_heads
        
        self.conv_q = tf.keras.layers.Conv1D(dk * num_heads, filter_size, padding='causal')
        self.conv_k = tf.keras.layers.Conv1D(dk * num_heads, filter_size, padding='causal')
        self.dense_v = tf.keras.layers.Dense(dv * num_heads)
        self.dense1 = tf.keras.layers.Dense(dv, activation='relu')
        self.dense2 = tf.keras.layers.Dense(dv)
        
    def split_heads(self, x, batch_size, dim):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, inputs):
        batch_size, time_steps, _ = tf.shape(inputs)
        
        q = self.conv_q(inputs)
        k = self.conv_k(inputs)
        v = self.dense_v(inputs)
        
        q = self.split_heads(q, batch_size, self.dk)
        k = self.split_heads(k, batch_size, self.dk)
        v = self.split_heads(v, batch_size, self.dv)
        
        mask = 1 - tf.linalg.band_part(tf.ones((batch_size, self.num_heads, time_steps, time_steps)), -1, 0)
        
        dk = tf.cast(self.dk, tf.float32)
        
        score = tf.nn.softmax(tf.matmul(q, k, transpose_b=True)/tf.math.sqrt(dk) + mask * -1e9)
        
        outputs = tf.matmul(score, v)
        
        outputs = tf.transpose(outputs, perm=[0, 2, 1, 3])
        outputs = tf.reshape(outputs, (batch_size, time_steps, -1))
        
        outputs = self.dense1(outputs)
        outputs = self.dense2(outputs)
        
        return outputs

class Transformer(tf.keras.models.Model):
    """
    Time Series Transformer Model
    """
    def __init__(self, dk, dv, num_heads, filter_size):
        super().__init__()
        # 注意鼎俘,文章中使用了多層 Attention,為了簡單起見辩涝,本 demo 只使用一層
        self.attention = Attention(dk, dv, num_heads, filter_size)
        self.dense_mu = tf.keras.layers.Dense(1)
        self.dense_sigma = tf.keras.layers.Dense(1, activation='softplus')
    
    def call(self, inputs):
        outputs = self.attention(inputs)
        mu = self.dense_mu(outputs)
        sigma = self.dense_sigma(outputs)
        
        return [mu, sigma]

關(guān)于損失函數(shù)和訓(xùn)練部分贸伐,請參考我們在介紹 DeepAR 時給出的 demo。

為了驗(yàn)證代碼怔揩,我們隨機(jī)生成一個帶有周期的時間序列捉邢。下圖展示了這個序列的一部分?jǐn)?shù)據(jù)點(diǎn)。


時間序列

與 DeepAR 有所不同的是商膊,由于 Attention 結(jié)構(gòu)并不能很好地捕捉序列的順序伏伐,我們加入了相對位置作為特征。

經(jīng)過訓(xùn)練后用于預(yù)測晕拆,效果如下圖所示藐翎,其中陰影部分表示 0.05 分位數(shù) ~ 0.95 分位數(shù)的區(qū)間。


預(yù)測效果

與 DeepAR 對比

  • 從某種意義上來說,兩者的網(wǎng)絡(luò)結(jié)構(gòu)很像吝镣,學(xué)習(xí)的也都是概率分布的參數(shù)堤器。
  • Attention 結(jié)構(gòu)本身不能很好地捕捉序列的順序,當(dāng)然這個不是大問題末贾,因?yàn)橥ǔ碚f時序預(yù)測任務(wù)都會有時間特征闸溃,不需要像自然語言處理時那樣加入額外的位置編碼。
  • 該文章中給出的實(shí)驗(yàn)結(jié)果表明 Time Series Transformer 在捕捉長程依賴方面比 DeepAR 更有優(yōu)勢拱撵。
  • Time Series Transformer 在訓(xùn)練的時候可以并行計算圈暗,這是優(yōu)于 DeepAR 的。不過因?yàn)楹?DeepAR 一樣采用了自回歸結(jié)構(gòu)裕膀,預(yù)測的時候無法并行员串。不僅如此,DeepAR 預(yù)測單個時間步時只需要使用當(dāng)前輸入和上一步輸出的隱態(tài)即可昼扛;而 Transformer 卻需要計算全局的 Attention寸齐。因此在預(yù)測的時候,Transformer 的計算效率很可能不如 DeepAR抄谐。

  1. 嚴(yán)格來說渺鹦,該文章使用的并不是 Transformer。Google 的 Transformer 是一個 Encoder-Decoder 的結(jié)構(gòu)蛹含,而該文章所使用的網(wǎng)絡(luò)結(jié)構(gòu)實(shí)際上是把 DeepAR 中的 LSTM 層替換為 Multi-Head Self Attention 層毅厚。這個結(jié)構(gòu)其實(shí)一個(不完整的) Transformer Decoder 部分。But, whatever... ?

  2. 原文沒有這種說法浦箱。 ?

  3. 這個圖是我自己畫的吸耿,原文沒有。 ?

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
禁止轉(zhuǎn)載酷窥,如需轉(zhuǎn)載請通過簡信或評論聯(lián)系作者咽安。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市蓬推,隨后出現(xiàn)的幾起案子妆棒,更是在濱河造成了極大的恐慌,老刑警劉巖沸伏,帶你破解...
    沈念sama閱讀 206,378評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件糕珊,死亡現(xiàn)場離奇詭異,居然都是意外死亡毅糟,警方通過查閱死者的電腦和手機(jī)红选,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,356評論 2 382
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來留特,“玉大人纠脾,你說我怎么就攤上這事玛瘸。” “怎么了苟蹈?”我有些...
    開封第一講書人閱讀 152,702評論 0 342
  • 文/不壞的土叔 我叫張陵糊渊,是天一觀的道長。 經(jīng)常有香客問我慧脱,道長渺绒,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,259評論 1 279
  • 正文 為了忘掉前任菱鸥,我火速辦了婚禮宗兼,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘氮采。我一直安慰自己殷绍,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,263評論 5 371
  • 文/花漫 我一把揭開白布鹊漠。 她就那樣靜靜地躺著主到,像睡著了一般。 火紅的嫁衣襯著肌膚如雪躯概。 梳的紋絲不亂的頭發(fā)上登钥,一...
    開封第一講書人閱讀 49,036評論 1 285
  • 那天,我揣著相機(jī)與錄音娶靡,去河邊找鬼牧牢。 笑死,一個胖子當(dāng)著我的面吹牛姿锭,可吹牛的內(nèi)容都是我干的塔鳍。 我是一名探鬼主播,決...
    沈念sama閱讀 38,349評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼艾凯,長吁一口氣:“原來是場噩夢啊……” “哼献幔!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起趾诗,我...
    開封第一講書人閱讀 36,979評論 0 259
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎蹬蚁,沒想到半個月后恃泪,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,469評論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡犀斋,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 35,938評論 2 323
  • 正文 我和宋清朗相戀三年贝乎,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片叽粹。...
    茶點(diǎn)故事閱讀 38,059評論 1 333
  • 序言:一個原本活蹦亂跳的男人離奇死亡览效,死狀恐怖却舀,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情锤灿,我是刑警寧澤挽拔,帶...
    沈念sama閱讀 33,703評論 4 323
  • 正文 年R本政府宣布,位于F島的核電站但校,受9級特大地震影響螃诅,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜状囱,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,257評論 3 307
  • 文/蒙蒙 一术裸、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧亭枷,春花似錦袭艺、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,262評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至宣鄙,卻和暖如春袍镀,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背冻晤。 一陣腳步聲響...
    開封第一講書人閱讀 31,485評論 1 262
  • 我被黑心中介騙來泰國打工苇羡, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人鼻弧。 一個月前我還...
    沈念sama閱讀 45,501評論 2 354
  • 正文 我出身青樓设江,卻偏偏與公主長得像,于是被迫代替她去往敵國和親攘轩。 傳聞我的和親對象是個殘疾皇子叉存,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,792評論 2 345