轉(zhuǎn)自Tensorflow slim庫使用小記
看fensorflow的書發(fā)現(xiàn)使用的是slim庫,那就要研究slim的常用函數(shù)柳譬,這個(gè)文章寫的很好喳张,轉(zhuǎn)一下哈。
slim庫的導(dǎo)入:
import tensorflow as tf
import tensorflow.contrib.slim as slim
常用函數(shù):
與tensorflow自帶的函數(shù)相比美澳,slim能夠讓我們不用重復(fù)寫函數(shù)的參數(shù)销部。那么函數(shù)的參數(shù)寫在哪里呢?核心方法就是slim.arg_scope制跟。
slim.arg_scope
def arg_scope(list_ops_or_scope, **kwargs)
list_ops_or_scope:要用的函數(shù)的作用域柴墩,可以在需要使用的地方用@add_arg_scope 聲明
**kwargs: keyword=value 定義了list_ops中要使用的變量
也就是說可以通過這個(gè)函數(shù)將不想重復(fù)寫的參數(shù)通過這個(gè)函數(shù)自動(dòng)賦值。
示例:
import tensorflow.contrib.slim as slim
@slim.add_arg_scope
def hh(name, add_arg):
print("name:", name)
print("add_arg:", add_arg)
with slim.arg_scope([hh], add_arg='this is add'):
hh('test')
#結(jié)果:
#name: test
#add_arg: this is add
進(jìn)入add_arg_scope函數(shù)查看代碼可知:
def add_arg_scope(func):
"""Decorates a function with args so it can be used within an arg_scope.
Args:
func: function to decorate.
Returns:
A tuple with the decorated function func_with_args().
"""
def func_with_args(*args, **kwargs):
current_scope = _current_arg_scope()
current_args = kwargs
key_func = _key_op(func)
if key_func in current_scope:
current_args = current_scope[key_func].copy()
current_args.update(kwargs)
return func(*args, **current_args)
_add_op(func)
setattr(func_with_args, '_key_op', _key_op(func))
return tf_decorator.make_decorator(func, func_with_args)
其實(shí)就是看看你調(diào)用的是那個(gè)函數(shù)凫岖,給參數(shù)中添加你之前賦值的參數(shù)江咳。
之后是使用slim構(gòu)建神經(jīng)網(wǎng)絡(luò)常用的函數(shù)。
slim.conv2d
slim.conv2d是對tf.conv2d的進(jìn)一步封裝哥放。常見調(diào)用方式:
net = slim.conv2d(inputs, 256, [3, 3], stride=1, scope='conv1_1')
源代碼:
@add_arg_scope
def convolution(inputs,num_outputs,
kernel_size,
stride=1,
padding='SAME',
data_format=None,
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None)
常用的有:
padding : 補(bǔ)零的方式歼指,例如'SAME'
activation_fn : 激活函數(shù)爹土,默認(rèn)是nn.relu
normalizer_fn : 正則化函數(shù),默認(rèn)為None踩身,這里可以設(shè)置為batch normalization胀茵,函數(shù)用slim.batch_norm
normalizer_params : slim.batch_norm中的參數(shù),以字典形式表示
weights_initializer : 權(quán)重的初始化器挟阻,initializers.xavier_initializer()
weights_regularizer : 權(quán)重的正則化器琼娘,一般不怎么用到
biases_initializer : 如果之前有batch norm,那么這個(gè)及下面一個(gè)就不用管了
biases_regularizer :
trainable : 參數(shù)是否可訓(xùn)練附鸽,默認(rèn)為True
scope:你繪制的網(wǎng)絡(luò)結(jié)構(gòu)圖中它屬于那個(gè)范圍內(nèi)
slim.max_pool2d
net = slim.max_pool2d(net, [2, 2], scope='pool1')
前兩個(gè)參數(shù)分別為網(wǎng)絡(luò)輸入脱拼、輸出的神經(jīng)元數(shù)量,第三個(gè)同上坷备。
slim.batch_norm
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
activation_fn=None,
param_initializers=None,
param_regularizers=None,
updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
fused=False,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99):
這個(gè)我沒有理解熄浓。以下是原博客說的。
接下來說我在用slim.batch_norm時(shí)踩到的坑省撑。slim.batch_norm里有moving_mean和moving_variance兩個(gè)量赌蔑,分別表示每個(gè)批次的均值和方差。在訓(xùn)練時(shí)還好理解竟秫,但在測試時(shí)娃惯,moving_mean和moving_variance的含義變了。在訓(xùn)練時(shí)肥败,有一些語句是必不可少的:
# 定義占位符石景,X表示網(wǎng)絡(luò)的輸入,Y表示真實(shí)值label
X = tf.placeholder("float", [None, 224, 224, 3])
Y = tf.placeholder("float", [None, 100])
#調(diào)用含batch_norm的resnet網(wǎng)絡(luò)拙吉,其中記得is_training=True
logits = model.resnet(X, 100, is_training=True)
cross_entropy = -tf.reduce_sum(Y*tf.log(logits))
#訓(xùn)練的op一定要用slim的slim.learning.create_train_op潮孽,只用tf.train.MomentumOptimizer.minimize()是不行的
opt = tf.train.MomentumOptimizer(lr_rate, 0.9)
train_op = slim.learning.create_train_op(cross_entropy, opt, global_step=global_step)
#更新操作,具體含義不是很明白筷黔,直接套用即可
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.group(*update_ops)
cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)
之后的訓(xùn)練都和往常一樣了往史,導(dǎo)出模型后,在測試階段調(diào)用相同的網(wǎng)絡(luò)佛舱,參數(shù)is_training一定要設(shè)置成False椎例。```
logits = model.resnet(X, 100, is_training=False)
否則,可能會(huì)出現(xiàn)這種情況:所有的單個(gè)圖像分類请祖,最后幾乎全被歸為同一類订歪。這可能就是訓(xùn)練模式設(shè)置反了的問題。