CNTK/Tutorials/CNTK_206B_DCGAN.ipynb

689 строки
134 KiB
Plaintext
Исходник Обычный вид История

2017-03-17 08:53:39 +03:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNTK 206: Part B - Deep Convolutional GAN with MNIST data\n",
2017-03-17 08:53:39 +03:00
"\n",
"**Prerequisites**: We assume that you have successfully downloaded the MNIST data by completing the tutorial titled CNTK_103A_MNIST_DataLoader.ipynb.\n",
"\n",
"## Introduction\n",
"\n",
"[Generative models](https://en.wikipedia.org/wiki/Generative_model) have gained a [lot of attention](https://openai.com/blog/generative-models/) in deep learning community which has traditionally leveraged [discriminative models](https://en.wikipedia.org/wiki/Discriminative_model) for (semi-supervised) and unsupervised learning. \n",
"\n",
"## Overview\n",
2017-03-18 01:07:25 +03:00
"In the previous tutorial we introduce the original GAN implementation by [Goodfellow et al](https://arxiv.org/pdf/1406.2661v1.pdf) at NIPS 2014. This pioneering work has since then been extended and many techniques have been published amongst which the Deep Convolutional Generative Adversarial Network a.k.a. DCGAN has become the recommended launch pad in the community.\n",
2017-03-17 08:53:39 +03:00
"\n",
2017-03-18 01:07:25 +03:00
"In this tutorial, we introduce an implementation of the DCGAN with some well tested architectural constraints that improve stability in the GAN training: \n",
2017-03-17 08:53:39 +03:00
"\n",
2017-03-19 04:38:23 +03:00
"- We use [strided convolutions](https://en.wikipedia.org/wiki/Convolutional_neural_network) in the (discriminator) and [fractional-strided convolutions](https://arxiv.org/pdf/1603.07285v1.pdf) in the generator.\n",
2017-03-17 08:53:39 +03:00
"- We have used batch normalization in both the generator and the discriminator\n",
"- We have removed fully connected hidden layers for deeper architectures.\n",
"- We use ReLU activation in generator for all layers except for the output, which uses Tanh.\n",
"- We use LeakyReLU activation in the discriminator for all layers.\n"
]
},
{
"cell_type": "code",
2017-05-13 05:37:33 +03:00
"execution_count": 1,
"metadata": {
"collapsed": true
},
2017-03-17 08:53:39 +03:00
"outputs": [],
"source": [
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"\n",
"import cntk as C\n",
"import cntk.tests.test_utils\n",
"cntk.tests.test_utils.set_device_from_pytest_env() # (only needed for our build system)\n",
"C.cntk_py.set_fixed_random_seed(1) # fix a random seed for CNTK components\n",
2017-03-17 08:53:39 +03:00
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are two run modes:\n",
"- *Fast mode*: `isFast` is set to `True`. This is the default mode for the notebooks, which means we train for fewer iterations or train / test on limited data. This ensures functional correctness of the notebook though the models produced are far from what a completed training would produce.\n",
"\n",
"- *Slow mode*: We recommend the user to set this flag to `False` once the user has gained familiarity with the notebook content and wants to gain insight from running the notebooks for a longer period with different parameters for training. \n",
"\n",
"**Note**\n",
"If the `isFlag` is set to `False` the notebook will take a few hours on a GPU enabled machine. You can try fewer iterations by setting the `num_minibatches` to a smaller number say `20,000` which comes at the expense of quality of the generated images."
]
},
{
"cell_type": "code",
"execution_count": 2,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
2017-03-18 01:07:25 +03:00
"isFast = True"
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Reading\n",
"The input to the GAN will be a vector of random numbers. At the end of the traning, the GAN \"learns\" to generate images of hand written digits drawn from the [MNIST database](https://en.wikipedia.org/wiki/MNIST_database). We will be using the same MNIST data generated in tutorial 103A. A more in-depth discussion of the data format and reading methods can be seen in previous tutorials. For our purposes, just know that the following function returns an object that will be used to generate images from the MNIST dataset. Since we are building an unsupervised model, we only need to read in `features` and ignore the `labels`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2017-03-17 08:53:39 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2017-03-24 23:53:25 +03:00
"Data directory is ..\\Examples\\Image\\DataSets\\MNIST\n"
2017-03-17 08:53:39 +03:00
]
}
],
"source": [
"# Ensure the training data is generated and available for this tutorial\n",
"# We search in two locations in the toolkit for the cached MNIST data set.\n",
"\n",
"data_found = False\n",
"for data_dir in [os.path.join(\"..\", \"Examples\", \"Image\", \"DataSets\", \"MNIST\"),\n",
" os.path.join(\"data\", \"MNIST\")]:\n",
" train_file = os.path.join(data_dir, \"Train-28x28_cntk_text.txt\")\n",
" if os.path.isfile(train_file):\n",
" data_found = True\n",
" break\n",
" \n",
"if not data_found:\n",
" raise ValueError(\"Please generate the data by completing CNTK 103 Part A\")\n",
" \n",
"print(\"Data directory is {0}\".format(data_dir))"
]
},
{
"cell_type": "code",
"execution_count": 4,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_reader(path, is_training, input_dim, label_dim):\n",
" deserializer = C.io.CTFDeserializer(\n",
2017-03-17 08:53:39 +03:00
" filename = path,\n",
" streams = C.io.StreamDefs(\n",
" labels_unused = C.io.StreamDef(field = 'labels', shape = label_dim, is_sparse = False),\n",
" features = C.io.StreamDef(field = 'features', shape = input_dim, is_sparse = False\n",
2017-03-17 08:53:39 +03:00
" )\n",
" )\n",
" )\n",
" return C.io.MinibatchSource(\n",
2017-03-17 08:53:39 +03:00
" deserializers = deserializer,\n",
" randomize = is_training,\n",
" max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1\n",
2017-03-17 08:53:39 +03:00
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The random noise we will use to train the GAN is provided by the `noise_sample` function to generate random noise samples from a uniform distribution within the interval [-1, 1]."
]
},
{
"cell_type": "code",
"execution_count": 5,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"np.random.seed(123)\n",
"def noise_sample(num_samples):\n",
" return np.random.uniform(\n",
" low = -1.0,\n",
" high = 1.0,\n",
" size = [num_samples, g_input_dim]\n",
" ).astype(np.float32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Creation\n",
"\n",
"First we provide a brief recap of the basics of GAN. You may skip this block if you are familiar with CNTK 206A. \n",
"\n",
"A GAN network is composed of two sub-networks, one called the Generator ($G$) and the other Discriminator ($D$). \n",
"- The **Generator** takes random noise vector ($z$) as input and strives to output synthetic (fake) image ($x^*$) that is indistinguishable from the real image ($x$) from the MNIST dataset. \n",
"- The **Discriminator** strives to differentiate between the real image ($x$) and the fake ($x^*$) image.\n",
"\n",
"![](https://www.cntk.ai/jup/GAN_basic_flow.png)\n",
2017-03-17 08:53:39 +03:00
"\n",
2017-03-18 01:07:25 +03:00
"In each training iteration, the Generator produces more realistic fake images (in other words *minimizes* the difference between the real and generated counterpart) and the Discriminator *maximizes* the probability of assigning the correct label (real vs. fake) to both real examples (from training set) and the generated fake ones. The two conflicting objectives between the sub-networks ($G$ and $D$) leads to the GAN network (when trained) converge to an equilibrium, where the Generator produces realistic looking fake MNIST images and the Discriminator can at best randomly guess whether images are real or fake. The resulting Generator model once trained produces realistic MNIST image with the input being a random number. "
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model config\n",
"\n",
"First, we establish some of the architectural and training hyper-parameters for our model. \n",
"\n",
"- The generator network is fractional strided convolutional network. The input is a 100-dimensional random vector and the output of the generator is a flattened version of a 28 x 28 fake image. The discriminator is strided-convolution network. It takes as input the 784 dimensional output of the generator or a real MNIST image, reshapes into a 28 x 28 image format and outputs a single scalar - the estimated probability that the input image is a real MNIST image."
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model components\n",
"We build a computational graph for our model, one each for the generator and the discriminator. First, we establish some of the architectural parameters of our model. "
]
},
{
"cell_type": "code",
"execution_count": 6,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# architectural parameters\n",
"img_h, img_w = 28, 28\n",
"kernel_h, kernel_w = 5, 5 \n",
"stride_h, stride_w = 2, 2\n",
"\n",
"# Input / Output parameter of Generator and Discriminator\n",
"g_input_dim = 100\n",
2017-03-18 01:07:25 +03:00
"g_output_dim = d_input_dim = img_h * img_w\n",
"\n",
"# We expect the kernel shapes to be square in this tutorial and\n",
"# the strides to be of the same length along each data dimension\n",
"if kernel_h == kernel_w:\n",
" gkernel = dkernel = kernel_h\n",
"else:\n",
" raise ValueError('This tutorial needs square shaped kernel') \n",
" \n",
"if stride_h == stride_w:\n",
" gstride = dstride = stride_h\n",
"else:\n",
" raise ValueError('This tutorial needs same stride in all dims')"
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Helper functions\n",
"def bn_with_relu(x, activation=C.relu):\n",
" h = C.layers.BatchNormalization(map_rank=1)(x)\n",
2017-03-17 08:53:39 +03:00
" return C.relu(h)\n",
"\n",
"# We use param-relu function to use a leak=0.2 since CNTK implementation \n",
"# of Leaky ReLU is fixed to 0.01\n",
"def bn_with_leaky_relu(x, leak=0.2):\n",
" h = C.layers.BatchNormalization(map_rank=1)(x)\n",
2017-03-17 08:53:39 +03:00
" r = C.param_relu(C.constant((np.ones(h.shape)*leak).astype(np.float32)), h)\n",
" return r"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Generator**\n",
"\n",
2017-03-18 01:07:25 +03:00
"The generator takes a 100-dimensional random vector (for starters) as input ($z$) and the outputs a 784 dimensional vector, corresponding to a flattened version of a 28 x 28 fake (synthetic) image ($x^*$). In this tutorial, we use fractionally strided convolutions (a.k.a ConvolutionTranspose) with ReLU activations except for the last layer. We use a tanh activation on the last layer to make sure that the output of the generator function is confined to the interval [-1, 1]. The use of ReLU and tanh activation functions are key in addition to using the fractionally strided convolutions."
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "code",
"execution_count": 8,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def convolutional_generator(z):\n",
" with C.layers.default_options(init=C.normal(scale=0.02)):\n",
2017-03-17 08:53:39 +03:00
" print('Generator input shape: ', z.shape)\n",
"\n",
" s_h2, s_w2 = img_h//2, img_w//2 #Input shape (14,14)\n",
" s_h4, s_w4 = img_h//4, img_w//4 # Input shape (7,7)\n",
" gfc_dim = 1024\n",
" gf_dim = 64\n",
"\n",
" h0 = C.layers.Dense(gfc_dim, activation=None)(z)\n",
2017-03-17 08:53:39 +03:00
" h0 = bn_with_relu(h0)\n",
" print('h0 shape', h0.shape)\n",
"\n",
" h1 = C.layers.Dense([gf_dim * 2, s_h4, s_w4], activation=None)(h0)\n",
2017-03-17 08:53:39 +03:00
" h1 = bn_with_relu(h1)\n",
" print('h1 shape', h1.shape)\n",
"\n",
" h2 = C.layers.ConvolutionTranspose2D(gkernel,\n",
2017-03-17 08:53:39 +03:00
" num_filters=gf_dim*2,\n",
" strides=gstride,\n",
" pad=True,\n",
" output_shape=(s_h2, s_w2),\n",
" activation=None)(h1)\n",
" h2 = bn_with_relu(h2)\n",
" print('h2 shape', h2.shape)\n",
"\n",
" h3 = C.layers.ConvolutionTranspose2D(gkernel,\n",
2017-03-17 08:53:39 +03:00
" num_filters=1,\n",
" strides=gstride,\n",
" pad=True,\n",
" output_shape=(img_h, img_w),\n",
" activation=C.sigmoid)(h2)\n",
" print('h3 shape :', h3.shape)\n",
"\n",
" return C.reshape(h3, img_h * img_w)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Discriminator**\n",
"\n",
"The discriminator takes as input ($x^*$) the 784 dimensional output of the generator or a real MNIST image, re-shapes the input to a 28 x 28 image and outputs the estimated probability that the input image is a real MNIST image. The network is modeled using strided convolution with Leaky ReLU activation except for the last layer. We use a sigmoid activation on the last layer to ensure the discriminator output lies in the inteval of [0,1]."
]
},
{
"cell_type": "code",
"execution_count": 9,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def convolutional_discriminator(x):\n",
" with C.layers.default_options(init=C.normal(scale=0.02)):\n",
2017-03-17 08:53:39 +03:00
"\n",
" dfc_dim = 1024\n",
" df_dim = 64\n",
"\n",
" print('Discriminator convolution input shape', x.shape)\n",
" x = C.reshape(x, (1, img_h, img_w))\n",
"\n",
" h0 = C.layers.Convolution2D(dkernel, 1, strides=dstride)(x)\n",
2017-03-17 08:53:39 +03:00
" h0 = bn_with_leaky_relu(h0, leak=0.2)\n",
" print('h0 shape :', h0.shape)\n",
"\n",
" h1 = C.layers.Convolution2D(dkernel, df_dim, strides=dstride)(h0)\n",
2017-03-17 08:53:39 +03:00
" h1 = bn_with_leaky_relu(h1, leak=0.2)\n",
" print('h1 shape :', h1.shape)\n",
"\n",
" h2 = C.layers.Dense(dfc_dim, activation=None)(h1)\n",
2017-03-17 08:53:39 +03:00
" h2 = bn_with_leaky_relu(h2, leak=0.2)\n",
" print('h2 shape :', h2.shape)\n",
"\n",
" h3 = C.layers.Dense(1, activation=C.sigmoid)(h2)\n",
2017-03-17 08:53:39 +03:00
" print('h3 shape :', h3.shape)\n",
"\n",
" return h3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use a minibatch size of 128 and a fixed learning rate of 0.0002 for training. In the fast mode (`isFast = True`) we verify only functional correctness with 5000 iterations. \n",
"\n",
"**Note**: In the slow mode, the results look a lot better but it requires in the order of 10 minutes depending on your hardware. In general, the more number of minibatches one trains, the better is the fidelity of the generated images."
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# training config\n",
"minibatch_size = 128\n",
"num_minibatches = 5000 if isFast else 10000\n",
"lr = 0.0002\n",
"momentum = 0.5 #equivalent to beta1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build the graph\n",
"\n",
2017-03-18 01:07:25 +03:00
"The rest of the computational graph is mostly responsible for coordinating the training algorithms and parameter updates, which is particularly tricky with GANs for couple reasons. The GANs are sensitive to the choice of learner and the parameters. Many of the parameters chosen here are based on many hard learnt lessons from the community. You may directly go to the code if you have read the basic GAN tutorial. \n",
2017-03-17 08:53:39 +03:00
"\n",
"- First, the discriminator must be used on both the real MNIST images and fake images generated by the generator function. One way to represent this in the computational graph is to create a clone of the output of the discriminator function, but with substituted inputs. Setting `method=share` in the `clone` function ensures that both paths through the discriminator model use the same set of parameters.\n",
"\n",
"\n",
"- Second, we need to update the parameters for the generator and discriminator model separately using the gradients from different loss functions. We can get the parameters for a `Function` in the graph with the `parameters` attribute. However, when updating the model parameters, update only the parameters of the respective models while keeping the other parameters unchanged. In other words, when updating the generator we will update only the parameters of the $G$ function while keeping the parameters of the $D$ function fixed and vice versa.\n",
"\n",
"### Training the Model\n",
"The code for training the GAN very closely follows the algorithm as presented in the [original NIPS 2014 paper](https://arxiv.org/pdf/1406.2661v1.pdf). In this implementation, we train $D$ to maximize the probability of assigning the correct label (fake vs. real) to both training examples and the samples from $G$. In other words, $D$ and $G$ play the following two-player minimax game with the value function $V(G,D)$:\n",
"\n",
"$$\n",
" \\min_G \\max_D V(D,G)= \\mathbb{E}_{x}[ log D(x) ] + \\mathbb{E}_{z}[ log(1 - D(G(z))) ]\n",
"$$\n",
"\n",
"At the optimal point of this game the generator will produce realistic looking data while the discriminator will predict that the generated image is indeed fake with a probability of 0.5. The [algorithm referred below](https://arxiv.org/pdf/1406.2661v1.pdf) is implemented in this tutorial.\n",
"\n",
"![](https://www.cntk.ai/jup/GAN_goodfellow_NIPS2014.png)"
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "code",
"execution_count": 11,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def build_graph(noise_shape, image_shape, generator, discriminator):\n",
" input_dynamic_axes = [C.Axis.default_batch_axis()]\n",
2017-05-12 03:00:07 +03:00
" Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes)\n",
" X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes)\n",
2017-03-17 08:53:39 +03:00
" X_real_scaled = X_real / 255.0\n",
"\n",
" # Create the model function for the generator and discriminator models\n",
" X_fake = generator(Z)\n",
" D_real = discriminator(X_real_scaled)\n",
" D_fake = D_real.clone(\n",
" method = 'share',\n",
" substitutions = {X_real_scaled.output: X_fake.output}\n",
" )\n",
"\n",
" # Create loss functions and configure optimazation algorithms\n",
" G_loss = 1.0 - C.log(D_fake)\n",
" D_loss = -(C.log(D_real) + C.log(1.0 - D_fake))\n",
"\n",
" G_learner = C.adam(\n",
2017-03-17 08:53:39 +03:00
" parameters = X_fake.parameters,\n",
" lr = C.learning_parameter_schedule_per_sample(lr),\n",
" momentum = C.momentum_schedule(momentum)\n",
2017-03-17 08:53:39 +03:00
" )\n",
" D_learner = C.adam(\n",
2017-03-17 08:53:39 +03:00
" parameters = D_real.parameters,\n",
" lr = C.learning_parameter_schedule_per_sample(lr),\n",
" momentum = C.momentum_schedule(momentum)\n",
2017-03-17 08:53:39 +03:00
" )\n",
"\n",
" # Instantiate the trainers\n",
" G_trainer = C.Trainer(X_fake,\n",
" (G_loss, None),\n",
" G_learner)\n",
" D_trainer = C.Trainer(D_real,\n",
" (D_loss, None),\n",
" D_learner)\n",
2017-03-17 08:53:39 +03:00
"\n",
" return X_real, X_fake, Z, G_trainer, D_trainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the value functions defined we proceed to iteratively train the GAN model. The training of the model can take significantly long depending on the hardware especially if `isFast` flag is turned off."
2017-03-17 08:53:39 +03:00
]
},
{
"cell_type": "code",
"execution_count": 12,
2017-03-17 08:53:39 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def train(reader_train, generator, discriminator):\n",
" X_real, X_fake, Z, G_trainer, D_trainer = \\\n",
" build_graph(g_input_dim, d_input_dim, generator, discriminator)\n",
"\n",
" # print out loss for each model for upto 25 times\n",
" print_frequency_mbsize = num_minibatches // 25\n",
" \n",
" print(\"First row is Generator loss, second row is Discriminator loss\")\n",
" pp_G = C.logging.ProgressPrinter(print_frequency_mbsize)\n",
" pp_D = C.logging.ProgressPrinter(print_frequency_mbsize)\n",
2017-03-17 08:53:39 +03:00
"\n",
" k = 2\n",
"\n",
" input_map = {X_real: reader_train.streams.features}\n",
" for train_step in range(num_minibatches):\n",
"\n",
" # train the discriminator model for k steps\n",
" for gen_train_step in range(k):\n",
" Z_data = noise_sample(minibatch_size)\n",
" X_data = reader_train.next_minibatch(minibatch_size, input_map)\n",
" if X_data[X_real].num_samples == Z_data.shape[0]:\n",
" batch_inputs = {X_real: X_data[X_real].data, Z: Z_data}\n",
2017-03-17 08:53:39 +03:00
" D_trainer.train_minibatch(batch_inputs)\n",
"\n",
" # train the generator model for a single step\n",
" Z_data = noise_sample(minibatch_size)\n",
" batch_inputs = {Z: Z_data}\n",
"\n",
" G_trainer.train_minibatch(batch_inputs)\n",
" G_trainer.train_minibatch(batch_inputs)\n",
"\n",
" pp_G.update_with_trainer(G_trainer)\n",
" pp_D.update_with_trainer(D_trainer)\n",
"\n",
2017-03-24 23:53:25 +03:00
" G_trainer_loss = G_trainer.previous_minibatch_loss_average\n",
2017-03-17 08:53:39 +03:00
"\n",
" return Z, X_fake, G_trainer_loss"
]
},
{
"cell_type": "code",
"execution_count": 13,
2017-03-17 08:53:39 +03:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generator input shape: (100,)\n",
"h0 shape (1024,)\n",
"h1 shape (128, 7, 7)\n",
"h2 shape (128, 14, 14)\n",
"h3 shape : (1, 28, 28)\n",
"Discriminator convolution input shape (784,)\n",
"h0 shape : (1, 12, 12)\n",
"h1 shape : (64, 4, 4)\n",
"h2 shape : (1024,)\n",
"h3 shape : (1,)\n",
2017-03-18 01:07:25 +03:00
"First row is Generator loss, second row is Discriminator loss\n",
" Minibatch[ 1- 200]: loss = 1.724305 * 25600;\n",
" Minibatch[ 1- 200]: loss = 1.164579 * 25600;\n",
" Minibatch[ 201- 400]: loss = 1.738649 * 25600;\n",
" Minibatch[ 201- 400]: loss = 1.190806 * 25600;\n",
" Minibatch[ 401- 600]: loss = 1.743701 * 25600;\n",
" Minibatch[ 401- 600]: loss = 1.178363 * 25600;\n",
" Minibatch[ 601- 800]: loss = 1.752785 * 25600;\n",
" Minibatch[ 601- 800]: loss = 1.173294 * 25600;\n",
" Minibatch[ 801-1000]: loss = 1.750716 * 25600;\n",
" Minibatch[ 801-1000]: loss = 1.159751 * 25600;\n",
" Minibatch[1001-1200]: loss = 1.752001 * 25600;\n",
" Minibatch[1001-1200]: loss = 1.162093 * 25600;\n",
" Minibatch[1201-1400]: loss = 1.755410 * 25600;\n",
" Minibatch[1201-1400]: loss = 1.162755 * 25600;\n",
" Minibatch[1401-1600]: loss = 1.757578 * 25600;\n",
" Minibatch[1401-1600]: loss = 1.157969 * 25600;\n",
" Minibatch[1601-1800]: loss = 1.760176 * 25600;\n",
" Minibatch[1601-1800]: loss = 1.169943 * 25600;\n",
" Minibatch[1801-2000]: loss = 1.752263 * 25600;\n",
" Minibatch[1801-2000]: loss = 1.174252 * 25600;\n",
" Minibatch[2001-2200]: loss = 1.754073 * 25600;\n",
" Minibatch[2001-2200]: loss = 1.182062 * 25600;\n",
" Minibatch[2201-2400]: loss = 1.753659 * 25600;\n",
" Minibatch[2201-2400]: loss = 1.191013 * 25600;\n",
" Minibatch[2401-2600]: loss = 1.744683 * 25600;\n",
" Minibatch[2401-2600]: loss = 1.204642 * 25600;\n",
" Minibatch[2601-2800]: loss = 1.739423 * 25600;\n",
" Minibatch[2601-2800]: loss = 1.210787 * 25600;\n",
" Minibatch[2801-3000]: loss = 1.741835 * 25600;\n",
" Minibatch[2801-3000]: loss = 1.211336 * 25600;\n",
" Minibatch[3001-3200]: loss = 1.738885 * 25600;\n",
" Minibatch[3001-3200]: loss = 1.214446 * 25600;\n",
" Minibatch[3201-3400]: loss = 1.737622 * 25600;\n",
" Minibatch[3201-3400]: loss = 1.219743 * 25600;\n",
" Minibatch[3401-3600]: loss = 1.738669 * 25600;\n",
" Minibatch[3401-3600]: loss = 1.209985 * 25600;\n",
" Minibatch[3601-3800]: loss = 1.746416 * 25600;\n",
" Minibatch[3601-3800]: loss = 1.203059 * 25600;\n",
" Minibatch[3801-4000]: loss = 1.745153 * 25600;\n",
" Minibatch[3801-4000]: loss = 1.232357 * 25600;\n",
" Minibatch[4001-4200]: loss = 1.725411 * 25600;\n",
" Minibatch[4001-4200]: loss = 1.251152 * 25600;\n",
" Minibatch[4201-4400]: loss = 1.731132 * 25600;\n",
" Minibatch[4201-4400]: loss = 1.241505 * 25600;\n",
" Minibatch[4401-4600]: loss = 1.737252 * 25600;\n",
" Minibatch[4401-4600]: loss = 1.232204 * 25600;\n",
" Minibatch[4601-4800]: loss = 1.735165 * 25600;\n",
" Minibatch[4601-4800]: loss = 1.233294 * 25600;\n",
" Minibatch[4801-5000]: loss = 1.732835 * 25600;\n",
" Minibatch[4801-5000]: loss = 1.228176 * 25600;\n"
2017-03-17 08:53:39 +03:00
]
}
],
"source": [
"reader_train = create_reader(train_file, True, d_input_dim, label_dim=10)\n",
"\n",
"# G_input, G_output, G_trainer_loss = train(reader_train, dense_generator, dense_discriminator)\n",
"G_input, G_output, G_trainer_loss = train(reader_train,\n",
" convolutional_generator,\n",
" convolutional_discriminator)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
2017-03-17 08:53:39 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training loss of the generator is: 1.77\n"
2017-03-17 08:53:39 +03:00
]
}
],
"source": [
"# Print the generator loss \n",
"print(\"Training loss of the generator is: {0:.2f}\".format(G_trainer_loss))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generating Fake (Synthetic) Images\n",
"\n",
"Now that we have trained the model, we can create fake images simply by feeding random noise into the generator and displaying the outputs. Below are a few images generated from random samples. To get a new set of samples, you can re-run the last cell."
]
},
{
"cell_type": "code",
"execution_count": 15,
2017-03-17 08:53:39 +03:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWUAAAEECAYAAADwLSVEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvVdw3Od1Pvxs771XAItCsAAEm0RKLBbFJpGSLFlukj1J\nbEdyJmWSS+UiyWTGSSb2hSeTdhWXeBJf2LJlK7asYoqiJVEsEAkCIDqwi+299/K/4HeOFvryxfx/\nGax4sWcGI5EAF7/ze9/3vKc85zmCTqfTQV/60pe+9OW+EOEn/QB96Utf+tKXj6RvlPvSl7705T6S\nvlHuS1/60pf7SPpGuS996Utf7iPpG+W+9KUvfbmPpG+U+9KXvvTlPpK+Ue5LX/rSl/tI+ka5L33p\nS1/uI+kb5b70pS99uY+kb5T70pe+9OU+kr5R7ktf+tKX+0jEvf6FbrcbEokEqVQKNpsNOp0Oq6ur\neOKJJ/Dzn/8c+/btQzweh1AohNfrxZUrV3D27FlcvHgRO3fuRKFQQKFQgM/nw6VLl/DYY4/h6tWr\n8Hg8aDQaiMViOHjwIF5//XWcOnUKN2/exIsvvogjR47g/fffx5/+6Z9CJpNtq44ejwftdhuZTAYu\nlwt6vR4rKys4f/48Xn31Vezfvx/xeBwCgQBerxdXr17FmTNnWMdisYhCoYDh4WFcvnwZZ8+exQcf\nfAC328067t+/H6+99hrOnTuH27dv48UXX8S5c+cwNzeHxx9/HBqNZlt1BACXywWVSoVEIgGr1Qq9\nXo/l5WWcP38ev/zlLzE1NYVUKgUAcDqduHbtGq/l2NgYisUiSqUSr+Xp06cxPT0Np9OJRqOBeDyO\nyclJXLp0CZ/61KewsLCAr371qzh37hyWlpZw7Ngx2Gy2bdVxcHAQQqEQyWQSdrsdBoMBS0tLuHDh\nAn7xi1/wWgJ39/YHH3zw3+7X0dFRXL58GWfOnMH09DRsNhsajQaSyST27t2Lixcv4tFHH8Xs7Cy+\n9KUv4ejRo5ifn8fzzz8PnU63rToCwMjICMrlMtLpNEZGRgAAhUIBe/fuxRtvvIGnn34ab7/9NiYn\nJ1lfo9GI2dlZnDx5Er/4xS9w7NgxLCwswG63QyAQIJFIYHx8HG+99RYee+wxXL58GZOTkwiFQjh5\n8iS++MUv4rXXXsNLL7207WcSALxeLzQaDSKRCKxWKwwGAxYWFvD444/jF7/4BQ4ePIhYLIZ2uw2X\ny4Vr167xc4+NjaFaraJQKMBut+P999/HyZMnce3aNXg8HrRaLSSTSRw4cACvv/46Tp48idu3b+OF\nF17A0aNHMT09jRdeeOGe9Oy5UY5EIpBKpWg0GggEAhCL7z7C7OwsBALB3YcSi9Fut1GpVCCVSlGr\n1aBQKCAUCtHpdNBut9FsNiGRSNBsNvnv6d+2Wi1oNBo4nU4sLy8DAPR6PYaHh1Gr1bZ9A4TDYYjF\nYjSbTQQCAYTDYbTbbUxPT6PT6aBaraLT6aDVaiGfz0MsFqPRaEAmk6FUKqFaraLdbqNQKEAoFCIQ\nCKBcLiOTyaDZbKJarSKZTEIgEEAsFkMgEKBer6NaraLZbKJQKPTEKKfTaeRyOdRqNdRqNcRiMbRa\nLSwuLvIaCgQCtFotlMtlyOVyVKtVSCQSFItF1Ot1vrxEIhGi0ShqtRry+TyazSbq9ToKhQKUSiUk\nEgmkUil/plqtRrFY3HajHAqFIJFI0Gg0sLm5iUgkgna7jdnZWQDgtaL1omeUy+Wo1WpoNBr87AaD\nAe12G3q9HkqlEpVKBQqFAp1OB2azGWq1GgqFAiKRCDqdDh6PB+l0uidGeW1tjffsysoKAEAqlWJ2\ndhZisRjZbBZisRiVSgWNRgNyuRxSqRRSqRS5XA5yuRwikQhqtZrXXiKRoFqtAgDK5TJkMhmvZafT\ngcFgwOTkZE/OJAAEg0GoVCrUajWUy2UEg0G0Wi3MzMwAuLuWrVYLnU4H9XodarUaMpkMWq0WOp0O\nSqUSIpEIGo0GbrcblUoF7XYbnU4HIpEIUqkUrVYLSqUSGo0GSqUSACCRSKDValEqle5Jz56nL9rt\nNmq1GlqtFvR6PaxWK0QiEcxmM4C7RpUMr0gk4k1dr9fR6XT4izYQbQCSVqsFgUCASqWCUqmERqPB\nv7darUIkEm27jp1OB41Ggw8g6WixWCASiSCTySAQCCAUCiGTyVgP0pUuGZlMhkajAZVKxQaYjDBt\ngHq9jmaziU6ng1KphFgsxhfddkuj0eDfrdPpWD+j0QgAEIlEvGnFYjHq9TobOKlUynoqFAo0m03o\n9XrWjb5Pxq3dbvNaNptN5PN5KBSKbdex3W6jXq/zfrXZbFv2q0Qi4Z8l3YRCIRqNBkQiEYRCIb8j\n4K6hk8lkWxwJ4K5BoPcpEonYuZBKpduuI/DRnu10OrBYLDCbzZDJZDCbzeh0OpDL5VscoGaziVar\nhVarBZlMhnq9DgCo1+usR6PRgEAgQLvd5r3caDT4jLZaLRQKhS3vcLt1LJfLaDabvF/FYjGsViuA\nu+vX7RjSs4rFYmg0GgwMDGBychI7duzgs0vnktZZKBTyZUw2q91uo1wu37Pt+URyyqS4UCiERCKB\nXC6H0WiEWCxGqVRCpVLhRadbq1qtolgssleWyWRQr9eRz+f5+7VaDQKBgG9jOswA+OW0Wq1t16/7\nsAmFQojFYiiVSgwODvICCoVCSKVSqNVqqFQqNq70b7qNAd26arUaGo0GOp0ONpsNBoOBb/dWq4V2\nu90zg0zSvYmlUinkcjnsdjuEwo+2FunfvSnFYjFfUAaDAUqlEk6nE1qtFlqtFnq9HiqVCsBdI0yH\nqVKpIJVKIZ1O90S/7gtfJBLxfnU4HGw8RSIR5HI5dDode4HtdhsKhQIKhYI9ZbFYjLGxMQwMDECt\nVkMgEEAkEkGpVLIHJRAIIBAIOCKii6gX0n0uyet1uVwAwAZIo9Gwk0DPRpepUqmEVCrFwMAAdu3a\nBZPJxJcyvRv6PSKRCJ1OB8Vikfd9L4XWTiaTwel0QigUsqMklUpZR7pAzWYzjEYjVCoV1Go1SqUS\nTCYT5HI57wHylqVSKV9GxWIRqVQKxWJxy176n6Tn6QsAfGDJmJIhIQ+DvOFms8npC9qoFCqWSiX+\nDLp16VYjY6dQKPizpVIptFotarVaT3QUCAScqqAFI6PUarV444vFYqjVavasgLsGSy6XQygUQqfT\nYXx8HI1GA16vl8P9AwcOYGVlhTdEu92GUqmEx+Pp2San5xUIBKjVaryWdLjpZ0hPhUKBRqPBh1en\n00Gv12N8fBzVahUHDhxAu92GwWBApVKBVquFxWKBVqvlg1wul5FIJJDL5ZDP5+F0OrdVRzqoH1/L\n7r8nY006djodjnjo71QqFRwOB44ePQqZTIb19XVkMhlIJBLIZDJoNBpIpVJIJBK0Wi1ew3w+v636\nfVxXiriEQiE0Gs0WPcViMT9vp9NhL5EuJYVCAaPRiAceeABmsxnFYhHRaJT3ukKhgEwm+39dzuVy\nuSfptm4dq9UqX4DAR45Qt04KhQJ6vR5qtZrTFZRKpDQcRfYCgQASiQQikQgqlQoymQxisRgrKytQ\nKpUc1d6L9NwoUwjearVgtVphMpmwurrKeWaHw4FWq4VqtQq5XI5SqQStVstGp9lsQiaTYXBwEIuL\nizCZTADAaYB0Oo1oNIpgMIhYLIZyucxGI5FIQK1Wb7uO5AlTKGg0GrG4uIhUKoVarcapF8qZRiIR\nHDhwAEKhEEqlktMehw8fRi6Xw3PPPYdXX30VDzzwAABgYWEBhw8fxssvv8wGmW53iUTSEx0B8EFt\nt9swGo0wGAxYXl5GLpdDpVIBAI5WyGMgg2a32zEwMIAdO3bg8OHDEAqFOH/+PDQaDUZHRxGPx3Hj\nxg1MTU3h3XffZYPeaDRQLBY5b7/dQhcKAJhMJhiNRiwvL3PU1ul0UKvVUKlUkE6nOddfLpdRLBZR\nLpdRr9chEolgt9vhcrn
2017-03-17 08:53:39 +03:00
"text/plain": [
"<matplotlib.figure.Figure at 0x299921933c8>"
2017-03-17 08:53:39 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_images(images, subplot_shape):\n",
" plt.style.use('ggplot')\n",
" fig, axes = plt.subplots(*subplot_shape)\n",
" for image, ax in zip(images, axes.flatten()):\n",
" ax.imshow(image.reshape(28, 28), vmin=0, vmax=1.0, cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()\n",
"\n",
"\n",
"noise = noise_sample(36)\n",
"images = G_output.eval({G_input: noise})\n",
"plot_images(images, subplot_shape=[6, 6])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Larger number of iterations should generate more realistic looking MNIST images. A sampling of such generated images are shown below.\n",
"\n",
"![](http://www.cntk.ai/jup/cntk206B_dcgan_result.jpg)\n",
2017-03-17 08:53:39 +03:00
"\n",
"**Note**: It takes a large number of iterations to capture a representation of the real world signal. Even simple dense networks can be quite effective in modelling data albeit MNIST is a relatively simple dataset as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"**Suggested Task**\n",
"\n",
"- Please refer to several hacks presented in this [article](https://github.com/soumith/ganhacks) by Soumith Chintala, Facebook Research. While some of the hacks have been incorporated in this notebook, there are several others I would suggest that you try out.\n",
"\n",
"- Performance is a key aspect to deep neural networks training. Study how the changing the minibatch sizes impact the performance both with regards to quality of the generated images and the time it takes to train a model.\n",
"\n",
"- Try generating fake images using the CIFAR-10 data set as the training data. How does the network above performs? There are other variation in GAN, such as [conditional GAN](https://arxiv.org/pdf/1411.1784.pdf) where the network is additionally conditioned on the input label. Try implementing the labels.\n"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
2017-03-17 08:53:39 +03:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
2017-03-17 08:53:39 +03:00
}
},
"nbformat": 4,
"nbformat_minor": 1
2017-03-17 08:53:39 +03:00
}