2016-10-25 06:42:43 +03:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-25 23:52:23 +03:00
"# CNTK 201B: Hands On Labs Image Recognition"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-11-18 13:48:30 +03:00
"This hands-on lab shows how to implement image recognition task using [convolution network][] with CNTK v2 Python API. You will start with a basic feedforward CNN architecture in order to classify Cifar dataset, then you will keep adding advanced feature to your network. Finally, you will implement a VGG net and residual net similar to the one that won ImageNet competition but smaller in size.\n",
2016-11-02 04:03:02 +03:00
"\n",
"[convolution network]:https://en.wikipedia.org/wiki/Convolutional_neural_network\n",
2016-10-25 06:42:43 +03:00
"\n",
"## Introduction\n",
"\n",
"In this hands-on, you will practice the following:\n",
"\n",
2016-10-26 09:34:40 +03:00
"* Understanding subset of CNTK python API needed for image classification task.\n",
2016-10-25 06:42:43 +03:00
"* Write a custom convolution network to classify Cifar dataset.\n",
"* Modifying the network structure by adding:\n",
2016-11-02 04:03:02 +03:00
" * [Dropout][] layer.\n",
2016-10-25 06:42:43 +03:00
" * Batchnormalization layer.\n",
2016-11-02 04:03:02 +03:00
"* Implement a [VGG][] style network.\n",
2016-10-25 06:42:43 +03:00
"* Introduction to Residual Nets (RESNET).\n",
2016-11-02 04:03:02 +03:00
"* Implement and train [RESNET network][].\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-02 04:03:02 +03:00
"[RESNET network]:https://github.com/Microsoft/CNTK/wiki/Hands-On-Labs-Image-Recognition\n",
"[VGG]:http://www.robots.ox.ac.uk/~vgg/research/very_deep/\n",
"[Dropout]:https://en.wikipedia.org/wiki/Dropout_(neural_networks)\n",
2016-10-25 06:42:43 +03:00
"\n",
"## Prerequisites\n",
"\n",
2016-11-18 13:48:30 +03:00
"CNTK 201A hands-on lab, in which you will download and prepare Cifar dataset is a prerequisites for this lab. This tutorial depends on CNTK v2, so before starting this lab you will need to install CNTK v2. Furthermore, all the tutorials in this lab are done in python, therefore, you will need a basic knowledge of Python.\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 09:34:40 +03:00
"CNTK 102 lab is recommended but not a prerequisites for this tutorials. However, a basic understanding of Deep Learning is needed.\n",
"\n",
"## Dataset\n",
"\n",
2016-10-26 11:30:39 +03:00
"You will use Cifar 10 dataset, from https://www.cs.toronto.edu/~kriz/cifar.html, during this tutorials. The dataset contains 50000 training images and 10000 test images, all images are 32x32x3. Each image is classified as one of 10 classes as shown below:\n",
2016-10-26 09:34:40 +03:00
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/cifar-10.png\", width=500, height=500>\n",
2016-10-26 09:34:40 +03:00
"\n",
"The above image is from: https://www.cs.toronto.edu/~kriz/cifar.html\n",
2016-10-25 06:42:43 +03:00
"\n",
"## Convolution Neural Network (CNN)\n",
"\n",
"Convolution Neural Network (CNN) is a feedforward network comprise of a bunch of layers in such a way that the output of one layer is fed to the next layer (There are more complex architecture that skip layers, we will discuss one of those at the end of this lab). Usually, CNN start with alternating between convolution layer and pooling layer (downsample), then end up with fully connected layer for the classification part.\n",
"\n",
"### Convolution layer\n",
"\n",
2016-10-26 09:34:40 +03:00
"Convolution layer consist of multiple 2D convolution kernels applied on the input image or the previous layer, each convolution kernel output a feature map.\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/Conv2D.png\">\n",
2016-10-25 06:42:43 +03:00
"\n",
"The stack of feature maps output are the input to the next layer.\n",
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/Conv2DFeatures.png\">\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 20:16:27 +03:00
"> Gradient-Based Learning Applied to Document Recognition, Proceedings of the IEEE, 86(11):2278-2324, November 1998\n",
"> Y. LeCun, L. Bottou, Y. Bengio and P. Haffner\n",
"\n",
2016-10-26 09:34:40 +03:00
"#### In CNTK:\n",
"\n",
2016-11-02 04:03:02 +03:00
"Here the [convolution][] layer in Python:\n",
2016-10-26 09:34:40 +03:00
"\n",
"```python\n",
"def Convolution(filter_shape, # e.g. (3,3)\n",
" num_filters, # e.g. 64\n",
" activation, # relu or None...etc.\n",
" init, # Random initialization\n",
" pad, # True or False\n",
" strides) # strides e.g. (1,1)\n",
"```\n",
"\n",
2016-11-02 04:03:02 +03:00
"[convolution]:https://www.cntk.ai/pythondocs/layerref.html#convolution\n",
"\n",
2016-10-25 06:42:43 +03:00
"### Pooling layer\n",
"\n",
"In most CNN vision architecture, each convolution layer is succeeded by a pooling layer, so they keep alternating until the fully connected layer. \n",
"\n",
"The purpose of the pooling layer is as follow:\n",
"\n",
"* Reduce the dimensionality of the previous layer, which speed up the network.\n",
"* Provide a limited translation invariant.\n",
"\n",
"Here an example of max pooling with a stride of 2:\n",
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/MaxPooling.png\", 200,200>\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 09:34:40 +03:00
"#### In CNTK:\n",
"\n",
2016-11-02 04:03:02 +03:00
"Here the [pooling][] layer in Python:\n",
2016-10-26 09:34:40 +03:00
"\n",
"```python\n",
"\n",
"# Max pooling\n",
"def MaxPooling(filter_shape, # e.g. (3,3)\n",
" strides, # (2,2)\n",
" pad) # True or False\n",
"\n",
"# Average pooling\n",
"def AveragePooling(filter_shape, # e.g. (3,3)\n",
" strides, # (2,2)\n",
" pad) # True or False\n",
"```\n",
"\n",
2016-11-02 04:03:02 +03:00
"[pooling]:https://www.cntk.ai/pythondocs/layerref.html#maxpooling-averagepooling\n",
"\n",
2016-10-25 06:42:43 +03:00
"### Dropout layer\n",
"\n",
"Dropout layer takes a probability value as an input, the value is called the dropout rate. Let's say the dropu rate is 0.5, what this layer does it pick at random 50% of the nodes from the previous layer and drop them out of the nework. This behavior help regularize the network.\n",
"\n",
2016-10-26 20:16:27 +03:00
"> Dropout: A Simple Way to Prevent Neural Networks from Overfitting\n",
"> Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, Ruslan Salakhutdinov\n",
"\n",
"\n",
"#### In CNTK:\n",
"\n",
"Dropout layer in Python:\n",
"\n",
2016-10-26 09:34:40 +03:00
"```python\n",
"\n",
"# Dropout\n",
"def Dropout(prob) # dropout rate e.g. 0.5\n",
"```\n",
"\n",
2016-10-25 06:42:43 +03:00
"### Batch normalization (BN)\n",
"\n",
"Batch normalization is a way to make the input to each layer has zero mean and unit variance. BN help the network converge faster and keep the input of each layer around zero. BN has two learnable parameters called gamma and beta, the purpose of those parameters is for the network to decide for itself if the normalized input is what is best or the raw input.\n",
"\n",
2016-10-26 20:16:27 +03:00
"> Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift\n",
"> Sergey Ioffe, Christian Szegedy\n",
"\n",
"#### In CNTK:\n",
"\n",
2016-11-02 04:03:02 +03:00
"[Batch normalization][] layer in Python:\n",
2016-10-26 20:16:27 +03:00
"\n",
2016-10-26 09:34:40 +03:00
"```python\n",
"\n",
"# Batch normalization\n",
"def BatchNormalization(map_rank) # For image map_rank=1\n",
"```\n",
"\n",
2016-11-02 04:03:02 +03:00
"[Batch normalization]:https://www.cntk.ai/pythondocs/layerref.html#batchnormalization-layernormalization-stabilizer\n",
"\n",
2016-10-25 06:42:43 +03:00
"## Computational Network Toolkit (CNTK)\n",
"\n",
"CNTK is a highly flexible computation graphs, each node take inputs as tensors and produce tensors as the result of the computation. Each node is exposed in Python API, which give you the flexibility of creating any custom graphs, you can also define your own node in Python or C++ using CPU, GPU or both.\n",
"\n",
"For Deep learning, you can use the low level API directly or you can use CNTK layered API. We will start with the low level API, then switch to the layered API in this lab.\n",
"\n",
"So let's first import the needed modules for this lab."
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 17,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
2016-12-19 15:49:05 +03:00
"from __future__ import print_function\n",
2016-10-26 09:34:40 +03:00
"import os\n",
2016-10-25 06:42:43 +03:00
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import math\n",
2016-10-26 11:30:39 +03:00
"\n",
"from cntk.blocks import default_options\n",
2016-10-26 09:34:40 +03:00
"from cntk.layers import Convolution, MaxPooling, AveragePooling, Dropout, BatchNormalization, Dense\n",
2016-10-26 11:30:39 +03:00
"from cntk.models import Sequential, LayerStack\n",
2016-10-25 06:42:43 +03:00
"from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs\n",
2016-11-02 04:03:02 +03:00
"from cntk.initializer import glorot_uniform, he_normal\n",
2016-10-25 06:42:43 +03:00
"from cntk import Trainer\n",
2016-11-02 21:29:44 +03:00
"from cntk.learner import momentum_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule\n",
2016-11-02 04:03:02 +03:00
"from cntk.ops import cross_entropy_with_softmax, classification_error, relu, input_variable, softmax, element_times\n",
2016-10-26 09:34:40 +03:00
"from cntk.utils import *"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we imported the needed modules, let's implement our first CNN, as shown below:\n",
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/CNN.png\">\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 06:41:49 +03:00
"Let's implement the above network using CNTK layer API:"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 18,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_basic_model(input, out_dims):\n",
2016-10-26 06:41:49 +03:00
" \n",
2016-11-01 06:03:17 +03:00
" net = Convolution((5,5), 32, init=glorot_uniform(), activation=relu, pad=True)(input)\n",
2016-10-26 06:41:49 +03:00
" net = MaxPooling((3,3), strides=(2,2))(net)\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-01 06:03:17 +03:00
" net = Convolution((5,5), 32, init=glorot_uniform(), activation=relu, pad=True)(net)\n",
2016-10-26 06:41:49 +03:00
" net = MaxPooling((3,3), strides=(2,2))(net)\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-01 06:03:17 +03:00
" net = Convolution((5,5), 64, init=glorot_uniform(), activation=relu, pad=True)(net)\n",
2016-10-26 06:41:49 +03:00
" net = MaxPooling((3,3), strides=(2,2))(net)\n",
" \n",
2016-11-01 06:03:17 +03:00
" net = Dense(64, init=glorot_uniform())(net)\n",
" net = Dense(out_dims, init=glorot_uniform(), activation=None)(net)\n",
2016-10-26 06:41:49 +03:00
" \n",
2016-10-25 06:42:43 +03:00
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-11-18 13:48:30 +03:00
"To train the above model we need two things:\n",
2016-10-26 06:41:49 +03:00
"* Read the training images and their corresponding labels.\n",
"* Define a cost function, compute the cost for each mini-batch and update the model weights according to the cost value.\n",
"\n",
"To read the data in CNTK, we will use CNTK readers which handle data augmentation and can fetch data in parallel.\n",
"\n",
"Example of a map text file:\n",
"\n",
2016-11-10 15:02:35 +03:00
" S:\\data\\CIFAR-10\\train\\00001.png\t9\n",
" S:\\data\\CIFAR-10\\train\\00002.png\t9\n",
" S:\\data\\CIFAR-10\\train\\00003.png\t4\n",
" S:\\data\\CIFAR-10\\train\\00004.png\t1\n",
" S:\\data\\CIFAR-10\\train\\00005.png\t1\n"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 19,
2016-10-25 06:42:43 +03:00
"metadata": {
2016-10-26 06:41:49 +03:00
"collapsed": false
2016-10-25 06:42:43 +03:00
},
"outputs": [],
"source": [
"# model dimensions\n",
"image_height = 32\n",
"image_width = 32\n",
"num_channels = 3\n",
"num_classes = 10\n",
"\n",
"#\n",
"# Define the reader for both training and evaluation action.\n",
"#\n",
"def create_reader(map_file, mean_file, train):\n",
" if not os.path.exists(map_file) or not os.path.exists(mean_file):\n",
2016-10-26 06:41:49 +03:00
" raise RuntimeError(\"This tutorials depends 201A tutorials, please run 201A first.\")\n",
2016-10-25 06:42:43 +03:00
"\n",
" # transformation pipeline for the features has jitter/crop only when training\n",
" transforms = []\n",
" if train:\n",
" transforms += [\n",
2017-01-13 19:39:56 +03:00
" ImageDeserializer.crop(crop_type='randomside', side_ratio=0.8) # train uses data augmentation (translation only)\n",
2016-10-25 06:42:43 +03:00
" ]\n",
" transforms += [\n",
" ImageDeserializer.scale(width=image_width, height=image_height, channels=num_channels, interpolations='linear'),\n",
" ImageDeserializer.mean(mean_file)\n",
" ]\n",
" # deserializer\n",
" return MinibatchSource(ImageDeserializer(map_file, StreamDefs(\n",
" features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'\n",
" labels = StreamDef(field='label', shape=num_classes) # and second as 'label'\n",
" )))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-11-18 13:48:30 +03:00
"Now let us write the the training and validation loop."
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 20,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"#\n",
"# Train and evaluate the network.\n",
"#\n",
"def train_and_evaluate(reader_train, reader_test, max_epochs, model_func):\n",
2016-11-02 04:03:02 +03:00
" # Input variables denoting the features and label data\n",
" input_var = input_variable((num_channels, image_height, image_width))\n",
" label_var = input_variable((num_classes))\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-01 06:03:17 +03:00
" # Normalize the input\n",
" feature_scale = 1.0 / 256.0\n",
" input_var_norm = element_times(feature_scale, input_var)\n",
" \n",
2016-10-25 06:42:43 +03:00
" # apply model to input\n",
2016-11-01 06:03:17 +03:00
" z = model_func(input_var_norm, out_dims=10)\n",
2016-10-25 06:42:43 +03:00
"\n",
" #\n",
" # Training action\n",
" #\n",
"\n",
" # loss and metric\n",
" ce = cross_entropy_with_softmax(z, label_var)\n",
" pe = classification_error(z, label_var)\n",
"\n",
" # training config\n",
" epoch_size = 50000\n",
" minibatch_size = 64\n",
"\n",
2016-11-03 10:42:00 +03:00
" # Set training parameters\n",
2016-11-18 13:48:30 +03:00
" lr_per_minibatch = learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], UnitType.minibatch, epoch_size)\n",
2016-11-02 21:29:44 +03:00
" momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))\n",
2016-11-03 10:42:00 +03:00
" l2_reg_weight = 0.001\n",
2016-11-02 21:29:44 +03:00
" \n",
2016-10-25 06:42:43 +03:00
" # trainer object\n",
2016-11-01 06:03:17 +03:00
" learner = momentum_sgd(z.parameters, \n",
2016-11-02 21:29:44 +03:00
" lr = lr_per_minibatch, momentum = momentum_time_constant, \n",
2016-11-01 06:03:17 +03:00
" l2_regularization_weight=l2_reg_weight)\n",
2016-10-25 06:42:43 +03:00
" trainer = Trainer(z, ce, pe, [learner])\n",
"\n",
" # define mapping from reader streams to network inputs\n",
" input_map = {\n",
" input_var: reader_train.streams.features,\n",
" label_var: reader_train.streams.labels\n",
" }\n",
"\n",
" log_number_of_parameters(z) ; print()\n",
" progress_printer = ProgressPrinter(tag='Training')\n",
"\n",
" # perform model training\n",
" batch_index = 0\n",
" plot_data = {'batchindex':[], 'loss':[], 'error':[]}\n",
" for epoch in range(max_epochs): # loop over epochs\n",
" sample_count = 0\n",
" while sample_count < epoch_size: # loop over minibatches in the epoch\n",
" data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), input_map=input_map) # fetch minibatch.\n",
" trainer.train_minibatch(data) # update model with it\n",
"\n",
" sample_count += data[label_var].num_samples # count samples processed so far\n",
" \n",
" # For visualization... \n",
" plot_data['batchindex'].append(batch_index)\n",
" plot_data['loss'].append(trainer.previous_minibatch_loss_average)\n",
" plot_data['error'].append(trainer.previous_minibatch_evaluation_average)\n",
" \n",
" progress_printer.update_with_trainer(trainer, with_metric=True) # log progress\n",
" batch_index += 1\n",
" progress_printer.epoch_summary(with_metric=True)\n",
" \n",
" #\n",
" # Evaluation action\n",
" #\n",
" epoch_size = 10000\n",
" minibatch_size = 16\n",
"\n",
" # process minibatches and evaluate the model\n",
" metric_numer = 0\n",
" metric_denom = 0\n",
" sample_count = 0\n",
" minibatch_index = 0\n",
"\n",
" while sample_count < epoch_size:\n",
" current_minibatch = min(minibatch_size, epoch_size - sample_count)\n",
"\n",
" # Fetch next test min batch.\n",
" data = reader_test.next_minibatch(current_minibatch, input_map=input_map)\n",
"\n",
" # minibatch data to be trained with\n",
" metric_numer += trainer.test_minibatch(data) * current_minibatch\n",
" metric_denom += current_minibatch\n",
"\n",
" # Keep track of the number of samples processed so far.\n",
" sample_count += data[label_var].num_samples\n",
" minibatch_index += 1\n",
"\n",
" print(\"\")\n",
" print(\"Final Results: Minibatch[1-{}]: errs = {:0.1f}% * {}\".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))\n",
" print(\"\")\n",
" \n",
" # Visualize training result:\n",
" window_width = 32\n",
" loss_cumsum = np.cumsum(np.insert(plot_data['loss'], 0, 0)) \n",
" error_cumsum = np.cumsum(np.insert(plot_data['error'], 0, 0)) \n",
"\n",
" # Moving average.\n",
" plot_data['batchindex'] = np.insert(plot_data['batchindex'], 0, 0)[window_width:]\n",
" plot_data['avg_loss'] = (loss_cumsum[window_width:] - loss_cumsum[:-window_width]) / window_width\n",
" plot_data['avg_error'] = (error_cumsum[window_width:] - error_cumsum[:-window_width]) / window_width\n",
" \n",
" plt.figure(1)\n",
" plt.subplot(211)\n",
" plt.plot(plot_data[\"batchindex\"], plot_data[\"avg_loss\"], 'b--')\n",
" plt.xlabel('Minibatch number')\n",
" plt.ylabel('Loss')\n",
" plt.title('Minibatch run vs. Training loss ')\n",
"\n",
" plt.show()\n",
"\n",
" plt.subplot(212)\n",
" plt.plot(plot_data[\"batchindex\"], plot_data[\"avg_error\"], 'r--')\n",
" plt.xlabel('Minibatch number')\n",
" plt.ylabel('Label Prediction Error')\n",
" plt.title('Minibatch run vs. Label Prediction Error ')\n",
2016-10-26 11:30:39 +03:00
" plt.show()\n",
" \n",
" return softmax(z)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 21,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-10-26 11:30:39 +03:00
"Training 116906 parameters in 10 parameter tensors.\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Finished Epoch [1]: [Training] loss = 1.996150 * 50000, metric = 72.5% * 50000 5.085s (9833.1 samples per second)\n",
"Finished Epoch [2]: [Training] loss = 1.630085 * 50000, metric = 59.3% * 50000 5.113s (9779.7 samples per second)\n",
"Finished Epoch [3]: [Training] loss = 1.505027 * 50000, metric = 54.7% * 50000 5.147s (9713.6 samples per second)\n",
"Finished Epoch [4]: [Training] loss = 1.414553 * 50000, metric = 50.8% * 50000 5.105s (9795.0 samples per second)\n",
"Finished Epoch [5]: [Training] loss = 1.324908 * 50000, metric = 47.2% * 50000 5.089s (9825.3 samples per second)\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Final Results: Minibatch[1-626]: errs = 42.7% * 10000\n",
2016-10-25 06:42:43 +03:00
"\n"
]
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3Xmc1fP+wPHXu1JpUQpljagkiXJtiRCyJLvGzsVtwb3x\n49q3S+jal1y7iMG1ZkvWa8k6ISlFKSmVVNOeat6/P97fr/OdM2dmzixnznfOvJ+Px3mc7/79fM73\n1HnPZxVVxTnnnHMuLuplOwHOOeecc1EenDjnnHMuVjw4cc4551yseHDinHPOuVjx4MQ555xzseLB\niXPOOedixYMT55xzzsWKByfOOeecixUPTpxzzjkXKx6cuJwgIkUiclUlz50hIo9E1k8Lrte9+lJY\neSLSLkjPBdlOiyudiHwqIq9X8tynRWRydacpzXtXOt3OZYoHJy42IkFBkYjsVcoxs4L9o5N2afCq\njKIU51brvA4ikicif6/Oa7rSicijke9SWa9Hyr9a2qryHVTse5gNPoeJi50G2U6AcymsBE4ExkU3\nisi+wObAqhTnrA+sreT9OpH5H4YTgS7AnRm+jzP/Ad6KrG8DXAc8AHwY2T6tGu/Zi8r/0J8MSDWm\nxblazYMTF0evA8eJyPmqGg0aTgS+BDZKPkFV/6jszVR1TWXPra1EpImqrsh2OjJFVT8DPgvXRaQH\n8C/gE1V9Kp1riEhjVU0VCJd2z8oGx6jqusqe61wu8modFzcK5AOtgQPDjSKyHnAs8BQp/sJMbnMi\nItcE27YVkcdEZJGILBaRR0SkcdK5M0op3m8qIveLyAIRKRSRkSLSMuncI0TkVRGZLSKrRORHEblC\nROpFjnkPOAwI244Uicj0yP5GQXqniMhKEZkjIs+LyDYp8nl2cI9VIvK5iOxa3gcaqS7bR0RGiMg8\nYFaw7zER+SnFOdeISFHStiIRuUtE+ovIt0EaJorIweXcfxMRWSMiV6bY1zG47uBgvYGIXC0iU4PP\nYoGIfCgiB5SXz6oQkbki8qyIHCYiBSKyCjg12He2iLwrIvOCNH0rImemuEaxthsicnCQtyOCz3O2\niKwQkTdFpF3SucXanIhIp/BzCV7TgnuPE5FuKe59oohMDo75OshHpduxiEib4LsxP7jmVyKSl+K4\nU0VkvIgsDf59fSMigyL7G4rI9SLyQ3Cd30TkfyKyT2XS5eoOLzlxcTQD+BTIA94Mth0KbAA8DaTT\ndiMsXn8WmA5cAnQHzgLmAZemODZKgHuARcDVWNXPYGArYL/IcacDS4FbgWXA/lj1QXPgn8Ex1wMt\nsCqpfwTXXgYQBDGvBdfMB+4Izj0Q2BGIBg4nAc2wKgsNrv+8iLRP8y/vEcB84FqgSSTvqfJf2vZe\nwNHBtZYC5wPPichWqroo1U1Vdb6I/A84Hiu9iBqAVcc9G6xfiz2rB4AvsGe+K/bs3ik/i5WmwE7A\nSCxv/wG+C/YNDtLyIlb9dyTwkIioqj6adI1UrgZWAzdhQffFwGMU/x6V9nn/FWgM3AvUx575cyLS\nUVUVQESOBkZhpYr/xEoWnwDmlJGmUolIU+Aj7Pt6F/ALcALwpIg0U9UHg+P6Bfl4A7gf+2O3C7An\ncF9wuRuxf6/3AV9h/w52A3YGPqho2lwdoqr+8lcsXsBpwDrsh2gwsBhoFOx7Bng7WP4JGJ10bhFw\nVWT96mDbA0nHPQ/MT9r2E/BIUjqKsGqB+pHt/xek7/DItkYp8nEf9sO9XmTbK8D0FMeeEdzr/DI+\nl3bBMfOBDSLb+wXpOTSNz7UIeB+QpH2PlpKuq4F1KT7jlcDWkW1dg+2Dy0nD2UFad0jaPhF4K7L+\nVfKzrabvVo8gnaeWsv/XIH17p9iX6hm/C3ybtO0T4PXI+sHBPccnfY8uCu7VPrItH5gUWQ/bQc0G\nmka2Hxecu39k2xTgh2g6seC2KHrNMj6b5HT/M7jHkZFtDbDg53egceR7Preca08Gnq3u5+mv3H95\ntY6Lq2exv+4PF5FmwOHAkxW8hmJ/0UV9CLQOrlmeB7R4icR9BMHAnzdQXR0ui0gzEWmN/dXZBNg+\njXscDfyGldKU52lVXRJZ/xArhWmfxrkKPKiqVe2Z8ZaqzvjzoqrfAkvSSMML2Gd3QrhBRLoAO2Cl\nYaHFQBcR2a6K6ayMyar6UfLGpGfcQkQ2wv7q7ywiDdO47kNJ36OwQW46z+1JVV2edO6fzzyo+usA\nPBpNp6q+hQUslXEIMFNVX4pcby1wN9ASCHvSLQZaiMj+ZVxrMbBTqipK58riwYmLJVVdALyNNYI9\nGvuuPleJS/2ctB5WPWxYXhKAH5PStBz7C3vrcJuI7CAiL4rIYuxH+jesSB2sCLs82wJTtHjD39LM\nSkrP4mCxvLyEZqR5XNppCCwqLw2q+jtWLXN8ZPMAYA1WXRK6CvsBnCoiE0RkuIh0rVqS01ai7Q1Y\nLzEReU9ElmN5nR+kU7Bqp/Ikf2aLgnPTeW6pziVybth2JVWvox9TbEtHO2Bqiu2TsXSH97wbmAm8\nJSIzReRBEemTdM7lQBtgWtAW5kYR2aGS6XJ1iAcnLs6ewkopBgJvqOrSSlyjtLYYVe62KSItsL+g\nuwJXYKU7fUi0Nanuf19VzcvKFNtKK0mpn4E0PA10FJGdgvXjgHdUdeGfiVH9EAvYzgC+xdpcjE/V\nADUDSnw+IrI9MBZoirWdOBR7xmFJVzrPuCqfWca+v1WlqnOw7/5RWLupPsBYEbkvcsy72PP8Kxbc\n/A34WkROqvkUu9rEgxMXZ2EDxN2xQKUmCVZcnthgDQU3JVEC0Rv7C/Y0Vb1HVV8P/jNeTEmlBQHT\ngE4iUlowkGmLsJKKZFtn4F4vYSUlJwQ9TjpibS2KUdXFqjpSVU8CtgQmANdkID3p6I+1tzhUVR9S\n1THBM45L9/OZwXuqarDKVo3NxJ5Nss7Y9zi8J6q6RlVHq+pgrKrpMeAcEdkscsxCVX1UVfOwBuVT\nsDZNzpXKgxMXW0E1ykDsh+mVLCThHBGJ9mgbjJUohN1F12FBTLTbcMPguGTLSV3N8zywMXBudSS4\nEqZh7QZ2DDeIyKZYj5RqpaqFWO+r47EqndXAy9FjRKRV0jkrsOqJRpFjNgi62qZTpVJVYclF9Bm3\nxgZNS0dGR19V1Z+wtiWnS6SLvFj37g6lnli217Fu7/0j12uAfUcXAx8H25KflWINnCF4XimOWYb1\nnmuEc2XwrsQubooVV6vqE6UdWAMaAu+IyLNY49ZBwIeq+mqwfxxW8vC4iNwVbDuZ1D9IBcDxInIr\n1i11WXCdx7HxNG4Tkd2xBo/NgAOAe1W1uoKy0qoBngZuBl4K8tAUCwinYL2mqtszWLfXwcCbSQ18\nASaJyPvY57UQ+As2vs1dkWOOwnoZnY59fpk0BhgGvCEiD2GlTOdgvWhKDAaYQk1Uv1yOfa4ficjj\nwCbYd/U7KvcH6L1Yl/unROQerN3LAOz7MDDS8HaUiDTCeoHNxkpOzgU+C4ImsLYmb2A9lhZh3YwP\nB4ZXIl2uDvHgxMVNOn9pphoToqrzmqS63rnY2CLXAuthvYX+HGNFVReKyGHYGCf/wv7zfQLrZvpm\n0vVGAN2wH9R/YEXjr6pqkYgcgv3AhI1/f8eClG/TyF+6+U55TJCHI4HbsCDlJ2yckY6UDE6qmgaA\n0VjbjqYU76UTuhM4AusK2wj7nC4Dbklxz4oq65yUeVDViSJyHPZ8b8V+hG/HSn1GpHGP0u5Z2ueY\n7rl/7lPV50TkFOBK7BlOwYLkwcBmKa9Qxr1VdbmI9MLGZTkDG3dnMnCSqkaf2WNYW5LBWND2KxYs\nXhs55nZsAMKDsef5E9aV+o400+XqKKl6z0LnnHNxE4wOO1VV+5d7sHMxk/U2JyJyqdgw3EvEhod+\nUURSNcYq7fyeYkNjj89kOp1zLo7Ehvyvl7StLzaQ23vZSZVzVZP1khOxuSjysdEHG2DDHe8IdFbV\nVF0fo+e2wOqmfwDaqGo
2016-10-25 06:42:43 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0x6ab9b70>"
2016-10-25 06:42:43 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnWeYFFXWgN8DAgoKCkhUERERFQNgQEyYE2JGDKCuATCi\nu4Y14KqfYV0xY1gDYEBZI4KKomBAMQwioohKUkSSIkgOc74fp8qu6eme6e7pnumZOe/z9FNVt27d\nOrdrevr0uSeIquI4juM4jpMv1KhoARzHcRzHcaK4cuI4juM4Tl7hyonjOI7jOHmFKyeO4ziO4+QV\nrpw4juM4jpNXuHLiOI7jOE5e4cqJ4ziO4zh5hSsnjuM4juPkFa6cOI7jOI6TV7hy4qSFiBSKyI0Z\nXjtbRJ6MHPcJxuuYPQkzR0RaBfJcUdGyVGdEZLyITMnymEX+9vKZBJ+TA4O/ywOyeI+MP8eOUx64\nclINiSgFhSKyb5I+PwfnR8ad0uCVCYUJrs1q/QQR6SUil2VzTKd0gr+V+7M0XC5qaqQ0ZuRzUSgi\nG0TkFxEZIyIH5kCmZCSSNe33RESOEpGBJdyj3GuXiMjAuPc4/v1uUt4yOfnJRhUtgFOhrAJOBz6O\nNgb/iFsCqxNcswmwPsP7tcMUlFxyOrAzcF+O7+NUXd4GhgECtAb6A++JyNGqOqa8hVHV90VkE1Vd\nm+alR2Oy/yvBubJ8jsuKAn2BFQnO/VHOsjh5iisn1Zs3gFNE5FJVjSoNpwNfAI3jL8jgH2T02nWZ\nXltZEZG6qrqyouVw0uJ7VX0uPBCRV4EpwOVAQuVERASoraprciFQhp87yfJ42eQlVf09nQtEpA6w\nVhNUq83G58w/q/mFL+tUXxQYDjQCDgsbRaQWcDLwHAn+ucWvVYvITUFbGxEZIiJLROQPEXlSRDaO\nuzbZun89EXlURBaLyFIRGSoim8dde5yIjArM7KtF5EcRuV5EakT6jAOOAULfkUIRmRk5XyeQd7qI\nrBKReSLykoi0TjDP84N7rBaRz0Skc2lvaGS57AARGSwiC4Cfg3NDRGRWgmtuEpHCuLZCEblfRHqI\nyNeBDFNF5IhS7t9ERNaJyA0Jzu0QjNs/ON4oMLF/H7wXi0XkQxE5pLR5ZkoqzzCuf0cRmSAiK0Vk\npohcmKBPbRH5l4j8EIz5k4jcKSK1syW3qk4FFmNWlPC+4TM6XUSmYlbGI4JzIiKXB89slYjMF5FH\n4v+mg77Xiy2hrhCRd0VkpwR9EvqciMjeIvKGiPwuIstF5CsRuSQ49xRmNYkuVW2Ik//GuPH2EJE3\ng8/gnyIyVkT2jusT/o3vKyKDRGRhcO+XRaRR2m9uEiJz7ikit4rIXMzSspmInJ3sc5bBPBKO4VQ8\nbjmp3swGJgK9iP0iPBqoDzwPpOK7Ef6KGQHMBK4BOgLnAQuAaxP0jSLAg8ASYCC29NMf2AboFul3\nNvAncDewHDgYuBnYDLg66HMr0ABbkro8GHs5QPAFODoYczhwb3DtYcAuQFRxOAPYFHgkkPlq4CUR\n2U5VN1A6g4GFmDm9bmTuyXwJErXvD5wYjPUncCnwoohso6pLEt1UVReKyPvAqcAtcadPw8z4I4Lj\nf2HP6jHgc+yZd8ae3bulTzEjzqb0ZxjSEHteIzBF+VTgYRFZo6pD4C9rxevAvsCjwHdAB2AA0BZ7\n/8qMiGwBbAH8EHfqkECuBzHlZXbQ/hjQG3gSW15sDVwC7C4iXcO/IRG5BbgOGAW8ib33bwO1EohR\n5G9ERA7D5j4P+1ueD7QHjgUewN6PFsCh2N9zUitKMN5OwAfAUuAO7G/lQmC8iBygqp/HXfIA8Dtw\nE7At9p4/iP0vSYVGwfOLsl5Vl8a13QCsAe4C6gBrib0X0c9ZvWAeO6c5j2JjOHmCqvqrmr2APsAG\n7J9hf2ydt05w7gVgbLA/CxgZd20hcGPkeGDQ9lhcv5eAhXFts4An4+QoBD4Fakba/x7Id2ykrU6C\neTyMfdnVirS9DsxM0Pec4F6XlvC+tAr6LATqR9q7B/IcncL7WgiMByTu3FNJ5BoIbEjwHq8Cto20\ndQja+5ciw/mBrDvFtU8F3okcfxn/bMv4N1UI3F9Kn1Sf4bhgDpdF2moBk4Bfw78V4ExgHdAlbswL\nguv3Sfa3V8o8HsMsio2BvYCxCeQpDO7dLu76/YJzPePaDwvaTwuOG2PWltfi+t0a9It+Tg4M7n9A\ncFwD+yEwA9ishLk8EP+3FSd/9HP8SvA31yrS1gz7kh+X4G/8rbjx7sYUh6TyRP7eC5O8vo2bcyGm\nENZO43OW7jyKjeGv/Hj5so4zAvt1f6yIbIr98no2zTEU+6UW5UPs19GmKVz/mBa1SDxMoAz8dYPI\nWr6IbBqYkD8KZN8xhXucCCzCft2VxvOquixy/CH2y3O7FK5V4L8a/AcsA++o6uy/BlX9GliWggwv\nY+9dz7Ah+DW5E2YNC/kD2FlEti+jnCmT5jNcjykJ4bXrsL+xJkCnoPlkYBrwvYg0Cl+YciMUtbyl\nw9+wv5WFmGWxC3C3qsY7WY9X1elxbSdj7+27cTJ9iVmLQpkOwxSuB+KuvzcF+fbArBX3quqfqU0p\nOYFV8TDgFVWdE7ar6nzMarVf3OdYiTybgA+BmpiCXxoKnIBZdaKvcxL0HaKJ/WOKfc4ynEc2PqtO\nDvBlnWqOqi4WkbGYE2w97FfZixkM9VPccbj0sAXB0koyEYAf42RaISK/Yv+Agb/Mzv+H/XOvH3d9\ngxTkawNM16KOv8kosvasqn8EFugtUrgWYub9spBo/XtJaTKo6m8i8i623BCGkZ6G/cp/JdL1RuBV\n7It9KvAW8HSgBOWENJ/hPFVdFdf2PaZ0bAt8hi3d7IgpEvEopshkwmuYEquYVeebBLJA4ufcFtgc\nU2xKkmmbYBv/t79YRBIu20VoE4z1TSn9UmVLTEH8PsG5adj/hK2D/ZD4v8/o5z0VPtTUHGJnp3Eu\nk3mUNL5Tgbhy4oD9qvgv0Bx4M8NfY8l8MUpc604FEWmArSP/AVyPmbRXY7+g7yD7jt1lnUuiL7Jk\nv85q5kCG54EnRWRXVZ0CnAK8G/0yUNUPRaQN0AM4HLMWDBCRC1U168nKcvQMawBfY/4Oid6XTB0c\n56rqeyn0S/Sca2C+VqcnkSmRIlUZydnnPY5E73Eq57IxvlOBuHLigP2ifhTYm8hyQDkh2K/N9/9q\nEKmHKUqjg6aDsF9kPVR1QqRfmwTjJVMCZgB7iUhNTc2pNdsswX5Rx7NtDu71KvY8ewZOhztgFosi\nqOofwFBgqIjUxUzzN2GOnNnmIFJ/hgAtxHJ7RL882mHPN3RengHsqqrjciBvpszAHGU/1pLDisNl\nh7ZEfr2LSGNKtz7MwD43uwAlKVGpLlcsAlZi72887THfjMoQyVJV5uHgocQOtoyCJUW6CXMoLW8u\nEJGootwfsyi8ERxvwP4ZR8OGawf94llB4mWelzCz78XZEDgDZgANRGSXsEFEmgPHZ/tGahEPY7Cl\nndOwaIfXon1EpGHcNSuxJYY6kT71RaSdiESXYDIlnWcI9sOpb6RvLSzqYhHmGAvmL7WViJwff7GI\nbBwoXOXNCEz2YqnhRaRmYEECc7Jdj0XxRBmQwj0mYQra5ZHxErEiuG+Jzy9Y6nwb6CEi4XITItIU\ni775UFVLWprNC6rKPBzDLSfVlyLmV1V9uqIEAWpjDoQjMB+Cftg/klHB+Y8xy8MwiaVIP5PEvwwL\ngFNF5G4sRHZ5MM4wLLxzUJDz4EMsXPgQ4CFVzZZSlsys/TxwJ/BqMId62JfvdCxqKtu8ADyDffmP\niXPwBfhWRMZj79fvwJ6YM2c0Bf0JWJTR2dj7VxqdReS6BO3jSO8ZgkXlXCUi22I+BKcBuwLnRyxf\nTxMLMe4GTMCU2vbYUtbhxBSZckFVPxCRR4FrRGR37MtyHWa9OhkLCX858C35T9BvFKaI7wEcSeKl\nn7/+rlRVRaQfMBKYLJb
2016-10-25 06:42:43 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0x93c9a90>"
2016-10-25 06:42:43 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2016-10-25 09:14:01 +03:00
"data_path = os.path.join('data', 'CIFAR-10')\n",
"reader_train = create_reader(os.path.join(data_path, 'train_map.txt'), os.path.join(data_path, 'CIFAR-10_mean.xml'), True)\n",
"reader_test = create_reader(os.path.join(data_path, 'test_map.txt'), os.path.join(data_path, 'CIFAR-10_mean.xml'), False)\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 11:30:39 +03:00
"pred = train_and_evaluate(reader_train, reader_test, max_epochs=5, model_func=create_basic_model)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-26 11:30:39 +03:00
"Although, this model is very simple, it still has too much code, we can do better. Here the same model in more terse format:"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 22,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
2016-10-26 06:41:49 +03:00
"def create_basic_model_terse(input, out_dims):\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 06:41:49 +03:00
" with default_options(activation=relu):\n",
" model = Sequential([\n",
" LayerStack(3, lambda i: [\n",
2016-11-01 06:03:17 +03:00
" Convolution((5,5), [32,32,64][i], init=glorot_uniform(), pad=True),\n",
2016-10-26 06:41:49 +03:00
" MaxPooling((3,3), strides=(2,2))\n",
" ]),\n",
2016-11-01 06:03:17 +03:00
" Dense(64, init=glorot_uniform()),\n",
" Dense(out_dims, init=glorot_uniform(), activation=None)\n",
2016-10-26 06:41:49 +03:00
" ])\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 06:41:49 +03:00
" return model(input)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 23,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": false
},
2016-10-25 23:52:23 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training 116906 parameters in 10 parameter tensors.\n",
"\n",
2016-11-29 05:56:43 +03:00
"Finished Epoch [1]: [Training] loss = 2.087211 * 50000, metric = 76.8% * 50000 5.140s (9728.0 samples per second)\n",
"Finished Epoch [2]: [Training] loss = 1.703542 * 50000, metric = 63.1% * 50000 5.124s (9757.6 samples per second)\n",
"Finished Epoch [3]: [Training] loss = 1.555579 * 50000, metric = 57.4% * 50000 5.128s (9750.6 samples per second)\n",
"Finished Epoch [4]: [Training] loss = 1.461751 * 50000, metric = 53.3% * 50000 5.176s (9659.5 samples per second)\n",
"Finished Epoch [5]: [Training] loss = 1.382035 * 50000, metric = 49.7% * 50000 5.129s (9747.7 samples per second)\n",
"Finished Epoch [6]: [Training] loss = 1.305139 * 50000, metric = 46.8% * 50000 5.176s (9659.5 samples per second)\n",
"Finished Epoch [7]: [Training] loss = 1.243207 * 50000, metric = 44.0% * 50000 5.264s (9498.1 samples per second)\n",
"Finished Epoch [8]: [Training] loss = 1.183436 * 50000, metric = 41.6% * 50000 5.378s (9297.5 samples per second)\n",
"Finished Epoch [9]: [Training] loss = 1.139780 * 50000, metric = 40.0% * 50000 5.342s (9359.1 samples per second)\n",
"Finished Epoch [10]: [Training] loss = 1.097680 * 50000, metric = 38.7% * 50000 5.328s (9383.8 samples per second)\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Final Results: Minibatch[1-626]: errs = 34.2% * 10000\n",
2016-10-25 23:52:23 +03:00
"\n"
]
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3Xd4VGX2wPHvASkKgqiIoKgoqIAKgtgRe8EV7BJwRdjV\nVSwru2svqGv/2bBgWRFRNBZs2FDsuopIQFRAdKkKSC8inZzfH+de585kkkzKZCaT83meee7c97b3\nnQnk5K2iqjjnnHPOZYtamc6Ac84551yUByfOOeecyyoenDjnnHMuq3hw4pxzzrms4sGJc84557KK\nByfOOeecyyoenDjnnHMuq3hw4pxzzrms4sGJc84557KKBycuJ4hIoYjcUM5rZ4nIk5H9vsH9OlVe\nDstPRHYO8vOPTOfFFU9ExorI2+W89nkRmVrZeUrx2eXOt3Pp4sGJyxqRoKBQRA4u5pyfg+OjEg5p\n8CqPwiTXVuq6DiKSJyJ/r8x7uuKJyLDIz1JJrydLv1vKKvIzqNjPYSb4GiYu62yW6Qw4l8QaoDfw\nRTRRRLoBOwBrk1yzObCxnM/bg/T/YugNtAcGp/k5zjwKjInstwJuBh4HPoukT6/EZ3al/L/ozwak\nEvPiXLXmwYnLRm8DZ4jIpaoaDRp6A+OBbRMvUNX15X2Yqm4o77XVlYhsoaqrM52PdFHVr4Cvwn0R\n6Qz8G/hSVZ9L5R4iUl9VkwXCxT2zvMExqrqpvNc6l4u8WcdlGwXygW2AY8JEEakDnA48R5K/MBP7\nnIjIjUHabiLylIgsE5HlIvKkiNRPuHZWMdX7DUTkMRFZLCIrRGS4iGyVcG0PEXlTROaKyFoR+Z+I\nXCcitSLnfAScCIR9RwpFZEbkeL0gv9NEZI2IzBORl0WkVZJynhc8Y62IjBOR/Ur7QCPNZYeJyBAR\nWQD8HBx7SkRmJrnmRhEpTEgrFJEHRKSniHwX5OF7ETmulOdvJyIbROT6JMd2D+47INjfTEQGiciP\nwWexWEQ+E5GjSitnRYjIryLyooicKCIFIrIWOCc4dp6IfCgiC4I8fSci/ZPcI67vhogcF5StR/B5\nzhWR1SLyrojsnHBtXJ8TEdkj/FyC1/Tg2V+ISIckz+4tIlODc74JylHufiwi0iz42VgY3HOiiOQl\nOe8cEZkgIr8F/74miciFkeN1ReQWEfkpuM8iEflERA4rT75czeE1Jy4bzQLGAnnAu0Fad6AR8DyQ\nSt+NsHr9RWAGcBXQCfgrsAC4Osm5UQI8BCwDBmFNPwOAnYAjIuedC/wG3AOsAo7Emg+2BK4MzrkF\naIw1SV0W3HsVQBDEvBXcMx+4P7j2GGAvIBo49AEaYk0WGtz/ZRHZNcW/vIcAC4GbgC0iZU9W/uLS\nuwKnBvf6DbgUGCkiO6nqsmQPVdWFIvIJcCZWexHVC2uOezHYvwn7rh4Hvsa+8/2w7+6D0otYbgrs\nAwzHyvYoMDk4NiDIy6tY89/JwBMioqo6LOEeyQwC1gF3YEH3FcBTxP8cFfd5/wWoDzwM1Ma+85Ei\nsruqKoCInAqMwGoVr8RqFp8B5pWQp2KJSAPgc+zn9QHgF+As4FkRaaiq/wnOOykoxzvAY9gfu+2B\ng4BHgtvdjv17fQSYiP072B/oCHxa1ry5GkRV/eWvrHgBfYFN2C+iAcByoF5w7AXg/eD9TGBUwrWF\nwA2R/UFB2uMJ570MLExImwk8mZCPQqxZoHYk/V9B/v4USauXpByPYL+460TS3gBmJDm3X/CsS0v4\nXHYOzlkINIqknxTkp3sKn2sh8DEgCceGFZOvQcCmJJ/xGmCXSNreQfqAUvJwXpDXdgnp3wNjIvsT\nE7/bSvrZ6hzk85xijs8P8ndokmPJvuMPge8S0r4E3o7sHxc8c0LCz9HlwbN2jaTlA1Mi+2E/qLlA\ng0j6GcG1R0bSpgE/RfOJBbeF0XuW8Nkk5vvK4BknR9I2w4KfJUD9yM/5r6XceyrwYmV/n/7K/Zc3\n67hs9SL21/2fRKQh8Cfg2TLeQ7G/6KI+A7YJ7lmaxzW+RuIRgmDgjweorgvfi0hDEdkG+6tzC2DP\nFJ5xKrAIq6UpzfOqujKy/xlWC7NrCtcq8B9VrejIjDGqOuuPm6p+B6xMIQ+vYJ/dWWGCiLQH2mG1\nYaHlQHsRaV3BfJbHVFX9PDEx4TtuLCLbYn/1txWRuinc94mEn6OwQ24q39uzqvp7wrV/fOdB018b\nYFg0n6o6BgtYyuMEYLaqvha530bgQWArIBxJtxxoLCJHlnCv5cA+yZoonSuJBycuK6nqYuB9rBPs\nqdjP6shy3GpOwn7Y9NCktCwA/0vI0+/YX9i7hGki0k5EXhWR5dgv6UVYlTpYFXZpdgOmaXzH3+L8\nnJCf5cHb0soSmpXieSnnIbCstDyo6hKsWebMSHIvYAPWXBK6AfsF+KOIfCsid4nI3hXLcsqK9L0B\nGyUmIh+JyO9YWRcG+RSs2ak0iZ/ZsuDaVL63ZNcSuTbsu5Js1NH/kqSlYmfgxyTpU7F8h898EJgN\njBGR2SLyHxE5OuGaa4FmwPSgL8ztItKunPlyNYgHJy6bPYfVUlwAvKOqv5XjHsX1xajwsE0RaYz9\nBb03cB1Wu3M0sb4mlf3vq6JlWZMkrbialNppyMPzwO4isk+wfwbwgaou/SMzqp9hAVs/4Dusz8WE\nZB1Q06DI5yMiewLvAQ2wvhPdse84rOlK5TuuyGeWtp/filLVedjP/ilYv6mjgfdE5JHIOR9i3+df\nsODmb8A3ItKn6nPsqhMPTlw2CzsgHoAFKlVJsOryWIJ1FGxOrAbicOwv2L6q+pCqvh38Z7ycoooL\nAqYDe4hIccFAui3DaioS7ZKGZ72G1ZScFYw42R3raxFHVZer6nBV7QO0BL4FbkxDflLRE+tv0V1V\nn1DV0cF3nC3Dz2cH22TNYOVtGpuNfTeJ2mI/x+EzUdUNqjpKVQdgTU1PAeeLSIvIOUtVdZiq5mEd\nyqdhfZqcK5YHJy5rBc0oF2C/mN7IQBbOF5HoiLYBWI1COFx0ExbERIcN1w3OS/Q7yZt5XgaaAhdX\nRobLYTrWb2CvMEFEmmMjUiqVqq7ARl+diTXprANej54jIlsnXLMaa56oFzmnUTDUNpUmlYoKay6i\n3/E22KRpqUjr7KuqOhPrW3KuRIbIiw3vblPshSV7Gxv23jNyv82wn9HlwH+DtMTvSrEOzhB8X0nO\nWYWNnquHcyXwocQu28RVV6vqM8WdWAXqAh+IyItY59YLgc9U9c3g+BdYzcPTIvJAkHY2yX8hFQBn\nisg92LDUVcF9nsbm07hXRA7AOjw2BI4CHlbVygrKimsGeB64E3gtKEMDLCCcho2aqmwvYMNeBwDv\nJnTwBZgiIh9jn9dSoAs2v80DkXNOwUYZnYt9fuk0GrgNeEdEnsBqmc7HRtEUmQwwiapofrkW+1w/\nF5Gnge2wn9XJlO8P0IexIffPichDWL+XXtjPwwWRjrcjRKQeNgpsLlZzcjHwVRA0gfU1eQcbsbQM\nG2b8J+CucuTL1SAenLhsk8pfmsnmhKjouibJ7ncxNrfITUAdbLTQH3OsqOpSETkRm+Pk39h/vs9g\nw0zfTbjfEKAD9gv1Mqxq/E1VLRSRE7BfMGHn3yVYkPJdCuVLtdxJzwnKcDJwLxakzMTmGdmdosFJ\nRfMAMArr29GA+FE6ocFAD2wobD3sc7oGuDvJM8uqpGuSlkFVvxeRM7Dv9x7sl/B9WK3PkBSeUdwz\ni/scU732j2OqOlJE/gxcj32H07AgeQDQIukdSni2qv4uIl2xeVn6YfPuTAX6qGr0O3sK60syAAva\n5mPB4k2Rc+7DJiA8Dvs+Z2JDqe9PMV+uhpKKjyx0zjmXbYLZYX9U1Z6lnuxclsl4nxMRuVpsGu6V\nYtNDvyoiyTpjFXf9IWJTY09IZz6dcy4biU35Xysh7XhsIrePMpMr5yom4zUnYmtR5GOzD26GTXe8\nF9BWVZMNfYxe2xhrm/4
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0x9ff6c18>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnWeYFFXWgN8DEkQEAwqYBRUxExQjqJiz7ipgBjOuq5hz\nxrBmPrOrIKCYA6ZVwbgCIoMEFXURVBTJUUDSnO/HqaKqe7pnunt6Us95n6eeqrr3VtW53T3Tp889\nQVQVx3Ecx3Gc6kKdqhbAcRzHcRwnjisnjuM4juNUK1w5cRzHcRynWuHKieM4juM41QpXThzHcRzH\nqVa4cuI4juM4TrXClRPHcRzHcaoVrpw4juM4jlOtcOXEcRzHcZxqhSsnTlaISLGI3JjjtT+LyDOx\n8zOC+7XPn4S5IyJbBvJcWtWy1GZE5BMRmZDneyZ89qozKf5OugSfy855fEbOf8eOUxm4clILiSkF\nxSKyd5ox04L+oUldGmy5UJzi2rzWTxCRHiJycT7v6ZRN8Fnpl6fbVURNjYzuGfu7KBaR1SLyu4i8\nLyJdKkCmdKSSNevXREQOF5GbSnlGpdcuEZGbkl7j5Nd748qWyamerFXVAjhVyjLgZGBEvDH4R7wp\n8FeKa9YGVuX4vDaYglKRnAzsCDxUwc9xCpcPgIGAAFsDvYGPROQIVX2/soVR1U9FZG1VXZHlpUdg\nst+Soq88f8flRYHzgSUp+hZUsixONcWVk9rNu8CJIvJPVY0rDScDY4BmyRfk8A8yfu3KXK+tqYhI\nI1VdWtVyOFnxo6o+H56IyBvABOASIKVyIiIC1FfV5RUhUI5/d5Ln++WTV1V1XjYXiEgDYIWmqFab\nj78z/1utXviyTu1FgSHAhsDBYaOI1AP+DjxPin9uyWvVInJz0NZaRAaIyHwRWSAiz4hIw6Rr0637\nryMiT4jIHBFZKCLPish6SdceIyJvB2b2v0RksohcLyJ1YmM+Bo4EQt+RYhGZEutvEMj7g4gsE5Hp\nIvKqiGydYp7nBM/4S0RGi0jHsl7Q2HJZZxF5VERmAtOCvgEiMjXFNTeLSHFSW7GI9BORY0VkYiDD\nNyJyaBnP31hEVorIDSn6tgvu2zs4Xyswsf8YvBZzRORzEela1jxzJZP3MGl8exH5QkSWisgUETkv\nxZj6InKLiPwvuOevInK3iNTPl9yq+g0wB7OihM8N36OTReQbzMp4aNAnInJJ8J4tE5EZIvJ48mc6\nGHu92BLqEhEZLiI7pBiT0udERDqJyLsiMk9E/hSR8SJyUdDXH7OaxJeqVifJf2PS/dqJyHvB3+Bi\nERkmIp2SxoSf8b1F5H4RmRU8+zUR2TDrFzcNsTl3E5HbReQ3zNKyroicme7vLId5pLyHU/W45aR2\n8zMwCuhB9IvwCKAJ8AKQie9G+CvmJWAKcDXQHjgbmAlck2JsHAEeBuYDN2FLP72BLYADYuPOBBYD\n9wF/AgcCtwLrAlcFY24HmmJLUpcE9/4TIPgCfCe45xDgweDag4GdgLjicArQGHg8kPkq4FURaaWq\nqymbR4FZmDm9UWzu6XwJUrXvB5wQ3Gsx8E/gFRHZQlXnp3qoqs4SkU+Bk4Dbkrq7Y2b8l4LzW7D3\n6kngK+w974i9d8PLnmJOnEnZ72HIBtj79RKmKJ8EPCYiy1V1AKyxVrwF7A08AXwP7Az0AbbFXr9y\nIyLrA+sD/0vq6hrI9TCmvPwctD8JnA48gy0vbg1cBOwmIvuEnyERuQ24DngbeA977T8A6qUQI+Ez\nIiIHY3Ofjn2WZwBtgaOA/8Nej02Ag7DPc1orSnC/HYDPgIXAXdhn5TzgExHprKpfJV3yf8A84GZg\nK+w1fxj7X5IJGwbvX5xVqrowqe0GYDlwD9AAWEH0WsT/ztYJ5rFjlvMocQ+nmqCqvtWyDTgDWI39\nM+yNrfM2CPpeBIYFx1OBoUnXFgM3xs5vCtqeTBr3KjArqW0q8EySHMXAl0DdWPvlgXxHxdoapJjH\nY9iXXb1Y21vAlBRjewbP+mcpr8uWwZhZQJNY+9GBPEdk8LoWA58AktTXP41cNwGrU7zGy4CtYm07\nB+29y5DhnEDWHZLavwE+jJ1/nfzelvMzVQz0K2NMpu/hx8EcLo611QPGAn+EnxXgVGAlsFfSPc8N\nrt8z3WevjHk8iVkUmwF7AMNSyFMcPLtN0vX7Bn3dktoPDtq7B+fNMGvLm0njbg/Gxf9OugTP7xyc\n18F+CPwErFvKXP4v+bOVJH/87/j14DO3ZaytBfYl/3GKz/h/ku53H6Y4pJUn9nkvTrN9lzTnYkwh\nrJ/F31m28yhxD9+qx+bLOs5L2K/7o0SkMfbL67ks76HYL7U4n2O/jhpncP2TmmiReIxAGVjzgNha\nvog0DkzI/w1k3z6DZ5wAzMZ+3ZXFC6q6KHb+OfbLs1UG1yrwlAb/AcvBh6r685qbqk4EFmUgw2vY\na9ctbAh+Te6AWcNCFgA7isg25ZQzY7J8D1dhSkJ47UrsM7Yx0CFo/jswCfhRRDYMN0y5ERItb9lw\nFvZZmYVZFvcC7lPVZCfrT1T1h6S2v2Ov7fAkmb7GrEWhTAdjCtf/JV3/YAbytcOsFQ+q6uLMppSe\nwKp4MPC6qv4StqvqDMxqtW/S37ESe28CPgfqYgp+WShwPGbViW89U4wdoKn9Y0r8neU4j3z8rToV\ngC/r1HJUdY6IDMOcYNfBfpW9ksOtfk06D5ce1idYWkknAjA5SaYlIvIH9g8YWGN27ov9c2+SdH3T\nDORrDfygiY6/6UhYe1bVBYEFev0MroXIvF8eUq1/zy9LBlWdKyLDseWGMIy0O/Yr//XY0BuBN7Av\n9m+A/wCDAiWoQsjyPZyuqsuS2n7ElI6tgNHY0s32mCKRjGKKTC68iSmxill1vk0hC6R+n7cF1sMU\nm9Jk2iLYJ3/254hIymW7GK2De31bxrhM2QhTEH9M0TcJ+5+weXAckvz5jP+9Z8LnmplD7M9Z9OUy\nj9Lu71Qhrpw4YL8qngJaAu/l+GssnS9GqWvdmSAiTbF15AXA9ZhJ+y/sF/Rd5N+xu7xzSfVFlu7X\nWd0KkOEF4BkR2UVVJwAnAsPjXwaq+rmItAaOBQ7BrAV9ROQ8Vc17srIKeg/rABMxf4dUr0uuDo6/\nqepHGYxL9T7XwXytTk4jUypFqiZSYX/vSaR6jTPpy8f9nSrElRMH7Bf1E0AnYssBlYRgvzY/XdMg\nsg6mKL0TNO2P/SI7VlW/iI1rneJ+6ZSAn4A9RKSuZubUmm/mY7+ok9mqAp71BvZ+dgucDrfDLBYJ\nqOoC4FngWRFphJnmb8YcOfPN/mT+HgJsIpbbI/7l0QZ7f0Pn5Z+AXVT14wqQN1d+whxlR2jpYcXh\nssO2xH69i0gzyrY+/IT93ewElKZEZbpcMRtYir2+ybTFfDNqQiRLoczDwUOJHWwZBUuKdDPmUFrZ\nnCsicUW5N2ZReDc4X439M46HDdcPxiWzhNTLPK9iZt9/5EPgHPgJaCoiO4UNItISOC7fD1KLeHgf\nW9rpjkU7vBkfIyIbJF2zFFtiaBAb00RE2ohIfAkmV7J5D8F+OJ0fG1sPi7qYjTnGgvlLbSYi5yRf\nLCINA4WrsnkJk71EangRqRtYkMCcbFdhUTxx+mTwjLGYgnZJ7H6pWBI8t9T3L1jq/AA4VkTC5SZE\npDkWffO5qpa2NFstKJR5OIZbTmovCeZXVR1UVYIA9TEHwpcwH4ILsH8kbwf9IzDLw0CJUqSfSupf\nhkXASSJyHxYi+2dwn4FYeOf9Qc6Dz7Fw4a7AI6qaL6UsnVn7BeBu4I1gDutgX74/YFFT+eZFYDD2\n5f9+koMvwHci8gn2es0DdsecOeMp6I/HoozOxF6/sugoItelaP+Y7N5DsKicK0VkK8yHoDuwC3BO\nzPI1iCjE+ADgC0ypbYstZR1CpMhUCqr6mYg8AVwtIrthX5YrMevV37GQ8NcC35J7g3FvY4p4O+Aw\nUi/9rPlcqaqKyAXAUGC
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xa0ed9b0>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
2016-10-26 20:16:27 +03:00
"pred_basic_model = train_and_evaluate(reader_train, reader_test, max_epochs=10, model_func=create_basic_model_terse)"
2016-10-26 11:30:39 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have a trained model, let's classify the following image:\n",
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/00014.png\", width=64, height=64>\n"
2016-10-26 11:30:39 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 24,
2016-10-26 11:30:39 +03:00
"metadata": {
2016-10-26 20:16:27 +03:00
"collapsed": false
2016-10-26 11:30:39 +03:00
},
"outputs": [],
"source": [
"from PIL import Image\n",
"\n",
"def eval(pred_op, image_path):\n",
" label_lookup = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"]\n",
" image_mean = 133.0\n",
2016-11-29 05:56:43 +03:00
" image_data = np.array(Image.open(image_path), dtype=np.float32)\n",
2016-10-26 11:30:39 +03:00
" image_data -= image_mean\n",
2016-11-29 05:56:43 +03:00
" image_data = np.ascontiguousarray(np.transpose(image_data, (2, 0, 1)))\n",
2016-11-02 04:03:02 +03:00
" \n",
2016-11-18 21:15:15 +03:00
" result = np.squeeze(pred_op.eval({pred_op.arguments[0]:[image_data]}))\n",
2016-11-02 04:03:02 +03:00
" \n",
" # Return top 3 results:\n",
" top_count = 3\n",
" result_indices = (-np.array(result)).argsort()[:top_count]\n",
2016-10-26 11:30:39 +03:00
"\n",
2016-11-02 04:03:02 +03:00
" print(\"Top 3 predictions:\")\n",
" for i in range(top_count):\n",
" print(\"\\tLabel: {:10s}, confidence: {:.2f}%\".format(label_lookup[result_indices[i]], result[result_indices[i]] * 100))"
2016-10-26 11:30:39 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 25,
2016-10-26 11:30:39 +03:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-11-02 04:03:02 +03:00
"Top 3 predictions:\n",
2016-11-29 05:56:43 +03:00
"\tLabel: truck , confidence: 96.60%\n",
"\tLabel: cat , confidence: 1.11%\n",
"\tLabel: ship , confidence: 0.96%\n"
2016-10-26 11:30:39 +03:00
]
}
],
"source": [
2016-11-18 13:48:30 +03:00
"eval(pred_basic_model, \"data/CIFAR-10/test/00014.png\")"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-26 09:34:40 +03:00
"Adding dropout layer, with drop rate of 0.25, before the last dense layer:"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 26,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
2016-10-26 06:41:49 +03:00
"def create_basic_model_with_dropout(input, out_dims):\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-10-26 06:41:49 +03:00
" with default_options(activation=relu):\n",
" model = Sequential([\n",
" LayerStack(3, lambda i: [\n",
2016-11-01 06:03:17 +03:00
" Convolution((5,5), [32,32,64][i], init=glorot_uniform(), pad=True),\n",
2016-10-26 06:41:49 +03:00
" MaxPooling((3,3), strides=(2,2))\n",
" ]),\n",
2016-11-01 06:03:17 +03:00
" Dense(64, init=glorot_uniform()),\n",
2016-10-26 06:41:49 +03:00
" Dropout(0.25),\n",
2016-11-01 06:03:17 +03:00
" Dense(out_dims, init=glorot_uniform(), activation=None)\n",
2016-10-26 06:41:49 +03:00
" ])\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-10-26 06:41:49 +03:00
" return model(input)"
2016-10-25 23:52:23 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 27,
2016-10-25 23:52:23 +03:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-10-26 06:41:49 +03:00
"Training 116906 parameters in 10 parameter tensors.\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Finished Epoch [1]: [Training] loss = 2.121308 * 50000, metric = 78.6% * 50000 5.960s (8388.7 samples per second)\n",
"Finished Epoch [2]: [Training] loss = 1.810775 * 50000, metric = 67.2% * 50000 6.054s (8259.0 samples per second)\n",
"Finished Epoch [3]: [Training] loss = 1.656878 * 50000, metric = 61.4% * 50000 6.098s (8199.4 samples per second)\n",
"Finished Epoch [4]: [Training] loss = 1.565552 * 50000, metric = 57.8% * 50000 6.035s (8285.1 samples per second)\n",
"Finished Epoch [5]: [Training] loss = 1.494175 * 50000, metric = 54.6% * 50000 5.960s (8388.8 samples per second)\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Final Results: Minibatch[1-626]: errs = 48.7% * 10000\n",
2016-10-25 23:52:23 +03:00
"\n"
]
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XecVNX5x/HPQ48oYAVFomKhiRqwK/bYe4tgBWuwYov+\nTGyxJ3bFEiMWdNVYsURR1KjYdzUWsKBAEKQoRXrb8/vjuePcHWZ3Z+vcmf2+X695zdxz2zlzF/bZ\nUy2EgIiIiEhSNMt3BkRERETiFJyIiIhIoig4ERERkURRcCIiIiKJouBEREREEkXBiYiIiCSKghMR\nERFJFAUnIiIikigKTkRERCRRFJxIUTCzcjO7tJbnTjCz+2Pbx0fX61N/Oaw9M1svys+5+c6LVM7M\n3jezl2p57mNmNra+85TjvWudb5GGouBEEiMWFJSb2faVHDMp2j8iY1eIXrVRnuXcel3Xwcz6m9nZ\n9XlNqZyZDYv9LFX1ur/6q+WsLj+DAf85zAetYSKJ0yLfGRDJYiEwAHg3nmhmOwOdgUVZzvkNsKyW\n9+tGw/9iGAD0Am5t4PuIuxt4Nba9AXAlcC/wdiz9u3q8Zz9q/4v+GMDqMS8iBU3BiSTRS8ARZnZW\nCCEeNAwAPgbWyDwhhLCktjcLISyt7bmFysxWCiEsyHc+GkoI4QPgg9S2mfUF/gq8F0J4NJdrmFmb\nEEK2QLiye9Y2OCaEsLy254oUIzXrSNIEoARYHfh9KtHMWgKHA4+S5S/MzD4nZnZ5lLahmT1gZrPM\nbLaZ3W9mbTLOnVBJ9X5bM7vHzH4yszlm9qCZdcg490Aze8HMJpvZIjMbZ2Z/NrNmsWPeAPYDUn1H\nys3s+9j+1lF+vzazhWY2xcyeMrMNspTz5Ogei8zsQzPbsrovNNZctpOZDTWzacCkaN8DZjY+yzmX\nm1l5Rlq5md1mZgeZ2edRHr4ws72quf9aZrbUzP6SZd8m0XUHR9stzOwyM/sm+i5+MrO3zWz36spZ\nF2Y21cyeMLP9zKzUzBYBx0X7Tjaz181sWpSnz81sUJZrVOi7YWZ7RWU7MPo+J5vZAjN7xczWyzi3\nQp8TM+uW+l6i13fRvd81s82z3HuAmY2Njvk0Kket+7GYWcfoZ2N6dM1PzKx/luOOM7MyM5sb/fv6\nr5n9Mba/lZldZWbfRteZYWb/MbOdapMvaTpUcyJJNAF4H+gPvBKl7Qu0Ax4Dcum7kapefwL4HrgI\n6AOcBEwDLs5ybJwBdwCzgMvwpp/BwG+BXWPHnQDMBW4E5gG74c0HqwB/io65CmiPN0mdE117HkAU\nxLwYXbMEuCU69/fApkA8cDgaWBlvsgjR9Z8ys645/uU9FJgOXAGsFCt7tvJXlt4PODS61lzgLOBJ\nM/ttCGFWtpuGEKab2X+AI/Hai7ij8Oa4J6LtK/BndS/wEf7Mt8Sf3ajqi1hrAdgMeBAv293Al9G+\nwVFensGb/w4G7jOzEEIYlnGNbC4DFgPX4UH3hcADVPw5quz7PhFoA9wJNMef+ZNmtkkIIQCY2aHA\ncLxW8U94zeLDwJQq8lQpM2sLvIP/vN4G/AD8AXjEzFYOIfwjOu6AqBz/Bu7B/9jtBWwH3BVd7lr8\n3+tdwCf4v4OtgS2At2qaN2lCQgh66ZWIF3A8sBz/RTQYmA20jvY9DrwWfR4PjMg4txy4NLZ9WZR2\nb8ZxTwHTM9LGA/dn5KMcbxZoHks/P8rf/rG01lnKcRf+i7tlLO154Pssxw6M7nVWFd/LetEx04F2\nsfQDovzsm8P3Wg68CVjGvmGV5OsyYHmW73ghsH4srXeUPriaPJwc5bVnRvoXwKux7U8yn209/Wz1\njfJ5XCX7f4zyt2OWfdme8evA5xlp7wEvxbb3iu5ZlvFzdEF0r66xtBJgTGw71Q9qMtA2ln5EdO5u\nsbSvgW/j+cSD2/L4Nav4bjLz/afoHgfH0lrgwc/PQJvYz/nUaq49Fniivp+nXsX/UrOOJNUT+F/3\n+5vZysD+wCM1vEbA/6KLextYPbpmde4NFWsk7iIKBn69QQiLU5/NbGUzWx3/q3MloHsO9zgUmIHX\n0lTnsRDCL7Htt/FamK45nBuAf4QQ6joy49UQwoRfLxrC58AvOeThafy7+0Mqwcx6AT3x2rCU2UAv\nM9uojvmsjbEhhHcyEzOecXszWwP/q7+HmbXK4br3ZfwcpTrk5vLcHgkhzM8499dnHjX9bQwMi+cz\nhPAqHrDUxj7AxBDCs7HrLQNuBzoAqZF0s4H2ZrZbFdeaDWyWrYlSpCoKTiSRQgg/Aa/hnWAPxX9W\nn6zFpf6XsZ1qeli1uiwA4zLyNB//C3v9VJqZ9TSzZ8xsNv5LegZepQ5ehV2dDYGvQ8WOv5WZlJGf\n2dHH6sqSMiHH43LOQ2RWdXkIIfyMN8scGUs+CliKN5ekXIr/AvzGzD4zsxvMrHfdspyzFfregI8S\nM7M3zGw+XtbpUT4Nb3aqTuZ3Nis6N5fnlu1cYuem+q5kG3U0LktaLtYDvsmSPhbPd+qetwMTgVfN\nbKKZ/cPM9sg45xKgI/Bd1BfmWjPrWct8SROi4ESS7FG8luI04N8hhLm1uEZlfTHqPGzTzNrjf0H3\nBv6M1+7sQbqvSX3/+6prWRZmSausJqV5A+ThMWATM9ss2j4CGBVCmPlrZkJ4Gw/YBgKf430uyrJ1\nQG0AK3w/ZtYdGAm0xftO7Is/41RNVy7PuC7fWYP9/NZVCGEK/rN/CN5vag9gpJndFTvmdfx5nogH\nN6cCn5rZ0Y2fYykkCk4kyVIdELfBA5XGZHh1eTrBOwquTboGYhf8L9jjQwh3hBBeiv4zns2KKgsC\nvgO6mVllwUBDm4XXVGRavwHu9SxeU/KHaMTJJnhfiwpCCLNDCA+GEI4GugCfAZc3QH5ycRDe32Lf\nEMJ9IYSXo2eclOHnE6P3bM1gtW0am4g/m0w98J/j1D0JISwNIYwIIQzGm5oeAE4xs3Vix8wMIQwL\nIfTHO5R/jfdpEqmUghNJrKgZ5TT8F9PzecjCKWYWH9E2GK9RSA0XXY4HMfFhw62i4zLNJ3szz1PA\nmsAZ9ZHhWvgO7zewaSrBzNbGR6TUqxDCHHz01ZF4k85i4Ln4MWa2WsY5C/DmidaxY9pFQ21zaVKp\nq1TNRfwZr45PmpaLBp19NYQwHu9bcoLFhsibD+/euNITq/YSPuz9oNj1WuA/o7OB0VFa5rMKeAdn\niJ5XlmPm4aPnWiNSBQ0llqSpUF0dQni4sgMbQStglJk9gXdu/SPwdgjhhWj/u3jNw0NmdluUdgzZ\nfyGVAkea2Y34sNR50XUewufTuMnMtsE7PK4M7A7cGUKor6CssmaAx4DrgWejMrTFA8Kv8VFT9e1x\nfNjrYOCVjA6+AGPM7E38+5oJbIXPb3Nb7JhD8FFGJ+DfX0N6GbgG+LeZ3YfXMp2Cj6JZYTLALBqj\n+eUS/Ht9x8weAtbCf1a/pHZ/gN6JD7l/1MzuwPu9HIX/PJwW63g73Mxa46PAJuM1J2cAH0RBE3hf\nk3/jI5Zm4cOM9wduqEW+pAlRcCJJk8tfmtnmhKjruibZrncGPrfIFUBLfLTQr3OshBBmmtl++Bwn\nf8X/830YH2b6Ssb1hgKb479Qz8Grxl8IIZSb2T74L5hU59+f8SDl8xzKl2u5sx4TleFg4CY8SBmP\nzzOyCSsGJ3XNA8AIvG9HWyqO0km5FTgQHwrbGv+e/g/4e5Z71lRV52QtQwjhCzM7An++N+K/hG/G\na32G5nCPyu5Z2feY67m/7gshPGlmxwJ/wZ/h13iQPBhYJ+sVqrh3CGG+mfXD52UZiM+7MxY4OoQQ\nf2YP4H1JBuNB2494sHhF7Jib8QkI98Kf53h8KPUtOeZLmiir+8hCERFJmmh22G9CCAdVe7BIwuS9\nz4mZXWw+Dfcv5tNDP2Nm2TpjVXb+DuZTY5c1ZD5FRJLIfMr/Zhlpe+MTub2Rn1yJ1E3ea07M16Io\nwWcfbIFPd7wp0COEkG3
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0x69b0710>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXeYFGXSwH8FElWQIGDAjIIiKmDEzHl6hjOcoJiziAHx\nPPU7A545nFkRPQVEFHNAxdPDLIhhQRFBTAgGgkgQZAnL1vdHdds9szO7M7Ozu7O79Xueebr7TV3v\n9M5OTb1V9Yqq4jiO4ziOUyg0qGkBHMdxHMdx4rhy4jiO4zhOQeHKieM4juM4BYUrJ47jOI7jFBSu\nnDiO4ziOU1C4cuI4juM4TkHhyonjOI7jOAWFKyeO4ziO4xQUrpw4juM4jlNQuHLiZIWIlIrIVTn2\n/V5EhsWuTw7G654/CXNHRDYN5LmopmWpz4jI2yIyJc9jJvztFTIpPif7BH+Xe+fxHjl/jh2nOnDl\npB4SUwpKRWSPNG1+COrHJFVp8MqF0hR987p/goj0E5GB+RzTqZjgb+XuPA1XFXtqZDRm7HNRKiJr\nROQnEXlNRPapApnSkUrWrN8TEfmLiAwu5x7VvneJiAxOeo+T3+921S2TU5isVdMCODVKMXAcMCFe\nGPwj3ghYkaJPM6Akx/ttgykoVclxwHbAXVV8H6fu8jowEhBgc2AA8KaIHKyqr1W3MKr6jog0U9VV\nWXY9GJP9XynqKvM5riwK9Ad+T1G3uJplcQoUV07qN2OBPiJygarGlYbjgE+AtskdcvgHGe+7Ote+\ntRURaa6qy2taDicrvlLVx8MLEXkBmAJcCKRUTkREgMaqurIqBMrxcyd5Hi+fPKuqC7PpICJNgFWa\nYrfafHzO/LNaWPiyTv1FgdFAG+CAsFBEGgFHA4+T4p9b8lq1iFwdlG0pIiNEZJGILBaRYSLSNKlv\nunX/tUXkARFZICJLROQREVkvqe9fReTlwMy+QkS+EZErRKRBrM1bwCFA6DtSKiLfxeqbBPLOEJFi\nEflZRJ4Vkc1TzPPM4B4rROQjEelZ0RsaWy7bW0SGiMg84IegboSIzEzR52oRKU0qKxWRu0XkcBH5\nPJBhqogcWMH924nIahG5MkXd1sG4A4LrtQIT+1fBe7FARN4Tkd4VzTNXMnmGSe27i8h4EVkuIt+J\nyNkp2jQWkX+JyNfBmLNF5GYRaZwvuVV1KrAAs6KE9w2f0XEiMhWzMh4Y1ImIXBg8s2IRmSsiQ5P/\npoO2V4gtof4uIm+IyLYp2qT0ORGRXUVkrIgsFJFlIvKZiJwf1A3HrCbxpao1SfJflTTeTiLyavAZ\nXCoi40Rk16Q24d/4HiJyu4jMD+79nIi0yfrNTUNszseIyHUi8iNmaVlXRE5J9znLYR4px3BqHrec\n1G++ByYC/Yh+ER4MtACeADLx3Qh/xTwFfAdcBnQHzgDmAf+Xom0cAe4FFgGDsaWfAcAmwH6xdqcA\nS4HbgGXA/sA1wLrApUGb64CW2JLUhcHYywCCL8BXgjFHA3cGfQ8AugJxxeF4YB1gaCDzpcCzIrKF\nqq6hYoYA8zFzevPY3NP5EqQq3ws4KhhrKXAB8IyIbKKqi1LdVFXni8g7QF/g2qTqYzEz/lPB9b+w\nZ/Ug8DH2zHtiz+6NiqeYE6dQ8TMMaY09r6cwRbkvcL+IrFTVEfCHteIlYA/gAeBLYHtgENAJe/8q\njYi0AloBXydV9Q7kuhdTXr4Pyh8ETgKGYcuLmwPnAzuKSK/wb0hErgUuB14GXsXe+9eBRinESPgb\nEZEDsLn/jP0tzwW6AIcC92Dvx4bAn7C/57RWlGC8bYF3gSXATdjfytnA2yKyt6p+nNTlHmAhcDWw\nGfae34v9L8mENsHzi1OiqkuSyq4EVgK3Ak2AVUTvRfxztnYwj+2ynEeZMZwCQVX9Vc9ewMnAGuyf\n4QBsnbdJUPckMC44nwmMSepbClwVux4clD2Y1O5ZYH5S2UxgWJIcpcCHQMNY+cWBfIfGypqkmMf9\n2Jddo1jZS8B3KdqeGtzrgnLel02DNvOBFrHywwJ5Ds7gfS0F3gYkqW54GrkGA2tSvMfFwGaxsu2D\n8gEVyHBmIOu2SeVTgf/FricnP9tK/k2VAndX0CbTZ/hWMIeBsbJGwCRgTvi3ApwArAZ2TxrzrKD/\nbun+9iqYx4OYRbEtsAswLoU8pcG9t0nqv2dQd0xS+QFB+bHBdVvM2vJiUrvrgnbxz8k+wf33Dq4b\nYD8EvgXWLWcu9yT/bSXJH/8cPx/8zW0aK+uAfcm/leJv/L9J492GKQ5p5Yn9vZemeU1LmnMpphA2\nzuJzlu08yozhr8J4+bKO8xT26/5QEVkH++X1WJZjKPZLLc572K+jdTLo/6AmWiTuJ1AG/rhBbC1f\nRNYJTMjvB7J3zuAeRwG/YL/uKuIJVf0tdv0e9stziwz6KvAfDf4DVoL/qer3fwyq+jnwWwYyPIe9\nd8eEBcGvyW0xa1jIYmA7EdmqknJmTJbPsARTEsK+q7G/sXZAj6D4aGA68JWItAlfmHIjJFresuF0\n7G9lPmZZ3B24TVWTnazfVtUZSWVHY+/tG0kyTcasRaFMB2AK1z1J/e/MQL6dMGvFnaq6NLMppSew\nKh4APK+qs8JyVZ2LWa32TPocK7FnE/Ae0BBT8CtCgSMxq078dWqKtiM0tX9Mmc9ZjvPIx2fVqQJ8\nWaeeo6oLRGQc5gS7Nvar7JkchpqddB0uPbQiWFpJJwLwTZJMv4vIHOwfMPCH2fl67J97i6T+LTOQ\nb0tghiY6/qYjYe1ZVRcHFuhWGfSFyLxfGVKtfy+qSAZV/VVE3sCWG8Iw0mOxX/nPx5peBbyAfbFP\nBf4LPBooQVVCls/wZ1UtTir7ClM6NgM+wpZuOmOKRDKKKTK58CKmxCpm1fkihSyQ+jl3AtbDFJvy\nZNokOCb/7S8QkZTLdjG2DMb6ooJ2mbI+piB+laJuOvY/oWNwHpL89xn/vGfCe5qZQ+z3WdTlMo/y\nxndqEFdOHLBfFf8BNgBezfHXWDpfjHLXujNBRFpi68iLgSswk/YK7Bf0TeTfsbuyc0n1RZbu11nD\nKpDhCWCYiHRT1SlAH+CN+JeBqr4nIlsChwN/xqwFg0TkbFXNe7KyKnqGDYDPMX+HVO9Lrg6OP6rq\nmxm0S/WcG2C+VselkSmVIlUbqbLPexKp3uNM6vIxvlODuHLigP2ifgDYldhyQDUh2K/Nd/4oEFkb\nU5ReCYr2xX6RHa6q42PttkwxXjol4FtgFxFpqJk5teabRdgv6mQ2q4J7vYA9z2MCp8OtMYtFAqq6\nGHgEeEREmmOm+asxR858sy+ZP0OADcVye8S/PLbBnm/ovPwt0E1V36oCeXPlW8xRdoKWH1YcLjt0\nIvbrXUTaUrH14Vvsc9MVKE+JynS54hdgOfb+JtMF882oDZEsdWUeDh5K7GDLKFhSpKsxh9Lq5iwR\niSvKAzCLwtjgeg32zzgeNtw4aJfM76Re5nkWM/uelw+Bc+BboKWIdA0LRGQD4Ih830gt4uE1bGnn\nWCza4cV4GxFpndRnObbE0CTWpoWIbCMi8SWYXMnmGYL9cOofa9sIi7r4BXOMBfOX2lhEzkzuLCJN\nA4WrunkKk71MangRaRhYkMCcbEuwKJ44gzK4xyRMQbswNl4qfg/uW+7zC5Y6XwcOF5FwuQkRaY9F\n37ynquUtzRYEdWUejuGWk/pLgvlVVR+tKUGAxpgD4VOYD8E52D+Sl4P6CZjlYaREKdJPIPUvwyKg\nr4jchoXILgvGGYmFd94e5Dx4DwsX7g3cp6r5UsrSmbWfAG4GXgjmsDb25TsDi5rKN08Co7Av/9eS\nHHwBponI29j7tRDYGXPmjKegPxKLMjoFe/8qoqeIXJ6i/C2ye4ZgUTmXiMhmmA/BsUA34MyY5etR\nohDj/YDxmFLbBVvK+jORIlMtqOq7IvIAcJmI7Ih9Wa7GrFdHYyHhzwW+Jf8O2r2MKeI7AQeReunn\nj78rVVUROQcYA3wqltN
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xa1dff98>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2016-10-26 20:16:27 +03:00
"pred_basic_model_dropout = train_and_evaluate(reader_train, reader_test, max_epochs=5, model_func=create_basic_model_with_dropout)"
2016-10-25 23:52:23 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-26 09:34:40 +03:00
"Add batch normalization after each convolution and before the last dense layer:"
2016-10-25 23:52:23 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 28,
2016-10-25 23:52:23 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
2016-10-26 06:41:49 +03:00
"def create_basic_model_with_batch_normalization(input, out_dims):\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-10-25 06:42:43 +03:00
" with default_options(activation=relu):\n",
" model = Sequential([\n",
" LayerStack(3, lambda i: [\n",
2016-11-01 06:03:17 +03:00
" Convolution((5,5), [32,32,64][i], init=glorot_uniform(), pad=True),\n",
2016-10-26 06:41:49 +03:00
" BatchNormalization(map_rank=1),\n",
2016-10-25 06:42:43 +03:00
" MaxPooling((3,3), strides=(2,2))\n",
" ]),\n",
2016-11-01 06:03:17 +03:00
" Dense(64, init=glorot_uniform()),\n",
2016-10-26 06:41:49 +03:00
" BatchNormalization(map_rank=1),\n",
2016-11-01 06:03:17 +03:00
" Dense(out_dims, init=glorot_uniform(), activation=None)\n",
2016-10-25 06:42:43 +03:00
" ])\n",
"\n",
2016-10-26 06:41:49 +03:00
" return model(input)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 29,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": false
},
2016-10-25 23:52:23 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-10-26 06:41:49 +03:00
"Training 117290 parameters in 18 parameter tensors.\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Finished Epoch [1]: [Training] loss = 1.572498 * 50000, metric = 56.4% * 50000 5.768s (8667.9 samples per second)\n",
"Finished Epoch [2]: [Training] loss = 1.235720 * 50000, metric = 43.7% * 50000 5.817s (8595.0 samples per second)\n",
"Finished Epoch [3]: [Training] loss = 1.096666 * 50000, metric = 38.5% * 50000 5.811s (8604.3 samples per second)\n",
"Finished Epoch [4]: [Training] loss = 1.021178 * 50000, metric = 35.6% * 50000 5.837s (8566.3 samples per second)\n",
"Finished Epoch [5]: [Training] loss = 0.965618 * 50000, metric = 33.6% * 50000 5.765s (8672.7 samples per second)\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Final Results: Minibatch[1-626]: errs = 32.3% * 10000\n",
2016-10-25 23:52:23 +03:00
"\n"
]
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XecVNX9//HXG1RQlMVoxBJj14hRFDT5JgoaQ2KLEv0q\nip0Ufwb1qyQxxiS2xKgpduwVCxs1RoNK7CUSjeiuHewgClIUWJQiZT+/Pz53nDvD7O5sv7v7eT4e\n85iZc8+995y5A/PZc0+RmRFCCCGEkBXd2rsAIYQQQghpEZyEEEIIIVMiOAkhhBBCpkRwEkIIIYRM\nieAkhBBCCJkSwUkIIYQQMiWCkxBCCCFkSgQnIYQQQsiUCE5CCCGEkCkRnIROQVKtpDObuO9USTem\n3h+THG9Ay5Ww6SRtkpTn5+1dllA3Sf+VNL6J+/5N0uSWLlOZ525yuUNoLRGchMxIBQW1kr5dR54P\nku3jijZZ8miK2hL7tui6DpKGSzq5JY8Z6ibpptR3qb7HjQ0frWzN+Q4a/j1sD7GGScicVdq7ACGU\nsBg4HHgmnShpd2AjYEmJfVYHljfxfNvQ+j8MhwPbAZe28nmCuxp4JPV+M+D3wLXA06n0d1vwnINo\n+g/9kYBasCwhdGgRnIQsGg8cIun/zCwdNBwOvACsW7yDmS1t6snMbFlT9+2oJK1hZovauxytxcye\nA57LvZc0EPgD8KyZjS3nGJJ6mlmpQLiuczY1OMbMVjR13xA6o7itE7LGgEpgHeB7uURJqwIHA2Mp\n8RdmcZ8TSWcnaVtIulnSPEnzJd0oqWfRvlPraN7vJekaSR9LqpE0RlKfon0PkHS/pOmSlkh6R9Lv\nJHVL5XkC2A/I9R2plfReanuPpLxvSlosaYakuyVtVqKeP03OsUTSREk7N/SBpm6XDZZ0paRZwAfJ\ntpslTSmxz9mSaovSaiVdJmmopFeTMrwmaa8Gzr+epGWSziixbevkuCOT96tIOkvSW8ln8bGkpyV9\nt6F6NoekmZLulLSfpCpJS4Cjk20/lfS4pFlJmV6V9KMSxyjouyFpr6RuBySf53RJiyQ9JGmTon0L\n+pxI2ib3uSSPd5NzPyOpf4lzHy5pcpLnpaQeTe7HIqlv8t2YnRzzRUnDS+Q7WlK1pE+Tf18vS/pZ\navtqks6V9HZynDmSnpI0uCnlCl1HtJyELJoK/BcYDjyUpO0L9Ab+BpTTdyPXvH4n8B7wa2AA8BNg\nFnB6ibxpAkYD84Cz8Fs/I4GvAt9J5TsW+BS4EPgM2BO/fbAWcFqS51ygAr8ldUpy7M8AkiDmgeSY\nlcAlyb7fA74OpAOHI4A18VsWlhz/bkmbl/mX95XAbOAcYI1U3UvVv670QcBBybE+Bf4P+Lukr5rZ\nvFInNbPZkp4ChuGtF2mH4bfj7kzen4Nfq2uB5/FrvjN+7R5ruIpNZsAOwBi8blcDryfbRiZluQe/\n/fdD4HpJZmY3FR2jlLOAz4EL8KD7V8DNFH6P6vq8fwz0BK4AuuPX/O+StjYzA5B0EHAb3qp4Gt6y\neCswo54y1UlSL2AC/n29DPgQOBS4XdKaZnZdkm//pB7/Aq7B/9jdDvgWcFVyuPPxf69XAS/i/w6+\nAewI/LuxZQtdiJnFIx6ZeADHACvwH6KRwHygR7LtDuDR5PUUYFzRvrXAman3ZyVp1xbluxuYXZQ2\nBbixqBy1+G2B7qn0Xybl+0EqrUeJelyF/3Cvmkq7D3ivRN4Rybn+r57PZZMkz2ygdyp9/6Q8+5bx\nudYCTwIq2nZTHeU6C1hR4jNeDGyaSts+SR/ZQBl+mpS1X1H6a8AjqfcvFl/bFvpuDUzKeXQd2z9K\nyrdbiW2lrvHjwKtFac8C41Pv90rOWV30PTo1OdfmqbRKYFLqfa4f1HSgVyr9kGTfPVNpbwJvp8uJ\nB7e16WPW89kUl/u05Bw/TKWtggc/nwA9U9/zmQ0cezJwZ0tfz3h0/kfc1glZdSf+1/0PJK0J/AC4\nvZHHMPwvurSngXWSYzbkWitskbiKJBj44gRmn+deS1pT0jr4X51rAF8r4xwHAXPwVpqG/M3MFqTe\nP423wmxexr4GXGdmzR2Z8YiZTf3ioGavAgvKKMM/8M/u0FyCpO2AfnhrWM58YDtJWzaznE0x2cwm\nFCcWXeMKSevif/VvK2m1Mo57fdH3KNcht5zrdruZLSza94trntz62wq4KV1OM3sED1iaYh/gfTO7\nN3W85cDlQB8gN5JuPlAhac96jjUf2KHULcoQ6hPBScgkM/sYeBTvBHsQ/l39exMONa3ofe7Ww9oN\nFQF4p6hMC/G/sDfNpUnqJ+keSfPxH+k5eJM6eBN2Q7YA3rTCjr91+aCoPPOTlw3VJWdqmfnKLkNi\nXkNlMLNP8Nsyw1LJhwHL8NslOWfiP4BvSXpF0p8lbd+8Ipdtpb434KPEJD0haSFe19lJOYXfdmpI\n8Wc2L9m3nOtWal9S++b6rpQadfROibRybAK8VSJ9Ml7u3DkvB94HHpH0vqTrJA0p2ue3QF/g3aQv\nzPmS+jWxXKELieAkZNlYvJXieOBfZvZpE45RV1+MZg/blFSB/wW9PfA7vHVnCPm+Ji3976u5dVlc\nIq2ulpTurVCGvwFbS9oheX8I8JiZzf2iMGZP4wHbCOBVvM9FdakOqK1gpc9H0teAh4FeeN+JffFr\nnGvpKucaN+cza7Xvb3OZ2Qz8u38g3m9qCPCwpKtSeR7Hr+eP8eDm/wEvSTqi7UscOpIITkKW5Tog\nfhMPVNqS8ObyfIJ3FNyAfAvEHvhfsMeY2WgzG5/8ZzyfldUVBLwLbCOprmCgtc3DWyqKbdoK57oX\nbyk5NBlxsjXe16KAmc03szFmdgSwMfAKcHYrlKccQ/H+Fvua2fVm9mByjbMy/Pz95LnUbbCm3hp7\nH782xbbFv8e5c2Jmy8xsnJmNxG813QwcJ2nDVJ65ZnaTmQ3HO5S/ifdpCqFOEZyEzEpuoxyP/zDd\n1w5FOE5SekTbSLxFITdcdAUexKSHDa+W5Cu2kNK3ee4Gvgyc2BIFboJ38X4DX88lSNoAH5HSosys\nBh99NQy/pfM58M90HklfKtpnEX57okcqT+9kqG05t1SaK9dykb7G6+CTppWjVWdfNbMpeN+SY5Ua\nIi8f3r1VnTvWbzw+7H1o6nir4N/R+cB/krTia2V4B2dIrleJPJ/ho+d6EEI9YihxyJqC5mozu7Wu\njG1gNeAxSXfinVt/BjxtZvcn25/BWx5ukXRZknYkpX+QqoBhki7Eh6V+lhznFnw+jYskfRPv8Lgm\n8F3gCjNrqaCsrtsAfwP+BNyb1KEXHhC+iY+aaml34MNeRwIPFXXwBZgk6Un885oL7ILPb3NZKs+B\n+CijY/HPrzU9CJwH/EvS9Xgr03H4KJqVJgMsoS1uv/wW/1wnSLoFWA//rr5O0/4AvQIfcj9W0mi8\n38th+Pfh+FTH29sk9cBHgU3HW05OBJ5Lgibwvib/wkcszcOHGf8A+HMTyhW6kAhOQtaU85dmqTkh\nmruuSanjnYjPLXIOsCo+WuiLOVbMbK6k/fA5Tv6A/+d7Kz7M9KGi410J9Md/UE/Bm8bvN7NaSfvg\nPzC5zr+f4EHKq2XUr9x6l8yT1OGHwEV4kDIFn2dka1YOTppbBoBxeN+OXhSO0sm5FDgAHwrbA/+c\nfgP8tcQ5G6u+fUrWwcxek3QIfn0vxH+EL8Zbfa4s4xx1nbOuz7Hcfb/YZmZ/l3QUcAZ+Dd/Eg+SR\nwIYlj1DPuc1soaRB+LwsI/B5dyYDR5hZ+prdjPclGYkHbR/hweI5qTwX4xMQ7oVfzyn4UOpLyixX\n6KLU/JGFIYQQsiaZHfYtMxvaYOYQMqbd+5xIOl0+DfcC+fTQ90gq1RmreL/VJP1RPvX4EknvSTq2\nDYocQgiZIZ/yv1tR2t74RG5PtE+pQmieLNzWGYSPl38BL8/5+HC0bc2s1NDHnLvwjoQj8E59G5CB\nYCuEENrYFnifoUr81sp
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xa143e10>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnWeYFFXWgN8DEkQERQQMoKiYI5gwh8+sa1xxFBVdE6Di\nqKvrGnDNa44oRsQwiophTSiIigHDkAyIARCQnJHMzPl+nCq7uqd7prunZ6Zn5rzPU0/VzedWdXed\nvvfcc0VVcRzHcRzHyRca1LQAjuM4juM4UVw5cRzHcRwnr3DlxHEcx3GcvMKVE8dxHMdx8gpXThzH\ncRzHyStcOXEcx3EcJ69w5cRxHMdxnLzClRPHcRzHcfIKV04cx3Ecx8krXDlxMkJESkXkhizLThaR\npyPhs4P6OudOwuwRkc0CeS6vaVnqMyLysYiMy3GdcZ+9fCbJ9+TA4HN5QA7byPp77DjVgSsn9ZCI\nUlAqIvukyDM1SH8rIUmDIxtKk5TN6f4JIlIgIn1yWadTMcFn5cEcVVcVe2qkVWfke1EqIiUi8oeI\nDBGRA6tAplQkkzXjeyIiR4lI33LaqPa9S0Skb8I9TrzfbapbJic/WaumBXBqlOXA6cAX0cjgh3gT\nYEWSMmsDa7JsbxtMQalKTgd2AB6o4nacussHwEBAgI5AL+AjETlaVYdUtzCq+omIrK2qqzIsejQm\n+3+SpFXme1xZFLgIWJokbWE1y+LkKa6c1G/eBf4uIpeqalRpOB34FmidWCCLH8ho2dXZlq2tiEgz\nVV1W03I4GfGzqr4YBkTkDWAccBmQVDkREQEaq+rKqhAoy++d5Li+XPKaqs7PpICINAFWaZLdanPx\nPfPvan7h0zr1FwWKgA2Aw8JIEWkEnAK8SJIft8S5ahG5MYjbUkQGiMgCEVkoIk+LSNOEsqnm/dcR\nkf4iMldEFonIsyKyXkLZv4nI28Ew+woR+VVErhORBpE8w4FjgNB2pFREJkbSmwTyThCR5SIyXURe\nE5GOSfp5ftDGChH5WkR2r+iGRqbLDhCRfiIyC5gapA0QkUlJytwoIqUJcaUi8qCIHC8i3wUyfC8i\nR1TQfhsRWS0i1ydJ2zqot1cQXisYYv85uBdzRWSEiBxaUT+zJZ1nmJC/s4h8LiLLRGSiiFyYJE9j\nEfmPiPwS1DlFRP4rIo1zJbeqfg/MxUZRwnbDZ3S6iHyPjTIeEaSJiFwWPLPlIjJTRB5L/EwHea8T\nm0JdKiLDRGT7JHmS2pyIyF4i8q6IzBeRP0VkrIhcEqQ9g42aRKeqShLkvyGhvt1E5L3gO7hERIaK\nyF4JecLP+D4icq+IzA7aHiwiG2R8c1MQ6XM3EblFRKZhIy3rikiPVN+zLPqRtA6n5vGRk/rNZGAk\nUEDsH+HRQAvgJSAd243wX8wgYCLwL6AzcB4wC7gmSd4oAjwMLAD6YlM/vYAOwMGRfD2AJcA9wJ/A\nIcBNwLrA1UGeW4CW2JTUZUHdfwIEL8B3gjqLgPuDsocBOwJRxeEMoDnwWCDz1cBrIrKFqpZQMf2A\n2dhwerNI31PZEiSL3x84KahrCXAp8KqIdFDVBckaVdXZIvIJcCpwc0Lyadgw/qAg/B/sWT0OfIM9\n892xZzes4i5mRQ8qfoYhrbDnNQhTlE8FHhWRlao6AP4arfgfsA/QH/gJ2AkoBDph96/SiMj6wPrA\nLwlJhwZyPYwpL5OD+MeBs4CnsenFjsAlwK4ism/4GRKRm4FrgbeB97B7/wHQKIkYcZ8RETkM6/t0\n7LM8E9gOOBZ4CLsfGwP/h32eU46iBPVtD3wKLALuwD4rFwIfi8gBqvpNQpGHgPnAjcDm2D1/GPst\nSYcNgucXZY2qLkqIux5YCdwFNAFWEbsX0e/ZOkE/dsiwH2XqcPIEVfWjnh3A2UAJ9mPYC5vnbRKk\nvQwMDa4nAW8llC0FboiE+wZxjyfkew2YnRA3CXg6QY5S4CugYST+ykC+YyNxTZL041HsZdcoEvc/\nYGKSvOcEbV1azn3ZLMgzG2gRiT8ukOfoNO5rKfAxIAlpz6SQqy9QkuQeLwc2j8TtFMT3qkCG8wNZ\nt0+I/x74MBIenfhsK/mZKgUerCBPus9weNCHPpG4RsAoYEb4WQG6A6uBrgl1XhCU3zvVZ6+CfjyO\njSi2BvYEhiaRpzRoe5uE8vsFad0S4g8L4k8Lwq2x0ZY3E/LdEuSLfk8ODNo/IAg3wP4I/AasW05f\nHkr8bCXIH/0evx585jaLxLXDXvLDk3zG30+o7x5McUgpT+TzXpri+DGhz6WYQtg4g+9Zpv0oU4cf\n+XH4tI4zCPt3f6yINMf+eb2QYR2K/VOLMgL7d9Q8jfKPa/yIxKMEysBfDUTm8kWkeTCE/Fkg+7Zp\ntHESMAf7d1cRL6nq4kh4BPbPc4s0yirwhAa/gJXgQ1Wd/Felqt8Bi9OQYTB277qFEcG/ye2x0bCQ\nhcAOIrJVJeVMmwyf4RpMSQjLrsY+Y22ALkH0KcB44GcR2SA8MOVGiB95y4R/YJ+V2djIYlfgHlVN\nNLL+WFUnJMSdgt3bYQkyjcZGi0KZDsMUrocSyt+fhny7YaMV96vqkvS6lJpgVPEw4HVV/T2MV9WZ\n2KjVfgnfYyXybAJGAA0xBb8iFDgRG9WJHuckyTtAk9vHlPmeZdmPXHxXnSrAp3XqOao6V0SGYkaw\n62D/yl7NoqopCeFw6mF9gqmVVCIAvybItFREZmA/wMBfw863Yj/uLRLKt0xDvi2BCRpv+JuKuLln\nVV0YjECvn0ZZiA3vV4Zk898LKpJBVeeJyDBsuiFcRnoa9i//9UjWG4A3sBf798D7wHOBElQlZPgM\np6vq8oS4nzGlY3Pga2zqZltMkUhEMUUmG97ElFjFRnV+SCILJH/OnYD1MMWmPJk6BOfEz/5cEUk6\nbRdhy6CuHyrIly4bYgriz0nSxmO/Ce2D65DEz2f0+54OIzQ9g9jJGaRl04/y6ndqEFdOHLB/FU8A\nGwHvZflvLJUtRrlz3ekgIi2xeeSFwHXYkPYK7B/0HeTesLuyfUn2Ikv176xhFcjwEvC0iOysquOA\nvwPDoi8DVR0hIlsCxwOHY6MFhSJyoarm3FlZFT3DBsB3mL1DsvuSrYHjNFX9KI18yZ5zA8zW6vQU\nMiVTpGojVfZ9TyDZPU4nLRf1OzWIKycO2D/q/sBeRKYDqgnB/m1+8leEyDqYovROEHUQ9o/seFX9\nPJJvyyT1pVICfgP2FJGGmp5Ra65ZgP2jTmTzKmjrDex5dguMDrfGRiziUNWFwLPAsyLSDBuavxEz\n5Mw1B5H+MwTYWMy3R/TlsQ32fEPj5d+AnVV1eBXImy2/YYayX2j5y4rDaYdORP69i0hrKh59+A37\n3uwIlKdEpTtdMQdYht3fRLbDbDNqw0qWutIPB19K7GDTKJhTpBsxg9Lq5gIRiSrKvbARhXeDcAn2\nYxxdNtw4yJfIUpJP87yGDftenAuBs+A3oKWI7BhGiMhGwAm5bkhtxcMQbGrnNGy1w5vRPCLSKqHM\nMmyKoUkkTwsR2UZEolMw2ZLJMwT743RRJG8jbNXFHMwwFsxealMROT+xsIg0DRSu6mYQJnsZ1/Ai\n0jAYQQIzsl2DreKJUphGG6MwBe2ySH3JWBq0W+7zC6Y6PwCOF5FwugkRaYutvhmhquVNzeYFdaUf\njuEjJ/WXuOFXVX2upgQBGmMGhIMwG4Ke2A/J20H6F9jIw0CJuUjvTvJ/hsXAqSJyD7ZE9s+gnoHY\n8s57A58HI7DlwocCj6hqrpSyVMPaLwH/Bd4I+rAO9vKdgK2ayjUvA89jL/8hCQa+AD+KyMfY/ZoP\n7IEZc0Zd0J+IrTLqgd2/ithdRK5NEj+czJ4h2Kqcq0Rkc8yG4DRgZ+D8yMjXc8SWGB8MfI4ptdth\nU1mHE1NkqgVV/VRE+gP/EpFdsZflamz06hRsSfjgwLbk7iDf25givhtwJMmnfv76XKmqikhP4C1g\njJhPkxnYd2d7VT0qyFo
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xa2e2710>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
2016-10-26 20:16:27 +03:00
"pred_basic_model_bn = train_and_evaluate(reader_train, reader_test, max_epochs=5, model_func=create_basic_model_with_batch_normalization)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's implement an inspired VGG style network, using layer API, here the architecture:\n",
"\n",
"| VGG9 |\n",
"| ------------- |\n",
"| conv3-64 |\n",
"| conv3-64 |\n",
"| max3 |\n",
"| |\n",
"| conv3-96 |\n",
"| conv3-96 |\n",
"| max3 |\n",
"| |\n",
"| conv3-128 |\n",
"| conv3-128 |\n",
"| max3 |\n",
"| |\n",
"| FC-1024 |\n",
"| FC-1024 |\n",
"| |\n",
"| FC-10 |\n"
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 30,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_vgg9_model(input, out_dims):\n",
" with default_options(activation=relu):\n",
" model = Sequential([\n",
" LayerStack(3, lambda i: [\n",
2016-11-01 06:03:17 +03:00
" Convolution((3,3), [64,96,128][i], init=glorot_uniform(), pad=True),\n",
" Convolution((3,3), [64,96,128][i], init=glorot_uniform(), pad=True),\n",
2016-10-25 06:42:43 +03:00
" MaxPooling((3,3), strides=(2,2))\n",
" ]),\n",
" LayerStack(2, lambda : [\n",
2016-11-01 06:03:17 +03:00
" Dense(1024, init=glorot_uniform())\n",
2016-10-25 06:42:43 +03:00
" ]),\n",
2016-11-01 06:03:17 +03:00
" Dense(out_dims, init=glorot_uniform(), activation=None)\n",
2016-10-25 06:42:43 +03:00
" ])\n",
2016-10-26 06:41:49 +03:00
" \n",
" return model(input)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 31,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": false
},
2016-10-25 23:52:23 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training 2675978 parameters in 18 parameter tensors.\n",
"\n",
2016-11-29 05:56:43 +03:00
"Finished Epoch [1]: [Training] loss = 2.258327 * 50000, metric = 83.0% * 50000 11.357s (4402.7 samples per second)\n",
"Finished Epoch [2]: [Training] loss = 1.928672 * 50000, metric = 72.0% * 50000 11.328s (4414.0 samples per second)\n",
"Finished Epoch [3]: [Training] loss = 1.709348 * 50000, metric = 63.5% * 50000 11.514s (4342.5 samples per second)\n",
"Finished Epoch [4]: [Training] loss = 1.571040 * 50000, metric = 57.7% * 50000 12.057s (4147.1 samples per second)\n",
"Finished Epoch [5]: [Training] loss = 1.467748 * 50000, metric = 53.6% * 50000 11.984s (4172.3 samples per second)\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Final Results: Minibatch[1-626]: errs = 49.6% * 10000\n",
2016-10-25 23:52:23 +03:00
"\n"
]
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XeYVOXZx/HvDVIEBRTFktgVxYaCJfYae0nsK0bUGAsx\nKhpN1NjeGFtiF4jGRixrjcYWC7HHDtaIYsMCKkGqNIG93z/uM87ZYXZ3drbM7Ozvc11zzZzntOeZ\ns7D3PtXcHREREZFy0aHUGRARERFJU3AiIiIiZUXBiYiIiJQVBSciIiJSVhSciIiISFlRcCIiIiJl\nRcGJiIiIlBUFJyIiIlJWFJyIiIhIWVFwIhXBzGrM7Jwizx1vZjeltgcn1xvQfDksnpmtkuTnlFLn\nRepmZi+b2aNFnnunmY1t7jwVeO+i8y3SUhScSNlIBQU1ZrZlHcd8kex/MGeXJ69i1OQ5t1nXdTCz\nKjM7qTmvKXUzs5tTP0v1vW5q+GoFa8rPoBM/h6WgNUyk7CxW6gyI5DEHOBR4MZ1oZtsBPwLm5jln\ncWBBkfdbm5b/xXAosB5wVQvfR8JfgSdT26sB/wdcDzyfSv+4Ge+5DcX/oj8MsGbMi0ibpuBEytGj\nwIFmdqK7p4OGQ4HXgWVyT3D374u9mbvPL/bctsrMurn77FLno6W4+yvAK5ltMxsI/BF4yd3vKOQa\nZtbV3fMFwnXds9jgGHdfWOy5IpVIzTpSbhyoBnoDP80kmlkn4ADgDvL8hZnb58TMzkvS1jCzW8xs\nqplNM7ObzKxrzrnj66je725m15nZZDObbmYjzaxXzrn7mNnDZjbBzOaa2Udm9gcz65A65mlgTyDT\nd6TGzD5J7e+S5PcDM5tjZhPN7D4zWy1POX+V3GOumb1qZps09IWmmsu2NbPhZvYN8EWy7xYz+zTP\nOeeZWU1OWo2ZXW1m+5rZO0ke3jWzXRu4fx8zm29mZ+fZ1ze57pBkezEzO9fMxiXfxWQze97Mdmqo\nnE1hZl+b2d1mtqeZjTazucDhyb5fmdlTZvZNkqd3zOyoPNeo1XfDzHZNyrZP8n1OMLPZZva4ma2S\nc26tPidmtnbme0leHyf3ftHM+ue596FmNjY55s2kHEX3YzGz5ZKfjUnJNd8ws6o8xx1uZmPMbGby\n7+stMzs+tb+zmV1gZh8m1/mfmT1rZtsWky9pP1RzIuVoPPAyUAU8nqTtAfQA7gQK6buRqV6/G/gE\n+D0wADga+AY4I8+xaQZcC0wFziWafoYAKwM7pI47ApgJXAZ8B+xINB8sCfwuOeYCoCfRJHVycu3v\nAJIg5pHkmtXAlcm5PwXWB9KBwyBgCaLJwpPr32dmqxf4l/dwYBJwPtAtVfZ85a8rfRtgv+RaM4ET\ngXvNbGV3n5rvpu4+ycyeBQ4iai/SDiGa4+5Ots8nntX1wGvEM9+EeHb/briIRXNgQ2AkUba/Av9N\n9g1J8nI/0fz3M+AGM3N3vznnGvmcC8wDLiaC7tOBW6j9c1TX9/1LoCswDOhIPPN7zayvuzuAme0H\n3EbUKv6OqFm8FZhYT57qZGbdgReIn9ergS+Bg4HbzWwJd/9bctzeSTn+BVxH/LG7HrAFMCK53EXE\nv9cRwBvEv4PNgI2A5xqbN2lH3F0vvcriBQwGFhK/iIYA04Auyb67gFHJ50+BB3POrQHOSW2fm6Rd\nn3PcfcCknLRPgZty8lFDNAt0TKX/NsnfXqm0LnnKMYL4xd0plfYQ8EmeY49M7nViPd/LKskxk4Ae\nqfS9k/zsUcD3WgM8A1jOvpvryNe5wMI83/EcYNVU2gZJ+pAG8vCrJK/r5qS/CzyZ2n4j99k208/W\nwCSfh9ex/6skf1vn2ZfvGT8FvJOT9hLwaGp71+SeY3J+jk5L7rV6Kq0aeC+1nekHNQHonko/MDl3\nx1TaB8CH6XwSwW1N+pr1fDe5+f5dco+fpdIWI4Kfb4GuqZ/zrxu49ljg7uZ+nnpV/kvNOlKu7ib+\nut/LzJYA9gJub+Q1nPiLLu15oHdyzYZc77VrJEaQBAM/3MB9XuazmS1hZr2Jvzq7AesUcI/9gP8R\ntTQNudPdZ6S2nydqYVYv4FwH/ubuTR2Z8aS7j//hou7vADMKyMM/iO/u4EyCma0HrEvUhmVMA9Yz\nszWbmM9ijHX3F3ITc55xTzNbhvirv5+ZdS7gujfk/BxlOuQW8txud/dZOef+8MyTpr+1gJvT+XT3\nJ4mApRi7A5+5+wOp6y0ArgF6AZmRdNOAnma2Yz3XmgZsmK+JUqQ+Ck6kLLn7ZGAU0Ql2P+Jn9d4i\nLvV5znam6WGphrIAfJSTp1nEX9irZtLMbF0zu9/MphG/pP9HVKlDVGE3ZA3gA6/d8bcuX+TkZ1ry\nsaGyZIwv8LiC85CY2lAe3P1bolnmoFTyIcB8orkk4xziF+A4M3vbzC41sw2aluWCLdL3BmKUmJk9\nbWaziLJOSvJpRLNTQ3K/s6nJuYU8t3znkjo303cl36ijj/KkFWIVYFye9LFEvjP3vAb4DHjSzD4z\ns7+Z2c4555wFLAd8nPSFucjM1i0yX9KOKDiRcnYHUUtxHPAvd59ZxDXq6ovR5GGbZtaT+At6A+AP\nRO3OzmT7mjT3v6+mlmVOnrS6alI6tkAe7gT6mtmGyfaBwL/dfcoPmXF/ngjYjgTeIfpcjMnXAbUF\nLPL9mNk6wBNAd6LvxB7EM87UdBXyjJvynbXYz29TuftE4mf/50S/qZ2BJ8xsROqYp4jn+UsiuDkW\neNPMBrV+jqUtUXAi5SzTAXFzIlBpTUZUl2cToqPgCmRrILYn/oId7O7XuvujyX/G01hUXUHAx8Da\nZlZXMNDSphI1FblWbYF7PUDUlBycjDjpS/S1qMXdp7n7SHcfBKwEvA2c1wL5KcS+RH+LPdz9Bnd/\nLHnG5TL8/LPkPV8zWLFNY58RzyZXP+LnOHNP3H2+uz/o7kOIpqZbgGPMbMXUMVPc/WZ3ryI6lH9A\n9GkSqZOCEylbSTPKccQvpodKkIVjzCw9om0IUaOQGS66kAhi0sOGOyfH5ZpF/mae+4BlgROaI8NF\n+JjoN7B+JsHMViBGpDQrd59OjL46iGjSmQf8M32MmS2dc85sonmiS+qYHslQ20KaVJoqU3ORfsa9\niUnTCtGis6+6+6dE35IjLDVE3mJ491p1nli/R4lh7/umrrcY8TM6DfhPkpb7rJzo4AzJ88pzzHfE\n6LkuiNRDQ4ml3NSqrnb3W+s6sBV0Bv5tZncTnVuPB55394eT/S8SNQ9/N7Ork7TDyP8LaTRwkJld\nRgxL/S65zt+J+TQuN7PNiQ6PSwA7AcPcvbmCsrqaAe4ELgEeSMrQnQgIPyBGTTW3u4hhr0OAx3M6\n+AK8Z2bPEN/XFGBTYn6bq1PH/JwYZXQE8f21pMeAC4F/mdkNRC3TMcQomkUmA8yjNZpfziK+1xfM\n7O9AH+Jn9b8U9wfoMGLI/R1mdi3R7+UQ4ufhuFTH29vMrAsxCmwCUXNyAvBKEjRB9DX5FzFiaSox\nzHgv4NIi8iXtiIITKTeF/KWZb06Ipq5rku96JxBzi5wPdCJGC/0wx4q7TzGzPYk5Tv5I/Od7KzHM\n9PGc6w0H+hO/UE8mqsYfdvcaM9ud+AWT6fz7LRGkvFNA+Qotd95jkjL8DLicCFI+JeYZ6cuiwUlT\n8wDwING3ozu1R+lkXAXsQwyF7UJ8T2cCf8lzz8aq75y8ZXD3d83sQOL5Xkb8Er6CqPUZXsA96rpn\nXd9joef+sM/d7zWzXwBnE8/wAyJIHgKsmPcK9dzb3WeZ2TbEvCxHEvPujAUGuXv6md1C9CUZQgRt\nXxHB4vmpY64gJiDclXienxJDqa8sMF/STlnTRxaKiEi5SWaHHefu+zZ4sEiZKXmfEzM7w2Ia7hkW\n00Pfb2b5OmPVdf5WFlNjj2nJfIqIlCOLKf875KTtRkzk9nRpciXSNCWvObFYi6KamH1wMWK64/WB\nfu6eb+hj+tyeRNv0h8B
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xa2aa160>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXncVeP2wL+rNEqlokIomshFGa4pFBeZZyWuKZIx7g/X\nFOEiV2aJS8mQzDJGJBGiiSiRypBKI6k0rd8fax97n/Oe877nnPe873ve867v57M/e+9nP8+z13P2\nGdZZz1rrEVXFcRzHcRwnX6hW0QI4juM4juNEceXEcRzHcZy8wpUTx3Ecx3HyCldOHMdxHMfJK1w5\ncRzHcRwnr3DlxHEcx3GcvMKVE8dxHMdx8gpXThzHcRzHyStcOXEcx3EcJ69w5cTJCBHZICLXZ9l2\njog8Fjn/Z9Bfx9xJmD0isk0gz2UVLUtVRkTeF5Evctxn3Hsvn0nyOdk/eF92zuE9sv4cO0554MpJ\nFSSiFGwQkb1T1PkxuD4y4ZIGWzZsSNI2p+sniEh3Ebkkl306JRO8V+7NUXdlsaZGWn1GPhcbRGS9\niPwsIqNEZP8ykCkVyWTN+DURkcNEpF8x9yj3tUtEpF/Ca5z4em9e3jI5+clGFS2AU6GsAnoA46OF\nwRfxlsDqJG3qAOuyvF9bTEEpS3oAOwL3lPF9nMLlbWAYIEBLoA/wnoh0U9VR5S2Mqo4VkTqquibD\npt0w2W9Mcq00n+PSokBv4I8k15aVsyxOnuLKSdXmDeBEEblYVaNKQw/gc6BJYoMsviCjbddm27ay\nIiJ1VXVlRcvhZMRMVX06diIiLwNfAJcCSZUTERGgpqr+WRYCZfm5kxz3l0teUNUlmTQQkVrAGk2y\nWm0uPmf+Wc0vfFqn6qLAcKAxcHCsUERqACcAT5Pkyy1xrlpEbgjKthORoSKyVESWichjIlI7oW2q\nef+NRWSwiCwSkeUi8riINExoe5SIvBaY2VeLyHcicq2IVIvUGQMcDsR8RzaIyPeR67UCeb8RkVUi\nMk9EXhCRlknG2Su4x2oRmSAiu5X0gkamyzqLyIMisgD4Mbg2VERmJ2lzg4hsSCjbICL3isjRIvJl\nIMM0ETmkhPtvLiJrReS6JNfaBP32Cc43CkzsM4PXYpGIjBORriWNM1vSeYYJ9TuKyEcislJEvheR\n85LUqSkiN4rIt0GfP4jI7SJSM1dyq+o0YBFmRYndN/aMeojINMzKeEhwTUTk0uCZrRKR+SLyUOJ7\nOqh7rdgU6h8i8q6I7JCkTlKfExHZU0TeEJElIrJCRKaKyEXBtSGY1SQ6VbU+Qf7rE/rbVUTeDD6D\nv4vIaBHZM6FO7D2+t4gMFJGFwb1fFJHGGb+4KYiM+WQRuVlEfsIsLZuIyBmpPmdZjCNpH07F45aT\nqs0c4BOgO+E/wm5AfeAZIB3fjdi/mGeB74GrgI7AOcAC4N9J6kYR4H5gKdAPm/rpA2wNHBipdwbw\nO3AnsALoAvQHNgGuDOrcDDTApqQuDfpeARD8AL4e9DkcuDtoezDQAYgqDqcC9YCHApmvBF4QkVaq\nup6SeRBYiJnT60bGnsqXIFn5fsBxQV+/AxcDz4vI1qq6NNlNVXWhiIwFTgJuSrh8CmbGfzY4vxF7\nVg8Dn2HPfDfs2b1b8hCz4gxKfoYxGmHP61lMUT4JGCQif6rqUPjLWvEqsDcwGJgB7AT0BVpjr1+p\nEZFNgU2BbxMudQ3kuh9TXuYE5Q8DpwOPYdOLLYGLgF1EZJ/Ye0hEbgKuAV4D3sRe+7eBGknEiHuP\niMjB2NjnYe/l+UB74AjgPuz12AI4CHs/p7SiBP3tAHwALAduw94r5wHvi0hnVf0socl9wBLgBmBb\n7DW/H/suSYfGwfOLsk5VlyeUXQf8CdwB1ALWEL4W0c/ZxsE4dsxwHEX6cPIEVfWtim3AP4H12Jdh\nH2yet1ZwbQQwOjieDYxMaLsBuD5y3i8oezih3gvAwoSy2cBjCXJsAD4FqkfK/xXId0SkrFaScQzC\nfuxqRMpeBb5PUvfM4F4XF/O6bBPUWQjUj5QfGcjTLY3XdQPwPiAJ14akkKsfsD7Ja7wK2DZStlNQ\n3qcEGXoFsu6QUD4NeCdyPjnx2ZbyPbUBuLeEOuk+wzHBGC6JlNUAJgG/xN4rQE9gLbBXQp/nBu3/\nnuq9V8I4HsYsik2APYDRSeTZENy7bUL7fYNrJyeUHxyUnxKcN8GsLa8k1Ls5qBf9nOwf3L9zcF4N\n+yMwC9ikmLHcl/jeSpA/+jl+KXjPbRMpa4b9yI9J8h5/K6G/OzHFIaU8kff7hhTb1wlj3oAphDUz\n+JxlOo4iffiWH5tP6zjPYv/ujxCRetg/r6cy7EOxf2pRxmH/juql0f5hjbdIDCJQBv66QWQuX0Tq\nBSbkDwPZ26Vxj+OAX7F/dyXxjKr+Fjkfh/3zbJVGWwUe0eAbsBS8o6pz/upU9UvgtzRkeBF77U6O\nFQT/JnfArGExlgE7isj2pZQzbTJ8huswJSHWdi32Htsc6BQUnwBMB2aKSOPYhik3QrzlLRPOxt4r\nCzHL4l7Anaqa6GT9vqp+k1B2Avbavpsg02TMWhST6WBM4bovof3daci3K2atuFtVf09vSKkJrIoH\nAy+p6txYuarOx6xW+yZ8jpXIswkYB1THFPySUOBYzKoT3c5MUneoJvePKfI5y3IcufisOmWAT+tU\ncVR1kYiMxpxgN8b+lT2fRVc/JJzHph42JZhaSSUC8F2CTH+IyC/YFzDwl9n5FuzLvX5C+wZpyLcd\n8I3GO/6mIm7uWVWXBRboTdNoC6F5vzQkm/9eWpIMqrpYRN7FphtiYaSnYP/yX4pUvR54Gfthnwa8\nBTwRKEFlQobPcJ6qrkoom4kpHdsCE7Cpm3aYIpGIYopMNryCKbGKWXW+SiILJH/OrYGGmGJTnExb\nB/vE9/4iEUk6bRdhu6Cvr0qoly6bYQrizCTXpmPfCS2C4xiJ78/o5z0dxml6DrFzMriWzTiK69+p\nQFw5ccD+VTwCNAfezPLfWCpfjGLnutNBRBpg88jLgGsxk/Zq7B/0beTesbu0Y0n2Q5bq31n1MpDh\nGeAxEfmbqn4BnAi8G/0xUNVxIrIdcDTwD8xa0FdEzlPVnCcrK6NnWA34EvN3SPa6ZOvg+JOqvpdG\nvWTPuRrma9UjhUzJFKnKSJl93hNI9hqncy0X/TsViCsnDtg/6sHAnkSmA8oJwf5tjv2rQGRjTFF6\nPSg6APtHdrSqfhSpt12S/lIpAbOAPUSkuqbn1JprlmL/qBPZtgzu9TL2PE8OnA7bYBaLOFR1GfA4\n8LiI1MVM8zdgjpy55gDSf4YAW4jl9oj+eLTFnm/MeXkW8DdVHVMG8mbLLMxRdrwWH1Ycm3ZoTeTf\nu4g0oWTrwyzsc9MBKE6JSne64ldgJfb6JtIe882oDJEshTIOBw8ldrBpFCwp0g2YQ2l5c66IRBXl\nPphF4Y3gfD32ZRwNG64Z1EvkD5JP87yAmX0vzIXAWTALaCAiHWIFItIcOCbXN1KLeBiFTe2cgkU7\nvBKtIyKNEtqsxKYYakXq1BeRtiISnYLJlkyeIdgfp96RujWwqItfMcdYMH+prUSkV2JjEakdKFzl\nzbOY7EVSw4tI9cCCBOZkuw6L4onSN417TMIUtEsj/SXjj+C+xT6/YKrzbeBoEYlNNyEiTbHom3Gq\nWtzUbF5QKONwDLecVF3izK+q+kRFCQLUxBwIn8V8CM7HvkheC66PxywPwyRMkd6T5P8MJwInicid\nWIjsiqCfYVh458Ag58E4LFy4K/CAquZKKUtl1n4GuB14ORjDxtiP7zdY1FSuGQE8if34j0pw8AX4\nWkTex16vJcDumDNnNAX9sViU0RnY61cSu4nINUnKx5DZMwSLyrlCRLbFfAhOAf4G9IpYvp4gDDE+\nEPgIU2rbY1NZ/yBUZMoFVf1ARAYDV4nILtiP5VrMenUCFhL+YuBb8t+g3muYIr4rcCjJp37+el+p\nqorI+cBIYIpYTpNfsM/
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xa131ba8>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
2016-10-26 20:16:27 +03:00
"pred_vgg = train_and_evaluate(reader_train, reader_test, max_epochs=5, model_func=create_vgg9_model)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Residual Network (ResNet)\n",
"\n",
"One of the main problem of a Deep Neural Network is how to propagate the error all the way to the first layer. For a deep network, the gradient keep getting smaller until it has no effect on the network weights. [ResNet](https://arxiv.org/abs/1512.03385) was designed to overcome such problem, by defining a block with identity path, as shown below:\n",
"\n",
2016-11-14 16:52:39 +03:00
"<img src=\"https://cntk.ai/jup/201/ResNetBlock2.png\">\n",
2016-10-25 06:42:43 +03:00
"\n",
"The idea of the above block is 2 folds:\n",
"\n",
"* During back propagation the gradient have a path that doesn't affect its magnitude.\n",
"* The network need to learn residual mapping (delta to x).\n",
"\n",
"So let's implements ResNet blocks using CNTK:\n",
"\n",
" ResNetNode ResNetNodeInc\n",
" | |\n",
" +------+------+ +---------+----------+\n",
" | | | |\n",
" V | V V\n",
" +----------+ | +--------------+ +----------------+\n",
" | Conv, BN | | | Conv x 2, BN | | SubSample, BN |\n",
" +----------+ | +--------------+ +----------------+\n",
" | | | |\n",
" V | V |\n",
" +-------+ | +-------+ |\n",
" | ReLU | | | ReLU | |\n",
" +-------+ | +-------+ |\n",
" | | | |\n",
" V | V |\n",
" +----------+ | +----------+ |\n",
" | Conv, BN | | | Conv, BN | |\n",
" +----------+ | +----------+ |\n",
" | | | |\n",
" | +---+ | | +---+ |\n",
" +--->| + |<---+ +------>+ + +<-------+\n",
" +---+ +---+\n",
" | |\n",
" V V\n",
" +-------+ +-------+\n",
" | ReLU | | ReLU |\n",
" +-------+ +-------+\n",
" | |\n",
" V V\n"
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 32,
2016-10-25 06:42:43 +03:00
"metadata": {
2016-11-01 06:03:17 +03:00
"collapsed": false
2016-10-25 06:42:43 +03:00
},
"outputs": [],
"source": [
"from cntk.ops import combine, times, element_times, AVG_POOLING\n",
"\n",
2016-11-02 04:03:02 +03:00
"def convolution_bn(input, filter_size, num_filters, strides=(1,1), init=he_normal(), activation=relu):\n",
2016-10-26 06:41:49 +03:00
" if activation is None:\n",
" activation = lambda x: x\n",
" \n",
2016-11-01 06:03:17 +03:00
" r = Convolution(filter_size, num_filters, strides=strides, init=init, activation=None, pad=True, bias=False)(input)\n",
2016-10-26 06:41:49 +03:00
" r = BatchNormalization(map_rank=1)(r)\n",
" r = activation(r)\n",
" \n",
" return r\n",
"\n",
2016-10-25 06:42:43 +03:00
"def resnet_basic(input, num_filters):\n",
2016-11-01 06:03:17 +03:00
" c1 = convolution_bn(input, (3,3), num_filters)\n",
" c2 = convolution_bn(c1, (3,3), num_filters, activation=None)\n",
2016-10-25 06:42:43 +03:00
" p = c2 + input\n",
" return relu(p)\n",
"\n",
"def resnet_basic_inc(input, num_filters):\n",
2016-11-01 06:03:17 +03:00
" c1 = convolution_bn(input, (3,3), num_filters, strides=(2,2))\n",
" c2 = convolution_bn(c1, (3,3), num_filters, activation=None)\n",
2016-10-26 06:41:49 +03:00
"\n",
2016-11-01 06:03:17 +03:00
" s = convolution_bn(input, (1,1), num_filters, strides=(2,2), activation=None)\n",
2016-10-25 06:42:43 +03:00
" \n",
" p = c2 + s\n",
" return relu(p)\n",
"\n",
2016-11-02 04:03:02 +03:00
"def resnet_basic_stack(input, num_filters, num_stack):\n",
" assert (num_stack > 0)\n",
" \n",
" r = input\n",
" for _ in range(num_stack):\n",
" r = resnet_basic(r, num_filters)\n",
" return r"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's write the full model:"
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 33,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_resnet_model(input, out_dims):\n",
2016-11-02 04:03:02 +03:00
" conv = convolution_bn(input, (3,3), 16)\n",
2016-11-02 21:29:44 +03:00
" r1_1 = resnet_basic_stack(conv, 16, 3)\n",
2016-10-25 06:42:43 +03:00
"\n",
" r2_1 = resnet_basic_inc(r1_1, 32)\n",
2016-11-02 04:03:02 +03:00
" r2_2 = resnet_basic_stack(r2_1, 32, 2)\n",
2016-10-25 06:42:43 +03:00
"\n",
" r3_1 = resnet_basic_inc(r2_2, 64)\n",
2016-11-02 04:03:02 +03:00
" r3_2 = resnet_basic_stack(r3_1, 64, 2)\n",
2016-10-25 06:42:43 +03:00
"\n",
" # Global average pooling\n",
2016-10-26 06:41:49 +03:00
" pool = AveragePooling(filter_shape=(8,8), strides=(1,1))(r3_2) \n",
2016-11-02 04:03:02 +03:00
" net = Dense(out_dims, init=he_normal(), activation=None)(pool)\n",
2016-10-25 06:42:43 +03:00
" \n",
" return net"
]
},
{
"cell_type": "code",
2016-11-29 05:56:43 +03:00
"execution_count": 34,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": false
},
2016-10-25 23:52:23 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-11-02 21:29:44 +03:00
"Training 272474 parameters in 65 parameter tensors.\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Finished Epoch [1]: [Training] loss = 1.888094 * 50000, metric = 70.5% * 50000 15.690s (3186.8 samples per second)\n",
"Finished Epoch [2]: [Training] loss = 1.545802 * 50000, metric = 57.7% * 50000 15.700s (3184.6 samples per second)\n",
"Finished Epoch [3]: [Training] loss = 1.421820 * 50000, metric = 51.9% * 50000 15.715s (3181.6 samples per second)\n",
"Finished Epoch [4]: [Training] loss = 1.333261 * 50000, metric = 48.4% * 50000 15.514s (3223.0 samples per second)\n",
"Finished Epoch [5]: [Training] loss = 1.255839 * 50000, metric = 45.6% * 50000 15.436s (3239.2 samples per second)\n",
2016-10-25 23:52:23 +03:00
"\n",
2016-11-29 05:56:43 +03:00
"Final Results: Minibatch[1-626]: errs = 45.1% * 10000\n",
2016-10-25 23:52:23 +03:00
"\n"
]
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XeYVOXZx/HvDQhEFKyIRsXeK5gYVOy966sosWs0isSI\nSSyJiklMjF2jYkwRO2LvvWPXxQIKooCChSZFRUBg7/eP+4xzZtjO7s7Z2d/nuuaaOc9pzzNnYe99\nqrk7IiIiIlnRptQZEBEREUlTcCIiIiKZouBEREREMkXBiYiIiGSKghMRERHJFAUnIiIikikKTkRE\nRCRTFJyIiIhIpig4ERERkUxRcCJlwcwqzez8Bp77qZndmNo+Jrlej8bLYcOZWfckP2eUOi9SPTN7\n3cwea+C5d5rZqMbOUx3v3eB8izQVBSeSGamgoNLMtqnmmInJ/oeKdnnyaojKKs5t1HUdzKyvmf22\nMa8p1TOzwamfpZpeN9Z+tTpbnJ9BJ34OS0FrmEjmtCt1BkSqMAf4JfBqOtHMdgB+Csyt4pyfAAsa\neL/1afpfDL8ENgaubuL7SPgX8HRqe03gL8C/gWGp9LGNeM/eNPwX/ZGANWJeRFo0BSeSRY8Bh5rZ\nae6eDhp+CbwNrFB8grv/0NCbufv8hp7bUpnZku7+fanz0VTc/Q3gjdy2mfUE/gq85u531OUaZtbR\n3asKhKu7Z0ODY9x9YUPPFSlHataRrHFgCLA8sFsu0cyWAA4B7qCKvzCL+5yY2QVJ2tpmdpOZzTCz\nmWZ2o5l1LDr302qq9zuZ2Q1mNs3MZpnZzWa2TNG5+5vZI2b2hZnNNbNPzOxcM2uTOuZ5YB8g13ek\n0szGpfZ3SPL7kZnNMbMvzexeM1uzinKemNxjrpm9aWZb1faFpprLtjezQWY2GZiY7LvJzMZXcc4F\nZlZZlFZpZv80swPMbESSh5Fmtkct9+9qZvPN7Lwq9q2XXLdfst3OzAaa2Zjku5hmZsPMbJfayrk4\nzGySmd1lZvuYWYWZzQWOTvadaGbPmdnkJE8jzOz4Kq5R0HfDzPZIyrZ/8n1+YWbfm9mTZta96NyC\nPidmtn7ue0leY5N7v2pmm1dx71+a2ajkmHeTcjS4H4uZrZT8bExJrvmOmfWt4rijzWy4mX2b/Pt6\nz8xOSe1vb2YXmtnHyXWmmtmLZrZ9Q/IlrYdqTiSLPgVeB/oCTyZpewOdgTuBuvTdyFWv3wWMA84G\negC/AiYD51RxbJoB1wIzgIFE008/YHVgp9RxxwLfApcD3wE7E80HSwNnJcdcCHQhmqROT679HUAS\nxDyaXHMIcFVy7m7AJkA6cDgCWIposvDk+vea2Vp1/Mt7EDAF+DOwZKrsVZW/uvTewMHJtb4FTgPu\nMbPV3X1GVTd19ylm9iLQh6i9SDucaI67K9n+M/Gs/g28RTzzrYhn92ztRWwwBzYDbibK9i/gg2Rf\nvyQv9xPNfwcC/zUzd/fBRdeoykBgHvAPIug+E7iJwp+j6r7vE4COwHVAW+KZ32Nm67m7A5jZwcBt\nRK3iWUTN4q3AlzXkqVpm1gl4mfh5/SfwOXAYcLuZLeXu/0mO2y8px+PADcQfuxsDvYDrk8tdRPx7\nvR54h/h38HNgC+Cl+uZNWhF310uvTLyAY4CFxC+ifsBMoEOybyjwTPJ5PPBQ0bmVwPmp7YFJ2r+L\njrsXmFKUNh64sSgflUSzQNtU+u+T/O2bSutQRTmuJ35xL5FKexgYV8WxxyX3Oq2G76V7cswUoHMq\nfb8kP3vX4XutBF4ArGjf4GryNRBYWMV3PAdYI5W2aZLer5Y8nJjkdaOi9JHA06ntd4qfbSP9bPVM\n8nl0Nfu/SvK3XRX7qnrGzwEjitJeAx5Lbe+R3HN40c/RH5J7rZVKGwJ8mNrO9YP6AuiUSj80OXfn\nVNpHwMfpfBLBbWX6mjV8N8X5Piu5x4GptHZE8PM10DH1cz6plmuPAu5q7OepV/m/1KwjWXUX8df9\nvma2FLAvcHs9r+HEX3Rpw4Dlk2vW5t9eWCNxPUkw8OMN3OflPpvZUma2PPFX55LABnW4x8HAVKKW\npjZ3uvs3qe1hRC3MWnU414H/uPvijsx42t0//fGi7iOAb+qQh/uI7+6wXIKZbQxsRNSG5cwENjaz\ndRYznw0xyt1fLk4sesZdzGwF4q/+Dc2sfR2u+9+in6Nch9y6PLfb3X120bk/PvOk6W9dYHA6n+7+\nNBGwNMRewGfu/kDqeguAa4BlgNxIuplAFzPbuYZrzQQ2q6qJUqQmCk4kk9x9GvAM0Qn2YOJn9Z4G\nXGpC0Xau6WHZ2rIAfFKUp9nEX9hr5NLMbCMzu9/MZhK/pKcSVeoQVdi1WRv4yAs7/lZnYlF+ZiYf\naytLzqd1PK7OeUjMqC0P7v410SzTJ5V8ODCfaC7JOZ/4BTjGzN43s0vMbNPFy3KdLdL3BmKUmJk9\nb2azibJOSfJpRLNTbYq/sxnJuXV5blWdS+rcXN+VqkYdfVJFWl10B8ZUkT6KyHfuntcAnwFPm9ln\nZvYfM9u16Jw/ASsBY5O+MBeZ2UYNzJe0IgpOJMvuIGopTgYed/dvG3CN6vpiLPawTTPrQvwFvSlw\nLlG7syv5viaN/e9rccsyp4q06mpS2jZBHu4E1jOzzZLtQ4Fn3X36j5lxH0YEbMcBI4g+F8Or6oDa\nBBb5fsxsA+ApoBPRd2Jv4hnnarrq8owX5ztrsp/fxeXuXxI/+wcR/aZ2BZ4ys+tTxzxHPM8TiODm\n18C7ZnZE8+dYWhIFJ5JluQ6IWxOBSnMyoro8nxAdBVcmXwOxI/EX7DHufq27P5b8ZzyTRVUXBIwF\n1jez6oKBpjaDqKkotkYT3OsBoqbksGTEyXpEX4sC7j7T3W929yOA1YD3gQuaID91cQDR32Jvd/+v\nuz+RPOOsDD//LHmvqhmsoU1jnxHPptiGxM9x7p64+3x3f8jd+xFNTTcBJ5nZKqljprv7YHfvS3Qo\n/4jo0yRSLQUnkllJM8rJxC+mh0uQhZPMLD2irR9Ro5AbLrqQCGLSw4bbJ8cVm03VzTz3AisC/Rsj\nww0wlug3sEkuwcxWJkakNCp3n0WMvupDNOnMAx5MH2NmyxWd8z3RPNEhdUznZKhtXZpUFleu5iL9\njJcnJk2riyadfdXdxxN9S4611BB5i+Hd61Z7Ys0eI4a9H5C6XjviZ3Qm8EqSVvysnOjgDMnzquKY\n74jRcx0QqYGGEkvWFFRXu/ut1R3YDNoDz5rZXUTn1lOAYe7+SLL/VaLm4RYz+2eSdiRV/0KqAPqY\n2eXEsNTvkuvcQsyncYWZbU10eFwK2AW4zt0bKyirrhngTuBi4IGkDJ2IgPAjYtRUYxtKDHvtBzxZ\n1MEX4EMze4H4vqYDPyPmt/ln6piDiFFGxxLfX1N6Avg78LiZ/ZeoZTqJGEWzyGSAVWiO5pc/Ed/r\ny2Z2C9CV+Fn9gIb9AXodMeT+DjO7luj3cjjx83ByquPtbWbWgRgF9gVRc9IfeCMJmiD6mjxOjFia\nQQwz3he4pAH5klZEwYlkTV3+0qxqTojFXdekquv1J+YW+TOwBDFa6Mc5Vtx9upntQ8xx8lfiP99b\niWGmTxZdbxCwOfEL9XSiavwRd680s72IXzC5zr9fE0HKiDqUr67lrvKYpAwHAlcQQcp4Yp6R9Vg0\nOFncPAA8RPTt6EThKJ2cq4H9iaGwHYjv6Y/AZVXcs75qOqfKMrj7SDM7lHi+lxO/hK8kan0G1eEe\n1d2zuu+xruf+uM/d7zGzo4DziGf4EREk9wNWqfIKNdzb3WebWW9iXpbjiHl3RgFHuHv6md1E9CXp\nRwRtXxHB4p9Tx1xJTEC4B/E8xxNDqa+qY76klbLFH1koIiJZk8wOO8bdD6j1YJGMUZ8TEZEWzGLK\n/zZFaXsSE7k9X5pciSwe1ZyIiLRgZrY+MRJqCNG0sjExZHcSsFkDh+CLlJT6nIiItGxTieHWJxGd\ndL8hZuQ9R4GJtFSqORE
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0x6a18908>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-11-29 05:56:43 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAADeCAYAAADmUqAlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XeYlOXVx/HvD0SMGFExgj0WxG4Ee0djTSK2qGvX2I0x\nxGg09m5ibygmEcSCvRC7iPhiwbJgRyMKNgRUFKRJ2fP+cZ5xnp2d3Z2ZbbO753Ndc808/b7nmd05\nc1eZGSGEEEII5aJDSycghBBCCCEtgpMQQgghlJUITkIIIYRQViI4CSGEEEJZieAkhBBCCGUlgpMQ\nQgghlJUITkIIIYRQViI4CSGEEEJZieAkhBBCCGUlgpNQFElVks4t8diJkm5LLR+enK9346WwdJJW\nTdLzl5ZOS3smaaSktxv5nNU+e+Usz9/J9snncrtGvEbJf8chNIcITtqhVFBQJWmrWvb5PNk+LGeT\nJY9SVOU5tlHnT5BUIemUxjxnqF/yWbm+kU7XFHNqFHTO1N9FlaSFkr6U9LSk7ZsgTbXJl9ai3xNJ\nu0s6r45rNPvcJZLOy3mPc9/v5Zo7TaE8LdLSCQgtag5wEPByemXyj3hFYG6eY34GLCjxer3wAKUp\nHQSsB1zXxNcJbdczwBBAwGrAicAISXuY2dPNnRgze0HSz8xsXpGH7oGn/YI82xryd9xQBhwPzMqz\n7ftmTksoUxGctG9PAL+X9CczSwcNBwFvAMvmHlDCP8j0sfNLPba1krS4mc1u6XSEovzPzO7OLEh6\nBHgb+DOQNziRJGBRM/uxKRJU4t+dGvl8jelBM5tWzAGSOgPzLM9stY3xdxZ/q+UlqnXaLwOGAt2A\nnTMrJXUC9gPuJs8/t9y6aknnJ+vWkDRY0neSvpd0m6TFco6trd6/i6SBkr6RNF3S7ZKWyjl2T0mP\nJcXscyWNl3S2pA6pfZ4HfgNk2o5USfoktb1zkt4PJc2RNEnSg5JWy5PPY5JrzJX0mqRN6ntDU9Vl\n20kaIGkK8HmybbCkCXmOOV9SVc66KknXS+on6Z0kDe9K2rWe6y8nab6kc/JsWys574nJ8iJJEfv/\nkvfiG0mjJO1UXz5LVcg9zNm/t6SXJM2W9Imk4/Lss6ikCyR9lJzzM0n/kLRoY6XbzN4FvsFLUTLX\nzdyjgyS9i5cy7ppsk6Q/J/dsjqTJkm7J/Uwn+54tr0KdJek5Sevm2SdvmxNJm0t6QtI0STMlvSXp\n5GTbILzUJF1VtTAn/efmnG9jSU8mf4M/SBouafOcfTKf8a0kXS1panLthyR1K/rNrUUqzwdIuljS\nF3hJy88lHVHb31kJ+ch7jtDyouSkfZsIjAYqyP4i3ANYErgHKKTtRuZXzH3AJ8AZQG/gaGAKcGae\nfdME3Ah8B5yHV/2cCKwC9E3tdwTwA3AVMBPYEbgQ+Dnwt2Sfi4GueJXUn5NzzwRIvgAfT845FLg2\nOXZnYH0gHTgcDCwB3JKk+W/Ag5JWN7OF1G8AMBUvTl88lffa2hLkW78tsE9yrh+APwEPSFrFzL7L\nd1EzmyrpBWB/4KKczQfixfj3JcsX4PfqVuB1/J5vgt+75+rPYkmOoP57mLEMfr/uwwPl/YGbJf1o\nZoPhp9KK/wJbAQOBD4ANgP5AT/z9azBJSwNLAx/lbNopSdeNePAyMVl/K3AYcBtevbgacDLwK0lb\nZz5Dki4CzgIeA57E3/tngE55klHtMyJpZzzvk/DP8mRgHeC3wA34+7EC8Gv881xrKUpyvnWB/wOm\nA5fjn5XjgJGStjOz13MOuQGYBpwP/BJ/z2/E/5cUolty/9IWmNn0nHXnAD8CVwCdgXlk34v031mX\nJB/rFZmPGucIZcLM4tHOHsDhwEL8n+GJeD1v52TbvcDw5PUEYFjOsVXAuanl85J1t+bs9yAwNWfd\nBOC2nHRUAa8CHVPr/5qk77epdZ3z5ONm/MuuU2rdf4FP8ux7ZHKtP9Xxvqya7DMVWDK1/ndJevYo\n4H2tAkYCytk2qJZ0nQcszPMezwF+mVq3QbL+xHrScEyS1nVz1r8LPJtaHpt7bxv4maoCrq9nn0Lv\n4fNJHk5JresEjAG+ynxWgEOA+cCWOec8Njl+i9o+e/Xk41a8RHFZYDNgeJ70VCXX7pVz/DbJtgNy\n1u+crD8wWV4WL215NGe/i5P90n8n2yfX3y5Z7oD/EPgY+Hkdebkh97OVk/703/HDyWdu1dS6HviX\n/PN5PuNP5ZzvKjxwqDU9qc97VS2P93PyXIUHhIsW8XdWbD5qnCMe5fGIap1wH/7r/reSlsB/ed1V\n5DkM/6WWNgr/dbREAcffatVLJG4mCQZ+ukCqLl/SEkkR8otJ2tcu4Br7AF/jv+7qc4+ZzUgtj8J/\nea5ewLEG/MuS/4AN8KyZTfzppGbvADMKSMND+Ht3QGZF8mtyXbw0LON7YD1JazYwnQUr8h4uwIOE\nzLHz8c/YckCfZPV+wDjgf5K6ZR54cCOql7wV4w/4Z2UqXrK4JXCVmeU2sh5pZh/mrNsPf2+fy0nT\nWLy0KJOmnfGA64ac468tIH0b46UV15rZD4VlqXZJqeLOwMNm9mlmvZlNxkuttsn5OzZS9yYxCuiI\nB/j1MWBvvFQn/Tgyz76DLX/7mBp/ZyXmozH+VkMTiGqdds7MvpE0HG8E2wX/VfZACaf6LGc5U/Ww\nNEnVSm1JAMbnpGmWpK/wf8DAT8XOl+D/3JfMOb5rAelbA/jQqjf8rU21umcz+z4pgV66gGMhW7zf\nEPnqv7+rLw1m9q2k5/Dqhkw30gPxX/kPp3Y9F3gE/2J/F3gKuCMJgppEkfdwkpnNyVn3Pzzo+CXw\nGl51szYeSOQyPJApxaN4EGt4qc57edIC+e9zT2ApPLCpK02rJM+5n/1vJOWttktZIznXe/XsV6hf\n4AHi//JsG4f/T1g5eZ2R+/lM/70XYpQV1iB2YhHbSslHXecPLSiCkwD+q+JfwPLAkyX+GqutLUad\ndd2FkNQVr0f+HjgbL9Kei/+CvpzGb9jd0Lzk+yKr7ddZxyZIwz3AbZI2NLO3gd8Dz6W/DMxslKQ1\ngH7ALnhpQX9Jx5lZow9W1kT3sAPwDt7eId/7UmoDxy/MbEQB++W7zx3wtlYH1ZKmfIFUa9Rkf+85\n8r3HhWxrjPOHFhTBSQD/RT0Q2JxUdUAzEf5r84WfVkhd8EDp8WTVDvgvsn5m9lJqvzXynK+2IOBj\nYDNJHa2wRq2N7Tv8F3WuXzbBtR7B7+cBSaPDtfASi2rM7HvgduB2SYvjRfPn4w05G9sOFH4PAVaQ\nj+2R/vLohd/fTOPlj4ENzez5JkhvqT7GG8q+bHV3K85UO/Qk9etd0rLUX/rwMf53sz5QVxBVaHXF\n18Bs/P3NtQ7eNqM19GRpK/kIRFfigFej4IMinY83KG1ux0pKB8on4iUKTyTLC/F/xuluw4sm++Wa\nRf5qngfxYt8/NkaCS/Ax0FXS+pkVkpYH9mrsC5n3eHgar9o5EO/t8Gh6H0nL5BwzG69i6JzaZ0lJ\nvSSlq2BKVcw9BP/hdHxq3054r4uv8Yax4O2lVpJ0TO7BkhZLAq7mdh+e9hpDw0vqmJQggTeyXYD3\n4knrX8A1xuAB2p9T58tnVnLdOu9fUtX5DNBPUqa6CUnd8d43o8ysrqrZstBW8hFclJy0X9WKX83s\njpZKCLAo3oDwPrwNwQn4P5LHku0v4yUPQ5QdIv0Q8v8yrAT2l3QV3kV2ZnKeIXj3zquTMQ9G4d2F\ndwJuMrPGCspqK9a+B/gH8EiShy74l++HeK+pxnYvcCf+5f90TgNfgPcljcTfr2nApnhjzvQQ9Hvj\nvYyOwN+/+mwi6aw865+nuHsI3ivndEm/xNsQHAhsCByTKvm6g2wX477AS3hQuw5elbUL2UCmWZjZ\n/0kaCJwh6Vf4l+V8vPRqP7xL+ENJ25Irk/0ewwPxjYHdyF/189PnysxM0gnAMOBN+ZgmX+F/O+ua\n2e7JrpXJcTdIehrvuXN
2016-10-25 23:52:23 +03:00
"text/plain": [
2016-11-29 05:56:43 +03:00
"<matplotlib.figure.Figure at 0xfde9b38>"
2016-10-25 23:52:23 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
2016-10-26 20:16:27 +03:00
"pred_resnet = train_and_evaluate(reader_train, reader_test, max_epochs=5, model_func=create_resnet_model)"
2016-10-25 06:42:43 +03:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}