FCN:Fully Convolutional Networks for Semantic Segmentation的閱讀與pytorch實(shí)現(xiàn)

作 者: 心有寶寶人自圓

聲 明: 歡迎轉(zhuǎn)載本文中的圖片或文字,請(qǐng)說(shuō)明出處

寫在前面

本篇文章介紹的FCN是語(yǔ)義分割(Semantic Segmentation)之中Fully Convolutional Network結(jié)構(gòu)流派的開山鼻祖啊片,以至于之后的語(yǔ)義分割研究基本采取了這種結(jié)構(gòu)僻肖。

語(yǔ)義分割的目標(biāo)是為每圖片中的每一個(gè)pixel進(jìn)行類別的預(yù)測(cè)(Dense Prediction)

本文的主體內(nèi)容十分容易理解涕烧,但是一些作者介紹的tricks讓人看得云里霧里的(關(guān)鍵這些tricks作者最后一般都沒使用??),所以對(duì)應(yīng)FCN的原理理解可以忽略這些tricks(但我還會(huì)分享一些理解??)

最后我給出了我的代碼和結(jié)果(pytorch),訓(xùn)練這個(gè)真是太難了??以下坑點(diǎn)萬(wàn)分注意:

a) resize時(shí),插值的方法一定要選擇NEAREAST而不是默認(rèn)的Bilinear便锨,否則會(huì)對(duì)true label image的pixel進(jìn)行誤標(biāo)

b) 一定要充足的耐心進(jìn)行訓(xùn)練,不然你的分割圖像一直是黑的(沒用非極大抑制大概80 epochs我碟,使用后大概40 epochs)

c) 關(guān)于不同的loss:loss的設(shè)計(jì)可能會(huì)出現(xiàn)梯度爆炸的現(xiàn)象(若非loss的設(shè)計(jì)問(wèn)題)放案,batch size不要設(shè)太大;有時(shí)候loss的設(shè)計(jì)實(shí)際影響了收斂時(shí)間(就是分割圖像一直是黑持續(xù)時(shí)間)

1. Introduction

語(yǔ)義分割的目標(biāo)就是要為每個(gè)像素做出預(yù)測(cè)矫俺,每個(gè)像素要被標(biāo)記為包含它的目標(biāo)的類別吱殉。而FCN是第一個(gè)使用end-to-end掸冤,pixel-to-pixel訓(xùn)練的語(yǔ)義分割方法。FCN能使用任意大小的圖像作為輸入(去除了網(wǎng)絡(luò)中的fully connected layers)考婴,進(jìn)行密集預(yù)測(cè)贩虾。學(xué)習(xí)特征和推斷分別通過(guò)feedforward(下采樣)和backpropagation(上采樣)進(jìn)行催烘,這樣的結(jié)構(gòu)特征使網(wǎng)絡(luò)可以進(jìn)行pixelwise預(yù)測(cè)沥阱。

作者介紹了語(yǔ)義分割一種內(nèi)在矛盾:全局信息(global/semantic information)和位置信息(location information)。他們分別代表network高層和低層的特征伊群,作者形象的稱號(hào)它們?yōu)閣hat和where考杉。高層的信息經(jīng)過(guò)了下采樣所以更能代表類別信息,而低層則包含了目標(biāo)的細(xì)節(jié)信息舰始,而語(yǔ)義分割則需要全局信息和位置信息的共同編碼崇棠,否則在目標(biāo)邊緣的預(yù)測(cè)會(huì)變得很不準(zhǔn)卻。為了解決這一問(wèn)題作者設(shè)計(jì)了一種跳躍結(jié)構(gòu)(Skip Architect)丸卷,進(jìn)而結(jié)合了兩種信息

2. Related Work

在FCN提出之前枕稀,語(yǔ)義分割基本是基于pitch-wise訓(xùn)練的(fine-tuned R-CNN system),以選擇性搜索選取一定大小的proposal region谜嫉,使用CNN提取proposal region的特征傳入分類器萎坷。這樣的操著并非end-to-end,proposal region的大小一般是先驗(yàn)指定(限制了模型感受野的大小沐兰,使其僅對(duì)某些scale的特征敏感)哆档,隨機(jī)選取的proposal region可能高度重疊而造成計(jì)算、存儲(chǔ)資源的過(guò)多消耗住闯。

