CNTK/Tutorials/CNTK_201B_CIFAR-10_ImageHan...

1361 строка
223 KiB
Plaintext
Исходник Обычный вид История

2016-10-25 06:42:43 +03:00
{
"cells": [
2017-02-03 04:40:20 +03:00
{
"cell_type": "code",
"execution_count": 1,
2017-02-03 04:40:20 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from IPython.display import Image"
]
},
2016-10-25 06:42:43 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNTK 201B: Hands On Labs Image Recognition"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This hands-on lab shows how to implement image recognition task using [convolution network][] with CNTK v2 Python API. You will start with a basic feedforward CNN architecture to classify [CIFAR dataset](https://www.cs.toronto.edu/~kriz/cifar.html), then you will keep adding advanced features to your network. Finally, you will implement a VGG net and residual net like the one that won ImageNet competition but smaller in size.\n",
2016-11-02 04:03:02 +03:00
"\n",
"[convolution network]:https://en.wikipedia.org/wiki/Convolutional_neural_network\n",
2016-10-25 06:42:43 +03:00
"\n",
"## Introduction\n",
"\n",
"In this hands-on, you will practice the following:\n",
"\n",
"* Understanding subset of CNTK python API needed for image classification task.\n",
"* Write a custom convolution network to classify CIFAR dataset.\n",
2016-10-25 06:42:43 +03:00
"* Modifying the network structure by adding:\n",
2016-11-02 04:03:02 +03:00
" * [Dropout][] layer.\n",
2016-10-25 06:42:43 +03:00
" * Batchnormalization layer.\n",
2016-11-02 04:03:02 +03:00
"* Implement a [VGG][] style network.\n",
2016-10-25 06:42:43 +03:00
"* Introduction to Residual Nets (RESNET).\n",
2016-11-02 04:03:02 +03:00
"* Implement and train [RESNET network][].\n",
2016-10-25 06:42:43 +03:00
"\n",
"[RESNET network]:https://docs.microsoft.com/en-us/cognitive-toolkit/Hands-On-Labs-Image-Recognition \n",
2016-11-02 04:03:02 +03:00
"[VGG]:http://www.robots.ox.ac.uk/~vgg/research/very_deep/\n",
"[Dropout]:https://en.wikipedia.org/wiki/Dropout_(neural_networks)\n",
2016-10-25 06:42:43 +03:00
"\n",
"## Prerequisites\n",
"\n",
"CNTK 201A hands-on lab, in which you will download and prepare CIFAR dataset is a prerequisites for this lab. This tutorial depends on CNTK v2, so before starting this lab you will need to install CNTK v2. Furthermore, all the tutorials in this lab are done in python, therefore, you will need a basic knowledge of Python. \n",
2016-10-25 06:42:43 +03:00
"\n",
"CNTK 102 lab is recommended but not a prerequisite for this tutorial. However, a basic understanding of Deep Learning is needed. Familiarity with basic convolution operations is highly desirable (Refer to CNTK tutorial 103D).\n",
"\n",
"## Dataset\n",
"\n",
"You will use CIFAR 10 dataset, from https://www.cs.toronto.edu/~kriz/cifar.html, during this tutorial. The dataset contains 50000 training images and 10000 test images, all images are 32 x 32 x 3. Each image is classified as one of 10 classes as shown below:"
2017-02-03 04:40:20 +03:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/cifar-10.png\" width=\"500\" height=\"500\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 1\n",
"Image(url=\"https://cntk.ai/jup/201/cifar-10.png\", width=500, height=500)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above image is from: https://www.cs.toronto.edu/~kriz/cifar.html\n",
2016-10-25 06:42:43 +03:00
"\n",
"## Convolution Neural Network (CNN)\n",
"\n",
"We recommend completing CNTK 103D tutorial before proceeding. Here is a brief recap of Convolution Neural Network (CNN). CNN is a feedforward network comprise of a bunch of layers in such a way that the output of one layer is fed to the next layer (There are more complex architecture that skip layers, we will discuss one of those at the end of this lab). Usually, CNN start with alternating between convolution layer and pooling layer (downsample), then end up with fully connected layer for the classification part.\n",
2016-10-25 06:42:43 +03:00
"\n",
"### Convolution layer\n",
"\n",
"Convolution layer consist of multiple 2D convolution kernels applied on the input image or the previous layer, each convolution kernel outputs a feature map."
2017-02-03 04:40:20 +03:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/Conv2D.png\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 3,
2017-02-03 04:40:20 +03:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 2\n",
"Image(url=\"https://cntk.ai/jup/201/Conv2D.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The stack of feature maps output are the input to the next layer."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/Conv2DFeatures.png\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 3\n",
"Image(url=\"https://cntk.ai/jup/201/Conv2DFeatures.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-26 20:16:27 +03:00
"> Gradient-Based Learning Applied to Document Recognition, Proceedings of the IEEE, 86(11):2278-2324, November 1998\n",
"> Y. LeCun, L. Bottou, Y. Bengio and P. Haffner\n",
"\n",
"#### In CNTK:\n",
"\n",
2016-11-02 04:03:02 +03:00
"Here the [convolution][] layer in Python:\n",
"\n",
"```python\n",
"def Convolution(filter_shape, # e.g. (3,3)\n",
" num_filters, # e.g. 64\n",
" activation, # relu or None...etc.\n",
" init, # Random initialization\n",
" pad, # True or False\n",
" strides) # strides e.g. (1,1)\n",
"```\n",
"\n",
2016-11-02 04:03:02 +03:00
"[convolution]:https://www.cntk.ai/pythondocs/layerref.html#convolution\n",
"\n",
2016-10-25 06:42:43 +03:00
"### Pooling layer\n",
"\n",
"In most CNN vision architecture, each convolution layer is succeeded by a pooling layer, so they keep alternating until the fully connected layer. \n",
"\n",
"The purpose of the pooling layer is as follow:\n",
"\n",
"* Reduce the dimensionality of the previous layer, which speed up the network.\n",
"* Provide a limited translation invariant.\n",
"\n",
2017-02-03 04:40:20 +03:00
"Here an example of max pooling with a stride of 2:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/MaxPooling.png\" width=\"400\" height=\"400\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 4\n",
"Image(url=\"https://cntk.ai/jup/201/MaxPooling.png\", width=400, height=400)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### In CNTK:\n",
"\n",
2016-11-02 04:03:02 +03:00
"Here the [pooling][] layer in Python:\n",
"\n",
"```python\n",
"\n",
"# Max pooling\n",
"def MaxPooling(filter_shape, # e.g. (3,3)\n",
" strides, # (2,2)\n",
" pad) # True or False\n",
"\n",
"# Average pooling\n",
"def AveragePooling(filter_shape, # e.g. (3,3)\n",
" strides, # (2,2)\n",
" pad) # True or False\n",
"```\n",
"\n",
2016-11-02 04:03:02 +03:00
"[pooling]:https://www.cntk.ai/pythondocs/layerref.html#maxpooling-averagepooling\n",
"\n",
2016-10-25 06:42:43 +03:00
"### Dropout layer\n",
"\n",
"Dropout layer takes a probability value as an input, the value is called the dropout rate. Let us say the dropout rate is 0.5, what this layer does it pick at random 50% of the nodes from the previous layer and drop them out of the network. This behavior help regularize the network.\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-10-26 20:16:27 +03:00
"> Dropout: A Simple Way to Prevent Neural Networks from Overfitting\n",
"> Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, Ruslan Salakhutdinov\n",
"\n",
"\n",
"#### In CNTK:\n",
"\n",
"Dropout layer in Python:\n",
"\n",
"```python\n",
"\n",
"# Dropout\n",
"def Dropout(prob) # dropout rate e.g. 0.5\n",
"```\n",
"\n",
2016-10-25 06:42:43 +03:00
"### Batch normalization (BN)\n",
"\n",
"Batch normalization is a way to make the input to each layer has zero mean and unit variance. BN help the network converge faster and keep the input of each layer around zero. BN has two learnable parameters called gamma and beta, the purpose of those parameters is for the network to decide for itself if the normalized input is what is best or the raw input.\n",
"\n",
2016-10-26 20:16:27 +03:00
"> Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift\n",
"> Sergey Ioffe, Christian Szegedy\n",
"\n",
"#### In CNTK:\n",
"\n",
2016-11-02 04:03:02 +03:00
"[Batch normalization][] layer in Python:\n",
2016-10-26 20:16:27 +03:00
"\n",
"```python\n",
"\n",
"# Batch normalization\n",
"def BatchNormalization(map_rank) # For image map_rank=1\n",
"```\n",
"\n",
2016-11-02 04:03:02 +03:00
"[Batch normalization]:https://www.cntk.ai/pythondocs/layerref.html#batchnormalization-layernormalization-stabilizer\n",
"\n",
2017-02-03 04:40:20 +03:00
"## Microsoft Cognitive Network Toolkit (CNTK)\n",
2016-10-25 06:42:43 +03:00
"\n",
"CNTK is a highly flexible computation graphs, each node take inputs as tensors and produce tensors as the result of the computation. Each node is exposed in Python API, which give you the flexibility of creating any custom graphs, you can also define your own node in Python or C++ using CPU, GPU or both.\n",
"\n",
"For Deep learning, you can use the low level API directly or you can use CNTK layered API. We will start with the low level API, then switch to the layered API in this lab.\n",
"\n",
"So let's first import the needed modules for this lab."
]
},
{
"cell_type": "code",
2017-02-03 04:40:20 +03:00
"execution_count": 6,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
2016-10-25 06:42:43 +03:00
},
"outputs": [],
"source": [
"from __future__ import print_function # Use a function definition from future version (say 3.x from 2.7 interpreter)\n",
2017-05-13 05:37:33 +03:00
"\n",
2016-10-25 06:42:43 +03:00
"import matplotlib.pyplot as plt\n",
"import math\n",
2017-05-13 05:37:33 +03:00
"import numpy as np\n",
"import os\n",
2017-05-16 22:15:54 +03:00
"import PIL\n",
2017-05-13 05:37:33 +03:00
"import sys\n",
2017-05-16 22:15:54 +03:00
"try: \n",
" from urllib.request import urlopen \n",
"except ImportError: \n",
" from urllib import urlopen\n",
2016-10-26 11:30:39 +03:00
"\n",
2017-05-13 05:37:33 +03:00
"import cntk as C"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the block below, we check if we are running this notebook in the CNTK internal test machines by looking for environment variables defined there. We then select the right target device (GPU vs CPU) to test this notebook. In other cases, we use CNTK's default policy to use the best available device (GPU, if available, else CPU)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"if 'TEST_DEVICE' in os.environ:\n",
" if os.environ['TEST_DEVICE'] == 'cpu':\n",
" C.device.try_set_default_device(C.device.cpu())\n",
" else:\n",
" C.device.try_set_default_device(C.device.gpu(0))"
2016-10-25 06:42:43 +03:00
]
},
2017-02-03 04:40:20 +03:00
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/CNN.png\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 8,
2017-02-03 04:40:20 +03:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 5\n",
"Image(url=\"https://cntk.ai/jup/201/CNN.png\")"
]
},
2016-10-25 06:42:43 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-02-03 04:40:20 +03:00
"Now that we imported the needed modules, let's implement our first CNN, as shown in Figure 5 above.\n",
2016-10-25 06:42:43 +03:00
"\n",
"Let's implement the above network using CNTK layer API:"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": 9,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_basic_model(input, out_dims):\n",
" with C.layers.default_options(init=C.glorot_uniform(), activation=C.relu):\n",
" net = C.layers.Convolution((5,5), 32, pad=True)(input)\n",
" net = C.layers.MaxPooling((3,3), strides=(2,2))(net)\n",
2016-10-25 06:42:43 +03:00
"\n",
" net = C.layers.Convolution((5,5), 32, pad=True)(net)\n",
" net = C.layers.MaxPooling((3,3), strides=(2,2))(net)\n",
2016-10-25 06:42:43 +03:00
"\n",
" net = C.layers.Convolution((5,5), 64, pad=True)(net)\n",
" net = C.layers.MaxPooling((3,3), strides=(2,2))(net)\n",
" \n",
" net = C.layers.Dense(64)(net)\n",
" net = C.layers.Dense(out_dims, activation=None)(net)\n",
" \n",
2016-10-25 06:42:43 +03:00
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-11-18 13:48:30 +03:00
"To train the above model we need two things:\n",
"* Read the training images and their corresponding labels.\n",
"* Define a cost function, compute the cost for each mini-batch and update the model weights according to the cost value.\n",
"\n",
"To read the data in CNTK, we will use CNTK readers which handle data augmentation and can fetch data in parallel.\n",
"\n",
"Example of a map text file:\n",
"\n",
2016-11-10 15:02:35 +03:00
" S:\\data\\CIFAR-10\\train\\00001.png\t9\n",
" S:\\data\\CIFAR-10\\train\\00002.png\t9\n",
" S:\\data\\CIFAR-10\\train\\00003.png\t4\n",
" S:\\data\\CIFAR-10\\train\\00004.png\t1\n",
" S:\\data\\CIFAR-10\\train\\00005.png\t1\n"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-05-16 02:41:38 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
2016-10-25 06:42:43 +03:00
"source": [
2017-05-13 05:37:33 +03:00
"# Determine the data path for testing\n",
"# Check for an environment variable defined in CNTK's test infrastructure\n",
"envvar = 'CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'\n",
"def is_test(): return envvar in os.environ\n",
"\n",
"if is_test():\n",
2017-05-16 02:41:38 +03:00
" data_path = os.path.join(os.environ[envvar],'Image','CIFAR','v0','tutorial201')\n",
2017-05-15 23:57:56 +03:00
" data_path = os.path.normpath(data_path)\n",
2017-05-13 05:37:33 +03:00
"else:\n",
" data_path = os.path.join('data', 'CIFAR-10')\n",
"\n",
2016-10-25 06:42:43 +03:00
"# model dimensions\n",
"image_height = 32\n",
"image_width = 32\n",
"num_channels = 3\n",
"num_classes = 10\n",
"\n",
"import cntk.io.transforms as xforms \n",
2016-10-25 06:42:43 +03:00
"#\n",
"# Define the reader for both training and evaluation action.\n",
"#\n",
"def create_reader(map_file, mean_file, train):\n",
2017-05-16 02:41:38 +03:00
" print(\"Reading map file:\", map_file)\n",
" print(\"Reading mean file:\", mean_file)\n",
2017-05-13 05:37:33 +03:00
" \n",
2016-10-25 06:42:43 +03:00
" if not os.path.exists(map_file) or not os.path.exists(mean_file):\n",
" raise RuntimeError(\"This tutorials depends 201A tutorials, please run 201A first.\")\n",
2016-10-25 06:42:43 +03:00
"\n",
" # transformation pipeline for the features has jitter/crop only when training\n",
" transforms = []\n",
" # train uses data augmentation (translation only)\n",
2016-10-25 06:42:43 +03:00
" if train:\n",
" transforms += [\n",
" xforms.crop(crop_type='randomside', side_ratio=0.8) \n",
2016-10-25 06:42:43 +03:00
" ]\n",
" transforms += [\n",
HTK deserializers and Reader Transforms Squashed commit of the following: commit f0250727423dcea2fa138e519562fc868b60633e Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 16:36:10 2017 -0800 fix a typo commit 87d4a2119fc0011bb3199c9dcefead913a92b358 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 16:23:20 2017 -0800 Moved C++ test to use HTK deserializers commit 7b52fe1e5de282787e3e59d1a798336976f3d331 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 16:07:23 2017 -0800 final api for htk deserializers commit e21495f3db074e2fa95ac30933925ed30262a7ba Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 15:19:20 2017 -0800 test commit e3951291d060acad2caf49e3a4d6781da354835a Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 14:45:06 2017 -0800 fix commit 3096607bd2fcbaa93e970daccdb503e29ac4d37c Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 14:27:27 2017 -0800 Incorporated CR feedback. commit 2e3c9d31fa1e4dc141fbc6f2ca3ce69d1b5ad855 Merge: 3e116ed e70eac3 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 09:45:00 2017 -0800 Merge remote-tracking branch 'origin/master' into nikosk/iocleanup commit 3e116ed79d35d788c5326eef4d93ad39f715983f Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Fri Feb 17 09:43:48 2017 -0800 typo fix commit 02587a5ef53ded988621ab9759f4507f69b72933 Merge: a01c137 721a6f7 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 20:25:48 2017 -0800 Merge remote-tracking branch 'origin/master' into nikosk/iocleanup commit a01c13727aeb4788ddebdfdbd05b64319dd08443 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 20:05:41 2017 -0800 Removed dead code commit 1e6f6f9c30edeea27d7f7fdd01baeaf8a3d590ec Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 19:56:10 2017 -0800 Updated examples that use transforms commit b40c71667300d0803e38f5b78cd8b68a6c5dd70d Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 19:43:12 2017 -0800 pytests pass commit a6f6af04d854f0a2cefc2f20142cf0164fadde32 Merge: 7bc0893 df46851 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 12:21:07 2017 -0800 Merge branch 'mudithtk' into nikosk/iocleanup commit 7bc08937eeee000682afe28405143896c563940b Merge: d9b5fdf 0bf2f46 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 12:12:14 2017 -0800 Merge remote-tracking branch 'origin/master' into nikosk/iocleanup commit df468513c53a2a5189172045a6df91a699a564a8 Merge: 0bf2f46 15737ef Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 10:03:57 2017 -0800 Merge branch 'muditj/HTKDeserializerPythonWrapper' of https://github.com/MSharman/CNTK into mudithtk commit d9b5fdf0ebab2ff4175d7ae27d1e3cc52d1c195d Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 16 10:01:46 2017 -0800 Various fixes commit 45d149ef1013fc118de142d96c19110753331aab Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Sat Feb 11 17:16:47 2017 -0800 added docs for HTK deserializers commit 3baaaf019dd95990f2f8962e35b9bbd0f37f768b Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Sat Feb 11 17:00:54 2017 -0800 update htk test to conform to new trainer commit 9a39a41be1d2d553f9730dde1912eecbe0805dfb Merge: 05d211b fd2e796 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Sat Feb 11 16:54:09 2017 -0800 Merge remote-tracking branch 'origin/master' into nikosk/iocleanup commit 05d211b3ffcb58c98f71d717b79cb8f30bada060 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Sat Feb 11 16:45:41 2017 -0800 HTK deserializers working properly with a test commit 910580b3df72f15846c9a3e65dced1b121c0a192 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 9 22:03:03 2017 -0800 cleanup commit bb17ae40818ca901d6c2fc09fd692be947e7bdc9 Merge: c920260 b159781 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 9 21:53:22 2017 -0800 Merge remote-tracking branch 'origin/master' into nikosk/iocleanup commit c92026038a6a194f996f7e64261378eb563d9659 Merge: 1db7eb8 74fa7f4 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 9 21:50:33 2017 -0800 Merge remote-tracking branch 'origin/master' into nikosk/iocleanup commit 1db7eb832f6c878aeb46faa015d21e55780bc219 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Thu Feb 9 17:24:58 2017 -0800 All tests work commit af6ee5af9ddd95677192f997a134276834cec537 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Wed Feb 8 22:34:18 2017 -0800 All tests that should pass pass commit 15737ef8437f8c47413f7479ba2549320fb4f0a6 Author: Mudit Jain <muditj@microsoft.com> Date: Tue Feb 7 13:45:41 2017 -0800 Add Htkfeature and mlf deserializers commit a8941ac82e8a028f4444e64e3dd826cbbe266761 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Mon Feb 6 10:14:15 2017 -0800 image deserializer in C++ commit b599a14189a2c36b9a9c14df5d58fa8c577046b9 Author: Nikos Karampatziakis <nikosk@microsoft.com> Date: Mon Jan 30 17:39:10 2017 -0800 Image transforms in c++
2017-02-18 03:39:58 +03:00
" xforms.scale(width=image_width, height=image_height, channels=num_channels, interpolations='linear'),\n",
" xforms.mean(mean_file)\n",
2016-10-25 06:42:43 +03:00
" ]\n",
" # deserializer\n",
" return C.io.MinibatchSource(C.io.ImageDeserializer(map_file, C.io.StreamDefs(\n",
" features = C.io.StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'\n",
" labels = C.io.StreamDef(field='label', shape=num_classes) # and second as 'label'\n",
2016-10-25 06:42:43 +03:00
" )))"
]
},
2017-05-15 23:57:56 +03:00
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2017-05-16 02:41:38 +03:00
"Reading map file: c:\\Data\\CNTKTestData\\Image\\CIFAR\\v0\\tutorial201\\train_map.txt\n",
"Reading mean file: c:\\Data\\CNTKTestData\\Image\\CIFAR\\v0\\tutorial201\\CIFAR-10_mean.xml\n",
"Reading map file: c:\\Data\\CNTKTestData\\Image\\CIFAR\\v0\\tutorial201\\test_map.txt\n",
"Reading mean file: c:\\Data\\CNTKTestData\\Image\\CIFAR\\v0\\tutorial201\\CIFAR-10_mean.xml\n"
2017-05-15 23:57:56 +03:00
]
}
],
"source": [
"# Create the train and test readers\n",
"reader_train = create_reader(os.path.join(data_path, 'train_map.txt'), \n",
" os.path.join(data_path, 'CIFAR-10_mean.xml'), True)\n",
"reader_test = create_reader(os.path.join(data_path, 'test_map.txt'), \n",
" os.path.join(data_path, 'CIFAR-10_mean.xml'), False) "
]
},
2016-10-25 06:42:43 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-11-18 13:48:30 +03:00
"Now let us write the the training and validation loop."
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-15 23:57:56 +03:00
"execution_count": 12,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
2016-10-25 06:42:43 +03:00
},
"outputs": [],
"source": [
"#\n",
"# Train and evaluate the network.\n",
"#\n",
"def train_and_evaluate(reader_train, reader_test, max_epochs, model_func):\n",
2016-11-02 04:03:02 +03:00
" # Input variables denoting the features and label data\n",
2017-05-12 03:00:07 +03:00
" input_var = C.input_variable((num_channels, image_height, image_width))\n",
" label_var = C.input_variable((num_classes))\n",
2016-10-25 06:42:43 +03:00
"\n",
" # Normalize the input\n",
" feature_scale = 1.0 / 256.0\n",
" input_var_norm = C.element_times(feature_scale, input_var)\n",
" \n",
2016-10-25 06:42:43 +03:00
" # apply model to input\n",
" z = model_func(input_var_norm, out_dims=10)\n",
2016-10-25 06:42:43 +03:00
"\n",
" #\n",
" # Training action\n",
" #\n",
"\n",
" # loss and metric\n",
" ce = C.cross_entropy_with_softmax(z, label_var)\n",
" pe = C.classification_error(z, label_var)\n",
2016-10-25 06:42:43 +03:00
"\n",
" # training config\n",
" epoch_size = 50000\n",
" minibatch_size = 64\n",
"\n",
" # Set training parameters\n",
" lr_per_minibatch = C.learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], \n",
" C.UnitType.minibatch, epoch_size)\n",
" momentum_time_constant = C.momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))\n",
" l2_reg_weight = 0.001\n",
2016-11-02 21:29:44 +03:00
" \n",
2016-10-25 06:42:43 +03:00
" # trainer object\n",
" learner = C.momentum_sgd(z.parameters, \n",
" lr = lr_per_minibatch, \n",
" momentum = momentum_time_constant, \n",
" l2_regularization_weight=l2_reg_weight)\n",
" progress_printer = C.logging.ProgressPrinter(tag='Training', num_epochs=max_epochs)\n",
" trainer = C.Trainer(z, (ce, pe), [learner], [progress_printer])\n",
2016-10-25 06:42:43 +03:00
"\n",
" # define mapping from reader streams to network inputs\n",
" input_map = {\n",
" input_var: reader_train.streams.features,\n",
" label_var: reader_train.streams.labels\n",
" }\n",
"\n",
" C.logging.log_number_of_parameters(z) ; print()\n",
2016-10-25 06:42:43 +03:00
"\n",
" # perform model training\n",
" batch_index = 0\n",
" plot_data = {'batchindex':[], 'loss':[], 'error':[]}\n",
" for epoch in range(max_epochs): # loop over epochs\n",
" sample_count = 0\n",
" while sample_count < epoch_size: # loop over minibatches in the epoch\n",
" data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), \n",
" input_map=input_map) # fetch minibatch.\n",
2016-10-25 06:42:43 +03:00
" trainer.train_minibatch(data) # update model with it\n",
"\n",
" sample_count += data[label_var].num_samples # count samples processed so far\n",
" \n",
" # For visualization... \n",
" plot_data['batchindex'].append(batch_index)\n",
" plot_data['loss'].append(trainer.previous_minibatch_loss_average)\n",
" plot_data['error'].append(trainer.previous_minibatch_evaluation_average)\n",
" \n",
" batch_index += 1\n",
" trainer.summarize_training_progress()\n",
2016-10-25 06:42:43 +03:00
" \n",
" #\n",
" # Evaluation action\n",
" #\n",
" epoch_size = 10000\n",
" minibatch_size = 16\n",
"\n",
" # process minibatches and evaluate the model\n",
" metric_numer = 0\n",
" metric_denom = 0\n",
" sample_count = 0\n",
" minibatch_index = 0\n",
"\n",
" while sample_count < epoch_size:\n",
" current_minibatch = min(minibatch_size, epoch_size - sample_count)\n",
"\n",
" # Fetch next test min batch.\n",
" data = reader_test.next_minibatch(current_minibatch, input_map=input_map)\n",
"\n",
" # minibatch data to be trained with\n",
" metric_numer += trainer.test_minibatch(data) * current_minibatch\n",
" metric_denom += current_minibatch\n",
"\n",
" # Keep track of the number of samples processed so far.\n",
" sample_count += data[label_var].num_samples\n",
" minibatch_index += 1\n",
"\n",
" print(\"\")\n",
" print(\"Final Results: Minibatch[1-{}]: errs = {:0.1f}% * {}\".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))\n",
" print(\"\")\n",
" \n",
" # Visualize training result:\n",
" window_width = 32\n",
" loss_cumsum = np.cumsum(np.insert(plot_data['loss'], 0, 0)) \n",
" error_cumsum = np.cumsum(np.insert(plot_data['error'], 0, 0)) \n",
"\n",
" # Moving average.\n",
" plot_data['batchindex'] = np.insert(plot_data['batchindex'], 0, 0)[window_width:]\n",
" plot_data['avg_loss'] = (loss_cumsum[window_width:] - loss_cumsum[:-window_width]) / window_width\n",
" plot_data['avg_error'] = (error_cumsum[window_width:] - error_cumsum[:-window_width]) / window_width\n",
" \n",
" plt.figure(1)\n",
" plt.subplot(211)\n",
" plt.plot(plot_data[\"batchindex\"], plot_data[\"avg_loss\"], 'b--')\n",
" plt.xlabel('Minibatch number')\n",
" plt.ylabel('Loss')\n",
" plt.title('Minibatch run vs. Training loss ')\n",
"\n",
" plt.show()\n",
"\n",
" plt.subplot(212)\n",
" plt.plot(plot_data[\"batchindex\"], plot_data[\"avg_error\"], 'r--')\n",
" plt.xlabel('Minibatch number')\n",
" plt.ylabel('Label Prediction Error')\n",
" plt.title('Minibatch run vs. Label Prediction Error ')\n",
2016-10-26 11:30:39 +03:00
" plt.show()\n",
" \n",
" return C.softmax(z)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-15 23:57:56 +03:00
"execution_count": 13,
"metadata": {},
2016-10-25 06:42:43 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-10-26 11:30:39 +03:00
"Training 116906 parameters in 10 parameter tensors.\n",
2016-10-25 06:42:43 +03:00
"\n",
"Learning rate per minibatch: 0.01\n",
"Momentum per sample: 0.9983550962823424\n",
2017-05-16 22:15:54 +03:00
"Finished Epoch[1 of 5]: [Training] loss = 2.127933 * 50000, metric = 77.96% * 50000 15.851s (3154.4 samples/s);\n",
"Finished Epoch[2 of 5]: [Training] loss = 1.768975 * 50000, metric = 64.90% * 50000 11.970s (4177.1 samples/s);\n",
"Finished Epoch[3 of 5]: [Training] loss = 1.589650 * 50000, metric = 58.28% * 50000 11.958s (4181.3 samples/s);\n",
"Finished Epoch[4 of 5]: [Training] loss = 1.495402 * 50000, metric = 54.32% * 50000 11.960s (4180.6 samples/s);\n",
"Finished Epoch[5 of 5]: [Training] loss = 1.421285 * 50000, metric = 51.41% * 50000 11.965s (4178.9 samples/s);\n",
2016-10-25 06:42:43 +03:00
"\n",
2017-05-16 22:15:54 +03:00
"Final Results: Minibatch[1-626]: errs = 46.4% * 10000\n",
2016-10-25 06:42:43 +03:00
"\n"
]
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXmYVNW1t98fIIgREIegOKPJNY7g54DRKEJMFCMiRhNn\njVOMRqM316i50cQhMdcYBzRGcUjEgag4ICpKEJUhOAGCAgYFEVAmZZ4b1vfHOpWqbqq6q5uqrqp2\nvc9TzzlnT2ed3V1n1d5r7b1kZgRBEARBXTQrtQBBEARBZRAKIwiCIMiLUBhBEARBXoTCCIIgCPIi\nFEYQBEGQF6EwgiAIgrwIhfEVRdI9kn5d37KSjpA0s7jS/ee+0yV1b4x7NQWSv83EQpdtgBwjJJ1Z\njLaD0tKi1AIEhUXSJ8C2QEcz+zIjfRywH7CLmX1qZhfl22aWsg1avCNpZ2A60MLM1jekjaaCpMOA\nl/C+bAZsBiwDlKTtaWaz6tOmmb0O7FPoskGQIkYYTQ/DX8qnpBIk7Q20poEv+gKSehmq6DeSmhf7\nHhuDmY00szZm1hbYC++Xdqm0mspCCSURNggSQmE0TfoDZ2VcnwX8PbOApIckXZ+cHyFppqQrJM2V\nNFvS2dnKppN0taT5kqZJOjUjo6eksZIWS5oh6bqMeq8nx0WSlkg6OKlzvqRJSdr7kjpn1Oki6T1J\nCyU9LqlltgeWdJakkZL+LGk+cJ2k6yT1zyizs6T1kpol18MlXZ/UWyJpiKQtc7Q/SVLPjOvmkuZJ\n6iyplaT+khYkcr4paZts7dRBNYWQTO1cL2k0PvrYUdK5GX01VdK5GeV7SJqecT1T0uWSJiRyPSpp\nk/qWTfKvlvR5Uu68pB93qvOBnGslfSJpjqQHJbVJ8lon90n125hU/yfPOT15zo8kndyA/gwKTCiM\npskYoI2k/0pejj8CHqH2X/bbAm2AjsB5wN2S2tVSdsuk7NnAfZK+keQtA84ws3bAscBPJfVK8g5P\njm2TX9FvSjoJuBY4Pfm13Qv4IuNeJwHfA3bFp9TOruUZDgY+AjoANyVpNUdVNa9PwRXqNkAr4Jc5\n2n4MODXj+mhgvpmNT+q3BbbH++WnwMpa5KwPp+PP3BaYDcwBjkn66nygbzKCTFHz+U4CegCdgAOA\nM+pbVtIPgIuBI4BvAt2z1M3F+Xi/HQ7shvfP7UneOfjIt2OS/jNgVaJQbgV6JM95KDAhz/sFRSQU\nRtMlNco4CpgMfFZH+TXADWa2zsxewl/8/5WjrAG/MbO1ZvYG8AJwMoCZvWFmHyTn7wMD8BdNJpmK\n61zg/8xsbFJnmpllGtXvMLO5ZrYIeB7IHH3UZLaZ/cXM1pvZ6jqeN8VDZvZxUv6JWtp/HOgladPk\n+pQkDWAtsBXwTXPGmdmyPO9fFw+a2b+Tv8s6M3vBzGYAmNlrwDDgO7XUv83M5pvZQmAwtfdfrrIn\nAQ8kcqwEflcP+U8F/pTYzZYD15BWvGuBrUn321gzW5HkrQf2kdQq+ftPqcc9gyIRCqPp8gj+xTwb\neDiP8l/UMESvADbPUXahma3KuJ6B/0pE0sGSXk2maxYBF+IvhVzsCHxcS/7cPGUCaIj31px82jez\nj4FJwHGSWuMjoceS7P7Ay8AASbMk3azC2VCqPZOkHyRTN19IWoj/IKitf+vTf7nKdqwhx0zyt0N1\nxP8/UswAWiVTdn8D/gk8kUx1/V5SMzNbiivkS4A5kgZljGCDEhIKo4liZp/ixu9jgKcL3Hz75KWZ\nYifSI5hHgWeB7c1sC+Be0i+XbNMYM/GpikJQs/3luPdRiu02sv0BuBI+HvjAzKYBmFmVmd1gZnsB\n3waOAwrlVvqfZ0pGN0/i023bmFl7YCjFdyL4HNgh43on8p+S+gzYOeN6Z2B1MpJZa2bXm9mewGFA\nH+A0ADN72cyOwqc/P8b/j4ISEwqjafMToHsyjVBIBPxO0iaSvoPbKp5I8jbHRyBrJR1E9Xn/+fhU\nQ6aCuB/4paT9ASTtJmnHAsk5Hjhc0o6JPeaqjWxvAG5PuYj06AJJ3STtndiLluFTLfV1G87npd8K\n2ARYAFhiW+hRz/s0hCeAcyV9U9JmwP/Wo+7jwBWJw0Eb4EaSvpN0pKS9JImMfpO0bTKSag1U4Yp/\nXSEfKGgYoTCaHv/55Wdm01O2gZp59WknC58DC/Ffj/2BC81sapL3M+AGSYvxF8s/MuRZif86HiXp\nS0kHmdlTSdpjkpYAz+AG0PrKu+EDmP0zuf8E4G3cBlKtSD3bmwP8C+hKxnPhv4KfAhYDHwDD8X5J\nLXr8Sz7N15VmZouBy/ER3Bf4L/Kaz1RXm/Uua2aDgXuAN4APgZFJVi47UWZb/fC+GoE7JCwGfpHk\ndcRHv4uBicAruDJpDvwP/v81HzgEN7oHJUbFDKAkaQd8/rwD/ourn5ndmaPsgcBo4EdmVugplCAI\nCkTilfWumbUqtSxB41LsEUYVcEUyt3sIcLGkPWoWSobyN+OGwyAIygxJvZMpyC3x7+qzpZYpaHyK\nqjDMbE7ip07iZjgZ91Wvyc/xIf28YsoTBEGDuRi3nfwb96C6pLTiBKWg0faSkrQL7tf9Zo30jkBv\nMzsyMZIGQVBmJB5LwVecRlEYkjbHRxCXZVnQdDvwq8ziOdoo9T5IQRAEFYmZFcT1uuheUpJa4Mqi\nv5k9l6XIAfiCp+nAD/EtKXplKYeZlf3nuuuuK7kMIWfIWakyhpyF/xSSxhhhPAhMMrM7smWaWafU\nuaSHgOfNbFAjyBUEQRDUg6IqDEmH4is3J8rjMRi+l8zOgJnZfTWqxLRTEARBmVJUhWFmo/BFOPmW\n/0kRxWkUunXrVmoR8iLkLCyVIGclyAghZzlT1IV7hUSSVYqsQRAE5YIkrFKM3kEQBEHTIBRGEARB\nkBeNtnCvULzzDvTqBV26wLe+BbvvDnvsAXvtBds0JChmEARBkBcVZ8NYvx5mzYJ334V//xumToXJ\nk2HnneGxxzasV1XlxxYVpxqDIAg2nkLaMCpOYdSX++6Dyy+H006DE06A73wHNq8t5lgQBEETIoze\n9eCCC+DDD6FTJ7jhBthuO/je92DSpFJLFgRBUFk0+RFGTVasgIEDoUcP6NixAIIFQRCUMRUzJZVP\nACVJp5LefHApcJGZTczSVqzDCIIgqCeVNCWVTwClacDhZrYfHu+3X5FlysqaNfDjH8Pbb5fi7kEQ\nBOVPyQMomdkY81jFAGNq5jcWLVvCwQdDz54wenQpJAiCIChvSh5AqQbnAS81hjzZuPxyX9dxxhlu\nKA9X3CAIgjTlEEApVeZI4BzgsFzt/Pa3v/3Pebdu3Yqy+ddxx8Ett8BDD8H55xe8+SAIgqLy2muv\n8dprrxWl7aJ7SSUBlAYDL+WKiSFpX2AgcLSZfZyjTKMZvd98E7p2dY+q1q0b5ZZBEARFoWK8pAAk\nPQwsMLMrcuTvBAwDzjCzMbW006heUqNGwaGHNtrtgiAIikLFKIwkgNIbwEQ8ONIGAZQk9QP6ADPw\neN5rzeygLG2FW20QBEE9qRiFUUhKqTDeftu9qPbbryS3D4IgaDCVtA6jSfCvf8Gf/1xqKYIgCEpL\njDDy4OOP3d120CD3ogqCIKgUYkqqkTGD00+H996DCROgWYzLgiCoEGJKqpGR4OGHYc4cuPtuT+vU\nydPnzCmtbEEQBI1FKIw8ad7cY2scdpiPOGbP9vQirB0MgiAoS2JKaiMYNcoVyMqVsOmmpZYmCIJg\nQ8KGUUasWeMut0EQBOVI2DDKiJSymDIFHn/cFUiKTz4piUhBEARFoagKQ9IOkl6V9IGkiZIuzVHu\nTklTJY2X1LmYMhWLkSPh1FOhbVuforr1Vth1V3ipZHvvBkEQFJaSB1CSdAywm5l9A7gQ+GuRZSoK\nP/mJK4zVq+HKK2H//T29Z09YsKC0sgVBEBSCkgdQAo7Hw7hiZm8C7SR1KKZcxaBZM3j0UfjgA7jx\nRjjySFi71vNGj4ZVq3z
2016-10-25 06:42:43 +03:00
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec6cb8f908>"
2016-10-25 06:42:43 +03:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXecFeX1/98fERELYC/YEmNDYyyxJDbEht1oLGgsmHzt\nMbaomIKx/NRYsWtiwd4bauyK2AsoSBNFUBBRlKYgbc/vjzOXO3v37t3Z5e7du3Ler9e8ZuaZZ57n\nzLN358x5yjkyM4IgCIKgIRZpaQGCIAiC1kEojCAIgiAToTCCIAiCTITCCIIgCDIRCiMIgiDIRCiM\nIAiCIBOhMKoMSTdI+ltj80raQdIXzSvd/Ho/k9StEnW1NiTVSPp5pe9dUNJ1N+Y3WKSc6ZLWKqds\nQfUQCqNCSBoj6UdJyxakD0r+WdcAMLPjzezCLGUWydukRTWS1kxkiN8DC9weC7Kwqd57Jb0iaaak\naZK+lvSwpJUWoK566876G5T0sqSjaxVitrSZjSmjXLm6xkiakTz/9GR/dbnrCUoTL4jKYcBnQI9c\ngqSNgPYs2EumHCiRQc1ekdSmuesoE039myxIG5a614ATzKwDsC7QCbiyaCFNU3TN/rdfQAzY08w6\nJEqpg5mdXCxjsd9YY393reh3WlFCYVSWO4EjU+dHAn3TGSTdJum85HgHSV9IOk3SREnjJR1VLG8+\nSb0kfSNptKRDUxf2kDRQ0lRJYyX1Tt3XP9lPSb7ctkru+T9Jw5K0jyRtkrpnU0kfSpos6V5JixV7\nYElHSnpN0hWSvgF6S+ot6c5Unlpf9MmX63nJfdMkPVNomaXuHSZpj9R5m+QLfBNJ7STdKWlSIufb\nklYoVk5WJG0h6Y2kvPGSrpG0aEG2PSV9msjx74L7j05k/lbS/3KWZdbqAcxsCvAwsFFS5m2Srpf0\nlKTpQFdJi0m6LPlbT0iut0vJ8VdJX0oaJ6knKQVZ+LuStG9iCU+VNErSrpIuALYDrk1/7at211YH\nSXck7fCZUt1cye9igKRLJX2XtFf3LM9fJ7H2b2wS/hsrliZJf5dbK19Jul1Sh6SM3G/waEljgRez\n/lEWJkJhVJa3gKUlrZe8HA8G7qL0193KwNLAqsCfgOskdSyRd9kk71HAzZLWSa59DxxuZh2BPYHj\nJO2TXNs+2XdIvtzelnQg8E/gD8lX7T7At6m6DgR2BX4G/Cqprz62Aj4BVgJyXR2FX/CF5z1whboC\n0A44o56y7wEOTZ13B74xsw+S+zsAnfF2OQ6YWULOLMwDTknK+w3QDTihIM9+wGbJtq+SbhtJ+wJn\nJ9dXAAYA9zZWAEnLAwcAA1PJPYDzzWxp4HXgEuAXwMbJvjP+9yR5MZ8G7ASsA+xcoq4t8Y+a05Pf\nzvbAGDP7eyL/SQVf++m/47X4b3ctoCtwRKKccmwJDAeWAy4FbmlMOxSQ+42tSP43VpjWEzgC2AH4\neSLbtQXlbA+sD+y2ALL8ZAmFUXlyVsYu+D/Llw3kn42/COaZ2f/wF/969eQ14B9mNsfMXgWeAg4C\nMLNXzWxocvwRcB/+j5Mmrbj+CPzbzAYm94w2s/Sgeh8zm5h87fYD0tZHIePN7HozqzGzWQ08b47b\nzOzTJP8DJcq/F9hH0uLJeQ/yL+E5+MtoXXMGmdn3GesvipkNNLN3kvI+B26mbjtebGZTzWwccBX5\nbshjgYvM7GMzqwEuBjaRtHrG6q+R9B0wCP/dnJ669riZvZXIOAv4P+DURI4fkrpychyIt+9wM5sJ\nnFuizqOBW8zspaTsCWb2cYn8gvndYgcDZ5vZDDMbC1wOHJ7KO9bMbjV3aNcXWFnSiiXKfiyxRiYn\n+z+mrhX7jRWmHQpcYWZjzWwG0As4RPkuPAN6m9nMRvxOFyoKTemg+bkLeBX/Mr8jQ/5vk5dLjhnA\nUvXknWxmP6bOx+LWBvJupovwbozFku3BEvWuDnxa4vrEAplWKZG3KbO3vioov+gzm9mnkoYBe0t6\nEreE/plcvhNYDbgvscruAv5mZvOaIA8AicV2BfBrfPxpUeD9gmzjUsfz/wbAmkAfSZfnisNfUp3J\n1kZ/NrNb67k2//6k220J4H1p/jfAIuQ/CFYF3iuQsT4rd3X8w6OxLI+3zecF9XROnc//G5vZTLmw\nSwFf11Pmvmb2cj3XirVfYdqqiQxpeRbFLd8c4wjqJSyMCpN8lX4G7A48Uubil5HUPnW+BnkL5m7g\nMaCzmXUCbiL/kig2wPsFsHaZ5Cos/wf8hZajlLLJwn341+O+wFAzGw1gZnPN7Hwz2xD4LbA33iWx\nINyAW4ZrJ+34N+q+bNMWw5rk/wZfAMea2bLJtoyZLZWzDBaQdBtPwpXshqm6OiVdSgATishY3yB/\nqd9BqYkBk3ALb82CesaXuKchGpoU0FDal0XkmUPtj5+WnoBS1YTCaBmOBrol3QHlRMC/JLWVtB0+\nVvFAcm0p3AKZk/RLp/v9vwFqqP1i+C9whqTNACSt3Yiuk4b4ANhe0urJl//ZC1jeffh4yvH4mAYA\nkrpK2ijpcvgefznUFC+iDgIWTwbOc5vwfu9pZjZD0vpJnYX8VVKnpL1OTuQDuBE4R1KXRL6Okn7f\n+MctTdLF8x/gqsTaQFJnSbsmWR4AjpK0gaQlyFtkxbgF6Clpx2TQeFVJuS7RifhYQDEZapJ6LpS0\nlKQ1gVNxq6+luBc4VdJakpbCxzXuS1nw1T5TrMUJhVE50vPcP8uNDRRea0w5RZgATMa/pO7Ev2ZH\nJddOAM6XNBX4O3B/Sp6Z+D/P60nf8JZm9lCSdo+kacCj+EBvY+Wt+wBmLyT1DwbexcdAamVpZHlf\nAW8CW5N6LnwSwEPAVGAo8DLJC0u+OO36UsUC0/Ev9ZnJfkd83OCwpE1uIq8M0vc9jndTDUye7dZE\nzsfwsYT7JE3Bn797wb2l5GnMtbPwAd+3krqew6fjYmbP4GMrLwEfU2JGkJm9iw8WX4W34yu45QrQ\nBzhQPuPrqiKynIy322i8G/YuM7utkc+Rpl8yIyu3PdxA/kJuxf/+r+LdrTMSGbPWv9Cj5gygJOkW\nYC9gopltXE+eq/HumR+Ao5LZLUEQBEGV0dwWxm2UmJ4maXe8L3gdfAbJjc0sTxAEQdBEmlVhmNlr\neBdJfexLMlPIzN4GOqq87g6CIAiCMlFSYchXzdY3ja0cFE4nHE/taXdBEARBlVByHYaZzUuWy3c0\ns6mVEqoYkmJAKgiCoAmYWVlmgGXpkvoeGCLpFklX57ZyVI5bFOmpmqtRYp62mVX91rt37xaXIeQM\nOVurjCFn+bdykmWl9yMs2AIzUf/85ieAE4H7JW0NTDGzifXkDYIgCFqQBhWGmfWVeyJdN0kaaWZz\nshQu6R7c6dhykj4HeuMuKczMbjazp+VeVD/Bp9X2rL+0IAiCoCVpUGFI6oo7BhuDWwqrSzrS3Lld\nSczs0Ax5TmpYzNZD165dW1qETISc5aU1yNkaZISQs5ppcOGepPeBQ81sZHK+LnCvmW1eAfnScli5\n++OCIAh+6kjCKjjo3TanLADMXRu3LUflQRAEQeshy6D3e5L+i7uGBjiM2q6RgyAIgoWALF1S7fCZ\nTNsmSQOA663CAUaiSyoIgqDxlLNLqqTCkAdCv8PMDitHZQtCKIwgCILGU7ExDPPIZGsm02qDIAiC\nhZgsYxij8TgJT+BrJQAwsyuaTaogCIKg6siiMD5NtkXwaGNBEATBQkhJhZGMYSxtZmc0tQJJ3fFo\nXYsAt5jZJQXXO+GRsNbGI5sdbWbDmlpfEARB0DxkGcPYpqmFJ7GUr8WDKG0I9EjiIKc5BxhkZr8C\njgTK5dgwCIIgKCNZFu59IOkJSYdL2j+3ZSx/S2CUmY1N/E/dhwdNStMFjy1MskBwrVzg+qIMHw6n\nnZax+iAIgqBcZFEYiwPfAt2AvZNtr4zlFwZIGkfdAEkfAvsDSNoSDzC/Wr0ljh4NV14JQ4dmFCEI\ngiAoB1m81Ta3B9mLgT6
2016-10-25 06:42:43 +03:00
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec09a99e10>"
2016-10-25 06:42:43 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pred = train_and_evaluate(reader_train, \n",
" reader_test, \n",
" max_epochs=5, \n",
" model_func=create_basic_model)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-26 11:30:39 +03:00
"Although, this model is very simple, it still has too much code, we can do better. Here the same model in more terse format:"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-15 23:57:56 +03:00
"execution_count": 14,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_basic_model_terse(input, out_dims):\n",
2016-10-25 06:42:43 +03:00
"\n",
" with C.layers.default_options(init=C.glorot_uniform(), activation=C.relu):\n",
" model = C.layers.Sequential([\n",
" C.layers.For(range(3), lambda i: [\n",
" C.layers.Convolution((5,5), [32,32,64][i], pad=True),\n",
" C.layers.MaxPooling((3,3), strides=(2,2))\n",
" ]),\n",
" C.layers.Dense(64),\n",
" C.layers.Dense(out_dims, activation=None)\n",
" ])\n",
2016-10-25 06:42:43 +03:00
"\n",
" return model(input)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-15 23:57:56 +03:00
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training 116906 parameters in 10 parameter tensors.\n",
"\n",
"Learning rate per minibatch: 0.01\n",
"Momentum per sample: 0.9983550962823424\n",
2017-05-16 22:15:54 +03:00
"Finished Epoch[1 of 10]: [Training] loss = 2.064481 * 50000, metric = 75.90% * 50000 12.260s (4078.3 samples/s);\n",
"Finished Epoch[2 of 10]: [Training] loss = 1.703777 * 50000, metric = 63.17% * 50000 12.127s (4123.0 samples/s);\n",
"Finished Epoch[3 of 10]: [Training] loss = 1.561847 * 50000, metric = 57.11% * 50000 12.116s (4126.8 samples/s);\n",
"Finished Epoch[4 of 10]: [Training] loss = 1.463862 * 50000, metric = 53.17% * 50000 12.078s (4139.8 samples/s);\n",
"Finished Epoch[5 of 10]: [Training] loss = 1.374724 * 50000, metric = 49.49% * 50000 12.117s (4126.4 samples/s);\n",
"Finished Epoch[6 of 10]: [Training] loss = 1.294328 * 50000, metric = 46.14% * 50000 12.158s (4112.5 samples/s);\n",
"Finished Epoch[7 of 10]: [Training] loss = 1.231594 * 50000, metric = 43.62% * 50000 12.084s (4137.7 samples/s);\n",
"Finished Epoch[8 of 10]: [Training] loss = 1.179700 * 50000, metric = 41.84% * 50000 12.156s (4113.2 samples/s);\n",
"Finished Epoch[9 of 10]: [Training] loss = 1.136541 * 50000, metric = 39.93% * 50000 12.065s (4144.2 samples/s);\n",
"Finished Epoch[10 of 10]: [Training] loss = 1.096253 * 50000, metric = 38.56% * 50000 12.148s (4115.9 samples/s);\n",
"\n",
2017-05-16 22:15:54 +03:00
"Final Results: Minibatch[1-626]: errs = 34.4% * 10000\n",
"\n"
]
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXm4FNXR/z9fRBHcUDQgIKgQ5XVBwD0SRYyKGtHoqwbj\nrvmp0WjM5vZGoiYuCcYYjSbGJUZFxC2Ca1DRqLixI7iDC6uIAgLKWr8/qsfpO8xc5sLMvTOX+jzP\nPN3n9OnT1Q23q+vUqToyM4IgCIJgVTRpaAGCIAiC6iAURhAEQVAUoTCCIAiCogiFEQRBEBRFKIwg\nCIKgKEJhBEEQBEURCmMtRdItki6ta1tJ+0n6pLzSfXPdKZJ618e1GgPJv82EUrddDTlelHRSOfoO\nGpamDS1AUFokfQi0Adqa2eep+jHALsDWZvaxmZ1dbJ952q5W8I6kjsAUoKmZrVidPhoLknoCT+LP\nsgnQAlgAKKnbwcym1qVPM3sB2LnUbYMgQ1gYjQ/DX8r9MhWSdgKas5ov+hKSeRmq7BeS1in3NdYE\nM3vJzDYys42BHfHnskmmLldZKKFBhA2ChFAYjZO7gZNT5ZOBu9INJN0p6Ypkfz9Jn0j6uaRZkqZJ\nOiVf22yVLpY0W9JkScenDhwqabSkeZI+ktQ/dd4LyXaupPmS9kzO+bGkSUndm5K6pc7pLmmcpC8k\n3SdpvXw3LOlkSS9J+pOk2UB/Sf0l3Z1q01HSCklNkvJwSVck582X9JSkzQr0P0nSoanyOpI+ldRN\nUjNJd0v6LJHzNUlb5OtnFdRQCMnQzhWSRuDWx1aSTk89q/cknZ5qf4CkKanyJ5IukDQ+keteSevW\ntW1y/GJJM5J2ZyTPscMqb8i5TNKHkmZKukPSRsmx5sl1Ms/t1czzT+5zSnKf70s6djWeZ1BiQmE0\nTl4FNpK0ffJyPA64h9q/7NsAGwFtgTOAv0rapJa2myVtTwFulfTt5NgC4EQz2wQ4DDhLUt/k2L7J\nduPkK/o1SccAlwEnJF/bfYE5qWsdAxwEbIMPqZ1Syz3sCbwPtAZ+n9TlWlW55X64Qt0CaAb8skDf\nA4HjU+U+wGwzG5ucvzHQDn8uZwFf1SJnXTgBv+eNgWnATOCQ5Fn9GLgxsSAz5N7fMcABwLbAbsCJ\ndW0r6fvAOcB+wHZA7zznFuLH+HPbF+iEP58/J8dOxS3ftkn9T4CvE4VyHXBAcp/7AOOLvF5QRkJh\nNF4yVsaBwFvA9FW0XwJcaWbLzexJ/MW/fYG2BvzGzJaa2X+Bx4FjAczsv2Y2Mdl/ExiEv2jSpBXX\n6cAfzGx0cs5kM0s71W8ws1lmNhcYCqStj1ymmdnNZrbCzBav4n4z3GlmHyTtB9fS/31AX0nrJ+V+\nSR3AUqAVsJ05Y8xsQZHXXxV3mNm7yb/LcjN73Mw+AjCz54Fnge/Wcv71ZjbbzL4AHqP251eo7THA\n7YkcXwGX10H+44EBid9sIXAJWcW7FNic7HMbbWaLkmMrgJ0lNUv+/d+uwzWDMhEKo/FyD/6HeQrw\nryLaz8lxRC8CNizQ9gsz+zpV/gj/SkTSnpKeS4Zr5gJn4i+FQmwFfFDL8VlFygSwOrO3ZhbTv5l9\nAEwCDpfUHLeEBiaH7waeBgZJmirpGpXOh1LjniR9Pxm6mSPpC/yDoLbnW5fnV6ht2xw5PqF4P1Rb\n/P9Hho+AZsmQ3T+BZ4DByVDXVZKamNmXuEI+F5gpaUjKgg0akFAYjRQz+xh3fh8CPFzi7jdNXpoZ\nOpC1YO4F/g20M7OWwN/JvlzyDWN8gg9VlILc/hfis48ybLmG/Q/ClfARwEQzmwxgZsvM7Eoz2xH4\nDnA4UKpppd/cU2LdPIAPt21hZpsCwyj/JIIZQPtUuQPFD0lNBzqmyh2BxYkls9TMrjCzHYCewFHA\njwDM7GkzOxAf/vwA/38UNDChMBo3pwG9k2GEUiLgcknrSvou7qsYnBzbELdAlkrag5rj/rPxoYa0\ngrgN+KWkHgCSOknaqkRyjgX2lbRV4o+5aA37G4T7U84ma10gqZeknRJ/0QJ8qKWu04aLeek3A9YF\nPgMs8S0cUMfrrA6DgdMlbSepBfB/dTj3PuDnyYSDjYDfkTw7SftL2lGSSD03SW0SS6o5sAxX/MtL\neUPB6hEKo/HxzZefmU3J+AZyj9WlnzzMAL7Avx7vBs40s/eSYz8BrpQ0D3+x3J+S5yv86/hlSZ9L\n2sPMHkzqBkqaDzyCO0DrKu/KN2D2THL98cAbuA+kRpM69jcTeAXYi9R94V/BDwLzgInAcPy5ZIIe\nby6m+1XVmdk84ALcgpuDf5Hn3tOq+qxzWzN7DLgF+C/wDvBScqiQnyjd1z/wZ/UiPiFhHvCz5Fhb\n3PqdB0wA/oMrk3WAX+H/v2YDe+NO96CBUTkXUJLUHh8/b41/cf3DzP5SoO3uwAjgODMr9RBKEAQl\nIpmVNcrMmjW0LEH9Um4LYxnw82Rsd2/gHEldchslpvw1uOMwCIIKQ9KRyRDkZvjf6r8bWqag/imr\nwjCzmck8dZJphm/hc9Vz+Slu0n9aTnmCIFhtzsF9J+/iM6jObVhxgoag3nJJSdoan9f9Wk59W+BI\nM9s/cZIGQVBhJDOWgrWcelEYkjbELYjz8wQ0/Rm4MN28QB8NnQcpCIKgKjGzkky9LvssKUlNcWVx\nt5k9mqfJbnjA0xTgf/GUFH3ztMPMKv7Xv3//Bpch5Aw5q1XGkLP0v1JSHxbGHcAkM7sh30Ez2zaz\nL+lOYKiZDakHuYIgCII6UFaFIWkfPHJzgnw9BsNzyXQEzMxuzTklhp2CIAgqlLIqDDN7GQ/CKbb9\naWUUp17o1atXQ4tQFCFnaakGOatBRgg5K5myBu6VEklWLbIGQRBUCpKwanF6B0EQBI2DqlrT+x//\ngI8+gr33ht69oXnzVZ8TBEEQlIaqsjB69YIFC2DAAGjfHs47DyZPbmipgiAI1g6q1ofxySfw17/C\n7bfDmDGuQIIgCIKalNKHUbUKI8PixdAscmYGQRDkpWqc3pLaJ8t1TpQ0QdJ5edocL2lc8ntJ0s51\nuUYoiyAIgvqh3OthtAHamNnYJJ/UKOAISy3oLmkv4C0zmyepD/BbM9srT18xrTYIgqCOVI2FYUWk\nNzezV81XEgN4Nfd4XVm6FJ5/fk16CIIgCPJRb7OkCqU3z+EM4Mk1uc6ECbD//jB//pr0EgRBEORS\nCenNM232B04Fehbq57e//e03+7169cobmt+jB5xwAlx/PfTvv2ZyB0EQVBvPP/88z5dpmKXss6SS\n9OaPAU8WylgrqSvwENDHzD4o0KZoH8aUKbD77jBxIrRuvZqCB0EQNAKqxoeRUGt6c0kdcGVxYiFl\nUVe22QaOOcZjNIIgCILSUAnpzX8DbAbcLEnAUjNb46Va990XHnhgTXsJgiAIMlR94F4hZs+GqVOh\ne/cyChUEQVDhRKR3EARBUBTV5sMIgiAIGgGhMIIgCIKiaPQKY8kSuOoqiNGsIAiCNaPRK4z58+HS\nS/0XBEEQrD4Nnq02afcXSe9JGiupWyll2HxzuPJKuPrqUvYaBEGw9lEJ2WoPAc41s8Mk7QncUOps\ntQsXwoYb+lTbzTdfvXsJgiCoRqpmllQx2WqBI4B/JW1eAzaRVNKEHhtsAC1bwuWXe3ny5PBpBEEQ\n1JV6ST4ItWarbQd8kipPS+pmlfL6c+ZAkybQtCksXw5vvAG77VbKKwRBEDRuKiZbbTEUk622EE0S\nW2r5ct9uu+3qShEEQVC5NOpstZL+Bgw3s/uT8tvAfmY2K6ddySK9zUAlGdELgiCobKrGh5FQa7Za\nYAhwEnyzXOvcXGVRajLKwix8GUEQBMVS7llS+wD/BSbgmWrzZatF0k1AH2AhcKqZjc7TV0lzSZll\nh6lCaQRB0FiJ5IMlYPJk6NTJ9597Dm65BQYPLln3QRAEFUEojBKwaBH84x9wzjmw7rpet2JF+DaC\nIGhchMIoMbvtBqNGwWu
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec09b867b8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXecFEX2wL8PBAEJZkwkUcwJFT0Dcnh6eCqcZ0D0zPDz\nDOeZ03mip57hzGLCiHoGjIAeigGUMyKgoCIiWUUEyVFg3++PV830zM7M9i6zs7Pyvp9Pf7q6urrq\nde9sv656Ve+JquI4juM4FVGnpgVwHMdxageuMBzHcZxEuMJwHMdxEuEKw3Ecx0mEKwzHcRwnEa4w\nHMdxnES4wigxROR+Efl7ZcuKyEEiMr16pVvd7mQR6VyMtmobIlImIlsX+9o1Jd52ZX6DWepZKCKt\nCymbUzq4wigSIjJFRJaJyIYZ+aPDP2tLAFU9S1VvSFJnlrJVWlQjIq2CDP57YI2fx5osbMp5rYgM\nE5GlIrJARH4SkRdFpPkatJWz7aS/QREZKiKnp1Wi2kRVpxRQrqitKSKyJNz/wrC/u9DtOPnxF0Tx\nUGAy0CPKEJGdgYas2UumEEiQQaq9IZG61d1Ggajq32RNnmG+axU4W1WbAu2A9YE7slZSNUVX7X/7\nNUSBw1W1aVBKTVX1vGwFs/3GKvu7q0W/06LiCqO4PAmcEjs+BegXLyAij4nIP0P6IBGZLiIXishM\nEfleRE7NVjaVJVeIyCwRmSQiJ8RO/EFERonIfBGZKiK9Y9e9G/bzwpfbPuGaXiLyVcj7QkR2j12z\nh4h8LiJzReQZEamf7YZF5BQR+Z+I3C4is4DeItJbRJ6MlUn7og9frv8M1y0Qkdcze2axa78SkT/E\njuuGL/DdRWRdEXlSRGYHOT8WkU2y1ZMUEdlbRD4I9X0vIveIyDoZxQ4XkYlBjlsyrj89yPyziAyO\nepZJmwdQ1XnAi8DOoc7HROQ+EXlNRBYCnUSkvojcGv7WM8L5dWNyXCIiP4jIdyJyGjEFmfm7EpFu\noSc8X0QmiMihInI9cCDQJ/61L+lDW01F5InwHCZLbJgr/C6Gi8i/RWROeF5dktx/ucz039hs7DeW\nLU9E5Cqx3sqPIvK4iDQNdUS/wdNFZCrwdtI/ytqEK4zi8hHQRES2Cy/H7sBT5P+62wxoAmwB9ATu\nFZFmecpuGMqeCvQVkW3DuUXASaraDDgc+IuIdA3nOoZ90/Dl9rGIHAtcDfw5fNV2BX6OtXUscCjQ\nBtgttJeLfYBvgeZANNSR+QWfedwDU6ibAOsCF+eo+2nghNhxF2CWqn4Wrm8KbIk9l78AS/PImYRV\nwPmhvt8AnYGzM8r8EWgftm4Shm1EpBtweTi/CTAceKayAojIxsDRwKhYdg/gOlVtArwP3AxsA+wa\n9ltif0/Ci/lC4GBgW+B3edrqgH3UXBR+Ox2BKap6VZD/3Iyv/fjfsQ/2220NdAJODsopogMwDtgI\n+DfwSGWeQwbRb2xTUr+xzLzTgJOBg4Ctg2x9MurpCGwP/H4NZPnV4gqj+ES9jEOwf5YfKij/C/Yi\nWKWqg7EX/3Y5yirwD1VdoarvAa8BxwGo6nuq+mVIfwE8i/3jxIkrrjOAW1R1VLhmkqrGjep3qerM\n8LU7CIj3PjL5XlXvU9UyVV1ewf1GPKaqE0P5/nnqfwboKiINwnEPUi/hFdjLqJ0ao1V1UcL2s6Kq\no1T1k1DfNKAv5Z/jTao6X1W/A+4kNQx5JnCjqn6jqmXATcDuItIiYfP3iMgcYDT2u7kodm6Aqn4U\nZFwO9AIuCHIsDm1FchyLPd9xqroUuCZPm6cDj6jqO6HuGar6TZ7yAquHxboDl6vqElWdCtwGnBQr\nO1VVH1VzaNcP2ExENs1T9yuhNzI37M+Incv2G8vMOwG4XVWnquoS4ArgeEkN4SnQW1WXVuJ3ulaR\n2ZV2qp+ngPewL/MnEpT/ObxcIpYAjXOUnauqy2LHU7HeBmLDTDdiwxj1w/Z8nnZbABPznJ+ZIdPm\necpWZfbWjxn1Z71nVZ0oIl8BR4rIq1hP6Opw+klgK+DZ0Ct7Cvi7qq6qgjwAhB7b7cBemP1pHWBk\nRrHvYunVfwOgFXCXiNwWVYe9pLYk2TP6q6o+muPc6uvDsFsjYKTI6m+AOqQ+CLYAPs2QMVcvtwX2\n4VFZNsaezbSMdraMHa/+G6vqUjFhGwM/5aizm6oOzXEu2/PLzNsiyBCXZx2s5xvxHU5OvIdRZMJX\n6WTgMOClAle/gYg0jB23JNWD+Q/wCrClqq4PPEjqJZHNwDsdaFsguTLrX4y90CLyKZskPIt9PXYD\nvlTVSQCqulJVr1PVnYD9gCOxIYk14X6sZ9g2PMe/U/5lG+8xtCL1N5gOnKmqG4ZtA1VtHPUM1pD4\nM56NKdmdYm2tH4aUAGZkkTGXkT/f7yDfxIDZWA+vVUY73+e5piIqmhRQUd4PWeRZQfrHT01PQClp\nXGHUDKcDncNwQCER4FoRqSciB2K2iv7hXGOsB7IijEvHx/1nAWWkvxgeBi4WkfYAItK2EkMnFfEZ\n0FFEWoQv/8vXsL5nMXvKWZhNAwAR6SQiO4chh0XYy6EsexXlEKBBMJxHm2Dj3gtUdYmIbB/azOQS\nEVk/PK/zgnwADwBXisiOQb5mInJM5W83P2GI5yHgztDbQES2FJFDQ5H+wKkisoOINCLVI8vGI8Bp\nIvLbYDTeQkSiIdGZmC0gmwxloZ0bRKSxiLQCLsB6fTXFM8AFItJaRBpjdo1nYz34Up8pVuO4wige\n8XnukyPbQOa5ytSThRnAXOxL6knsa3ZCOHc2cJ2IzAeuAp6LybMU++d5P4wNd1DVF0Le0yKyAHgZ\nM/RWVt7yN6D6Vmh/DDACs4GkFalkfT8CHwL7ErsvbBLAC8B84EtgKOGFJbY47b581QILsS/1pWH/\nW8xucGJ4Jg+SUgbx6wZgw1Sjwr09GuR8BbMlPCsi87D775JxbT55KnPuMszg+1Foawg2HRdVfR2z\nrbwDfEOeGUGqOgIzFt+JPcdhWM8V4C7gWLEZX3dmkeU87LlNwoZhn1LVxyp5H3EGhRlZ0fZiBeUz\neRT7+7+HDbcuCTImbX+tR6o7gFKYkXEnppweUdWbM86vj/0h22L/mKer6lfVKpTjOI5Taaq1hxGG\nAvpgU9R2AnqEbnycK4HRqrobNnvIV286juOUINU9JNUBmBCmsa3Auu/dMsrsiHWNUdXxQGtZw8VV\njuM4TuHJqzDEVs3mmsaWhMzpgt+RPq0O4HPgT6G9Dtj46FZr0KbjOI5TDeRdh6Gqq8Jy+WaqOr+a\nZLgJm5s+ChiLLUoqN09eRNwg5TiOUwVUtSAzwJIMSS0CxorIIyJyd7QlrP97UjMqwHoOafOwVXWh\nqp6uqu1V9RRsGf+kbJVpt27ovHmoasluvXv3rnEZXE6Xs7bK6HIWfiskSRTGS8A/sKloI2NbEkYA\n2wTHXvWB44GB8QJhLnq9kO4FvKu53DcMGAD335+wacdxHKeQVOgaRFX7hZd9u5A1Xs2AXSFqQ1rn\nYnPAo2m140TkTDutfYEdgH4iUobNlT8jZ4UnnwxLliRp2nEcxykwFSoMEemEOQabgq2EbCEip6g5\nt6sQtUVC22XkPRhLf5R5PicdO8Lw4YmK1hSdOnWqaRES4XIWltogZ22QEVzOUqbChXsiMhI4QW3K\nKyLSDnhGVfcsgnxxOVSHDoV//KPklYbjOE6pICJoEY3e9SJlAaDm2rheIRqvNG3bgri7F8dxnJog\nSQ/jUcxh21Mh60SgrqqenvuqwiMiWmiLv+M4zq+dQvYwkiiMdYFzgANC1nDgPi1ygBFXGI7jOJWn\naApDLBD6E6p6YiEaWxNcYTiO41Seotkw1CKTRWsoSgNV+M6DYjmO4xSbJCFaJ2FxEgZikdIAUNXb\nq02qfJSVQYsWMGoUtG9vx24IdxzHqXaSzJKaCLwayjaJbTVD3bq2339/2w8ZUmOiOI7jrE0ksWHc\nrKoXV7mBigMoNcVmYLU
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec10332128>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
"pred_basic_model = train_and_evaluate(reader_train, \n",
" reader_test, \n",
" max_epochs=10, \n",
" model_func=create_basic_model_terse)"
2016-10-26 11:30:39 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-05-16 22:15:54 +03:00
"Now that we have a trained model, let us classify the following image of a truck. We use PIL to read the image."
2016-10-26 11:30:39 +03:00
]
},
{
"cell_type": "code",
2017-05-15 23:57:56 +03:00
"execution_count": 16,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/00014.png\" width=\"64\" height=\"64\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
2017-05-15 23:57:56 +03:00
"execution_count": 16,
2017-02-03 04:40:20 +03:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 6\n",
"Image(url=\"https://cntk.ai/jup/201/00014.png\", width=64, height=64)"
]
},
{
"cell_type": "code",
2017-05-15 23:57:56 +03:00
"execution_count": 17,
2016-10-26 11:30:39 +03:00
"metadata": {
"collapsed": true
2016-10-26 11:30:39 +03:00
},
"outputs": [],
"source": [
2017-05-16 22:15:54 +03:00
"# Download a sample image \n",
"# (this is 00014.png from test dataset)\n",
"# Any image of size 32,32 can be evaluated\n",
2016-10-26 11:30:39 +03:00
"\n",
2017-05-16 22:15:54 +03:00
"url = \"https://cntk.ai/jup/201/00014.png\"\n",
"myimg = np.array(PIL.Image.open(urlopen(url)), dtype=np.float32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"During training we have subtracted the mean from the input images. Here we take an approximate value of the mean and subtract it from the image."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def eval(pred_op, image_data):\n",
2016-10-26 11:30:39 +03:00
" label_lookup = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"]\n",
2017-05-16 22:15:54 +03:00
" image_mean = 133.0\n",
" image_data -= image_mean\n",
" image_data = np.ascontiguousarray(np.transpose(image_data, (2, 0, 1)))\n",
2016-11-02 04:03:02 +03:00
" \n",
2017-05-16 22:15:54 +03:00
" result = np.squeeze(pred_op.eval({pred_op.arguments[0]:[image_data]}))\n",
2016-11-02 04:03:02 +03:00
" \n",
" # Return top 3 results:\n",
" top_count = 3\n",
" result_indices = (-np.array(result)).argsort()[:top_count]\n",
2016-10-26 11:30:39 +03:00
"\n",
2016-11-02 04:03:02 +03:00
" print(\"Top 3 predictions:\")\n",
" for i in range(top_count):\n",
" print(\"\\tLabel: {:10s}, confidence: {:.2f}%\".format(label_lookup[result_indices[i]], result[result_indices[i]] * 100))"
2016-10-26 11:30:39 +03:00
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 19,
"metadata": {},
2016-10-26 11:30:39 +03:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-11-02 04:03:02 +03:00
"Top 3 predictions:\n",
2017-05-16 22:15:54 +03:00
"\tLabel: truck , confidence: 96.59%\n",
"\tLabel: ship , confidence: 2.31%\n",
"\tLabel: cat , confidence: 0.43%\n"
2016-10-26 11:30:39 +03:00
]
}
],
"source": [
2017-05-16 22:15:54 +03:00
"# Run the evaluation on the downloaded image\n",
"eval(pred_basic_model, myimg)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding dropout layer, with drop rate of 0.25, before the last dense layer:"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 20,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_basic_model_with_dropout(input, out_dims):\n",
"\n",
" with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):\n",
" model = C.layers.Sequential([\n",
" C.layers.For(range(3), lambda i: [\n",
" C.layers.Convolution((5,5), [32,32,64][i], pad=True),\n",
" C.layers.MaxPooling((3,3), strides=(2,2))\n",
" ]),\n",
" C.layers.Dense(64),\n",
" C.layers.Dropout(0.25),\n",
" C.layers.Dense(out_dims, activation=None)\n",
" ])\n",
"\n",
" return model(input)"
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training 116906 parameters in 10 parameter tensors.\n",
"\n",
"Learning rate per minibatch: 0.01\n",
"Momentum per sample: 0.9983550962823424\n",
2017-05-16 22:15:54 +03:00
"Finished Epoch[1 of 5]: [Training] loss = 2.107245 * 50000, metric = 79.08% * 50000 12.977s (3853.0 samples/s);\n",
"Finished Epoch[2 of 5]: [Training] loss = 1.795581 * 50000, metric = 67.10% * 50000 12.244s (4083.6 samples/s);\n",
"Finished Epoch[3 of 5]: [Training] loss = 1.657041 * 50000, metric = 61.52% * 50000 12.265s (4076.6 samples/s);\n",
"Finished Epoch[4 of 5]: [Training] loss = 1.567592 * 50000, metric = 57.72% * 50000 12.251s (4081.3 samples/s);\n",
"Finished Epoch[5 of 5]: [Training] loss = 1.500142 * 50000, metric = 54.97% * 50000 12.228s (4089.0 samples/s);\n",
"\n",
2017-05-16 22:15:54 +03:00
"Final Results: Minibatch[1-626]: errs = 47.1% * 10000\n",
"\n"
]
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXeYVOX1xz9fQAgqoNhQELA3VEBFY2NFjb38NJhgiT1q\nbIkldkFNsEWNNREjqChixYJdBDs2QBBRsVBEUUSqgFLO749zxxmW2d3ZZWZndjmf55nn3vve9773\n3Hd37pnznnPeV2ZGEARBEFRFg2ILEARBENQNQmEEQRAEOREKIwiCIMiJUBhBEARBToTCCIIgCHIi\nFEYQBEGQE6EwVlAk/UfSJdWtK6mrpMmFle7X+34lqVtt3Ks+kPxtxuS7bg3keF3SnwrRdlBcGhVb\ngCC/SJoAtALWM7MfM8pHAtsC7c1skpmdlmubWerWKHlHUjvgK6CRmS2pSRv1BUm7As/hfdkAWBmY\nCygp29LMvq5Om2b2KrB1vusGQYqwMOofhr+Ue6QKJHUAmlLDF30eSb0MVfAbSQ0LfY/lwczeMLNm\nZtYc2ArvlxapsvLKQglFETYIEkJh1E/6A8dmHB8L3JtZQVI/SVcm+10lTZZ0jqTvJE2RdFy2uuki\nXSRpmqQvJR2ZcWJ/SSMkzZI0UVLPjOteTbYzJc2WtGNyzcmSPk7KPpLUMeOaTpI+lDRD0oOSGmd7\nYEnHSnpD0o2SpgE9JfWU1D+jTjtJSyQ1SI6HSroyuW62pOcltayg/Y8l7Z9x3FDS95I6Smoiqb+k\nHxI535G0VrZ2qmAphZAM7Vwp6S3c+lhf0okZfTVe0okZ9feU9FXG8WRJf5M0OpHrAUkrVbducv4i\nSd8m9U5K+rFtlQ/kXC5pgqSpkvpKapaca5rcJ9Vvw1P9nzznV8lzfi7piBr0Z5BnQmHUT4YDzSRt\nlrwc/wDcT+W/7FsBzYD1gJOA2yW1qKRuy6TucUAfSZsk5+YCx5hZC+AA4FRJByfndk+2zZNf0e9I\n6g5cDhyd/No+GJieca/uwO+ADfAhteMqeYYdgc+BdYB/JmXlraryxz1whboW0AQ4r4K2BwBHZhzv\nC0wzs1HJ9c2B1ni/nArMr0TO6nA0/szNgSnAVGC/pK9OBm5NLMgU5Z+vO7AnsCGwPXBMdetKOhA4\nHegKbAp0y3JtRZyM99vuwEZ4//w7OXc8bvmul5T/BViQKJQbgD2T59wFGJ3j/YICEgqj/pKyMvYG\nxgHfVFH/F+AqM1tsZs/hL/7NKqhrwGVmttDMXgOeAY4AMLPXzGxssv8RMBB/0WSSqbhOBK4zsxHJ\nNV+aWaZT/WYz+87MZgJPA5nWR3mmmNkdZrbEzH6u4nlT9DOzL5L6D1fS/oPAwZJ+kxz3SMoAFgJr\nAJuaM9LM5uZ4/6roa2afJX+XxWb2jJlNBDCzYcAQYLdKrr/JzKaZ2QxgMJX3X0V1uwN3J3LMB66o\nhvxHAv9K/GY/AReTVrwLgTVJ99sIM5uXnFsCbC2pSfL3/6Qa9wwKRCiM+sv9+BfzOOC+HOpPL+eI\nngesWkHdGWa2ION4Iv4rEUk7SnolGa6ZCZyCvxQqYn3gi0rOf5ejTAA1id6amkv7ZvYF8DFwkKSm\nuCU0IDndH3gBGCjpa0nXKH8+lKWeSdKBydDNdEkz8B8ElfVvdfqvorrrlZNjMrn7odbD/z9STASa\nJEN29wAvAw8nQ129JTUwszm4Qj4DmCrpqQwLNigioTDqKWY2CXd+7wc8nufmV09eminakrZgHgCe\nAFqb2WrAnaRfLtmGMSbjQxX5oHz7P+HRRynWXc72B+JK+BBgrJl9CWBmi8zsKjPbCtgZOAjIV1jp\nr8+UWDeP4MNta5nZ6sBLFD6I4FugTcZxW3IfkvoGaJdx3A74ObFkFprZlWa2JbArcBhwFICZvWBm\ne+PDn1/g/0dBkQmFUb85AeiWDCPkEwFXSFpJ0m64r+Lh5NyquAWyUFIXlh73n4YPNWQqiP8B50nq\nDCBpI0nr50nOUcDuktZP/DEXLmd7A3F/ymmkrQsklUnqkPiL5uJDLdUNG87lpd8EWAn4AbDEt7Bn\nNe9TEx4GTpS0qaSVgUurce2DwDlJwEEz4B8kfSdpD0lbSRIZ/SapVWJJNQUW4Yp/cT4fKKgZoTDq\nH7/+8jOzr1K+gfLnqtNOFr4FZuC/HvsDp5jZ+OTcX4CrJM3CXywPZcgzH/91/KakHyV1MbNHk7IB\nkmYDg3AHaHXlXfYBzF5O7j8aeA/3gSxVpZrtTQXeBnYi47nwX8GPArOAscBQvF9SSY935NJ8VWVm\nNgv4G27BTcd/kZd/pqrarHZdMxsM/Ad4DfgUeCM5VZGfKLOtu/C+eh0PSJgF/DU5tx5u/c4CxgAv\n4sqkIXA+/v81Dfgt7nQPioxiAaUgCKpDEpX1gZk1KbYsQe0SFkYQBFUi6dBkCLIlcA1u5QQrGKEw\ngiDIhdNx38lneATVGcUVJygGMSQVBEEQ5ESdmXxQUmi2IAiCGmBmeQm9rlNDUmbG3LnG9dcbnToZ\nW21lnHqqce+9xtixhlnxPz179iy6DCFnyFlXZQw58//JJ3VKYQCssgqcdx588AFccQWsuSY8/jj8\n9FOxJQuCIKjf1JkhqfJIcPjh/qkIM68XBEEQLD91zsLIlQULoEEDePZZmJ/vPOdKKCsrq72bLQch\nZ36pC3LWBRkh5Cxl6kyUlCSrjqxmcNppcOed0K0bvPACNKqz9lQQBEHNkITlyeldbxVGijlzoGtX\naNIEXnkFmjat+pogCIL6Qj4VRkGHpCS1Saa6HitpjKSzstQ5Ur6i2ofylc/yus5ws2buIN9sM/jw\nw3y2HARBsGJRUAtDUiuglZmNkrQq8AFwiGUshiJpJ2Ccmc2StC/Qy8x2ytJWjSyMIAiCFZl8WhgF\nHdU3n91zarI/V9I4fBnLTzLqDM+4ZHhyPgiCICgxai1KSlJ7fMnHdyqpdhLwXG3I849/wN/+VrsR\nVEEQBHWZWlEYyXDUo8DZVsFax5L2wBeFv6DQ8ixZAiNHwr//DW3bFvpuQRAE9YOCB5pKaoQri/5m\n9mQFdbYB+gD7mi9An5VevXr9ul9WVlbjOOgGDeCxx2DuXHeKSx6GGwRBUNcZNmwYw4YNK0jbBQ+r\nlXQf8IOZnVPB+bbAEOCYcv6M8vUK4vR+/XWYMAGOOSbvTQdBEBSdOpOHIWkXfFnHMfiyjQZcjC8E\nb2bWR9Jd+FKTE/F1jReaWZcsbdVKlNRll8GsWXDVVdCiRcFvFwRBUFDqjMLIJ7WlMNq3h4kTfX/m\nzFAaQRDUbepM4l5dZMIEGDcOzjjDs8SDIAgCJyyMIAiCekxYGLXML78UW4IgCILiEwqjCj79FDp1\nKrYUQRAExScURhVstBF8/DFst12s6hcEwYpNKIwqaNQI7roLRoyAVVdNl4c7JQiCFY1QGDlw0knw\nzTdw5JF+PHWqZ4s/9FBx5QqCIKhNQmHkyLrrwgMP+P7VV/v29tuLJ08QBEFtEwqjBpx/Pgwd6tOK\nXHddsaUJgiCoHWKV6xrQpo1/XnstfBlBEKw4ROJeEARBPSYS90qMvn1hlVViKpEgCOo3oTDywEcf\nwbx50Lx5DFEFQVB/CYWRB2680Rdk6tzZF2MKgiCojxRUYUhqI+kVSWMljZF0VgX1bpE0XtIoSR0L\nKVOhOOww+OADWLTIw23D0giCoL5RaAtjEXCOmW0F/BY4XdLmmRUk7QdsZGabAKcA/y2wTAXl3Xd9\navQCrZAYBEFQNAqqMMxsqpmNSvbnAuOA1uWqHQLcl9R5B2ghaZ1CylVIdt4ZdtwRunWLWW6DIKhf\n1JoPQ1J7oCPwTrlTrYHJGcdTWFap1Cnuu8+3Awemt+PGFU+eIAiCfFAriXuSVgUeBc5OLI0a0atX\nr1/3y8rKKCsrW27ZCsGmm/rKfW3auC+jRw8vnzcPmjZ1x/gXX8CGGxZVzCAI6iHDhg1jWIHGxAue\nuCepETAYeM7Mbs5y/r/
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec1f9b5630>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAACfCAYAAAA8qTSuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXecFdX1wL9HQAQpdhQLKIpEsKGisSIRxVijwYCx19h7\nIfEnscRuLDEaVFQswRgbNhSNHRFQQDqIFBWQJiCdBc7vj3OHN+/t27ez+97uvoXz/XzmMzN3bjkz\n+3bO3HvPPUdUFcdxHMfJhw1qWgDHcRyn9uPKxHEcx8kbVyaO4zhO3rgycRzHcfLGlYnjOI6TN65M\nHMdxnLxxZVJLEJHHROQvFc0rIoeJyA9VK93adqeISKfqaKu2ISJrRGSn6i6bL/G2K/IbzFLPIhFp\nWUjZnOLClUkNIyJTRWS5iGyWkT48/CPvAKCqF6nq35LUmSVvpRYTiUiLIIP/Tsj7eeSzoKvMsiLy\nsYgsE5FfRGS2iLwiIs3yaKvMtpP+BkXkIxE5J60S1caqOrWAckVtTRWRpeH+F4X9w4Vuxykff0nU\nPApMAbpHCSLSDmhAfi+gQiBBBqnyhkTqVHUbBaKyf5N8nmGusgpcrKpNgNbAJsADWSupnBKs8r99\nnihwjKo2CQqriapeni1jtt9YRX93teh3Wu24MikOngPOjJ2fCfSJZxCRp0Xk1nB8mIj8ICJXi8gs\nEZkuImdly5tKkh4iMkdEJovIqbELvxWRYSKyUESmiUjPWLlPwn5B+OLbP5Q5X0TGhrTRIrJXrMze\nIvKNiMwXkb4ismG2GxaRM0XkcxH5u4jMAXqKSE8ReS6WJ60nEL54bw3lfhGRdzN7dLGyY0Xkt7Hz\nOuHLfS8RqS8iz4nI3CDnYBHZMls9SRGR/UTki1DfdBH5h4jUzch2jIh8F+S4J6P8OUHmeSLSP+qR\nJm0eQFUXAK8A7UKdT4vIoyLytogsAjqKyIYicl/4W88M1+vH5LhORGaIyI8icjYx5Zn5uxKRE0IP\neqGIfCsiR4rI7cAhwCPxXoKkD5c1EZFnw3OYIrGhs/C7+ExE7hWRn8Pz6pLk/kslpv/G5mK/sWxp\nIiI3ifVyfhKRZ0SkSagj+g2eIyLTgP8l/aOsb7gyKQ6+BBqLyK7hxfkH4HlyfxVuDTQGmgPnAf8U\nkaY58m4W8p4FPC4iu4Rri4HTVbUpcAzwJxE5Plw7NOybhC++wSLSFbgZOC18DR8PzIu11RU4EtgR\n2DO0Vxb7A5OAZkA0fJL55Z953h1TtlsC9YFry6j738CpsfMuwBxVHRHKNwG2xZ7Ln4BlOeRMwmrg\nylDfr4FOwMUZeU4E2oftBAlDQSJyAnBjuL4l8BnQt6ICiMgWwMnAsFhyd+A2VW0MDATuBnYG9gj7\nbbG/J+GlfTXwG2AX4IgcbXXAPniuCb+dQ4GpqnpTkP/SjF5C/O/4CPbbbQl0BM4IiiuiAzAO2By4\nF+hdkeeQQfQb24rUbywz7WzgDOAwYKcg2yMZ9RwKtAGOykOWdRpXJsVD1DvpjP0jzSgn/0rsJbFa\nVftjSmHXMvIq8H+qWqKqnwJvA6cAqOqnqjomHI8GXsT+qeLEldq5wD2qOiyUmayq8Qn+h1R1VvhK\nfhOI91oyma6qj6rqGlVdUc79Rjytqt+F/C/lqL8vcLyIbBTOu5N6QZdgL6rWagxX1cUJ28+Kqg5T\n1SGhvu+Bxyn9HO9S1YWq+iPwIKmhzQuBO1V1oqquAe4C9hKR7RM2/w8R+RkYjv1urold66eqXwYZ\nVwDnA1cFOZaEtiI5umLPd5yqLgP+mqPNc4DeqvphqHumqk7MkV9g7VDbH4AbVXWpqk4D7gdOj+Wd\npqpPqTkO7ANsLSJb5aj79dCLmR/258auZfuNZaadCvxdVaep6lKgB9BNUsOCCvRU1WUV+J2ud2R2\nw52a43ngU+yL/tkE+eeFF0/EUqBRGXnnq+ry2Pk0rJeC2NDVndjQyIZh+2+OdrcHvstxfVaGTNvk\nyFsZK7OfMurPes+q+p2IjAWOE5G3sB7UzeHyc8B2wIuhN/c88BdVXV0JeQAIPb2/A/ti8111ga8z\nsv0YO177NwBaAA+JyP1RddgLbFuSPaPLVPWpMq6tLR+G8hoCX4us/T7YgNTHQnPgqwwZy+odb499\nlFSULbBn831GO9vGztf+jVV1mZiwjYDZZdR5gqp+VMa1bM8vM615kCEuT12sxxzxI05OvGdSJISv\n2SnA0cCrBa5+UxFpEDvfgVTP5wXgdWBbVd0E6EXqBZJtsvkHoFWB5Mqsfwn2sovIpYiS8CL21XkC\nMEZVJwOo6ipVvU1V2wIHAsdhwxz58BjWo2wVnuNfKP0ijvc0WpD6G/wAXKiqm4VtU1VtFPUo8iT+\njOdiCrhtrK1NwjAVwMwsMpZlcJDrd5DLSGEu1jNskdHO9BxlyqM8A4Xy0mZkkaeE9A+jmjaGKXpc\nmRQX5wCdwhBDIRHgFhGpJyKHYHMjL4VrjbCeS0kYB4/PM8wB1pD+0ngSuFZE2gOISKsKDMeUxwjg\nUBHZPvQYbsyzvhex+ZuLsDkUAESko4i0C8MYi7EXx5rsVZRCgI3CJH60CTbO/ouqLhWRNqHNTK4T\nkU3C87o8yAfwL+DPIrJbkK+piPy+4rebmzBs9ATwYOilICLbisiRIctLwFki8isRaUiqJ5eN3sDZ\nInJ4mMBuLiLRMOssbO4hmwxrQjt/E5FGItICuArrLdYUfYGrRKSliDTC5lFejPX8i92irShwZVLz\nxO34p0RzEZnXKlJPFmYC87EvsOewr+Bvw7WLgdtEZCFwE/CfmDzLsH+sgWEsuoOqvhzS/i0ivwCv\nYZPOFZW39A2ofhDaHwkMxeZc0rJUsL6fgEHAAcTuCzNIeBlYCIwBPiK8zMQW5j2aq1pgEfaFvyzs\nD8fmKf4YnkkvUooiXq4fNvQ1LNzbU0HO17G5ixdFZAF2/10yyuaSpyLXbsAmn78MbQ3ATIpR1Xex\nuZwPgYnksFxS1aHYxPWD2HP8GOvxAjwEdBWzTHswiyyXY89tMja0+7yqPl3B+4jzZrAci7ZXysmf\nyVPY3/9TbAh3aZAxafsOIFUdHCtYiDyIKa7eqnp3xvVNsD9mK+yf8xxVHZukrOM4jlMcVKkyCcMI\nEzFTwxnY12Y3VR0fy3MPsEhVbwvd5H+q6hFJyjqO4zjFQVUPc3UAvg0mdyVY1/+EjDy7Yd1qVHUC\n0DKM5yYp6ziO4xQBOZWJ2KrhskzukpBp2vgj6SaAAN8AJ4X2OmDjrtslLOs4juMUATnXmajq6uBK\noKmqLqwiGe7CbOyHAaOwhVcVsvcXEZ8gcxzHqSCqWjBLtSTDXIuBUSLSW0QejraE9U8nZeEB1uNI\nsydX1UWqeo6qtlfVMzEXB5OTlM2op6i3nj171rgMLqfL6XK6nNFWaJKsgH+Vyi+iGwrsHGzJZwLd\niHnHBbOpB5aqrXM4H/hEVReLSLllHcdxnOKgXGWiqn3EPL+2DkkT1CbEy0VtmOxSzJY9Mu8dJyIX\n2mV9HPgV0EdE1mA2/+fmKltmY6tWwbRp0KpQi7Mdx3GcpJSrTESkI+ZsbSq2EnR7ETlTzWFguagt\nhNo1I61X7PjLzOu5ypbJgAFwzDGwejVsUHxrMTt27FjTIiTC5SwsLmdhcTmLl3LXmYjI18Cpama7\niEhroK+q7lMN8iVCRFRXroQNN4RXX4Xf/a6mRXIcxylqRAQt4AR8EmUyUlX3KC+tJhERVVWIPKFW\nweSS4zjOukShlUmS8aCvROTJ4Byvo4g8Qbqb6uLhX/+CbX0piuM4TnWTpGdSH7gEODgkfQY8qkUU\nJGZtzwSsVxINdRXh3InjOE4xUK3DXCJSB3hWVf9Y6QbKd/TYBAtOtANQB7hfVZ8J16ZiHknXACWq\n2qGMNlLKZOlS2HhjO16zJjX0FbFkSeq64zjOekq1DnOpRZ5rEUyDK0xw1vgIFje5LdA9xHqIcwkW\nuGgvzJX3/SISWZmtATq
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec1f9ccb38>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pred_basic_model_dropout = train_and_evaluate(reader_train, \n",
" reader_test, \n",
" max_epochs=5, \n",
" model_func=create_basic_model_with_dropout)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Add batch normalization after each convolution and before the last dense layer:"
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_basic_model_with_batch_normalization(input, out_dims):\n",
"\n",
" with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):\n",
" model = C.layers.Sequential([\n",
" C.layers.For(range(3), lambda i: [\n",
" C.layers.Convolution((5,5), [32,32,64][i], pad=True),\n",
" C.layers.BatchNormalization(map_rank=1),\n",
" C.layers.MaxPooling((3,3), strides=(2,2))\n",
2016-10-25 06:42:43 +03:00
" ]),\n",
" C.layers.Dense(64),\n",
" C.layers.BatchNormalization(map_rank=1),\n",
" C.layers.Dense(out_dims, activation=None)\n",
2016-10-25 06:42:43 +03:00
" ])\n",
"\n",
" return model(input)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training 117290 parameters in 18 parameter tensors.\n",
"\n",
"Learning rate per minibatch: 0.01\n",
"Momentum per sample: 0.9983550962823424\n",
2017-05-16 22:15:54 +03:00
"Finished Epoch[1 of 5]: [Training] loss = 1.536584 * 50000, metric = 55.22% * 50000 12.978s (3852.7 samples/s);\n",
"Finished Epoch[2 of 5]: [Training] loss = 1.215455 * 50000, metric = 43.35% * 50000 12.196s (4099.7 samples/s);\n",
"Finished Epoch[3 of 5]: [Training] loss = 1.092067 * 50000, metric = 38.66% * 50000 12.260s (4078.3 samples/s);\n",
"Finished Epoch[4 of 5]: [Training] loss = 1.011021 * 50000, metric = 35.57% * 50000 12.330s (4055.2 samples/s);\n",
"Finished Epoch[5 of 5]: [Training] loss = 0.952613 * 50000, metric = 33.38% * 50000 12.286s (4069.7 samples/s);\n",
"\n",
2017-05-16 22:15:54 +03:00
"Final Results: Minibatch[1-626]: errs = 30.1% * 10000\n",
"\n"
]
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXmYVNW1t9+fqIgBBYMTouAQTBSNccIZJc4aNSbG6Cdx\niMYYFaNXY9BcUHON93rVkGgSr3PEoHEWNYojKhpwAASZkRYQBUERZRTo9f2xTlnV3VXVVU1VV3Wz\n3uepp87ZZ52919nQZ9Xaa++1ZWYEQRAEQWOsU2kFgiAIgpZBGIwgCIKgIMJgBEEQBAURBiMIgiAo\niDAYQRAEQUGEwQiCIAgKIgzGWoqkv0m6slhZSb0lzS6vdl+3WyOpT3O01RpI/m3Gl1q2CXq8Juln\n5ag7qCzrVlqBoLRI+gDYAuhiZp9llI8Bvgt0N7NZZnZeoXVmkW3S4h1J3YAaYF0zq21KHa0FSQcA\nz+B9uQ6wIbAYUFK2k5l9WEydZvYKsEupZYMgRXgYrQ/DX8qnpAok9QTa0cQXfQlJvQxV9oakNuVu\nY00wsxFm1sHMNgJ2xvtl41RZfWOhhIooGwQJYTBaJ4OB0zPOTwf+nikg6W5J1yTHvSXNlnSJpHmS\n5kg6I5tsukj9Jc2XNEPSqRkXjpY0WtIiSTMlDcy475Xk+3NJX0jqldxzjqSJSdl7knbLuOd7kt6V\ntFDS/ZLWz/bAkk6XNELSTZLmAwMlDZQ0OEOmm6RaSesk5y9Luia57wtJz0raJEf9EyUdnXHeRtIn\nknaT1FbSYEkLEj1HSdo0Wz2NUMcgJEM710h6A/c+tpb084y+mibp5xny35dUk3E+W9LFksYlev1D\n0nrFyibX+0v6OJE7O+nHbRp9IGeApA8kzZV0l6QOybV2STupfhuZ6v/kOWuS55wu6SdN6M+gxITB\naJ2MBDpI2jF5OZ4M3Ef+X/ZbAB2ALsDZwF8kbZxHdpNE9gzgNknfSq4tBvqa2cbAMcAvJR2XXDso\n+d4o+RU9StJJwADgtOTX9nHApxltnQQcDmyLD6mdkecZegHTgc2Ba5Oy+l5V/fNTcIO6KdAWuDRH\n3UOAUzPOjwTmm9nY5P6NgK3wfvklsCyPnsVwGv7MGwFzgLnAUUlfnQPcnHiQKeo/30nA94HtgD2B\nvsXKSjoWOB/oDfQA+mS5Nxfn4P12ELA93j+Dkmtn4p5vl6T8V8DyxKDcCHw/ec79gXEFtheUkTAY\nrZeUl3EYMAn4qBH5r4Dfm9lqM3sGf/HvmEPWgP80s5Vm9irwNPATADN71cwmJMfvAQ/gL5pMMg3X\nz4HrzWx0cs8MM8sMqv/JzOaZ2efAk0Cm91GfOWb2VzOrNbMVjTxvirvN7P1E/sE89d8PHCdpg+T8\nlKQMYCXwTaCHOWPMbHGB7TfGXWY2Nfl3WW1mT5vZTAAzGw68CByY5/4/mtl8M1sIPEX+/sslexJw\nZ6LHMuDqIvQ/FbghiZstAa4gbXhXAp1J99toM1uaXKsFdpHUNvn3n1xEm0GZCIPRerkP/8M8A7i3\nAPlP6wWilwLtc8guNLPlGecz8V+JSOol6aVkuOZz4Fz8pZCLrYH381yfV6BOAE2ZvTW3kPrN7H1g\nIvADSe1wT2hIcnkwMAx4QNKHkv5bpYuh1HkmSccmQzefSlqI/yDI17/F9F8u2S719JhN4XGoLvj/\njxQzgbbJkN09wAvAg8lQ1x8krWNmX+IG+QJgrqShGR5sUEHCYLRSzGwWHvw+Cni0xNV3Sl6aKbYh\n7cH8A3gc2MrMOgL/R/rlkm0YYzY+VFEK6te/BJ99lGLLNaz/AdwIHw9MMLMZAGa2ysx+b2Y7A/sB\nPwBKNa3062dKvJuH8OG2Tc2sE/A85Z9E8DHQNeN8GwofkvoI6JZx3g1YkXgyK83sGjPbCTgAOBH4\nfwBmNszMDsOHP9/H/x8FFSYMRuvmLKBPMoxQSgRcLWk9SQfisYoHk2vtcQ9kpaS9qTvuPx8fasg0\nEHcAl0raHUDS9pK2LpGeY4GDJG2dxGN+u4b1PYDHU84j7V0g6WBJPZN40WJ8qKXYacOFvPTbAusB\nCwBLYgvfL7KdpvAg8HNJPSRtCPyuiHvvBy5JJhx0AP6LpO8kHSJpZ0kio98kbZF4Uu2AVbjhX13K\nBwqaRhiM1sfXv/zMrCYVG6h/rZh6svAxsBD/9TgYONfMpiXXfgX8XtIi/MXyzwx9luG/jl+X9Jmk\nvc3s4aRsiKQvgMfwAGix+jZ8ALMXkvbHAW/hMZA6IkXWNxf4N7APGc+F/wp+GFgETABexvsltejx\nr4VU31iZmS0CLsY9uE/xX+T1n6mxOouWNbOngL8BrwJTgBHJpVxxosy6bsf76jV8QsIi4NfJtS64\n97sIGA88hxuTNsBl+P+v+cC+eNA9qDAq5wZKkrri4+eb47+4bjezP+eQ3Qt4AzjZzEo9hBIEQYlI\nZmW9Y2ZtK61L0LyU28NYBVySjO3uC5wv6dv1hRJX/r/xwGEQBFWGpBOSIchN8L/VxyutU9D8lNVg\nmNncZJ46yTTDSfhc9fpciLv0n5RTnyAImsz5eOxkKj6D6oLKqhNUgmbLJSWpOz6ve1S98i7ACWZ2\nSBIkDYKgykhmLAVrOc1iMCS1xz2Ii7IsaBoEXJ4pnqOOSudBCoIgaJGYWUmmXpd9lpSkdXFjMdjM\nnsgisie+4KkG+DGekuK4LHKYWdV/Bg4cWHEdQs/Qs6XqGHqW/lNKmsPDuAuYaGZ/ynbRzLZLHUu6\nG3jSzIY2g15BEARBEZTVYEjaH1+5OV6+H4PhuWS6AWZmt9W7JYadgiAIqpSyGgwzex1fhFOo/Fll\nVKdZOPjggyutQkGEnqWlJejZEnSE0LOaKevCvVIiyVqKrkEQBNWCJKylBL2DIAiC1kGLMxhTp8Kn\nnzYuFwRBEJSWFmcwfvc7ePHFSmsRBEGw9lFWgyGpa7KZzgRJ4yX1yyJzqnzP5nfleyvvkq/Ojh1h\n4cLy6RwEQRBkp9zrMFLJB8cmq73fkfSc1d1ucQZwkJktknQkng55n1wVduoEn39eXqWDIAiChpR7\nWu1cki0wzWyxpFTywckZMiMzbhlJ9uSEX9OxYxiMIAiCStBsMYxcyQfrcTbwTL56wmAEQRBUhmpI\nPpiSOQQ4E9/bNytXXXUV06fDnDkwfPjBa+XCmSAIgnwMHz6c4cOHl6Xusi/cS5IPPgU8kyuflKRd\ngUeAI83s/RwysXAvCIKgSFrawr28yQclbYMbi765jEUQBEFQecq9p/f++Mbx4/HEgg2SD0q6Hd/M\nfia+F8ZKM2uwkVJ4GEEQBMVTSg8jckkFQRC0YlrakFQQBEHQCmiRBuONN2DVqkprEQRBsHbRIg3G\nySf71NogCIKg+WiRBmOrrcJgBEEQNDcVTz6YyP1Z0jRJYyXt1li9XbqEwQiCIGhuKp58UNJRwPZm\n9i1JvYBbyZN8EMLDCIIgqARl9TDMbK6ZjU2OFwOp5IOZHA/cm8iMAjaWtHm+erffHqZPL4PCQRAE\nQU6qIfngVsDsjPM5NJKxdq+9oHPnUmoXBEEQNEbVJB8shKuuuurrY088ePAaahYEQdC6aNXJByXd\nCrxsZv9MzicDvc1sXj25WOkdBEFQJC1tpXfe5IPAUOBnAJL2AT6vbyyCIAiCylPx5IOJ3C3AkcAS\n4EwzG52lrvAwgiAIiiSSDwZBEAQF0dKGpMrCypVwxx2V1iIIgmDtocV6GKtXw0Ybwdy50KFDBRUL\ngiCoYsLDANq0gaVL4Z57Kq1JEATB2kG5c0ndKWmepHE5rm8kaWiSQ2q8pDOKqf/oo+H++0uiahAE\nQdAI5Z4ldQCwGLjXzHbNcr0/sJGZ9ZfUGZgCbG5mDXa7yBb0njkTuneH+fNj5XcQBEE2WsyQlJmN\nABbmEwFSEYgOwKfZjEUuunWDc87x4akuXaBv3zVQNgiCIMhLc6z07gY8mcPDaI8v3Ps20B442cye\nyVFP3mm1ffrAyy9DbS2
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec25a34e80>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXfYFNX1xz9fCygKNlSiYm9Rf/aWaASNNcaaWDCxRmM3\nGmPsYjSJmsRujLFhS+y9Y0ONBRVEUUFUQKWoiIgoiMh7fn+cGXbeZXffWdh3333hfJ5nnpm5c+fe\ns3d358y995x7ZGYEQRAEQUvM09YCBEEQBO2DUBhBEARBLkJhBEEQBLkIhREEQRDkIhRGEARBkItQ\nGEEQBEEuQmE0GJL+Jen0avNK6iHp49aVbka9IyRtU4+62huSmiStXO97Z5ds3dX8BkuUM0nSirWU\nLWgcQmHUCUkjJX0rafGi9NeTP+vyAGZ2pJn9JU+ZJfLOklONpBUSGeL3wGy3x+w4NpW9V1I/SVMk\nfSXpM0l3S1p6NuoqW3fe36CkZyQd0qwQs85mNrKGcqV1jZQ0Ofn8k5L9ZbWuJ6hMPCDqhwEjgF5p\ngqR1gAWZvYdMLVAig1q9Imne1q6jRszqdzI7bVjpXgOOMrMuwOrAosDFJQuZNUXX6t/9bGLAzmbW\nJVFKXczsuFIZS/3Gqv3dtaPfaV0JhVFfbgYOzJwfCNyYzSCpj6RzkuMekj6W9HtJn0oaLemgUnkL\nSTpV0jhJwyXtl7nwM0kDJU2U9KGk3pn7nk32XyZvbpsl9xwm6Z0k7S1J62fu2UDSG5ImSLpVUodS\nH1jSgZL+J+kiSeOA3pJ6S7o5k6fZG33y5npOct9Xkh4r7pll7n1H0s8y5/Mmb+DrS+oo6WZJnydy\n9pe0ZKly8iJpE0kvJuWNlnS5pPmKsu0s6YNEjr8V3X9IIvN4SY+mPcu81QOY2ZfA3cA6SZl9JF0p\n6WFJk4CekjpI+kfyXY9NrnfMyHGSpDGSRkk6mIyCLP5dSdot6QlPlPSepO0l/Rn4CXBF9m1fzYe2\nuki6KWmHEcoMcyW/i+cl/V3SF0l77Zjn88+U2Pw39jn+GyuVJklnyHsrn0i6QVKXpIz0N3iIpA+B\np/J+KXMToTDqy8tAZ0lrJA/HfYBbqPx21w3oDCwDHAr8U9IiFfIunuQ9CLha0mrJta+B/c1sEWBn\n4AhJuybXtkr2XZI3t/6S9gLOAn6dvNXuCozP1LUXsD2wErBeUl85NgPeB5YG0qGO4jf44vNeuEJd\nEugI/KFM2f8F9suc7wiMM7NByf1dgGXxdjkCmFJBzjxMB45PyvsRsA1wVFGe3YENk203JcM2knYD\nTkmuLwk8D9xarQCSugK/AAZmknsB55pZZ+AF4AJgVWDdZL8s/n2SPJh/D/wUWA3YtkJdm+IvNScm\nv52tgJFmdkYi/zFFb/vZ7/EK/Le7ItATOCBRTimbAkOAJYC/A9dV0w5FpL+xpSj8xorTDgYOAHoA\nKyeyXVFUzlbAmsAOsyHLHEsojPqT9jK2w/8sY1rI/x3+IJhuZo/iD/41yuQ14Ewzm2ZmzwEPA3sD\nmNlzZvZ2cvwWcBv+x8mSVVy/Af5mZgOTe4abWXZS/VIz+zR5230QyPY+ihltZleaWZOZTW3h86b0\nMbMPkvx3VCj/VmBXSQsk570oPISn4Q+j1c153cy+zll/ScxsoJm9kpT3EXA1M7fj+WY20cxGAZdQ\nGIY8HDjPzIaZWRNwPrC+pO45q79c0hfA6/jv5sTMtfvN7OVExqnAYcAJiRzfJHWlcuyFt+8QM5sC\nnF2hzkOA68zs6aTssWY2rEJ+wYxhsX2AU8xsspl9CFwI7J/J+6GZXW++oN2NQDdJS1Uo+76kNzIh\n2f8mc63Ub6w4bT/gIjP70MwmA6cC+6owhGdAbzObUsXvdK6iuCsdtD63AM/hb+Y35cg/Pnm4pEwG\nFi6Td4KZfZs5/xDvbSAfZjoPH8bokGx3Vqi3O/BBheufFsn0gwp5Z8V665Oi8kt+ZjP7QNI7wC6S\nHsJ7Qmcll28GlgNuS3pltwCnm9n0WZAHgKTHdhGwMT7/NB8woCjbqMzxjO8AWAG4VNKFaXH4Q2pZ\n8rXRsWZ2fZlrM+5Pht06AQOkGe8A81B4IVgGeK1IxnK93O74i0e1dMXb5qOiepbNnM/4js1silzY\nhYHPypS5m5k9U+ZaqfYrTlsmkSErz3x4zzdlFEFZoodRZ5K30hHATsA9NS5+MUkLZs6Xp9CD+Q9w\nH7CsmS0K/JvCQ6LUBO/HwCo1kqu4/G/wB1pKJWWTh9vwt8fdgLfNbDiAmX1vZuea2drAj4Fd8CGJ\n2eFfeM9wlaQdT2fmh222x7AChe/gY+BwM1s82RYzs4XTnsFskm3jz3Elu3amrkWTISWAsSVkLDfJ\nX+l3UMkw4HO8h7dCUT2jK9zTEi0ZBbSUNqaEPNNo/vLT1gYoDU0ojLbhEGCbZDiglgj4k6T5Jf0E\nn6u4I7m2MN4DmZaMS2fH/ccBTTR/MFwL/EHShgCSVqli6KQlBgFbSeqevPmfMpvl3YbPpxyJz2kA\nIKmnpHWSIYev8YdDU+kiZkLAAsnEeboJH/f+yswmS1ozqbOYkyQtmrTXcYl8AFcBp0laK5FvEUm/\nrP7jViYZ4rkGuCTpbSBpWUnbJ1nuAA6S9ENJnSj0yEpxHXCwpK2TSeNlJKVDop/icwGlZGhK6vmL\npIUlrQCcgPf62opbgRMkrShpYXxe47ZMD77RLcXanFAY9SNr5z4inRsovlZNOSUYC0zA36Ruxt9m\n30uuHQWcK2kicAZwe0aeKfif54VkbHhTM7srSfuvpK+Ae/GJ3mrlnfkDmD2Z1P8m8Co+B9IsS5Xl\nfQK8BGxO5nPhRgB3AROBt4FnSB5Ycue0KysVC0zC39SnJPut8XmDXyVt8m8KyiB73/34MNXA5LNd\nn8h5Hz6XcJukL/HPv2PRvZXkqebayfiE78tJXX1xc1zM7DF8buVpYBgVLILM7FV8svgSvB374T1X\ngEuBveQWX5eUkOU4vN2G48Owt5hZnyo/R5YHE4usdLu7hfzFXI9//8/hw62TExnz1j/Xo9YOoJRY\nZFyCK6frzOyCouuL4l/kKvgf8xAze6dVhQqCIAiqplV7GMlQwBW4idraQK+kG5/lNOB1M1sPtx4K\n780gCIIGpLWHpDYF3kvM2Kbh3ffdivKshXeNMbN3gRU1m85VQRAEQe2pqDDkXrPlzNjyUGwuOIrm\nZnUAbwB7JvVtio+PLjcbdQZBEAStQEU/DDObnrjLL2JmE1tJhvNx2/SBwGDcKWkmO3lJMSEVBEEw\nC5hZTSzA8gxJfQ0MlnSdpMvSLWf5oylYVID3HJrZYZvZJDM7xMw2NLMDcTf+4aUKM7OG33r37t3m\nMoScIWd7lTHkrP1WS/J4et/DrDuYvQqsmthgjwX2JbNaK7gtOjDZ3D/gMOBZm83lG4IgCILa06LC\nMLMb5SuRrp4kvWs+gd0i5kNax+A24KlZ7RBJh/tluxr4IXCjpCbcVv435UsMgiAI2ooWFYaknvjC\nYCNxT8jukg40X9yuRcydhNYoSvt35vjl4uvtmZ49e7a1CLkIOWtLe5CzPcgIIWcj06LjnqQBwH7m\nJq9IWh241cw2qoN8WTms1uNxQRAEczqSsDpOes+fKgsA86WN569F5UEQBEH7Ic+k92uSrsWXhgb4\nFc2XRg6CIAjmAvIMSXUEjga2TJKeB660OgcYiSGpIAiC6qnlkFRFhSEPhH6Tmf2qFpXNDqEwgiAI\nqqducxjmkclWSMxqgyAIgrmYPHMYw/E4CQ/gkdIAMLOLWk2qIAiCoOHIozA+SLZ58GhjQRAEwVxI\nRYWRzGF0NrM/zGoFOQIodcEtsJYH5gUuNLMbZrW+IAiCoHXIYyX1kpn9aJYK9wBKw4Cf4mFDXwX2\nNbOhmTynAl3M7FRJXYF3gaXN7PuismLSOwiCoEpqOemdZ0hqUDJ/cSfN5zDyLEg4I4ASgKQ0gNLQ\nTB6jMNTVGRhfrCyCIAi
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec09a89860>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
"pred_basic_model_bn = train_and_evaluate(reader_train, \n",
" reader_test, \n",
" max_epochs=5, \n",
" model_func=create_basic_model_with_batch_normalization)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's implement an inspired VGG style network, using layer API, here the architecture:\n",
"\n",
"| VGG9 |\n",
"| ------------- |\n",
"| conv3-64 |\n",
"| conv3-64 |\n",
"| max3 |\n",
"| |\n",
"| conv3-96 |\n",
"| conv3-96 |\n",
"| max3 |\n",
"| |\n",
"| conv3-128 |\n",
"| conv3-128 |\n",
"| max3 |\n",
"| |\n",
"| FC-1024 |\n",
"| FC-1024 |\n",
"| |\n",
"| FC-10 |\n"
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 24,
2017-05-13 05:37:33 +03:00
"metadata": {
"collapsed": true
},
2016-10-25 06:42:43 +03:00
"outputs": [],
"source": [
"def create_vgg9_model(input, out_dims):\n",
" with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):\n",
" model = C.layers.Sequential([\n",
" C.layers.For(range(3), lambda i: [\n",
" C.layers.Convolution((3,3), [64,96,128][i], pad=True),\n",
" C.layers.Convolution((3,3), [64,96,128][i], pad=True),\n",
" C.layers.MaxPooling((3,3), strides=(2,2))\n",
2016-10-25 06:42:43 +03:00
" ]),\n",
" C.layers.For(range(2), lambda : [\n",
" C.layers.Dense(1024)\n",
2016-10-25 06:42:43 +03:00
" ]),\n",
" C.layers.Dense(out_dims, activation=None)\n",
2016-10-25 06:42:43 +03:00
" ])\n",
" \n",
" return model(input)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training 2675978 parameters in 18 parameter tensors.\n",
"\n",
"Learning rate per minibatch: 0.01\n",
"Momentum per sample: 0.9983550962823424\n",
2017-05-16 22:15:54 +03:00
"Finished Epoch[1 of 5]: [Training] loss = 2.267064 * 50000, metric = 84.67% * 50000 18.672s (2677.8 samples/s);\n",
"Finished Epoch[2 of 5]: [Training] loss = 1.877782 * 50000, metric = 69.81% * 50000 12.578s (3975.2 samples/s);\n",
"Finished Epoch[3 of 5]: [Training] loss = 1.689757 * 50000, metric = 63.07% * 50000 12.729s (3928.0 samples/s);\n",
"Finished Epoch[4 of 5]: [Training] loss = 1.564912 * 50000, metric = 57.57% * 50000 12.536s (3988.5 samples/s);\n",
"Finished Epoch[5 of 5]: [Training] loss = 1.475126 * 50000, metric = 53.79% * 50000 13.171s (3796.2 samples/s);\n",
"\n",
2017-05-16 02:41:38 +03:00
"Final Results: Minibatch[1-626]: errs = 50.1% * 10000\n",
"\n"
]
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXmYVNW1t99fgwyGQXBGGRyDM2JEDEYBL05xuhqNc5wS\noxiNJnGIX0Ickhg196pcNcEBFa8ah6jo1RhnBRWNgIziwKgCggqGUaDX98c6ZVUP1V3dVHVVda/3\neeo55+yzzz7r7O46q9Zee68lMyMIgiAI6qOi2AIEQRAE5UEojCAIgiAnQmEEQRAEOREKIwiCIMiJ\nUBhBEARBToTCCIIgCHIiFEYLRdJtkq5oaF1JB0iaV1jpvrnvLEmDm+JezYHkbzM533UbIcdrkk4r\nRNtBcWldbAGC/CJpNrAF0M3MvsgonwDsAfQys7lmdm6ubdZSt1GLdyT1BGYBrc2ssjFtNBck7Qc8\ng/dlBbAhsAxQUrazmX3ckDbN7BVgt3zXDYIUYWE0Pwx/KZ+YKpC0K9CeRr7o80jqZaiC30hqVeh7\nrA9mNsbMOppZJ2AXvF86p8qqKwslFEXYIEgIhdE8GQX8KOP4R8A9mRUkjZR0VbJ/gKR5ki6WtFDS\nJ5JOr61uukiXS1okaaakkzJOHCZpvKSlkuZIGpZx3SvJdomkryTtk1zzY0nTkrIpkvpkXLOnpHcl\nfSnpAUltantgST+SNEbSf0laBAyTNEzSqIw6PSVVSqpIjl+SdFVy3VeS/iGpa5b2p0k6LOO4laTP\nJPWR1FbSKEmLEznHSdq0tnbqoYpCSIZ2rpL0Om59dJd0VkZffSDprIz6B0qalXE8T9JFkiYlcv2v\npA0aWjc5f7mk+Um9s5N+7FHvAzm/lTRb0gJJd0nqmJxrn9wn1W9vpvo/ec5ZyXN+KOn4RvRnkGdC\nYTRP3gQ6Svp28nL8IXAfdf+y3wLoCHQDzgZukdS5jrpdk7qnAyMk7ZCcWwacamadge8DP5V0ZHJu\n/2TbKfkVPU7SccBvgVOSX9tHAp9n3Os44CBgG3xI7fQ6nmEf4ENgc+D3SVl1q6r68Ym4Qt0UaAv8\nMkvb9wMnZRwfAiwys4nJ9Z2ArfB++Smwsg45G8Ip+DN3Aj4BFgCHJn31Y2B4YkGmqP58xwEHAtsC\n3wFObWhdSYcDQ4EDgB2BwbVcm40f4/22P7Ad3j83JufOwC3fbkn5ecCqRKH8GTgwec4BwKQc7xcU\nkFAYzZeUlTEEmA58Wk/9r4GrzWydmT2Dv/i/naWuAb8xszVm9irwf8DxAGb2qplNTfanAA/iL5pM\nMhXXWcB1ZjY+uWammWU61W8ys4VmtgR4Esi0PqrziZndamaVZra6nudNMdLMPkrqP1RH+w8AR0pq\nlxyfmJQBrAE2BnY0Z4KZLcvx/vVxl5m9n/xd1pnZ/5nZHAAzexl4AfheHdf/t5ktMrMvgaeou/+y\n1T0OuDORYyVwZQPkPwm4IfGbLQd+TVrxrgE2Id1v481sRXKuEthNUtvk7/9eA+4ZFIhQGM2X+/Av\n5unAvTnU/7yaI3oF0CFL3S/NbFXG8Rz8VyKS9pH0YjJcswQ4B38pZKM78FEd5xfmKBNAY2ZvLcil\nfTP7CJgGHCGpPW4J3Z+cHgU8Czwo6WNJ1yp/PpQqzyTp8GTo5nNJX+I/COrq34b0X7a63arJMY/c\n/VDd8P+PFHOAtsmQ3d3A88BDyVDXHyRVmNm/cYV8PrBA0ugMCzYoIqEwmilmNhd3fh8K/D3PzXdJ\nXpopepC2YP4XeBzYysw2Av5K+uVS2zDGPHyoIh9Ub385PvsoxZbr2f6DuBI+CphqZjMBzGytmV1t\nZrsA3wWOAPI1rfSbZ0qsm4fx4bZNzawL8ByFn0QwH9g647gHuQ9JfQr0zDjuCaxOLJk1ZnaVme0M\n7AccA5wMYGbPmtkQfPjzI/z/KCgyoTCaN2cCg5NhhHwi4EpJG0j6Hu6reCg51wG3QNZI6kfVcf9F\n+FBDpoK4A/ilpL4AkraT1D1Pck4E9pfUPfHHXLae7T2I+1POJW1dIGmgpF0Tf9EyfKilodOGc3np\ntwU2ABYDlvgWDmzgfRrDQ8BZknaUtCHw/xpw7QPAxcmEg47ANSR9J2mQpF0kiYx+k7RFYkm1B9bi\nin9dPh8oaByhMJof3/zyM7NZKd9A9XMNaacW5gNf4r8eRwHnmNkHybnzgKslLcVfLH/LkGcl/ut4\nrKQvJPUzs0eSsvslfQU8hjtAGypvzQcwez65/yTgbdwHUqVKA9tbALwB9CfjufBfwY8AS4GpwEt4\nv6QWPd6aS/P1lZnZUuAi3IL7HP9FXv2Z6muzwXXN7CngNuBVYAYwJjmVzU+U2dbteF+9hk9IWAr8\nPDnXDbd+lwKTgX/iyqQV8Cv8/2sRsC/udA+KjAqZQEnS1vj4+eb4L67bzezmLHX3Bl4Hfmhm+R5C\nCYIgTySzst4xs7bFliVoWgptYawFLk7GdvcFhkrqXb1SYspfizsOgyAoMSQdnQxBdsW/q48XW6ag\n6SmowjCzBck8dZJphtPxuerV+Rlu0n9WSHmCIGg0Q3Hfyfv4DKrziytOUAyaLJaUpF74vO5x1cq7\nAUeb2aDESRoEQYmRzFgKWjhNojAkdcAtiAtrWdB0I3BpZvUsbRQ7DlIQBEFZYmZ5mXpd8FlSklrj\nymKUmT1RS5Xv4AueZgE/wENSHFlLPcys5D/Dhg0rugwhZ8hZrjKGnPn/5JOmsDDuAqaZ2U21nTSz\nbVP7kkYCT5rZ6CaQKwiCIGgABVUYkgbgKzcny/MxGB5LpidgZjai2iUx7BQEQVCiFFRhmNlYfBFO\nrvXPLKA4TcLAgQOLLUJOhJz5pRzkLAcZIeQsZQq6cC+fSLJykTUIgqBUkISVi9M7CIIgaB6EwgiC\nIAhyoskW7uWLqVPh1FNBgooKaNsW2rWDXXeFG2+sWf/jj+G226B9e2jTBjbYwD9bbgnHHluz/sqV\n8OmnsNFG0KWLl1WEWg2CICg/hdGrF9x+O5hBZSWsXg2rVsGGG9Zev6LClcXKlbB0KaxZ459//7v2\n+jNmwDHHwJIl/jFzpXTAAfBsLZGuZs+Gl16C3XZzBTN3Lmy/PXTPV4DuIAiCEiGc3nWQut3q1a5k\nOnasWWfKFLjmGpg2zRVMjx6udM46C669tknFDYIgqEE+nd6hMArAunVuwWy0Uc1zo0bBHnvA7rs3\nvVxBELQ8YpZUidOqVe3KAmDmTDjwQDj3XLdOgiAIyoWCKgxJW0t6UdJUSZMlXVBLnZMkvZt8xkja\nrZAyFZthw+Ctt2DjjWHwYDj5ZHj33WJLFQRBUD+lkEBpJrC/me2B5/u9vcAyFZ1ttnG/x9tvQ+fO\noTCCICgPmtSHIelxYLiZvZDl/EbAZDOrMceonHwYQRAEpUJZ+jCyJVCqxtnAM00hTxAEQdAwSiGB\nUqrOIOAMYL9s7fzud7/7Zn/gwIHNMviXGRx3HAwf7osLgyAIGsLLL7/Myy+/XJC2Cz4klSRQegp4\nJltODEm7A48Ch5jZR1nqtJghqfPO83UdTz+dfUFiEARBLpTbkFSdCZQk9cCVxanZlEVLY/hwX8Xe\nqZMvAgyCICgFCmphJAmUXgUm48mRaiRQknQ7cAwwB8/nvcbM+tXSVouxMMAVxoAB8OabsHAhbLZZ\nsSUKgqAciZXeLQQzGDkSjjrK120EQRA0lFAYQRAEQU6Umw8jCIIgaAaEwigTKivhX/8qthRBELRk\nQmGUCWvXwmGHwfvvF1uSIAhaKqEwyoQ2beCnP4Xrriu2JEEQtFTC6V1GLFzomf2eeQb22qvY0gRB\nUA6E07uFsvnmcPHFvj4jCIKgqQkLo8yYOxd69vSUsa3LLiN7EARNTdlYGLkkUErq3SzpA0kTJfUp\npEzlTo8evqAvlEUQBE1N0RMoSToU2M7MdgDOAf5SYJmaDWvWwJlnwiuvFFuSIAhaAkVPoCTpL8BL\nZva35Hg6MNDMFla7Noa
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec0f89f588>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-05-16 22:15:54 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXecFdX5/98fARWlKLFFFBQR6xcVFYkaWUmMJsYSicbe\nEo0tltijRmLvPcbuzxJFDRaMJRojdmyAgFRFVJoFqSL9+f3xzOXevdy9O7vcvXsXnvfrNa+Zc+bM\nOc/M3p1nznnOeR6ZGUEQBEFQGys1tgBBEARB0yAURhAEQZCKUBhBEARBKkJhBEEQBKkIhREEQRCk\nIhRGEARBkIpQGBWGpH9IuqCuZSX1lPRlw0q3pN3PJPUqR1tNDUmLJXUq97XLSm7bdfkNFqhnlqSN\nSilbUDmEwigTksZLmiupXV7+4OSftQOAmZ1oZpenqbNA2XotqpHUMZEhfg8s8/NYloVNNV4raYCk\nHyTNlPS1pH6S1l2GtmpsO+1vUNKrko6tVolZazMbX0K5Mm2NlzQnuf9Zyf6WUrcTFCdeEOXDgM+A\nQzIZkrYGWrJsL5lSoEQGNXhDUrOGbqNE1PdvsizPsNi1BpxkZm2ALsAawI0FK6mfomvwv/0yYsDe\nZtYmUUptzOzUQgUL/cbq+rtrQr/TshIKo7w8BByVkz4KeCC3gKT7JV2SHPeU9KWkP0v6StJESUcX\nKpvN0vmSvpE0TtKhOSd+JWmQpBmSPpd0cc51ryX76cmX207JNcdJGpHkDZe0bc4120n6SNI0SY9K\nWrnQDUs6StKbkm6Q9A1wsaSLJT2UU6baF33y5XpJct1MSS/m98xyrh0h6Vc56WbJF/i2klaR9JCk\nbxM535W0dqF60iJpR0lvJ/VNlHSrpOZ5xfaW9GkixzV51x+byDxV0guZnmXa5gHMbDrQD9g6qfN+\nSbdLek7SLKBK0sqSrkv+1pOT86vkyHG2pEmSJkg6hhwFmf+7krRf0hOeIWmspF9Iugz4KXBb7te+\nqg9ttZH0YPIcPlPOMFfyu3hD0rWSvkue115p7n+pzOq/sW/x31ihPEm6UN5bmSLp/0lqk9SR+Q0e\nK+lz4JW0f5QViVAY5WUg0FrSZsnL8XfAwxT/ulsPaA2sD/wB+LuktkXKtkvKHg3cJWnT5Nxs4Agz\nawvsDZwgad/k3G7Jvk3y5faupAOBvwKHJ1+1+wJTc9o6EPgFsDGwTdJeTewEfAKsC2SGOvK/4PPT\nh+AKdW1gFeCsGup+BDg0J70X8I2ZDUmubwO0x5/LCcAPReRMwyLg9KS+nwC9gJPyyuwPdEu2/ZQM\n20jaDzgvOb828AbwaF0FkLQW0BsYlJN9CHCpmbUG3gKuBjoDXZN9e/zvSfJi/jPwM2BT4OdF2uqO\nf9Scmfx2dgPGm9mFifyn5H3t5/4db8N/uxsBVcCRiXLK0B0YCfwIuBa4ty7PIY/Mb2wdsr+x/Lxj\ngCOBnkCnRLbb8urZDdgc2HMZZFluCYVRfjK9jD3wf5ZJtZSfj78IFpnZC/iLf7MayhpwkZktMLPX\ngeeAgwDM7HUz+zg5Hg70xf9xcslVXL8HrjGzQck148ws16h+s5l9lXztPgvk9j7ymWhmt5vZYjOb\nV8v9ZrjfzD5Nyj9epP5HgX0lrZqkDyH7El6Av4y6mDPYzGanbL8gZjbIzN5L6vsCuIuln+NVZjbD\nzCYAN5EdhvwjcKWZjTGzxcBVwLaSNkzZ/K2SvgMG47+bM3POPWNmAxMZ5wHHAWckcnyftJWR40D8\n+Y40sx+APkXaPBa418z+l9Q92czGFCkvWDIs9jvgPDObY2afA9cDR+SU/dzM7jN3aPcAsJ6kdYrU\n/XTSG5mW7H+fc67Qbyw/71DgBjP73MzmAOcDBys7hGfAxWb2Qx1+pysU+V3poOF5GHgd/zJ/MEX5\nqcnLJcMcoFUNZaeZ2dyc9Od4bwP5MNOV+DDGysn2RJF2NwQ+LXL+qzyZflykbH1mb03Jq7/gPZvZ\np5JGAPtI+jfeE/prcvohYAOgb9Irexi4wMwW1UMeAJIe2w3ADrj9qTnwYV6xCTnHS/4GQEfgZknX\nZ6rDX1LtSfeM/mRm99Vwbsn1ybDbasCH0pJvgJXIfhCsD3yQJ2NNvdwN8Q+PurIW/my+yGunfU56\nyd/YzH6QC9sK+LqGOvczs1drOFfo+eXnrZ/IkCtPc7znm2ECQY1ED6PMJF+lnwG/BJ4scfVrSmqZ\nk+5AtgfzT+BpoL2ZrQHcSfYlUcjA+yWwSYnkyq//e/yFlqGYsklDX/zrcT/gYzMbB2BmC83sUjPb\nCtgZ2AcfklgW/oH3DDdJnuMFLP2yze0xdCT7N/gS+KOZtUu2Nc2sVaZnsIzkPuNvcSW7VU5bayRD\nSgCTC8hYk5G/2O+g2MSAb/EeXse8diYWuaY2apsUUFvepALyLKD6x09jT0CpaEJhNA7HAr2S4YBS\nIuBvklpI+iluq3g8OdcK74EsSMalc8f9vwEWU/3FcA9wlqRuAJI2qcPQSW0MAXaTtGHy5X/eMtbX\nF7ennIjbNACQVCVp62TIYTb+clhcuIqlELBqYjjPbMLHvWea2RxJmydt5nO2pDWS53VqIh/AHcBf\nJG2ZyNdW0m/rfrvFSYZ47gZuSnobSGov6RdJkceBoyVtIWk1sj2yQtwLHCNp98RovL6kzJDoV7gt\noJAMi5N2LpfUSlJH4Ay819dYPAqcIWkjSa1wu0bfnB58pc8Ua3RCYZSP3Hnun2VsA/nn6lJPASYD\n0/AvqYfwr9mxybmTgEslzQAuBB7LkecH/J/nrWRsuLuZ/SvJe0TSTOAp3NBbV3mXvgGz/ybtDwXe\nx20g1YrUsb4pwDtAD3LuC58E8C9gBvAx8CrJC0u+OO32YtUCs/Av9R+S/e643eCw5JncSVYZ5F73\nDD5MNSi5t/sSOZ/GbQl9JU3H73+vvGuLyVOXc+fiBt+BSVsv4dNxMbMXcdvK/4AxFJkRZGbv48bi\nm/DnOADvuQLcDBwon/F1UwFZTsWf2zh8GPZhM7u/jveRy7PJjKzM1q+W8vnch//9X8eHW+ckMqZt\nf4VHEUApCIIgSEP0MIIgCIJUhMIIgiAIUlFUYchXzdY0jS0IgiBYgSi6DsPMFiXL5dua2YxyCVUI\nSWFsCYIgqAdmVpIZYGmGpGYDwyTdK+mWzFaKxuuKmVX8dvHFFze6DCFnyNlUZQw5S7+VkjQrvZ+k\n9AvMgiAIgiZGrQrDzB6QeyLtkmSNNrMFDStWEARBUGnUqjAkVeGOwcbjKyE3lHSUuXO7II+qqqrG\nFiEVIWdpaQpyNgUZIeSsZGpduCfpQ+BQMxudpLsAj5rZ9mWQL1cOK/V4XBAEwfKOJKyMRu8WGWUB\nYO7auEUpGg+CIAiaDmkUxgeS7kkcuVVJupvqrpGLImkvSaMkjZF0boHza0h6Uh69bWDGMVsQBEFQ\nWaQZkloFOBnYNcl6A7jdUgQYSbyEjsEje03CHc0dbGajcspcA8wys0sTL5h/N7OlIoDFkFQQBEHd\nKduQlDwQ+n1mdoOZHZBsN6ZRFgndgbHmEa4W4J4998srsyXuNZNk6GsjpY27fMEFoPBIHARBUA6K\nKgzzyGQdk2m19SE/ktgEqkfcAvgIOACWxA/ugEdJq53hwzOC1lO8IAiCIC1pFu6Nw+Mk9McjpQFg\nZjeUSIar8LCVg4BheLzidCE0OybBsyZNgvb5eqgO/PADrLYafPQRdO1a/3qCIAiWY9IojE+TbSU8\n2lhdmEg22Ap4z6FaiEYzm4VHoANA0me4klqKPn36QN++0KIFVbfeStUtt8Ann8CcOXUUK4977/V9\nu3bFywVBEFQ4AwYMYMCAAQ1Sd1Gjd2LDuNrMzqpX5X79aNzoPRl4DzjEzEbmlGkLzDEPHXocsIuZ\nHV2gLjd6n3EGPP44TFyW0MB5XHstTJ4MN9wAbdrACy/ALruUrv4gCIJGomxG78SGUe83Z3L9KXh4\nyI/x+LkjJf1R0vFJsS2
"text/plain": [
2017-05-16 22:15:54 +03:00
"<matplotlib.figure.Figure at 0x2ec10398ef0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
"pred_vgg = train_and_evaluate(reader_train, \n",
" reader_test, \n",
" max_epochs=5, \n",
" model_func=create_vgg9_model)"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Residual Network (ResNet)\n",
"\n",
"One of the main problem of a Deep Neural Network is how to propagate the error all the way to the first layer. For a deep network, the gradients keep getting smaller until it has no effect on the network weights. [ResNet](https://arxiv.org/abs/1512.03385) was designed to overcome such problem, by defining a block with identity path, as shown below:"
2017-02-03 04:40:20 +03:00
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 26,
"metadata": {},
2017-02-03 04:40:20 +03:00
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cntk.ai/jup/201/ResNetBlock2.png\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
2017-05-16 22:15:54 +03:00
"execution_count": 26,
2017-02-03 04:40:20 +03:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Figure 7\n",
"Image(url=\"https://cntk.ai/jup/201/ResNetBlock2.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-10-25 06:42:43 +03:00
"The idea of the above block is 2 folds:\n",
"\n",
"* During back propagation the gradients have a path that does not affect its magnitude.\n",
2016-10-25 06:42:43 +03:00
"* The network need to learn residual mapping (delta to x).\n",
"\n",
"So let's implements ResNet blocks using CNTK:\n",
"\n",
" ResNetNode ResNetNodeInc\n",
" | |\n",
" +------+------+ +---------+----------+\n",
" | | | |\n",
" V | V V\n",
" +----------+ | +--------------+ +----------------+\n",
" | Conv, BN | | | Conv x 2, BN | | SubSample, BN |\n",
" +----------+ | +--------------+ +----------------+\n",
" | | | |\n",
" V | V |\n",
" +-------+ | +-------+ |\n",
" | ReLU | | | ReLU | |\n",
" +-------+ | +-------+ |\n",
" | | | |\n",
" V | V |\n",
" +----------+ | +----------+ |\n",
" | Conv, BN | | | Conv, BN | |\n",
" +----------+ | +----------+ |\n",
" | | | |\n",
" | +---+ | | +---+ |\n",
" +--->| + |<---+ +------>+ + +<-------+\n",
" +---+ +---+\n",
" | |\n",
" V V\n",
" +-------+ +-------+\n",
" | ReLU | | ReLU |\n",
" +-------+ +-------+\n",
" | |\n",
" V V\n"
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 27,
2017-05-13 05:37:33 +03:00
"metadata": {
"collapsed": true
},
2016-10-25 06:42:43 +03:00
"outputs": [],
"source": [
"def convolution_bn(input, filter_size, num_filters, strides=(1,1), init=C.he_normal(), activation=C.relu):\n",
" if activation is None:\n",
" activation = lambda x: x\n",
" \n",
" r = C.layers.Convolution(filter_size, \n",
" num_filters, \n",
" strides=strides, \n",
" init=init, \n",
" activation=None, \n",
" pad=True, bias=False)(input)\n",
" r = C.layers.BatchNormalization(map_rank=1)(r)\n",
" r = activation(r)\n",
" \n",
" return r\n",
"\n",
2016-10-25 06:42:43 +03:00
"def resnet_basic(input, num_filters):\n",
" c1 = convolution_bn(input, (3,3), num_filters)\n",
" c2 = convolution_bn(c1, (3,3), num_filters, activation=None)\n",
2016-10-25 06:42:43 +03:00
" p = c2 + input\n",
" return C.relu(p)\n",
2016-10-25 06:42:43 +03:00
"\n",
"def resnet_basic_inc(input, num_filters):\n",
" c1 = convolution_bn(input, (3,3), num_filters, strides=(2,2))\n",
" c2 = convolution_bn(c1, (3,3), num_filters, activation=None)\n",
"\n",
" s = convolution_bn(input, (1,1), num_filters, strides=(2,2), activation=None)\n",
2016-10-25 06:42:43 +03:00
" \n",
" p = c2 + s\n",
" return C.relu(p)\n",
2016-10-25 06:42:43 +03:00
"\n",
2016-11-02 04:03:02 +03:00
"def resnet_basic_stack(input, num_filters, num_stack):\n",
" assert (num_stack > 0)\n",
" \n",
" r = input\n",
" for _ in range(num_stack):\n",
" r = resnet_basic(r, num_filters)\n",
" return r"
2016-10-25 06:42:43 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's write the full model:"
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 28,
2016-10-25 06:42:43 +03:00
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_resnet_model(input, out_dims):\n",
2016-11-02 04:03:02 +03:00
" conv = convolution_bn(input, (3,3), 16)\n",
2016-11-02 21:29:44 +03:00
" r1_1 = resnet_basic_stack(conv, 16, 3)\n",
2016-10-25 06:42:43 +03:00
"\n",
" r2_1 = resnet_basic_inc(r1_1, 32)\n",
2016-11-02 04:03:02 +03:00
" r2_2 = resnet_basic_stack(r2_1, 32, 2)\n",
2016-10-25 06:42:43 +03:00
"\n",
" r3_1 = resnet_basic_inc(r2_2, 64)\n",
2016-11-02 04:03:02 +03:00
" r3_2 = resnet_basic_stack(r3_1, 64, 2)\n",
2016-10-25 06:42:43 +03:00
"\n",
" # Global average pooling\n",
" pool = C.layers.AveragePooling(filter_shape=(8,8), strides=(1,1))(r3_2) \n",
" net = C.layers.Dense(out_dims, init=C.he_normal(), activation=None)(pool)\n",
2016-10-25 06:42:43 +03:00
" \n",
" return net"
]
},
{
"cell_type": "code",
2017-05-16 22:15:54 +03:00
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2016-11-02 21:29:44 +03:00
"Training 272474 parameters in 65 parameter tensors.\n",
"\n",
"Learning rate per minibatch: 0.01\n",
"Momentum per sample: 0.9983550962823424\n",
2017-05-16 22:15:54 +03:00
"Finished Epoch[1 of 5]: [Training] loss = 1.895607 * 50000, metric = 70.00% * 50000 24.547s (2036.9 samples/s);\n",
"Finished Epoch[2 of 5]: [Training] loss = 1.594962 * 50000, metric = 59.18% * 50000 21.075s (2372.5 samples/s);\n",
"Finished Epoch[3 of 5]: [Training] loss = 1.456406 * 50000, metric = 53.31% * 50000 21.631s (2311.5 samples/s);\n",
"Finished Epoch[4 of 5]: [Training] loss = 1.354717 * 50000, metric = 49.36% * 50000 20.848s (2398.3 samples/s);\n",
"Finished Epoch[5 of 5]: [Training] loss = 1.275108 * 50000, metric = 45.98% * 50000 21.164s (2362.5 samples/s);\n",
"\n",
"Final Results: Minibatch[1-626]: errs = 43.9% * 10000\n",
"\n"
]
2017-05-16 22:15:54 +03:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXe4lNW1/z9fkapYCHZs0VgQVPSKDRUxWKKAMTEGjaKJ\n0URjfRIN5v7EknajKbbkCooFry0aFSwxKmAXCyAIFhQFBKlSVepZvz/WO85wmJkzAzNnZg7r8zzz\nnF3Wu9/17nPOu2bvtffaMjOCIAiCoCE2qLQCQRAEQW0QBiMIgiAoiDAYQRAEQUGEwQiCIAgKIgxG\nEARBUBBhMIIgCIKCCIOxniLpH5J+U6yspCMkTSuvdl/f92NJPRrjXk2B5HczvtSya6HHi5LOKEfb\nQWXZsNIKBKVF0ifA1sC2ZvZ5RvkYYB9gJzObamY/L7TNLLJrtXlH0o7Ax8CGZla3Nm00FSR1A57C\n+3IDoA2wBFBS1tHMPi2mTTN7HuhcatkgSBEjjKaH4S/lvqkCSZ2A1qzli76EpF6GKvuNpGblvse6\nYGYvmVlbM9sE2Avvl01TZfWNhRIqomwQJITBaJoMAfpl5PsBd2UKSLpD0jVJ+ghJ0yRdKmmWpOmS\nzswmmy5Sf0lzJE2WdGpGxXckjZa0UNIUSQMyrns++blA0iJJBybX/FTSxKTsHUn7ZlzTRdLbkuZL\nuk9Si2wPLKmfpJck/UXSHGCApAGShmTI7CipTtIGSX6EpGuS6xZJ+rekdjnanyjpOxn5ZpJmS9pX\nUktJQyTNTfQcJWmLbO00wGoGIZnauUbSK/joY3tJP8noq0mSfpIhf5SkjzPy0yRdImlcotf/SWpe\nrGxS31/SZ4nc2Uk/7tDgAzlXSvpE0kxJgyW1TepaJ/dJ9dtrqf5PnvPj5Dk/lPSDtejPoMSEwWia\nvAa0lbR78nI8BbiH/N/stwbaAtsCZwO3SNo0j2y7RPZMYKCkbyV1S4DTzWxT4HjgZ5J6J3WHJz83\nSb5Fj5J0MnAl8KPk23ZvYF7GvU4GjgZ2xqfUzszzDAcCHwJbAb9LyuqPqurn++IGdQugJfDLHG3f\nC5yakT8WmGNmY5PrNwG2w/vlZ8BXefQshh/hz7wJMB2YCRyX9NVPgZuSEWSK+s93MnAU8E3gv4DT\ni5WVdAJwPnAEsBvQI8u1ufgp3m+HA7vg/fO3pO4sfOS7bVJ+HrA0MSh/Bo5KnvNQYFyB9wvKSBiM\npktqlNETeBeY0YD8cuBaM1tlZk/hL/7dc8ga8P/MbIWZvQA8AfwAwMxeMLMJSfod4H78RZNJpuH6\nCfAnMxudXDPZzDKd6jeY2SwzWwAMAzJHH/WZbmZ/N7M6M1vWwPOmuMPMPkrkH8zT/n1Ab0mtknzf\npAxgBfANYDdzxpjZkgLv3xCDzeyD5PeyysyeMLMpAGY2EngOOCzP9X81szlmNh94nPz9l0v2ZOD2\nRI+vgKuL0P9U4PrEb/YFcAVpw7sCaE+630ab2ZdJXR3QWVLL5Pf/XhH3DMpEGIymyz34P+aZwN0F\nyM+r54j+Etg4h+x8M1uakZ+Cf0tE0oGShifTNQuAc/GXQi62Bz7KUz+rQJ0A1mb11sxC2jezj4CJ\nQC9JrfGR0L1J9RDgaeB+SZ9K+qNK50NZ7ZkknZBM3cyTNB//QpCvf4vpv1yy29bTYxqF+6G2xf8+\nUkwBWiZTdncCzwIPJlNdv5e0gZktxg3yL4CZkoZmjGCDChIGo4liZlNx5/dxwL9K3PzmyUszxQ6k\nRzD/BzwKbGdmmwG3kn65ZJvGmIZPVZSC+u1/ga8+SrHNOrZ/P26E+wATzGwygJmtNLNrzWwv4BCg\nF1CqZaVfP1MyuvknPt22hZltDjxD+RcRfAZ0yMjvQOFTUjOAHTPyOwLLkpHMCjO7xsw6At2Ak4DT\nAMzsaTPriU9/foT/HQUVJgxG0+bHQI9kGqGUCLhaUnNJh+G+igeTuo3xEcgKSV1Zfd5/Dj7VkGkg\nbgN+KWk/AEm7SNq+RHqOBQ6XtH3ij/n1OrZ3P+5P+Tnp0QWSukvqlPiLluBTLcUuGy7kpd8SaA7M\nBSzxLRxV5H3WhgeBn0jaTVIb4L+LuPY+4NJkwUFb4LckfSfpSEl7SRIZ/SZp62Qk1RpYiRv+VaV8\noGDtCIPR9Pj6m5+ZfZzyDdSvK6adLHwGzMe/PQ4BzjWzSUndecC1khbiL5YHMvT5Cv92/LKkzyV1\nNbOHkrJ7JS0CHsEdoMXqu+YDmD2b3H8c8AbuA1lNpMj2ZgKvAgeR8Vz4t+CHgIXABGAE3i+pTY9/\nL6T5hsrMbCFwCT6Cm4d/I6//TA21WbSsmT0O/AN4AXgfeCmpyuUnymxrEN5XL+ILEhYCFyd12+Kj\n34XAeOA/uDFpBvwK//uaAxyMO92DCqNyHqAkqSX+R9YC3yT4kJmt4TCTdCM+dfIFcGay8iQIgiok\nWZX1lpm1rLQuQeNS1hFGsvLkSDPrgq+4OC6ZpvgaSccBu5jZt3AH6f+WU6cgCIpH0onJFGQ74I/4\nKCdYzyj7lFTGMrmW+Cij/pCmD8kqHjMbBWwqaaty6xUEQVGcj/tOPsBXUP2isuoElaDssaQSR+Bb\nuKPzFjN7o57Idqy+ZG96UjaLIAiqgmTFUrCeU3aDkazt7yJpE+BRSR3NbGKx7UiqdBykIAiCmsTM\nSrL0utFWSZnZInz1yLH1qqbjm7dSdEjKsrVBnz7GI48YZtX5GTBgQMV1CD1Dz1rVMfQs/aeUlNVg\nSGqfikeUrKnuCdTf4j+UZJOTpIOABWaWczqqVStYujRXbRAEQVAuyj0ltQ1wV+LH2AB4wMyelHQu\nYGY2MMl/R9KH+LLas/I1GAYjCIKgMpTVYJjZeGC/LOW31ssXvOKi2g1G9+7dK61CQYSepaUW9KwF\nHSH0rGbKunGvlEgyM+Oii2DnneHiixu+JgiCYH1HElZrTu9SseWW0DL2lwZBEDQ65Q4N0gHflLcV\nHoxtkJndWE9mEzwU9w54DJk/m9mdWdqyWhkNBUEQVAulHGGU22BsDWxtZmMlbYxv4OtjGYehSOqP\nn8DWX1J7PLjZVma2sl5bYTCCIAiKpGampMxspiWBBM1PIHsX38W9mhh+NCjJz3n1jUUQBEFQecq+\n0zuFpJ3wAISj6lXdDAyVNAM/S+GUxtIpCIIgKJxGcXon01EPARfZmmcdHwOMMbNtgS7ALYl8EARB\nUEU0RvDBDXFjMcTMHssichbwB/BzkyV9DOwBvFlf8KqrrmLFCli+HI4/vvt6uQ46CIIgHyNHjmTk\nyJFlabvs+zAk3Q3MNbNLc9TfAsw2s6uTsOZvAvuY2ef15MzMGDYMbr0VHn+8rGoHQRA0CUrp9C7r\nCEPSofih7uMljcEd3FfgB8GbmQ3Ez/i9U9K45LLL6huLTNq0gS+/zFUbBEEQlItyhwZ5Gd9bkU/m\nM9yPURBhMIIgCCpDze30DoMRBEFQGcJgBEEQBAVRcwajbVto167SWgRBEKx/VDyWVCLXHfgr0ByY\nY2ZHZpGJ0CBBEARF0tRiSW0KvAIcbWbTJbU3s7lZ2gqDEQRBUCRNLZbUqcDDZjY9kVvDWARBEASV\np9F8GHliSe0GtJM0QtIbkk5vLJ2CIAiCwmmU4IMNxJLaED/GtQewEfCqpFfN7MP67Vx11VVfp7t3\nj9AgQRAE9an10CAbAo8DT5nZDVnqLwdamdnVSf62RPbhenJf+zBmzoRNNvEltkEQBEFuasaHkTAY\nmJjNWCQ8BnST1ExSG+BA3NeRk3794IUXSqxlEARBkJeKx5Iys/ckPQ2MA1YBA81sYr5227WDz3NG\nmwqCIAjKQcVjSSVy1wPXF9puGIwgCILGp+Z2ekMYjCAIgkoQBiMIgiAoiLIaDEkdJA2XNEHSeEkX\n5pE9QNIKSSc11G6HDtCqVWl1DYIgCPJT8dAgidwGwDPAV8BgM/tXlrYiNEgQBEGR1Myy2gJDgwBc\ngG/sm11OfYIgCIK1p+K
"text/plain": [
"<matplotlib.figure.Figure at 0x2ec0f8674e0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXe4FdXVh98fCiKiYO+9i7GLGpOI5LPEEqPG2BJrbDGx\npKhoEvQzRjTG8lmiRCyxEWKLaCyJiC0qShGUIkoTRERUBBG9wPr+WHM4c88959y5l3PPPRfW+zzz\nzMyePXuv2ffcWbP3XmttmRlBEARB0BjtWluAIAiCoG0QCiMIgiDIRCiMIAiCIBOhMIIgCIJMhMII\ngiAIMhEKIwiCIMhEKIwaQ9JfJF3S1LyS9pH0fstKt7jeiZJ6VqOutoakRZI2q/a9S0q67qb8BouU\nM0fSJpWULagdQmFUCUmTJM2XtFpB+vDkn3UjADM7y8yuyFJmkbzNcqqRtHEiQ/weWOL2WBLHppL3\nShos6UtJn0v6SNJDktZegrpK1p31NyjpOUmn1CvEbGUzm1RBuXJ1TZI0L3n+Ocn+/ypdT1CeeEFU\nDwMmAsfmEiRtD6zIkr1kKoESGdTiFUnLtXQdFaK5f5MlacNy9xrwMzNbBdgK6ApcV7SQ5im6Fv/b\nLyEGHGxmqyRKaRUzO6dYxmK/sab+7trQ77SqhMKoLvcAJ6bOTwTuTmeQdKek/02O95H0vqRfSpoh\naZqkk4rlzSepl6SZkiZIOi514SBJwyTNljRZUu/Ufc8n+8+SL7c9kntOkzQ6SXtL0k6pe3aW9Kak\nTyU9IKlDsQeWdKKklyRdK2km0FtSb0n3pPLU+6JPvlz/N7nvc0lPFfbMUveOlnRQ6ny55At8J0kr\nSLpH0seJnK9JWrNYOVmRtLuk/yblTZN0o6TlC7IdLOm9RI6rC+4/JZF5lqQncz3LrNUDmNlnwEPA\n9kmZd0q6RdITkuYAPSR1kHRN8reenlxfISXHbyR9IGmqpJNJKcjC35Wkw5Ke8GxJ4yXtL+kPwLeB\nm9Jf+6o/tLWKpL8l7TBRqWGu5HfxoqQ/Sfokaa8Dszx/g8T6v7GP8d9YsTRJ+q28t/KhpLskrZKU\nkfsNniJpMvBs1j/KskQojOryKrCypK2Tl+PRwL2U/7pbB1gZWA/4KXCzpC5l8q6W5D0J6Ctpy+Ta\nXOAnZtYFOBg4U9L3k2vfSfarJF9ur0k6Cvg98OPkq/b7wKxUXUcB+wObAjsm9ZViD+BdYG0gN9RR\n+AVfeH4srlDXBFYAfl2i7PuB41LnBwIzzWxEcv8qwPp4u5wJfFlGziwsBM5LytsL6An8rCDPD4Bd\nku0wJcM2kg4DLkqurwm8CDzQVAEkrQEcCQxLJR8LXG5mKwMvA1cBWwA7JPv18b8nyYv5l8B3gS2B\n/ylTV3f8o+ZXyW/nO8AkM/ttIv/PC77203/Hm/Df7iZAD+CERDnl6A6MAVYH/gT0a0o7FJD7ja1F\n/jdWmHYycAKwD7BZIttNBeV8B9gGOGAJZFlqCYVRfXK9jP3wf5YPGsn/Nf4iWGhmT+Iv/q1L5DXg\nd2ZWZ2YvAE8APwIwsxfM7O3k+C2gP/6PkyatuE4FrjazYck9E8wsPal+g5nNSL52BwLp3kch08zs\nFjNbZGZfNfK8Oe40s/eS/APKlP8A8H1JHZPzY8m/hOvwl9FW5gw3s7kZ6y+KmQ0zsyFJeVOAvjRs\nxz5mNtvMpgLXkx+GPAO40szeMbNFQB9gJ0kbZqz+RkmfAMPx382vUtf+aWavJjJ+BZwGnJ/I8UVS\nV06Oo/D2HWNmXwKXlqnzFKCfmQ1Kyp5uZu+UyS9YPCx2NHCRmc0zs8nAn4GfpPJONrM7zAPa3Q2s\nI2mtMmU/mvRGPk32p6auFfuNFaYdB1xrZpPNbB7QCzhG+SE8A3qb2ZdN+J0uUxR2pYOW517gBfzL\n/G8Z8s9KXi455gGdS+T91Mzmp84n470N5MNMV+LDGB2S7R9l6t0QeK/M9RkFMq1bJm9zrLc+LCi/\n6DOb2XuSRgOHSnoc7wn9Prl8D7AB0D/pld0LXGJmC5shDwBJj+1aYDd8/ml5YGhBtqmp48V/A2Bj\n4AZJf84Vh7+k1idbG/3CzO4ocW3x/cmwWydgqLT4G6Ad+Q+C9YA3CmQs1cvdEP/waCpr4G0zpaCe\n9VPni//GZvalXNjOwEclyjzMzJ4rca1Y+xWmrZfIkJZnebznm2MqQUmih1Flkq/SicD3gIcrXPyq\nklZMnW9EvgdzH/AosL6ZdQVuI/+SKDbB+z6weYXkKiz/C/yFlqOcsslCf/zr8TDgbTObAGBmC8zs\ncjPrBnwTOBQfklgS/oL3DDdP2vESGr5s0z2Gjcn/Dd4HzjCz1ZJtVTPrnOsZLCHpNv4YV7LdUnV1\nTYaUAKYXkbHUJH+530E5w4CP8R7exgX1TCtzT2M0ZhTQWNoHReSpo/7HT2sboNQ0oTBah1OAnslw\nQCURcJmk9pK+jc9VDEiudcZ7IHXJuHR63H8msIj6L4bbgV9L2gVA0uZNGDppjBHAdyRtmHz5X7SE\n5fXH51POwuc0AJDUQ9L2yZDDXPzlsKh4EQ0Q0DGZOM9twse9PzezeZK2Seos5DeSuibtdU4iH8Ct\nwMWStkvk6yLph01/3PIkQzx/Ba5PehtIWl/S/kmWAcBJkraV1Il8j6wY/YCTJe2bTBqvJyk3JDoD\nnwsoJsOipJ4rJHWWtDFwPt7ray0eAM6XtImkzvi8Rv9UD77WLcVanVAY1SNt5z4xNzdQeK0p5RRh\nOvAp/iV1D/41Oz659jPgckmzgd8Cf0/J8yX+z/NyMjbc3cweTNLul/Q58Ag+0dtUeRs+gNl/kvpH\nAq/jcyD1sjSxvA+BV4A9ST0XbgTwIDAbeBt4juSFJXdOu6VcscAc/Ev9y2S/Lz5vcHzSJreRVwbp\n+/6JD1MNS57tjkTOR/G5hP6SPsOf/8CCe8vJ05RrF+ITvq8mdT2Dm+NiZk/hcyuDgHcoYxFkZq/j\nk8XX4+04GO+5AtwAHCW3+Lq+iCzn4O02AR+GvdfM7mzic6QZmFhk5baHGslfyB343/8FfLh1XiJj\n1vqXedTSCyglFhnX48qpn5ldVXC9K/6H3Bz/xzzFzEa3qFBBEARBk2nRHkYyFHATbqLWDTg26can\nuRgYbmY74tZD4b0ZBEFQg7T0kFR3YHxixlaHd98PK8izHd41xszGAZtoCZ2rgiAIgspTVmHIvWZL\nmbFlodBccCr1zeoA3gSOSOrrjo+PbrAEdQZBEAQtQFk/DDNbmLjLdzGz2S0kQx/cNn0YMAp3Smpg\nJy8pJqSCIAiagZlVxAIsy5DUXGCUpH6S/i+3ZSx/GnmLCvCeQz07bDObY2anmNkuZnYi7sY/oVhh\nZlbzW+/evVtdhpAz5GyrMoacld8qSRZP74dpvoPZ68AWiQ32dOAYUtFawW3RgXnm/gGnAc/bEoZv\nCIIgCCpPowrDzO6WRyLdKkkaZz6B3SjmQ1o/x23Ac2a1YySd4ZetL7AtcLekRbit/KmlSwQuuwzW\nXRdOPz2LCEEQBEGFaFRhSOqBBwabhHtCbijpRPPgdo1i7iS0dUHabanjVwuvl6VLFxg1KnP2atOj\nR4/WFiETIWdlaQtytgUZIeSsZRp13JM0FDjO3OQVSVsBD5jZrlWQLy2HmRk8+ij06wcDC52DgyAI\ngkIkYVWc9G6fUxYA5qGN21ei8max0Ubw+OOwsNkBR4MgCIJmkEVhvCHp9iSQWw9Jf6V+aOTqsk3i\nKD46oocEQRBUkyxWUmcBZ5MP0vUiUC5oW8vSqRP06QPLx1IeQRAE1aTsHIZ8IfS/mdnx1ROppCxW\naZviIAiCpZ2qzWGYr0y2cWJWGwRBECzDZBnXmYCvk/AYvlIaAGZ2bYtJFQRBENQcWRTGe8nWDl9t\nLAiCIFgGKaswkjmMlc3s11WSJwiCIKhRssxh7F0lWZrGU0+F814QBEEVyeLp/Rd8DYt/UH8Oo7kB\nCZtFAyupww93r++vv4b
"text/plain": [
"<matplotlib.figure.Figure at 0x2ec2b341b00>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2016-10-25 06:42:43 +03:00
"source": [
2016-10-26 20:16:27 +03:00
"pred_resnet = train_and_evaluate(reader_train, reader_test, max_epochs=5, model_func=create_resnet_model)"
2016-10-25 06:42:43 +03:00
]
2017-02-03 04:40:20 +03:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
2016-10-25 06:42:43 +03:00
}
],
"metadata": {
"anaconda-cloud": {},
2016-10-25 06:42:43 +03:00
"kernelspec": {
"display_name": "Python 3",
2016-10-25 06:42:43 +03:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2017-02-03 04:40:20 +03:00
"version": "3.5.2"
2016-10-25 06:42:43 +03:00
}
},
"nbformat": 4,
"nbformat_minor": 1
}