Pytorch 摳圖算法 Deep Image Matting 模型實(shí)現(xiàn)

????????本文旨在實(shí)現(xiàn)摳圖算法 Semantic Human Matting 的第二階段模型 M-Net能扒,也即 Deep Image Matting。值得說明的是,本文實(shí)現(xiàn)的模型與原始論文略有出入纺涤,除了模型的輸入層有細(xì)微差別之外,損失函數(shù)也作了簡化(但無本質(zhì)差別)抠忘。

????????本文完整代碼見 GitHub: deep_image_matting_pytorch撩炊。Pytorch 需要 1.1.0 或后續(xù)版本

????????本文 訓(xùn)練數(shù)據(jù) 來源于 愛分割 公司開源的 數(shù)據(jù)集崎脉,總共包含 34426 張圖片和對(duì)應(yīng)的 alpha 通道拧咳,數(shù)據(jù)量非常大,能公開特別值得點(diǎn)贊囚灼。但同時(shí)因?yàn)闃?biāo)注數(shù)據(jù)的 alpha 通道精度不高骆膝,導(dǎo)致訓(xùn)練后測(cè)試效果較差。建議使用 Deep Image Matting 的數(shù)據(jù)集訓(xùn)練灶体。

Semantic Human Matting 摳圖模型

????????總的來說阅签,Semantic Human Matting 論文提出的自動(dòng)摳圖的思路特別清晰明了(如上圖),對(duì)于一張待摳圖像蝎抽,首先通過語義分割模型(即 T-Net)分割出前景F_s政钟、背景B_s和未知區(qū)域U_sF_s + B_s + U_s = 1),然后廣義的認(rèn)為前景(F_s)+ 未知區(qū)域(U_s)組成一個(gè)三分圖( Trimap)织中,此時(shí)再利用 Deep Image Matting(即 M-Net) 即可高質(zhì)量的完成摳圖锥涕。完整的模型將在接下來的幾篇文章逐步實(shí)現(xiàn)衷戈,本文只關(guān)注該模型的第二階段(M-Net)狭吼。

????????M-Net 接受待摳圖像(前景與背景的 RGB 3 通道合成)以及語義分割模型輸出的 3 通道預(yù)測(cè)(F_s, B_s, U_s)拼接而成的 6 通道輸入,經(jīng)過編碼器提取圖像特征之后殖妇,由解碼器得到預(yù)測(cè) \alpha_r刁笙。如果語義分割模型分割的精度較高,那么可以認(rèn)為 F_s, B_s 對(duì)應(yīng)的區(qū)域已經(jīng)很好的摳出了大部分的前景和背景谦趣,唯一需要提升準(zhǔn)確率的是待摳對(duì)象的邊緣區(qū)域疲吸,所以模型的第二階段 M-Net 的目的就是細(xì)化的預(yù)測(cè)邊緣區(qū)域(這正是 Deep Image Matting 要干的事情),兩部分結(jié)合即得到最終的預(yù)測(cè):
F_s + U_s\alpha_r 前鹅。

這個(gè)公式可以這樣理解:
\textrm{預(yù)測(cè)的前景} = \textrm{確定區(qū)域上的前景} + \textrm{未知區(qū)域上的前景}

也就是:
P(\textrm{前景}) = P(確定區(qū)域上前景) + P(未知區(qū)域上前景)

根據(jù)全概率公式摘悴,用符號(hào)來表示則是:
\begin{align} \alpha &= P(F) \\ &= P(F|\mathrm{known})P(\mathrm{known}) + P(F|\mathrm{unknown})P(\mathrm{unknow}) \\ &= \frac{F_s}{F_s + B_s}(F_s + B_s) + \alpha_rU_s\\ &= F_s + U_s\alpha_r \end{align}

????????但上述公式存在一個(gè)缺陷,即如果待摳目標(biāo)外有大塊噪聲舰绘,則最終的預(yù)測(cè)也消除不了這個(gè)噪聲蹂喻,如下圖:

語義分割之后的前景帶有外部噪聲(衣服左側(cè)的小照片)

