接著上一篇文章TSM視頻理解, 今天介紹新的視頻分類網(wǎng)絡(luò)TIN(Temporal Interlacing Network)。相對于TSM综看,TIN可以更靈活的基于交錯網(wǎng)絡(luò)預(yù)測出我們的特征圖的隨著時(shí)間的偏移量值而不是向TSM每次移動一位進(jìn)行特征融合避诽,如果不了解的具體可以通過下面的內(nèi)容進(jìn)行理解信夫。我們?nèi)匀换趍maction框架進(jìn)行講解裕照。
paper: Temporal Interlacing Network
code: mmaction2
一、 原理介紹
該模型是基于時(shí)間交錯網(wǎng)絡(luò)進(jìn)行行為識別壁熄,當(dāng)時(shí)其在速度上實(shí)現(xiàn)了SOTA 6倍的加速耻台,同時(shí)在準(zhǔn)確率上實(shí)現(xiàn)了4%的提升。該方法的思想和TSM的思想一樣都是希望將時(shí)間信息嵌入到空間信息特征去紫皇,以便可以同時(shí)一次同時(shí)聯(lián)合學(xué)習(xí)兩個(gè)域中的信息慰安。作者發(fā)明此網(wǎng)絡(luò)的直覺做出了如下解釋
In order to integrate temporal information at different times, we can provide different frames with a unique interlacing offset. Instead of habitually assigning each channel with a separately learnable offset, we adopt distinctive offsets for different channel groups. As observed in SlowFast (Feichtenhofer et al. 2018), human perception on object motion focuses on different temporal resolutions. To maintain temporal fidelity and recognize spatial semantics jointly, different groups of temporal receptive fields pursuit a thorough separation of expertise convolution. Besides, groups of offsets also reduce the model complexity as well as stabilize the training procedure across heavy backbone architectures.(為了整合不同時(shí)間的時(shí)間信息,我們可以為不同的幀提供獨(dú)特的交錯偏移聪铺。 我們沒有習(xí)慣性地為每個(gè)通道分配一個(gè)可單獨(dú)學(xué)習(xí)的偏移量化焕,而是為不同的通道組采用不同的偏移量。 正如在 SlowFast (Feichtenhofer et al. 2018) 中所觀察到的铃剔,人類對物體運(yùn)動的感知側(cè)重于不同的時(shí)間分辨率撒桨。 為了保持時(shí)間保真度并共同識別空間語義,不同組的時(shí)間感受野追求專業(yè)卷積的徹底分離键兜。 此外凤类,偏移組還降低了模型的復(fù)雜性,并穩(wěn)定了跨重型骨干架構(gòu)的訓(xùn)練過程)
主要原理圖如下所示:
時(shí)間交錯框架如上圖所示普气,如圖(a) 展示(b)結(jié)構(gòu)位于殘差神經(jīng)網(wǎng)絡(luò)之前的結(jié)構(gòu)谜疤。對于整個(gè)特征圖,會將前的通道保持不變现诀,再會將的通道進(jìn)行分組夷磕,這里我們會分為4組,(
兩組
沿著T維度的偏移量仔沿,剩下的兩組
是偏移量是這前兩組的相反值坐桩, 這樣做可以保證信息在時(shí)序維度上的流動是對稱的,有利于 后續(xù)特征的融合)因?yàn)樽髡邔?shí)驗(yàn)發(fā)現(xiàn)兩組
的效果是最好的于未, 這樣每組對應(yīng)不同偏移量撕攒。這些偏移量是怎么預(yù)測出來的呢陡鹃?還是要對應(yīng)上圖的原理圖, 對于輸入的特征圖我們首先會輸入到3D平均池化網(wǎng)絡(luò),接著分別輸入到
OffsetNet
網(wǎng)絡(luò)以及WeightNet
在將兩者結(jié)合即可得到我們的偏移網(wǎng)絡(luò)的特征圖抖坪。OffsetNet
主要負(fù)責(zé)預(yù)測偏移量而WeightNet
主要負(fù)責(zé)預(yù)測融合后的時(shí)序維度上的特征權(quán)重萍鲸。
如果原始輸入是8幀,該網(wǎng)絡(luò)便會為每組輸出8個(gè)值分別代表每一幀的權(quán)重然后會直接用此值來加權(quán)融合過后每一幀的feature擦俐。我們也同時(shí)發(fā)現(xiàn)位于兩端的幀所預(yù)測的權(quán)重大多會比較低脊阴,這里我們的猜想是兩端的幀的特征在沿著時(shí)序移動時(shí)由于一邊沒有其他幀會損失掉一部分,因此導(dǎo)致了網(wǎng)絡(luò)給他們一個(gè)較低的權(quán)重來彌補(bǔ)信息損失帶來的影響蚯瞧。
可微模塊的具體框架如下所示:
它可以將各組按channel維度切分出來的特征沿著時(shí)間維度移動任意個(gè)單位嘿期。其實(shí)現(xiàn)方式主要是通過一維線性差值實(shí)現(xiàn)的。其中我們還采用了時(shí)序擴(kuò)展技術(shù)埋合,以保證偏移之后位于視頻之外的特征不為空备徐。舉個(gè)例子,原本位于T=0的特征在向前偏移0.5個(gè)單位后便位于T=-0.5的位置甚颂,該位置理論上是不存在特征的蜜猾,但我們通過假設(shè)T=-1位置的特征全為0使位于-0.5的位置取到了特征,也即Feature(T=-0.5) = (Feature(T=-1) + Feature(T=0))振诬。
1.1. Temporal-wise Frame Sampling
這里需要好好講解Temporal-wise Frame Sampling蹭睡, 該過程是一個(gè)線性插值的過程。
針對上面的描述赶么,這邊用一張圖片來進(jìn)行解釋說明肩豁。
1.2. Temporal Extension
部分特征可能被移出而變?yōu)?,進(jìn)而在訓(xùn)練階段損失梯度辫呻。輸入范圍是[1, T]清钥,為了減輕這個(gè)現(xiàn)象帶來的影響,設(shè)置一個(gè)buffer來存儲處于(0,1)與(T,T+1)間隔中被移出的特征印屁。超出T+1與小于0的部分會被置0
1.3. Temporal Attention
關(guān)于這里的Temporal Attentation
則是基于WeightNet
生成的權(quán)重進(jìn)行, 與OffsetNet進(jìn)行組合循捺。
二、 代碼介紹
這里有關(guān)于數(shù)據(jù)及數(shù)據(jù)預(yù)處理可以參考前面的[mmaction2版本] 視頻分類(一) TSM:Temporal Shift Module for Efficient Video Understanding 原理及代碼講解這篇博客進(jìn)行理解雄人。
2.1. 特征提取網(wǎng)絡(luò)
本代碼的特征提取網(wǎng)絡(luò)是基于Resnet
模型基礎(chǔ)上進(jìn)行改進(jìn)从橘。代碼如下所示:
blocks = list(stage.children())
for i, b in enumerate(blocks):
if i % n_round == 0:
tds = TemporalInterlace(
b.conv1.in_channels,
num_segments=num_segments,
shift_div=shift_div)
blocks[i].conv1.conv = CombineNet(tds, blocks[i].conv1.conv)
return nn.Sequential(*blocks)
self.layer1 = make_block_interlace(self.layer1, num_segment_list[0], self.shift_div)
self.layer2 = make_block_interlace(self.layer2, num_segment_list[1], self.shift_div)
self.layer3 = make_block_interlace(self.layer3, num_segment_list[2], self.shift_div)
self.layer4 = make_block_interlace(self.layer4, num_segment_list[3], self.shift_div)
我們先看下self.layer1
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): CombineNet(
(net1): TemporalInterlace(
(offset_net): OffsetNet(
(sigmoid): Sigmoid()
(conv): Conv1d(16, 1, kernel_size=(3,), stride=(1,), padding=(1,))
(fc1): Linear(in_features=8, out_features=8, bias=True)
(relu): ReLU()
(fc2): Linear(in_features=8, out_features=2, bias=True)
)
(weight_net): WeightNet(
(sigmoid): Sigmoid()
(conv): Conv1d(16, 2, kernel_size=(3,), stride=(1,), padding=(1,))
)
)
(net2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): CombineNet(
(net1): TemporalInterlace(
(offset_net): OffsetNet(
(sigmoid): Sigmoid()
(conv): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
(fc1): Linear(in_features=8, out_features=8, bias=True)
(relu): ReLU()
(fc2): Linear(in_features=8, out_features=2, bias=True)
)
(weight_net): WeightNet(
(sigmoid): Sigmoid()
(conv): Conv1d(64, 2, kernel_size=(3,), stride=(1,), padding=(1,))
)
)
(net2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): CombineNet(
(net1): TemporalInterlace(
(offset_net): OffsetNet(
(sigmoid): Sigmoid()
(conv): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
(fc1): Linear(in_features=8, out_features=8, bias=True)
(relu): ReLU()
(fc2): Linear(in_features=8, out_features=2, bias=True)
)
(weight_net): WeightNet(
(sigmoid): Sigmoid()
(conv): Conv1d(64, 2, kernel_size=(3,), stride=(1,), padding=(1,))
)
)
(net2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
2.2 TemproalInterplace
執(zhí)行代碼如下:
class TemporalInterlace(nn.Module):
"""Temporal interlace module.
This module is proposed in `Temporal Interlacing Network
<https://arxiv.org/abs/2001.06499>`_
Args:
in_channels (int): Channel num of input features.
num_segments (int): Number of frame segments. Default: 3.
shift_div (int): Number of division parts for shift. Default: 1.
"""
def __init__(self, in_channels, num_segments=3, shift_div=1):
super().__init__()
self.num_segments = num_segments
self.shift_div = shift_div
self.in_channels = in_channels
# hard code ``deform_groups`` according to original repo.
self.deform_groups = 2
self.offset_net = OffsetNet(in_channels // shift_div,
self.deform_groups, num_segments)
self.weight_net = WeightNet(in_channels // shift_div,
self.deform_groups)
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
# x: [N, C, H, W],
# where N = num_batches x num_segments, C = shift_div * num_folds
n, c, h, w = x.size() # n=48 c=64, h=56, w=56
#print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
#print(x.size())
#print("#####################################################")
num_batches = n // self.num_segments
num_folds = c // self.shift_div # self.shift_div=4
# x_out: [num_batches x num_segments, C, H, W]
x_out = torch.zeros((n, c, h, w), device=x.device) # x_out shape=[48, 64, 56, 56]
# x_descriptor: [num_batches, num_segments, num_folds, H, W]
# num_folders=16
x_descriptor = x[:, :num_folds, :, :].view(num_batches,
self.num_segments,
num_folds, h, w)
# x_descriptor shape [6, 8, 16, 56, 56]
# x should only obtain information on temporal and channel dimensions
# x_pooled: [num_batches, num_segments, num_folds, W]
x_pooled = torch.mean(x_descriptor, 3)
# x_pooled: [num_batches, num_segments, num_folds]
x_pooled = torch.mean(x_pooled, 3)
# x_pooled: [num_batches, num_folds, num_segments]
x_pooled = x_pooled.permute(0, 2, 1).contiguous()# x_pooled shape=[6, 16, 8]
# Calculate weight and bias, here groups = 2
# x_offset: [num_batches, groups]
x_offset = self.offset_net(x_pooled).view(num_batches, -1) # x_offset shape [6, 2]
# x_weight: [num_batches, num_segments, groups]
x_weight = self.weight_net(x_pooled)
# x_offset: [num_batches, 2 * groups]
x_offset = torch.cat([x_offset, -x_offset], 1) # x_offset shape [6, 4]
# x_shift: [num_batches, num_segments, num_folds, H, W]
x_shift = linear_sampler(x_descriptor, x_offset)
# x_weight: [num_batches, num_segments, groups, 1]
x_weight = x_weight[:, :, :, None]
# x_weight:
# [num_batches, num_segments, groups * 2, c // self.shift_div // 4]
x_weight = x_weight.repeat(1, 1, 2, num_folds // 2 // 2)
# x_weight:
# [num_batches, num_segments, c // self.shift_div = num_folds]
x_weight = x_weight.view(x_weight.size(0), x_weight.size(1), -1)
# x_weight: [num_batches, num_segments, num_folds, 1, 1]
x_weight = x_weight[:, :, :, None, None]
# x_shift: [num_batches, num_segments, num_folds, H, W]
x_shift = x_shift * x_weight
# x_shift: [num_batches, num_segments, num_folds, H, W]
x_shift = x_shift.contiguous().view(n, num_folds, h, w)
# x_out: [num_batches x num_segments, C, H, W]
x_out[:, :num_folds, :] = x_shift
x_out[:, num_folds:, :] = x[:, num_folds:, :]
return x_out
首先輸入x
shape為[48, 3, 224, 224]
, 對應(yīng)的含義分別是[batch_size, channel, height, width]
。在進(jìn)行conv
和max pool
得到特征圖大小為[48, 64, 56, 56]
在輸入到上述模型代碼中础钠。
n, c, h, w = x.size()
, 這里的n=48
, c=64
, h=56
, w=56
, num_batches=6
, num_folders=16
, 再通過x_descriptor
shape 為 [num_batches, num_segments, C, H, W](這里shape為[6, 8, 16, 56, 56])
恰力。x_out
shape為[48, 64, 56, 56]
。
根據(jù)論文中提到的公式
后面在通過求平均的方式torch.mean
對空間信息進(jìn)行平均信息壓縮旗吁,如下面代碼所示:
x_pooled = torch.mean(x_descriptor, 3)
x_pooled = torch.mean(x_pooled, 3)
我們得到x_pooled
shape為[6, 16, 8]
踩萎, 之后將該結(jié)果輸入到Offset Net
。
2.2.1 Offset Net
先給出代碼
class OffsetNet(nn.Module):
"""OffsetNet in Temporal interlace module.
The OffsetNet consists of one convolution layer and two fc layers
with a relu activation following with a sigmoid function. Following
the convolution layer, two fc layers and relu are applied to the output.
Then, apply the sigmoid function with a multiply factor and a minus 0.5
to transform the output to (-4, 4).
Args:
in_channels (int): Channel num of input features.
groups (int): Number of groups for fc layer outputs.
num_segments (int): Number of frame segments.
"""
def __init__(self, in_channels, groups, num_segments):
super().__init__()
self.sigmoid = nn.Sigmoid()
# hard code ``kernel_size`` and ``padding`` according to original repo.
kernel_size = 3
padding = 1
self.conv = nn.Conv1d(in_channels, 1, kernel_size, padding=padding)
self.fc1 = nn.Linear(num_segments, num_segments)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(num_segments, groups)
self.init_weights()
def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
# The bias of the last fc layer is initialized to
# make the post-sigmoid output start from 1
self.fc2.bias.data[...] = 0.5108
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
# calculate offset
# [N, C, T]
# x shape=[6, 16, 8]
n, _, t = x.shape # n=6, t=8
# [N, 1, T]
x = self.conv(x) # conv1d[16, 1], kernel_size=3 x shape=[6, 1, 8] 相當(dāng)于在通道維度降維
# [N, T]
x = x.view(n, t) # x shape [6, 8]
# [N, T]
x = self.relu(self.fc1(x)) # fc1 [8, 8] x shape[6,8]
# [N, groups]
x = self.fc2(x) # fc2 shape [8, 2] x shape [6,2]
# [N, 1, groups]
x = x.view(n, 1, -1) # x shape [6, 1, 2]
# to make sure the output is in (-t/2, t/2)
# where t = num_segments = 8
x = 4 * (self.sigmoid(x) - 0.5) # x shape [6, 1, 2] t=8 so T= 8/2=4
# [N, 1, groups]
return x
根據(jù)論文中的公式:
首先對于輸入x
shape為[6 ,16, 8]
通過fc1
以及relu
得到輸出shape為[6, 8], 再將其輸入到fc2
網(wǎng)絡(luò)中很钓,這里的fc2
的輸出通道為2
, 因?yàn)檫@里的group
設(shè)置為2
self.deform_groups = 2
所以輸出shape為[6, 1, 2]
香府。再經(jīng)過如下公式:
這里我們設(shè)置
T
為4董栽,即T=t/2(t=num_segments)
, 輸出x
的范圍為[-2, 2]
并且shape為[6, 1, 2]
對應(yīng)的代碼如下所示:
# to make sure the output is in (-t/2, t/2)
# where t = num_segments = 8
x = 4 * (self.sigmoid(x) - 0.5) # x shape [6, 1, 2] t=8 so T= 8/2=4
# [N, 1, groups]
return x
2.2.2 Weight Net
同時(shí)我們將x并行輸入到Weight Net, 首先先給出代碼
class WeightNet(nn.Module):
"""WeightNet in Temporal interlace module.
The WeightNet consists of two parts: one convolution layer
and a sigmoid function. Following the convolution layer, the sigmoid
function and rescale module can scale our output to the range (0, 2).
Here we set the initial bias of the convolution layer to 0, and the
final initial output will be 1.0.
Args:
in_channels (int): Channel num of input features.
groups (int): Number of groups for fc layer outputs.
"""
def __init__(self, in_channels, groups):
super().__init__()
self.sigmoid = nn.Sigmoid()
self.groups = groups
self.conv = nn.Conv1d(in_channels, groups, 3, padding=1)
self.init_weights()
def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
# we set the initial bias of the convolution
# layer to 0, and the final initial output will be 1.0
self.conv.bias.data[...] = 0
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
# calculate weight
# [N, C, T]
# x shape=[6, 16, 8]
n, _, t = x.shape
# [N, groups, T]
x = self.conv(x) # x shape [6, 2, 8]
x = x.view(n, self.groups, t) # x shape [6, 2, 8]
# [N, T, groups]
x = x.permute(0, 2, 1) # x shape [6, 8, 2]
# scale the output to range (0, 2)
x = 2 * self.sigmoid(x)
# [N, T, groups]
return x
對應(yīng)的最后x
輸出的shape為[6, 8, 2]
范圍是(0, 2)
2.2.3 Offset Net與 Weight Net結(jié)合
- 首先
x_offset = torch.cat([x_offset, -x_offset], 1)
對offset
做對稱企孩。 - 再去取其權(quán)重及對應(yīng)的特征
x_shift = linear_sampler(x_descriptor, x_offset)
具體代碼如下
def linear_sampler(data, offset):
"""Differentiable Temporal-wise Frame Sampling, which is essentially a
linear interpolation process.
It gets the feature map which has been split into several groups
and shift them by different offsets according to their groups.
Then compute the weighted sum along with the temporal dimension.
Args:
data (torch.Tensor): Split data for certain group in shape
[N, num_segments, C, H, W].
offset (torch.Tensor): Data offsets for this group data in shape
[N, num_segments].
"""
# [N, num_segments, C, H, W]
n, t, c, h, w = data.shape
# offset0, offset1: [N, num_segments]
offset0 = torch.floor(offset).int() # offset range [-2, 1]
offset1 = offset0 + 1 # offset1 rang e[-1, 2] # 可以看出offset0 與offset1是對稱左右移動
# data, data0, data1: [N, num_segments, C, H * W]
data = data.view(n, t, c, h * w).contiguous() # data shape [6, 8, 16, 3136]
try:
from mmcv.ops import tin_shift
except (ImportError, ModuleNotFoundError):
raise ImportError('Failed to import `tin_shift` from `mmcv.ops`. You '
'will be unable to use TIN. ')
# data shape [6, 8, 16, 3136]
data0 = tin_shift(data, offset0) # data0 shape [6, 8, 16, 3136]
data1 = tin_shift(data, offset1)
# weight0, weight1: [N, num_segments]
weight0 = 1 - (offset - offset0.float())
weight1 = 1 - weight0
# weight0, weight1:
# [N, num_segments] -> [N, num_segments, C // num_segments] -> [N, C]
group_size = offset.shape[1]
weight0 = weight0[:, :, None].repeat(1, 1, c // group_size)
weight0 = weight0.view(weight0.size(0), -1)
weight1 = weight1[:, :, None].repeat(1, 1, c // group_size)
weight1 = weight1.view(weight1.size(0), -1)
# weight0, weight1: [N, C] -> [N, 1, C, 1]
weight0 = weight0[:, None, :, None]
weight1 = weight1[:, None, :, None]
# output: [N, num_segments, C, H * W] -> [N, num_segments, C, H, W]
output = weight0 * data0 + weight1 * data1
output = output.view(n, t, c, h, w)
return output
代碼中offset0
對應(yīng)圖片中
上式output = weight0 * data0 + weight1 * data1
即反映了文章的精華锭碳。我們繼續(xù)拿上面的這張圖來解釋這里的代碼中weight0
對應(yīng)圖中, weight1
對應(yīng)圖中。
weight0 * data0 + weight1 * data1
對應(yīng)論文中
最終得到
x_shift
, 再加上weight Net
得到的權(quán)重注意力相乘得到其結(jié)果x_shift = x_shift * x_weight
勿璃。 這里的tin_shift
原理可以看.cuh
代碼如下所示:
template <typename T>
__global__ void tin_shift_forward_cuda_kernel(
const int nthreads, const T* input, const int* shift, T* output,
const int batch_size, const int channels, const int t_size,
const int hw_size, const int group_size, const int group_channel) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int hw_index = index % hw_size;
const int j = (index / hw_size) % channels;
const int n_index = (index / hw_size / channels) % batch_size;
int group_id = j / group_channel;
int t_shift = shift[n_index * group_size + group_id];
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
for (int i = 0; i < t_size; i++) {
int now_t = i + t_shift;
int data_id = i * hw_size * channels + offset;
if (now_t < 0 || now_t >= t_size) {
continue;
}
int out_id = now_t * hw_size * channels + offset;
output[out_id] = input[data_id];
}
}
}
剩下的部分就很簡單了擒抛,和TSM原理類似, 這里就不作多余解釋了补疑。
總結(jié)下歧沪,據(jù)我的理解是相對于TSM,在時(shí)間上基于OffsetNet
偏移量是可以訓(xùn)練的莲组,再通過WeightNet
可以給偏移量加權(quán)重诊胞,給更合適的偏移量更高的權(quán)重。有一個(gè)疑問就是為什么這邊的偏移量范圍是[-2, 2]的范圍锹杈,我這里的理解是相對于TSM增大了時(shí)間維度的感受野厢钧,如果更大則很多信息溢出了T,導(dǎo)致無法獲取嬉橙,所以這邊進(jìn)行了平衡,如果有其他不同的觀點(diǎn)歡迎提出寥假。