【網(wǎng)絡(luò)設(shè)計(jì)】
采用全連接網(wǎng)絡(luò):
3層編碼权悟,784->256->128
3層解碼氓鄙,128->256->784
輸入:mnist手寫圖片
輸出:由網(wǎng)絡(luò)還原出來(lái)的圖片
目標(biāo):還原度越高越好
因此我們可以總結(jié)出宫峦,最簡(jiǎn)單的Auto-encoder和decoder其實(shí)就是特殊結(jié)構(gòu)的全連接神經(jīng)網(wǎng)絡(luò)
【代碼展示】
#定義數(shù)據(jù)
mnist = input_data.read_data_sets('./mnist', one_hot=True)
n_input=784
n_hidden_1=256
n_hidden_2=128
#定義批個(gè)數(shù)和學(xué)習(xí)速率禁灼,這些決定了學(xué)習(xí)成果
batch_size=100
lr=0.001
training_epoches=200
display_epoches=10
total_batch=mnist.count()/batch_size
#輸入谤绳,一個(gè)batch的圖片
tf_x=tf.placeholder(tf.float32,shape=[None,28*28])
examples_to_show=7
#定義網(wǎng)絡(luò)參數(shù)
weights={
'encoder_w1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
'encoder_w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
'decoder_w1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
'decoder_w2': tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
biases={
'encoder_b1':tf.Variable(tf.random_normal([n_hidden_1])),
'encoder_b2':tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b1':tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b2': tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
#定義網(wǎng)絡(luò)的運(yùn)算和連接方式
def encoder(x):
layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_w1']),biases['encoder_b1']))
layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_w2']),biases['encoder_b2']))
return layer_2
def decoder(x):
layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_w1']),biases['decoder_b1']))
layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_w2']),biases['decoder_b2']))
return layer_2
encoder_op=encoder(tf_x)
decoder_op=decoder(encoder_op)
y_pred=decoder_op
y_true=tf_x
#定義學(xué)習(xí)方式
cost=tf.reduce_mean(tf.pow(y_true-y_pred,2))
optimizer=tf.train.AdamOptimizer(lr).minimize(cost)
init=tf.initialize_all_variables()
#訓(xùn)練
with tf.Session()as sess:
sess.run(init)
total_batch
for i in range(training_epoches):
for j in range(total_batch):
batch_x, batch_y = mnist.train.nextbatch(batch_size)
_,c=sess.run([cost,optimizer],feed_dict={tf_x:batch_x})
if(j%display_epoches==0):
print("Epoch:%04d"%(j+1),"cost=","{:,%.9f}".format(c))
print("Optimize Finished!")
encode_decode=sess.run(y_pred,feed_dict={tf_x:mnist.test.images[:examples_to_show]})
f,a=plt.subplots(2,10,figsize=(10,2))
for i in range(examples_to_show):
a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))
a[1][i].imshow(np.reshape(encode_decode[i],(28,28)))
plt.show()
【注意】
1、采用AdamOptimizer塌计,效果最好
2挺身、解碼和編碼網(wǎng)絡(luò)架構(gòu)是對(duì)稱的
3、learningRate(lr)是個(gè)很重要的參數(shù)