為了消除第一階段可能包含的外部噪聲葱椭,本文的在實(shí)現(xiàn) M-Net 的時(shí)候做了一個(gè)小的改動(dòng):第二階段的輸入改為由待摳圖像 + F_s 組成的 4 通道圖片(此時(shí),相當(dāng)于將 F_s 看成是三分圖 trimap)口四,并且將第二階段的預(yù)測(cè)則作為最終的預(yù)測(cè)孵运。另外,第二階段損失函數(shù)簡化為只用 alpha 通道的損失更好蔓彩。

一治笨、模型實(shí)現(xiàn)

????????Deep Image Matting 原文的模型如下:

Deep Image Matting 論文模型

模型先通過一個(gè)編碼器提取特征,之后經(jīng)過一個(gè)解碼器預(yù)測(cè)一個(gè)初始的 alpha 通道赤嚼,這個(gè)預(yù)測(cè)值效果已經(jīng)很好旷赖,但作者為了進(jìn)一步提升摳圖的精度,又額外的接了幾層細(xì)化的小網(wǎng)絡(luò)探膊,然后將細(xì)化后的輸出作為整個(gè)模型的最終輸出杠愧。具體來說,首先將待摳圖像(3 通道)以及事先準(zhǔn)備好的三分圖(trimap)合成一個(gè) 4 通道圖像逞壁,然后經(jīng)過 VGG16 的前 13 個(gè)卷積層以及之后的 1 個(gè)全連接層(看成是 1x1 的卷積層)流济,總共 14 個(gè)卷積層提取圖像特征(此時(shí)已做了 5 個(gè)最大池化,因此圖像分辨率下降了 32 倍腌闯,如果輸入是 320x320绳瘟,那么特征映射的分辨率就變成了 10x10),這是模型的編碼器階段姿骏。接下來對(duì)圖像特征進(jìn)行解碼糖声,即開始解碼器階段。解碼器使用 6 個(gè)卷積層(5x5 的卷積核)和 5 個(gè) 反池化層分瘦,每個(gè)反池化層將特征映射的分辨率提升 2 倍蘸泻,因此解碼器的輸出與模型輸入的大小一樣。這里嘲玫,使用反池化層的效果要比直接使用轉(zhuǎn)置卷積(deconvolution)的效果要好悦施。雖然他們都是為了提升圖像分辨率,但使用轉(zhuǎn)置卷積并不能很好的摳出細(xì)節(jié)去团,而使用反池化層卻可以摳圖頭發(fā)絲等非常細(xì)的前景抡诞。為了最求極致效果,作者又接了一個(gè)小網(wǎng)絡(luò)土陪,將待摳圖像和編碼器預(yù)測(cè)的 alpha 通道合成一個(gè) 4 通道圖像昼汗,然后通過 4 個(gè) 3x3 的卷積層得到細(xì)化后的 alpha 通道預(yù)測(cè),作為最后的輸出鬼雀。

Deep Image Matting 效果圖:(a) 原圖顷窒;(b) 編碼器-解碼器階段結(jié)果;(c)細(xì)化階段結(jié)果

????????損失方面源哩,總共用了 3 個(gè)分損失來合成網(wǎng)絡(luò)的損失:

  • 編碼器階段預(yù)測(cè)的 alpha 通道和真實(shí)的 alpha 通道的損失鞋吉;
  • 編碼器階段使用預(yù)測(cè)的 alpha 合成的圖像和真實(shí)的 alpha 合成的圖像的損失出刷;
  • 細(xì)化階段預(yù)測(cè)的 alpha 通道和真實(shí)的 alpha 通道的損失。

這些損失都是逐點(diǎn)損失坯辩,即平方和誤差:

預(yù)測(cè)與真實(shí) alpha 通道之間的損失
由前景馁龟、背景和 alpha 通道合成圖像之間的損失

三個(gè)損失使用加權(quán)和形成整個(gè)網(wǎng)絡(luò)最后反向傳播的總損失。

????????Deep Image Matting 雖然論文上報(bào)告的效果很驚人漆魔,但實(shí)際實(shí)現(xiàn)時(shí)(在個(gè)人應(yīng)用數(shù)據(jù)集上)泛化性能不夠理想坷檩。

