定義兩個(gè)間距,一個(gè)用于內(nèi)容薛闪,另一個(gè)用于風(fēng)格
測(cè)量?jī)蓮垐D片內(nèi)容的不同,而用來(lái)測(cè)量?jī)蓮垐D片風(fēng)格的不同俺陋。然后豁延,我們輸入第三張圖片,并改變這張圖片腊状,使其與內(nèi)容圖片的內(nèi)容間距和風(fēng)格圖片的風(fēng)格間距最小化
樣式遷移常用的損失函數(shù)由3部分組成:內(nèi)容損失(content loss)使合成圖像與內(nèi)容圖像在內(nèi)容特征上接近诱咏,樣式損失(style loss)令合成圖像與樣式圖像在樣式特征上接近,而總變差損失(total variation loss)則有助于減少合成圖像中的噪點(diǎn)缴挖。最后袋狞,當(dāng)模型訓(xùn)練結(jié)束時(shí),我們輸出樣式遷移的模型參數(shù)醇疼,即得到最終的合成圖像硕并。
損失函數(shù)
內(nèi)容損失
class ContentLoss(nn.Module):
def __init__(self, target,):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.target = target.detach()
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
風(fēng)格損失
def gram_matrix(input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input