3. Fully Convolution Networks

f表示卷積或池化瓜浸,x為輸入,y為輸出比原,k是kernel size插佛,s是stride or subsampling factor;下式則表示連續(xù)的卷積或池化可以合成等效的一步(當(dāng)然非線性的激活函數(shù)也可以代表量窘,但它對(duì)下采樣過(guò)程沒有作用)朗涩。這個(gè)公式也可以說(shuō)明為什么5x5,stride=1的卷積可以轉(zhuǎn)化成2個(gè)3x3绑改,stride=1的卷積

)

  • 關(guān)于損失函數(shù):

    每一圖像的損失是每個(gè)空間空間點(diǎn)的損失之和?l(x;\theta)=\sum_{ij}l^{'}(x_{ij};\theta)

    因此每個(gè)圖像的隨機(jī)梯度下降等于每個(gè)空間點(diǎn)的梯度下降之和

3.1 改編分類器以適應(yīng)密集預(yù)測(cè)

傳統(tǒng)的分類器(全連接層)使用固定大小的輸入谢床,產(chǎn)生非空間性的輸出,因此全連接層被認(rèn)為是固定size厘线、丟棄了空間信息识腿;然而全連接也可以視為kernels覆蓋了整個(gè)輸入的卷積層(這樣就可以將全連接層與卷積層相互轉(zhuǎn)換),而卷積層可接受任意大小的輸入造壮,輸出分類maps渡讼。使用卷積層代替全連接能帶來(lái)更高的計(jì)算效率

2.PNG

然而輸出的分類maps(粗糙輸出)的維度由于經(jīng)過(guò)下采樣而比原始輸入的維度更小

3.2 Shift-and-stitch是濾波稀疏

為了將全卷積網(wǎng)絡(luò)的粗糙輸出轉(zhuǎn)化到原始空間的密集預(yù)測(cè)骂束,作者引入了input shifting(輸入平移)和output interlacing(輸出交織)的trick(然而這并非作者最終選擇使用的上采樣策略??)

給定下采樣因子f(stride),將將原始輸入分別從左上(0成箫,0)開始展箱,分別向右和向下平移[0,f-1]個(gè)像素蹬昌,共得到f^2個(gè)輸入分別通過(guò)全卷積網(wǎng)絡(luò)產(chǎn)生f^2個(gè)output混驰,將這些結(jié)果交織在一起就能得到原始輸入空間大小的輸出,這樣的預(yù)測(cè)結(jié)果與感受野中心像素有關(guān)皂贩∑苷ィ可以看出shift-and-stitch與傳統(tǒng)的上采樣方法(如雙線性插值)是不一樣的。然而這種做法并沒有真正利用到低層更細(xì)節(jié)的信息

之后作者有想出了一個(gè)trick:縮小卷積核(等同于對(duì)原始圖像進(jìn)行上采樣)明刷,同樣可以達(dá)到輸出的維度與輸入的維度相同婴栽。然而這種做法導(dǎo)致卷積層的感受野過(guò)小、更長(zhǎng)的計(jì)算時(shí)間

3.3 上采樣是反向的卷積

在神經(jīng)網(wǎng)絡(luò)里辈末,一個(gè)關(guān)于上采樣的自然想法便是反向傳播愚争,所以作者就采用反卷積(deconvolution)的方法進(jìn)行上采樣。deconvolution中的卷積轉(zhuǎn)置層的參數(shù)是可學(xué)習(xí)的挤聘,然而在作者的倉(cāng)庫(kù)中轰枝,設(shè)定其為固定值(作者實(shí)際使用了雙線性插值的方法)

3.4 Patchwise training是損失的采樣

在隨機(jī)優(yōu)化中,梯度的計(jì)算實(shí)際是由訓(xùn)練的分布驅(qū)動(dòng)的檬洞。Patchwising training和fully-conv training都可以產(chǎn)生任意的分布(即使它們的效率與重疊部分和小批量的大小有關(guān))狸膏。通常來(lái)說(shuō)后者比前者的效率更高(更少的batches)。

對(duì)于patchwise training的采樣可以減少類別不平衡和緩解空間的相關(guān)性添怔;在fully-conv training中湾戳,類別的平衡和緩解空間的相關(guān)性可以通過(guò)對(duì)loss的加權(quán)或下采樣loss得到。然而4.3節(jié)的結(jié)果表明下采樣并沒有對(duì)結(jié)果產(chǎn)生顯著的影響(類別不平衡為對(duì)FCN并不重要)广料,僅加快了收斂的速度

4 分割結(jié)構(gòu)

(遷移學(xué)習(xí)+微調(diào))

從預(yù)訓(xùn)練網(wǎng)絡(luò)的全連接層截?cái)嗟木W(wǎng)絡(luò)砾脑,之前的網(wǎng)絡(luò)直接使用,全連接層轉(zhuǎn)換為卷積層(除GoogLeNet)艾杏。

