iOS相比Android移植TensorFlow沒那么方便,要用C++來編寫,接下來講一下iOS調(diào)用TensorFlow的過程洽沟。
- 引入依賴
在Podfile中加入pod 'TensorFlow-experimental'攘宙,再在terminal中cd進(jìn)項(xiàng)目目錄輸入pod install即可安裝依賴。
- 復(fù)制PB文件
快速開發(fā)的話直接把PB文件放在data文件夾里就行划煮,如果正式上線的時(shí)候覺得PB文件一起打包較大的話可以放在服務(wù)器送丰,打開APP的時(shí)候提示下載再?gòu)?fù)制進(jìn)去就好。
- 引入頭文件弛秋、命名空間
#import <opencv2/imgcodecs/ios.h>
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/public/session.h"
#include <tensorflow/core/kernels/reshape_op.h>
#include <tensorflow/core/kernels/argmax_op.h>
using namespace tensorflow;
using namespace tensorflow::core;
- 處理數(shù)據(jù)
圖像處理相比于Android的bitmap操作還是較為麻煩器躏,iOS需要用到opencv,所以還需要引入opencv的依賴蟹略,通過cv的UIImageToMat方法吧UIImage轉(zhuǎn)成cv::Mat再進(jìn)行矩陣操作(類似:灰度處理登失、歸一化、平展)
UIImage *image = [UIImage imageNamed:@"OOLU8095571.jpg"];
self.preImageView.contentMode = UIViewContentModeRedraw;
UIImageToMat(image,cvMatImage);
cvMatImage.convertTo(cvMatImage, CV_32F, 1.0/255., 0);//歸一化
cv::Mat reshapeMat= cvMatImage.reshape(0,1);//reshape
NSString* inference_result = RunModel(reshapeMat);
self.urlContentTextView.text = inference_result;
RunModel(reshapeMat)就是把處理過的數(shù)據(jù)傳遞給TensorFlow去運(yùn)算了挖炬。
- 定義常量
這里跟Android差不多揽浙,定義一些必要的常量,輸入輸出節(jié)點(diǎn)茅茂,輸出輸出節(jié)點(diǎn)數(shù)據(jù)捏萍,圖像尺寸、通道等
std::string input_layer = "inputs/X";
std::string output_layer = "output/predict";
tensorflow::Tensor x(
tensorflow::DT_FLOAT,
tensorflow::TensorShape({wanted_height*wanted_width}));
std::vector<tensorflow::Tensor> outputs;
const int wanted_width = 256;
const int wanted_height = 64;
const int wanted_channels = 1;
- 創(chuàng)建session
這里跟Android不同空闲,需要手動(dòng)創(chuàng)建session
tensorflow::SessionOptions options;
tensorflow::Session* session_pointer = nullptr;
tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer);
if (!session_status.ok()) {
std::string status_string = session_status.ToString();
return [NSString stringWithFormat: @"Session create failed - %s",
status_string.c_str()];
}
std::unique_ptr<tensorflow::Session> session(session_pointer);
- 載入graph
tensorflow::GraphDef tensorflow_graph;
NSString* network_path = FilePathForResourceName(@"rounded_graph", @"pb");
PortableReadFileToProto([network_path UTF8String], &tensorflow_graph);
tensorflow::Status s = session->Create(tensorflow_graph);
if (!s.ok()) {
LOG(ERROR) << "Could not create TensorFlow Graph: " << s;
return @"";
}
其中FilePathForResourceName是返回graph的地址
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
<< [extension UTF8String] << "' in bundle.";
}
return file_path;
}
PortableReadFileToProto是把graph讀到內(nèi)存中并賦值給tensorflow_graph令杈,并使用session->Create(tensorflow_graph)把graph載入到session中。
- 輸入數(shù)據(jù)的類型轉(zhuǎn)換
輸入到TensorFlow的數(shù)據(jù)不能是mat類型的所以進(jìn)行mat轉(zhuǎn)vector操作
vector<float> Vmat;
Vmat.assign ( ( float* )ImageMat.datastart, ( float* )ImageMat.dataend );
auto dst = x.flat<float>().data();
auto img = Vmat;
std::copy_n(img.begin(), wanted_width*wanted_height, dst);
- run session
tensorflow::Status run_status = session->Run({{input_layer, x}},
{output_layer}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
tensorflow::LogAllRegisteredKernels();
result = @"Error running model";
return result;
}
- 數(shù)據(jù)變換
使用operator方法獲取到tensor中的每一個(gè)元素值碴倾,重新賦值給array逗噩。
auto outputMatrix = outputs[0].flat<int64>();
array<long,11> outputArray;
for(int i=0;i<11;i++){
outputArray[i]=outputMatrix.operator()(i);
}
NSString *predictionstr = vec2text(outputArray);
獲取完之后需要對(duì)數(shù)據(jù)進(jìn)行處理掉丽,比如我們做的vector轉(zhuǎn)text。
NSString* vec2text(array<long,11> outputArray) {
std::stringstream ss;
ss.precision(12);
ss <<"Prediction:";
for(int i=0;i<11;i++){
long char_idx=outputArray[i];
long char_code = 0;
if (char_idx<10){
char_code = char_idx + int('0');
}
else if (char_idx<36){
char_code = char_idx-10 + int('A');
}
else if (char_idx<62){
char_code = char_idx + int('a');
}
ss << char(char_code);
}
tensorflow::string predictions = ss.str();
NSString* result = [NSString stringWithFormat: @"%s",
predictions.c_str()];
return result;
}
iOS調(diào)用TensorFlow的基礎(chǔ)運(yùn)用就這樣异雁,高級(jí)用法可以使用MemoryMappedModel捶障,這種方法會(huì)比較節(jié)省內(nèi)存,更加優(yōu)雅纲刀。