Unet圖像分割

Unet網(wǎng)絡(luò)是一種圖像語義分割網(wǎng)絡(luò),圖像語義分割網(wǎng)絡(luò)讓計算機根據(jù)圖像的語義來進行分割,例如讓計算機在輸入下面下圖,能夠輸出指定分割的圖片规惰。
基本圖片分割

原圖中文狱,物體被分為三類壕曼,1.背景衅鹿, 2.人泵三, 3.自行車

地理信息

語義分割的用處很多较曼,比如說上圖中分割衛(wèi)星圖伏恐,通過多倫迭代,Prediction逐漸與Grond Truth一致斗幼。

Unet網(wǎng)絡(luò)結(jié)構(gòu)

Unet網(wǎng)絡(luò)結(jié)構(gòu)如下毁兆,整個網(wǎng)絡(luò)形如字母U畔咧。簡單的來說,整個網(wǎng)絡(luò)分為兩個部分虹蓄,左邊部分負責特征提取律胀,隨著網(wǎng)絡(luò)層加深,網(wǎng)絡(luò)的channel逐漸變大,"圖片"逐漸變小。右邊的網(wǎng)絡(luò)負責特征的還原旷坦,整個網(wǎng)絡(luò)實際上就是一個編碼-解碼器默勾。需要注意的是辆琅,整個網(wǎng)絡(luò)最出彩的地方是灰色箭頭的部分婉烟。在編碼的過程中似袁,部分信息丟失了(Maxpooling和Conv2D)扬霜。在解碼時啼县,加入與之對應的編碼層信息季眷。從圖上來看的話就是右邊每一層網(wǎng)絡(luò)都加入了一部分"白"色的"圖片"(特征)。

全連接語義分割

那么這里就有個問題兼搏,為什么要這么復雜的做一個編碼-解碼器佛呻?上圖的一個簡單的多層卷積就可以完成圖像語義分割。


編碼解碼器

原因就在于隨著卷積核的越大,伴隨著參數(shù)就會成倍增長诫肠,一是運算效率會大大下降,其次不利于收斂丧鸯。這里強烈推薦看一篇文章“看懂”卷積神經(jīng)網(wǎng)(Visualizing and Understanding Convolutional Networks)

工作原理1

這里講一下,Unet工作原理,假設(shè)我們有一張圖片派敷,如左圖所示,我們會根據(jù)實際需要將需要識別的區(qū)域轉(zhuǎn)化為特定的"編碼"作為類標簽试躏。


工作原理2

工作原理3

實際上每個需要識別的物體需要一個channel,有多少個需要識別的物體,就有多少個輸出channel寡键,最后再做一個疊加就是最終我們想分割的結(jié)果。

下面哪一個簡單的實例代碼來說明Unet的工作原理,源代碼Github在這里,下面我做一些解釋性說明

1.首先引入必要包
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os, sys
import random
import copy
import itertools
import time
from functools import reduce
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from torchsummary import summary
2.生成模擬數(shù)據(jù),這一部分不用太糾結(jié)代碼韭邓,復制粘貼就可以
def generate_random_data(height, width, count):
    x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)])
    X = np.asarray(x) * 255
    X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
    Y = np.asarray(y)
    return X, Y

def generate_img_and_mask(height, width):
    shape = (height, width)
    triangle_location = get_random_location(*shape)
    circle_location1 = get_random_location(*shape, zoom=0.7)
    circle_location2 = get_random_location(*shape, zoom=0.5)
    mesh_location = get_random_location(*shape)
    square_location = get_random_location(*shape, zoom=0.8)
    plus_location = get_random_location(*shape, zoom=1.2)

    # Create input image
    arr = np.zeros(shape, dtype=bool)
    arr = add_triangle(arr, *triangle_location)
    arr = add_circle(arr, *circle_location1)
    arr = add_circle(arr, *circle_location2, fill=True)
    arr = add_mesh_square(arr, *mesh_location)
    arr = add_filled_square(arr, *square_location)
    arr = add_plus(arr, *plus_location)
    arr = np.reshape(arr, (1, height, width)).astype(np.float32)

    # Create target masks
    masks = np.asarray([
        add_filled_square(np.zeros(shape, dtype=bool), *square_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True),
        add_triangle(np.zeros(shape, dtype=bool), *triangle_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location1),
         add_filled_square(np.zeros(shape, dtype=bool), *mesh_location),
        # add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location),
        add_plus(np.zeros(shape, dtype=bool), *plus_location)
    ]).astype(np.float32)
    return arr, masks