????????Semantic Human Matting(SHM)這篇論文的 M-Net 在以上基礎(chǔ)上做了一些簡化和修改。首先改抡,為了防止網(wǎng)絡(luò)容量太大造成過擬合矢炼,SHM 只使用 VGG16 的前 13 個(gè)卷積層及 4 個(gè)最大池化層來作為編碼器,相應(yīng)的阿纤,解碼器階段也就少一個(gè)反池化層句灌。另外,為了加速網(wǎng)絡(luò)收斂欠拾,所有的卷積層(編碼器以及解碼器的)都帶批標(biāo)準(zhǔn)化(Batch Normalization)處理胰锌。其次,網(wǎng)絡(luò)的輸入由 4 通道變成了 6 通道藐窄,這樣做一方面沒有影響網(wǎng)絡(luò)性能(論文 4.2 節(jié))资昧,另一方面也是為了方便與 T-Net 對(duì)接,因?yàn)?T-Net 輸出 前景荆忍、背景格带、未知 3 個(gè)預(yù)測(cè)通道,與待摳圖像的 3 通道直接合成即得到 6 通道輸入刹枉。最后叽唱,SHM 直接去掉了 Deep Image Matting 網(wǎng)絡(luò)的細(xì)化階段,因此損失也相應(yīng)的減少為 2 個(gè)分損失微宝。

????????本文基本忠實(shí)的實(shí)現(xiàn)了 SHM 的 M-Net 結(jié)構(gòu)棺亭,但如本文開始時(shí)候說的那樣,將 6 通道的輸入改成了 4 通道芥吟,且為了完全引入 VGG16 的預(yù)訓(xùn)練模型侦铜,直接在 VGG16 的最前面接了一個(gè)輸入為 6 通道专甩、輸出為 4 通道的卷積層钟鸵。此外,本文將 M-Net 的預(yù)測(cè)作為最終的輸出涤躲,以及訓(xùn)練時(shí)不再求合成圖像的損失(以下模型實(shí)現(xiàn)時(shí)棺耍,loss 函數(shù)是支持合成圖像損失的)。

????????總的來說种樱,Deep Image Matting (或 M-Net)網(wǎng)絡(luò)是非常清晰明了的蒙袍,實(shí)現(xiàn)也很簡單俊卤,模型文件 model.py 如下:

# -*- coding: utf-8 -*-
"""
Created on Sun Jul 21 07:08:58 2019

@author: shirhe-lyh

Implementation of paper:
    Deep Image Matting, Ning Xu, eta., arxiv:1703.03872
"""

import torch
import torchvision as tv

VGG16_BN_MODEL_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'

VGG16_BN_CONFIGS = {
    '13conv':
        [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 
         'M', 512, 512, 512],
    '10conv':
        [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
    }


def make_layers(cfg, batch_norm=False):
    """Copy from: torchvision/models/vgg.
    
    Changs retrue_indices in MaxPool2d from False to True.
    """
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, 
                                          return_indices=True)]
        else:
            conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, torch.nn.BatchNorm2d(v), 
                           torch.nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, torch.nn.ReLU(inplace=True)]
            in_channels = v
    return torch.nn.Sequential(*layers)


class VGGFeatureExtractor(torch.nn.Module):
    """Feature extractor by VGG network."""
    
    def __init__(self, config=None, batch_norm=True):
        """Constructor.
        
        Args:
            config: The convolutional architecture of VGG network.
            batch_norm: A boolean indicating whether the architecture 
                include Batch Normalization layers or not.
        """
        super(VGGFeatureExtractor, self).__init__()
        self._config = config
        if self._config is None:
            self._config = VGG16_BN_CONFIGS.get('10conv')
        self.features = make_layers(self._config, batch_norm=batch_norm)
        self._indices = None
        
    def forward(self, x):
        self._indices = []
        for layer in self.features:
            if isinstance(layer, torch.nn.modules.pooling.MaxPool2d):
                x, indices = layer(x)
                self._indices.append(indices)
            else:
                x = layer(x)
        return x
    
    
