這兩天算法同事那邊用keras訓(xùn)練了個二分類的模型返干。
有一個新的需求是把keras模型跑到 tensorflow serving上 (TensorFlow Serving 系統(tǒng)用于在生產(chǎn)環(huán)境中運行模型)激涤。
在這之前我并沒有接觸過keras微王、tensorflow 蔼夜, 官方教程和一堆的博客論壇資料有些過時兼耀,(keras 模型轉(zhuǎn) tensorflow 模型的示例代碼跑不動),過程不太順利求冷,于是花了一天學(xué)習(xí) keras瘤运、 tensorlow, 寫個小demo,再追蹤一下keras和tensorflow源代碼匠题,耗時兩天終于把這個需求實現(xiàn)了拯坟。這里記錄填坑過程。
keras模型轉(zhuǎn) tensorflow模型
我把 keras模型轉(zhuǎn)tensorflow serving模型所使用的方法如下:
1梧躺、要拿到算法訓(xùn)練好的keras模型文件(一個HDF5文件)
該文件應(yīng)該包含:
- 模型的結(jié)構(gòu)似谁,以便重構(gòu)該模型
- 模型的權(quán)重
- 訓(xùn)練配置(損失函數(shù)傲绣,優(yōu)化器等)
- 優(yōu)化器的狀態(tài)掠哥,以便于從上次訓(xùn)練中斷的地方開始
2、編寫 keras模型轉(zhuǎn)tensorflow serving模型的代碼
import tensorflow as tf
from keras import backend as K
from keras.models import Sequential, Model
from os.path import isfile
def build_model():
model = Sequential()
# 省略這部分代碼秃诵,根據(jù)算法實際情況填寫
return model
def save_model_to_serving(model, export_version, export_path='prod_models'):
print(model.input, model.output)
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'voice': model.input}, outputs={'scores': model.output})
export_path = os.path.join(
tf.compat.as_bytes(export_path),
tf.compat.as_bytes(str(export_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'voice_classification': signature,
},
legacy_init_op=legacy_init_op)
builder.save()
if __name__ == '__main__':
model = build_model()
model.compile(loss='categorical_crossentropy',
optimizer='xxx', # 用實際算法情況替換這里的xxx
metrics=['xxx'])
model.summary()
checkpoint_filepath = 'weights.hdf5'
if (isfile(checkpoint_filepath)):
print('Checkpoint file detected. Loading weights.')
model.load_weights(checkpoint_filepath) # 加載模型
else:
print('No checkpoint file detected. Starting from scratch.')
export_path = "test_model"
save_model_to_serving(model, "1", export_path)
上面的例子將模型保存到 test_model目錄下
test_model目錄結(jié)構(gòu)如下:
test_model/
└── 1
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
saved_model.pb 是能在 tensorflow serving跑起來的模型续搀。
3、跑模型
tensorflow_model_server --port=8500 --model_name="voice" --model_base_path="/home/yu/workspace/test/test_model/"
標準輸出如下(算法模型已成功跑起來了):
2018-02-08 16:28:02.641662: I tensorflow_serving/model_servers/main.cc:149] Building single TensorFlow model file config: model_name: voice model_base_path: /home/yu/workspace/test/test_model/
2018-02-08 16:28:02.641917: I tensorflow_serving/model_servers/server_core.cc:439] Adding/updating models.
2018-02-08 16:28:02.641976: I tensorflow_serving/model_servers/server_core.cc:490] (Re-)adding model: voice
2018-02-08 16:28:02.742740: I tensorflow_serving/core/basic_manager.cc:705] Successfully reserved resources to load servable {name: voice version: 1}
2018-02-08 16:28:02.742800: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: voice version: 1}
2018-02-08 16:28:02.742815: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: voice version: 1}
2018-02-08 16:28:02.742867: I external/org_tensorflow/tensorflow/contrib/session_bundle/bundle_shim.cc:360] Attempting to load native SavedModelBundle in bundle-shim from: /home/yu/workspace/test/test_model/1
2018-02-08 16:28:02.742906: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:236] Loading SavedModel from: /home/yu/workspace/test/test_model/1
2018-02-08 16:28:02.755299: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-02-08 16:28:02.795329: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2018-02-08 16:28:02.820146: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2018-02-08 16:28:02.832832: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 89481 microseconds.
2018-02-08 16:28:02.834804: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: voice version: 1}
2018-02-08 16:28:02.836855: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:8500 ...
4菠净、客戶端代碼
from __future__ import print_function
from grpc.beta import implementations
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
import numpy as np
tf.app.flags.DEFINE_string('server', 'localhost:8500',
'PredictionService host:port')
tf.app.flags.DEFINE_string('vocie', '', 'path to voice in wav format')
FLAGS = tf.app.flags.FLAGS
def get_melgram(path):
melgram = .... # 這里省略
return melgram
def main(_):
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
# Send request
# See prediction_service.proto for gRPC request/response details.
data = get_melgram("T_1000001.wav")
data = data.astype(np.float32)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'voice' # 這個name跟tensorflow_model_server --model_name="voice" 對應(yīng)
request.model_spec.signature_name = 'voice_classification' # 這個signature_name 跟signature_def_map 對應(yīng)
request.inputs['voice'].CopyFrom(
tf.contrib.util.make_tensor_proto(data, shape=[1, 1, 96, 89])) # shape跟 keras的model.input類型對應(yīng)
result = stub.Predict(request, 10.0) # 10 secs timeout
print(result)
if __name__ == '__main__':
tf.app.run()
客戶端跑出的結(jié)果是:
outputs {
key: "scores"
value {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
dim {
size: 2
}
}
float_val: 0.0341101661325
float_val: 0.965889811516
}
}
float_val: 0.0341101661325
和float_val: 0.965889811516
就是我們需要的結(jié)果禁舷。
keras模型轉(zhuǎn) tensorflow模型的一些說明
1彪杉、 keras 保存模型
可以使用model.save(filepath)
將Keras模型和權(quán)重保存在一個HDF5文件中,該文件將包含:
- 模型的結(jié)構(gòu)牵咙,以便重構(gòu)該模型
- 模型的權(quán)重
- 訓(xùn)練配置(損失函數(shù)派近,優(yōu)化器等)
- 優(yōu)化器的狀態(tài),以便于從上次訓(xùn)練中斷的地方開始
當(dāng)然這個 HDF5 也可以是用下面的代碼生成
from keras.callbacks import ModelCheckpoint
checkpoint_filepath = 'weights.hdf5'
checkpointer = ModelCheckpoint(filepath=checkpoint_filepath, verbose=1, save_best_only=True)
2洁桌、 keras 加載模型
keras 加載模型像下面這樣子(中間部分代碼省略了):
from keras.models import Sequential, Model
model = Sequential()
.....
model.compile(loss='categorical_crossentropy',
optimizer='xxx', # 用實際算法情況替換這里的xxx
metrics=['xxx'])
model.summary()
model.load_weights("xxx.h5") # 加載keras模型(一個HDF5文件)
keras 模型轉(zhuǎn)tensorflow serving 模型的一些坑
希望能讓新手少走一些彎路
坑1:過時的生成方法
有些方法已經(jīng)過時了(例如下面這種):
from tensorflow_serving.session_bundle import exporter
export_path = ... # where to save the exported graph
export_version = ... # version number (integer)
saver = tf.train.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
signature = exporter.classification_signature(input_tensor=model.input,
scores_tensor=model.output)
model_exporter.init(sess.graph.as_graph_def(),
default_graph_signature=signature)
model_exporter.export(export_path, tf.constant(export_version), sess)
如果使用這種過時的方法渴丸,用tensorflow serving 跑模型的時候會提示:
WARNING:tensorflow:From test.py:107: Exporter.export (from tensorflow.contrib.session_bundle.exporter) is deprecated and will be removed after 2017-06-30.
Instructions for updating:
No longer supported. Switch to SavedModel immediately.
從warning中 顯然可以知道這種方法要被拋棄了,不再支持這種方法了另凌, 建議我們轉(zhuǎn)用 SaveModel方法谱轨。
填坑大法: 使用 SaveModel
def save_model_to_serving(model, export_version, export_path='prod_models'):
print(model.input, model.output)
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'voice': model.input}, outputs={'scores': model.output})
export_path = os.path.join(
tf.compat.as_bytes(export_path),
tf.compat.as_bytes(str(export_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'classification': signature,
},
legacy_init_op=legacy_init_op)
builder.save()