Dynamic Routing Between Capsules

Sabour S, Frosst N, Hinton G E, et al. Dynamic Routing Between Capsules[C]. neural information processing systems, 2017: 3856-3866.

雖然11年就提出了capsule的概念, 但是走入人們視線的應(yīng)該還是這篇文章吧. 雖然現(xiàn)階段, capsule沒有體現(xiàn)出什么優(yōu)勢. 不過, capsule相較于傳統(tǒng)的CNN融入了很多先驗(yàn)知識(shí), 更能夠擬合人類的視覺系統(tǒng)(我不知), 或許有一天它會(huì)大放異彩.

主要內(nèi)容

在這里插入圖片描述

直接從這個(gè)結(jié)構(gòu)圖講起吧.

  1. Input: 1 x 28 x 28 的圖片 經(jīng)過 9 x 9的卷積核(stride=1, padding=0, out_channels=256)作用;
  2. 256 x 20 x 20的特征圖, 經(jīng)過primarycaps作用(9 x 9 的卷積核(strde=2, padding=0, out_channels=256);
  3. (32 x 8) x 6 x 6的特征圖, 理解為32 x 6 x 6 x 8 = 1152 x 8, 即1152個(gè)膠囊, 每個(gè)膠囊由一個(gè)8D的向量表示u_{i}; (這個(gè)地方要不要squash, 大部分實(shí)現(xiàn)都是要的.)
  4. 接下來digitcaps中有10個(gè)caps(對應(yīng)10個(gè)類別), 1152caps和10個(gè)caps一一對應(yīng), 分別用i, j表示, 前一層的caps為后一層提供輸入, 輸入為
    \hat{u}_{j|i} = W_{ij}u_i,
    可見, 應(yīng)當(dāng)有1152 x 10個(gè)W_{ij}\in \mathbb{R}^{16\times 8}, 其中16是輸出膠囊的維度. 最后10個(gè)caps的輸出為
    s_j= \sum_{i}c_{ij}\hat{u}_{j|i}, v_j= \frac{\|s\|_j^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}.

其中c_{ij}是通過一個(gè)路由算法決定的, v_j, 即最后的輸入如此定義是出于一種直覺, 即保持原始輸出(s)的方向, 同時(shí)讓v的長度表示一個(gè)概率(這一步稱為squash).

首先初始化b_{ij}=0 (這里在程序?qū)崿F(xiàn)的時(shí)候有一個(gè)考量, 是每一次都要初始化嗎, 我看大部分的實(shí)現(xiàn)都是如此的).

在這里插入圖片描述

上面的Eq.3就是
\tag{3} c_{ij}=\frac{\exp(b_{ij})}{\sum_{k}\exp(b_{ik})}.

另外\hat{\mu}_{j|i} \cdot v_j=\hat{\mu}_{j|i}^Tv_j是一種cos相似度度量.

損失函數(shù)

損失函數(shù)采用的是margin loss:
\tag{4} L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\|-m^-)^2.

m^+, m^-通常取0.9和0.1, \lambda通常取0.5.

代碼

我的代碼, 在sgd下可以訓(xùn)練(但是準(zhǔn)確率只有98), 在adam下就死翹翹了, 所以代碼肯定是有問題, 但是我實(shí)在是找不出來了, 這里有很多實(shí)現(xiàn)的匯總.



"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""


import torch
import torch.nn as nn
import torch.nn.functional as F



def squash(s):
    temp = s.norm(dim=-1, keepdim=True)
    return (temp / (1. + temp ** 2)) * s


class PrimaryCaps(nn.Module):

    def __init__(
        self, in_channel, out_entities, 
        out_dims, kernel_size, stride, padding
    ):
        super(PrimaryCaps, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_entities * out_dims, 
                            kernel_size, stride, padding)
        self.out_entities = out_entities
        self.out_dims = out_dims

    def forward(self, inputs):
        conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
        outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
        return squash(outs)


