PyTorch實現(xiàn)卷積LSTM核(ConvLSTMCell)

簡單RNN與LSTM對比
LSTM計算示意
LSTM計算示意
import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable


# Define some constants
KERNEL_SIZE = 3
PADDING = KERNEL_SIZE // 2


class ConvLSTMCell(nn.Module):
    """
    Generate a convolutional LSTM cell
    """

    def __init__(self, input_size, hidden_size):
        super(ConvLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING)

    def forward(self, input_, prev_state):

        # get batch and spatial sizes
        batch_size = input_.data.size()[0]
        spatial_size = input_.data.size()[2:]

        # generate empty prev_state, if None is provided
        if prev_state is None:
            state_size = [batch_size, self.hidden_size] + list(spatial_size)
            prev_state = (
                Variable(torch.zeros(state_size)),
                Variable(torch.zeros(state_size))
            )

        prev_hidden, prev_cell = prev_state

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat((input_, prev_hidden), 1)
        gates = self.Gates(stacked_inputs)

        # chunk across channel dimension
        in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)

        # apply sigmoid non linearity
        in_gate = f.sigmoid(in_gate)
        remember_gate = f.sigmoid(remember_gate)
        out_gate = f.sigmoid(out_gate)

        # apply tanh non linearity
        cell_gate = f.tanh(cell_gate)

        # compute current cell and hidden state
        cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
        hidden = out_gate * f.tanh(cell)

        return hidden, cell


def _main():
    """
    Run some basic tests on the API
    """

    # define batch_size, channels, height, width
    b, c, h, w = 1, 3, 4, 8
    d = 5           # hidden state size
    lr = 1e-1       # learning rate
    T = 6           # sequence length
    max_epoch = 20  # number of epochs

    # set manual seed
    torch.manual_seed(0)

    print('Instantiate model')
    model = ConvLSTMCell(c, d)
    print(repr(model))

    print('Create input and target Variables')
    x = Variable(torch.rand(T, b, c, h, w))
    y = Variable(torch.randn(T, b, d, h, w))

    print('Create a MSE criterion')
    loss_fn = nn.MSELoss()

    print('Run for', max_epoch, 'iterations')
    for epoch in range(0, max_epoch):
        state = None
        loss = 0
        for t in range(0, T):
            state = model(x[t], state)
            loss += loss_fn(state[0], y[t])

        print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0]))

        # zero grad parameters
        model.zero_grad()

        # compute new grad parameters through time!
        loss.backward()

        # learning_rate step against the gradient
        for p in model.parameters():
            p.data.sub_(p.grad.data * lr)

    print('Input size:', list(x.data.size()))
    print('Target size:', list(y.data.size()))
    print('Last hidden state size:', list(state[0].size()))


if __name__ == '__main__':
    _main()
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末竟宋,一起剝皮案震驚了整個濱河市提完,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌丘侠,老刑警劉巖徒欣,帶你破解...
    沈念sama閱讀 222,104評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異蜗字,居然都是意外死亡打肝,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,816評論 3 399
  • 文/潘曉璐 我一進店門挪捕,熙熙樓的掌柜王于貴愁眉苦臉地迎上來粗梭,“玉大人,你說我怎么就攤上這事担神÷コ裕” “怎么了?”我有些...
    開封第一講書人閱讀 168,697評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長孩锡。 經(jīng)常有香客問我酷宵,道長,這世上最難降的妖魔是什么躬窜? 我笑而不...
    開封第一講書人閱讀 59,836評論 1 298
  • 正文 為了忘掉前任浇垦,我火速辦了婚禮,結果婚禮上荣挨,老公的妹妹穿的比我還像新娘男韧。我一直安慰自己,他們只是感情好默垄,可當我...
    茶點故事閱讀 68,851評論 6 397
  • 文/花漫 我一把揭開白布此虑。 她就那樣靜靜地躺著,像睡著了一般口锭。 火紅的嫁衣襯著肌膚如雪朦前。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,441評論 1 310
  • 那天鹃操,我揣著相機與錄音韭寸,去河邊找鬼。 笑死荆隘,一個胖子當著我的面吹牛恩伺,可吹牛的內容都是我干的。 我是一名探鬼主播椰拒,決...
    沈念sama閱讀 40,992評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼晶渠,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了耸三?” 一聲冷哼從身側響起乱陡,我...
    開封第一講書人閱讀 39,899評論 0 276
  • 序言:老撾萬榮一對情侶失蹤浇揩,失蹤者是張志新(化名)和其女友劉穎仪壮,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體胳徽,經(jīng)...
    沈念sama閱讀 46,457評論 1 318
  • 正文 獨居荒郊野嶺守林人離奇死亡积锅,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 38,529評論 3 341
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了养盗。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片缚陷。...
    茶點故事閱讀 40,664評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖往核,靈堂內的尸體忽然破棺而出箫爷,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 36,346評論 5 350
  • 正文 年R本政府宣布虎锚,位于F島的核電站硫痰,受9級特大地震影響,放射性物質發(fā)生泄漏窜护。R本人自食惡果不足惜效斑,卻給世界環(huán)境...
    茶點故事閱讀 42,025評論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望柱徙。 院中可真熱鬧缓屠,春花似錦、人聲如沸护侮。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,511評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽羊初。三九已至蠢挡,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間凳忙,已是汗流浹背业踏。 一陣腳步聲響...
    開封第一講書人閱讀 33,611評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留涧卵,地道東北人勤家。 一個月前我還...
    沈念sama閱讀 49,081評論 3 377
  • 正文 我出身青樓,卻偏偏與公主長得像柳恐,于是被迫代替她去往敵國和親伐脖。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,675評論 2 359

推薦閱讀更多精彩內容