siamfc-pytorch代碼講解(一):backbone&head

最近才真正開(kāi)始研究目標(biāo)跟蹤領(lǐng)域(好吧,是真的慢)凌外。就先看了一篇論文:
Fully-Convolutional Siamese Networks for Object Tracking【ECCV2016 workshop】
又因?yàn)閷W(xué)的是PyTorch框架,所以找了一份比較clean的代碼涛浙,還是pytorch1.0的:
https://github.com/huanglianghua/siamfc-pytorch
因?yàn)檫@個(gè)作者也是GOT-10k toolkit的主要貢獻(xiàn)者康辑,所以用上這個(gè)工具箱之后顯得training和test會(huì)clean一些,要能跑訓(xùn)練和測(cè)試代碼轿亮,還得去下載GOT-10k數(shù)據(jù)集疮薇,訓(xùn)練數(shù)據(jù)分成了19份,如果只是為了跑一下下一份就行我注。

論文概述

SiamFC這篇論文算是將深度神經(jīng)網(wǎng)絡(luò)較早運(yùn)用于tracking的按咒,比它還早一點(diǎn)的就是SINT了,主要是運(yùn)用了相似度學(xué)習(xí)的思想但骨,采用孿生網(wǎng)絡(luò)励七,把127×127的exemplar image z 和255×255的search image x 輸入同一個(gè)backbone(論文中就是AlexNet)也叫Embedding Network,生成各自的Embedding奔缠,然后這兩個(gè)Embedding經(jīng)過(guò)互相關(guān)計(jì)算的得到score map掠抬,其上大的位置就代表對(duì)應(yīng)位置上的Embedding相似度大,反之亦然校哎。整個(gè)訓(xùn)練流程可以用下圖表示:

SiamFC訓(xùn)練流程

個(gè)人感覺(jué)两波,訓(xùn)練就是為了優(yōu)化Embedding Network,在見(jiàn)到的序列中生成一個(gè)更好embedding,從而使生成的score map和生成的ground truth有更小的logistic loss雨女。更多細(xì)節(jié)在之后的幾篇會(huì)和代碼一起分析谚攒。

backbones.py分析

from __future__ import absolute_import

import torch.nn as nn


__all__ = ['AlexNetV1', 'AlexNetV2', 'AlexNetV3']


class _BatchNorm2d(nn.BatchNorm2d):

    def __init__(self, num_features, *args, **kwargs):
        super(_BatchNorm2d, self).__init__(
            num_features, *args, eps=1e-6, momentum=0.05, **kwargs)


class _AlexNet(nn.Module):
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x


class AlexNetV1(_AlexNet):
    output_stride = 8

    def __init__(self):
        super(AlexNetV1, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, 2),
            _BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, 1, groups=2),
            _BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2))
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, 1),
            _BatchNorm2d(384),
            nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, 1, groups=2),
            _BatchNorm2d(384),
            nn.ReLU(inplace=True))
        self.conv5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, 1, groups=2))


class AlexNetV2(_AlexNet):
    output_stride = 4

    def __init__(self):
        super(AlexNetV2, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, 2),
            _BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, 1, groups=2),
            _BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 1))
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, 1),
            _BatchNorm2d(384),
            nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, 1, groups=2),
            _BatchNorm2d(384),
            nn.ReLU(inplace=True))
        self.conv5 = nn.Sequential(
            nn.Conv2d(384, 32, 3, 1, groups=2))


