作為深度學(xué)習(xí)界的“hello world!”码俩,學(xué)習(xí)起來(lái)真沒(méi)那么容易。
接觸深度學(xué)習(xí)歼捏,第一個(gè)接觸的就是mnist稿存。但是初次接觸就只跑了三個(gè)腳本
get_mnist.sh
create_mnist.sh
train_lenet.sh
然后就結(jié)束了,對(duì)此我蒙逼了許久瞳秽。因?yàn)閷?duì)于caffe的整體框架不熟悉瓣履,對(duì)CNN不深入,因此感覺(jué)舉步維艱练俐。經(jīng)過(guò)1個(gè)多月的沉淀終于能完整的走一遍MNIST袖迎。
對(duì)于初學(xué)者,深度學(xué)習(xí)分為三步:1.數(shù)據(jù)準(zhǔn)備 2.訓(xùn)練 3.預(yù)測(cè)
一.數(shù)據(jù)準(zhǔn)備
官方例程推薦的數(shù)據(jù)集為
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
train-images-idx3-ubyte
train-labels-idx1-ubyte
相信許多人和我一樣會(huì)問(wèn):這是什么啊腺晾,打開(kāi)還是一推二進(jìn)制數(shù)燕锥。確實(shí),官方的數(shù)據(jù)集可視化不好悯蝉,但是可以借助matlab或者python解析出來(lái)归形。但是,對(duì)于普通人拿到的數(shù)據(jù)往往都是圖片格式鼻由,而且是很多暇榴。
這該進(jìn)行怎么加載訓(xùn)練呢。
先粗略的看下蕉世,官方的數(shù)據(jù)集蔼紧。可以看出images對(duì)應(yīng)一個(gè)labels狠轻,所以我們準(zhǔn)備的數(shù)據(jù)包括圖片和標(biāo)簽奸例。
1)基礎(chǔ)準(zhǔn)備
在data文件夾下創(chuàng)建如下文件夾,準(zhǔn)備訓(xùn)練集,驗(yàn)證集和測(cè)試集
創(chuàng)建 train test文件夾和對(duì)應(yīng)的txt將你的訓(xùn)練集放到train中哈误,將驗(yàn)證集放到test中哩至。(這里應(yīng)該多建一個(gè)valid文件夾躏嚎,里面存放的是驗(yàn)證集,而test中放測(cè)試集菩貌,這里偷工減料了)
接著要制作標(biāo)簽卢佣,如果量少可以考慮手敲,但是大數(shù)據(jù)就只能借助代碼了箭阶。
創(chuàng)建make_list.py
#coding=utf-8
#caffe and opencv test mnist
#test by yuzefan
import os
from os.path import join, isdir
def gen_listfile(dir):
cwd=os.getcwd() # 獲取當(dāng)前目錄
os.chdir(dir) # 改變當(dāng)前的目錄
sd=[d for d in os.listdir('.') if isdir(d)] # 列出當(dāng)前目錄下的所有文件和目錄名,os.listdir可以列出文件和目錄
sd.sort()
class_id=0
with open(join(dir,'listfile.txt'),'w') as f : #join():connect string,"with...as"is used for safety,without it,you must write by"file = open("/tmp/foo.txt") file.close()
for d in sd :
fs=[join(d,x) for x in os.listdir(d)]
for img in fs:
f.write(img + ' '+str(class_id)+'\n')
class_id+=1
os.chdir(cwd)
if __name__ == "__main__":
root_dir = raw_input('image root dir: ')
while not isdir(root_dir):
raw_input('not exist, re-input please: ')
gen_listfile(root_dir)
運(yùn)行后可以得到標(biāo)簽虚茶,如下:
list已經(jīng)準(zhǔn)備好了,接著要把數(shù)據(jù)轉(zhuǎn)成lmdb仇参。caffe之所以速度快嘹叫,得益于lmdb數(shù)據(jù)格式。
創(chuàng)建creat_lmdb.sh腳本
#coding=utf-8
#!/usr/bin/env sh
#指定腳本的解釋程序
#by yuzefan
set -e #如果任何語(yǔ)句的執(zhí)行結(jié)果不是true則應(yīng)該退出
# CAFFEIMAGEPATH is the txt file path
# DATA is the txt file path
CAFFEDATAPATH=mytest/chinese/data
DATA=mytest/chinese/data/mnist
TOOLS=~/caffe-master/build/tools
# TRAIN_DATA_PATH & VAL_DATA_ROOT is root path of your images path, so your train.txt file must do not contain
# this line again!!
TRAIN_DATA_ROOT=/home/ubuntu/caffe-master/mytest/chinese/data/mnist/train/
VAL_DATA_ROOT=/home/ubuntu/caffe-master/mytest/chinese/data/mnist/test/
# Set RESIZE=true to resize the images to 28x28. Leave as false if images have
# already been resized using another tool.
RESIZE=true
if $RESIZE;then
RESIZE_HEIGHT=28
RESIZE_WIDTH=28
else
RESIZE_HEIGHT=0
RESIZE_WIDTH=0
fi
if [ ! -d "$TRAIN_DATA_ROOT" ]; then
echo "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"
echo "Set the TRAIN_DATA_ROOT variable in create_imagenet.sh to the path" \
"where the ImageNet training data is stored."
exit 1
fi
if [ ! -d "$VAL_DATA_ROOT" ]; then
echo "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"
echo "Set the VAL_DATA_ROOT variable in create_imagenet.sh to the path" \
"where the ImageNet validation data is stored."
exit 1
fi
echo "Creating train lmdb..."
GLOG_logtostderr=1 $TOOLS/convert_imageset \
--resize_height=$RESIZE_HEIGHT \
--resize_width=$RESIZE_WIDTH \
--shuffle \
--gray=true\
$TRAIN_DATA_ROOT \
$DATA/train.txt \
$CAFFEDATAPATH/caffe_train_lmdb
echo "Creating val lmdb..."
GLOG_logtostderr=1 $TOOLS/convert_imageset \
--resize_height=$RESIZE_HEIGHT \
--resize_width=$RESIZE_WIDTH \
--shuffle \
--gray=true\
$VAL_DATA_ROOT \
$DATA/test.txt \
$CAFFEDATAPATH/caffe_val_lmdb
echo "Done."
運(yùn)行完后在data目錄下出現(xiàn)
caffe_train_lmdb
caffe_val_lmdb
這里使用了caffe的tools中的convert_imageset诈乒。使用方法:
convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
其中
參數(shù):ROOTFOLDER 表示輸入的文件夾
參數(shù):LISTFILE 表示輸入文件列表罩扇,其每一行為:類(lèi)似 subfolder1/file1.JPEG 7
可選參數(shù):[FLAGS] 可以指示是否使用shuffle,顏色空間怕磨,編碼等喂饥。
--gray=true \-------------------------------------------->如果灰度圖的話(huà)加上即可
還調(diào)用了opencv,對(duì)輸入圖像進(jìn)行尺寸變換肠鲫,滿(mǎn)足網(wǎng)絡(luò)的要求员帮。
注意:
TRAIN_DATA_PATH & VAL_DATA_ROOT is root path of your images path, so your train.txt file must do not contain
到此,數(shù)據(jù)準(zhǔn)備就結(jié)束了导饲。
二.訓(xùn)練
訓(xùn)練需要模型描述文件和模型求解文件捞高。
lenet_train_test.prototxt
lenet_solver.prototxt
對(duì)于lenet_train_test.prototxt,需要改的地方只有數(shù)據(jù)層
name: "LeNet"
layer {
name: "mnist" #名字隨便
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
scale: 0.00390625
}
data_param {
source: "mytest/chinese/data/caffe_train_lmdb" #這里是上一步生成的lmdb
batch_size: 64#一次壓入網(wǎng)絡(luò)的數(shù)量
backend: LMDB
}
}
layer {
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
scale: 0.00390625
}
data_param {
source: "mytest/chinese/data/caffe_val_lmdb"
batch_size: 100
backend: LMDB
}
}
對(duì)于lenet_solver.prototxt
# The train/test net protocol buffer definition
net: "mytest/chinese/lenet_train_test.prototxt"#這里可以把訓(xùn)練和驗(yàn)證放到一起渣锦,實(shí)際可以分開(kāi)
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100 #test_iter * batch_size= 10000(test集的大邢醺凇)
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 20
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "mytest/chinese/lenet"
# solver mode: CPU or GPU
solver_mode: GPU
訓(xùn)練可以執(zhí)行train_lenet.sh,實(shí)際上還是調(diào)用了tools
#!/usr/bin/env sh
set -e
./build/tools/caffe train --solver=mytest/chinese/lenet_solver.prototxt $@
沒(méi)有意外的話(huà)就能正常開(kāi)始訓(xùn)練了泡挺。
三.預(yù)測(cè)
預(yù)測(cè)可以參考我之前寫(xiě)的
Caffe學(xué)習(xí)筆記1:用訓(xùn)練好的mnist模型進(jìn)行預(yù)測(cè)(兩種方法)
http://www.reibang.com/p/6fcdefbacf5b
小筆記:均值計(jì)算
減均值預(yù)處理能提高訓(xùn)練和預(yù)測(cè)的速度辈讶,利用tools
二進(jìn)制格式的均值計(jì)算
build/tools/compute_image_mean examples/mnist/mnist_train_lmdb examples/mnist/mean.binaryproto
帶兩個(gè)參數(shù):
第一個(gè)參數(shù):examples/mnist/mnist_train_lmdb, 表示需要計(jì)算均值的數(shù)據(jù)娄猫,格式為lmdb的訓(xùn)練數(shù)據(jù)贱除。
第二個(gè)參數(shù):examples/mnist/mean.binaryproto, 計(jì)算出來(lái)的結(jié)果保存文件媳溺。
接下來(lái)的計(jì)劃:現(xiàn)在說(shuō)白了是個(gè)10類(lèi)的分類(lèi)器月幌,接下來(lái)增強(qiáng)網(wǎng)絡(luò)使其能夠訓(xùn)練并預(yù)測(cè)出0~9 and ‘a(chǎn)’~‘z’