diff --git a/README.md b/README.md index 094e9b4..d516410 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,6 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su + [DiscoGAN](#discogan) + [DualGAN](#dualgan) + [Generative Adversarial Network](#gan) - + [Improved Wasserstein GAN](#improved-wgan) + [InfoGAN](#infogan) + [LSGAN](#lsgan) + [Pix2Pix](#pix2pix) @@ -31,6 +30,7 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su + [Semi-Supervised GAN](#sgan) + [Super-Resolution GAN](#srgan) + [Wasserstein GAN](#wgan) + + [Wasserstein GAN GP](#wgan-gp) ## Installation $ git clone https://github.com/eriklindernoren/Keras-GAN @@ -268,23 +268,6 @@ $ python3 gan_rgb.py
-### Improved WGAN -Implementation of _Improved Training of Wasserstein GANs_. - -[Code](improved_wgan/improved_wgan.py) - -Paper: https://arxiv.org/abs/1704.00028 - -#### Example -``` -$ cd improved_wgan/ -$ python3 improved_wgan.py -``` - -- -
- ### InfoGAN Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_. @@ -413,3 +396,20 @@ $ python3 wgan.py+ +### WGAN GP +Implementation of _Improved Training of Wasserstein GANs_. + +[Code](wgan_gp/wgan_gp.py) + +Paper: https://arxiv.org/abs/1704.00028 + +#### Example +``` +$ cd wgan_gp/ +$ python3 wgan_gp.py +``` + +
+ +
diff --git a/dcgan/dcgan.py b/dcgan/dcgan.py index 073b518..4a4c91d 100644 --- a/dcgan/dcgan.py +++ b/dcgan/dcgan.py @@ -54,15 +54,14 @@ class DCGAN(): model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,))) model.add(Reshape((7, 7, 128))) - model.add(BatchNormalization(momentum=0.8)) model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=3, padding="same")) - model.add(Activation("relu")) model.add(BatchNormalization(momentum=0.8)) + model.add(Activation("relu")) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=3, padding="same")) - model.add(Activation("relu")) model.add(BatchNormalization(momentum=0.8)) + model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) model.add(Activation("tanh")) @@ -82,17 +81,17 @@ class DCGAN(): model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) + model.add(BatchNormalization(momentum=0.8)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) - model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) - model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) model.add(BatchNormalization(momentum=0.8)) - model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) - + model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) + model.add(BatchNormalization(momentum=0.8)) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) @@ -109,7 +108,7 @@ class DCGAN(): (X_train, _), (_, _) = mnist.load_data() # Rescale -1 to 1 - X_train = (X_train.astype(np.float32) - 127.5) / 127.5 + X_train = X_train.astype(np.float32) / 127.5 - 1. X_train = np.expand_dims(X_train, axis=3) half_batch = int(batch_size / 2) @@ -125,7 +124,7 @@ class DCGAN(): imgs = X_train[idx] # Sample noise and generate a half batch of new images - noise = np.random.normal(0, 1, (half_batch, 100)) + noise = np.random.normal(0, 1, (half_batch, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Train the discriminator (real classified as ones and generated as zeros) @@ -138,7 +137,7 @@ class DCGAN(): # --------------------- # Sample generator input - noise = np.random.normal(0, 1, (batch_size, 100)) + noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator (wants discriminator to mistake images as real) g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1))) @@ -152,7 +151,7 @@ class DCGAN(): def save_imgs(self, epoch): r, c = 5, 5 - noise = np.random.normal(0, 1, (r * c, 100)) + noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Rescale images 0 - 1 diff --git a/wgan/wgan.py b/wgan/wgan.py index 252fd37..6c42284 100644 --- a/wgan/wgan.py +++ b/wgan/wgan.py @@ -21,15 +21,17 @@ class WGAN(): self.img_rows = 28 self.img_cols = 28 self.channels = 1 + self.img_shape = (self.img_rows, self.img_cols, self.channels) + self.latent_dim = 100 # Following parameter and optimizer set as recommended in paper self.n_critic = 5 self.clip_value = 0.01 optimizer = RMSprop(lr=0.00005) - # Build and compile the discriminator - self.discriminator = self.build_discriminator() - self.discriminator.compile(loss=self.wasserstein_loss, + # Build and compile the critic + self.critic = self.build_critic() + self.critic.compile(loss=self.wasserstein_loss, optimizer=optimizer, metrics=['accuracy']) @@ -41,13 +43,12 @@ class WGAN(): img = self.generator(z) # For the combined model we will only train the generator - self.discriminator.trainable = False + self.critic.trainable = False - # The discriminator takes generated images as input and determines validity - valid = self.discriminator(img) + # The critic takes generated images as input and determines validity + valid = self.critic(img) - # The combined model (stacked generator and discriminator) takes - # noise as input => generates images => determines validity + # The combined model (stacked generator and critic) self.combined = Model(z, valid) self.combined.compile(loss=self.wasserstein_loss, optimizer=optimizer, @@ -58,21 +59,18 @@ class WGAN(): def build_generator(self): - noise_shape = (100,) - model = Sequential() - model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape)) + model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,))) model.add(Reshape((7, 7, 128))) - model.add(BatchNormalization(momentum=0.8)) model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=4, padding="same")) - model.add(Activation("relu")) model.add(BatchNormalization(momentum=0.8)) + model.add(Activation("relu")) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=4, padding="same")) - model.add(Activation("relu")) model.add(BatchNormalization(momentum=0.8)) + model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=4, padding="same")) model.add(Activation("tanh")) @@ -83,37 +81,35 @@ class WGAN(): return Model(noise, img) - def build_discriminator(self): - - img_shape = (self.img_rows, self.img_cols, self.channels) + def build_critic(self): model = Sequential() - model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same")) + model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) + model.add(BatchNormalization(momentum=0.8)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) - model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) - model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) model.add(BatchNormalization(momentum=0.8)) - model.add(Conv2D(128, kernel_size=3, strides=1, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) - + model.add(Conv2D(128, kernel_size=3, strides=1, padding="same")) + model.add(BatchNormalization(momentum=0.8)) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) model.add(Flatten()) + model.add(Dense(1)) model.summary() img = Input(shape=img_shape) - features = model(img) - valid = Dense(1, activation="linear")(features) + validity = model(img) - return Model(img, valid) + return Model(img, validity) def train(self, epochs, batch_size=128, sample_interval=50): @@ -138,18 +134,18 @@ class WGAN(): idx = np.random.randint(0, X_train.shape[0], half_batch) imgs = X_train[idx] - noise = np.random.normal(0, 1, (half_batch, 100)) + noise = np.random.normal(0, 1, (half_batch, self.latent_dim)) # Generate a half batch of new images gen_imgs = self.generator.predict(noise) - # Train the discriminator - d_loss_real = self.discriminator.train_on_batch(imgs, -np.ones((half_batch, 1))) - d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.ones((half_batch, 1))) + # Train the critic + d_loss_real = self.critic.train_on_batch(imgs, -np.ones((half_batch, 1))) + d_loss_fake = self.critic.train_on_batch(gen_imgs, np.ones((half_batch, 1))) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) - # Clip discriminator weights - for l in self.discriminator.layers: + # Clip critic weights + for l in self.critic.layers: weights = l.get_weights() weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights] l.set_weights(weights) @@ -159,7 +155,7 @@ class WGAN(): # Train Generator # --------------------- - noise = np.random.normal(0, 1, (batch_size, 100)) + noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator g_loss = self.combined.train_on_batch(noise, -np.ones((batch_size, 1))) @@ -173,7 +169,7 @@ class WGAN(): def sample_images(self, epoch): r, c = 5, 5 - noise = np.random.normal(0, 1, (r * c, 100)) + noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Rescale images 0 - 1 diff --git a/improved_wgan/images/.gitignore b/wgan_gp/images/.gitignore similarity index 100% rename from improved_wgan/images/.gitignore rename to wgan_gp/images/.gitignore diff --git a/improved_wgan/saved_model/.gitignore b/wgan_gp/saved_model/.gitignore similarity index 100% rename from improved_wgan/saved_model/.gitignore rename to wgan_gp/saved_model/.gitignore diff --git a/improved_wgan/improved_wgan.py b/wgan_gp/wgan_gp.py similarity index 81% rename from improved_wgan/improved_wgan.py rename to wgan_gp/wgan_gp.py index 4f01a27..1003f6f 100644 --- a/improved_wgan/improved_wgan.py +++ b/wgan_gp/wgan_gp.py @@ -26,30 +26,31 @@ import numpy as np class RandomWeightedAverage(_Merge): """Provides a (random) weighted average between real and generated image samples""" def _merge_function(self, inputs): - weights = K.random_uniform((32, 1, 1, 1)) - return (weights * inputs[0]) + ((1 - weights) * inputs[1]) + alpha = K.random_uniform((32, 1, 1, 1)) + return (alpha * inputs[0]) + ((1 - alpha) * inputs[1]) -class ImprovedWGAN(): +class WGANGP(): def __init__(self): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) + self.latent_dim = 100 # Following parameter and optimizer set as recommended in paper self.n_critic = 5 optimizer = RMSprop(lr=0.00005) - # Build the generator and discriminator + # Build the generator and critic self.generator = self.build_generator() - self.discriminator = self.build_discriminator() + self.critic = self.build_critic() #------------------------------- # Construct Computational Graph - # for Discriminator + # for the Critic #------------------------------- - # Freeze generator's layers while training discriminator + # Freeze generator's layers while training critic self.generator.trainable = False # Image input (real sample) @@ -61,23 +62,23 @@ class ImprovedWGAN(): fake_img = self.generator(z_disc) # Discriminator determines validity of the real and fake images - fake = self.discriminator(fake_img) - real = self.discriminator(real_img) + fake = self.critic(fake_img) + valid = self.critic(real_img) # Construct weighted average between real and fake images - merged_img = RandomWeightedAverage()([real_img, fake_img]) + interpolated_img = RandomWeightedAverage()([real_img, fake_img]) # Determine validity of weighted sample - valid_merged = self.discriminator(merged_img) + validity_interpolated = self.critic(interpolated_img) # Use Python partial to provide loss function with additional # 'averaged_samples' argument partial_gp_loss = partial(self.gradient_penalty_loss, - averaged_samples=merged_img) + averaged_samples=interpolated_img) partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names - self.discriminator_model = Model(inputs=[real_img, z_disc], - outputs=[real, fake, valid_merged]) - self.discriminator_model.compile(loss=[self.wasserstein_loss, + self.critic_model = Model(inputs=[real_img, z_disc], + outputs=[valid, fake, validity_interpolated]) + self.critic_model.compile(loss=[self.wasserstein_loss, self.wasserstein_loss, partial_gp_loss], optimizer=optimizer, @@ -87,8 +88,8 @@ class ImprovedWGAN(): # for Generator #------------------------------- - # For the generator we freeze the discriminator's layers - self.discriminator.trainable = False + # For the generator we freeze the critic's layers + self.critic.trainable = False self.generator.trainable = True # Sampled noise for input to generator @@ -96,7 +97,7 @@ class ImprovedWGAN(): # Generate images based of noise img = self.generator(z_gen) # Discriminator determines validity - valid = self.discriminator(img) + valid = self.critic(img) # Defines generator model self.generator_model = Model(z_gen, valid) self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer) @@ -125,21 +126,18 @@ class ImprovedWGAN(): def build_generator(self): - noise_shape = (100,) - model = Sequential() - model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape)) + model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,))) model.add(Reshape((7, 7, 128))) - model.add(BatchNormalization(momentum=0.8)) model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=4, padding="same")) - model.add(Activation("relu")) model.add(BatchNormalization(momentum=0.8)) + model.add(Activation("relu")) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=4, padding="same")) - model.add(Activation("relu")) model.add(BatchNormalization(momentum=0.8)) + model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=4, padding="same")) model.add(Activation("tanh")) @@ -150,37 +148,35 @@ class ImprovedWGAN(): return Model(noise, img) - def build_discriminator(self): - - img_shape = (self.img_rows, self.img_cols, self.channels) + def build_critic(self): model = Sequential() - model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same")) + model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) + model.add(BatchNormalization(momentum=0.8)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) - model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) - model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) model.add(BatchNormalization(momentum=0.8)) - model.add(Conv2D(128, kernel_size=3, strides=1, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) - + model.add(Conv2D(128, kernel_size=3, strides=1, padding="same")) + model.add(BatchNormalization(momentum=0.8)) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) model.add(Flatten()) + model.add(Dense(1)) model.summary() img = Input(shape=img_shape) - features = model(img) - valid = Dense(1, activation="linear")(features) + validity = model(img) - return Model(img, valid) + return Model(img, validity) def train(self, epochs, batch_size, sample_interval=50): @@ -207,9 +203,9 @@ class ImprovedWGAN(): idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] # Sample generator input - noise = np.random.normal(0, 1, (batch_size, 100)) - # Train the discriminator - d_loss = self.discriminator_model.train_on_batch([imgs, noise], + noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) + # Train the critic + d_loss = self.critic_model.train_on_batch([imgs, noise], [valid, fake, dummy]) # --------------------- @@ -217,12 +213,12 @@ class ImprovedWGAN(): # --------------------- # Sample generator input - noise = np.random.normal(0, 1, (batch_size, 100)) + noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator g_loss = self.generator_model.train_on_batch(noise, valid) # Plot the progress - print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss)) + print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss)) # If at save interval => save generated image samples if epoch % sample_interval == 0: @@ -230,7 +226,7 @@ class ImprovedWGAN(): def sample_images(self, epoch): r, c = 5, 5 - noise = np.random.normal(0, 1, (r * c, 100)) + noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Rescale images 0 - 1 @@ -248,5 +244,5 @@ class ImprovedWGAN(): if __name__ == '__main__': - wgan = ImprovedWGAN() + wgan = WGANGP() wgan.train(epochs=30000, batch_size=32, sample_interval=100)