301 строка
9.0 KiB
Plaintext
301 строка
9.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Develop Model Driver\n",
|
|
"In this notebook we will develop the API that will call our model. We need it to initialise the model and transform the input from the Flask app so that it is in the appropriate format to call the model. We expect the input to be JSON that will have the image encoded as a base64 string. The code below uses the writefile magic to write the contents of the cell to the file driver.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import logging\n",
|
|
"from testing_utilities import img_url_to_json"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Overwriting driver.py\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%writefile driver.py\n",
|
|
"import base64\n",
|
|
"import json\n",
|
|
"import logging\n",
|
|
"import os\n",
|
|
"import timeit as t\n",
|
|
"from io import BytesIO\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import tensorflow as tf\n",
|
|
"from PIL import Image, ImageOps\n",
|
|
"from tensorflow.contrib.slim.nets import resnet_v1\n",
|
|
"\n",
|
|
"_MODEL_FILE = os.getenv('MODEL_FILE', \"resnet_v1_152.ckpt\")\n",
|
|
"_LABEL_FILE = os.getenv('LABEL_FILE', \"synset.txt\")\n",
|
|
"_NUMBER_RESULTS = 3\n",
|
|
"\n",
|
|
"\n",
|
|
"def _create_label_lookup(label_path):\n",
|
|
" with open(label_path, 'r') as f:\n",
|
|
" label_list = [l.rstrip() for l in f]\n",
|
|
" \n",
|
|
" def _label_lookup(*label_locks):\n",
|
|
" return [label_list[l] for l in label_locks]\n",
|
|
" \n",
|
|
" return _label_lookup\n",
|
|
"\n",
|
|
"\n",
|
|
"def _load_tf_model(checkpoint_file):\n",
|
|
" # Placeholder\n",
|
|
" input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')\n",
|
|
" \n",
|
|
" # Load the model\n",
|
|
" sess = tf.Session()\n",
|
|
" arg_scope = resnet_v1.resnet_arg_scope()\n",
|
|
" with tf.contrib.slim.arg_scope(arg_scope):\n",
|
|
" logits, _ = resnet_v1.resnet_v1_152(input_tensor, num_classes=1000, is_training=False, reuse=tf.AUTO_REUSE)\n",
|
|
" probabilities = tf.nn.softmax(logits)\n",
|
|
" \n",
|
|
" saver = tf.train.Saver()\n",
|
|
" saver.restore(sess, checkpoint_file)\n",
|
|
" \n",
|
|
" def predict_for(image):\n",
|
|
" pred, pred_proba = sess.run([logits,probabilities], feed_dict={input_tensor: image})\n",
|
|
" return pred_proba\n",
|
|
" \n",
|
|
" return predict_for\n",
|
|
"\n",
|
|
"\n",
|
|
"def _base64img_to_numpy(base64_img_string):\n",
|
|
" if base64_img_string.startswith('b\\''):\n",
|
|
" base64_img_string = base64_img_string[2:-1]\n",
|
|
" base64Img = base64_img_string.encode('utf-8')\n",
|
|
"\n",
|
|
" # Preprocess the input data \n",
|
|
" startPreprocess = t.default_timer()\n",
|
|
" decoded_img = base64.b64decode(base64Img)\n",
|
|
" img_buffer = BytesIO(decoded_img)\n",
|
|
"\n",
|
|
" # Load image with PIL (RGB)\n",
|
|
" pil_img = Image.open(img_buffer).convert('RGB')\n",
|
|
" pil_img = ImageOps.fit(pil_img, (224, 224), Image.ANTIALIAS)\n",
|
|
" return np.array(pil_img, dtype=np.float32)\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_scoring_func(model_path=_MODEL_FILE, label_path=_LABEL_FILE):\n",
|
|
" logger = logging.getLogger(\"model_driver\")\n",
|
|
" \n",
|
|
" start = t.default_timer()\n",
|
|
" labels_for = _create_label_lookup(label_path)\n",
|
|
" predict_for = _load_tf_model(model_path)\n",
|
|
" end = t.default_timer()\n",
|
|
"\n",
|
|
" loadTimeMsg = \"Model loading time: {0} ms\".format(round((end-start)*1000, 2))\n",
|
|
" logger.info(loadTimeMsg)\n",
|
|
" \n",
|
|
" def call_model(image_array, number_results=_NUMBER_RESULTS):\n",
|
|
" pred_proba = predict_for(image_array).squeeze()\n",
|
|
" selected_results = np.flip(np.argsort(pred_proba), 0)[:number_results]\n",
|
|
" labels = labels_for(*selected_results)\n",
|
|
" return list(zip(labels, pred_proba[selected_results].astype(np.float64)))\n",
|
|
" return call_model\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_model_api():\n",
|
|
" logger = logging.getLogger(\"model_driver\")\n",
|
|
" scoring_func = create_scoring_func()\n",
|
|
" \n",
|
|
" def process_and_score(images_dict, number_results=_NUMBER_RESULTS):\n",
|
|
" start = t.default_timer()\n",
|
|
"\n",
|
|
" results = {}\n",
|
|
" for key, base64_img_string in images_dict.items():\n",
|
|
" rgb_image = _base64img_to_numpy(base64_img_string)\n",
|
|
" batch_image = np.expand_dims(rgb_image, 0)\n",
|
|
" results[key]=scoring_func(batch_image, number_results=_NUMBER_RESULTS)\n",
|
|
" \n",
|
|
" end = t.default_timer()\n",
|
|
"\n",
|
|
" logger.info(\"Predictions: {0}\".format(results))\n",
|
|
" logger.info(\"Predictions took {0} ms\".format(round((end-start)*1000, 2)))\n",
|
|
" return (results, 'Computed in {0} ms'.format(round((end-start)*1000, 2)))\n",
|
|
" return process_and_score\n",
|
|
"\n",
|
|
"def version():\n",
|
|
" return tf.__version__\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"logging.basicConfig(level=logging.DEBUG)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We run the file driver.py which will be everything into the context of the notebook."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%run driver.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We will use the same Lynx image we used ealier to check that our driver works as expected."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"IMAGEURL = \"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Lynx_lynx_poing.jpg/220px-Lynx_lynx_poing.jpg\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"jsonimg = img_url_to_json(IMAGEURL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"json_lod= json.loads(jsonimg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"INFO:tensorflow:Restoring parameters from resnet_v1_152.ckpt\n",
|
|
"INFO:model_driver:Model loading time: 17208.69 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predict_for = get_model_api()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13\n",
|
|
"DEBUG:PIL.PngImagePlugin:STREAM b'iCCP' 41 292\n",
|
|
"DEBUG:PIL.PngImagePlugin:iCCP profile name b'ICC Profile'\n",
|
|
"DEBUG:PIL.PngImagePlugin:Compression method 0\n",
|
|
"DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 345 65536\n",
|
|
"INFO:model_driver:Predictions: {'image': [('n02127052 lynx, catamount', 0.9974517226219177), ('n02128385 leopard, Panthera pardus', 0.0015077503630891442), ('n02128757 snow leopard, ounce, Panthera uncia', 0.0005164773901924491)]}\n",
|
|
"INFO:model_driver:Predictions took 1916.85 ms\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"output = predict_for(json_lod['input'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The output of our prediction function is JSON that will be returned to our Flask app. It looks like our model predicted Lynx with over 99% probability."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'[{\"image\": [[\"n02127052 lynx, catamount\", 0.9974517226219177], [\"n02128385 leopard, Panthera pardus\", 0.0015077503630891442], [\"n02128757 snow leopard, ounce, Panthera uncia\", 0.0005164773901924491]]}, \"Computed in 1916.85 ms\"]'"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"json.dumps(output)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can move onto [building our docker image](02_BuildImage.ipynb)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [conda env:AKSDeployment]",
|
|
"language": "python",
|
|
"name": "conda-env-AKSDeployment-py"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.5.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|