深度殘差收縮網(wǎng)絡(luò)是一種新的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),實際上是深度殘差網(wǎng)絡(luò)的升級版本,能夠在一定程度上提高深度學(xué)習(xí)方法在含噪數(shù)據(jù)上的特征學(xué)習(xí)效果晶通。
首先,簡要回顧一下深度殘差網(wǎng)絡(luò)哟玷,其基本模塊如下圖所示狮辽。相較于傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò),深度殘差網(wǎng)絡(luò)利用了跨越多層的恒等映射巢寡,來緩解模型訓(xùn)練的難度喉脖,提高準(zhǔn)確性。
然后抑月,和深度殘差網(wǎng)絡(luò)不同的是动看,深度殘差收縮網(wǎng)絡(luò)引入了一個小型的子網(wǎng)絡(luò),用這個子網(wǎng)絡(luò)學(xué)習(xí)得到一組閾值爪幻,對特征圖的各個通道進(jìn)行軟閾值化菱皆。這個過程其實可以看成一個可訓(xùn)練的特征選擇的過程须误。具體而言,就是通過前面的卷積層將重要的特征轉(zhuǎn)換成絕對值較大的值仇轻,將冗余信息所對應(yīng)的特征轉(zhuǎn)換成絕對值較小的值京痢;通過子網(wǎng)絡(luò)學(xué)習(xí)得到二者之間的界限,并且通過軟閾值化將冗余特征置為零篷店,同時使重要的特征有著非零的輸出祭椰。
深度殘差收縮網(wǎng)絡(luò)其實是一種通用的方法,不僅可以用于含噪數(shù)據(jù)疲陕,也可以用于不含噪聲的情況方淤。這是因為,深度殘差收縮網(wǎng)絡(luò)中的閾值是根據(jù)樣本情況自適應(yīng)確定的蹄殃。換言之携茂,如果樣本中不含冗余信息、不需要軟閾值化诅岩,那么閾值可以被訓(xùn)練得非常接近于零讳苦,從而軟閾值化就相當(dāng)于不存在了。
最后吩谦,堆疊一定數(shù)量的基本模塊鸳谜,就得到了完整的網(wǎng)絡(luò)結(jié)構(gòu)。
利用深度殘差收縮網(wǎng)絡(luò)進(jìn)行MNIST圖像識別式廷,可以看到咐扭,效果還是不錯的。下面是深度殘差收縮網(wǎng)絡(luò)的代碼:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 26 07:46:00 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: me
"""
import tflearn
import tensorflow as tf
from tflearn.layers.conv import conv_2d
# Data loading
from tflearn.datasets import mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1,28,28,1])
testX = testX.reshape([-1,28,28,1])
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2, activation='relu', batch_norm=True,
bias=True, weights_init='variance_scaling',
bias_init='zeros', regularizer='L2', weight_decay=0.0001,
trainable=True, restore=True, reuse=False, scope=None,
name="ResidualBlock"):
# residual shrinkage blocks with channel-wise thresholds
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
# Variable Scope fix for older TF
try:
vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse)
except Exception:
vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
with vscope as scope:
name = scope.name #TODO
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3,
downsample_strides, 'same', 'linear',
bias, weights_init, bias_init,
regularizer, weight_decay, trainable,
restore)
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3, 1, 'same',
'linear', bias, weights_init,
bias_init, regularizer, weight_decay,
trainable, restore)
# get thresholds and apply thresholding
abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tflearn.batch_normalization(scales)
scales = tflearn.activation(scales, 'relu')
scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
# Downsampling
if downsample_strides > 1:
identity = tflearn.avg_pool_2d(identity, 1,
downsample_strides)
# Projection to new dimension
if in_channels != out_channels:
if (out_channels - in_channels) % 2 == 0:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch]])
else:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch+1]])
in_channels = out_channels
residual = residual + identity
return residual
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
# Building A Deep Residual Shrinkage Network
net = tflearn.input_data(shape=[None, 28, 28, 1])
net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_mnist',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]
接下來是深度殘差網(wǎng)絡(luò)ResNet的代碼:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 26 07:46:00 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, CVPR, 2016.
@author: me
"""
import tflearn
# Data loading
from tflearn.datasets import mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1,28,28,1])
testX = testX.reshape([-1,28,28,1])
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
# Building a deep residual network
net = tflearn.input_data(shape=[None, 28, 28, 1])
net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, 1, 8, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_mnist',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]
上述兩個程序構(gòu)建了只有1個基本模塊的小型網(wǎng)絡(luò)滑废,MNIST數(shù)據(jù)集中沒有添加噪聲草描,每次運(yùn)行結(jié)果會有些不同。準(zhǔn)確率如下表所示策严,可以看到穗慕,即使是對于不含噪聲的數(shù)據(jù),深度殘差收縮網(wǎng)絡(luò)的結(jié)果也是不錯的:
轉(zhuǎn)載網(wǎng)址:
https://my.oschina.net/u/4223274/blog/3148949
參考文獻(xiàn):
M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898