把最后的輸出層換為輸出通道為類別數(shù)的1x1卷積層

(輸入在以上結(jié)構(gòu)的向前傳播的結(jié)果稱為coarse output)

網(wǎng)絡(luò)中加入反卷積層進(jìn)行上采樣(實(shí)際是固定的雙線性插值)

作者提出了一種新穎的skip architect(跳躍結(jié)構(gòu))韧衣,結(jié)合了高層的位置信息和低層的細(xì)節(jié)信息

3.PNG
  • FCN-32s:將coarse out通過(guò)deconv(雙線性插值)直接上采樣32倍

  • FCN-16s:將coarse out通過(guò)deconv(雙線性插值)上采樣2倍;使用1x1卷積層處理pool4的輸出使其輸出通道為類別數(shù)(額外的預(yù)測(cè)器)购桑;將前兩步結(jié)果相加(為方便記作coarse out 2x)后通過(guò)deconv(雙線性插值)上采樣16倍

  • FCN-8s:將coarse out 2x通過(guò)deconv(雙線性插值)上采樣2倍畅铭;使用1x1卷積層處理pool3的輸出使其輸出通道為類別數(shù)(額外的預(yù)測(cè)器);將前兩步結(jié)果相加后通過(guò)deconv(雙線性插值)上采樣8倍

當(dāng)繼續(xù)采用更低層輸出的跳躍結(jié)構(gòu)后勃蜘,模型遇到了衰減回饋(diminishing returns)硕噩,不能對(duì)meanIoU等指標(biāo)產(chǎn)生明顯的改善,因此跳躍結(jié)構(gòu)僅到8s就截止了缭贡。

實(shí)驗(yàn)框架

  • 優(yōu)化器:SGD with momentum=0.9炉擅,weight decay=5^{-4} or 2^{-4}(盡管訓(xùn)練對(duì)這些參數(shù)不敏感辉懒,但對(duì)learning rate敏感),10^{-3},10^{-4},5^{-5} for FCN-AlexNet, Vgg-16, GoogLeNet,原分類器中轉(zhuǎn)化來(lái)的卷積層使用Dropout

  • 微調(diào):需花費(fèi)很長(zhǎng)時(shí)間谍失,由FCN-32s(微調(diào)時(shí)作者用了3天......)向16s和8s微調(diào)

  • Patch sampling:使用整個(gè)圖像進(jìn)行訓(xùn)練的效果和sampling patches的效果差不多眶俩,且整個(gè)圖像進(jìn)行訓(xùn)練需要的收斂時(shí)間更短,所以直接使用完整圖像進(jìn)行訓(xùn)練

  • Class Balancing:正負(fù)類的不平衡(背景類為負(fù)類)對(duì)訓(xùn)練的效果沒有顯著影響(所以作者直接使用了所有像素計(jì)算loss快鱼,而沒有進(jìn)行hard negative mining)

  • Dense Prediction:采用deconv(雙線性插值)進(jìn)行上采樣颠印,而未使用3節(jié)中其他trick

  • 數(shù)據(jù)增強(qiáng):隨機(jī)鏡像和縮小輸入的scale(增強(qiáng)網(wǎng)絡(luò)對(duì)小尺度目標(biāo)的能力)并未產(chǎn)生顯著的效果提升

  • 更多的訓(xùn)練數(shù)據(jù):更好的效果

