用自己的圖片構(gòu)建cifar10 binary格式的數(shù)據(jù)
標(biāo)簽(空格分隔): 未分類
摘要
- 本文主要討論用python構(gòu)建cifar10 binary version數(shù)據(jù)
- 環(huán)境:windows顷帖, python3.6
- 本文最后更新日期是2018.04.13,因?yàn)榘姹静町愒斐傻膯?wèn)題請(qǐng)見(jiàn)諒
正文
tensorflow教程中的卷積神經(jīng)網(wǎng)絡(luò)一節(jié)厌衔,利用cifar10數(shù)據(jù)做分類,效果還不錯(cuò)茵瀑,那么我們能不能用這個(gè)模型去訓(xùn)練做其他的圖片分類呢酵使?
tensorflow教程的卷積神經(jīng)網(wǎng)絡(luò)教程中文網(wǎng)址
tensorflow教程的卷積神經(jīng)網(wǎng)絡(luò)教程英文網(wǎng)址
答案當(dāng)然是可以,但是教程給的模型是利用cifar10 binary格式的鸵荠,所以需要我們事先把我們的圖片數(shù)據(jù)轉(zhuǎn)成cifar10 binary格式簸呈。
教程給的模型代碼
cifar10 dataset網(wǎng)址
cifar10 binary version如下:
cifar10 binary version.png
由于cifar10的圖片是32*32的榕订,所以每張圖片一共有1024個(gè)像素,按RGB分出來(lái)就是一共3072個(gè)byte蜕便,每張圖片就可以寫(xiě)成(1+1024+1024+1024)的格式劫恒,其中第一個(gè)字節(jié)是label,而后1024個(gè)字節(jié)是Red通道玩裙,1024個(gè)字節(jié)是Green通道兼贸,1024個(gè)字節(jié)是Blue通道。將每張圖片都按這種格式表示吃溅,無(wú)縫連接在一起溶诞,就構(gòu)成了一個(gè)cifar10數(shù)據(jù)。
為此我參考這篇博文制作自己的python版本的類CIFAR10數(shù)據(jù)集决侈,這篇博文是講述制作python version的螺垢,將他的代碼修改一下就可以制作binary version了喧务。
# -*- coding: UTF-8 -*-
import cv2
import os
import numpy as np
DATA_LEN = 3072
CHANNEL_LEN = 1024
SHAPE = 32
def imread(im_path, shape=None, color="RGB", mode=cv2.IMREAD_UNCHANGED):
im = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)
if color == "RGB":
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if shape != None:
assert isinstance(shape, int)
im = cv2.resize(im, (shape, shape))
return im
def read_data(filename, data_path, shape=None, color='RGB'):
"""
filename (str): a file
data file is stored in such format:
image_name label
data_path (str): image data folder
return (numpy): a array of image and a array of label
"""
if os.path.isdir(filename):
print("Can't found data file!")
else:
f = open(filename)
lines = f.read().splitlines()
count = len(lines)
data = np.zeros((count, DATA_LEN), dtype=np.uint8)
# label = np.zeros(count, dtype=np.uint8)
lst = [ln.split(' ')[0] for ln in lines]
label = [int(ln.split(' ')[1]) for ln in lines]
idx = 0
s, c = SHAPE, CHANNEL_LEN
for ln in lines:
fname, lab = ln.split(' ')
im = imread(os.path.join(data_path, fname), shape=s, color='RGB')
'''
im = cv2.imread(os.path.join(data_path, fname), cv2.IMREAD_UNCHANGED)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (s, s))
'''
data[idx, :c] = np.reshape(im[:, :, 0], c)
data[idx, c:2 * c] = np.reshape(im[:, :, 1], c)
data[idx, 2 * c:] = np.reshape(im[:, :, 2], c)
label[idx] = int(lab)
idx = idx + 1
return data, label, lst
def py2bin(data, label):
label_arr = np.array(label).reshape(len(label), 1)
label_uint8 = label_arr.astype(np.uint8)
arr = np.hstack((label_uint8, data))
with open('./bin/train_batch', 'wb') as f:
for element in arr.flat:
f.write(element)
def imagelist():
directory_normal = r"data/normal" #原始圖片位置,32*32 pixel
file_train_list = r"data/image_train_list.txt" #構(gòu)建imagelist輸出位置
with open(file_train_list, "a") as f:
for filename in os.listdir(directory_normal):
f.write(filename + " " + "0" + "\n") #這里分類默認(rèn)全為0
if __name__ == '__main__':
data_path = './data/normal'
file_list = './data/image_train_list.txt'
save_path = './bin'
imagelist() #構(gòu)建imagelist
data, label, lst = read_data(file_list, data_path, shape=32) #將圖片像素?cái)?shù)據(jù)轉(zhuǎn)成矩陣和標(biāo)簽列表
py2bin(data, label) #將像素矩陣和標(biāo)簽列表轉(zhuǎn)成cifar10 binary version