Pytorch之圖像分割(多目標(biāo)分割,Multi Object Segmentation)

示例調(diào)用預(yù)訓(xùn)練模型(deeplabv3_resnet101)對VOCSegmentation數(shù)據(jù)進行圖像分割實驗枪萄。

  • PyTorch的DeepLabv3-ResNet101語義分割模型是在COCO 2017訓(xùn)練集上的一個子集訓(xùn)練得到的隐岛,相當(dāng)于PASCAL VOC數(shù)據(jù)集,支持20個類別瓷翻。
  • Deeplabv3-ResNet101由具有ResNet-101主干的Deeplabv3模型構(gòu)成聚凹。

引入相關(guān)包

%matplotlib inline
import os
import copy
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
from PIL import Image   

import torch
from torch import nn
from torch import optim
from torchvision.datasets import VOCSegmentation
from torchvision.transforms.functional import to_tensor, to_pil_image
from torch.utils.data import DataLoader
from torchvision.models.segmentation import deeplabv3_resnet101
from torch.optim.lr_scheduler import ReduceLROnPlateau

構(gòu)建數(shù)據(jù) dataset

class DemoVOCSegmentation(VOCSegmentation):
    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

        if self.transforms is not None:
            augmented = self.transforms(image=np.array(img), mask=np.array(target))
            img = augmented['image']
            target = augmented['mask']                  
            target[target>20] = 0

        img = to_tensor(img)            
        target = torch.from_numpy(target).type(torch.long)
        return img, target
    
    
from albumentations import (
    HorizontalFlip,
    Compose,
    Resize,
    Normalize)

mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
h, w = 520,520

transform_train = Compose([ Resize(h,w),
                HorizontalFlip(p=0.5), 
                Normalize(mean=mean, std=std)])

transform_val = Compose([ Resize(h,w),
                          Normalize(mean=mean, std=std)])

 數(shù)據(jù)地址
path_data = "./data/mos/"    
# 創(chuàng)建dataset
train_ds = DemoVOCSegmentation(path_data, 
                year='2012', 
                image_set='train', 
                download=False, 
                transforms=transform_train) 
print(len(train_ds))
# 1464


val_ds = DemoVOCSegmentation(path_data, 
                year='2012', 
                image_set='val', 
                download=False, 
                transforms=transform_val)
print(len(val_ds)) #1449
  • 數(shù)據(jù)查看(可視化)
np.random.seed(0)
num_classes =21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")

def show_img_target(img, target):
    if torch.is_tensor(img):
        img = to_pil_image(img)
        target = target.numpy()
    for ll in range(num_classes):
        mask = (target==ll)
        img = mark_boundaries(np.array(img) , 
                            mask,
                            outline_color=COLORS[ll],
                            color=COLORS[ll])
    plt.imshow(img)
    

def re_normalize (x, mean = mean, std= std):
    x_r= x.clone()
    for c, (mean_c, std_c) in enumerate(zip(mean, std)):
        x_r [c] *= std_c
        x_r [c] += mean_c
    return x_r



img, mask = train_ds[6]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))

plt.figure(figsize=(20,20))

img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))

plt.subplot(1, 3, 2) 
plt.imshow(mask)

plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)
"""
torch.Size([3, 520, 520]) torch.FloatTensor tensor(2.6400)
torch.Size([520, 520]) torch.LongTensor tensor(4)
"""
image segmentation

數(shù)據(jù)加載器及加載模型

# dataloader
train_dl = DataLoader(train_ds, batch_size=2, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False)
# 加載預(yù)訓(xùn)練模型
model=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
# print(model)

模型部署

model.eval()
with torch.no_grad():
    for xb, yb in val_dl:
        yb_pred = model(xb.to(device))
        yb_pred = yb_pred["out"].cpu()
        print(yb_pred.shape)    
        yb_pred = torch.argmax(yb_pred,axis=1)
        break
print(yb_pred.shape)

plt.figure(figsize=(20,20))

n=4
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))

plt.subplot(1, 3, 2) 
plt.imshow(mask)

plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)
"""
torch.Size([16, 21, 520, 520])
torch.Size([16, 520, 520])
"""
deploy model to predict

模型訓(xùn)練(因為電腦顯卡太低割坠,微調(diào)訓(xùn)練無法實驗測試)

def get_lr(opt):
    for param_group in opt.param_groups:
        return param_group['lr']

def loss_batch(loss_func, output, target, opt=None):   
    loss = loss_func(output, target)
    
    if opt is not None:
        opt.zero_grad()
        loss.backward()
        opt.step()

    return loss.item(), None

# 訓(xùn)練模型
def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
    running_loss = 0.0
    len_data = len(dataset_dl.dataset)

    for xb, yb in dataset_dl:
        xb = xb.to(device)
        yb = yb.to(device)
        
        output = model(xb)["out"]
        loss_b, _ = loss_batch(loss_func, output, yb, opt)
        running_loss += loss_b
        
        if sanity_check is True:
            break
    
    loss = running_loss / float(len_data)
    return loss, None

