Tensorflow 簡單實現(xiàn)自編碼器

自編碼器簡介

代碼及詳細(xì)注釋

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 31 16:05:38 2017
@author: mml
"""

import numpy as np
# 數(shù)據(jù)預(yù)處理模塊
import sklearn.preprocessing as prep
import tensorflow as tf
# 使用MNIST數(shù)據(jù)集
from tensorflow.examples.tutorials.mnist import input_data

# 參數(shù)初始化方法
# 自動根據(jù)某一層網(wǎng)絡(luò)的輸入模闲,輸出節(jié)點數(shù)量自動調(diào)整最合適分布
# fan_in輸入節(jié)點數(shù)量 fan_out輸出節(jié)點數(shù)量
def xavier_init(fan_in,fan_out,constant = 1):
    low = -constant*np.sqrt(6.0/(fan_in+fan_out))
    high = constant*np.sqrt(6.0/(fan_in+fan_out))
    # tf.random_uniform創(chuàng)建一個low到high之間的均勻分布
    return tf.random_uniform((fan_in,fan_out),minval = low,maxval = high,dtype = tf.float32)

class AdditiveGaussianNoiseAutoencoder(object):
    # 構(gòu)建函數(shù) 輸入變量數(shù)暂筝,隱層節(jié)點數(shù) 激活函數(shù) 優(yōu)化器 scale高斯噪聲系數(shù)
    def __init__(self,n_input,n_hidden,transfer_function=tf.nn.softplus,
                 optimizer=tf.train.AdamOptimizer(),scale=0.1):
        # 輸入變量數(shù)
        self.n_input = n_input
        # 隱層節(jié)點數(shù)
        self.n_hidden = n_hidden
        # 激活函數(shù)
        self.transfer = transfer_function
        self.scale = tf.placeholder(tf.float32)
        self.training_scale = scale
        # 參數(shù)初始化方法
        network_weights = self._initialize_weights()
        self.weights = network_weights
        # 為輸入x創(chuàng)建一個維度為n_input的placeholder
        self.x = tf.placeholder(tf.float32,[None,self.n_input])
        #  隱含層提取特征過程
        #  scale*tf.random_normal((n_input,))產(chǎn)生高斯噪聲
        #  self.x + scale*tf.random_normal((n_input,)) 為輸入加上高斯噪聲
        #  tf.matmul(self.x + scale*tf.random_normal((n_input,)),self.weights['w1']) 加入噪聲后的輸入乘以權(quán)重
        #  tf.add(tf.matmul(self.x + scale*tf.random_normal((n_input,)),self.weights['w1']),self.weights['b1'])) 最后加上偏置
        #  self.transfer() 對結(jié)果進(jìn)行激活函數(shù)處理
        self.hidden = self.transfer(tf.add(tf.matmul(self.x + scale*tf.random_normal
                                                     ((n_input,)),self.weights['w1']),
                                                                self.weights['b1']))
        #  經(jīng)過隱含層提取特征后游沿,我們需要在輸出層進(jìn)行數(shù)據(jù)復(fù)原重建操作
        #  重構(gòu)層直接把隱含層輸出乘以輸出層權(quán)重并加上偏置即可
        self.reconstruction = tf.add(tf.matmul(self.hidden,
                                               self.weights['w2']),self.weights['b2'])

        # 自編碼器的損失函數(shù) 平方誤差作為cost
        # tf.subtract(self.reconstruction,self.x) 重構(gòu)后的輸出和輸入相減
        # tf.pow求差的平方
        # tf.reduce_sum求所有平方誤差和
        self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(
                            self.reconstruction,self.x),2.0))
        # 定義優(yōu)化方法左电,對cost進(jìn)行優(yōu)化
        self.optimizer = optimizer.minimize(self.cost)
        # 全局參數(shù)初始化
        init = tf.global_variables_initializer()
        # 創(chuàng)建Session
        self.sess = tf.Session()
        self.sess.run(init)
    # 參數(shù)初始化函數(shù)
    def _initialize_weights(self):
        # 創(chuàng)建一個所有參數(shù)的字典
        all_weights = dict()
        # w1使用前面的xavier_init初始化
        all_weights['w1'] = tf.Variable(xavier_init(self.n_input,self.n_hidden))
        # 其它都初始化為0
        all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden],dtype = tf.float32))
        all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden,self.n_input],dtype = tf.float32))
        all_weights['b2'] = tf.Variable(tf.zeros([self.n_input],dtype = tf.float32))
        return all_weights
    # 定義一個batch數(shù)據(jù)進(jìn)行訓(xùn)練并返回當(dāng)前cost    
    def partial_fit(self,X):
        # 讓Session執(zhí)行計算流圖節(jié)點cost和optimizer
        # feed_dict為輸入數(shù)據(jù)X和噪聲系數(shù)
        cost,opt = self.sess.run((self.cost,self.optimizer),
                                 feed_dict = {self.x:X,self.scale:self.training_scale})
        return cost
    # 還需要一個只計算cost不訓(xùn)練的函數(shù)
    def calc_total_cost(self,X):
        # 讓Session只觸發(fā)計算流圖節(jié)點self.cost
        return self.sess.run(self.cost,feed_dict = {self.x:X,self.scale:self.training_scale})
    # 還需函數(shù)返回隱含層輸出結(jié)果(即提取的特征)
    def transform(self,X):
        # Session觸發(fā)計算節(jié)點hidden
        return self.sess.run(self.hidden,feed_dict = {self.x:X,
                                                      self.scale:self.training_scale})
    # 定義函數(shù)進(jìn)行單獨重建(輸入為隱含層輸出)    
    def generate(self,hidden = None):
        if hidden is None:
            hidden = np.random.normal(size = self.weights['b1'])
        return self.sess.run(self.reconstruction,feed_dict = {self.hidden:hidden})
    # 定義完整的重建穴亏,包括前面的transform和reconstruction
    def reconstruct(self,X):
        return self.sess.run(self.reconstruction,feed_dict = {self.x:X,self.scale:self.training_scale})
    # 獲取隱含層參數(shù)    
    def getWeights(self):
        return self.sess.run(self.weights['w1'])
    
    def getBiases(self):
        return self.sess.run(self.weights['b1'])
        
# 載入MINIST數(shù)據(jù)集
mnist = input_data.read_data_sets('MNIST_data',one_hot = True)

# 對訓(xùn)練和測試數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化處理
# 標(biāo)準(zhǔn)化讓數(shù)據(jù)變成0均值妄迁,且標(biāo)準(zhǔn)差為1的分布
def standard_scale(X_train,X_test):
    preprocessor = prep.StandardScaler().fit(X_train)
    X_train = preprocessor.transform(X_train)
    X_test = preprocessor.transform(X_test)
    return X_train,X_test

# 定義獲取隨機(jī)block數(shù)據(jù)的方法
def get_random_block_from_data(data,batch_size):
    # 取一個隨機(jī)整數(shù)
    start_index = np.random.randint(0,len(data)-batch_size)
    # 順序取到一個batch_size的數(shù)據(jù)
    return data[start_index:(start_index + batch_size)]

# 使用之前的標(biāo)準(zhǔn)化函數(shù)對訓(xùn)練集和測試集進(jìn)行標(biāo)準(zhǔn)化處理
X_train,X_test = standard_scale(mnist.train.images,mnist.test.images)

# 總訓(xùn)練樣本數(shù)
n_samples = int(mnist.train.num_examples)
# 最大訓(xùn)練輪數(shù)
training_epochs = 20
# batchsize
batch_size = 128
# 每一輪顯示一次cost
display_step = 1

# 創(chuàng)建AGN實例
# 輸入784(mnist數(shù)據(jù)28*28)
autoencoder = AdditiveGaussianNoiseAutoencoder(n_input = 784,
                                               n_hidden = 200,
                                               transfer_function = tf.nn.softplus,
                                               optimizer = tf.train.AdagradOptimizer(learning_rate = 0.001),scale = 0.01)

# 開始真正的訓(xùn)練過程
for epoch in range(training_epochs):
    avg_cost = 0.
    # 計算總共的batch數(shù)
    total_batch = int(n_samples / batch_size)
    for i in range(total_batch):
        # 使用get_random_block_from_data獲取隨機(jī)batch數(shù)據(jù)
        batch_xs = get_random_block_from_data(X_train,batch_size)
        # 使用partial_fit進(jìn)行訓(xùn)練卵渴,并返回cost
        cost = autoencoder.partial_fit(batch_xs)
        avg_cost += cost / n_samples * batch_size
    if epoch % display_step == 0:
        print "Epoch:",'%04d' % (epoch + 1),"cost=","{:.9f}".format(avg_cost)

print "Total cost: " + str(autoencoder.calc_total_cost(X_test))

結(jié)果

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市榆骚,隨后出現(xiàn)的幾起案子片拍,更是在濱河造成了極大的恐慌,老刑警劉巖妓肢,帶你破解...
    沈念sama閱讀 216,919評論 6 502
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件捌省,死亡現(xiàn)場離奇詭異,居然都是意外死亡碉钠,警方通過查閱死者的電腦和手機(jī)纲缓,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,567評論 3 392
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來喊废,“玉大人祝高,你說我怎么就攤上這事∥劭辏” “怎么了工闺?”我有些...
    開封第一講書人閱讀 163,316評論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長瓣蛀。 經(jīng)常有香客問我陆蟆,道長,這世上最難降的妖魔是什么惋增? 我笑而不...
    開封第一講書人閱讀 58,294評論 1 292
  • 正文 為了忘掉前任遍搞,我火速辦了婚禮,結(jié)果婚禮上器腋,老公的妹妹穿的比我還像新娘溪猿。我一直安慰自己,他們只是感情好纫塌,可當(dāng)我...
    茶點故事閱讀 67,318評論 6 390
  • 文/花漫 我一把揭開白布诊县。 她就那樣靜靜地躺著,像睡著了一般措左。 火紅的嫁衣襯著肌膚如雪依痊。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,245評論 1 299
  • 那天怎披,我揣著相機(jī)與錄音胸嘁,去河邊找鬼。 笑死凉逛,一個胖子當(dāng)著我的面吹牛性宏,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播状飞,決...
    沈念sama閱讀 40,120評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼毫胜,長吁一口氣:“原來是場噩夢啊……” “哼书斜!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起酵使,我...
    開封第一講書人閱讀 38,964評論 0 275
  • 序言:老撾萬榮一對情侶失蹤荐吉,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后口渔,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體样屠,經(jīng)...
    沈念sama閱讀 45,376評論 1 313
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,592評論 2 333
  • 正文 我和宋清朗相戀三年缺脉,在試婚紗的時候發(fā)現(xiàn)自己被綠了痪欲。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,764評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡枪向,死狀恐怖勤揩,靈堂內(nèi)的尸體忽然破棺而出咧党,到底是詐尸還是另有隱情秘蛔,我是刑警寧澤,帶...
    沈念sama閱讀 35,460評論 5 344
  • 正文 年R本政府宣布傍衡,位于F島的核電站深员,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏蛙埂。R本人自食惡果不足惜倦畅,卻給世界環(huán)境...
    茶點故事閱讀 41,070評論 3 327
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望绣的。 院中可真熱鬧叠赐,春花似錦、人聲如沸屡江。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,697評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽惩嘉。三九已至罢洲,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間文黎,已是汗流浹背惹苗。 一陣腳步聲響...
    開封第一講書人閱讀 32,846評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留耸峭,地道東北人桩蓉。 一個月前我還...
    沈念sama閱讀 47,819評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像劳闹,于是被迫代替她去往敵國和親触机。 傳聞我的和親對象是個殘疾皇子帚戳,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,665評論 2 354

推薦閱讀更多精彩內(nèi)容

  • Spring Cloud為開發(fā)人員提供了快速構(gòu)建分布式系統(tǒng)中一些常見模式的工具(例如配置管理,服務(wù)發(fā)現(xiàn)儡首,斷路器片任,智...
    卡卡羅2017閱讀 134,654評論 18 139
  • Android 自定義View的各種姿勢1 Activity的顯示之ViewRootImpl詳解 Activity...
    passiontim閱讀 172,090評論 25 707
  • 發(fā)現(xiàn) 關(guān)注 消息 iOS 第三方庫、插件蔬胯、知名博客總結(jié) 作者大灰狼的小綿羊哥哥關(guān)注 2017.06.26 09:4...
    肇東周閱讀 12,098評論 4 62
  • 01 java 被static修飾的成員變量屬于類不屬于這個類的某個對象 person類的nameeigenper...
    f73d56f67419閱讀 274評論 0 0
  • 一盅粪、 四有文章中的哪一有對我啟發(fā)最大窒悔? 所謂的四有是指:有趣、有用、有料涩金、有力,其中”有力”還分為審美和情感兩種類...
    熱拿鐵閱讀 308評論 1 0