def vgg16_bn_feature_extractor(config=None, pretrained=True, progress=True):
    model = VGGFeatureExtractor(config, batch_norm=True)
    if pretrained:
        state_dict = tv.models.utils.load_state_dict_from_url(
            VGG16_BN_MODEL_URL, progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model


class DIM(torch.nn.Module):
    """Deep Image Matting."""
    
    def __init__(self, feature_extractor):
        """Constructor.
        
        Args:
            feature_extractor: Feature extractor, such as VGGFeatureExtractor.
        """
        super(DIM, self).__init__()
        # Head convolution layer, number of channels: 4 -> 3
        self._head_conv = torch.nn.Conv2d(in_channels=4, out_channels=3,
                                          kernel_size=5, padding=2)
        # Encoder
        self._feature_extractor = feature_extractor
        self._feature_extract_config = self._feature_extractor._config
        # Decoder
        self._decode_layers = self.decode_layers()
        # Prediction
        self._final_conv = torch.nn.Conv2d(self._feature_extract_config[0], 1,
                                           kernel_size=5, padding=2)
        self._sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        x = self._head_conv(x)
        x = self._feature_extractor(x)
        indices = self._feature_extractor._indices[::-1]
        index = 0
        for layer in self._decode_layers:
            if isinstance(layer, torch.nn.modules.pooling.MaxUnpool2d):
                x = layer(x, indices[index])
                index += 1
            else:
                x = layer(x)
        x = self._final_conv(x)
        x = self._sigmoid(x)
        return x
    
    def decode_layers(self):
        layers = []
        strides = [1]
        channels = []
        config_reversed = self._feature_extract_config[::-1]
        for i, v in enumerate(config_reversed):
            if v == 'M':
                strides.append(2)
                channels.append(config_reversed[i+1])
        channels.append(channels[-1])
        in_channels = self._feature_extract_config[-1]
        for stride, out_channels in zip(strides, channels):
            if stride == 2:
                layers += [torch.nn.MaxUnpool2d(kernel_size=2, stride=2)]
            layers += [torch.nn.Conv2d(in_channels, out_channels,
                                       kernel_size=5, padding=2),
                       torch.nn.BatchNorm2d(num_features=out_channels),
                       torch.nn.ReLU(inplace=True)]
            in_channels = out_channels
        return torch.nn.Sequential(*layers)
    
    
def loss(alphas_pred, alphas_gt, images=None, epsilon=1e-12):
    losses = torch.sqrt(
        torch.mul(alphas_pred - alphas_gt, alphas_pred - alphas_gt) + 
        epsilon)
    loss = torch.mean(losses)
    if images is not None:
        images_fg_gt = torch.mul(images, alphas_gt)
        images_fg_pred = torch.mul(images, alphas_pred)
        images_fg_error = images_fg_pred - images_fg_gt
        losses_image = torch.sqrt(
            torch.mul(images_fg_error, images_fg_error) + epsilon)
        loss += torch.mean(losses_image)
    return loss

????????Pytorch 的官方是帶有 VGG 系列模型的,使用也很方便害幅,比如使用帶批標(biāo)準(zhǔn)化層的 VGG16 直接寫為:

vgg = torchvision.models.vgg16_bn(pretrained=True)

其中 pretrained=True 表示導(dǎo)入在 ImageNet 上預(yù)訓(xùn)練的參數(shù)消恍。但因?yàn)椋覀兪侵皇褂?VGG16 的前 13 個(gè)卷積層以现,而不需要后面的全連接層狠怨,因此,不會(huì)像上面那樣直接使用邑遏,而是要從 torchvision 的官方實(shí)現(xiàn)中截取卷積層的部分(官方實(shí)現(xiàn)見文件 Python 安裝路徑下的 site-packages/torchvision/models/vgg.py)佣赖。我們主要復(fù)制該文件中的 make_layers 函數(shù),但因?yàn)楹竺娼獯a器階段要用反池化记盒,所以還要做一些修改:要把

layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

改為

layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)]

之所以要加上 return_indices=True憎蛤,是因?yàn)楹竺娣闯鼗瘜右玫竭@些池化層的池化過程中的最大值的下標(biāo)(從而要記下來)。也正因?yàn)槌鼗瘜佣喾祷亓艘粋€(gè)值(同時(shí)返回特征映射和最大值下標(biāo)張量)纪吮,因此在重載 forward 函數(shù)時(shí)要進(jìn)行如下的區(qū)別對(duì)待:

