TensorFlow 實(shí)現(xiàn)語義分割模型:DeepLab V3+(占坑,因 TensorFlow 2.0 改版很大寥殖,以前很多 API 都將取消,所以博主停更了,但仍歡迎多多交流)

本文將實(shí)現(xiàn) deeplab v3 + 模型(參考:DeepLab 官方開源代碼

# -*- coding: utf-8 -*-
"""
Created on Mon Dec  3 17:57:46 2018

@author: shirhe-lyh


Implementation of DeepLab V3+:
    Encoder-Decoder with atrous seperable convolutioon for semantic image
    segmentation, Liang-Chieh Chen, et. al., arXiv:1802.02611v3.
"""

import numpy as np
import tensorflow as tf

from tensorflow.contrib.slim import nets

import preprocessing
import resnet_v1_beta

slim = tf.contrib.slim


class DeepLab(object):
    """Implementation of DeepLab V3+."""
    
    def __init__(self,
                 is_training,
                 num_classes=3,
                 output_stride=16,
                 atrous_rates=[6, 12, 18],  # [12, 24, 36] for output_stride=8
                 decoder_output_stride=4,
                 default_image_size=513,
                 fine_tune_batch_norm=False):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: The number of classes.
            defualt_image_size: The input size of the model.
        """
        self._is_training = is_training
        self._num_classes = num_classes
        self._output_stride = output_stride
        self._atrous_rates = atrous_rates
        self._decoder_output_stride = decoder_output_stride
        self._default_image_size = default_image_size
        
        # When fine_tune_batch_norm=True, use at least batch size larger than 
        # 12 (batch size more than 16 is better). Otherwise, one could use 
        # smaller batch size and set fine_tune_batch_norm=False.
        _is_training = is_training and fine_tune_batch_norm
        self._batch_norm_params = {'is_training': _is_training,
                                   'epsilon': 1e-5,
                                   'decay': 0.9997,
                                   'scale': True}
        
    @property
    def default_image_size(self):
        return self._default_image_size
        
    def preprocess(self, images=None, masks=None):
        """Preprocessing.
        
        Args:
            images: A float32 tensor with shape [batch_size, height, width,
                3] representing a batch of images. Only passed values in case
                of test (i.e., in training case images=None).
            masks: A float32 tensor with shape [batch_size, height, width, 1] 
                representing a batch of groundtruth masks.
            
        Returns:
            The preprocessed inputs.
        """
        
        preprocessed_dict = {'images': images_preprocessed,
                             'masks': trimaps_preprocessed}
        return preprocessed_dict
    
    def _preprocess_zero_mean_unit_range(self, inputs):
        """Map image values from [0, 255] to [-1, 1].
        
        Only for beta version.
        """
        return (2.0 / 255.0) * tf.to_float(inputs) - 1.0
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A 4-D float32 tensor with shape [batch_size, 
                height, width, channels].
            
        Returns:
            The prediction tensors to be passed to the Loss or Postprocess 
            functions.
        """
        # ResNet-50
        with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
            net, end_points = resnet_v1_beta.resnet_v1_50_beta(
                preprocessed_inputs, num_classes=None,
                is_training=self._is_training,
                multi_grid=[1, 2, 4],
                global_pool=False,
                output_stride=self._output_stride)
            
        # Use the same scope with ResNet-50
        scope='resnet_v1_50'
        
        # Atrous spatial pyramid pooling
        net = self._atrous_spatial_pyramid_pooling(
            net, atrous_rates=self._atrous_rates, scope=scope)
        
        # Refine by decoder
        decoder_height = self.default_image_size // self._decoder_output_stride
        decoder_width = self.default_image_size // self._decoder_output_stride
        net = self._refine_by_decoder(
            net,
            end_points,
            decoder_height=decoder_height,
            decoder_width=decoder_width,
            decoder_use_seperable_conv=True,
            is_training=self._is_training)
        
        # Convolution
        net = self._get_branch_logits(net, self._num_classes,
                                      self._atrous_rates, kernel_size=1)
        net = tf.image.resize_bilinear(net, size=[self._default_image_size,
                                                  self._default_image_size],
                                       align_corners=True,
                                       name='upsampling_logits')
        return net
    
    def split_seperable_conv2d(self,
                               inputs,
                               filters,
                               kernel_size=3,
                               rate=1,
                               weight_decay=0.00004,
                               depthwise_weights_initializer_stddev=0.33,
                               pointwise_weights_initializer_stddev=0.06,
                               scope=None):
        """Splits a seperable conv2d into depthwise and pointwise conv2d.
        
        This operation differs from `tf.layers.separable_conv2d` as this 
        operation applies activation function between depthwise and pointwise 
        conv2d.
        
        Copy from:
            https://github.com/tensorflow/models/blob/master/research/deeplab/
            core/utils.py
            
        Args:
            inputs: Input tensor with shape [batch, height, width, channels].
            filters: Number of filters in the 1x1 pointwise convolution.
            kernel_size: A list of length 2: [kernel_height, kernel_width] of
                of the filters. Can be an int if both values are the same.
            rate: Atrous convolution rate for the depthwise convolution.
            weight_decay: The weight decay to use for regularizing the model.
            depthwise_weights_initializer_stddev: The standard deviation of the
                truncated normal weight initializer for depthwise convolution.
            pointwise_weights_initializer_stddev: The standard deviation of the
                truncated normal weight initializer for pointwise convolution.
            scope: Optional scope for the operation.
            
        Returns:
            Computed features after split separable conv2d.
        """
        outputs = slim.separable_conv2d(
            inputs,
            None,
            kernel_size=kernel_size,
            depth_multiplier=1,
            rate=rate,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=depthwise_weights_initializer_stddev),
            weights_regularizer=None,
            scope=scope + '_depthwise')
        return slim.conv2d(
            outputs,
            filters,
            1,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=pointwise_weights_initializer_stddev),
            weights_regularizer=slim.l2_regularizer(weight_decay),
            scope=scope + '_pointwise')
              
    def _atrous_spatial_pyramid_pooling(self, feature_map, weight_decay=0.0001,
                                        atrous_rates=[12, 24, 36],
                                        scope='resnet_v1_50'):
        """Atrous spatial pyramid pooling for DeepLab v3."""
        branch_nets = []
        # Convolution
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d, slim.separable_conv2d], 
                                weights_regularizer=slim.l2_regularizer(
                                    weight_decay),
                                normalizer_fn=slim.batch_norm,
                                normalizer_params=self._batch_norm_params):
                depth=256
                
                # Image pooling feature
                shape = tf.shape(feature_map)[1:3]
                image_feature = tf.reduce_mean(feature_map, axis=[1, 2],
                                               keep_dims=True)
                image_feature = slim.conv2d(image_feature, kernel_size=1,
                                            num_outputs=depth,
                                            scope='global_pool')
                image_feature = tf.image.resize_bilinear(image_feature, 
                                                         size=shape,
                                                         align_corners=True)
                branch_nets.append(image_feature)
                
                # Employ a 1x1 convolution
                branch_nets.append(slim.conv2d(feature_map, kernel_size=1,
                                               num_outputs=depth,
                                               scope='aspp' + str(0)))          
                
                # Employ 3x3 convolutions with different atrous rates.
                for i, rate in enumerate(atrous_rates, 1):
                    scope =scope + 'aspp' + str(i)
                    aspp_net = self.split_seperable_conv2d(
                        feature_map,
                        filters=depth,
                        rate=rate,
                        weight_decay=weight_decay,
                        scope=scope)
                    branch_nets.append(aspp_net)
        
        # Concatenation
        net = tf.concat(branch_nets, axis=3, name='aspp_concate')
        net = slim.conv2d(net, depth, kernel_size=1, 
                          scope=scope + '/concat_projection')
        net = slim.dropout(net, keep_prob=0.9, is_training=self._is_training,
                           scope= scope + '/concat_projection_dropout')
        return net
    
    def _refine_by_decoder(self,
                           feature_map,
                           end_points,
                           decoder_height,
                           decoder_width,
                           decoder_use_seperable_conv=False,
                           weight_decay=0.0001,
                           reuse=None,
                           is_training=False,
                           scope='resnet_v1_50'):
        """Adds the decoder to obtain sharper segmentation results.
        
        Args:
            feature_map: A tensor with shape [batch_size, height, width, depth].
            end_points: A dictionary from components of the network to the 
                corresponding activation.
            decoder_height: The height of decoder feature maps.
            decoder_width: The width of decoder feature maps.
            decoder_use_seperable_conv: Employ seperable convolution for 
                decoder or not.
            weight_decay: The weight decay for model variables.
            reuse: Reuse the model variables or not.
            is_training: Is training or not.
            #fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
            
        Returns:
            Decoder output size [batch_size, decoder_height, decoder_width,
            decoder_depth].
        """
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
                                weights_regularizer=slim.l2_regularizer(
                                    weight_decay),
                                normalizer_fn=slim.batch_norm,
                                normalizer_params=self._batch_norm_params,
                                reuse=reuse):
                feature_list = ['block1/unit_2/bottleneck_v1/conv3']
                decoder_features = feature_map
                for i, name in enumerate(feature_list):
                    decoder_features_list = [decoder_features]
                    feature_name = '{}/{}'.format('resnet_v1_50', name)
                    decoder_features_list.append(
                        slim.conv2d(end_points[feature_name], 48, 1,
                                   scope='feature_project' + str(i)))
                    for j, feature in enumerate(decoder_features_list):
                        decoder_features_list[j] = tf.image.resize_bilinear(
                            feature, [decoder_height, decoder_width], 
                            align_corners=True)
                        h = (None if isinstance(decoder_height, tf.Tensor)
                             else decoder_height)
                        w = (None if isinstance(decoder_width, tf.Tensor)
                             else decoder_width)
                        decoder_features_list[j].set_shape([None, h, w, None])
                    decoder_depth = 256
                    if decoder_use_seperable_conv:
                        decoder_features = self.split_seperable_conv2d(
                            tf.concat(decoder_features_list, axis=3),
                            filters=decoder_depth,
                            rate=1,
                            weight_decay=weight_decay,
                            scope='decoder_conv0')
                        decoder_features = self.split_seperable_conv2d(
                            decoder_features,
                            filters=decoder_depth,
                            rate=1,
                            weight_decay=weight_decay,
                            scope='decoder_conv1')
                    else:
                        num_convs = 2
                        decoder_features = slim.repeat(
                            tf.concat(decoder_features_list, axis=3),
                            num_convs,
                            slim.conv2d,
                            decoder_depth,
                            3,
                            scope='decoder_conv' + str(i))
                return decoder_features
            
    def _get_branch_logits(self,
                           feature_map,
                           num_classes,
                           atrous_rates=[12, 24, 36],
                           kernel_size=1,
                           weight_decay=0.0001,
                           reuse=None,
                           scope_suffix='seg_logits',
                           scope='resnet_v1_50'):
        """Gets the logits from each model's branch.
        
        The underlying model is branched out in the last layer when atrous
        spatial pyramid pooling is employed, and all branches are sum-merged
        to form the final logits.
        
        Args:
            feature_map: A float32 tensor with shape [batch_size, height,
                width, channels].
            num_classes: Number of classes to predict.
            atrous_rates: A list of atrous convolution rates for last layer.
            kernel_size: Kernel size for convolution.
            weight_decay: Weight decay for the model variables.
            reuse: Reuse model variables or not.
            scope_suffix: Scope suffix for the model variables.
            
        Returns:
            Merged logits with shape [batch_size, height, width, num_classes].
        """
        with tf.variable_scope(scope):
            with slim.arg_scope(
                [slim.conv2d],
                weights_regularizer=slim.l2_regularizer(
                    weight_decay),
                weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                reuse=reuse):
                branch_logits = []
                for i, rate in enumerate(atrous_rates):
                    scope = scope_suffix
                    if i:
                        scope += '_%d' % i
                        
                    branch_logits.append(
                        slim.conv2d(feature_map,
                                    num_classes,
                                    kernel_size=kernel_size,
                                    rate=rate,
                                    activation_fn=None,
                                    normalizer_fn=None,
                                    scope=scope))
            return tf.add_n(branch_logits)
    
    def postprocess(self, prediction_tensors):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_tensors: The prediction tensors.
                
        Returns:
            The postprocessed results.
        """
        logits = tf.nn.softmax(prediction_tensors, axis=3)
        return logits
    
    def loss(self, prediction_tensors, groundtruth_tensors):
        """Compute scalar loss tensors with respect to provided groundtruth."""
        logits = tf.reshape(prediction_tensors, shape=[-1, self._num_classes])
        labels = tf.reshape(groundtruth_tensors, shape=[-1,])
        labels = tf.where(tf.greater(labels, 0.8),
                          tf.ones_like(labels),
                          labels)
        labels = tf.where(tf.logical_and(tf.less_equal(labels, 0.8),
                                         tf.greater(labels, 0.0)),
                          2 * tf.ones_like(labels),
                          labels)
        labels = tf.cast(labels, dtype=tf.int32)
        slim.losses.sparse_softmax_cross_entropy(logits, labels)
        loss = slim.losses.get_total_loss()
        return loss
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末蔚携,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子克饶,更是在濱河造成了極大的恐慌酝蜒,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件矾湃,死亡現(xiàn)場離奇詭異亡脑,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進(jìn)店門远豺,熙熙樓的掌柜王于貴愁眉苦臉地迎上來奈偏,“玉大人,你說我怎么就攤上這事躯护【矗” “怎么了?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵棺滞,是天一觀的道長裁蚁。 經(jīng)常有香客問我,道長继准,這世上最難降的妖魔是什么枉证? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮移必,結(jié)果婚禮上室谚,老公的妹妹穿的比我還像新娘。我一直安慰自己崔泵,他們只是感情好秒赤,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著憎瘸,像睡著了一般入篮。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上幌甘,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天潮售,我揣著相機(jī)與錄音,去河邊找鬼锅风。 笑死酥诽,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的遏弱。 我是一名探鬼主播盆均,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼漱逸!你這毒婦竟也來了泪姨?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤饰抒,失蹤者是張志新(化名)和其女友劉穎肮砾,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體袋坑,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡仗处,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片婆誓。...
    茶點(diǎn)故事閱讀 39,690評論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡吃环,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出洋幻,到底是詐尸還是另有隱情郁轻,我是刑警寧澤,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布文留,位于F島的核電站好唯,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏燥翅。R本人自食惡果不足惜骑篙,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望森书。 院中可真熱鬧靶端,春花似錦、人聲如沸凛膏。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽译柏。三九已至,卻和暖如春姐霍,著一層夾襖步出監(jiān)牢的瞬間鄙麦,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工镊折, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留胯府,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓恨胚,卻偏偏與公主長得像骂因,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子赃泡,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,577評論 2 353