裂縫數(shù)據(jù)集
數(shù)據(jù)集地址:https://github.com/cuilimeng/CrackForest-dataset
結(jié)構(gòu):
--project
main.py
--image
--train
--data
--groundTruth
--val
--data
--groundTruth
我手動(dòng)將數(shù)據(jù)集做成這個(gè)格式地梨,其中trian84張蛋逾,val34張程储,都保存為了jpg圖像。
Unet
論文地址:http://www.arxiv.org/pdf/1505.04597.pdf
代碼來(lái)源:https://github.com/JavisPeng/u_net_liver
上面代碼中,作者將Unet運(yùn)用于liver識(shí)別,和裂縫一樣,都只有一個(gè)mask偏化,因而我們可以直接使用上述代碼。
需要修改dataset.py為自己的數(shù)據(jù)集镐侯,其他小小改動(dòng)即可侦讨。
#dataset.py
import torch.utils.data as data
import PIL.Image as Image
import os
def make_dataset(rootdata,roottarget):#獲取img和mask的地址
imgs = []
filename_data = [x for x in os.listdir(rootdata)]
for name in filename_data:
img = os.path.join(rootdata, name)
mask = os.path.join(roottarget, name)
imgs.append((img, mask))#作為元組返回
return imgs
class MyDataset(data.Dataset):
def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
imgs = make_dataset(rootdata,roottarget)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
x_path, y_path = self.imgs[index]
img_x = Image.open(x_path).convert('L')#讀取并轉(zhuǎn)換為二值圖像
img_y = Image.open(y_path).convert('L')
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
return img_x, img_y
def __len__(self):
return len(self.imgs)
#main.py
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import MyDataset
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # 復(fù)活了,這里修改就沒錯(cuò)誤了
])
# mask只需要轉(zhuǎn)換為tensor
y_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, num_epochs=10):
for epoch in range(0,num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" %
(step,
(dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
return model
#訓(xùn)練模型
def train():
batch_size = 1
liver_dataset = MyDataset(
"image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(
liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer, dataloaders)
#顯示模型的輸出結(jié)果
def test():
liver_dataset = MyDataset(
"image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y = model(x)
img_y = torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
plt.show()
if __name__ == '__main__':
pretrained = False
model = Unet(1, 1).to(device)
if pretrained:
model.load_state_dict(torch.load('./weights_4.pth'))
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
train()
test()
unet.py不需要變動(dòng)
結(jié)果
訓(xùn)練了10個(gè)epoch后:累加loss大概到3
前幾張預(yù)測(cè)圖片:
對(duì)于100多張的數(shù)據(jù)集韵卤,這個(gè)效果還行。
也算是填了一個(gè)以前的坑崇猫。