下面要實(shí)現(xiàn)的功能是:g1和g2并聯(lián)揩悄,placeholder輸入x是3.0, g1實(shí)現(xiàn)系y=3*x,g2實(shí)現(xiàn)y+3, 最后輸出12
文件model_b.py如下:
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import saved_model_utils
MODEL_SAVE_PATH = "./models/" # 保存模型的路徑
with tf.Graph().as_default() as g2:
????????????????input1 = tf.placeholder(tf.float32,name='g2_input')
????????????????data = tf.Variable(3.)
????????????????mul = tf.add(input1,data)
????????????????tf.identity(mul,name='g2_output')
????????????????init = tf.global_variables_initializer()
????????????????saver = tf.train.Saver()
????????????????with tf.Session(graph=g2) as sess:
????????????????????????????sess.run(init)
????????????????????????????g1def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["g2_output"],
????????????????????????????????????????????????????????????????variable_names_whitelist=None,variable_names_blacklist=None)
? ? ? ? ? ? ? ? ? ? ? ? ? ? #tf.train.write_graph(g1def, MODEL_SAVE_PATH, 'model_g2.pb', as_text=False)
? ? ? ? ? ? ? ? ? ? ? ? ? ?? saver.save(sess, "./models/g2_model.ckpt")
文件model_combined.py如下:
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import saved_model_utils
MODEL_SAVE_PATH = "./models/" # 保存模型的路徑
#g1和g2并聯(lián)态蒂,輸入x是3.0, g1實(shí)現(xiàn)系y=3*x椒涯,g2實(shí)現(xiàn)y+3, 最后輸出12
with tf.Graph().as_default() as g1:
????????????input1 = tf.placeholder(tf.float32,name='g1_input')
????????????data = tf.Variable(3.)
????????????mul = tf.multiply(input1,data)
????????????tf.identity(mul,name='g1_output')
????????????init = tf.global_variables_initializer()
????????????with tf.Session(graph=g1) as sess:
????????????????????????sess.run(init)
????????????????????????g1def = graph_util.convert_variables_to_constants(sess, sess.graph_def,["g1_output"],
????????????????????????????????????????????????variable_names_whitelist=None,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? variable_names_blacklist=None)
with tf.Graph().as_default() as g2:
? ? ? ? ? ? ? ? ? with tf.Session(graph=g2) as sess:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? saver=tf.train.import_meta_graph('./models/g2_model.ckpt.meta')
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? saver.restore(sess, './models/g2_model.ckpt')
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? g2def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["g2_output"])
##------------------------------------------------------------
with tf.Graph().as_default() as g_combined:
????????????with tf.Session(graph=g_combined) as sess:
????????????????????????x = tf.placeholder(tf.float32, name="my_input")
????????????????????????y = tf.import_graph_def(g1def, input_map={"g1_input:0": x}, return_elements=["g1_output:0"])
????????????????????????z, = tf.import_graph_def(g2def, input_map={"g2_input:0": y}, return_elements=["g2_output:0"])
????????????????????????tf.identity(z, "my_output")
????????????????????????print(sess.run(z,feed_dict={'my_input:0':3.}))
????????????????????????#保存1
????????????????????????#g_combineddef = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["my_output"])
????????????????????????#tf.train.write_graph(g_combineddef, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)
????????????????????????#保存2
????????????????????????#? tf.saved_model.simple_save(sess,
????????????????????????# ? "./modelbase",
????????????????????????# ? inputs={"my_input": x},
????????????????????????# ? outputs={"my_output": z})