class AlexNetV3(_AlexNet):
    output_stride = 8

    def __init__(self):
        super(AlexNetV3, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 192, 11, 2),
            _BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(192, 512, 5, 1),
            _BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2))
        self.conv3 = nn.Sequential(
            nn.Conv2d(512, 768, 3, 1),
            _BatchNorm2d(768),
            nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(
            nn.Conv2d(768, 768, 3, 1),
            _BatchNorm2d(768),
            nn.ReLU(inplace=True))
        self.conv5 = nn.Sequential(
            nn.Conv2d(768, 512, 3, 1),
            _BatchNorm2d(512))

這個(gè)module主要實(shí)現(xiàn)了3個(gè)AlexNet版本作為backbone,開(kāi)頭的__all__ = ['AlexNetV1', 'AlexNetV2', 'AlexNetV3']主要是為了讓別的module導(dǎo)入這個(gè)backbones.py的東西時(shí)氛堕,只能導(dǎo)入__all__后面的部分馏臭。
后面就是三個(gè)類AlexNetV1、AlexNetV2讼稚、AlexNetV3括儒,他們都集成了類_AlexNet,所以他們都是使用同樣的forward函數(shù)锐想,依次通過(guò)五個(gè)卷積層帮寻,每個(gè)卷積層使用nn.Sequential()堆疊,只是他們各自的total_stride和具體每層卷積層實(shí)現(xiàn)稍有不同(當(dāng)然跟原本的AlexNet還是有些差別的赠摇,比如通道數(shù)上):

  • AlexNetV1AlexNetV2
    • <font color=blue>共同點(diǎn):</font>conv2固逗、conv4、conv5這幾層都用了groups=2的分組卷積藕帜,這跟原來(lái)的AlexNet會(huì)更接近一點(diǎn)
    • <font color=red>不同點(diǎn):</font>conv2中的MaxPool2d的stride不一樣大烫罩,conv5層的輸出通道數(shù)不一樣
  • AlexNetV1AlexNetV3:前兩層的MaxPool2d是一樣的,但是中間層的卷積層輸入輸出通道都不一樣洽故,最后的輸出通道也不一樣贝攒,AlexNetV3最后輸出經(jīng)過(guò)了BN
  • AlexNetV2AlexNetV3:conv2中的MaxPool2d的stride不一樣,AlexNetV2最后輸出通道數(shù)小很多

其實(shí)感覺(jué)即使有這些區(qū)別时甚,但是這并不是很重要隘弊,這一部分也是整體當(dāng)中容易理解的,所以不必太去糾結(jié)為什么不一樣荒适,最后作者用的是AlexNetV1梨熙,論文中是這樣的結(jié)構(gòu),其實(shí)也就是AlexNetV1:

論文中backbone結(jié)構(gòu)

注意:有些人會(huì)感覺(jué)這里輸入輸出通道對(duì)不上吻贿,這是因?yàn)橄裨続lexNet分成了2個(gè)group串结,所以會(huì)有48->96, 192->384這樣。
也可以在此py文件下面再加一段代碼舅列,測(cè)試一下打印出的tensor的shape:

if __name__ == '__main__':
    alexnetv1 = AlexNetV1()
    import torch
    z = torch.randn(1, 3, 127, 127)
    output = alexnetv1(z)
    print(output.shape)  # torch.Size([1, 256, 6, 6])
    x = torch.randn(1, 3, 256, 256)
    output = alexnetv1(x)
    print(output.shape)  # torch.Size([1, 256, 22, 22])
    # 換成AlexNetV2依次是:
    # torch.Size([1, 32, 17, 17])肌割、torch.Size([1, 32, 49, 49])
    # 換成AlexNetV3依次是:
    # torch.Size([1, 512, 6, 6])、torch.Size([1, 512, 22, 22])

heads.py

先放代碼為敬:

class SiamFC(nn.Module):

    def __init__(self, out_scale=0.001):
        super(SiamFC, self).__init__()
        self.out_scale = out_scale
    
    def forward(self, z, x):
        return self._fast_xcorr(z, x) * self.out_scale
    
    def _fast_xcorr(self, z, x):
        # fast cross correlation
        nz = z.size(0)
        nx, c, h, w = x.size()
        x = x.view(-1, nz * c, h, w)  
        out = F.conv2d(x, z, groups=nz)  # shape:[nx/nz, nz, H, W]
        out = out.view(nx, -1, out.size(-2), out.size(-1)) #[nx, 1, H, W]
        return out
  • 為什么這里會(huì)有個(gè)out_scale帐要,根據(jù)作者說(shuō)是因?yàn)椋?zx互相關(guān)之后的值太大把敞,經(jīng)過(guò)sigmoid函數(shù)之后會(huì)使值處于梯度飽和的那塊,梯度太小榨惠,乘以out_scale就是為了避免這個(gè)奋早。
  • _fast_xcorr函數(shù)中最關(guān)鍵的部分就是F.conv2d函數(shù)了盛霎,可以通過(guò)官網(wǎng)查詢到用法

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor

  • input – input tensor of shape (\text{minibatch},\text{in\_channels} ,iH,iW)
  • weight – filters of shape (\text{out\_channels},\frac{\text{in\_channels}}{\text{groups}},kH,kW)