5. My codes

我使用的是PASCAL VOC2012的數(shù)據(jù)集,按其劃分好的trainval來(lái)進(jìn)行訓(xùn)練

為每個(gè)分割圖像進(jìn)行標(biāo)注:每個(gè)pixel表為對(duì)應(yīng)的類別(0-20攒巍,0代表背景)

# 每個(gè)RGB顏色的值及其標(biāo)注的類別
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

# CLASSES_LABEL = {k: v for k, v in zip(VOC_COLORMAP, VOC_CLASSES)}

# 為每個(gè)(R, G, B)組合分配類別
colormap2label = torch.zeros((256, 256, 256), dtype=torch.long)
for i, color in enumerate(VOC_COLORMAP):
    colormap2label[color[0], color[1], color[2]] = i


def get_pixel_label(segmentation_image):
    """
    為分割標(biāo)記圖像的每個(gè)像素分配類別標(biāo)簽
    :param segmentation_image: 標(biāo)記圖像嗽仪,a PIL image
    :return: a tensor of (image.height, image.width)荒勇,為每個(gè)像素分配了類別標(biāo)簽
    """
    cmap = np.array(segmentation_image.convert('RGB'), dtype=np.uint8)

    cmap = colormap2label[
        cmap[:, :, 0].flatten().tolist(), cmap[:, :, 1].flatten().tolist(), cmap[:, :, 2].flatten().tolist()].reshape(
        cmap.shape[0], cmap.shape[1])
    return cmap

網(wǎng)絡(luò)的結(jié)構(gòu)

這里只列出了FCN32s和FCN8s柒莉,使用的是Vgg-16預(yù)訓(xùn)練模型(注意deconv的權(quán)重初始化雙線性插值,不再對(duì)其權(quán)重進(jìn)行學(xué)習(xí))

import torch
from torch import nn
import torchvision
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_bilinear_weights(in_channels, out_channels, kernel_size):
    """
    構(gòu)造雙線性插值的上采樣的權(quán)重
    :param in_channels: 轉(zhuǎn)置卷積層的輸入通道數(shù)
    :param out_channels: 轉(zhuǎn)置卷積層的輸出通道數(shù)
    :param kernel_size: 轉(zhuǎn)置卷積層的卷積核大小
    :return: 權(quán)重, a tensor in shape of (in_channels, out_channels , kernel_size, kernel_size)
    """
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1  # array從0開始以需要-1
    else:
        center = factor - 0.5  # center = factor + 0.5 - 1
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float32)
    weight[range(in_channels), range(out_channels), :, :] = filt  # 只對(duì)對(duì)角線上核的值進(jìn)行替換
    return torch.from_numpy(weight)


