FCN實現(xiàn)語義分割-Pytorch(三)

3.4积蜻、驗證(Validation)

當(dāng)我們在訓(xùn)練集上指標(biāo)表現(xiàn)良好時院崇,需要使用驗證集來檢驗一下訓(xùn)練的結(jié)果是否存在過擬合現(xiàn)象摔寨。

3.4.1迈着、模型與參數(shù)的保存

模型的訓(xùn)練可能是一個漫長的過程竭望,在模型訓(xùn)練過程中,以及模型訓(xùn)練完成準(zhǔn)備發(fā)布時裕菠,我們需要保存模型或模型參數(shù)咬清,以便在此基礎(chǔ)上繼續(xù)訓(xùn)練,或者把訓(xùn)練好的模型發(fā)布上線奴潘。

# 保存模型
torch.save(net, './fcn8s.pth')
# 保存模型參數(shù)
torch.save(net.state_dict(), './fcn8s.pth')
# 加載整個模型
Net = torch.load('./fcn8s.pth')
# 加載模型參數(shù)
net.load_state_dict(torch.load('./fcn8s.pth'))

對于本文旧烧,我們僅保存了模型參數(shù),用于繼續(xù)訓(xùn)練和訓(xùn)練完成后的測試和預(yù)測工作画髓。

3.4.2掘剪、模型驗證

驗證是用來評估訓(xùn)練的參數(shù)是否過存在擬合現(xiàn)象。驗證和測試的過程和代碼幾乎相同奈虾,主要的不同點在于驗證階段不需要進(jìn)行優(yōu)化夺谁,沒有反向傳播,梯度下降等優(yōu)化操作肉微。我們簡單的調(diào)整訓(xùn)練代碼匾鸥,去掉優(yōu)化部分,得到如下的驗證代碼

def validate(self):
    training = self.model.training
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss()
    val_loss = 0.0
    val_acc = 0.0
    mean_iu = 0.0

    self.model.to(device)
    self.model.eval()
    for batch_index, data in enumerate(self.val_loader):
        iteration = batch_index + 1
        std_input = data[0].float() / 255
        if self.transform:
            std_input = self.transform(std_input)
        input = Variable(std_input.to(device))
        target = data[1].float().to(device)
        with torch.no_grad():
            score = self.model(input)

        # metrics
        loss = criterion(score, target)
        if np.isnan(loss.item()):
            raise ValueError('loss is nan while validating')
        val_loss += loss.item()
        pred = OneHotEncoder.encode_score(score)
        cm = Trainer.confusion_matrix(target, pred)
        acc = torch.diag(cm).sum().item() / torch.sum(cm).item()
        val_acc += acc
        iu = torch.diag(cm) / (cm.sum(dim=1) + cm.sum(dim=0) - torch.diag(cm))
        mean_iu += torch.nanmean(iu).item()

    data_len = len(self.val_loader)
    val_loss /= data_len
    val_acc /= data_len
    mean_iu /= data_len

    print(f'validate loss: {val_loss:.5f}, accuracy:{val_acc:.5f}, mean IU:{mean_iu:.5f}')

    if training:
        self.model.train()

上面代碼中的model.eval()用來通知pytorch模型碉纳,當(dāng)前處于評估階段扫腺,此時模型中的BatchNormalization, Dropout等算法的行為會發(fā)生改變。torch.no_grad()區(qū)域內(nèi)的模型計算不會計算梯度值村象。在驗證代碼完成后,我們把模型在訓(xùn)練還是評估階段的標(biāo)識還原攒至,這樣方便我們接下來進(jìn)行的混合訓(xùn)練和驗證厚者。

3.4.3、混合訓(xùn)練與驗證

在指標(biāo)中迫吐,我們列出了一些模型輸出結(jié)果的度量方法库菲。如果一個模型訓(xùn)練結(jié)果的指標(biāo)符合要求,并且在驗證集上同樣表現(xiàn)良好志膀,那么我們可以保存模型或模型的參數(shù)熙宇,之后可直接使用保存下來的模型或參數(shù)去做測試和預(yù)測工作。
那么我們在是何種情況下溉浙,保存模型或模型的參數(shù)烫止?這通常依賴于我們要做的具體事情。在語義分割任務(wù)中戳稽,我們通常選擇IOU指標(biāo)馆蠕,作為評估保存模型或模型參數(shù)的指標(biāo)。為了讓程序智能為我們選擇理想的結(jié)果并保存,首先互躬,我們要確保模型的參數(shù)播赁,在訓(xùn)練集上訓(xùn)練的結(jié)果指標(biāo)滿足需求,然后我們使用此參數(shù)進(jìn)行模型驗證吼渡,輸出驗證結(jié)果的指標(biāo)容为,并保存模型參數(shù)。在下一次訓(xùn)練的結(jié)果指標(biāo)滿足需求時寺酪,如果再次驗證的結(jié)果指標(biāo)優(yōu)于上次保存的指標(biāo)坎背,那么保存最新的模型參數(shù)。最終訓(xùn)練和驗證完成后房维,我們保存的模型的參數(shù)沼瘫,在訓(xùn)練集上的表現(xiàn)符合預(yù)期,并且在驗證集上的泛化能力最優(yōu)化咙俩。

