Tensorflow C++ API調(diào)用Python預(yù)訓(xùn)練模型

最近一段時間研究了如何打通tensorflow線下使用python訓(xùn)練深度學(xué)習(xí)模型,然后線上使用c++調(diào)用預(yù)先訓(xùn)練好的模型完成預(yù)測的流程症歇。畢竟深度學(xué)習(xí)模型上線是需要考慮效率的谭梗,目前來說c++的效率還是python所不能比的。
這篇文章基于tensorflow 1.2版本寫的设塔,tensorflow 1.2版本及以上提供了一種更加方便的c++ API調(diào)用python API訓(xùn)練好的模型远舅。但這方面的資料比較少痕钢,自己也踩了不少坑盖喷,于是寫了一個簡單的使用tensorflow c++ API調(diào)用線下python API訓(xùn)練好的模型的demo课梳,以及如何配置環(huán)境和編譯余佃。

大體的流程如下:

  • 1.使用tensorflow python API編寫和訓(xùn)練自己的模型,訓(xùn)練完成后椭懊,使用tensorflow saver 將模型保存下來步势。
  • 2.使用tensorflow c++ API 構(gòu)建新的session,讀取python版本保存的模型盅抚,然后使用session->run()獲得模型的輸出倔矾。
  • 3.編譯和運行基于tensorflow c++ API寫的代碼。

安裝Bazel

Bazel是一個類似于Make的工具丰包,是Google為其內(nèi)部軟件開發(fā)的特點量身定制的工具邑彪,如今Google使用它來構(gòu)建內(nèi)部大多數(shù)的軟件隙笆。
在編譯 tensorflow c++的時候,需要利用bazel來進(jìn)行編譯的瘸爽,理論上是可以使用Cmake等其工具來編譯的铅忿,但是我嘗試了好久沒有成功,所以最后還是使用了google的bazel進(jìn)行編譯柑潦。希望有大神可以把編譯方法告訴我~
安裝方法按照官方教程走就行。我采用的是直接編譯二進(jìn)制文件的方法览露,這個最簡單直接譬胎,首先下載對應(yīng)版本的二進(jìn)制文件,然后執(zhí)行下面的命令即可:

$ chmod +x bazel-version-installer-os.sh
$ ./bazel-version-installer-os.sh --user

下載Tensorflow源碼

我們需要將Tensorflow源碼下載到本地偏化,后續(xù)編譯tensorflow c++ 代碼需要在這個目錄下進(jìn)行侦讨。在這里需要說明的一點是苟翻,本文采用的c++ API載入python 預(yù)訓(xùn)練模型的方法,是基于tensorflow1.2版本怜俐。所以需要下載tensorflow 1.2版本及以上邓尤,直接從github上clone即可:

$ git clone -b r1.2 https://github.com/tensorflow/tensorflow.git

使用Tensorflow Python API線下定義模型和訓(xùn)練

這里的話我寫了一個十分簡單的基于tensorflow的demo:res=a*b+y汞扎,代碼如下:

# -*-coding:utf-8 -*-
import numpy as np
import tensorflow as tf
import sys, os

if __name__ == '__main__':
    train_dir = os.path.join('demo_model/', "demo")
    a = tf.placeholder(dtype=tf.int32, shape=None, name='a')
    b = tf.placeholder(dtype=tf.int32, shape=None, name='b')
    y = tf.Variable(tf.ones(shape=[1], dtype=tf.int32), dtype=tf.int32, name='y')
    res = tf.add(tf.multiply(a, b), y, name='res')
    with tf.Session() as sess:
        feed_dict = dict()
        feed_dict[a] = 2
        feed_dict[b] = 3
        fetch_list = [res]
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        # 訓(xùn)練和保存模型
        res = sess.run(feed_dict=feed_dict, fetches=fetch_list)
        saver.save(sess, train_dir)

        print("result: ", res[0])

運行結(jié)果如下:

result:  [7]

模型保存在了demo_model/下澈魄,里面的包含四個文件:

checkpoint  #模型checkpoint中的一些文件名的信息
demo.data-00000-of-00001  #模型中保存的各個權(quán)重
demo.index  #可能是保存的各個權(quán)重的索引
demo.meta  #模型構(gòu)造的圖的拓?fù)浣Y(jié)構(gòu)

使用python API 載入和運行模型的代碼如下:

# -*- coding:utf-8 -*-
import tensorflow as tf
import os