def train_val(model, params):
    num_epochs=params["num_epochs"]
    loss_func=params["loss_func"]
    opt=params["optimizer"]
    train_dl=params["train_dl"]
    val_dl=params["val_dl"]
    sanity_check=params["sanity_check"]
    lr_scheduler=params["lr_scheduler"]
    path2weights=params["path2weights"]
    
    loss_history={
        "train": [],
        "val": []}
    
    metric_history={
        "train": [],
        "val": []}    
    
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss=float('inf')    
    
    for epoch in range(num_epochs):
        current_lr=get_lr(opt)
        print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   

        model.train()
        train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)

        loss_history["train"].append(train_loss)
        metric_history["train"].append(train_metric)
        
        model.eval()
        with torch.no_grad():
            val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)
       
        loss_history["val"].append(val_loss)
        metric_history["val"].append(val_metric)   
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            
            torch.save(model.state_dict(), path2weights)
            print("Copied best model weights!")
            
        lr_scheduler.step(val_loss)
        if current_lr != get_lr(opt):
            print("Loading best model weights!")
            model.load_state_dict(best_model_wts) 
            
        print("train loss: %.6f" %(train_loss))
        print("val loss: %.6f" %(val_loss))
        print("-"*10) 
    model.load_state_dict(best_model_wts)
    return model, loss_history, metric_history
  • 訓(xùn)練模型
criterion = nn.CrossEntropyLoss(reduction="sum")
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)

path2models= "./models/mos/"
if not os.path.exists(path2models):
        os.mkdir(path2models)

params_train={
    "num_epochs": 10,
    "optimizer": opt,
    "loss_func": criterion,
    "train_dl": train_dl,
    "val_dl": val_dl,
    "sanity_check": True,
    "lr_scheduler": lr_scheduler,
    "path2weights": path2models+"sanity_weights.pt",
}

model, loss_hist, _ = train_val(model, params_train)
  • 可視化結(jié)果
num_epochs=params_train["num_epochs"]

plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
image.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市妒牙,隨后出現(xiàn)的幾起案子彼哼,更是在濱河造成了極大的恐慌,老刑警劉巖湘今,帶你破解...
    沈念sama閱讀 207,113評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件敢朱,死亡現(xiàn)場離奇詭異,居然都是意外死亡摩瞎,警方通過查閱死者的電腦和手機拴签,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,644評論 2 381
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來旗们,“玉大人蚓哩,你說我怎么就攤上這事∩峡剩” “怎么了岸梨?”我有些...
    開封第一講書人閱讀 153,340評論 0 344
  • 文/不壞的土叔 我叫張陵,是天一觀的道長稠氮。 經(jīng)常有香客問我盛嘿,道長,這世上最難降的妖魔是什么括袒? 我笑而不...
    開封第一講書人閱讀 55,449評論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮稿茉,結(jié)果婚禮上锹锰,老公的妹妹穿的比我還像新娘。我一直安慰自己漓库,他們只是感情好恃慧,可當(dāng)我...
    茶點故事閱讀 64,445評論 5 374
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著渺蒿,像睡著了一般痢士。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上茂装,一...
    開封第一講書人閱讀 49,166評論 1 284
  • 那天怠蹂,我揣著相機與錄音,去河邊找鬼少态。 笑死城侧,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的彼妻。 我是一名探鬼主播嫌佑,決...
    沈念sama閱讀 38,442評論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼豆茫,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了屋摇?” 一聲冷哼從身側(cè)響起揩魂,我...
    開封第一講書人閱讀 37,105評論 0 261
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎炮温,沒想到半個月后火脉,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,601評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡茅特,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,066評論 2 325
  • 正文 我和宋清朗相戀三年忘分,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片白修。...
    茶點故事閱讀 38,161評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡妒峦,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出兵睛,到底是詐尸還是另有隱情肯骇,我是刑警寧澤,帶...
    沈念sama閱讀 33,792評論 4 323
  • 正文 年R本政府宣布祖很,位于F島的核電站笛丙,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏假颇。R本人自食惡果不足惜胚鸯,卻給世界環(huán)境...
    茶點故事閱讀 39,351評論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望笨鸡。 院中可真熱鬧姜钳,春花似錦、人聲如沸形耗。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,352評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽激涤。三九已至拟糕,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間倦踢,已是汗流浹背送滞。 一陣腳步聲響...
    開封第一講書人閱讀 31,584評論 1 261
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留辱挥,地道東北人累澡。 一個月前我還...
    沈念sama閱讀 45,618評論 2 355
  • 正文 我出身青樓,卻偏偏與公主長得像般贼,于是被迫代替她去往敵國和親愧哟。 傳聞我的和親對象是個殘疾皇子奥吩,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 42,916評論 2 344

推薦閱讀更多精彩內(nèi)容