3.4积蜻、驗證(Validation)
當(dāng)我們在訓(xùn)練集上指標(biāo)表現(xiàn)良好時院崇,需要使用驗證集來檢驗一下訓(xùn)練的結(jié)果是否存在過擬合現(xiàn)象摔寨。
3.4.1迈着、模型與參數(shù)的保存
模型的訓(xùn)練可能是一個漫長的過程竭望,在模型訓(xùn)練過程中,以及模型訓(xùn)練完成準(zhǔn)備發(fā)布時裕菠,我們需要保存模型或模型參數(shù)咬清,以便在此基礎(chǔ)上繼續(xù)訓(xùn)練,或者把訓(xùn)練好的模型發(fā)布上線奴潘。
# 保存模型
torch.save(net, './fcn8s.pth')
# 保存模型參數(shù)
torch.save(net.state_dict(), './fcn8s.pth')
# 加載整個模型
Net = torch.load('./fcn8s.pth')
# 加載模型參數(shù)
net.load_state_dict(torch.load('./fcn8s.pth'))
對于本文旧烧,我們僅保存了模型參數(shù),用于繼續(xù)訓(xùn)練和訓(xùn)練完成后的測試和預(yù)測工作画髓。
3.4.2掘剪、模型驗證
驗證是用來評估訓(xùn)練的參數(shù)是否過存在擬合現(xiàn)象。驗證和測試的過程和代碼幾乎相同奈虾,主要的不同點在于驗證階段不需要進(jìn)行優(yōu)化夺谁,沒有反向傳播,梯度下降等優(yōu)化操作肉微。我們簡單的調(diào)整訓(xùn)練代碼匾鸥,去掉優(yōu)化部分,得到如下的驗證代碼
def validate(self):
training = self.model.training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
val_loss = 0.0
val_acc = 0.0
mean_iu = 0.0
self.model.to(device)
self.model.eval()
for batch_index, data in enumerate(self.val_loader):
iteration = batch_index + 1
std_input = data[0].float() / 255
if self.transform:
std_input = self.transform(std_input)
input = Variable(std_input.to(device))
target = data[1].float().to(device)
with torch.no_grad():
score = self.model(input)
# metrics
loss = criterion(score, target)
if np.isnan(loss.item()):
raise ValueError('loss is nan while validating')
val_loss += loss.item()
pred = OneHotEncoder.encode_score(score)
cm = Trainer.confusion_matrix(target, pred)
acc = torch.diag(cm).sum().item() / torch.sum(cm).item()
val_acc += acc
iu = torch.diag(cm) / (cm.sum(dim=1) + cm.sum(dim=0) - torch.diag(cm))
mean_iu += torch.nanmean(iu).item()
data_len = len(self.val_loader)
val_loss /= data_len
val_acc /= data_len
mean_iu /= data_len
print(f'validate loss: {val_loss:.5f}, accuracy:{val_acc:.5f}, mean IU:{mean_iu:.5f}')
if training:
self.model.train()
上面代碼中的model.eval()用來通知pytorch模型碉纳,當(dāng)前處于評估階段扫腺,此時模型中的BatchNormalization, Dropout等算法的行為會發(fā)生改變。torch.no_grad()區(qū)域內(nèi)的模型計算不會計算梯度值村象。在驗證代碼完成后,我們把模型在訓(xùn)練還是評估階段的標(biāo)識還原攒至,這樣方便我們接下來進(jìn)行的混合訓(xùn)練和驗證厚者。
3.4.3、混合訓(xùn)練與驗證
在指標(biāo)中迫吐,我們列出了一些模型輸出結(jié)果的度量方法库菲。如果一個模型訓(xùn)練結(jié)果的指標(biāo)符合要求,并且在驗證集上同樣表現(xiàn)良好志膀,那么我們可以保存模型或模型的參數(shù)熙宇,之后可直接使用保存下來的模型或參數(shù)去做測試和預(yù)測工作。
那么我們在是何種情況下溉浙,保存模型或模型的參數(shù)烫止?這通常依賴于我們要做的具體事情。在語義分割任務(wù)中戳稽,我們通常選擇IOU指標(biāo)馆蠕,作為評估保存模型或模型參數(shù)的指標(biāo)。為了讓程序智能為我們選擇理想的結(jié)果并保存,首先互躬,我們要確保模型的參數(shù)播赁,在訓(xùn)練集上訓(xùn)練的結(jié)果指標(biāo)滿足需求,然后我們使用此參數(shù)進(jìn)行模型驗證吼渡,輸出驗證結(jié)果的指標(biāo)容为,并保存模型參數(shù)。在下一次訓(xùn)練的結(jié)果指標(biāo)滿足需求時寺酪,如果再次驗證的結(jié)果指標(biāo)優(yōu)于上次保存的指標(biāo)坎背,那么保存最新的模型參數(shù)。最終訓(xùn)練和驗證完成后房维,我們保存的模型的參數(shù)沼瘫,在訓(xùn)練集上的表現(xiàn)符合預(yù)期,并且在驗證集上的泛化能力最優(yōu)化咙俩。
在trainer的構(gòu)造函數(shù)中定義準(zhǔn)確率閾值和中間比對的IOU值
class Trainer(object):
def __init__(self, model: torch.nn.Module, transform, train_loader: DataLoader, val_loader: DataLoader, class_names, class_colors):
self.model = model
self.transform = transform
self.visualizer = Visualizer(class_names, class_colors)
self.acc_threshold = 0.95
self.best_mean_iu = 0
self.train_loader = train_loader
self.val_loader = val_loader
在訓(xùn)練代碼中耿戚,加入混合驗證的代碼:
if verbose and iteration % iterations_per_epoch == 0:
mean_acc = train_acc/iterations_per_epoch
mean_iu = train_iu/iterations_per_epoch
print(f'epoch {epoch + 1} / {epochs}: loss: {train_loss/iterations_per_epoch:.5f}, accuracy:{mean_acc:.5f}, mean IU:{mean_iu:.5f}')
if mean_acc > self.acc_threshold:
self.validate()
最后在模型驗證代碼中,加入擇優(yōu)保存模型參數(shù)的代碼
print(f'validate loss: {val_loss:.5f}, accuracy:{val_acc:.5f}, mean IU:{mean_iu:.5f}')
if mean_iu > self.best_mean_iu:
self.save_model_params()
self.best_mean_iu = mean_iu
3.5阿趁、測試(Test)
當(dāng)我們訓(xùn)練好了一個模型膜蛔,我們可以測試模型實際運行的效果. 測試階段是實際預(yù)測的預(yù)演,我們通過測試來評估模型正式運行時的效果脖阵。通常測試使用的數(shù)據(jù)皂股,是在訓(xùn)練和驗證都沒有使用過的數(shù)據(jù),這樣可以保證測試的結(jié)果盡可能接近真實的結(jié)果命黔。在測試階段呜呐,我們增加了兩個指標(biāo):ROC和PR
3.5.1、ROC
ROC(Receiver Operating Characteristic)指標(biāo)悍募,可以直觀地評價分類器的優(yōu)劣蘑辑。ROC指標(biāo)是多個指標(biāo)的組合,橫坐標(biāo)FPR(False Positive Rate)也稱為誤報率坠宴。是所有實際為假的樣本中被錯誤地預(yù)測為陽性的比例洋魂。計算公式為:
FPR = FP / (FP + TN)
FP在混淆矩陣中是分類所在列中除去斜對角線元素之外所有數(shù)值的和, TN在混淆矩陣中是除去分類所在的行和列之外所有的數(shù)值之和喜鼓。
縱坐標(biāo)TPR(True Positive Rate)也稱為召回率副砍,查全率。是所有實際為真的樣本中庄岖,被正確地預(yù)測為陽性的比例豁翎。計算公式為:
TPR = TP / ( TP + FN)
TP 在混淆矩陣中是分類所在的斜對角線元素,F(xiàn)N在混淆矩陣中是分類所在行中除去斜對角線元素之外的所有數(shù)值之和隅忿。
基于預(yù)測結(jié)果的打分或概率启搂,選定若干個閾值,在不同閾值下的混淆矩陣刘陶,對應(yīng)的TPR和FPR胳赌,即構(gòu)成了一幅ROC曲線圖。
ROC曲線圖的左下到右上的對角線是隨機(jī)猜測線匙隔,ROC曲線的區(qū)域越大疑苫,說明預(yù)測準(zhǔn)確率和越高,如果ROC曲線在對角線下方纷责,說明模型預(yù)測的準(zhǔn)確率低于隨機(jī)猜測捍掺。
為了繪制各種圖表和可視化結(jié)果,我們構(gòu)建了一個可視化的類再膳,使用標(biāo)簽數(shù)據(jù)和預(yù)測結(jié)果作為參數(shù)來繪制ROC曲線挺勿。
這里注意如果是多分類,那么y_pred只能使用概率喂柒,否則由于計算某一分類時不瓶,并不會參考其它分類的打分,會導(dǎo)致ROC曲線與實際不符灾杰。
class Visualizer:
def __init__(self, class_names, class_colors):
plt.rcParams['font.sans-serif'] = ['SimHei']
self.class_names = class_names
self.n_classes = len(class_names)
self.class_colors = class_colors
def draw_roc_auc(self, y_true: Tensor, y_pred: Tensor, title, x_label="False Positive Rate", y_label="True Positive Rate"):
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(self.n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true[:, i, :, :].view(-1).numpy(), y_pred[:, i, :, :].view(-1).numpy())
roc_auc[i] = auc(fpr[i], tpr[i])
for i, color in zip(range(self.n_classes), self.class_colors):
plt.plot(
fpr[i],
tpr[i],
color=color,
lw=2,
label="ROC curve of class {0} (area = {1:0.2f})".format(self.class_names[i], roc_auc[i]),
)
plt.plot([0, 1], [0, 1], "k--", lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.legend(loc="lower right")
plt.show()
3.5.2蚊丐、PR
PR(Precision Recall)指標(biāo),是精確率(Precision)和召回率(Recall)兩個指標(biāo)的組合艳吠。其中橫坐標(biāo)是召回率(Recall)麦备,和ROC中的TPR的概念是一致的,表示真的樣本中昭娩,預(yù)測為陽性所在的比例凛篙。縱坐標(biāo)是精確率(Precision)栏渺,也稱為查準(zhǔn)率呛梆。是所有預(yù)測為陽性的樣本中,實際為真的比例迈嘹。計算公式為:
Precision = TP /(TP + FP)
基于預(yù)測結(jié)果的打分或概率,選定若干個閾值全庸,在不同閾值下的混淆矩陣秀仲,對應(yīng)的Precision和Recall,即構(gòu)成了一幅PR曲線圖壶笼。
def draw_pr(self, y_true: Tensor, y_pred: Tensor, title, x_label="Recall", y_label="Precision"):
precision = dict()
recall = dict()
aps = dict()
for i in range(self.n_classes):
precision[i], recall[i], thresholds = precision_recall_curve(y_true[:, i, :, :].view(-1).numpy(), y_pred[:, i, :, :].view(-1).numpy())
aps[i] = average_precision_score(y_true[:, i, :, :].view(-1).numpy(), y_pred[:, i, :, :].view(-1).numpy())
for i, color in zip(range(self.n_classes), self.class_colors):
plt.plot(
recall[i],
precision[i],
color=color,
lw=2,
label="PR of class {0} (area = {1:0.2f})".format(self.class_names[i], aps[i]),
)
plt.plot([0, 1], [0, 1], "k--", lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.legend(loc="lower right")
plt.show()
3.5.3神僵、繪制測試結(jié)果
我們可以把測試結(jié)果繪制成類似于語義分割標(biāo)簽圖片的圖像,并對比原標(biāo)簽圖像覆劈,直觀觀察分割的結(jié)果和實際標(biāo)簽的匹配程度保礼。為了繪制測試結(jié)果沛励,我們首先為one-hot編碼添加解碼能力,把one-hot編碼解碼成使用不同顏色表示不同分類的圖像炮障。
@staticmethod
def decode(input: Tensor, colors: Tensor):
height, width = input.shape[1:]
mask = torch.zeros([3, height, width], dtype=torch.long)
for label_num in range(0, len(colors)):
index = (input[label_num] == 1)
mask[:, index] = colors[label_num][:, None]
return mask
之后使用新增的方法實現(xiàn)繪制測試結(jié)果的功能目派。在一行中分別繪制原圖,標(biāo)簽圖和預(yù)測圖胁赢。
def draw_result(self, img: Tensor, mask: Tensor, y_pred: Tensor):
mask_img = OneHotEncoder.decode(mask, self.class_colors)
pred_img = OneHotEncoder.decode(y_pred, self.class_colors)
plt.figure(figsize=(12, 5))
plt.subplot(131)
plt.imshow(img.permute(1, 2, 0))
plt.subplot(132)
plt.imshow(mask_img.permute(1, 2, 0))
plt.subplot(133)
plt.imshow(pred_img.permute(1, 2, 0))
plt.show()
3.5.4企蹭、網(wǎng)格化標(biāo)注
有了預(yù)測結(jié)果,我們可以根據(jù)預(yù)測結(jié)果在原圖或者標(biāo)簽圖的基礎(chǔ)上做各種疊加處理智末,用以反饋預(yù)測結(jié)果在原圖上的效果谅摄。這里我們嘗試使用小網(wǎng)格的方式,在原圖之上標(biāo)注分類的網(wǎng)格區(qū)域系馆。
def draw_overlay_grid(self, img: Tensor, overlay_classes, y_pred: Tensor, label):
font = {'color': 'green',
'size': 20,
'family': 'Times New Roman'}
grid = torch.tensor([
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0]
])
w, h = img.shape[1:]
k_size = grid.shape[0]
left, top = 0, 0
while top < h:
left = 0
bottom = min(top + k_size, h)
while left < w:
right = min(left + k_size, w)
sum_pred = torch.sum(y_pred[:, top:bottom, left:right].flatten(1, 2), dim=1)
klass = sum_pred.argmax()
if klass in overlay_classes:
img[:, top:bottom, left:right] = torch.mul(
img[:, top:bottom, left:right], grid[0:bottom-top,0:right-left]) + torch.mul(self.class_colors[klass][:,None, None], grid ^ 1)
plt.figure(figsize=(12, 5))
plt.imshow(img.permute(1, 2, 0))
if label:
plt.text(10, 20, label, fontdict=font)
plt.show()
4送漠、總結(jié)
在本文中,我們介紹了語義分割技術(shù)由蘑,一些機(jī)器學(xué)習(xí)的技術(shù)和概念在語義分割技術(shù)中的應(yīng)用闽寡。最后,我們介紹了幾種評估指標(biāo)以及繪制指標(biāo)圖纵穿,通過指標(biāo)圖和參數(shù)的配合下隧,深入理解語義分割模型,學(xué)習(xí)準(zhǔn)則和優(yōu)化過程中谓媒,各個超參數(shù)的意義和影響淆院。整個實驗涉及到了許多的Scalar,Vector句惯,Matrix土辩,Tensor之間的運算,需要我們熟練使用pyorch抢野,numpy等框架和庫對這些類型的數(shù)據(jù)進(jìn)行處理拷淘。