[mmaction2版本] 視頻分類(二) TIN:Temporal Interlacing Network 原理及代碼講解

接著上一篇文章TSM視頻理解, 今天介紹新的視頻分類網(wǎng)絡(luò)TIN(Temporal Interlacing Network)。相對于TSM综看,TIN可以更靈活的基于交錯網(wǎng)絡(luò)預(yù)測出我們的特征圖的隨著時(shí)間的偏移量值而不是向TSM每次移動一位進(jìn)行特征融合避诽,如果不了解的具體可以通過下面的內(nèi)容進(jìn)行理解信夫。我們?nèi)匀换趍maction框架進(jìn)行講解裕照。
paper: Temporal Interlacing Network
code: mmaction2

一、 原理介紹


該模型是基于時(shí)間交錯網(wǎng)絡(luò)進(jìn)行行為識別壁熄,當(dāng)時(shí)其在速度上實(shí)現(xiàn)了SOTA 6倍的加速耻台,同時(shí)在準(zhǔn)確率上實(shí)現(xiàn)了4%的提升。該方法的思想和TSM的思想一樣都是希望將時(shí)間信息嵌入到空間信息特征去紫皇,以便可以同時(shí)一次同時(shí)聯(lián)合學(xué)習(xí)兩個(gè)域中的信息慰安。作者發(fā)明此網(wǎng)絡(luò)的直覺做出了如下解釋

In order to integrate temporal information at different times, we can provide different frames with a unique interlacing offset. Instead of habitually assigning each channel with a separately learnable offset, we adopt distinctive offsets for different channel groups. As observed in SlowFast (Feichtenhofer et al. 2018), human perception on object motion focuses on different temporal resolutions. To maintain temporal fidelity and recognize spatial semantics jointly, different groups of temporal receptive fields pursuit a thorough separation of expertise convolution. Besides, groups of offsets also reduce the model complexity as well as stabilize the training procedure across heavy backbone architectures.(為了整合不同時(shí)間的時(shí)間信息,我們可以為不同的幀提供獨(dú)特的交錯偏移聪铺。 我們沒有習(xí)慣性地為每個(gè)通道分配一個(gè)可單獨(dú)學(xué)習(xí)的偏移量化焕,而是為不同的通道組采用不同的偏移量。 正如在 SlowFast (Feichtenhofer et al. 2018) 中所觀察到的铃剔,人類對物體運(yùn)動的感知側(cè)重于不同的時(shí)間分辨率撒桨。 為了保持時(shí)間保真度并共同識別空間語義,不同組的時(shí)間感受野追求專業(yè)卷積的徹底分離键兜。 此外凤类,偏移組還降低了模型的復(fù)雜性,并穩(wěn)定了跨重型骨干架構(gòu)的訓(xùn)練過程)

主要原理圖如下所示:

Deformable Shift Module

時(shí)間交錯框架如上圖所示普气,如圖(a) 展示(b)結(jié)構(gòu)位于殘差神經(jīng)網(wǎng)絡(luò)之前的結(jié)構(gòu)谜疤。對于整個(gè)特征圖,會將前\frac{3}{4}的通道保持不變现诀,再會將\frac{1}{4}的通道進(jìn)行分組夷磕,這里我們會分為4組,(兩組沿著T維度的偏移量仔沿,剩下的兩組是偏移量是這前兩組的相反值坐桩, 這樣做可以保證信息在時(shí)序維度上的流動是對稱的,有利于 后續(xù)特征的融合)因?yàn)樽髡邔?shí)驗(yàn)發(fā)現(xiàn)兩組的效果是最好的于未, 這樣每組對應(yīng)不同偏移量撕攒。
Accuracies with different numbers of groups and reverse offsets

這些偏移量是怎么預(yù)測出來的呢陡鹃?還是要對應(yīng)上圖的原理圖, 對于輸入的特征圖我們首先會輸入到3D平均池化網(wǎng)絡(luò),接著分別輸入到OffsetNet網(wǎng)絡(luò)以及WeightNet在將兩者結(jié)合即可得到我們的偏移網(wǎng)絡(luò)的特征圖抖坪。OffsetNet主要負(fù)責(zé)預(yù)測偏移量而WeightNet主要負(fù)責(zé)預(yù)測融合后的時(shí)序維度上的特征權(quán)重萍鲸。

