【CV中的Attention機(jī)制】Non-Local Network的理解與實(shí)現(xiàn)

1. Non-local

Non-Local是王小龍?jiān)贑VPR2018年提出的一個(gè)自注意力模型赦邻。Non-Local Neural Network和Non-Local Means非局部均值去燥濾波有點(diǎn)相似的感覺盗温。普通的濾波都是3×3的卷積核闪朱,然后在整個(gè)圖片上進(jìn)行移動(dòng)挟冠,處理的是3×3局部的信息。Non-Local Means操作則是結(jié)合了一個(gè)比較大的搜索范圍弄慰,并進(jìn)行加權(quán)勺远。

在Non-Local NN這篇文章中的Local也與以上有一定關(guān)系,主要是針對(duì)感受野來說的耙替,一般的卷積的感受野都是3×3或5×5的大小亚侠,而使用Non-Local可以讓感受野很大,而不是局限于一個(gè)局部領(lǐng)域俗扇。

與之前介紹的CBAM模塊硝烂,SE模塊,BAM模塊铜幽,SK模塊類似滞谢,Non-Local也是一個(gè)易于集成的模塊串稀,針對(duì)一個(gè)feature map進(jìn)行信息的refine, 也是一種比較好的attention機(jī)制的實(shí)現(xiàn)。不過相比前幾種attention模塊狮杨,Non-Local中的attention擁有更多地理論支撐母截,稍微有點(diǎn)晦澀難懂。

Non-local的通用公式表示:

  • x是輸入信號(hào)橄教,cv中使用的一般是feature map
  • i 代表的是輸出位置清寇,如空間、時(shí)間或者時(shí)空的索引护蝶,他的響應(yīng)應(yīng)該對(duì)j進(jìn)行枚舉然后計(jì)算得到的
  • f 函數(shù)式計(jì)算i和j的相似度
  • g 函數(shù)計(jì)算feature map在j位置的表示
  • 最終的y是通過響應(yīng)因子C(x) 進(jìn)行標(biāo)準(zhǔn)化處理以后得到的

理解:與Non local mean相比颗管,就很容易理解,i 代表的是當(dāng)前位置的響應(yīng)滓走,j 代表全局響應(yīng)垦江,通過加權(quán)得到一個(gè)非局部的響應(yīng)值。

Non-Local的優(yōu)點(diǎn)是什么搅方?

  • 提出的non-local operations通過計(jì)算任意兩個(gè)位置之間的交互直接捕捉遠(yuǎn)程依賴比吭,而不用局限于相鄰點(diǎn),其相當(dāng)于構(gòu)造了一個(gè)和特征圖譜尺寸一樣大的卷積核, 從而可以維持更多信息姨涡。
  • non-local可以作為一個(gè)組件衩藤,和其它網(wǎng)絡(luò)結(jié)構(gòu)結(jié)合,經(jīng)過作者實(shí)驗(yàn)涛漂,證明了其可以應(yīng)用于圖像分類赏表、目標(biāo)檢測(cè)、目標(biāo)分割匈仗、姿態(tài)識(shí)別等視覺任務(wù)中瓢剿,并且效果有不同程度的提升。
  • Non-local在視頻分類上效果很好悠轩,在視頻分類的任務(wù)中效果可觀间狂。

2. 細(xì)節(jié)

論文中給了通用公式,然后分別介紹f函數(shù)g函數(shù)的實(shí)例化表示:

g函數(shù):可以看做一個(gè)線性轉(zhuǎn)化(Linear Embedding)公式如下:

是需要學(xué)習(xí)的權(quán)重矩陣火架,可以通過空間上的1×1卷積實(shí)現(xiàn)(實(shí)現(xiàn)起來比較簡(jiǎn)單)鉴象。


f函數(shù):這是一個(gè)用于計(jì)算i和j相似度的函數(shù),作者提出了四個(gè)具體的函數(shù)可以用作f函數(shù)何鸡。

  • Gaussian function: 具體公式如下:

這里使用的是 一個(gè)點(diǎn)乘來計(jì)算相似度纺弊,之所以點(diǎn)積可以衡量相似度,這是通過余弦相似度簡(jiǎn)化而來的骡男。

  • Embedded Gaussian: 具體公式如下:

  • Dot product: 具體公式如下:

  • Concatenation: 具體公式如下:


以上四個(gè)函數(shù)可能看起來感覺讓人讀起來很吃力淆游,下邊進(jìn)行大概解釋一下上邊符號(hào)的意義,結(jié)合示意圖(以Embeded Gaussian為例,對(duì)原圖進(jìn)行細(xì)節(jié)上加工,具體參見代碼稽犁,地址為文末鏈接中的non_local_embedded_gaussian.py文件):

image
  • x代表feature map, 代表的是當(dāng)前關(guān)注位置的信息焰望; 代表的是全局信息。

  • θ代表的是 ,實(shí)際操作是用一個(gè)1×1卷積進(jìn)行學(xué)習(xí)的已亥。

  • φ代表的是 ,實(shí)際操作是用一個(gè)1×1卷積進(jìn)行學(xué)習(xí)的熊赖。

  • g函數(shù)意義同上。

  • C(x)代表的是歸一化操作虑椎,在embedding gaussian中使用的是Sigmoid實(shí)現(xiàn)的震鹉。

然后可以將上圖(實(shí)現(xiàn)角度)與下圖(比較抽象)進(jìn)行結(jié)合理解:

image

具體解釋如下:(ps: 以下解釋帶上了bs,上圖中由于bs不方便畫圖捆姜,所以沒有添加bs)

X是一個(gè)feature map,形狀為[bs, c, h, w], 經(jīng)過三個(gè)1×1卷積核传趾,將通道縮減為原來一半(c/2)。然后將h,w兩個(gè)維度進(jìn)行flatten泥技,變?yōu)閔×w浆兰,最終形狀為[bs, c/2, h×w]的tensor。對(duì)θ對(duì)應(yīng)的tensor進(jìn)行通道重排珊豹,在線性代數(shù)中也就是轉(zhuǎn)置簸呈,得到形狀為[bs, h×w, c/2]。然后與φ代表的tensor進(jìn)行矩陣乘法店茶,得到一個(gè)形狀為[bs, h×w蜕便,h×w]的矩陣,這個(gè)矩陣計(jì)算的是相似度(或者理解為attention)贩幻。然后經(jīng)過softmax進(jìn)行歸一化轿腺,然后將該得到的矩陣 與g 經(jīng)過flatten和轉(zhuǎn)置的結(jié)果進(jìn)行矩陣相乘,得到的形狀為[bs, h*w, c/2]的結(jié)果y丛楚。然后轉(zhuǎn)置為[bs, c/2, h×w]的tensor, 然后將h×w維度重新伸展為[h, w]族壳,從而得到了形狀為[bs, c/2, h, w]的tensor。然后對(duì)這個(gè)tensor再使用一個(gè)1×1卷積核鸯檬,將通道擴(kuò)展為原來的c决侈,這樣得到了[bs, c, h, w]的tensor,與初始X的形狀是一致的。最終一步操作是將X與得到的tensor進(jìn)行相加(類似resnet中的residual block)喧务。

可能存在的問題

計(jì)算量偏大:在高階語義層引入non local layer, 也可以在具體實(shí)現(xiàn)的過程中添加pooling層來進(jìn)一步減少計(jì)算量。

3. 代碼

代碼來自官方枉圃,修改了一點(diǎn)點(diǎn)以便于理解功茴,推薦將代碼的forward部分與上圖進(jìn)行對(duì)照理解。

import torch
from torch import nn
from torch.nn import functional as F