def add_square(arr, x, y, size):
    s = int(size / 2)
    arr[x-s,y-s:y+s] = True
    arr[x+s,y-s:y+s] = True
    arr[x-s:x+s,y-s] = True
    arr[x-s:x+s,y+s] = True
    return arr

def add_filled_square(arr, x, y, size):
    s = int(size / 2)
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))

def logical_and(arrays):
    new_array = np.ones(arrays[0].shape, dtype=bool)
    for a in arrays:
        new_array = np.logical_and(new_array, a)
    return new_array

def add_mesh_square(arr, x, y, size):
    s = int(size / 2)
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))

def add_triangle(arr, x, y, size):
    s = int(size / 2)
    triangle = np.tril(np.ones((size, size), dtype=bool))
    arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle
    return arr

def add_circle(arr, x, y, size, fill=False):
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
    new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))
    return new_arr

def add_plus(arr, x, y, size):
    s = int(size / 2)
    arr[x-1:x+1,y-s:y+s] = True
    arr[x-s:x+s,y-1:y+1] = True
    return arr

def get_random_location(width, height, zoom=1.0):
    x = int(width * random.uniform(0.1, 0.9))
    y = int(height * random.uniform(0.1, 0.9))
    size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
    return (x, y, size)

def plot_img_array(img_array, ncol=3):
    nrow = len(img_array) // ncol
    f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
    for i in range(len(img_array)):
        plots[i // ncol, i % ncol]
        plots[i // ncol, i % ncol].imshow(img_array[i])

def plot_side_by_side(img_arrays):
    flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
    plot_img_array(np.array(flatten_list), ncol=len(img_arrays))

def plot_errors(results_dict, title):
    markers = itertools.cycle(('+', 'x', 'o'))
    plt.title('{}'.format(title))
    for label, result in sorted(results_dict.items()):
        plt.plot(result, marker=next(markers), label=label)
        plt.ylabel('dice_coef')
        plt.xlabel('epoch')
        plt.legend(loc=3, bbox_to_anchor=(1, 0))
    plt.show()

def masks_to_colorimg(masks):
    colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
    colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    channels, height, width = masks.shape
    for y in range(height):
        for x in range(width):
            selected_colors = colors[masks[:,y,x] > 0.5]
            if len(selected_colors) > 0:
                colorimg[y,x,:] = np.mean(selected_colors, axis=0)
    return colorimg.astype(np.uint8)
3.看一下輸入數(shù)據(jù)和類標簽數(shù)據(jù)
# 生成圖片與類標簽(192*192, 3張)
input_images, target_masks = generate_random_data(192, 192, count=1)
print(f'輸入數(shù)據(jù)維度:{input_images.shape}')
print(f'輸出數(shù)據(jù)維度:{target_masks.shape}')
# 修改數(shù)據(jù)類型,方便畫圖
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# 將灰度圖片(channel=1)變?yōu)镽GB圖片(channel=3)
target_masks_rgb = [masks_to_colorimg(x) for x in target_masks]
# 顯示模擬圖片
plot_side_by_side([input_images_rgb, target_masks_rgb])

['out']:輸入數(shù)據(jù)維度:(1, 192, 192, 3)
['out']:輸出數(shù)據(jù)維度:(1, 6, 192, 192)

訓練數(shù)據(jù)一個(192袜茧,192纳鼎,3(RGB通道))的RGB圖片, 類標簽數(shù)據(jù)是一組灰度圖片(6逗宁,192件甥,192)引有,每個需要識別的圖形是一個灰度圖片一共6個圖形倦逐。
模擬數(shù)據(jù)

左圖為輸入數(shù)據(jù),右圖中將類標簽灰度圖片加了RBG通道您单,然后6張圖疊加的效果圖(我們只需預測6張灰度圖即可)。

4.數(shù)據(jù)生成器
# 一個簡單的pytorch數(shù)據(jù)迭代器
class SimDataset(Dataset):
    def __init__(self, count, transform=None):
        # count:每次需要生成的數(shù)據(jù)量
        # transform指定數(shù)據(jù)轉(zhuǎn)化器
        self.input_images, self.target_masks = generate_random_data(192, 192, count=count)        
        self.transform = transform

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx):
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)
        return [image, mask]
