Wasserstein GAN: 'discriminator' => 'critic'

This commit is contained in:
eriklindernoren 2018-05-11 19:56:34 +02:00
Родитель ce7dc229be
Коммит 0293a4e6ff
6 изменённых файлов: 98 добавлений и 107 удалений

Просмотреть файл

@ -23,7 +23,6 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su
+ [DiscoGAN](#discogan)
+ [DualGAN](#dualgan)
+ [Generative Adversarial Network](#gan)
+ [Improved Wasserstein GAN](#improved-wgan)
+ [InfoGAN](#infogan)
+ [LSGAN](#lsgan)
+ [Pix2Pix](#pix2pix)
@ -31,6 +30,7 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su
+ [Semi-Supervised GAN](#sgan)
+ [Super-Resolution GAN](#srgan)
+ [Wasserstein GAN](#wgan)
+ [Wasserstein GAN GP](#wgan-gp)
## Installation
$ git clone https://github.com/eriklindernoren/Keras-GAN
@ -268,23 +268,6 @@ $ python3 gan_rgb.py
<img src="gan/etc/adam.gif" width="640"\>
</p>
### Improved WGAN
Implementation of _Improved Training of Wasserstein GANs_.
[Code](improved_wgan/improved_wgan.py)
Paper: https://arxiv.org/abs/1704.00028
#### Example
```
$ cd improved_wgan/
$ python3 improved_wgan.py
```
<p align="center">
<img src="http://eriklindernoren.se/images/imp_wgan.gif" width="640"\>
</p>
### InfoGAN
Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_.
@ -413,3 +396,20 @@ $ python3 wgan.py
<p align="center">
<img src="http://eriklindernoren.se/images/wgan2.png" width="640"\>
</p>
### WGAN GP
Implementation of _Improved Training of Wasserstein GANs_.
[Code](wgan_gp/wgan_gp.py)
Paper: https://arxiv.org/abs/1704.00028
#### Example
```
$ cd wgan_gp/
$ python3 wgan_gp.py
```
<p align="center">
<img src="http://eriklindernoren.se/images/imp_wgan.gif" width="640"\>
</p>

Просмотреть файл

@ -54,15 +54,14 @@ class DCGAN():
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
@ -82,17 +81,17 @@ class DCGAN():
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
@ -109,7 +108,7 @@ class DCGAN():
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train.astype(np.float32) / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
half_batch = int(batch_size / 2)
@ -125,7 +124,7 @@ class DCGAN():
imgs = X_train[idx]
# Sample noise and generate a half batch of new images
noise = np.random.normal(0, 1, (half_batch, 100))
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Train the discriminator (real classified as ones and generated as zeros)
@ -138,7 +137,7 @@ class DCGAN():
# ---------------------
# Sample generator input
noise = np.random.normal(0, 1, (batch_size, 100))
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (wants discriminator to mistake images as real)
g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))
@ -152,7 +151,7 @@ class DCGAN():
def save_imgs(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1

Просмотреть файл

@ -21,15 +21,17 @@ class WGAN():
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# Following parameter and optimizer set as recommended in paper
self.n_critic = 5
self.clip_value = 0.01
optimizer = RMSprop(lr=0.00005)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss=self.wasserstein_loss,
# Build and compile the critic
self.critic = self.build_critic()
self.critic.compile(loss=self.wasserstein_loss,
optimizer=optimizer,
metrics=['accuracy'])
@ -41,13 +43,12 @@ class WGAN():
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
self.critic.trainable = False
# The discriminator takes generated images as input and determines validity
valid = self.discriminator(img)
# The critic takes generated images as input and determines validity
valid = self.critic(img)
# The combined model (stacked generator and discriminator) takes
# noise as input => generates images => determines validity
# The combined model (stacked generator and critic)
self.combined = Model(z, valid)
self.combined.compile(loss=self.wasserstein_loss,
optimizer=optimizer,
@ -58,21 +59,18 @@ class WGAN():
def build_generator(self):
noise_shape = (100,)
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape))
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
model.add(Activation("tanh"))
@ -83,37 +81,35 @@ class WGAN():
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
def build_critic(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1))
model.summary()
img = Input(shape=img_shape)
features = model(img)
valid = Dense(1, activation="linear")(features)
validity = model(img)
return Model(img, valid)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
@ -138,18 +134,18 @@ class WGAN():
idx = np.random.randint(0, X_train.shape[0], half_batch)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (half_batch, 100))
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
# Generate a half batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, -np.ones((half_batch, 1)))
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.ones((half_batch, 1)))
# Train the critic
d_loss_real = self.critic.train_on_batch(imgs, -np.ones((half_batch, 1)))
d_loss_fake = self.critic.train_on_batch(gen_imgs, np.ones((half_batch, 1)))
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
# Clip discriminator weights
for l in self.discriminator.layers:
# Clip critic weights
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)
@ -159,7 +155,7 @@ class WGAN():
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, 100))
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator
g_loss = self.combined.train_on_batch(noise, -np.ones((batch_size, 1)))
@ -173,7 +169,7 @@ class WGAN():
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -26,30 +26,31 @@ import numpy as np
class RandomWeightedAverage(_Merge):
"""Provides a (random) weighted average between real and generated image samples"""
def _merge_function(self, inputs):
weights = K.random_uniform((32, 1, 1, 1))
return (weights * inputs[0]) + ((1 - weights) * inputs[1])
alpha = K.random_uniform((32, 1, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
class ImprovedWGAN():
class WGANGP():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# Following parameter and optimizer set as recommended in paper
self.n_critic = 5
optimizer = RMSprop(lr=0.00005)
# Build the generator and discriminator
# Build the generator and critic
self.generator = self.build_generator()
self.discriminator = self.build_discriminator()
self.critic = self.build_critic()
#-------------------------------
# Construct Computational Graph
# for Discriminator
# for the Critic
#-------------------------------
# Freeze generator's layers while training discriminator
# Freeze generator's layers while training critic
self.generator.trainable = False
# Image input (real sample)
@ -61,23 +62,23 @@ class ImprovedWGAN():
fake_img = self.generator(z_disc)
# Discriminator determines validity of the real and fake images
fake = self.discriminator(fake_img)
real = self.discriminator(real_img)
fake = self.critic(fake_img)
valid = self.critic(real_img)
# Construct weighted average between real and fake images
merged_img = RandomWeightedAverage()([real_img, fake_img])
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample
valid_merged = self.discriminator(merged_img)
validity_interpolated = self.critic(interpolated_img)
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(self.gradient_penalty_loss,
averaged_samples=merged_img)
averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
self.discriminator_model = Model(inputs=[real_img, z_disc],
outputs=[real, fake, valid_merged])
self.discriminator_model.compile(loss=[self.wasserstein_loss,
self.critic_model = Model(inputs=[real_img, z_disc],
outputs=[valid, fake, validity_interpolated])
self.critic_model.compile(loss=[self.wasserstein_loss,
self.wasserstein_loss,
partial_gp_loss],
optimizer=optimizer,
@ -87,8 +88,8 @@ class ImprovedWGAN():
# for Generator
#-------------------------------
# For the generator we freeze the discriminator's layers
self.discriminator.trainable = False
# For the generator we freeze the critic's layers
self.critic.trainable = False
self.generator.trainable = True
# Sampled noise for input to generator
@ -96,7 +97,7 @@ class ImprovedWGAN():
# Generate images based of noise
img = self.generator(z_gen)
# Discriminator determines validity
valid = self.discriminator(img)
valid = self.critic(img)
# Defines generator model
self.generator_model = Model(z_gen, valid)
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
@ -125,21 +126,18 @@ class ImprovedWGAN():
def build_generator(self):
noise_shape = (100,)
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape))
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
model.add(Activation("tanh"))
@ -150,37 +148,35 @@ class ImprovedWGAN():
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
def build_critic(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1))
model.summary()
img = Input(shape=img_shape)
features = model(img)
valid = Dense(1, activation="linear")(features)
validity = model(img)
return Model(img, valid)
return Model(img, validity)
def train(self, epochs, batch_size, sample_interval=50):
@ -207,9 +203,9 @@ class ImprovedWGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# Sample generator input
noise = np.random.normal(0, 1, (batch_size, 100))
# Train the discriminator
d_loss = self.discriminator_model.train_on_batch([imgs, noise],
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the critic
d_loss = self.critic_model.train_on_batch([imgs, noise],
[valid, fake, dummy])
# ---------------------
@ -217,12 +213,12 @@ class ImprovedWGAN():
# ---------------------
# Sample generator input
noise = np.random.normal(0, 1, (batch_size, 100))
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator
g_loss = self.generator_model.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss))
print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
@ -230,7 +226,7 @@ class ImprovedWGAN():
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
@ -248,5 +244,5 @@ class ImprovedWGAN():
if __name__ == '__main__':
wgan = ImprovedWGAN()
wgan = WGANGP()
wgan.train(epochs=30000, batch_size=32, sample_interval=100)