def forward(self, x):
    self._indices = []
    for layer in self.features:
        if isinstance(layer, torch.nn.modules.pooling.MaxPool2d):
            x, indices = layer(x)
            self._indices.append(indices)
        else:
            x = layer(x)
    return x

除此之外俩檬,編碼器階段都是非常簡單的,無需贅言碾盟。

????????來看解碼器階段豆胸。只需要重點(diǎn)關(guān)注一下反池化層(接后續(xù)的卷積層)的實(shí)現(xiàn)即可。具體也很簡單:

unpool = torch.nn.MaxUnpool2d(kernel_size=2, stride=2)
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)

即巷疼,先用反池化操作提升 2 倍分辨率晚胡,然后再接一個(gè)普通卷積層(可選操作:批標(biāo)準(zhǔn)化、整流線性單元)嚼沿。實(shí)際前向傳播時(shí)的計(jì)算如下:

x = unpool(x, indices)
x = conv(x)

其中估盘,indices 是編碼器階段對(duì)應(yīng)的池化層返回的最大池化操作返回的最大值下標(biāo)。反池化層與轉(zhuǎn)置卷積層都是可以訓(xùn)練的(都帶有參數(shù))骡尽,作用也幾乎相同(提升分辨率)遣妥,但對(duì)于摳圖這個(gè)任務(wù)來說,最關(guān)心的就是目標(biāo)的邊界區(qū)域攀细,而這些邊界因?yàn)槎际乔熬绑锊取⒈尘暗慕唤鐓^(qū),因此表現(xiàn)在特征映射的響應(yīng)上谭贪,就基本都是局部極大值境钟,從而在池化操作時(shí),返回的最大值下標(biāo)就基本完整的記錄下了待摳目標(biāo)的邊界俭识,反池化操作因?yàn)闀?huì)重點(diǎn)關(guān)注這些區(qū)域慨削,所以效果較好。

????????DIM 類的 decode_layers 函數(shù)就是整個(gè)的解碼器的定義。它看起來有點(diǎn)晦澀缚态,其實(shí)不難理解磁椒。我們看編碼器階段的配置:

[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]

其中的 M 表示最大池化層,其它的數(shù)字就是全部的 13 個(gè)卷積層對(duì)應(yīng)的輸出通道數(shù)玫芦。解碼器執(zhí)行的操作基本上就是以上操作的逆過程:

[512, 'U', 512, 'U', 256, 'U', 128, 'U', 64]

其中的 U 表示 torch.nn.MaxUnpool2d 反池化操作浆熔。

二、訓(xùn)練過程

????????本節(jié)訓(xùn)練數(shù)據(jù)來源于愛分割開源的 數(shù)據(jù)集桥帆。該數(shù)據(jù)集內(nèi)的所有 34426 張圖片都類似于如下的上身模特圖:

愛分割開源數(shù)據(jù)集實(shí)例圖片

可以明顯看到標(biāo)注的 alpha 通道是非常粗陋的蘸拔,遠(yuǎn)達(dá)不到頭發(fā)絲的精度。因?yàn)檫@個(gè)數(shù)據(jù)集直接給出了原始圖像环葵,所以不需要前景调窍、背景圖片的合成。

數(shù)據(jù)準(zhǔn)備

????????當(dāng)你下載好愛分割開源數(shù)據(jù)集(并解壓)之后张遭,我們需要一次性將所有圖片的掩碼(mask)都準(zhǔn)備好邓萨,因此你需要打開 data/retrieve.py 文件,將 root_dir 改成你的 Matting_Human_Half 文件夾的路徑菊卷,然后執(zhí)行 retrieve.py 等待生成所有圖片的 alpha 和 mask(在 Matting_Human_Half 文件夾內(nèi))缔恳,以及用于訓(xùn)練的 train.txtval.txt(在 data 文件夾內(nèi),其中默認(rèn)隨機(jī)選擇 100 張圖像用于驗(yàn)證)洁闰。假如歉甚,你訓(xùn)練時(shí)不再改動(dòng) Matting_Human_Half 文件夾的路徑,那么你不需要再做其它處理了扑眉。如果你訓(xùn)練時(shí)纸泄,Matting_Human_Half 與以上制作 train.txt 和 val.txt 時(shí)指定的 root_dir 路徑不一致了,那么你可以使用諸如 Notepad ++ 之類的工具腰素,將 root_dir 替換為空聘裁,形成如下的形式:

