290 строки
11 KiB
Plaintext
290 строки
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"source": [
|
|
"# CNTK 103 Part A: MNIST Data Loader\n",
|
|
"\n",
|
|
"This tutorial is targeted to individuals who are new to CNTK and to machine learning. We assume you have completed or are familiar with CNTK 101 and 102. In this tutorial, you will train a feed forward network based simple model to recognize handwritten digits. This is the first example, where we will train and evaluate a neural network based model on read real world data. \n",
|
|
"\n",
|
|
"CNTK 103 tutorial is divided into two parts:\n",
|
|
"- Part A: Familiarize with the [MNIST][] database that will be used later in the tutorial\n",
|
|
"- [Part B](https://github.com/Microsoft/CNTK/blob/v2.0.beta9.0/Tutorials/CNTK_103B_MNIST_FeedForwardNetwork.ipynb): We will use the feedforward classifier used in CNTK 102 to classify digits in MNIST data set.\n",
|
|
"\n",
|
|
"[MNIST]: http://yann.lecun.com/exdb/mnist/\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Import the relevant modules to be used later\n",
|
|
"from __future__ import print_function\n",
|
|
"import gzip\n",
|
|
"import matplotlib.image as mpimg\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import shutil\n",
|
|
"import struct\n",
|
|
"import sys\n",
|
|
"\n",
|
|
"try: \n",
|
|
" from urllib.request import urlretrieve \n",
|
|
"except ImportError: \n",
|
|
" from urllib import urlretrieve\n",
|
|
"\n",
|
|
"# Config matplotlib for inline plotting\n",
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data download\n",
|
|
"\n",
|
|
"We will download the data into local machine. The MNIST database is a standard handwritten digits that has been widely used for training and testing of machine learning algorithms. It has a training set of 60,000 images and a test set of 10,000 images with each image being 28 x 28 pixels. This set is easy to use visualize and train on any computer."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Functions to load MNIST images and unpack into train and test set.\n",
|
|
"# - loadData reads image data and formats into a 28x28 long array\n",
|
|
"# - loadLabels reads the corresponding labels data, 1 for each image\n",
|
|
"# - load packs the downloaded image and labels data into a combined format to be read later by \n",
|
|
"# CNTK text reader \n",
|
|
"\n",
|
|
"def loadData(src, cimg):\n",
|
|
" print ('Downloading ' + src)\n",
|
|
" gzfname, h = urlretrieve(src, './delete.me')\n",
|
|
" print ('Done.')\n",
|
|
" try:\n",
|
|
" with gzip.open(gzfname) as gz:\n",
|
|
" n = struct.unpack('I', gz.read(4))\n",
|
|
" # Read magic number.\n",
|
|
" if n[0] != 0x3080000:\n",
|
|
" raise Exception('Invalid file: unexpected magic number.')\n",
|
|
" # Read number of entries.\n",
|
|
" n = struct.unpack('>I', gz.read(4))[0]\n",
|
|
" if n != cimg:\n",
|
|
" raise Exception('Invalid file: expected {0} entries.'.format(cimg))\n",
|
|
" crow = struct.unpack('>I', gz.read(4))[0]\n",
|
|
" ccol = struct.unpack('>I', gz.read(4))[0]\n",
|
|
" if crow != 28 or ccol != 28:\n",
|
|
" raise Exception('Invalid file: expected 28 rows/cols per image.')\n",
|
|
" # Read data.\n",
|
|
" res = np.fromstring(gz.read(cimg * crow * ccol), dtype = np.uint8)\n",
|
|
" finally:\n",
|
|
" os.remove(gzfname)\n",
|
|
" return res.reshape((cimg, crow * ccol))\n",
|
|
"\n",
|
|
"def loadLabels(src, cimg):\n",
|
|
" print ('Downloading ' + src)\n",
|
|
" gzfname, h = urlretrieve(src, './delete.me')\n",
|
|
" print ('Done.')\n",
|
|
" try:\n",
|
|
" with gzip.open(gzfname) as gz:\n",
|
|
" n = struct.unpack('I', gz.read(4))\n",
|
|
" # Read magic number.\n",
|
|
" if n[0] != 0x1080000:\n",
|
|
" raise Exception('Invalid file: unexpected magic number.')\n",
|
|
" # Read number of entries.\n",
|
|
" n = struct.unpack('>I', gz.read(4))\n",
|
|
" if n[0] != cimg:\n",
|
|
" raise Exception('Invalid file: expected {0} rows.'.format(cimg))\n",
|
|
" # Read labels.\n",
|
|
" res = np.fromstring(gz.read(cimg), dtype = np.uint8)\n",
|
|
" finally:\n",
|
|
" os.remove(gzfname)\n",
|
|
" return res.reshape((cimg, 1))\n",
|
|
"\n",
|
|
"def try_download(dataSrc, labelsSrc, cimg):\n",
|
|
" data = loadData(dataSrc, cimg)\n",
|
|
" labels = loadLabels(labelsSrc, cimg)\n",
|
|
" return np.hstack((data, labels))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Download the data\n",
|
|
"\n",
|
|
"The MNIST data is provided as train and test set. Training set has 60000 images while the test set has 10000 images. Let us download the data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# URLs for the train image and labels data\n",
|
|
"url_train_image = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'\n",
|
|
"url_train_labels = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'\n",
|
|
"num_train_samples = 60000\n",
|
|
"\n",
|
|
"print(\"Downloading train data\")\n",
|
|
"train = try_download(url_train_image, url_train_labels, num_train_samples)\n",
|
|
"\n",
|
|
"\n",
|
|
"url_test_image = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'\n",
|
|
"url_test_labels = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'\n",
|
|
"num_test_samples = 10000\n",
|
|
"\n",
|
|
"print(\"Downloading test data\")\n",
|
|
"test = try_download(url_test_image, url_test_labels, num_test_samples)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Visualize the data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot a random image\n",
|
|
"sample_number = 5001\n",
|
|
"plt.imshow(train[sample_number,:-1].reshape(28,28), cmap=\"gray_r\")\n",
|
|
"plt.axis('off')\n",
|
|
"print(\"Image Label: \", train[sample_number,-1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Save the images\n",
|
|
"\n",
|
|
"Save the images in a local directory. While saving the data we flatten the images to a vector (28x28 image pixels becomes an array of length 784 data points) and the labels are encoded as [1-hot][] encoding (label of 3 with 10 digits becomes `0010000000`.\n",
|
|
"\n",
|
|
"[1-hot]: https://en.wikipedia.org/wiki/One-hot"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save the data files into a format compatible with CNTK text reader\n",
|
|
"def savetxt(filename, ndarray):\n",
|
|
" dir = os.path.dirname(filename)\n",
|
|
"\n",
|
|
" if not os.path.exists(dir):\n",
|
|
" os.makedirs(dir)\n",
|
|
"\n",
|
|
" if not os.path.isfile(filename):\n",
|
|
" print(\"Saving\", filename )\n",
|
|
" with open(filename, 'w') as f:\n",
|
|
" labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str)))\n",
|
|
" for row in ndarray:\n",
|
|
" row_str = row.astype(str)\n",
|
|
" label_str = labels[row[-1]]\n",
|
|
" feature_str = ' '.join(row_str[:-1])\n",
|
|
" f.write('|labels {} |features {}\\n'.format(label_str, feature_str))\n",
|
|
" else:\n",
|
|
" print(\"File already exists\", filename)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save the train and test files (prefer our default path for the data)\n",
|
|
"data_dir = os.path.join(\"..\", \"Examples\", \"Image\", \"DataSets\", \"MNIST\")\n",
|
|
"if not os.path.exists(data_dir):\n",
|
|
" data_dir = os.path.join(\"data\", \"MNIST\")\n",
|
|
"\n",
|
|
"print ('Writing train text file...')\n",
|
|
"savetxt(os.path.join(data_dir, \"Train-28x28_cntk_text.txt\"), train)\n",
|
|
"\n",
|
|
"print ('Writing test text file...')\n",
|
|
"savetxt(os.path.join(data_dir, \"Test-28x28_cntk_text.txt\"), test)\n",
|
|
"\n",
|
|
"print('Done')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Suggested Explorations**\n",
|
|
"\n",
|
|
"One can do data manipulations to improve the performance of a machine learning system. I suggest you first use the data generated so far and run the classifier in CNTK 103 Part B. Once you have a baseline with classifying the data in its original form, now use the different data manipulation techniques to further improve the model.\n",
|
|
"\n",
|
|
"There are several ways data alterations can be performed. CNTK readers automate a lot of these actions for you. However, to get a feel for how these transforms can impact training and test accuracies, I strongly encourage individuals to try one or more of data perturbation.\n",
|
|
"\n",
|
|
"- Shuffle the training data (rows to create a different). Hint: Use `permute_indices = np.random.permutation(train.shape[0])`. Then run Part B of the tutorial with this newly permuted data.\n",
|
|
"- Adding noise to the data can often improves [generalization error][]. You can augment the training set by adding noise (generated with numpy, hint: use `numpy.random`) to the training images. \n",
|
|
"- Distort the images with [affine transformation][] (translations or rotations)\n",
|
|
"\n",
|
|
"[generalization error]: https://en.wikipedia.org/wiki/Generalization_error\n",
|
|
"[affine transformation]: https://en.wikipedia.org/wiki/Affine_transformation\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"anaconda-cloud": {},
|
|
"kernelspec": {
|
|
"display_name": "Python [default]",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.4.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|