if __name__ == '__main__':
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('demo_model/demo.meta')
        saver.restore(sess, tf.train.latest_checkpoint('demo_model/'))
        # sess.run()
        graph = tf.get_default_graph()
        a = graph.get_tensor_by_name("a:0")
        b = graph.get_tensor_by_name("b:0")
        feed_dict = {a: 2, b: 3}

        op_to_restore = graph.get_tensor_by_name("res:0")
        print(sess.run(fetches=op_to_restore, feed_dict=feed_dict))

運行結(jié)果如下:

[7]

使用Tensorflow c++ API讀入預(yù)訓(xùn)練模型

關(guān)于tensorflow c++ API的教程網(wǎng)上資料真的很少痹扇,我只能一邊看著官方文檔一邊查著Stack Overflow慢慢寫了溯香,有些API我現(xiàn)在也不是很清楚怎么用,直接上代碼吧:

//
// Created by MoMo on 17-8-10.
//
#include <iostream>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"

using namespace std;
using namespace tensorflow;

int main()
{
    const string pathToGraph = "demo_model/demo.meta";
    const string checkpointPath = "demo_model/demo";
    auto session = NewSession(SessionOptions());
    if (session == nullptr)
    {
        throw runtime_error("Could not create Tensorflow session.");
    }

    Status status;

// 讀入我們預(yù)先定義好的模型的計算圖的拓?fù)浣Y(jié)構(gòu)
    MetaGraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
    if (!status.ok())
    {
        throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
    }

// 利用讀入的模型的圖的拓?fù)浣Y(jié)構(gòu)構(gòu)建一個session
    status = session->Create(graph_def.graph_def());
    if (!status.ok())
    {
        throw runtime_error("Error creating graph: " + status.ToString());
    }

// 讀入預(yù)先訓(xùn)練好的模型的權(quán)重
    Tensor checkpointPathTensor(DT_STRING, TensorShape());
    checkpointPathTensor.scalar<std::string>()() = checkpointPath;
    status = session->Run(
            {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
            {},
            {graph_def.saver_def().restore_op_name()},
            nullptr);
    if (!status.ok())
    {
        throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
    }

//  構(gòu)造模型的輸入,相當(dāng)與python版本中的feed
    std::vector<std::pair<string, Tensor>> input;
    tensorflow::TensorShape inputshape;
    inputshape.InsertDim(0,1);
    Tensor a(tensorflow::DT_INT32,inputshape);
    Tensor b(tensorflow::DT_INT32,inputshape);
    auto a_map = a.tensor<int,1>();
    a_map(0) = 2;
    auto b_map = b.tensor<int,1>();
    b_map(0) = 3;
    input.emplace_back(std::string("a"), a);
    input.emplace_back(std::string("b"), b);

//   運行模型炕吸,并獲取輸出
    std::vector<tensorflow::Tensor> answer;
    status = session->Run(input, {"res"}, {}, &answer);

    Tensor result = answer[0];
    auto result_map = result.tensor<int,1>();
    cout<<"result: "<<result_map(0)<<endl;

    return 0;
}

使用tensorflow c++ API讀入預(yù)先訓(xùn)練的模型的大體的流程就是這樣 ,復(fù)雜的模型树肃,可能會需要構(gòu)造更加復(fù)雜的輸入和輸出瀑罗,讀入部分一樣。

編譯和運行

代碼寫了筛谚,最后一步就是編譯和運行了停忿。在這里我采用的是bazel進(jìn)行編譯運行,這里需要寫一個BUILD文件吮铭,內(nèi)容如下:

cc_binary(
    name = "demo",#目標(biāo)文件名
    srcs = ["demo.cc"],#源代碼文件名
    deps = [
        "http://tensorflow/cc:cc_ops",
        "http://tensorflow/cc:client_session",
        "http://tensorflow/core:tensorflow"
        ],
)

然后將代碼颅停,BUILD文件一起放在我們下載下來的tensorflow的源碼的tensorflow/tensorflow/demo目錄下,demo目錄為自己新建的。執(zhí)行如下命令進(jìn)行編譯運行:

bazel build -c opt --copt=-mavx --copt="-ggdb" --copt="-g3" demo/...

經(jīng)過漫長的編譯過程纸肉,大概30分鐘喊熟。會在tensorflow/bazel-bin/tensorflow/demo生成可執(zhí)行文件demo,之后將我們預(yù)先訓(xùn)練好的模型放入相同的目錄,運行即可烦味,下面是運行結(jié)果:

result: 7

總結(jié)

整個tensorflow線下使用python訓(xùn)練深度學(xué)習(xí)模型壁拉,然后線上使用c++調(diào)用預(yù)先訓(xùn)練好的模型完成預(yù)測的流程,基本介紹完了溃论。從這個過程可以看出tensorflow的強(qiáng)大之處案铺,開發(fā)者在開發(fā)之處考慮到了落地工業(yè)界梆靖,提供了這樣一套線上和線下打通的流程笔诵,十分方便。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末测僵,一起剝皮案震驚了整個濱河市捍靠,隨后出現(xiàn)的幾起案子森逮,更是在濱河造成了極大的恐慌,老刑警劉巖良风,帶你破解...
    沈念sama閱讀 222,807評論 6 518
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件闷供,死亡現(xiàn)場離奇詭異,居然都是意外死亡疑俭,警方通過查閱死者的電腦和手機(jī)婿失,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 95,284評論 3 399
  • 文/潘曉璐 我一進(jìn)店門豪硅,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人舟误,你說我怎么就攤上這事嵌溢√Q遥” “怎么了?”我有些...
    開封第一講書人閱讀 169,589評論 0 363
  • 文/不壞的土叔 我叫張陵秧骑,是天一觀的道長乎折。 經(jīng)常有香客問我,道長骂澄,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 60,188評論 1 300
  • 正文 為了忘掉前任磨镶,我火速辦了婚禮琳猫,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘脐嫂。我一直安慰自己侄榴,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 69,185評論 6 398
  • 文/花漫 我一把揭開白布蕊爵。 她就那樣靜靜地躺著桦山,像睡著了一般。 火紅的嫁衣襯著肌膚如雪会放。 梳的紋絲不亂的頭發(fā)上钉凌,一...
    開封第一講書人閱讀 52,785評論 1 314
  • 那天,我揣著相機(jī)與錄音矢沿,去河邊找鬼酸纲。 笑死,一個胖子當(dāng)著我的面吹牛栽惶,可吹牛的內(nèi)容都是我干的愁溜。 我是一名探鬼主播冕象,決...
    沈念sama閱讀 41,220評論 3 423
  • 文/蒼蘭香墨 我猛地睜開眼交惯,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了席爽?” 一聲冷哼從身側(cè)響起啊片,我...
    開封第一講書人閱讀 40,167評論 0 277
  • 序言:老撾萬榮一對情侶失蹤紫谷,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后笤昨,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,698評論 1 320
  • 正文 獨居荒郊野嶺守林人離奇死亡捺僻,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,767評論 3 343
  • 正文 我和宋清朗相戀三年匕坯,在試婚紗的時候發(fā)現(xiàn)自己被綠了拔稳。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,912評論 1 353
  • 序言:一個原本活蹦亂跳的男人離奇死亡术奖,死狀恐怖轻绞,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情铲球,我是刑警寧澤稼病,帶...
    沈念sama閱讀 36,572評論 5 351
  • 正文 年R本政府宣布,位于F島的核電站然走,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏晨仑。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 42,254評論 3 336
  • 文/蒙蒙 一洪己、第九天 我趴在偏房一處隱蔽的房頂上張望答捕。 院中可真熱鬧,春花似錦拱镐、人聲如沸持际。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,746評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽芒填。三九已至,卻和暖如春殿衰,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背闷祥。 一陣腳步聲響...
    開封第一講書人閱讀 33,859評論 1 274
  • 我被黑心中介騙來泰國打工凯砍, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人悟衩。 一個月前我還...
    沈念sama閱讀 49,359評論 3 379
  • 正文 我出身青樓座泳,卻偏偏與公主長得像幕与,于是被迫代替她去往敵國和親镇防。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,922評論 2 361

推薦閱讀更多精彩內(nèi)容

  • 1. 介紹 首先讓我們來看看TensorFlow诫给! 但是在我們開始之前啦扬,我們先來看看Python API中的Ten...
    JasonJe閱讀 11,756評論 1 32
  • 簡介 由于生產(chǎn)環(huán)境使用windows、C++吃型,而tensorflow模型訓(xùn)練使用python更為方便僚楞,因此存在需求...
    菜鳥游俠k2閱讀 6,035評論 0 2
  • 愛上民謠第一首歌就是火了很久的南山南,可惜是因為滿是套路的某選秀節(jié)目才被人們熟悉赐写,在節(jié)目開播之前,這首歌我循環(huán)了不...
    壞人王閱讀 828評論 0 0
  • 2015年31號凌晨我們坐上了去往桂林的火車,買的是臥鋪换淆。這是大黃與二哈的第一次結(jié)伴旅游。我跟柜子在8號車廂倍试,大黃...
    鐘無意閱讀 396評論 4 4
  • 小果和果爸去列士公園玩到嗨翻天蛋哭,果媽獨自在家自得意趣,做做手工花藝畫會畫兒。 小果回來哈蝇,直奔媽媽:媽媽味赃,我下次帶你...
    羚羊漫步閱讀 224評論 0 1