本文目錄:
- Introduction
- Related work
- Methods
- Gram 矩陣
- Batch Normalization
Introduction
不久前麻削,一個(gè)名叫Prisma的APP在微博和朋友圈火了起來。Prisma是個(gè)能夠?qū)D像風(fēng)格轉(zhuǎn)換為藝術(shù)風(fēng)格的APP嗽仪,它能夠?qū)崿F(xiàn)如下轉(zhuǎn)換:
除了引起大眾的好奇心外,業(yè)內(nèi)人士也紛紛猜測Prisma是如何做到實(shí)現(xiàn)快速的圖像風(fēng)格轉(zhuǎn)換肿嘲。此前晋南,在Gatys的論文<Image Style Transfer Using Convolutional Neural Networks>中嵌削,實(shí)現(xiàn)一張圖片的圖像風(fēng)格轉(zhuǎn)換需要較長時(shí)間磷雇。
在文中我將講解Prisma是如何實(shí)現(xiàn)實(shí)時(shí)風(fēng)格轉(zhuǎn)換的偿警。本文內(nèi)容基于Fei Fei Li團(tuán)隊(duì)的<Perceptual Losses for Real-Time Style Transfer
and Super-Resolution>一文。
系列文章目錄如下:
- 梵高眼中的世界(一)實(shí)時(shí)圖像風(fēng)格轉(zhuǎn)換簡介
- 梵高眼中的世界(二)基于perceptual損失的網(wǎng)絡(luò)
- 梵高眼中的世界(三)實(shí)現(xiàn)與改進(jìn)
Related work
在進(jìn)行圖像風(fēng)格轉(zhuǎn)換時(shí)唯笙,我們需要一張風(fēng)格圖像style image和一張內(nèi)容圖像content image螟蒸。我們構(gòu)造一個(gè)網(wǎng)絡(luò)衡量生成圖像與style image以及content image的loss盒使,再通過訓(xùn)練減小loss得到最終圖像。
在Gatys的方法中七嫌,他使用了如下圖所示的方法:
上圖最左邊是風(fēng)格圖像少办,梵高的《星夜》;最右邊是內(nèi)容圖像诵原。
算法步驟如下:
生成了一張白噪聲圖像作為初始圖像英妓。
將風(fēng)格圖像,內(nèi)容圖像绍赛,初始圖像分別通過一個(gè)預(yù)訓(xùn)練的VGG-19網(wǎng)絡(luò)鞋拟,得到某些層的輸出。這里的“某些層”是經(jīng)過實(shí)驗(yàn)得出的惹资,是使得輸出圖像最佳的層數(shù)。
-
計(jì)算內(nèi)容損失函數(shù):
內(nèi)容損失函數(shù)
其中Pl_ij是原始圖像在第l層位置j與第i個(gè)filter卷積后的輸出航闺,F(xiàn)l_ij是相應(yīng)的生成圖像的輸出褪测。
計(jì)算風(fēng)格損失函數(shù):
風(fēng)格損失函數(shù)與圖像有些不同,在這里我們不直接使用某些層卷積后的輸出潦刃,而是計(jì)算輸出的Gram矩陣侮措,再用于上式風(fēng)格損失的計(jì)算:
5.計(jì)算總損失
此時(shí)我們可以通過梯度下降算法對初始化的白噪聲圖像進(jìn)行訓(xùn)練,得到最終的風(fēng)格轉(zhuǎn)換圖像乖杠。
Gatys的算法缺點(diǎn)是一次只能訓(xùn)練出一張圖分扎。我們希望得到一個(gè)前饋的神經(jīng)網(wǎng)絡(luò),對于每一張內(nèi)容圖像胧洒,只需要通過這個(gè)前饋神經(jīng)網(wǎng)絡(luò)畏吓,就能快速得到風(fēng)格轉(zhuǎn)換圖像。
Methods
在這里只對Gram matrix以及Batch Normalization進(jìn)行講解卫漫,具體實(shí)現(xiàn)細(xì)節(jié)請閱讀原文菲饼。
Gram matrix
Gram matrix 計(jì)算如下:
上式的意思為,G^l_i,j意味著第l層特征圖i和j的內(nèi)積列赎。同理可表示為:
在論文中宏悦,作者用高維的特征圖相關(guān)性來表示圖像風(fēng)格。上式矩陣的對角線表示每一個(gè)特征圖自身的信息包吝,其余元素表示了不同特征圖之間的信息饼煞。
Gram matrix的tensorflow實(shí)現(xiàn)如下:
def gram_matrix(x):
'''
Args:
x: Tensor with shape [batch size, length, width, channels]
Return:
Tensor with shape [channels, channels]
'''
bs, l, w, c = x.get_shape()
size = l*w*c
x = tf.reshape(x, (bs, l*w, c))
x_t = tf.transpose(x, perm=[0,2,1])
return tf.matmul(x_t, x)/size
Batch Normalization
Batch Normalization 最早由Google在ICML2015的論文<Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift>提出。
其算法如下:
這個(gè)算法看上去有點(diǎn)復(fù)雜诗越,但直觀上很好理解:
對于一個(gè)mini-batch里面的值x_i砖瞧,我們計(jì)算平均值 μ和方差σ。對于每一個(gè)x_i掺喻,我們對其進(jìn)行z-score歸一化芭届,得到平均值為0储矩,標(biāo)準(zhǔn)差為1的數(shù)據(jù)。式子中的ε是一個(gè)很小的偏差值褂乍,防止出現(xiàn)除以0的情況持隧。實(shí)現(xiàn)中可以取ε=1e-3。在對數(shù)據(jù)進(jìn)行歸一化后逃片,BN算法再進(jìn)行“scale and shift”屡拨,將數(shù)據(jù)還原成原來的輸入。
Batch Normalization是為了解決Internal Covariate Shift問題而提出褥实。
Batch Normalization在Tensorflow下的實(shí)現(xiàn):
from tensorflow.contrib.layers import batch_norm
def batch_norm_layer(x, is_training, scope):
bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=True,
reuse=None,
trainable=True,
scope=scope)
bn_test = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=False,
reuse=True,
trainable=True,
scope=scope)
bn = tf.cond(is_training, lambda: bn_train, lambda: bn_test)
return bn
注意其中is_training是一個(gè)placeholder呀狼。