Unet網(wǎng)絡(luò)是一種圖像語義分割網(wǎng)絡(luò),圖像語義分割網(wǎng)絡(luò)讓計算機根據(jù)圖像的語義來進行分割,例如讓計算機在輸入下面下圖,能夠輸出指定分割的圖片规惰。
原圖中文狱,物體被分為三類壕曼,1.背景衅鹿, 2.人泵三, 3.自行車
語義分割的用處很多较曼,比如說上圖中分割衛(wèi)星圖伏恐,通過多倫迭代,Prediction逐漸與Grond Truth一致斗幼。
Unet網(wǎng)絡(luò)結(jié)構(gòu)如下毁兆,整個網(wǎng)絡(luò)形如字母U畔咧。簡單的來說,整個網(wǎng)絡(luò)分為兩個部分虹蓄,左邊部分負責特征提取律胀,隨著網(wǎng)絡(luò)層加深,網(wǎng)絡(luò)的channel逐漸變大,"圖片"逐漸變小。右邊的網(wǎng)絡(luò)負責特征的還原旷坦,整個網(wǎng)絡(luò)實際上就是一個編碼-解碼器默勾。需要注意的是辆琅,整個網(wǎng)絡(luò)最出彩的地方是灰色箭頭的部分婉烟。在編碼的過程中似袁,部分信息丟失了(Maxpooling和Conv2D)扬霜。在解碼時啼县,加入與之對應的編碼層信息季眷。從圖上來看的話就是右邊每一層網(wǎng)絡(luò)都加入了一部分"白"色的"圖片"(特征)。
那么這里就有個問題兼搏,為什么要這么復雜的做一個編碼-解碼器佛呻?上圖的一個簡單的多層卷積就可以完成圖像語義分割。
原因就在于隨著卷積核的越大,伴隨著參數(shù)就會成倍增長诫肠,一是運算效率會大大下降,其次不利于收斂丧鸯。這里強烈推薦看一篇文章“看懂”卷積神經(jīng)網(wǎng)(Visualizing and Understanding Convolutional Networks)
這里講一下,Unet工作原理,假設(shè)我們有一張圖片派敷,如左圖所示,我們會根據(jù)實際需要將需要識別的區(qū)域轉(zhuǎn)化為特定的"編碼"作為類標簽试躏。
實際上每個需要識別的物體需要一個channel,有多少個需要識別的物體,就有多少個輸出channel寡键,最后再做一個疊加就是最終我們想分割的結(jié)果。
下面哪一個簡單的實例代碼來說明Unet的工作原理,源代碼Github在這里,下面我做一些解釋性說明
1.首先引入必要包
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os, sys
import random
import copy
import itertools
import time
from functools import reduce
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from torchsummary import summary
2.生成模擬數(shù)據(jù),這一部分不用太糾結(jié)代碼韭邓,復制粘貼就可以
def generate_random_data(height, width, count):
x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)])
X = np.asarray(x) * 255
X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
Y = np.asarray(y)
return X, Y
def generate_img_and_mask(height, width):
shape = (height, width)
triangle_location = get_random_location(*shape)
circle_location1 = get_random_location(*shape, zoom=0.7)
circle_location2 = get_random_location(*shape, zoom=0.5)
mesh_location = get_random_location(*shape)
square_location = get_random_location(*shape, zoom=0.8)
plus_location = get_random_location(*shape, zoom=1.2)
# Create input image
arr = np.zeros(shape, dtype=bool)
arr = add_triangle(arr, *triangle_location)
arr = add_circle(arr, *circle_location1)
arr = add_circle(arr, *circle_location2, fill=True)
arr = add_mesh_square(arr, *mesh_location)
arr = add_filled_square(arr, *square_location)
arr = add_plus(arr, *plus_location)
arr = np.reshape(arr, (1, height, width)).astype(np.float32)
# Create target masks
masks = np.asarray([
add_filled_square(np.zeros(shape, dtype=bool), *square_location),
add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True),
add_triangle(np.zeros(shape, dtype=bool), *triangle_location),
add_circle(np.zeros(shape, dtype=bool), *circle_location1),
add_filled_square(np.zeros(shape, dtype=bool), *mesh_location),
# add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location),
add_plus(np.zeros(shape, dtype=bool), *plus_location)
]).astype(np.float32)
return arr, masks
def add_square(arr, x, y, size):
s = int(size / 2)
arr[x-s,y-s:y+s] = True
arr[x+s,y-s:y+s] = True
arr[x-s:x+s,y-s] = True
arr[x-s:x+s,y+s] = True
return arr
def add_filled_square(arr, x, y, size):
s = int(size / 2)
xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))
def logical_and(arrays):
new_array = np.ones(arrays[0].shape, dtype=bool)
for a in arrays:
new_array = np.logical_and(new_array, a)
return new_array
def add_mesh_square(arr, x, y, size):
s = int(size / 2)
xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))
def add_triangle(arr, x, y, size):
s = int(size / 2)
triangle = np.tril(np.ones((size, size), dtype=bool))
arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle
return arr
def add_circle(arr, x, y, size, fill=False):
xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))
return new_arr
def add_plus(arr, x, y, size):
s = int(size / 2)
arr[x-1:x+1,y-s:y+s] = True
arr[x-s:x+s,y-1:y+1] = True
return arr
def get_random_location(width, height, zoom=1.0):
x = int(width * random.uniform(0.1, 0.9))
y = int(height * random.uniform(0.1, 0.9))
size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
return (x, y, size)
def plot_img_array(img_array, ncol=3):
nrow = len(img_array) // ncol
f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
for i in range(len(img_array)):
plots[i // ncol, i % ncol]
plots[i // ncol, i % ncol].imshow(img_array[i])
def plot_side_by_side(img_arrays):
flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
plot_img_array(np.array(flatten_list), ncol=len(img_arrays))
def plot_errors(results_dict, title):
markers = itertools.cycle(('+', 'x', 'o'))
plt.title('{}'.format(title))
for label, result in sorted(results_dict.items()):
plt.plot(result, marker=next(markers), label=label)
plt.ylabel('dice_coef')
plt.xlabel('epoch')
plt.legend(loc=3, bbox_to_anchor=(1, 0))
plt.show()
def masks_to_colorimg(masks):
colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
channels, height, width = masks.shape
for y in range(height):
for x in range(width):
selected_colors = colors[masks[:,y,x] > 0.5]
if len(selected_colors) > 0:
colorimg[y,x,:] = np.mean(selected_colors, axis=0)
return colorimg.astype(np.uint8)
3.看一下輸入數(shù)據(jù)和類標簽數(shù)據(jù)
# 生成圖片與類標簽(192*192, 3張)
input_images, target_masks = generate_random_data(192, 192, count=1)
print(f'輸入數(shù)據(jù)維度:{input_images.shape}')
print(f'輸出數(shù)據(jù)維度:{target_masks.shape}')
# 修改數(shù)據(jù)類型,方便畫圖
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# 將灰度圖片(channel=1)變?yōu)镽GB圖片(channel=3)
target_masks_rgb = [masks_to_colorimg(x) for x in target_masks]
# 顯示模擬圖片
plot_side_by_side([input_images_rgb, target_masks_rgb])
['out']:輸入數(shù)據(jù)維度:(1, 192, 192, 3)
['out']:輸出數(shù)據(jù)維度:(1, 6, 192, 192)
訓練數(shù)據(jù)一個(192袜茧,192纳鼎,3(RGB通道))的RGB圖片, 類標簽數(shù)據(jù)是一組灰度圖片(6逗宁,192件甥,192)引有,每個需要識別的圖形是一個灰度圖片一共6個圖形倦逐。
左圖為輸入數(shù)據(jù),右圖中將類標簽灰度圖片加了RBG通道您单,然后6張圖疊加的效果圖(我們只需預測6張灰度圖即可)。
4.數(shù)據(jù)生成器
# 一個簡單的pytorch數(shù)據(jù)迭代器
class SimDataset(Dataset):
def __init__(self, count, transform=None):
# count:每次需要生成的數(shù)據(jù)量
# transform指定數(shù)據(jù)轉(zhuǎn)化器
self.input_images, self.target_masks = generate_random_data(192, 192, count=count)
self.transform = transform
def __len__(self):
return len(self.input_images)
def __getitem__(self, idx):
image = self.input_images[idx]
mask = self.target_masks[idx]
if self.transform:
image = self.transform(image)
return [image, mask]
# use same transform for train/val for this example
trans = transforms.Compose([
transforms.ToTensor(),
])
# 這里生成2000組模擬數(shù)據(jù)作為訓練集, 200組模擬數(shù)據(jù)作為測試集
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
batch_size = 25
dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}
Unet網(wǎng)絡(luò)
# Unet編碼層, 如上圖所示,包含兩個(卷積+Relu)
# 原始Unet網(wǎng)絡(luò)中padding=0(填充)俺驶,所以"圖片"會變小
# 572*572--->570*570--->568*568
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
5.定義網(wǎng)絡(luò)
# Unet經(jīng)過一次double_conv通道數(shù)加倍(變厚)楚昭,然后使用Maxpool, "圖片"維度/2(變小)
class Unet(nn.Module):
def __init__(self, n_class):
super().__init__()
self.dconv_down1 = double_conv(3, 64)
self.dconv_down2 = double_conv(64, 128)
self.dconv_down3 = double_conv(128, 256)
self.dconv_down4 = double_conv(256, 512)
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 這里使用雙線性插值
self.dconv_up3 = double_conv(256 + 512, 256)
self.dconv_up2 = double_conv(128 + 256, 128)
self.dconv_up1 = double_conv(128 + 64, 64)
self.conv_last = nn.Conv2d(64, n_class, 1) # 最后一層, 需要識別多少種目標,則輸出多少個channel(n_class)
def forward(self, x):
conv1 = self.dconv_down1(x)
x = self.maxpool(conv1) # 對應上圖Unet編碼層2
conv2 = self.dconv_down2(x)
x = self.maxpool(conv2)
conv3 = self.dconv_down3(x)
x = self.maxpool(conv3)
x = self.dconv_down4(x) #到底了
x = self.upsample(x) # 雙線性插值电媳,還原"圖片"
# 解碼數(shù)據(jù)與對應編碼數(shù)據(jù)concat使channel數(shù)增加, 彌補了單純上采樣導致的信息還原不足
# 這一步很關(guān)鍵(也就是圖Unet解碼層1中數(shù)據(jù)變"厚")
x = torch.cat([x, conv3], dim=1)
x = self.dconv_up3(x)
x = self.upsample(x)
x = torch.cat([x, conv2], dim=1) # 256+128
x = self.dconv_up2(x)#
x = self.upsample(x)
x = torch.cat([x, conv1], dim=1)
x = self.dconv_up1(x)
out = self.conv_last(x)
return out
# 這里打印一下網(wǎng)絡(luò)結(jié)構(gòu)
model = Unet(6)
summary(model, input_size=(3, 224, 224))
6.損失函數(shù)
def dice_loss(pred, target, smooth = 1.):
pred = pred.contiguous()
target = target.contiguous()
intersection = (pred * target).sum(dim=2).sum(dim=2)
loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
return loss.mean()
# 這里使用兩種損失函數(shù)加權(quán)
def calc_loss(pred, target, metrics, bce_weight=0.5):
bce = F.binary_cross_entropy_with_logits(pred, target)
pred = F.sigmoid(pred)
dice = dice_loss(pred, target)
loss = bce * bce_weight + dice * (1 - bce_weight)
metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
return loss
def print_metrics(metrics, epoch_samples, phase):
outputs = []
for k in metrics.keys():
outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
print("{}: {}".format(phase, ", ".join(outputs)))
def train_model(model, optimizer, scheduler, num_epochs=25):
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-'*10)
since = time.time()
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
metrics = defaultdict(float)
epoch_samples = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = calc_loss(outputs, labels, metrics)
if phase == 'train':
loss.backward()
optimizer.step()
epoch_samples += inputs.size(0)
print_metrics(metrics, epoch_samples, phase)
epoch_loss = metrics['loss'] / epoch_samples
if phase == 'val' and epoch_loss < best_loss:
print("saving best model")
best_loss = epoch_loss
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(best_model_wts)
return model
7.訓練模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_class = 6
model = Unet(num_class).to(device)
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)