class _NonLocalBlockND(nn.Module):
    """
    調(diào)用過程
    NONLocalBlock2D(in_channels=32),
    super(NONLocalBlock2D, self).__init__(in_channels,
            inter_channels=inter_channels,
            dimension=2, sub_sample=sub_sample,
            bn_layer=bn_layer)
    """
    def __init__(self,
                 in_channels,
                 inter_channels=None,
                 dimension=3,
                 sub_sample=True,
                 bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            # 進(jìn)行壓縮得到channel個(gè)數(shù)
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels,
                         out_channels=self.inter_channels,
                         kernel_size=1,
                         stride=1,
                         padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels,
                        out_channels=self.in_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0), bn(self.in_channels))
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels,
                             out_channels=self.in_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels,
                             out_channels=self.inter_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
        self.phi = conv_nd(in_channels=self.in_channels,
                           out_channels=self.inter_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        '''
        :param x: (b, c,  h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)#[bs, c, w*h]
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        f = torch.matmul(theta_x, phi_x)

        print(f.shape)

        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
        return z

4. 實(shí)驗(yàn)結(jié)論

  • 文中提出了四個(gè)計(jì)算相似度的模型孽亲,實(shí)驗(yàn)對(duì)四個(gè)方法都進(jìn)行了實(shí)驗(yàn)坎穿,發(fā)現(xiàn)了這四個(gè)模型效果相差并不大,于是有一個(gè)結(jié)論:使用non-local對(duì)baseline結(jié)果是有提升的,但是不同相似度計(jì)算方法之間差距并不大玲昧,所以可以采用其中一個(gè)做實(shí)驗(yàn)即可栖茉,文中用embedding gaussian作為默認(rèn)的相似度計(jì)算方法。

  • 作者做了一系列消融實(shí)驗(yàn)來證明non local NN的有效性:

  1. 使用四個(gè)相似度計(jì)算模型孵延,發(fā)現(xiàn)影響不大吕漂,但是都比baseline效果好。
image
  1. 以ResNet50為例尘应,測(cè)試加在不同stage下的結(jié)果惶凝。可以看出在res2,3,4部分得到的結(jié)果相對(duì)baseline提升比較大犬钢,但是res5就一般了苍鲜,這有可能是由于第5個(gè)stage中的feature map的spatial size比較小,信息比較少玷犹,所以提升比較小混滔。
image
  1. 嘗試添加不同數(shù)量的non local block ,結(jié)果如下〈跬牵可以發(fā)現(xiàn)坯屿,添加越多的non local 模塊,其效果越好晴股,但是與此同時(shí)帶來的計(jì)算量也會(huì)比較大愿伴,所以要對(duì)速度和精度進(jìn)行權(quán)衡。
image
  1. Non-local 與3D卷積的對(duì)比电湘,發(fā)現(xiàn)要比3D卷積計(jì)算量小的情況下隔节,準(zhǔn)確率有較為可觀的提升。
image
  1. 作者還將Non-local block應(yīng)用在目標(biāo)檢測(cè)寂呛、實(shí)例分割怎诫、關(guān)鍵點(diǎn)檢測(cè)等領(lǐng)域〈荆可以將non-local block作為一個(gè)trick添加到目標(biāo)檢測(cè)幻妓、實(shí)例分割、關(guān)鍵點(diǎn)檢測(cè)等領(lǐng)域, 可能帶來1-3%的提升劫拢。
image

5. 評(píng)價(jià)

Non local NN從傳統(tǒng)方法Non local means中獲得靈感肉津,然后接著在神經(jīng)網(wǎng)絡(luò)中應(yīng)用了這個(gè)思想,直接融合了全局的信息舱沧,而不僅僅是通過堆疊多個(gè)卷積層獲得較為全局的信息妹沙。這樣可以為后邊的層帶來更為豐富的語義信息。

論文中也通過消融實(shí)驗(yàn)熟吏,完全證明了該模塊在視頻分類距糖,目標(biāo)檢測(cè)玄窝,實(shí)例分割、關(guān)鍵點(diǎn)檢測(cè)等領(lǐng)域的有效性悍引,但是其中并沒有給出其帶來的參數(shù)量上的變化恩脂,或者計(jì)算速度的變化。但是可以猜得到趣斤,參數(shù)量的增加還是有一定的俩块,如果對(duì)速度有要求的實(shí)驗(yàn)可能要進(jìn)行速度和精度上的權(quán)衡,不能盲目添加non local block唬渗。神經(jīng)網(wǎng)絡(luò)中還有一個(gè)常見的操作也是利用的全局信息典阵,那就是Linear層,全連接層將feature map上每一個(gè)點(diǎn)的信息都進(jìn)行了融合镊逝,Linear可以看做一種特殊的Non local操作壮啊。

之后GCNet等工作對(duì)Non-Local Neural Network結(jié)構(gòu)進(jìn)行改進(jìn),能夠大幅降低Non-Local NN的計(jì)算量撑蒜,更具有實(shí)用價(jià)值歹啼。

6. 參考內(nèi)容

論文:https://arxiv.org/abs/1711.07971

video classification 代碼:https://github.com/facebookresearch/video-nonlocal-net

non local官方實(shí)現(xiàn):https://github.com/pprp/SimpleCVReproduction/tree/master/attention/Non-local/Non-Local_pytorch_0.4.1_to_1.1.0/lib

知乎文章:https://zhuanlan.zhihu.com/p/33345791

博客:https://hellozhaozheng.github.io/z_post/計(jì)算機(jī)視覺-NonLocal-CVPR2018/


推薦閱讀:

CV中的Attention機(jī)制-最簡(jiǎn)單最易實(shí)現(xiàn)的SE模塊

CV中的Attention機(jī)制-Selective-Kernel-Networks-SE進(jìn)化版

CV中的Attention機(jī)制-CBAM模塊

CV中的Attention機(jī)制-并行版的CBAM-BAM模塊

CV中的attention機(jī)制-語義分割中的scSE模塊

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市座菠,隨后出現(xiàn)的幾起案子狸眼,更是在濱河造成了極大的恐慌,老刑警劉巖浴滴,帶你破解...
    沈念sama閱讀 217,907評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件拓萌,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡升略,警方通過查閱死者的電腦和手機(jī)微王,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,987評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來品嚣,“玉大人炕倘,你說我怎么就攤上這事『渤牛” “怎么了罩旋?”我有些...
    開封第一講書人閱讀 164,298評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)眶诈。 經(jīng)常有香客問我涨醋,道長(zhǎng),這世上最難降的妖魔是什么逝撬? 我笑而不...
    開封第一講書人閱讀 58,586評(píng)論 1 293
  • 正文 為了忘掉前任东帅,我火速辦了婚禮,結(jié)果婚禮上球拦,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好坎炼,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,633評(píng)論 6 392
  • 文/花漫 我一把揭開白布愧膀。 她就那樣靜靜地躺著,像睡著了一般谣光。 火紅的嫁衣襯著肌膚如雪檩淋。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,488評(píng)論 1 302
  • 那天萄金,我揣著相機(jī)與錄音蟀悦,去河邊找鬼。 笑死氧敢,一個(gè)胖子當(dāng)著我的面吹牛日戈,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播孙乖,決...
    沈念sama閱讀 40,275評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼浙炼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了唯袄?” 一聲冷哼從身側(cè)響起弯屈,我...
    開封第一講書人閱讀 39,176評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤了讨,失蹤者是張志新(化名)和其女友劉穎生巡,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體罢吃,經(jīng)...
    沈念sama閱讀 45,619評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡蔬顾,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,819評(píng)論 3 336
  • 正文 我和宋清朗相戀三年宴偿,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片阎抒。...
    茶點(diǎn)故事閱讀 39,932評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡酪我,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出且叁,到底是詐尸還是另有隱情都哭,我是刑警寧澤,帶...
    沈念sama閱讀 35,655評(píng)論 5 346
  • 正文 年R本政府宣布逞带,位于F島的核電站欺矫,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏展氓。R本人自食惡果不足惜穆趴,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,265評(píng)論 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望遇汞。 院中可真熱鬧未妹,春花似錦簿废、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,871評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至化戳,卻和暖如春单料,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背点楼。 一陣腳步聲響...
    開封第一講書人閱讀 32,994評(píng)論 1 269
  • 我被黑心中介騙來泰國打工扫尖, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人掠廓。 一個(gè)月前我還...
    沈念sama閱讀 48,095評(píng)論 3 370
  • 正文 我出身青樓换怖,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國和親却盘。 傳聞我的和親對(duì)象是個(gè)殘疾皇子狰域,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,884評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容