參考:
使用spark-scala調(diào)用tensorflow2.0訓(xùn)練好的模型
1. 使用TF2訓(xùn)練并保存模型:
import tensorflow as tf
from tensorflow.keras import models,layers,optimizers
## 樣本數(shù)量
n = 800
## 生成測(cè)試用數(shù)據(jù)集
X = tf.random.uniform([n,2],minval=-10,maxval=10)
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)
Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0) # @表示矩陣乘法,增加正態(tài)擾動(dòng)
## 建立模型
tf.keras.backend.clear_session()
inputs = layers.Input(shape = (2,),name ="inputs") #設(shè)置輸入名字為inputs
outputs = layers.Dense(1, name = "outputs")(inputs) #設(shè)置輸出名字為outputs
linear = models.Model(inputs = inputs,outputs = outputs)
linear.summary()
## 使用fit方法進(jìn)行訓(xùn)練
linear.compile(optimizer="rmsprop",loss="mse",metrics=["mae"])
linear.fit(X,Y,batch_size = 8,epochs = 100)
tf.print("w = ",linear.layers[1].kernel)
tf.print("b = ",linear.layers[1].bias)
## 將模型保存成pb格式文件
export_path = "/your_path/tf2_linear"
linear.save(export_path, save_format="tf")
保存模型目錄:
~/demo/your_path tree
.
└── tf2_linear
├── assets
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
3 directories, 3 files
2. 使用Java加載模型并預(yù)測(cè)
查看模型細(xì)節(jié)(Java加載模型及預(yù)測(cè)需要)
~/demo/your_path saved_model_cli show --dir ./tf2_linear --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: serving_default_inputs:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
maven依賴
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.73</version>
</dependency>
Java代碼
package com.ml.demo.tf;
import com.alibaba.fastjson.JSON;
import org.tensorflow.*;
public class PredictNN {
public static void main(String args[]){
Session session = SavedModelBundle.load("/your_path/tf2_linear",
"serve").session();
float[][] input = {
{2.6327686f, -9.201903f},
{ -1.3209248f, 8.569574f},
{ -5.6642127f, 3.3681698f},
{ 9.604832f, 5.9664965f},
{ -0.8812313f, -6.76733f}
};
System.out.println("input: \n" + JSON.toJSONString(input));
Tensor inputTensor = Tensor.create(input);
Tensor resultTensor = session.runner()
.feed("serving_default_inputs:0", inputTensor)
.fetch("StatefulPartitionedCall:0")
.run().get(0);
float[][] result = new float[input.length][1];
resultTensor.copyTo(result);
System.out.println("result: \n" + JSON.toJSONString(result));
session.close();
}
}
輸出日志