Wasserstein GAN: 'discriminator' => 'critic'
This commit is contained in:
Родитель
ce7dc229be
Коммит
0293a4e6ff
36
README.md
36
README.md
|
@ -23,7 +23,6 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su
|
||||||
+ [DiscoGAN](#discogan)
|
+ [DiscoGAN](#discogan)
|
||||||
+ [DualGAN](#dualgan)
|
+ [DualGAN](#dualgan)
|
||||||
+ [Generative Adversarial Network](#gan)
|
+ [Generative Adversarial Network](#gan)
|
||||||
+ [Improved Wasserstein GAN](#improved-wgan)
|
|
||||||
+ [InfoGAN](#infogan)
|
+ [InfoGAN](#infogan)
|
||||||
+ [LSGAN](#lsgan)
|
+ [LSGAN](#lsgan)
|
||||||
+ [Pix2Pix](#pix2pix)
|
+ [Pix2Pix](#pix2pix)
|
||||||
|
@ -31,6 +30,7 @@ Collection of Keras implementations of Generative Adversarial Networks (GANs) su
|
||||||
+ [Semi-Supervised GAN](#sgan)
|
+ [Semi-Supervised GAN](#sgan)
|
||||||
+ [Super-Resolution GAN](#srgan)
|
+ [Super-Resolution GAN](#srgan)
|
||||||
+ [Wasserstein GAN](#wgan)
|
+ [Wasserstein GAN](#wgan)
|
||||||
|
+ [Wasserstein GAN GP](#wgan-gp)
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
$ git clone https://github.com/eriklindernoren/Keras-GAN
|
$ git clone https://github.com/eriklindernoren/Keras-GAN
|
||||||
|
@ -268,23 +268,6 @@ $ python3 gan_rgb.py
|
||||||
<img src="gan/etc/adam.gif" width="640"\>
|
<img src="gan/etc/adam.gif" width="640"\>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Improved WGAN
|
|
||||||
Implementation of _Improved Training of Wasserstein GANs_.
|
|
||||||
|
|
||||||
[Code](improved_wgan/improved_wgan.py)
|
|
||||||
|
|
||||||
Paper: https://arxiv.org/abs/1704.00028
|
|
||||||
|
|
||||||
#### Example
|
|
||||||
```
|
|
||||||
$ cd improved_wgan/
|
|
||||||
$ python3 improved_wgan.py
|
|
||||||
```
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="http://eriklindernoren.se/images/imp_wgan.gif" width="640"\>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
### InfoGAN
|
### InfoGAN
|
||||||
Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_.
|
Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_.
|
||||||
|
|
||||||
|
@ -413,3 +396,20 @@ $ python3 wgan.py
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="http://eriklindernoren.se/images/wgan2.png" width="640"\>
|
<img src="http://eriklindernoren.se/images/wgan2.png" width="640"\>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
### WGAN GP
|
||||||
|
Implementation of _Improved Training of Wasserstein GANs_.
|
||||||
|
|
||||||
|
[Code](wgan_gp/wgan_gp.py)
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1704.00028
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
```
|
||||||
|
$ cd wgan_gp/
|
||||||
|
$ python3 wgan_gp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="http://eriklindernoren.se/images/imp_wgan.gif" width="640"\>
|
||||||
|
</p>
|
||||||
|
|
|
@ -54,15 +54,14 @@ class DCGAN():
|
||||||
|
|
||||||
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
|
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
|
||||||
model.add(Reshape((7, 7, 128)))
|
model.add(Reshape((7, 7, 128)))
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
|
||||||
model.add(UpSampling2D())
|
model.add(UpSampling2D())
|
||||||
model.add(Conv2D(128, kernel_size=3, padding="same"))
|
model.add(Conv2D(128, kernel_size=3, padding="same"))
|
||||||
model.add(Activation("relu"))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(Activation("relu"))
|
||||||
model.add(UpSampling2D())
|
model.add(UpSampling2D())
|
||||||
model.add(Conv2D(64, kernel_size=3, padding="same"))
|
model.add(Conv2D(64, kernel_size=3, padding="same"))
|
||||||
model.add(Activation("relu"))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(Activation("relu"))
|
||||||
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
|
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
|
||||||
model.add(Activation("tanh"))
|
model.add(Activation("tanh"))
|
||||||
|
|
||||||
|
@ -82,17 +81,17 @@ class DCGAN():
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
|
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
|
||||||
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
|
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
|
||||||
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
|
||||||
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
|
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
|
||||||
model.add(Dropout(0.25))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
|
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
|
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
|
||||||
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
|
model.add(Dropout(0.25))
|
||||||
model.add(Flatten())
|
model.add(Flatten())
|
||||||
model.add(Dense(1, activation='sigmoid'))
|
model.add(Dense(1, activation='sigmoid'))
|
||||||
|
|
||||||
|
@ -109,7 +108,7 @@ class DCGAN():
|
||||||
(X_train, _), (_, _) = mnist.load_data()
|
(X_train, _), (_, _) = mnist.load_data()
|
||||||
|
|
||||||
# Rescale -1 to 1
|
# Rescale -1 to 1
|
||||||
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
|
X_train = X_train.astype(np.float32) / 127.5 - 1.
|
||||||
X_train = np.expand_dims(X_train, axis=3)
|
X_train = np.expand_dims(X_train, axis=3)
|
||||||
|
|
||||||
half_batch = int(batch_size / 2)
|
half_batch = int(batch_size / 2)
|
||||||
|
@ -125,7 +124,7 @@ class DCGAN():
|
||||||
imgs = X_train[idx]
|
imgs = X_train[idx]
|
||||||
|
|
||||||
# Sample noise and generate a half batch of new images
|
# Sample noise and generate a half batch of new images
|
||||||
noise = np.random.normal(0, 1, (half_batch, 100))
|
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
|
||||||
gen_imgs = self.generator.predict(noise)
|
gen_imgs = self.generator.predict(noise)
|
||||||
|
|
||||||
# Train the discriminator (real classified as ones and generated as zeros)
|
# Train the discriminator (real classified as ones and generated as zeros)
|
||||||
|
@ -138,7 +137,7 @@ class DCGAN():
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
|
||||||
# Sample generator input
|
# Sample generator input
|
||||||
noise = np.random.normal(0, 1, (batch_size, 100))
|
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
|
||||||
|
|
||||||
# Train the generator (wants discriminator to mistake images as real)
|
# Train the generator (wants discriminator to mistake images as real)
|
||||||
g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))
|
g_loss = self.combined.train_on_batch(noise, np.ones((batch_size, 1)))
|
||||||
|
@ -152,7 +151,7 @@ class DCGAN():
|
||||||
|
|
||||||
def save_imgs(self, epoch):
|
def save_imgs(self, epoch):
|
||||||
r, c = 5, 5
|
r, c = 5, 5
|
||||||
noise = np.random.normal(0, 1, (r * c, 100))
|
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
|
||||||
gen_imgs = self.generator.predict(noise)
|
gen_imgs = self.generator.predict(noise)
|
||||||
|
|
||||||
# Rescale images 0 - 1
|
# Rescale images 0 - 1
|
||||||
|
|
64
wgan/wgan.py
64
wgan/wgan.py
|
@ -21,15 +21,17 @@ class WGAN():
|
||||||
self.img_rows = 28
|
self.img_rows = 28
|
||||||
self.img_cols = 28
|
self.img_cols = 28
|
||||||
self.channels = 1
|
self.channels = 1
|
||||||
|
self.img_shape = (self.img_rows, self.img_cols, self.channels)
|
||||||
|
self.latent_dim = 100
|
||||||
|
|
||||||
# Following parameter and optimizer set as recommended in paper
|
# Following parameter and optimizer set as recommended in paper
|
||||||
self.n_critic = 5
|
self.n_critic = 5
|
||||||
self.clip_value = 0.01
|
self.clip_value = 0.01
|
||||||
optimizer = RMSprop(lr=0.00005)
|
optimizer = RMSprop(lr=0.00005)
|
||||||
|
|
||||||
# Build and compile the discriminator
|
# Build and compile the critic
|
||||||
self.discriminator = self.build_discriminator()
|
self.critic = self.build_critic()
|
||||||
self.discriminator.compile(loss=self.wasserstein_loss,
|
self.critic.compile(loss=self.wasserstein_loss,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
|
@ -41,13 +43,12 @@ class WGAN():
|
||||||
img = self.generator(z)
|
img = self.generator(z)
|
||||||
|
|
||||||
# For the combined model we will only train the generator
|
# For the combined model we will only train the generator
|
||||||
self.discriminator.trainable = False
|
self.critic.trainable = False
|
||||||
|
|
||||||
# The discriminator takes generated images as input and determines validity
|
# The critic takes generated images as input and determines validity
|
||||||
valid = self.discriminator(img)
|
valid = self.critic(img)
|
||||||
|
|
||||||
# The combined model (stacked generator and discriminator) takes
|
# The combined model (stacked generator and critic)
|
||||||
# noise as input => generates images => determines validity
|
|
||||||
self.combined = Model(z, valid)
|
self.combined = Model(z, valid)
|
||||||
self.combined.compile(loss=self.wasserstein_loss,
|
self.combined.compile(loss=self.wasserstein_loss,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -58,21 +59,18 @@ class WGAN():
|
||||||
|
|
||||||
def build_generator(self):
|
def build_generator(self):
|
||||||
|
|
||||||
noise_shape = (100,)
|
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
|
|
||||||
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape))
|
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
|
||||||
model.add(Reshape((7, 7, 128)))
|
model.add(Reshape((7, 7, 128)))
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
|
||||||
model.add(UpSampling2D())
|
model.add(UpSampling2D())
|
||||||
model.add(Conv2D(128, kernel_size=4, padding="same"))
|
model.add(Conv2D(128, kernel_size=4, padding="same"))
|
||||||
model.add(Activation("relu"))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(Activation("relu"))
|
||||||
model.add(UpSampling2D())
|
model.add(UpSampling2D())
|
||||||
model.add(Conv2D(64, kernel_size=4, padding="same"))
|
model.add(Conv2D(64, kernel_size=4, padding="same"))
|
||||||
model.add(Activation("relu"))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(Activation("relu"))
|
||||||
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
|
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
|
||||||
model.add(Activation("tanh"))
|
model.add(Activation("tanh"))
|
||||||
|
|
||||||
|
@ -83,37 +81,35 @@ class WGAN():
|
||||||
|
|
||||||
return Model(noise, img)
|
return Model(noise, img)
|
||||||
|
|
||||||
def build_discriminator(self):
|
def build_critic(self):
|
||||||
|
|
||||||
img_shape = (self.img_rows, self.img_cols, self.channels)
|
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
|
|
||||||
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
|
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
|
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
|
||||||
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
|
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
|
||||||
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
|
||||||
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
|
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
|
||||||
model.add(Dropout(0.25))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
|
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
|
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
|
||||||
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
|
model.add(Dropout(0.25))
|
||||||
model.add(Flatten())
|
model.add(Flatten())
|
||||||
|
model.add(Dense(1))
|
||||||
|
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
img = Input(shape=img_shape)
|
img = Input(shape=img_shape)
|
||||||
features = model(img)
|
validity = model(img)
|
||||||
valid = Dense(1, activation="linear")(features)
|
|
||||||
|
|
||||||
return Model(img, valid)
|
return Model(img, validity)
|
||||||
|
|
||||||
def train(self, epochs, batch_size=128, sample_interval=50):
|
def train(self, epochs, batch_size=128, sample_interval=50):
|
||||||
|
|
||||||
|
@ -138,18 +134,18 @@ class WGAN():
|
||||||
idx = np.random.randint(0, X_train.shape[0], half_batch)
|
idx = np.random.randint(0, X_train.shape[0], half_batch)
|
||||||
imgs = X_train[idx]
|
imgs = X_train[idx]
|
||||||
|
|
||||||
noise = np.random.normal(0, 1, (half_batch, 100))
|
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
|
||||||
|
|
||||||
# Generate a half batch of new images
|
# Generate a half batch of new images
|
||||||
gen_imgs = self.generator.predict(noise)
|
gen_imgs = self.generator.predict(noise)
|
||||||
|
|
||||||
# Train the discriminator
|
# Train the critic
|
||||||
d_loss_real = self.discriminator.train_on_batch(imgs, -np.ones((half_batch, 1)))
|
d_loss_real = self.critic.train_on_batch(imgs, -np.ones((half_batch, 1)))
|
||||||
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.ones((half_batch, 1)))
|
d_loss_fake = self.critic.train_on_batch(gen_imgs, np.ones((half_batch, 1)))
|
||||||
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
|
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
|
||||||
|
|
||||||
# Clip discriminator weights
|
# Clip critic weights
|
||||||
for l in self.discriminator.layers:
|
for l in self.critic.layers:
|
||||||
weights = l.get_weights()
|
weights = l.get_weights()
|
||||||
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
|
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
|
||||||
l.set_weights(weights)
|
l.set_weights(weights)
|
||||||
|
@ -159,7 +155,7 @@ class WGAN():
|
||||||
# Train Generator
|
# Train Generator
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
|
||||||
noise = np.random.normal(0, 1, (batch_size, 100))
|
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
|
||||||
|
|
||||||
# Train the generator
|
# Train the generator
|
||||||
g_loss = self.combined.train_on_batch(noise, -np.ones((batch_size, 1)))
|
g_loss = self.combined.train_on_batch(noise, -np.ones((batch_size, 1)))
|
||||||
|
@ -173,7 +169,7 @@ class WGAN():
|
||||||
|
|
||||||
def sample_images(self, epoch):
|
def sample_images(self, epoch):
|
||||||
r, c = 5, 5
|
r, c = 5, 5
|
||||||
noise = np.random.normal(0, 1, (r * c, 100))
|
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
|
||||||
gen_imgs = self.generator.predict(noise)
|
gen_imgs = self.generator.predict(noise)
|
||||||
|
|
||||||
# Rescale images 0 - 1
|
# Rescale images 0 - 1
|
||||||
|
|
|
@ -26,30 +26,31 @@ import numpy as np
|
||||||
class RandomWeightedAverage(_Merge):
|
class RandomWeightedAverage(_Merge):
|
||||||
"""Provides a (random) weighted average between real and generated image samples"""
|
"""Provides a (random) weighted average between real and generated image samples"""
|
||||||
def _merge_function(self, inputs):
|
def _merge_function(self, inputs):
|
||||||
weights = K.random_uniform((32, 1, 1, 1))
|
alpha = K.random_uniform((32, 1, 1, 1))
|
||||||
return (weights * inputs[0]) + ((1 - weights) * inputs[1])
|
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
|
||||||
|
|
||||||
class ImprovedWGAN():
|
class WGANGP():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.img_rows = 28
|
self.img_rows = 28
|
||||||
self.img_cols = 28
|
self.img_cols = 28
|
||||||
self.channels = 1
|
self.channels = 1
|
||||||
self.img_shape = (self.img_rows, self.img_cols, self.channels)
|
self.img_shape = (self.img_rows, self.img_cols, self.channels)
|
||||||
|
self.latent_dim = 100
|
||||||
|
|
||||||
# Following parameter and optimizer set as recommended in paper
|
# Following parameter and optimizer set as recommended in paper
|
||||||
self.n_critic = 5
|
self.n_critic = 5
|
||||||
optimizer = RMSprop(lr=0.00005)
|
optimizer = RMSprop(lr=0.00005)
|
||||||
|
|
||||||
# Build the generator and discriminator
|
# Build the generator and critic
|
||||||
self.generator = self.build_generator()
|
self.generator = self.build_generator()
|
||||||
self.discriminator = self.build_discriminator()
|
self.critic = self.build_critic()
|
||||||
|
|
||||||
#-------------------------------
|
#-------------------------------
|
||||||
# Construct Computational Graph
|
# Construct Computational Graph
|
||||||
# for Discriminator
|
# for the Critic
|
||||||
#-------------------------------
|
#-------------------------------
|
||||||
|
|
||||||
# Freeze generator's layers while training discriminator
|
# Freeze generator's layers while training critic
|
||||||
self.generator.trainable = False
|
self.generator.trainable = False
|
||||||
|
|
||||||
# Image input (real sample)
|
# Image input (real sample)
|
||||||
|
@ -61,23 +62,23 @@ class ImprovedWGAN():
|
||||||
fake_img = self.generator(z_disc)
|
fake_img = self.generator(z_disc)
|
||||||
|
|
||||||
# Discriminator determines validity of the real and fake images
|
# Discriminator determines validity of the real and fake images
|
||||||
fake = self.discriminator(fake_img)
|
fake = self.critic(fake_img)
|
||||||
real = self.discriminator(real_img)
|
valid = self.critic(real_img)
|
||||||
|
|
||||||
# Construct weighted average between real and fake images
|
# Construct weighted average between real and fake images
|
||||||
merged_img = RandomWeightedAverage()([real_img, fake_img])
|
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
|
||||||
# Determine validity of weighted sample
|
# Determine validity of weighted sample
|
||||||
valid_merged = self.discriminator(merged_img)
|
validity_interpolated = self.critic(interpolated_img)
|
||||||
|
|
||||||
# Use Python partial to provide loss function with additional
|
# Use Python partial to provide loss function with additional
|
||||||
# 'averaged_samples' argument
|
# 'averaged_samples' argument
|
||||||
partial_gp_loss = partial(self.gradient_penalty_loss,
|
partial_gp_loss = partial(self.gradient_penalty_loss,
|
||||||
averaged_samples=merged_img)
|
averaged_samples=interpolated_img)
|
||||||
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
|
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
|
||||||
|
|
||||||
self.discriminator_model = Model(inputs=[real_img, z_disc],
|
self.critic_model = Model(inputs=[real_img, z_disc],
|
||||||
outputs=[real, fake, valid_merged])
|
outputs=[valid, fake, validity_interpolated])
|
||||||
self.discriminator_model.compile(loss=[self.wasserstein_loss,
|
self.critic_model.compile(loss=[self.wasserstein_loss,
|
||||||
self.wasserstein_loss,
|
self.wasserstein_loss,
|
||||||
partial_gp_loss],
|
partial_gp_loss],
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -87,8 +88,8 @@ class ImprovedWGAN():
|
||||||
# for Generator
|
# for Generator
|
||||||
#-------------------------------
|
#-------------------------------
|
||||||
|
|
||||||
# For the generator we freeze the discriminator's layers
|
# For the generator we freeze the critic's layers
|
||||||
self.discriminator.trainable = False
|
self.critic.trainable = False
|
||||||
self.generator.trainable = True
|
self.generator.trainable = True
|
||||||
|
|
||||||
# Sampled noise for input to generator
|
# Sampled noise for input to generator
|
||||||
|
@ -96,7 +97,7 @@ class ImprovedWGAN():
|
||||||
# Generate images based of noise
|
# Generate images based of noise
|
||||||
img = self.generator(z_gen)
|
img = self.generator(z_gen)
|
||||||
# Discriminator determines validity
|
# Discriminator determines validity
|
||||||
valid = self.discriminator(img)
|
valid = self.critic(img)
|
||||||
# Defines generator model
|
# Defines generator model
|
||||||
self.generator_model = Model(z_gen, valid)
|
self.generator_model = Model(z_gen, valid)
|
||||||
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
|
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
|
||||||
|
@ -125,21 +126,18 @@ class ImprovedWGAN():
|
||||||
|
|
||||||
def build_generator(self):
|
def build_generator(self):
|
||||||
|
|
||||||
noise_shape = (100,)
|
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
|
|
||||||
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape))
|
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
|
||||||
model.add(Reshape((7, 7, 128)))
|
model.add(Reshape((7, 7, 128)))
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
|
||||||
model.add(UpSampling2D())
|
model.add(UpSampling2D())
|
||||||
model.add(Conv2D(128, kernel_size=4, padding="same"))
|
model.add(Conv2D(128, kernel_size=4, padding="same"))
|
||||||
model.add(Activation("relu"))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(Activation("relu"))
|
||||||
model.add(UpSampling2D())
|
model.add(UpSampling2D())
|
||||||
model.add(Conv2D(64, kernel_size=4, padding="same"))
|
model.add(Conv2D(64, kernel_size=4, padding="same"))
|
||||||
model.add(Activation("relu"))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(Activation("relu"))
|
||||||
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
|
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
|
||||||
model.add(Activation("tanh"))
|
model.add(Activation("tanh"))
|
||||||
|
|
||||||
|
@ -150,37 +148,35 @@ class ImprovedWGAN():
|
||||||
|
|
||||||
return Model(noise, img)
|
return Model(noise, img)
|
||||||
|
|
||||||
def build_discriminator(self):
|
def build_critic(self):
|
||||||
|
|
||||||
img_shape = (self.img_rows, self.img_cols, self.channels)
|
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
|
|
||||||
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
|
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
|
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
|
||||||
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
|
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
|
||||||
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
|
||||||
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
|
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
|
||||||
model.add(Dropout(0.25))
|
|
||||||
model.add(BatchNormalization(momentum=0.8))
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
|
|
||||||
model.add(LeakyReLU(alpha=0.2))
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
|
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
|
||||||
|
model.add(BatchNormalization(momentum=0.8))
|
||||||
|
model.add(LeakyReLU(alpha=0.2))
|
||||||
|
model.add(Dropout(0.25))
|
||||||
model.add(Flatten())
|
model.add(Flatten())
|
||||||
|
model.add(Dense(1))
|
||||||
|
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
img = Input(shape=img_shape)
|
img = Input(shape=img_shape)
|
||||||
features = model(img)
|
validity = model(img)
|
||||||
valid = Dense(1, activation="linear")(features)
|
|
||||||
|
|
||||||
return Model(img, valid)
|
return Model(img, validity)
|
||||||
|
|
||||||
def train(self, epochs, batch_size, sample_interval=50):
|
def train(self, epochs, batch_size, sample_interval=50):
|
||||||
|
|
||||||
|
@ -207,9 +203,9 @@ class ImprovedWGAN():
|
||||||
idx = np.random.randint(0, X_train.shape[0], batch_size)
|
idx = np.random.randint(0, X_train.shape[0], batch_size)
|
||||||
imgs = X_train[idx]
|
imgs = X_train[idx]
|
||||||
# Sample generator input
|
# Sample generator input
|
||||||
noise = np.random.normal(0, 1, (batch_size, 100))
|
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
|
||||||
# Train the discriminator
|
# Train the critic
|
||||||
d_loss = self.discriminator_model.train_on_batch([imgs, noise],
|
d_loss = self.critic_model.train_on_batch([imgs, noise],
|
||||||
[valid, fake, dummy])
|
[valid, fake, dummy])
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
@ -217,12 +213,12 @@ class ImprovedWGAN():
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
|
||||||
# Sample generator input
|
# Sample generator input
|
||||||
noise = np.random.normal(0, 1, (batch_size, 100))
|
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
|
||||||
# Train the generator
|
# Train the generator
|
||||||
g_loss = self.generator_model.train_on_batch(noise, valid)
|
g_loss = self.generator_model.train_on_batch(noise, valid)
|
||||||
|
|
||||||
# Plot the progress
|
# Plot the progress
|
||||||
print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss))
|
print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
|
||||||
|
|
||||||
# If at save interval => save generated image samples
|
# If at save interval => save generated image samples
|
||||||
if epoch % sample_interval == 0:
|
if epoch % sample_interval == 0:
|
||||||
|
@ -230,7 +226,7 @@ class ImprovedWGAN():
|
||||||
|
|
||||||
def sample_images(self, epoch):
|
def sample_images(self, epoch):
|
||||||
r, c = 5, 5
|
r, c = 5, 5
|
||||||
noise = np.random.normal(0, 1, (r * c, 100))
|
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
|
||||||
gen_imgs = self.generator.predict(noise)
|
gen_imgs = self.generator.predict(noise)
|
||||||
|
|
||||||
# Rescale images 0 - 1
|
# Rescale images 0 - 1
|
||||||
|
@ -248,5 +244,5 @@ class ImprovedWGAN():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
wgan = ImprovedWGAN()
|
wgan = WGANGP()
|
||||||
wgan.train(epochs=30000, batch_size=32, sample_interval=100)
|
wgan.train(epochs=30000, batch_size=32, sample_interval=100)
|
Загрузка…
Ссылка в новой задаче