Updates model driver notebook
This commit is contained in:
Родитель
45de6b6382
Коммит
7c183698f1
|
@ -3,43 +3,31 @@
|
|||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Develop Model Driver"
|
||||
]
|
||||
"source": "# Develop Model Driver"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this notebook, we will develop the API that will call our model. This module initializes the model, transforms the input so that it is in the appropriate format and defines the scoring method that will produce the predictions. The API will expect the input to be in JSON format. Once a request is received, the API will convert the json encoded request body into the image format. There are two main functions in the API. The first function loads the model and returns a scoring function. The second function process the images and uses the first function to score them."
|
||||
]
|
||||
"source": "In this notebook, we will develop the API that will call our model. This module initializes the model, transforms the input so that it is in the appropriate format and defines the scoring method that will produce the predictions. The API will expect the input to be in JSON format. Once a request is received, the API will convert the json encoded request body into the image format. There are two main functions in the API. The first function loads the model and returns a scoring function. The second function process the images and uses the first function to score them."
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"from testing_utilities import img_url_to_json\n",
|
||||
"from pprint import pprint"
|
||||
]
|
||||
"source": "import logging\nfrom testing_utilities import img_url_to_json\nfrom pprint import pprint"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"logging.basicConfig(level=logging.DEBUG)"
|
||||
]
|
||||
"source": "logging.basicConfig(level=logging.DEBUG)"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We use the writefile magic to write the contents of the below cell to driver.py which includes the driver methods."
|
||||
]
|
||||
"source": "We use the writefile magic to write the contents of the below cell to driver.py which includes the driver methods."
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -51,168 +39,39 @@
|
|||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Overwriting driver.py\n"
|
||||
]
|
||||
"text": "Overwriting driver.py\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%writefile driver.py \n",
|
||||
"import base64\n",
|
||||
"import json\n",
|
||||
"import logging\n",
|
||||
"import os\n",
|
||||
"import timeit as t\n",
|
||||
"from io import BytesIO\n",
|
||||
"\n",
|
||||
"import PIL\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torchvision\n",
|
||||
"from PIL import Image\n",
|
||||
"from torchvision import models, transforms\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"_LABEL_FILE = os.getenv(\"LABEL_FILE\", \"synset.txt\")\n",
|
||||
"_NUMBER_RESULTS = 3\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _create_label_lookup(label_path):\n",
|
||||
" with open(label_path, \"r\") as f:\n",
|
||||
" label_list = [l.rstrip() for l in f]\n",
|
||||
"\n",
|
||||
" def _label_lookup(*label_locks):\n",
|
||||
" return [label_list[l] for l in label_locks]\n",
|
||||
"\n",
|
||||
" return _label_lookup\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _load_model():\n",
|
||||
" # Load the model\n",
|
||||
" model = models.resnet152(pretrained=True)\n",
|
||||
" model = model.cuda()\n",
|
||||
" softmax = nn.Softmax(dim=1).cuda()\n",
|
||||
" model = model.eval()\n",
|
||||
"\n",
|
||||
" preprocess_input = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" torchvision.transforms.Resize((224, 224), interpolation=PIL.Image.BICUBIC),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def predict_for(image):\n",
|
||||
" image = preprocess_input(image)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" image = image.unsqueeze(0)\n",
|
||||
" image_gpu = image.type(torch.float).cuda()\n",
|
||||
" outputs = model(image_gpu)\n",
|
||||
" pred_proba = softmax(outputs)\n",
|
||||
" return pred_proba.cpu().numpy().squeeze()\n",
|
||||
"\n",
|
||||
" return predict_for\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _base64img_to_pil_image(base64_img_string):\n",
|
||||
" if base64_img_string.startswith(\"b'\"):\n",
|
||||
" base64_img_string = base64_img_string[2:-1]\n",
|
||||
" base64Img = base64_img_string.encode(\"utf-8\")\n",
|
||||
"\n",
|
||||
" # Preprocess the input data\n",
|
||||
" decoded_img = base64.b64decode(base64Img)\n",
|
||||
" img_buffer = BytesIO(decoded_img)\n",
|
||||
"\n",
|
||||
" # Load image with PIL (RGB)\n",
|
||||
" pil_img = Image.open(img_buffer).convert(\"RGB\")\n",
|
||||
" return pil_img\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_scoring_func(label_path=_LABEL_FILE):\n",
|
||||
" logger = logging.getLogger(\"model_driver\")\n",
|
||||
"\n",
|
||||
" start = t.default_timer()\n",
|
||||
" labels_for = _create_label_lookup(label_path)\n",
|
||||
" predict_for = _load_model()\n",
|
||||
" end = t.default_timer()\n",
|
||||
"\n",
|
||||
" loadTimeMsg = \"Model loading time: {0} ms\".format(round((end - start) * 1000, 2))\n",
|
||||
" logger.info(loadTimeMsg)\n",
|
||||
"\n",
|
||||
" def call_model(image, number_results=_NUMBER_RESULTS):\n",
|
||||
" pred_proba = predict_for(image).squeeze()\n",
|
||||
" selected_results = np.flip(np.argsort(pred_proba), 0)[:number_results]\n",
|
||||
" labels = labels_for(*selected_results)\n",
|
||||
" return list(zip(labels, pred_proba[selected_results].astype(np.float64)))\n",
|
||||
"\n",
|
||||
" return call_model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_model_api():\n",
|
||||
" logger = logging.getLogger(\"model_driver\")\n",
|
||||
" scoring_func = create_scoring_func()\n",
|
||||
"\n",
|
||||
" def process_and_score(images_dict, number_results=_NUMBER_RESULTS):\n",
|
||||
" start = t.default_timer()\n",
|
||||
"\n",
|
||||
" results = {}\n",
|
||||
" for key, base64_img_string in images_dict.items():\n",
|
||||
" rgb_image = _base64img_to_pil_image(base64_img_string)\n",
|
||||
" results[key] = scoring_func(rgb_image, number_results=number_results)\n",
|
||||
"\n",
|
||||
" end = t.default_timer()\n",
|
||||
"\n",
|
||||
" logger.info(\"Predictions: {0}\".format(results))\n",
|
||||
" logger.info(\"Predictions took {0} ms\".format(round((end - start) * 1000, 2)))\n",
|
||||
" return (results, \"Computed in {0} ms\".format(round((end - start) * 1000, 2)))\n",
|
||||
"\n",
|
||||
" return process_and_score\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def version():\n",
|
||||
" return torch.__version__"
|
||||
]
|
||||
"source": "%%writefile driver.py \nimport base64\nimport json\nimport logging\nimport os\nimport timeit as t\nfrom io import BytesIO\n\nimport PIL\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom PIL import Image\nfrom torchvision import models, transforms\n\n\n\n_LABEL_FILE = os.getenv(\"LABEL_FILE\", \"synset.txt\")\n_NUMBER_RESULTS = 3\n\n\ndef _create_label_lookup(label_path):\n with open(label_path, \"r\") as f:\n label_list = [l.rstrip() for l in f]\n\n def _label_lookup(*label_locks):\n return [label_list[l] for l in label_locks]\n\n return _label_lookup\n\n\ndef _load_model():\n # Load the model\n model = models.resnet152(pretrained=True)\n model = model.cuda()\n softmax = nn.Softmax(dim=1).cuda()\n model = model.eval()\n\n preprocess_input = transforms.Compose(\n [\n torchvision.transforms.Resize((224, 224), interpolation=PIL.Image.BICUBIC),\n transforms.ToTensor(),\n transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n ]\n )\n\n def predict_for(image):\n image = preprocess_input(image)\n with torch.no_grad():\n image = image.unsqueeze(0)\n image_gpu = image.type(torch.float).cuda()\n outputs = model(image_gpu)\n pred_proba = softmax(outputs)\n return pred_proba.cpu().numpy().squeeze()\n\n return predict_for\n\n\ndef _base64img_to_pil_image(base64_img_string):\n if base64_img_string.startswith(\"b'\"):\n base64_img_string = base64_img_string[2:-1]\n base64Img = base64_img_string.encode(\"utf-8\")\n\n # Preprocess the input data\n decoded_img = base64.b64decode(base64Img)\n img_buffer = BytesIO(decoded_img)\n\n # Load image with PIL (RGB)\n pil_img = Image.open(img_buffer).convert(\"RGB\")\n return pil_img\n\n\ndef create_scoring_func(label_path=_LABEL_FILE):\n logger = logging.getLogger(\"model_driver\")\n\n start = t.default_timer()\n labels_for = _create_label_lookup(label_path)\n predict_for = _load_model()\n end = t.default_timer()\n\n loadTimeMsg = \"Model loading time: {0} ms\".format(round((end - start) * 1000, 2))\n logger.info(loadTimeMsg)\n\n def call_model(image, number_results=_NUMBER_RESULTS):\n pred_proba = predict_for(image).squeeze()\n selected_results = np.flip(np.argsort(pred_proba), 0)[:number_results]\n labels = labels_for(*selected_results)\n return list(zip(labels, pred_proba[selected_results].astype(np.float64)))\n\n return call_model\n\n\ndef get_model_api():\n logger = logging.getLogger(\"model_driver\")\n scoring_func = create_scoring_func()\n\n def process_and_score(images_dict, number_results=_NUMBER_RESULTS):\n start = t.default_timer()\n\n results = {}\n for key, base64_img_string in images_dict.items():\n rgb_image = _base64img_to_pil_image(base64_img_string)\n results[key] = scoring_func(rgb_image, number_results=number_results)\n\n end = t.default_timer()\n\n logger.info(\"Predictions: {0}\".format(results))\n logger.info(\"Predictions took {0} ms\".format(round((end - start) * 1000, 2)))\n return (results, \"Computed in {0} ms\".format(round((end - start) * 1000, 2)))\n\n return process_and_score\n\n\ndef version():\n return torch.__version__"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's test the module."
|
||||
]
|
||||
"source": "Let's test the module."
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We run the file driver.py which will bring everything into the context of the notebook."
|
||||
]
|
||||
"source": "We run the file driver.py which will bring everything into the context of the notebook."
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run driver.py"
|
||||
]
|
||||
"source": "%run driver.py"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We will use the same Lynx image we used ealier to check that our driver works as expected."
|
||||
]
|
||||
"source": "We will use the same Lynx image we used ealier to check that our driver works as expected."
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"IMAGEURL = \"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Lynx_lynx_poing.jpg/220px-Lynx_lynx_poing.jpg\""
|
||||
]
|
||||
"source": "IMAGEURL = \"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Lynx_lynx_poing.jpg/220px-Lynx_lynx_poing.jpg\""
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -222,14 +81,10 @@
|
|||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:model_driver:Model loading time: 4070.21 ms\n"
|
||||
]
|
||||
"text": "INFO:model_driver:Model loading time: 3972.62 ms\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predict_for = get_model_api()"
|
||||
]
|
||||
"source": "predict_for = get_model_api()"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -239,23 +94,10 @@
|
|||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13\n",
|
||||
"DEBUG:PIL.PngImagePlugin:STREAM b'iCCP' 41 292\n",
|
||||
"DEBUG:PIL.PngImagePlugin:iCCP profile name b'ICC Profile'\n",
|
||||
"DEBUG:PIL.PngImagePlugin:Compression method 0\n",
|
||||
"DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 345 65536\n",
|
||||
"INFO:model_driver:Predictions: {'image': [('n02127052 lynx, catamount', 0.9965722560882568), ('n02128757 snow leopard, ounce, Panthera uncia', 0.0013256857637315989), ('n02128385 leopard, Panthera pardus', 0.0009192737634293735)]}\n",
|
||||
"INFO:model_driver:Predictions took 92.28 ms\n"
|
||||
]
|
||||
"text": "DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13\nDEBUG:PIL.PngImagePlugin:STREAM b'iCCP' 41 292\nDEBUG:PIL.PngImagePlugin:iCCP profile name b'ICC Profile'\nDEBUG:PIL.PngImagePlugin:Compression method 0\nDEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 345 65536\nINFO:model_driver:Predictions: {'image': [('n02127052 lynx, catamount', 0.9965722560882568), ('n02128757 snow leopard, ounce, Panthera uncia', 0.0013256857637315989), ('n02128385 leopard, Panthera pardus', 0.0009192737634293735)]}\nINFO:model_driver:Predictions took 84.51 ms\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jsonimg = img_url_to_json(IMAGEURL)\n",
|
||||
"json_load_img = json.loads(jsonimg)\n",
|
||||
"body = json_load_img[\"input\"]\n",
|
||||
"resp = predict_for(body)"
|
||||
]
|
||||
"source": "jsonimg = img_url_to_json(IMAGEURL)\njson_load_img = json.loads(jsonimg)\nbody = json_load_img[\"input\"]\nresp = predict_for(body)"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -265,33 +107,24 @@
|
|||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'image': [('n02127052 lynx, catamount', 0.9965722560882568),\n",
|
||||
" ('n02128757 snow leopard, ounce, Panthera uncia',\n",
|
||||
" 0.0013256857637315989),\n",
|
||||
" ('n02128385 leopard, Panthera pardus', 0.0009192737634293735)]}\n"
|
||||
]
|
||||
"text": "{'image': [('n02127052 lynx, catamount', 0.9965722560882568),\n ('n02128757 snow leopard, ounce, Panthera uncia',\n 0.0013256857637315989),\n ('n02128385 leopard, Panthera pardus', 0.0009192737634293735)]}\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pprint(resp[0])"
|
||||
]
|
||||
"source": "pprint(resp[0])"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we can move on to [building our docker image](02_BuildImage.ipynb)."
|
||||
]
|
||||
"source": "Next, we can move on to [building our docker image](02_BuildImage.ipynb)."
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"jupytext_format_version": "1.3",
|
||||
"jupytext_formats": "py:light",
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:AKSDeploymentPytorch]",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "conda-env-AKSDeploymentPytorch-py"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
|
Загрузка…
Ссылка в новой задаче