Wasserstein GAN: 'discriminator' => 'critic'

This commit is contained in:
eriklindernoren 2018-05-11 19:56:34 +02:00
Родитель ce7dc229be
Коммит 0293a4e6ff
6 изменённых файлов: 98 добавлений и 107 удалений

Просмотреть файл

@ -23,7 +23,6 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su
+ [DiscoGAN](#discogan) + [DiscoGAN](#discogan)
+ [DualGAN](#dualgan) + [DualGAN](#dualgan)
+ [Generative Adversarial Network](#gan) + [Generative Adversarial Network](#gan)
+ [Improved Wasserstein GAN](#improved-wgan)
+ [InfoGAN](#infogan) + [InfoGAN](#infogan)
+ [LSGAN](#lsgan) + [LSGAN](#lsgan)
+ [Pix2Pix](#pix2pix) + [Pix2Pix](#pix2pix)
@ -31,6 +30,7 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su
+ [Semi-Supervised GAN](#sgan) + [Semi-Supervised GAN](#sgan)
+ [Super-Resolution GAN](#srgan) + [Super-Resolution GAN](#srgan)
+ [Wasserstein GAN](#wgan) + [Wasserstein GAN](#wgan)
+ [Wasserstein GAN GP](#wgan-gp)
## Installation ## Installation
$ git clone https://github.com/eriklindernoren/Keras-GAN $ git clone https://github.com/eriklindernoren/Keras-GAN
@ -268,23 +268,6 @@ $ python3 gan_rgb.py
<img src="gan/etc/adam.gif" width="640"\> <img src="gan/etc/adam.gif" width="640"\>
</p> </p>
### Improved WGAN
Implementation of _Improved Training of Wasserstein GANs_.
[Code](improved_wgan/improved_wgan.py)
Paper: https://arxiv.org/abs/1704.00028
#### Example
```
$ cd improved_wgan/
$ python3 improved_wgan.py
```
<p align="center">
<img src="http://eriklindernoren.se/images/imp_wgan.gif" width="640"\>
</p>
### InfoGAN ### InfoGAN
Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_. Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_.
@ -413,3 +396,20 @@ $ python3 wgan.py
<p align="center"> <p align="center">
<img src="http://eriklindernoren.se/images/wgan2.png" width="640"\> <img src="http://eriklindernoren.se/images/wgan2.png" width="640"\>
</p> </p>
### WGAN GP
Implementation of _Improved Training of Wasserstein GANs_.
[Code](wgan_gp/wgan_gp.py)
Paper: https://arxiv.org/abs/1704.00028
#### Example
```
$ cd wgan_gp/
$ python3 wgan_gp.py
```
<p align="center">
<img src="http://eriklindernoren.se/images/imp_wgan.gif" width="640"\>
</p>

Просмотреть файл

@ -54,15 +54,14 @@ class DCGAN():
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,))) model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Reshape((7, 7, 128))) model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D()) model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same")) model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D()) model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same")) model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same")) model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
model.add(Activation("tanh")) model.add(Activation("tanh"))
@ -82,17 +81,17 @@ class DCGAN():
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten()) model.add(Flatten())
model.add(Dense(1, activation='sigmoid')) model.add(Dense(1, activation='sigmoid'))
@ -109,7 +108,7 @@ class DCGAN():
(X_train, _), (_, _) = mnist.load_data() (X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1 # Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = X_train.astype(np.float32) / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3) X_train = np.expand_dims(X_train, axis=3)
half_batch = int(batch_size / 2) half_batch = int(batch_size / 2)
@ -125,7 +124,7 @@ class DCGAN():
imgs = X_train[idx] imgs = X_train[idx]
# Sample noise and generate a half batch of new images # Sample noise and generate a half batch of new images
noise = np.random.normal(0, 1, (half_batch, 100)) noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
gen_imgs = self.generator.predict(noise) gen_imgs = self.generator.predict(noise)
# Train the discriminator (real classified as ones and generated as zeros) # Train the discriminator (real classified as ones and generated as zeros)
@ -138,7 +137,7 @@ class DCGAN():
# --------------------- # ---------------------
# Sample generator input # Sample generator input
noise = np.random.normal(0, 1, (batch_size, 100)) noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (wants discriminator to mistake images as real) # Train the generator (wants discriminator to mistake images as real)
g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1))) g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))
@ -152,7 +151,7 @@ class DCGAN():
def save_imgs(self, epoch): def save_imgs(self, epoch):
r, c = 5, 5 r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100)) noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise) gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1 # Rescale images 0 - 1

Просмотреть файл