所以根據(jù)上面條件,可以得到:x shape:[nx/nz, nz*c, h, w] 和 z shape:[nz, c, hz, wz]耽装,最后out shape:[nx, 1, H, W]
其實(shí)最后真實(shí)喂入此函數(shù)的z embedding shape:[8, 256, 6, 6], x embedding shape:[8, 256, 20, 20], output shape:[8, 1, 15, 15]【這個(gè)之后再回過(guò)來(lái)看也行】

同樣的愤炸,也可以用下面一段代碼測(cè)試一下:

if __name__ == '__main__':
    import torch
    z = torch.randn(8, 256, 6, 6)
    x = torch.randn(8, 256, 20, 20)
    siamfc = SiamFC()
    output = siamfc(z, x)
    print(output.shape)  # torch.Size([8, 1, 15, 15])

好了,這部分先講到這里掉奄,這一塊還是算簡(jiǎn)單的规个,一般看一下應(yīng)該就能理解,之后的代碼會(huì)更具挑戰(zhàn)性姓建,嘻嘻诞仓,放一個(gè)輔助鏈接,下面這個(gè)版本中有一些動(dòng)圖速兔,還是會(huì)幫助理解的:

還有下面是GOT-10k的toolkit墅拭,可以先看一下,但是訓(xùn)練部分代碼還不是涉及很多:

下一篇

siamfc-pytorch代碼講解(二):train&siamfc

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末涣狗,一起剝皮案震驚了整個(gè)濱河市谍婉,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌镀钓,老刑警劉巖屡萤,帶你破解...
    沈念sama閱讀 210,978評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異掸宛,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)招拙,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,954評(píng)論 2 384
  • 文/潘曉璐 我一進(jìn)店門唧瘾,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人别凤,你說(shuō)我怎么就攤上這事饰序。” “怎么了规哪?”我有些...
    開(kāi)封第一講書(shū)人閱讀 156,623評(píng)論 0 345
  • 文/不壞的土叔 我叫張陵求豫,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我诉稍,道長(zhǎng)蝠嘉,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,324評(píng)論 1 282
  • 正文 為了忘掉前任杯巨,我火速辦了婚禮蚤告,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘服爷。我一直安慰自己杜恰,他們只是感情好获诈,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,390評(píng)論 5 384
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著心褐,像睡著了一般舔涎。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上逗爹,一...
    開(kāi)封第一講書(shū)人閱讀 49,741評(píng)論 1 289
  • 那天亡嫌,我揣著相機(jī)與錄音,去河邊找鬼桶至。 笑死昼伴,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的镣屹。 我是一名探鬼主播圃郊,決...
    沈念sama閱讀 38,892評(píng)論 3 405
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼女蜈!你這毒婦竟也來(lái)了持舆?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 37,655評(píng)論 0 266
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤伪窖,失蹤者是張志新(化名)和其女友劉穎逸寓,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體覆山,經(jīng)...
    沈念sama閱讀 44,104評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,451評(píng)論 2 325
  • 正文 我和宋清朗相戀三年簇宽,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了勋篓。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,569評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡魏割,死狀恐怖譬嚣,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情钞它,我是刑警寧澤拜银,帶...
    沈念sama閱讀 34,254評(píng)論 4 328
  • 正文 年R本政府宣布,位于F島的核電站遭垛,受9級(jí)特大地震影響尼桶,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜锯仪,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,834評(píng)論 3 312
  • 文/蒙蒙 一疯汁、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧卵酪,春花似錦幌蚊、人聲如沸谤碳。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,725評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)蜒简。三九已至,卻和暖如春漩仙,著一層夾襖步出監(jiān)牢的瞬間搓茬,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 31,950評(píng)論 1 264
  • 我被黑心中介騙來(lái)泰國(guó)打工队他, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留卷仑,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,260評(píng)論 2 360
  • 正文 我出身青樓麸折,卻偏偏與公主長(zhǎng)得像锡凝,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子垢啼,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,446評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容