RNN(循環(huán)神經(jīng)網(wǎng)絡(luò))訓(xùn)練手寫數(shù)字

簡(jiǎn)介


RNN(recurrent neural network )循環(huán)(遞歸)神經(jīng)網(wǎng)絡(luò)主要用來(lái)處理序列數(shù)據(jù)首繁。因?yàn)閭鹘y(tǒng)的神經(jīng)網(wǎng)絡(luò)從輸入-隱含層-輸出是全連接的,層中的神經(jīng)元是沒有連接的翎卓,所以對(duì)于輸入數(shù)據(jù)本身具有時(shí)序性(例如輸入的文本數(shù)據(jù)抵乓,每個(gè)單詞之間有一定聯(lián)系)的處理表現(xiàn)并不理想店归。而RNN每一個(gè)輸出與前面的輸出建立起關(guān)聯(lián)刃麸,這樣就能夠很好的處理序列化的數(shù)據(jù)醒叁。
單純循環(huán)神經(jīng)網(wǎng)絡(luò)也面臨一些問題,如無(wú)法處理隨著遞歸泊业,權(quán)重指數(shù)級(jí)爆炸或消失的問題把沼,難以捕捉長(zhǎng)期時(shí)間關(guān)聯(lián)。這些可以結(jié)合不同的LSTM很好的解決這個(gè)問題吁伺。
本文主要介紹簡(jiǎn)單的RNN用OC的實(shí)現(xiàn)饮睬,并通過訓(xùn)練MNIST數(shù)據(jù)來(lái)檢測(cè)模型。后面有時(shí)間再介紹LSTM的實(shí)現(xiàn)篮奄。

公式


簡(jiǎn)單的RNN就三層捆愁,輸入-隱含層-輸出,如下:

將其展開的模型如下:

其中窟却,A這個(gè)隱含層的操作就是將當(dāng)前輸入與前面的輸出相結(jié)合昼丑,然后激活就得到當(dāng)前狀態(tài)信號(hào)。如下:

計(jì)算公式如下:

其中Xt是輸入數(shù)據(jù)序列间校,St是的狀態(tài)序列矾克,V*St就是圖中Ot輸出页慷,softmax運(yùn)算并沒有畫出來(lái)憔足。

由于RNN結(jié)構(gòu)簡(jiǎn)單,反向傳播的公式結(jié)合一點(diǎn)數(shù)理知識(shí)就可以求得酒繁,這里就不列出滓彰,詳見代碼實(shí)現(xiàn)。

數(shù)據(jù)處理


由于沒找到比較好的訓(xùn)練數(shù)據(jù)州袒,這里用的是前面《OC實(shí)現(xiàn)Softmax識(shí)別手寫數(shù)字》文章里面的MNIST數(shù)據(jù)源揭绑。輸入數(shù)據(jù)處理、softmax實(shí)現(xiàn)也都是復(fù)用的。
圖片數(shù)據(jù)本質(zhì)上并非是序列化的他匪,我這里將圖片的每行的的像素?cái)?shù)據(jù)當(dāng)作一個(gè)信號(hào)輸入菇存,如果一共N行,序列長(zhǎng)度就是N邦蜜。訓(xùn)練數(shù)據(jù)是28*28維的圖片依鸥,那么就是每個(gè)信號(hào)是28*1,一共時(shí)間長(zhǎng)度是28悼沈。

RNN實(shí)現(xiàn)


簡(jiǎn)單的RNN實(shí)現(xiàn)流程并不復(fù)雜贱迟,需要訓(xùn)練的參數(shù)就5個(gè):輸入的權(quán)值、神經(jīng)元間轉(zhuǎn)移的權(quán)值絮供、輸出的權(quán)值衣吠、以及兩個(gè)轉(zhuǎn)移和輸出的偏置量。直接看代碼:

//
//  MLRnn.m
//  LSTM
//
//  Created by Jiao Liu on 11/9/16.
//  Copyright ? 2016 ChangHong. All rights reserved.
//

#import "MLRnn.h"

@implementation MLRnn

#pragma mark - Inner Method

+ (double)truncated_normal:(double)mean dev:(double)stddev
{
    double outP = 0.0;
    do {
        static int hasSpare = 0;
        static double spare;
        if (hasSpare) {
            hasSpare = 0;
            outP = mean + stddev * spare;
            continue;
        }
        
        hasSpare = 1;
        static double u,v,s;
        do {
            u = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
            v = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
            s = u * u + v * v;
        } while ((s >= 1.0) || (s == 0.0));
        s = sqrt(-2.0 * log(s) / s);
        spare = v * s;
        outP = mean + stddev * u * s;
    } while (fabsl(outP) > 2*stddev);
    return outP;
}

+ (double *)fillVector:(double)num size:(int)size
{
    double *outP = malloc(sizeof(double) * size);
    vDSP_vfillD(&num, outP, 1, size);
    return outP;
    
}