class AgreeRouting(nn.Module):

    def __init__(self, in_caps, out_caps, out_dims, iterations=3):
        super(AgreeRouting, self).__init__()

        self.in_caps = in_caps
        self.out_caps = out_caps
        self.out_dims = out_dims
        self.iterations = iterations

    @staticmethod
    def softmax(inputs, dim=-1):
        return F.softmax(inputs, dim=dim)

    def forward(self, inputs):
        # inputs N x in_caps x out_caps x out_dims
        b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
        for r in range(self.iterations):
            c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
            s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
            v = squash(s) # N x out_caps x out_dims
            b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
        return v



class CapsLayer(nn.Module):

    def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
        super(CapsLayer, self).__init__()
        self.in_caps = in_caps
        self.in_dims = in_dims
        self.routing = routing
        self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
        nn.init.kaiming_uniform_(self.weights)

    def forward(self, inputs):
        # inputs: N x in_caps x in_dims
        inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
        u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
        outs = self.routing(u_pres) # N x out_caps x out_dims

        return outs




class CapsNet(nn.Module):

    def __init__(self):
        super(CapsNet, self).__init__()

        # N x 1 x 28 x 28
        self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
        self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
        routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
        self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)


    def forward(self, inputs):
        conv_outs = F.relu(self.conv(inputs))
        pri_outs = self.primarycaps(conv_outs)
        outs = self.digitlayer(pri_outs)
        probs = outs.norm(dim=-1)
        return probs
        


if __name__ == "__main__":

    x = torch.randn(4, 1, 28 ,28)
    capsnet = CapsNet()
    print(capsnet(x))


def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
    # outs: N x num_classes x dim
    # labels: N
    temp1 = F.relu(m - logits) ** 2
    temp2 = F.relu(logits + m - 1) ** 2
    T = F.one_hot(labels.long(), logits.size(-1))
    loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
    if adverage:
        loss = loss / logits.size(0)
    # Another implement is using scatter_
    # T = torch.zero(logits.size()).long()
    # T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
    return loss

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末预麸,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子属瓣,更是在濱河造成了極大的恐慌扯躺,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,682評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件藕畔,死亡現(xiàn)場離奇詭異马僻,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)注服,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,277評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門韭邓,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人溶弟,你說我怎么就攤上這事女淑。” “怎么了辜御?”我有些...
    開封第一講書人閱讀 165,083評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵鸭你,是天一觀的道長。 經(jīng)常有香客問我擒权,道長苇本,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,763評(píng)論 1 295
  • 正文 為了忘掉前任菜拓,我火速辦了婚禮瓣窄,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘纳鼎。我一直安慰自己俺夕,他們只是感情好裳凸,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,785評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著劝贸,像睡著了一般姨谷。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上映九,一...
    開封第一講書人閱讀 51,624評(píng)論 1 305
  • 那天梦湘,我揣著相機(jī)與錄音,去河邊找鬼件甥。 笑死捌议,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的引有。 我是一名探鬼主播瓣颅,決...
    沈念sama閱讀 40,358評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼譬正!你這毒婦竟也來了宫补?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,261評(píng)論 0 276
  • 序言:老撾萬榮一對情侶失蹤曾我,失蹤者是張志新(化名)和其女友劉穎粉怕,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體抒巢,經(jīng)...
    沈念sama閱讀 45,722評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡贫贝,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,900評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了虐秦。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片平酿。...
    茶點(diǎn)故事閱讀 40,030評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡凤优,死狀恐怖悦陋,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情筑辨,我是刑警寧澤俺驶,帶...
    沈念sama閱讀 35,737評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站棍辕,受9級(jí)特大地震影響暮现,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜楚昭,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,360評(píng)論 3 330
  • 文/蒙蒙 一栖袋、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧抚太,春花似錦塘幅、人聲如沸昔案。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,941評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽踏揣。三九已至,卻和暖如春匾乓,著一層夾襖步出監(jiān)牢的瞬間捞稿,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,057評(píng)論 1 270
  • 我被黑心中介騙來泰國打工拼缝, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留娱局,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,237評(píng)論 3 371
  • 正文 我出身青樓珍促,卻偏偏與公主長得像铃辖,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子猪叙,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,976評(píng)論 2 355