# use same transform for train/val for this example
trans = transforms.Compose([
    transforms.ToTensor(),
])
# 這里生成2000組模擬數(shù)據(jù)作為訓練集, 200組模擬數(shù)據(jù)作為測試集
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
batch_size = 25
dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

Unet網(wǎng)絡(luò)

Unet編碼層
# Unet編碼層, 如上圖所示,包含兩個(卷積+Relu)
# 原始Unet網(wǎng)絡(luò)中padding=0(填充)俺驶,所以"圖片"會變小
# 572*572--->570*570--->568*568
def double_conv(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, 3, padding=1),
      nn.ReLU(inplace=True)
  )
Unet編碼層2

Unet解碼層1
5.定義網(wǎng)絡(luò)
# Unet經(jīng)過一次double_conv通道數(shù)加倍(變厚)楚昭,然后使用Maxpool, "圖片"維度/2(變小)
class Unet(nn.Module):
  def __init__(self, n_class):
    super().__init__()
    self.dconv_down1 = double_conv(3, 64)
    self.dconv_down2 = double_conv(64, 128)
    self.dconv_down3 = double_conv(128, 256)
    self.dconv_down4 = double_conv(256, 512)
    self.maxpool = nn.MaxPool2d(2)
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 這里使用雙線性插值
    self.dconv_up3 = double_conv(256 + 512, 256)
    self.dconv_up2 = double_conv(128 + 256, 128)
    self.dconv_up1 = double_conv(128 + 64, 64)
    self.conv_last = nn.Conv2d(64, n_class, 1) # 最后一層, 需要識別多少種目標,則輸出多少個channel(n_class)

  def forward(self, x):
    conv1 = self.dconv_down1(x)
    x = self.maxpool(conv1) # 對應上圖Unet編碼層2
    conv2 = self.dconv_down2(x)
    x = self.maxpool(conv2)
    conv3 = self.dconv_down3(x)
    x = self.maxpool(conv3)
    x = self.dconv_down4(x) #到底了
    x = self.upsample(x) # 雙線性插值电媳,還原"圖片"
    # 解碼數(shù)據(jù)與對應編碼數(shù)據(jù)concat使channel數(shù)增加, 彌補了單純上采樣導致的信息還原不足
    # 這一步很關(guān)鍵(也就是圖Unet解碼層1中數(shù)據(jù)變"厚")
    x = torch.cat([x, conv3], dim=1) 
    x = self.dconv_up3(x)
    x = self.upsample(x)        
    x = torch.cat([x, conv2], dim=1) # 256+128
    x = self.dconv_up2(x)# 
    x = self.upsample(x)        
    x = torch.cat([x, conv1], dim=1)
    x = self.dconv_up1(x)
    out = self.conv_last(x)
    return out
# 這里打印一下網(wǎng)絡(luò)結(jié)構(gòu)
model = Unet(6)
summary(model, input_size=(3, 224, 224))
數(shù)值化網(wǎng)絡(luò)結(jié)構(gòu)
6.損失函數(shù)
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()
# 這里使用兩種損失函數(shù)加權(quán)
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target) 
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
    print("{}: {}".format(phase, ", ".join(outputs))) 

def train_model(model, optimizer, scheduler, num_epochs=25):
  best_model_wts = copy.deepcopy(model.state_dict())
  best_loss = 1e10
  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-'*10)
    since = time.time()
    for phase in ['train', 'val']:
      if phase == 'train':
        scheduler.step()
        model.train()  # Set model to training mode
      else:
        model.eval()   # Set model to evaluate mode
      metrics = defaultdict(float)
      epoch_samples = 0

      for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          loss = calc_loss(outputs, labels, metrics)
          if phase == 'train':
            loss.backward()
            optimizer.step()
        epoch_samples += inputs.size(0)
      print_metrics(metrics, epoch_samples, phase)
      epoch_loss = metrics['loss'] / epoch_samples

      if phase == 'val' and epoch_loss < best_loss:
        print("saving best model")
        best_loss = epoch_loss
        best_model_wts = copy.deepcopy(model.state_dict())
    time_elapsed = time.time() - since
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  print('Best val loss: {:4f}'.format(best_loss))
  # load best model weights
  model.load_state_dict(best_model_wts)
  return model
