Update acgan.py
This commit is contained in:
Родитель
44d3320e84
Коммит
88591d89a9
|
@ -75,7 +75,7 @@ class ACGAN():
|
|||
|
||||
noise = Input(shape=(self.latent_dim,))
|
||||
label = Input(shape=(1,), dtype='int32')
|
||||
label_embedding = Flatten()(Embedding(self.num_classes, 100)(label))
|
||||
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
|
||||
|
||||
model_input = multiply([noise, label_embedding])
|
||||
img = model(model_input)
|
||||
|
@ -141,7 +141,7 @@ class ACGAN():
|
|||
imgs = X_train[idx]
|
||||
|
||||
# Sample noise as generator input
|
||||
noise = np.random.normal(0, 1, (batch_size, 100))
|
||||
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
|
||||
|
||||
# The labels of the digits that the generator tries to create an
|
||||
# image representation of
|
||||
|
@ -175,7 +175,7 @@ class ACGAN():
|
|||
|
||||
def sample_images(self, epoch):
|
||||
r, c = 10, 10
|
||||
noise = np.random.normal(0, 1, (r * c, 100))
|
||||
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
|
||||
sampled_labels = np.array([num for _ in range(r) for num in range(c)])
|
||||
gen_imgs = self.generator.predict([noise, sampled_labels])
|
||||
# Rescale images 0 - 1
|
||||
|
|
Загрузка…
Ссылка в новой задаче