+ (double *)weight_init:(int)size
{
    double *outP = malloc(sizeof(double) * size);
    for (int i = 0; i < size; i++) {
        outP[i] = [MLRnn truncated_normal:0 dev:0.1];
    }
    return outP;
}

+ (double *)bias_init:(int)size
{
    return [MLRnn fillVector:0.1f size:size];
}

+ (double *)tanh:(double *)input size:(int)size
{
    for (int i = 0; i < size; i++) {
        double num = input[i];
        if (num > 20) {
            input[i] = 1;
        }
        else if (num < -20)
        {
            input[i] = -1;
        }
        else
        {
            input[i] = (exp(num) - exp(-num)) / (exp(num) + exp(-num));
        }
    }
    return input;
}

#pragma mark - Init

- (id)initWithNodeNum:(int)num layerSize:(int)size dataDim:(int)dim
{
    self = [super init];
    if (self) {
        _nodeNum = num;
        _layerSize = size;
        _dataDim = dim;
        [self setupNet];
    }
    return self;
}

- (id)init
{
    self = [super init];
    if (self) {
        [self setupNet];
    }
    return self;
}

- (void)setupNet
{
    _inWeight = [MLRnn weight_init:_nodeNum * _dataDim];
    _outWeight = [MLRnn weight_init:_nodeNum * _dataDim];
    _flowWeight = [MLRnn weight_init:_nodeNum * _nodeNum];
    _outBias = calloc(_dataDim, sizeof(double));
    _flowBias = calloc(_nodeNum, sizeof(double));
    _output = calloc(_layerSize * _dataDim, sizeof(double));
    _state = calloc(_layerSize * _nodeNum, sizeof(double));
}

#pragma mark - Main Method

