163 строки
5.0 KiB
Python
163 строки
5.0 KiB
Python
from __future__ import print_function, division
|
|
|
|
from keras.datasets import mnist
|
|
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
|
|
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
|
|
from keras.layers.advanced_activations import LeakyReLU
|
|
from keras.layers.convolutional import UpSampling2D, Conv2D
|
|
from keras.models import Sequential, Model
|
|
from keras.optimizers import Adam
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import sys
|
|
|
|
import numpy as np
|
|
|
|
class LSGAN():
|
|
def __init__(self):
|
|
self.img_rows = 28
|
|
self.img_cols = 28
|
|
self.channels = 1
|
|
self.img_shape = (self.img_rows, self.img_cols, self.channels)
|
|
self.latent_dim = 100
|
|
|
|
optimizer = Adam(0.0002, 0.5)
|
|
|
|
# Build and compile the discriminator
|
|
self.discriminator = self.build_discriminator()
|
|
self.discriminator.compile(loss='mse',
|
|
optimizer=optimizer,
|
|
metrics=['accuracy'])
|
|
|
|
# Build the generator
|
|
self.generator = self.build_generator()
|
|
|
|
# The generator takes noise as input and generated imgs
|
|
z = Input(shape=(self.latent_dim,))
|
|
img = self.generator(z)
|
|
|
|
# For the combined model we will only train the generator
|
|
self.discriminator.trainable = False
|
|
|
|
# The valid takes generated images as input and determines validity
|
|
valid = self.discriminator(img)
|
|
|
|
# The combined model (stacked generator and discriminator)
|
|
# Trains generator to fool discriminator
|
|
self.combined = Model(z, valid)
|
|
# (!!!) Optimize w.r.t. MSE loss instead of crossentropy
|
|
self.combined.compile(loss='mse', optimizer=optimizer)
|
|
|
|
def build_generator(self):
|
|
|
|
model = Sequential()
|
|
|
|
model.add(Dense(256, input_dim=self.latent_dim))
|
|
model.add(LeakyReLU(alpha=0.2))
|
|
model.add(BatchNormalization(momentum=0.8))
|
|
model.add(Dense(512))
|
|
model.add(LeakyReLU(alpha=0.2))
|
|
model.add(BatchNormalization(momentum=0.8))
|
|
model.add(Dense(1024))
|
|
model.add(LeakyReLU(alpha=0.2))
|
|
model.add(BatchNormalization(momentum=0.8))
|
|
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
|
|
model.add(Reshape(self.img_shape))
|
|
|
|
model.summary()
|
|
|
|
noise = Input(shape=(self.latent_dim,))
|
|
img = model(noise)
|
|
|
|
return Model(noise, img)
|
|
|
|
def build_discriminator(self):
|
|
|
|
model = Sequential()
|
|
|
|
model.add(Flatten(input_shape=self.img_shape))
|
|
model.add(Dense(512))
|
|
model.add(LeakyReLU(alpha=0.2))
|
|
model.add(Dense(256))
|
|
model.add(LeakyReLU(alpha=0.2))
|
|
# (!!!) No softmax
|
|
model.add(Dense(1))
|
|
model.summary()
|
|
|
|
img = Input(shape=self.img_shape)
|
|
validity = model(img)
|
|
|
|
return Model(img, validity)
|
|
|
|
def train(self, epochs, batch_size=128, sample_interval=50):
|
|
|
|
# Load the dataset
|
|
(X_train, _), (_, _) = mnist.load_data()
|
|
|
|
# Rescale -1 to 1
|
|
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
|
|
X_train = np.expand_dims(X_train, axis=3)
|
|
|
|
# Adversarial ground truths
|
|
valid = np.ones((batch_size, 1))
|
|
fake = np.zeros((batch_size, 1))
|
|
|
|
for epoch in range(epochs):
|
|
|
|
# ---------------------
|
|
# Train Discriminator
|
|
# ---------------------
|
|
|
|
# Select a random batch of images
|
|
idx = np.random.randint(0, X_train.shape[0], batch_size)
|
|
imgs = X_train[idx]
|
|
|
|
# Sample noise as generator input
|
|
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
|
|
|
|
# Generate a batch of new images
|
|
gen_imgs = self.generator.predict(noise)
|
|
|
|
# Train the discriminator
|
|
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
|
|
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
|
|
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
|
|
|
|
|
|
# ---------------------
|
|
# Train Generator
|
|
# ---------------------
|
|
|
|
g_loss = self.combined.train_on_batch(noise, valid)
|
|
|
|
# Plot the progress
|
|
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
|
|
|
|
# If at save interval => save generated image samples
|
|
if epoch % sample_interval == 0:
|
|
self.sample_images(epoch)
|
|
|
|
def sample_images(self, epoch):
|
|
r, c = 5, 5
|
|
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
|
|
gen_imgs = self.generator.predict(noise)
|
|
|
|
# Rescale images 0 - 1
|
|
gen_imgs = 0.5 * gen_imgs + 0.5
|
|
|
|
fig, axs = plt.subplots(r, c)
|
|
cnt = 0
|
|
for i in range(r):
|
|
for j in range(c):
|
|
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
|
|
axs[i,j].axis('off')
|
|
cnt += 1
|
|
fig.savefig("images/mnist_%d.png" % epoch)
|
|
plt.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
gan = LSGAN()
|
|
gan.train(epochs=30000, batch_size=32, sample_interval=200)
|