Image Inpainting for Irregular Holes Using Partial Convolutions

引言

本文是由NVIDIA提出的一種基于局部卷積的圖像修復(fù)算法。
圖像修復(fù),即修復(fù)圖像中缺失的塊,可用于圖像編輯硼讽,替換掉圖像中不想要的內(nèi)容。本文使用自動(dòng)mask更新的局部卷積網(wǎng)絡(luò)進(jìn)行圖像修復(fù)牲阁。

方法

partial convolutional layer

假設(shè)W是卷積核的權(quán)重固阁,b是相應(yīng)的偏差。X是當(dāng)前卷積窗口的特征值城菊,M是相應(yīng)的二進(jìn)制mask备燃。則卷積計(jì)算為
x'=\begin{cases} W^T (X*M) \frac{sum(1)}{sum(M)} + b,\quad sum(M)\leq 0 \\\\ 0,\quad otherwise \end{cases}
由上式可知,輸出是由沒(méi)有mask的輸入決定凌唬。
在局部卷積操作之后并齐,需要更新mask:
x'=\begin{cases} 1,\quad sum(M)\leq 0 \\\\ 0,\quad otherwise \end{cases}

網(wǎng)絡(luò)結(jié)構(gòu)

整體網(wǎng)絡(luò)使用UNet結(jié)構(gòu),將所有的卷積換成局部卷積層客税,在decoder階段使用最近鄰上采樣况褪。skip連接連接兩個(gè)特征圖和兩個(gè)mask作為下一層的輸入,最后一個(gè)卷積層的輸入為有洞的原始輸入和原始mask的組成更耻。

損失函數(shù)