@ -21,15 +21,17 @@ class WGAN():
self.img_rows = 28 self.img_rows = 28
self.img_cols = 28 self.img_cols = 28
self.channels = 1 self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# Following parameter and optimizer set as recommended in paper # Following parameter and optimizer set as recommended in paper
self.n_critic = 5 self.n_critic = 5
self.clip_value = 0.01 self.clip_value = 0.01
optimizer = RMSprop(lr=0.00005) optimizer = RMSprop(lr=0.00005)
# Build and compile the discriminator # Build and compile the critic
self.discriminator = self.build_discriminator() self.critic = self.build_critic()
self.discriminator.compile(loss=self.wasserstein_loss, self.critic.compile(loss=self.wasserstein_loss,
optimizer=optimizer, optimizer=optimizer,
metrics=['accuracy']) metrics=['accuracy'])
@ -41,13 +43,12 @@ class WGAN():
img = self.generator(z) img = self.generator(z)
# For the combined model we will only train the generator # For the combined model we will only train the generator
self.discriminator.trainable = False self.critic.trainable = False
# The discriminator takes generated images as input and determines validity # The critic takes generated images as input and determines validity
valid = self.discriminator(img) valid = self.critic(img)
# The combined model (stacked generator and discriminator) takes # The combined model (stacked generator and critic)
# noise as input => generates images => determines validity
self.combined = Model(z, valid) self.combined = Model(z, valid)
self.combined.compile(loss=self.wasserstein_loss, self.combined.compile(loss=self.wasserstein_loss,
optimizer=optimizer, optimizer=optimizer,
@ -58,21 +59,18 @@ class WGAN():
def build_generator(self): def build_generator(self):
noise_shape = (100,)
model = Sequential() model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape)) model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Reshape((7, 7, 128))) model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D()) model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same")) model.add(Conv2D(128, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D()) model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding="same")) model.add(Conv2D(64, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=4, padding="same")) model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
model.add(Activation("tanh")) model.add(Activation("tanh"))
@ -83,37 +81,35 @@ class WGAN():
return Model(noise, img) return Model(noise, img)
def build_discriminator(self): def build_critic(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential() model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same")) model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten()) model.add(Flatten())
model.add(Dense(1))
model.summary() model.summary()
img = Input(shape=img_shape) img = Input(shape=img_shape)
features = model(img) validity = model(img)
valid = Dense(1, activation="linear")(features)
return Model(img, valid) return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50): def train(self, epochs, batch_size=128, sample_interval=50):
@ -138,18 +134,18 @@ class WGAN():
idx = np.random.randint(0, X_train.shape[0], half_batch) idx = np.random.randint(0, X_train.shape[0], half_batch)
imgs = X_train[idx] imgs = X_train[idx]
noise = np.random.normal(0, 1, (half_batch, 100)) noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
# Generate a half batch of new images # Generate a half batch of new images
gen_imgs = self.generator.predict(noise) gen_imgs = self.generator.predict(noise)
# Train the discriminator # Train the critic
d_loss_real = self.discriminator.train_on_batch(imgs, -np.ones((half_batch, 1))) d_loss_real = self.critic.train_on_batch(imgs, -np.ones((half_batch, 1)))
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.ones((half_batch, 1))) d_loss_fake = self.critic.train_on_batch(gen_imgs, np.ones((half_batch, 1)))
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
# Clip discriminator weights # Clip critic weights
for l in self.discriminator.layers: for l in self.critic.layers:
weights = l.get_weights() weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights] weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights) l.set_weights(weights)
@ -159,7 +155,7 @@ class WGAN():
# Train Generator # Train Generator
# --------------------- # ---------------------
noise = np.random.normal(0, 1, (batch_size, 100)) noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator # Train the generator
g_loss = self.combined.train_on_batch(noise, -np.ones((batch_size, 1))) g_loss = self.combined.train_on_batch(noise, -np.ones((batch_size, 1)))
@ -173,7 +169,7 @@ class WGAN():
def sample_images(self, epoch): def sample_images(self, epoch):
r, c = 5, 5 r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100)) noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise) gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1 # Rescale images 0 - 1

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -26,30 +26,31 @@ import numpy as np
class RandomWeightedAverage(_Merge): class RandomWeightedAverage(_Merge):
"""Provides a (random) weighted average between real and generated image samples""" """Provides a (random) weighted average between real and generated image samples"""
def _merge_function(self, inputs): def _merge_function(self, inputs):
weights = K.random_uniform((32, 1, 1, 1)) alpha = K.random_uniform((32, 1, 1, 1))
return (weights * inputs[0]) + ((1 - weights) * inputs[1]) return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
class ImprovedWGAN(): class WGANGP():
def __init__(self): def __init__(self):
self.img_rows = 28 self.img_rows = 28
self.img_cols = 28 self.img_cols = 28
self.channels = 1 self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels) self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# Following parameter and optimizer set as recommended in paper # Following parameter and optimizer set as recommended in paper
self.n_critic = 5 self.n_critic = 5
optimizer = RMSprop(lr=0.00005) optimizer = RMSprop(lr=0.00005)
# Build the generator and discriminator # Build the generator and critic
self.generator = self.build_generator() self.generator = self.build_generator()
self.discriminator = self.build_discriminator() self.critic = self.build_critic()
#------------------------------- #-------------------------------
# Construct Computational Graph # Construct Computational Graph
# for Discriminator # for the Critic
#------------------------------- #-------------------------------
# Freeze generator's layers while training discriminator # Freeze generator's layers while training critic
self.generator.trainable = False self.generator.trainable = False
# Image input (real sample) # Image input (real sample)
@ -61,23 +62,23 @@ class ImprovedWGAN():
fake_img = self.generator(z_disc) fake_img = self.generator(z_disc)
# Discriminator determines validity of the real and fake images # Discriminator determines validity of the real and fake images
fake = self.discriminator(fake_img) fake = self.critic(fake_img)
real = self.discriminator(real_img) valid = self.critic(real_img)
# Construct weighted average between real and fake images # Construct weighted average between real and fake images
merged_img = RandomWeightedAverage()([real_img, fake_img]) interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample # Determine validity of weighted sample
valid_merged = self.discriminator(merged_img) validity_interpolated = self.critic(interpolated_img)
# Use Python partial to provide loss function with additional # Use Python partial to provide loss function with additional
# 'averaged_samples' argument # 'averaged_samples' argument
partial_gp_loss = partial(self.gradient_penalty_loss, partial_gp_loss = partial(self.gradient_penalty_loss,
averaged_samples=merged_img) averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
self.discriminator_model = Model(inputs=[real_img, z_disc], self.critic_model = Model(inputs=[real_img, z_disc],
outputs=[real, fake, valid_merged]) outputs=[valid, fake, validity_interpolated])
self.discriminator_model.compile(loss=[self.wasserstein_loss, self.critic_model.compile(loss=[self.wasserstein_loss,
self.wasserstein_loss, self.wasserstein_loss,
partial_gp_loss], partial_gp_loss],
optimizer=optimizer, optimizer=optimizer,
@ -87,8 +88,8 @@ class ImprovedWGAN():
# for Generator # for Generator
#------------------------------- #-------------------------------
# For the generator we freeze the discriminator's layers # For the generator we freeze the critic's layers
self.discriminator.trainable = False self.critic.trainable = False
self.generator.trainable = True self.generator.trainable = True
# Sampled noise for input to generator # Sampled noise for input to generator
@ -96,7 +97,7 @@ class ImprovedWGAN():
# Generate images based of noise # Generate images based of noise
img = self.generator(z_gen) img = self.generator(z_gen)
# Discriminator determines validity # Discriminator determines validity
valid = self.discriminator(img) valid = self.critic(img)
# Defines generator model # Defines generator model
self.generator_model = Model(z_gen, valid) self.generator_model = Model(z_gen, valid)
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer) self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
@ -125,21 +126,18 @@ class ImprovedWGAN():
def build_generator(self): def build_generator(self):
noise_shape = (100,)
model = Sequential() model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape)) model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Reshape((7, 7, 128))) model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D()) model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same")) model.add(Conv2D(128, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D()) model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding="same")) model.add(Conv2D(64, kernel_size=4, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=4, padding="same")) model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
model.add(Activation("tanh")) model.add(Activation("tanh"))
@ -150,37 +148,35 @@ class ImprovedWGAN():
return Model(noise, img) return Model(noise, img)
def build_discriminator(self): def build_critic(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential() model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same")) model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8)) model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2)) model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten()) model.add(Flatten())
model.add(Dense(1))
model.summary() model.summary()
img = Input(shape=img_shape) img = Input(shape=img_shape)
features = model(img) validity = model(img)
valid = Dense(1, activation="linear")(features)
return Model(img, valid) return Model(img, validity)
def train(self, epochs, batch_size, sample_interval=50): def train(self, epochs, batch_size, sample_interval=50):
@ -207,9 +203,9 @@ class ImprovedWGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size) idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx] imgs = X_train[idx]
# Sample generator input # Sample generator input
noise = np.random.normal(0, 1, (batch_size, 100)) noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the discriminator # Train the critic
d_loss = self.discriminator_model.train_on_batch([imgs, noise], d_loss = self.critic_model.train_on_batch([imgs, noise],
[valid, fake, dummy]) [valid, fake, dummy])
# --------------------- # ---------------------
@ -217,12 +213,12 @@ class ImprovedWGAN():
# --------------------- # ---------------------
# Sample generator input # Sample generator input
noise = np.random.normal(0, 1, (batch_size, 100)) noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator # Train the generator
g_loss = self.generator_model.train_on_batch(noise, valid) g_loss = self.generator_model.train_on_batch(noise, valid)
# Plot the progress # Plot the progress
print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss)) print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
# If at save interval => save generated image samples # If at save interval => save generated image samples
if epoch % sample_interval == 0: if epoch % sample_interval == 0:
@ -230,7 +226,7 @@ class ImprovedWGAN():
def sample_images(self, epoch): def sample_images(self, epoch):
r, c = 5, 5 r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100)) noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise) gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1 # Rescale images 0 - 1
@ -248,5 +244,5 @@ class ImprovedWGAN():
if __name__ == '__main__': if __name__ == '__main__':
wgan = ImprovedWGAN() wgan = WGANGP()
wgan.train(epochs=30000, batch_size=32, sample_interval=100) wgan.train(epochs=30000, batch_size=32, sample_interval=100)