WGAN (GP): Resolves #37. + clean up of handling input shapes of latent-to-image models

This commit is contained in:
eriklindernoren 2018-05-14 12:39:31 +02:00
Родитель 0293a4e6ff
Коммит 78333e934f
10 изменённых файлов: 39 добавлений и 47 удалений

Просмотреть файл

@ -58,7 +58,7 @@ $ python3 acgan.py
### Adversarial Autoencoder
Implementation of _Adversarial Autoencoder_.
[Code](aae/adversarial_autoencoder.py)
[Code](aae/aae.py)
Paper: https://arxiv.org/abs/1511.05644

Просмотреть файл

@ -60,7 +60,7 @@ class AdversarialAutoencoder():
img = Input(shape=self.img_shape)
h = Flatten(input_shape=self.img_shape)(img)
h = Flatten()(img)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)

Двоичные данные
assets/keras_gan.png

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 10 KiB

После

Ширина:  |  Высота:  |  Размер: 10 KiB

Просмотреть файл

@ -22,6 +22,7 @@ class BGAN():
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
@ -51,11 +52,9 @@ class BGAN():
def build_generator(self):
noise_shape = (100,)
model = Sequential()
model.add(Dense(256, input_shape=noise_shape))
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
@ -69,18 +68,16 @@ class BGAN():
model.summary()
noise = Input(shape=noise_shape)
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
@ -88,7 +85,7 @@ class BGAN():
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=img_shape)
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
@ -121,7 +118,7 @@ class BGAN():
idx = np.random.randint(0, X_train.shape[0], half_batch)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (half_batch, 100))
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
# Generate a half batch of new images
gen_imgs = self.generator.predict(noise)
@ -136,7 +133,7 @@ class BGAN():
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, 100))
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# The generator wants the discriminator to label the generated samples
# as valid (ones)
@ -154,7 +151,7 @@ class BGAN():
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1

Просмотреть файл

@ -52,7 +52,7 @@ class DCGAN():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))

Просмотреть файл

@ -20,6 +20,7 @@ class GAN():
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
@ -54,7 +55,7 @@ class GAN():
model = Sequential()
model.add(Dense(256, input_shape=noise_shape))
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
@ -68,18 +69,16 @@ class GAN():
model.summary()
noise = Input(shape=noise_shape)
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
@ -87,7 +86,7 @@ class GAN():
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=img_shape)
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)

Просмотреть файл

@ -20,6 +20,7 @@ class LSGAN():
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
@ -50,11 +51,9 @@ class LSGAN():
def build_generator(self):
noise_shape = (100,)
model = Sequential()
model.add(Dense(256, input_shape=noise_shape))
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
@ -68,18 +67,16 @@ class LSGAN():
model.summary()
noise = Input(shape=noise_shape)
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
@ -88,7 +85,7 @@ class LSGAN():
model.add(Dense(1))
model.summary()
img = Input(shape=img_shape)
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
@ -114,7 +111,7 @@ class LSGAN():
idx = np.random.randint(0, X_train.shape[0], half_batch)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (half_batch, 100))
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
# Generate a half batch of new images
gen_imgs = self.generator.predict(noise)
@ -129,7 +126,7 @@ class LSGAN():
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, 100))
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# The generator wants the discriminator to label the generated samples
# as valid (ones)
@ -147,7 +144,7 @@ class LSGAN():
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1

Просмотреть файл

@ -21,6 +21,7 @@ class SGAN():
self.img_cols = 28
self.channels = 1
self.num_classes = 10
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
@ -55,7 +56,7 @@ class SGAN():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=100))
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
@ -71,18 +72,16 @@ class SGAN():
model.summary()
noise = Input(shape=(100,))
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
@ -101,7 +100,7 @@ class SGAN():
model.add(Flatten())
model.summary()
img = Input(shape=img_shape)
img = Input(shape=self.img_shape)
features = model(img)
valid = Dense(1, activation="sigmoid")(features)
@ -142,7 +141,7 @@ class SGAN():
imgs = X_train[idx]
# Sample noise and generate a half batch of new images
noise = np.random.normal(0, 1, (half_batch, 100))
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
gen_imgs = self.generator.predict(noise)
valid = np.ones((half_batch, 1))
@ -161,7 +160,7 @@ class SGAN():
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, 100))
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
validity = np.ones((batch_size, 1))
# Train the generator
@ -176,7 +175,7 @@ class SGAN():
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1

Просмотреть файл

@ -61,7 +61,7 @@ class WGAN():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same"))
@ -76,7 +76,7 @@ class WGAN():
model.summary()
noise = Input(shape=noise_shape)
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
@ -106,7 +106,7 @@ class WGAN():
model.summary()
img = Input(shape=img_shape)
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)

Просмотреть файл

@ -128,7 +128,7 @@ class WGANGP():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(self.latent_dim,)))
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same"))
@ -143,7 +143,7 @@ class WGANGP():
model.summary()
noise = Input(shape=noise_shape)
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
@ -173,7 +173,7 @@ class WGANGP():
model.summary()
img = Input(shape=img_shape)
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)