概
雖然11年就提出了capsule的概念, 但是走入人們視線的應(yīng)該還是這篇文章吧. 雖然現(xiàn)階段, capsule沒有體現(xiàn)出什么優(yōu)勢. 不過, capsule相較于傳統(tǒng)的CNN融入了很多先驗(yàn)知識(shí), 更能夠擬合人類的視覺系統(tǒng)(我不知), 或許有一天它會(huì)大放異彩.
主要內(nèi)容
在這里插入圖片描述
直接從這個(gè)結(jié)構(gòu)圖講起吧.
- Input: 1 x 28 x 28 的圖片 經(jīng)過 9 x 9的卷積核(stride=1, padding=0, out_channels=256)作用;
- 256 x 20 x 20的特征圖, 經(jīng)過primarycaps作用(9 x 9 的卷積核(strde=2, padding=0, out_channels=256);
- (32 x 8) x 6 x 6的特征圖, 理解為32 x 6 x 6 x 8 = 1152 x 8, 即1152個(gè)膠囊, 每個(gè)膠囊由一個(gè)8D的向量表示
; (這個(gè)地方要不要squash, 大部分實(shí)現(xiàn)都是要的.)
- 接下來digitcaps中有10個(gè)caps(對應(yīng)10個(gè)類別), 1152caps和10個(gè)caps一一對應(yīng), 分別用
表示, 前一層的caps為后一層提供輸入, 輸入為
可見, 應(yīng)當(dāng)有1152 x 10個(gè), 其中16是輸出膠囊的維度. 最后10個(gè)caps的輸出為
其中是通過一個(gè)路由算法決定的,
, 即最后的輸入如此定義是出于一種直覺, 即保持原始輸出(
)的方向, 同時(shí)讓
的長度表示一個(gè)概率(這一步稱為squash).
首先初始化 (這里在程序?qū)崿F(xiàn)的時(shí)候有一個(gè)考量, 是每一次都要初始化嗎, 我看大部分的實(shí)現(xiàn)都是如此的).
在這里插入圖片描述
上面的Eq.3就是
另外是一種cos相似度度量.
損失函數(shù)
損失函數(shù)采用的是margin loss:
通常取0.9和0.1,
通常取0.5.
代碼
我的代碼, 在sgd下可以訓(xùn)練(但是準(zhǔn)確率只有98), 在adam下就死翹翹了, 所以代碼肯定是有問題, 但是我實(shí)在是找不出來了, 這里有很多實(shí)現(xiàn)的匯總.
"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def squash(s):
temp = s.norm(dim=-1, keepdim=True)
return (temp / (1. + temp ** 2)) * s
class PrimaryCaps(nn.Module):
def __init__(
self, in_channel, out_entities,
out_dims, kernel_size, stride, padding
):
super(PrimaryCaps, self).__init__()
self.conv = nn.Conv2d(in_channel, out_entities * out_dims,
kernel_size, stride, padding)
self.out_entities = out_entities
self.out_dims = out_dims
def forward(self, inputs):
conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
return squash(outs)
class AgreeRouting(nn.Module):
def __init__(self, in_caps, out_caps, out_dims, iterations=3):
super(AgreeRouting, self).__init__()
self.in_caps = in_caps
self.out_caps = out_caps
self.out_dims = out_dims
self.iterations = iterations
@staticmethod
def softmax(inputs, dim=-1):
return F.softmax(inputs, dim=dim)
def forward(self, inputs):
# inputs N x in_caps x out_caps x out_dims
b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
for r in range(self.iterations):
c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
v = squash(s) # N x out_caps x out_dims
b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
return v
class CapsLayer(nn.Module):
def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
super(CapsLayer, self).__init__()
self.in_caps = in_caps
self.in_dims = in_dims
self.routing = routing
self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
nn.init.kaiming_uniform_(self.weights)
def forward(self, inputs):
# inputs: N x in_caps x in_dims
inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
outs = self.routing(u_pres) # N x out_caps x out_dims
return outs
class CapsNet(nn.Module):
def __init__(self):
super(CapsNet, self).__init__()
# N x 1 x 28 x 28
self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)
def forward(self, inputs):
conv_outs = F.relu(self.conv(inputs))
pri_outs = self.primarycaps(conv_outs)
outs = self.digitlayer(pri_outs)
probs = outs.norm(dim=-1)
return probs
if __name__ == "__main__":
x = torch.randn(4, 1, 28 ,28)
capsnet = CapsNet()
print(capsnet(x))
def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
# outs: N x num_classes x dim
# labels: N
temp1 = F.relu(m - logits) ** 2
temp2 = F.relu(logits + m - 1) ** 2
T = F.one_hot(labels.long(), logits.size(-1))
loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
if adverage:
loss = loss / logits.size(0)
# Another implement is using scatter_
# T = torch.zero(logits.size()).long()
# T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
return loss