799 строки
66 KiB
Plaintext
799 строки
66 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from IPython.display import Image"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"nbpresent": {
|
||
"id": "29b9bd1d-766f-4422-ad96-de0accc1ce58"
|
||
}
|
||
},
|
||
"source": [
|
||
"# CNTK 103: Part B - Logistic Regression with MNIST\n",
|
||
"\n",
|
||
"We assume that you have successfully completed CNTK 103 Part A.\n",
|
||
"\n",
|
||
"In this tutorial we will build and train a Multinomial Logistic Regression model using the MNIST data. This notebook provides the recipe using Python APIs. If you are looking for this example in BrainScript, please look [here](https://github.com/Microsoft/CNTK/tree/v2.0/Examples/Image/GettingStarted)\n",
|
||
"\n",
|
||
"## Introduction\n",
|
||
"\n",
|
||
"**Problem**:\n",
|
||
"Optical Character Recognition (OCR) is a hot area research and there is a great demand for automation. The MNIST data comprises of hand-written digits with little background noise making it a nice dataset to create, experiment and learn deep learning models with reasonably small computing resources."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<img src=\"http://3.bp.blogspot.com/_UpN7DfJA0j4/TJtUBWPk0SI/AAAAAAAAABY/oWPMtmqJn3k/s1600/mnist_originals.png\" width=\"200\" height=\"200\"/>"
|
||
],
|
||
"text/plain": [
|
||
"<IPython.core.display.Image object>"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Figure 1\n",
|
||
"Image(url= \"http://3.bp.blogspot.com/_UpN7DfJA0j4/TJtUBWPk0SI/AAAAAAAAABY/oWPMtmqJn3k/s1600/mnist_originals.png\", width=200, height=200)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Goal**:\n",
|
||
"Our goal is to train a classifier that will identify the digits in the MNIST dataset. \n",
|
||
"\n",
|
||
"**Approach**:\n",
|
||
"The same 5 stages we have used in the previous tutorial are applicable: Data reading, Data preprocessing, Creating a model, Learning the model parameters and Evaluating (a.k.a. testing/prediction) the model. \n",
|
||
"- Data reading: We will use the CNTK Text reader \n",
|
||
"- Data preprocessing: Covered in part A (suggested extension section). \n",
|
||
"\n",
|
||
"Rest of the steps are kept identical to CNTK 102. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Logistic Regression\n",
|
||
"[Logistic Regression](https://en.wikipedia.org/wiki/Logistic_regression) (LR) is a fundamental machine learning technique that uses a linear weighted combination of features and generates probability-based predictions of different classes. \n",
|
||
"\n",
|
||
"There are two basic forms of LR: **Binary LR** (with a single output that can predict two classes) and **multinomial LR** (with multiple outputs, each of which is used to predict a single class). \n",
|
||
"\n",
|
||
"![LR-forms](http://www.cntk.ai/jup/cntk103b_TwoFormsOfLR-v3.png)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In **Binary Logistic Regression** (see top of figure above), the input features are each scaled by an associated weight and summed together. The sum is passed through a squashing (aka activation) function and generates an output in [0,1]. This output value (which can be thought of as a probability) is then compared with a threshold (such as 0.5) to produce a binary label (0 or 1). This technique supports only classification problems with two output classes, hence the name binary LR. In the binary LR example shown above, the [sigmoid][] function is used as the squashing function.\n",
|
||
"\n",
|
||
"[sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In **Multinomial Linear Regression** (see bottom of figure above), 2 or more output nodes are used, one for each output class to be predicted. Each summation node uses its own set of weights to scale the input features and sum them together. Instead of passing the summed output of the weighted input features through a sigmoid squashing function, the output is often passed through a [softmax][] function (which in addition to squashing, like the sigmoid, the softmax normalizes each nodes' output value using the sum of all unnormalized nodes). (Details in the context of MNIST image to follow)\n",
|
||
"\n",
|
||
"In this tutorials, we will use multinomial LR for classifying the MNIST digits (0-9) using 10 output nodes (1 for each of our output classes).\n",
|
||
"\n",
|
||
"[softmax]: https://en.wikipedia.org/wiki/Softmax_function"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"collapsed": true,
|
||
"nbpresent": {
|
||
"id": "138d1a78-02e2-4bd6-a20e-07b83f303563"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Import the relevant components\n",
|
||
"from __future__ import print_function # Use a function definition from future version (say 3.x from 2.7 interpreter)\n",
|
||
"import matplotlib.image as mpimg\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import sys\n",
|
||
"import os\n",
|
||
"\n",
|
||
"import cntk as C\n",
|
||
"import cntk.tests.test_utils\n",
|
||
"cntk.tests.test_utils.set_device_from_pytest_env() # (only needed for our build system)\n",
|
||
"C.cntk_py.set_fixed_random_seed(1) # fix the random seed so that LR examples are repeatable\n",
|
||
"\n",
|
||
"%matplotlib inline"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Initialization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Define the data dimensions\n",
|
||
"input_dim = 784\n",
|
||
"num_output_classes = 10"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Data reading\n",
|
||
"\n",
|
||
"In this tutorial we are using the MNIST data you have downloaded using CNTK_103A_MNIST_DataLoader notebook. The dataset has 60,000 training images and 10,000 test images with each image being 28 x 28 pixels. Thus the number of features is equal to 784 (= 28 x 28 pixels), 1 per pixel. The variable `num_output_classes` is set to 10 corresponding to the number of digits (0-9) in the dataset.\n",
|
||
"\n",
|
||
"The data is in the following format:\n",
|
||
"\n",
|
||
" |labels 0 0 0 1 0 0 0 0 0 0 |features 0 0 0 0 ... \n",
|
||
" (784 integers each representing a pixel)\n",
|
||
" \n",
|
||
"In this tutorial we are going to use the image pixels corresponding the integer stream named \"features\". We define a `create_reader` function to read the training and test data using the [CTF deserializer](https://cntk.ai/pythondocs/cntk.io.html?highlight=ctfdeserializer#cntk.io.CTFDeserializer). The labels are [1-hot encoded](https://en.wikipedia.org/wiki/One-hot). Refer to CNTK 103A tutorial for data format visualizations. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Read a CTF formatted text (as mentioned above) using the CTF deserializer from a file\n",
|
||
"def create_reader(path, is_training, input_dim, num_label_classes):\n",
|
||
" \n",
|
||
" labelStream = C.io.StreamDef(field='labels', shape=num_label_classes, is_sparse=False)\n",
|
||
" featureStream = C.io.StreamDef(field='features', shape=input_dim, is_sparse=False)\n",
|
||
" \n",
|
||
" deserailizer = C.io.CTFDeserializer(path, C.io.StreamDefs(labels = labelStream, features = featureStream))\n",
|
||
" \n",
|
||
" return C.io.MinibatchSource(deserailizer,\n",
|
||
" randomize = is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Data directory is ..\\Examples\\Image\\DataSets\\MNIST\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Ensure the training and test data is generated and available for this tutorial.\n",
|
||
"# We search in two locations in the toolkit for the cached MNIST data set.\n",
|
||
"data_found = False\n",
|
||
"\n",
|
||
"for data_dir in [os.path.join(\"..\", \"Examples\", \"Image\", \"DataSets\", \"MNIST\"),\n",
|
||
" os.path.join(\"data\", \"MNIST\")]:\n",
|
||
" train_file = os.path.join(data_dir, \"Train-28x28_cntk_text.txt\")\n",
|
||
" test_file = os.path.join(data_dir, \"Test-28x28_cntk_text.txt\")\n",
|
||
" if os.path.isfile(train_file) and os.path.isfile(test_file):\n",
|
||
" data_found = True\n",
|
||
" break\n",
|
||
" \n",
|
||
"if not data_found:\n",
|
||
" raise ValueError(\"Please generate the data by completing CNTK 103 Part A\")\n",
|
||
" \n",
|
||
"print(\"Data directory is {0}\".format(data_dir))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Model Creation\n",
|
||
"\n",
|
||
"A logistic regression (LR) network is a simple building block that has been effectively powering many ML \n",
|
||
"applications in the past decade. The figure below summarizes the model in the context of the MNIST data.\n",
|
||
"\n",
|
||
"![mnist-LR](https://www.cntk.ai/jup/cntk103b_MNIST_LR.png)\n",
|
||
"\n",
|
||
"LR is a simple linear model that takes as input, a vector of numbers describing the properties of what we are classifying (also known as a feature vector, $\\bf \\vec{x}$, the pixels in the input MNIST digit image) and emits the *evidence* ($z$). For each of the 10 digits, there is a vector of weights corresponding to the input pixels as show in the figure. These 10 weight vectors define the weight matrix ($\\bf {W}$) with dimension of 10 x 784. Each feature in the input layer is connected with a summation node by a corresponding weight $w$ (individual weight values from the $\\bf{W}$ matrix). Note there are 10 such nodes, 1 corresponding to each digit to be classified. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The first step is to compute the evidence for an observation. \n",
|
||
"\n",
|
||
"$$\\vec{z} = \\textbf{W} \\bf \\vec{x}^T + \\vec{b}$$ \n",
|
||
"\n",
|
||
"where $\\bf{W}$ is the weight matrix of dimension 10 x 784 and $\\vec{b}$ is known as the *bias* vector with lenght 10, one for each digit. \n",
|
||
"\n",
|
||
"The evidence ($\\vec{z}$) is not squashed (hence no activation). Instead the output is normalized using a [softmax](https://en.wikipedia.org/wiki/Softmax_function) function such that all the outputs add up to a value of 1, thus lending a probabilistic iterpretation to the prediction. In CNTK, we use the softmax operation that is combined with the cross entropy error function."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Network input and output: \n",
|
||
"- **input** variable (a key CNTK concept): \n",
|
||
">An **input** variable is a container in which we fill different observations in this case image pixels during model learning (a.k.a.training) and model evaluation (a.k.a. testing). Thus, the shape of the `input` must match the shape of the data that will be provided. For example, when data are images each of height 10 pixels and width 5 pixels, the input feature dimension will be 50 (representing the total number of image pixels). More on data and their dimensions to appear in separate tutorials.\n",
|
||
"\n",
|
||
"\n",
|
||
"**Question** What is the input dimension of your chosen model? This is fundamental to our understanding of variables in a network or model representation in CNTK."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"input = C.input_variable(input_dim)\n",
|
||
"label = C.input_variable(num_output_classes)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Logistic Regression network setup\n",
|
||
"\n",
|
||
"The CNTK Layers module provides a Dense function that creates a fully connected layer which performs the above operations of weighted input summing and bias addition. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def create_model(features):\n",
|
||
" with C.layers.default_options(init = C.glorot_uniform()):\n",
|
||
" r = C.layers.Dense(num_output_classes, activation = None)(features)\n",
|
||
" return r"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"`z` will be used to represent the output of a network."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Scale the input to 0-1 range by dividing each pixel by 255.\n",
|
||
"z = create_model(input/255.0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Learning model parameters\n",
|
||
"\n",
|
||
"Same as the previous tutorial, we use the `softmax` function to map the accumulated evidences or activations to a probability distribution over the classes (Details of the [softmax function][] and other [activation][] functions).\n",
|
||
"\n",
|
||
"[softmax function]: http://cntk.ai/pythondocs/cntk.ops.html#cntk.ops.softmax\n",
|
||
"\n",
|
||
"[activation]: https://docs.microsoft.com/en-us/cognitive-toolkit/Brainscript-Activation-Functions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Training\n",
|
||
"\n",
|
||
"Similar to CNTK 102, we use minimize the cross-entropy between the label and predicted probability by the network. If this terminology sounds strange to you, please refer to the CNTK 102 for a refresher. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"loss = C.cross_entropy_with_softmax(z, label)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Evaluation\n",
|
||
"\n",
|
||
"In order to evaluate the classification, one can compare the output of the network which for each observation emits a vector of evidences (can be converted into probabilities using `softmax` functions) with dimension equal to number of classes."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"label_error = C.classification_error(z, label)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Configure training\n",
|
||
"\n",
|
||
"The trainer strives to reduce the `loss` function by different optimization approaches, [Stochastic Gradient Descent][] (`sgd`) being one of the most popular one. Typically, one would start with random initialization of the model parameters. The `sgd` optimizer would calculate the `loss` or error between the predicted label against the corresponding ground-truth label and using [gradient-decent][] generate a new set model parameters in a single iteration. \n",
|
||
"\n",
|
||
"The aforementioned model parameter update using a single observation at a time is attractive since it does not require the entire data set (all observation) to be loaded in memory and also requires gradient computation over fewer datapoints, thus allowing for training on large data sets. However, the updates generated using a single observation sample at a time can vary wildly between iterations. An intermediate ground is to load a small set of observations and use an average of the `loss` or error from that set to update the model parameters. This subset is called a *minibatch*.\n",
|
||
"\n",
|
||
"With minibatches, we often sample observation from the larger training dataset. We repeat the process of model parameters update using different combination of training samples and over a period of time minimize the `loss` (and the error). When the incremental error rates are no longer changing significantly or after a preset number of maximum minibatches to train, we claim that our model is trained.\n",
|
||
"\n",
|
||
"One of the key optimization parameter is called the `learning_rate`. For now, we can think of it as a scaling factor that modulates how much we change the parameters in any iteration. We will be covering more details in later tutorial. \n",
|
||
"With this information, we are ready to create our trainer. \n",
|
||
"\n",
|
||
"[optimization]: https://en.wikipedia.org/wiki/Category:Convex_optimization\n",
|
||
"[Stochastic Gradient Descent]: https://en.wikipedia.org/wiki/Stochastic_gradient_descent\n",
|
||
"[gradient-decent]: http://www.statisticsviews.com/details/feature/5722691/Getting-to-the-Bottom-of-Regression-with-Gradient-Descent.html"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Instantiate the trainer object to drive the model training\n",
|
||
"learning_rate = 0.2\n",
|
||
"lr_schedule = C.learning_rate_schedule(learning_rate, C.UnitType.minibatch)\n",
|
||
"learner = C.sgd(z.parameters, lr_schedule)\n",
|
||
"trainer = C.Trainer(z, (loss, label_error), [learner])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"First let us create some helper functions that will be needed to visualize different functions associated with training."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Define a utility function to compute the moving average sum.\n",
|
||
"# A more efficient implementation is possible with np.cumsum() function\n",
|
||
"def moving_average(a, w=5):\n",
|
||
" if len(a) < w:\n",
|
||
" return a[:] # Need to send a copy of the array\n",
|
||
" return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]\n",
|
||
"\n",
|
||
"\n",
|
||
"# Defines a utility that prints the training progress\n",
|
||
"def print_training_progress(trainer, mb, frequency, verbose=1):\n",
|
||
" training_loss = \"NA\"\n",
|
||
" eval_error = \"NA\"\n",
|
||
"\n",
|
||
" if mb%frequency == 0:\n",
|
||
" training_loss = trainer.previous_minibatch_loss_average\n",
|
||
" eval_error = trainer.previous_minibatch_evaluation_average\n",
|
||
" if verbose: \n",
|
||
" print (\"Minibatch: {0}, Loss: {1:.4f}, Error: {2:.2f}%\".format(mb, training_loss, eval_error*100))\n",
|
||
" \n",
|
||
" return mb, training_loss, eval_error"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<a id='#Run the trainer'></a>\n",
|
||
"### Run the trainer\n",
|
||
"\n",
|
||
"We are now ready to train our fully connected neural net. We want to decide what data we need to feed into the training engine.\n",
|
||
"\n",
|
||
"In this example, each iteration of the optimizer will work on `minibatch_size` sized samples. We would like to train on all 60000 observations. Additionally we will make multiple passes through the data specified by the variable `num_sweeps_to_train_with`. With these parameters we can proceed with training our simple feed forward network."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Initialize the parameters for the trainer\n",
|
||
"minibatch_size = 64\n",
|
||
"num_samples_per_sweep = 60000\n",
|
||
"num_sweeps_to_train_with = 10\n",
|
||
"num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Minibatch: 0, Loss: 2.2132, Error: 221.32%\n",
|
||
"Minibatch: 500, Loss: 0.5081, Error: 50.81%\n",
|
||
"Minibatch: 1000, Loss: 0.2309, Error: 23.09%\n",
|
||
"Minibatch: 1500, Loss: 0.4300, Error: 43.00%\n",
|
||
"Minibatch: 2000, Loss: 0.1918, Error: 19.18%\n",
|
||
"Minibatch: 2500, Loss: 0.1843, Error: 18.43%\n",
|
||
"Minibatch: 3000, Loss: 0.1117, Error: 11.17%\n",
|
||
"Minibatch: 3500, Loss: 0.3094, Error: 30.94%\n",
|
||
"Minibatch: 4000, Loss: 0.3554, Error: 35.54%\n",
|
||
"Minibatch: 4500, Loss: 0.2465, Error: 24.65%\n",
|
||
"Minibatch: 5000, Loss: 0.1988, Error: 19.88%\n",
|
||
"Minibatch: 5500, Loss: 0.1277, Error: 12.77%\n",
|
||
"Minibatch: 6000, Loss: 0.1448, Error: 14.48%\n",
|
||
"Minibatch: 6500, Loss: 0.2789, Error: 27.89%\n",
|
||
"Minibatch: 7000, Loss: 0.1692, Error: 16.92%\n",
|
||
"Minibatch: 7500, Loss: 0.3069, Error: 30.69%\n",
|
||
"Minibatch: 8000, Loss: 0.1194, Error: 11.94%\n",
|
||
"Minibatch: 8500, Loss: 0.1464, Error: 14.64%\n",
|
||
"Minibatch: 9000, Loss: 0.1096, Error: 10.96%\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create the reader to training data set\n",
|
||
"reader_train = create_reader(train_file, True, input_dim, num_output_classes)\n",
|
||
"\n",
|
||
"# Map the data streams to the input and labels.\n",
|
||
"input_map = {\n",
|
||
" label : reader_train.streams.labels,\n",
|
||
" input : reader_train.streams.features\n",
|
||
"} \n",
|
||
"\n",
|
||
"# Run the trainer on and perform model training\n",
|
||
"training_progress_output_freq = 500\n",
|
||
"\n",
|
||
"plotdata = {\"batchsize\":[], \"loss\":[], \"error\":[]}\n",
|
||
"\n",
|
||
"for i in range(0, int(num_minibatches_to_train)):\n",
|
||
" \n",
|
||
" # Read a mini batch from the training data file\n",
|
||
" data = reader_train.next_minibatch(minibatch_size, input_map = input_map)\n",
|
||
" \n",
|
||
" trainer.train_minibatch(data)\n",
|
||
" batchsize, loss, error = print_training_progress(trainer, i, training_progress_output_freq, verbose=1)\n",
|
||
" \n",
|
||
" if not (loss == \"NA\" or error ==\"NA\"):\n",
|
||
" plotdata[\"batchsize\"].append(batchsize)\n",
|
||
" plotdata[\"loss\"].append(loss)\n",
|
||
" plotdata[\"error\"].append(error)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let us plot the errors over the different training minibatches. Note that as we iterate the training loss decreases though we do see some intermediate bumps. \n",
|
||
"\n",
|
||
"Hence, we use smaller minibatches and using `sgd` enables us to have a great scalability while being performant for large data sets. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAHhtJREFUeJzt3XmcFNW5//HPdwQRr4igBnEB93jVRDSKGBfm4q5RY1zJ\npmaRa8zLLYlR4y9ETXKzea+7RpOYQKLGbO5GlohbrktYLsiiKKgIQpB9GwPM8/vjnJaip2emeuia\n6p553q9Xv6ar+vSpp2tm+qk6deocmRnOOedca+ryDsA551xt8IThnHMuFU8YzjnnUvGE4ZxzLhVP\nGM4551LxhOGccy4VTxidnKQ7JX2n3LKSBkuak210H253tqQh7bGtjiD+bqZUumwb4nhO0hezqNvl\no0veAbhsSHoL2AHY0cwWJ9ZPBA4AdjWzd8zsorR1lijbppt4JPUHZgNdzKyxLXV0FJKOAJ4k7Ms6\nYEtgJaC4bl8ze7ecOs3sGeBjlS7rnJ9hdFxG+FIeWlghaX+gO238oq+gwpehMt+QtFnW29gUZva8\nmfUws62B/Qj7pWdhXXGyUJRLsK7T84TRsY0Ezkssnwf8JllA0r2Sro/PB0uaI+kKSQskzZV0fqmy\nG1bpakkLJc2S9NnECydJmiBpmaS3JQ1PvO+Z+HOppOWSDo3v+aqkaXHdq5IGJN5zoKT/k7RE0v2S\nNi/1gSWdJ+l5Sf8taSEwXNJwSSMTZfpLapRUF5eflnR9fN9ySX+V1LuZ+qdJOimxvJmkf0oaIKmb\npJGS3o9xviRp+1L1tGKjhBCbdq6X9HfC2ccukr6c2FczJX05Uf5oSbMTy3MkXS5pcozrd5K6lls2\nvn61pPdiua/E/div1Q8UfFfSW5LmS/qVpB7xte5xO4X99mJh/8fPOTt+zjcknd2G/ekqxBNGx/Yi\n0EPSR+OX4znAb2n5yH4HoAewI/AV4HZJPVso2zuWPR+4W9Je8bWVwBfMrCdwMvCfkk6Nrx0Vf24d\nj6JfknQW8F3g8/Fo+1RgUWJbZwHHAbsRmtTOb+EzHAq8AfQBfhDXFZ9VFS8PJSTU7YFuwDebqfs+\n4LOJ5ROAhWY2Kb5/a2Anwn75T2BNC3GW4/OEz7w1MBeYD5wY99VXgVvjGWRB8ec7Czga2B04GPhC\nuWUlfQq4GBgM7A0MKfHe5nyVsN+OAvYg7J+b4msXEM58d4zrvwY0xIRyI3B0/JyHA5NTbs9lwBNG\nx1c4yzgWmA7Ma6X8v4AbzGy9mT1J+OL/aDNlDfh/ZrbWzJ4FHgfOBjCzZ81sanz+KvAA4YsmKZm4\nvgz8xMwmxPfMMrPkRfWbzWyBmS0FHgWSZx/F5prZHWbWaGYftPJ5C+41szdj+QdbqP9+4FRJW8Tl\noXEdwFpgW2BvCyaa2cqU22/Nr8zs9fh7WW9mj5vZ2wBmNg4YCxzZwvv/x8wWmtkS4DFa3n/NlT0L\n+GWMYw1wXRnxfxb4Wbxutgq4hg2Jdy2wHRv22wQzWx1fawQ+Jqlb/P3PKGObrsI8YXR8vyX8Y54P\njEhRflHRhejVwFbNlF1iZg2J5bcJR4lIOlTS32JzzVJgGOFLoTm7AG+28PqClDEBtKX31vw09ZvZ\nm8A04BRJ3QlnQvfFl0cCTwEPSHpX0o9UuWsoG30mSZ+KTTeLJC0hHBC0tH/L2X/Nld2xKI45pL8O\ntSPh76PgbaBbbLL7NTAGeDA2df1QUp2ZrSAk5K8D8yU9kjiDdTnwhNHBmdk7hIvfJwJ/rnD1veKX\nZkE/NpzB/A54CNjJzLYBfs6GL5dSzRhzCE0VlVBc/ypC76OCvptY/wOEJHwaMNXMZgGY2Tozu8HM\n9gM+CZwCVKpb6YefKZ7d/IHQ3La9mfUCRpN9J4L3gJ0Ty/1I3yQ1D+ifWO4PfBDPZNaa2fVmti9w\nBPAZ4HMAZvaUmR1LaP58k/B35HLiCaNz+BIwJDYjVJKA6yR1lXQk4VrFg/G1rQhnIGslDWTjdv+F\nhKaGZIL4BfBNSQcBSNpD0i4VinMScJSkXeL1mKs2sb4HCNdTLmLD2QWS6iXtH68XrSQ0tZTbbTjN\nl343oCvwPmDx2sLRZW6nLR4Evixpb0lbAteW8d77gStih4MewPeJ+07Sf0jaT5JI7DdJO8Qzqe7A\nOkLiX1/JD+TK4wmj4/rwyM/MZheuDRS/Vk49JbwHLCEcPY4EhpnZzPja14AbJC0jfLH8PhHPGsLR\n8QuSFksaaGZ/jOvuk7Qc+AvhAmi58Tb9AGZj4vYnA68QroFsVKTM+uYD/wsMIvG5CEfBfwSWAVOB\npwn7pXDT4x1pqm9tnZktAy4nnMEtIhyRF3+m1uosu6yZPQbcCTwLvAY8H19q7jpRsq57CPvqOUKH\nhGXAZfG1HQlnv8uAKcAoQjLZDPgW4e9rIXAY4aK7y4mynEBJ0s6EdvM+hCOte8zslqIyg4GHgVlx\n1Z/N7PuZBeWcq4jYK2u8mXXLOxbXPrK+03sdcIWZTZK0FTBe0qgSPR2eNbNTS7zfOVdFJH2a0Buu\nB/AjwlmO6yQybZIys/mxfzqxe+F0Qh/1Yn7nqnO14WLCtZPXCT2ovp5vOK49tdtYUpJ2JfTnfqnE\ny4dJmkS4IelbZjatveJyzqUXeyy5TqpdEkZsjvojcGmJG5nGA/3MbLWkEwmnuHuXqCPv8Y+cc64m\nmVlFWnEy7yUlqQshWYw0s4eLXzezlYW7OuOdxV3VzDg+ZlZVj+HDh+ceQ63E5TF5TJ0hrmqMqZLa\no1vtr4BpZnZzqRcl9Uk8H0joubW4VFnnnHP5ybRJStLhhDs2pyjMw2CEMWT6A2ZmdwNnSrqIcLPO\nGsIAec4556pMpgnDzF4g3HzTUpnbgduzjCMr9fX1eYdQUjXG5TGl4zGlV41xVWNMlZTpjXuVJMlq\nJVbnnKsWkrBauejtnHOuY/CE4ZxzLhVPGM4551LxhOGccy4VTxjOOedS8YThnHMuFU8YzjnnUqmp\nhPHWW3lH4JxznVdNJYzRo/OOwDnnOi9PGM4551LJNGFI2lnS3yRNlTRF0iXNlLtF0kxJkyQNaK6+\nsWNh/frs4nXOOde8rM8wCnN67wccBlwsaZ9kgThp0h5mthcwDLirucoOOwzeey/LcJ1zzjUn69Fq\n5wPz4/OVkgpzes9IFDsNGBHLvCSpp6Q+ZraguL7HHssyWueccy1pt2sYLczpvRMwJ7E8N65zzjlX\nRdolYbQyp7dzzrkakGmTFLQ+pzfhjGKXxPLOcV0T3/ve9z58Xl9f3+EnK3HOuXKNGzeOcePGZVJ3\n5hMoSRoBvG9mVzTz+knAxWZ2sqRBwE1mNqhEOZ9AyTnnylQzEygl5vQeImmipAmSTpA0TNKFAGb2\nBDBb0hvAz4GvtVTnqlXw859nGbVzzrlSam6K1nXrYLvt4LXXoE+fvKNyzrnqVjNnGFno0gXq68NN\nfM4559pPzSUMgGOPhTFj8o7COec6l5pNGKNHQ420pjnnXIdQkwljr71ACtcxnHPOtY+aTBgS3HUX\nbLNN3pE451znUXO9pJxzzqXXqXtJOeecy4cnDOecc6l4wnDOOZeKJwznnHOp1HTCaGyEAw+EZcvy\njsQ55zq+mk4YdXVhXKmMRvJ1zjmXkPVotb+UtEDS5GZeHyxpaRzFdoKka8vdhg8T4pxz7SPrM4x7\ngeNbKfOsmR0UH98vdwPHHBOGCXHOOZetTBOGmT0PLGml2CbdUDJgACxaBHPmtF7WOedc21XDNYzD\nJE2S9Likfct9c11dOMt4+eUsQnPOOVeQ+ZzerRgP9DOz1ZJOBB4C9m6ucHNzeo8cGebJcM65zq7W\n5/TuDzxqZh9PUXY28AkzW1ziNR9LyjnnytTuY0lJ2kNSt/i8XtIlktKOFSuauU4hqU/i+UBCAmuS\nLJxzzuUvbUPOn4CDJe0J3A08DNwHnNTSmyTdB9QD20p6BxgObA6Ymd0NnCnpImAtsAY4py0fwjnn\nXPZSNUlJmmBmB0n6FtBgZrdKmmhmB2Yf4ocxeJOUc86VKY/hzddKGgqcBzwW13WtRACV9Pzz0NCQ\ndxTOOdcxpU0YFwCHAT8ws9mSdgNGZhdW21x5JbzwQt5ROOdcx1R2LylJvYBdzKzkcB9ZSdMkNXw4\nfPAB/OhH7RSUc85VuTx6SY2TtLWk3sAE4B5J/12JACrJhwlxzrnspG2S6mlmy4HPACPM7FDgmOzC\naptBg+CNN+D99/OOxDnnOp60CaOLpL7A2Wy46F11unaFo46CsWPzjsQ55zqetAnjeuAp4E0ze0XS\n7sDM7MJquwsvhG23zTsK55zreDIfGqRS/D4M55wrXx4XvXeW9BdJ/4yPP0nauRIBOOecqw1pm6Tu\nBR4BdoyPR+M655xznUTaoUEmmdmA1tZlyZuknHOufHkMDbJI0uclbRYfnwcWtfam1ub0jmVukTQz\nTqLUbgnIOedcedImjC8RutTOB94DzgTOT/G+Fuf0jpMm7WFmewHDgLtSxtOq666D8eMrVZtzzrlU\nCcPM3jazU81sezP7iJl9Gjgjxftam9P7NGBELPsS0DM5R8amWL4cHn+8EjU555yDTZvT+4oKbH8n\nYE5ieW5ct8mOPRbGjKlETc4552DT5vSuyEWUcjQ3p3cpRx0FZ50FK1ZAjx7Zx+acc9WgKuf0lvSO\nmfVLUa7ZOb0l3QU8bWa/j8szgMFmtqBE2bJ7SQ0ZAldcAZ/6VFlvc865DqPdeklJWiFpeYnHCsL9\nGGk0O6c34d6OL8ZtDQKWlkoWbeXNUs45VzmZDg2SnNMbWEDTOb2RdBtwArAKuMDMJjRTV9lnGIsX\nh5+9e7cpfOecq3mVPMPwsaScc64Dy+PGPeecc52cJwznnHOpeMJwzjmXSqdIGGvWhPsxnHPOtV2n\nSBjf+Abcc0/eUTjnXG3rFAnj6KNh9Oi8o3DOudrWKbrVLlkC/fvDwoXQrVuFA3POuSrm3WrL1KsX\n/Pu/w9//nncktWH6dLjwQqiRYwnnXDvpFAkDwjAh3izVuvnz4aST4PDDQe0+vKRzrpp1moRx8sl+\nxNyalSvDfvrSl+C885q+3tDQ/jE556pHp7iG4Vq3bh2cdhr07Rt6lBWfXbz/Phx8MNx/Pxx2WD4x\nOufKV1PXMCSdIGmGpNclfbvE64MlLZU0IT6uzTom19Qtt8D69XDnnaWborbbLrx26qnw0EPtH59z\nLn9Zj1ZbB7wOHA3MA14BzjWzGYkyg4FvmNmprdTlZxgZamiAtWtbn2zqH/8ISeM734GLL26f2Jxz\nbVfJM4xNmXEvjYHATDN7G0DSA4R5vGcUlfPLqznbYovwaM3BB8Pzz8OJJ4ZuyolJEJ1zHVzWTVLF\nc3a/S+k5uw+TNEnS45L2zTgmGhuz3kLHtvvu8MILcMwxeUfinGtP1dBLajzQz8wGALcBmbaQNzTA\nAQfAvfd64tgU220HRxyRdxTOufaUdZPUXCA57/fOcd2HzGxl4vmTku6Q1NvMFhdX9r1E+0d9fT31\n9fVlB7TFFvDrX8PXvga/+AXcfjsMGFB2NTVt3jy49trQG2qzzfKOxjlXSePGjWPcuHGZ1J31Re/N\ngNcIF73fA14GhprZ9ESZPoV5vCUNBB40s11L1FXRi96NjSFhXHstDB0K118PPXtWrPqyrF4deiZ1\n7579tlasgMGD4cwz4ZprKl//ypWw1VaVr9c51zY1063WzNYDXwdGAVOBB8xsuqRhki6Mxc6U9Kqk\nicBNwDlZxlRQVxeGv5g2DVatCsNhtLe33oJvfhN23BHGjs1+e+vWwdlnhwvXV19d+foXLICPfhSe\neabydTvn8uc37rUzM3j2Wbj55vDzggtC99Rdd81+uxdeCO++C48+Cl0yaowcOzacsd16K5zTLqnf\nOdeSWupW64q89BIMGwaXXAIjRrTcfNPQAF27VuY6w/33w/jx4eg/q2QBYSj5MWPCECNz58IVV2S3\nLedc+/IzjBYMHw777x/a+ys1EJ9ZeNSlaAy86y647bZwfeX00zcthrVrYdmy0LupPcyZE+7VOP10\nuOGG9tmmc66pSp5heMJowXPPhd5UffuGL+69907/3ldegX79oE+ftm/fDJ58MtxV3aULfP/7cNxx\ntTOK7NKl8Oab8IlP5B2Jc51XzVz0rnVHHgkTJsAJJ8AnPxl6VK1e3Xz5devgwQfD0OBnngmvv75p\n25fCUOPjx8OVV8Jll0F9fcsxVJNttmk+WdxzT2i6mj/fRxF2rlb4GUZK8+aFucH79YMf/3jj15Ys\ngbvvDmchu+0Gl14aRn6t9LWCdevg6afD3B61rLEx7KMpU8Kjri40/X3sY3DTTema65xz6XiTVI7W\nr296EXratJBELr0UDjoon7iSli8P3XVvugm23DLvaFpmFs4yXn0VZs8OPbmKrVoFf/5zSCp77hkS\n55o14b07lRhoZskSePjhUKahIfxcswZ694bLL29aft48uPHGcFNn9+4bxtXq2xfOOKNp+YaGEGuh\nfLdu4e/CDLbfvmn55ctDZ4d//WvD44MPwn0/p51WOp477gj1NTaGh1nofn3ZZU3Lv/MO/PCHTcvv\nsgtcd13T8rNmhTPWZNnGxjDky803Ny3/5ptw1VVh//XuDdtuG37uuisMGdK0vKsu3ksqR6V6LO27\nL/zmN+0fS9KIEaHZrH9/OOus8M/fHjcCbiopfDH37dt8meXL4YknQlKeNQs23zx8WX/84zBqVNPy\nK1eGM7Hu3TckgO7dm78xs0uXsP2GhvBYsiT8XLSodMKYPTtczC+Ub2gIdRx4YOlZHRcsgJ/8JMSd\nfOyxR+mEIYUkVFcXHlL42bt36fi7dw/D3STLS6WTF4Qpi889d+O66+pCE2IpvXuH+3cWLYLFi0OC\nnzo1NLmWShgTJ4ZJuJLJZdttw//J5z7XtPzy5TBz5sa/qy22CAc73bqVjqnamIUDgfXrw7IUvis2\n37x02UKZWuNnGB3Ej38MP/1pOKrcaacwZ0WW3Weda87q1TBjRkguhSSzeHHooTdsWNPy48fDV7+6\n8RlhQwMceig89VTT8q+8Ahdd1DTBHHhg6RtS584NnUcKZQvl+/SB/fZrWv7dd+GRR8KBx4oV4efK\nlbDXXuHMrNioUaEbuRT+5wo9IY8/PpzpFnviiVC+oJDgTzwRHnusafknnyx9YHHCCSHOUuVPOWXD\n8vr13iTlSli8GO67D84/34fncB3XihUhIRUnmF69Qi/CYtOnw89+tnHzZENDuGZ2yy1Ny0+bFm48\n7dEj/B8VHrvvXnqE5nXrQpNeqbOJlhQSS+EBpQ/yGhvDNopJ4T6tUuWTA6t27eoJwznnXArerdY5\n51y784ThnHMulcwThqQTJM2Q9LqkbzdT5hZJM+OsezUzO0VWY85vqmqMy2NKx2NKrxrjqsaYKinT\nhCGpjjCL3vHAfsBQSfsUlTkR2MPM9gKGAXdlGVMlVesfRzXG5TGl4zGlV41xVWNMlZT1GcZAYKaZ\nvW1ma4EHgOIOYqcBIwDM7CWgp6RNGIHJOedcFrJOGDsBcxLL78Z1LZWZW6KMc865nGU9ResZwPFm\ndmFc/jww0MwuSZR5FPgvM/t7XB4DXGlmE4rq8j61zjnXBrUyNMhcoF9ieee4rrjMLq2UqdgHds45\n1zZZN0m9Auwpqb+kzYFzgeKb2R8BvgggaRCw1MwWZByXc865MmV6hmFm6yV9HRhFSE6/NLPpkoaF\nl+1uM3tC0kmS3gBWARdkGZNzzrm2qZmhQZxzzuWrJu70TnPzXwW39UtJCyRNTqzrJWmUpNckPSWp\nZ+K1q+NNh9MlHZdYf5CkyTHmmzYxpp0l/U3SVElTJF2Sd1ySukl6SdLEGNPwvGNK1FcnaYKkR6oh\nJklvSfq/uK9erpKYekr6Q9zGVEmHVkFMe8d9NCH+XCbpkiqI63JJr8b6fidp8yqI6dL4f9e+3wdm\nVtUPQlJ7A+gPdAUmAftkuL0jgAHA5MS6HxN6bgF8G/hRfL4vMJHQtLdrjLNw1vYScEh8/gSht1hb\nY9oBGBCfbwW8BuxTBXFtGX9uBrxIuO8m15hiHZcDvwUeqZLf3yygV9G6vGP6NXBBfN4F6Jl3TEXx\n1QHzCB1icosL2DH+/jaPy78Hzss5pv2AyUA3wv/eKGCP9ohpk3+xWT+AQcCTieWrgG9nvM3+bJww\nZgB94vMdgBmlYgGeBA6NZaYl1p8L3FnB+B4CjqmWuIAtgX8Ah+QdE6GX3Wigng0JI++YZgPbFq3L\nLSZga+DNEuur4u8p1nUc8FzecRESxttAL8IX7iN5/+8BZwL3JJavBb4FTM86plpokkpz81/WPmKx\n55aZzQc+0kxshZsOdyLEWVCxmCXtSjgDepHwx5FbXLHpZyIwHxhtZq/kHRPwP4R/nuTFubxjMmC0\npFckfaUKYtoNeF/SvbH5525JW+YcU7FzgPvi89ziMrN5wI3AO7H+ZWY2Js+YgFeBI2MT1JbASYQz\nscxjqoWEUY1y6SkgaSvgj8ClZrayRBztGpeZNZrZgYSj+oGS9sszJkknAwvMbBLQ0n077f37O9zM\nDiL8Y18s6cgSMbRnTF2Ag4DbY1yrCEehuf49FUjqCpwK/KGZONrzb2obwvBF/QlnG/8m6XN5xmRm\nMwjNT6MJzUgTgfWlilZ627WQMNLc/Je1BYrjW0naAfhnIrZSNx2muhmxHJK6EJLFSDMrTPyYe1wA\nZrYcGAeckHNMhwOnSpoF3A8MkTQSmJ/nfjKz9+LPhYTmxIHku5/eBeaY2T/i8p8ICaQq/p6AE4Hx\nZvZ+XM4zrmOAWWa22MzWA38BPplzTJjZvWZ2sJnVA0sJ1zUzj6kWEkaam/8qTWx8hPoIcH58fh7w\ncGL9ubHXxG7AnsDL8XRwmaSBkkS4MbHE7L5l+RWhvfHmaohL0naFXhiSugPHEtpQc4vJzK4xs35m\ntjvh7+RvZvYF4NG8YpK0ZTwzRNK/Edrmp5DvfloAzJG0d1x1NDA1z5iKDCUk/II843oHGCRpi1jX\n0cC0nGNC0vbxZz/gdELzXfYxVeICVdYPwpHra8BM4KqMt3UfoXfGB4Q/lgsIF7zGxBhGAdskyl9N\n6HUwHTgusf4ThC+GmcDNmxjT4YRTzkmE088JcZ/0zisu4GMxjkmEHhvfietzi6kovsFsuOid537a\nLfF7m1L4+817PwEHEA7GJgF/JvSSyv13R+hAsRDokViX974aHuufDPyG0Fsz75ieJVzLmAjUt9d+\n8hv3nHPOpVILTVLOOeeqgCcM55xzqXjCcM45l4onDOecc6l4wnDOOZeKJwznnHOpeMJwVUdSo6QR\nieXNJC3UhuHKT5F0ZSt19JX0YHx+nqRby4zh6hRl7pX0mXLqrSRJT0s6KK/tu87HE4arRquA/SV1\ni8vHkhg8zcweNbOftFSBmb1nZmcnV5UZwzVllq8pkjbLOwZXezxhuGr1BHByfL7RUBHJM4Z4lH+z\npBckvVE44o9DyUxJ1NcvHpG/Jum7ibr+EkeRnVIYSVbSfwHd40iuI+O6L2rDJEi/SdQ7uHjbSTGO\naQojwr4q6a+FRJg8Q5C0raTZic/3F4XJcGZJulhhEp8Jkv6uMCBewRdjTJMlHRLfv6XCRGAvShov\n6ZREvQ9LGku4I9i5snjCcNXIgAeAofHL9eOEiV6KyxTsYGaHA6cQRvEsVeYQwpg7BwBnJZpyLjCz\nQ+Lrl0rqZWZXA6vN7CAz+4KkfQlnHPUWRue9NMW2k/YEbjWz/YFlwBktfO6C/YBPEwYq/AGw0sLI\nsi8Sxvwp6B5jupgw3hjAd4CxZjYIGAL8LI73BXAg8Bkz+49mYnCuWZ4wXFUys1cJs4MNBR6n5eHK\nH4rvmc6GOQCKjTazpWbWQBg76Yi4/jJJkwhfxDsDe8X1ye0NAf5gZkvidpaWue3ZZlY42xkfP1dr\nnjaz1RZGbF0KPBbXTyl6//1x+88BPSRtTRjg8CqFuUrGAZuzYcTn0Wa2LMX2nWuiS94BONeCR4Cf\nEmbP266Fch8knjeXWJrMXyBpMCEZHGpmH0h6GtiizBjTbDtZZn1iG+vYcNBWvN3keyyx3MjG/7el\n5mUQcIaZzUy+IGkQ4fqQc23iZxiuGhW+eH8FXGdmU9vw3mLHStomNs18GniBMELrkpgs9iFMB1zw\nr8SF4b8RmrF6A0jqVea2m1v/FnBwfH5WM2Vac06M6QjCbHArgKeASz7cuDSgjXU7txFPGK4aGYCZ\nzTWz29KUbWG54GVCU9QkQvPSBOCvQFdJU4EfAv+bKH83MEXSSDObFl9/Jjbz3Fjmtptb/zPgIknj\nCUNTN6elehskTQDuAL4U199A+FyTJb0KXN9C3c6l5sObO+ecS8XPMJxzzqXiCcM551wqnjCcc86l\n4gnDOedcKp4wnHPOpeIJwznnXCqeMJxzzqXy/wEK3Iz4BW7IRAAAAABJRU5ErkJggg==\n",
|
||
"text/plain": [
|
||
"<matplotlib.figure.Figure at 0x26a51844128>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAACfCAYAAADqDO7LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXm8XdPd/9+fyCAqQtAgiCkoipjVkFvEWLRFTa3paeVB\nfzHW/GuKtkppzZSihqLaqlkbU4yPKZFHyEBlEEEMmSVXM3yfP9Y6ufuee865+9x7ztnn5n7fr9d+\nnb3XXnutz9lnn/1d67smmRmO4ziO0xpdshbgOI7jdAzcYDiO4zipcIPhOI7jpMINhuM4jpMKNxiO\n4zhOKtxgOI7jOKlwg1GnSLpR0gXlxpU0SNLU6qpbmu8kSXvUIq+OhqQlkjao9bXtJZl3Oc9ggXTm\nSlqvktqc7HGDUWMkTZbUKKlPXvib8c+6LoCZnWRmv0qTZoG4bRpcI6l/1ODPBe2+H+0Z4FT0Wkkj\nJC2QNEfSp5L+LqlvO/IqmnfaZ1DSs5JOaJaIWS8zm1xBXbm8JkuaH7//3Ph5TaXzcQrjL4baY8Ak\n4MhcgKQtgJ607yVTCRQ1qOoZSctVO48K0dbfpD33sNS1BpxsZisBGwMrA78vmEjbDF3Vf/t2YsAB\nZrZSNEormdnQQhELPWPlPncd6DmtCW4wsuEu4NjE8bHAHckIkm6XdHHcHyRpqqQzJE2XNE3ScYXi\nNgXpPEmfSZoo6ajEif0ljZI0W9IUScMS1z0XP2fFktuO8ZqfSBobw96WtHXimoGS/lfSTEn3Supe\n6AtLOlbSi5J+J+kzYJikYZLuSsRpVqKPJdeL43VzJP0zv2aWuHaspP0Tx8vFEvjWknpIukvS51Hn\nq5JWL5ROWiRtL+nlmN40SddK6poX7QBJ70cdl+ddf0LU/IWkJ3I1y7TZA5jZLODvwBYxzdsl3SDp\nMUlzgQZJ3SVdEX/rj+P5HgkdP5P0kaQPJR1PwkDmP1eSDo414dmS3pO0t6RfArsB1yVL+2ru2lpJ\n0p3xPkxSws0Vn4sXJP1W0ox4v/ZN8/1bBDZ/xj4nPGOFwiTpQoXayieS/iRppZhG7hk8QdIU4Om0\nP0pnwA1GNrwC9JK0SXw5Hg7cTenS3RpAL2At4MfA9ZJ6l4jbJ8Y9DrhZ0oB4bh7wIzPrDRwA/Lek\ng+K53ePnSrHk9qqkw4CfAz+MpdqDgC8SeR0G7A2sD2wV8yvGjsC/gb5AztWRX4LPPz6SYFBXB3oA\nZxVJ+x7gqMTxvsBnZjY6Xr8S0I9wX/4bWFBCZxoWA6fF9HYG9gBOzovzXWCbuB2s6LaRdDBwbjy/\nOvACcG+5AiStBhwCjEoEHwlcYma9gJeAy4CNgC3jZz/C70l8MZ8B7AkMAPYqkdcOhELNmfHZ2R2Y\nbGYXRv0/zSvtJ3/H6wjP7npAA3BMNE45dgDGAasCvwVuLec+5JF7xr5O0zOWH3Y8cAwwCNggarsu\nL53dgU2BfdqhZZnDDUZ25GoZgwl/lo9aif8fwotgsZk9QXjxb1IkrgH/38wWmtnzwGPADwDM7Hkz\neyfuvw3cR/jjJEkarv8CLjezUfGaiWaWbFS/2symx9LuI0Cy9pHPNDO7wcyWmNlXrXzfHLeb2fsx\n/v0l0r8XOEjS8vH4SJpewgsJL6ONLfCmmc1LmX9BzGyUmb0W0/sAuJmW9/E3ZjbbzD4ErqLJDTkE\nuNTM3jWzJcBvgK0lrZMy+2slzQDeJDw3ZybOPWRmr0SNXwE/AU6POr6MeeV0HEa4v+PMbAHwixJ5\nngDcambPxLQ/NrN3S8QXLHWLHQ6ca2bzzWwKcCXwo0TcKWZ2m4WJ7e4A1pD09RJpPxhrIzPj538l\nzhV6xvLDjgJ+Z2ZTzGw+cB5whJpceAYMM7MFZTynnYL8KrRTO+4GnieUzO9MEf+L+HLJMR9YsUjc\nmWbWmDieQqhtoOBmupTgxuget7+WyHcd4P0S56fnaVqzRNy29N76JC/9gt/ZzN6XNBY4UNKjhJrQ\nz+Ppu4C1gftirexu4AIzW9wGPQDEGtvvgO0I7U9dgZF50T5M7C/9DYD+wNWSrswlR3hJ9SPdPfp/\nZnZbkXNLr49utxWAkdLSMkAXmgoEawFv5GksVstdh1DwKJfVCPfmg7x8+iWOl/7GZrZAQeyKwKdF\n0jzYzJ4tcq7Q/csPWytqSOrpSqj55vgQpwVew8iIWCqdBOwHPFDh5FeR1DNxvC5NNZg/Aw8C/cxs\nZeAPNL0kCjXwTgU2rJCu/PS/JLzQcpQyNmm4j1B6PBh4x8wmApjZIjO7xMw2B74FHEhwSbSHGwk1\nww3jfbyAli/bZI2hP02/wVRgiJn1idsqZrZirmbQTpL3+HOCkd08kdfK0aUE8HEBjcUa+Us9B6U6\nBnxOqOH1z8tnWolrWqO1TgGthX1UQM9Cmhd+su6AUpe4wciWE4A9ojugkgi4SFI3SbsR2iruj+dW\nJNRAFka/dNLv/xmwhOYvhj8CZ0naBkDShmW4TlpjNLC7pHViyf/cdqZ3H6E95SRCmwYAkhokbRFd\nDvMIL4clhZNogYDlY8N5bhPB7z3HzOZL2jTmmc/PJK0c79fQqA/gJuB8SZtFfb0lHVr+1y1NdPHc\nAlwVaxtI6idp7xjlfuA4Sd+QtAJNNbJC3AocL+nbsdF4LUk5l+h0QltAIQ1LYj6/krSipP7A6YRa\nX1bcC5wuaT1JKxLaNe5L1ODrvadYZrjBqD3Jfu6Tcm0D+efKSacAHwMzCSWpuwil2ffiuZOBSyTN\nBi4E/pLQs4Dw53kp+oZ3MLO/xbB7JM0B/kFo6C1Xb8svYPZUzP8t4HVCG0izKGWm9wnwP8BOJL4X\noRPA34DZwDvAs8QXlsLgtBtKJQvMJZTUF8TPbxPaDY6O9+QPNBmD5HUPEdxUo+J3uy3qfJDQlnCf\npFmE779v3rWl9JRz7hxCg+8rMa/hhO64mNk/CW0rzwDvUqJHkJm9TmgsvopwH0cQaq4AVwOHKfT4\nuqqAlqGE+zaR4Ia928xuL/N7JHkk9sjKbX9vJX4+txF+/+cJ7tb5UWPa/DstquYCSpLWJvjn+xJK\ndLeY2TV5cQYR/lgTY9ADZvbLqolyHMdx2kS1G70XAWeY2ehY9RspabiZjc+L97yZHVTgesdxHKdO\nqKpLysw+if3gid0Yx9G8d0QO9xk6juPUOSUNhsJo2WLd18pCYSKyrYFXC5zeWdJohRGqm1UiP8dx\nHKeylHRJmdniOEy+t5nNbmsm0R31N+DUAgOmRgLrxt4m+xG6fG5cIA1viHIcx2kDZlYRL04al9Q8\nYIykWyVdk9vSZqAwv87fgLvM7KH882Y2L462JI5g7qYi8wWZWV1tw4YNy1xDR9HlmlxTZ9BVj5oq\nSZpG7wdo38Cy24CxZnZ1oZOS+prZ9Li/A6Hn1ox25Oc4juNUgVYNhpndoTADac5NNMHMFqZJXNIu\nwNGEGsqbhP7N5xNHlJrZzcChkk4iDKZaQJh3xnEcx6kzWjUYkhoIE4JNJvRmWkfSsRYmtSuJmb0E\nlJxP3syuB65PI7beaGhoyFpCQepRl2tKh2tKTz3qqkdNlaTVgXuSRgJHmdmEeLwxcK+ZbVsDfUkd\nVml/nOM4zrKOJKyGjd7dcsYCwMKUxt0qkbnjOI7TcUjT6P2GpD8SpoSG0CbxRon4juM4zjJIGpdU\nD+AUYNcY9AJwg9V4YRF3STmO45RPJV1SJQ2GwgLod5rZ0ZXIrD24wXAcxymfmrVhWFiRrH/sVus4\njuN0YtK0YUwkrI/wMGGFNADM7HdVU+U4juPUHWkMxvtx60JYZcxxHMfphJQ0GLENo5eZnVUjPaWZ\nPBnWWy9rFY7jOJ2SNG0Yu9RIS+s8+WTWChzHcTotaQbujZb0sKQfSfp+bqu6skK4wXAcx8mMNOMw\nCi3WbmZ2QquJp1jTO8a7BtiP0Kh+nMVV+vLimPXpA59+CsuVnJ7KcRzHiVSyW22a2WqPb0f6ra7p\nHRdN2tDMBkjaEbgJ2KlgajvvDB9/DGuv3Q5JjuM4Tlso6pKSdH9i/7K8c8PTJG7p1vQ+mFALwcxe\nBXpL6lswwUcfdWPhOI6TEaXaMAYk9gfnnVu93IxKrOndD5iaOJ5GS6PiOI7jZEwpg1GqcaOsOTpa\nWdPbcRzH6QCUasNYQdJAglHpGfcVt55pM2htTW9CjWKdxPHaMawFv/jFL5buNzQ0LPOLlTiO45TL\niBEjGDFiRFXSLtpLStKzpS40s2+nykC6E/jczM4ocn5/4BQzO0DSTsBVZtai0dsnH3Qcxymfms1W\n2+7Ew5rezwNjCG6sQmt6I+k6YF9Ct9rjzWxUgbSCwfjyS7j7bhgypGq6HcdxlhU6jMGoJEsNxqJF\nsNpqMGEC9C3cmcpxHMcJ1HqJ1vqia1doaICnn85aieM4Tqei4xkMgMGD4amnslbhOI7TqUjlkpLU\nj9DusLRXlZk9X0VdhTQ0NXq/+y7suSd88AGoIjUtx3GcZZKaTg0SR3kfDowFFsdgIzRmZ8OAAcFQ\nTJgAm26amQzHcZzORJoFlL4LbGJmX1VbTGokuOkmWHnlrJU4juN0GtLMVvsEcFjWI7R9HIbjOE75\n1NQlBcwnrInxNLC0lmFmQyshwHEcx+kYpDEYD8fNcRzH6cSk7SXVHdg4Hk4ws4VVVVVYg7ukHMdx\nyqTWvaQagDuAyYSJB9eRdGytu9U6juM42ZJm4N6VwN5mNsjMdgf2AX5fXVkpWbIEBg6E2bOzVuI4\njrPMk8ZgdDOzCbkDM3sX6FY9SWXQpUuYV6pKU/k6juM4TaQxGG9I+qOkhrjdAryRJnFJt0qaLumt\nIucHSZolaVTcLixHPODThDiO49SINOMwegCnALvGoBeAG9IM5JO0KzAPuNPMtixwfhBwppkdlCKt\nwo3eo0bBUUfB+PGtJeE4jtPpqGmjdzQMv4tbWZjZi5L6txKtfV9k663hiy9g6lRYZ53W4zuO4zht\noqhLStL98XOMpLfytwpq2FnSaEmPSdqs7Ku7dIG99oLXXqugJMdxHCefUjWMU+Pnd6qY/0hgXTOb\nL2k/4EGaxnu0oOia3nfdFdbJcBzH6eRksqb30gjSZWZ2TmthJa7vDzxSqA2jQNxJwLZmNqPAOR+4\n5ziOUya1XnFvcIGw/crIQxRpp5DUN7G/A8GAtTAWjuM4TvYU9eNIOgk4Gdgwr82iF/BymsQl3QM0\nAKtK+gAYBnQHzMxuBg6N+SwEFhDW3XAcx3HqkKIuKUm9gVWAS4FzE6fmZlELcJeU4zhO+dTEJWVm\ns81sMnA1MMPMppjZFGCRpB0rkXnFefFFaGzMWoXjOM4ySZo2jBsJg+9yzIth9cfZZ8NLL2WtwnEc\nZ5kkjcFo5gsysyWkW0ej9gweDE8+mbUKx3GcZZI0BmOipKGSusXtVGBitYW1ib32coPhOI5TJdKM\nw/g6cA2wB2DA08BpZvZp9eU109F6o/fChWH22vffD5+O4zidnEo2eqdaca8eSN1L6sAD4Yc/hMO9\nh67jOE5NJh+UdLaZXS7pWkLNohlmNrQSAirOiSdCz55Zq3Acx1nmKNV4PS5+plr7om448MCsFTiO\n4yyTLHsuKcdxHGcptXJJPUIBV1SONIseOY7jOMsOpVxSV8TP7wNrAHfH4yOB6dUU5TiO49QfabrV\nvmFm27UWVuTaWwnraUwvNr25pGsIs99+CRxnZqOLxHOXlOM4TpnUenrzr0naIJH5+sDXUqZ/O7BP\nsZNx0aQNzWwAMAS4KWW6rXPRRTByZMWScxzH6eykMRinAyMkjZD0HPAscFqaxM3sRWBmiSgHA3fG\nuK8CvZNrZLSLOXPgsccqkpTjOI6TwmCY2T+BAYQlW4cCm5jZvyqUfz9gauJ4WgxrP4MHw1NPVSQp\nx3EcJ8UkgpJWAM4A+pvZTyQNkLSJmT1afXnNKbqmdyF23x0OOwzmzoVevaquzXEcpx7Iek3vvwAj\ngWPMbItoQF42s61TZVBiTW9JNwHPmtlf4vF4YJCZteiF1aZG7z32gDPOgO98p7zrHMdxlhFq3ei9\noZldTlhGFTObT5E1uotQdE1v4GHgGABJOwGzChmLNuNuKcdxnIqRZl2L/0jqSRzEJ2lD4Ks0ibe2\npreZPS5pf0n/JnSrPb4N36E4Q4ZUNDnHcZzOTBqX1GDgQmAzYDiwC2G8xIiqq2uuw8dhOI7jlEnN\npjeXJGBtYD6wE8G19IqZfV6JzMvBDYbjOE751HQ9DEljzOyblcisPbjBcBzHKZ9aN3qPkrR9JTJz\nHMdxOi5pahjjCQP3JhMapkVotC44N1S1aFcNY8ECWLTIx2M4jtPpqLVLqn+hcDObUgkBaWmXwTj5\nZNhoozAmw3EcpxNRE5eUpOUlnQb8DNgXmGZmU3JbJTKvGXvuCU8+mbUKx3GcDk3RGkYc4b0QeIEw\n/fgUMzu1htry9bS9hjFzJvTvD599Bj16VFaY4zhOHVOrRu/NzOyHZvYH4FBgt0pkmAmrrALf+Aa8\n/HLWSjoG48bBiSeC90pzHCdBKYOxMLdjZotqoKW6DB7sbqk0fPIJ7L8/7LILqCKFEsdxlhFKGYyt\nJM2J21xgy9y+pDm1ElgxDjjAS8ytMW9euE8nnADHHtvyfGNj7TU5jlM3tNpLql7wgXtVZtEiOPhg\nWHNNuOWWlrWLzz+H7baDe++FnXfORqPjOGVT64F77ULSvpLGS3pX0jkFzg+SNEvSqLhdWG1NTgGu\nuQYWL4YbbyzsilpttXDuoIPgwQdrr89xnMypag1DUhfgXWBP4CPgdeAIMxufiDMIONPMDmolLa9h\nVJPGRli4sPXBjW+8EYzGBRfAKafURpvjOG2mkjWMNNObt4cdgPdy4zYk3UdYx3t8XjxvXc2a5ZcP\nW2tstx28+CLst1/oppxYBdFxnGWbaruk8tfs/pDCa3bvLGm0pMckbVZlTbBkSdWzWKbZYAN46SXY\na6+slTiOU0NKjfSem+wllTiudC+pkcC6ccnX64DqOsgbG2GrreD2291wtIfVVoNdd81aheM4NaSo\nS8rMKjFT3zRg3cTx2jEsmc+8xP4Tkm6Q1MfMZuQn9ouE+6OhoYGGhobyFS2/PPzpT2F+qT/+Ea6/\nHrZOtTz5ssNHH8GFF4beUMstl7Uax3EqyIgRIxgxYkRV0k7V6C1pV2CAmd0uaTWgl5lNSnHdcsAE\nQqP3x8BrwJFmNi4Rp29uHW9JOwD3m9l6BdKqbKP3kiXBYFx4IRx5JFx8MfTuXbn0y2H+/NAzqWfP\n6uc1dy4MGgSHHgrnn1/59OfNgxVXrHy6juO0iZp2q5U0DDgHOC8GdQfuTpO4mS0GfkpY2vUd4D4z\nGydpiKQTY7RDJb0t6U3gKuDwMr9D2+jSJUx/MXYsfPllmA6j1kyeDGedBWutBU8/Xf38Fi2CH/wg\nNFyfd17r8ctl+nTYZBN47rnKp+04Tuakmd58NDAQGGVmA2PYWx1qPYx6wgyefx6uvjp8Hn986J66\n3nrVz/fEE+HDD+GRR6BrlTrIPf10qLFdey0cXhvb7zhOcWrdrfY/ZmaSLGb+tUpk3Gl59VUYMgSG\nDoU77yztvmlshG7dKtPOcO+9MHJkKP1Xy1hAmEr+qafCFCPTpvkaJI6zDJGmhnEWYcW9wcClwAnA\nPWZ2bfXlNdNR+xrGsGGwxRbB31+pifjMwtYlRY/mm26C664L7Svf+177NCxcCLNnh95NtWDq1DBW\n43vfg0suqU2ejuO0oKYr7sUMBwN7x8PhZlbzaV8zMRgvvBB6U625Znhxb7xx+mtffx3WXRf69m17\n/mbwxBNhVHXXrvDLX8Lee3ecWWRnzYL334dtt81aieN0WrKYS2oMYSGl5+N+52C33WDUKNh3X/jW\nt0KPqvnzi8dftAjuvz9MDX7oofDuu+3LXwpTjY8cCWefDaedBg0NpTXUEyuvXNxY3HJLcF198onP\nIuw4HYQ0LqkfAz8HniFM4TEIuNjMbqu+vGY6sm30/ugjOPPMUGu47LLm52bOhJtvDrWQ9deHU08N\nM79Wuq1g0SJ49tmwtkdHZsmScI/GjAlbly7B9ffNb8JVV6Vz1zmOk4qauqQkTQC+ZWZfxONVgZfN\nbJNKCEhL5gYjx+LFLRuhx44NRuTUU2GbbbLRlWTOnNBd96qrYIUVslZTGrNQy3j7bZg0KfTkyufL\nL+GBB4JR2WijYDgXLAjX9isw08zMmfDQQyFOY2P4XLAA+vSB009vGf+jj+DKK8Ogzp49m+bVWnNN\nOOSQlvEbG4PWXPwePcJzYQarr94y/pw5obPDf/7TtH31VRj3c/DBhfXccENIb8mSsJmF7tenndYy\n/gcfwK9/3TL+OuvARRe1jD9xYqixJuMuWRKmfLn66pbx338fzj033L8+fWDVVcPneuvBHnu0jO/U\nFbXuJfUFMDdxPDeGdU4K9VjabDO4447aa0ly553Bbda/Pxx2WPjz12IgYHuRwot5zTWLx5kzBx5/\nPBjliROhe/fwst5ySxg+vGX8efNCTaxnzyYD0LNn8YGZXbuG/BsbwzZzZvj84ovCBmPSpNCYn4vf\n2BjSGDiw8KqO06fD5ZcH3cltww0LGwwpGKEuXcImhc8+fQrr79kzTHeTjC8VNl4Qliw+4ojmaXfp\nElyIhejTJ4zf+eILmDEjGPh33gku10IG4803wyJcSeOy6qrhf3L00S3jz5kD773X/LdafvlQ2OnR\no7CmesMsFAQWLw7HUnhXdO9eOG4uTgejaA1DUq4/5NbAN4GHACPMNvuWmR1XC4EJPfVRw6hXLrsM\nfvvbUKrs1y+sWVHN7rOOU4z582H8+GBcckZmxozQQ2/IkJbxR46En/ykeY2wsRF23BH+9a+W8V9/\nHU46qaWBGTiw8IDUadNC55Fc3Fz8vn1h881bxv/wQ3j44VDwmDs3fM6bBwMGhJpZPsOHh27kUvjP\n5XpC7rNPqOnm8/jjIX6OnIHfbz949NGW8Z94onDBYt99g85C8Q88sCn5xYur75KKI7yLYmYF6rrV\nww1GCmbMgHvugeOO8+k5nGWXuXODQco3MKusEnoR5jNuHFxxRXP3ZGNjaDO75pqW8ceODQNPe/UK\n/6PctsEGhWdoXrQouPQK1SZKkTMsuQ0KF/KWLAl55COFcVqF4icmVlW3brXtVlsPuMFwHMcpn5q2\nYUhaHTgb2BxYusKOmXlrl+M4TiciTf/FPxNWyFsfuAiYTFhq1XEcx+lEpDEYq5rZrcBCM3vOzE4A\nUtcuJO0rabykdyWdUyTONZLei6vudZjFKao153x7qUddrikdrik99airHjVVkjQGY2H8/FjSAZIG\nAkX69zVHUhfCKnr7EFxaR0raNC/OfsCGZjYAGALclFZ81tTrw1GPulxTOlxTeupRVz1qqiRp+l3+\nUlJv4EzgWmAloMDooYLsALxnZlMAJN1H6JY7PhHnYOBOADN7VVLv5KJKjuM4Tn3Qag3DzB41s9lm\n9raZfdvMtgU2TJl+P2Bq4vjDGFYqzrQCcRzHcZyMaVO3WkkfmNm6KeIdAuxjZifG4x8CO5jZ0ESc\nR4BLzezlePwUcLaZjcpLy/vUOo7jtIFaTg1SiLSZTwOShmXtGJYfZ51W4lTsCzuO4zhto63TgqYt\n7b8ObCSpv6TuwBFA/lj2h4FjACTtBMzy9gvHcZz6o2gNQ9JcChsGAalmtTOzxZJ+CgwnGKdbzWyc\npCHhtN1sZo9L2l/Sv4EvgePL/haO4zhO1ekwU4M4juM42dIhVqpJM/ivgnndKmm6pLcSYatIGi5p\ngqR/xW7GuXPnxUGH4yTtnQjfRtJbUfNV7dS0tqRnJL0jaYykoVnrktRD0quS3oyahmWtKZFeF0mj\nJD1cD5okTZb0v/FevVYnmnpL+mvM4x1JO9aBpo3jPRoVP2dLGloHuk6X9HZM78+SuteBplPj/662\n7wMzq+uNYNT+DfQHugGjgU2rmN+uhCnd30qEXUbouQVwDvCbuL8Z8CbBtbde1Jmrtb0KbB/3Hyf0\nFmurpjWAreP+isAEYNM60LVC/FwOeIUw7iZTTTGN04G7gYfr5PebCKySF5a1pj8Bx8f9rkDvrDXl\n6esCfEToEJOZLmCt+Pt1j8d/AY7NWNPmwFtAD8J/bzhhqEPVNbX7h632BuwEPJE4Phc4p8p59qe5\nwRgP9I37awDjC2kBngB2jHHGJsKPAG6soL4Hgb3qRRewAvAGsH3Wmgi97J4EGmgyGFlrmkSYYicZ\nlpkmwuDb9wuE18XzFNPaG3gha10EgzEFWIXwwn046/8ecChwS+L4QuBnwLhqa+oILqk0g/+qzdct\n9twys0+ArxfRlht02I+gM0fFNEtaj1ADeoXwcGSmK7p+3gQ+AZ40s9ez1gT8nvDnSTbOZa3JgCcl\nvS7px3WgaX3gc0m3R/fPzZJWyFhTPocD98T9zHSZ2UfAlcAHMf3ZZvZUlpqAt4HdogtqBWB/Qk2s\n6po6gsGoRzLpKSBpReBvwKlmNq+AjprqMrMlZjaQUKrfQdLmWWqSdAAw3cxGU3qsUK1/v13MbBvC\nH/sUSbsV0FBLTV2BbYDro64vCaXQTJ+nHJK6AQcBfy2io5bP1MqE6Yv6E2obX5N0dJaazGw8wf30\nJMGN9CawuFDUSufdEQxGmsF/1Wa6pL4AktYAPk1oKzToMNVgxHKQ1JVgLO4ys9y6j5nrAjCzOcAI\nYN+MNe0CHCRpInAvsIeku4BPsrxPZvZx/PyM4E7cgWzv04fAVDN7Ix7/nWBA6uJ5AvYDRprZ5/E4\nS117ARPNbIaZLQb+AXwrY02Y2e1mtp2ZNQCzCO2aVdfUEQxGmsF/lUY0L6E+DBwX948lrG+eCz8i\n9ppYH9gIeC1WB2dL2kGSCAMTCyzuWxa3EfyNV9eDLkmr5XphSOoJDCb4UDPTZGbnm9m6ZrYB4Tl5\nxsx+BDySlSZJK8SaIZK+RvDNjyHb+zQdmCpp4xi0J/BOlpryOJJg8HNkqesDYCdJy8e09gTGZqwp\nt7AdktYFvkdw31VfUyUaqKq9EUquE4D3gHOrnNc9hN4ZXxEeluMJDV5PRQ3DgZUT8c8j9DoYB+yd\nCN+W8GJGo6JjAAAD40lEQVR4D7i6nZp2IVQ5RxOqn6PiPemTlS7gm1HHaEKPjQtieGaa8vQNoqnR\nO8v7tH7idxuTe36zvk/AVoTC2GjgAUIvqcx/O0IHis+AXomwrO/VsJj+W8AdhN6aWWt6ntCW8SbQ\nUKv75AP3HMdxnFR0BJeU4ziOUwe4wXAcx3FS4QbDcRzHSYUbDMdxHCcVbjAcx3GcVLjBcBzHcVLh\nBsOpOyQtkXRn4ng5SZ+pabryAyWd3Uoaa0q6P+4fK+naMjWclyLO7ZK+X066lUTSs5K2ySp/p/Ph\nBsOpR74EtpDUIx4PJjF5mpk9YmaXl0rAzD42sx8kg8rUcH6Z8TsUkpbLWoPT8XCD4dQrjwMHxP1m\nU0UkawyxlH+1pJck/TtX4o9TyYxJpLduLJFPkPTzRFr/iLPIjsnNJCvpUqBnnMn1rhh2jJoWQboj\nke6g/LyTRB1jFWaEfVvSP3OGMFlDkLSqpEmJ7/cPhcVwJko6RWERn1GSXlaYEC/HMVHTW5K2j9ev\noLAQ2CuSRko6MJHuQ5KeJowIdpyycIPh1CMG3AccGV+uWxIWesmPk2MNM9sFOJAwi2ehONsT5tzZ\nCjgs4co53sy2j+dPlbSKmZ0HzDezbczsR5I2I9Q4GizMzntqiryTbARca2ZbALOBQ0p87xybA98l\nTFT4K2CehZllXyHM+ZOjZ9R0CmG+MYALgKfNbCdgD+CKON8XwEDg+2b27SIaHKcobjCcusTM3ias\nDnYk8Bilpyt/MF4zjqY1APJ50sxmmVkjYe6kXWP4aZJGE17EawMDYngyvz2Av5rZzJjPrDLznmRm\nudrOyPi9WuNZM5tvYcbWWcCjMXxM3vX3xvxfAHpJWokwweG5CmuVjAC60zTj85NmNjtF/o7Tgq5Z\nC3CcEjwM/Jawet5qJeJ9ldgvZlharF8gaRDBGOxoZl9JehZYvkyNafJOxlmcyGMRTYW2/HyT11ji\neAnN/7eF1mUQcIiZvZc8IWknQvuQ47QJr2E49UjuxXsbcJGZvdOGa/MZLGnl6Jr5LvASYYbWmdFY\nbEpYDjjHfxINw88Q3Fh9ACStUmbexcInA9vF/cOKxGmNw6OmXQmrwc0F/gUMXZq5tHUb03acZrjB\ncOoRAzCzaWZ2XZq4JY5zvEZwRY0muJdGAf8Eukl6B/g18D+J+DcDYyTdZWZj4/nnopvnyjLzLhZ+\nBXCSpJGEqamLUSrdRkmjgBuAE2L4JYTv9Zakt4GLS6TtOKnx6c0dx3GcVHgNw3Ecx0mFGwzHcRwn\nFW4wHMdxnFS4wXAcx3FS4QbDcRzHSYUbDMdxHCcVbjAcx3GcVPwfdgRMuQOf1c4AAAAASUVORK5C\nYII=\n",
|
||
"text/plain": [
|
||
"<matplotlib.figure.Figure at 0x26a51af63c8>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Compute the moving average loss to smooth out the noise in SGD\n",
|
||
"plotdata[\"avgloss\"] = moving_average(plotdata[\"loss\"])\n",
|
||
"plotdata[\"avgerror\"] = moving_average(plotdata[\"error\"])\n",
|
||
"\n",
|
||
"# Plot the training loss and the training error\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"plt.figure(1)\n",
|
||
"plt.subplot(211)\n",
|
||
"plt.plot(plotdata[\"batchsize\"], plotdata[\"avgloss\"], 'b--')\n",
|
||
"plt.xlabel('Minibatch number')\n",
|
||
"plt.ylabel('Loss')\n",
|
||
"plt.title('Minibatch run vs. Training loss')\n",
|
||
"\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"plt.subplot(212)\n",
|
||
"plt.plot(plotdata[\"batchsize\"], plotdata[\"avgerror\"], 'r--')\n",
|
||
"plt.xlabel('Minibatch number')\n",
|
||
"plt.ylabel('Label Prediction Error')\n",
|
||
"plt.title('Minibatch run vs. Label Prediction Error')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Evaluation / Testing \n",
|
||
"\n",
|
||
"Now that we have trained the network, let us evaluate the trained network on the test data. This is done using `trainer.test_minibatch`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Average test error: 7.41%\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Read the training data\n",
|
||
"reader_test = create_reader(test_file, False, input_dim, num_output_classes)\n",
|
||
"\n",
|
||
"test_input_map = {\n",
|
||
" label : reader_test.streams.labels,\n",
|
||
" input : reader_test.streams.features,\n",
|
||
"}\n",
|
||
"\n",
|
||
"# Test data for trained model\n",
|
||
"test_minibatch_size = 512\n",
|
||
"num_samples = 10000\n",
|
||
"num_minibatches_to_test = num_samples // test_minibatch_size\n",
|
||
"test_result = 0.0\n",
|
||
"\n",
|
||
"for i in range(num_minibatches_to_test):\n",
|
||
" \n",
|
||
" # We are loading test data in batches specified by test_minibatch_size\n",
|
||
" # Each data point in the minibatch is a MNIST digit image of 784 dimensions \n",
|
||
" # with one pixel per dimension that we will encode / decode with the \n",
|
||
" # trained model.\n",
|
||
" data = reader_test.next_minibatch(test_minibatch_size,\n",
|
||
" input_map = test_input_map)\n",
|
||
"\n",
|
||
" eval_error = trainer.test_minibatch(data)\n",
|
||
" test_result = test_result + eval_error\n",
|
||
"\n",
|
||
"# Average of evaluation errors of all test minibatches\n",
|
||
"print(\"Average test error: {0:.2f}%\".format(test_result*100 / num_minibatches_to_test))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Note, this error is very comparable to our training error indicating that our model has good \"out of sample\" error a.k.a. generalization error. This implies that our model can very effectively deal with previously unseen observations (during the training process). This is key to avoid the phenomenon of overfitting."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We have so far been dealing with aggregate measures of error. Let us now get the probabilities associated with individual data points. For each observation, the `eval` function returns the probability distribution across all the classes. The classifier is trained to recognize digits, hence has 10 classes. First let us route the network output through a `softmax` function. This maps the aggregated activations across the network to probabilities across the 10 classes."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"out = C.softmax(z)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let us a small minibatch sample from the test data."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Read the data for evaluation\n",
|
||
"reader_eval = create_reader(test_file, False, input_dim, num_output_classes)\n",
|
||
"\n",
|
||
"eval_minibatch_size = 25\n",
|
||
"eval_input_map = {input: reader_eval.streams.features} \n",
|
||
"\n",
|
||
"data = reader_test.next_minibatch(eval_minibatch_size, input_map = test_input_map)\n",
|
||
"\n",
|
||
"img_label = data[label].asarray()\n",
|
||
"img_data = data[input].asarray()\n",
|
||
"predicted_label_prob = [out.eval(img_data[i]) for i in range(len(img_data))]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Find the index with the maximum value for both predicted as well as the ground truth\n",
|
||
"pred = [np.argmax(predicted_label_prob[i]) for i in range(len(predicted_label_prob))]\n",
|
||
"gtlabel = [np.argmax(img_label[i]) for i in range(len(img_label))]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Label : [4, 5, 6, 7, 8, 9, 7, 4, 6, 1, 4, 0, 9, 9, 3, 7, 8, 4, 7, 5, 8, 5, 3, 2, 2]\n",
|
||
"Predicted: [4, 6, 6, 7, 5, 8, 7, 4, 6, 1, 6, 0, 4, 9, 3, 7, 1, 2, 7, 5, 8, 6, 3, 2, 2]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"Label :\", gtlabel[:25])\n",
|
||
"print(\"Predicted:\", pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let us visualize some of the results"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Image Label: 8\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztXdly4loSLAwYDBivfSdi/v/vpr2wY8BmHm6kOlWqI7AN\nMrQyI05I0DYWNKmqU0tWY7vdmiAI9cLFT1+AIAjVQ8QXhBpCxBeEGkLEF4QaQsQXhBpCxBeEGkLE\nF4QaolXB31ChgCD8HBrRk7L4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI\n+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiC\nUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBD\niPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4\nglBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEO0fvoChGqx3W6P8rOffU0++uca\njUb28zjn5/a5xug1dv1+dG1l18qLf+7i4sIajUa2/GP/nqLnjg0RX8iwDym/ezPAa358fOQWnjOz\nkDA496/jz/n3o9fhn4+Iy9fir43X+/t7eGw0GtZqtZKr2Wxm1xStqiDi1xyeyBEpIpJ8FSAJE4bX\ndru1ZrOZIwM/bjQaGRlT1+bJxETDe2RC+6O/Jr82m42t12vbbDaF1Wg0rNPpZOvy8jJ33m63rdls\nWrPZzG4EzWbTzPKW/9gQ8WuMlJubIgQs8nfw8fERkgbPmVlGBk8OWEt/TXyd/Pt802CCeavtLXnq\n2nC+Wq2yI6/1em0XFxd2dXVlvV4vPHY6ncz6f3x8WKv1LwVxc6oKIn5NsYv0qfVdvL+/ZyTxJFqv\n12Zmpa4yiF92fbhZ8E0DRMM1RJ4Hbkq8/LWuVitbLpf29vaWHfm82WzaYDAorPV6nf2dy8tLu7y8\nLHgoVULErzlS+/rUfva7rj6Iv1qtMtLgfLVa2Xa7tXa7nVyNRiNnqffdY8O6NhoN22w2Bdcd5+v1\nunBNfFwul7ZYLMI1n8+t1WrZzc1Ntt7e3jLSNxqNgncCS99sNo8STE1BxK8x9iW934N/B5vNpmA1\neZlZZhHb7XZ2zsSO9t241kajYe1221qtVnazYKLhGrC8a88Ex+LHi8XCZrOZzWYzm8/n2TlWu922\nu7s7m8/ntlwuc6THjceTnr2RqiDi1xBlbj6vKKj1XeLDZX57ewut5na7zQXHmNjb7TYjPltpPm80\nGtlNg2MADL935+NqtUpadFj1yWRi0+k0PLbbbZvP5wVL32w27fLyMhfIA+lxQ5XFF5IoyztHz5VZ\ndU90EKXMIh6C+BGZcG5mhWg4L++q8zUy8b3HgHMzS5J+s9kkb0ie+BHpJ5NJ9rfa7bZ1Oh3rdrt2\ndXVlV1dX1u/37fLyMnPtW61Wdt2y+MJOlBF31zGVj45I76PZm83m21/QzWZTcO/Zld5utxlZo5Vy\n9bFgRVMxAlxDmasfXRcew71fLBaZK8+fC/9fsBeB1+12u9n18e9GnskxIeKfGXZF3X3AKyo08e6x\nt5wc1faPP0P86Iv8/v6eC5b5czPbGdz7SgEN9v24htRWIRXcwzkH9/Acfh/v16cF8Zq4gbTbbVut\nVrmtTJWkNxPxzxK+0MQH4SJrFuWno0KUKMXG5+/v79+69mOn87Cf9qk8HHENqZsjrH7q+rxHAOKz\nxeb/H45p4PewjZHFFz4FT3wmt7fS0Tl/sVP5anyp/drX4qe+xOwCRzei7XZbKNph8vL7j4p4OD3m\ni3eQK9/lLZUVGPnCncjV55sI/zzI3+l0shuG9xaqgoh/ZogCcZFL7r+cEYkjgqfccBy/Y/FT2QJe\nZlYo0+XH/DpRjCOqg/9Mye6uct1UyS6TNyI/SN/tdnP/LxxjkcUXShGVl3rLwuT1JOaKM195tuv5\nfYhf9gWOrLT/4vvmGt/dxqmvqIegrDPO/05ZqbJfuDFEsYVdFp9dfaT6cPOQxRd2gr+gZe5kKjLN\nC5Hp1IpeC/X0n73m6HHqGLWt8jEiiSe+P/fNL1EK1F9L2XnZzcf//0TE56Ag1ylUBRH/xOC/QP45\nuI1sjX3N+HdXZPnx5f0K8fm9mX2/53wXOauA76HHQvZgV2uu34ZU1ZUHiPgnhl3puqiyjFNMUeNI\nisTRY07ZNZvNLPftG132fS+78NkvPFvTKDBXBfmjPnoQ+OrqKmvMub6+tuFwaMPh0G5ubnLHwWBg\nV1dX1u12s1bdKskv4p8g2E30Ry4i8cfZbGbL5XJnAC+VquM9p5llxTAXFxdZ5VvVe9Gyz8anKXFT\nAI5xrT5r4FOHqNAD8a+vr7OGndvbWxsOh3Z9fW39ft96vZ51Oh0RX/hj0VIpucViYdPptLBQOopq\nshSxU2k0LCbOodpFUwT8yhcd6TZe+BuI6vOW4tDkZ+KzW4/zfr9v/X4/Iz2s/O3tbUZ8/Az680V8\nISz+YKs9n89tPB7bZDKx8XhcOF8sFskmFCZ3qohlu80r4Pi0WtV7UQ+u/PPpOcQfjk1+eEKtVqvQ\nE9Dr9UJX//b21u7u7uz6+jqr3RfxhQxRnTcH3WDdX19fbTQaZQuPF4tFcpvA5aG+3h/nIBOIztas\n3W5XLhjhsdlsCjl9fF4XFxfZ+zgW+b3FRzMOFqw5E59d/cFgUGg8QteeiF9zMPG9+MN0OrXxeGyj\n0cheXl7s5eXFXl9fs3MQP1Wjv4sA2NOb/Qnucaccqud+CpC3YlEL3Aw4yu7Jfygw8WHx0YXX7Xat\n1+uFxIfF7/f72Q2Dm4lQjlwVRPwTA1ePscWHAIQn/vPzsz0/P9vT05M9Pz/bfD4PC2O4yCRSr2Uy\nIZIPi9/pdDLdONS7/xRWq1WO9NxYwwU+xyR/ivggfZnF7/f7ST1AEf8vRqqYBfBFH4jYI4jHLj72\n9bzPR2trKteN7jX+0nFUutPpWK/XyxYIj4WbwrE/H3/E+dvbW3ZzxOfkifNdsnt5bp+n588E0Xmc\ng+SI3g8Gg+xmgM+zrCqxKoj4P4CyAhRYL1h4DtzB0r++vtp4PLbpdFpQewG5U4MboAQTiVXAeiG/\njAAUnx/b4kels3yEYg28IbjJ3w08+s8oavLBjRFEjhYCeTc3NzYYDKzX6xVy9V7nn/9+VRDxfwgc\nWOMve0R8kB3WHjcDCEJExE81qbTb7YzI2Jf6x3Bdo2MVxOfiHB+raDabuS3Qd4nvyccR+2gv3u12\nMyuO6D2fDwaDnMX3ufpoqs5PZEpE/B+AJztbtRTxsZ8fjUY5cUcmPgpvoiKTyGLxnpRdUbb+fGTN\nuGMhJRbC0lrYBnlCfRZRHT+IH71/bIOY5HDno4XPtNvtZpJbp0B6MxG/ckSkZ4tWRvynpycbjUYF\niWfU0HOjB+/fudAEJaWwSLyGw6FdXV0VVG9YR+7YxPcdh/7czLLPB9f1FYufagDC5xZte7rdbq44\nhxduAnwDjXL1uEFFtf5VQsT/AfgOu88QfzweF/rkUZmHqH2UawZJer1eRnJUk2FPent7m0XuU+vY\neXzev0eVh9vt1haLRZZe9C70ZxCRn2sXQHgO4vFNE8U5OGfX3o/Q8kU6qa7BqiDi/xDY4nuRh13E\nT0loQWwSFoRTTt5Vvbm5sfv7+8Lq9XrhFqGqlBNX5kW9Bh8fHzabzTL3GRb/M55Iqm2XiZ9y7UF2\n33yDBY8p0vs7haAeIOL/AHy02iu7pIj/+/dvm0wmSREL1KqbFV19zjWjceTu7s4eHx/t8fHRfv36\nZb9+/bJer5cMDFbhkuLGl+ow3Gw22b75Oxbfky8iPiw+3HvOyfPi57vdbvLzO4WSZ0DErxiRiAYv\naMzP5/PcQhfefD7PXodfE2AXn/epcFkjy8Vf4n6/X5rH/u4Xd5cIx8XFRfbZcEEOKvS4K4/7C1Kf\nh4cnol9l1XdMdHbzeZ/f6XSyv1N2/GmI+BUD6SqWZOKFdB2IDjcXX3Ig5S4y6WGteHGHGAJPl5eX\nSVf0GEgJZW63W1utVrmbnL/pjUajLLsxmUzCOgZ8znzkzyrVWYfiHN87z0eQvKy77tRIHkHErxhM\n/GhqCxPfp+pg2cq+YJFr73vDOfqM4hIfGT9WysnHNnx34HK5zFUq+jUej+3l5cVGo1FWwLRcLgtK\nt3xkoHLRp+mwer1ewarz4jQd3zj9duMU9vFlEPErRkR8zsuniO975aN0EFfmcX09iM+uKivAeIvv\nXx+PD/0ZRNr/HNtItR7zc7ssvj9n4nPUHtshdu+jyH2/389VMsLic5PNqZPeTMSvHCniQ0gDX2g/\npinVWeeLQuDGwuJjr8r7VG/xU8Tn4yHfP1t5LxDCvQlR2zE+J/yMJz7+Bv89/3nh5shqOV41J6px\nQC+9T9elAoynSnozEb9yMPG56w4WLbL42ON7y+Wj0bBmID5bfK4hx5ccOecy4vvzQ7z/SHMAR/95\ncMvxy8uLTSaTgs5gVMCEv+XhLT5H7Nmd94U5WNga+XJen+o8ZdKbifiVo8ziw7qVufp+j++j1N7i\nM/Fh8dm1hcWv0mKVSYOnSpWfnp7s6enJptNpOBjEu/opMPE5VceFTKmS3Ovr60yTwK+qAqOHgohf\nMb5K/H0sPled4YvNe3xE9Lk5JxXcO+b7j4aBIGcPV58t/vPzs/3+/dt+//5ts9ksOfByHwVgdvVh\n8eEN3d/fZz3zqfp7TOw9RqqzSoj4FYMDW5G77/f30TRWbsBhrfZWq1VovIk6yXwpKbuqh8jT+1Zj\nfhypCvGKAni85vN5chKPWXkvfaPRyKU4vTYeAp8+BQoPCd7R3wARv2JEYprQ1EPxDh77oYzQxOMy\nXL+Gw6Hd3d1l1h1k56AUk/3QbqrvQfDnEBdJrfF4XMjTc4DTtzMDKPTx4iL+OBgM7P7+PvuMEOz0\n3XS7AnfnDhG/YkSBLbj8+PJHs9fZorGr6rvHQHxO24H4XN/ONe6HktHG+2M33i8IhnIKkxcLieLn\ncBOMiO/JzzMAIsGRwWBgd3d3uc8IWQ58Tr6wqaptUJUQ8X8A3uJ74vvBivyFZ4sfVeZ5i4+OsejL\nzMQ/5B7V39h81D5VnINUnXftfUrTE94s31YLncBIZAT1DBzMww2SXXrOz3+13/+UIeJXDF937l39\n+Xyei1azFn6K+L7gZJcLm4pIH+r9RcMiUZLMsYxowcpzqa539fF3PNjiI53p9+ogPu/p/efkPQW5\n+sK3Ee3xvcX3E3S8q+/z0CzhzF/mlKt/zEEZfGNjhWDELlCkhJoFH8yDa+8n+kb9ChHY4vtyZb5B\n+sf4nDhP772ivwkifsXYh/jRIAyu0/eVZ0jV3d/f5yw9R/I5uOeLfvj8EO/PW3y8N5+29FV5o9Eo\n83h4m4DHUUrTn3uL74uXcDOMFryiKCh4bum6XRDxfwC7XP3UpFyWx+aSXBD/4eEhq8zjtB7vXdvt\ndkiYY1h8X6uA/Tvy837xQBDOz/sbYKrIyJcs4zPgPD0KmPwWAOcpj0gWX/gWODjFX2yuWY9kpWHp\noqKdsj27l/nah+D+Z/zjKKqO87e3t0KQjgN30UwA5Osx9DN10+M6Bn7vfPQluBzIQ8CT5wXgyAE9\n9oB+Svf+2BDxfxBl5bHIS+MIpCr/8MU1s3BCLuvQ73NN/suOc75hRcM3ka7DilpseT/vqxMjojOi\ndB0/7vf7WVYDY6s42Il9POfqfcrulCSyjgUR/4fhv1ye6B6e+NgiwFpBzMITH8/tQ/yo+YevzU/h\n5SOX3aYW3xAi4pfl6Tm+wcpC3FbrtfB8TUNZd90pSWAfEyL+CSCyLpG1Nysq+EBxFl9eED8i/Xq9\n3pv4Kd04VN9FQpi4nrLKPJYWw9GrDKWsvdm/Fh/E5+65qKEmOmJuAHsJpzbsogqI+BWj7IvEVpVJ\nz+ccNYfFn81mWeT54+MjJD7Od03C8QEttoTNZtM2m00u3eYX192D3PyYc/p8nsrTl1n8wWBQkAhn\nrYHoiGlA3FLLrv6p6N4fGyL+D2DXl6jM3Y8sPkiPf/Oa9F8hfqr1FGIZKcvOvfJ8xHlqG+KJ798z\nXx+08TibgYXuQ1/K7FuQuckpCooeOttxahDxfwhlHWRR8wng9/hMei6ciaz9Pnt83/nnz9frdSFg\nx49BdLboID5b9ihll1IZYsDiozgHxP/Pf/5j//zzjw2Hw3D0FY5s2VPy139zUA8Q8U8UqUAfCM4z\n4VOKNl7d5jPET03SWa1WWS7eH8fjcRasw2IlYZQg++BdKpAXHXmPj8EgDw8P9s8//9h///tfGw6H\nhUGX7NJj8EbKlf9bie4h4v8AYGWigRdvb285C4g0GYOFLFarVSHN5gd08A1gH1ff9/jzWq/XoQAm\nztmdjxaGfkTZA/z9ssyC76FnZSHk6cs8lr+tEOerEPErRlRdxiq4ZlaaLmOLxKW/3mJ5pRuQf9eo\nqVQ/O87X63XBzfetxDzskiP1/P598U30XHQcDof2+PiYtdXyvDq+Vp+eE/IQ8SsGE9+PaALxfZoM\nVorLdqNqPHaZvVeAvfYu4uP6osAeovqpoB6781FBDl6fJcL8UAvfMuzPr6+vM+J7mXAvLuILcoQ/\nEPErBqekuHX0+vo6G4+1XC6t3W7bcrnMkd6nu3gb4GWouBfg8vLSlstlVodehrJU3sXFRaZ9H6Xx\nOB+fGnPFgqCRgpAf0e3XYDCwh4eHzOKzWjBrB0YBO5H/D0T8itFoFCWwQfzlcmlmlvsCm/1p6rm4\nuMjlt6ObgBexjCa27rq+KNqN51AnEI3qhsy1jzWkLD5ai1k0AxF4nvDLjxHJj4jvhUXOTfm2Soj4\nFcPv8UF8iE3gxuBJv16vrdlsFsiOI77ccMfx837tU0MQleoy8X3cIdIOiBpsuK3Yj6L2zTKpI7rt\neMJNNMMueg/CH4j4FYNdXSY+SGP2J0Dn21ubzWZSfspb1LI89S54svgagyjrELn1/hxgV5/jHNxC\nzIU3/BjBUFbKZVefif83V959FyJ+xYgsflS1tt1uc6T301pSBINV/U6OelcBS1ke3pfa+mPk6qOc\nFvX07AFEy98QcFPAtib1XoQ/EPErRkR8r5vvCc0y1a1Wq2BtvYQ1w1tbPMd/h5879Hvlwhszyyw9\n3HevGwjiw/r75WfXcWut8vT7Q8SvGJHF85NwzfL5dC7wWSwWSelqPz8uInY0iIIfH/I9RjECnukH\nwvNAC6Tn2OVnq86ttF4eXNZ9f4j4PwAObjHpfSotGoCJyjhfjstBtjI3PNqTc0rwu5Y/ygZwOpBH\nU/OEWk98P+aLo/6RGKZI/zmI+BWDXX3e18ML4Co5donh6nK+nGvisdjq++0Ct/TyMrODWXu8R9/h\nh/dWRnzMrQPBo8WKOVEfvbAfRPyK4V19sz+k900lTPrZbGaDwSDZ7ooja/ZFC2277Bp/fHxkjw9l\n8VONPvtY/NR4MBbN8PX32tt/DiJ+xeA8tlkxvYcvN5N+sVhkpI/64Dudji0WC2u321kjTGp23Xq9\nzhGF03OHsphRExLWPsQvq9xDjYOvKpS1/xxE/IoBUrTb7ZylBznh0oL0vred9eym02ku4NVsNrOS\n2VSe3RcHcVXgod5fRHzc0LBX98RHh91gMEi2BPN+PqrOE/n3h4hfMXx3mo+8c1ONb2mFqi5kqznY\nhddcLpehHj3WarUyM8t5A4ceGOHFPLgufx+LH1UccvMNPsfoKOwHEb9i7CIYB/hAGJap8tJRPojW\n6XQKk3j4HOIdZsVy4EO9P1wTByexeLpPlKu/urpKNghpL384iPgnCCaPV53tdDqFqTKcBVgul2F+\nH6k+3/EHD+NQefCo+5ALcqLptN5ridx3ufKHhYh/guBcOLYDeN635/rgIHTtogYaeAxmedKzdt8h\nrj0SGuG9PGvc++0K7+FVmHM8iPgnBh8cY9Kz4IaZ5fbQ2DtHKra82u12YZqt7wP47vX7a+KOutvb\n20JNvu+lV9Du+BDxTxBM/larlbOAXPvO+X4/WTZ1jrZdDhYy4Q5x7b4XIZpYC/37yNVXW+3xIeKf\nGPhLz49RZIPFpOd9vB8v7bMDjUYjN7eei2KOscfnmnwQ34+l9q4+v2+R/jgQ8U8QbO0gV8UluK1W\nKxxYybLbkbT129ubmVlu2CYr1xza4sPV53l2t7e3uWBf5OrjdaKjcBiI+CcIFr0wy7fMpnrfsRC5\n5/p9frzdbnMTdtniHwI+2BhZfN915+fX4TMQjgcR/8SwjwhGGTabTSEQ6It1fF6crequ1/cuuE+5\nYV/PQhoo1MHiJhxusVWevjqI+GeOaLwW6+jD4vPwSh5n5cdT7wJc+VQBEarwELlHoQ5uBkx2n8IT\nqoOIf6ZISVtFQzQ8+UF8aOB/hfi++QYrIj728qjeQ0Win1Ir8lcHEf+MESntlBGfp9n6iTefJT72\n8H6xAi5H730EP+qnF6qDiH+GSAX2UmOz2NrzjHoeT71vHz53F/J+HhF8ztN7i4+f8/30h24SEnZD\nxD9jlElrsULvMVx9LwmGvTwTH0U62OPD4nPjjYQ0fgYi/pkiRXo/My/l6n81uGdmOeJ3u92caCb6\n6r3FZ1ffl+VKOqt6iPhniCivz9p60R4/iup7i78P2OJjTj3LY3tXn2vyYfXxOtFRqAYi/pnBu/Re\nT8+79V6xB2OtfYAvsvgRKVlYwyvpREE9RPI5ii/8PPS/cGZgjbyo7342m9l4PC5d0+k0R34/0CNq\nkMHiQRi+QIfde5YEO2QTkHAYiPhnCNTke2391Wpls9ksk+aaTCYF0k8mk0Jaj7X9fSeg18f3U3Ai\n4nPenqfciPinAxH/zMDVeRy4Q/R+Op0mLf1oNLLpdJqba48gX2TxuSoP50x6T3yQv6zPXjgNiPhn\nBuzvWUiDA3fT6TRn7b3Vn81mhVZdL+XlVXI57+5VgDmVB4vvx13hNYTTgYh/ZogsPgfxIrLD2oP4\nfr69D+55a4+AHoiPgF3K1eebg/b4pwkR/8yQsviI2jPp/Q1gNBrZYrFIau6bpXXxsVgiO3L1MQmH\ntfRF/NODiH9mgMVHcI9FNTioF1n98Xhsy+WydKimJ77Xxmdrjhw+W/zr6+twfJaCe6cFEf/EkKqe\n8004ID0q8tjiT6fTXM7eF+6UgWW9PNE7nU6uMIer8rjtNhqEIdKfFkT8E0Wq7Rak9xV5sPggvK/O\n+2yBDu/h+Xhzc2N3d3dZWS5y9hzEi2bWi/inBRH/BFHWfZdqt4XFR3EOxmlz8A6IymVxDqFM1OD7\n/bsnPtfg++m1kUKPcBoQ8U8UqT041+EjsOeJP5vNMuJHlXlmsYSWWZ74vV4vq8HHQhMOW3zo5jHx\nfSGQSH9aEPFPEN7aoxb/sxZ/uVwma/FT2nns6oP4UMe9u7vLtdymXH3p4p8+RPwThSc9N+Ls2uPD\nE0jt8b2LzwT1Fh/u/f39vd3f39vNzU2uaAd7f5+vly7+aUPEPzGkeuyjXvuUxWcdfXb1GZ6YIKzf\n48Pi393d2ePjow2Hw5x+Hlfp+bSd9venCxH/ROF77MsktTzx/bw8tvgpFx/k964+B/RAfMy6R7oP\n5740V2Q/XYj4FWMflRsmup91DyuPdB3OOVfvW3W9wo4v0uFinahIh137fr+fq+TjAh1JaJ0PRPwT\nA+/jo+ULc7yMlid7FNTjclxfXecbbFj7vmxuvaz7eUHEPzFwLT6q83i/jqg95+p5Lw/SY/lxW2bl\n2vjYu6em3ESkF84PIv6JIdVvjxVZfB6Djf28t/i+155183ih7NYTH0E7Feb8HRDxTwgcyWeLz/p5\nXKDjXX0/HIOzAQAH8bx+HuSykZf3xN+VqtMN4Hwg4p8IuGgnsvgpwUw/HGOz2RTSgSmLD+LDwrMS\nLtRzyojPEOnPCyL+CcDLZXO/vRfa8K4+q+mA+HidaMSWmeWID8vOPfZePWeXxRfpzw8i/g/Dp/fK\n6vGjdlvv6u8ah5Wy+FFBThTVj8gvnB9E/B+Ct8g4+l57VteBqg7v81kfHwG9VKqt0WjkVHI5P89i\nmei1j6ryWHFX5D9fiPg/AL/35vOI8JPJxEajkb2+vtpoNLLJZJLrwGOV3EgSm1e3282p5fBYa36M\n51gjP9VvL/KfH0T8iuGj7f6x39ezjt7r62s2EAPufiSP7WWz+Pzq6io39YbHXnGrLaw+LL+Xyo7I\nL5wPRPwfQDQGCwU3XkATXXej0ShbLKedsvjI03ORTqvVyunjoa/+9vbWbm9vM+Kz4g4W7/cjd99M\nkf1zgoj/A/CNN1yPz65+yuL7IRqr1SokPkftsfyAS/TZY0EeG+49nyPCz4U8XmJLOA+I+BXDW3s/\nA8933fk9/ng8Dmv4PfG9YCbIyxZ/OBxmxH94eLD7+/usCQcWnmW1YfGjdl4R/7wg4v8AUq22iOaz\nZLa3+JPJJFeLz7X52+02NwyDq/JYDht7fFh8iGw8PDxYv98vxAZ4qV7/74CI/wNg/Ty2+Kk+ewzD\neH19tel0WsgEeF187+pzVR5cfS+pdX9/b4+Pj9br9cKafFXs/V0Q8Q+MqCCH4ctx/WJd/JQ2fhnY\n1WfSI2fPuvisj886+cLfDxH/CCiTx+ZqPL/m87mNRiN7fn7OJtsiZYca/F1Ayy0q8kB6n7Pn/Dzv\n3YV6QMQ/AlJNMphr73XyuAHHE3+xWBRm2JchIj5r593c3OQq80B8RebrBRH/wGBL74UyuUAHEXse\nbono/cvLS8Hir9frL1t8uPIo1OEZ9j5aL9QDIv4R4KP2XhobFp+j9SjHxThrDL5Ekc5XXH0Ia/jc\nfWTxRfx6QcQ/Arw0NqffELWHxYeFf3l5yVx8bAO4LPere3wvk826+ClZbOHvh4h/BKSKdLzFh6v/\n+vpqT09P9vv372yUNS8Q/6t7fJ6Iw7r43uIL9YGIf2BEgzCY/NEe/+XlxZ6enux///tfrjIP7bas\nrLMLZXtJkxiPAAAGzUlEQVR8uPosoa09fj0h4n8Su8jnNfNYHQcqucjVYx+Phc47VsuNdPG5uMaf\ns/59tHq9XliGq6h+vSDiHxjI1Xt1XOTux+OxPT8/Z+W3SNl5FR2fDWDSsxY+D7VotVo2HA7t4eGh\nMMPet9ZGI62F+kDEPzBg7bnengN1nK4bj8ehoEaZLn6j8e9EW7jqvpNuOBza4+NjgfhQzvXEj+rv\nhb8fIv6BAGIy8X1bLXfZweIz8VkeO7L2Zv9afJ5th+Adlp9sm8rZi/T1hoh/ADAxUZ23Wq1ssVhk\nKTvO1Y/H42x5i48gnic+N+HA4mOopZfNuru7ywlrRDn7SEVHqA9E/G8iasrxFh/u/dPTk41Go9xk\nW1+dx4E8v8z+7PFB/Ovr65yCDlJ2XkqLXf1Ii09Wv14Q8b+BiPR+j+9TdijQQccd6+P7tF1q7h1S\ndUx89NOjMs9337HFT7XdCvWBiH8glO3x0Xjz+/dvG41GuRHXfOSBGP51AVh87PFB/IeHB/v165fd\n3t4W9v08FgvEN7PcUcSvF0R8h33y9L7zjl1zWHHuvuMA33g8znL7vLxSboqUfvoN99nDzY+GY/Ao\nLIYIX0+I+J8AyB0V2OCIElzo32MPH8242zXDPtK2Q7Udp/O8oCYKdLwGvodIX1+I+DvgPQCea8eV\neXjMarhI1/k5dyyu6avydg3E8OW2eOxJj1x9Kngn0tcbIn4J/DBLszzxfSPNcrnMtdb6ybYshc0i\nmWUTbf1i0peRv6xAR6QXRPwEItL7SbbosuPlR13B1S+rzvMW3wtm4og0HhM95epHZbl4fUEQ8XfA\nD7Zk4rMENhbr37Orzxbfd+7xPp9dfVhsdt8j0keuvt/jK3IvMET8EkTTbDld55V0WAabJ9v6PX5Z\ngY5ZPMqah2OkSO9n2rO7L9ILDBE/QMrNj1x9FsmM5LMii+9vKHweufqcwttnj+9r8VWWK3iI+AlE\nFhlDMHjwhR9zldrbc66+DKyJz1NwsPz8el+cw/t6tvYK7AkMET9AWZEOBDZSwy254265XGaEj2Sz\nIjJiT8/FOXy8ubmxh4cHu729teFwmKvDj3rstbcXIoj4CaTm2MPNj0ZdpaL4SN35fXxUnddsNjNL\nz+Ou/NgrJj5abjlvH5FfNwABEPEdvLX3EfjI4vNwy8lkklPf4WEYfh/viY+WW5bF5uGW6L7zE3HQ\ngMNqueq6E8og4geIyB8JZnqLj5Ser8OPLD7grXJk8Xmi7XA4LMy7g8VPufkiv+Ah4ieQ0sVnBdzU\nHp+n37KqTmTxfT1+GfF//fqVyWPz8nt8kV3YBRE/gA/qeYsfufq8x0dRDlfm+Wh+RHqO6LOrz223\nw+EwzOWzq4/Xj46CYCbiF7DLzfcjriOL74tyoiIdMyuQfl+LX6ayq9JcYR+I+CXwOXx/I2C3HzeD\n1WqV/b7vhovKcbnQptlsZuIZ0MGHig6i+oPBoPA7u9pvBcFDxD8CUnt4nnLD1pqPiOAjiNfr9ZJ5\nekXtha9CxD8CWDjDk5Q77KLuupubG7u7uwuJH5XhqlBH+ApE/AMjtW9nIQ1fhssLwTyIZnqLL2sv\nHAIi/hHg03PcV89aedFADATzvMXnOnyp5ArfhYh/BOwS0uA6fF+Mg0o9VOZFe/xUua8g7AsR/8CI\novccwGOLH03B4WPK1Y9IL/ILn4GIfwRwBJ+tPQZhQPqaZbG5Dp+9gMji898RhK9AxC9BqqQWajhM\n4MFgYDc3NznpKy+F1W63rd/v54gOt56Jj30/ynHhLShPLxwKIr5D5KpzxR3KaRGEg5RWo/HvaKvl\nclmoquMjfjdag8Egq7/3MlpS0BEOCRE/AJN/u93m3OvLy8ss+s6kx1irt7e3nN4dl9MiuIeqvGiB\n9Jh+w9NtBeFQEPEDMPH9c2zx39/fc6Tv9/u2Wq1y0Xy/8PupBdceBT1crSeLLxwKIr5DRHp+Hhaf\nLT2s+HA4tM1mUyja4bp6BPhSi6W0ZfGFY0HET4DJj3O06l5dXeVIPxgMbLFY2GKxsPf391xlnT9n\nQkeLtwi+AUcWXzgURPwAPk+O7jyuloPl7vf7uS49eAIpFZzUaKyUJLbksYVjoLFrLPQBcPQ/cEj4\nz8Nr30ciHbwA35LL5/su/nl/Lgh7IvzCiPiC8HcjJL4iRoJQQ4j4glBDiPiCUENUEdVXNEoQTgyy\n+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiC\nUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBDiPiCUEOI+IJQQ4j4glBD/B98ADVKMyzD/gAAAABJ\nRU5ErkJggg==\n",
|
||
"text/plain": [
|
||
"<matplotlib.figure.Figure at 0x26a51b65be0>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Plot a random image\n",
|
||
"sample_number = 5\n",
|
||
"plt.imshow(img_data[sample_number].reshape(28,28), cmap=\"gray_r\")\n",
|
||
"plt.axis('off')\n",
|
||
"\n",
|
||
"img_gt, img_pred = gtlabel[sample_number], pred[sample_number]\n",
|
||
"print(\"Image Label: \", img_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"**Exploration Suggestion**\n",
|
||
"- Try exploring how the classifier behaves with different parameters, e.g. changing the `minibatch_size` parameter from 25 to say 64 or 128. What happens to the error rate? How does the error compare to the logistic regression classifier?\n",
|
||
"- Try increasing the number of sweeps\n",
|
||
"- Try changing the network to reduce the training error rate? When do you see *overfitting* happening?"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"anaconda-cloud": {},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.5.2"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 1
|
||
}
|