關(guān)鍵代碼
第一步
tf.Variable
第二步
saver = tf.train.Saver()
第三步
saver.save或者saver.restore
保存變量
import sys
print(sys.version)
'''
3.5.3 |Continuum Analytics, Inc.| (default, May 15 2017, 10:43:23) [MSC v.1900 64 bit (AMD64)]
'''
import tensorflow as tf
import numpy as np
# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
# init= tf.initialize_all_variables() # tf 馬上就要廢棄這種寫法
# 替換成下面的寫法:
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "save/save_net.ckpt")
print("Save to path: ", save_path)
"""
Save to path: my_net/save_net.ckpt
"""
提取變量
# 先建立 W, b 的容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
# 這里不需要初始化步驟 init= tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
# 提取變量
saver.restore(sess, "save/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
"""
weights: [[ 1. 2. 3.]
[ 3. 4. 5.]]
biases: [[ 1. 2. 3.]]
"""
在變量很多的情況下,每個變量都加name很麻煩,可以用下面這種形式
保存
with tf.variable_scope("regression"):
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32)
b = tf.Variable([[1,2,3]], dtype=tf.float32)
提取
with tf.variable_scope("regression"):
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32)
疑問芭商?
def regression(x):
W = tf.Variable(tf.zeros([784, 10]), dtype=tf.float32)
b = tf.Variable(tf.zeros([10]), dtype=tf.float32)
y = tf.nn.softmax(tf.matmul(x, W) + b)
return y, [W, b]
恢復(fù)變量
with tf.variable_scope("regression"):
y1, variables = model.regression(x)
saver = tf.train.Saver(variables)
在恢復(fù)變量時靴跛,w和b必須指定dtype或者name,不然報錯
但是下面這種情況就不用指定
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)