在trainer的構(gòu)造函數(shù)中定義準(zhǔn)確率閾值和中間比對的IOU值

class Trainer(object):
    def __init__(self, model: torch.nn.Module, transform, train_loader: DataLoader, val_loader: DataLoader, class_names, class_colors):
        self.model = model
        self.transform = transform
        self.visualizer = Visualizer(class_names, class_colors)
        self.acc_threshold = 0.95
        self.best_mean_iu = 0
        self.train_loader = train_loader
        self.val_loader = val_loader

在訓(xùn)練代碼中耿戚,加入混合驗證的代碼:

if verbose and iteration % iterations_per_epoch == 0:
    mean_acc = train_acc/iterations_per_epoch
    mean_iu = train_iu/iterations_per_epoch
    print(f'epoch {epoch + 1} / {epochs}: loss: {train_loss/iterations_per_epoch:.5f}, accuracy:{mean_acc:.5f}, mean IU:{mean_iu:.5f}')
    if mean_acc > self.acc_threshold:
        self.validate()

最后在模型驗證代碼中,加入擇優(yōu)保存模型參數(shù)的代碼

print(f'validate loss: {val_loss:.5f}, accuracy:{val_acc:.5f}, mean IU:{mean_iu:.5f}')
if mean_iu > self.best_mean_iu:
    self.save_model_params()
    self.best_mean_iu = mean_iu

3.5阿趁、測試(Test)

當(dāng)我們訓(xùn)練好了一個模型膜蛔,我們可以測試模型實際運行的效果. 測試階段是實際預(yù)測的預(yù)演,我們通過測試來評估模型正式運行時的效果脖阵。通常測試使用的數(shù)據(jù)皂股,是在訓(xùn)練和驗證都沒有使用過的數(shù)據(jù),這樣可以保證測試的結(jié)果盡可能接近真實的結(jié)果命黔。在測試階段呜呐,我們增加了兩個指標(biāo):ROC和PR

3.5.1、ROC

ROC曲線

ROC(Receiver Operating Characteristic)指標(biāo)悍募,可以直觀地評價分類器的優(yōu)劣蘑辑。ROC指標(biāo)是多個指標(biāo)的組合,橫坐標(biāo)FPR(False Positive Rate)也稱為誤報率坠宴。是所有實際為假的樣本中被錯誤地預(yù)測為陽性的比例洋魂。計算公式為:

FPR = FP / (FP + TN)

FP在混淆矩陣中是分類所在列中除去斜對角線元素之外所有數(shù)值的和, TN在混淆矩陣中是除去分類所在的行和列之外所有的數(shù)值之和喜鼓。

縱坐標(biāo)TPR(True Positive Rate)也稱為召回率副砍,查全率。是所有實際為真的樣本中庄岖,被正確地預(yù)測為陽性的比例豁翎。計算公式為:

TPR = TP / ( TP + FN)

TP 在混淆矩陣中是分類所在的斜對角線元素,F(xiàn)N在混淆矩陣中是分類所在行中除去斜對角線元素之外的所有數(shù)值之和隅忿。


混淆矩陣中的TPR谨垃,F(xiàn)PR

基于預(yù)測結(jié)果的打分或概率启搂,選定若干個閾值,在不同閾值下的混淆矩陣刘陶,對應(yīng)的TPR和FPR胳赌,即構(gòu)成了一幅ROC曲線圖。
ROC曲線圖的左下到右上的對角線是隨機(jī)猜測線匙隔,ROC曲線的區(qū)域越大疑苫,說明預(yù)測準(zhǔn)確率和越高,如果ROC曲線在對角線下方纷责,說明模型預(yù)測的準(zhǔn)確率低于隨機(jī)猜測捍掺。
為了繪制各種圖表和可視化結(jié)果,我們構(gòu)建了一個可視化的類再膳,使用標(biāo)簽數(shù)據(jù)和預(yù)測結(jié)果作為參數(shù)來繪制ROC曲線挺勿。
這里注意如果是多分類,那么y_pred只能使用概率喂柒,否則由于計算某一分類時不瓶,并不會參考其它分類的打分,會導(dǎo)致ROC曲線與實際不符灾杰。

