【深度殘差收縮網(wǎng)絡(luò)】算法原理及TFLearn實現(xiàn)

深度殘差收縮網(wǎng)絡(luò)是一種新的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),實際上是深度殘差網(wǎng)絡(luò)的升級版本,能夠在一定程度上提高深度學(xué)習(xí)方法在含噪數(shù)據(jù)上的特征學(xué)習(xí)效果晶通。

首先,簡要回顧一下深度殘差網(wǎng)絡(luò)哟玷,其基本模塊如下圖所示狮辽。相較于傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò),深度殘差網(wǎng)絡(luò)利用了跨越多層的恒等映射巢寡,來緩解模型訓(xùn)練的難度喉脖,提高準(zhǔn)確性。


深度殘差網(wǎng)絡(luò)的基本模塊

然后抑月,和深度殘差網(wǎng)絡(luò)不同的是动看,深度殘差收縮網(wǎng)絡(luò)引入了一個小型的子網(wǎng)絡(luò),用這個子網(wǎng)絡(luò)學(xué)習(xí)得到一組閾值爪幻,對特征圖的各個通道進(jìn)行軟閾值化菱皆。這個過程其實可以看成一個可訓(xùn)練的特征選擇的過程须误。具體而言,就是通過前面的卷積層將重要的特征轉(zhuǎn)換成絕對值較大的值仇轻,將冗余信息所對應(yīng)的特征轉(zhuǎn)換成絕對值較小的值京痢;通過子網(wǎng)絡(luò)學(xué)習(xí)得到二者之間的界限,并且通過軟閾值化將冗余特征置為零篷店,同時使重要的特征有著非零的輸出祭椰。


深度殘差收縮網(wǎng)絡(luò)的基本模塊

深度殘差收縮網(wǎng)絡(luò)其實是一種通用的方法,不僅可以用于含噪數(shù)據(jù)疲陕,也可以用于不含噪聲的情況方淤。這是因為,深度殘差收縮網(wǎng)絡(luò)中的閾值是根據(jù)樣本情況自適應(yīng)確定的蹄殃。換言之携茂,如果樣本中不含冗余信息、不需要軟閾值化诅岩,那么閾值可以被訓(xùn)練得非常接近于零讳苦,從而軟閾值化就相當(dāng)于不存在了。

最后吩谦,堆疊一定數(shù)量的基本模塊鸳谜,就得到了完整的網(wǎng)絡(luò)結(jié)構(gòu)。


深度殘差收縮網(wǎng)絡(luò)的整體結(jié)構(gòu)

利用深度殘差收縮網(wǎng)絡(luò)進(jìn)行MNIST圖像識別式廷,可以看到咐扭,效果還是不錯的。下面是深度殘差收縮網(wǎng)絡(luò)的代碼:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 26 07:46:00 2019

Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

@author: me
"""

import tflearn
import tensorflow as tf
from tflearn.layers.conv import conv_2d

# Data loading
from tflearn.datasets import mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1,28,28,1])
testX = testX.reshape([-1,28,28,1])

def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
    
    # residual shrinkage blocks with channel-wise thresholds

    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)

    with vscope as scope:
        name = scope.name #TODO

        for i in range(nb_blocks):

            identity = residual

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)

            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
            
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
            

            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)

            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels

            residual = residual + identity

    return residual


# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Building A Deep Residual Shrinkage Network
net = tflearn.input_data(shape=[None, 28, 28, 1])
net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1,  8, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_mnist',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)

model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')

training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

接下來是深度殘差網(wǎng)絡(luò)ResNet的代碼:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 26 07:46:00 2019

Implemented using TensorFlow 1.0 and TFLearn 0.3.2
K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, CVPR, 2016.

@author: me
"""

import tflearn

# Data loading
from tflearn.datasets import mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1,28,28,1])
testX = testX.reshape([-1,28,28,1])

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Building a deep residual network
net = tflearn.input_data(shape=[None, 28, 28, 1])
net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, 1,  8, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_mnist',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)

model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')

training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

上述兩個程序構(gòu)建了只有1個基本模塊的小型網(wǎng)絡(luò)滑废,MNIST數(shù)據(jù)集中沒有添加噪聲草描,每次運(yùn)行結(jié)果會有些不同。準(zhǔn)確率如下表所示策严,可以看到穗慕,即使是對于不含噪聲的數(shù)據(jù),深度殘差收縮網(wǎng)絡(luò)的結(jié)果也是不錯的:


實驗結(jié)果

轉(zhuǎn)載網(wǎng)址:

https://my.oschina.net/u/4223274/blog/3148949

參考文獻(xiàn):

M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末妻导,一起剝皮案震驚了整個濱河市逛绵,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌倔韭,老刑警劉巖术浪,帶你破解...
    沈念sama閱讀 222,590評論 6 517
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異寿酌,居然都是意外死亡胰苏,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 95,157評論 3 399
  • 文/潘曉璐 我一進(jìn)店門醇疼,熙熙樓的掌柜王于貴愁眉苦臉地迎上來硕并,“玉大人法焰,你說我怎么就攤上這事【蟊校” “怎么了埃仪?”我有些...
    開封第一講書人閱讀 169,301評論 0 362
  • 文/不壞的土叔 我叫張陵,是天一觀的道長陕赃。 經(jīng)常有香客問我卵蛉,道長,這世上最難降的妖魔是什么么库? 我笑而不...
    開封第一講書人閱讀 60,078評論 1 300
  • 正文 為了忘掉前任傻丝,我火速辦了婚禮,結(jié)果婚禮上诉儒,老公的妹妹穿的比我還像新娘葡缰。我一直安慰自己,他們只是感情好允睹,可當(dāng)我...
    茶點(diǎn)故事閱讀 69,082評論 6 398
  • 文/花漫 我一把揭開白布运准。 她就那樣靜靜地躺著幌氮,像睡著了一般缭受。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上该互,一...
    開封第一講書人閱讀 52,682評論 1 312
  • 那天米者,我揣著相機(jī)與錄音,去河邊找鬼宇智。 笑死蔓搞,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的随橘。 我是一名探鬼主播喂分,決...
    沈念sama閱讀 41,155評論 3 422
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼机蔗!你這毒婦竟也來了蒲祈?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 40,098評論 0 277
  • 序言:老撾萬榮一對情侶失蹤萝嘁,失蹤者是張志新(化名)和其女友劉穎梆掸,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體牙言,經(jīng)...
    沈念sama閱讀 46,638評論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡酸钦,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,701評論 3 342
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了咱枉。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片卑硫。...
    茶點(diǎn)故事閱讀 40,852評論 1 353
  • 序言:一個原本活蹦亂跳的男人離奇死亡徒恋,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出拔恰,到底是詐尸還是另有隱情因谎,我是刑警寧澤,帶...
    沈念sama閱讀 36,520評論 5 351
  • 正文 年R本政府宣布颜懊,位于F島的核電站财岔,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏河爹。R本人自食惡果不足惜匠璧,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,181評論 3 335
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望咸这。 院中可真熱鬧夷恍,春花似錦、人聲如沸媳维。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,674評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽枫绅。三九已至鹊汛,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間州丹,已是汗流浹背醋安。 一陣腳步聲響...
    開封第一講書人閱讀 33,788評論 1 274
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留墓毒,地道東北人吓揪。 一個月前我還...
    沈念sama閱讀 49,279評論 3 379
  • 正文 我出身青樓,卻偏偏與公主長得像所计,于是被迫代替她去往敵國和親柠辞。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,851評論 2 361