7.訓練模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_class = 6
model = Unet(num_class).to(device)
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)
訓練結(jié)果
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末铃辖,一起剝皮案震驚了整個濱河市娇斩,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌歉嗓,老刑警劉巖,帶你破解...
    沈念sama閱讀 212,454評論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件志珍,死亡現(xiàn)場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機还棱,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,553評論 3 385
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人稚补,你說我怎么就攤上這事课幕。” “怎么了杜秸?”我有些...
    開封第一講書人閱讀 157,921評論 0 348
  • 文/不壞的土叔 我叫張陵润绎,是天一觀的道長。 經(jīng)常有香客問我呢蛤,道長,這世上最難降的妖魔是什么其障? 我笑而不...
    開封第一講書人閱讀 56,648評論 1 284
  • 正文 為了忘掉前任坝撑,我火速辦了婚禮,結(jié)果婚禮上抚笔,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好狱从,可當我...
    茶點故事閱讀 65,770評論 6 386
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著敞葛,像睡著了一般与涡。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上驼卖,一...
    開封第一講書人閱讀 49,950評論 1 291
  • 那天,我揣著相機與錄音怎囚,去河邊找鬼桥胞。 笑死考婴,一個胖子當著我的面吹牛井誉,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播颗圣,決...
    沈念sama閱讀 39,090評論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼在岂,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了蔽午?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,817評論 0 268
  • 序言:老撾萬榮一對情侶失蹤抽莱,失蹤者是張志新(化名)和其女友劉穎骄恶,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體僧鲁,經(jīng)...
    沈念sama閱讀 44,275評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,592評論 2 327
  • 正文 我和宋清朗相戀三年斟叼,在試婚紗的時候發(fā)現(xiàn)自己被綠了春寿。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 38,724評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡馋缅,死狀恐怖绢淀,靈堂內(nèi)的尸體忽然破棺而出瘾腰,到底是詐尸還是另有隱情,我是刑警寧澤蹋盆,帶...
    沈念sama閱讀 34,409評論 4 333
  • 正文 年R本政府宣布硝全,位于F島的核電站伟众,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏凳厢。R本人自食惡果不足惜竞慢,卻給世界環(huán)境...
    茶點故事閱讀 40,052評論 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望筹煮。 院中可真熱鬧,春花似錦本冲、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,815評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至闷板,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間遮晚,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,043評論 1 266
  • 我被黑心中介騙來泰國打工糜颠, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留萧求,地道東北人。 一個月前我還...
    沈念sama閱讀 46,503評論 2 361
  • 正文 我出身青樓夸政,卻偏偏與公主長得像,于是被迫代替她去往敵國和親匀归。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 43,627評論 2 350

推薦閱讀更多精彩內(nèi)容

  • 主要內(nèi)容包括: 1袱贮、基于boosting級聯(lián)學習的遙感目標檢測 Adaboost算法、前向分步算法攒巍、提升樹、梯度提...
    大概是只翻車魚的莫方閱讀 2,932評論 0 2
  • 這些年計算機視覺識別和搜索這個領(lǐng)域非常熱鬧窑业,后期出現(xiàn)了很多的創(chuàng)業(yè)公司枕屉,大公司也在這方面也花了很多力氣在做。做視覺搜...
    方弟閱讀 6,466評論 6 24
  • 月輪國西潘,地處西域,崇尚佛法喷市,共轄七府四十二縣一百六十八鎮(zhèn)。府縣鎮(zhèn)各級都設(shè)有寺廟品姓。佛在月輪國為尊箫措,連帶著和尚的地位也...
    三小山閑話閱讀 638評論 0 2
  • 昨天晚上十點下班,挺累的斤蔓,只因昨天周三,店里搞活動弦牡,人比較多。 天氣逐漸轉(zhuǎn)冷驾锰,騎車走在下班的路上,雖是十點椭豫,在上海...
    愛佳蓉閱讀 120評論 0 0
  • 自己從建庫開始完成按key1鍵實現(xiàn)紅綠藍燈轉(zhuǎn)換买喧,按key2蜂鳴器響,松手停今缚,用中斷形式,其中包括中斷子函數(shù)的編寫
    李欣l閱讀 207評論 0 0