InnerEye-Inference/app.py

207 строки
8.5 KiB
Python

# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from enum import Enum
import logging
from pathlib import Path
import sys
import tempfile
from typing import Any, Dict, Optional
from azureml._restclient.constants import RunStatus
from azureml._restclient.exceptions import ServiceException
from azureml.core import Workspace
from azureml.exceptions import WebserviceException
from flask import Flask, Response, make_response, jsonify, Request, request
from flask_injector import FlaskInjector
from injector import inject
from memory_tempfile import MemoryTempfile
from azure_config import AzureConfig
from configure import configure, API_AUTH_SECRET_HEADER_NAME, API_AUTH_SECRET
from submit_for_inference import DEFAULT_RESULT_IMAGE_NAME, submit_for_inference, SubmitForInferenceConfig
app = Flask(__name__)
RUNNING_OR_POST_PROCESSING = RunStatus.get_running_statuses() + RunStatus.get_post_processing_statuses()
root = logging.getLogger()
root.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)
# HTTP REST status codes.
class HTTP_STATUS_CODE(Enum):
OK = 200
CREATED = 201
ACCEPTED = 202
BAD_REQUEST = 400
UNAUTHORIZED = 401
FORBIDDEN = 403
NOT_FOUND = 404
INTERNAL_SERVER_ERROR = 500
# HTTP REST error messages, to be formatted as JSON.
ERROR_MESSAGES: Dict[HTTP_STATUS_CODE, Any] = {
HTTP_STATUS_CODE.BAD_REQUEST: {
'detail': 'Input file is not in correct format.',
'title': 'InvalidInput'
},
HTTP_STATUS_CODE.UNAUTHORIZED: {
'detail': 'Server failed to authenticate the request. '
f'Make sure the value of the {API_AUTH_SECRET_HEADER_NAME} header is populated.',
'title': 'NoAuthenticationInformation'
},
HTTP_STATUS_CODE.FORBIDDEN: {
'detail': 'Server failed to authenticate the request. '
f'Make sure the value of the {API_AUTH_SECRET_HEADER_NAME} header is correct.',
'title': 'AuthenticationFailed'
},
HTTP_STATUS_CODE.NOT_FOUND: {
'detail': 'The specified resource does not exist.',
'title': 'ResourceNotFound'
},
HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR: {
'detail': 'The server encountered an internal error. Please retry the request.',
'title': 'InternalError'
},
}
class ERROR_EXTRA_DETAILS(Enum):
INVALID_MODEL_ID = 'InvalidModelId'
INVALID_ZIP_FILE = 'InvalidZipFile'
RUN_CANCELLED = 'RunCancelled'
INVALID_RUN_ID = 'InvalidRunId'
def make_error_response(error_code: HTTP_STATUS_CODE, extra_details: Optional[ERROR_EXTRA_DETAILS] = None) -> Response:
"""
Format a Response object for an error_code.
:param error_code: Error code.
:param extra_details: Optional, any further information.
:return: Flask Response object with JSON error message.
"""
error_message = ERROR_MESSAGES[error_code]
error_message['code'] = error_code.name
error_message['status'] = error_code.value
if extra_details is not None:
error_message['extra_details'] = extra_details.value
return make_response(jsonify(error_message), error_code.value)
def is_authenticated_request(req: Request) -> Optional[Response]:
"""
Check request is authenticated.
If API_AUTH_SECRET_HEADER_NAME is not in request headers then return 401.
If API_AUTH_SECRET_HEADER_NAME is in request headers but incorrect then return 403.
Else return none.
:param req: Flask request object.
:return: Response if error else None.
"""
if API_AUTH_SECRET_HEADER_NAME not in req.headers:
return make_error_response(HTTP_STATUS_CODE.UNAUTHORIZED)
if req.headers[API_AUTH_SECRET_HEADER_NAME] != API_AUTH_SECRET:
return make_error_response(HTTP_STATUS_CODE.FORBIDDEN)
return None
@app.route("/v1/ping", methods=['GET'])
def ping() -> Response:
authentication_response = is_authenticated_request(request)
if authentication_response is not None:
return authentication_response
return make_response("", HTTP_STATUS_CODE.OK.value)
@inject
@app.route("/v1/model/start/<model_id>", methods=['POST'])
def start_model(model_id: str, workspace: Workspace, azure_config: AzureConfig) -> Response:
authentication_response = is_authenticated_request(request)
if authentication_response is not None:
return authentication_response
try:
image_data: bytes = request.stream.read()
logging.info(f'Starting {model_id}')
config = SubmitForInferenceConfig(model_id=model_id, image_data=image_data,
experiment_name=azure_config.experiment_name)
run_id = submit_for_inference(config, workspace, azure_config)
response = make_response(run_id, HTTP_STATUS_CODE.CREATED.value)
response.headers.set('Content-Type', 'text/plain')
return response
except WebserviceException as webException:
if webException.message.startswith('ModelNotFound'):
return make_error_response(HTTP_STATUS_CODE.NOT_FOUND,
ERROR_EXTRA_DETAILS.INVALID_MODEL_ID)
logging.error(webException)
return make_error_response(HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR)
except Exception as fatal_error:
logging.error(fatal_error)
return make_error_response(HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR)
@inject
@app.route("/v1/model/results/<run_id>", methods=['GET'])
def download_result(run_id: str, workspace: Workspace) -> Response:
authentication_response = is_authenticated_request(request)
if authentication_response is not None:
return authentication_response
logging.info(f"Checking run_id='{run_id}'")
try:
run = workspace.get_run(run_id)
run_status = run.status
if run_status in RUNNING_OR_POST_PROCESSING:
return make_response("", HTTP_STATUS_CODE.ACCEPTED.value)
logging.info(f"Run has completed with status {run.get_status()}")
if run_status != RunStatus.COMPLETED:
# Run cancelled or failed.
if run_status == RunStatus.FAILED:
with tempfile.TemporaryDirectory() as tmpdirname:
# Download the azureml-log files
run.download_files(prefix="azureml-logs", output_directory=tmpdirname,
append_prefix=False)
# In particular look for 70_driver_log.txt
driver_log_path = Path(tmpdirname) / '70_driver_log.txt'
if driver_log_path.exists():
driver_log = driver_log_path.read_text()
if "zipfile.BadZipFile" in driver_log:
return make_error_response(HTTP_STATUS_CODE.BAD_REQUEST,
ERROR_EXTRA_DETAILS.INVALID_ZIP_FILE)
if run_status == RunStatus.CANCELED:
return make_error_response(HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR,
ERROR_EXTRA_DETAILS.RUN_CANCELLED)
return make_error_response(HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR)
memory_tempfile = MemoryTempfile(fallback=True)
with memory_tempfile.NamedTemporaryFile() as tf:
file_name = str(tf.name)
run.download_file(DEFAULT_RESULT_IMAGE_NAME, file_name)
tf.seek(0)
result_bytes = tf.read()
response = make_response(result_bytes, HTTP_STATUS_CODE.OK.value)
response.headers.set('Content-Type', 'application/zip')
return response
except ServiceException as error:
if error.status_code == 404:
return make_error_response(HTTP_STATUS_CODE.NOT_FOUND,
ERROR_EXTRA_DETAILS.INVALID_RUN_ID)
logging.error(error)
return make_error_response(HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR)
except Exception as fatal_error:
logging.error(fatal_error)
return make_error_response(HTTP_STATUS_CODE.INTERNAL_SERVER_ERROR)
# Setup Flask Injector, this has to happen AFTER routes are added
FlaskInjector(app=app, modules=[configure])