在loss函數(shù)寫法上做改進,代碼更簡單贬墩;
import tensorflow as tf
from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Reshape,Conv2DTranspose,Conv2D,MaxPool2D,Flatten,BatchNormalization
import numpy as np
import matplotlib.pyplot as plt
(train_x,train_y),(test_x,test_y) = load_data()
train_x = train_x[:150]/255
x_real,y_real = zip(*zip(train_x,np.ones(train_x.shape[0])))
g = Sequential([Dense(4*4*128,input_shape=(10,)),
Reshape((4,4,128)),
Conv2DTranspose(64,(4,4),padding="valid",activation="relu"),
BatchNormalization(),
Conv2DTranspose(32,(2,2),strides=(2, 2),padding="same",activation="relu"),
BatchNormalization(),
Conv2DTranspose(1, (2,2),strides=(2, 2),padding="same",activation="tanh"),
Reshape((28,28))])
d = Sequential([Reshape((28,28,1),input_shape=(28,28)),
Conv2D(32,(2,2),padding="same",activation="relu"),
MaxPool2D((2,2)),
Conv2D(64,(2,2),padding="same",activation="relu"),
MaxPool2D((2,2)),
Conv2D(64,(2,2),padding="valid",activation="relu"),
MaxPool2D((2,2)),
Flatten(),
Dense(1,activation="sigmoid")])
gan = Sequential([g,d])
d.compile(optimizer="adam",loss="binary_crossentropy",metrics=['accuracy'])
for i in range(50):
print(f"===============判別器第{i+1}輪訓(xùn)練================")
d.trainable = True
x_fake,y_fake = zip(*zip(g(tf.random.uniform((train_x.shape[0],10),1,0)),np.zeros(train_x.shape[0])))
x = x_real + x_fake
y = y_real + y_fake
dataset = tf.data.Dataset.from_tensor_slices((np.array(x),np.array(y))).shuffle(150).batch(20)
d.fit(dataset,epochs=2)
print(f"===============生成器第{i+1}輪訓(xùn)練================")
d.trainable = False
gan.compile(optimizer="adam",loss="binary_crossentropy")
x = tf.random.uniform((100,10),1,0)
y = 1-d(g(x))
gan.fit(x,y,epochs=50)
img = g(tf.random.uniform((1,10),1,0))[0]
plt.imshow(img)
plt.show()
訓(xùn)練20個來回,能看到可識別的效果妄呕。
可以調(diào)節(jié)加載數(shù)據(jù)量陶舞,當(dāng)然越多越慢!