在Android工程中宇色,集成TensorFlow模型。運(yùn)行TensorFlow的默認(rèn)Android工程施禾,請(qǐng)參考脚线。
Android源碼:https://github.com/SpikeKing/TFAndroid/tree/master
庫(kù)及模型的大小
libtensorflow_inference.so 10.2 M
libandroid_tensorflow_inference_java.jar 27 KB
optimized_tfdroid.pb 291 B
如果將so轉(zhuǎn)換為jar庫(kù),參考弥搞,則TF的so由10.2M縮小至4.1M邮绿。
TensorFlow
TF模型源碼:
https://github.com/SpikeKing/MachineLearningTutorial/blob/master/tests/android_test.py
創(chuàng)建TensorFlow模型,簡(jiǎn)單的y=WX+b
攀例,存儲(chǔ)圖信息write_graph
船逮,存儲(chǔ)參數(shù)信息saver.save
。輸入數(shù)據(jù)placeholder是I
肛度,輸出數(shù)據(jù)是O
傻唾。
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None, 3], name='I') # input
W = tf.Variable(tf.zeros(shape=[3, 2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
tf.train.write_graph(sess.graph_def, './data/android/', 'tfdroid.pbtxt') # 存儲(chǔ)TensorFlow的圖
# 訓(xùn)練數(shù)據(jù),本例直接賦值
sess.run(tf.assign(W, [[1, 2], [4, 5], [7, 8]]))
sess.run(tf.assign(b, [1, 1]))
# 存儲(chǔ)checkpoint文件承耿,即參數(shù)信息
saver.save(sess, './data/android/tfdroid.ckpt')
創(chuàng)建Freeze的圖冠骄,將圖結(jié)構(gòu)與參數(shù)組合在一起,生成模型加袋,參考凛辣。
def gnr_freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, output_graph, clear_devices):
"""
將輸入圖與參數(shù)結(jié)合在一起
:param input_graph: 輸入圖
:param input_saver: Saver解析器
:param input_binary: 輸入圖的格式,false是文本职烧,true是二進(jìn)制
:param input_checkpoint: checkpoint扁誓,檢查點(diǎn)文件
:param output_node_names: 輸出節(jié)點(diǎn)名稱
:param output_graph: 保存輸出文件
:param clear_devices: 清除訓(xùn)練設(shè)備
:return: NULL
"""
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
freeze_graph.freeze_graph(
input_graph=input_graph, # 輸入圖
input_saver=input_saver, # Saver解析器
input_binary=input_binary, # 輸入圖的格式,false是文本蚀之,true是二進(jìn)制
input_checkpoint=input_checkpoint, # checkpoint蝗敢,檢查點(diǎn)文件
output_node_names=output_node_names, # 輸出節(jié)點(diǎn)名稱
restore_op_name=restore_op_name, # 從模型恢復(fù)節(jié)點(diǎn)的名字
filename_tensor_name=filename_tensor_name, # tensor名稱
output_graph=output_graph, # 保存輸出文件
clear_devices=clear_devices, # 清除訓(xùn)練設(shè)備
initializer_nodes="") # 初始化節(jié)點(diǎn)
優(yōu)化模型,剪切節(jié)點(diǎn)足删,模型只保留輸入輸出的參數(shù)寿谴。
def gnr_optimize_graph(graph_path, optimized_graph_path):
"""
優(yōu)化圖
:param graph_path: 原始圖
:param optimized_graph_path: 優(yōu)化的圖
:return: NULL
"""
input_graph_def = tf.GraphDef() # 讀取原始圖
with tf.gfile.Open(graph_path, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
# 設(shè)置輸入輸出節(jié)點(diǎn),剪切分支失受,大約節(jié)省1/4
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# 存儲(chǔ)優(yōu)化的圖
f = tf.gfile.FastGFile(optimized_graph_path, "w")
f.write(output_graph_def.SerializeToString())
執(zhí)行函數(shù)讶泰,生成模型咏瑟,frozen_tfdroid.pb
和optimized_tfdroid.pb
。
if __name__ == "__main__":
input_graph_path = MODEL_FOLDER + MODEL_NAME + '.pbtxt' # 輸入圖
checkpoint_path = MODEL_FOLDER + MODEL_NAME + '.ckpt' # 輸入?yún)?shù)
output_path = MODEL_FOLDER + 'frozen_' + MODEL_NAME + '.pb' # Freeze模型
gnr_freeze_graph(input_graph=input_graph_path, input_saver="",
input_binary=False, input_checkpoint=checkpoint_path,
output_node_names="O", output_graph=output_path, clear_devices=True)
optimized_output_graph = MODEL_FOLDER + 'optimized_' + MODEL_NAME + '.pb'
gnr_optimize_graph(output_path, optimized_output_graph)
Android
編譯Android的庫(kù)痪署,參考码泞,或者,直接在Nightly中下載狼犯,參考余寥,archive.zip,大約158M辜王。
創(chuàng)建Android工程劈狐,添加app/libs/
中添加庫(kù)文件。
armeabi-v7a/libtensorflow_inference.so
libandroid_tensorflow_inference_java.jar
在build.gradle中呐馆,添加
android {
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
}
在app/src/main/assets中肥缔,添加模型optimized_tfdroid.pb
文件。
在MainActivity中汹来,添加so庫(kù)续膳。
static {
System.loadLibrary("tensorflow_inference");
}
模型文件在assets中,TF的核心接口類TensorFlowInferenceInterface收班。
private static final String MODEL_FILE = "file:///android_asset/optimized_tfdroid.pb";
private TensorFlowInferenceInterface mInferenceInterface;
初始模型文件
mInferenceInterface = new TensorFlowInferenceInterface();
mInferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
模型Feed數(shù)據(jù)坟岔,輸入點(diǎn)名稱是INPUT_NODE
,輸入結(jié)構(gòu)INPUT_SIZE
摔桦,輸入數(shù)據(jù)inputFloats社付。
float[] inputFloats = {num1, num2, num3};
mInferenceInterface.fillNodeFloat(INPUT_NODE, INPUT_SIZE, inputFloats);
模型執(zhí)行文件,輸出點(diǎn)名稱是OUTPUT_NODE
邻耕,即"O"
mInferenceInterface.runInference(new String[]{OUTPUT_NODE});
輸出數(shù)據(jù)結(jié)構(gòu)
float[] resu = {0, 0};
mInferenceInterface.readNodeFloat(OUTPUT_NODE, resu);
最后鸥咖,在layout中創(chuàng)建GUI布局。
效果
TensorFlow集成至春雨醫(yī)生
That's all! Enjoy it!