584 строки
18 KiB
Plaintext
584 строки
18 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Develop Model Driver\n",
|
|
"\n",
|
|
"In this notebook, we will develop the API that will call our model. This module initializes the model, transforms the input so that it is in the appropriate format and defines the scoring method that will produce the predictions. The API will expect the input to be passed as an image. Once a request is received, the API will convert load the image preprocess it and pass it to the model. There are two main functions in the API: init() and run(). The init() function loads the model and returns a scoring function. The run() function processes the images and uses the first function to score them.\n",
|
|
"\n",
|
|
" Note: Always make sure you don't have any lingering notebooks running (Shutdown previous notebooks). Otherwise it may cause GPU memory issue."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%reload_ext autoreload\n",
|
|
"%autoreload 2\n",
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from azureml.core import Workspace\n",
|
|
"from azureml.core.model import Model\n",
|
|
"from dotenv import set_key, find_dotenv\n",
|
|
"import logging\n",
|
|
"from testing_utilities import get_auth"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env_path = find_dotenv(raise_error_if_not_found=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Write and save driver script\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 105,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Overwriting driver.py\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%writefile driver.py\n",
|
|
"\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torchvision\n",
|
|
"from torchvision.models.detection.mask_rcnn import MaskRCNN\n",
|
|
"from torchvision.models.detection.backbone_utils import resnet_fpn_backbone\n",
|
|
"import numpy as np\n",
|
|
"from pathlib import *\n",
|
|
"from azureml.contrib.services.aml_request import rawhttp\n",
|
|
"from azureml.core.model import Model\n",
|
|
"from toolz import compose\n",
|
|
"import numpy as np\n",
|
|
"import timeit as t\n",
|
|
"from PIL import Image, ImageOps\n",
|
|
"import logging\n",
|
|
"\n",
|
|
"\n",
|
|
"def _tensor_to_list(image_tensor):\n",
|
|
" #return image_tensor.cpu().numpy()\n",
|
|
" return image_tensor.tolist()\n",
|
|
"\n",
|
|
"def _image_ref_to_pil_image(image_ref):\n",
|
|
" \"\"\" Load image with PIL (RGB)\n",
|
|
" \"\"\"\n",
|
|
" return Image.open(image_ref).convert(\"RGB\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def _pil_to_tensor(pil_image):\n",
|
|
" img = np.array(pil_image, dtype = np.float32) \n",
|
|
" img_tensor = torchvision.transforms.functional.to_tensor(img)/255\n",
|
|
" return img_tensor\n",
|
|
"\n",
|
|
"\n",
|
|
"def _create_scoring_func():\n",
|
|
" \"\"\" Initialize MaskRCNN ResNet 50 Model\n",
|
|
" \"\"\"\n",
|
|
" logger = logging.getLogger(\"model_driver\")\n",
|
|
" start = t.default_timer()\n",
|
|
" model_name = \"maskrcnn_resnet50_model\"\n",
|
|
" model_path = Model.get_model_path(model_name)\n",
|
|
" checkpoint = model_path\n",
|
|
" # Model class must be defined somewhere\n",
|
|
" num_classes=91\n",
|
|
" pretrained_backbone = False\n",
|
|
" backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)\n",
|
|
" model = MaskRCNN(backbone, num_classes)\n",
|
|
"\n",
|
|
" #model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)\n",
|
|
" #model.load_state_dict(torch.load('model_best.pth.tar')['state_dict'])\n",
|
|
" #model = torch.load(model_path)\n",
|
|
" \n",
|
|
" state_dict = torch.load(checkpoint)\n",
|
|
" model.load_state_dict(state_dict)\n",
|
|
" \n",
|
|
" \n",
|
|
" model.eval()\n",
|
|
" end = t.default_timer()\n",
|
|
"\n",
|
|
" loadTimeMsg = \"Model loading time: {0} ms\".format(round((end - start) * 1000, 2))\n",
|
|
" logger.info(loadTimeMsg)\n",
|
|
"\n",
|
|
" def call_model(img_array_list):\n",
|
|
" #img_array = np.stack(img_array_list)\n",
|
|
" #img_array = preprocess_input(img_array)\n",
|
|
" # preds = model.predict(img_array)\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" preds = model(img_array_list)\n",
|
|
" #print(type(preds))\n",
|
|
" #print(preds)\n",
|
|
" for j in range(len(preds)):\n",
|
|
" del preds[j]['masks'] \n",
|
|
" transformed_preds = []\n",
|
|
" for i in range(len(preds)): \n",
|
|
" transformed_preds.append({\n",
|
|
" key: _tensor_to_list(img_tensor) for key, img_tensor in preds[i].items()\n",
|
|
" })\n",
|
|
" return transformed_preds\n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
" return call_model\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_model_api():\n",
|
|
" logger = logging.getLogger(\"model_driver\")\n",
|
|
" scoring_func = _create_scoring_func()\n",
|
|
"\n",
|
|
" def process_and_score(images_dict):\n",
|
|
" \"\"\" Classify the input using the loaded model\n",
|
|
" \"\"\"\n",
|
|
" start = t.default_timer()\n",
|
|
" logger.info(\"Scoring {} images\".format(len(images_dict)))\n",
|
|
" transform_input = compose(_pil_to_tensor, _image_ref_to_pil_image)\n",
|
|
" transformed_dict = {\n",
|
|
" key: transform_input(img_ref) for key, img_ref in images_dict.items()\n",
|
|
" }\n",
|
|
" #pdb.set_trace()\n",
|
|
" dict_values_list = list(transformed_dict.values())\n",
|
|
" preds = scoring_func(dict_values_list)\n",
|
|
" #preds = dict(zip(transformed_dict.keys(), preds))\n",
|
|
" end = t.default_timer()\n",
|
|
"\n",
|
|
" logger.info(\"Predictions: {0}\".format(preds))\n",
|
|
" logger.info(\"Predictions took {0} ms\".format(round((end - start) * 1000, 2)))\n",
|
|
" return (preds, \"Computed in {0} ms\".format(round((end - start) * 1000, 2)))\n",
|
|
"\n",
|
|
" return process_and_score\n",
|
|
"\n",
|
|
"\n",
|
|
"def init():\n",
|
|
" \"\"\" Initialise the model and scoring function\n",
|
|
" \"\"\"\n",
|
|
" global process_and_score\n",
|
|
" process_and_score = get_model_api()\n",
|
|
"\n",
|
|
"\n",
|
|
"@rawhttp\n",
|
|
"def run(request):\n",
|
|
" \"\"\" Make a prediction based on the data passed in using the preloaded model\n",
|
|
" \"\"\"\n",
|
|
" return process_and_score(request.files)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Test the Driver"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 110,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"logging.basicConfig(level=logging.DEBUG)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 106,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%run driver.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get workspace\n",
|
|
"# Load existing workspace from the config file info.\n",
|
|
"\n",
|
|
"ws = Workspace.from_config(auth=get_auth())\n",
|
|
"print(ws.name, ws.resource_group, ws.location, ws.subscription_id, sep=\"\\n\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model_path = Model.get_model_path(\"maskrcnn_resnet50_model\", _workspace=ws)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'azureml-models/maskrcnn_resnet50_model/1/maskrcnn_resnet50.pth'"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Test the driver¶\n",
|
|
"We test the driver by passing data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(1920, 1080)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"img_path = \"./test_image.jpg\"\n",
|
|
"print(Image.open(img_path).size)\n",
|
|
"img = Image.open(img_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 107,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Always make sure you don't have any lingering notebooks running. Otherwise it may cause GPU memory issue.\n",
|
|
"process_and_score = get_model_api()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 108,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"resp = process_and_score({\"test\": open(img_path, \"rb\")})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 109,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"([{'boxes': [[678.3322143554688,\n",
|
|
" 680.0374145507812,\n",
|
|
" 784.1846923828125,\n",
|
|
" 993.3047485351562],\n",
|
|
" [136.69444274902344, 633.5426635742188, 274.71533203125, 926.4990234375],\n",
|
|
" [233.74720764160156,\n",
|
|
" 581.7813720703125,\n",
|
|
" 382.95440673828125,\n",
|
|
" 861.6013793945312],\n",
|
|
" [1395.167236328125,\n",
|
|
" 440.8146057128906,\n",
|
|
" 1538.195556640625,\n",
|
|
" 639.5924682617188],\n",
|
|
" [1422.93359375, 579.0057373046875, 1476.787353515625, 638.4216918945312],\n",
|
|
" [699.564697265625, 731.795166015625, 745.90185546875, 826.1213989257812],\n",
|
|
" [925.3984375, 340.40972900390625, 1026.978515625, 490.4237060546875],\n",
|
|
" [1003.208740234375,\n",
|
|
" 397.6317443847656,\n",
|
|
" 1058.7530517578125,\n",
|
|
" 498.1560363769531],\n",
|
|
" [529.6865844726562,\n",
|
|
" 275.9363098144531,\n",
|
|
" 913.7566528320312,\n",
|
|
" 472.7071533203125],\n",
|
|
" [298.8797912597656,\n",
|
|
" 584.5956420898438,\n",
|
|
" 382.6814880371094,\n",
|
|
" 807.8283081054688],\n",
|
|
" [933.8543090820312,\n",
|
|
" 439.1072082519531,\n",
|
|
" 1008.7671508789062,\n",
|
|
" 496.53955078125],\n",
|
|
" [4.950125217437744,\n",
|
|
" 567.5128784179688,\n",
|
|
" 73.67479705810547,\n",
|
|
" 749.3277587890625],\n",
|
|
" [945.1486206054688,\n",
|
|
" 391.22332763671875,\n",
|
|
" 1017.757080078125,\n",
|
|
" 496.1236267089844],\n",
|
|
" [947.085693359375,\n",
|
|
" 444.74627685546875,\n",
|
|
" 981.1513061523438,\n",
|
|
" 489.8131408691406],\n",
|
|
" [689.1828002929688,\n",
|
|
" 730.1939086914062,\n",
|
|
" 756.8470458984375,\n",
|
|
" 860.362548828125],\n",
|
|
" [1405.2069091796875,\n",
|
|
" 512.3316650390625,\n",
|
|
" 1485.7989501953125,\n",
|
|
" 635.4192504882812],\n",
|
|
" [686.844482421875, 729.1159057617188, 726.2559814453125, 829.8857421875],\n",
|
|
" [682.5338745117188,\n",
|
|
" 732.0889282226562,\n",
|
|
" 707.0296630859375,\n",
|
|
" 832.9822387695312],\n",
|
|
" [526.6405639648438,\n",
|
|
" 273.91778564453125,\n",
|
|
" 924.1195068359375,\n",
|
|
" 474.7964782714844],\n",
|
|
" [117.77389526367188,\n",
|
|
" 509.304931640625,\n",
|
|
" 161.4928741455078,\n",
|
|
" 613.563232421875],\n",
|
|
" [938.2327270507812,\n",
|
|
" 329.28369140625,\n",
|
|
" 1113.7261962890625,\n",
|
|
" 497.2240295410156],\n",
|
|
" [259.53515625, 442.64373779296875, 385.9668273925781, 530.3612670898438],\n",
|
|
" [247.56719970703125,\n",
|
|
" 616.0457153320312,\n",
|
|
" 381.5702819824219,\n",
|
|
" 851.794189453125],\n",
|
|
" [954.4630737304688,\n",
|
|
" 456.65557861328125,\n",
|
|
" 1008.6592407226562,\n",
|
|
" 496.7553405761719],\n",
|
|
" [125.06944274902344,\n",
|
|
" 539.5031127929688,\n",
|
|
" 162.098388671875,\n",
|
|
" 612.8021240234375],\n",
|
|
" [1333.5140380859375,\n",
|
|
" 478.6513977050781,\n",
|
|
" 1366.9879150390625,\n",
|
|
" 566.387939453125],\n",
|
|
" [535.1708374023438,\n",
|
|
" 278.2537841796875,\n",
|
|
" 910.2845458984375,\n",
|
|
" 466.73895263671875],\n",
|
|
" [7.870540618896484, 544.68701171875, 135.5785369873047, 738.8515625],\n",
|
|
" [1461.7086181640625,\n",
|
|
" 445.4164123535156,\n",
|
|
" 1540.7529296875,\n",
|
|
" 607.0694580078125],\n",
|
|
" [868.0377807617188,\n",
|
|
" 74.55133819580078,\n",
|
|
" 887.7618408203125,\n",
|
|
" 109.45865631103516],\n",
|
|
" [308.97320556640625,\n",
|
|
" 655.6863403320312,\n",
|
|
" 378.05908203125,\n",
|
|
" 831.0111083984375],\n",
|
|
" [1060.7030029296875,\n",
|
|
" 296.0729064941406,\n",
|
|
" 1170.833251953125,\n",
|
|
" 345.16497802734375],\n",
|
|
" [910.4449462890625,\n",
|
|
" 338.066162109375,\n",
|
|
" 985.2666625976562,\n",
|
|
" 471.2206115722656],\n",
|
|
" [831.786865234375, 206.3505401611328, 860.1435546875, 239.9379425048828],\n",
|
|
" [680.18701171875, 908.08837890625, 723.345947265625, 949.937744140625],\n",
|
|
" [0.0, 815.1854858398438, 218.9480438232422, 1063.6282958984375],\n",
|
|
" [691.5162963867188,\n",
|
|
" 736.5338134765625,\n",
|
|
" 759.6444702148438,\n",
|
|
" 898.4368286132812],\n",
|
|
" [281.6488342285156, 644.4691772460938, 366.0682373046875, 779.4921875],\n",
|
|
" [905.4010009765625,\n",
|
|
" 444.86456298828125,\n",
|
|
" 957.723388671875,\n",
|
|
" 469.62005615234375],\n",
|
|
" [957.407958984375,\n",
|
|
" 453.3540344238281,\n",
|
|
" 1008.3375244140625,\n",
|
|
" 493.2886962890625],\n",
|
|
" [246.2247314453125,\n",
|
|
" 616.4740600585938,\n",
|
|
" 387.1805419921875,\n",
|
|
" 850.2962036132812],\n",
|
|
" [200.32151794433594,\n",
|
|
" 660.9036254882812,\n",
|
|
" 213.2830047607422,\n",
|
|
" 685.2865600585938],\n",
|
|
" [684.9480590820312, 927.5275268554688, 764.8531494140625, 994.39892578125],\n",
|
|
" [485.0767822265625,\n",
|
|
" 279.13385009765625,\n",
|
|
" 944.2658081054688,\n",
|
|
" 476.6548156738281]],\n",
|
|
" 'labels': [1,\n",
|
|
" 1,\n",
|
|
" 1,\n",
|
|
" 64,\n",
|
|
" 86,\n",
|
|
" 27,\n",
|
|
" 64,\n",
|
|
" 62,\n",
|
|
" 72,\n",
|
|
" 1,\n",
|
|
" 64,\n",
|
|
" 72,\n",
|
|
" 64,\n",
|
|
" 64,\n",
|
|
" 27,\n",
|
|
" 64,\n",
|
|
" 27,\n",
|
|
" 27,\n",
|
|
" 15,\n",
|
|
" 1,\n",
|
|
" 64,\n",
|
|
" 15,\n",
|
|
" 19,\n",
|
|
" 51,\n",
|
|
" 62,\n",
|
|
" 1,\n",
|
|
" 33,\n",
|
|
" 72,\n",
|
|
" 64,\n",
|
|
" 10,\n",
|
|
" 2,\n",
|
|
" 15,\n",
|
|
" 64,\n",
|
|
" 11,\n",
|
|
" 41,\n",
|
|
" 15,\n",
|
|
" 31,\n",
|
|
" 27,\n",
|
|
" 62,\n",
|
|
" 47,\n",
|
|
" 18,\n",
|
|
" 43,\n",
|
|
" 41,\n",
|
|
" 7],\n",
|
|
" 'scores': [0.9994587302207947,\n",
|
|
" 0.9982571005821228,\n",
|
|
" 0.9812719225883484,\n",
|
|
" 0.9731340408325195,\n",
|
|
" 0.835928201675415,\n",
|
|
" 0.7480331063270569,\n",
|
|
" 0.7479689121246338,\n",
|
|
" 0.5302860140800476,\n",
|
|
" 0.4127267003059387,\n",
|
|
" 0.3182488679885864,\n",
|
|
" 0.2993265688419342,\n",
|
|
" 0.289673775434494,\n",
|
|
" 0.22936685383319855,\n",
|
|
" 0.19749070703983307,\n",
|
|
" 0.19524291157722473,\n",
|
|
" 0.19428908824920654,\n",
|
|
" 0.19411492347717285,\n",
|
|
" 0.1929921954870224,\n",
|
|
" 0.1731988489627838,\n",
|
|
" 0.15991298854351044,\n",
|
|
" 0.14675544202327728,\n",
|
|
" 0.14254669845104218,\n",
|
|
" 0.1270696073770523,\n",
|
|
" 0.11369671672582626,\n",
|
|
" 0.10950600355863571,\n",
|
|
" 0.10708904266357422,\n",
|
|
" 0.10192873328924179,\n",
|
|
" 0.10060916095972061,\n",
|
|
" 0.09780096262693405,\n",
|
|
" 0.09209641069173813,\n",
|
|
" 0.08996262401342392,\n",
|
|
" 0.08869230002164841,\n",
|
|
" 0.07615663856267929,\n",
|
|
" 0.07077847421169281,\n",
|
|
" 0.07028283178806305,\n",
|
|
" 0.06470400840044022,\n",
|
|
" 0.06275784969329834,\n",
|
|
" 0.06206080690026283,\n",
|
|
" 0.06033644452691078,\n",
|
|
" 0.05446040257811546,\n",
|
|
" 0.05442744493484497,\n",
|
|
" 0.05301089957356453,\n",
|
|
" 0.05052192881703377,\n",
|
|
" 0.05008789151906967]}],\n",
|
|
" 'Computed in 3093.76 ms')"
|
|
]
|
|
},
|
|
"execution_count": 109,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"resp"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next we will proceed with notebook [033_BuildImageForMLModule.ipynb](033_BuildImageForMLModule.ipynb)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|