Tensorflow是Google開源的一套機器學習框架素挽,支持GPU照卦、CPU式矫、Android等多種計算平臺。本文將介紹在Tensorflow在Android上的使用役耕。
Android使用Tensorflow框架需要引入兩個文件libtensorflow_inference.so采转、libandroid_tensorflow_inference_java.jar。這兩個文件可以使用官方預編譯的文件瞬痘。如果預編譯的so不滿足要求(比如不支持訓練模型中的某些操作符運算)故慈,也可以自己通過bazel編譯生成這兩個文件板熊。
將libandroid_tensorflow_inference_java.jar放在app下的libs目錄下,so文件命名為libtensorflow_jni.so放在src/main/jniLibs目錄下對應的ABI文件夾下察绷。目錄結構如下:
同時在app的build.gradle中的dependencies模塊下添加如下配置:
dependencies {
...
compile files('libs/libandroid_tensorflow_inference_java.jar')
...
}
使用tensorflow框架進行機器學習分為四個步驟:
- 構造神經(jīng)網(wǎng)絡
- 訓練神經(jīng)網(wǎng)絡模型
- 將訓練好的模型輸出為pb文件
- 在Android上加載pb模型進行計算
前三步是模型的構造干签,我們通過python實現(xiàn),下面給出了一個二分類的簡單模型的構造過程克婶,首先是訓練過程:
# -*-coding:utf-8 -*-
from __future__ import print_function
import os
import tensorflow as tf
from numpy.random import RandomState
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
"""
訓練模型
"""
def train():
# 定義訓練數(shù)據(jù)集batch大小為8
batch_size = 8
# 定義神經(jīng)網(wǎng)絡參數(shù)筒严,參數(shù)體現(xiàn)出神經(jīng)網(wǎng)絡結構,一個輸入層情萤,一個輸出層,一個隱藏層
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")
# 定義輸入輸出格式
x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
y_ = tf.placeholder(tf.float32, shape=(None, 1))
# 定義神經(jīng)網(wǎng)絡前向傳播過程
a = tf.matmul(x, w1)
y = tf.matmul(a, w2, name="cal_node")
# 定義交叉熵和反向傳播算法
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)
# 生成隨機訓練集
rdm = RandomState(1)
dataset_size = 128
# 定義映射關系
X = rdm.rand(dataset_size, 2)
Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]
with tf.Session() as sess:
# 初始化所有參數(shù)
init_op = tf.global_variables_initializer()
sess.run(init_op)
# print sess.run(w1)
# print sess.run(w2)
STEPS = 500
for i in range(STEPS):
start = (i * batch_size) % dataset_size
end = min(start + batch_size, dataset_size)
# 訓練神經(jīng)網(wǎng)絡摹恨,更新神經(jīng)網(wǎng)絡參數(shù)
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 100 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))
print(sess.run(w1))
print(sess.run(w2))
# 保存check point
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, './model/checpt')
上面的代碼首先定義神經(jīng)網(wǎng)絡筋岛,初始化訓練數(shù)據(jù),進行500次訓練過程晒哄,并將訓練結果checkpoints保存到model文件夾下睁宰,checkpoints包含了訓練模型得到的參數(shù)信息,共生成四個相關的文件寝凌,如下圖:
由于checkpoint文件眾多柒傻,為了方便使用,我們通過下面的代碼將它們生成一個pb文件较木,在android上只需要這個pb文件即可使用這個訓練好的模型:
"""
存儲pb模型
"""
def dump_graph_to_pb(pb_path):
with tf.Session() as sess:
check_point = tf.train.get_checkpoint_state("./model/")
if check_point:
saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
saver.restore(sess, check_point.model_checkpoint_path)
else:
raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))
graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))
with tf.gfile.GFile(pb_path, "wb") as f:
f.write(graph_def.SerializeToString())
拿到生成的pb模型红符,我們可以在android上使用了。將pb文件在這main/assets下:
接下來就可以載入pb伐债,進行計算了:
public class MainActivity extends AppCompatActivity {
private Graph graph_;
private Session session_;
private AssetManager assetManager;
private static ExecutorService executorService;
private static Handler handler;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
executorService = Executors.newFixedThreadPool(5);
// 初始化tensorflow
initTensorFlow("outmodel.pb");
// 使用tensorflow進行計算
runTensorFlow();
}
...
}
通過如下方式載入pb模型预侯,初始化tensorflow:
private boolean initTensorFlow(String modelFile) {
assetManager = getAssets();
// 新建Graph
graph_ = new Graph();
InputStream is = null;
try {
// 讀取Assets pb文件
is = assetManager.open(modelFile);
} catch (IOException e) {
e.printStackTrace();
return false;
}
try {
// 加載pb到Graph
TensorUtil.loadGraph(is, graph_);
is.close();
} catch (IOException e) {
e.printStackTrace();
return false;
}
// 初始化session
session_ = new Session(graph_);
if (session_ == null) {
return false;
}
return true;
}
然后就可以使用tensorflow API進行運算了:
private void runTensorFlow() {
executorService.execute(generatePredictRunnable(handler));
}
private Runnable generatePredictRunnable(Handler handler) {
return new Runnable() {
@Override
public void run() {
float[][] input = new float[1][2];
input[0][0] = 1;
input[0][1] = 2;
// 定義輸入tensor
Tensor inputTensor = Tensor.create(input);
// 指定輸入,輸出節(jié)點峰锁,運行并得到結果
Tensor resultTensor = session_.runner()
.feed("x_input", inputTensor)
.fetch("cal_node")
.run()
.get(0);
float[][] dst = new float[1][1];
resultTensor.copyTo(dst);
// 處理結果
ArrayList<Float> resultList = new ArrayList<>();
for (float val : dst[0]) {
if (val != 0) {
resultList.add(val);
} else {
break;
}
}
}
};
}
上面就是通過python訓練機器學習模型萎馅,并在android平臺進行調用的完整流程。