街景字符識別比賽所用的數(shù)據(jù)集包括圖像和JSON標(biāo)注。訓(xùn)練集數(shù)據(jù)包括3W張照片甲馋,驗證集數(shù)據(jù)包括1W張照片慈迈。數(shù)據(jù)的標(biāo)注使用JSON格式,并使用文件名進(jìn)行索引戳吝。對于賽題和數(shù)據(jù)集的更多信息浩销,可參考街景字符編碼識別-賽題解析。
下面我們將構(gòu)建讀取比賽的數(shù)據(jù)集听哭,首先生成數(shù)據(jù)名列表的csv文件以方便后面dataloader處理:
import os
import csv
DirList = os.listdir(ImgPath)
## write data list
with open(outPath+'train.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
for row in DirList:
writer.writerow([row])
讀入圖像數(shù)據(jù)和JSON數(shù)據(jù):
## 讀入圖像
from PIL import Image
im = Image.open(mainPath+'mchar_train/mchar_train/000000.png')
## 讀入JSON
## 讀入后為一個字典對象慢洋,key為圖像名,value為對應(yīng)標(biāo)簽
import json
import numpy as np
with open(json_trainDir, 'r') as f:
data = json.load(f)
print(data['000000.png'])
print(data['000000.png']['label'])
輸出
{'height': [219, 219],
'label': [1, 9],
'left': [246, 323],
'top': [77, 81],
'width': [81, 96]}
[1, 9]
讀入數(shù)據(jù)后陆盘,在載入網(wǎng)絡(luò)前為了增加訓(xùn)練集數(shù)據(jù)數(shù)量和類型普筹,我們要進(jìn)行數(shù)據(jù)增廣。
數(shù)據(jù)增廣包括幾何變換類如平移隘马,旋轉(zhuǎn)太防,翻轉(zhuǎn),縮放酸员;圖像色彩分布改變?nèi)缰狈綀D均衡蜒车,亮度色度調(diào)整讳嘱。也有一些針對特定任務(wù)的如加噪等。此處詳細(xì)可參考數(shù)據(jù)增廣之詳細(xì)理解
對于街景字符識別任務(wù)酿愧,可利用torchvision很方便的進(jìn)行數(shù)據(jù)增廣沥潭。關(guān)于torchvision中transforms的使用,可參考pytorch中transform常用的幾個方法
我們將賽題抽象為一個定長字符識別問題嬉挡,在賽題數(shù)據(jù)集中大部分圖像中字符個數(shù)為2-4個钝鸽,最多的字符個數(shù)為6個。因此將問題抽象為6個字符的識別問題棘伴,字符abc填充為abcXXX寞埠,X取"10",“0”~ "9"對于標(biāo)簽0~9焊夸。
import numpy as np
import os
import csv
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
class TrainDataLoader(Dataset):
def __init__(self, root, csvPath, json_Dir):
data = []
self.root = root
with open(csvPath, 'r') as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
data.append(row[0])
with open(json_Dir, 'r') as f:
info = json.load(f)
self.dataList = data
self.InfoDict = info
self.num = len(self.dataList)
def __len__(self):
return self.num
def ImgProcess(self, img):
# ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
# RandomRotation(degrees, resample=False, expand=False, center=None) 在(-degrees,+degrees)之間隨機旋轉(zhuǎn)
# transforms.ToTensor, 將PIL Image或者 ndarray 轉(zhuǎn)換為tensor仁连,并且歸一化至[0-1]
transform = transforms.Compose([transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(30),
transforms.ToTensor()])
imTensor = transform(img) # H,W,C,N
return imTensor
def __getitem__(self, idx):
# print('data path: ', self.root+self.dataList[idx])
imgName = self.dataList[idx]
img = Image.open(self.root+imgName)
imgInfo = self.InfoDict[imgName]
imgTensor = self.ImgProcess(img)
label = imgInfo['label']
label += [10]*(6-len(label)) ## 標(biāo)簽字符填充
sample = {'image':imgTensor, 'label':label}
return sample
定義好DataLoader逐batch載入數(shù)據(jù)
import matplotlib.pyplot as plt
DataRootPath = 'dir-to-your-data'
trainImgPath = dataPath+'mchar_train/mchar_train/'
ValImgPath = dataPath+'mchar_val/mchar_val/'
trainLabelPath = dataPath+'mchar_train.json'
ValLabelPath = dataPath+'mchar_val.json'
BATCH_SIZE = 1
train_dataset = TrainDataLoader(trainImgPath , DataRootPath+'mchar_train/train.csv', trainLabelPath )
train_num = len(train_dataset)
train_loader = DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = True)
for step, sample in enumerate(train_loader):
if(step==10): break #輸出10個樣本觀察
imgData = sample['image']
label = sample['label']
print('label:', label)
print('img size: ',imgData.size())
imgNp = imgData.squeeze_(0).numpy().transpose(1,2,0)
plt.imshow(imgNp)
plt.show()
以上,我們就完成了數(shù)據(jù)的讀取和增廣阱穗,下一步是選取合適的baseline進(jìn)行訓(xùn)練饭冬。