有一天在做關(guān)于物體特征點定位的工作和泌,有一天突發(fā)奇想村缸,想要通過pytorch建一個模型進(jìn)行特征點定位。努力敲了半天代碼武氓,終于實現(xiàn)了梯皿,可惜由于自己采的數(shù)據(jù)集過小(或者是其他原因县恕,歡迎大神賜教)东羹,導(dǎo)致定位結(jié)果誤差很大≈抑颍總的來說還算成功属提。
我采集的數(shù)據(jù)集,我在杯子上點了個黑點,然后手工標(biāo)定得到j(luò)son文件冤议,想要通過模型定位黑點坐標(biāo)斟薇,奈何。恕酸。不太理想:
先潑代碼:
首先是對應(yīng)標(biāo)簽及樣本的數(shù)據(jù)集以便載入:
#作者:Rayne
#作用:對應(yīng)json文件中的坐標(biāo)及文件夾中圖片路徑堪滨,以便Dataset模塊載入
import os
import json
def get_img_path(img_path):
file_path=[]
for file in os.listdir(img_path):
file_path.append(os.path.join(img_path+'/', file))
return file_path
def get_label(label_path):# ./label.json
labels={}
with open(label_path,'r',encoding='UTF-8') as f:
js=json.load(f)
for i in js:
x,y,name=i['Data']['svgArr'][0]['data'][0]['x'],i['Data']['svgArr'][0]['data'][0]['y'],i['imageName']
labels[name]=[x,y]
return labels
def get_all(img_path,label_path):
file_path=get_img_path(img_path)
labels=get_label(label_path)
label=[]
for file in file_path:
label.append(labels[file.split('/')[2]])
return file_path,label
其次是載入文件夾中的數(shù)據(jù):ImageLoader.py:
#作者:Rayne
#作用:載入圖片文件及標(biāo)簽,標(biāo)簽是[1,2]的list蕊温,對應(yīng)特征點x袱箱,y
import torch.utils.data as data
import torch
from PIL import Image
import numpy as np
def default_loader(path):
return Image.open(path).convert('RGB')
###############################################
class myImageFloder(data.Dataset):
def __init__(self, img, label, loader=default_loader):
self.img = img
self.label = label
self.loader = loader
def __getitem__(self, index):
img = self.img[index]
label = self.label[index]
# 數(shù)據(jù)打開
img_open = self.loader(img)
data = np.ascontiguousarray(img_open, dtype=np.float32) / 256
label = np.array([label[0]/540.0,label[1]/384.0],dtype=np.float32)
data = torch.from_numpy(data).view(3, 540, 384).cuda()
label = torch.from_numpy(label).cuda()
return data, label
def __len__(self):
return len(self.img)
然后是模型搭建,后來我用了遷移學(xué)習(xí):
#作者:Rayne
#作用:博主花費(fèi)20分鐘搭建的模型
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3)
self.conv2 = nn.Conv2d(6, 12, 3)
self.pool=nn.MaxPool2d(2,2)
self.conv3 = nn.Conv2d(12, 24, 3)
self.conv4 = nn.Conv2d(24, 48, 3)
self.conv5 = nn.Conv2d(48, 96, 3)
self.conv6 = nn.Conv2d(96, 192, 3)
self.fc1 = nn.Linear(192*24, 48)
self.fc2 = nn.Linear(48, 12)
self.fc3 = nn.Linear(12, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = self.pool(F.relu(self.conv5(x)))
x = self.pool(F.relu(self.conv6(x)))
x=torch.flatten(x)
x = x.view(-1, 192*24)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
然后是訓(xùn)練函數(shù)啦:
#作者:Rayne
#作用:載入圖片及標(biāo)簽义矛,定義訓(xùn)練函數(shù)发笔,打印訓(xùn)練結(jié)果。
import model
import torch
import torch.nn as nn
from data import dir_xy, ImageLoader
import torch.optim as optim
train_hist=[]
test_hist=[]
def train(net=None, criterion=None, optimizer=None, TrainImgLoader=None, TestImgLoader=None, epochs=20):
running_loss = 0.0
test_loss = 0.0
for epoch in range(epochs): # loop over the dataset multiple times
for i, data in enumerate(TrainImgLoader):
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs) # 輸出為[-1症革,2]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 10 ==9:
print('第{}圈train Loss: {}'.format(epoch, running_loss / 10))
train_hist.append(running_loss / 10)
running_loss = 0.0
with torch.no_grad():
for i, data in enumerate(TestImgLoader):
images, labels = data
outputs = net(images)
loss = criterion(outputs, labels)
test_loss += loss.item()
if i % 10 ==9:
print('第{}圈test Loss: {}'.format(epoch, test_loss / 10))
test_hist.append(test_loss)
test_loss = 0.0
plter(epochs=epochs, train_loss=train_hist, test_loss=test_hist)
print('Finished Training')
def plter(train_loss,test_loss,epochs):
import matplotlib.pyplot as plt
x = range(0, epochs)
fig, ax = plt.subplots()
ax.plot(range(len(train_loss)), train_loss, label='train')
ax.plot(range(len(test_loss)), test_loss, label='test')
ax.set_xlabel(xlabel='epoch')
ax.set_ylabel(ylabel='MSE')
ax.set_title('Epochs VS MSE')
ax.legend()
plt.show()
最后是主函數(shù):
#作者:Rayne
#作用:定義優(yōu)化器筐咧,模型,損失函數(shù)等并進(jìn)行訓(xùn)練
import model
import torch
import torch.nn as nn
from data import dir_xy, ImageLoader
import torch.optim as optim
import train_test
import torchvision
def train():
torch.set_default_tensor_type(torch.FloatTensor)
# net = model.Net().cuda()
net = torchvision.models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2)
net=net.cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.00001)
dir, xy = dir_xy.get_all(img_path='data/train', label_path='data/label/train.json')
loader = ImageLoader.myImageFloder(dir, xy)
TrainImgLoader = torch.utils.data.DataLoader(loader,batch_size=10,shuffle = True)
dir, xy = dir_xy.get_all(img_path='data/test', label_path='data/label/test.json')
loader2 = ImageLoader.myImageFloder(dir, xy)
TestImgLoader = torch.utils.data.DataLoader(loader2)
train_test.train(net=net.cuda(), criterion=criterion, optimizer=optimizer, TrainImgLoader=TrainImgLoader,
TestImgLoader=TestImgLoader)
train()
結(jié)果:
下降到后面噪矛,尤其是10個循環(huán)后不太明顯量蕊,我相信擁有更多的數(shù)據(jù)后會得到更好的結(jié)果。希望可以幫到你~