class FCN32s(nn.Module):
    def __init__(self, n_classes):
        super(FCN32s, self).__init__()

        # 直接使用Vgg-16預(yù)訓(xùn)練網(wǎng)絡(luò)沽翔,拋棄classifier層兢孝,并把fc層轉(zhuǎn)換為卷積層
        # fc6轉(zhuǎn)化為conv6,使用的卷積核大小為7x7仅偎,該層輸出長(zhǎng)度有6個(gè)像素的損失跨蟹,
        # 向上采樣32倍即原始空間192個(gè)像素的損失,因而小于192x192的輸入會(huì)導(dǎo)致報(bào)錯(cuò)
        # 同時(shí)這些像素?fù)p失必需通過(guò)padding使上采樣的空間大小與原輸入空間一致
        # 其實(shí)這個(gè)值可以屬于(96,112)都能達(dá)到以上效果

        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=100)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.conv6 = nn.Conv2d(512, 4096, kernel_size=7)
        self.dropout6 = nn.Dropout2d()

        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        self.dropout7 = nn.Dropout2d()

        self.load_pretrained_layers()

        self.score = nn.Conv2d(4096, n_classes, 1)

        # 此處的kernel_size我認(rèn)為是作者主觀選擇的橘沥,默認(rèn)是下采樣率的2倍
        self.upsample = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=64, stride=32, bias=False)

        self.upsample.weight.data = get_bilinear_weights(n_classes, n_classes, kernel_size=64)
        self.upsample.weight.requires_grad = False

    def forward(self, x):
        # 我們假設(shè)輸入圖片的height, width均為能被32整除
        out = torch.relu(self.conv1_1(x))  # (b, 64, h+198, w+198)
        out = torch.relu(self.conv1_2(out))  # (b, 64, h+198, w+198)
        out = self.pool1(out)  # (b, 64, h/2 + 99, w/2 +99)

        out = torch.relu(self.conv2_1(out))  # (b, 128, h/2+99, w+99)
        out = torch.relu(self.conv2_2(out))  # (b, 128, h/2+99, w+99)
        out = self.pool2(out)  # (b, 128, h/4 + 49, w/4 + 49)

        out = torch.relu(self.conv3_1(out))
        out = torch.relu(self.conv3_2(out))
        out = torch.relu(self.conv3_3(out))
        out = self.pool3(out)  # (b, 256, h/8 + 24, w/18 + 24)

        out = torch.relu(self.conv4_1(out))
        out = torch.relu(self.conv4_2(out))
        out = torch.relu(self.conv4_3(out))
        out = self.pool4(out)  # (b, 512, h/16 + 12, w/16 + 12)

        out = torch.relu(self.conv5_1(out))
        out = torch.relu(self.conv5_2(out))
        out = torch.relu(self.conv5_3(out))
        out = self.pool5(out)  # (b, 512, h/32 + 6, w/32 + 6)

        out = torch.relu(self.conv6(out))  # (b, 512, h/32, w/32)
        out = self.dropout6(out)

        out = torch.relu(self.conv7(out))
        out = self.dropout7(out)

        out = self.score(out)

        # 由于轉(zhuǎn)置卷積的卷積核大小使上采樣32倍后比原始size大了(kernel_size - stride)
        out = self.upsample(out)  # (b, n_classes, h+32, w+32)

        return out[:, :, 16:16 + x.shape[2], 16:16 + x.shape[3]].contiguous()

def load_pretrained_layers(self):
        state_dict = self.state_dict()
        param_names = list(state_dict.keys())

        pretrained_state_dict = torchvision.models.vgg16(pretrained=True).state_dict()
        pretrained_param_names = list(pretrained_state_dict.keys())

        for i, param in enumerate(param_names[:-4]):
            state_dict[param] = pretrained_state_dict[pretrained_param_names[i]]

        state_dict['conv6.weight'] = pretrained_state_dict['classifier.0.weight'].view(4096, 512, 7, 7)
        state_dict['conv6.bias'] = pretrained_state_dict['classifier.0.bias']

        state_dict['conv7.weight'] = pretrained_state_dict['classifier.3.weight'].view(4096, 4096, 1, 1)
        state_dict['conv6.bias'] = pretrained_state_dict['classifier.3.bias']
        self.load_state_dict(state_dict)
    
