介紹
Xception是Google出品耘婚,屬于2017年左右的東東烤礁。它在Google家的MobileNet v1之后扯罐,MobileNet v2之前负拟。
它的主旨與MobileNet系列很像即推動(dòng)Depthwise Conv + Pointwise Conv的使用。只是它直接以Inception v3為模子歹河,將里面的基本inception module替換為使用Depthwise Conv + Pointwise Conv掩浙,又外加了residual connects,
最終模型在ImageNet等數(shù)據(jù)集上都取得了相比Inception v3與Resnet-152更好的結(jié)果。當(dāng)然其模型大小與計(jì)算效率相對(duì)Inception v3也取得了較大提高秸歧。
從Inception module到Separable Conv的演變之路
下圖1為一個(gè)典型的Inception module厨姚,它的實(shí)現(xiàn)的基本assumption就是feature在經(jīng)conv處理時(shí)可分別學(xué)習(xí)feature channels間的關(guān)聯(lián)關(guān)系與feature單個(gè)channel內(nèi)部空間上的關(guān)聯(lián)關(guān)系,為此inception module中使用了大量的1x1 conv來(lái)重視學(xué)習(xí)
channels之間的關(guān)聯(lián)键菱,然后再分別使用3x3/5x5(兩個(gè)3x3)等去學(xué)習(xí)其不同維度上的單個(gè)channel內(nèi)的空間上的關(guān)聯(lián)谬墙;若我們基于以上inception中用到的關(guān)聯(lián)關(guān)系分離假設(shè)而只使用3x3 convs來(lái)表示單個(gè)channel內(nèi)的空間關(guān)聯(lián)關(guān)系,那么就可以得到出下圖2表示的
簡(jiǎn)化后了的inception module。
而本質(zhì)上上圖2中表示的簡(jiǎn)化版Inception模塊又可被表示為下圖3中的形式拭抬〔磕可以看出實(shí)質(zhì)上它等價(jià)于先使用一個(gè)1x1 conv來(lái)學(xué)習(xí)input feature maps之上channels間特征的關(guān)聯(lián)關(guān)系,然后再將1x1 conv輸出的feature maps進(jìn)行分割造虎,分別交由下面的若干個(gè)3x3
conv來(lái)處理其內(nèi)的空間上元素的關(guān)聯(lián)關(guān)系傅蹂。
更進(jìn)一步,何不做事做絕將每個(gè)channel上的空間關(guān)聯(lián)分別使用一個(gè)相應(yīng)的conv 3x3來(lái)單獨(dú)處理呢算凿。如此就得到了下圖4中所示的Separable conv份蝴。
Xception架構(gòu)
下圖中為Xception結(jié)構(gòu)的表示。它就是由Inception v3直接演變而來(lái)氓轰。其中引入了Residual learning的結(jié)構(gòu)(已經(jīng)有多項(xiàng)工作搞乏,同時(shí)在本文中作者也有相關(guān)實(shí)驗(yàn)表明Residual learning在CNN模型中的使用可帶來(lái)收斂速度的加快。)戒努。
同一向復(fù)雜的Inception系列模型一樣,它也引入了Entry/Middle/Exit三個(gè)flow镐躲,每個(gè)flow內(nèi)部使用不同的重復(fù)模塊储玫,當(dāng)然最最核心的屬于中間不斷分析、過(guò)濾特征的Middel flow萤皂。
Entry flow主要是用來(lái)不斷下采樣撒穷,減小空間維度;中間則是不斷學(xué)習(xí)關(guān)聯(lián)關(guān)系裆熙,優(yōu)化特征端礼;最終則是匯總、整理特征入录,用于交由FC來(lái)進(jìn)行表達(dá)蛤奥。
實(shí)驗(yàn)結(jié)果
下表為Xception與其它模型在Imagenet上分類精度的結(jié)果比較。
然后下表則為Xception與Inception v3在模型參數(shù)大小與計(jì)算速度上的比較僚稿。
代碼分析
在TF的官方實(shí)現(xiàn)中凡桥,他們對(duì)Xception的模型結(jié)構(gòu)做了些改變,主要如下兩塊蚀同。一是使用stride為2的conv來(lái)代替使用Maxpooling進(jìn)行feature降維缅刽;二則是在depthwise conv之后同樣使用ReLu與BN。
We made a few more changes on top of MSRA's modifications:
1. Fully convolutional: All the max-pooling layers are replaced with separable
conv2d with stride = 2. This allows us to use atrous convolution to extract
feature maps at any resolution.
2. We support adding ReLU and BatchNorm after depthwise convolution, motivated
by the design of MobileNetv1.
以下為Xception的模型構(gòu)建入口函數(shù)蠢络。
def xception(inputs,
blocks,
num_classes=None,
is_training=True,
global_pool=True,
keep_prob=0.5,
output_stride=None,
reuse=None,
scope=None):
"""Generator for Xception models.
This function generates a family of Xception models. See the xception_*()
methods for specific model instantiations, obtained by selecting different
block instantiations that produce Xception of various depths.
Args:
inputs: A tensor of size [batch, height_in, width_in, channels]. Must be
floating point. If a pretrained checkpoint is used, pixel values should be
the same as during training (see go/slim-classification-models for
specifics).
blocks: A list of length equal to the number of Xception blocks. Each
element is an Xception Block object describing the units in the block.
num_classes: Number of predicted classes for classification tasks.
If 0 or None, we return the features before the logit layer.
is_training: whether batch_norm layers are in training mode.
global_pool: If True, we perform global average pooling before computing the
logits. Set to True for image classification, False for dense prediction.
keep_prob: Keep probability used in the pre-logits dropout layer.
output_stride: If None, then the output will be computed at the nominal
network stride. If output_stride is not None, it specifies the requested
ratio of input to output spatial resolution.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
Returns:
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
If global_pool is False, then height_out and width_out are reduced by a
factor of output_stride compared to the respective height_in and width_in,
else both height_out and width_out equal one. If num_classes is 0 or None,
then net is the output of the last Xception block, potentially after
global average pooling. If num_classes is a non-zero integer, net contains
the pre-softmax activations.
end_points: A dictionary from components of the network to the corresponding
activation.
Raises:
ValueError: If the target output_stride is not valid.
"""
with tf.variable_scope(
scope, 'xception', [inputs], reuse=reuse) as sc:
end_points_collection = sc.original_name_scope + 'end_points'
with slim.arg_scope([slim.conv2d,
slim.separable_conv2d,
xception_module,
stack_blocks_dense],
outputs_collections=end_points_collection):
with slim.arg_scope([slim.batch_norm], is_training=is_training):
net = inputs
if output_stride is not None:
if output_stride % 2 != 0:
raise ValueError('The output_stride needs to be a multiple of 2.')
output_stride /= 2
# Root block function operated on inputs.
net = resnet_utils.conv2d_same(net, 32, 3, stride=2,
scope='entry_flow/conv1_1')
net = resnet_utils.conv2d_same(net, 64, 3, stride=1,
scope='entry_flow/conv1_2')
# Extract features for entry_flow, middle_flow, and exit_flow.
net = stack_blocks_dense(net, blocks, output_stride)
# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(
end_points_collection, clear_collection=True)
if global_pool:
# Global average pooling.
net = tf.reduce_mean(net, [1, 2], name='global_pool', keepdims=True)
end_points['global_pool'] = net
if num_classes:
net = slim.dropout(net, keep_prob=keep_prob, is_training=is_training,
scope='prelogits_dropout')
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits')
end_points[sc.name + '/logits'] = net
end_points['predictions'] = slim.softmax(net, scope='predictions')
return net, end_points
以下函數(shù)則是構(gòu)建具體的blocks衰猛。。感覺(jué)Google有過(guò)度設(shè)計(jì)的嫌疑刹孔,非得將一個(gè)簡(jiǎn)單模型構(gòu)得那么多層次啡省,即使是slim這樣已經(jīng)高級(jí)了些的model讓人看起來(lái)仍然要追著看很久。。推薦大家使用Pytorch啊冕杠。微姊。:)。
就到這了分预,再向下不追了兢交,還想進(jìn)一步探究的朋友可以直接去翻參考文獻(xiàn)里標(biāo)明的code file。
@slim.add_arg_scope
def stack_blocks_dense(net,
blocks,
output_stride=None,
outputs_collections=None):
"""Stacks Xception blocks and controls output feature density.
First, this function creates scopes for the Xception in the form of
'block_name/unit_1', 'block_name/unit_2', etc.
Second, this function allows the user to explicitly control the output
stride, which is the ratio of the input to output spatial resolution. This
is useful for dense prediction tasks such as semantic segmentation or
object detection.
Control of the output feature density is implemented by atrous convolution.
Args:
net: A tensor of size [batch, height, width, channels].
blocks: A list of length equal to the number of Xception blocks. Each
element is an Xception Block object describing the units in the block.
output_stride: If None, then the output will be computed at the nominal
network stride. If output_stride is not None, it specifies the requested
ratio of input to output spatial resolution, which needs to be equal to
the product of unit strides from the start up to some level of Xception.
For example, if the Xception employs units with strides 1, 2, 1, 3, 4, 1,
then valid values for the output_stride are 1, 2, 6, 24 or None (which
is equivalent to output_stride=24).
outputs_collections: Collection to add the Xception block outputs.
Returns:
net: Output tensor with stride equal to the specified output_stride.
Raises:
ValueError: If the target output_stride is not valid.
"""
# The current_stride variable keeps track of the effective stride of the
# activations. This allows us to invoke atrous convolution whenever applying
# the next residual unit would result in the activations having stride larger
# than the target output_stride.
current_stride = 1
# The atrous convolution rate parameter.
rate = 1
for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]) as sc:
for i, unit in enumerate(block.args):
if output_stride is not None and current_stride > output_stride:
raise ValueError('The target output_stride cannot be reached.')
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
# If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers.
if output_stride is not None and current_stride == output_stride:
net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
rate *= unit.get('stride', 1)
else:
net = block.unit_fn(net, rate=1, **unit)
current_stride *= unit.get('stride', 1)
# Collect activations at the block's end before performing subsampling.
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
if output_stride is not None and current_stride != output_stride:
raise ValueError('The target output_stride cannot be reached.')
return net
參考文獻(xiàn)
- Xception: Deep Learning with Depthwise Separable Convolutions, Franc?ois-Chollet, 2017
- https://github.com/tensorflow/models/blob/master/research/deeplab/core/xception.py