djghdfkjMerge branch 'master' of github.com:Unity-Technologies/Keras-GAN
This commit is contained in:
Коммит
302caaf86c
40
cgan/cgan.py
40
cgan/cgan.py
|
@ -1,4 +1,6 @@
|
|||
from __future__ import print_function, division
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from keras.datasets import mnist
|
||||
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
|
||||
|
@ -9,17 +11,19 @@ from keras.models import Sequential, Model
|
|||
from keras.optimizers import Adam
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
sys.path.append('../unity')
|
||||
from unity import semantic_maps, semantic_maps_shape
|
||||
|
||||
class CGAN():
|
||||
def __init__(self):
|
||||
def __init__(self, image_size, channels, num_classes):
|
||||
# Input shape
|
||||
self.img_rows = 28
|
||||
self.img_cols = 28
|
||||
self.channels = 1
|
||||
self.img_rows = image_size
|
||||
self.img_cols = image_size
|
||||
self.channels = channels
|
||||
self.img_shape = (self.img_rows, self.img_cols, self.channels)
|
||||
self.num_classes = 10
|
||||
self.num_classes = num_classes
|
||||
self.latent_dim = 100
|
||||
|
||||
optimizer = Adam(0.0002, 0.5)
|
||||
|
@ -109,11 +113,11 @@ class CGAN():
|
|||
def train(self, epochs, batch_size=128, sample_interval=50):
|
||||
|
||||
# Load the dataset
|
||||
(X_train, y_train), (_, _) = mnist.load_data()
|
||||
|
||||
#(X_train, y_train), (_, _) = mnist.load_data()
|
||||
(X_train, y_train), (_, _) = semantic_maps(resize=self.img_cols)
|
||||
# Configure input
|
||||
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
|
||||
X_train = np.expand_dims(X_train, axis=3)
|
||||
#X_train = np.expand_dims(X_train, axis=3)
|
||||
y_train = y_train.reshape(-1, 1)
|
||||
|
||||
# Adversarial ground truths
|
||||
|
@ -161,8 +165,8 @@ class CGAN():
|
|||
def sample_images(self, epoch):
|
||||
r, c = 2, 5
|
||||
noise = np.random.normal(0, 1, (r * c, 100))
|
||||
sampled_labels = np.arange(0, 10).reshape(-1, 1)
|
||||
|
||||
sampled_labels = np.zeros(10).reshape(-1, 1)
|
||||
print("sampled_labels", sampled_labels)
|
||||
gen_imgs = self.generator.predict([noise, sampled_labels])
|
||||
|
||||
# Rescale images 0 - 1
|
||||
|
@ -172,8 +176,11 @@ class CGAN():
|
|||
cnt = 0
|
||||
for i in range(r):
|
||||
for j in range(c):
|
||||
axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
|
||||
axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
|
||||
if (self.channels==1):
|
||||
axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
|
||||
else:
|
||||
axs[i,j].imshow(gen_imgs[cnt,:,:])
|
||||
axs[i,j].set_title("Trait: %d" % sampled_labels[cnt])
|
||||
axs[i,j].axis('off')
|
||||
cnt += 1
|
||||
fig.savefig("images/%d.png" % epoch)
|
||||
|
@ -181,5 +188,8 @@ class CGAN():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cgan = CGAN()
|
||||
cgan.train(epochs=20000, batch_size=32, sample_interval=200)
|
||||
image_size, channels, classes = semantic_maps_shape()
|
||||
if (len(sys.argv)>1):
|
||||
image_size = int(sys.argv[1])
|
||||
cgan = CGAN(image_size, channels, 1)
|
||||
cgan.train(epochs=20000, batch_size=8, sample_interval=200)
|
||||
|
|
|
@ -257,7 +257,7 @@ class INFOGAN():
|
|||
save(self.discriminator, "discriminator")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == '__main__':
|
||||
image_size, channels, classes = semantic_maps_shape()
|
||||
infogan = INFOGAN(image_size, channels, classes)
|
||||
infogan.train(epochs=50000, batch_size=32, sample_interval=50)
|
||||
|
|
Загрузка…
Ссылка в новой задаче