from __future__ import print_function, division
from sklearn import datasets
import math
import matplotlib.pyplot as plt
import numpy as np
import progressbar
from sklearn.datasets import fetch_mldata
from mlfromscratch.deep_learning.optimizers import Adam
from mlfromscratch.deep_learning.loss_functions import CrossEntropy, SquareLoss
from mlfromscratch.deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
from mlfromscratch.deep_learning import NeuralNetwork
class Autoencoder():
"""An Autoencoder with deep fully-connected neural nets.
Training Data: MNIST Handwritten Digits (28x28 images)
"""
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.img_dim = self.img_rows * self.img_cols
self.latent_dim = 128 # The dimension of the data embedding
optimizer = Adam(learning_rate=0.0002, b1=0.5)
loss_function = SquareLoss
self.encoder = self.build_encoder(optimizer, loss_function)
self.decoder = self.build_decoder(optimizer, loss_function)
self.autoencoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
self.autoencoder.layers.extend(self.encoder.layers)
self.autoencoder.layers.extend(self.decoder.layers)
print ()
self.autoencoder.summary(name="Variational Autoencoder")
def build_encoder(self, optimizer, loss_function):
encoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
encoder.add(Dense(512, input_shape=(self.img_dim,)))
encoder.add(Activation('leaky_relu'))
encoder.add(BatchNormalization(momentum=0.8))
encoder.add(Dense(256))
encoder.add(Activation('leaky_relu'))
encoder.add(BatchNormalization(momentum=0.8))
encoder.add(Dense(self.latent_dim))
return encoder
def build_decoder(self, optimizer, loss_function):
decoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
decoder.add(Dense(256, input_shape=(self.latent_dim,)))
decoder.add(Activation('leaky_relu'))
decoder.add(BatchNormalization(momentum=0.8))
decoder.add(Dense(512))
decoder.add(Activation('leaky_relu'))
decoder.add(BatchNormalization(momentum=0.8))
decoder.add(Dense(self.img_dim))
decoder.add(Activation('tanh'))
return decoder
def train(self, n_epochs, batch_size=128, save_interval=50):
mnist = fetch_mldata('MNIST original')
X = mnist.data
y = mnist.target
# Rescale [-1, 1]
X = (X.astype(np.float32) - 127.5) / 127.5
for epoch in range(n_epochs):
# Select a random half batch of images
idx = np.random.randint(0, X.shape[0], batch_size)
imgs = X[idx]
# Train the Autoencoder
loss, _ = self.autoencoder.train_on_batch(imgs, imgs)
# Display the progress
print ("%d [D loss: %f]" % (epoch, loss))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
self.save_imgs(epoch, X)
def save_imgs(self, epoch, X):
r, c = 5, 5 # Grid size
# Select a random half batch of images
idx = np.random.randint(0, X.shape[0], r*c)
imgs = X[idx]
# Generate images and reshape to image shape
gen_imgs = self.autoencoder.predict(imgs).reshape((-1, self.img_rows, self.img_cols))
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
plt.suptitle("Autoencoder")
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("ae_%d.png" % epoch)
plt.close()
if __name__ == '__main__':
ae = Autoencoder()
ae.train(n_epochs=200000, batch_size=64, save_interval=400)
[Machine Learning From Scratch]-unsupervised_learning-autoencoder
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
- 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來澳腹,“玉大人织盼,你說我怎么就攤上這事〗此” “怎么了沥邻?”我有些...
- 文/不壞的土叔 我叫張陵,是天一觀的道長羊娃。 經(jīng)常有香客問我唐全,道長,這世上最難降的妖魔是什么迁沫? 我笑而不...
- 正文 為了忘掉前任芦瘾,我火速辦了婚禮,結(jié)果婚禮上集畅,老公的妹妹穿的比我還像新娘近弟。我一直安慰自己,他們只是感情好挺智,可當我...
- 文/花漫 我一把揭開白布祷愉。 她就那樣靜靜地躺著,像睡著了一般赦颇。 火紅的嫁衣襯著肌膚如雪二鳄。 梳的紋絲不亂的頭發(fā)上,一...
- 文/蒼蘭香墨 我猛地睜開眼脖苏,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了定踱?” 一聲冷哼從身側(cè)響起棍潘,我...
- 正文 年R本政府宣布,位于F島的核電站瑰剃,受9級特大地震影響齿诉,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜晌姚,卻給世界環(huán)境...
- 文/蒙蒙 一粤剧、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧挥唠,春花似錦抵恋、人聲如沸。這莊子的主人今日做“春日...
- 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至唤锉,卻和暖如春世囊,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背窿祥。 一陣腳步聲響...
推薦閱讀更多精彩內(nèi)容
- 全連接層實現(xiàn)代碼: 全連接神經(jīng)網(wǎng)絡(luò)做線性回歸 一箱亿、定義前向跛锌、后向傳播本文將用numpy實現(xiàn)全連接層的前向過程和反向...
- Max Pooling前向過程 Max Pooling反向過程 Average Pooling前向過程 Avera...