給定有洞的輸入I_{in}窝剖,初始化二進(jìn)制maskM(有洞的地方為0),網(wǎng)絡(luò)的輸出為I_{out}酥夭,原始的圖像為I_{gt}
1、像素?fù)p失:
L_{hole} = ||(1-M)*(I_{out}-I_{gt})||_1
L_{valid} = ||M*(I_{out}-I_{gt})||_1
2脊奋、感知損失
L_{perceptual} = \sum_{n=0}^{N-1}||\psi(I_{out}) - \psi(I_{gt}) ||_1 + \sum_{n=0}^{N-1}||\psi(I_{comp}) - \psi(I_{gt}) ||_1
其中熬北,I_{comp}是未加工的輸出圖像I_{output}\psi使用vgg16的pool1诚隙,pool2,pool3
3讶隐、風(fēng)格損失
L_{style_{out}} = \sum_{n=0}^{N-1}||K_n ((\psi(I_{out}))^T(\psi(I_{out}) - (\psi(I_{gt}))^T(\psi(I_{gt}) ||_1 +
L_{style_{comp}} = sum_{n=0}^{N-1}||K_n ((\psi(I_{comp}))^T(\psi(I_{comp}) - (\psi(I_{gt}))^T(\psi(I_{gt}) ||_1

4、全變差損失
L_{tv} = \sum ||I_{comp}^{i,j+1} - I_{comp}^{i,j}||_1 + \sum ||I_{comp}^{i+1,j} - I_{comp}^{i,j}||_1
總的損失函數(shù)為:
L_{total} = L_{valid} + 6L_{hole} + 0.05 L_{perceptual} + 120 (L_{style_{out}} + L_{style_{comp}}) + 0.1L_{tv}

代碼分析

1久又、局部卷積層

from keras.utils import conv_utils
from keras import backend as K
from keras.engine import InputSpec
from keras.layers import Conv2D

class PConv2D(Conv2D):
    def __init__(self, *args, n_channels=3, mono=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
    def build(self, input_shape):        
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
            
        if input_shape[0][channel_axis] is None:
            raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
            
        self.input_dim = input_shape[0][channel_axis]
  
        kernel_shape = self.kernel_size + (self.input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='img_kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters))

        # Calculate padding size to achieve zero-padding
        self.pconv_padding = (
            (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)), 
            (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)), 
        )

        # Window size - used for normalization
        self.window_size = self.kernel_size[0] * self.kernel_size[1]
        
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, inputs, mask=None):

        if type(inputs) is not list or len(inputs) != 2:
            raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))

        # Padding done explicitly so that padding becomes part of the masked partial convolution
        images = K.spatial_2d_padding(inputs[0], self.pconv_padding, self.data_format)
        masks = K.spatial_2d_padding(inputs[1], self.pconv_padding, self.data_format)
        # Apply convolutions to mask
        mask_output = K.conv2d(
            masks, self.kernel_mask, 
            strides=self.strides,
            padding='valid',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate
        )
        # Apply convolutions to image
        img_output = K.conv2d(
            (images*masks), self.kernel, 
            strides=self.strides,
            padding='valid',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate
        )        
        # Calculate the mask ratio on each pixel in the output mask
        mask_ratio = self.window_size / (mask_output + 1e-8)
        # Clip output to be between 0 and 1
        mask_output = K.clip(mask_output, 0, 1)
        # Remove ratio values where there are holes
        mask_ratio = mask_ratio * mask_output
        # Normalize iamge output
        img_output = img_output * mask_ratio
        # Apply bias only to the image (if chosen to do so)
        if self.use_bias:
            img_output = K.bias_add(
                img_output,
                self.bias,
                data_format=self.data_format)
        
        # Apply activations on the image
        if self.activation is not None:
            img_output = self.activation(img_output)
         
        return [img_output, mask_output]
    
    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[0][1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,)
            return [new_shape, new_shape]
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0], self.filters) + tuple(new_space)
            return [new_shape, new_shape]

2巫延、損失函數(shù)

    def loss_hole(self, mask, y_true, y_pred):
        """Pixel L1 loss within the hole / mask"""
        return self.l1((1-mask) * y_true, (1-mask) * y_pred)
    
    def loss_valid(self, mask, y_true, y_pred):
        """Pixel L1 loss outside the hole / mask"""
        return self.l1(mask * y_true, mask * y_pred)
    
    def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp): 
        """Perceptual loss based on VGG16, see. eq. 3 in paper"""       
        loss = 0
        for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
            loss += self.l1(o, g) + self.l1(c, g)
        return loss
        
    def loss_style(self, output, vgg_gt):
        """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
        loss = 0
        for o, g in zip(output, vgg_gt):
            loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
        return loss
    
    def loss_tv(self, mask, y_comp):
        """Total variation loss, used for smoothing the hole region, see. eq. 6"""

        # Create dilated hole region using a 3x3 kernel of all 1s.
        kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
        dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')

        # Cast values to be [0., 1.], and compute dilated hole region of y_comp
        dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
        P = dilated_mask * y_comp

        # Calculate total variation loss
        a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
        b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])        
        return a+b

參考文獻(xiàn)

[1]Image Inpainting for Irregular Holes Using
Partial Convolutions

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市地消,隨后出現(xiàn)的幾起案子炉峰,更是在濱河造成了極大的恐慌,老刑警劉巖脉执,帶你破解...
    沈念sama閱讀 218,755評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件疼阔,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)婆廊,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,305評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門迅细,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人淘邻,你說(shuō)我怎么就攤上這事茵典。” “怎么了宾舅?”我有些...
    開封第一講書人閱讀 165,138評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵统阿,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我贴浙,道長(zhǎng)砂吞,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,791評(píng)論 1 295
  • 正文 為了忘掉前任崎溃,我火速辦了婚禮蜻直,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘袁串。我一直安慰自己概而,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,794評(píng)論 6 392
  • 文/花漫 我一把揭開白布囱修。 她就那樣靜靜地躺著赎瑰,像睡著了一般。 火紅的嫁衣襯著肌膚如雪破镰。 梳的紋絲不亂的頭發(fā)上餐曼,一...
    開封第一講書人閱讀 51,631評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音鲜漩,去河邊找鬼源譬。 笑死,一個(gè)胖子當(dāng)著我的面吹牛孕似,可吹牛的內(nèi)容都是我干的踩娘。 我是一名探鬼主播,決...
    沈念sama閱讀 40,362評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼喉祭,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼养渴!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起泛烙,我...
    開封第一講書人閱讀 39,264評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤理卑,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后蔽氨,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體傻工,經(jīng)...
    沈念sama閱讀 45,724評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡霞溪,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,900評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了中捆。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片鸯匹。...
    茶點(diǎn)故事閱讀 40,040評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖泄伪,靈堂內(nèi)的尸體忽然破棺而出殴蓬,到底是詐尸還是另有隱情,我是刑警寧澤蟋滴,帶...
    沈念sama閱讀 35,742評(píng)論 5 346
  • 正文 年R本政府宣布染厅,位于F島的核電站,受9級(jí)特大地震影響津函,放射性物質(zhì)發(fā)生泄漏肖粮。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,364評(píng)論 3 330
  • 文/蒙蒙 一尔苦、第九天 我趴在偏房一處隱蔽的房頂上張望涩馆。 院中可真熱鬧,春花似錦允坚、人聲如沸魂那。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,944評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)涯雅。三九已至,卻和暖如春展运,著一層夾襖步出監(jiān)牢的瞬間活逆,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,060評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工拗胜, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留蔗候,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,247評(píng)論 3 371
  • 正文 我出身青樓挤土,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親误算。 傳聞我的和親對(duì)象是個(gè)殘疾皇子仰美,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,979評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容