引言
本文是由NVIDIA提出的一種基于局部卷積的圖像修復(fù)算法。
圖像修復(fù),即修復(fù)圖像中缺失的塊,可用于圖像編輯硼讽,替換掉圖像中不想要的內(nèi)容。本文使用自動(dòng)mask更新的局部卷積網(wǎng)絡(luò)進(jìn)行圖像修復(fù)牲阁。
方法
partial convolutional layer
假設(shè)是卷積核的權(quán)重固阁,
是相應(yīng)的偏差。
是當(dāng)前卷積窗口的特征值城菊,
是相應(yīng)的二進(jìn)制mask备燃。則卷積計(jì)算為
由上式可知,輸出是由沒(méi)有mask的輸入決定凌唬。
在局部卷積操作之后并齐,需要更新mask:
網(wǎng)絡(luò)結(jié)構(gòu)
整體網(wǎng)絡(luò)使用UNet結(jié)構(gòu),將所有的卷積換成局部卷積層客税,在decoder階段使用最近鄰上采樣况褪。skip連接連接兩個(gè)特征圖和兩個(gè)mask作為下一層的輸入,最后一個(gè)卷積層的輸入為有洞的原始輸入和原始mask的組成更耻。
損失函數(shù)
給定有洞的輸入窝剖,初始化二進(jìn)制mask
(有洞的地方為0),網(wǎng)絡(luò)的輸出為
酥夭,原始的圖像為
1、像素?fù)p失:
2脊奋、感知損失
其中熬北,是未加工的輸出圖像
,
使用vgg16的pool1诚隙,pool2,pool3
3讶隐、風(fēng)格損失
4、全變差損失
總的損失函數(shù)為:
代碼分析
1久又、局部卷積層
from keras.utils import conv_utils
from keras import backend as K
from keras.engine import InputSpec
from keras.layers import Conv2D
class PConv2D(Conv2D):
def __init__(self, *args, n_channels=3, mono=False, **kwargs):
super().__init__(*args, **kwargs)
self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
def build(self, input_shape):
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[0][channel_axis] is None:
raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
self.input_dim = input_shape[0][channel_axis]
kernel_shape = self.kernel_size + (self.input_dim, self.filters)
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='img_kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters))
# Calculate padding size to achieve zero-padding
self.pconv_padding = (
(int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)),
(int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)),
)
# Window size - used for normalization
self.window_size = self.kernel_size[0] * self.kernel_size[1]
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.built = True
def call(self, inputs, mask=None):
if type(inputs) is not list or len(inputs) != 2:
raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))
# Padding done explicitly so that padding becomes part of the masked partial convolution
images = K.spatial_2d_padding(inputs[0], self.pconv_padding, self.data_format)
masks = K.spatial_2d_padding(inputs[1], self.pconv_padding, self.data_format)
# Apply convolutions to mask
mask_output = K.conv2d(
masks, self.kernel_mask,
strides=self.strides,
padding='valid',
data_format=self.data_format,
dilation_rate=self.dilation_rate
)
# Apply convolutions to image
img_output = K.conv2d(
(images*masks), self.kernel,
strides=self.strides,
padding='valid',
data_format=self.data_format,
dilation_rate=self.dilation_rate
)
# Calculate the mask ratio on each pixel in the output mask
mask_ratio = self.window_size / (mask_output + 1e-8)
# Clip output to be between 0 and 1
mask_output = K.clip(mask_output, 0, 1)
# Remove ratio values where there are holes
mask_ratio = mask_ratio * mask_output
# Normalize iamge output
img_output = img_output * mask_ratio
# Apply bias only to the image (if chosen to do so)
if self.use_bias:
img_output = K.bias_add(
img_output,
self.bias,
data_format=self.data_format)
# Apply activations on the image
if self.activation is not None:
img_output = self.activation(img_output)
return [img_output, mask_output]
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_last':
space = input_shape[0][1:-1]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding='same',
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,)
return [new_shape, new_shape]
if self.data_format == 'channels_first':
space = input_shape[2:]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding='same',
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
new_shape = (input_shape[0], self.filters) + tuple(new_space)
return [new_shape, new_shape]
2巫延、損失函數(shù)
def loss_hole(self, mask, y_true, y_pred):
"""Pixel L1 loss within the hole / mask"""
return self.l1((1-mask) * y_true, (1-mask) * y_pred)
def loss_valid(self, mask, y_true, y_pred):
"""Pixel L1 loss outside the hole / mask"""
return self.l1(mask * y_true, mask * y_pred)
def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
"""Perceptual loss based on VGG16, see. eq. 3 in paper"""
loss = 0
for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
loss += self.l1(o, g) + self.l1(c, g)
return loss
def loss_style(self, output, vgg_gt):
"""Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
loss = 0
for o, g in zip(output, vgg_gt):
loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
return loss
def loss_tv(self, mask, y_comp):
"""Total variation loss, used for smoothing the hole region, see. eq. 6"""
# Create dilated hole region using a 3x3 kernel of all 1s.
kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
# Cast values to be [0., 1.], and compute dilated hole region of y_comp
dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
P = dilated_mask * y_comp
# Calculate total variation loss
a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
return a+b
參考文獻(xiàn)
[1]Image Inpainting for Irregular Holes Using
Partial Convolutions