網(wǎng)絡(luò)模型已被證明在解決分割問題方面非常有效货矮,達到了最先進的準確性甜紫。它們導致各種應(yīng)用的顯著改進份蝴,包括醫(yī)學圖像分析、自動駕駛趣避、機器人技術(shù)、衛(wèi)星圖像新翎、視頻監(jiān)控等等程帕。然而,構(gòu)建這些模型通常需要很長時間地啰,但在閱讀本指南后愁拭,您只需幾行代碼就可以構(gòu)建一個模型。
主要內(nèi)容
- 介紹
- 建筑模塊
- 建立一個模型
- 訓練模型
介紹
分割是根據(jù)某些特征或?qū)傩詫D像分成多個片段或區(qū)域的任務(wù)亏吝。分割模型將圖像作為輸入并返回分割掩碼:
分割神經(jīng)網(wǎng)絡(luò)模型由兩部分組成:
- 編碼器:獲取輸入圖像并提取特征岭埠。編碼器的例子有 ResNet、EfficentNet 和 ViT蔚鸥。
- 解碼器:獲取提取的特征并生成分割掩碼惜论。解碼器因架構(gòu)而異。架構(gòu)的例子有 U-Net止喷、FPN 和 DeepLab馆类。
因此,在為特定應(yīng)用構(gòu)建分割模型時启盛,您需要選擇架構(gòu)和編碼器蹦掐。但是,如果不測試幾個僵闯,很難選擇最佳組合卧抗。這通常需要很長時間,因為更改模型需要編寫大量樣板代碼鳖粟。Segmentation Models庫解決了這個問題社裆。它允許您通過指定架構(gòu)和編碼器在一行中創(chuàng)建模型。然后您只需修改該行即可更改其中任何一個向图。
要從 PyPI 安裝最新版本的分段模型泳秀,請使用:
pip install segmentation-models-pytorch
建筑模塊
該庫為大多數(shù)分段架構(gòu)提供了一個類,并且它們中的每一個都可以與任何可用的編碼器一起使用榄攀。在下一節(jié)中嗜傅,您將看到要構(gòu)建模型,您需要實例化所選架構(gòu)的類并將所選編碼器的字符串作為參數(shù)傳遞檩赢。下圖展示了庫提供的各個架構(gòu)的類名:
編碼器有 400 多種吕嘀,因此無法全部顯示,但您可以在此處找到完整列表。
https://github.com/qubvel/segmentation_models.pytorch#encoders
建立一個模型
一旦從上圖中選擇了架構(gòu)和編碼器偶房,構(gòu)建模型就非常簡單:
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet50", # choose encoder
encoder_weights="imagenet", # choose pretrained (not required)
in_channels=3, # model input channels
classes=10, # model output channels
activation="None" # None|"sigmoid"|"softmax"; default is None
)
參數(shù):
-
encoder_name
是所選編碼器的名稱(例如 resnet50趁曼、efficentnet-b7、mit_b5)棕洋。 -
encoder_weights
是預訓練的數(shù)據(jù)集挡闰。如果encoder_weights
等于"imagenet"
編碼器權(quán)重,則使用預訓練的 ImageNet 進行初始化掰盘。所有的編碼器都至少有一個預訓練的摄悯,這里有一個完整的列表。 -
in_channels
是輸入圖像的通道數(shù)(如果是 RGB庆杜,則為 3)射众。
即使in_channels
不是 3,也可以使用預訓練的 ImageNet:第一層將通過重新使用預訓練的第一個卷積層的權(quán)重來初始化(過程在此處描述)晃财。 -
out_classes
是數(shù)據(jù)集中的類數(shù)叨橱。 -
activation
是輸出層的激活函數(shù)《鲜ⅲ可能的選擇是None
(默認)sigmoid
和softmax
罗洗。
注意:當使用期望 logits 作為輸入的損失函數(shù)時,激活函數(shù)必須為 None钢猛。例如伙菜,使用CrossEntropyLoss
函數(shù)時,activation
必須是None
.
訓練模型
本節(jié)顯示執(zhí)行培訓所需的所有代碼命迈。但是贩绕,這個庫不會改變通常用于訓練和驗證模型的管道。為了簡化流程壶愤,該庫提供了許多損失函數(shù)的實現(xiàn)淑倾,例如Jaccard Loss、Dice Loss征椒、Dice Cross-Entropy Loss娇哆、Focal Loss,以及Accuracy勃救、Precision碍讨、Recall、F1Score 和 IOUScore 等指標蒙秒。有關(guān)它們及其參數(shù)的完整列表勃黍,請查看損失和指標部分中的文檔。
提議的訓練示例是使用Oxford-IIIT Pet Dataset 的二進制分割(它將通過代碼下載)晕讲。這是數(shù)據(jù)集中的兩個樣本:
最后覆获,這些是執(zhí)行此類分割任務(wù)的所有步驟:
1.建立模型榜田。
import os
from pprint import pprint
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.datasets import SimpleOxfordPetDataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# I don't use any activation function on the last layer
# because I set from_logits=True on the DiceLoss
model = smp.FPN(
encoder_name='efficientnet-b0',
encoder_weights='imagenet',
in_channels=3,
classes=1,
activation=None
)
model.to(device)
根據(jù)您要使用的損失函數(shù)設(shè)置最后一層的激活函數(shù)。
2. 定義參數(shù)锻梳。
# get_processing_params returns mean and std you should use to normalize the input
params = smp.encoders.get_preprocessing_params('efficientnet-b0')
mean = torch.tensor(params["mean"]).view(1, 3, 1, 1).to(device)
std = torch.tensor(params["std"]).view(1, 3, 1, 1).to(device)
num_epochs = 50
loss_fn = smp.losses.DiceLoss('binary', from_logits=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, verbose=True)
root = 'data'
SimpleOxfordPetDataset.download(root)
train_dataset = SimpleOxfordPetDataset(root, 'train')
val_dataset = SimpleOxfordPetDataset(root, 'valid')
n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=n_cpu)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=n_cpu)
請記住,在使用預訓練時净捅,應(yīng)使用用于訓練預訓練的數(shù)據(jù)的均值和標準差對輸入進行歸一化疑枯。
3.定義train函數(shù)。
def train():
best_accuracy = 0.0
for epoch in range(num_epochs):
mean_loss = 0.0
for i, batch in enumerate(train_dataloader):
image = batch["image"].to(device)
mask = batch["mask"].to(device)
# normalize input
image = (image - mean) / std
optimizer.zero_grad()
logits_mask = model(image)
loss = loss_fn(logits_mask, mask)
loss.backward()
optimizer.step()
mean_loss += loss.item()
print(f'[epoch {epoch + 1}, batch {i + 1}/{len(train_dataloader)}] step_loss: {loss.item():.4f}, mean_loss: {(mean_loss / (i + 1)):.4f}')
scheduler.step()
# compute validation metrics of this epoch
metrics = validate()
epoch_accuracy = metrics["accuracy"]
# save the model if accuracy has improved
if epoch_accuracy > best_accuracy:
torch.save(model.state_dict(), 'best_model.pth')
best_accuracy = epoch_accuracy
print(f'For epoch {epoch + 1} the validation metrics are:')
pprint(metrics)
與您在不使用庫的情況下為訓練模型而編寫的訓練函數(shù)相比蛔六,此處沒有任何變化荆永。
4. 定義驗證函數(shù)。
def validate():
with torch.no_grad():
# total true positives, false positives, true negatives and false negatives
total_tp, total_fp, total_fn, total_tn = None, None, None, None
for batch in val_dataloader:
image = batch["image"].to(device)
mask = batch["mask"].to(device).long()
image = (image - mean) / std
logits_mask = model(image)
loss = loss_fn(logits_mask, mask)
# we need to convert the logits to classes to compute metrics
prob_mask = logits_mask.sigmoid()
pred_mask = (prob_mask > 0.5).long()
# computing true positives, false positives, true negatives and false negatives of the batch
tp, fp, fn, tn = smp.metrics.get_stats(pred_mask, mask, mode="binary")
total_tp = torch.cat([total_tp, tp]) if total_tp != None else tp
total_fp = torch.cat([total_fp, fp]) if total_fp != None else fp
total_fn = torch.cat([total_fn, fn]) if total_fn != None else fn
total_tn = torch.cat([total_tn, tn]) if total_tn != None else tn
# metrics are computed using tp, fp, tn, fn values
metrics = {
"loss": loss,
"accuracy": smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro"),
"precision": smp.metrics.precision(tp, fp, fn, tn, reduction="micro"),
"recall": smp.metrics.recall(tp, fp, fn, tn, reduction="micro"),
"f1_score": smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
}
return metrics
批次中的真陽性国章、假陽性具钥、假陰性和真陰性全部加在一起,僅在批次結(jié)束時計算指標液兽。請注意骂删,必須先將 logits 轉(zhuǎn)換為類,然后才能計算指標四啰。調(diào)用訓練函數(shù)開始訓練宁玫。
5.使用模型。
test_dataset = SimpleOxfordPetDataset(root, 'test')
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=n_cpu)
# take a single batch
batch = next(iter(test_dataloader))
model.load_state_dict(torch.load("best_model.pth"))
with torch.no_grad():
model.eval()
image = batch["image"].to(device)
mask = batch["mask"].to(device).long()
image_norm = (image - mean) / std
logits = model(image_norm)
pred_mask = logits.sigmoid()
for i, (im, pr, gt) in enumerate(zip(image, pred_mask, mask)):
fig, axes = plt.subplots(1, 3, figsize=(9, 3))
# show input
axes[0].imshow(im.cpu().numpy().transpose(1, 2, 0))
axes[0].set_title("Image")
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)
# show prediction
axes[1].imshow(pr.cpu().numpy().squeeze())
axes[1].set_title("Prediction")
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)
# show target
axes[2].imshow(gt.cpu().numpy().squeeze())
axes[2].set_title("Ground truth")
axes[2].get_xaxis().set_visible(False)
axes[2].get_yaxis().set_visible(False)
plt.tight_layout()
plt.savefig(f"pred_{i}.png")
這些是一些細分:
結(jié)束語
這個庫擁有你進行分割實驗所需的一切柑晒。構(gòu)建模型和應(yīng)用更改非常容易欧瘪,并且提供了大多數(shù)損失函數(shù)和指標。此外匙赞,使用這個庫不會改變我們習慣的管道佛掖。有關(guān)詳細信息,請參閱官方文檔涌庭。我還在參考資料中包含了一些最常見的編碼器和架構(gòu)芥被。
項目參考文獻
[1] O. Ronneberger, P. Fischer and T. Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation (2015)
[2] Z. Zhou, Md. M. R. Siddiquee, N. Tajbakhsh and J. Liang, UNet++: A Nested U-Net Architecture for Medical Image Segmentation (2018)
[3] L. Chen, G. Papandreou, F. Schroff, H. Adam, Rethinking Atrous Convolution for Semantic Image Segmentation (2017)
[4] L. Chen, Y. Zhu, G. Papandreou, F. Schroff, H. Adam, Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation (2018)
[5] R. Li, S. Zheng, C. Duan, C. Zhang, J. Su, P.M. Atkinson, Multi-Attention-Network for Semantic Segmentation of Fine Resolution Remote Sensing Images (2020)
[6] A. Chaurasia, E. Culurciello, LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation (2017)
[7] T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, S. Belongie, Feature Pyramid Networks for Object Detection (2017)
[8] H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid Scene Parsing Network (2016)
[9] H. Li, P. Xiong, J. An, L. Wang, Pyramid Attention Network for Semantic Segmentation (2018)
[10] K. Simonyan, A. Zisserman, Very Deep Convolutional Networks for Large-Scale Image Recognition (2014)
[11] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, Deep Residual Learning for Image Recognition (2015)
[12] S. Xie, R. Girshick, P. Dollár, Z. Tu, K. He, Aggregated Residual Transformations for Deep Neural Networks (2016)
[13] J. Hu, L. Shen, S. Albanie, G. Sun, E. Wu, Squeeze-and-Excitation Networks (2017)
[14] G. Huang, Z. Liu, L. van der Maaten, K. Q. Weinberger, Densely Connected Convolutional Networks (2016)
[15] M. Tan, Q. V. Le, EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (2019)
[16] E. Xie, W. Wang, Z. Yu, A. Anandkumar, J. M. Alvarez, P. Luo, SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (2021)