作者:Mao Chan
著作權(quán)歸作者所有。商業(yè)轉(zhuǎn)載請(qǐng)聯(lián)系作者獲得授權(quán)拗胜,非商業(yè)轉(zhuǎn)載請(qǐng)注明出處歧匈。
2016年,機(jī)器學(xué)習(xí)在 Alpha Go 與李世石的世紀(jì)之戰(zhàn)后變得更加炙手可熱王带。Google也在今年推出了 TensorFlow Serving 又加了一把火淑蔚。
TensorFlow Serving 是一個(gè)用于機(jī)器學(xué)習(xí)模型 serving 的高性能開(kāi)源庫(kù)。它可以將訓(xùn)練好的機(jī)器學(xué)習(xí)模型部署到線上愕撰,使用 gRPC 作為接口接受外部調(diào)用刹衫。更加讓人眼前一亮的是,它支持模型熱更新與自動(dòng)模型版本管理搞挣。這意味著一旦部署 TensorFlow Serving 后带迟,你再也不需要為線上服務(wù)操心,只需要關(guān)心你的線下模型訓(xùn)練囱桨。
今天我就帶大家來(lái)用 TensorFlow Serving 部署一個(gè)簡(jiǎn)單的 Linear Regression 模型仓犬。
以下演示運(yùn)行在 Ubuntu 16.04 LTS 之上。
TensorFlow Serving 處于快速迭代期舍肠。如果本文內(nèi)容與官方文檔矛盾搀继,請(qǐng)以官方文檔為參考窘面。
環(huán)境
TensorFlow Serving 目前依賴 Google 的開(kāi)源編譯工具?Bazel。Bazel 是 Google 內(nèi)部編譯工具 Blaze 的開(kāi)源版本律歼,功能與性能基本一致民镜。具體的安裝可以參考官方文檔。此外還需要安裝?gRPC?(Google 又一個(gè)內(nèi)部工具的開(kāi)源版)险毁。
之后請(qǐng)參考官方安裝指南完成制圈。值得注意的是,最后的 bazel build 將會(huì)需要大約30分鐘時(shí)間并占用約5-10G的空間(時(shí)間取決于機(jī)器性能)畔况。配合使用 -c opt 能一定程度加快 build鲸鹦。
模型訓(xùn)練
接下來(lái)我們用 TensorFlow 寫(xiě)一個(gè)簡(jiǎn)單的測(cè)試用 Linear Regression 模型。數(shù)據(jù)的話我就使用正弦函數(shù)生成 1000 個(gè)點(diǎn)跷跪,嘗試用一條直線去擬合馋嗜。
樣本數(shù)據(jù)生成如下:
# Generate input data
x_data = np.arange(100, step=.1)
y_data = x_data + 20 * np.sin(x_data / 10)
# Reshape data
x_data = np.reshape(x_data, (n_samples, 1))
y_data = np.reshape(y_data, (n_samples, 1))
然后用一個(gè)簡(jiǎn)單的 y = wx + b 來(lái)做一個(gè)訓(xùn)練,使用 Adam 算法吵瞻。簡(jiǎn)單調(diào)整了下參數(shù):
sample = 1000, learning_rate = 0.01, batch_size = 100, n_steps = 500
# Placeholders for batched input
x = tf.placeholder(tf.float32, shape=(batch_size, 1))
y = tf.placeholder(tf.float32, shape=(batch_size, 1))
# Do training
with tf.variable_scope('test'):
? ? w = tf.get_variable('weights', (1, 1), initializer=tf.random_normal_initializer())
? ? b = tf.get_variable('bias', (1,), initializer=tf.constant_initializer(0))
? ? y_pred = tf.matmul(x, w) + b
? ? loss = tf.reduce_sum((y - y_pred) ** 2 / n_samples)
? ? opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
? ? with tf.Session() as sess:
? ? ? ? sess.run(tf.initialize_all_variables())
? ? ? ? for _ in range(n_steps):
? ? ? ? ? ? indices = np.random.choice(n_samples, batch_size)
? ? ? ? ? ? x_batch = x_data[indices]
? ? ? ? ? ? y_batch = y_data[indices]
? ? ? ? ? ? _, loss_val = sess.run([opt, loss], feed_dict={x:x_batch, y:y_batch})
? ? ? ? print w.eval()
? ? ? ? print b.eval()
? ? ? ? print loss_val
大致把 loss 收斂在 15.8 左右葛菇。精度應(yīng)該足夠了,畢竟只是一個(gè)簡(jiǎn)單的測(cè)試用模型橡羞。
模型導(dǎo)出
接下來(lái)的就是本文的重點(diǎn):導(dǎo)出模型眯停。
tf.train.Saver
用于保存和恢復(fù)Variable。它可以非常方便的保存當(dāng)前模型的變量或者倒入之前訓(xùn)練好的變量卿泽。一個(gè)最簡(jiǎn)單的運(yùn)用:
saver - tf.train.Saver()
# Save the variables to disk.
saver.save(sess, "/tmp/test.ckpt")
# Restore variables from disk.
saver.restore(sess, "/tmp/test.ckpt")
tf.contrib.session_bundle.exporter.Exporter
導(dǎo)出模型還需要這個(gè) Exporter 的協(xié)助莺债。令人尷尬的是這個(gè) Exporter 太新了,還沒(méi)有 API 文檔支持签夭,只能參考 Github 的代碼實(shí)現(xiàn)齐邦。
Exporter 的基本使用方式是
傳入 saver 構(gòu)造一個(gè)實(shí)例
調(diào)用?init?定義模型的 graph 和 input/output
使用 export 導(dǎo)出為文件
model_exporter = exporter.Exporter(saver)
model_exporter.init(
? ? sess.graph.as_graph_def(),
? ? named_graph_signatures={
? ? ? ? 'inputs': exporter.generic_signature({'x': x}),
? ? ? ? 'outputs': exporter.generic_signature({'y': y_pred})})
model_exporter.export(FLAGS.work_dir,? ? ? ?
? ? ? ? ? ? ? ? ? ? ? tf.constant(FLAGS.export_version),
? ? ? ? ? ? ? ? ? ? ? sess)
大功告成!編譯第租!我們成功導(dǎo)出了一個(gè)可以部署在 TensorFlow Serving 上的模型措拇。它接受一個(gè) x 值然后返回一個(gè) y 值。導(dǎo)出的文件夾以 version 命名慎宾,包含用于部署的 meta 文件, 模型 checkpoint 文件和序列化的模型 graph:
/tmp/test/00000001checkpoint export-00000-of-00001 export.meta
模型部署
部署的方式非常簡(jiǎn)單儡羔,只需要以下兩步:
$ bazel build //tensorflow_serving/model_servers:tensorflow_model_server$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=test --model_base_path=/tmp/test/
我們看到 TensorFlow Serving 成功加載了我們剛剛導(dǎo)出的 model。并且還在不斷嘗試 poll 新的 model:
$ bazel build //tensorflow_serving/model_servers:tensorflow_model_server
$bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=test --model_base_path=/tmp/test/
我們看到 TensorFlow Serving 成功加載了我們剛剛導(dǎo)出的 model璧诵。并且還在不斷嘗試 poll 新的 model:
客戶端
接下來(lái)我們寫(xiě)一個(gè)簡(jiǎn)單的 Client 來(lái)調(diào)用下我們部署好的 Model。這里我們需要用到 TensorFlow Serving 的 Predict API 和 gRPC 的 implementations.insecure_channel 來(lái)construct 一個(gè) request仇冯。特別要注意的是 input 的 signature 和數(shù)據(jù)必須和之前 export 的模型匹配之宿。本例中為 名稱為 x, float32類型苛坚,大小為 [100, 1] 的 Tensor比被。
from grpc.beta import implementations
import numpy as np
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
tf.app.flags.DEFINE_string('server', 'localhost:9000',
? ? ? ? ? ? ? ? ? ? ? ? ? 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS
n_samples = 100
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
# Generate test data
x_data = np.arange(n_samples, step=1, dtype=np.float32)
x_data = np.reshape(x_data, (n_samples, 1))
# Send request
request = predict_pb2.PredictRequest()
request.model_spec.name = 'test'
? request.inputs['x'].CopyFrom(tf.contrib.util.make_tensor_proto(x_data, shape=[100, 1]))
result = stub.Predict(request, 10.0)? # 10 secs timeout
別忘了配置一下 bazel 的 BUILD 文件:
py_binary(
? ? name = "test_client",
? ? srcs = [
? ? ? ? "test_client.py",
? ? ],
? ? deps = [
? ? ? ? "http://tensorflow_serving/apis:predict_proto_py_pb2",
? ? ? ? "http://tensorflow_serving/apis:prediction_service_proto_py_pb2",
? ? ? ? "@org_tensorflow//tensorflow:tensorflow_py",
? ? ],
)
最后編譯運(yùn)行色难,就能看到在線預(yù)測(cè)結(jié)果啦!
bazel build //tensorflow_serving/test:test_client && ./bazel-bin/tensorflow_serving/test/test_client
延伸
TensorFlow 封裝了眾多常用模型成為?Estimator等缀,幫助用戶避免了冗長(zhǎng)易錯(cuò)的算法實(shí)現(xiàn)部分枷莉。比如以上的例子就可以完全用 LinearRegressor 來(lái)替換。只需要幾行代碼簡(jiǎn)單地調(diào)用 fit() 函數(shù)就能輕松得到收斂的模型尺迂。唯一不足的是目前與 TensorFlow Serving 還不能 100% 兼容笤妙。雖然 Google 還在全力完善 TensorFlow Serving,但是距離完善還需要一定的時(shí)間噪裕。
如果既想要使用方便快捷的的 Estimator 蹲盘,又想線上部署呢?當(dāng)然也是有辦法的膳音,筆者鉆研了一下后召衔,實(shí)現(xiàn)了一個(gè)用 Estimator 訓(xùn)練數(shù)據(jù),導(dǎo)出模型后再部署上線的方法祭陷。最后用這個(gè)線上部署的模型實(shí)現(xiàn)一個(gè)在線評(píng)估房屋價(jià)值的系統(tǒng)苍凛。