Python+Android進行TensorFlow開發(fā)

tensorflow

Tensorflow是Google開源的一套機器學習框架素挽,支持GPU照卦、CPU式矫、Android等多種計算平臺。本文將介紹在Tensorflow在Android上的使用役耕。

Android使用Tensorflow框架需要引入兩個文件libtensorflow_inference.so采转、libandroid_tensorflow_inference_java.jar。這兩個文件可以使用官方預編譯的文件瞬痘。如果預編譯的so不滿足要求(比如不支持訓練模型中的某些操作符運算)故慈,也可以自己通過bazel編譯生成這兩個文件板熊。

將libandroid_tensorflow_inference_java.jar放在app下的libs目錄下,so文件命名為libtensorflow_jni.so放在src/main/jniLibs目錄下對應的ABI文件夾下察绷。目錄結構如下:

android目錄結構

同時在app的build.gradle中的dependencies模塊下添加如下配置:

dependencies {
    ...
    compile files('libs/libandroid_tensorflow_inference_java.jar')
    ...
}

使用tensorflow框架進行機器學習分為四個步驟:

  1. 構造神經(jīng)網(wǎng)絡
  2. 訓練神經(jīng)網(wǎng)絡模型
  3. 將訓練好的模型輸出為pb文件
  4. 在Android上加載pb模型進行計算

前三步是模型的構造干签,我們通過python實現(xiàn),下面給出了一個二分類的簡單模型的構造過程克婶,首先是訓練過程:

# -*-coding:utf-8 -*-
from __future__ import print_function
import os
import tensorflow as tf
from numpy.random import RandomState

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""
訓練模型
"""
def train():
    # 定義訓練數(shù)據(jù)集batch大小為8
    batch_size = 8

    # 定義神經(jīng)網(wǎng)絡參數(shù)筒严,參數(shù)體現(xiàn)出神經(jīng)網(wǎng)絡結構,一個輸入層情萤,一個輸出層,一個隱藏層
    w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
    w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")

    # 定義輸入輸出格式
    x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
    y_ = tf.placeholder(tf.float32, shape=(None, 1))

    # 定義神經(jīng)網(wǎng)絡前向傳播過程
    a = tf.matmul(x, w1)
    y = tf.matmul(a, w2, name="cal_node")

    # 定義交叉熵和反向傳播算法
    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)

    # 生成隨機訓練集
    rdm = RandomState(1)
    dataset_size = 128

    # 定義映射關系
    X = rdm.rand(dataset_size, 2)
    Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

    with tf.Session() as sess:
        # 初始化所有參數(shù)
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # print sess.run(w1)
        # print sess.run(w2)

        STEPS = 500
        for i in range(STEPS):
            start = (i * batch_size) % dataset_size
            end = min(start + batch_size, dataset_size)

            # 訓練神經(jīng)網(wǎng)絡摹恨,更新神經(jīng)網(wǎng)絡參數(shù)
            sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})

            if i % 100 == 0:
                total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
                print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))

            print(sess.run(w1))
            print(sess.run(w2))

        # 保存check point
        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(sess, './model/checpt')

上面的代碼首先定義神經(jīng)網(wǎng)絡筋岛,初始化訓練數(shù)據(jù),進行500次訓練過程晒哄,并將訓練結果checkpoints保存到model文件夾下睁宰,checkpoints包含了訓練模型得到的參數(shù)信息,共生成四個相關的文件寝凌,如下圖:

checkpoint相關文件

由于checkpoint文件眾多柒傻,為了方便使用,我們通過下面的代碼將它們生成一個pb文件较木,在android上只需要這個pb文件即可使用這個訓練好的模型:

"""
存儲pb模型
"""
def dump_graph_to_pb(pb_path):
    with tf.Session() as sess:
        check_point = tf.train.get_checkpoint_state("./model/")
        if check_point:
            saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
            saver.restore(sess, check_point.model_checkpoint_path)
        else:
            raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))

        graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))

        with tf.gfile.GFile(pb_path, "wb") as f:
            f.write(graph_def.SerializeToString())

拿到生成的pb模型红符,我們可以在android上使用了。將pb文件在這main/assets下:

image.png

接下來就可以載入pb伐债,進行計算了:

public class MainActivity extends AppCompatActivity {
    private Graph graph_;
    private Session session_;
    private AssetManager assetManager;

    private static ExecutorService executorService;
    private static Handler handler;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        executorService = Executors.newFixedThreadPool(5);

        // 初始化tensorflow
        initTensorFlow("outmodel.pb");

        // 使用tensorflow進行計算
        runTensorFlow();
    }
    ...
}

通過如下方式載入pb模型预侯,初始化tensorflow:

private boolean initTensorFlow(String modelFile) {
        assetManager = getAssets();
        // 新建Graph
        graph_ = new Graph();

        InputStream is = null;
        try {
            // 讀取Assets pb文件
            is = assetManager.open(modelFile);
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }

        try {
            // 加載pb到Graph
            TensorUtil.loadGraph(is, graph_);
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }
        // 初始化session
        session_ = new Session(graph_);
        if (session_ == null) {
            return false;
        }

        return true;
    }

然后就可以使用tensorflow API進行運算了:

private void runTensorFlow() {
        executorService.execute(generatePredictRunnable(handler));
    }

    private Runnable generatePredictRunnable(Handler handler) {
        return new Runnable() {
            @Override
            public void run() {
                float[][] input = new float[1][2];

                input[0][0] = 1;
                input[0][1] = 2;

                // 定義輸入tensor
                Tensor inputTensor = Tensor.create(input);

                // 指定輸入,輸出節(jié)點峰锁,運行并得到結果
                Tensor resultTensor = session_.runner()
                        .feed("x_input", inputTensor)
                        .fetch("cal_node")
                        .run()
                        .get(0);

                float[][] dst = new float[1][1];
                resultTensor.copyTo(dst);

                // 處理結果
                ArrayList<Float> resultList = new ArrayList<>();
                for (float val : dst[0]) {
                    if (val != 0) {
                        resultList.add(val);
                    } else {
                        break;
                    }
                }
            }
        };
    }

上面就是通過python訓練機器學習模型萎馅,并在android平臺進行調用的完整流程。

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末虹蒋,一起剝皮案震驚了整個濱河市糜芳,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌魄衅,老刑警劉巖峭竣,帶你破解...
    沈念sama閱讀 217,277評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異徐绑,居然都是意外死亡,警方通過查閱死者的電腦和手機傲茄,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評論 3 393
  • 文/潘曉璐 我一進店門毅访,熙熙樓的掌柜王于貴愁眉苦臉地迎上來沮榜,“玉大人,你說我怎么就攤上這事喻粹◇∪冢” “怎么了?”我有些...
    開封第一講書人閱讀 163,624評論 0 353
  • 文/不壞的土叔 我叫張陵守呜,是天一觀的道長型酥。 經(jīng)常有香客問我,道長查乒,這世上最難降的妖魔是什么弥喉? 我笑而不...
    開封第一講書人閱讀 58,356評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮玛迄,結果婚禮上由境,老公的妹妹穿的比我還像新娘。我一直安慰自己蓖议,他們只是感情好虏杰,可當我...
    茶點故事閱讀 67,402評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著勒虾,像睡著了一般纺阔。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上修然,一...
    開封第一講書人閱讀 51,292評論 1 301
  • 那天笛钝,我揣著相機與錄音,去河邊找鬼低零。 笑死婆翔,一個胖子當著我的面吹牛,可吹牛的內容都是我干的掏婶。 我是一名探鬼主播啃奴,決...
    沈念sama閱讀 40,135評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼雄妥!你這毒婦竟也來了最蕾?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 38,992評論 0 275
  • 序言:老撾萬榮一對情侶失蹤老厌,失蹤者是張志新(化名)和其女友劉穎瘟则,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體枝秤,經(jīng)...
    沈念sama閱讀 45,429評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡醋拧,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,636評論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片丹壕。...
    茶點故事閱讀 39,785評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡庆械,死狀恐怖,靈堂內的尸體忽然破棺而出菌赖,到底是詐尸還是另有隱情缭乘,我是刑警寧澤,帶...
    沈念sama閱讀 35,492評論 5 345
  • 正文 年R本政府宣布琉用,位于F島的核電站堕绩,受9級特大地震影響,放射性物質發(fā)生泄漏邑时。R本人自食惡果不足惜奴紧,卻給世界環(huán)境...
    茶點故事閱讀 41,092評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望晶丘。 院中可真熱鬧绰寞,春花似錦、人聲如沸铣口。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽脑题。三九已至,卻和暖如春铜靶,著一層夾襖步出監(jiān)牢的瞬間叔遂,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評論 1 269
  • 我被黑心中介騙來泰國打工争剿, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留已艰,地道東北人。 一個月前我還...
    沈念sama閱讀 47,891評論 2 370
  • 正文 我出身青樓蚕苇,卻偏偏與公主長得像哩掺,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子涩笤,可洞房花燭夜當晚...
    茶點故事閱讀 44,713評論 2 354

推薦閱讀更多精彩內容