如果原始輸入是8幀,該網(wǎng)絡(luò)便會為每組輸出8個(gè)值分別代表每一幀的權(quán)重然后會直接用此值來加權(quán)融合過后每一幀的feature擦俐。我們也同時(shí)發(fā)現(xiàn)位于兩端的幀所預(yù)測的權(quán)重大多會比較低脊阴,這里我們的猜想是兩端的幀的特征在沿著時(shí)序移動時(shí)由于一邊沒有其他幀會損失掉一部分,因此導(dǎo)致了網(wǎng)絡(luò)給他們一個(gè)較低的權(quán)重來彌補(bǔ)信息損失帶來的影響蚯瞧。

可微模塊的具體框架如下所示:


它可以將各組按channel維度切分出來的特征沿著時(shí)間維度移動任意個(gè)單位嘿期。其實(shí)現(xiàn)方式主要是通過一維線性差值實(shí)現(xiàn)的。其中我們還采用了時(shí)序擴(kuò)展技術(shù)埋合,以保證偏移之后位于視頻之外的特征不為空备徐。舉個(gè)例子,原本位于T=0的特征在向前偏移0.5個(gè)單位后便位于T=-0.5的位置甚颂,該位置理論上是不存在特征的蜜猾,但我們通過假設(shè)T=-1位置的特征全為0使位于-0.5的位置取到了特征,也即Feature(T=-0.5) = (Feature(T=-1) + Feature(T=0))振诬。

1.1. Temporal-wise Frame Sampling

這里需要好好講解Temporal-wise Frame Sampling蹭睡, 該過程是一個(gè)線性插值的過程。

針對上面的描述赶么,這邊用一張圖片來進(jìn)行解釋說明肩豁。


Temporal-wise Frame Sampling

1.2. Temporal Extension
Temporal Extension

部分特征可能被移出而變?yōu)?,進(jìn)而在訓(xùn)練階段損失梯度辫呻。輸入范圍是[1, T]清钥,為了減輕這個(gè)現(xiàn)象帶來的影響,設(shè)置一個(gè)buffer來存儲處于(0,1)與(T,T+1)間隔中被移出的特征印屁。超出T+1與小于0的部分會被置0

1.3. Temporal Attention

關(guān)于這里的Temporal Attentation則是基于WeightNet生成的權(quán)重進(jìn)行, 與OffsetNet進(jìn)行組合循捺。

二、 代碼介紹


這里有關(guān)于數(shù)據(jù)及數(shù)據(jù)預(yù)處理可以參考前面的[mmaction2版本] 視頻分類(一) TSM:Temporal Shift Module for Efficient Video Understanding 原理及代碼講解這篇博客進(jìn)行理解雄人。

2.1. 特征提取網(wǎng)絡(luò)

本代碼的特征提取網(wǎng)絡(luò)是基于Resnet模型基礎(chǔ)上進(jìn)行改進(jìn)从橘。代碼如下所示:

blocks = list(stage.children())
for i, b in enumerate(blocks):
    if i % n_round == 0:
        tds = TemporalInterlace(
                   b.conv1.in_channels,
                   num_segments=num_segments,
                   shift_div=shift_div)
        blocks[i].conv1.conv = CombineNet(tds, blocks[i].conv1.conv)
return nn.Sequential(*blocks)

self.layer1 = make_block_interlace(self.layer1, num_segment_list[0], self.shift_div)
self.layer2 = make_block_interlace(self.layer2, num_segment_list[1], self.shift_div)
self.layer3 = make_block_interlace(self.layer3, num_segment_list[2], self.shift_div)
self.layer4 = make_block_interlace(self.layer4, num_segment_list[3], self.shift_div)

我們先看下self.layer1