class FCN8s(nn.Module):
    def __init__(self, n_classes):
        super(FCN8s, self).__init__()

        # 直接使用Vgg-16預(yù)訓(xùn)練網(wǎng)絡(luò)窗轩,拋棄classifier層,并把fc層轉(zhuǎn)換為卷積層
        # fc6轉(zhuǎn)化為conv6座咆,使用的卷積核大小為7x7痢艺,該層輸出長(zhǎng)度有6個(gè)像素的損失,
        # 向上采樣32倍即原始空間192個(gè)像素的損失介陶,因而小于192x192的輸入會(huì)導(dǎo)致報(bào)錯(cuò)
        # 同時(shí)這些像素?fù)p失必需通過(guò)padding使上采樣的空間大小與原輸入空間一致
        # 其實(shí)這個(gè)值可以屬于(96,112)都能達(dá)到以上效果

        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=100)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.conv6 = nn.Conv2d(512, 4096, kernel_size=7)
        self.dropout6 = nn.Dropout2d()

        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        self.dropout7 = nn.Dropout2d()

        self.load_pretrained_layers()

        self.score = nn.Conv2d(4096, n_classes, 1)
        self.score_pool4 = nn.Conv2d(512, n_classes, 1)
        self.score_pool3 = nn.Conv2d(256, n_classes, 1)

        # 此處的kernel_size我認(rèn)為是作者主觀選擇的堤舒,默認(rèn)是下采樣率的2倍
        self.upsample_2x = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=4, stride=2, bias=False)
        self.upsample_8x = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=16, stride=8, bias=False)

        self.upsample_2x.weight.data = get_bilinear_weights(n_classes, n_classes, kernel_size=4)
        self.upsample_2x.weight.requires_grad = False
        self.upsample_8x.weight.data = get_bilinear_weights(n_classes, n_classes, kernel_size=16)
        self.upsample_8x.weight.requires_grad = False

    def forward(self, x):
        # 我們假設(shè)輸入圖片的height, width均為能被32整除
        out = torch.relu(self.conv1_1(x))  # (b, 64, h+198, w+198)
        out = torch.relu(self.conv1_2(out))  # (b, 64, h+198, w+198)
        out = self.pool1(out)  # (b, 64, h/2 + 99, w/2 +99)

        out = torch.relu(self.conv2_1(out))  # (b, 128, h/2+99, w+99)
        out = torch.relu(self.conv2_2(out))  # (b, 128, h/2+99, w+99)
        out = self.pool2(out)  # (b, 128, h/4 + 49, w/4 + 49)

        out = torch.relu(self.conv3_1(out))
        out = torch.relu(self.conv3_2(out))
        out = torch.relu(self.conv3_3(out))
        out = self.pool3(out)  # (b, 256, h/8 + 24, w/8 + 24)
        pool3 = out

        out = torch.relu(self.conv4_1(out))
        out = torch.relu(self.conv4_2(out))
        out = torch.relu(self.conv4_3(out))
        out = self.pool4(out)  # (b, 512, h/16 + 12, w/16 + 12)
        pool4 = out

        out = torch.relu(self.conv5_1(out))
        out = torch.relu(self.conv5_2(out))
        out = torch.relu(self.conv5_3(out))
        out = self.pool5(out)  # (b, 512, h/32 + 6, w/32 + 6)

        out = torch.relu(self.conv6(out))  # (b, 512, h/32, w/32)
        out = self.dropout6(out)

        out = torch.relu(self.conv7(out))
        out = self.dropout7(out)

        out = self.score(out)

        # 由于轉(zhuǎn)置卷積的卷積核大小使上采樣32倍后比原始size大了(kernel_size - stride)
        out = self.upsample_2x(out)  # (b, n_classes, h/16 + 2, w/16 + 2)
        pool4 = self.score_pool4(pool4)  # (b, n_classes, h/16 + 12, w/16 + 12)
        out = out + pool4[:, :, 5:5 + out.size(2), 5:5 + out.size(3)]  # (b, n_classes, h/16 + 2, w/16 + 2)

        out = self.upsample_2x(out)  # (b, n_classes, h/8 + 4 + 2, w/8 + 4 + 2)
        pool3 = self.score_pool3(pool3)  # (b, 256, h/8 + 24, w/8 + 24)
        out = out + pool3[:, :, 9:9 + out.size(2), 9:9 + out.size(3)]  # (b, n_classes, h/8 + 6, w/8 + 6)

        out = self.upsample_8x(out)  # (b, n_classes, h + 48 + 8, w + 48 + 8)

        return out[:, :, 28:28 + x.shape[2], 28:28 + x.shape[3]].contiguous()

    def load_pretrained_layers(self):
        state_dict = self.state_dict()
        param_names = list(state_dict.keys())

        pretrained_state_dict = torchvision.models.vgg16(pretrained=True).state_dict()
        pretrained_param_names = list(pretrained_state_dict.keys())

        for i, param in enumerate(param_names[:-4]):
            state_dict[param] = pretrained_state_dict[pretrained_param_names[i]]

        state_dict['conv6.weight'] = pretrained_state_dict['classifier.0.weight'].view(4096, 512, 7, 7)
        state_dict['conv6.bias'] = pretrained_state_dict['classifier.0.bias']

        state_dict['conv7.weight'] = pretrained_state_dict['classifier.3.weight'].view(4096, 4096, 1, 1)
        state_dict['conv6.bias'] = pretrained_state_dict['classifier.3.bias']
        self.load_state_dict(state_dict)