去掉 root_dir 的標(biāo)注文件

????????train.txt 和 val.txt 分別記錄了訓(xùn)練和驗(yàn)證圖像的路徑,每一行對(duì)應(yīng)一張圖像的 4 個(gè)路徑弓千,分別是 原圖像路徑(3 通道)衡便、透明圖路徑(4 通道)、alpha 通道圖像路徑洋访、mask 路徑镣陕,它們通過 @ 符號(hào)分隔。

訓(xùn)練

????????直接在命令行執(zhí)行:

python3 train.py --root_dir "xxx/Matting_Human_Half" [--gpu_indices 0 1 ...]

開始訓(xùn)練姻政,如果你從制作數(shù)據(jù)時(shí)開始呆抑, Matting_Human_Half 這個(gè)文件夾的路徑始終沒有改動(dòng)過,那么 root_dir 這個(gè)參數(shù)也可以不指定(指定也無妨)扶歪。后面的 [--gpu_indices ...] 表示需要根據(jù)實(shí)際情況理肺,可選的指定可用的 GPU 下標(biāo),這里默認(rèn)是使用 0,1,2,3 共 4 塊 GPU善镰,如果你使用一塊 GPU妹萨,則指定

--gpu_indices 0

如果使用多塊,比如使用 第 1 塊和第 3 塊 GPU炫欺,則指定

--gpu_indices 1 3

即可乎完。其它類似。訓(xùn)練過程中的所有超參數(shù)都在 train.py 文件的開頭部分品洛,可以直接修改默認(rèn)值或通過命令行指定树姨。

????????訓(xùn)練開始幾分鐘后,你在項(xiàng)目路徑下執(zhí)行:

tensorboard --logdir ./models/logs

可以打開瀏覽器查看訓(xùn)練的學(xué)習(xí)率桥状、損失曲線帽揪,和訓(xùn)練過程中的分割結(jié)果圖像。這里使用的是 Pytorch 自帶的類:from torch.utils.tensorboard import SummaryWriter 來調(diào)用 tensorboard辅斟,因此需要 Pytorch 1.1.0 以及之后的版本才可以转晰。(但好像瀏覽器刷新不了新結(jié)果,需要不斷重開 tensorboard 才可以觀看訓(xùn)練進(jìn)展)

????????訓(xùn)練結(jié)束后(默認(rèn)訓(xùn)練 30 個(gè) epoch)士飒,在 models 文件夾中保存了訓(xùn)練過程中的模型參數(shù)文件(模型使用參考 predict.py)查邢。直接執(zhí)行:

python3 predict.py

將在 test 文件夾里生成測(cè)試圖片的摳圖結(jié)果。

其它數(shù)據(jù)集上訓(xùn)練

????????訓(xùn)練 Pytorch 模型時(shí)酵幕,需要重載 torch.utils.data.Dataset 類扰藕,用來提供數(shù)據(jù)的批量生成。重載時(shí)芳撒,只需要實(shí)現(xiàn) __ init __, __ getitem __, __ len __ 這三個(gè)函數(shù)邓深。在這個(gè)項(xiàng)目里,我們使用的是 dataset.py 的重載類 MattingDataset笔刹。讀者可以按照自己的方式依據(jù)自己的標(biāo)注格式來重載庐完,也可以依照 MattingDataset 來改寫。

????????對(duì)于只提供前徘熔、背景分離的數(shù)據(jù)门躯,建議先一次性提前合成好合成圖像,和制作好 alpha 通道圖像酷师。此時(shí)你就可以適當(dāng)修改一下 get_image_mask_paths 函數(shù)即可讶凉。這個(gè)函數(shù)需要返回一個(gè)如下格式的列表

[[image_path, alpha_path],
 [image_path, alpha_path],
...
 [image_path, alpha_path]]

