簡(jiǎn)介
長(zhǎng)短期記憶人工神經(jīng)網(wǎng)絡(luò)(Long-Short Term Memory, LSTM)是一種時(shí)間遞歸神經(jīng)網(wǎng)絡(luò)(RNN),論文首次發(fā)表于1997年。由于獨(dú)特的設(shè)計(jì)結(jié)構(gòu),LSTM適合于處理和預(yù)測(cè)時(shí)間序列中間隔和延遲非常長(zhǎng)的重要事件。
由于其結(jié)構(gòu)和RNN很相似惕稻,就是將單一的激活函數(shù)換成更為復(fù)雜的結(jié)構(gòu)。前面《RNN(循環(huán)神經(jīng)網(wǎng)絡(luò))訓(xùn)練手寫數(shù)字》的數(shù)據(jù)處理和很多代碼都有共通之處蝙叛,本文就從簡(jiǎn)??俺祠。
公式
LSTM的結(jié)構(gòu)有很多種形式,但是都大同小異借帘,主要都包含輸入門蜘渣、輸出門、遺忘門肺然。
本文實(shí)現(xiàn)的一種較為流行的結(jié)構(gòu)GRU(Gated Recurrent Unit)蔫缸。公式與結(jié)構(gòu)圖如下:
實(shí)現(xiàn)
相比較于簡(jiǎn)單的RNN網(wǎng)絡(luò),LSTM訓(xùn)練的參數(shù)更多际起,單個(gè)塊的結(jié)構(gòu)也更復(fù)雜拾碌。實(shí)現(xiàn)中,我將輸入的誤差也反傳了街望,這樣可以很方便的實(shí)現(xiàn)多層LSTM網(wǎng)絡(luò)校翔,或者與RNN/CNN網(wǎng)絡(luò)結(jié)合使用。
主體代碼如下:
//
// MLLstm.m
// LSTM
//
// Created by Jiao Liu on 11/12/16.
// Copyright ? 2016 ChangHong. All rights reserved.
//
#import "MLLstm.h"
@implementation MLLstm
#pragma mark - Inner Method
+ (double)truncated_normal:(double)mean dev:(double)stddev
{
double outP = 0.0;
do {
static int hasSpare = 0;
static double spare;
if (hasSpare) {
hasSpare = 0;
outP = mean + stddev * spare;
continue;
}
hasSpare = 1;
static double u,v,s;
do {
u = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
v = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
s = u * u + v * v;
} while ((s >= 1.0) || (s == 0.0));
s = sqrt(-2.0 * log(s) / s);
spare = v * s;
outP = mean + stddev * u * s;
} while (fabsl(outP) > 2*stddev);
return outP;
}
+ (double *)fillVector:(double)num size:(int)size
{
double *outP = malloc(sizeof(double) * size);
vDSP_vfillD(&num, outP, 1, size);
return outP;
}
+ (double *)weight_init:(int)size
{
double *outP = malloc(sizeof(double) * size);
for (int i = 0; i < size; i++) {
outP[i] = [MLLstm truncated_normal:0 dev:0.1];
}
return outP;
}
+ (double *)bias_init:(int)size
{
return [MLLstm fillVector:0.1f size:size];
}
+ (double *)tanh:(double *)input size:(int)size
{
for (int i = 0; i < size; i++) {
double num = input[i];
if (num > 20) {
input[i] = 1;
}
else if (num < -20)
{
input[i] = -1;
}
else
{
input[i] = (exp(num) - exp(-num)) / (exp(num) + exp(-num));
}
}
return input;
}
+ (double *)sigmoid:(double *)input size:(int)size
{
for (int i = 0; i < size; i++) {
double num = input[i];
if (num > 20) {
input[i] = 1;
}
else if (num < -20)
{
input[i] = 0;
}
else
{
input[i] = exp(num) / (exp(num) + 1);
}
}
return input;
}
#pragma mark - Init
- (id)initWithNodeNum:(int)num layerSize:(int)size dataDim:(int)dim
{
self = [super init];
if (self) {
_nodeNum = num;
_layerSize = size;
_dataDim = dim;
[self setupNet];
}
return self;
}
- (id)init
{
self = [super init];
if (self) {
[self setupNet];
}
return self;
}
- (void)setupNet
{
_hState = calloc(_layerSize * _nodeNum, sizeof(double));
_rState = calloc(_layerSize * _nodeNum, sizeof(double));
_zState = calloc(_layerSize * _nodeNum, sizeof(double));
_hbState = calloc(_layerSize * _nodeNum, sizeof(double));
_output = calloc(_layerSize * _dataDim, sizeof(double));
_backLoss = calloc(_layerSize * _dataDim, sizeof(double));
_rW = [MLLstm weight_init:_nodeNum * _dataDim];
_rU = [MLLstm weight_init:_nodeNum * _nodeNum];
_rBias = [MLLstm bias_init:_nodeNum];
_zW = [MLLstm weight_init:_nodeNum * _dataDim];
_zU = [MLLstm weight_init:_nodeNum * _nodeNum];
_zBias = [MLLstm bias_init:_nodeNum];
_hW = [MLLstm weight_init:_nodeNum * _dataDim];
_hU = [MLLstm weight_init:_nodeNum * _nodeNum];
_hBias = [MLLstm bias_init:_nodeNum];
_outW = [MLLstm weight_init:_dataDim * _nodeNum];
_outBias = [MLLstm bias_init:_dataDim];
}
- (double *)forwardPropagation:(double *)input
{
_input = input;
// clean data
double zero = 0;
vDSP_vfillD(&zero, _output, 1, _layerSize * _dataDim);
vDSP_vfillD(&zero, _hState, 1, _layerSize * _nodeNum);
vDSP_vfillD(&zero, _rState, 1, _layerSize * _nodeNum);
vDSP_vfillD(&zero, _zState, 1, _layerSize * _nodeNum);
vDSP_vfillD(&zero, _hbState, 1, _layerSize * _nodeNum);
vDSP_vfillD(&zero, _backLoss, 1, _layerSize * _dataDim);
double *temp1 = calloc(_nodeNum, sizeof(double));
double *temp2 = calloc(_nodeNum, sizeof(double));
double *temp3 = calloc(_nodeNum, sizeof(double));
double *one = [MLLstm fillVector:1 size:_nodeNum];
for (int i = 0; i < _layerSize; i++) {
//rj =σ [Wr*(xt)]j + Ur*h?t?1? + rBias]
if (i == 0) {
vDSP_mmulD(_rW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
vDSP_vaddD(temp1, 1, _rBias, 1, temp1, 1, _nodeNum);
}
else
{
vDSP_mmulD(_rW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
vDSP_mmulD(_rU, 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum, 1, _nodeNum);
vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
vDSP_vaddD(temp1, 1, _rBias, 1, temp1, 1, _nodeNum);
}
[MLLstm sigmoid:temp1 size:_nodeNum];
vDSP_vaddD((_rState + i * _nodeNum), 1, temp1, 1, (_rState + i * _nodeNum), 1, _nodeNum);
//zj =σ [Wz*(xt)]j + Uz*h?t?1? + zBias]
if (i == 0) {
vDSP_mmulD(_zW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
vDSP_vaddD(temp1, 1, _zBias, 1, temp1, 1, _nodeNum);
}
else
{
vDSP_mmulD(_zW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
vDSP_mmulD(_zU, 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum, 1, _nodeNum);
vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
vDSP_vaddD(temp1, 1, _zBias, 1, temp1, 1, _nodeNum);
}
[MLLstm sigmoid:temp1 size:_nodeNum];
vDSP_vaddD((_zState + i * _nodeNum), 1, temp1, 1, (_zState + i * _nodeNum), 1, _nodeNum);
//h ??t? = tanh {[W*(xt)] + U * [r ⊙ h?t?1?] + hBias}
if (i == 0) {
vDSP_mmulD(_hW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
vDSP_vaddD(temp1, 1, _hBias, 1, temp1, 1, _nodeNum);
}
else
{
vDSP_mmulD(_hW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
vDSP_vmulD((_rState + i * _nodeNum), 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum);
vDSP_mmulD(_hU, 1, temp2, 1, temp3, 1, _nodeNum, 1, _nodeNum);
vDSP_vaddD(temp1, 1, temp3, 1, temp1, 1, _nodeNum);
vDSP_vaddD(temp1, 1, _hBias, 1, temp1, 1, _nodeNum);
}
[MLLstm tanh:temp1 size:_nodeNum];
vDSP_vaddD((_hbState + i * _nodeNum), 1, temp1, 1, (_hbState + i * _nodeNum), 1, _nodeNum);
//h?t? = zj⊙ h?t?1? + (1 ? zj)⊙ h ??t?
if (i == 0) {
vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
vDSP_vmulD((_hbState + i * _nodeNum), 1, temp1, 1, temp1, 1, _nodeNum);
}
else
{
vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
vDSP_vmulD((_hbState + i * _nodeNum), 1, temp1, 1, temp1, 1, _nodeNum);
vDSP_vmulD((_zState + i * _nodeNum), 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum);
vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
}
vDSP_vaddD((_hState + i * _nodeNum), 1, temp1, 1, (_hState + i * _nodeNum), 1, _nodeNum);
// output
vDSP_mmulD(_outW, 1, (_hState + i * _nodeNum), 1, (_output + i * _dataDim), 1, _dataDim, 1, _nodeNum);
vDSP_vaddD(_outBias, 1, (_output + i * _dataDim), 1, (_output + i * _dataDim), 1, _dataDim);
}
free(one);
free(temp1);
free(temp2);
free(temp3);
return _output;
}
- (double *)backPropagation:(double *)loss
{
double *flowLoss = calloc(_nodeNum, sizeof(double));
double *outTW = calloc(_nodeNum * _dataDim, sizeof(double));
double *outLoss = calloc(_nodeNum, sizeof(double));
double *outWLoss = calloc(_dataDim * _nodeNum, sizeof(double));
double *temp1 = calloc(_nodeNum, sizeof(double));
double *one = [MLLstm fillVector:1 size:_nodeNum];
double *zLoss = calloc(_nodeNum, sizeof(double));
double *hbLoss = calloc(_nodeNum, sizeof(double));
double *inWLoss = calloc(_nodeNum * _dataDim, sizeof(double));
double *rLoss = calloc(_nodeNum, sizeof(double));
double *tU = calloc(_nodeNum * _nodeNum, sizeof(double));
double *uLoss = calloc(_nodeNum * _nodeNum, sizeof(double));
double *tW = calloc(_dataDim * _nodeNum, sizeof(double));
double *temp2 = calloc(_dataDim, sizeof(double));
for (int i = _layerSize - 1; i >= 0; i--) {
// update output parameters
vDSP_vaddD(_outBias, 1, (loss + i * _dataDim), 1, _outBias, 1, _dataDim);
vDSP_mtransD(_outW, 1, outTW, 1, _nodeNum, _dataDim);
vDSP_mmulD(outTW, 1, (loss + i * _dataDim), 1, outLoss, 1, _nodeNum, 1, _dataDim);
vDSP_mmulD((loss + i * _dataDim), 1, (_hState + i * _nodeNum), 1, outWLoss, 1, _dataDim, _nodeNum, 1);
vDSP_vaddD(_outW, 1, outWLoss, 1, _outW, 1, _dataDim * _nodeNum);
// h(t) back loss
if (i != _layerSize - 1) {
vDSP_vaddD(outLoss, 1, flowLoss, 1, outLoss, 1, _nodeNum);
}
if (i > 0) {
vDSP_vsubD((_hState + (i-1) * _nodeNum), 1, (_hbState + i * _nodeNum), 1, temp1, 1, _nodeNum);
vDSP_vmulD(outLoss, 1, temp1, 1, zLoss, 1, _nodeNum);
vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
vDSP_vmulD(outLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
}
else
{
vDSP_vmulD(outLoss, 1, (_hbState + i * _nodeNum), 1, zLoss, 1, _nodeNum);
}
// σ` = f(x)*(1-f(x))
vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
vDSP_vmulD(temp1, 1, (_zState + i * _nodeNum), 1, temp1, 1, _nodeNum);
vDSP_vmulD(temp1, 1, zLoss, 1, zLoss, 1, _nodeNum);
vDSP_vmulD(outLoss, 1, (_zState + i * _nodeNum), 1, hbLoss, 1, _nodeNum);
// tanh` = 1-f(x)**2
vDSP_vsqD((_hbState + i * _nodeNum), 1, temp1, 1, _nodeNum);
vDSP_vsubD(temp1, 1, one, 1, temp1, 1, _nodeNum);
vDSP_vmulD(hbLoss, 1, temp1, 1, hbLoss, 1, _nodeNum);
// update h`(t) parameters
vDSP_vaddD(_hBias, 1, hbLoss, 1, _hBias, 1, _nodeNum);
vDSP_mtransD(_hW, 1, tW, 1, _dataDim, _nodeNum);
vDSP_mmulD(tW, 1, hbLoss, 1, temp2, 1, _dataDim, 1, _nodeNum);
vDSP_vaddD((_backLoss + i * _dataDim), 1, temp2, 1, (_backLoss + i * _dataDim), 1, _dataDim);
vDSP_mmulD(hbLoss, 1, (_input + i * _dataDim), 1, inWLoss, 1, _nodeNum, _dataDim, 1);
vDSP_vaddD(_hW, 1, inWLoss, 1, _hW, 1, _nodeNum * _dataDim);
if (i > 0) {
vDSP_mtransD(_hU, 1, tU, 1, _nodeNum, _nodeNum);
vDSP_mmulD(tU, 1, hbLoss, 1, rLoss, 1, _nodeNum, 1, _nodeNum);
vDSP_vmulD(rLoss, 1, (_hState + (i-1) * _nodeNum), 1, rLoss, 1, _nodeNum);
vDSP_vsubD((_rState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
vDSP_vmulD(temp1, 1, (_rState + i * _nodeNum), 1, temp1, 1, _nodeNum);
vDSP_vmulD(temp1, 1, rLoss, 1, rLoss, 1, _nodeNum);
vDSP_mmulD(tU, 1, hbLoss, 1, temp1, 1, _nodeNum, 1, _nodeNum);
vDSP_vmulD(temp1, 1, (_rState + i * _nodeNum), 1, temp1, 1, _nodeNum);
vDSP_vaddD(flowLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
vDSP_vmulD((_rState + i * _nodeNum), 1, (_hState + (i-1) * _nodeNum), 1, temp1, 1, _nodeNum);
vDSP_mmulD(hbLoss, 1, temp1, 1, uLoss, 1, _nodeNum, _nodeNum, 1);
vDSP_vaddD(_hU, 1, uLoss, 1, _hU, 1, _nodeNum * _nodeNum);
}
// update z(t) parameters
vDSP_vaddD(_zBias, 1, zLoss, 1, _zBias, 1, _nodeNum);
vDSP_mtransD(_zW, 1, tW, 1, _dataDim, _nodeNum);
vDSP_mmulD(tW, 1, zLoss, 1, temp2, 1, _dataDim, 1, _nodeNum);
vDSP_vaddD((_backLoss + i * _dataDim), 1, temp2, 1, (_backLoss + i * _dataDim), 1, _dataDim);
vDSP_mmulD(zLoss, 1, (_input + i * _dataDim), 1, inWLoss, 1, _nodeNum, _dataDim, 1);
vDSP_vaddD(_zW, 1, inWLoss, 1, _zW, 1, _nodeNum * _dataDim);
if (i > 0) {
vDSP_mtransD(_zU, 1, tU, 1, _nodeNum, _nodeNum);
vDSP_mmulD(tU, 1, zLoss, 1, temp1, 1, _nodeNum, 1, _nodeNum);
vDSP_vaddD(flowLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
vDSP_mmulD(zLoss, 1, (_hState + (i-1) * _nodeNum), 1, uLoss, 1, _nodeNum, _nodeNum, 1);
vDSP_vaddD(_zU, 1, uLoss, 1, _zU, 1, _nodeNum * _nodeNum);
}
// update r(t) parameters
if (i > 0) {
vDSP_vaddD(_rBias, 1, rLoss, 1, _rBias, 1, _nodeNum);
vDSP_mtransD(_rW, 1, tW, 1, _dataDim, _nodeNum);
vDSP_mmulD(tW, 1,rLoss, 1, temp2, 1, _dataDim, 1, _nodeNum);
vDSP_vaddD((_backLoss + i * _dataDim), 1, temp2, 1, (_backLoss + i * _dataDim), 1, _dataDim);
vDSP_mmulD(rLoss, 1, (_input + i * _dataDim), 1, inWLoss, 1, _nodeNum, _dataDim, 1);
vDSP_vaddD(_rW, 1, inWLoss, 1, _rW, 1, _nodeNum * _dataDim);
vDSP_mtransD(_rU, 1, tU, 1, _nodeNum, _nodeNum);
vDSP_mmulD(tU, 1, rLoss, 1, temp1, 1, _nodeNum, 1, _nodeNum);
vDSP_vaddD(flowLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
vDSP_mmulD(rLoss, 1, (_hState + (i-1) * _nodeNum), 1, uLoss, 1, _nodeNum, _nodeNum, 1);
vDSP_vaddD(_rU, 1, uLoss, 1, _rU, 1, _nodeNum * _nodeNum);
}
}
free(flowLoss);
free(outTW);
free(outLoss);
free(outWLoss);
free(temp1);
free(one);
free(zLoss);
free(hbLoss);
free(inWLoss);
free(rLoss);
free(tU);
free(uLoss);
free(tW);
free(temp2);
return _backLoss;
}
@end
結(jié)語(yǔ)
這里同樣用MNIST數(shù)據(jù)訓(xùn)練了單層LSTM的效果它匕,參數(shù)選用單個(gè)神經(jīng)元節(jié)點(diǎn)500展融,迭代1300窖认,一次5張圖片豫柬,得到90%左右正確率。
多次嘗試發(fā)現(xiàn)神經(jīng)元節(jié)點(diǎn)個(gè)數(shù)越大扑浸,單次迭代訓(xùn)練時(shí)間越長(zhǎng)烧给,準(zhǔn)確率越高。所以將節(jié)點(diǎn)個(gè)數(shù)設(shè)到500喝噪,為了加快速度將一次迭代圖片數(shù)由RNN網(wǎng)絡(luò)的100降到5張础嫡,但是整個(gè)過程還是花了3個(gè)多小時(shí)??。其效果不及CNN酝惧、RNN在相似環(huán)境下的表現(xiàn)榴鼎。
有興趣的朋友可以點(diǎn)這里看完整代碼。