class Visualizer:
    def __init__(self, class_names, class_colors):
        plt.rcParams['font.sans-serif'] = ['SimHei']
        self.class_names = class_names
        self.n_classes = len(class_names)
        self.class_colors = class_colors

    def draw_roc_auc(self, y_true: Tensor, y_pred: Tensor, title, x_label="False Positive Rate", y_label="True Positive Rate"):
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(self.n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_true[:, i, :, :].view(-1).numpy(), y_pred[:, i, :, :].view(-1).numpy())
            roc_auc[i] = auc(fpr[i], tpr[i])

        for i, color in zip(range(self.n_classes), self.class_colors):
            plt.plot(
                fpr[i],
                tpr[i],
                color=color,
                lw=2,
                label="ROC curve of class {0} (area = {1:0.2f})".format(self.class_names[i], roc_auc[i]),
            )

        plt.plot([0, 1], [0, 1], "k--", lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.title(title)
        plt.legend(loc="lower right")
        plt.show()

3.5.2蚊丐、PR

PR曲線

PR(Precision Recall)指標(biāo),是精確率(Precision)和召回率(Recall)兩個指標(biāo)的組合艳吠。其中橫坐標(biāo)是召回率(Recall)麦备,和ROC中的TPR的概念是一致的,表示真的樣本中昭娩,預(yù)測為陽性所在的比例凛篙。縱坐標(biāo)是精確率(Precision)栏渺,也稱為查準(zhǔn)率呛梆。是所有預(yù)測為陽性的樣本中,實際為真的比例迈嘹。計算公式為:

Precision = TP /(TP + FP)

基于預(yù)測結(jié)果的打分或概率,選定若干個閾值全庸,在不同閾值下的混淆矩陣秀仲,對應(yīng)的Precision和Recall,即構(gòu)成了一幅PR曲線圖壶笼。

def draw_pr(self, y_true: Tensor, y_pred: Tensor, title, x_label="Recall", y_label="Precision"):
    precision = dict()
    recall = dict()
    aps = dict()
    for i in range(self.n_classes):
        precision[i], recall[i], thresholds = precision_recall_curve(y_true[:, i, :, :].view(-1).numpy(), y_pred[:, i, :, :].view(-1).numpy())
        aps[i] = average_precision_score(y_true[:, i, :, :].view(-1).numpy(), y_pred[:, i, :, :].view(-1).numpy())

    for i, color in zip(range(self.n_classes), self.class_colors):
        plt.plot(
            recall[i],
            precision[i],
            color=color,
            lw=2,
            label="PR of class {0} (area = {1:0.2f})".format(self.class_names[i], aps[i]),
        )

    plt.plot([0, 1], [0, 1], "k--", lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

3.5.3神僵、繪制測試結(jié)果

我們可以把測試結(jié)果繪制成類似于語義分割標(biāo)簽圖片的圖像,并對比原標(biāo)簽圖像覆劈,直觀觀察分割的結(jié)果和實際標(biāo)簽的匹配程度保礼。為了繪制測試結(jié)果沛励,我們首先為one-hot編碼添加解碼能力,把one-hot編碼解碼成使用不同顏色表示不同分類的圖像炮障。

@staticmethod
def decode(input: Tensor, colors: Tensor):
    height, width = input.shape[1:]
    mask = torch.zeros([3, height, width], dtype=torch.long)
    for label_num in range(0, len(colors)):
        index = (input[label_num] == 1)
        mask[:, index] = colors[label_num][:, None]
    return mask

之后使用新增的方法實現(xiàn)繪制測試結(jié)果的功能目派。在一行中分別繪制原圖,標(biāo)簽圖和預(yù)測圖胁赢。

def draw_result(self, img: Tensor, mask: Tensor, y_pred: Tensor):
    mask_img = OneHotEncoder.decode(mask, self.class_colors)
    pred_img = OneHotEncoder.decode(y_pred, self.class_colors)
    plt.figure(figsize=(12, 5))
    plt.subplot(131)
    plt.imshow(img.permute(1, 2, 0))
    plt.subplot(132)
    plt.imshow(mask_img.permute(1, 2, 0))
    plt.subplot(133)
    plt.imshow(pred_img.permute(1, 2, 0))
    plt.show()

3.5.4企蹭、網(wǎng)格化標(biāo)注

網(wǎng)格化標(biāo)注

有了預(yù)測結(jié)果,我們可以根據(jù)預(yù)測結(jié)果在原圖或者標(biāo)簽圖的基礎(chǔ)上做各種疊加處理智末,用以反饋預(yù)測結(jié)果在原圖上的效果谅摄。這里我們嘗試使用小網(wǎng)格的方式,在原圖之上標(biāo)注分類的網(wǎng)格區(qū)域系馆。

def draw_overlay_grid(self, img: Tensor, overlay_classes, y_pred: Tensor, label):
    font = {'color': 'green',
            'size': 20,
            'family': 'Times New Roman'}
    grid = torch.tensor([
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]
    ])
    w, h = img.shape[1:]
    k_size = grid.shape[0]
    left, top = 0, 0
    while top < h:
        left = 0
        bottom = min(top + k_size, h)
        while left < w:
            right = min(left + k_size, w)
            sum_pred = torch.sum(y_pred[:, top:bottom, left:right].flatten(1, 2), dim=1)
            klass = sum_pred.argmax()
            if klass in overlay_classes:
                img[:, top:bottom, left:right] = torch.mul(
                    img[:, top:bottom, left:right], grid[0:bottom-top,0:right-left]) + torch.mul(self.class_colors[klass][:,None, None], grid ^ 1)

    plt.figure(figsize=(12, 5))
    plt.imshow(img.permute(1, 2, 0))
    if label:
        plt.text(10, 20, label, fontdict=font)
    plt.show()

4送漠、總結(jié)

在本文中,我們介紹了語義分割技術(shù)由蘑,一些機(jī)器學(xué)習(xí)的技術(shù)和概念在語義分割技術(shù)中的應(yīng)用闽寡。最后,我們介紹了幾種評估指標(biāo)以及繪制指標(biāo)圖纵穿,通過指標(biāo)圖和參數(shù)的配合下隧,深入理解語義分割模型,學(xué)習(xí)準(zhǔn)則和優(yōu)化過程中谓媒,各個超參數(shù)的意義和影響淆院。整個實驗涉及到了許多的Scalar,Vector句惯,Matrix土辩,Tensor之間的運算,需要我們熟練使用pyorch抢野,numpy等框架和庫對這些類型的數(shù)據(jù)進(jìn)行處理拷淘。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市指孤,隨后出現(xiàn)的幾起案子启涯,更是在濱河造成了極大的恐慌,老刑警劉巖恃轩,帶你破解...
    沈念sama閱讀 216,744評論 6 502
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件结洼,死亡現(xiàn)場離奇詭異,居然都是意外死亡叉跛,警方通過查閱死者的電腦和手機(jī)松忍,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,505評論 3 392
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來筷厘,“玉大人鸣峭,你說我怎么就攤上這事宏所。” “怎么了摊溶?”我有些...
    開封第一講書人閱讀 163,105評論 0 353
  • 文/不壞的土叔 我叫張陵爬骤,是天一觀的道長。 經(jīng)常有香客問我更扁,道長盖腕,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,242評論 1 292
  • 正文 為了忘掉前任浓镜,我火速辦了婚禮溃列,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘膛薛。我一直安慰自己听隐,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,269評論 6 389
  • 文/花漫 我一把揭開白布哄啄。 她就那樣靜靜地躺著雅任,像睡著了一般。 火紅的嫁衣襯著肌膚如雪咨跌。 梳的紋絲不亂的頭發(fā)上沪么,一...
    開封第一講書人閱讀 51,215評論 1 299
  • 那天,我揣著相機(jī)與錄音锌半,去河邊找鬼禽车。 笑死,一個胖子當(dāng)著我的面吹牛刊殉,可吹牛的內(nèi)容都是我干的殉摔。 我是一名探鬼主播,決...
    沈念sama閱讀 40,096評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼记焊,長吁一口氣:“原來是場噩夢啊……” “哼逸月!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起遍膜,我...
    開封第一講書人閱讀 38,939評論 0 274
  • 序言:老撾萬榮一對情侶失蹤碗硬,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后瓢颅,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體恩尾,經(jīng)...
    沈念sama閱讀 45,354評論 1 311
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,573評論 2 333
  • 正文 我和宋清朗相戀三年惜索,在試婚紗的時候發(fā)現(xiàn)自己被綠了特笋。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片剃浇。...
    茶點故事閱讀 39,745評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡巾兆,死狀恐怖猎物,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情角塑,我是刑警寧澤蔫磨,帶...
    沈念sama閱讀 35,448評論 5 344
  • 正文 年R本政府宣布,位于F島的核電站圃伶,受9級特大地震影響堤如,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜窒朋,卻給世界環(huán)境...
    茶點故事閱讀 41,048評論 3 327
  • 文/蒙蒙 一搀罢、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧侥猩,春花似錦榔至、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,683評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至划提,卻和暖如春枫弟,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背鹏往。 一陣腳步聲響...
    開封第一講書人閱讀 32,838評論 1 269
  • 我被黑心中介騙來泰國打工淡诗, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人掸犬。 一個月前我還...
    沈念sama閱讀 47,776評論 2 369
  • 正文 我出身青樓袜漩,卻偏偏與公主長得像,于是被迫代替她去往敵國和親湾碎。 傳聞我的和親對象是個殘疾皇子宙攻,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,652評論 2 354

推薦閱讀更多精彩內(nèi)容