示例數(shù)據(jù)為Feta-Head-Circumference
下載地址: https://zenodo.org/record/1322001#.YTHD2Y4zaUl
Feta-Head-Circumference.png
模型結(jié)構(gòu) U-Net
U-Net
擴(kuò)展閱讀:https://github.com/pranjalrai-iitd/Fetal-head-segmentation-and-circumference-measurement-from-ultrasound-images
引入包
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.pylab as plab
from PIL import Image, ImageDraw
import numpy as np
import pandas as pd
import os
import copy
import collections
from sklearn.model_selection import ShuffleSplit
from scipy import ndimage as ndi
from skimage.segmentation import mark_boundaries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.transforms as transforms
from torchvision import models,utils, datasets
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from albumentations import (HorizontalFlip, VerticalFlip, Compose, Resize,)
from torchsummary import summary
# CPU or GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# dataloader里的多進(jìn)程用到num_workers
workers = 0 if os.name=='nt' else 4
數(shù)據(jù)初探
# 數(shù)據(jù)地址
path_train="./data/sos/training_set/"
imgs_list = [pp for pp in os.listdir(path_train) if "Annotation" not in pp and pp.endswith('.png')]
annts_list = [pp for pp in os.listdir(path_train) if "Annotation" in pp and pp.endswith('.png')]
print("number of images:", len(imgs_list))
print("number of annotations:", len(annts_list))
"""
number of images: 999
number of annotations: 999
"""
# 查看一些圖片
np.random.seed(2019)
rnd_imgs = np.random.choice(imgs_list, 4)
print('The random images are: ', rnd_imgs)
# The random images are: ['166_2HC.png' '434_HC.png' '244_HC.png' '826_3HC.png']
# 可視化圖片
def show_img_mask(img, mask):
if torch.is_tensor(img):
img = to_pil_image(img)
mask = to_pil_image(mask)
img_mask = mark_boundaries(
np.array(img),
np.array(mask),
outline_color=(0,1,0),
color=(0,1,0)
)
plt.imshow(img_mask)
# 畫圖查看圖片
for fn in rnd_imgs:
img_path = os.path.join(path_train, fn)
annt_path = img_path.replace(".png", "_Annotation.png")
img = Image.open(img_path)
annt_edges = Image.open(annt_path)
mask = ndi.binary_fill_holes(annt_edges)
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap="gray")
plt.subplot(1, 3, 2)
plt.imshow(mask, cmap="gray")
plt.subplot(1, 3, 3)
show_img_mask(img, mask)
data status
構(gòu)建Dataset阎曹,Transforms,DataLoader
# transforms
h, w = 128, 192
transform_train = Compose([ Resize(h, w),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
])
transform_val = Resize(h, w)
# 創(chuàng)建datasets
class FetalDataset(Dataset):
def __init__(self, path_data, transform=None):
imgs_list = [pp for pp in os.listdir(path_train) if "Annotation" not in pp and pp.endswith('.png')]
annts_list = [pp for pp in os.listdir(path_train) if "Annotation" in pp and pp.endswith('.png')]
self.path_imgs = [os.path.join(path_data, fn) for fn in imgs_list]
self.path_annts = [path_img.replace('.png', '_Annotation.png') for path_img in self.path_imgs]
self.transform = transform
def __len__(self):
return len(self.path_imgs)
def __getitem__(self, idx):
path_img = self.path_imgs[idx]
image = Image.open(path_img)
path_annt = self.path_annts[idx]
annt_edges = Image.open(path_annt)
mask = ndi.binary_fill_holes(annt_edges)
image = np.array(image)
mask = mask.astype('uint8')
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
image = to_tensor(image)
mask = 255 * to_tensor(mask)
return image, mask
# 實(shí)例化dataset
fetal_train_ds = FetalDataset(path_train, transform=transform_train)
fetal_val_ds = FetalDataset(path_train, transform=transform_val)
# print(len(fetal_train_ds))
# print(len(fetal_val_ds))
# 數(shù)據(jù)分割為訓(xùn)練驗(yàn)證集
sss = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices = range(len(fetal_train_ds))
for train_index, val_index in sss.split(indices):
train_ds = Subset(fetal_train_ds, train_index)
print(len(train_ds))
val_ds = Subset(fetal_val_ds, val_index)
print(len(val_ds))
plt.figure(figsize=(5,5))
for img,mask in train_ds:
show_img_mask(img,mask)
break
# 構(gòu)建dataloader
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False)
# 打印出數(shù)據(jù)查看
for img, mask in train_dl:
print(img.shape, img.dtype)
# torch.Size([8, 1, 128, 192]) torch.float32
print(mask.shape, mask.dtype)
# torch.Size([8, 1, 128, 192]) torch.float32
break
"""
799
200
torch.Size([8, 1, 128, 192]) torch.float32
torch.Size([8, 1, 128, 192]) torch.float32
"""
轉(zhuǎn)換后圖片
模型定義
# 定義模型 encoder-decoder model U-Net
class SegNet(nn.Module):
def __init__(self, params):
super(SegNet, self).__init__()
C_in, H_in, W_in = params['input_shape']
init_f = params['initial_filters']
num_outputs = params['num_outputs']
# 定義各卷積層
self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3, stride=1, padding=1)
# 定義上采樣層
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv4(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv5(x))
x = self.upsample(x)
x = F.relu(self.conv_up1(x))
x = self.upsample(x)
x = F.relu(self.conv_up2(x))
x = self.upsample(x)
x = F.relu(self.conv_up3(x))
x = self.upsample(x)
x = F.relu(self.conv_up4(x))
x = self.conv_out(x)
return x
params_model={
"input_shape": (1, 128, 192),
"initial_filters": 16,
"num_outputs": 1,
}
model = SegNet(params_model).to(device)
# print(model)
# """
# SegNet(
# (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (upsample): Upsample(scale_factor=2.0, mode=bilinear)
# (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# """
# 查看模型信息
summary(model, input_size=(1, 128, 192))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Conv2d-1 [-1, 16, 128, 192] 160
# Conv2d-2 [-1, 32, 64, 96] 4,640
# Conv2d-3 [-1, 64, 32, 48] 18,496
# Conv2d-4 [-1, 128, 16, 24] 73,856
# Conv2d-5 [-1, 256, 8, 12] 295,168
# Upsample-6 [-1, 256, 16, 24] 0
# Conv2d-7 [-1, 128, 16, 24] 295,040
# Upsample-8 [-1, 128, 32, 48] 0
# Conv2d-9 [-1, 64, 32, 48] 73,792
# Upsample-10 [-1, 64, 64, 96] 0
# Conv2d-11 [-1, 32, 64, 96] 18,464
# Upsample-12 [-1, 32, 128, 192] 0
# Conv2d-13 [-1, 16, 128, 192] 4,624
# Conv2d-14 [-1, 1, 128, 192] 145
# ================================================================
# Total params: 784,385
# Trainable params: 784,385
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.09
# Forward/backward pass size (MB): 22.88
# Params size (MB): 2.99
# Estimated Total Size (MB): 25.96
# ----------------------------------------------------------------
定義損失函數(shù) Dice metric
Dice系數(shù), 根據(jù) Lee Raymond Dice命名,是一種集合相似度度量函數(shù),通常用于計(jì)算兩個(gè)樣本的相似度(值范圍為 [0, 1]):
dice coefficient
|X?Y| - X 和 Y 之間的交集;|X| 和 |Y| 分別表示 X 和 Y 的元素個(gè)數(shù). 其中匙握,分子中的系數(shù) 2,是因?yàn)榉帜复嬖谥貜?fù)計(jì)算 X 和 Y 之間的共同元素的原因.
Dice 系數(shù)差異函數(shù)(Dice loss):
Dice loss.png
## 定義損失函數(shù)
# Dice系數(shù)是一種集合相似度度量函數(shù)陈轿,通常用于計(jì)算兩個(gè)樣本的相似度肺孤,取值范圍在[0,1]
# https://blog.csdn.net/JMU_Ma/article/details/97533768 , https://zhuanlan.zhihu.com/p/86704421
def dice_loss(pred, target, smooth = 1e-5):
intersection = (pred * target).sum(dim=(2,3))
union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
dice = 2.0 * (intersection + smooth) / (union+ smooth)
loss = 1.0 - dice
return loss.sum(), dice.sum()
def loss_func(pred, target):
bce = F.binary_cross_entropy_with_logits(pred, target, reduction='sum')
pred = torch.sigmoid(pred)
dlv, _ = dice_loss(pred, target)
loss = bce + dlv
return loss
模型設(shè)計(jì)及訓(xùn)練
定義幾個(gè)計(jì)算輔助函數(shù)
# 取得學(xué)習(xí)率
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# 定義評價(jià)函數(shù)
def metrics_batch(pred, target):
pred = torch.sigmoid(pred)
_, metric = dice_loss(pred, target)
return metric
# 各批次損失計(jì)算
def loss_batch(loss_func, output, target, opt=None):
loss = loss_func(output, target)
with torch.no_grad():
pred = torch.sigmoid(output)
_, metric_b = dice_loss(pred, target)
if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), metric_b
# 各輪次計(jì)算
def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / float(len_data)
metric = running_metric / float(len_data)
return loss, metric
模型定義
# 定義模型 encoder-decoder model U-Net
class SegNet(nn.Module):
def __init__(self, params):
super(SegNet, self).__init__()
C_in, H_in, W_in = params['input_shape']
init_f = params['initial_filters']
num_outputs = params['num_outputs']
# 定義各卷積層
self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3, stride=1, padding=1)
# 定義上采樣層
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv4(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv5(x))
x = self.upsample(x)
x = F.relu(self.conv_up1(x))
x = self.upsample(x)
x = F.relu(self.conv_up2(x))
x = self.upsample(x)
x = F.relu(self.conv_up3(x))
x = self.upsample(x)
x = F.relu(self.conv_up4(x))
x = self.conv_out(x)
return x
params_model={
"input_shape": (1, 128, 192),
"initial_filters": 16,
"num_outputs": 1,
}
model = SegNet(params_model).to(device)
# print(model)
# """
# SegNet(
# (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (upsample): Upsample(scale_factor=2.0, mode=bilinear)
# (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# """
# 查看模型信息
summary(model, input_size=(1, 128, 192))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Conv2d-1 [-1, 16, 128, 192] 160
# Conv2d-2 [-1, 32, 64, 96] 4,640
# Conv2d-3 [-1, 64, 32, 48] 18,496
# Conv2d-4 [-1, 128, 16, 24] 73,856
# Conv2d-5 [-1, 256, 8, 12] 295,168
# Upsample-6 [-1, 256, 16, 24] 0
# Conv2d-7 [-1, 128, 16, 24] 295,040
# Upsample-8 [-1, 128, 32, 48] 0
# Conv2d-9 [-1, 64, 32, 48] 73,792
# Upsample-10 [-1, 64, 64, 96] 0
# Conv2d-11 [-1, 32, 64, 96] 18,464
# Upsample-12 [-1, 32, 128, 192] 0
# Conv2d-13 [-1, 16, 128, 192] 4,624
# Conv2d-14 [-1, 1, 128, 192] 145
# ================================================================
# Total params: 784,385
# Trainable params: 784,385
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.09
# Forward/backward pass size (MB): 22.88
# Params size (MB): 2.99
# Estimated Total Size (MB): 25.96
# ----------------------------------------------------------------
模型訓(xùn)練與驗(yàn)證
模型訓(xùn)練主函數(shù)
# 訓(xùn)練驗(yàn)證主函數(shù)
def train_val(model, params):
num_epochs = params["num_epochs"]
loss_func = params["loss_func"]
opt = params["optimizer"]
train_dl = params["train_dl"]
val_dl = params["val_dl"]
sanity_check = params["sanity_check"]
lr_scheduler = params["lr_scheduler"]
path2weights = params["path2weights"]
loss_history = {
"train": [],
"val": []}
metric_history = {
"train": [],
"val": []}
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float('inf')
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model,loss_func,train_dl,sanity_check,opt)
loss_history["train"].append(train_loss)
metric_history["train"].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model,loss_func,val_dl,sanity_check)
loss_history["val"].append(val_loss)
metric_history["val"].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print("Copied best model weights!")
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print("Loading best model weights!")
model.load_state_dict(best_model_wts)
print("train loss: %.6f, accuracy: %.2f" %(train_loss, 100*train_metric))
print("val loss: %.6f, accuracy: %.2f" %(val_loss, 100*val_metric))
print("-"*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
模型訓(xùn)練
# 優(yōu)化函數(shù)及學(xué)習(xí)率更新策略
opt = optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
path_models = "./models/sos/"
if not os.path.exists(path_models):
os.mkdir(path_models)
params_train={
"num_epochs": 10,
"optimizer": opt,
"loss_func": loss_func,
"train_dl": train_dl,
"val_dl": val_dl,
"sanity_check": False,
"lr_scheduler": lr_scheduler,
"path2weights": path_models+"weights.pt",
}
model, loss_hist, metric_hist = train_val(model,params_train)
可視化結(jié)果
num_epochs=params_train["num_epochs"]
plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
Train-Val Loss.png
Train-Val Accuracy.png
部署測試
# 部署并對測試數(shù)據(jù)進(jìn)行測試驗(yàn)證
# 部署前需要加載model的網(wǎng)絡(luò)結(jié)構(gòu)济欢,這里因?yàn)榍懊鎚odel已存在赠堵,所以未實(shí)例化
np.random.seed(2019)
path_test = './data/sos/test_set/'
imgs_list = [pp for pp in os.listdir(path_test) if "Annotation" not in pp]
rnd_imgs = np.random.choice(imgs_list, 4)
print(rnd_imgs)
model_weights_path = './models/sos/weights.pt'
model.load_state_dict(torch.load(model_weights_path))
model.eval()
for fn in rnd_imgs:
path_img = os.path.join(path_test, fn)
img = Image.open(path_img)
img = img.resize((w,h))
img_t = to_tensor(img).unsqueeze(0).to(device)
pred = model(img_t)
pred = torch.sigmoid(pred)[0]
mask_pred = (pred[0]>=0.5)
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap="gray")
plt.subplot(1, 3, 2)
plt.imshow(mask_pred.cpu(), cmap="gray")
plt.subplot(1, 3, 3)
show_img_mask(img, mask_pred.cpu())
test data result