由于正負(fù)類不平衡對(duì)于FCN無(wú)影響(見第4節(jié)),直接使用交叉熵的計(jì)算方法來(lái)計(jì)算pixel loss(注意是2D版)

(其實(shí)也可以進(jìn)行Hard Negative Mining來(lái)加快收斂哺呜,這里簡(jiǎn)單起見使用這種方法)

class LossFunction(nn.Module):
    def __init__(self):
        super(LossFunction, self).__init__()
        self.loss = nn.NLLLoss()

     def forward(self, pred, target):
         pred = nn.functional.log_softmax(pred, dim=1)
         loss = self.loss(pred, target)
         return loss

接下來(lái)的Dataset舌缤、DataLoader的構(gòu)建、train和valid的具體函數(shù)不再詳細(xì)寫了(所有項(xiàng)目都差不多??)

注意

  • 在進(jìn)行數(shù)據(jù)增廣時(shí)(resize)某残,插值的方法一定要選擇NEAREAST而不是默認(rèn)的Bilinear国撵,否則會(huì)對(duì)true label image的pixel進(jìn)行誤標(biāo),導(dǎo)致問(wèn)題的出現(xiàn)
  • 訓(xùn)練要有足夠的耐心玻墅,作者的32s都訓(xùn)練了3天
  • 關(guān)于batch_size屉来,如果選擇不進(jìn)行resize,可以將batch_size設(shè)為1

一些衡量的Metrics見:wkentaro/pytorch-fcn玫芦,它的算法方法非常巧妙

結(jié)果:

6.我的問(wèn)題

從上面的分割結(jié)果來(lái)看效果還可以...但那些Metrics的值一直上不去...可能是我訓(xùn)練時(shí)間的問(wèn)題吧(我只訓(xùn)練了大概一天,可能這是最大的問(wèn)題了吧旨指,對(duì)復(fù)雜的圖像的分割能力有待加強(qiáng)??),但mIoU只達(dá)到了0.28...而且難以再升上去喳整,這個(gè)地方使我很苦惱(可能真得訓(xùn)練個(gè)3天吧??)

這里更新一下:終于找到mIoU上不去的原因了
這個(gè)問(wèn)題所在其實(shí)很傻谆构,就是在模型的load_pretrained_layer()中,最后忘記加上了self.load_state_dict()了框都,等于是預(yù)訓(xùn)練的網(wǎng)絡(luò)參數(shù)沒有用上搬素,而是重新直接訓(xùn)練了??
其實(shí)就這點(diǎn)問(wèn)題導(dǎo)致訓(xùn)練時(shí)間拉了極長(zhǎng)、輸出為黑的情況出現(xiàn)很長(zhǎng)時(shí)間魏保。FCN32s的精度太差熬尺,收斂的時(shí)間還是會(huì)稍久一點(diǎn)的,但也不會(huì)像重新訓(xùn)練一樣那么慢
心碎了一地??

