1. 程式人生 > 實用技巧 >Dynamic Routing Between Capsules

Dynamic Routing Between Capsules

目錄

Sabour S, Frosst N, Hinton G E, et al. Dynamic Routing Between Capsules[C]. neural information processing systems, 2017: 3856-3866.

雖然11年就提出了capsule的概念, 但是走入人們視線的應該還是這篇文章吧. 雖然現階段, capsule沒有體現出什麼優勢. 不過, capsule相較於傳統的CNN融入了很多先驗知識, 更能夠擬合人類的視覺系統(我不知), 或許有一天它會大放異彩.

主要內容

直接從這個結構圖講起吧.

  1. Input:
    1 x 28 x 28 的圖片 經過 9 x 9的卷積核(stride=1, padding=0, out_channels=256)作用;
  2. 256 x 20 x 20的特徵圖, 經過primarycaps作用(9 x 9 的卷積核(strde=2, padding=0, out_channels=256);
  3. (32 x 8) x 6 x 6的特徵圖, 理解為32 x 6 x 6 x 8 = 1152 x 8, 即1152個膠囊, 每個膠囊由一個8D的向量表示\(u_{i}\); (這個地方要不要squash, 大部分實現都是要的.)
  4. 接下來digitcaps中有10個caps(對應10個類別), 1152caps和10個caps一一對應, 分別用\(i, j\)
    表示, 前一層的caps為後一層提供輸入, 輸入為

\[\hat{u}_{j|i} = W_{ij}u_i, \]

可見, 應當有1152 x 10個\(W_{ij}\in \mathbb{R}^{16\times 8}\), 其中16是輸出膠囊的維度. 最後10個caps的輸出為

\[s_j= \sum_{i}c_{ij}\hat{u}_{j|i}, v_j= \frac{\|s\|_j^2}{1 + \|s_j\|^2} \frac{s_j}{\|s_j\|}. \]

其中\(c_{ij}\)是通過一個路由演算法決定的, \(v_j\), 即最後的輸入如此定義是出於一種直覺, 即保持原始輸出(\(s\)

)的方向, 同時讓\(v\)的長度表示一個概率(這一步稱為squash).

首先初始化\(b_{ij}=0\) (這裡在程式實現的時候有一個考量, 是每一次都要初始化嗎, 我看大部分的實現都是如此的).

上面的Eq.3就是

\[\tag{3} c_{ij}=\frac{\exp(b_{ij})}{\sum_{k}\exp(b_{ik})}. \]

另外\(\hat{\mu}_{j|i} \cdot v_j=\hat{\mu}_{j|i}^Tv_j\)是一種cos相似度度量.

損失函式

損失函式採用的是margin loss:

\[\tag{4} L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\|-m^-)^2. \]

\(m^+, m^-\)通常取0.9和0.1, \(\lambda\)通常取0.5.

程式碼

我的程式碼, 在sgd下可以訓練(但是準確率只有98), 在adam下就死翹翹了, 所以程式碼肯定是有問題, 但是我實在是找不出來了, 這裡有很多實現的彙總.



"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""


import torch
import torch.nn as nn
import torch.nn.functional as F



def squash(s):
    temp = s.norm(dim=-1, keepdim=True)
    return (temp / (1. + temp ** 2)) * s


class PrimaryCaps(nn.Module):

    def __init__(
        self, in_channel, out_entities, 
        out_dims, kernel_size, stride, padding
    ):
        super(PrimaryCaps, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_entities * out_dims, 
                            kernel_size, stride, padding)
        self.out_entities = out_entities
        self.out_dims = out_dims

    def forward(self, inputs):
        conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
        outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
        return squash(outs)


class AgreeRouting(nn.Module):

    def __init__(self, in_caps, out_caps, out_dims, iterations=3):
        super(AgreeRouting, self).__init__()

        self.in_caps = in_caps
        self.out_caps = out_caps
        self.out_dims = out_dims
        self.iterations = iterations

    @staticmethod
    def softmax(inputs, dim=-1):
        return F.softmax(inputs, dim=dim)

    def forward(self, inputs):
        # inputs N x in_caps x out_caps x out_dims
        b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
        for r in range(self.iterations):
            c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
            s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
            v = squash(s) # N x out_caps x out_dims
            b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
        return v



class CapsLayer(nn.Module):

    def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
        super(CapsLayer, self).__init__()
        self.in_caps = in_caps
        self.in_dims = in_dims
        self.routing = routing
        self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
        nn.init.kaiming_uniform_(self.weights)

    def forward(self, inputs):
        # inputs: N x in_caps x in_dims
        inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
        u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
        outs = self.routing(u_pres) # N x out_caps x out_dims

        return outs




class CapsNet(nn.Module):

    def __init__(self):
        super(CapsNet, self).__init__()

        # N x 1 x 28 x 28
        self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
        self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
        routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
        self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)


    def forward(self, inputs):
        conv_outs = F.relu(self.conv(inputs))
        pri_outs = self.primarycaps(conv_outs)
        outs = self.digitlayer(pri_outs)
        probs = outs.norm(dim=-1)
        return probs
        


if __name__ == "__main__":

    x = torch.randn(4, 1, 28 ,28)
    capsnet = CapsNet()
    print(capsnet(x))


def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
    # outs: N x num_classes x dim
    # labels: N
    temp1 = F.relu(m - logits) ** 2
    temp2 = F.relu(logits + m - 1) ** 2
    T = F.one_hot(labels.long(), logits.size(-1))
    loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
    if adverage:
        loss = loss / logits.size(0)
    # Another implement is using scatter_
    # T = torch.zero(logits.size()).long()
    # T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
    return loss