djghdfkjMerge branch 'master' of github.com:Unity-Technologies/Keras-GAN

This commit is contained in:
dom 2019-09-12 15:37:26 -04:00
Родитель 295aebfdbc b665b5d09a
Коммит 302caaf86c
2 изменённых файлов: 26 добавлений и 16 удалений

Просмотреть файл

@ -1,4 +1,6 @@
from __future__ import print_function, division
import warnings
warnings.filterwarnings("ignore")
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
@ -9,17 +11,19 @@ from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import numpy as np
sys.path.append('../unity')
from unity import semantic_maps, semantic_maps_shape
class CGAN():
def __init__(self):
def __init__(self, image_size, channels, num_classes):
# Input shape
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_rows = image_size
self.img_cols = image_size
self.channels = channels
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.num_classes = 10
self.num_classes = num_classes
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
@ -109,11 +113,11 @@ class CGAN():
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data()
#(X_train, y_train), (_, _) = mnist.load_data()
(X_train, y_train), (_, _) = semantic_maps(resize=self.img_cols)
# Configure input
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
#X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)
# Adversarial ground truths
@ -161,8 +165,8 @@ class CGAN():
def sample_images(self, epoch):
r, c = 2, 5
noise = np.random.normal(0, 1, (r * c, 100))
sampled_labels = np.arange(0, 10).reshape(-1, 1)
sampled_labels = np.zeros(10).reshape(-1, 1)
print("sampled_labels", sampled_labels)
gen_imgs = self.generator.predict([noise, sampled_labels])
# Rescale images 0 - 1
@ -172,8 +176,11 @@ class CGAN():
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
if (self.channels==1):
axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
else:
axs[i,j].imshow(gen_imgs[cnt,:,:])
axs[i,j].set_title("Trait: %d" % sampled_labels[cnt])
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
@ -181,5 +188,8 @@ class CGAN():
if __name__ == '__main__':
cgan = CGAN()
cgan.train(epochs=20000, batch_size=32, sample_interval=200)
image_size, channels, classes = semantic_maps_shape()
if (len(sys.argv)>1):
image_size = int(sys.argv[1])
cgan = CGAN(image_size, channels, 1)
cgan.train(epochs=20000, batch_size=8, sample_interval=200)

Просмотреть файл

@ -257,7 +257,7 @@ class INFOGAN():
save(self.discriminator, "discriminator")
if __name__ == '__main__':
if __name__ == '__main__':
image_size, channels, classes = semantic_maps_shape()
infogan = INFOGAN(image_size, channels, classes)
infogan.train(epochs=50000, batch_size=32, sample_interval=50)