Sequential(
  (0): Bottleneck(
    (conv1): ConvModule(
      (conv): CombineNet(
        (net1): TemporalInterlace(
          (offset_net): OffsetNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(16, 1, kernel_size=(3,), stride=(1,), padding=(1,))
            (fc1): Linear(in_features=8, out_features=8, bias=True)
            (relu): ReLU()
            (fc2): Linear(in_features=8, out_features=2, bias=True)
          )
          (weight_net): WeightNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(16, 2, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (net2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv2): ConvModule(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv3): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
    (downsample): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): ConvModule(
      (conv): CombineNet(
        (net1): TemporalInterlace(
          (offset_net): OffsetNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
            (fc1): Linear(in_features=8, out_features=8, bias=True)
            (relu): ReLU()
            (fc2): Linear(in_features=8, out_features=2, bias=True)
          )
          (weight_net): WeightNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 2, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (net2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv2): ConvModule(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv3): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
  (2): Bottleneck(
    (conv1): ConvModule(
      (conv): CombineNet(
        (net1): TemporalInterlace(
          (offset_net): OffsetNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
            (fc1): Linear(in_features=8, out_features=8, bias=True)
            (relu): ReLU()
            (fc2): Linear(in_features=8, out_features=2, bias=True)
          )
          (weight_net): WeightNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 2, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (net2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv2): ConvModule(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv3): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
)
2.2 TemproalInterplace

執(zhí)行代碼如下:

class TemporalInterlace(nn.Module):
    """Temporal interlace module.

    This module is proposed in `Temporal Interlacing Network
    <https://arxiv.org/abs/2001.06499>`_

    Args:
        in_channels (int): Channel num of input features.
        num_segments (int): Number of frame segments. Default: 3.
        shift_div (int): Number of division parts for shift. Default: 1.
    """

    def __init__(self, in_channels, num_segments=3, shift_div=1):
        super().__init__()
        self.num_segments = num_segments
        self.shift_div = shift_div
        self.in_channels = in_channels
        # hard code ``deform_groups`` according to original repo.
        self.deform_groups = 2

        self.offset_net = OffsetNet(in_channels // shift_div,
                                    self.deform_groups, num_segments)
        self.weight_net = WeightNet(in_channels // shift_div,
                                    self.deform_groups)

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        # x: [N, C, H, W],
        # where N = num_batches x num_segments, C = shift_div * num_folds
        n, c, h, w = x.size() # n=48 c=64, h=56, w=56
        #print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        #print(x.size())
        #print("#####################################################")
        num_batches = n // self.num_segments
        num_folds = c // self.shift_div # self.shift_div=4

        # x_out: [num_batches x num_segments, C, H, W]
        x_out = torch.zeros((n, c, h, w), device=x.device) # x_out shape=[48, 64, 56, 56]
        # x_descriptor: [num_batches, num_segments, num_folds, H, W]
        # num_folders=16
        x_descriptor = x[:, :num_folds, :, :].view(num_batches,
                                                   self.num_segments,
                                                   num_folds, h, w)
        # x_descriptor shape [6, 8, 16, 56, 56]

        # x should only obtain information on temporal and channel dimensions
        # x_pooled: [num_batches, num_segments, num_folds, W]
        x_pooled = torch.mean(x_descriptor, 3)
        # x_pooled: [num_batches, num_segments, num_folds]
        x_pooled = torch.mean(x_pooled, 3)
        # x_pooled: [num_batches, num_folds, num_segments]
        x_pooled = x_pooled.permute(0, 2, 1).contiguous()# x_pooled shape=[6, 16, 8]

        # Calculate weight and bias, here groups = 2
        # x_offset: [num_batches, groups]
        x_offset = self.offset_net(x_pooled).view(num_batches, -1) # x_offset shape [6, 2]
        # x_weight: [num_batches, num_segments, groups]
        x_weight = self.weight_net(x_pooled)

        # x_offset: [num_batches, 2 * groups]
        x_offset = torch.cat([x_offset, -x_offset], 1) # x_offset shape [6, 4]
        # x_shift: [num_batches, num_segments, num_folds, H, W]
        x_shift = linear_sampler(x_descriptor, x_offset)

        # x_weight: [num_batches, num_segments, groups, 1]
        x_weight = x_weight[:, :, :, None]
        # x_weight:
        # [num_batches, num_segments, groups * 2, c // self.shift_div // 4]
        x_weight = x_weight.repeat(1, 1, 2, num_folds // 2 // 2)
        # x_weight:
        # [num_batches, num_segments, c // self.shift_div = num_folds]
        x_weight = x_weight.view(x_weight.size(0), x_weight.size(1), -1)

        # x_weight: [num_batches, num_segments, num_folds, 1, 1]
        x_weight = x_weight[:, :, :, None, None]
        # x_shift: [num_batches, num_segments, num_folds, H, W]
        x_shift = x_shift * x_weight
        # x_shift: [num_batches, num_segments, num_folds, H, W]
        x_shift = x_shift.contiguous().view(n, num_folds, h, w)

        # x_out: [num_batches x num_segments, C, H, W]
        x_out[:, :num_folds, :] = x_shift
        x_out[:, num_folds:, :] = x[:, num_folds:, :]

        return x_out

首先輸入x shape為[48, 3, 224, 224], 對應(yīng)的含義分別是[batch_size, channel, height, width]。在進(jìn)行convmax pool得到特征圖大小為[48, 64, 56, 56]在輸入到上述模型代碼中础钠。
n, c, h, w = x.size(), 這里的n=48, c=64, h=56, w=56, num_batches=6, num_folders=16, 再通過x_descriptor shape 為 [num_batches, num_segments, C, H, W](這里shape為[6, 8, 16, 56, 56])恰力。x_out shape為[48, 64, 56, 56]
根據(jù)論文中提到的公式

后面在通過求平均的方式torch.mean對空間信息進(jìn)行平均信息壓縮旗吁,如下面代碼所示:

x_pooled = torch.mean(x_descriptor, 3)
x_pooled = torch.mean(x_pooled, 3)

我們得到x_pooledshape為[6, 16, 8]踩萎, 之后將該結(jié)果輸入到Offset Net

2.2.1 Offset Net

先給出代碼

class OffsetNet(nn.Module):
    """OffsetNet in Temporal interlace module.

    The OffsetNet consists of one convolution layer and two fc layers
    with a relu activation following with a sigmoid function. Following
    the convolution layer, two fc layers and relu are applied to the output.
    Then, apply the sigmoid function with a multiply factor and a minus 0.5
    to transform the output to (-4, 4).

    Args:
        in_channels (int): Channel num of input features.
        groups (int): Number of groups for fc layer outputs.
        num_segments (int): Number of frame segments.
    """

    def __init__(self, in_channels, groups, num_segments):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        # hard code ``kernel_size`` and ``padding`` according to original repo.
        kernel_size = 3
        padding = 1

        self.conv = nn.Conv1d(in_channels, 1, kernel_size, padding=padding)
        self.fc1 = nn.Linear(num_segments, num_segments)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(num_segments, groups)

        self.init_weights()

    def init_weights(self):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        # The bias of the last fc layer is initialized to
        # make the post-sigmoid output start from 1
        self.fc2.bias.data[...] = 0.5108

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        # calculate offset
        # [N, C, T]
        # x shape=[6, 16, 8]
        n, _, t = x.shape # n=6, t=8
        # [N, 1, T]
        x = self.conv(x) # conv1d[16, 1], kernel_size=3  x shape=[6, 1, 8] 相當(dāng)于在通道維度降維
        # [N, T]
        x = x.view(n, t) # x shape [6, 8]
        # [N, T]
        x = self.relu(self.fc1(x)) # fc1 [8, 8] x shape[6,8]
        # [N, groups]
        x = self.fc2(x) # fc2 shape [8, 2] x shape [6,2]
        # [N, 1, groups]
        x = x.view(n, 1, -1) # x shape [6, 1, 2]

        # to make sure the output is in (-t/2, t/2)
        # where t = num_segments = 8
        x = 4 * (self.sigmoid(x) - 0.5) # x shape [6, 1, 2]  t=8 so T= 8/2=4
        # [N, 1, groups]
        return x

根據(jù)論文中的公式:


首先對于輸入x shape為[6 ,16, 8]通過fc1以及relu得到輸出shape為[6, 8], 再將其輸入到fc2網(wǎng)絡(luò)中很钓,這里的fc2的輸出通道為2, 因?yàn)檫@里的group設(shè)置為2

        self.deform_groups = 2

所以輸出shape為[6, 1, 2]香府。再經(jīng)過如下公式:


這里我們設(shè)置T為4董栽,即T=t/2(t=num_segments), 輸出x的范圍為[-2, 2]并且shape為[6, 1, 2] 對應(yīng)的代碼如下所示:

# to make sure the output is in (-t/2, t/2)
# where t = num_segments = 8
x = 4 * (self.sigmoid(x) - 0.5) # x shape [6, 1, 2]  t=8 so T= 8/2=4
# [N, 1, groups]
return x
2.2.2 Weight Net

同時(shí)我們將x并行輸入到Weight Net, 首先先給出代碼

class WeightNet(nn.Module):
    """WeightNet in Temporal interlace module.

    The WeightNet consists of two parts: one convolution layer
    and a sigmoid function. Following the convolution layer, the sigmoid
    function and rescale module can scale our output to the range (0, 2).
    Here we set the initial bias of the convolution layer to 0, and the
    final initial output will be 1.0.

    Args:
        in_channels (int): Channel num of input features.
        groups (int): Number of groups for fc layer outputs.
    """

    def __init__(self, in_channels, groups):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        self.groups = groups

        self.conv = nn.Conv1d(in_channels, groups, 3, padding=1)

        self.init_weights()

    def init_weights(self):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        # we set the initial bias of the convolution
        # layer to 0, and the final initial output will be 1.0
        self.conv.bias.data[...] = 0

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        # calculate weight
        # [N, C, T]
        # x shape=[6, 16, 8]
        n, _, t = x.shape
        # [N, groups, T]
        x = self.conv(x) # x shape [6, 2, 8]
        x = x.view(n, self.groups, t) # x shape [6, 2, 8]
        # [N, T, groups]
        x = x.permute(0, 2, 1) # x shape [6, 8, 2]

        # scale the output to range (0, 2)
        x = 2 * self.sigmoid(x)
        # [N, T, groups]
        return x

對應(yīng)的最后x輸出的shape為[6, 8, 2]范圍是(0, 2)

2.2.3 Offset Net與 Weight Net結(jié)合
  1. 首先 x_offset = torch.cat([x_offset, -x_offset], 1)offset做對稱企孩。
  2. 再去取其權(quán)重及對應(yīng)的特征
    x_shift = linear_sampler(x_descriptor, x_offset)
    具體代碼如下
def linear_sampler(data, offset):
    """Differentiable Temporal-wise Frame Sampling, which is essentially a
    linear interpolation process.

    It gets the feature map which has been split into several groups
    and shift them by different offsets according to their groups.
    Then compute the weighted sum along with the temporal dimension.

    Args:
        data (torch.Tensor): Split data for certain group in shape
            [N, num_segments, C, H, W].
        offset (torch.Tensor): Data offsets for this group data in shape
            [N, num_segments].
    """
    # [N, num_segments, C, H, W]
    n, t, c, h, w = data.shape

    # offset0, offset1: [N, num_segments]
    offset0 = torch.floor(offset).int() # offset range [-2, 1]
    offset1 = offset0 + 1 # offset1 rang e[-1, 2] # 可以看出offset0 與offset1是對稱左右移動

    # data, data0, data1: [N, num_segments, C, H * W]
    data = data.view(n, t, c, h * w).contiguous() # data shape [6, 8, 16, 3136]

    try:
        from mmcv.ops import tin_shift
    except (ImportError, ModuleNotFoundError):
        raise ImportError('Failed to import `tin_shift` from `mmcv.ops`. You '
                          'will be unable to use TIN. ')
    # data shape [6, 8, 16, 3136]
    data0 = tin_shift(data, offset0) # data0 shape [6, 8, 16, 3136]
    data1 = tin_shift(data, offset1)

    # weight0, weight1: [N, num_segments]
    weight0 = 1 - (offset - offset0.float())
    weight1 = 1 - weight0

    # weight0, weight1:
    # [N, num_segments] -> [N, num_segments, C // num_segments] -> [N, C]
    group_size = offset.shape[1]
    weight0 = weight0[:, :, None].repeat(1, 1, c // group_size)
    weight0 = weight0.view(weight0.size(0), -1)
    weight1 = weight1[:, :, None].repeat(1, 1, c // group_size)
    weight1 = weight1.view(weight1.size(0), -1)

    # weight0, weight1: [N, C] -> [N, 1, C, 1]
    weight0 = weight0[:, None, :, None]
    weight1 = weight1[:, None, :, None]

    # output: [N, num_segments, C, H * W] -> [N, num_segments, C, H, W]
    output = weight0 * data0 + weight1 * data1
    output = output.view(n, t, c, h, w)

    return output

代碼中offset0對應(yīng)圖片中O_g-n_0
上式output = weight0 * data0 + weight1 * data1 即反映了文章的精華锭碳。我們繼續(xù)拿上面的這張圖來解釋這里的代碼中weight0對應(yīng)圖中n_0+1-O_g, weight1對應(yīng)圖中O_g-n_0
weight0 * data0 + weight1 * data1對應(yīng)論文中


最終得到x_shift, 再加上weight Net得到的權(quán)重注意力相乘得到其結(jié)果x_shift = x_shift * x_weight勿璃。 這里的tin_shift原理可以看.cuh代碼如下所示:

template <typename T>
__global__ void tin_shift_forward_cuda_kernel(
    const int nthreads, const T* input, const int* shift, T* output,
    const int batch_size, const int channels, const int t_size,
    const int hw_size, const int group_size, const int group_channel) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    const int hw_index = index % hw_size;
    const int j = (index / hw_size) % channels;

    const int n_index = (index / hw_size / channels) % batch_size;
    int group_id = j / group_channel;
    int t_shift = shift[n_index * group_size + group_id];
    int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
    for (int i = 0; i < t_size; i++) {
      int now_t = i + t_shift;
      int data_id = i * hw_size * channels + offset;
      if (now_t < 0 || now_t >= t_size) {
        continue;
      }
      int out_id = now_t * hw_size * channels + offset;
      output[out_id] = input[data_id];
    }
  }
}

剩下的部分就很簡單了擒抛,和TSM原理類似, 這里就不作多余解釋了补疑。
總結(jié)下歧沪,據(jù)我的理解是相對于TSM,在時(shí)間上基于OffsetNet偏移量是可以訓(xùn)練的莲组,再通過WeightNet可以給偏移量加權(quán)重诊胞,給更合適的偏移量更高的權(quán)重。有一個(gè)疑問就是為什么這邊的偏移量范圍是[-2, 2]的范圍锹杈,我這里的理解是相對于TSM增大了時(shí)間維度的感受野厢钧,如果更大則很多信息溢出了T,導(dǎo)致無法獲取嬉橙,所以這邊進(jìn)行了平衡,如果有其他不同的觀點(diǎn)歡迎提出寥假。

參考資料

【1】MMIT冠軍方案|用于行為識別的時(shí)間交錯網(wǎng)絡(luò)市框,商湯公開視頻理解代碼庫

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市糕韧,隨后出現(xiàn)的幾起案子枫振,更是在濱河造成了極大的恐慌,老刑警劉巖萤彩,帶你破解...
    沈念sama閱讀 211,561評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件粪滤,死亡現(xiàn)場離奇詭異,居然都是意外死亡雀扶,警方通過查閱死者的電腦和手機(jī)杖小,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,218評論 3 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來愚墓,“玉大人予权,你說我怎么就攤上這事±瞬幔” “怎么了扫腺?”我有些...
    開封第一講書人閱讀 157,162評論 0 348
  • 文/不壞的土叔 我叫張陵,是天一觀的道長村象。 經(jīng)常有香客問我笆环,道長攒至,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,470評論 1 283
  • 正文 為了忘掉前任躁劣,我火速辦了婚禮迫吐,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘习绢。我一直安慰自己渠抹,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,550評論 6 385
  • 文/花漫 我一把揭開白布闪萄。 她就那樣靜靜地躺著梧却,像睡著了一般。 火紅的嫁衣襯著肌膚如雪败去。 梳的紋絲不亂的頭發(fā)上放航,一...
    開封第一講書人閱讀 49,806評論 1 290
  • 那天,我揣著相機(jī)與錄音圆裕,去河邊找鬼广鳍。 笑死,一個(gè)胖子當(dāng)著我的面吹牛吓妆,可吹牛的內(nèi)容都是我干的赊时。 我是一名探鬼主播,決...
    沈念sama閱讀 38,951評論 3 407
  • 文/蒼蘭香墨 我猛地睜開眼行拢,長吁一口氣:“原來是場噩夢啊……” “哼祖秒!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起舟奠,我...
    開封第一講書人閱讀 37,712評論 0 266
  • 序言:老撾萬榮一對情侶失蹤竭缝,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后沼瘫,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體抬纸,經(jīng)...
    沈念sama閱讀 44,166評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,510評論 2 327
  • 正文 我和宋清朗相戀三年耿戚,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了湿故。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,643評論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡溅话,死狀恐怖晓锻,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情飞几,我是刑警寧澤砚哆,帶...
    沈念sama閱讀 34,306評論 4 330
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響躁锁,放射性物質(zhì)發(fā)生泄漏纷铣。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,930評論 3 313
  • 文/蒙蒙 一战转、第九天 我趴在偏房一處隱蔽的房頂上張望搜立。 院中可真熱鬧,春花似錦槐秧、人聲如沸啄踊。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,745評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽颠通。三九已至,卻和暖如春膀懈,著一層夾襖步出監(jiān)牢的瞬間顿锰,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,983評論 1 266
  • 我被黑心中介騙來泰國打工启搂, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留硼控,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,351評論 2 360
  • 正文 我出身青樓胳赌,卻偏偏與公主長得像牢撼,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子疑苫,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,509評論 2 348

推薦閱讀更多精彩內(nèi)容