? ?tensorflow的官網(wǎng)上提供很詳細(xì)的python教程,也確實(shí)很好用凡辱。但是,python始終是一個(gè)開發(fā)工具燕少,應(yīng)用軟件產(chǎn)品男窟,大多用c/c++寫的盆赤。所以,我打算用python訓(xùn)練tensorflow模型歉眷,然后用c++調(diào)用tensorflow模型牺六。本人通過收集資料,總結(jié)了方法汗捡。本文主要講解一下內(nèi)容:
- tensorflow c++加載訓(xùn)練好的模型淑际。
注:均在ubuntu上實(shí)現(xiàn)
1.使用python訓(xùn)練模型,并保存模型扇住。
a.訓(xùn)練模型春缕,保存模型
利用tf.train.Saver()進(jìn)行保存模型。
sess = tf.InteractiveSession() ##session
saver = tf.train.Saver() ###需要添加的代碼艘蹋,在初始化變量前锄贼。
sess.run(tf.initialize_all_variables())
#your code
#....訓(xùn)練過程....
#your code
saver.save(sess, "model/model.ckpt") ###保存模型在model目錄下
model目錄下生成的文件:
- checkpoint
- model.ckpt.data-00000-of-00001
- model.ckpt.index
- model.ckpt.meta
b.模型整合
調(diào)用tensorflow自帶的 freeze_graph.py 小工具, 輸入為格式.pb或.pbtxt的protobuf文件和.ckpt的參數(shù)文件女阀,輸出為一個(gè)新的同時(shí)包含圖定義和參數(shù)的.pb文件宅荤;這個(gè)步驟的作用是把checkpoint .ckpt文件中的參數(shù)轉(zhuǎn)化為常量const operator后和之前的tensor定義綁定在一起。
python freeze_graph.py --input_checkpoint=../ckpt/model.ckpt -- \
output_graph=../model/model_frozen.pb --output_node_names=output_node
得到model_frozen.pb最終模型
2.使用c++加載模型浸策。
a.頭文件包含
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
using namespace tensorflow;
b.tensorflow模型初始化
//加載tensorflow模型
Session *session;
cout << "start initalize session" << "\n";
Status status = NewSession(SessionOptions(), &session);
if (!status.ok()) {
cout << status.ToString() << "\n";
return 1;
}
GraphDef graph_def;
status = ReadBinaryProto(Env::Default(),MNIST_MODEL_PATH, &graph_def);
//MNIST_MODEL_PATH為模型的路徑冯键,即model_frozen.pb的路徑
if (!status.ok()) {
cout << status.ToString() << "\n";
return 1;
}
status = session->Create(graph_def);
if (!status.ok()) {
cout << status.ToString() << "\n";
return 1;
}
cout << "tensorflow加載成功" << "\n";
c.模型預(yù)測(cè)
Tensor x(DT_FLOAT, TensorShape({1, 784}));//定義輸入張量,包括數(shù)據(jù)類型和大小的榛。
std::vector<float> mydata; //輸入數(shù)據(jù)琼了,784維向量
auto dst = x.flat<float>().data();
copy_n(mydata.begin(), 784, dst); //復(fù)制mydata到dst
vector<pair<string, Tensor>> inputs = {
{ "input", x}
}; //定義模型輸入
vector<Tensor> outputs; //定義模型輸出
Status status = session->Run(inputs, {"softmax"}, {}, &outputs); //調(diào)用模型,
//輸出節(jié)點(diǎn)名為softmax,結(jié)果保存在output中。
if (!status.ok()) {
cout << status.ToString() << "\n";
return 1;
}
//get the final label by max probablity
Tensor t = outputs[0]; // Fetch the first tensor
int ndim = t.shape().dims(); // Get the dimension of the tensor
auto tmap = t.tensor<float, 2>();
// Tensor Shape: [batch_size, target_class_num]
// int output_dim = t.shape().dim_size(1);
// Get the target_class_num from 1st dimension
//將結(jié)果保存在softmax數(shù)組中(該模型是多輸出模型)
double softmax[9];
for (int j = 1; j < 10; j++) {
softmax[j-1]=tmap(0, j);
}
參考資料:
- Tensorflow C++ 編譯和調(diào)用圖模型 :即講了安裝tensorflow c++過程,又講了使用過程雕薪。
- Tensorflow CPP API demo :第一個(gè)參考資料的源碼昧诱,很詳細(xì)。
- Training a TensorFlow graph in C++ API :一個(gè)使用tensorflow c++案例所袁。
- Loading a TensorFlow graph with the C++ API :一個(gè)非常好的加載tensorflow模型教程盏档,常被引用。