在使用深度學(xué)習(xí)學(xué)習(xí)圖像文件的時(shí)候,如果圖片文件很多瘪菌,比如大幾千張,或者幾萬(wàn)張嘹朗。如果將一個(gè)圖片再接一個(gè)圖片導(dǎo)入到內(nèi)存中师妙,會(huì)極大地拖慢深度學(xué)習(xí)算法運(yùn)行速度。我們都有過(guò)這樣的經(jīng)驗(yàn)屹培,如果將一個(gè)文件夾里面有幾萬(wàn)個(gè)文件的文件夾進(jìn)行復(fù)制默穴,其速度要比將文件夾打包之后復(fù)制速度要慢很多。
為了不讓IO運(yùn)算(將硬件中的圖片一個(gè)一個(gè)導(dǎo)入到內(nèi)存中)成為深度學(xué)習(xí)訓(xùn)練速度無(wú)法提高的一個(gè)瓶頸褪秀,這里來(lái)介紹一種方法蓄诽,就是將很多的文件打包成一個(gè)HDF5的文件格式,最后用深度學(xué)習(xí)算法學(xué)習(xí)的時(shí)候媒吗,直接HDF5文件中導(dǎo)入數(shù)據(jù)就可以仑氛。
HDF5文件介紹
HDF5是一種數(shù)據(jù)存儲(chǔ)格式,特別適合向磁盤(pán)中存取大數(shù)據(jù)的時(shí)候使用闸英。一個(gè)HDF5文件可以被看成一個(gè)組锯岖,包含了不同的數(shù)據(jù)集,數(shù)據(jù)集可以是圖像表格等等自阱。HDF5組結(jié)構(gòu)類(lèi)似于文件系統(tǒng)的目錄層次結(jié)構(gòu)嚎莉,根目錄再包含其他目錄米酬。節(jié)點(diǎn)目錄里存放相應(yīng)的數(shù)據(jù)集沛豌。
安裝很簡(jiǎn)單,使用pip
pip install h5py
如何將訓(xùn)練數(shù)據(jù)生成HDF5文件
class HDF5DatasetWriter:
def __init__(self, dims, outputPath, dataKey="images", bufSize=1000):
# 如果輸出文件路徑存在,提示異常
if os.path.exists(outputPath):
raise ValueError("The supplied 'outputPath' already exists and cannot be overwritten. Manually delete the file before continuing", outputPath)
# 構(gòu)建兩種數(shù)據(jù)加派,一種用來(lái)存儲(chǔ)圖像特征一種用來(lái)存儲(chǔ)標(biāo)簽
self.db = h5py.File(outputPath, "w")
self.data = self.db.create_dataset(dataKey, dims, dtype="float")
self.labels = self.db.create_dataset("labels", (dims[0],), dtype="int")
# 設(shè)置buffer大小叫确,并初始化buffer
self.bufSize = bufSize
self.buffer = {"data": [], "labels": []}
self.idx = 0 # 用來(lái)進(jìn)行計(jì)數(shù)
def add(self, rows, labels):
self.buffer["data"].extend(rows)
self.buffer["labels"].extend(labels)
# 查看是否需要將緩沖區(qū)的數(shù)據(jù)添加到磁盤(pán)中
if len(self.buffer["data"]) >= self.bufSize:
self.flush()
def flush(self):
# 將buffer中的內(nèi)容寫(xiě)入磁盤(pán)之后重置buffer
i = self.idx + len(self.buffer["data"])
self.data[self.idx:i] = self.buffer["data"]
self.labels[self.idx:i] = self.buffer["labels"]
self.idx = i
self.buffer = {"data": [], "labels": []}
def storeClassLabels(self, classLabels):
# 存儲(chǔ)類(lèi)別標(biāo)簽
dt = h5py.special_dtype(vlen=str) # 表明存儲(chǔ)的數(shù)據(jù)類(lèi)型為字符串類(lèi)型
labelSet = self.db.create_dataset("label_names", (len(classLabels),), dtype=dt)
# 將classLabels賦值給labelSet但二者不指向同一內(nèi)存地址
labelSet[:] = classLabels
def close(self):
if len(self.buffer["data"]) > 0: # 查看是否緩沖區(qū)中還有數(shù)據(jù)
self.flush()
self.db.close()
在這段代碼中,我們定義一個(gè)類(lèi)來(lái)實(shí)現(xiàn)文件的讀取和打包并生成HDF5文件芍锦。
如何讀取HDF5文件用于訓(xùn)練
class HDF5DatasetGenerator:
def __init__(self, dbPath, batchSize, preprocessors = None, aug = None, binarize=True, classes=2):
# 保存參數(shù)列表
self.batchSize = batchSize
self.preprocessors = preprocessors
self.aug = aug
self.binarize = binarize
self.classes = classes
# hdf5數(shù)據(jù)集
self.db = h5py.File(dbPath)
self.numImages = self.db['labels'].shape[0]
def generator(self, passes=np.inf):
epochs = 0
# 默認(rèn)是無(wú)限循環(huán)遍歷竹勉,因?yàn)閚p.inf是無(wú)窮
while epochs < passes:
# 遍歷數(shù)據(jù)
for i in np.arange(0, self.numImages, self.batchSize):
# 從hdf5中提取數(shù)據(jù)集
images = self.db['images'][i: i + self.batchSize]
labels = self.db['labels'][i: i + self.batchSize]
# 檢查是否標(biāo)簽需要二值化處理
if self.binarize:
labels = np_utils.to_categorical(labels, self.classes)
# 預(yù)處理
if self.preprocessors is not None:
proImages = []
for image in images:
for p in self.preprocessors:
image = p.preprocess(image)
proImages.append(image)
images = np.array(proImages)
# 查看是否存在數(shù)據(jù)增強(qiáng),如果存在娄琉,應(yīng)用數(shù)據(jù)增強(qiáng)
if self.aug is not None:
(images, labels) = next(self.aug.flow(images,
labels, batch_size = self.batchSize))
# 返回
yield (images, labels)
epochs += 1
def close(self):
# 關(guān)閉db
self.db.close()
在這段代碼中次乓,我們以生成器的形式來(lái)讀取HDF5文件,返回用于訓(xùn)練孽水。
另外票腰,我錄制了一個(gè)視頻用來(lái)演示如何將貓狗大戰(zhàn)數(shù)據(jù)集生成HDF5文件,然后讀取HDF5文件用于進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練女气。對(duì)具體如何操作感興趣的杏慰,可以看我這個(gè)視頻了解一下。