MXnet的學(xué)習(xí)筆記,這次主要是使用MXnet提供的example模型進行訓(xùn)練時如何加載數(shù)據(jù)集的介紹。步驟基本上按照MXNet Python Data Loading API炬藤。
有關(guān)MXnet在OSX下的編譯安裝晾捏,可以看這里Mac下編譯安裝MXNet。
有關(guān)MXnet提供的example的綜述介紹<-在這里穿撮。
Sample iterator for data loading
在瀏覽完MXnet提供的example后想要在自己的機器上跑一下簡單的數(shù)據(jù)集看看結(jié)果缺脉。因為現(xiàn)在只是裝在自己的MBA上痪欲,沒有裝CUDA和OpenMP,也沒有使用GPU訓(xùn)練攻礼,因此只能跑一跑簡單的數(shù)據(jù)集业踢。MXnet的Image Classification Example中的樣例都比較完整,使用步驟也很詳細礁扮,訓(xùn)練最基本的MNIST數(shù)據(jù)集基本上不需要多余的工作量陨亡,只要能聯(lián)網(wǎng)下載MNIST數(shù)據(jù)集(或者自己有數(shù)據(jù)集的話移動到對應(yīng)文件夾下)就可以直接訓(xùn)練,效果也挺不錯:
→ python train_mnist.py
2016-05-23 08:51:41,616 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)
[08:51:45] src/io/iter_mnist.cc:91: MNISTIter: load 60000 images, shuffle=1, shape=(128,784)
[08:51:46] src/io/iter_mnist.cc:91: MNISTIter: load 10000 images, shuffle=1, shape=(128,784)
2016-05-23 08:51:46,460 Node[0] Start training with [cpu(0)]
...
2016-05-23 08:52:02,548 Node[0] Epoch[9] Batch [450] Speed: 41054.59 samples/sec Train-top_k_accuracy_20=1.000000
2016-05-23 08:52:02,605 Node[0] Epoch[9] Resetting Data Iterator
2016-05-23 08:52:02,605 Node[0] Epoch[9] Time cost=1.470
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-accuracy=0.977464
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.999299
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_10=1.000000
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_20=1.000000
默認的參數(shù)為:batch-size=128深员,初始學(xué)習(xí)率為0.1(固定學(xué)習(xí)率负蠕,lr_factor_epoch=1),使用最基本的多層感知機MLP進行訓(xùn)練倦畅。每個epoch耗時大約1.5秒左右遮糖,在10次迭代后測試集的accuracy達到0.977464。
其實在測試的時候看到Validation-accuracy時就有在想指的是cross-validation的accuracy還是test的accuracy叠赐,因此這時候就可以先去看看MXnet中到底是怎么讀取數(shù)據(jù)欲账、怎么使用KVstore的。
根據(jù)官方文檔的介紹芭概,MXnet使用iterator將參數(shù)傳遞給訓(xùn)練模型赛不。這里的iterator會做一些數(shù)據(jù)預(yù)處理,并且生成指定大小的batch輸入訓(xùn)練模型罢洲。
由于MNIST的數(shù)據(jù)比較簡單踢故,example里面提供了載入MNIST數(shù)據(jù)集的iterator實現(xiàn),如下:
def get_iterator(data_shape):
def get_iterator_impl(args, kv):
data_dir = args.data_dir
# 若指定位置沒有MNIST數(shù)據(jù)集則會調(diào)用_download()函數(shù)聯(lián)網(wǎng)下載
if '://' not in args.data_dir:
_download(args.data_dir)
# data_shape變量為輸入數(shù)據(jù)的格式惹苗。對于MNIST:
# 若使用MLP進行訓(xùn)練殿较,輸入數(shù)據(jù)為有784個元素的一維向量,data_shape = (784, )
# 若使用LeNet進行訓(xùn)練桩蓉,輸入數(shù)據(jù)為一個28*28的矩陣淋纲,data_shape = (1, 28, 28)
# 因此若len(data_shape)不等于3時,設(shè)置flat變量為True院究,即對MNIST每一個輸入數(shù)據(jù)一維扁平化
flat = False if len(data_shape) == 3 else True
# 訓(xùn)練集的參數(shù)指定
train = mx.io.MNISTIter(
image = data_dir + "train-images-idx3-ubyte",
label = data_dir + "train-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = args.batch_size,
## A commonly mistake is forgetting shuffle the image list during packing.
## This will lead fail of training.
## eg. accuracy keeps 0.001 for several rounds.
shuffle = True,
flat = flat,
num_parts = kv.num_workers,
part_index = kv.rank)
# 測試集的參數(shù)指定
val = mx.io.MNISTIter(
image = data_dir + "t10k-images-idx3-ubyte",
label = data_dir + "t10k-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = args.batch_size,
flat = flat,
num_parts = kv.num_workers,
part_index = kv.rank)
return (train, val)
return get_iterator_impl
在 train_mnist.py 的main函數(shù)里洽瞬,會調(diào)用get_iterator()函數(shù)得到輸入的iterator,傳遞給train_model.fit()函數(shù)執(zhí)行真正的訓(xùn)練過程业汰。
在之前的example介紹里有說到伙窃,Image Classification(包括后面基于CNN的很多其它網(wǎng)絡(luò))的不同網(wǎng)絡(luò)結(jié)構(gòu)運用在不同的數(shù)據(jù)集上,最后都是回到調(diào)用train_model.fit()函數(shù)進行訓(xùn)練蔬胯。因此輸入數(shù)據(jù)的獲取和iterator的定義都在對應(yīng)的 train_{mnist, cifar10, imagenet}.py 中对供,最簡單的定義就如上面的代碼所示。
Build your own iterator
MNIST輸入數(shù)據(jù)的格式類型分為recordio,MNIST和csv。MNIST數(shù)據(jù)集的參數(shù)指定較為簡單产场,上面的例子基本都覆蓋到了鹅髓。有關(guān)csv和MNIST數(shù)據(jù)集的更多參數(shù)指定信息<--點擊鏈接。
對于圖片數(shù)據(jù)集(recordio格式的數(shù)據(jù))京景,在創(chuàng)建iterator時窿冯,一般需要指定的參數(shù)有五類,包括:
- 數(shù)據(jù)集參數(shù) (Dataset Param)确徙,提供了數(shù)據(jù)集的基本信息醒串,如數(shù)據(jù)文件地址、數(shù)據(jù)形狀(即前例中的input_shape)等等鄙皇。
- 批參數(shù) (Batch Param) 提供了形成batch的信息芜赌,比如batch size。
- Augmentation Param 可以設(shè)定對數(shù)據(jù)集預(yù)處理的參數(shù)伴逸,比如mean_image(將圖像中的每個像素減去圖片像素均值)缠沈,rand_crop(隨機對圖像進行部分切割),rand_mirror(隨機對圖像進行水平對稱變換)等等错蝴。
- 后臺參數(shù) (Backend Param) 控制后臺線程來隱藏讀取數(shù)據(jù)的開銷的相關(guān)參數(shù)洲愤,如preprocess_threads設(shè)定后臺預(yù)讀取線程數(shù)量,prefetch_buffer設(shè)定預(yù)讀取buffer的大小顷锰。
- 輔助參數(shù) (Auxiliary Param) 提供用于調(diào)試的參數(shù)設(shè)定柬赐,如verbose設(shè)定是否要輸出parser信息。
具體的參數(shù)定義可以看官方文檔:I/O API官紫。
Use your own data
要使用自己的數(shù)據(jù)集(或者ImageNet數(shù)據(jù)集)肛宋,由于MXnet沒有提供類似MNIST和cifar的自動下載和加載腳本將原始數(shù)據(jù)轉(zhuǎn)換為ImageRecord數(shù)據(jù),因此需要自己進行數(shù)據(jù)格式轉(zhuǎn)換万矾。
不過將數(shù)據(jù)轉(zhuǎn)換為ImageRecord格式也很簡單:
- 首先將圖像存儲為壓縮過的格式(比如.jpg)悼吱,以降低數(shù)據(jù)量慎框。
- 使用MXnet提供的make_list[./mxnet/tools/make_list.py]工具生成lst文件良狈,lst文件的格式為
integer_image_index \t label_index \t path_to_image
make_list接受的參數(shù)包括
- chunks[int]:將原始數(shù)據(jù)集分成chunks塊,得到chunks個數(shù)據(jù)量相同但對應(yīng)數(shù)據(jù)不同的lst文件笨枯,默認值為1薪丁。
- train_ratio[float]:指定每個chunk內(nèi)用于訓(xùn)練的數(shù)據(jù)所占的比例,可以設(shè)置不同的訓(xùn)練集-測試集比馅精,默認值為1(即所有數(shù)據(jù)用于訓(xùn)練)严嗜。
- exts[list]:接受的輸入數(shù)據(jù)格式,默認值為{.jpg,.jpeg}洲敢。
- recursive[bool]:若設(shè)定為TRUE且原始數(shù)據(jù)集已經(jīng)按照label放在了不同的子文件夾中漫玄,則make_list會自動為每個子文件夾內(nèi)的數(shù)據(jù)標記對應(yīng)的label_index,否則所有的數(shù)據(jù)都標注統(tǒng)一 label_index = 0,默認值為FALSE睦优。
- 使用MXnet提供的im2rec[./mxnet/tools/im2rec.{cc,py}]工具(提供C++版本和Python版本)渗常,通過原始數(shù)據(jù)和lst文件得到ImageRecord格式的數(shù)據(jù)供MXnet使用。若不指定lst文件則使用與make_list相同的方法先生成lst文件再生成ImageRecord文件汗盘。im2rec除了有make_list相同的參數(shù)外皱碘,在生成ImageRecord部分的參數(shù)還有
- resize[int, default = 0]:等比例縮放圖片,將圖片短邊設(shè)置為指定大小隐孽。
- center_crop[bool, default = FALSE]:截取圖片中間的方形部分癌椿,方形邊長為短邊長。
- quality[int, default = 80]:設(shè)定圖像的質(zhì)量(.jpg:1-100, .png:1-9)菱阵。
- num_thread[int, default = 1]:若使用多線程進行數(shù)據(jù)格式轉(zhuǎn)換踢俄,則生成圖像順序會與輸入list的不同。
- color[int, default = 1, choice = {-1, 0, 1}]:輸入圖像的color mode晴及,若為1則直接讀如彩色數(shù)據(jù)褪贵,0為灰度模式,-1為使用alpha channel(<--這個應(yīng)該是圖像處理領(lǐng)域的專業(yè)知識抗俄,我也不是很理解)脆丁。
- encoding[str, default = .jpg, choice = {.jpg, .png}]:圖像轉(zhuǎn)換后保存的格式。
這邊會遇到一點問題动雹,如果調(diào)用im2rec.py的時候提示
No module named cv
的話槽卫,網(wǎng)上查詢到的原因是沒安裝openCV(不過其實之前裝了……)
那只要把代碼中
import cv, cv2
中的cv去掉即可,后續(xù)好像只使用到了cv2庫中的內(nèi)容胰蝠,不需要cv歼培。
然后就可以在MXnet中使用自己的數(shù)據(jù)集了。