Add an HTTP server for hosting of ONNX models (#806)

* Simple integration into CMake build system

* Adds vcpkg as a submodule and updates build.py to install hosting dependencies

* Don't create vcpkg executable if already created

* Fixes how CMake finds toolchain file and quick changes to build.py

* Removes setting the CMAKE_TOOLCHAIN_FILE in build.py

* Adds Boost Beast echo server and Boost program_options

* Fixes spacing problem with program_options

* Adds Microsoft headers to all the beast server headers

* Removes CXX 14 from CMake file

* Adds TODO to create configuration class

* Run clang-format on main

* Better exception handling of program_options

* Remove vckpg submodule via ssh

* Add vcpkg as https

* Adds onnxruntime namespace to call classes

* Fixed places where namespaces were anonymous

* Adds a TODO to use the logger

* Moves all setting namespace shortnames outside of onnxruntime namespace

* Add onnxruntime session options to force app to link with it

* Set CMAKE_TOOLCHAIN_FILE in build.py

* Remove whitespace

* Adds initial ONNX Hosting tests (#5)

* Add initial test which is failing linking with no main

* Adds test_main to get hosting tests working

* Deletes useless add_executable line

* Merge changes from upstream

* Enable CI build in Vienna environment

* make hosting_run*.sh executable

* Add boost path in unittest

* Add boost to TEST_INC_DIR

* Add component detection task in ci yaml

* Get tests and hosting to compile with re2 (#7)

* Add finding boost packages before using it in unit tests

* Add predict.proto and build

* Ignore unused parameters in generated code

* Removes std::regex in favor of re2 (#8)

* Removes std::regex in favor of re2

* Adds back find_package in unit tests and fixes regexes

* Adds more negative test cases

* Adding more protos

* Fix google protobuf file path in the cmake file

* Ignore unused parameters for pb generated code

* Updates onnx submodule (#10)

* Remove duplicated lib in link

* Follow Google style guide (#11)

* Google style names
* Adds more 
* Adds an additional namespace
* Fixes header guards to match filepaths

* Consume protobuf

* Unit Test setup

* Json deserialization simple test cases

* Split hosting app to lib and exe for testability

* Add more cases

* Clean up

* Add more comments

* Update namespace and format the cmake files

* Update cmake/external/onnx to checkout 1ec81bc6d49ccae23cd7801515feaadd13082903

* Separate h and cc in http folder

* Clean up hosting application cmake file

* Enable logging and proper initialize the session

* Update const position for GetSession()

* Take latest onnx and onnx-tensorrt

* Creates configuration header file for program_options (#15)

* Sets up PredictRequest callback (#16)

* Init version, porting from prototype, e2e works

* More executor implementation

* Adds function on application startup (#17)

* Attempts to pass HostingEnvironment as a shared_ptr

* Removes logging and environment from all http classes

* Passes http details to OnStart function

* Using full protobuf for hosting app build

* MLValue2TensorProto

* Revert back changes in inference_session.cc

* Refactor logger access and predict handler

* Create an error handling callback (#19)

* Creates error callback

* Logs error and returns back as JSON

* Catches exceptions in user functions

* Refactor executor and add some test cases

* Fix build warning

* Add onnx as a dependency and in includes to hosting app (#20)

* Converter for specific types and more UTs

* More unit tests

* Update onnx submodule

* Fix string data test

* Clean up code

* Cleanup code

* Refactor logging to use unique id per request and take logging level from user (#21)

* Removes capturing env by reference in main

* Uses uuid for logging ids

* Take logging_level as a program argument

* Pass logging_level to default_logging_manager

* Change name of logger to HostingApp

* Log if request id is null

* Update GetHttpStatusCode signature

* Fix random result issue and camel-case names

* Rollback accidentally changed pybin_state.cc

* Rollback pybind_state.cc

* Generate protobuf status from onnxruntime status

* Fix function name in error message

* Clean up comments

* Support protobuf byte array as input

* Refactor predict handler and add unit tests

* Add one more test

* update cmake/external/onnx

* Accept more protobuf MIME types

* Update onnx-tensorrt

* Add build instruction and usage doc

* Address PR comments

* Install g++-7 in the Ubuntu 16.04 build image for vcpkg

* Fix onnx-tensorrt version

* Check return value during initialization

* Fix infinite loop when http port is in use (#29)

* Simplify Executor.cc by breaking up Run method (#27)

* Move request id to Executor constructor

* Refactor the logger to respect user verbosity level

* Use Arena allocator instead of device

* Creates initial executor tests

* Merge upstream master (#31)

* Remove all possible shared_ptrs (#30)

* Changes GetLogger to unique_ptr

* Reserve BFloat raw data vector size

* Change HostingEnvironment to being passed by lvalue and rvalue references

* Change routes to getting passed by const references

* Enable full protobuf if building hosting (#32)

* Building hosting application no longer needs use_full_protobuf flag

* Improve hosting application docs

* Move server core into separate folder (#34)

* Turn hosting project off by default (#38)

* Remove vcpkg as a submodule and download/install Boost from source (#39)

* Remove vcpkg

* Use CMake script to download and build Boost as part of the project

* Remove std::move for const references

* Remove error_code.proto

* Change wording of executable help description

* Better GenerateProtobufStatus description

* Remove error_code protobuf from CMake files

* Use all outputs if no filter is given

* Pass MLValue by const reference in MLValueToTensorProto

* Rename variables to argc and argv

* Revert "Use all outputs if no filter is given"

This reverts commit 7554190ab8e50ba6947648c2f3e2a3d4d9606ce0.

* Remove all header guards in favor of #pragma once

* Reserve size for output vector and optimize for-loop

* Use static libs by default for Boost

* Improves documentation for GenerateResponseInJson function

* Start Result enum at 0 instead of 1

* Remove g++ from Ubuntu's install.sh

* Update cmake files

* Give explanation for Result enum type

* Remove all program options shortcuts except for -h

* Add comments for predict.proto

* Fix JSON for error codes

* Add notice on hosting application docs that it's in beta

* Change HostingEnvironment back to a shared_ptr

* Handle empty output_filter field

* Fix build break

* Refactor unit tests location and groups

* First end-to-end test

* Add missing log

* Missing req id and client req id in error response

* Add one test case to validate failed resp header

* Add build flag for hosting app end to end tests

* Update pipeline setup to run e2e test for CI build

* Model Zoo data preparation and tests

* Add protobuf tests

* Remove mention of needing g++-7 in BUILD.md

* Make GetAppLogger const

* Make using_raw_data_ match the styling of other fields

* Avoid copy of strings when initializing model

* Escape JSON strings correctly for error messages (#44)

* Escape JSON strings correctly

* Add test examples with lots of carriage returns

* Add result validation

* Remove temporary path

* Optimize model zoo test execution

* Improve reliability of test cases

* Generate _pb2.py during the build time

* README for integration tests

* Pass environment by pointer instead of shared_ptr to executor (#49)

* More Integration tests

* Remove generated files

* Make session private and use a getter instead (#53)

* logging_level to log_level for CLI

* Single model prediction shortcut

* Health endpoint

* Integration tests

* Rename to onnxruntime server

* Build ONNX Server application on Windows (#57)

* Gets Boost compiling on Windows

* Fix integer conversion and comparison problems

* Use size_t in converter_tests instead of int

* Fix hosting integration tests on Windows

* Removes checks for port because it's an unsigned short

* Fixes comparison between signed and unsigned data types

* Pip install protobuf and numpy

* Missing test data from the rename change

* Fix server app path (#58)

* Pass shared_ptr by const reference to avoid ref count increase (#59)

* Download test model during test setup

* Make download into test_util

* Rename ci pipeline for onnx runtime server

*  Support up to 10MiB http request (#61)

* Changes minimum request size to 10MB to support all models in ONNX Model Zoo
This commit is contained in:
tmccrmck 2019-04-30 18:21:23 -07:00 коммит произвёл Pranav Sharma
Родитель 6f5c28fd3a
Коммит 1978b3c953
59 изменённых файлов: 5046 добавлений и 3 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -11,6 +11,7 @@ distribute/*
*.bin
cmake_build
.cmake_build
cmake-build-debug
gen
*~
.vs

Просмотреть файл

@ -53,6 +53,10 @@ The complete list of build options can be found by running `./build.sh (or ./bui
1. For Windows, just add --x86 argument when launching build.bat
2. For Linux, it must be built out of a x86 os, --x86 argument also needs be specified to build.sh
## Build ONNX Runtime Server on Linux
1. In the ONNX Runtime root folder, run `./build.sh --config RelWithDebInfo --build_server --use_openmp --parallel`
## Build/Test Flavors for CI
### CI Build Environments

Просмотреть файл

@ -70,6 +70,7 @@ option(onnxruntime_USE_BRAINSLICE "Build with BrainSlice" OFF)
option(onnxruntime_USE_TENSORRT "Build with TensorRT support" OFF)
option(onnxruntime_ENABLE_LTO "Enable link time optimization, which is not stable on older GCCs" OFF)
option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF)
option(onnxruntime_BUILD_SERVER "Build ONNX Runtime Server" OFF)
option(onnxruntime_USE_FULL_PROTOBUF "Use full protobuf" OFF)
option(onnxruntime_DISABLE_CONTRIB_OPS "Disable contrib ops" OFF)
option(onnxruntime_USE_EIGEN_THREADPOOL "Use eigen threadpool. Otherwise OpenMP or a homemade one will be used" OFF)
@ -607,6 +608,10 @@ if (onnxruntime_BUILD_SHARED_LIB)
include(onnxruntime.cmake)
endif()
if (onnxruntime_BUILD_SERVER)
include(onnxruntime_server.cmake)
endif()
# some of the tests rely on the shared libs to be
# built; hence the ordering
if (onnxruntime_BUILD_UNIT_TESTS)
@ -633,3 +638,4 @@ if (onnxruntime_BUILD_CSHARP)
# set_property(GLOBAL PROPERTY VS_DOTNET_TARGET_FRAMEWORK_VERSION "netstandard2.0")
include(onnxruntime_csharp.cmake)
endif()

100
cmake/get_boost.cmake Normal file
Просмотреть файл

@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set(BOOST_REQUESTED_VERSION 1.69.0 CACHE STRING "")
set(BOOST_SHA1 8f32d4617390d1c2d16f26a27ab60d97807b35440d45891fa340fc2648b04406 CACHE STRING "")
set(BOOST_USE_STATIC_LIBS true CACHE BOOL "")
set(BOOST_COMPONENTS program_options system thread)
# These components are only needed for Windows
if(WIN32)
list(APPEND BOOST_COMPONENTS date_time regex)
endif()
# MSVC doesn't set these variables
if(WIN32)
set(CMAKE_STATIC_LIBRARY_PREFIX lib)
set(CMAKE_SHARED_LIBRARY_PREFIX lib)
endif()
# Set lib prefixes and suffixes for linking
if(BOOST_USE_STATIC_LIBS)
set(LIBRARY_PREFIX ${CMAKE_STATIC_LIBRARY_PREFIX})
set(LIBRARY_SUFFIX ${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(LIBRARY_PREFIX ${CMAKE_SHARED_LIBRARY_PREFIX})
set(LIBRARY_SUFFIX ${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
# Create list of components in Boost format
foreach(component ${BOOST_COMPONENTS})
list(APPEND BOOST_COMPONENTS_FOR_BUILD --with-${component})
endforeach()
set(BOOST_ROOT_DIR ${CMAKE_BINARY_DIR}/boost CACHE PATH "")
# TODO: let user give their own Boost installation
macro(DOWNLOAD_BOOST)
if(NOT BOOST_REQUESTED_VERSION)
message(FATAL_ERROR "BOOST_REQUESTED_VERSION is not defined.")
endif()
string(REPLACE "." "_" BOOST_REQUESTED_VERSION_UNDERSCORE ${BOOST_REQUESTED_VERSION})
set(BOOST_MAYBE_STATIC)
if(BOOST_USE_STATIC_LIBS)
set(BOOST_MAYBE_STATIC "link=static")
endif()
set(VARIANT "release")
if(CMAKE_BUILD_TYPE MATCHES Debug)
set(VARIANT "debug")
endif()
set(WINDOWS_B2_OPTIONS)
set(WINDOWS_LIB_NAME_SCHEME)
if(WIN32)
set(BOOTSTRAP_FILE_TYPE "bat")
set(WINDOWS_B2_OPTIONS "toolset=msvc-14.1" "architecture=x86" "address-model=64")
set(WINDOWS_LIB_NAME_SCHEME "-vc141-mt-gd-x64-1_69")
else()
set(BOOTSTRAP_FILE_TYPE "sh")
endif()
message(STATUS "Adding Boost components")
include(ExternalProject)
ExternalProject_Add(
Boost
URL http://dl.bintray.com/boostorg/release/${BOOST_REQUESTED_VERSION}/source/boost_${BOOST_REQUESTED_VERSION_UNDERSCORE}.tar.bz2
URL_HASH SHA256=${BOOST_SHA1}
DOWNLOAD_DIR ${BOOST_ROOT_DIR}
SOURCE_DIR ${BOOST_ROOT_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ./bootstrap.${BOOTSTRAP_FILE_TYPE} --prefix=${BOOST_ROOT_DIR}
BUILD_COMMAND ./b2 install ${BOOST_MAYBE_STATIC} --prefix=${BOOST_ROOT_DIR} variant=${VARIANT} ${WINDOWS_B2_OPTIONS} ${BOOST_COMPONENTS_FOR_BUILD}
BUILD_IN_SOURCE true
INSTALL_COMMAND ""
INSTALL_DIR ${BOOST_ROOT_DIR}
)
# Set include folders
ExternalProject_Get_Property(Boost INSTALL_DIR)
set(Boost_INCLUDE_DIR ${INSTALL_DIR}/include)
if(WIN32)
set(Boost_INCLUDE_DIR ${INSTALL_DIR}/include/boost-1_69)
endif()
# Set libraries to link
macro(libraries_to_fullpath varname)
set(${varname})
foreach(component ${BOOST_COMPONENTS})
list(APPEND ${varname} ${INSTALL_DIR}/lib/${LIBRARY_PREFIX}boost_${component}${WINDOWS_LIB_NAME_SCHEME}${LIBRARY_SUFFIX})
endforeach()
endmacro()
libraries_to_fullpath(Boost_LIBRARIES)
mark_as_advanced(Boost_LIBRARIES Boost_INCLUDE_DIR)
endmacro()
DOWNLOAD_BOOST()

Просмотреть файл

@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set(SERVER_APP_NAME "onnxruntime_server")
# Generate .h and .cc files from protobuf file
add_library(server_proto ${ONNXRUNTIME_ROOT}/server/protobuf/predict.proto)
if(WIN32)
target_compile_options(server_proto PRIVATE "/wd4125" "/wd4456")
endif()
target_include_directories(server_proto PUBLIC $<TARGET_PROPERTY:protobuf::libprotobuf,INTERFACE_INCLUDE_DIRECTORIES> "${CMAKE_CURRENT_BINARY_DIR}/.." ${CMAKE_CURRENT_BINARY_DIR}/onnx)
target_compile_definitions(server_proto PUBLIC $<TARGET_PROPERTY:protobuf::libprotobuf,INTERFACE_COMPILE_DEFINITIONS>)
onnxruntime_protobuf_generate(APPEND_PATH IMPORT_DIRS ${REPO_ROOT}/cmake/external/protobuf/src ${ONNXRUNTIME_ROOT}/server/protobuf ${ONNXRUNTIME_ROOT}/core/protobuf TARGET server_proto)
add_dependencies(server_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/model_metadata.pb.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/model_status.pb.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/predict.pb.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()
# Setup dependencies
include(get_boost.cmake)
set(re2_src ${REPO_ROOT}/cmake/external/re2)
# Setup source code
set(onnxruntime_server_lib_srcs
"${ONNXRUNTIME_ROOT}/server/http/json_handling.cc"
"${ONNXRUNTIME_ROOT}/server/http/predict_request_handler.cc"
"${ONNXRUNTIME_ROOT}/server/http/util.cc"
"${ONNXRUNTIME_ROOT}/server/environment.cc"
"${ONNXRUNTIME_ROOT}/server/executor.cc"
"${ONNXRUNTIME_ROOT}/server/converter.cc"
"${ONNXRUNTIME_ROOT}/server/util.cc"
)
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/http/json_handling.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/http/predict_request_handler.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/executor.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/converter.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/util.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()
file(GLOB_RECURSE onnxruntime_server_http_core_lib_srcs
"${ONNXRUNTIME_ROOT}/server/http/core/*.cc"
)
file(GLOB_RECURSE onnxruntime_server_srcs
"${ONNXRUNTIME_ROOT}/server/main.cc"
)
# HTTP core library
add_library(onnxruntime_server_http_core_lib STATIC
${onnxruntime_server_http_core_lib_srcs})
target_include_directories(onnxruntime_server_http_core_lib
PUBLIC
${ONNXRUNTIME_ROOT}/server/http/core
${Boost_INCLUDE_DIR}
${re2_src}
)
add_dependencies(onnxruntime_server_http_core_lib Boost)
target_link_libraries(onnxruntime_server_http_core_lib PRIVATE
${Boost_LIBRARIES}
)
# Server library
add_library(onnxruntime_server_lib ${onnxruntime_server_lib_srcs})
onnxruntime_add_include_to_target(onnxruntime_server_lib gsl onnx_proto server_proto)
target_include_directories(onnxruntime_server_lib PRIVATE
${ONNXRUNTIME_ROOT}
${CMAKE_CURRENT_BINARY_DIR}/onnx
${ONNXRUNTIME_ROOT}/server
${ONNXRUNTIME_ROOT}/server/http
PUBLIC
${Boost_INCLUDE_DIR}
${re2_src}
)
target_link_libraries(onnxruntime_server_lib PRIVATE
server_proto
${Boost_LIBRARIES}
onnxruntime_server_http_core_lib
onnxruntime_session
onnxruntime_optimizer
onnxruntime_providers
onnxruntime_util
onnxruntime_framework
onnxruntime_util
onnxruntime_graph
onnxruntime_common
onnxruntime_mlas
${onnxruntime_EXTERNAL_LIBRARIES}
)
# For IDE only
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_server_srcs} ${onnxruntime_server_lib_srcs} ${onnxruntime_server_lib})
# Server Application
add_executable(${SERVER_APP_NAME} ${onnxruntime_server_srcs})
add_dependencies(${SERVER_APP_NAME} onnx server_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties("${ONNXRUNTIME_ROOT}/server/main.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()
onnxruntime_add_include_to_target(${SERVER_APP_NAME} onnxruntime_session onnxruntime_server_lib gsl onnx onnx_proto server_proto)
target_include_directories(${SERVER_APP_NAME} PRIVATE
${ONNXRUNTIME_ROOT}
${ONNXRUNTIME_ROOT}/server/http
)
target_link_libraries(${SERVER_APP_NAME} PRIVATE
onnxruntime_server_http_core_lib
onnxruntime_server_lib
)

Просмотреть файл

@ -163,13 +163,15 @@ set(onnxruntime_test_framework_libs
onnxruntime_mlas
)
set(onnxruntime_test_server_libs
onnxruntime_test_utils
onnxruntime_test_utils_for_server
)
if(WIN32)
list(APPEND onnxruntime_test_framework_libs Advapi32)
endif()
set (onnxruntime_test_providers_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES})
if(onnxruntime_USE_CUDA)
@ -557,6 +559,58 @@ if (onnxruntime_BUILD_SHARED_LIB)
endif()
endif()
if (onnxruntime_BUILD_SERVER)
file(GLOB onnxruntime_test_server_src
"${TEST_SRC_DIR}/server/unit_tests/*.cc"
"${TEST_SRC_DIR}/server/unit_tests/*.h"
)
file(GLOB onnxruntime_integration_test_server_src
"${TEST_SRC_DIR}/server/integration_tests/*.py"
)
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties("${TEST_SRC_DIR}/server/unit_tests/json_handling_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${TEST_SRC_DIR}/server/unit_tests/converter_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${TEST_SRC_DIR}/server/unit_tests/util_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()
add_library(onnxruntime_test_utils_for_server ${onnxruntime_test_server_src})
onnxruntime_add_include_to_target(onnxruntime_test_utils_for_server onnxruntime_test_utils gtest gmock gsl onnx onnx_proto server_proto)
add_dependencies(onnxruntime_test_utils_for_server onnxruntime_server_lib onnxruntime_server_http_core_lib Boost ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_test_utils_for_server PUBLIC ${Boost_INCLUDE_DIR} ${REPO_ROOT}/cmake/external/re2 ${CMAKE_CURRENT_BINARY_DIR}/onnx ${ONNXRUNTIME_ROOT}/server/http ${ONNXRUNTIME_ROOT}/server/http/core PRIVATE ${ONNXRUNTIME_ROOT} )
target_link_libraries(onnxruntime_test_utils_for_server ${Boost_LIBRARIES} ${onnx_test_libs})
AddTest(
TARGET onnxruntime_server_tests
SOURCES ${onnxruntime_test_server_src}
LIBS ${onnxruntime_test_server_libs} server_proto onnxruntime_server_lib ${onnxruntime_test_providers_libs}
DEPENDS ${onnxruntime_EXTERNAL_DEPENDENCIES}
)
onnxruntime_protobuf_generate(
APPEND_PATH IMPORT_DIRS ${REPO_ROOT}/cmake/external/protobuf/src ${ONNXRUNTIME_ROOT}/server/protobuf ${ONNXRUNTIME_ROOT}/core/protobuf
PROTOS ${ONNXRUNTIME_ROOT}/server/protobuf/predict.proto ${ONNXRUNTIME_ROOT}/server/protobuf/onnx-ml.proto
LANGUAGE python
TARGET onnxruntime_server_tests
OUT_VAR server_test_py)
add_custom_command(
TARGET onnxruntime_server_tests POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/server_test
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_integration_test_server_src}
${CMAKE_CURRENT_BINARY_DIR}/server_test/
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_BINARY_DIR}/onnx_ml_pb2.py
${CMAKE_CURRENT_BINARY_DIR}/server_test/
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_BINARY_DIR}/predict_pb2.py
${CMAKE_CURRENT_BINARY_DIR}/server_test/
)
endif()
add_executable(onnxruntime_mlas_test ${TEST_SRC_DIR}/mlas/unittest.cpp)
target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc)

Просмотреть файл

@ -0,0 +1,140 @@
<h1><span style="color:red">Note: ONNX Runtime Server is still in beta state. It's currently not ready for production environments.</span></h1>
# How to Use ONNX Runtime Server REST API for Prediction
ONNX Runtime Server provides a REST API for prediction. The goal of the project is to make it easy to "host" any ONNX model as a RESTful service. The CLI command to start the service is shown below:
```
$ ./onnxruntime_server
the option '--model_path' is required but missing
Allowed options:
-h [ --help ] Shows a help message and exits
--log_level arg (=info) Logging level. Allowed options (case sensitive):
verbose, info, warning, error, fatal
--model_path arg Path to ONNX model
--address arg (=0.0.0.0) The base HTTP address
--http_port arg (=8001) HTTP port to listen to requests
--num_http_threads arg (=<# of your cpu cores>) Number of http threads
```
Note: The only mandatory argument for the program here is `model_path`
## Start the Server
To host an ONNX model as a REST API server, run:
```
./onnxruntime_server --model_path /<your>/<model>/<path>
```
The prediction URL is in this format:
```
http://<your_ip_address>:<port>/v1/models/<your-model-name>/versions/<your-version>:predict
```
**Note**: Since we currently only support one model, the model name and version can be any string length > 0. In the future, model_names and versions will be verified.
## Request and Response Payload
An HTTP request can be a Protobuf message in two formats: binary or JSON. The HTTP request header field `Content-Type` tells the server how to handle the request and thus it is mandatory for all requests. Requests missing `Content-Type` will be rejected as `400 Bad Request`.
* For `"Content-Type: application/json"`, the payload will be deserialized as JSON string in UTF-8 format
* For `"Content-Type: application/vnd.google.protobuf"`, `"Content-Type: application/x-protobuf"` or `"Content-Type: application/octet-stream"`, the payload will be consumed as protobuf message directly.
The Protobuf definition can be found [here](https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/server/protobuf/predict.proto).
## Inferencing
To send a request to the server, you can use any tool which supports making HTTP requests. Here is an example using `curl`:
```
curl -X POST -d "@predict_request_0.json" -H "Content-Type: application/json" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
```
or
```
curl -X POST --data-binary "@predict_request_0.pb" -H "Content-Type: application/octet-stream" -H "Foo: 1234" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
```
Clients can control the response type by setting the request with an `Accept` header field and the server will serialize in your desired format. The choices currently available are the same as the `Content-Type` header field.
## Advanced Topics
### Number of HTTP Threads
You can change this to optimize server utilization. The default is the number of CPU cores on the host machine.
### Request ID and Client Request ID
For easy tracking of requests, we provide the following header fields:
* `x-ms-request-id`: will be in the response header, no matter the request result. It will be a GUID/uuid with dash, e.g. `72b68108-18a4-493c-ac75-d0abd82f0a11`. If the request headers contain this field, the value will be ignored.
* `x-ms-client-request-id`: a field for clients to tracking their requests. The content will persist in the response headers.
Here is an example of a client sending a request:
#### Client Side
```
$ curl -v -X POST --data-binary "@predict_request_0.pb" -H "Content-Type: application/octet-stream" -H "Foo: 1234" -H "x-ms-client-request-id: my-request-001" -H "Accept: application/json" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
Note: Unnecessary use of -X or --request, POST is already inferred.
* Trying 127.0.0.1...
* Connected to 127.0.0.1 (127.0.0.1) port 8001 (#0)
> POST /v1/models/mymodel/versions/3:predict HTTP/1.1
> Host: 127.0.0.1:8001
> User-Agent: curl/7.47.0
> Content-Type: application/octet-stream
> x-ms-client-request-id: my-request-001
> Accept: application/json
> Content-Length: 3179
> Expect: 100-continue
>
* Done waiting for 100-continue
* We are completely uploaded and fine
< HTTP/1.1 200 OK
< Content-Type: application/json
< x-ms-request-id: 72b68108-18a4-493c-ac75-d0abd82f0a11
< x-ms-client-request-id: my-request-001
< Content-Length: 159
<
* Connection #0 to host 127.0.0.1 left intact
{"outputs":{"Sample_Output_Name":{"dims":["1","10"],"dataType":1,"rawData":"6OpzRFquGsSFdM1FyAEnRFtRZcRa9NDEUBj0xI4ydsJIS0LE//CzxA==","dataLocation":"DEFAULT"}}}%
```
#### Server Side
And here is what the output on the server side looks like with logging level of verbose:
```
2019-04-04 23:48:26.395200744 [V:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, predict_request_handler.cc:40 Predict] Name: mymodel Version: 3 Action: predict
2019-04-04 23:48:26.395289437 [V:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, predict_request_handler.cc:46 Predict] x-ms-client-request-id: [my-request-001]
2019-04-04 23:48:26.395540707 [I:onnxruntime:InferenceSession, inference_session.cc:736 Run] Running with tag: 72b68108-18a4-493c-ac75-d0abd82f0a11
2019-04-04 23:48:26.395596858 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, inference_session.cc:976 CreateLoggerForRun] Created logger for run with id of 72b68108-18a4-493c-ac75-d0abd82f0a11
2019-04-04 23:48:26.395731391 [I:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:42 Execute] Begin execution
2019-04-04 23:48:26.395763319 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:45 Execute] Size of execution plan vector: 12
2019-04-04 23:48:26.396228981 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Convolution28
2019-04-04 23:48:26.396580161 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus30
2019-04-04 23:48:26.396623732 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 10
2019-04-04 23:48:26.396878822 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: ReLU32
2019-04-04 23:48:26.397091882 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Pooling66
2019-04-04 23:48:26.397126243 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 11
2019-04-04 23:48:26.397772701 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Convolution110
2019-04-04 23:48:26.397818174 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 13
2019-04-04 23:48:26.398060592 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus112
2019-04-04 23:48:26.398095300 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 14
2019-04-04 23:48:26.398257563 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: ReLU114
2019-04-04 23:48:26.398426740 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Pooling160
2019-04-04 23:48:26.398466031 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 15
2019-04-04 23:48:26.398542823 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212_reshape0
2019-04-04 23:48:26.398599687 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212_reshape1
2019-04-04 23:48:26.398692631 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212
2019-04-04 23:48:26.398731471 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 17
2019-04-04 23:48:26.398832735 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus214
2019-04-04 23:48:26.398873229 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 19
2019-04-04 23:48:26.398922929 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:160 Execute] Fetching output.
2019-04-04 23:48:26.398956560 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:163 Execute] Done with execution.
```

Просмотреть файл

@ -0,0 +1,261 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <onnx/onnx_pb.h>
#include "core/common/logging/logging.h"
#include "core/framework/data_types.h"
#include "core/framework/environment.h"
#include "core/framework/framework_common.h"
#include "core/framework/mem_buffer.h"
#include "core/framework/ml_value.h"
#include "core/framework/tensor.h"
#include "core/framework/tensorprotoutils.h"
#include "onnx-ml.pb.h"
#include "predict.pb.h"
#include "converter.h"
namespace onnxruntime {
namespace server {
namespace protobufutil = google::protobuf::util;
onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::DataTypeImpl* cpp_type) {
if (cpp_type == onnxruntime::DataTypeImpl::GetType<float>()) {
return onnx::TensorProto_DataType_FLOAT;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint8_t>()) {
return onnx::TensorProto_DataType_UINT8;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int8_t>()) {
return onnx::TensorProto_DataType_INT8;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint16_t>()) {
return onnx::TensorProto_DataType_UINT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int16_t>()) {
return onnx::TensorProto_DataType_INT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int32_t>()) {
return onnx::TensorProto_DataType_INT32;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int64_t>()) {
return onnx::TensorProto_DataType_INT64;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<std::string>()) {
return onnx::TensorProto_DataType_STRING;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<bool>()) {
return onnx::TensorProto_DataType_BOOL;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MLFloat16>()) {
return onnx::TensorProto_DataType_FLOAT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<onnxruntime::BFloat16>()) {
return onnx::TensorProto_DataType_BFLOAT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<double>()) {
return onnx::TensorProto_DataType_DOUBLE;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint32_t>()) {
return onnx::TensorProto_DataType_UINT32;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint64_t>()) {
return onnx::TensorProto_DataType_UINT64;
} else {
return onnx::TensorProto_DataType_UNDEFINED;
}
}
common::Status MLValueToTensorProto(const onnxruntime::MLValue& ml_value, bool using_raw_data,
std::unique_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto) {
// Tensor in MLValue
const auto& tensor = ml_value.Get<onnxruntime::Tensor>();
// dims field
const onnxruntime::TensorShape& tensor_shape = tensor.Shape();
for (const auto& dim : tensor_shape.GetDims()) {
tensor_proto.add_dims(dim);
}
// data_type field
onnx::TensorProto_DataType data_type = MLDataTypeToTensorProtoDataType(tensor.DataType());
tensor_proto.set_data_type(data_type);
// data_location field: Data is stored in raw_data (if set) otherwise in type-specified field.
if (using_raw_data && data_type != onnx::TensorProto_DataType_STRING) {
tensor_proto.set_data_location(onnx::TensorProto_DataLocation_DEFAULT);
}
// *_data field
// According to onnx_ml.proto, depending on the data_type field,
// exactly one of the *_data fields is used to store the elements of the tensor.
switch (data_type) {
case onnx::TensorProto_DataType_FLOAT: { // Target: raw_data or float_data
const auto* data = tensor.Data<float>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (size_t i = 0, count = tensor.Shape().Size(); i < count; ++i) {
tensor_proto.add_float_data(data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_INT32: { // Target: raw_data or int32_data
const auto* data = tensor.Data<int32_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (size_t i = 0, count = tensor.Shape().Size(); i < count; ++i) {
tensor_proto.add_int32_data(data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT8: { // Target: raw_data or int32_data
const auto* data = tensor.Data<uint8_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_INT8: { // Target: raw_data or int32_data
const auto* data = tensor.Data<int8_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<uint16_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_INT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<int16_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_BOOL: { // Target: raw_data or int32_data
const auto* data = tensor.Data<bool>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_FLOAT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<onnxruntime::MLFloat16>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_BFLOAT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<onnxruntime::BFloat16>();
const auto raw_data_size = tensor.Shape().Size();
std::vector<uint16_t> raw_data;
raw_data.reserve(raw_data_size);
for (int i = 0; i < raw_data_size; ++i) {
raw_data.push_back(data[i].val);
}
if (using_raw_data) {
tensor_proto.set_raw_data(raw_data.data(), raw_data.size() * sizeof(uint16_t));
} else {
auto i32data = reinterpret_cast<const int32_t*>(raw_data.data());
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_STRING: { // Target: string_data
// string could not be written into "raw_data"
const auto* data = tensor.Data<std::string>();
for (size_t i = 0, count = tensor.Shape().Size(); i < count; ++i) {
tensor_proto.add_string_data(data[i]);
}
break;
}
case onnx::TensorProto_DataType_INT64: { // Target: raw_data or int64_data
const auto* data = tensor.Data<int64_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (size_t x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
tensor_proto.add_int64_data(data[x]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT32: { // Target: raw_data or uint64_data
const auto* data = tensor.Data<uint32_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto u64data = reinterpret_cast<const uint64_t*>(data);
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(uint64_t)); i < count; ++i) {
tensor_proto.add_uint64_data(u64data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT64: { // Target: raw_data or uint64_data
const auto* data = tensor.Data<uint64_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (size_t x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
tensor_proto.add_uint64_data(data[x]);
}
}
break;
}
case onnx::TensorProto_DataType_DOUBLE: { // Target: raw_data or double_data
auto data = tensor.Data<double>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (size_t x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
tensor_proto.add_double_data(data[x]);
}
}
break;
}
default: {
LOGS(*logger, ERROR) << "Unsupported TensorProto DataType: " << data_type;
return common::Status(common::StatusCategory::ONNXRUNTIME,
common::StatusCode::NOT_IMPLEMENTED,
"Unsupported TensorProto DataType: " + std::to_string(data_type));
}
}
return common::Status::OK();
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <google/protobuf/stubs/status.h>
#include "core/framework/data_types.h"
#include "environment.h"
#include "predict.pb.h"
namespace onnxruntime {
namespace server {
onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::DataTypeImpl* cpp_type);
// Convert MLValue to TensorProto. Some fields are ignored:
// * name field: could not get from MLValue
// * doc_string: could not get from MLValue
// * segment field: we do not expect very large tensors in the prediction output
// * external_data field: we do not expect very large tensors in the prediction output
// Note: If any input data is in raw_data field, all outputs tensor data will be put into raw_data field.
common::Status MLValueToTensorProto(const onnxruntime::MLValue& ml_value, bool using_raw_data,
std::unique_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto);
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <memory>
#include "core/common/logging/logging.h"
#include "environment.h"
#include "log_sink.h"
namespace onnxruntime {
namespace server {
ServerEnvironment::ServerEnvironment(logging::Severity severity) : severity_(severity),
logger_id_("ServerApp"),
default_logging_manager_(
std::unique_ptr<logging::ISink>{new LogSink{}},
severity,
/* default_filter_user_data */ false,
logging::LoggingManager::InstanceType::Default,
&logger_id_) {
auto status = onnxruntime::Environment::Create(runtime_environment_);
// The session initialization MUST BE AFTER environment creation
session = std::make_unique<onnxruntime::InferenceSession>(options_, &default_logging_manager_);
}
common::Status ServerEnvironment::InitializeModel(const std::string& model_path) {
auto status = session->Load(model_path);
if (!status.IsOK()) {
return status;
}
auto outputs = session->GetModelOutputs();
if (!outputs.first.IsOK()) {
return outputs.first;
}
for (const auto* output_node : *(outputs.second)) {
model_output_names_.push_back(output_node->Name());
}
return common::Status::OK();
}
const std::vector<std::string>& ServerEnvironment::GetModelOutputNames() const {
return model_output_names_;
}
const logging::Logger& ServerEnvironment::GetAppLogger() const {
return default_logging_manager_.DefaultLogger();
}
logging::Severity ServerEnvironment::GetLogSeverity() const {
return severity_;
}
std::unique_ptr<logging::Logger> ServerEnvironment::GetLogger(const std::string& id) {
if (id.empty()) {
LOGS(GetAppLogger(), WARNING) << "Request id is null or empty string";
}
return default_logging_manager_.CreateLogger(id, severity_, false);
}
onnxruntime::InferenceSession* ServerEnvironment::GetSession() const {
return session.get();
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <vector>
#include "core/framework/environment.h"
#include "core/common/logging/logging.h"
#include "core/session/inference_session.h"
namespace onnxruntime {
namespace server {
namespace logging = logging;
class ServerEnvironment {
public:
explicit ServerEnvironment(logging::Severity severity);
~ServerEnvironment() = default;
ServerEnvironment(const ServerEnvironment&) = delete;
const logging::Logger& GetAppLogger() const;
std::unique_ptr<logging::Logger> GetLogger(const std::string& id);
logging::Severity GetLogSeverity() const;
onnxruntime::InferenceSession* GetSession() const;
common::Status InitializeModel(const std::string& model_path);
const std::vector<std::string>& GetModelOutputNames() const;
private:
const logging::Severity severity_;
const std::string logger_id_;
logging::LoggingManager default_logging_manager_;
std::unique_ptr<onnxruntime::Environment> runtime_environment_;
onnxruntime::SessionOptions options_;
std::unique_ptr<onnxruntime::InferenceSession> session;
std::vector<std::string> model_output_names_;
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,148 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <stdio.h>
#include <onnx/onnx_pb.h>
#include "core/common/logging/logging.h"
#include "core/framework/data_types.h"
#include "core/framework/environment.h"
#include "core/framework/framework_common.h"
#include "core/framework/mem_buffer.h"
#include "core/framework/ml_value.h"
#include "core/framework/tensor.h"
#include "core/framework/tensorprotoutils.h"
#include "onnx-ml.pb.h"
#include "predict.pb.h"
#include "converter.h"
#include "executor.h"
#include "util.h"
namespace onnxruntime {
namespace server {
namespace protobufutil = google::protobuf::util;
protobufutil::Status Executor::SetMLValue(const onnx::TensorProto& input_tensor,
OrtAllocatorInfo* cpu_allocator_info,
/* out */ MLValue& ml_value) {
auto logger = env_->GetLogger(request_id_);
size_t cpu_tensor_length = 0;
auto status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(input_tensor, &cpu_tensor_length);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "GetSizeInBytesFromTensorProto() failed. Error Message: " << status.ToString();
return GenerateProtobufStatus(status, "GetSizeInBytesFromTensorProto() failed: " + status.ToString());
}
std::unique_ptr<char[]> data(new char[cpu_tensor_length]);
memset(data.get(), 0, cpu_tensor_length);
OrtCallback deleter;
status = onnxruntime::utils::TensorProtoToMLValue(onnxruntime::Env::Default(), nullptr, input_tensor,
onnxruntime::MemBuffer(data.release(), cpu_tensor_length, *cpu_allocator_info),
ml_value, deleter);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "TensorProtoToMLValue() failed. Message: " << status.ToString();
return GenerateProtobufStatus(status, "TensorProtoToMLValue() failed:" + status.ToString());
}
return protobufutil::Status::OK;
}
protobufutil::Status Executor::SetNameMLValueMap(onnxruntime::NameMLValMap& name_value_map, const onnxruntime::server::PredictRequest& request) {
auto logger = env_->GetLogger(request_id_);
OrtAllocatorInfo* cpu_allocator_info = nullptr;
auto ort_status = OrtCreateAllocatorInfo("Cpu", OrtArenaAllocator, 0, OrtMemTypeDefault, &cpu_allocator_info);
if (ort_status != nullptr || cpu_allocator_info == nullptr) {
LOGS(*logger, ERROR) << "OrtCreateAllocatorInfo failed";
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "OrtCreateAllocatorInfo() failed");
}
// Prepare the MLValue object
for (const auto& input : request.inputs()) {
using_raw_data_ = using_raw_data_ && input.second.has_raw_data();
MLValue ml_value;
auto status = SetMLValue(input.second, cpu_allocator_info, ml_value);
if (status != protobufutil::Status::OK) {
LOGS(*logger, ERROR) << "SetMLValue() failed! Input name: " << input.first;
return status;
}
auto insertion_result = name_value_map.insert(std::make_pair(input.first, ml_value));
if (!insertion_result.second) {
LOGS(*logger, ERROR) << "SetNameMLValueMap() failed! Input name: " << input.first << " Trying to overwrite existing input value";
return protobufutil::Status(protobufutil::error::Code::INVALID_ARGUMENT, "SetNameMLValueMap() failed: Cannot have two inputs with the same name");
}
}
return protobufutil::Status::OK;
}
protobufutil::Status Executor::Predict(const std::string& model_name,
const std::string& model_version,
onnxruntime::server::PredictRequest& request,
/* out */ onnxruntime::server::PredictResponse& response) {
auto logger = env_->GetLogger(request_id_);
// Convert PredictRequest to NameMLValMap
onnxruntime::NameMLValMap name_ml_value_map{};
auto conversion_status = SetNameMLValueMap(name_ml_value_map, request);
if (conversion_status != protobufutil::Status::OK) {
return conversion_status;
}
// Prepare the output names and vector
std::vector<std::string> output_names;
if (!request.output_filter().empty()) {
output_names.reserve(request.output_filter_size());
for (const auto& name : request.output_filter()) {
output_names.push_back(name);
}
} else {
output_names = env_->GetModelOutputNames();
}
std::vector<onnxruntime::MLValue> outputs(output_names.size());
// Run
OrtRunOptions run_options{};
run_options.run_log_verbosity_level = static_cast<unsigned int>(env_->GetLogSeverity());
run_options.run_tag = request_id_;
auto status = env_->GetSession()->Run(run_options, name_ml_value_map, output_names, &outputs);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "Run() failed."
<< ". Error Message: " << status.ToString();
return GenerateProtobufStatus(status, "Run() failed: " + status.ToString());
}
// Build the response
for (size_t i = 0, sz = outputs.size(); i < sz; ++i) {
onnx::TensorProto output_tensor{};
status = MLValueToTensorProto(outputs[i], using_raw_data_, std::move(logger), output_tensor);
logger = env_->GetLogger(request_id_);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "MLValueToTensorProto() failed. Output name: " << output_names[i] << ". Error Message: " << status.ToString();
return GenerateProtobufStatus(status, "MLValueToTensorProto() failed: " + status.ToString());
}
auto insertion_result = response.mutable_outputs()->insert({output_names[i], output_tensor});
if (!insertion_result.second) {
LOGS(*logger, ERROR) << "SetNameMLValueMap() failed. Output name: " << output_names[i] << " Trying to overwrite existing output value";
return protobufutil::Status(protobufutil::error::Code::INVALID_ARGUMENT, "SetNameMLValueMap() failed: Cannot have two outputs with the same name");
}
}
return protobufutil::Status::OK;
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <google/protobuf/stubs/status.h>
#include "environment.h"
#include "predict.pb.h"
namespace onnxruntime {
namespace server {
class Executor {
public:
Executor(ServerEnvironment* server_env, std::string request_id) : env_(server_env),
request_id_(std::move(request_id)),
using_raw_data_(true) {}
// Prediction method
google::protobuf::util::Status Predict(const std::string& model_name,
const std::string& model_version,
onnxruntime::server::PredictRequest& request,
/* out */ onnxruntime::server::PredictResponse& response);
private:
ServerEnvironment* env_;
const std::string request_id_;
bool using_raw_data_;
google::protobuf::util::Status SetMLValue(const onnx::TensorProto& input_tensor,
OrtAllocatorInfo* cpu_allocator_info,
/* out */ MLValue& ml_value);
google::protobuf::util::Status SetNameMLValueMap(onnxruntime::NameMLValMap& name_value_map, const onnxruntime::server::PredictRequest& request);
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
// boost random is using a deprecated header in 1.69
// See: https://github.com/boostorg/random/issues/49
#define BOOST_PENDING_INTEGER_LOG2_HPP
#include <boost/integer/integer_log2.hpp>
#include <string>
#include <boost/beast/http.hpp>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <boost/uuid/uuid_generators.hpp>
namespace onnxruntime {
namespace server {
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
// This class represents the HTTP context given to the user
// Currently, we are just giving the Boost request and response object
// But in the future we should write a wrapper around them
class HttpContext {
public:
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
http::response<http::string_body> response{};
const std::string request_id;
std::string client_request_id;
http::status error_code;
std::string error_message;
HttpContext() : request_id(boost::uuids::to_string(boost::uuids::random_generator()())),
client_request_id(""),
error_code(http::status::internal_server_error),
error_message("An unknown server error has occurred") {}
~HttpContext() = default;
HttpContext(const HttpContext&) = delete;
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <functional>
#include <iostream>
#include <memory>
#include <thread>
#include <vector>
#include <boost/asio.hpp>
#include <boost/beast/http.hpp>
#include "context.h"
#include "session.h"
#include "listener.h"
#include "http_server.h"
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
namespace onnxruntime {
namespace server {
App::App() {
http_details.address = boost::asio::ip::make_address_v4("0.0.0.0");
http_details.port = 8001;
http_details.threads = std::thread::hardware_concurrency();
}
App& App::Bind(net::ip::address address, unsigned short port) {
http_details.address = std::move(address);
http_details.port = port;
return *this;
}
App& App::NumThreads(int threads) {
http_details.threads = threads;
return *this;
}
App& App::RegisterStartup(const StartFn& on_start) {
on_start_ = on_start;
return *this;
}
App& App::RegisterPost(const std::string& route, const HandlerFn& fn) {
routes_.RegisterController(http::verb::post, route, fn);
return *this;
}
App& App::RegisterError(const ErrorFn& fn) {
routes_.RegisterErrorCallback(fn);
return *this;
}
App& App::Run() {
net::io_context ioc{http_details.threads};
// Create and launch a listening port
auto listener = std::make_shared<Listener>(routes_, ioc, tcp::endpoint{http_details.address, http_details.port});
auto initialized = listener->Init();
if (!initialized) {
exit(EXIT_FAILURE);
}
auto started = listener->Run();
if (!started) {
exit(EXIT_FAILURE);
}
// Run user on_start function
on_start_(http_details);
// Run the I/O service on the requested number of threads
std::vector<std::thread> v;
v.reserve(http_details.threads - 1);
for (auto i = http_details.threads - 1; i > 0; --i) {
v.emplace_back(
[&ioc] {
ioc.run();
});
}
ioc.run();
return *this;
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include "util.h"
#include "context.h"
#include "routes.h"
#include "session.h"
#include "listener.h"
namespace onnxruntime {
namespace server {
namespace http = beast::http; // from <boost/beast/http.hpp>
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
struct Details {
net::ip::address address;
unsigned short port;
int threads;
};
using StartFn = std::function<void(Details&)>;
// Accepts incoming connections and launches the sessions
// Each method returns the app itself so methods can be chained
class App {
public:
App();
App& Bind(net::ip::address address, unsigned short port);
App& NumThreads(int threads);
App& RegisterStartup(const StartFn& fn);
App& RegisterPost(const std::string& route, const HandlerFn& fn);
App& RegisterError(const ErrorFn& fn);
App& Run();
private:
Routes routes_{};
StartFn on_start_ = {};
Details http_details{};
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "listener.h"
#include "session.h"
#include "util.h"
namespace onnxruntime {
namespace server {
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
Listener::Listener(const Routes& routes, net::io_context& ioc, const tcp::endpoint& endpoint)
: routes_(routes), acceptor_(ioc), socket_(ioc), endpoint_(endpoint) {
}
bool Listener::Init() {
beast::error_code ec;
// Open the acceptor
acceptor_.open(endpoint_.protocol(), ec);
if (ec) {
ErrorHandling(ec, "open");
return false;
}
// Allow address reuse
acceptor_.set_option(net::socket_base::reuse_address(true), ec);
if (ec) {
ErrorHandling(ec, "set_option");
return false;
}
// Bind to the routes address
acceptor_.bind(endpoint_, ec);
if (ec) {
ErrorHandling(ec, "bind");
return false;
}
// Start listening for connections
acceptor_.listen(
net::socket_base::max_listen_connections, ec);
if (ec) {
ErrorHandling(ec, "listen");
return false;
}
return true;
}
bool Listener::Run() {
if (!acceptor_.is_open()) {
return false;
}
DoAccept();
return true;
}
void Listener::DoAccept() {
acceptor_.async_accept(
socket_,
std::bind(
&Listener::OnAccept,
shared_from_this(),
std::placeholders::_1));
}
void Listener::OnAccept(beast::error_code ec) {
if (ec) {
ErrorHandling(ec, "accept");
} else {
std::make_shared<HttpSession>(routes_, std::move(socket_))->Run();
}
// Accept another connection
DoAccept();
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <boost/asio/ip/tcp.hpp>
#include "routes.h"
#include "util.h"
namespace onnxruntime {
namespace server {
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
// Listens on a socket and creates an HTTP session
class Listener : public std::enable_shared_from_this<Listener> {
Routes routes_;
tcp::acceptor acceptor_;
tcp::socket socket_;
const tcp::endpoint endpoint_;
public:
Listener(const Routes& routes, net::io_context& ioc, const tcp::endpoint& endpoint);
// Initialize the HTTP server
bool Init();
// Start accepting incoming connections
bool Run();
// Asynchronously accepts the socket
void DoAccept();
// Creates the HTTP session and runs it
void OnAccept(beast::error_code ec);
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include "re2/re2.h"
#include "context.h"
#include "routes.h"
namespace onnxruntime {
namespace server {
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
bool Routes::RegisterController(http::verb method, const std::string& url_pattern, const HandlerFn& controller) {
if (controller == nullptr) {
return false;
}
switch (method) {
case http::verb::get:
this->get_fn_table.emplace_back(url_pattern, controller);
return true;
case http::verb::post:
this->post_fn_table.emplace_back(url_pattern, controller);
return true;
default:
return false;
}
}
bool Routes::RegisterErrorCallback(const ErrorFn& controller) {
if (controller == nullptr) {
return false;
}
on_error = controller;
return true;
}
http::status Routes::ParseUrl(http::verb method,
const std::string& url,
/* out */ std::string& model_name,
/* out */ std::string& model_version,
/* out */ std::string& action,
/* out */ HandlerFn& func) const {
std::vector<std::pair<std::string, HandlerFn>> func_table;
switch (method) {
case http::verb::get:
func_table = this->get_fn_table;
break;
case http::verb::post:
func_table = this->post_fn_table;
break;
default:
return http::status::method_not_allowed;
}
if (func_table.empty()) {
return http::status::method_not_allowed;
}
bool found_match = false;
for (const auto& pattern : func_table) {
if (re2::RE2::FullMatch(url, pattern.first, &model_name, &model_version, &action)) {
func = pattern.second;
found_match = true;
break;
}
}
if (!found_match) {
return http::status::not_found;
}
return http::status::ok;
}
} //namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <boost/beast/http.hpp>
#include "context.h"
namespace onnxruntime {
namespace server {
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
using HandlerFn = std::function<void(std::string&, std::string&, std::string&, HttpContext&)>;
using ErrorFn = std::function<void(HttpContext&)>;
// This class maintains two lists of regex -> function lists. One for POST requests and one for GET requests
// If the incoming URL could match more than one regex, the first one will win.
class Routes {
public:
Routes() = default;
ErrorFn on_error;
bool RegisterController(http::verb method, const std::string& url_pattern, const HandlerFn& controller);
bool RegisterErrorCallback(const ErrorFn& controller);
http::status ParseUrl(http::verb method,
const std::string& url,
/* out */ std::string& model_name,
/* out */ std::string& model_version,
/* out */ std::string& action,
/* out */ HandlerFn& func) const;
private:
std::vector<std::pair<std::string, HandlerFn>> post_fn_table;
std::vector<std::pair<std::string, HandlerFn>> get_fn_table;
};
} //namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,153 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "session.h"
namespace onnxruntime {
namespace server {
namespace net = boost::asio; // from <boost/asio.hpp>
namespace beast = boost::beast; // from <boost/beast.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
HttpSession::HttpSession(const Routes& routes, tcp::socket socket)
: routes_(routes), socket_(std::move(socket)), strand_(socket_.get_executor()) {
}
void HttpSession::DoRead() {
req_.emplace();
// TODO: make the max request size configable.
req_->body_limit(10 * 1024 * 1024); // Max request size: 10 MiB
http::async_read(socket_, buffer_, *req_,
net::bind_executor(
strand_,
std::bind(
&HttpSession::OnRead,
shared_from_this(),
std::placeholders::_1,
std::placeholders::_2)));
}
void HttpSession::OnRead(beast::error_code ec, std::size_t bytes_transferred) {
boost::ignore_unused(bytes_transferred);
// This means they closed the connection
if (ec == http::error::end_of_stream) {
return DoClose();
}
if (ec) {
ErrorHandling(ec, "read");
return;
}
// Send the response
HandleRequest(req_->release());
}
void HttpSession::OnWrite(beast::error_code ec, std::size_t bytes_transferred, bool close) {
boost::ignore_unused(bytes_transferred);
if (ec) {
ErrorHandling(ec, "write");
return;
}
if (close) {
// This means we should close the connection, usually because
// the response indicated the "Connection: close" semantic.
return DoClose();
}
// We're done with the response so delete it
res_ = nullptr;
// Read another request
DoRead();
}
void HttpSession::DoClose() {
// Send a TCP shutdown
beast::error_code ec;
socket_.shutdown(tcp::socket::shutdown_send, ec);
// At this point the connection is closed gracefully
}
template <class Msg>
void HttpSession::Send(Msg&& msg) {
using item_type = std::remove_reference_t<decltype(msg)>;
auto ptr = std::make_shared<item_type>(std::move(msg));
auto self_ = shared_from_this();
self_->res_ = ptr;
http::async_write(self_->socket_, *ptr,
net::bind_executor(strand_,
[ self_, close = ptr->need_eof() ](beast::error_code ec, std::size_t bytes) {
self_->OnWrite(ec, bytes, close);
}));
}
template <typename Body, typename Allocator>
void HttpSession::HandleRequest(http::request<Body, http::basic_fields<Allocator> >&& req) {
HttpContext context{};
context.request = std::move(req);
// Special handle the liveness probe endpoint for orchestration systems like Kubernetes.
if (context.request.method() == http::verb::get && context.request.target().to_string() == "/") {
context.response.body() = "Healthy";
} else {
auto status = ExecuteUserFunction(context);
if (status != http::status::ok) {
routes_.on_error(context);
}
}
context.response.keep_alive(context.request.keep_alive());
context.response.prepare_payload();
return Send(std::move(context.response));
}
http::status HttpSession::ExecuteUserFunction(HttpContext& context) {
std::string path = context.request.target().to_string();
std::string model_name, model_version, action;
HandlerFn func;
if (context.request.find("x-ms-client-request-id") != context.request.end()) {
context.client_request_id = context.request["x-ms-client-request-id"].to_string();
}
if (path == "/score") {
// This is a shortcut since we have only one model instance currently.
// This code path will be removed once we start supporting multiple models or multiple versions of one model.
path = "/v1/models/default/versions/1:predict";
}
auto status = routes_.ParseUrl(context.request.method(), path, model_name, model_version, action, func);
if (status != http::status::ok) {
context.error_code = status;
context.error_message = std::string(http::obsolete_reason(status)) +
". For HTTP method: " +
std::string(http::to_string(context.request.method())) +
" and request path: " +
context.request.target().to_string();
return status;
}
try {
func(model_name, model_version, action, context);
} catch (const std::exception& ex) {
context.error_message = std::string(ex.what());
return http::status::internal_server_error;
}
return http::status::ok;
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <boost/beast/version.hpp>
#include <boost/asio/bind_executor.hpp>
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/strand.hpp>
#include "context.h"
#include "routes.h"
#include "util.h"
namespace onnxruntime {
namespace server {
namespace net = boost::asio; // from <boost/asio.hpp>
namespace beast = boost::beast; // from <boost/beast.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
namespace http = beast::http;
// An implementation of a single HTTP session
// Used by a listener to hand off the work and async write back to a socket
class HttpSession : public std::enable_shared_from_this<HttpSession> {
public:
HttpSession(const Routes& routes, tcp::socket socket);
// Start the asynchronous operation
// The entrypoint for the class
void Run() {
DoRead();
}
private:
const Routes routes_;
tcp::socket socket_;
net::strand<net::io_context::executor_type> strand_;
beast::flat_buffer buffer_;
boost::optional<http::request_parser<http::string_body>> req_;
std::shared_ptr<void> res_{nullptr};
// Writes the message asynchronously back to the socket
// Stores the pointer to the message and the class itself so that
// They do not get destructed before the async process is finished
// If you pass shared_from_this() are guaranteed that the life time
// of your object will be extended to as long as the function needs it
// Most examples in boost::asio are based on this logic
template <class Msg>
void Send(Msg&& msg);
// Called after the session is finished reading the message
// Should set the response before calling Send
template <typename Body, typename Allocator>
void HandleRequest(http::request<Body, http::basic_fields<Allocator>>&& req);
// Handle the request and hand it off to the user's function
// Execute user function, handle errors
// HttpContext parameter can be updated here or in HandleRequest
http::status ExecuteUserFunction(HttpContext& context);
// Asynchronously reads the request from the socket
void DoRead();
// Perform error checking before handing off to HandleRequest
void OnRead(beast::error_code ec, std::size_t bytes_transferred);
// After writing, make the session read another request
void OnWrite(beast::error_code ec, std::size_t bytes_transferred, bool close);
// Close the connection
void DoClose();
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include <boost/beast/core.hpp>
#include "context.h"
#include "util.h"
namespace onnxruntime {
namespace server {
// Report a failure
void ErrorHandling(beast::error_code ec, char const* what) {
std::cerr << what << " failed: " << ec.value() << " : " << ec.message() << "\n";
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <boost/beast/core.hpp>
#include <boost/beast/http/status.hpp>
#include "context.h"
namespace onnxruntime {
namespace server {
namespace beast = boost::beast; // from <boost/beast.hpp>
// Report a failure
void ErrorHandling(beast::error_code ec, char const* what);
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <sstream>
#include <iomanip>
#include <boost/beast/core.hpp>
#include <google/protobuf/util/json_util.h>
#include "predict.pb.h"
#include "json_handling.h"
namespace protobufutil = google::protobuf::util;
namespace onnxruntime {
namespace server {
protobufutil::Status GetRequestFromJson(const std::string& json_string, /* out */ onnxruntime::server::PredictRequest& request) {
protobufutil::JsonParseOptions options;
options.ignore_unknown_fields = true;
protobufutil::Status result = JsonStringToMessage(json_string, &request, options);
return result;
}
protobufutil::Status GenerateResponseInJson(const onnxruntime::server::PredictResponse& response, /* out */ std::string& json_string) {
protobufutil::JsonPrintOptions options;
options.add_whitespace = false;
options.always_print_primitive_fields = false;
options.always_print_enums_as_ints = false;
options.preserve_proto_field_names = false;
protobufutil::Status result = MessageToJsonString(response, &json_string, options);
return result;
}
std::string CreateJsonError(const http::status error_code, const std::string& error_message) {
auto escaped_message = escape_string(error_message);
return R"({"error_code": )" + std::to_string(int(error_code)) + R"(, "error_message": ")" + escaped_message + R"("})" + "\n";
}
std::string escape_string(const std::string& message) {
std::ostringstream o;
for (char c : message) {
switch (c) {
case '"': o << "\\\""; break;
case '\\': o << "\\\\"; break;
case '\b': o << "\\b"; break;
case '\f': o << "\\f"; break;
case '\n': o << "\\n"; break;
case '\r': o << "\\r"; break;
case '\t': o << "\\t"; break;
default:
if ('\x00' <= c && c <= '\x1f') {
o << "\\u"
<< std::hex << std::setw(4) << std::setfill('0') << (int)c;
} else {
o << c;
}
}
}
return o.str();
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <google/protobuf/util/json_util.h>
#include <boost/beast/http.hpp>
#include "predict.pb.h"
namespace onnxruntime {
namespace server {
namespace http = boost::beast::http;
// Deserialize Json input to PredictRequest.
// Unknown fields in the json file will be ignored.
google::protobuf::util::Status GetRequestFromJson(const std::string& json_string, /* out */ onnxruntime::server::PredictRequest& request);
// Serialize PredictResponse to json string
// 1. Proto3 primitive fields with default values will be omitted in JSON output. Eg. int32 field with value 0 will be omitted
// 2. Enums will be printed as string, not int, to improve readability
google::protobuf::util::Status GenerateResponseInJson(const onnxruntime::server::PredictResponse& response, /* out */ std::string& json_string);
// Constructs JSON error message from error code object and error message
std::string CreateJsonError(http::status error_code, const std::string& error_message);
// Escapes a string following the JSON standard
// Mostly taken from here: https://stackoverflow.com/questions/7724448/simple-json-string-escape-for-c/33799784#33799784
std::string escape_string(const std::string& message);
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <google/protobuf/stubs/status.h>
#include "environment.h"
#include "http_server.h"
#include "json_handling.h"
#include "executor.h"
#include "util.h"
namespace onnxruntime {
namespace server {
namespace protobufutil = google::protobuf::util;
#define GenerateErrorResponse(logger, error_code, message, context) \
{ \
auto http_error_code = (error_code); \
(context).response.insert("x-ms-request-id", ((context).request_id)); \
if (!(context).client_request_id.empty()) { \
(context).response.insert("x-ms-client-request-id", (context).client_request_id); \
} \
auto json_error_message = CreateJsonError(http_error_code, (message)); \
LOGS((*logger), VERBOSE) << json_error_message; \
(context).response.result(http_error_code); \
(context).response.body() = json_error_message; \
(context).response.set(http::field::content_type, "application/json"); \
}
static bool ParseRequestPayload(const HttpContext& context, SupportedContentType request_type,
/* out */ PredictRequest& predictRequest, /* out */ http::status& error_code, /* out */ std::string& error_message);
void Predict(const std::string& name,
const std::string& version,
const std::string& action,
/* in, out */ HttpContext& context,
const std::shared_ptr<ServerEnvironment>& env) {
auto logger = env->GetLogger(context.request_id);
LOGS(*logger, INFO) << "Model Name: " << name << ", Version: " << version << ", Action: " << action;
if (!context.client_request_id.empty()) {
LOGS(*logger, INFO) << "x-ms-client-request-id: [" << context.client_request_id << "]";
}
// Request and Response content type information
SupportedContentType request_type = GetRequestContentType(context);
SupportedContentType response_type = GetResponseContentType(context);
if (response_type == SupportedContentType::Unknown) {
GenerateErrorResponse(logger, http::status::bad_request, "Unknown 'Accept' header field in the request", context);
}
// Deserialize the payload
auto body = context.request.body();
PredictRequest predict_request{};
http::status error_code;
std::string error_message;
bool parse_succeeded = ParseRequestPayload(context, request_type, predict_request, error_code, error_message);
if (!parse_succeeded) {
GenerateErrorResponse(logger, error_code, error_message, context);
return;
}
// Run Prediction
protobufutil::Status status;
Executor executor(env.get(), context.request_id);
PredictResponse predict_response{};
status = executor.Predict(name, version, predict_request, predict_response);
if (!status.ok()) {
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status.error_message(), context);
return;
}
// Serialize to proper output format
std::string response_body{};
if (response_type == SupportedContentType::Json) {
status = GenerateResponseInJson(predict_response, response_body);
if (!status.ok()) {
GenerateErrorResponse(logger, http::status::internal_server_error, status.error_message(), context);
return;
}
context.response.set(http::field::content_type, "application/json");
} else {
response_body = predict_response.SerializeAsString();
if (context.request.find("Accept") != context.request.end() && context.request["Accept"] != "*/*") {
context.response.set(http::field::content_type, context.request["Accept"].to_string());
} else {
context.response.set(http::field::content_type, "application/octet-stream");
}
}
// Build HTTP response
context.response.insert("x-ms-request-id", context.request_id);
if (!context.client_request_id.empty()) {
context.response.insert("x-ms-client-request-id", context.client_request_id);
}
context.response.body() = response_body;
context.response.result(http::status::ok);
};
static bool ParseRequestPayload(const HttpContext& context, SupportedContentType request_type, PredictRequest& predictRequest, http::status& error_code, std::string& error_message) {
auto body = context.request.body();
protobufutil::Status status;
switch (request_type) {
case SupportedContentType::Json: {
status = GetRequestFromJson(body, predictRequest);
if (!status.ok()) {
error_code = GetHttpStatusCode(status);
error_message = status.error_message();
return false;
}
break;
}
case SupportedContentType::PbByteArray: {
bool parse_succeeded = predictRequest.ParseFromArray(body.data(), static_cast<int>(body.size()));
if (!parse_succeeded) {
error_code = http::status::bad_request;
error_message = "Invalid payload.";
return false;
}
break;
}
default: {
error_code = http::status::bad_request;
error_message = "Missing or unknown 'Content-Type' header field in the request";
return false;
}
}
return true;
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "http_server.h"
#include "json_handling.h"
namespace onnxruntime {
namespace server {
namespace beast = boost::beast;
namespace http = beast::http;
void BadRequest(HttpContext& context, const std::string& error_message);
// TODO: decide whether this should be a class
void Predict(const std::string& name,
const std::string& version,
const std::string& action,
/* in, out */ HttpContext& context,
const std::shared_ptr<ServerEnvironment>& env);
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <unordered_set>
#include <boost/beast/core.hpp>
#include <boost/beast/http/status.hpp>
#include <google/protobuf/stubs/status.h>
#include "context.h"
#include "util.h"
namespace protobufutil = google::protobuf::util;
namespace onnxruntime {
namespace server {
static std::unordered_set<std::string> protobuf_mime_types{
"application/octet-stream",
"application/vnd.google.protobuf",
"application/x-protobuf"};
boost::beast::http::status GetHttpStatusCode(const protobufutil::Status& status) {
switch (status.error_code()) {
case protobufutil::error::Code::OK:
return boost::beast::http::status::ok;
case protobufutil::error::Code::UNKNOWN:
case protobufutil::error::Code::DEADLINE_EXCEEDED:
case protobufutil::error::Code::RESOURCE_EXHAUSTED:
case protobufutil::error::Code::ABORTED:
case protobufutil::error::Code::UNIMPLEMENTED:
case protobufutil::error::Code::INTERNAL:
case protobufutil::error::Code::UNAVAILABLE:
case protobufutil::error::Code::DATA_LOSS:
return boost::beast::http::status::internal_server_error;
case protobufutil::error::Code::CANCELLED:
case protobufutil::error::Code::INVALID_ARGUMENT:
case protobufutil::error::Code::ALREADY_EXISTS:
case protobufutil::error::Code::FAILED_PRECONDITION:
case protobufutil::error::Code::OUT_OF_RANGE:
return boost::beast::http::status::bad_request;
case protobufutil::error::Code::NOT_FOUND:
return boost::beast::http::status::not_found;
case protobufutil::error::Code::PERMISSION_DENIED:
return boost::beast::http::status::forbidden;
case protobufutil::error::Code::UNAUTHENTICATED:
return boost::beast::http::status::unauthorized;
default:
return boost::beast::http::status::internal_server_error;
}
}
SupportedContentType GetRequestContentType(const HttpContext& context) {
if (context.request.find("Content-Type") != context.request.end()) {
if (context.request["Content-Type"] == "application/json") {
return SupportedContentType::Json;
} else if (protobuf_mime_types.find(context.request["Content-Type"].to_string()) != protobuf_mime_types.end()) {
return SupportedContentType::PbByteArray;
}
}
return SupportedContentType::Unknown;
}
SupportedContentType GetResponseContentType(const HttpContext& context) {
if (context.request.find("Accept") != context.request.end()) {
if (context.request["Accept"] == "application/json") {
return SupportedContentType::Json;
} else if (context.request["Accept"] == "*/*" || protobuf_mime_types.find(context.request["Accept"].to_string()) != protobuf_mime_types.end()) {
return SupportedContentType::PbByteArray;
}
} else {
return SupportedContentType::PbByteArray;
}
return SupportedContentType::Unknown;
}
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <boost/beast/core.hpp>
#include <boost/beast/http/status.hpp>
#include <google/protobuf/stubs/status.h>
#include "server/http/core/context.h"
namespace onnxruntime {
namespace server {
namespace beast = boost::beast; // from <boost/beast.hpp>
enum class SupportedContentType : int {
Unknown,
Json,
PbByteArray
};
// Mapping protobuf status to http status
boost::beast::http::status GetHttpStatusCode(const google::protobuf::util::Status& status);
// "Content-Type" header field in request is MUST-HAVE.
// Currently we only support two types of input content type: application/json and application/octet-stream
SupportedContentType GetRequestContentType(const HttpContext& context);
// "Accept" header field in request is OPTIONAL.
// Currently we only support three types of response content type: */*, application/json and application/octet-stream
SupportedContentType GetResponseContentType(const HttpContext& context);
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <iostream>
#include "core/common/logging/logging.h"
#include "core/common/logging/sinks/ostream_sink.h"
namespace onnxruntime {
namespace server {
class LogSink : public onnxruntime::logging::OStreamSink {
public:
LogSink() : OStreamSink(std::cout, /*flush*/ true) {
}
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "environment.h"
#include "http_server.h"
#include "predict_request_handler.h"
#include "server_configuration.h"
namespace beast = boost::beast;
namespace http = beast::http;
namespace server = onnxruntime::server;
int main(int argc, char* argv[]) {
server::ServerConfiguration config{};
auto res = config.ParseInput(argc, argv);
if (res == server::Result::ExitSuccess) {
exit(EXIT_SUCCESS);
} else if (res == server::Result::ExitFailure) {
exit(EXIT_FAILURE);
}
const auto env = std::make_shared<server::ServerEnvironment>(config.logging_level);
auto logger = env->GetAppLogger();
LOGS(logger, VERBOSE) << "Logging manager initialized.";
LOGS(logger, INFO) << "Model path: " << config.model_path;
auto status = env->InitializeModel(config.model_path);
if (!status.IsOK()) {
LOGS(logger, FATAL) << "Initialize Model Failed: " << status.Code() << " ---- Error: [" << status.ErrorMessage() << "]";
exit(EXIT_FAILURE);
} else {
LOGS(logger, VERBOSE) << "Initialize Model Successfully!";
}
status = env->GetSession()->Initialize();
if (!status.IsOK()) {
LOGS(logger, FATAL) << "Session Initialization Failed:" << status.Code() << " ---- Error: [" << status.ErrorMessage() << "]";
exit(EXIT_FAILURE);
} else {
LOGS(logger, VERBOSE) << "Initialize Session Successfully!";
}
auto const boost_address = boost::asio::ip::make_address(config.address);
server::App app{};
app.RegisterStartup(
[&env](const auto& details) -> void {
auto logger = env->GetAppLogger();
LOGS(logger, INFO) << "Listening at: "
<< "http://" << details.address << ":" << details.port;
});
app.RegisterError(
[&env](auto& context) -> void {
auto logger = env->GetLogger(context.request_id);
LOGS(*logger, VERBOSE) << "Error code: " << context.error_code;
LOGS(*logger, VERBOSE) << "Error message: " << context.error_message;
context.response.result(context.error_code);
context.response.insert("Content-Type", "application/json");
context.response.insert("x-ms-request-id", context.request_id);
if (!context.client_request_id.empty()) {
context.response.insert("x-ms-client-request-id", (context).client_request_id);
}
context.response.body() = server::CreateJsonError(context.error_code, context.error_message);
});
app.RegisterPost(
R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))",
[&env](const auto& name, const auto& version, const auto& action, auto& context) -> void {
server::Predict(name, version, action, context, env);
});
app.Bind(boost_address, config.http_port)
.NumThreads(config.num_http_threads)
.Run();
return EXIT_SUCCESS;
}

Просмотреть файл

@ -0,0 +1 @@
../../core/protobuf/onnx-ml.proto3

Просмотреть файл

@ -0,0 +1,27 @@
syntax = "proto3";
import "onnx-ml.proto";
package onnxruntime.server;
// PredictRequest specifies how inputs are mapped to tensors
// and how outputs are filtered before returning to user.
message PredictRequest {
reserved 1;
// Input Tensors.
// This is a mapping between output name and tensor.
map<string, onnx.TensorProto> inputs = 2;
// Output Filters.
// This field is to specify which output fields need to be returned.
// If the list is empty, all outputs will be included.
repeated string output_filter = 3;
}
// Response for PredictRequest on successful run.
message PredictResponse {
// Output Tensors.
// This is a mapping between output name and tensor.
map<string, onnx.TensorProto> outputs = 1;
}

Просмотреть файл

@ -0,0 +1,128 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <thread>
#include <fstream>
#include <unordered_map>
#include "boost/program_options.hpp"
#include "core/common/logging/logging.h"
namespace onnxruntime {
namespace server {
namespace po = boost::program_options;
// Enumerates the different type of results which can occur
// The three different types are:
// 0. ExitSuccess which is when the program should exit with EXIT_SUCCESS
// 1. ExitFailure when program should exit with EXIT_FAILURE
// 2. No need for exiting the program, continue
enum class Result {
ExitSuccess,
ExitFailure,
ContinueSuccess
};
static std::unordered_map<std::string, onnxruntime::logging::Severity> supported_log_levels{
{"verbose", onnxruntime::logging::Severity::kVERBOSE},
{"info", onnxruntime::logging::Severity::kINFO},
{"warning", onnxruntime::logging::Severity::kWARNING},
{"error", onnxruntime::logging::Severity::kERROR},
{"fatal", onnxruntime::logging::Severity::kFATAL}};
// Wrapper around Boost program_options and should provide all the functionality for options parsing
// Provides sane default values
class ServerConfiguration {
public:
const std::string full_desc = "ONNX Server: host an ONNX model with ONNX Runtime";
std::string model_path;
std::string address = "0.0.0.0";
unsigned short http_port = 8001;
int num_http_threads = std::thread::hardware_concurrency();
onnxruntime::logging::Severity logging_level{};
ServerConfiguration() {
desc.add_options()("help,h", "Shows a help message and exits");
desc.add_options()("log_level", po::value(&log_level_str)->default_value(log_level_str), "Logging level. Allowed options (case sensitive): verbose, info, warning, error, fatal");
desc.add_options()("model_path", po::value(&model_path)->required(), "Path to ONNX model");
desc.add_options()("address", po::value(&address)->default_value(address), "The base HTTP address");
desc.add_options()("http_port", po::value(&http_port)->default_value(http_port), "HTTP port to listen to requests");
desc.add_options()("num_http_threads", po::value(&num_http_threads)->default_value(num_http_threads), "Number of http threads");
}
// Parses argc and argv and sets the values for the class
// Returns an enum with three options: ExitSuccess, ExitFailure, ContinueSuccess
// ExitSuccess and ExitFailure means the program should exit but is left to the caller
Result ParseInput(int argc, char** argv) {
try {
po::store(po::command_line_parser(argc, argv).options(desc).run(), vm); // can throw
if (ContainsHelp()) {
PrintHelp(std::cout, full_desc);
return Result::ExitSuccess;
}
po::notify(vm); // throws on error, so do after help
} catch (const po::error& e) {
PrintHelp(std::cerr, e.what());
return Result::ExitFailure;
} catch (const std::exception& e) {
PrintHelp(std::cerr, e.what());
return Result::ExitFailure;
}
Result result = ValidateOptions();
if (result == Result::ContinueSuccess) {
logging_level = supported_log_levels[log_level_str];
}
return result;
}
private:
po::options_description desc{"Allowed options"};
po::variables_map vm{};
std::string log_level_str = "info";
// Print help and return if there is a bad value
Result ValidateOptions() {
if (vm.count("log_level") &&
supported_log_levels.find(log_level_str) == supported_log_levels.end()) {
PrintHelp(std::cerr, "log_level must be one of verbose, info, warning, error, or fatal");
return Result::ExitFailure;
} else if (num_http_threads <= 0) {
PrintHelp(std::cerr, "num_http_threads must be greater than 0");
return Result::ExitFailure;
} else if (!file_exists(model_path)) {
PrintHelp(std::cerr, "model_path must be the location of a valid file");
return Result::ExitFailure;
} else {
return Result::ContinueSuccess;
}
}
// Checks if program options contains help
bool ContainsHelp() const {
return vm.count("help") || vm.count("h");
}
// Prints a helpful message (param: what) to the user and then the program options
// Example: config.PrintHelp(std::cout, "Non-negative values not allowed")
// Which will print that message and then all publicly available options
void PrintHelp(std::ostream& out, const std::string& what) const {
out << what << std::endl
<< desc << std::endl;
}
inline bool file_exists(const std::string& fileName) {
std::ifstream infile(fileName.c_str());
return infile.good();
}
};
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <sstream>
#include <google/protobuf/stubs/status.h>
#include "core/common/status.h"
#include "util.h"
namespace onnxruntime {
namespace server {
namespace protobufutil = google::protobuf::util;
protobufutil::Status GenerateProtobufStatus(const onnxruntime::common::Status& onnx_status, const std::string& message) {
protobufutil::error::Code code = protobufutil::error::Code::UNKNOWN;
switch (onnx_status.Code()) {
case onnxruntime::common::StatusCode::OK:
case onnxruntime::common::StatusCode::MODEL_LOADED:
code = protobufutil::error::Code::OK;
break;
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
case onnxruntime::common::StatusCode::INVALID_GRAPH:
case onnxruntime::common::StatusCode::SHAPE_INFERENCE_NOT_REGISTERED:
case onnxruntime::common::StatusCode::REQUIREMENT_NOT_REGISTERED:
case onnxruntime::common::StatusCode::NO_SUCHFILE:
case onnxruntime::common::StatusCode::NO_MODEL:
code = protobufutil::error::Code::INVALID_ARGUMENT;
break;
case onnxruntime::common::StatusCode::NOT_IMPLEMENTED:
code = protobufutil::error::Code::UNIMPLEMENTED;
break;
case onnxruntime::common::StatusCode::FAIL:
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
code = protobufutil::error::Code::INTERNAL;
break;
default:
code = protobufutil::error::Code::UNKNOWN;
}
std::ostringstream oss;
oss << "ONNX Runtime Status Code: " << onnx_status.Code() << ". " << message;
return protobufutil::Status(code, oss.str());
}
} // namespace server
} // namespace onnxruntime

18
onnxruntime/server/util.h Normal file
Просмотреть файл

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <google/protobuf/stubs/status.h>
#include "core/common/status.h"
namespace onnxruntime {
namespace server {
// Generate protobuf status from ONNX Runtime status
google::protobuf::util::Status GenerateProtobufStatus(const onnxruntime::common::Status& onnx_status, const std::string& message);
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,43 @@
# ONNX Runtime Server Integration Tests
## Preparation
Tests validation depends on protobuf generated *_pb2.py. So we need to have a sucessful server application build to have it generated in the build folder under `server_test` subfolder. The following instruction assume you are in the folder. Otherwise, tests will fail due to `ModuleNotFoundError`.
## Functional Tests
Functional test will be run when build with `--build_server --enable_server_tests`. To run it separately, here is the command line:
```Bash
/usr/bin/python3 ./test_main.py <server_app_path> <mnist_model_path> <test_data_path>
```
## Model Zoo Tests
To run this set of tests, a prepared test data set need to be downloaded from [Azure Blob Storage](https://onnxserverdev.blob.core.windows.net/testing/server_test_data_20190422.zip) and unzip to a folder, e.g. /home/foo/bar/model_zoo_test. It contains:
* ONNX models from [ONNX Model Zoo](https://github.com/onnx/models) with opset 7/8/9.
* HTTP request json and protobuf files
* Expected response json and protobuf files
If you only need the request and response data. Here is the [link](https://onnxserverdev.blob.core.windows.net/testing/server_test_data_req_resp_only.zip) to download.
To run the full model zoo tests, here is the command line:
```Bash
/usr/bin/python3 ./model_zoo_tests.py <server_app_path> <model_path> <test_data_path>
```
For example:
```Bash
/usr/bin/python3 ./model_zoo_tests.py /some/where/server_app /home/foo/bar/model_zoo_test /home/foo/bar/model_zoo_test
```
If those models are in different folder but in the same structure as the test data, you could also do
```Bash
/usr/bin/python3 ./model_zoo_tests.py /some/where/server_app /home/my/models/ /home/foo/bar/model_zoo_test/
```
All tests are running in sequential order.

Просмотреть файл

@ -0,0 +1,363 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import subprocess
import time
import os
import requests
import json
import numpy
import test_util
import onnx_ml_pb2
import predict_pb2
class HttpJsonPayloadTests(unittest.TestCase):
server_ip = '127.0.0.1'
server_port = 54321
url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict'
server_app_path = ''
test_data_path = ''
model_path = ''
log_level = 'verbose'
server_app_proc = None
wait_server_ready_in_seconds = 1
@classmethod
def setUpClass(cls):
onnx_model = os.path.join(cls.model_path, 'mnist.onnx')
test_util.prepare_mnist_model(onnx_model)
cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level]
test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
cls.server_app_proc = subprocess.Popen(cmd)
test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid))
test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
time.sleep(cls.wait_server_ready_in_seconds)
@classmethod
def tearDownClass(cls):
test_util.test_log('Shutdown server app')
cls.server_app_proc.kill()
test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)))
def test_mnist_happy_path(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')
with open(input_data_file, 'r') as f:
request_payload = f.read()
with open(output_data_file, 'r') as f:
expected_response_json = f.read()
expected_response = json.loads(expected_response_json)
request_headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'x-ms-client-request-id': 'This~is~my~id'
}
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
actual_response = json.loads(r.content.decode('utf-8'))
# Note:
# The 'dims' field is defined as "repeated int64" in protobuf.
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
self.assertTrue(actual_response['outputs'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')
for i in range(0, 10):
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
def test_mnist_invalid_url(self):
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', -1)
test_util.test_log(url)
request_headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
r = requests.post(url, headers=request_headers, data={'foo': 'bar'})
self.assertEqual(r.status_code, 404)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
def test_mnist_invalid_content_type(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
request_headers = {
'Content-Type': 'application/abc',
'Accept': 'application/json',
'x-ms-client-request-id': 'This~is~my~id'
}
with open(input_data_file, 'r') as f:
request_payload = f.read()
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 400)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n')
def test_mnist_missing_content_type(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
request_headers = {
'Accept': 'application/json'
}
with open(input_data_file, 'r') as f:
request_payload = f.read()
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 400)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n')
def test_single_model_shortcut(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')
with open(input_data_file, 'r') as f:
request_payload = f.read()
with open(output_data_file, 'r') as f:
expected_response_json = f.read()
expected_response = json.loads(expected_response_json)
request_headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'x-ms-client-request-id': 'This~is~my~id'
}
url = "http://{0}:{1}/score".format(self.server_ip, self.server_port)
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
actual_response = json.loads(r.content.decode('utf-8'))
# Note:
# The 'dims' field is defined as "repeated int64" in protobuf.
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
self.assertTrue(actual_response['outputs'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')
for i in range(0, 10):
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
class HttpProtobufPayloadTests(unittest.TestCase):
server_ip = '127.0.0.1'
server_port = 54321
url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict'
server_app_path = ''
test_data_path = ''
model_path = ''
log_level = 'verbose'
server_app_proc = None
wait_server_ready_in_seconds = 1
@classmethod
def setUpClass(cls):
onnx_model = os.path.join(cls.model_path, 'mnist.onnx')
test_util.prepare_mnist_model(onnx_model)
cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level]
test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
cls.server_app_proc = subprocess.Popen(cmd)
test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid))
test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
time.sleep(cls.wait_server_ready_in_seconds)
@classmethod
def tearDownClass(cls):
test_util.test_log('Shutdown server app')
cls.server_app_proc.kill()
test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)))
def test_mnist_happy_path(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.pb')
with open(input_data_file, 'rb') as f:
request_payload = f.read()
content_type_headers = ['application/x-protobuf', 'application/octet-stream', 'application/vnd.google.protobuf']
for h in content_type_headers:
request_headers = {
'Content-Type': h,
'Accept': 'application/x-protobuf'
}
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/x-protobuf')
self.assertTrue(r.headers.get('x-ms-request-id'))
actual_result = predict_pb2.PredictResponse()
actual_result.ParseFromString(r.content)
expected_result = predict_pb2.PredictResponse()
with open(output_data_file, 'rb') as f:
expected_result.ParseFromString(f.read())
for k in expected_result.outputs.keys():
self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
count = 1
for i in range(0, len(expected_result.outputs['Plus214_Output_0'].dims)):
self.assertEqual(actual_result.outputs['Plus214_Output_0'].dims[i], expected_result.outputs['Plus214_Output_0'].dims[i])
count = count * int(actual_result.outputs['Plus214_Output_0'].dims[i])
actual_array = numpy.frombuffer(actual_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32)
expected_array = numpy.frombuffer(expected_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32)
self.assertEqual(len(actual_array), len(expected_array))
self.assertEqual(len(actual_array), count)
for i in range(0, count):
self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
def test_respect_accept_header(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
with open(input_data_file, 'rb') as f:
request_payload = f.read()
accept_headers = ['application/x-protobuf', 'application/octet-stream', 'application/vnd.google.protobuf']
for h in accept_headers:
request_headers = {
'Content-Type': 'application/x-protobuf',
'Accept': h
}
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), h)
def test_missing_accept_header(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
with open(input_data_file, 'rb') as f:
request_payload = f.read()
request_headers = {
'Content-Type': 'application/x-protobuf',
}
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream')
def test_any_accept_header(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
with open(input_data_file, 'rb') as f:
request_payload = f.read()
request_headers = {
'Content-Type': 'application/x-protobuf',
'Accept': '*/*'
}
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream')
class HttpEndpointTests(unittest.TestCase):
server_ip = '127.0.0.1'
server_port = 54321
server_app_path = ''
test_data_path = ''
model_path = ''
log_level = 'verbose'
server_app_proc = None
wait_server_ready_in_seconds = 1
@classmethod
def setUpClass(cls):
onnx_model = os.path.join(cls.model_path, 'mnist.onnx')
test_util.prepare_mnist_model(onnx_model)
cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level]
test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
cls.server_app_proc = subprocess.Popen(cmd)
test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid))
test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
time.sleep(cls.wait_server_ready_in_seconds)
@classmethod
def tearDownClass(cls):
test_util.test_log('Shutdown server app')
cls.server_app_proc.kill()
test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)))
def test_health_endpoint(self):
url = url = "http://{0}:{1}/".format(self.server_ip, self.server_port)
test_util.test_log(url)
r = requests.get(url)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.content.decode('utf-8'), 'Healthy')
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,120 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import sys
import shutil
import onnx
import onnxruntime
import json
from google.protobuf.json_format import MessageToJson
import predict_pb2
import onnx_ml_pb2
# Current models only have one input and one output
def get_io_name(model_file_name):
sess = onnxruntime.InferenceSession(model_file_name)
return sess.get_inputs()[0].name, sess.get_outputs()[0].name
def gen_input_pb(pb_full_path, input_name, output_name, request_file_path):
t = onnx_ml_pb2.TensorProto()
with open(pb_full_path, 'rb') as fin:
t.ParseFromString(fin.read())
predict_request = predict_pb2.PredictRequest()
predict_request.inputs[input_name].CopyFrom(t)
predict_request.output_filter.append(output_name)
with open(request_file_path, "wb") as fout:
fout.write(predict_request.SerializeToString())
def gen_output_pb(pb_full_path, output_name, response_file_path):
t = onnx_ml_pb2.TensorProto()
with open(pb_full_path, 'rb') as fin:
t.ParseFromString(fin.read())
predict_response = predict_pb2.PredictResponse()
predict_response.outputs[output_name].CopyFrom(t)
with open(response_file_path, "wb") as fout:
fout.write(predict_response.SerializeToString())
def tensor2dict(full_path):
t = onnx.TensorProto()
with open(full_path, 'rb') as f:
t.ParseFromString(f.read())
jsonStr = MessageToJson(t, use_integers_for_enums=True)
data = json.loads(jsonStr)
return data
def gen_input_json(pb_full_path, input_name, output_name, json_file_path):
data = tensor2dict(pb_full_path)
inputs = {}
inputs[input_name] = data
output_filters = [ output_name ]
req = {}
req["inputs"] = inputs
req["outputFilter"] = output_filters
with open(json_file_path, 'w') as outfile:
json.dump(req, outfile)
def gen_output_json(pb_full_path, output_name, json_file_path):
data = tensor2dict(pb_full_path)
output = {}
output[output_name] = data
resp = {}
resp["outputs"] = output
with open(json_file_path, 'w') as outfile:
json.dump(resp, outfile)
def gen_req_resp(model_zoo, test_data, copy_model=True):
opsets = [name for name in os.listdir(model_zoo) if os.path.isdir(os.path.join(model_zoo, name))]
for opset in opsets:
os.makedirs(os.path.join(test_data, opset), exist_ok=True)
current_model_folder = os.path.join(model_zoo, opset)
current_data_folder = os.path.join(test_data, opset)
models = [name for name in os.listdir(current_model_folder) if os.path.isdir(os.path.join(current_model_folder, name))]
for model in models:
os.makedirs(os.path.join(current_data_folder, model), exist_ok=True)
src_folder = os.path.join(current_model_folder, model)
dst_folder = os.path.join(current_data_folder, model)
if copy_model:
shutil.copy2(os.path.join(src_folder, 'model.onnx'), dst_folder)
iname, oname = get_io_name(os.path.join(src_folder, 'model.onnx'))
model_test_data = [name for name in os.listdir(src_folder) if os.path.isdir(os.path.join(src_folder, name))]
for test in model_test_data:
src = os.path.join(src_folder, test)
dst = os.path.join(dst_folder, test)
os.makedirs(dst, exist_ok=True)
gen_input_json(os.path.join(src, 'input_0.pb'), iname, oname, os.path.join(dst, 'request.json'))
gen_output_json(os.path.join(src, 'output_0.pb'), oname, os.path.join(dst, 'response.json'))
gen_input_pb(os.path.join(src, 'input_0.pb'), iname, oname, os.path.join(dst, 'request.pb'))
gen_output_pb(os.path.join(src, 'output_0.pb'), oname, os.path.join(dst, 'response.pb'))
if __name__ == '__main__':
model_zoo = os.path.realpath(sys.argv[1])
test_data = os.path.realpath(sys.argv[2])
os.makedirs(test_data, exist_ok=True)
gen_req_resp(model_zoo, test_data)

Просмотреть файл

@ -0,0 +1,101 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import random
import os
import test_util
import sys
class ModelZooTests(unittest.TestCase):
server_ip = '127.0.0.1'
server_port = 54321
url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict'
server_app_path = '' # Required
log_level = 'verbose'
server_ready_in_seconds = 10
server_off_in_seconds = 100
need_data_preparation = False
need_data_cleanup = False
model_zoo_model_path = '' # Required
model_zoo_test_data_path = '' # Required
supported_opsets = ['opset_7', 'opset_8', 'opset_9']
skipped_models = []
def test_models_from_model_zoo(self):
json_request_headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
pb_request_headers = {
'Content-Type': 'application/octet-stream',
'Accept': 'application/octet-stream'
}
model_data_map = {}
for opset in self.supported_opsets:
test_data_folder = os.path.join(self.model_zoo_test_data_path, opset)
model_file_folder = os.path.join(self.model_zoo_model_path, opset)
if os.path.isdir(test_data_folder):
for name in os.listdir(test_data_folder):
if name in self.skipped_models:
continue
if os.path.isdir(os.path.join(test_data_folder, name)):
current_dir = os.path.join(test_data_folder, name)
model_data_map[os.path.join(model_file_folder, name)] = [os.path.join(current_dir, name) for name in os.listdir(current_dir) if os.path.isdir(os.path.join(current_dir, name))]
test_util.test_log('Planned models and test data:')
for model_data, data_paths in model_data_map.items():
test_util.test_log(model_data)
for data in data_paths:
test_util.test_log('\t\t{0}'.format(data))
test_util.test_log('-----------------------')
self.server_port = random.randint(30000, 40000)
for model_path, data_paths in model_data_map.items():
server_app_proc = None
try:
cmd = [self.server_app_path, '--http_port', str(self.server_port), '--model_path', os.path.join(model_path, 'model.onnx'), '--log_level', self.log_level]
test_util.test_log(cmd)
server_app_proc = test_util.launch_server_app(cmd, self.server_ip, self.server_port, self.server_ready_in_seconds)
test_util.test_log('[{0}] Run tests...'.format(model_path))
for test in data_paths:
test_util.test_log('[{0}] Current: {0}'.format(model_path, test))
test_util.test_log('[{0}] JSON payload testing ....'.format(model_path))
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
with open(os.path.join(test, 'request.json')) as f:
request_payload = f.read()
resp = test_util.make_http_request(url, json_request_headers, request_payload)
test_util.json_response_validation(self, resp, os.path.join(test, 'response.json'))
test_util.test_log('[{0}] Protobuf payload testing ....'.format(model_path))
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 54321)
with open(os.path.join(test, 'request.pb'), 'rb') as f:
request_payload = f.read()
resp = test_util.make_http_request(url, pb_request_headers, request_payload)
test_util.pb_response_validation(self, resp, os.path.join(test, 'response.pb'))
finally:
test_util.shutdown_server_app(server_app_proc, self.server_off_in_seconds)
if __name__ == '__main__':
loader = unittest.TestLoader()
test_classes = [ModelZooTests]
test_suites = []
for tests in test_classes:
tests.server_app_path = sys.argv[1]
tests.model_zoo_model_path = sys.argv[2]
tests.model_zoo_test_data_path = sys.argv[3]
test_suites.append(loader.loadTestsFromTestCase(tests))
suites = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
results = runner.run(suites)

Просмотреть файл

@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import sys
import random
import unittest
import function_tests
if __name__ == '__main__':
loader = unittest.TestLoader()
test_classes = [function_tests.HttpJsonPayloadTests, function_tests.HttpProtobufPayloadTests, function_tests.HttpEndpointTests]
test_suites = []
for tests in test_classes:
tests.server_app_path = sys.argv[1]
tests.model_path = sys.argv[2]
tests.test_data_path = sys.argv[3]
tests.server_port = random.randint(30000, 50000)
test_suites.append(loader.loadTestsFromTestCase(tests))
suites = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
results = runner.run(suites)

Просмотреть файл

@ -0,0 +1,179 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import base64
import struct
import math
import subprocess
import time
import requests
import json
import datetime
import socket
import errno
import sys
import urllib.request
import predict_pb2
import onnx_ml_pb2
import numpy
def test_log(str):
print('[Test Log][{0}] {1}'.format(datetime.datetime.now(), str))
def is_process_killed(pid):
if sys.platform.startswith("win"):
process_name = 'onnxruntime_host.exe'
call = 'TASKLIST', '/FI', 'imagename eq {0}'.format(process_name)
output = subprocess.check_output(call).decode('utf-8')
print(output)
last_line = output.strip().split('\r\n')[-1]
return not last_line.lower().startswith(process_name)
else:
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True
def prepare_mnist_model(target_path):
# TODO: This need to be replaced by test data on build machine after merged to upstream master.
if not os.path.isfile(target_path):
test_log('Downloading model from blob storage: https://ortsrvdev.blob.core.windows.net/test-data/mnist.onnx to {0}'.format(target_path))
urllib.request.urlretrieve('https://ortsrvdev.blob.core.windows.net/test-data/mnist.onnx', target_path)
else:
test_log('Found mnist model at {0}'.format(target_path))
def decode_base64_string(s, count_and_type):
b = base64.b64decode(s)
r = struct.unpack(count_and_type, b)
return r
def compare_floats(a, b, rel_tol=0.0001, abs_tol=0.0001):
if not math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol):
test_log('Not match with relative tolerance {0} and absolute tolerance {1}: {2} and {3}'.format(rel_tol, abs_tol, a, b))
return False
return True
def wait_service_up(server, port, timeout=1):
s = socket.socket()
if timeout:
end = time.time() + timeout
while True:
try:
if timeout:
next_timeout = end - time.time()
if next_timeout < 0:
return False
else:
s.settimeout(next_timeout)
s.connect((server, port))
except socket.timeout as err:
if timeout:
return False
except Exception as err:
pass
else:
s.close()
return True
def launch_server_app(cmd, server_ip, server_port, wait_server_ready_in_seconds):
test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
server_app_proc = subprocess.Popen(cmd)
test_log('Server app PID: {0}'.format(server_app_proc.pid))
test_log('Wait up to {0} second(s) for server initialization'.format(wait_server_ready_in_seconds))
wait_service_up(server_ip, server_port, wait_server_ready_in_seconds)
return server_app_proc
def shutdown_server_app(server_app_proc, wait_for_server_off_in_seconds):
if server_app_proc is not None:
test_log('Shutdown server app')
server_app_proc.kill()
while not is_process_killed(server_app_proc.pid):
server_app_proc.wait(timeout=wait_for_server_off_in_seconds)
test_log('PID {0} has been killed: {1}'.format(server_app_proc.pid, is_process_killed(server_app_proc.pid)))
# Additional sleep to make sure the resource has been freed.
time.sleep(1)
return True
def make_http_request(url, request_headers, payload):
test_log('POST Request Started')
resp = requests.post(url, headers=request_headers, data=payload)
test_log('POST Request Done')
return resp
def json_response_validation(cls, resp, expected_resp_json_file):
cls.assertEqual(resp.status_code, 200)
cls.assertTrue(resp.headers.get('x-ms-request-id'))
cls.assertEqual(resp.headers.get('Content-Type'), 'application/json')
with open(expected_resp_json_file) as f:
expected_result = json.loads(f.read())
actual_response = json.loads(resp.content.decode('utf-8'))
cls.assertTrue(actual_response['outputs'])
for output in expected_result['outputs'].keys():
cls.assertTrue(actual_response['outputs'][output])
cls.assertTrue(actual_response['outputs'][output]['dataType'])
cls.assertEqual(actual_response['outputs'][output]['dataType'], expected_result['outputs'][output]['dataType'])
cls.assertTrue(actual_response['outputs'][output]['dims'])
cls.assertEqual(actual_response['outputs'][output]['dims'], expected_result['outputs'][output]['dims'])
cls.assertTrue(actual_response['outputs'][output]['rawData'])
count = 1
for x in actual_response['outputs'][output]['dims']:
count = count * int(x)
actual_array = decode_base64_string(actual_response['outputs'][output]['rawData'], '{0}f'.format(count))
expected_array = decode_base64_string(expected_result['outputs'][output]['rawData'], '{0}f'.format(count))
cls.assertEqual(len(actual_array), len(expected_array))
cls.assertEqual(len(actual_array), count)
for i in range(0, count):
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
def pb_response_validation(cls, resp, expected_resp_pb_file):
cls.assertEqual(resp.status_code, 200)
cls.assertTrue(resp.headers.get('x-ms-request-id'))
cls.assertEqual(resp.headers.get('Content-Type'), 'application/octet-stream')
actual_result = predict_pb2.PredictResponse()
actual_result.ParseFromString(resp.content)
expected_result = predict_pb2.PredictResponse()
with open(expected_resp_pb_file, 'rb') as f:
expected_result.ParseFromString(f.read())
for k in expected_result.outputs.keys():
cls.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
count = 1
for i in range(0, len(expected_result.outputs[k].dims)):
cls.assertEqual(actual_result.outputs[k].dims[i], expected_result.outputs[k].dims[i])
count = count * int(actual_result.outputs[k].dims[i])
actual_array = numpy.frombuffer(actual_result.outputs[k].raw_data, dtype=numpy.float32)
expected_array = numpy.frombuffer(expected_result.outputs[k].raw_data, dtype=numpy.float32)
cls.assertEqual(len(actual_array), len(expected_array))
cls.assertEqual(len(actual_array), count)
for i in range(0, count):
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,109 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include "gtest/gtest.h"
#include "server/http/core/routes.h"
namespace onnxruntime {
namespace server {
namespace test {
using test_data = std::tuple<http::verb, std::string, std::string, std::string, std::string, http::status>;
void do_something(const std::string& name, const std::string& version,
const std::string& action, HttpContext& context) {
auto noop = name + version + action + context.request.body();
}
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data);
TEST(HttpRouteTests, RegisterTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
Routes routes;
EXPECT_TRUE(routes.RegisterController(http::verb::post, predict_regex, do_something));
auto status_regex = R"(/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?(?:\/(metadata))?)";
EXPECT_TRUE(routes.RegisterController(http::verb::get, status_regex, do_something));
}
TEST(HttpRouteTests, PostRouteTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
std::vector<test_data> actions{
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/abc:predict", "abc", "", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/models/versions/45:predict", "models", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/??$$%%@@$^^/versions/45:predict", "??$$%%@@$^^", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok)};
run_route(predict_regex, http::verb::post, actions, true);
}
TEST(HttpRouteTests, PostRouteInvalidURLTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
std::vector<test_data> actions{
std::make_tuple(http::verb::post, "", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models:bar", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/a:bc/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions/2.0:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/models/abc/versions/2:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/versions/2:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/foo/versions/:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/foo/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "v1/models/foo/versions/12:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)};
run_route(predict_regex, http::verb::post, actions, false);
}
// These tests are because we currently only support POST and GET
// Some HTTP methods should be removed from test data if we support more (e.g. PUT)
TEST(HttpRouteTests, PostRouteInvalidMethodTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
std::vector<test_data> actions{
std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed),
std::make_tuple(http::verb::put, "/v1/models", "", "", "", http::status::method_not_allowed),
std::make_tuple(http::verb::delete_, "/v1/models", "", "", "", http::status::method_not_allowed),
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)};
run_route(predict_regex, http::verb::post, actions, false);
}
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data) {
Routes routes;
EXPECT_TRUE(routes.RegisterController(method, pattern, do_something));
for (const auto& i : data) {
http::verb test_method;
std::string url_string;
std::string name;
std::string version;
std::string action;
HandlerFn fn;
std::string expected_name;
std::string expected_version;
std::string expected_action;
http::status expected_status;
std::tie(test_method, url_string, expected_name, expected_version, expected_action, expected_status) = i;
EXPECT_EQ(expected_status, routes.ParseUrl(test_method, url_string, name, version, action, fn));
if (does_validate_data) {
EXPECT_EQ(name, expected_name);
EXPECT_EQ(version, expected_version);
EXPECT_EQ(action, expected_action);
}
}
}
} // namespace test
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,128 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <fstream>
#include <google/protobuf/stubs/status.h>
#include "gtest/gtest.h"
#include "predict.pb.h"
#include "server/http/json_handling.h"
namespace onnxruntime {
namespace server {
namespace test {
namespace protobufutil = google::protobuf::util;
TEST(JsonDeserializationTests, HappyPath) {
std::string input_json = R"({"inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAQEAAAAAAAAAAAAAAgEAAAABAAAAAAAAAMEEAAAAAAAAAAAAAYEEAAIA/AAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQEEAAAAAAAAAAAAA4EAAAAAAAACAPwAAIEEAAAAAAAAAQAAAAEAAAIBBAAAAAAAAQEAAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4EAAAABBAAAAAAAAAEEAAAAAAAAAAAAAAEEAAAAAAAAAAAAAmEEAAAAAAAAAAAAAgD8AAKhBAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAAAAAAIEEAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABQQQAAAAAAAHBBAAAgQQAA0EEAAAhCAACIQQAAmkIAADVDAAAyQwAADEIAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWQwAAfkMAAHpDAAB7QwAAc0MAAHxDAAB8QwAAf0MAADRCAADAQAAAAAAAAKBAAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBAAACQQgAATUMAAH9DAABuQwAAc0MAAH9DAAB+QwAAe0MAAHhDAABJQwAARkMAAGRCAAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWkMAAH9DAABxQwAAf0MAAHlDAAB6QwAAe0MAAHpDAAB/QwAAf0MAAHJDAABgQwAAREIAAAAAAABAQQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAABAQAAAAEAAAABAAACAPwAAAAAAAIJCAABkQwAAf0MAAH5DAAB0QwAA7kIAAAhCAAAkQgAA3EIAAHpDAAB/QwAAeEMAAPhCAACgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBBAAAAAAAAeEIAAM5CAADiQgAA6kIAAAhCAAAAAAAAAAAAAAAAAABIQwAAdEMAAH9DAAB/QwAAAAAAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAEAAAIA/AAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAAAAAAAAAAAAAABAAACAQAAAAAAAADBBAAAAAAAA4EAAAMBAAAAAAAAAlkIAAHRDAAB/QwAAf0MAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAIBAAACAQAAAAAAAAGBBAAAAAAAAAAAAAAAAAAAQQQAAAAAAAABAAAAAAAAAAAAAAAhCAAB/QwAAf0MAAH1DAAAgQQAAIEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAABAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAFBBAAAwQQAAAAAAAAAAAAAAAAAAwEAAAEBBAADGQgAAf0MAAH5DAAB4QwAAcEEAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAoEAAAMBAAAAwQQAAAAAAAAAAAACIQQAAOEMAAHdDAAB/QwAAc0MAAFBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAABAAACAQAAAgEAAAAAAAAAwQQAAAAAAAExCAAC8QgAAqkIAAKBAAACgQAAAyEEAAHZDAAB2QwAAf0MAAFBDAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAACAQAAAgD8AAAAAAAAAAAAAgD8AAOBAAABwQQAAmEEAAMZCAADOQgAANkMAAD1DAABtQwAAfUMAAHxDAAA/QwAAPkMAAGNDAABzQwAAfEMAAFJDAACQQQAA4EAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAAAAAAAAAAAABCAADaQgAAOUMAAHdDAAB/QwAAckMAAH9DAAB0QwAAf0MAAH9DAAByQwAAe0MAAH9DAABwQwAAf0MAAH9DAABaQwAA+EIAABBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAD+QgAAf0MAAGtDAAB/QwAAf0MAAHdDAABlQwAAVEMAAHJDAAB6QwAAf0MAAH9DAAB4QwAAf0MAAH1DAAB5QwAAf0MAAHNDAAAqQwAAQEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAQQQAAfUMAAH9DAAB/QwAAaUMAAEpDAACqQgAAAAAAAFRCAABEQwAAbkMAAH9DAABjQwAAbkMAAA5DAADaQgAAQUMAAH9DAABwQwAAf0MAADRDAAAAAAAAAAAAAAAAAAAAAAAAwEAAAAAAAACwQQAAgD8AAHVDAABzQwAAfkMAAH9DAABZQwAAa0MAAGJDAABVQwAAdEMAAHtDAAB/QwAAb0MAAJpCAAAAAAAAAAAAAKBBAAA2QwAAd0MAAG9DAABzQwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAlQwAAe0MAAH9DAAB1QwAAf0MAAHJDAAB9QwAAekMAAH9DAABFQwAA1kIAAGxCAAAAAAAAkEEAAABAAADAQAAAAAAAAFhCAAB/QwAAHkMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEEAAAAAAAAAAAAAwEAAAAhCAAAnQwAAQkMAADBDAAA3QwAAJEMAADBCAAAAQAAAIEEAAMBAAADAQAAAAAAAAAAAAACgQAAAAAAAAIA/AAAAAAAAYEEAAABAAAAAAAAAAAAAAAAAAAAAAAAAIEEAAAAAAABgQQAAAAAAAEBBAAAAAAAAoEAAAAAAAACAPwAAAAAAAMBAAAAAAAAA4EAAAAAAAAAAAAAAAAAAAABBAAAAAAAAIEEAAAAAAACgQAAAAAAAAAAAAAAgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAABgQQAAAAAAAIBAAAAAAAAAAAAAAMhBAAAAAAAAAAAAABBBAAAAAAAAAAAAABBBAAAAAAAAMEEAAAAAAACAPwAAAAAAAAAAAAAAQAAAAAAAAAAAAADgQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}},"outputFilter":["Plus214_Output_0"]})";
onnxruntime::server::PredictRequest request;
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
EXPECT_EQ(protobufutil::error::OK, status.error_code());
}
TEST(JsonDeserializationTests, WithUnknownField) {
std::string input_json = R"({"foo": "bar","inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAQEAAAAAAAAAAAAAAgEAAAABAAAAAAAAAMEEAAAAAAAAAAAAAYEEAAIA/AAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQEEAAAAAAAAAAAAA4EAAAAAAAACAPwAAIEEAAAAAAAAAQAAAAEAAAIBBAAAAAAAAQEAAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4EAAAABBAAAAAAAAAEEAAAAAAAAAAAAAAEEAAAAAAAAAAAAAmEEAAAAAAAAAAAAAgD8AAKhBAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAAAAAAIEEAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABQQQAAAAAAAHBBAAAgQQAA0EEAAAhCAACIQQAAmkIAADVDAAAyQwAADEIAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWQwAAfkMAAHpDAAB7QwAAc0MAAHxDAAB8QwAAf0MAADRCAADAQAAAAAAAAKBAAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBAAACQQgAATUMAAH9DAABuQwAAc0MAAH9DAAB+QwAAe0MAAHhDAABJQwAARkMAAGRCAAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWkMAAH9DAABxQwAAf0MAAHlDAAB6QwAAe0MAAHpDAAB/QwAAf0MAAHJDAABgQwAAREIAAAAAAABAQQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAABAQAAAAEAAAABAAACAPwAAAAAAAIJCAABkQwAAf0MAAH5DAAB0QwAA7kIAAAhCAAAkQgAA3EIAAHpDAAB/QwAAeEMAAPhCAACgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBBAAAAAAAAeEIAAM5CAADiQgAA6kIAAAhCAAAAAAAAAAAAAAAAAABIQwAAdEMAAH9DAAB/QwAAAAAAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAEAAAIA/AAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAAAAAAAAAAAAAABAAACAQAAAAAAAADBBAAAAAAAA4EAAAMBAAAAAAAAAlkIAAHRDAAB/QwAAf0MAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAIBAAACAQAAAAAAAAGBBAAAAAAAAAAAAAAAAAAAQQQAAAAAAAABAAAAAAAAAAAAAAAhCAAB/QwAAf0MAAH1DAAAgQQAAIEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAABAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAFBBAAAwQQAAAAAAAAAAAAAAAAAAwEAAAEBBAADGQgAAf0MAAH5DAAB4QwAAcEEAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAoEAAAMBAAAAwQQAAAAAAAAAAAACIQQAAOEMAAHdDAAB/QwAAc0MAAFBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAABAAACAQAAAgEAAAAAAAAAwQQAAAAAAAExCAAC8QgAAqkIAAKBAAACgQAAAyEEAAHZDAAB2QwAAf0MAAFBDAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAACAQAAAgD8AAAAAAAAAAAAAgD8AAOBAAABwQQAAmEEAAMZCAADOQgAANkMAAD1DAABtQwAAfUMAAHxDAAA/QwAAPkMAAGNDAABzQwAAfEMAAFJDAACQQQAA4EAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAAAAAAAAAAAABCAADaQgAAOUMAAHdDAAB/QwAAckMAAH9DAAB0QwAAf0MAAH9DAAByQwAAe0MAAH9DAABwQwAAf0MAAH9DAABaQwAA+EIAABBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAD+QgAAf0MAAGtDAAB/QwAAf0MAAHdDAABlQwAAVEMAAHJDAAB6QwAAf0MAAH9DAAB4QwAAf0MAAH1DAAB5QwAAf0MAAHNDAAAqQwAAQEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAQQQAAfUMAAH9DAAB/QwAAaUMAAEpDAACqQgAAAAAAAFRCAABEQwAAbkMAAH9DAABjQwAAbkMAAA5DAADaQgAAQUMAAH9DAABwQwAAf0MAADRDAAAAAAAAAAAAAAAAAAAAAAAAwEAAAAAAAACwQQAAgD8AAHVDAABzQwAAfkMAAH9DAABZQwAAa0MAAGJDAABVQwAAdEMAAHtDAAB/QwAAb0MAAJpCAAAAAAAAAAAAAKBBAAA2QwAAd0MAAG9DAABzQwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAlQwAAe0MAAH9DAAB1QwAAf0MAAHJDAAB9QwAAekMAAH9DAABFQwAA1kIAAGxCAAAAAAAAkEEAAABAAADAQAAAAAAAAFhCAAB/QwAAHkMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEEAAAAAAAAAAAAAwEAAAAhCAAAnQwAAQkMAADBDAAA3QwAAJEMAADBCAAAAQAAAIEEAAMBAAADAQAAAAAAAAAAAAACgQAAAAAAAAIA/AAAAAAAAYEEAAABAAAAAAAAAAAAAAAAAAAAAAAAAIEEAAAAAAABgQQAAAAAAAEBBAAAAAAAAoEAAAAAAAACAPwAAAAAAAMBAAAAAAAAA4EAAAAAAAAAAAAAAAAAAAABBAAAAAAAAIEEAAAAAAACgQAAAAAAAAAAAAAAgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAABgQQAAAAAAAIBAAAAAAAAAAAAAAMhBAAAAAAAAAAAAABBBAAAAAAAAAAAAABBBAAAAAAAAMEEAAAAAAACAPwAAAAAAAAAAAAAAQAAAAAAAAAAAAADgQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}},"outputFilter":["Plus214_Output_0"]})";
onnxruntime::server::PredictRequest request;
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
EXPECT_EQ(protobufutil::error::OK, status.error_code());
}
TEST(JsonDeserializationTests, InvalidData) {
std::string input_json = R"({"inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"hello"}},"outputFilter":["Plus214_Output_0"]})";
onnxruntime::server::PredictRequest request;
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
EXPECT_EQ(protobufutil::error::INVALID_ARGUMENT, status.error_code());
EXPECT_EQ("inputs[0].value.raw_data: invalid value \"hello\" for type TYPE_BYTES", status.error_message());
}
TEST(JsonDeserializationTests, InvalidJson) {
std::string input_json = R"({inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"hello"}},"outputFilter":["Plus214_Output_0"]})";
onnxruntime::server::PredictRequest request;
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
EXPECT_EQ(protobufutil::error::INVALID_ARGUMENT, status.error_code());
std::string errmsg = status.error_message();
EXPECT_EQ("Expected : between key:value pair.\n{inputs\":{\"Input3\":{\"dims\":\n ^", status.error_message());
}
TEST(JsonSerializationTests, HappyPath) {
std::string test_data = "testdata/server/response_0.pb";
std::string expected_json_string = R"({"outputs":{"Plus214_Output_0":{"dims":["1","10"],"dataType":1,"rawData":"4+pzRFWuGsSMdM1F2gEnRFdRZcRZ9NDEURj0xBIzdsJOS0LEA/GzxA=="}}})";
onnxruntime::server::PredictResponse response;
std::string json_string;
std::ifstream ifs(test_data, std::ios_base::in | std::ios_base::binary);
ASSERT_TRUE(ifs) << test_data << " Not Found" << std::endl;
bool succeeded = response.ParseFromIstream(&ifs);
ifs.close();
EXPECT_TRUE(succeeded) << test_data << " is invalid" << std::endl;
protobufutil::Status status = onnxruntime::server::GenerateResponseInJson(response, json_string);
EXPECT_EQ(protobufutil::error::OK, status.error_code());
EXPECT_EQ(expected_json_string, json_string);
}
TEST(StringEscapingTests, SimpleString) {
std::string unescaped = "This is an error message \" \n ";
EXPECT_EQ("This is an error message \\\" \\n ", escape_string(unescaped));
}
TEST(StringEscapingTests, SimpleStringWithControlCharacter) {
std::string unescaped = "This is an \x1f error message";
EXPECT_EQ("This is an \\u001f error message", escape_string(unescaped));
}
TEST(StringEscapingTests, SimpleStringWithNullCharacter) {
std::string unescaped = "This is an error message \x00 end";
EXPECT_EQ("This is an error message ", escape_string(unescaped));
}
TEST(JsonErrorMessageTests, SimpleMessage) {
auto status = http::status::bad_request;
std::string error_message = "Incorrect headers";
std::string expected = "{\"error_code\": 400, \"error_message\": \"Incorrect headers\"}\n";
std::string res = CreateJsonError(status, error_message);
EXPECT_EQ(expected, res);
}
TEST(JsonErrorMessageTests, MessageWithNewLine) {
auto status = http::status::internal_server_error;
std::string error_message = "Contains newline \n here";
std::string expected = "{\"error_code\": 500, \"error_message\": \"Contains newline \\n here\"}\n";
std::string res = CreateJsonError(status, error_message);
EXPECT_EQ(expected, res);
}
TEST(JsonErrorMessageTests, MessageWithRealError) {
auto status = http::status::bad_request;
std::string error_message = "Expected , or ] after array value.\n0, 0.0, 0.0, 0.0 } }, \"outputFilter\n ^";
std::string expected = "{\"error_code\": 400, \"error_message\": \"Expected , or ] after array value.\\n0, 0.0, 0.0, 0.0 } }, \\\"outputFilter\\n ^\"}\n";
std::string res = CreateJsonError(status, error_message);
EXPECT_EQ(expected, res);
}
TEST(JsonErrorMessageTests, MessageWithQuotations) {
auto status = http::status::bad_request;
std::string error_message = R"(Error with "{"bleh": [1,2,3]|")";
std::string expected = "{\"error_code\": 400, \"error_message\": \"Error with \\\"{\\\"bleh\\\": [1,2,3]|\\\"\"}\n";
std::string result_t = CreateJsonError(status, error_message);
EXPECT_EQ(expected, result_t);
}
TEST(JsonErrorMessageTests, MessageWithManyCarriageCharacters) {
auto status = http::status::bad_request;
std::string error_message = "\"ab\r\n\b\f\t\\\x1a\"";
std::string expected = "{\"error_code\": 400, \"error_message\": \"\\\"ab\\r\\n\\b\\f\\t\\\\\\u001a\\\"\"}\n";
std::string result_t = CreateJsonError(status, error_message);
EXPECT_EQ(expected, result_t);
}
} // namespace test
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "gmock/gmock.h"
#include "server/server_configuration.h"
namespace onnxruntime {
namespace server {
namespace test {
TEST(ConfigParsingTests, AllArgs) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--model_path"), const_cast<char*>("testdata/mul_1.pb"),
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
const_cast<char*>("--http_port"), const_cast<char*>("80"),
const_cast<char*>("--num_http_threads"), const_cast<char*>("1"),
const_cast<char*>("--log_level"), const_cast<char*>("info")};
onnxruntime::server::ServerConfiguration config{};
Result res = config.ParseInput(11, test_argv);
EXPECT_EQ(res, Result::ContinueSuccess);
EXPECT_EQ(config.model_path, "testdata/mul_1.pb");
EXPECT_EQ(config.address, "4.4.4.4");
EXPECT_EQ(config.http_port, 80);
EXPECT_EQ(config.num_http_threads, 1);
EXPECT_EQ(config.logging_level, onnxruntime::logging::Severity::kINFO);
}
TEST(ConfigParsingTests, Defaults) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--model"), const_cast<char*>("testdata/mul_1.pb"),
const_cast<char*>("--num_http_threads"), const_cast<char*>("3")};
onnxruntime::server::ServerConfiguration config{};
Result res = config.ParseInput(5, test_argv);
EXPECT_EQ(res, Result::ContinueSuccess);
EXPECT_EQ(config.model_path, "testdata/mul_1.pb");
EXPECT_EQ(config.address, "0.0.0.0");
EXPECT_EQ(config.http_port, 8001);
EXPECT_EQ(config.num_http_threads, 3);
EXPECT_EQ(config.logging_level, onnxruntime::logging::Severity::kINFO);
}
TEST(ConfigParsingTests, Help) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--help")};
onnxruntime::server::ServerConfiguration config{};
auto res = config.ParseInput(2, test_argv);
EXPECT_EQ(res, Result::ExitSuccess);
}
TEST(ConfigParsingTests, NoModelArg) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--num_http_threads"), const_cast<char*>("3")};
onnxruntime::server::ServerConfiguration config{};
Result res = config.ParseInput(3, test_argv);
EXPECT_EQ(res, Result::ExitFailure);
}
TEST(ConfigParsingTests, ModelNotFound) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--model_path"), const_cast<char*>("does/not/exist"),
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
const_cast<char*>("--http_port"), const_cast<char*>("80"),
const_cast<char*>("--num_http_threads"), const_cast<char*>("1")};
onnxruntime::server::ServerConfiguration config{};
Result res = config.ParseInput(9, test_argv);
EXPECT_EQ(res, Result::ExitFailure);
}
TEST(ConfigParsingTests, WrongLoggingLevel) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--log_level"), const_cast<char*>("not a logging level"),
const_cast<char*>("--model_path"), const_cast<char*>("testdata/mul_1.pb"),
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
const_cast<char*>("--http_port"), const_cast<char*>("80"),
const_cast<char*>("--num_http_threads"), const_cast<char*>("1")};
onnxruntime::server::ServerConfiguration config{};
Result res = config.ParseInput(11, test_argv);
EXPECT_EQ(res, Result::ExitFailure);
}
} // namespace test
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/test_environment.h"
GTEST_API_ int main(int argc, char** argv) {
int status = 0;
try {
const bool create_default_logger = true;
onnxruntime::test::TestEnvironment environment{argc, argv, create_default_logger};
status = RUN_ALL_TESTS();
} catch (const std::exception& ex) {
std::cerr << ex.what();
status = -1;
}
return status;
}

Просмотреть файл

@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <google/protobuf/stubs/status.h>
#include "gtest/gtest.h"
#include "server/http/core/context.h"
#include "server/http/util.h"
namespace onnxruntime {
namespace server {
namespace test {
namespace protobufutil = google::protobuf::util;
TEST(RequestContentTypeTests, ContentTypeJson) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::content_type, "application/json");
context.request = request;
auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::Json);
}
TEST(RequestContentTypeTests, ContentTypeRawData) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::content_type, "application/octet-stream");
context.request = request;
auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
context.request.set(http::field::content_type, "application/vnd.google.protobuf");
result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
context.request.set(http::field::content_type, "application/x-protobuf");
result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}
TEST(RequestContentTypeTests, ContentTypeUnknown) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::content_type, "text/plain");
context.request = request;
auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::Unknown);
}
TEST(RequestContentTypeTests, ContentTypeMissing) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
context.request = request;
auto result = GetRequestContentType(context);
EXPECT_EQ(result, SupportedContentType::Unknown);
}
TEST(ResponseContentTypeTests, ContentTypeJson) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "application/json");
context.request = request;
auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::Json);
}
TEST(ResponseContentTypeTests, ContentTypeRawData) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "application/octet-stream");
context.request = request;
auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
context.request.set(http::field::accept, "application/vnd.google.protobuf");
result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
context.request.set(http::field::accept, "application/x-protobuf");
result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}
TEST(ResponseContentTypeTests, ContentTypeAny) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "*/*");
context.request = request;
auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}
TEST(ResponseContentTypeTests, ContentTypeUnknown) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
request.set(http::field::accept, "text/plain");
context.request = request;
auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::Unknown);
}
TEST(ContentTypeTests, ContentTypeMissing) {
HttpContext context;
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
context.request = request;
auto result = GetResponseContentType(context);
EXPECT_EQ(result, SupportedContentType::PbByteArray);
}
} // namespace test
} // namespace server
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1 @@
{"inputs": {"Input3": {"dims": ["1", "1", "28", "28"], "dataType": 1, "rawData": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAQEAAAAAAAAAAAAAAgEAAAABAAAAAAAAAMEEAAAAAAAAAAAAAYEEAAIA/AAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQEEAAAAAAAAAAAAA4EAAAAAAAACAPwAAIEEAAAAAAAAAQAAAAEAAAIBBAAAAAAAAQEAAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4EAAAABBAAAAAAAAAEEAAAAAAAAAAAAAAEEAAAAAAAAAAAAAmEEAAAAAAAAAAAAAgD8AAKhBAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAAAAAAIEEAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABQQQAAAAAAAHBBAAAgQQAA0EEAAAhCAACIQQAAmkIAADVDAAAyQwAADEIAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWQwAAfkMAAHpDAAB7QwAAc0MAAHxDAAB8QwAAf0MAADRCAADAQAAAAAAAAKBAAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBAAACQQgAATUMAAH9DAABuQwAAc0MAAH9DAAB+QwAAe0MAAHhDAABJQwAARkMAAGRCAAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWkMAAH9DAABxQwAAf0MAAHlDAAB6QwAAe0MAAHpDAAB/QwAAf0MAAHJDAABgQwAAREIAAAAAAABAQQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAABAQAAAAEAAAABAAACAPwAAAAAAAIJCAABkQwAAf0MAAH5DAAB0QwAA7kIAAAhCAAAkQgAA3EIAAHpDAAB/QwAAeEMAAPhCAACgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBBAAAAAAAAeEIAAM5CAADiQgAA6kIAAAhCAAAAAAAAAAAAAAAAAABIQwAAdEMAAH9DAAB/QwAAAAAAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAEAAAIA/AAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAAAAAAAAAAAAAABAAACAQAAAAAAAADBBAAAAAAAA4EAAAMBAAAAAAAAAlkIAAHRDAAB/QwAAf0MAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAIBAAACAQAAAAAAAAGBBAAAAAAAAAAAAAAAAAAAQQQAAAAAAAABAAAAAAAAAAAAAAAhCAAB/QwAAf0MAAH1DAAAgQQAAIEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAABAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAFBBAAAwQQAAAAAAAAAAAAAAAAAAwEAAAEBBAADGQgAAf0MAAH5DAAB4QwAAcEEAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAoEAAAMBAAAAwQQAAAAAAAAAAAACIQQAAOEMAAHdDAAB/QwAAc0MAAFBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAABAAACAQAAAgEAAAAAAAAAwQQAAAAAAAExCAAC8QgAAqkIAAKBAAACgQAAAyEEAAHZDAAB2QwAAf0MAAFBDAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAACAQAAAgD8AAAAAAAAAAAAAgD8AAOBAAABwQQAAmEEAAMZCAADOQgAANkMAAD1DAABtQwAAfUMAAHxDAAA/QwAAPkMAAGNDAABzQwAAfEMAAFJDAACQQQAA4EAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAAAAAAAAAAAABCAADaQgAAOUMAAHdDAAB/QwAAckMAAH9DAAB0QwAAf0MAAH9DAAByQwAAe0MAAH9DAABwQwAAf0MAAH9DAABaQwAA+EIAABBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAD+QgAAf0MAAGtDAAB/QwAAf0MAAHdDAABlQwAAVEMAAHJDAAB6QwAAf0MAAH9DAAB4QwAAf0MAAH1DAAB5QwAAf0MAAHNDAAAqQwAAQEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAQQQAAfUMAAH9DAAB/QwAAaUMAAEpDAACqQgAAAAAAAFRCAABEQwAAbkMAAH9DAABjQwAAbkMAAA5DAADaQgAAQUMAAH9DAABwQwAAf0MAADRDAAAAAAAAAAAAAAAAAAAAAAAAwEAAAAAAAACwQQAAgD8AAHVDAABzQwAAfkMAAH9DAABZQwAAa0MAAGJDAABVQwAAdEMAAHtDAAB/QwAAb0MAAJpCAAAAAAAAAAAAAKBBAAA2QwAAd0MAAG9DAABzQwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAlQwAAe0MAAH9DAAB1QwAAf0MAAHJDAAB9QwAAekMAAH9DAABFQwAA1kIAAGxCAAAAAAAAkEEAAABAAADAQAAAAAAAAFhCAAB/QwAAHkMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEEAAAAAAAAAAAAAwEAAAAhCAAAnQwAAQkMAADBDAAA3QwAAJEMAADBCAAAAQAAAIEEAAMBAAADAQAAAAAAAAAAAAACgQAAAAAAAAIA/AAAAAAAAYEEAAABAAAAAAAAAAAAAAAAAAAAAAAAAIEEAAAAAAABgQQAAAAAAAEBBAAAAAAAAoEAAAAAAAACAPwAAAAAAAMBAAAAAAAAA4EAAAAAAAAAAAAAAAAAAAABBAAAAAAAAIEEAAAAAAACgQAAAAAAAAAAAAAAgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAABgQQAAAAAAAIBAAAAAAAAAAAAAAMhBAAAAAAAAAAAAABBBAAAAAAAAAAAAABBBAAAAAAAAMEEAAAAAAACAPwAAAAAAAAAAAAAAQAAAAAAAAAAAAADgQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}}, "outputFilter": ["Plus214_Output_0"]}

Двоичные данные
onnxruntime/test/testdata/server/mnist_test_data_set_0_input.pb поставляемый Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1 @@
{"outputs": {"Plus214_Output_0": {"dims": ["1", "10"], "dataType": 1, "rawData": "4+pzRFWuGsSMdM1F2gEnRFdRZcRZ9NDEURj0xBIzdsJOS0LEA/GzxA=="}}}

Просмотреть файл

@ -0,0 +1,5 @@
D
Plus214_Output_00

J(肚sDUョト荊ヘEレ'DWQeトY<EFBE84>トQ<18>3vツNKBト<03>

Двоичные данные
onnxruntime/test/testdata/server/request_0.pb поставляемый Normal file

Двоичный файл не отображается.

5
onnxruntime/test/testdata/server/response_0.pb поставляемый Normal file
Просмотреть файл

@ -0,0 +1,5 @@
D
Plus214_Output_00

J(肚sDUョト荊ヘEレ'DWQeトY<EFBE84>トQ<18>3vツNKBト<03>

Просмотреть файл

@ -95,6 +95,10 @@ Use the individual flags to only run the specified stages.
# Build a shared lib
parser.add_argument("--build_shared_lib", action='store_true', help="Build a shared library for the ONNXRuntime.")
# Build ONNX Runtime server
parser.add_argument("--build_server", action='store_true', help="Build server application for the ONNXRuntime.")
parser.add_argument("--enable_server_tests", action='store_true', help="Run server application tests.")
# Build options
parser.add_argument("--cmake_extra_defines", nargs="+",
help="Extra definitions to pass to CMake during build system generation. " +
@ -324,9 +328,10 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
"-Donnxruntime_TENSORRT_HOME=" + (tensorrt_home if args.use_tensorrt else ""),
# By default - we currently support only cross compiling for ARM/ARM64 (no native compilation supported through this script)
"-Donnxruntime_CROSS_COMPILING=" + ("ON" if args.arm64 or args.arm else "OFF"),
"-Donnxruntime_BUILD_SERVER=" + ("ON" if args.build_server else "OFF"),
"-Donnxruntime_BUILD_x86=" + ("ON" if args.x86 else "OFF"),
# nGraph and TensorRT providers currently only supports full_protobuf option.
"-Donnxruntime_USE_FULL_PROTOBUF=" + ("ON" if args.use_full_protobuf or args.use_ngraph or args.use_tensorrt else "OFF"),
"-Donnxruntime_USE_FULL_PROTOBUF=" + ("ON" if args.use_full_protobuf or args.use_ngraph or args.use_tensorrt or args.build_server else "OFF"),
"-Donnxruntime_DISABLE_CONTRIB_OPS=" + ("ON" if args.disable_contrib_ops else "OFF"),
"-Donnxruntime_MSVC_STATIC_RUNTIME=" + ("ON" if args.enable_msvc_static_runtime else "OFF"),
]
@ -535,6 +540,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, enab
if onnxml_test:
run_subprocess([sys.executable, 'onnxruntime_test_python_keras.py'], cwd=cwd, dll_path=dll_path)
def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_parallel_executor_test, num_parallel_models):
for config in configs:
cwd = get_config_build_dir(build_dir, config)
@ -565,6 +571,18 @@ def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_para
run_subprocess([exe,'-x'] + cmd, cwd=cwd)
def run_server_tests(build_dir, configs):
run_subprocess([sys.executable, '-m', 'pip', 'install', '--trusted-host', 'files.pythonhosted.org', 'requests', 'protobuf', 'numpy'])
for config in configs:
config_build_dir = get_config_build_dir(build_dir, config)
if is_windows():
server_app_path = os.path.join(config_build_dir, config, 'onnxruntime_server.exe')
else:
server_app_path = os.path.join(config_build_dir, 'onnxruntime_server')
server_test_folder = os.path.join(config_build_dir, 'server_test')
server_test_data_folder = os.path.join(os.path.join(config_build_dir, 'testdata'), 'server')
run_subprocess([sys.executable, 'test_main.py', server_app_path, server_test_data_folder, server_test_data_folder], cwd=server_test_folder, dll_path=None)
def build_python_wheel(source_dir, build_dir, configs, use_cuda, use_ngraph, use_tensorrt, nightly_build = False):
for config in configs:
cwd = get_config_build_dir(build_dir, config)
@ -766,6 +784,9 @@ def main():
if args.use_mkldnn:
run_onnx_tests(build_dir, configs, onnx_test_data_dir, 'mkldnn', True, 1)
if args.build_server and args.enable_server_tests:
run_server_tests(build_dir, configs)
if args.build:
if args.build_wheel:
nightly_build = bool(os.getenv('NIGHTLY_BUILD') == '1')

Просмотреть файл

@ -0,0 +1,19 @@
jobs:
- job: Debug_Build
pool: Hosted Ubuntu 1604
steps:
- template: templates/set-test-data-variables-step.yml
- script: 'tools/ci_build/github/linux/server_run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -k $(acr.key) -x "--config Debug --build_server --use_openmp --use_full_protobuf --enable_server_tests"'
displayName: 'Debug Build'
- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
displayName: 'Component Detection'
- template: templates/clean-agent-build-directory-step.yml
- job: Release_Build
pool: Hosted Ubuntu 1604
steps:
- template: templates/set-test-data-variables-step.yml
- script: 'tools/ci_build/github/linux/server_run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -k $(acr.key) -x "--config Release --build_server --use_openmp --use_full_protobuf --enable_server_tests"'
displayName: 'Release Build'
- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
displayName: 'Component Detection'
- template: templates/clean-agent-build-directory-step.yml

Просмотреть файл

@ -0,0 +1,31 @@
#!/bin/bash
set -e -o -x
id
SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )"
while getopts c:d:x: parameter_Option
do case "${parameter_Option}"
in
d) BUILD_DEVICE=${OPTARG};;
x) BUILD_EXTR_PAR=${OPTARG};;
esac
done
if [ $BUILD_DEVICE = "gpu" ]; then
_CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2)
python3 $SCRIPT_DIR/../../build.py --build_dir /home/onnxruntimedev \
--config Debug Release \
--skip_submodule_sync --enable_onnx_tests \
--parallel --build_shared_lib \
--use_cuda --use_openmp \
--cuda_home /usr/local/cuda \
--cudnn_home /usr/local/cudnn-$_CUDNN_VERSION/cuda --build_shared_lib $BUILD_EXTR_PAR
/home/onnxruntimedev/Release/onnx_test_runner -e cuda /data/onnx
else
python3 $SCRIPT_DIR/../../build.py --build_dir /home/onnxruntimedev \
--skip_submodule_sync \
--parallel $BUILD_EXTR_PAR
# /home/onnxruntimedev/Release/onnx_test_runner /data/onnx
fi

Просмотреть файл

@ -0,0 +1,77 @@
#!/bin/bash
set -e -o -x
SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )"
SOURCE_ROOT=$(realpath $SCRIPT_DIR/../../../../)
CUDA_VER=cuda10.0-cudnn7.3
while getopts c:o:d:k:r:p:x: parameter_Option
do case "${parameter_Option}"
in
#ubuntu16.04
o) BUILD_OS=${OPTARG};;
#cpu, gpu
d) BUILD_DEVICE=${OPTARG};;
k) ACR_KEY=${OPTARG};;
r) BUILD_DIR=${OPTARG};;
#python version: 3.6 3.7 (absence means default 3.5)
p) PYTHON_VER=${OPTARG};;
# "--build_wheel --use_openblas"
x) BUILD_EXTR_PAR=${OPTARG};;
# "cuda10.0-cudnn7.3, cuda9.1-cudnn7.1"
c) CUDA_VER=${OPTARG};;
esac
done
EXIT_CODE=1
echo "bo=$BUILD_OS bd=$BUILD_DEVICE bdir=$BUILD_DIR pv=$PYTHON_VER bex=$BUILD_EXTR_PAR"
cd $SCRIPT_DIR/docker
if [ $BUILD_DEVICE = "gpu" ]; then
IMAGE="ubuntu16.04-$CUDA_VER"
DOCKER_FILE=Dockerfile.ubuntu_gpu
if [ $CUDA_VER = "cuda9.1-cudnn7.1" ]; then
DOCKER_FILE=Dockerfile.ubuntu_gpu_cuda9
fi
docker build -t "onnxruntime-$IMAGE" --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} -f $DOCKER_FILE .
else
IMAGE="ubuntu16.04"
docker login onnxhostingdev.azurecr.io -u onnxhostingdev -p ${ACR_KEY}
docker pull onnxhostingdev.azurecr.io/onnxruntime-ubuntu16.04:latest
docker tag onnxhostingdev.azurecr.io/onnxruntime-ubuntu16.04:latest onnxruntime-ubuntu16.04:latest
docker images
id
fi
set +e
if [ $BUILD_DEVICE = "cpu" ]; then
docker rm -f "onnxruntime-$BUILD_DEVICE" || true
docker run -h $HOSTNAME \
--rm \
--name "onnxruntime-$BUILD_DEVICE" \
--volume "$SOURCE_ROOT:/onnxruntime_src" \
--volume "$BUILD_DIR:/home/onnxruntimedev" \
--volume "$HOME/.cache/onnxruntime:/home/onnxruntimedev/.cache/onnxruntime" \
"onnxruntime-$IMAGE" \
/bin/bash /onnxruntime_src/tools/ci_build/github/linux/server_run_build.sh \
-d $BUILD_DEVICE -x "$BUILD_EXTR_PAR" &
else
docker rm -f "onnxruntime-$BUILD_DEVICE" || true
nvidia-docker run --rm -h $HOSTNAME \
--rm \
--name "onnxruntime-$BUILD_DEVICE" \
--volume "$SOURCE_ROOT:/onnxruntime_src" \
--volume "$BUILD_DIR:/home/onnxruntimedev" \
--volume "$HOME/.cache/onnxruntime:/home/onnxruntimedev/.cache/onnxruntime" \
"onnxruntime-$IMAGE" \
/bin/bash /onnxruntime_src/tools/ci_build/github/linux/server_run_build.sh \
-d $BUILD_DEVICE -x "$BUILD_EXTR_PAR" &
fi
wait -n
EXIT_CODE=$?
set -e
exit $EXIT_CODE