參考: ?Tensorflow 模型線上部署
????構(gòu)建 TensorFlow Serving Java 客戶端
-
docker安裝及部署
-
windows下docker安裝
-
tf-serving
??下載tensorflow服務(wù)并使用docker部署盟猖,這一步如果占用C盤空間太大的話福侈,可以使用Hyper-v工具將下載的鏡像轉(zhuǎn)到其他盤
# 在 cmd 中執(zhí)行以下命令 docker pull tensorflow/serving # 下載鏡像 docker run -itd -p 5000:5000 --name tfserving tensorflow/serving # 運(yùn)行鏡像并指定鏡像名 docker ps # 查看鏡像id dockerID docker cp ./mnist dockerID:/models # 將 pb 文件夾拷貝到容器中肖油,模型訓(xùn)練見下面 docker exec -it dockerID /bin/bash # 進(jìn)入到鏡像里面 tensorflow_model_server --port=5000 --model_name=mnist --model_base_path=/models/mnist # 容器內(nèi)運(yùn)行服務(wù)
-
-
訓(xùn)練模型
??使用官方給出的mnist樣例進(jìn)行訓(xùn)練艰额,改下代碼路徑就可以,訓(xùn)練得到pb文件如下劳景,并使用
saved_model_cli show --dir ./mnist/1 --all
命令查看節(jié)點(diǎn)名稱(后面客戶端使用)誉简,并將模型復(fù)制到docker里面docker cp ./mnist dockerID:/models
,此處注意文件夾層級(jí)
-
python端
??仿照官方代碼 mnist_clien.py編寫預(yù)測(cè)代碼
import grpc import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc server = 'localhost:5000' channel = grpc.insecure_channel(server) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() request.model_spec.name = 'mnist' request.model_spec.signature_name = 'predict_images' test_data_set = mnist_input_data.read_data_sets('./data').test image, label = test_data_set.next_batch(1) request.inputs['images'].CopyFrom(tf.make_tensor_proto(image[0], shape=[1, image[0].size])) pred = stub.Predict(request, 5.0) score = pred.outputs['scores'].float_val print(score) # [1.6178478001727115e-10, 1.6928293322847278e-15, 1.6151154341059737e-05, 0.000658366538118571, 8.010060947860609e-10, 2.2359495588375466e-08, 3.5608297452131843e-13, 0.9993133544921875, 5.620326870570125e-09, 1.1990837265329901e-05]
-
Java端
??Java端流程差不多枢泰,主要是編譯proto麻煩一些
-
proto安裝
??windows下proto的安裝參考windows之google protobuf安裝與使用描融,下載proto-3.4.0并解壓,注意目錄不要有空格衡蚂,否則后面編譯會(huì)報(bào)錯(cuò)窿克,找到
protoc.exe
所在路徑,我的是D:\protoc-3.4.0-win32\bin
-
pom配置編譯proto
??此處主要參考構(gòu)建 TensorFlow Serving Java 客戶端毛甲,給出的那個(gè)proto文件列表太棒了(未理解為什么是這些文件年叮,對(duì)java-grpc不熟悉),仿照其流程玻募,下載
tensorflow
及tensorflow-serving
兩個(gè)項(xiàng)目只损,復(fù)制相應(yīng)的proto文件出來src/main/proto ├── tensorflow │ └── core │ ├── example │ │ ├── example.proto │ │ └── feature.proto │ ├── framework │ │ ├── attr_value.proto │ │ ├── function.proto │ │ ├── graph.proto │ │ ├── node_def.proto │ │ ├── op_def.proto │ │ ├── resource_handle.proto │ │ ├── tensor.proto │ │ ├── tensor_shape.proto │ │ ├── types.proto │ │ └── versions.proto │ └── protobuf │ ├── meta_graph.proto │ └── saver.proto └── tensorflow_serving └── apis ├── classification.proto ├── get_model_metadata.proto ├── inference.proto ├── input.proto ├── model.proto ├── predict.proto ├── prediction_service.proto └── regression.proto
??創(chuàng)建Maven工程,將上面的proto文件放在src/main下面七咧,在pom中添加以下信息跃惫,此處額外添加了編譯文件的輸入及輸出目錄,否則會(huì)報(bào)錯(cuò) protoc did not exit cleanly
<build> <plugins> <plugin> <groupId>org.xolstice.maven.plugins</groupId> <artifactId>protobuf-maven-plugin</artifactId> <version>0.5.0</version> <configuration> <protocExecutable>D:\protoc-3.4.0-win32\bin\protoc.exe</protocExecutable> <protoSourceRoot>${project.basedir}/src/main/proto/</protoSourceRoot> <outputDirectory>${project.basedir}/src/main/resources/</outputDirectory> </configuration> <executions> <execution> <goals> <goal>compile</goal> <goal>compile-custom</goal> </goals> </execution> </executions> </plugin> </plugins> </build> <dependencies> <dependency> <groupId>com.google.protobuf</groupId> <artifactId>protobuf-java</artifactId> <version>3.11.4</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> <version>1.28.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> <version>1.28.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty-shaded</artifactId> <version>1.28.0</version> </dependency> </dependencies>
??配置完后艾栋,執(zhí)行
maven -> protobuf:compile
編譯爆存,在resources目錄下會(huì)生成org及tensorflow兩個(gè)文件夾,將這兩個(gè)文件夾復(fù)制到src/main/java目錄下
-
預(yù)測(cè)
??編寫java程序進(jìn)行預(yù)測(cè)蝗砾,過程中發(fā)現(xiàn)沒有
tensorflow/serving/PredictionServiceGrpc.java
這個(gè)文件先较,試了很多方法都沒有編譯出來携冤,最后是直接把別人的給復(fù)制過來了,PredictionServiceGrpc闲勺,拷過來后發(fā)現(xiàn)報(bào)了@java.lang.Override
這幾行代碼提示有問題曾棕,直接將override
注釋掉在
src/main/java
下建表及類,編寫預(yù)測(cè)代碼菜循,完整代碼如下,運(yùn)行得預(yù)測(cè)結(jié)果package SimpleAdd; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import tensorflow.serving.Model; import org.tensorflow.framework.DataType; import org.tensorflow.framework.TensorProto; import org.tensorflow.framework.TensorShapeProto; import tensorflow.serving.Predict; import tensorflow.serving.PredictionServiceGrpc; public class MnistPredict { public static void main(String[] args) throws Exception { // create a channel for gRPC ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 5000).usePlaintext().build(); PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel); // create a modelspec Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder(); modelSpec.setName("mnist"); modelSpec.setSignatureName("predict_images"); Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder(); request.setModelSpec(modelSpec); // data shape & load data TensorShapeProto.Builder shape = TensorShapeProto.newBuilder(); shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(784)); TensorProto.Builder tensor = TensorProto.newBuilder(); tensor.setTensorShape(shape); tensor.setDtype(DataType.DT_FLOAT); for(int i=0; i<784; i++){ tensor.addFloatVal(0); } request.putInputs("images", tensor.build()); tensor.clear(); // Predict Predict.PredictResponse response = stub.predict(request.build()); System.out.println(response); TensorProto result = response.toBuilder().getOutputsOrThrow("scores"); System.out.println("predict: " + result.getFloatValList()); System.out.println("predict: " + response.getOutputsMap().get("scores").getFloatValList()); // predict: [0.032191742, 0.09621494, 0.06525445, 0.039610844, 0.05699038, 0.46822935, 0.040578533, 0.1338098, 0.009549928, 0.057570033] } }
-