- (double *)forwardPropagation:(double *)input
{
    _input = input;
    // clean data
    double zero = 0;
    vDSP_vfillD(&zero, _output, 1, _layerSize * _dataDim);
    vDSP_vfillD(&zero, _state, 1, _layerSize * _nodeNum);
    
    for (int i = 0; i < _layerSize; i++) {
        double *temp1 = calloc(_nodeNum, sizeof(double));
        double *temp2 = calloc(_nodeNum, sizeof(double));
        if (i == 0) {
            vDSP_mmulD(_inWeight, 1, (input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_vaddD(temp1, 1,_flowBias, 1, temp1, 1, _nodeNum);
        }
        else
        {
            vDSP_mmulD(_inWeight, 1, (input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_mmulD(_flowWeight, 1, (_state + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
            vDSP_vaddD(temp1, 1,_flowBias, 1, temp1, 1, _nodeNum);
        }
        [MLRnn tanh:temp1 size:_nodeNum];
        vDSP_vaddD((_state + i * _nodeNum), 1, temp1, 1, (_state + i * _nodeNum), 1, _nodeNum);
        vDSP_mmulD(_outWeight, 1, temp1, 1, (_output + i * _dataDim), 1, _dataDim, 1, _nodeNum);
        vDSP_vaddD((_output + i * _dataDim), 1, _outBias, 1,  (_output + i * _dataDim), 1, _dataDim);
        
        free(temp1);
        free(temp2);
    }
    
    return _output;
}

- (void)backPropagation:(double *)loss
{
    double *flowLoss = calloc(_nodeNum, sizeof(double));
    for (int i = _layerSize - 1; i >= 0 ; i--) {
        vDSP_vaddD(_outBias, 1, (loss + i * _dataDim), 1, _outBias, 1, _dataDim);
        double *transWeight = calloc(_nodeNum * _dataDim, sizeof(double));
        vDSP_mtransD(_outWeight, 1, transWeight, 1, _nodeNum, _dataDim);
        double *tanhLoss = calloc(_nodeNum, sizeof(double));
        vDSP_mmulD(transWeight, 1, (loss + i * _dataDim), 1, tanhLoss, 1, _nodeNum, 1, _dataDim);
        double *outWeightLoss = calloc(_nodeNum * _dataDim, sizeof(double));
        vDSP_mmulD((loss + i * _dataDim), 1, (_state + i * _nodeNum), 1, outWeightLoss, 1, _dataDim, _nodeNum, 1);
        vDSP_vaddD(_outWeight, 1, outWeightLoss, 1, _outWeight, 1, _nodeNum * _dataDim);
        
        double *tanhIn = calloc(_nodeNum, sizeof(double));
        vDSP_vsqD((_state + i * _nodeNum), 1, tanhIn, 1, _nodeNum);
        double *one = [MLRnn fillVector:1 size:_nodeNum];
        vDSP_vsubD(tanhIn, 1, one, 1, tanhIn, 1, _nodeNum);
        if (i != _layerSize - 1) {
            vDSP_vaddD(tanhLoss, 1, flowLoss, 1, tanhLoss, 1, _nodeNum);
        }
        vDSP_vmulD(tanhLoss, 1, tanhIn, 1, tanhLoss, 1, _nodeNum);
        
        vDSP_vaddD(_flowBias, 1, tanhLoss, 1, _flowBias, 1, _nodeNum);
        if (i != 0) {
            double *transFlow = calloc(_nodeNum * _nodeNum, sizeof(double));
            vDSP_mtransD(_flowWeight, 1, transFlow, 1, _nodeNum, _nodeNum);
            vDSP_mmulD(transFlow, 1, tanhLoss, 1, flowLoss, 1, _nodeNum, 1, _nodeNum);
            free(transFlow);
            double *flowWeightLoss = calloc(_nodeNum * _nodeNum, sizeof(double));
            vDSP_mmulD(tanhLoss, 1, (_state + (i-1) * _nodeNum), 1, flowWeightLoss, 1, _nodeNum, _nodeNum, 1);
            vDSP_vaddD(_flowWeight, 1, flowWeightLoss, 1, _flowWeight, 1, _nodeNum * _nodeNum);
            free(flowWeightLoss);
        }

        double *inWeightLoss = calloc(_nodeNum * _dataDim, sizeof(double));
        vDSP_mmulD(tanhLoss, 1, (_input + i * _dataDim), 1, inWeightLoss, 1, _nodeNum, _dataDim, 1);
        vDSP_vaddD(_inWeight, 1, inWeightLoss, 1, _inWeight, 1, _nodeNum * _dataDim);
        
        free(transWeight);
        free(tanhLoss);
        free(outWeightLoss);
        free(tanhIn);
        free(one);
        free(inWeightLoss);
    }
    free(flowLoss);
    free(loss);
}

@end

很多初始化方法以及內(nèi)部函數(shù)直接是復(fù)用《OC實(shí)現(xiàn)(CNN)卷積神經(jīng)網(wǎng)絡(luò)》中相關(guān)的方法壤靶。

結(jié)語(yǔ)


我這里使用RNN缚俏,迭代2500次,每次訓(xùn)練100張圖片贮乳,單個(gè)神經(jīng)元節(jié)點(diǎn)個(gè)數(shù)選擇50袍榆,得到的正確率94%左右。

有興趣的朋友可以點(diǎn)這里看完整代碼塘揣。

本文參考:

  1. Understanding LSTM Networks
  2. recurrent-neural-networks-tutorial
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末包雀,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子亲铡,更是在濱河造成了極大的恐慌才写,老刑警劉巖,帶你破解...
    沈念sama閱讀 207,248評(píng)論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件奖蔓,死亡現(xiàn)場(chǎng)離奇詭異赞草,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)吆鹤,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,681評(píng)論 2 381
  • 文/潘曉璐 我一進(jìn)店門厨疙,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人疑务,你說我怎么就攤上這事沾凄。” “怎么了知允?”我有些...
    開封第一講書人閱讀 153,443評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵撒蟀,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我温鸽,道長(zhǎng)保屯,這世上最難降的妖魔是什么手负? 我笑而不...
    開封第一講書人閱讀 55,475評(píng)論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮姑尺,結(jié)果婚禮上竟终,老公的妹妹穿的比我還像新娘。我一直安慰自己切蟋,他們只是感情好衡楞,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,458評(píng)論 5 374
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著敦姻,像睡著了一般瘾境。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上镰惦,一...
    開封第一講書人閱讀 49,185評(píng)論 1 284
  • 那天迷守,我揣著相機(jī)與錄音,去河邊找鬼旺入。 笑死兑凿,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的茵瘾。 我是一名探鬼主播礼华,決...
    沈念sama閱讀 38,451評(píng)論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼拗秘!你這毒婦竟也來(lái)了圣絮?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,112評(píng)論 0 261
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤雕旨,失蹤者是張志新(化名)和其女友劉穎扮匠,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體凡涩,經(jīng)...
    沈念sama閱讀 43,609評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡棒搜,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,083評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了活箕。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片力麸。...
    茶點(diǎn)故事閱讀 38,163評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖育韩,靈堂內(nèi)的尸體忽然破棺而出克蚂,到底是詐尸還是另有隱情,我是刑警寧澤座慰,帶...
    沈念sama閱讀 33,803評(píng)論 4 323
  • 正文 年R本政府宣布陨舱,位于F島的核電站,受9級(jí)特大地震影響版仔,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,357評(píng)論 3 307
  • 文/蒙蒙 一蛮粮、第九天 我趴在偏房一處隱蔽的房頂上張望益缎。 院中可真熱鬧,春花似錦然想、人聲如沸莺奔。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,357評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)令哟。三九已至,卻和暖如春妨蛹,著一層夾襖步出監(jiān)牢的瞬間屏富,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,590評(píng)論 1 261
  • 我被黑心中介騙來(lái)泰國(guó)打工蛙卤, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留狠半,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,636評(píng)論 2 355
  • 正文 我出身青樓颤难,卻偏偏與公主長(zhǎng)得像神年,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子行嗤,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,925評(píng)論 2 344

推薦閱讀更多精彩內(nèi)容