將tensorflow訓(xùn)練好的模型移植到android上
說(shuō)明
本文將描述如何將一個(gè)訓(xùn)練好的模型植入到android設(shè)備上旷余,并且在android設(shè)備上輸入待處理數(shù)據(jù),通過(guò)模型,獲取輸出數(shù)據(jù)壮不。
通過(guò)一個(gè)例子,講述整個(gè)移植的過(guò)程。(demo的源碼訪問(wèn)github上了https://github.com/CrystalChen1017/TSFOnAndroid)
整體的思路如下:
- 使用python在PC上訓(xùn)練好你的模型尘喝,保存為pb文件
- 新建android project,把pb文件放到assets文件夾下
- 將tensorflow的so文件以及jar包放到libs下
- 加載庫(kù)文件斋陪,讓tensorflow在app中運(yùn)行起來(lái)
準(zhǔn)備
- tensorflow的環(huán)境朽褪,參閱http://blog.csdn.net/cxq234843654/article/details/70857562
- libtensorflow_inference.so
- libandroid_tensorflow_inference_java.jar
- 如果要自己編譯得到以上兩個(gè)文件,需要安裝bazel无虚。參閱http://blog.csdn.net/cxq234843654/article/details/70861155 的第2步
以上兩個(gè)文件通過(guò)以下兩個(gè)網(wǎng)址進(jìn)行下載:
https://github.com/CrystalChen1017/TSFOnAndroid/tree/master/app/libs
或者
http://download.csdn.net/detail/cxq234843654/9833372
PC端模型的準(zhǔn)備
這是一個(gè)很簡(jiǎn)單的模型缔赠,輸入是一個(gè)數(shù)組matrix1,經(jīng)過(guò)操作后友题,得到這個(gè)數(shù)組乘以2*matrix1嗤堰。
- 給輸入數(shù)據(jù)命名為
input
,在android端需要用這個(gè)input
來(lái)為輸入數(shù)據(jù)賦值 - 給輸輸數(shù)據(jù)命名為
output
,在android端需要用這個(gè)output
來(lái)為獲取輸出的值 - 不能使用 tf.train.write_graph()保存模型,因?yàn)樗皇潜4媪四P偷慕Y(jié)構(gòu)度宦,并不保存訓(xùn)練完畢的參數(shù)值
- 不能使用 tf.train.saver()保存模型踢匣,因?yàn)樗皇潜4媪司W(wǎng)絡(luò)中的參數(shù)值,并不保存模型的結(jié)構(gòu)戈抄。
-
graph_util.convert_variables_to_constants
可以把整個(gè)sesion當(dāng)作常量都保存下來(lái)离唬,通過(guò)output_node_names
參數(shù)來(lái)指定輸出 -
tf.gfile.FastGFile('model/cxq.pb', mode='wb')
指定保存文件的路徑以及讀寫(xiě)方式 -
f.write(output_graph_def.SerializeToString())
將固化的模型寫(xiě)入到文件
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.client import graph_util
session = tf.Session()
matrix1 = tf.constant([[3., 3.]], name='input')
add2Mat = tf.add(matrix1, matrix1, name='output')
session.run(add2Mat)
output_graph_def = graph_util.convert_variables_to_constants(session, session.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('model/cxq.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
session.close()
運(yùn)行后就會(huì)在model文件夾下產(chǎn)生一個(gè)cxq.pb文件,現(xiàn)在這個(gè)文件將剛才一系列的操作固化了划鸽,因此下次需要計(jì)算變量乘2時(shí)输莺,我們可以直接拿到pb文件,指定輸入裸诽,再獲取輸出嫂用。
(可選的)bazel編譯出so和jar文件
如果希望自己通過(guò)tensorflow的源碼編譯出so和jar文件,則需要通過(guò)終端進(jìn)入到tensorflow的目錄下崭捍,進(jìn)行如下操作:
- 編譯so庫(kù)
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
-- crosstool_top=//external:android/crosstool \
-- host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
-- cpu=armeabi-v7a
編譯完畢后尸折,libtensorflow_inference.so的路徑為:
/tensorflow/bazel-bin/tensorflow/contrib/android
- 編譯jar包
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
編譯完畢后,android_tensorflow_inference_java.jar的路徑為:
/tensorflow/bazel-bin/tensorflow/contrib/android
android端的準(zhǔn)備
- 新建一個(gè)Android Project
- 把剛才的pb文件存放到assets文件夾下
- 將libandroid_tensorflow_inference_java.jar存放到/app/libs目錄下殷蛇,并且右鍵“add as Libary”
- 在/app/libs下新建armeabi文件夾实夹,并將libtensorflow_inference.so放進(jìn)去
配置app:gradle以及gradle.properties
- 在android節(jié)點(diǎn)下添加soureSets,用于制定jniLibs的路徑
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
- 在defaultConfig節(jié)點(diǎn)下添加
defaultConfig {
ndk {
abiFilters "armeabi"
}
}
- 在gradle.properties中添加下面一行
android.useDeprecatedNdk=true
通過(guò)以上3步操作粒梦,tensorflow的環(huán)境已經(jīng)部署好了亮航。
模型的調(diào)用
我們先新建一個(gè)MyTSF類,在這個(gè)類里面進(jìn)行模型的調(diào)用匀们,并且獲取輸出
package com.learn.tsfonandroid;
import android.content.res.AssetManager;
import android.os.Trace;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class MyTSF {
private static final String MODEL_FILE = "file:///android_asset/cxq.pb"; //模型存放路徑
//數(shù)據(jù)的維度
private static final int HEIGHT = 1;
private static final int WIDTH = 2;
//模型中輸出變量的名稱
private static final String inputName = "input";
//用于存儲(chǔ)的模型輸入數(shù)據(jù)
private float[] inputs = new float[HEIGHT * WIDTH];
//模型中輸出變量的名稱
private static final String outputName = "output";
//用于存儲(chǔ)模型的輸出數(shù)據(jù)
private float[] outputs = new float[HEIGHT * WIDTH];
TensorFlowInferenceInterface inferenceInterface;
static {
//加載庫(kù)文件
System.loadLibrary("tensorflow_inference");
}
MyTSF(AssetManager assetManager) {
//接口定義
inferenceInterface = new TensorFlowInferenceInterface(assetManager,MODEL_FILE);
}
public float[] getAddResult() {
//為輸入數(shù)據(jù)賦值
inputs[0]=1;
inputs[1]=3;
//將數(shù)據(jù)feed給tensorflow
Trace.beginSection("feed");
inferenceInterface.feed(inputName, inputs, WIDTH, HEIGHT);
Trace.endSection();
//運(yùn)行乘2的操作
Trace.beginSection("run");
String[] outputNames = new String[] {outputName};
inferenceInterface.run(outputNames);
Trace.endSection();
//將輸出存放到outputs中
Trace.beginSection("fetch");
inferenceInterface.fetch(outputName, outputs);
Trace.endSection();
return outputs;
}
}
在Activity中使用MyTSF類
public void click01(View v){
Log.i(TAG, "click01: ");
MyTSF mytsf=new MyTSF(getAssets());
float[] result=mytsf.getAddResult();
for (int i=0;i<result.length;i++){
Log.i(TAG, "click01: "+result[i] );
}
}