前言
本文主要參考了幾篇文章输玷,搭建了一個(gè)在iOS上跑Tensorflow MNIST模型的demo队丝,本文會(huì)給出一個(gè)可用的Demo,寫(xiě)出當(dāng)時(shí)我遇到的問(wèn)題欲鹏。想要把項(xiàng)目跑起來(lái)机久,需要詳細(xì)的閱讀我貼出來(lái)的幾篇文章,某些具體步驟我會(huì)給出鏈接和索引赔嚎。
如果你是tensorflow新手膘盖,想要知道如何讀取訓(xùn)練好的MNIST模型并且做預(yù)測(cè),你會(huì)從這篇文章得到幫助并節(jié)約時(shí)間尤误,下載demo侠畔。
注意:這個(gè)demo需要Tensorflow的庫(kù)以及各種環(huán)境,你可以找到這個(gè)感受一下损晤,直接下載在iOS10上的真機(jī)就可以運(yùn)行软棺。
你需要一些Python,Tensorflow和iOS的知識(shí)尤勋。
Reference
1. python腳本喘落,訓(xùn)練MNIST+用自己的圖片做輸入預(yù)測(cè)結(jié)果
Using TensorFlow to create your own handwriting recognition engine
GitHub 下載腳本
2. 工程如何搭建請(qǐng)參考這篇
Getting started with TensorFlow on iOS
3. 在iOS里怎么load模型和讀取數(shù)據(jù)
Getting Started with Deep MNIST and TensorFlow on iOS
4. 深度學(xué)習(xí)德崭,卷積,神經(jīng)網(wǎng)絡(luò)簡(jiǎn)單的解釋看這篇
機(jī)器學(xué)習(xí)原來(lái)這么有趣揖盘!第三章:圖像識(shí)別【鳥(niǎo)or飛機(jī)】眉厨?深度學(xué)習(xí)與卷積神經(jīng)網(wǎng)絡(luò)
5. 刪除iOS不能支持的node
Drop dropout from Tensorflow
跑一個(gè)Tensorflow的例子
MNIST是一個(gè)手寫(xiě)數(shù)字0~9的數(shù)據(jù)集,通常機(jī)器學(xué)期的入門(mén)會(huì)使用這個(gè)數(shù)據(jù)集來(lái)跑一邊例子兽狭,因?yàn)閿?shù)據(jù)量不大憾股,訓(xùn)練的時(shí)間比較短,可以很快看到結(jié)果箕慧。
參考Getting started with TensorFlow on iOS中的Installing TensorFlow服球,在mac上搭建起運(yùn)行tensorflow的環(huán)境。
創(chuàng)建一個(gè)文件夾颠焦,比如名字叫train斩熊,下載train3.py,解壓好下載的MNIST數(shù)據(jù)集在MNIST_data文件夾中伐庭,在terminal中直接
python ./train3.py
這時(shí)Tensorflow會(huì)幫我們進(jìn)行訓(xùn)練粉渠。下載predict_2.py,并隨便的網(wǎng)上找?guī)讖埵謱?xiě)的數(shù)字0~9的圖片圾另,使用我們剛才訓(xùn)練的模型做預(yù)測(cè)
python ./predict_2.py ‘number1.png’
在predict_2.py中霸株,我們讀取了一張圖片,并對(duì)這張圖片做了一些處理集乔,包括使這張圖變?yōu)楹诎咨ゼs放圖片到28*28的大小(也是MNIST數(shù)據(jù)集中圖片的大腥怕贰)尤溜,讀取圖片的每一個(gè)像素并按照tensorflow需要的格式做處理,然后將數(shù)據(jù)輸入到模型中汗唱,獲取結(jié)果宫莱。
為iOS準(zhǔn)備Tensorflow的環(huán)境
這里請(qǐng)?jiān)敿?xì)參考Getting started with TensorFlow on iOS中的TensorFlow on iOS小結(jié),文章里已經(jīng)說(shuō)得非常詳細(xì)了渡嚣。步驟不復(fù)雜梢睛,但是編譯iOS需要的庫(kù)要一些時(shí)間肥印,我的macbook 13' 大概跑了2個(gè)多小時(shí)识椰。
Freezing the graph
這一節(jié)也在Getting started with TensorFlow on iOS有所提及,細(xì)節(jié)問(wèn)題深碱,我在這里說(shuō)明腹鹉。
如果你跟著做到了這里,那我們現(xiàn)在有了訓(xùn)練好的模型敷硅,這一步我們需要對(duì)這個(gè)模型進(jìn)行處理以便它可以用在iOS上面功咒。
上面的截圖顯示了你在運(yùn)行過(guò)train3.py之后會(huì)生成的模型文件愉阎。Freeze graph指的是將這些模型和訓(xùn)練好的網(wǎng)絡(luò)參數(shù)合并成一個(gè)文件,方便工程上的使用力奋。
在terminal中榜旦,進(jìn)入到tensorflow文件夾,復(fù)制粘貼執(zhí)行:
bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/mnist/model/graph.pb --input_checkpoint=/mnist/model/model.ckpt \ --output_node_names=softmax \ --output_graph=/mnist/model/frozen.pb
注意這個(gè)目錄:/mnist/model/是指Macintosh HD下的/mnist/model/景殷,也就是mac硬盤(pán)的根目錄下面溅呢。
這樣我們就把模型和參數(shù)合并到了一起,這里拿到的模型里面猿挚,有一些操作有可能是不能直接在iOS上面運(yùn)行的咐旧,所以我在train3.py中移除了一些node使得這個(gè)模型可以直接放到ios上面。
接下來(lái)我們需要用optimize_for_inference優(yōu)化這個(gè)模型绩蜻,獲得一個(gè)final.pb铣墨,這個(gè)才是最后用在iOS上的文件:
bazel-bin/tensorflow/python/tools/optimize_for_inference --input=/mnist/model/frozen.pb --output=/mnist/model/final.pb --output_names=softmax --frozen_graph=True --input_names=x
你可以在這里下載我訓(xùn)練并處理好的模型文件。
The iOS App
1.創(chuàng)建一個(gè)新的App工程
2.修改ViewController.m為.mm办绝,因?yàn)槲覀冃枰褂胏++
3.在Build Settings中伊约,根據(jù)你編譯好的tensorflow文件夾地址修改other link flags:
/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf-lite.a
/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf.a
-force_load /Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/lib/
libtensorflow-core.a
4.同樣修改 library search path:
-force_load
/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf-lite.a
/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf.a
/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/lib/
libtensorflow-core.a
注意這里有 -force_load 不然runtime要出錯(cuò)
5.修改Header Search Paths:
~/tensorflow
~/tensorflow/tensorflow/contrib/makefile/downloads
~/tensorflow/tensorflow/contrib/makefile/downloads/eigen
~/tensorflow/tensorflow/contrib/makefile/downloads/protobuf/src
~/tensorflow/tensorflow/contrib/makefile/gen/proto
6.修改Enable Bitcode: No
7.將final.pb拖入iOS項(xiàng)目中,記得勾選Add to target
0.參考Getting started with TensorFlow on iOS 里面的 The iOS App孕蝉,這里有不懂的對(duì)照著看一下
iOS代碼部分
詳細(xì)代碼請(qǐng)去github下載我的demo碱妆,可以配好環(huán)境運(yùn)行一下
加載model
- (void)viewDidLoad {
[super viewDidLoad];
// Do any additional setup after loading the view, typically from a nib.
NSString *path = [[NSBundle mainBundle] pathForResource:@"final" ofType:@"pb"];
if ([self loadGraphFromPath:path] && [self createSession]) {
NSLog(@"load model and create session");
}
}
-(BOOL)loadGraphFromPath:(NSString *)path {
auto status = ReadBinaryProto(tensorflow::Env::Default(), path.fileSystemRepresentation, &graph);
if (!status.ok()) {
NSLog(@"Error reading graph: %s", status.error_message().c_str());
return NO;
}
auto nodeCount = graph.node_size();
NSLog(@"Node count: %d", nodeCount);
for (auto i = 0; i < nodeCount; ++i) {
auto node = graph.node(i);
NSLog(@"Node %d: %s '%s'", i, node.op().c_str(), node.name().c_str());
}
return YES;
}
-(BOOL)createSession {
tensorflow::SessionOptions options;
auto status = tensorflow::NewSession(options, &session);
if (!status.ok()) {
NSLog(@"Error creating session: %s", status.error_message().c_str());
return NO;
}
status = session->Create(graph);
if (!status.ok()) {
NSLog(@"Error creating session: %s", status.error_message().c_str());
return NO;
}
return YES;
}
做預(yù)測(cè)
- 讀取圖片,將圖片scale昔驱,讀取像素做normalize
- 放入input
- 跑網(wǎng)絡(luò)
- 拿到輸出疹尾,獲得結(jié)果
-(void)predict {
// 1. 讀取圖片,將圖片scale骤肛,讀取像素做normalize
UIImage *orignalImage = [UIImage imageNamed:@"9-1.png"];
UIImage *scaledImage = [self scaleImage:orignalImage];
UIImage *image = [self convertImageToGrayScale:scaledImage];
UIImageView *imageView = [UIImageView new];
imageView.frame = CGRectMake(0, 0, 50, 50);
imageView.image = image;
[self.view addSubview:imageView];
tensorflow::Tensor x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1,kInputLength}));
NSArray *pixel = [self getRGBAsFromImage:image atX:0 andY:0 count:kInputLength];
for (auto i = 0; i < kInputLength; i++) {
UIColor *color = pixel[i];
CGFloat red = 0.0, green = 0.0, blue = 0.0, alpha =0.0;
[color getRed:&red green:&green blue:&blue alpha:&alpha];
x.matrix<float>().operator()(0,i) = (255.0 - red) / 255.0f;
NSLog(@"%f",x.matrix<float>().operator()(0,i));
}
// 2. 放入input
std::vector<std::pair<tensorflow::string, tensorflow::Tensor>> inputs = {
{"x", x}
};
std::vector<std::string> nodes = {
{"softmax"}
};
const auto start = CACurrentMediaTime();
std::vector<tensorflow::Tensor> outputs;
// 3. 跑網(wǎng)絡(luò)
auto status = session->Run(inputs, nodes, {}, &outputs);
if (!status.ok()) {
NSLog(@"Error reading graph: %s", status.error_message().c_str());
return;
}
NSLog(@"Time: %g seconds", CACurrentMediaTime() - start);
// 4. 拿到輸出纳本,獲得結(jié)果
const auto outputMatrix = outputs[0].matrix<float>();
float bestProbability = 0;
int bestIndex = -1;
for (auto i = 0; i < kOutputs; i++) {
const auto probability = outputMatrix(i);
if (probability > bestProbability) {
bestProbability = probability;
bestIndex = i;
}
}
NSLog(@"!!!!!!!!!!! result %d",bestIndex);
}
至此,我們就成功的在iOS上用tensorflow跑起了我們訓(xùn)練好的模型腋颠,并做出預(yù)測(cè)了繁成!
其他遇到的問(wèn)題
當(dāng)我使用create_model_2.py創(chuàng)建了一個(gè)模型,但是在iOS上卻報(bào)這么一個(gè)錯(cuò):
Invalid argument: No OpKernel was registered to support Op 'RandomUniform' with these attrs. Registered devices: [CPU], Registered kernels:
<no registered kernels>
[[Node: dropout/random_uniform/RandomUniform = RandomUniform[T=DT_INT32, dtype=DT_FLOAT, seed=0, seed2=0](dropout/Shape)]]
顯示我的模型里面淑玫,有iOS不能支持的node巾腕。Google了很久,發(fā)現(xiàn)我最開(kāi)始使用的訓(xùn)練模型中絮蒿,使用了dropout來(lái)防止訓(xùn)練過(guò)擬合尊搬,但是dropout中有iOS不能執(zhí)行的node操作,并且freeze_graph和optimize_for_inference也不能刪除iOS不支持的節(jié)點(diǎn)土涝。
目前我知道的解決方式就是手動(dòng)的刪除model中iOS不可以支持的節(jié)點(diǎn)佛寿,在這里我們可以直接干掉dropout相關(guān)的節(jié)點(diǎn)。
具體方法參考:
Drop dropout from Tensorflow
optimize_for_inference.py should remove Dropout operations #5867
主意上面的train3.py這個(gè)腳本但壮,這里使用的train3腳本是我對(duì)谷歌給出例子的修改冀泻,使得每訓(xùn)練1000個(gè)數(shù)據(jù)會(huì)自動(dòng)保存一下模型常侣,我們可以訓(xùn)練一會(huì)并ctrl+c取消訓(xùn)練,用已經(jīng)保存的模型來(lái)做預(yù)測(cè)弹渔,雖然預(yù)測(cè)會(huì)不那么準(zhǔn)胳施。我在這個(gè)腳本中創(chuàng)建了graph.pb并且移除了dropout的操作,所以這里訓(xùn)練出來(lái)的模型不會(huì)遇到有node在iOS上不支持的問(wèn)題肢专。
最后
如果你看到這里還沒(méi)有放棄巾乳,那希望你有一點(diǎn)點(diǎn)收獲:P