另外,__ getitem __ 函數(shù)數(shù)據(jù)增強(qiáng)的方式裁剪山孔、縮放懂讯、水平翻轉(zhuǎn)(以及 alpha 通道隨機(jī)膨脹腐蝕),如果你還有其它的處理方式請(qǐng)自行添加或刪減台颠。另外褐望,這里指定的裁剪尺寸:

crop_sizes = [320, 480, 600, 800]

是根據(jù)愛分割提供的數(shù)據(jù)來劃定的勒庄,里面所有圖片都是 600x800 的分辨率。一般來說瘫里,根據(jù) Deep Image Matting 論文实蔽,是從 320 開始,每隔 160 像素的尺寸裁剪谨读,最后統(tǒng)一縮放到 320 即可局装。

三、Deep Image Matting 數(shù)據(jù)集上的復(fù)現(xiàn)

????????本節(jié)將在 Deep Image Matting 數(shù)據(jù)集上進(jìn)行訓(xùn)練(訓(xùn)練數(shù)據(jù)可聯(lián)系論文作者獲壤椭场)铐尚,因部分參數(shù)未仔細(xì)調(diào)整,訓(xùn)練結(jié)果并非最優(yōu)哆姻。使用的背景圖像集是 COCO/train2017宣增。

數(shù)據(jù)準(zhǔn)備

????????我們將前景圖像背景圖像通過 alpha 通道合成訓(xùn)練集。假設(shè)你已經(jīng)獲取了 DIM 數(shù)據(jù)集矛缨,那么進(jìn)入 data_dim 文件夾碘裕,打開 composition.py肪凛,將 root_dir 替換成你保持 Combined_Dataset 文件夾的路徑森枪,bg_image_root_dir 填寫背景圖像的文件夾路徑柄沮,output_dir 填寫合成圖像的保持文件夾路徑。num_bg_images_per_fg 表示一張前景圖像對(duì)應(yīng)多少張背景圖像盟广,論文里這個(gè)值取 100(這里為了減少訓(xùn)練時(shí)間闷串,我取的是 50,讀者根據(jù)具體情況修改)筋量。當(dāng)這些值都確認(rèn)無誤后烹吵,執(zhí)行 composition.py,將花費(fèi)很長一段時(shí)間來合成圖片桨武。合成圖像結(jié)束后肋拔,在當(dāng)前路徑下生成 train.txtval.txt 兩個(gè)標(biāo)注文件。train.txt 文件里每一行對(duì)應(yīng) 4 個(gè)路徑呀酸,分別是合成圖像路徑凉蜂、前景圖像路徑、alpha 通道路徑性誉、背景圖像路徑窿吩,val.txt 里除了缺少合成圖像路徑之外,其它順序一致错览。

????????訓(xùn)練時(shí)纫雁,我們只需要 train.txt 中的第 1 個(gè)和第 3 個(gè)路徑,因此和訓(xùn)練愛分割開源數(shù)據(jù)集的標(biāo)注文件格式一致倾哺,從而可以共用同一個(gè) dataset.py轧邪,只需要確保 dataset.py 里面的函數(shù) __ getitem __ 中的

crop_sizes = [320, 480, 640]

即可刽脖。

訓(xùn)練過程

????????執(zhí)行

python3 train.py --annotation_path "./data_dim/train.txt" [gpu_indices 0 1 2 ...]

開始訓(xùn)練,期間可通過 tensorboard 查看訓(xùn)練進(jìn)程忌愚。訓(xùn)練的所有超參數(shù)可在 train.py 內(nèi)修改曲管,也可以通過命令行直接指定

訓(xùn)練期間的 tensorboard 摳圖結(jié)果展示
訓(xùn)練期間的 tensorboard 學(xué)習(xí)率和損失曲線

????????本次訓(xùn)練菜循,超參數(shù)采用的都是 train.py 中的默認(rèn)值翘地,由損失曲線可以看到申尤,如果再繼續(xù)訓(xùn)練(已訓(xùn)練 200 epoch)癌幕,損失會(huì)進(jìn)一步下降,摳圖效果會(huì)更好昧穿。

結(jié)果展示

????????訓(xùn)練結(jié)束后勺远,執(zhí)行(如果模型保存路徑是默認(rèn)的 ./models,否則需要修改一下 ckpt_path):

python3 predict_trimap.py

