????????本文旨在實(shí)現(xiàn)論文 Deep Image Matting 中的摳圖模型(Pytorch 版實(shí)現(xiàn)見 Pytorch 摳圖算法 Deep Image Matting 模型實(shí)現(xiàn))幻碱。
????????所有代碼見 GitHub: deep_image_matting家破。
????????摳圖是一個(gè)比較傳統(tǒng)和應(yīng)用廣泛的技術(shù),目前已經(jīng)提出了一大批的算法龟糕,見 AlphaMatting,雖然以傳統(tǒng)圖像處理的方式居多,但隨著深度學(xué)習(xí)技術(shù)的突飛猛進(jìn)者蠕,當(dāng)前摳圖效果排行榜前幾名已經(jīng)被基于深度學(xué)習(xí)的算法占據(jù)。摳圖問題可以用如下的方程來描述:
其中 表示給定的的要被摳圖的圖像佩番,
分別表示前景众旗、背景,
表示透明度的 alpha 通道趟畏。摳圖算法要求解的是上述方程右邊的
贡歧,但是因?yàn)閳D像有三個(gè)通道,因此方程右邊有 7 個(gè)未知數(shù)拱镐,而左邊只有 3 個(gè)已知值艘款,因此是一個(gè)不定方程(缺乏約束)。為了求出方程的確定解沃琅,通常的做法是添加一個(gè)額外的約束哗咆,或者事先給定一個(gè)三分圖 trimap,或者給定一個(gè)草圖 scribble益眉。比如晌柬,給定一張要被摳的圖像:
那么對(duì)應(yīng)的三分圖則類似于:
其中年碘,白色部分表示一定是前景的區(qū)域,而黑色則一定是背景展鸡,剩下的灰色是不確定區(qū)域屿衅,需要摳圖算法來求解;而草圖則比較隨意:
可以看成是三分圖的極其簡易版本莹弊。
????????Deep Image Matting 使用卷積神經(jīng)網(wǎng)絡(luò)來從原圖和三分圖中預(yù)測(cè) alpha 通道涤久,具體為:將原圖和三分圖同時(shí)輸入網(wǎng)絡(luò),首先借助卷積網(wǎng)絡(luò)從圖像中提取特征(編碼器)忍弛,然后利用轉(zhuǎn)置卷積提升分辨率預(yù)測(cè)與輸入一樣大小的 alpha 通道(解碼器)响迂,整個(gè)編碼-解碼的過程組成網(wǎng)絡(luò)的第一階段(編碼器-解碼器階段);因?yàn)榫W(wǎng)絡(luò)只關(guān)心三分圖的不確定區(qū)域(灰色區(qū)域细疚,對(duì)于確定區(qū)域由 trimap 提供 alpha 通道值)蔗彤,顯然有理由相信網(wǎng)絡(luò)的預(yù)測(cè)值要比輸入的 trimap 更準(zhǔn)確,如果用這個(gè)預(yù)測(cè)的 alpha 通道替換原來的 trimap疯兼,和原圖再次合并重新進(jìn)行編碼-解碼過程然遏,那么新的預(yù)測(cè)值將更加準(zhǔn)確,不過缺點(diǎn)也很明顯镇防,就是網(wǎng)絡(luò)太大了啦鸣,為了兼顧利用預(yù)測(cè)的更準(zhǔn)確的 alpha 通道,又不至于使網(wǎng)絡(luò)結(jié)構(gòu)太復(fù)雜来氧,論文作者將原圖和預(yù)測(cè)的 alpha 通道合并之后诫给,進(jìn)行了 4 次卷積運(yùn)行香拉,輸出最終的 alpha 通道預(yù)測(cè)值,這個(gè)過程稱為網(wǎng)絡(luò)的細(xì)化階段中狂。整個(gè)過程如下:
一凫碌、模型實(shí)現(xiàn)
????????對(duì)于給定的一張被摳圖像和對(duì)應(yīng)的三分圖,deep image matting 論文的思路是:首先使用 VGG-16 的卷積層和第一個(gè)全連接層(fc6胃榕,也用卷積實(shí)現(xiàn))作為編碼器來提取特征盛险,其中被摳圖像是三通道的,因此直接用預(yù)訓(xùn)練的 VGG-16 模型參數(shù)來初始化勋又,而三分圖這個(gè)單通道則隨機(jī)初始化苦掘;接下來,預(yù)測(cè)第一階段的 alpha 通道楔壤,因?yàn)榍懊娴木幋a階段做了 5 次步幅為 2 的池化鹤啡,因此圖像的分辨率下降了 32 倍,即如果輸入圖像的分辨率為 320 x 320蹲嚣,則現(xiàn)在的分辨率為 10 x 10递瑰,為了預(yù)測(cè)與輸入圖像具有相同分辨率的 alpha 通道,需要將分辨率擴(kuò)大 32 倍隙畜,這可以通過 5 個(gè)步幅為 2 的轉(zhuǎn)置卷積實(shí)現(xiàn)抖部;最后,將預(yù)測(cè)的 alpha 通道和輸入圖像拼接议惰,再進(jìn)行 4 個(gè)保持分辨率不變但通道數(shù)不斷減小的卷積層得到最終的預(yù)測(cè) alpha 通道慎颗。整個(gè)模型定義如下(見 model.py):
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 8 11:11:59 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from tensorflow.contrib.slim import nets
import preprocessing
slim = tf.contrib.slim
class Model(object):
"""xxx definition."""
def __init__(self, is_training,
default_image_size=320,
first_stage_alpha_loss_weight=1.0,
first_stage_image_loss_weight=1.0,
second_stage_alpha_loss_weight=1.0):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
"""
self._is_training = is_training
self._default_image_size = default_image_size
self._first_stage_alpha_loss_weight = first_stage_alpha_loss_weight
self._first_stage_image_loss_weight = first_stage_image_loss_weight
self._second_stage_alpha_loss_weight = second_stage_alpha_loss_weight
def preprocess(self, trimaps, images=None, images_forground=None,
images_background=None, alpha_mattes=None):
"""preprocessing.
Outputs of this function can be passed to loss or postprocess functions.
Args:
trimaps: A float32 tensor with shape [batch_size,
height, width, 1] representing a batch of trimaps.
images: A float32 tensor with shape [batch_size, height, width,
3] representing a batch of images. Only passed values in case
of test (i.e., in training case images=None).
images_foreground: A float32 tensor with shape [batch_size,
height, width, 3] representing a batch of foreground images.
images_background: A float32 tensor with shape [batch_size,
height, width, 3] representing a batch of background images.
alpha_mattes: A float32 tensor with shape [batch_size,
height, width, 1] representing a batch of groundtruth masks.
Returns:
The preprocessed tensors.
"""
def _random_crop(t):
num_channels = t.get_shape().as_list()[2]
return preprocessing.random_crop_background(
t, output_height=self._default_image_size,
output_width=self._default_image_size,
channels=num_channels)
def _border_expand_and_resize(t):
return preprocessing.border_expand_and_resize(
t, output_height=self._default_image_size,
output_width=self._default_image_size)
def _border_expand_and_resize_g(t):
return preprocessing.border_expand_and_resize(
t, output_height=self._default_image_size,
output_width=self._default_image_size,
channels=1)
preprocessed_images_fg = None
preprocessed_images_bg = None
preprocessed_alpha_mattes = None
preprocessed_trimaps = tf.map_fn(_border_expand_and_resize_g, trimaps)
preprocessed_trimaps = tf.to_float(preprocessed_trimaps)
if self._is_training:
preprocessed_images_fg = tf.map_fn(_border_expand_and_resize,
images_forground)
preprocessed_alpha_mattes = tf.map_fn(_border_expand_and_resize_g,
alpha_mattes)
images_background = tf.to_float(images_background)
preprocessed_images_bg = tf.map_fn(_random_crop, images_background)
preprocessed_images_fg = tf.to_float(preprocessed_images_fg)
preprocessed_alpha_mattes = tf.to_float(preprocessed_alpha_mattes)
preprocessed_images = (tf.multiply(
preprocessed_alpha_mattes, preprocessed_images_fg) +
tf.multiply(
1 - preprocessed_alpha_mattes, preprocessed_images_bg))
else:
preprocessed_images = tf.map_fn(_border_expand_and_resize, images)
preprocessed_images = tf.to_float(preprocessed_images)
preprocessed_dict = {'images_fg': preprocessed_images_fg,
'images_bg': preprocessed_images_bg,
'alpha_mattes': preprocessed_alpha_mattes,
'images': preprocessed_images,
'trimaps': preprocessed_trimaps}
return preprocessed_dict
def predict(self, preprocessed_dict):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_dict: See The preprocess function.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
# The inputs for the first stage
preprocessed_images = preprocessed_dict.get('images')
preprocessed_trimaps = preprocessed_dict.get('trimaps')
# VGG-16
_, endpoints = nets.vgg.vgg_16(preprocessed_images,
num_classes=1,
spatial_squeeze=False,
is_training=self._is_training)
# Note: The `padding` method of fc6 of VGG-16 in tf.contrib.slim is
# `VALID`, but the expected value is `SAME`, so we must replace it.
net_image = endpoints.get('vgg_16/pool5')
net_image = slim.conv2d(net_image, num_outputs=4096, kernel_size=7,
padding='SAME', scope='fc6_')
# VGG-16 for alpha channel
net_alpha = slim.repeat(preprocessed_trimaps, 2, slim.conv2d, 64,
[3, 3], scope='conv1_alpha')
net_alpha = slim.max_pool2d(net_alpha, [2, 2], scope='pool1_alpha')
net_alpha = slim.repeat(net_alpha, 2, slim.conv2d, 128, [3, 3],
scope='conv2_alpha')
net_alpha = slim.max_pool2d(net_alpha, [2, 2], scope='pool2_alpha')
net_alpha = slim.repeat(net_alpha, 2, slim.conv2d, 256, [3, 3],
scope='conv3_alpha')
net_alpha = slim.max_pool2d(net_alpha, [2, 2], scope='pool3_alpha')
net_alpha = slim.repeat(net_alpha, 2, slim.conv2d, 512, [3, 3],
scope='conv4_alpha')
net_alpha = slim.max_pool2d(net_alpha, [2, 2], scope='pool4_alpha')
net_alpha = slim.repeat(net_alpha, 2, slim.conv2d, 512, [3, 3],
scope='conv5_alpha')
net_alpha = slim.max_pool2d(net_alpha, [2, 2], scope='pool5_alpha')
net_alpha = slim.conv2d(net_alpha, 4096, [7, 7], padding='SAME',
scope='fc6_alpha')
# Concate the first stage prediction
net = tf.concat(values=[net_image, net_alpha], axis=3)
net.set_shape([None, self._default_image_size // 32,
self._default_image_size // 32, 8192])
# Deconvlution
with slim.arg_scope([slim.conv2d_transpose], stride=2, kernel_size=5):
# Deconv6
net = slim.conv2d_transpose(net, num_outputs=512, kernel_size=1,
scope='deconv6')
# Deconv5
net = slim.conv2d_transpose(net, num_outputs=512, scope='deconv5')
# Deconv4
net = slim.conv2d_transpose(net, num_outputs=256, scope='deconv4')
# Deconv3
net = slim.conv2d_transpose(net, num_outputs=128, scope='deconv3')
# Deconv2
net = slim.conv2d_transpose(net, num_outputs=64, scope='deconv2')
# Deconv1
net = slim.conv2d_transpose(net, num_outputs=64, stride=1,
scope='deconv1')
# Predict alpha matte
alpha_matte = slim.conv2d(net, num_outputs=1, kernel_size=[5, 5],
activation_fn=tf.nn.sigmoid,
scope='AlphaMatte')
# The inputs for the second stage
alpha_matte_scaled = tf.multiply(alpha_matte, 255.)
refine_inputs = tf.concat(
values=[preprocessed_images, alpha_matte_scaled], axis=3)
refine_inputs.set_shape([None, self._default_image_size,
self._default_image_size, 4])
# Refine
net = slim.conv2d(refine_inputs, num_outputs=64, kernel_size=[3, 3],
scope='refine_conv1')
net = slim.conv2d(net, num_outputs=64, kernel_size=[3, 3],
scope='refine_conv2')
net = slim.conv2d(net, num_outputs=64, kernel_size=[3, 3],
scope='refine_conv3')
refined_alpha_matte = slim.conv2d(net, num_outputs=1,
kernel_size=[3, 3],
activation_fn=tf.nn.sigmoid,
scope='RefinedAlphaMatte')
prediction_dict = {'alpha_matte': alpha_matte,
'refined_alpha_matte': refined_alpha_matte,
'trimaps': preprocessed_trimaps,}
return prediction_dict
def postprocess(self, prediction_dict, use_trimap=True):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
alpha_matte = prediction_dict.get('alpha_matte')
refined_alpha_matte = prediction_dict.get('refined_alpha_matte')
if use_trimap:
trimaps = prediction_dict.get('trimaps')
alpha_matte = tf.where(tf.equal(trimaps, 128), alpha_matte,
trimaps / 255.)
refined_alpha_matte = tf.where(tf.equal(trimaps, 128),
refined_alpha_matte,
trimaps / 255.)
postprocessed_dict = {'alpha_matte': alpha_matte,
'refined_alpha_matte': refined_alpha_matte}
return postprocessed_dict
def loss(self, prediction_dict, preprocessed_dict, epsilon=1e-12):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
preprocessed_dict: A dictionary of tensors holding groundtruth
information, see preprocess function. The pixel values of
groundtruth_alpha_matte must be in [0, 128, 255].
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
gt_images = preprocessed_dict.get('images')
gt_fg = preprocessed_dict.get('images_fg')
gt_bg = preprocessed_dict.get('images_bg')
gt_alpha_matte = preprocessed_dict.get('alpha_mattes')
alpha_matte = prediction_dict.get('alpha_matte')
refined_alpha_matte = prediction_dict.get('refined_alpha_matte')
pred_images = tf.multiply(alpha_matte, gt_fg) + tf.multiply(
1 - alpha_matte, gt_bg)
trimaps = prediction_dict.get('trimaps')
weights = tf.where(tf.equal(trimaps, 128),
tf.ones_like(trimaps),
tf.zeros_like(trimaps))
total_weights = tf.reduce_sum(weights) + epsilon
first_stage_alpha_losses = tf.sqrt(
tf.square(alpha_matte - gt_alpha_matte) + epsilon)
first_stage_alpha_loss = tf.reduce_sum(
first_stage_alpha_losses * weights) / total_weights
first_stage_image_losses = tf.sqrt(
tf.square(pred_images - gt_images) + epsilon) / 255.
first_stage_image_loss = tf.reduce_sum(
first_stage_image_losses * weights) / total_weights
second_stage_alpha_losses = tf.sqrt(
tf.square(refined_alpha_matte - gt_alpha_matte) + epsilon)
second_stage_alpha_loss = tf.reduce_sum(
second_stage_alpha_losses * weights) / total_weights
loss = (self._first_stage_alpha_loss_weight * first_stage_alpha_loss +
self._first_stage_image_loss_weight * first_stage_image_loss +
self._second_stage_alpha_loss_weight * second_stage_alpha_loss)
loss_dict = {'loss': loss}
return loss_dict
說明:
????????1.在 tf.contrib.slim
中的 VGG-16 的定義中,雖然 fc6 已經(jīng)用卷積替換全連接言询,但 padding 的方式是 VALID哗总,這樣經(jīng)過 fc6 作用后分辨率將變成 4 x 4(10 - 7 + 1 = 4,假如輸入圖像分辨率為 320 x 320)倍试,將給后面擴(kuò)充特征映射分辨率帶來麻煩。因此需要將該層的 padding 方式修改為 SMAE蛋哭,從而分辨率仍然保持為 10 x 10县习,這樣通過 5 個(gè)步幅為 2 的轉(zhuǎn)置卷積就可以將分辨率擴(kuò)充到 320 x 320。
????????2.因?yàn)轭A(yù)訓(xùn)練的 VGG-16 模型的參數(shù)是針對(duì) 3 通道圖像的谆趾,因此雖然待摳圖像和三分圖都要經(jīng)過 VGG-16 網(wǎng)絡(luò)躁愿,但為了導(dǎo)入預(yù)訓(xùn)練模型,仍然需要將它們分裂為兩部分獨(dú)立的輸入 VGG-16 模型沪蓬。(以上 model.py 定義 alpha 通道的 VGG-16 模型時(shí)寫得復(fù)雜了彤钟,簡化版參考如下說明 3 的 AlphaResNet
部分定義。)
????????3.因?yàn)?ResNet-50 比 VGG-16 ,在 ImageNet 上的分類效果好跷叉,而且模型參數(shù)總量更小逸雹,因此可以用 ResNet-50 替換 VGG-16营搅,這時(shí)候可以將輸入圖像大小擴(kuò)充為 640 x 640 的分辨率(但在 1080Ti 上需要將批量由 4 減小為 2)。替換代碼如下(只需要替換 predict
函數(shù)):
def predict(self, preprocessed_dict):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_dict: See The preprocess function.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
# The inputs for the first stage
preprocessed_images = preprocessed_dict.get('images')
preprocessed_trimaps = preprocessed_dict.get('trimaps')
# ResNet-50
net_image, _ = nets.resnet_v1.resnet_v1_50(
preprocessed_images,num_classes=None, global_pool=False,
is_training=self._is_training)
# ResNet-50 for alpha channel
with tf.variable_scope('AlphaResNet'):
net_alpha, _ = nets.resnet_v1.resnet_v1_50(
preprocessed_trimaps, num_classes=None, global_pool=False,
is_training=self._is_training)
# Concate the first stage prediction
net = tf.concat(values=[net_image, net_alpha], axis=3)
net.set_shape([None, self._default_image_size // 32,
self._default_image_size // 32, 4096])
# Deconvlution
with slim.arg_scope([slim.conv2d_transpose], stride=2, kernel_size=5):
# Deconv6
... (下同)
????????4.因?yàn)槿謭D中白色區(qū)域是確定的前景梆砸,黑色是確定的背景转质,因此在后處理(見函數(shù) postprocess
)時(shí),直接在預(yù)測(cè)結(jié)果基礎(chǔ)上將對(duì)應(yīng)的前景帖世、背景區(qū)域替換為三分圖的前景休蟹、背景區(qū)域值作為模型最后的輸出。
????????顯然日矫,整個(gè)模型的結(jié)構(gòu)是非常清晰的赂弓,接下來需要定義損失函數(shù)。損失函數(shù)由三部分組成哪轿,第一階段包含兩個(gè)損失盈魁,第二階段包含一個(gè)損失,這三個(gè)損失的加權(quán)和即是模型的總損失缔逛。因?yàn)楸赴#謭D中白色區(qū)域、黑色區(qū)域都是確定的前景褐奴、背景按脚,因此這兩個(gè)區(qū)域不存在損失,所以損失只需要對(duì)灰色區(qū)域計(jì)算即可敦冬。第一階段的損失包括:alpha 預(yù)測(cè)損失辅搬,即預(yù)測(cè)的
alpha 通道和 groundtruth 的 alpha 通道的損失值;圖像合成損失脖旱,即前景圖像堪遂、背景圖像關(guān)于預(yù)測(cè)的 alpha 通道的合成圖像,和前景圖像萌庆、背景圖像關(guān)于 groundtruth 的 alpha 通道的合成圖像的損失值溶褪。第二階段的損失只有 alpha 預(yù)測(cè)損失,即細(xì)化的 alpha 通道預(yù)測(cè)值和 groundtruth 的 alpha 通道值之間的損失践险。論文中使用的三個(gè)損失都是逐像素的差值絕對(duì)值之和猿妈。具體實(shí)現(xiàn)見 loss
函數(shù)。
二巍虫、代碼解釋
三彭则、訓(xùn)練實(shí)例
(未完,待續(xù))