我思考了一下問(wèn)題在哪里谓罗,可能是數(shù)據(jù)集過(guò)少的問(wèn)題粱哼,也跟可能是某種類別難以識(shí)別(有些類的IoU明顯較差),訓(xùn)練數(shù)據(jù)本身不平衡檩咱、標(biāo)注本來(lái)就不準(zhǔn)確什么的...也可能是FCN模型的真實(shí)能力并非想象中那么好...可以試一下讓網(wǎng)絡(luò)學(xué)習(xí)deconv層的參數(shù)揭措,亦或直接按照encoder-decoder的做法重新構(gòu)建一下網(wǎng)絡(luò)(雖然更耗時(shí),但肯定能提高細(xì)節(jié)的預(yù)測(cè))

其實(shí)大家有功夫可以多訓(xùn)練一下看看效果刻蚯,我看那種自動(dòng)駕駛的訓(xùn)練集(Cityscapes)的訓(xùn)練效果會(huì)更好一點(diǎn)(數(shù)據(jù)集里沒有背景類)

Reference

[1] Long, J., Shelhamer, E., & Darrell, T. (2015). Fully convolutional networks for semantic segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 3431-3440)

[2] 《動(dòng)手學(xué)深度學(xué)習(xí)》

[3] wkentaro/pytorch-fcn

轉(zhuǎn)載請(qǐng)說(shuō)明出處绊含。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市炊汹,隨后出現(xiàn)的幾起案子躬充,更是在濱河造成了極大的恐慌,老刑警劉巖讨便,帶你破解...
    沈念sama閱讀 216,591評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件充甚,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡器钟,警方通過(guò)查閱死者的電腦和手機(jī)津坑,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,448評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)傲霸,“玉大人疆瑰,你說(shuō)我怎么就攤上這事£甲模” “怎么了穆役?”我有些...
    開封第一講書人閱讀 162,823評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)梳凛。 經(jīng)常有香客問(wèn)我耿币,道長(zhǎng),這世上最難降的妖魔是什么韧拒? 我笑而不...
    開封第一講書人閱讀 58,204評(píng)論 1 292
  • 正文 為了忘掉前任淹接,我火速辦了婚禮十性,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘塑悼。我一直安慰自己劲适,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,228評(píng)論 6 388
  • 文/花漫 我一把揭開白布厢蒜。 她就那樣靜靜地躺著霞势,像睡著了一般。 火紅的嫁衣襯著肌膚如雪斑鸦。 梳的紋絲不亂的頭發(fā)上愕贡,一...
    開封第一講書人閱讀 51,190評(píng)論 1 299
  • 那天,我揣著相機(jī)與錄音巷屿,去河邊找鬼固以。 笑死,一個(gè)胖子當(dāng)著我的面吹牛攒庵,可吹牛的內(nèi)容都是我干的嘴纺。 我是一名探鬼主播败晴,決...
    沈念sama閱讀 40,078評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼浓冒,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了尖坤?” 一聲冷哼從身側(cè)響起稳懒,我...
    開封第一講書人閱讀 38,923評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎慢味,沒想到半個(gè)月后场梆,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,334評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡纯路,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,550評(píng)論 2 333
  • 正文 我和宋清朗相戀三年或油,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片驰唬。...
    茶點(diǎn)故事閱讀 39,727評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡顶岸,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出叫编,到底是詐尸還是另有隱情辖佣,我是刑警寧澤,帶...
    沈念sama閱讀 35,428評(píng)論 5 343
  • 正文 年R本政府宣布搓逾,位于F島的核電站卷谈,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏霞篡。R本人自食惡果不足惜世蔗,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,022評(píng)論 3 326
  • 文/蒙蒙 一端逼、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧污淋,春花似錦裳食、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,672評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至而昨,卻和暖如春救氯,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背歌憨。 一陣腳步聲響...
    開封第一講書人閱讀 32,826評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工着憨, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人务嫡。 一個(gè)月前我還...
    沈念sama閱讀 47,734評(píng)論 2 368
  • 正文 我出身青樓甲抖,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親心铃。 傳聞我的和親對(duì)象是個(gè)殘疾皇子准谚,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,619評(píng)論 2 354