會(huì)在 data_dim/test 文件夾里生成預(yù)測(cè)結(jié)果(見文件夾 preds):

合成圖片时鸵、trimap胶逢、摳圖結(jié)果、GT alpha
合成圖片饰潜、trimap初坠、摳圖結(jié)果、GT alpha

????????從以上摳圖結(jié)果可看出彭雾,在某些細(xì)節(jié)上效果還不理想碟刺。如果你想獲得更好的結(jié)果,一方面可以在合成訓(xùn)練圖片時(shí)提升 1 張前景對(duì)應(yīng)的背景圖片數(shù)(我取的是 1:50薯酝,論文是 1:100)半沽;其次,仔細(xì)調(diào)整學(xué)習(xí)率及其衰減吴菠;再次者填,增加訓(xùn)練的輪次(epoch,我訓(xùn)練了 200 個(gè) epoch)等做葵。

附錄
快速上手 Pytorch 的資料:PyTorch Tutorial for Deep Learning Researchers占哟。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市酿矢,隨后出現(xiàn)的幾起案子重挑,更是在濱河造成了極大的恐慌,老刑警劉巖棠涮,帶你破解...
    沈念sama閱讀 206,482評(píng)論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件谬哀,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡严肪,警方通過查閱死者的電腦和手機(jī)史煎,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,377評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門谦屑,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人篇梭,你說我怎么就攤上這事氢橙。” “怎么了恬偷?”我有些...
    開封第一講書人閱讀 152,762評(píng)論 0 342
  • 文/不壞的土叔 我叫張陵悍手,是天一觀的道長。 經(jīng)常有香客問我袍患,道長坦康,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,273評(píng)論 1 279
  • 正文 為了忘掉前任诡延,我火速辦了婚禮滞欠,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘肆良。我一直安慰自己筛璧,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,289評(píng)論 5 373
  • 文/花漫 我一把揭開白布惹恃。 她就那樣靜靜地躺著夭谤,像睡著了一般。 火紅的嫁衣襯著肌膚如雪巫糙。 梳的紋絲不亂的頭發(fā)上朗儒,一...
    開封第一講書人閱讀 49,046評(píng)論 1 285
  • 那天,我揣著相機(jī)與錄音曲秉,去河邊找鬼采蚀。 笑死,一個(gè)胖子當(dāng)著我的面吹牛承二,可吹牛的內(nèi)容都是我干的榆鼠。 我是一名探鬼主播,決...
    沈念sama閱讀 38,351評(píng)論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼亥鸠,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼妆够!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起负蚊,我...
    開封第一講書人閱讀 36,988評(píng)論 0 259
  • 序言:老撾萬榮一對(duì)情侶失蹤神妹,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后家妆,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體鸵荠,經(jīng)...
    沈念sama閱讀 43,476評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 35,948評(píng)論 2 324
  • 正文 我和宋清朗相戀三年伤极,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了蛹找。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片姨伤。...
    茶點(diǎn)故事閱讀 38,064評(píng)論 1 333
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖庸疾,靈堂內(nèi)的尸體忽然破棺而出乍楚,到底是詐尸還是另有隱情,我是刑警寧澤届慈,帶...
    沈念sama閱讀 33,712評(píng)論 4 323
  • 正文 年R本政府宣布徒溪,位于F島的核電站,受9級(jí)特大地震影響金顿,放射性物質(zhì)發(fā)生泄漏臊泌。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,261評(píng)論 3 307
  • 文/蒙蒙 一串绩、第九天 我趴在偏房一處隱蔽的房頂上張望缺虐。 院中可真熱鬧芜壁,春花似錦礁凡、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,264評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至塞淹,卻和暖如春窟蓝,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背饱普。 一陣腳步聲響...
    開封第一講書人閱讀 31,486評(píng)論 1 262
  • 我被黑心中介騙來泰國打工运挫, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人套耕。 一個(gè)月前我還...
    沈念sama閱讀 45,511評(píng)論 2 354
  • 正文 我出身青樓谁帕,卻偏偏與公主長得像,于是被迫代替她去往敵國和親冯袍。 傳聞我的和親對(duì)象是個(gè)殘疾皇子匈挖,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,802評(píng)論 2 345