這里對(duì)lstm的解釋挺好:https://zhuanlan.zhihu.com/p/32085405
https://blog.csdn.net/weixin_42769131/article/details/104728842
class ConvLSTMCell(nn.Module):
"""
Generate a convolutional LSTM cell
"""
def __init__(self, input_size, hidden_size):
super(ConvLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size=3, stride=1, padding=1)
def forward(self, input_, prev_state):
# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (
torch.zeros(state_size).cuda(),
torch.zeros(state_size).cuda()
)
prev_hidden, prev_cell = prev_state
# data size is [batch, channel, height, width]
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)
# chunk across channel dimension
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
# cell_gate, 表示當(dāng)前的輸入xt和前面時(shí)刻的輸出 的和巫击,tanh拉到-1~1之間 是輸入數(shù)據(jù)
# 忘記階段:remember_gate 遺忘門锭硼,控制上一個(gè)細(xì)胞狀態(tài)留下多少信息,
# 選擇記憶:in_gate 對(duì)當(dāng)前的輸入信息(information) xt有選擇的進(jìn)行記憶,
# 輸出階段:out_gate 決定哪些作為當(dāng)前狀態(tài)的輸出
# apply sigmoid non linearity
in_gate = F.sigmoid(in_gate)
remember_gate = F.sigmoid(remember_gate)
out_gate = F.sigmoid(out_gate)
# apply tanh non linearity
cell_gate = F.tanh(cell_gate) # -1~1 之間的特行糠馆,這是作為輸入數(shù)據(jù)而不是門控信號(hào)
# compute current cell and hidden state
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * F.tanh(cell)
return hidden, cell