Add an HTTP server for hosting of ONNX models (#806)
* Simple integration into CMake build system * Adds vcpkg as a submodule and updates build.py to install hosting dependencies * Don't create vcpkg executable if already created * Fixes how CMake finds toolchain file and quick changes to build.py * Removes setting the CMAKE_TOOLCHAIN_FILE in build.py * Adds Boost Beast echo server and Boost program_options * Fixes spacing problem with program_options * Adds Microsoft headers to all the beast server headers * Removes CXX 14 from CMake file * Adds TODO to create configuration class * Run clang-format on main * Better exception handling of program_options * Remove vckpg submodule via ssh * Add vcpkg as https * Adds onnxruntime namespace to call classes * Fixed places where namespaces were anonymous * Adds a TODO to use the logger * Moves all setting namespace shortnames outside of onnxruntime namespace * Add onnxruntime session options to force app to link with it * Set CMAKE_TOOLCHAIN_FILE in build.py * Remove whitespace * Adds initial ONNX Hosting tests (#5) * Add initial test which is failing linking with no main * Adds test_main to get hosting tests working * Deletes useless add_executable line * Merge changes from upstream * Enable CI build in Vienna environment * make hosting_run*.sh executable * Add boost path in unittest * Add boost to TEST_INC_DIR * Add component detection task in ci yaml * Get tests and hosting to compile with re2 (#7) * Add finding boost packages before using it in unit tests * Add predict.proto and build * Ignore unused parameters in generated code * Removes std::regex in favor of re2 (#8) * Removes std::regex in favor of re2 * Adds back find_package in unit tests and fixes regexes * Adds more negative test cases * Adding more protos * Fix google protobuf file path in the cmake file * Ignore unused parameters for pb generated code * Updates onnx submodule (#10) * Remove duplicated lib in link * Follow Google style guide (#11) * Google style names * Adds more * Adds an additional namespace * Fixes header guards to match filepaths * Consume protobuf * Unit Test setup * Json deserialization simple test cases * Split hosting app to lib and exe for testability * Add more cases * Clean up * Add more comments * Update namespace and format the cmake files * Update cmake/external/onnx to checkout 1ec81bc6d49ccae23cd7801515feaadd13082903 * Separate h and cc in http folder * Clean up hosting application cmake file * Enable logging and proper initialize the session * Update const position for GetSession() * Take latest onnx and onnx-tensorrt * Creates configuration header file for program_options (#15) * Sets up PredictRequest callback (#16) * Init version, porting from prototype, e2e works * More executor implementation * Adds function on application startup (#17) * Attempts to pass HostingEnvironment as a shared_ptr * Removes logging and environment from all http classes * Passes http details to OnStart function * Using full protobuf for hosting app build * MLValue2TensorProto * Revert back changes in inference_session.cc * Refactor logger access and predict handler * Create an error handling callback (#19) * Creates error callback * Logs error and returns back as JSON * Catches exceptions in user functions * Refactor executor and add some test cases * Fix build warning * Add onnx as a dependency and in includes to hosting app (#20) * Converter for specific types and more UTs * More unit tests * Update onnx submodule * Fix string data test * Clean up code * Cleanup code * Refactor logging to use unique id per request and take logging level from user (#21) * Removes capturing env by reference in main * Uses uuid for logging ids * Take logging_level as a program argument * Pass logging_level to default_logging_manager * Change name of logger to HostingApp * Log if request id is null * Update GetHttpStatusCode signature * Fix random result issue and camel-case names * Rollback accidentally changed pybin_state.cc * Rollback pybind_state.cc * Generate protobuf status from onnxruntime status * Fix function name in error message * Clean up comments * Support protobuf byte array as input * Refactor predict handler and add unit tests * Add one more test * update cmake/external/onnx * Accept more protobuf MIME types * Update onnx-tensorrt * Add build instruction and usage doc * Address PR comments * Install g++-7 in the Ubuntu 16.04 build image for vcpkg * Fix onnx-tensorrt version * Check return value during initialization * Fix infinite loop when http port is in use (#29) * Simplify Executor.cc by breaking up Run method (#27) * Move request id to Executor constructor * Refactor the logger to respect user verbosity level * Use Arena allocator instead of device * Creates initial executor tests * Merge upstream master (#31) * Remove all possible shared_ptrs (#30) * Changes GetLogger to unique_ptr * Reserve BFloat raw data vector size * Change HostingEnvironment to being passed by lvalue and rvalue references * Change routes to getting passed by const references * Enable full protobuf if building hosting (#32) * Building hosting application no longer needs use_full_protobuf flag * Improve hosting application docs * Move server core into separate folder (#34) * Turn hosting project off by default (#38) * Remove vcpkg as a submodule and download/install Boost from source (#39) * Remove vcpkg * Use CMake script to download and build Boost as part of the project * Remove std::move for const references * Remove error_code.proto * Change wording of executable help description * Better GenerateProtobufStatus description * Remove error_code protobuf from CMake files * Use all outputs if no filter is given * Pass MLValue by const reference in MLValueToTensorProto * Rename variables to argc and argv * Revert "Use all outputs if no filter is given" This reverts commit 7554190ab8e50ba6947648c2f3e2a3d4d9606ce0. * Remove all header guards in favor of #pragma once * Reserve size for output vector and optimize for-loop * Use static libs by default for Boost * Improves documentation for GenerateResponseInJson function * Start Result enum at 0 instead of 1 * Remove g++ from Ubuntu's install.sh * Update cmake files * Give explanation for Result enum type * Remove all program options shortcuts except for -h * Add comments for predict.proto * Fix JSON for error codes * Add notice on hosting application docs that it's in beta * Change HostingEnvironment back to a shared_ptr * Handle empty output_filter field * Fix build break * Refactor unit tests location and groups * First end-to-end test * Add missing log * Missing req id and client req id in error response * Add one test case to validate failed resp header * Add build flag for hosting app end to end tests * Update pipeline setup to run e2e test for CI build * Model Zoo data preparation and tests * Add protobuf tests * Remove mention of needing g++-7 in BUILD.md * Make GetAppLogger const * Make using_raw_data_ match the styling of other fields * Avoid copy of strings when initializing model * Escape JSON strings correctly for error messages (#44) * Escape JSON strings correctly * Add test examples with lots of carriage returns * Add result validation * Remove temporary path * Optimize model zoo test execution * Improve reliability of test cases * Generate _pb2.py during the build time * README for integration tests * Pass environment by pointer instead of shared_ptr to executor (#49) * More Integration tests * Remove generated files * Make session private and use a getter instead (#53) * logging_level to log_level for CLI * Single model prediction shortcut * Health endpoint * Integration tests * Rename to onnxruntime server * Build ONNX Server application on Windows (#57) * Gets Boost compiling on Windows * Fix integer conversion and comparison problems * Use size_t in converter_tests instead of int * Fix hosting integration tests on Windows * Removes checks for port because it's an unsigned short * Fixes comparison between signed and unsigned data types * Pip install protobuf and numpy * Missing test data from the rename change * Fix server app path (#58) * Pass shared_ptr by const reference to avoid ref count increase (#59) * Download test model during test setup * Make download into test_util * Rename ci pipeline for onnx runtime server * Support up to 10MiB http request (#61) * Changes minimum request size to 10MB to support all models in ONNX Model Zoo
This commit is contained in:
Родитель
6f5c28fd3a
Коммит
1978b3c953
|
@ -11,6 +11,7 @@ distribute/*
|
|||
*.bin
|
||||
cmake_build
|
||||
.cmake_build
|
||||
cmake-build-debug
|
||||
gen
|
||||
*~
|
||||
.vs
|
||||
|
|
4
BUILD.md
4
BUILD.md
|
@ -53,6 +53,10 @@ The complete list of build options can be found by running `./build.sh (or ./bui
|
|||
1. For Windows, just add --x86 argument when launching build.bat
|
||||
2. For Linux, it must be built out of a x86 os, --x86 argument also needs be specified to build.sh
|
||||
|
||||
## Build ONNX Runtime Server on Linux
|
||||
|
||||
1. In the ONNX Runtime root folder, run `./build.sh --config RelWithDebInfo --build_server --use_openmp --parallel`
|
||||
|
||||
## Build/Test Flavors for CI
|
||||
|
||||
### CI Build Environments
|
||||
|
|
|
@ -70,6 +70,7 @@ option(onnxruntime_USE_BRAINSLICE "Build with BrainSlice" OFF)
|
|||
option(onnxruntime_USE_TENSORRT "Build with TensorRT support" OFF)
|
||||
option(onnxruntime_ENABLE_LTO "Enable link time optimization, which is not stable on older GCCs" OFF)
|
||||
option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF)
|
||||
option(onnxruntime_BUILD_SERVER "Build ONNX Runtime Server" OFF)
|
||||
option(onnxruntime_USE_FULL_PROTOBUF "Use full protobuf" OFF)
|
||||
option(onnxruntime_DISABLE_CONTRIB_OPS "Disable contrib ops" OFF)
|
||||
option(onnxruntime_USE_EIGEN_THREADPOOL "Use eigen threadpool. Otherwise OpenMP or a homemade one will be used" OFF)
|
||||
|
@ -607,6 +608,10 @@ if (onnxruntime_BUILD_SHARED_LIB)
|
|||
include(onnxruntime.cmake)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_BUILD_SERVER)
|
||||
include(onnxruntime_server.cmake)
|
||||
endif()
|
||||
|
||||
# some of the tests rely on the shared libs to be
|
||||
# built; hence the ordering
|
||||
if (onnxruntime_BUILD_UNIT_TESTS)
|
||||
|
@ -633,3 +638,4 @@ if (onnxruntime_BUILD_CSHARP)
|
|||
# set_property(GLOBAL PROPERTY VS_DOTNET_TARGET_FRAMEWORK_VERSION "netstandard2.0")
|
||||
include(onnxruntime_csharp.cmake)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set(BOOST_REQUESTED_VERSION 1.69.0 CACHE STRING "")
|
||||
set(BOOST_SHA1 8f32d4617390d1c2d16f26a27ab60d97807b35440d45891fa340fc2648b04406 CACHE STRING "")
|
||||
set(BOOST_USE_STATIC_LIBS true CACHE BOOL "")
|
||||
|
||||
set(BOOST_COMPONENTS program_options system thread)
|
||||
|
||||
# These components are only needed for Windows
|
||||
if(WIN32)
|
||||
list(APPEND BOOST_COMPONENTS date_time regex)
|
||||
endif()
|
||||
|
||||
# MSVC doesn't set these variables
|
||||
if(WIN32)
|
||||
set(CMAKE_STATIC_LIBRARY_PREFIX lib)
|
||||
set(CMAKE_SHARED_LIBRARY_PREFIX lib)
|
||||
endif()
|
||||
|
||||
# Set lib prefixes and suffixes for linking
|
||||
if(BOOST_USE_STATIC_LIBS)
|
||||
set(LIBRARY_PREFIX ${CMAKE_STATIC_LIBRARY_PREFIX})
|
||||
set(LIBRARY_SUFFIX ${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
else()
|
||||
set(LIBRARY_PREFIX ${CMAKE_SHARED_LIBRARY_PREFIX})
|
||||
set(LIBRARY_SUFFIX ${CMAKE_SHARED_LIBRARY_SUFFIX})
|
||||
endif()
|
||||
|
||||
# Create list of components in Boost format
|
||||
foreach(component ${BOOST_COMPONENTS})
|
||||
list(APPEND BOOST_COMPONENTS_FOR_BUILD --with-${component})
|
||||
endforeach()
|
||||
|
||||
set(BOOST_ROOT_DIR ${CMAKE_BINARY_DIR}/boost CACHE PATH "")
|
||||
|
||||
# TODO: let user give their own Boost installation
|
||||
macro(DOWNLOAD_BOOST)
|
||||
if(NOT BOOST_REQUESTED_VERSION)
|
||||
message(FATAL_ERROR "BOOST_REQUESTED_VERSION is not defined.")
|
||||
endif()
|
||||
|
||||
string(REPLACE "." "_" BOOST_REQUESTED_VERSION_UNDERSCORE ${BOOST_REQUESTED_VERSION})
|
||||
|
||||
set(BOOST_MAYBE_STATIC)
|
||||
if(BOOST_USE_STATIC_LIBS)
|
||||
set(BOOST_MAYBE_STATIC "link=static")
|
||||
endif()
|
||||
|
||||
set(VARIANT "release")
|
||||
if(CMAKE_BUILD_TYPE MATCHES Debug)
|
||||
set(VARIANT "debug")
|
||||
endif()
|
||||
|
||||
set(WINDOWS_B2_OPTIONS)
|
||||
set(WINDOWS_LIB_NAME_SCHEME)
|
||||
if(WIN32)
|
||||
set(BOOTSTRAP_FILE_TYPE "bat")
|
||||
set(WINDOWS_B2_OPTIONS "toolset=msvc-14.1" "architecture=x86" "address-model=64")
|
||||
set(WINDOWS_LIB_NAME_SCHEME "-vc141-mt-gd-x64-1_69")
|
||||
else()
|
||||
set(BOOTSTRAP_FILE_TYPE "sh")
|
||||
endif()
|
||||
|
||||
message(STATUS "Adding Boost components")
|
||||
include(ExternalProject)
|
||||
ExternalProject_Add(
|
||||
Boost
|
||||
URL http://dl.bintray.com/boostorg/release/${BOOST_REQUESTED_VERSION}/source/boost_${BOOST_REQUESTED_VERSION_UNDERSCORE}.tar.bz2
|
||||
URL_HASH SHA256=${BOOST_SHA1}
|
||||
DOWNLOAD_DIR ${BOOST_ROOT_DIR}
|
||||
SOURCE_DIR ${BOOST_ROOT_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ./bootstrap.${BOOTSTRAP_FILE_TYPE} --prefix=${BOOST_ROOT_DIR}
|
||||
BUILD_COMMAND ./b2 install ${BOOST_MAYBE_STATIC} --prefix=${BOOST_ROOT_DIR} variant=${VARIANT} ${WINDOWS_B2_OPTIONS} ${BOOST_COMPONENTS_FOR_BUILD}
|
||||
BUILD_IN_SOURCE true
|
||||
INSTALL_COMMAND ""
|
||||
INSTALL_DIR ${BOOST_ROOT_DIR}
|
||||
)
|
||||
|
||||
# Set include folders
|
||||
ExternalProject_Get_Property(Boost INSTALL_DIR)
|
||||
set(Boost_INCLUDE_DIR ${INSTALL_DIR}/include)
|
||||
if(WIN32)
|
||||
set(Boost_INCLUDE_DIR ${INSTALL_DIR}/include/boost-1_69)
|
||||
endif()
|
||||
|
||||
# Set libraries to link
|
||||
macro(libraries_to_fullpath varname)
|
||||
set(${varname})
|
||||
foreach(component ${BOOST_COMPONENTS})
|
||||
list(APPEND ${varname} ${INSTALL_DIR}/lib/${LIBRARY_PREFIX}boost_${component}${WINDOWS_LIB_NAME_SCHEME}${LIBRARY_SUFFIX})
|
||||
endforeach()
|
||||
endmacro()
|
||||
|
||||
libraries_to_fullpath(Boost_LIBRARIES)
|
||||
mark_as_advanced(Boost_LIBRARIES Boost_INCLUDE_DIR)
|
||||
endmacro()
|
||||
|
||||
DOWNLOAD_BOOST()
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set(SERVER_APP_NAME "onnxruntime_server")
|
||||
|
||||
# Generate .h and .cc files from protobuf file
|
||||
add_library(server_proto ${ONNXRUNTIME_ROOT}/server/protobuf/predict.proto)
|
||||
if(WIN32)
|
||||
target_compile_options(server_proto PRIVATE "/wd4125" "/wd4456")
|
||||
endif()
|
||||
target_include_directories(server_proto PUBLIC $<TARGET_PROPERTY:protobuf::libprotobuf,INTERFACE_INCLUDE_DIRECTORIES> "${CMAKE_CURRENT_BINARY_DIR}/.." ${CMAKE_CURRENT_BINARY_DIR}/onnx)
|
||||
target_compile_definitions(server_proto PUBLIC $<TARGET_PROPERTY:protobuf::libprotobuf,INTERFACE_COMPILE_DEFINITIONS>)
|
||||
onnxruntime_protobuf_generate(APPEND_PATH IMPORT_DIRS ${REPO_ROOT}/cmake/external/protobuf/src ${ONNXRUNTIME_ROOT}/server/protobuf ${ONNXRUNTIME_ROOT}/core/protobuf TARGET server_proto)
|
||||
add_dependencies(server_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
if(NOT WIN32)
|
||||
if(HAS_UNUSED_PARAMETER)
|
||||
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/model_metadata.pb.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/model_status.pb.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/predict.pb.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Setup dependencies
|
||||
include(get_boost.cmake)
|
||||
set(re2_src ${REPO_ROOT}/cmake/external/re2)
|
||||
|
||||
# Setup source code
|
||||
set(onnxruntime_server_lib_srcs
|
||||
"${ONNXRUNTIME_ROOT}/server/http/json_handling.cc"
|
||||
"${ONNXRUNTIME_ROOT}/server/http/predict_request_handler.cc"
|
||||
"${ONNXRUNTIME_ROOT}/server/http/util.cc"
|
||||
"${ONNXRUNTIME_ROOT}/server/environment.cc"
|
||||
"${ONNXRUNTIME_ROOT}/server/executor.cc"
|
||||
"${ONNXRUNTIME_ROOT}/server/converter.cc"
|
||||
"${ONNXRUNTIME_ROOT}/server/util.cc"
|
||||
)
|
||||
if(NOT WIN32)
|
||||
if(HAS_UNUSED_PARAMETER)
|
||||
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/http/json_handling.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/http/predict_request_handler.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/executor.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/converter.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties(${ONNXRUNTIME_ROOT}/server/util.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE onnxruntime_server_http_core_lib_srcs
|
||||
"${ONNXRUNTIME_ROOT}/server/http/core/*.cc"
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE onnxruntime_server_srcs
|
||||
"${ONNXRUNTIME_ROOT}/server/main.cc"
|
||||
)
|
||||
|
||||
# HTTP core library
|
||||
add_library(onnxruntime_server_http_core_lib STATIC
|
||||
${onnxruntime_server_http_core_lib_srcs})
|
||||
target_include_directories(onnxruntime_server_http_core_lib
|
||||
PUBLIC
|
||||
${ONNXRUNTIME_ROOT}/server/http/core
|
||||
${Boost_INCLUDE_DIR}
|
||||
${re2_src}
|
||||
)
|
||||
add_dependencies(onnxruntime_server_http_core_lib Boost)
|
||||
target_link_libraries(onnxruntime_server_http_core_lib PRIVATE
|
||||
${Boost_LIBRARIES}
|
||||
)
|
||||
|
||||
# Server library
|
||||
add_library(onnxruntime_server_lib ${onnxruntime_server_lib_srcs})
|
||||
onnxruntime_add_include_to_target(onnxruntime_server_lib gsl onnx_proto server_proto)
|
||||
target_include_directories(onnxruntime_server_lib PRIVATE
|
||||
${ONNXRUNTIME_ROOT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/onnx
|
||||
${ONNXRUNTIME_ROOT}/server
|
||||
${ONNXRUNTIME_ROOT}/server/http
|
||||
PUBLIC
|
||||
${Boost_INCLUDE_DIR}
|
||||
${re2_src}
|
||||
)
|
||||
|
||||
target_link_libraries(onnxruntime_server_lib PRIVATE
|
||||
server_proto
|
||||
${Boost_LIBRARIES}
|
||||
onnxruntime_server_http_core_lib
|
||||
onnxruntime_session
|
||||
onnxruntime_optimizer
|
||||
onnxruntime_providers
|
||||
onnxruntime_util
|
||||
onnxruntime_framework
|
||||
onnxruntime_util
|
||||
onnxruntime_graph
|
||||
onnxruntime_common
|
||||
onnxruntime_mlas
|
||||
${onnxruntime_EXTERNAL_LIBRARIES}
|
||||
)
|
||||
|
||||
# For IDE only
|
||||
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_server_srcs} ${onnxruntime_server_lib_srcs} ${onnxruntime_server_lib})
|
||||
|
||||
# Server Application
|
||||
add_executable(${SERVER_APP_NAME} ${onnxruntime_server_srcs})
|
||||
add_dependencies(${SERVER_APP_NAME} onnx server_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
||||
if(NOT WIN32)
|
||||
if(HAS_UNUSED_PARAMETER)
|
||||
set_source_files_properties("${ONNXRUNTIME_ROOT}/server/main.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
onnxruntime_add_include_to_target(${SERVER_APP_NAME} onnxruntime_session onnxruntime_server_lib gsl onnx onnx_proto server_proto)
|
||||
|
||||
target_include_directories(${SERVER_APP_NAME} PRIVATE
|
||||
${ONNXRUNTIME_ROOT}
|
||||
${ONNXRUNTIME_ROOT}/server/http
|
||||
)
|
||||
|
||||
target_link_libraries(${SERVER_APP_NAME} PRIVATE
|
||||
onnxruntime_server_http_core_lib
|
||||
onnxruntime_server_lib
|
||||
)
|
||||
|
|
@ -163,13 +163,15 @@ set(onnxruntime_test_framework_libs
|
|||
onnxruntime_mlas
|
||||
)
|
||||
|
||||
set(onnxruntime_test_server_libs
|
||||
onnxruntime_test_utils
|
||||
onnxruntime_test_utils_for_server
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
list(APPEND onnxruntime_test_framework_libs Advapi32)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
set (onnxruntime_test_providers_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
||||
if(onnxruntime_USE_CUDA)
|
||||
|
@ -557,6 +559,58 @@ if (onnxruntime_BUILD_SHARED_LIB)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if (onnxruntime_BUILD_SERVER)
|
||||
file(GLOB onnxruntime_test_server_src
|
||||
"${TEST_SRC_DIR}/server/unit_tests/*.cc"
|
||||
"${TEST_SRC_DIR}/server/unit_tests/*.h"
|
||||
)
|
||||
|
||||
file(GLOB onnxruntime_integration_test_server_src
|
||||
"${TEST_SRC_DIR}/server/integration_tests/*.py"
|
||||
)
|
||||
if(NOT WIN32)
|
||||
if(HAS_UNUSED_PARAMETER)
|
||||
set_source_files_properties("${TEST_SRC_DIR}/server/unit_tests/json_handling_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties("${TEST_SRC_DIR}/server/unit_tests/converter_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
set_source_files_properties("${TEST_SRC_DIR}/server/unit_tests/util_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(onnxruntime_test_utils_for_server ${onnxruntime_test_server_src})
|
||||
onnxruntime_add_include_to_target(onnxruntime_test_utils_for_server onnxruntime_test_utils gtest gmock gsl onnx onnx_proto server_proto)
|
||||
add_dependencies(onnxruntime_test_utils_for_server onnxruntime_server_lib onnxruntime_server_http_core_lib Boost ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
target_include_directories(onnxruntime_test_utils_for_server PUBLIC ${Boost_INCLUDE_DIR} ${REPO_ROOT}/cmake/external/re2 ${CMAKE_CURRENT_BINARY_DIR}/onnx ${ONNXRUNTIME_ROOT}/server/http ${ONNXRUNTIME_ROOT}/server/http/core PRIVATE ${ONNXRUNTIME_ROOT} )
|
||||
target_link_libraries(onnxruntime_test_utils_for_server ${Boost_LIBRARIES} ${onnx_test_libs})
|
||||
|
||||
AddTest(
|
||||
TARGET onnxruntime_server_tests
|
||||
SOURCES ${onnxruntime_test_server_src}
|
||||
LIBS ${onnxruntime_test_server_libs} server_proto onnxruntime_server_lib ${onnxruntime_test_providers_libs}
|
||||
DEPENDS ${onnxruntime_EXTERNAL_DEPENDENCIES}
|
||||
)
|
||||
|
||||
onnxruntime_protobuf_generate(
|
||||
APPEND_PATH IMPORT_DIRS ${REPO_ROOT}/cmake/external/protobuf/src ${ONNXRUNTIME_ROOT}/server/protobuf ${ONNXRUNTIME_ROOT}/core/protobuf
|
||||
PROTOS ${ONNXRUNTIME_ROOT}/server/protobuf/predict.proto ${ONNXRUNTIME_ROOT}/server/protobuf/onnx-ml.proto
|
||||
LANGUAGE python
|
||||
TARGET onnxruntime_server_tests
|
||||
OUT_VAR server_test_py)
|
||||
|
||||
add_custom_command(
|
||||
TARGET onnxruntime_server_tests POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/server_test
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_integration_test_server_src}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/server_test/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_BINARY_DIR}/onnx_ml_pb2.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/server_test/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_BINARY_DIR}/predict_pb2.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/server_test/
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
add_executable(onnxruntime_mlas_test ${TEST_SRC_DIR}/mlas/unittest.cpp)
|
||||
target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc)
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
<h1><span style="color:red">Note: ONNX Runtime Server is still in beta state. It's currently not ready for production environments.</span></h1>
|
||||
|
||||
# How to Use ONNX Runtime Server REST API for Prediction
|
||||
|
||||
ONNX Runtime Server provides a REST API for prediction. The goal of the project is to make it easy to "host" any ONNX model as a RESTful service. The CLI command to start the service is shown below:
|
||||
|
||||
```
|
||||
$ ./onnxruntime_server
|
||||
the option '--model_path' is required but missing
|
||||
Allowed options:
|
||||
-h [ --help ] Shows a help message and exits
|
||||
--log_level arg (=info) Logging level. Allowed options (case sensitive):
|
||||
verbose, info, warning, error, fatal
|
||||
--model_path arg Path to ONNX model
|
||||
--address arg (=0.0.0.0) The base HTTP address
|
||||
--http_port arg (=8001) HTTP port to listen to requests
|
||||
--num_http_threads arg (=<# of your cpu cores>) Number of http threads
|
||||
|
||||
|
||||
```
|
||||
|
||||
Note: The only mandatory argument for the program here is `model_path`
|
||||
|
||||
## Start the Server
|
||||
|
||||
To host an ONNX model as a REST API server, run:
|
||||
|
||||
```
|
||||
./onnxruntime_server --model_path /<your>/<model>/<path>
|
||||
```
|
||||
|
||||
The prediction URL is in this format:
|
||||
|
||||
```
|
||||
http://<your_ip_address>:<port>/v1/models/<your-model-name>/versions/<your-version>:predict
|
||||
```
|
||||
|
||||
**Note**: Since we currently only support one model, the model name and version can be any string length > 0. In the future, model_names and versions will be verified.
|
||||
|
||||
## Request and Response Payload
|
||||
|
||||
An HTTP request can be a Protobuf message in two formats: binary or JSON. The HTTP request header field `Content-Type` tells the server how to handle the request and thus it is mandatory for all requests. Requests missing `Content-Type` will be rejected as `400 Bad Request`.
|
||||
|
||||
* For `"Content-Type: application/json"`, the payload will be deserialized as JSON string in UTF-8 format
|
||||
* For `"Content-Type: application/vnd.google.protobuf"`, `"Content-Type: application/x-protobuf"` or `"Content-Type: application/octet-stream"`, the payload will be consumed as protobuf message directly.
|
||||
|
||||
The Protobuf definition can be found [here](https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/server/protobuf/predict.proto).
|
||||
|
||||
## Inferencing
|
||||
|
||||
To send a request to the server, you can use any tool which supports making HTTP requests. Here is an example using `curl`:
|
||||
|
||||
```
|
||||
curl -X POST -d "@predict_request_0.json" -H "Content-Type: application/json" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
curl -X POST --data-binary "@predict_request_0.pb" -H "Content-Type: application/octet-stream" -H "Foo: 1234" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
|
||||
```
|
||||
|
||||
Clients can control the response type by setting the request with an `Accept` header field and the server will serialize in your desired format. The choices currently available are the same as the `Content-Type` header field.
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Number of HTTP Threads
|
||||
|
||||
You can change this to optimize server utilization. The default is the number of CPU cores on the host machine.
|
||||
|
||||
### Request ID and Client Request ID
|
||||
|
||||
For easy tracking of requests, we provide the following header fields:
|
||||
|
||||
* `x-ms-request-id`: will be in the response header, no matter the request result. It will be a GUID/uuid with dash, e.g. `72b68108-18a4-493c-ac75-d0abd82f0a11`. If the request headers contain this field, the value will be ignored.
|
||||
* `x-ms-client-request-id`: a field for clients to tracking their requests. The content will persist in the response headers.
|
||||
|
||||
Here is an example of a client sending a request:
|
||||
|
||||
#### Client Side
|
||||
|
||||
```
|
||||
$ curl -v -X POST --data-binary "@predict_request_0.pb" -H "Content-Type: application/octet-stream" -H "Foo: 1234" -H "x-ms-client-request-id: my-request-001" -H "Accept: application/json" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
|
||||
Note: Unnecessary use of -X or --request, POST is already inferred.
|
||||
* Trying 127.0.0.1...
|
||||
* Connected to 127.0.0.1 (127.0.0.1) port 8001 (#0)
|
||||
> POST /v1/models/mymodel/versions/3:predict HTTP/1.1
|
||||
> Host: 127.0.0.1:8001
|
||||
> User-Agent: curl/7.47.0
|
||||
> Content-Type: application/octet-stream
|
||||
> x-ms-client-request-id: my-request-001
|
||||
> Accept: application/json
|
||||
> Content-Length: 3179
|
||||
> Expect: 100-continue
|
||||
>
|
||||
* Done waiting for 100-continue
|
||||
* We are completely uploaded and fine
|
||||
< HTTP/1.1 200 OK
|
||||
< Content-Type: application/json
|
||||
< x-ms-request-id: 72b68108-18a4-493c-ac75-d0abd82f0a11
|
||||
< x-ms-client-request-id: my-request-001
|
||||
< Content-Length: 159
|
||||
<
|
||||
* Connection #0 to host 127.0.0.1 left intact
|
||||
{"outputs":{"Sample_Output_Name":{"dims":["1","10"],"dataType":1,"rawData":"6OpzRFquGsSFdM1FyAEnRFtRZcRa9NDEUBj0xI4ydsJIS0LE//CzxA==","dataLocation":"DEFAULT"}}}%
|
||||
```
|
||||
|
||||
#### Server Side
|
||||
|
||||
And here is what the output on the server side looks like with logging level of verbose:
|
||||
|
||||
```
|
||||
2019-04-04 23:48:26.395200744 [V:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, predict_request_handler.cc:40 Predict] Name: mymodel Version: 3 Action: predict
|
||||
2019-04-04 23:48:26.395289437 [V:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, predict_request_handler.cc:46 Predict] x-ms-client-request-id: [my-request-001]
|
||||
2019-04-04 23:48:26.395540707 [I:onnxruntime:InferenceSession, inference_session.cc:736 Run] Running with tag: 72b68108-18a4-493c-ac75-d0abd82f0a11
|
||||
2019-04-04 23:48:26.395596858 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, inference_session.cc:976 CreateLoggerForRun] Created logger for run with id of 72b68108-18a4-493c-ac75-d0abd82f0a11
|
||||
2019-04-04 23:48:26.395731391 [I:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:42 Execute] Begin execution
|
||||
2019-04-04 23:48:26.395763319 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:45 Execute] Size of execution plan vector: 12
|
||||
2019-04-04 23:48:26.396228981 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Convolution28
|
||||
2019-04-04 23:48:26.396580161 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus30
|
||||
2019-04-04 23:48:26.396623732 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 10
|
||||
2019-04-04 23:48:26.396878822 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: ReLU32
|
||||
2019-04-04 23:48:26.397091882 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Pooling66
|
||||
2019-04-04 23:48:26.397126243 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 11
|
||||
2019-04-04 23:48:26.397772701 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Convolution110
|
||||
2019-04-04 23:48:26.397818174 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 13
|
||||
2019-04-04 23:48:26.398060592 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus112
|
||||
2019-04-04 23:48:26.398095300 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 14
|
||||
2019-04-04 23:48:26.398257563 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: ReLU114
|
||||
2019-04-04 23:48:26.398426740 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Pooling160
|
||||
2019-04-04 23:48:26.398466031 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 15
|
||||
2019-04-04 23:48:26.398542823 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212_reshape0
|
||||
2019-04-04 23:48:26.398599687 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212_reshape1
|
||||
2019-04-04 23:48:26.398692631 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212
|
||||
2019-04-04 23:48:26.398731471 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 17
|
||||
2019-04-04 23:48:26.398832735 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus214
|
||||
2019-04-04 23:48:26.398873229 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 19
|
||||
2019-04-04 23:48:26.398922929 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:160 Execute] Fetching output.
|
||||
2019-04-04 23:48:26.398956560 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:163 Execute] Done with execution.
|
||||
```
|
|
@ -0,0 +1,261 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <onnx/onnx_pb.h>
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/environment.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/mem_buffer.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
#include "onnx-ml.pb.h"
|
||||
#include "predict.pb.h"
|
||||
|
||||
#include "converter.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::DataTypeImpl* cpp_type) {
|
||||
if (cpp_type == onnxruntime::DataTypeImpl::GetType<float>()) {
|
||||
return onnx::TensorProto_DataType_FLOAT;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint8_t>()) {
|
||||
return onnx::TensorProto_DataType_UINT8;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int8_t>()) {
|
||||
return onnx::TensorProto_DataType_INT8;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint16_t>()) {
|
||||
return onnx::TensorProto_DataType_UINT16;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int16_t>()) {
|
||||
return onnx::TensorProto_DataType_INT16;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int32_t>()) {
|
||||
return onnx::TensorProto_DataType_INT32;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int64_t>()) {
|
||||
return onnx::TensorProto_DataType_INT64;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<std::string>()) {
|
||||
return onnx::TensorProto_DataType_STRING;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<bool>()) {
|
||||
return onnx::TensorProto_DataType_BOOL;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MLFloat16>()) {
|
||||
return onnx::TensorProto_DataType_FLOAT16;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<onnxruntime::BFloat16>()) {
|
||||
return onnx::TensorProto_DataType_BFLOAT16;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<double>()) {
|
||||
return onnx::TensorProto_DataType_DOUBLE;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint32_t>()) {
|
||||
return onnx::TensorProto_DataType_UINT32;
|
||||
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint64_t>()) {
|
||||
return onnx::TensorProto_DataType_UINT64;
|
||||
} else {
|
||||
return onnx::TensorProto_DataType_UNDEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
common::Status MLValueToTensorProto(const onnxruntime::MLValue& ml_value, bool using_raw_data,
|
||||
std::unique_ptr<onnxruntime::logging::Logger> logger,
|
||||
/* out */ onnx::TensorProto& tensor_proto) {
|
||||
// Tensor in MLValue
|
||||
const auto& tensor = ml_value.Get<onnxruntime::Tensor>();
|
||||
|
||||
// dims field
|
||||
const onnxruntime::TensorShape& tensor_shape = tensor.Shape();
|
||||
for (const auto& dim : tensor_shape.GetDims()) {
|
||||
tensor_proto.add_dims(dim);
|
||||
}
|
||||
|
||||
// data_type field
|
||||
onnx::TensorProto_DataType data_type = MLDataTypeToTensorProtoDataType(tensor.DataType());
|
||||
tensor_proto.set_data_type(data_type);
|
||||
|
||||
// data_location field: Data is stored in raw_data (if set) otherwise in type-specified field.
|
||||
if (using_raw_data && data_type != onnx::TensorProto_DataType_STRING) {
|
||||
tensor_proto.set_data_location(onnx::TensorProto_DataLocation_DEFAULT);
|
||||
}
|
||||
|
||||
// *_data field
|
||||
// According to onnx_ml.proto, depending on the data_type field,
|
||||
// exactly one of the *_data fields is used to store the elements of the tensor.
|
||||
switch (data_type) {
|
||||
case onnx::TensorProto_DataType_FLOAT: { // Target: raw_data or float_data
|
||||
const auto* data = tensor.Data<float>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
for (size_t i = 0, count = tensor.Shape().Size(); i < count; ++i) {
|
||||
tensor_proto.add_float_data(data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT32: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<int32_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
for (size_t i = 0, count = tensor.Shape().Size(); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT8: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<uint8_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT8: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<int8_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT16: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<uint16_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT16: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<int16_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_BOOL: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<bool>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT16: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<onnxruntime::MLFloat16>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_BFLOAT16: { // Target: raw_data or int32_data
|
||||
const auto* data = tensor.Data<onnxruntime::BFloat16>();
|
||||
const auto raw_data_size = tensor.Shape().Size();
|
||||
|
||||
std::vector<uint16_t> raw_data;
|
||||
raw_data.reserve(raw_data_size);
|
||||
for (int i = 0; i < raw_data_size; ++i) {
|
||||
raw_data.push_back(data[i].val);
|
||||
}
|
||||
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(raw_data.data(), raw_data.size() * sizeof(uint16_t));
|
||||
} else {
|
||||
auto i32data = reinterpret_cast<const int32_t*>(raw_data.data());
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
|
||||
tensor_proto.add_int32_data(i32data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_STRING: { // Target: string_data
|
||||
// string could not be written into "raw_data"
|
||||
const auto* data = tensor.Data<std::string>();
|
||||
for (size_t i = 0, count = tensor.Shape().Size(); i < count; ++i) {
|
||||
tensor_proto.add_string_data(data[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT64: { // Target: raw_data or int64_data
|
||||
const auto* data = tensor.Data<int64_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
for (size_t x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
|
||||
tensor_proto.add_int64_data(data[x]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT32: { // Target: raw_data or uint64_data
|
||||
const auto* data = tensor.Data<uint32_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
auto u64data = reinterpret_cast<const uint64_t*>(data);
|
||||
for (size_t i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(uint64_t)); i < count; ++i) {
|
||||
tensor_proto.add_uint64_data(u64data[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT64: { // Target: raw_data or uint64_data
|
||||
const auto* data = tensor.Data<uint64_t>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
for (size_t x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
|
||||
tensor_proto.add_uint64_data(data[x]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_DOUBLE: { // Target: raw_data or double_data
|
||||
auto data = tensor.Data<double>();
|
||||
if (using_raw_data) {
|
||||
tensor_proto.set_raw_data(data, tensor.Size());
|
||||
} else {
|
||||
for (size_t x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
|
||||
tensor_proto.add_double_data(data[x]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOGS(*logger, ERROR) << "Unsupported TensorProto DataType: " << data_type;
|
||||
return common::Status(common::StatusCategory::ONNXRUNTIME,
|
||||
common::StatusCode::NOT_IMPLEMENTED,
|
||||
"Unsupported TensorProto DataType: " + std::to_string(data_type));
|
||||
}
|
||||
}
|
||||
|
||||
return common::Status::OK();
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "core/framework/data_types.h"
|
||||
|
||||
#include "environment.h"
|
||||
#include "predict.pb.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::DataTypeImpl* cpp_type);
|
||||
|
||||
// Convert MLValue to TensorProto. Some fields are ignored:
|
||||
// * name field: could not get from MLValue
|
||||
// * doc_string: could not get from MLValue
|
||||
// * segment field: we do not expect very large tensors in the prediction output
|
||||
// * external_data field: we do not expect very large tensors in the prediction output
|
||||
// Note: If any input data is in raw_data field, all outputs tensor data will be put into raw_data field.
|
||||
common::Status MLValueToTensorProto(const onnxruntime::MLValue& ml_value, bool using_raw_data,
|
||||
std::unique_ptr<onnxruntime::logging::Logger> logger,
|
||||
/* out */ onnx::TensorProto& tensor_proto);
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <memory>
|
||||
#include "core/common/logging/logging.h"
|
||||
|
||||
#include "environment.h"
|
||||
#include "log_sink.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
ServerEnvironment::ServerEnvironment(logging::Severity severity) : severity_(severity),
|
||||
logger_id_("ServerApp"),
|
||||
default_logging_manager_(
|
||||
std::unique_ptr<logging::ISink>{new LogSink{}},
|
||||
severity,
|
||||
/* default_filter_user_data */ false,
|
||||
logging::LoggingManager::InstanceType::Default,
|
||||
&logger_id_) {
|
||||
auto status = onnxruntime::Environment::Create(runtime_environment_);
|
||||
|
||||
// The session initialization MUST BE AFTER environment creation
|
||||
session = std::make_unique<onnxruntime::InferenceSession>(options_, &default_logging_manager_);
|
||||
}
|
||||
|
||||
common::Status ServerEnvironment::InitializeModel(const std::string& model_path) {
|
||||
auto status = session->Load(model_path);
|
||||
if (!status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto outputs = session->GetModelOutputs();
|
||||
if (!outputs.first.IsOK()) {
|
||||
return outputs.first;
|
||||
}
|
||||
|
||||
for (const auto* output_node : *(outputs.second)) {
|
||||
model_output_names_.push_back(output_node->Name());
|
||||
}
|
||||
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<std::string>& ServerEnvironment::GetModelOutputNames() const {
|
||||
return model_output_names_;
|
||||
}
|
||||
|
||||
const logging::Logger& ServerEnvironment::GetAppLogger() const {
|
||||
return default_logging_manager_.DefaultLogger();
|
||||
}
|
||||
|
||||
logging::Severity ServerEnvironment::GetLogSeverity() const {
|
||||
return severity_;
|
||||
}
|
||||
|
||||
std::unique_ptr<logging::Logger> ServerEnvironment::GetLogger(const std::string& id) {
|
||||
if (id.empty()) {
|
||||
LOGS(GetAppLogger(), WARNING) << "Request id is null or empty string";
|
||||
}
|
||||
|
||||
return default_logging_manager_.CreateLogger(id, severity_, false);
|
||||
}
|
||||
|
||||
onnxruntime::InferenceSession* ServerEnvironment::GetSession() const {
|
||||
return session.get();
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "core/framework/environment.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace logging = logging;
|
||||
|
||||
class ServerEnvironment {
|
||||
public:
|
||||
explicit ServerEnvironment(logging::Severity severity);
|
||||
~ServerEnvironment() = default;
|
||||
ServerEnvironment(const ServerEnvironment&) = delete;
|
||||
|
||||
const logging::Logger& GetAppLogger() const;
|
||||
std::unique_ptr<logging::Logger> GetLogger(const std::string& id);
|
||||
logging::Severity GetLogSeverity() const;
|
||||
|
||||
onnxruntime::InferenceSession* GetSession() const;
|
||||
common::Status InitializeModel(const std::string& model_path);
|
||||
const std::vector<std::string>& GetModelOutputNames() const;
|
||||
|
||||
|
||||
private:
|
||||
const logging::Severity severity_;
|
||||
const std::string logger_id_;
|
||||
logging::LoggingManager default_logging_manager_;
|
||||
|
||||
std::unique_ptr<onnxruntime::Environment> runtime_environment_;
|
||||
onnxruntime::SessionOptions options_;
|
||||
std::unique_ptr<onnxruntime::InferenceSession> session;
|
||||
std::vector<std::string> model_output_names_;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <onnx/onnx_pb.h>
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/environment.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/mem_buffer.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
#include "onnx-ml.pb.h"
|
||||
#include "predict.pb.h"
|
||||
|
||||
#include "converter.h"
|
||||
#include "executor.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
protobufutil::Status Executor::SetMLValue(const onnx::TensorProto& input_tensor,
|
||||
OrtAllocatorInfo* cpu_allocator_info,
|
||||
/* out */ MLValue& ml_value) {
|
||||
auto logger = env_->GetLogger(request_id_);
|
||||
|
||||
size_t cpu_tensor_length = 0;
|
||||
auto status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(input_tensor, &cpu_tensor_length);
|
||||
if (!status.IsOK()) {
|
||||
LOGS(*logger, ERROR) << "GetSizeInBytesFromTensorProto() failed. Error Message: " << status.ToString();
|
||||
return GenerateProtobufStatus(status, "GetSizeInBytesFromTensorProto() failed: " + status.ToString());
|
||||
}
|
||||
|
||||
std::unique_ptr<char[]> data(new char[cpu_tensor_length]);
|
||||
memset(data.get(), 0, cpu_tensor_length);
|
||||
|
||||
OrtCallback deleter;
|
||||
status = onnxruntime::utils::TensorProtoToMLValue(onnxruntime::Env::Default(), nullptr, input_tensor,
|
||||
onnxruntime::MemBuffer(data.release(), cpu_tensor_length, *cpu_allocator_info),
|
||||
ml_value, deleter);
|
||||
if (!status.IsOK()) {
|
||||
LOGS(*logger, ERROR) << "TensorProtoToMLValue() failed. Message: " << status.ToString();
|
||||
return GenerateProtobufStatus(status, "TensorProtoToMLValue() failed:" + status.ToString());
|
||||
}
|
||||
|
||||
return protobufutil::Status::OK;
|
||||
}
|
||||
|
||||
protobufutil::Status Executor::SetNameMLValueMap(onnxruntime::NameMLValMap& name_value_map, const onnxruntime::server::PredictRequest& request) {
|
||||
auto logger = env_->GetLogger(request_id_);
|
||||
|
||||
OrtAllocatorInfo* cpu_allocator_info = nullptr;
|
||||
auto ort_status = OrtCreateAllocatorInfo("Cpu", OrtArenaAllocator, 0, OrtMemTypeDefault, &cpu_allocator_info);
|
||||
if (ort_status != nullptr || cpu_allocator_info == nullptr) {
|
||||
LOGS(*logger, ERROR) << "OrtCreateAllocatorInfo failed";
|
||||
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "OrtCreateAllocatorInfo() failed");
|
||||
}
|
||||
|
||||
// Prepare the MLValue object
|
||||
for (const auto& input : request.inputs()) {
|
||||
using_raw_data_ = using_raw_data_ && input.second.has_raw_data();
|
||||
|
||||
MLValue ml_value;
|
||||
auto status = SetMLValue(input.second, cpu_allocator_info, ml_value);
|
||||
if (status != protobufutil::Status::OK) {
|
||||
LOGS(*logger, ERROR) << "SetMLValue() failed! Input name: " << input.first;
|
||||
return status;
|
||||
}
|
||||
|
||||
auto insertion_result = name_value_map.insert(std::make_pair(input.first, ml_value));
|
||||
if (!insertion_result.second) {
|
||||
LOGS(*logger, ERROR) << "SetNameMLValueMap() failed! Input name: " << input.first << " Trying to overwrite existing input value";
|
||||
return protobufutil::Status(protobufutil::error::Code::INVALID_ARGUMENT, "SetNameMLValueMap() failed: Cannot have two inputs with the same name");
|
||||
}
|
||||
}
|
||||
|
||||
return protobufutil::Status::OK;
|
||||
}
|
||||
|
||||
protobufutil::Status Executor::Predict(const std::string& model_name,
|
||||
const std::string& model_version,
|
||||
onnxruntime::server::PredictRequest& request,
|
||||
/* out */ onnxruntime::server::PredictResponse& response) {
|
||||
auto logger = env_->GetLogger(request_id_);
|
||||
|
||||
// Convert PredictRequest to NameMLValMap
|
||||
onnxruntime::NameMLValMap name_ml_value_map{};
|
||||
auto conversion_status = SetNameMLValueMap(name_ml_value_map, request);
|
||||
if (conversion_status != protobufutil::Status::OK) {
|
||||
return conversion_status;
|
||||
}
|
||||
|
||||
// Prepare the output names and vector
|
||||
std::vector<std::string> output_names;
|
||||
|
||||
if (!request.output_filter().empty()) {
|
||||
output_names.reserve(request.output_filter_size());
|
||||
for (const auto& name : request.output_filter()) {
|
||||
output_names.push_back(name);
|
||||
}
|
||||
} else {
|
||||
output_names = env_->GetModelOutputNames();
|
||||
}
|
||||
|
||||
std::vector<onnxruntime::MLValue> outputs(output_names.size());
|
||||
|
||||
// Run
|
||||
OrtRunOptions run_options{};
|
||||
run_options.run_log_verbosity_level = static_cast<unsigned int>(env_->GetLogSeverity());
|
||||
run_options.run_tag = request_id_;
|
||||
|
||||
auto status = env_->GetSession()->Run(run_options, name_ml_value_map, output_names, &outputs);
|
||||
|
||||
if (!status.IsOK()) {
|
||||
LOGS(*logger, ERROR) << "Run() failed."
|
||||
<< ". Error Message: " << status.ToString();
|
||||
return GenerateProtobufStatus(status, "Run() failed: " + status.ToString());
|
||||
}
|
||||
|
||||
// Build the response
|
||||
for (size_t i = 0, sz = outputs.size(); i < sz; ++i) {
|
||||
onnx::TensorProto output_tensor{};
|
||||
status = MLValueToTensorProto(outputs[i], using_raw_data_, std::move(logger), output_tensor);
|
||||
logger = env_->GetLogger(request_id_);
|
||||
|
||||
if (!status.IsOK()) {
|
||||
LOGS(*logger, ERROR) << "MLValueToTensorProto() failed. Output name: " << output_names[i] << ". Error Message: " << status.ToString();
|
||||
return GenerateProtobufStatus(status, "MLValueToTensorProto() failed: " + status.ToString());
|
||||
}
|
||||
|
||||
auto insertion_result = response.mutable_outputs()->insert({output_names[i], output_tensor});
|
||||
|
||||
if (!insertion_result.second) {
|
||||
LOGS(*logger, ERROR) << "SetNameMLValueMap() failed. Output name: " << output_names[i] << " Trying to overwrite existing output value";
|
||||
return protobufutil::Status(protobufutil::error::Code::INVALID_ARGUMENT, "SetNameMLValueMap() failed: Cannot have two outputs with the same name");
|
||||
}
|
||||
}
|
||||
|
||||
return protobufutil::Status::OK;
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "environment.h"
|
||||
#include "predict.pb.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
class Executor {
|
||||
public:
|
||||
Executor(ServerEnvironment* server_env, std::string request_id) : env_(server_env),
|
||||
request_id_(std::move(request_id)),
|
||||
using_raw_data_(true) {}
|
||||
|
||||
// Prediction method
|
||||
google::protobuf::util::Status Predict(const std::string& model_name,
|
||||
const std::string& model_version,
|
||||
onnxruntime::server::PredictRequest& request,
|
||||
/* out */ onnxruntime::server::PredictResponse& response);
|
||||
|
||||
private:
|
||||
ServerEnvironment* env_;
|
||||
const std::string request_id_;
|
||||
bool using_raw_data_;
|
||||
|
||||
google::protobuf::util::Status SetMLValue(const onnx::TensorProto& input_tensor,
|
||||
OrtAllocatorInfo* cpu_allocator_info,
|
||||
/* out */ MLValue& ml_value);
|
||||
|
||||
google::protobuf::util::Status SetNameMLValueMap(onnxruntime::NameMLValMap& name_value_map, const onnxruntime::server::PredictRequest& request);
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
// boost random is using a deprecated header in 1.69
|
||||
// See: https://github.com/boostorg/random/issues/49
|
||||
#define BOOST_PENDING_INTEGER_LOG2_HPP
|
||||
#include <boost/integer/integer_log2.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <boost/beast/http.hpp>
|
||||
#include <boost/uuid/uuid.hpp>
|
||||
#include <boost/uuid/uuid_io.hpp>
|
||||
#include <boost/uuid/uuid_generators.hpp>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
|
||||
|
||||
// This class represents the HTTP context given to the user
|
||||
// Currently, we are just giving the Boost request and response object
|
||||
// But in the future we should write a wrapper around them
|
||||
class HttpContext {
|
||||
public:
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
http::response<http::string_body> response{};
|
||||
|
||||
const std::string request_id;
|
||||
std::string client_request_id;
|
||||
http::status error_code;
|
||||
std::string error_message;
|
||||
|
||||
HttpContext() : request_id(boost::uuids::to_string(boost::uuids::random_generator()())),
|
||||
client_request_id(""),
|
||||
error_code(http::status::internal_server_error),
|
||||
error_message("An unknown server error has occurred") {}
|
||||
|
||||
~HttpContext() = default;
|
||||
HttpContext(const HttpContext&) = delete;
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/beast/http.hpp>
|
||||
|
||||
#include "context.h"
|
||||
#include "session.h"
|
||||
#include "listener.h"
|
||||
|
||||
#include "http_server.h"
|
||||
|
||||
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
|
||||
namespace net = boost::asio; // from <boost/asio.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
App::App() {
|
||||
http_details.address = boost::asio::ip::make_address_v4("0.0.0.0");
|
||||
http_details.port = 8001;
|
||||
http_details.threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
App& App::Bind(net::ip::address address, unsigned short port) {
|
||||
http_details.address = std::move(address);
|
||||
http_details.port = port;
|
||||
return *this;
|
||||
}
|
||||
|
||||
App& App::NumThreads(int threads) {
|
||||
http_details.threads = threads;
|
||||
return *this;
|
||||
}
|
||||
|
||||
App& App::RegisterStartup(const StartFn& on_start) {
|
||||
on_start_ = on_start;
|
||||
return *this;
|
||||
}
|
||||
|
||||
App& App::RegisterPost(const std::string& route, const HandlerFn& fn) {
|
||||
routes_.RegisterController(http::verb::post, route, fn);
|
||||
return *this;
|
||||
}
|
||||
|
||||
App& App::RegisterError(const ErrorFn& fn) {
|
||||
routes_.RegisterErrorCallback(fn);
|
||||
return *this;
|
||||
}
|
||||
|
||||
App& App::Run() {
|
||||
net::io_context ioc{http_details.threads};
|
||||
// Create and launch a listening port
|
||||
auto listener = std::make_shared<Listener>(routes_, ioc, tcp::endpoint{http_details.address, http_details.port});
|
||||
|
||||
auto initialized = listener->Init();
|
||||
if (!initialized) {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
auto started = listener->Run();
|
||||
if (!started) {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Run user on_start function
|
||||
on_start_(http_details);
|
||||
|
||||
// Run the I/O service on the requested number of threads
|
||||
std::vector<std::thread> v;
|
||||
v.reserve(http_details.threads - 1);
|
||||
for (auto i = http_details.threads - 1; i > 0; --i) {
|
||||
v.emplace_back(
|
||||
[&ioc] {
|
||||
ioc.run();
|
||||
});
|
||||
}
|
||||
ioc.run();
|
||||
return *this;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "util.h"
|
||||
#include "context.h"
|
||||
#include "routes.h"
|
||||
#include "session.h"
|
||||
#include "listener.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace http = beast::http; // from <boost/beast/http.hpp>
|
||||
namespace net = boost::asio; // from <boost/asio.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
struct Details {
|
||||
net::ip::address address;
|
||||
unsigned short port;
|
||||
int threads;
|
||||
};
|
||||
|
||||
using StartFn = std::function<void(Details&)>;
|
||||
|
||||
// Accepts incoming connections and launches the sessions
|
||||
// Each method returns the app itself so methods can be chained
|
||||
class App {
|
||||
public:
|
||||
App();
|
||||
|
||||
App& Bind(net::ip::address address, unsigned short port);
|
||||
App& NumThreads(int threads);
|
||||
App& RegisterStartup(const StartFn& fn);
|
||||
App& RegisterPost(const std::string& route, const HandlerFn& fn);
|
||||
App& RegisterError(const ErrorFn& fn);
|
||||
App& Run();
|
||||
|
||||
private:
|
||||
Routes routes_{};
|
||||
StartFn on_start_ = {};
|
||||
Details http_details{};
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "listener.h"
|
||||
#include "session.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace net = boost::asio; // from <boost/asio.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
Listener::Listener(const Routes& routes, net::io_context& ioc, const tcp::endpoint& endpoint)
|
||||
: routes_(routes), acceptor_(ioc), socket_(ioc), endpoint_(endpoint) {
|
||||
}
|
||||
|
||||
bool Listener::Init() {
|
||||
beast::error_code ec;
|
||||
|
||||
// Open the acceptor
|
||||
acceptor_.open(endpoint_.protocol(), ec);
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "open");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allow address reuse
|
||||
acceptor_.set_option(net::socket_base::reuse_address(true), ec);
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "set_option");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Bind to the routes address
|
||||
acceptor_.bind(endpoint_, ec);
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "bind");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Start listening for connections
|
||||
acceptor_.listen(
|
||||
net::socket_base::max_listen_connections, ec);
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "listen");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Listener::Run() {
|
||||
if (!acceptor_.is_open()) {
|
||||
return false;
|
||||
}
|
||||
DoAccept();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Listener::DoAccept() {
|
||||
acceptor_.async_accept(
|
||||
socket_,
|
||||
std::bind(
|
||||
&Listener::OnAccept,
|
||||
shared_from_this(),
|
||||
std::placeholders::_1));
|
||||
}
|
||||
|
||||
void Listener::OnAccept(beast::error_code ec) {
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "accept");
|
||||
} else {
|
||||
std::make_shared<HttpSession>(routes_, std::move(socket_))->Run();
|
||||
}
|
||||
|
||||
// Accept another connection
|
||||
DoAccept();
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
|
||||
#include "routes.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace net = boost::asio; // from <boost/asio.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
// Listens on a socket and creates an HTTP session
|
||||
class Listener : public std::enable_shared_from_this<Listener> {
|
||||
Routes routes_;
|
||||
tcp::acceptor acceptor_;
|
||||
tcp::socket socket_;
|
||||
const tcp::endpoint endpoint_;
|
||||
|
||||
public:
|
||||
Listener(const Routes& routes, net::io_context& ioc, const tcp::endpoint& endpoint);
|
||||
|
||||
// Initialize the HTTP server
|
||||
bool Init();
|
||||
|
||||
// Start accepting incoming connections
|
||||
bool Run();
|
||||
|
||||
// Asynchronously accepts the socket
|
||||
void DoAccept();
|
||||
|
||||
// Creates the HTTP session and runs it
|
||||
void OnAccept(beast::error_code ec);
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <iostream>
|
||||
#include "re2/re2.h"
|
||||
|
||||
#include "context.h"
|
||||
#include "routes.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
|
||||
|
||||
bool Routes::RegisterController(http::verb method, const std::string& url_pattern, const HandlerFn& controller) {
|
||||
if (controller == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (method) {
|
||||
case http::verb::get:
|
||||
this->get_fn_table.emplace_back(url_pattern, controller);
|
||||
return true;
|
||||
case http::verb::post:
|
||||
this->post_fn_table.emplace_back(url_pattern, controller);
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool Routes::RegisterErrorCallback(const ErrorFn& controller) {
|
||||
if (controller == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
on_error = controller;
|
||||
return true;
|
||||
}
|
||||
|
||||
http::status Routes::ParseUrl(http::verb method,
|
||||
const std::string& url,
|
||||
/* out */ std::string& model_name,
|
||||
/* out */ std::string& model_version,
|
||||
/* out */ std::string& action,
|
||||
/* out */ HandlerFn& func) const {
|
||||
std::vector<std::pair<std::string, HandlerFn>> func_table;
|
||||
switch (method) {
|
||||
case http::verb::get:
|
||||
func_table = this->get_fn_table;
|
||||
break;
|
||||
case http::verb::post:
|
||||
func_table = this->post_fn_table;
|
||||
break;
|
||||
default:
|
||||
return http::status::method_not_allowed;
|
||||
}
|
||||
|
||||
if (func_table.empty()) {
|
||||
return http::status::method_not_allowed;
|
||||
}
|
||||
|
||||
bool found_match = false;
|
||||
for (const auto& pattern : func_table) {
|
||||
if (re2::RE2::FullMatch(url, pattern.first, &model_name, &model_version, &action)) {
|
||||
func = pattern.second;
|
||||
|
||||
found_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_match) {
|
||||
return http::status::not_found;
|
||||
}
|
||||
|
||||
return http::status::ok;
|
||||
}
|
||||
|
||||
} //namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/beast/http.hpp>
|
||||
|
||||
#include "context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
|
||||
|
||||
using HandlerFn = std::function<void(std::string&, std::string&, std::string&, HttpContext&)>;
|
||||
using ErrorFn = std::function<void(HttpContext&)>;
|
||||
|
||||
// This class maintains two lists of regex -> function lists. One for POST requests and one for GET requests
|
||||
// If the incoming URL could match more than one regex, the first one will win.
|
||||
class Routes {
|
||||
public:
|
||||
Routes() = default;
|
||||
ErrorFn on_error;
|
||||
bool RegisterController(http::verb method, const std::string& url_pattern, const HandlerFn& controller);
|
||||
bool RegisterErrorCallback(const ErrorFn& controller);
|
||||
|
||||
http::status ParseUrl(http::verb method,
|
||||
const std::string& url,
|
||||
/* out */ std::string& model_name,
|
||||
/* out */ std::string& model_version,
|
||||
/* out */ std::string& action,
|
||||
/* out */ HandlerFn& func) const;
|
||||
|
||||
private:
|
||||
std::vector<std::pair<std::string, HandlerFn>> post_fn_table;
|
||||
std::vector<std::pair<std::string, HandlerFn>> get_fn_table;
|
||||
};
|
||||
|
||||
} //namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "session.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace net = boost::asio; // from <boost/asio.hpp>
|
||||
namespace beast = boost::beast; // from <boost/beast.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
HttpSession::HttpSession(const Routes& routes, tcp::socket socket)
|
||||
: routes_(routes), socket_(std::move(socket)), strand_(socket_.get_executor()) {
|
||||
}
|
||||
|
||||
void HttpSession::DoRead() {
|
||||
req_.emplace();
|
||||
|
||||
// TODO: make the max request size configable.
|
||||
req_->body_limit(10 * 1024 * 1024); // Max request size: 10 MiB
|
||||
|
||||
http::async_read(socket_, buffer_, *req_,
|
||||
net::bind_executor(
|
||||
strand_,
|
||||
std::bind(
|
||||
&HttpSession::OnRead,
|
||||
shared_from_this(),
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2)));
|
||||
}
|
||||
|
||||
void HttpSession::OnRead(beast::error_code ec, std::size_t bytes_transferred) {
|
||||
boost::ignore_unused(bytes_transferred);
|
||||
|
||||
// This means they closed the connection
|
||||
if (ec == http::error::end_of_stream) {
|
||||
return DoClose();
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "read");
|
||||
return;
|
||||
}
|
||||
|
||||
// Send the response
|
||||
HandleRequest(req_->release());
|
||||
}
|
||||
|
||||
void HttpSession::OnWrite(beast::error_code ec, std::size_t bytes_transferred, bool close) {
|
||||
boost::ignore_unused(bytes_transferred);
|
||||
|
||||
if (ec) {
|
||||
ErrorHandling(ec, "write");
|
||||
return;
|
||||
}
|
||||
|
||||
if (close) {
|
||||
// This means we should close the connection, usually because
|
||||
// the response indicated the "Connection: close" semantic.
|
||||
return DoClose();
|
||||
}
|
||||
|
||||
// We're done with the response so delete it
|
||||
res_ = nullptr;
|
||||
|
||||
// Read another request
|
||||
DoRead();
|
||||
}
|
||||
|
||||
void HttpSession::DoClose() {
|
||||
// Send a TCP shutdown
|
||||
beast::error_code ec;
|
||||
socket_.shutdown(tcp::socket::shutdown_send, ec);
|
||||
|
||||
// At this point the connection is closed gracefully
|
||||
}
|
||||
|
||||
template <class Msg>
|
||||
void HttpSession::Send(Msg&& msg) {
|
||||
using item_type = std::remove_reference_t<decltype(msg)>;
|
||||
|
||||
auto ptr = std::make_shared<item_type>(std::move(msg));
|
||||
auto self_ = shared_from_this();
|
||||
self_->res_ = ptr;
|
||||
|
||||
http::async_write(self_->socket_, *ptr,
|
||||
net::bind_executor(strand_,
|
||||
[ self_, close = ptr->need_eof() ](beast::error_code ec, std::size_t bytes) {
|
||||
self_->OnWrite(ec, bytes, close);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename Body, typename Allocator>
|
||||
void HttpSession::HandleRequest(http::request<Body, http::basic_fields<Allocator> >&& req) {
|
||||
HttpContext context{};
|
||||
context.request = std::move(req);
|
||||
|
||||
// Special handle the liveness probe endpoint for orchestration systems like Kubernetes.
|
||||
if (context.request.method() == http::verb::get && context.request.target().to_string() == "/") {
|
||||
context.response.body() = "Healthy";
|
||||
} else {
|
||||
auto status = ExecuteUserFunction(context);
|
||||
|
||||
if (status != http::status::ok) {
|
||||
routes_.on_error(context);
|
||||
}
|
||||
}
|
||||
|
||||
context.response.keep_alive(context.request.keep_alive());
|
||||
context.response.prepare_payload();
|
||||
return Send(std::move(context.response));
|
||||
}
|
||||
|
||||
http::status HttpSession::ExecuteUserFunction(HttpContext& context) {
|
||||
std::string path = context.request.target().to_string();
|
||||
std::string model_name, model_version, action;
|
||||
HandlerFn func;
|
||||
|
||||
if (context.request.find("x-ms-client-request-id") != context.request.end()) {
|
||||
context.client_request_id = context.request["x-ms-client-request-id"].to_string();
|
||||
}
|
||||
|
||||
if (path == "/score") {
|
||||
// This is a shortcut since we have only one model instance currently.
|
||||
// This code path will be removed once we start supporting multiple models or multiple versions of one model.
|
||||
path = "/v1/models/default/versions/1:predict";
|
||||
}
|
||||
|
||||
auto status = routes_.ParseUrl(context.request.method(), path, model_name, model_version, action, func);
|
||||
|
||||
if (status != http::status::ok) {
|
||||
context.error_code = status;
|
||||
context.error_message = std::string(http::obsolete_reason(status)) +
|
||||
". For HTTP method: " +
|
||||
std::string(http::to_string(context.request.method())) +
|
||||
" and request path: " +
|
||||
context.request.target().to_string();
|
||||
return status;
|
||||
}
|
||||
|
||||
try {
|
||||
func(model_name, model_version, action, context);
|
||||
} catch (const std::exception& ex) {
|
||||
context.error_message = std::string(ex.what());
|
||||
return http::status::internal_server_error;
|
||||
}
|
||||
|
||||
return http::status::ok;
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,78 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <boost/beast/version.hpp>
|
||||
#include <boost/asio/bind_executor.hpp>
|
||||
#include <boost/beast/core/flat_buffer.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/strand.hpp>
|
||||
|
||||
#include "context.h"
|
||||
#include "routes.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace net = boost::asio; // from <boost/asio.hpp>
|
||||
namespace beast = boost::beast; // from <boost/beast.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
namespace http = beast::http;
|
||||
|
||||
// An implementation of a single HTTP session
|
||||
// Used by a listener to hand off the work and async write back to a socket
|
||||
class HttpSession : public std::enable_shared_from_this<HttpSession> {
|
||||
public:
|
||||
HttpSession(const Routes& routes, tcp::socket socket);
|
||||
|
||||
// Start the asynchronous operation
|
||||
// The entrypoint for the class
|
||||
void Run() {
|
||||
DoRead();
|
||||
}
|
||||
|
||||
private:
|
||||
const Routes routes_;
|
||||
tcp::socket socket_;
|
||||
net::strand<net::io_context::executor_type> strand_;
|
||||
beast::flat_buffer buffer_;
|
||||
boost::optional<http::request_parser<http::string_body>> req_;
|
||||
std::shared_ptr<void> res_{nullptr};
|
||||
|
||||
// Writes the message asynchronously back to the socket
|
||||
// Stores the pointer to the message and the class itself so that
|
||||
// They do not get destructed before the async process is finished
|
||||
// If you pass shared_from_this() are guaranteed that the life time
|
||||
// of your object will be extended to as long as the function needs it
|
||||
// Most examples in boost::asio are based on this logic
|
||||
template <class Msg>
|
||||
void Send(Msg&& msg);
|
||||
|
||||
// Called after the session is finished reading the message
|
||||
// Should set the response before calling Send
|
||||
template <typename Body, typename Allocator>
|
||||
void HandleRequest(http::request<Body, http::basic_fields<Allocator>>&& req);
|
||||
|
||||
// Handle the request and hand it off to the user's function
|
||||
// Execute user function, handle errors
|
||||
// HttpContext parameter can be updated here or in HandleRequest
|
||||
http::status ExecuteUserFunction(HttpContext& context);
|
||||
|
||||
// Asynchronously reads the request from the socket
|
||||
void DoRead();
|
||||
|
||||
// Perform error checking before handing off to HandleRequest
|
||||
void OnRead(beast::error_code ec, std::size_t bytes_transferred);
|
||||
|
||||
// After writing, make the session read another request
|
||||
void OnWrite(beast::error_code ec, std::size_t bytes_transferred, bool close);
|
||||
|
||||
// Close the connection
|
||||
void DoClose();
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <boost/beast/core.hpp>
|
||||
|
||||
#include "context.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
// Report a failure
|
||||
void ErrorHandling(beast::error_code ec, char const* what) {
|
||||
std::cerr << what << " failed: " << ec.value() << " : " << ec.message() << "\n";
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/beast/http/status.hpp>
|
||||
|
||||
#include "context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace beast = boost::beast; // from <boost/beast.hpp>
|
||||
|
||||
// Report a failure
|
||||
void ErrorHandling(beast::error_code ec, char const* what);
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <google/protobuf/util/json_util.h>
|
||||
|
||||
#include "predict.pb.h"
|
||||
#include "json_handling.h"
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
protobufutil::Status GetRequestFromJson(const std::string& json_string, /* out */ onnxruntime::server::PredictRequest& request) {
|
||||
protobufutil::JsonParseOptions options;
|
||||
options.ignore_unknown_fields = true;
|
||||
|
||||
protobufutil::Status result = JsonStringToMessage(json_string, &request, options);
|
||||
return result;
|
||||
}
|
||||
|
||||
protobufutil::Status GenerateResponseInJson(const onnxruntime::server::PredictResponse& response, /* out */ std::string& json_string) {
|
||||
protobufutil::JsonPrintOptions options;
|
||||
options.add_whitespace = false;
|
||||
options.always_print_primitive_fields = false;
|
||||
options.always_print_enums_as_ints = false;
|
||||
options.preserve_proto_field_names = false;
|
||||
|
||||
protobufutil::Status result = MessageToJsonString(response, &json_string, options);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string CreateJsonError(const http::status error_code, const std::string& error_message) {
|
||||
auto escaped_message = escape_string(error_message);
|
||||
return R"({"error_code": )" + std::to_string(int(error_code)) + R"(, "error_message": ")" + escaped_message + R"("})" + "\n";
|
||||
}
|
||||
|
||||
std::string escape_string(const std::string& message) {
|
||||
std::ostringstream o;
|
||||
for (char c : message) {
|
||||
switch (c) {
|
||||
case '"': o << "\\\""; break;
|
||||
case '\\': o << "\\\\"; break;
|
||||
case '\b': o << "\\b"; break;
|
||||
case '\f': o << "\\f"; break;
|
||||
case '\n': o << "\\n"; break;
|
||||
case '\r': o << "\\r"; break;
|
||||
case '\t': o << "\\t"; break;
|
||||
default:
|
||||
if ('\x00' <= c && c <= '\x1f') {
|
||||
o << "\\u"
|
||||
<< std::hex << std::setw(4) << std::setfill('0') << (int)c;
|
||||
} else {
|
||||
o << c;
|
||||
}
|
||||
}
|
||||
}
|
||||
return o.str();
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <google/protobuf/util/json_util.h>
|
||||
#include <boost/beast/http.hpp>
|
||||
|
||||
#include "predict.pb.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace http = boost::beast::http;
|
||||
|
||||
// Deserialize Json input to PredictRequest.
|
||||
// Unknown fields in the json file will be ignored.
|
||||
google::protobuf::util::Status GetRequestFromJson(const std::string& json_string, /* out */ onnxruntime::server::PredictRequest& request);
|
||||
|
||||
// Serialize PredictResponse to json string
|
||||
// 1. Proto3 primitive fields with default values will be omitted in JSON output. Eg. int32 field with value 0 will be omitted
|
||||
// 2. Enums will be printed as string, not int, to improve readability
|
||||
google::protobuf::util::Status GenerateResponseInJson(const onnxruntime::server::PredictResponse& response, /* out */ std::string& json_string);
|
||||
|
||||
// Constructs JSON error message from error code object and error message
|
||||
std::string CreateJsonError(http::status error_code, const std::string& error_message);
|
||||
|
||||
// Escapes a string following the JSON standard
|
||||
// Mostly taken from here: https://stackoverflow.com/questions/7724448/simple-json-string-escape-for-c/33799784#33799784
|
||||
std::string escape_string(const std::string& message);
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "environment.h"
|
||||
#include "http_server.h"
|
||||
#include "json_handling.h"
|
||||
#include "executor.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
#define GenerateErrorResponse(logger, error_code, message, context) \
|
||||
{ \
|
||||
auto http_error_code = (error_code); \
|
||||
(context).response.insert("x-ms-request-id", ((context).request_id)); \
|
||||
if (!(context).client_request_id.empty()) { \
|
||||
(context).response.insert("x-ms-client-request-id", (context).client_request_id); \
|
||||
} \
|
||||
auto json_error_message = CreateJsonError(http_error_code, (message)); \
|
||||
LOGS((*logger), VERBOSE) << json_error_message; \
|
||||
(context).response.result(http_error_code); \
|
||||
(context).response.body() = json_error_message; \
|
||||
(context).response.set(http::field::content_type, "application/json"); \
|
||||
}
|
||||
|
||||
static bool ParseRequestPayload(const HttpContext& context, SupportedContentType request_type,
|
||||
/* out */ PredictRequest& predictRequest, /* out */ http::status& error_code, /* out */ std::string& error_message);
|
||||
|
||||
void Predict(const std::string& name,
|
||||
const std::string& version,
|
||||
const std::string& action,
|
||||
/* in, out */ HttpContext& context,
|
||||
const std::shared_ptr<ServerEnvironment>& env) {
|
||||
auto logger = env->GetLogger(context.request_id);
|
||||
LOGS(*logger, INFO) << "Model Name: " << name << ", Version: " << version << ", Action: " << action;
|
||||
|
||||
if (!context.client_request_id.empty()) {
|
||||
LOGS(*logger, INFO) << "x-ms-client-request-id: [" << context.client_request_id << "]";
|
||||
}
|
||||
|
||||
// Request and Response content type information
|
||||
SupportedContentType request_type = GetRequestContentType(context);
|
||||
SupportedContentType response_type = GetResponseContentType(context);
|
||||
if (response_type == SupportedContentType::Unknown) {
|
||||
GenerateErrorResponse(logger, http::status::bad_request, "Unknown 'Accept' header field in the request", context);
|
||||
}
|
||||
|
||||
// Deserialize the payload
|
||||
auto body = context.request.body();
|
||||
PredictRequest predict_request{};
|
||||
http::status error_code;
|
||||
std::string error_message;
|
||||
bool parse_succeeded = ParseRequestPayload(context, request_type, predict_request, error_code, error_message);
|
||||
if (!parse_succeeded) {
|
||||
GenerateErrorResponse(logger, error_code, error_message, context);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run Prediction
|
||||
protobufutil::Status status;
|
||||
Executor executor(env.get(), context.request_id);
|
||||
PredictResponse predict_response{};
|
||||
status = executor.Predict(name, version, predict_request, predict_response);
|
||||
if (!status.ok()) {
|
||||
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status.error_message(), context);
|
||||
return;
|
||||
}
|
||||
|
||||
// Serialize to proper output format
|
||||
std::string response_body{};
|
||||
if (response_type == SupportedContentType::Json) {
|
||||
status = GenerateResponseInJson(predict_response, response_body);
|
||||
if (!status.ok()) {
|
||||
GenerateErrorResponse(logger, http::status::internal_server_error, status.error_message(), context);
|
||||
return;
|
||||
}
|
||||
context.response.set(http::field::content_type, "application/json");
|
||||
} else {
|
||||
response_body = predict_response.SerializeAsString();
|
||||
if (context.request.find("Accept") != context.request.end() && context.request["Accept"] != "*/*") {
|
||||
context.response.set(http::field::content_type, context.request["Accept"].to_string());
|
||||
} else {
|
||||
context.response.set(http::field::content_type, "application/octet-stream");
|
||||
}
|
||||
}
|
||||
|
||||
// Build HTTP response
|
||||
context.response.insert("x-ms-request-id", context.request_id);
|
||||
if (!context.client_request_id.empty()) {
|
||||
context.response.insert("x-ms-client-request-id", context.client_request_id);
|
||||
}
|
||||
context.response.body() = response_body;
|
||||
context.response.result(http::status::ok);
|
||||
};
|
||||
|
||||
static bool ParseRequestPayload(const HttpContext& context, SupportedContentType request_type, PredictRequest& predictRequest, http::status& error_code, std::string& error_message) {
|
||||
auto body = context.request.body();
|
||||
protobufutil::Status status;
|
||||
switch (request_type) {
|
||||
case SupportedContentType::Json: {
|
||||
status = GetRequestFromJson(body, predictRequest);
|
||||
if (!status.ok()) {
|
||||
error_code = GetHttpStatusCode(status);
|
||||
error_message = status.error_message();
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case SupportedContentType::PbByteArray: {
|
||||
bool parse_succeeded = predictRequest.ParseFromArray(body.data(), static_cast<int>(body.size()));
|
||||
if (!parse_succeeded) {
|
||||
error_code = http::status::bad_request;
|
||||
error_message = "Invalid payload.";
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error_code = http::status::bad_request;
|
||||
error_message = "Missing or unknown 'Content-Type' header field in the request";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "http_server.h"
|
||||
#include "json_handling.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace beast = boost::beast;
|
||||
namespace http = beast::http;
|
||||
|
||||
void BadRequest(HttpContext& context, const std::string& error_message);
|
||||
|
||||
// TODO: decide whether this should be a class
|
||||
void Predict(const std::string& name,
|
||||
const std::string& version,
|
||||
const std::string& action,
|
||||
/* in, out */ HttpContext& context,
|
||||
const std::shared_ptr<ServerEnvironment>& env);
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <unordered_set>
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/beast/http/status.hpp>
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "context.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
static std::unordered_set<std::string> protobuf_mime_types{
|
||||
"application/octet-stream",
|
||||
"application/vnd.google.protobuf",
|
||||
"application/x-protobuf"};
|
||||
|
||||
boost::beast::http::status GetHttpStatusCode(const protobufutil::Status& status) {
|
||||
switch (status.error_code()) {
|
||||
case protobufutil::error::Code::OK:
|
||||
return boost::beast::http::status::ok;
|
||||
|
||||
case protobufutil::error::Code::UNKNOWN:
|
||||
case protobufutil::error::Code::DEADLINE_EXCEEDED:
|
||||
case protobufutil::error::Code::RESOURCE_EXHAUSTED:
|
||||
case protobufutil::error::Code::ABORTED:
|
||||
case protobufutil::error::Code::UNIMPLEMENTED:
|
||||
case protobufutil::error::Code::INTERNAL:
|
||||
case protobufutil::error::Code::UNAVAILABLE:
|
||||
case protobufutil::error::Code::DATA_LOSS:
|
||||
return boost::beast::http::status::internal_server_error;
|
||||
|
||||
case protobufutil::error::Code::CANCELLED:
|
||||
case protobufutil::error::Code::INVALID_ARGUMENT:
|
||||
case protobufutil::error::Code::ALREADY_EXISTS:
|
||||
case protobufutil::error::Code::FAILED_PRECONDITION:
|
||||
case protobufutil::error::Code::OUT_OF_RANGE:
|
||||
return boost::beast::http::status::bad_request;
|
||||
|
||||
case protobufutil::error::Code::NOT_FOUND:
|
||||
return boost::beast::http::status::not_found;
|
||||
|
||||
case protobufutil::error::Code::PERMISSION_DENIED:
|
||||
return boost::beast::http::status::forbidden;
|
||||
|
||||
case protobufutil::error::Code::UNAUTHENTICATED:
|
||||
return boost::beast::http::status::unauthorized;
|
||||
|
||||
default:
|
||||
return boost::beast::http::status::internal_server_error;
|
||||
}
|
||||
}
|
||||
|
||||
SupportedContentType GetRequestContentType(const HttpContext& context) {
|
||||
if (context.request.find("Content-Type") != context.request.end()) {
|
||||
if (context.request["Content-Type"] == "application/json") {
|
||||
return SupportedContentType::Json;
|
||||
} else if (protobuf_mime_types.find(context.request["Content-Type"].to_string()) != protobuf_mime_types.end()) {
|
||||
return SupportedContentType::PbByteArray;
|
||||
}
|
||||
}
|
||||
|
||||
return SupportedContentType::Unknown;
|
||||
}
|
||||
|
||||
SupportedContentType GetResponseContentType(const HttpContext& context) {
|
||||
if (context.request.find("Accept") != context.request.end()) {
|
||||
if (context.request["Accept"] == "application/json") {
|
||||
return SupportedContentType::Json;
|
||||
} else if (context.request["Accept"] == "*/*" || protobuf_mime_types.find(context.request["Accept"].to_string()) != protobuf_mime_types.end()) {
|
||||
return SupportedContentType::PbByteArray;
|
||||
}
|
||||
} else {
|
||||
return SupportedContentType::PbByteArray;
|
||||
}
|
||||
|
||||
return SupportedContentType::Unknown;
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/beast/http/status.hpp>
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "server/http/core/context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace beast = boost::beast; // from <boost/beast.hpp>
|
||||
|
||||
enum class SupportedContentType : int {
|
||||
Unknown,
|
||||
Json,
|
||||
PbByteArray
|
||||
};
|
||||
|
||||
// Mapping protobuf status to http status
|
||||
boost::beast::http::status GetHttpStatusCode(const google::protobuf::util::Status& status);
|
||||
|
||||
// "Content-Type" header field in request is MUST-HAVE.
|
||||
// Currently we only support two types of input content type: application/json and application/octet-stream
|
||||
SupportedContentType GetRequestContentType(const HttpContext& context);
|
||||
|
||||
// "Accept" header field in request is OPTIONAL.
|
||||
// Currently we only support three types of response content type: */*, application/json and application/octet-stream
|
||||
SupportedContentType GetResponseContentType(const HttpContext& context);
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/sinks/ostream_sink.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
class LogSink : public onnxruntime::logging::OStreamSink {
|
||||
public:
|
||||
LogSink() : OStreamSink(std::cout, /*flush*/ true) {
|
||||
}
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "environment.h"
|
||||
#include "http_server.h"
|
||||
#include "predict_request_handler.h"
|
||||
#include "server_configuration.h"
|
||||
|
||||
namespace beast = boost::beast;
|
||||
namespace http = beast::http;
|
||||
namespace server = onnxruntime::server;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
server::ServerConfiguration config{};
|
||||
auto res = config.ParseInput(argc, argv);
|
||||
|
||||
if (res == server::Result::ExitSuccess) {
|
||||
exit(EXIT_SUCCESS);
|
||||
} else if (res == server::Result::ExitFailure) {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
const auto env = std::make_shared<server::ServerEnvironment>(config.logging_level);
|
||||
auto logger = env->GetAppLogger();
|
||||
LOGS(logger, VERBOSE) << "Logging manager initialized.";
|
||||
LOGS(logger, INFO) << "Model path: " << config.model_path;
|
||||
|
||||
auto status = env->InitializeModel(config.model_path);
|
||||
if (!status.IsOK()) {
|
||||
LOGS(logger, FATAL) << "Initialize Model Failed: " << status.Code() << " ---- Error: [" << status.ErrorMessage() << "]";
|
||||
exit(EXIT_FAILURE);
|
||||
} else {
|
||||
LOGS(logger, VERBOSE) << "Initialize Model Successfully!";
|
||||
}
|
||||
|
||||
status = env->GetSession()->Initialize();
|
||||
if (!status.IsOK()) {
|
||||
LOGS(logger, FATAL) << "Session Initialization Failed:" << status.Code() << " ---- Error: [" << status.ErrorMessage() << "]";
|
||||
exit(EXIT_FAILURE);
|
||||
} else {
|
||||
LOGS(logger, VERBOSE) << "Initialize Session Successfully!";
|
||||
}
|
||||
|
||||
auto const boost_address = boost::asio::ip::make_address(config.address);
|
||||
server::App app{};
|
||||
|
||||
app.RegisterStartup(
|
||||
[&env](const auto& details) -> void {
|
||||
auto logger = env->GetAppLogger();
|
||||
LOGS(logger, INFO) << "Listening at: "
|
||||
<< "http://" << details.address << ":" << details.port;
|
||||
});
|
||||
|
||||
app.RegisterError(
|
||||
[&env](auto& context) -> void {
|
||||
auto logger = env->GetLogger(context.request_id);
|
||||
LOGS(*logger, VERBOSE) << "Error code: " << context.error_code;
|
||||
LOGS(*logger, VERBOSE) << "Error message: " << context.error_message;
|
||||
|
||||
context.response.result(context.error_code);
|
||||
context.response.insert("Content-Type", "application/json");
|
||||
context.response.insert("x-ms-request-id", context.request_id);
|
||||
if (!context.client_request_id.empty()) {
|
||||
context.response.insert("x-ms-client-request-id", (context).client_request_id);
|
||||
}
|
||||
context.response.body() = server::CreateJsonError(context.error_code, context.error_message);
|
||||
});
|
||||
|
||||
app.RegisterPost(
|
||||
R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))",
|
||||
[&env](const auto& name, const auto& version, const auto& action, auto& context) -> void {
|
||||
server::Predict(name, version, action, context, env);
|
||||
});
|
||||
|
||||
app.Bind(boost_address, config.http_port)
|
||||
.NumThreads(config.num_http_threads)
|
||||
.Run();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
../../core/protobuf/onnx-ml.proto3
|
|
@ -0,0 +1,27 @@
|
|||
syntax = "proto3";
|
||||
|
||||
import "onnx-ml.proto";
|
||||
|
||||
package onnxruntime.server;
|
||||
|
||||
// PredictRequest specifies how inputs are mapped to tensors
|
||||
// and how outputs are filtered before returning to user.
|
||||
message PredictRequest {
|
||||
reserved 1;
|
||||
|
||||
// Input Tensors.
|
||||
// This is a mapping between output name and tensor.
|
||||
map<string, onnx.TensorProto> inputs = 2;
|
||||
|
||||
// Output Filters.
|
||||
// This field is to specify which output fields need to be returned.
|
||||
// If the list is empty, all outputs will be included.
|
||||
repeated string output_filter = 3;
|
||||
}
|
||||
|
||||
// Response for PredictRequest on successful run.
|
||||
message PredictResponse {
|
||||
// Output Tensors.
|
||||
// This is a mapping between output name and tensor.
|
||||
map<string, onnx.TensorProto> outputs = 1;
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "boost/program_options.hpp"
|
||||
#include "core/common/logging/logging.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// Enumerates the different type of results which can occur
|
||||
// The three different types are:
|
||||
// 0. ExitSuccess which is when the program should exit with EXIT_SUCCESS
|
||||
// 1. ExitFailure when program should exit with EXIT_FAILURE
|
||||
// 2. No need for exiting the program, continue
|
||||
enum class Result {
|
||||
ExitSuccess,
|
||||
ExitFailure,
|
||||
ContinueSuccess
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, onnxruntime::logging::Severity> supported_log_levels{
|
||||
{"verbose", onnxruntime::logging::Severity::kVERBOSE},
|
||||
{"info", onnxruntime::logging::Severity::kINFO},
|
||||
{"warning", onnxruntime::logging::Severity::kWARNING},
|
||||
{"error", onnxruntime::logging::Severity::kERROR},
|
||||
{"fatal", onnxruntime::logging::Severity::kFATAL}};
|
||||
|
||||
// Wrapper around Boost program_options and should provide all the functionality for options parsing
|
||||
// Provides sane default values
|
||||
class ServerConfiguration {
|
||||
public:
|
||||
const std::string full_desc = "ONNX Server: host an ONNX model with ONNX Runtime";
|
||||
std::string model_path;
|
||||
std::string address = "0.0.0.0";
|
||||
unsigned short http_port = 8001;
|
||||
int num_http_threads = std::thread::hardware_concurrency();
|
||||
onnxruntime::logging::Severity logging_level{};
|
||||
|
||||
ServerConfiguration() {
|
||||
desc.add_options()("help,h", "Shows a help message and exits");
|
||||
desc.add_options()("log_level", po::value(&log_level_str)->default_value(log_level_str), "Logging level. Allowed options (case sensitive): verbose, info, warning, error, fatal");
|
||||
desc.add_options()("model_path", po::value(&model_path)->required(), "Path to ONNX model");
|
||||
desc.add_options()("address", po::value(&address)->default_value(address), "The base HTTP address");
|
||||
desc.add_options()("http_port", po::value(&http_port)->default_value(http_port), "HTTP port to listen to requests");
|
||||
desc.add_options()("num_http_threads", po::value(&num_http_threads)->default_value(num_http_threads), "Number of http threads");
|
||||
}
|
||||
|
||||
// Parses argc and argv and sets the values for the class
|
||||
// Returns an enum with three options: ExitSuccess, ExitFailure, ContinueSuccess
|
||||
// ExitSuccess and ExitFailure means the program should exit but is left to the caller
|
||||
Result ParseInput(int argc, char** argv) {
|
||||
try {
|
||||
po::store(po::command_line_parser(argc, argv).options(desc).run(), vm); // can throw
|
||||
|
||||
if (ContainsHelp()) {
|
||||
PrintHelp(std::cout, full_desc);
|
||||
return Result::ExitSuccess;
|
||||
}
|
||||
|
||||
po::notify(vm); // throws on error, so do after help
|
||||
} catch (const po::error& e) {
|
||||
PrintHelp(std::cerr, e.what());
|
||||
return Result::ExitFailure;
|
||||
} catch (const std::exception& e) {
|
||||
PrintHelp(std::cerr, e.what());
|
||||
return Result::ExitFailure;
|
||||
}
|
||||
|
||||
Result result = ValidateOptions();
|
||||
|
||||
if (result == Result::ContinueSuccess) {
|
||||
logging_level = supported_log_levels[log_level_str];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
po::options_description desc{"Allowed options"};
|
||||
po::variables_map vm{};
|
||||
std::string log_level_str = "info";
|
||||
|
||||
// Print help and return if there is a bad value
|
||||
Result ValidateOptions() {
|
||||
if (vm.count("log_level") &&
|
||||
supported_log_levels.find(log_level_str) == supported_log_levels.end()) {
|
||||
PrintHelp(std::cerr, "log_level must be one of verbose, info, warning, error, or fatal");
|
||||
return Result::ExitFailure;
|
||||
} else if (num_http_threads <= 0) {
|
||||
PrintHelp(std::cerr, "num_http_threads must be greater than 0");
|
||||
return Result::ExitFailure;
|
||||
} else if (!file_exists(model_path)) {
|
||||
PrintHelp(std::cerr, "model_path must be the location of a valid file");
|
||||
return Result::ExitFailure;
|
||||
} else {
|
||||
return Result::ContinueSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if program options contains help
|
||||
bool ContainsHelp() const {
|
||||
return vm.count("help") || vm.count("h");
|
||||
}
|
||||
|
||||
// Prints a helpful message (param: what) to the user and then the program options
|
||||
// Example: config.PrintHelp(std::cout, "Non-negative values not allowed")
|
||||
// Which will print that message and then all publicly available options
|
||||
void PrintHelp(std::ostream& out, const std::string& what) const {
|
||||
out << what << std::endl
|
||||
<< desc << std::endl;
|
||||
}
|
||||
|
||||
inline bool file_exists(const std::string& fileName) {
|
||||
std::ifstream infile(fileName.c_str());
|
||||
return infile.good();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <sstream>
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
protobufutil::Status GenerateProtobufStatus(const onnxruntime::common::Status& onnx_status, const std::string& message) {
|
||||
protobufutil::error::Code code = protobufutil::error::Code::UNKNOWN;
|
||||
switch (onnx_status.Code()) {
|
||||
case onnxruntime::common::StatusCode::OK:
|
||||
case onnxruntime::common::StatusCode::MODEL_LOADED:
|
||||
code = protobufutil::error::Code::OK;
|
||||
break;
|
||||
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
|
||||
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
|
||||
case onnxruntime::common::StatusCode::INVALID_GRAPH:
|
||||
case onnxruntime::common::StatusCode::SHAPE_INFERENCE_NOT_REGISTERED:
|
||||
case onnxruntime::common::StatusCode::REQUIREMENT_NOT_REGISTERED:
|
||||
case onnxruntime::common::StatusCode::NO_SUCHFILE:
|
||||
case onnxruntime::common::StatusCode::NO_MODEL:
|
||||
code = protobufutil::error::Code::INVALID_ARGUMENT;
|
||||
break;
|
||||
case onnxruntime::common::StatusCode::NOT_IMPLEMENTED:
|
||||
code = protobufutil::error::Code::UNIMPLEMENTED;
|
||||
break;
|
||||
case onnxruntime::common::StatusCode::FAIL:
|
||||
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
|
||||
code = protobufutil::error::Code::INTERNAL;
|
||||
break;
|
||||
default:
|
||||
code = protobufutil::error::Code::UNKNOWN;
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "ONNX Runtime Status Code: " << onnx_status.Code() << ". " << message;
|
||||
return protobufutil::Status(code, oss.str());
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "core/common/status.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
|
||||
// Generate protobuf status from ONNX Runtime status
|
||||
google::protobuf::util::Status GenerateProtobufStatus(const onnxruntime::common::Status& onnx_status, const std::string& message);
|
||||
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
# ONNX Runtime Server Integration Tests
|
||||
|
||||
## Preparation
|
||||
|
||||
Tests validation depends on protobuf generated *_pb2.py. So we need to have a sucessful server application build to have it generated in the build folder under `server_test` subfolder. The following instruction assume you are in the folder. Otherwise, tests will fail due to `ModuleNotFoundError`.
|
||||
|
||||
## Functional Tests
|
||||
|
||||
Functional test will be run when build with `--build_server --enable_server_tests`. To run it separately, here is the command line:
|
||||
|
||||
```Bash
|
||||
/usr/bin/python3 ./test_main.py <server_app_path> <mnist_model_path> <test_data_path>
|
||||
```
|
||||
|
||||
## Model Zoo Tests
|
||||
|
||||
To run this set of tests, a prepared test data set need to be downloaded from [Azure Blob Storage](https://onnxserverdev.blob.core.windows.net/testing/server_test_data_20190422.zip) and unzip to a folder, e.g. /home/foo/bar/model_zoo_test. It contains:
|
||||
|
||||
* ONNX models from [ONNX Model Zoo](https://github.com/onnx/models) with opset 7/8/9.
|
||||
* HTTP request json and protobuf files
|
||||
* Expected response json and protobuf files
|
||||
|
||||
If you only need the request and response data. Here is the [link](https://onnxserverdev.blob.core.windows.net/testing/server_test_data_req_resp_only.zip) to download.
|
||||
|
||||
To run the full model zoo tests, here is the command line:
|
||||
|
||||
```Bash
|
||||
/usr/bin/python3 ./model_zoo_tests.py <server_app_path> <model_path> <test_data_path>
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```Bash
|
||||
/usr/bin/python3 ./model_zoo_tests.py /some/where/server_app /home/foo/bar/model_zoo_test /home/foo/bar/model_zoo_test
|
||||
```
|
||||
|
||||
If those models are in different folder but in the same structure as the test data, you could also do
|
||||
|
||||
```Bash
|
||||
/usr/bin/python3 ./model_zoo_tests.py /some/where/server_app /home/my/models/ /home/foo/bar/model_zoo_test/
|
||||
```
|
||||
|
||||
All tests are running in sequential order.
|
|
@ -0,0 +1,363 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import numpy
|
||||
|
||||
import test_util
|
||||
import onnx_ml_pb2
|
||||
import predict_pb2
|
||||
|
||||
class HttpJsonPayloadTests(unittest.TestCase):
|
||||
server_ip = '127.0.0.1'
|
||||
server_port = 54321
|
||||
url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict'
|
||||
server_app_path = ''
|
||||
test_data_path = ''
|
||||
model_path = ''
|
||||
log_level = 'verbose'
|
||||
server_app_proc = None
|
||||
wait_server_ready_in_seconds = 1
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
onnx_model = os.path.join(cls.model_path, 'mnist.onnx')
|
||||
test_util.prepare_mnist_model(onnx_model)
|
||||
cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level]
|
||||
test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
|
||||
cls.server_app_proc = subprocess.Popen(cmd)
|
||||
test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid))
|
||||
test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
|
||||
time.sleep(cls.wait_server_ready_in_seconds)
|
||||
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
test_util.test_log('Shutdown server app')
|
||||
cls.server_app_proc.kill()
|
||||
test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)))
|
||||
|
||||
|
||||
def test_mnist_happy_path(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
|
||||
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')
|
||||
|
||||
with open(input_data_file, 'r') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
with open(output_data_file, 'r') as f:
|
||||
expected_response_json = f.read()
|
||||
expected_response = json.loads(expected_response_json)
|
||||
|
||||
request_headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'x-ms-client-request-id': 'This~is~my~id'
|
||||
}
|
||||
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
|
||||
self.assertTrue(r.headers.get('x-ms-request-id'))
|
||||
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
|
||||
|
||||
actual_response = json.loads(r.content.decode('utf-8'))
|
||||
|
||||
# Note:
|
||||
# The 'dims' field is defined as "repeated int64" in protobuf.
|
||||
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
|
||||
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
|
||||
|
||||
self.assertTrue(actual_response['outputs'])
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
|
||||
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
|
||||
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
|
||||
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
|
||||
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')
|
||||
|
||||
for i in range(0, 10):
|
||||
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
|
||||
|
||||
|
||||
def test_mnist_invalid_url(self):
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', -1)
|
||||
test_util.test_log(url)
|
||||
|
||||
request_headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
r = requests.post(url, headers=request_headers, data={'foo': 'bar'})
|
||||
self.assertEqual(r.status_code, 404)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
|
||||
self.assertTrue(r.headers.get('x-ms-request-id'))
|
||||
|
||||
|
||||
def test_mnist_invalid_content_type(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
|
||||
request_headers = {
|
||||
'Content-Type': 'application/abc',
|
||||
'Accept': 'application/json',
|
||||
'x-ms-client-request-id': 'This~is~my~id'
|
||||
}
|
||||
|
||||
with open(input_data_file, 'r') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 400)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
|
||||
self.assertTrue(r.headers.get('x-ms-request-id'))
|
||||
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
|
||||
self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n')
|
||||
|
||||
|
||||
def test_mnist_missing_content_type(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
|
||||
request_headers = {
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
with open(input_data_file, 'r') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 400)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
|
||||
self.assertTrue(r.headers.get('x-ms-request-id'))
|
||||
self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n')
|
||||
|
||||
|
||||
def test_single_model_shortcut(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
|
||||
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')
|
||||
|
||||
with open(input_data_file, 'r') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
with open(output_data_file, 'r') as f:
|
||||
expected_response_json = f.read()
|
||||
expected_response = json.loads(expected_response_json)
|
||||
|
||||
request_headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'x-ms-client-request-id': 'This~is~my~id'
|
||||
}
|
||||
|
||||
url = "http://{0}:{1}/score".format(self.server_ip, self.server_port)
|
||||
test_util.test_log(url)
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
|
||||
self.assertTrue(r.headers.get('x-ms-request-id'))
|
||||
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
|
||||
|
||||
actual_response = json.loads(r.content.decode('utf-8'))
|
||||
|
||||
# Note:
|
||||
# The 'dims' field is defined as "repeated int64" in protobuf.
|
||||
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
|
||||
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
|
||||
|
||||
self.assertTrue(actual_response['outputs'])
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
|
||||
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
|
||||
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
|
||||
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
|
||||
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
|
||||
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')
|
||||
|
||||
for i in range(0, 10):
|
||||
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
|
||||
|
||||
|
||||
class HttpProtobufPayloadTests(unittest.TestCase):
|
||||
server_ip = '127.0.0.1'
|
||||
server_port = 54321
|
||||
url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict'
|
||||
server_app_path = ''
|
||||
test_data_path = ''
|
||||
model_path = ''
|
||||
log_level = 'verbose'
|
||||
server_app_proc = None
|
||||
wait_server_ready_in_seconds = 1
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
onnx_model = os.path.join(cls.model_path, 'mnist.onnx')
|
||||
test_util.prepare_mnist_model(onnx_model)
|
||||
cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level]
|
||||
test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
|
||||
cls.server_app_proc = subprocess.Popen(cmd)
|
||||
test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid))
|
||||
test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
|
||||
time.sleep(cls.wait_server_ready_in_seconds)
|
||||
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
test_util.test_log('Shutdown server app')
|
||||
cls.server_app_proc.kill()
|
||||
test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)))
|
||||
|
||||
|
||||
def test_mnist_happy_path(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
|
||||
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.pb')
|
||||
|
||||
with open(input_data_file, 'rb') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
content_type_headers = ['application/x-protobuf', 'application/octet-stream', 'application/vnd.google.protobuf']
|
||||
|
||||
for h in content_type_headers:
|
||||
request_headers = {
|
||||
'Content-Type': h,
|
||||
'Accept': 'application/x-protobuf'
|
||||
}
|
||||
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/x-protobuf')
|
||||
self.assertTrue(r.headers.get('x-ms-request-id'))
|
||||
|
||||
actual_result = predict_pb2.PredictResponse()
|
||||
actual_result.ParseFromString(r.content)
|
||||
|
||||
expected_result = predict_pb2.PredictResponse()
|
||||
with open(output_data_file, 'rb') as f:
|
||||
expected_result.ParseFromString(f.read())
|
||||
|
||||
for k in expected_result.outputs.keys():
|
||||
self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
|
||||
|
||||
count = 1
|
||||
for i in range(0, len(expected_result.outputs['Plus214_Output_0'].dims)):
|
||||
self.assertEqual(actual_result.outputs['Plus214_Output_0'].dims[i], expected_result.outputs['Plus214_Output_0'].dims[i])
|
||||
count = count * int(actual_result.outputs['Plus214_Output_0'].dims[i])
|
||||
|
||||
actual_array = numpy.frombuffer(actual_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32)
|
||||
expected_array = numpy.frombuffer(expected_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32)
|
||||
self.assertEqual(len(actual_array), len(expected_array))
|
||||
self.assertEqual(len(actual_array), count)
|
||||
for i in range(0, count):
|
||||
self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|
||||
|
||||
|
||||
def test_respect_accept_header(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
|
||||
|
||||
with open(input_data_file, 'rb') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
accept_headers = ['application/x-protobuf', 'application/octet-stream', 'application/vnd.google.protobuf']
|
||||
|
||||
for h in accept_headers:
|
||||
request_headers = {
|
||||
'Content-Type': 'application/x-protobuf',
|
||||
'Accept': h
|
||||
}
|
||||
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.headers.get('Content-Type'), h)
|
||||
|
||||
|
||||
def test_missing_accept_header(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
|
||||
|
||||
with open(input_data_file, 'rb') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
request_headers = {
|
||||
'Content-Type': 'application/x-protobuf',
|
||||
}
|
||||
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream')
|
||||
|
||||
|
||||
def test_any_accept_header(self):
|
||||
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
|
||||
|
||||
with open(input_data_file, 'rb') as f:
|
||||
request_payload = f.read()
|
||||
|
||||
request_headers = {
|
||||
'Content-Type': 'application/x-protobuf',
|
||||
'Accept': '*/*'
|
||||
}
|
||||
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
test_util.test_log(url)
|
||||
r = requests.post(url, headers=request_headers, data=request_payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream')
|
||||
|
||||
|
||||
class HttpEndpointTests(unittest.TestCase):
|
||||
server_ip = '127.0.0.1'
|
||||
server_port = 54321
|
||||
server_app_path = ''
|
||||
test_data_path = ''
|
||||
model_path = ''
|
||||
log_level = 'verbose'
|
||||
server_app_proc = None
|
||||
wait_server_ready_in_seconds = 1
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
onnx_model = os.path.join(cls.model_path, 'mnist.onnx')
|
||||
test_util.prepare_mnist_model(onnx_model)
|
||||
cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level]
|
||||
test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
|
||||
cls.server_app_proc = subprocess.Popen(cmd)
|
||||
test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid))
|
||||
test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
|
||||
time.sleep(cls.wait_server_ready_in_seconds)
|
||||
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
test_util.test_log('Shutdown server app')
|
||||
cls.server_app_proc.kill()
|
||||
test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid)))
|
||||
|
||||
|
||||
def test_health_endpoint(self):
|
||||
url = url = "http://{0}:{1}/".format(self.server_ip, self.server_port)
|
||||
test_util.test_log(url)
|
||||
r = requests.get(url)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.content.decode('utf-8'), 'Healthy')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import json
|
||||
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
|
||||
import predict_pb2
|
||||
import onnx_ml_pb2
|
||||
|
||||
# Current models only have one input and one output
|
||||
def get_io_name(model_file_name):
|
||||
sess = onnxruntime.InferenceSession(model_file_name)
|
||||
return sess.get_inputs()[0].name, sess.get_outputs()[0].name
|
||||
|
||||
|
||||
def gen_input_pb(pb_full_path, input_name, output_name, request_file_path):
|
||||
t = onnx_ml_pb2.TensorProto()
|
||||
with open(pb_full_path, 'rb') as fin:
|
||||
t.ParseFromString(fin.read())
|
||||
predict_request = predict_pb2.PredictRequest()
|
||||
predict_request.inputs[input_name].CopyFrom(t)
|
||||
predict_request.output_filter.append(output_name)
|
||||
|
||||
with open(request_file_path, "wb") as fout:
|
||||
fout.write(predict_request.SerializeToString())
|
||||
|
||||
|
||||
def gen_output_pb(pb_full_path, output_name, response_file_path):
|
||||
t = onnx_ml_pb2.TensorProto()
|
||||
with open(pb_full_path, 'rb') as fin:
|
||||
t.ParseFromString(fin.read())
|
||||
predict_response = predict_pb2.PredictResponse()
|
||||
predict_response.outputs[output_name].CopyFrom(t)
|
||||
|
||||
with open(response_file_path, "wb") as fout:
|
||||
fout.write(predict_response.SerializeToString())
|
||||
|
||||
|
||||
def tensor2dict(full_path):
|
||||
t = onnx.TensorProto()
|
||||
with open(full_path, 'rb') as f:
|
||||
t.ParseFromString(f.read())
|
||||
|
||||
jsonStr = MessageToJson(t, use_integers_for_enums=True)
|
||||
data = json.loads(jsonStr)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def gen_input_json(pb_full_path, input_name, output_name, json_file_path):
|
||||
data = tensor2dict(pb_full_path)
|
||||
|
||||
inputs = {}
|
||||
inputs[input_name] = data
|
||||
output_filters = [ output_name ]
|
||||
|
||||
req = {}
|
||||
req["inputs"] = inputs
|
||||
req["outputFilter"] = output_filters
|
||||
|
||||
with open(json_file_path, 'w') as outfile:
|
||||
json.dump(req, outfile)
|
||||
|
||||
|
||||
def gen_output_json(pb_full_path, output_name, json_file_path):
|
||||
data = tensor2dict(pb_full_path)
|
||||
|
||||
output = {}
|
||||
output[output_name] = data
|
||||
|
||||
resp = {}
|
||||
resp["outputs"] = output
|
||||
|
||||
with open(json_file_path, 'w') as outfile:
|
||||
json.dump(resp, outfile)
|
||||
|
||||
|
||||
def gen_req_resp(model_zoo, test_data, copy_model=True):
|
||||
opsets = [name for name in os.listdir(model_zoo) if os.path.isdir(os.path.join(model_zoo, name))]
|
||||
for opset in opsets:
|
||||
os.makedirs(os.path.join(test_data, opset), exist_ok=True)
|
||||
|
||||
current_model_folder = os.path.join(model_zoo, opset)
|
||||
current_data_folder = os.path.join(test_data, opset)
|
||||
|
||||
models = [name for name in os.listdir(current_model_folder) if os.path.isdir(os.path.join(current_model_folder, name))]
|
||||
for model in models:
|
||||
os.makedirs(os.path.join(current_data_folder, model), exist_ok=True)
|
||||
|
||||
src_folder = os.path.join(current_model_folder, model)
|
||||
dst_folder = os.path.join(current_data_folder, model)
|
||||
|
||||
if copy_model:
|
||||
shutil.copy2(os.path.join(src_folder, 'model.onnx'), dst_folder)
|
||||
|
||||
iname, oname = get_io_name(os.path.join(src_folder, 'model.onnx'))
|
||||
model_test_data = [name for name in os.listdir(src_folder) if os.path.isdir(os.path.join(src_folder, name))]
|
||||
for test in model_test_data:
|
||||
src = os.path.join(src_folder, test)
|
||||
dst = os.path.join(dst_folder, test)
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
gen_input_json(os.path.join(src, 'input_0.pb'), iname, oname, os.path.join(dst, 'request.json'))
|
||||
gen_output_json(os.path.join(src, 'output_0.pb'), oname, os.path.join(dst, 'response.json'))
|
||||
gen_input_pb(os.path.join(src, 'input_0.pb'), iname, oname, os.path.join(dst, 'request.pb'))
|
||||
gen_output_pb(os.path.join(src, 'output_0.pb'), oname, os.path.join(dst, 'response.pb'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_zoo = os.path.realpath(sys.argv[1])
|
||||
test_data = os.path.realpath(sys.argv[2])
|
||||
|
||||
os.makedirs(test_data, exist_ok=True)
|
||||
gen_req_resp(model_zoo, test_data)
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import os
|
||||
import test_util
|
||||
import sys
|
||||
|
||||
class ModelZooTests(unittest.TestCase):
|
||||
server_ip = '127.0.0.1'
|
||||
server_port = 54321
|
||||
url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict'
|
||||
server_app_path = '' # Required
|
||||
log_level = 'verbose'
|
||||
server_ready_in_seconds = 10
|
||||
server_off_in_seconds = 100
|
||||
need_data_preparation = False
|
||||
need_data_cleanup = False
|
||||
model_zoo_model_path = '' # Required
|
||||
model_zoo_test_data_path = '' # Required
|
||||
supported_opsets = ['opset_7', 'opset_8', 'opset_9']
|
||||
skipped_models = []
|
||||
|
||||
def test_models_from_model_zoo(self):
|
||||
json_request_headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
pb_request_headers = {
|
||||
'Content-Type': 'application/octet-stream',
|
||||
'Accept': 'application/octet-stream'
|
||||
}
|
||||
|
||||
model_data_map = {}
|
||||
for opset in self.supported_opsets:
|
||||
test_data_folder = os.path.join(self.model_zoo_test_data_path, opset)
|
||||
model_file_folder = os.path.join(self.model_zoo_model_path, opset)
|
||||
|
||||
if os.path.isdir(test_data_folder):
|
||||
for name in os.listdir(test_data_folder):
|
||||
if name in self.skipped_models:
|
||||
continue
|
||||
|
||||
if os.path.isdir(os.path.join(test_data_folder, name)):
|
||||
current_dir = os.path.join(test_data_folder, name)
|
||||
model_data_map[os.path.join(model_file_folder, name)] = [os.path.join(current_dir, name) for name in os.listdir(current_dir) if os.path.isdir(os.path.join(current_dir, name))]
|
||||
|
||||
test_util.test_log('Planned models and test data:')
|
||||
for model_data, data_paths in model_data_map.items():
|
||||
test_util.test_log(model_data)
|
||||
for data in data_paths:
|
||||
test_util.test_log('\t\t{0}'.format(data))
|
||||
test_util.test_log('-----------------------')
|
||||
|
||||
self.server_port = random.randint(30000, 40000)
|
||||
for model_path, data_paths in model_data_map.items():
|
||||
server_app_proc = None
|
||||
try:
|
||||
cmd = [self.server_app_path, '--http_port', str(self.server_port), '--model_path', os.path.join(model_path, 'model.onnx'), '--log_level', self.log_level]
|
||||
test_util.test_log(cmd)
|
||||
server_app_proc = test_util.launch_server_app(cmd, self.server_ip, self.server_port, self.server_ready_in_seconds)
|
||||
|
||||
test_util.test_log('[{0}] Run tests...'.format(model_path))
|
||||
for test in data_paths:
|
||||
test_util.test_log('[{0}] Current: {0}'.format(model_path, test))
|
||||
|
||||
test_util.test_log('[{0}] JSON payload testing ....'.format(model_path))
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345)
|
||||
with open(os.path.join(test, 'request.json')) as f:
|
||||
request_payload = f.read()
|
||||
resp = test_util.make_http_request(url, json_request_headers, request_payload)
|
||||
test_util.json_response_validation(self, resp, os.path.join(test, 'response.json'))
|
||||
|
||||
test_util.test_log('[{0}] Protobuf payload testing ....'.format(model_path))
|
||||
url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 54321)
|
||||
with open(os.path.join(test, 'request.pb'), 'rb') as f:
|
||||
request_payload = f.read()
|
||||
resp = test_util.make_http_request(url, pb_request_headers, request_payload)
|
||||
test_util.pb_response_validation(self, resp, os.path.join(test, 'response.pb'))
|
||||
finally:
|
||||
test_util.shutdown_server_app(server_app_proc, self.server_off_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = unittest.TestLoader()
|
||||
|
||||
test_classes = [ModelZooTests]
|
||||
|
||||
test_suites = []
|
||||
for tests in test_classes:
|
||||
tests.server_app_path = sys.argv[1]
|
||||
tests.model_zoo_model_path = sys.argv[2]
|
||||
tests.model_zoo_test_data_path = sys.argv[3]
|
||||
|
||||
test_suites.append(loader.loadTestsFromTestCase(tests))
|
||||
|
||||
suites = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
|
||||
results = runner.run(suites)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import random
|
||||
import unittest
|
||||
import function_tests
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = unittest.TestLoader()
|
||||
|
||||
test_classes = [function_tests.HttpJsonPayloadTests, function_tests.HttpProtobufPayloadTests, function_tests.HttpEndpointTests]
|
||||
|
||||
test_suites = []
|
||||
for tests in test_classes:
|
||||
tests.server_app_path = sys.argv[1]
|
||||
tests.model_path = sys.argv[2]
|
||||
tests.test_data_path = sys.argv[3]
|
||||
tests.server_port = random.randint(30000, 50000)
|
||||
|
||||
test_suites.append(loader.loadTestsFromTestCase(tests))
|
||||
|
||||
suites = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
|
||||
results = runner.run(suites)
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import base64
|
||||
import struct
|
||||
import math
|
||||
import subprocess
|
||||
import time
|
||||
import requests
|
||||
import json
|
||||
import datetime
|
||||
import socket
|
||||
import errno
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
import predict_pb2
|
||||
import onnx_ml_pb2
|
||||
import numpy
|
||||
|
||||
def test_log(str):
|
||||
print('[Test Log][{0}] {1}'.format(datetime.datetime.now(), str))
|
||||
|
||||
|
||||
def is_process_killed(pid):
|
||||
if sys.platform.startswith("win"):
|
||||
process_name = 'onnxruntime_host.exe'
|
||||
call = 'TASKLIST', '/FI', 'imagename eq {0}'.format(process_name)
|
||||
output = subprocess.check_output(call).decode('utf-8')
|
||||
print(output)
|
||||
last_line = output.strip().split('\r\n')[-1]
|
||||
return not last_line.lower().startswith(process_name)
|
||||
else:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def prepare_mnist_model(target_path):
|
||||
# TODO: This need to be replaced by test data on build machine after merged to upstream master.
|
||||
if not os.path.isfile(target_path):
|
||||
test_log('Downloading model from blob storage: https://ortsrvdev.blob.core.windows.net/test-data/mnist.onnx to {0}'.format(target_path))
|
||||
urllib.request.urlretrieve('https://ortsrvdev.blob.core.windows.net/test-data/mnist.onnx', target_path)
|
||||
else:
|
||||
test_log('Found mnist model at {0}'.format(target_path))
|
||||
|
||||
|
||||
def decode_base64_string(s, count_and_type):
|
||||
b = base64.b64decode(s)
|
||||
r = struct.unpack(count_and_type, b)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def compare_floats(a, b, rel_tol=0.0001, abs_tol=0.0001):
|
||||
if not math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol):
|
||||
test_log('Not match with relative tolerance {0} and absolute tolerance {1}: {2} and {3}'.format(rel_tol, abs_tol, a, b))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def wait_service_up(server, port, timeout=1):
|
||||
s = socket.socket()
|
||||
if timeout:
|
||||
end = time.time() + timeout
|
||||
|
||||
while True:
|
||||
try:
|
||||
if timeout:
|
||||
next_timeout = end - time.time()
|
||||
if next_timeout < 0:
|
||||
return False
|
||||
else:
|
||||
s.settimeout(next_timeout)
|
||||
|
||||
s.connect((server, port))
|
||||
except socket.timeout as err:
|
||||
if timeout:
|
||||
return False
|
||||
except Exception as err:
|
||||
pass
|
||||
else:
|
||||
s.close()
|
||||
return True
|
||||
|
||||
|
||||
def launch_server_app(cmd, server_ip, server_port, wait_server_ready_in_seconds):
|
||||
test_log('Launching server app: [{0}]'.format(' '.join(cmd)))
|
||||
server_app_proc = subprocess.Popen(cmd)
|
||||
test_log('Server app PID: {0}'.format(server_app_proc.pid))
|
||||
test_log('Wait up to {0} second(s) for server initialization'.format(wait_server_ready_in_seconds))
|
||||
wait_service_up(server_ip, server_port, wait_server_ready_in_seconds)
|
||||
|
||||
return server_app_proc
|
||||
|
||||
|
||||
def shutdown_server_app(server_app_proc, wait_for_server_off_in_seconds):
|
||||
if server_app_proc is not None:
|
||||
test_log('Shutdown server app')
|
||||
server_app_proc.kill()
|
||||
|
||||
while not is_process_killed(server_app_proc.pid):
|
||||
server_app_proc.wait(timeout=wait_for_server_off_in_seconds)
|
||||
test_log('PID {0} has been killed: {1}'.format(server_app_proc.pid, is_process_killed(server_app_proc.pid)))
|
||||
|
||||
# Additional sleep to make sure the resource has been freed.
|
||||
time.sleep(1)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def make_http_request(url, request_headers, payload):
|
||||
test_log('POST Request Started')
|
||||
resp = requests.post(url, headers=request_headers, data=payload)
|
||||
test_log('POST Request Done')
|
||||
return resp
|
||||
|
||||
|
||||
def json_response_validation(cls, resp, expected_resp_json_file):
|
||||
cls.assertEqual(resp.status_code, 200)
|
||||
cls.assertTrue(resp.headers.get('x-ms-request-id'))
|
||||
cls.assertEqual(resp.headers.get('Content-Type'), 'application/json')
|
||||
|
||||
with open(expected_resp_json_file) as f:
|
||||
expected_result = json.loads(f.read())
|
||||
|
||||
actual_response = json.loads(resp.content.decode('utf-8'))
|
||||
cls.assertTrue(actual_response['outputs'])
|
||||
|
||||
for output in expected_result['outputs'].keys():
|
||||
cls.assertTrue(actual_response['outputs'][output])
|
||||
cls.assertTrue(actual_response['outputs'][output]['dataType'])
|
||||
cls.assertEqual(actual_response['outputs'][output]['dataType'], expected_result['outputs'][output]['dataType'])
|
||||
cls.assertTrue(actual_response['outputs'][output]['dims'])
|
||||
cls.assertEqual(actual_response['outputs'][output]['dims'], expected_result['outputs'][output]['dims'])
|
||||
cls.assertTrue(actual_response['outputs'][output]['rawData'])
|
||||
|
||||
count = 1
|
||||
for x in actual_response['outputs'][output]['dims']:
|
||||
count = count * int(x)
|
||||
|
||||
actual_array = decode_base64_string(actual_response['outputs'][output]['rawData'], '{0}f'.format(count))
|
||||
expected_array = decode_base64_string(expected_result['outputs'][output]['rawData'], '{0}f'.format(count))
|
||||
cls.assertEqual(len(actual_array), len(expected_array))
|
||||
cls.assertEqual(len(actual_array), count)
|
||||
for i in range(0, count):
|
||||
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|
||||
|
||||
|
||||
def pb_response_validation(cls, resp, expected_resp_pb_file):
|
||||
cls.assertEqual(resp.status_code, 200)
|
||||
cls.assertTrue(resp.headers.get('x-ms-request-id'))
|
||||
cls.assertEqual(resp.headers.get('Content-Type'), 'application/octet-stream')
|
||||
|
||||
actual_result = predict_pb2.PredictResponse()
|
||||
actual_result.ParseFromString(resp.content)
|
||||
|
||||
expected_result = predict_pb2.PredictResponse()
|
||||
with open(expected_resp_pb_file, 'rb') as f:
|
||||
expected_result.ParseFromString(f.read())
|
||||
|
||||
for k in expected_result.outputs.keys():
|
||||
cls.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)
|
||||
|
||||
count = 1
|
||||
for i in range(0, len(expected_result.outputs[k].dims)):
|
||||
cls.assertEqual(actual_result.outputs[k].dims[i], expected_result.outputs[k].dims[i])
|
||||
count = count * int(actual_result.outputs[k].dims[i])
|
||||
|
||||
actual_array = numpy.frombuffer(actual_result.outputs[k].raw_data, dtype=numpy.float32)
|
||||
expected_array = numpy.frombuffer(expected_result.outputs[k].raw_data, dtype=numpy.float32)
|
||||
cls.assertEqual(len(actual_array), len(expected_array))
|
||||
cls.assertEqual(len(actual_array), count)
|
||||
for i in range(0, count):
|
||||
cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,109 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "server/http/core/routes.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
namespace test {
|
||||
|
||||
using test_data = std::tuple<http::verb, std::string, std::string, std::string, std::string, http::status>;
|
||||
|
||||
void do_something(const std::string& name, const std::string& version,
|
||||
const std::string& action, HttpContext& context) {
|
||||
auto noop = name + version + action + context.request.body();
|
||||
}
|
||||
|
||||
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data);
|
||||
|
||||
TEST(HttpRouteTests, RegisterTest) {
|
||||
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
|
||||
Routes routes;
|
||||
EXPECT_TRUE(routes.RegisterController(http::verb::post, predict_regex, do_something));
|
||||
|
||||
auto status_regex = R"(/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?(?:\/(metadata))?)";
|
||||
EXPECT_TRUE(routes.RegisterController(http::verb::get, status_regex, do_something));
|
||||
}
|
||||
|
||||
TEST(HttpRouteTests, PostRouteTest) {
|
||||
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
|
||||
|
||||
std::vector<test_data> actions{
|
||||
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::ok),
|
||||
std::make_tuple(http::verb::post, "/v1/models/abc:predict", "abc", "", "predict", http::status::ok),
|
||||
std::make_tuple(http::verb::post, "/v1/models/models/versions/45:predict", "models", "45", "predict", http::status::ok),
|
||||
std::make_tuple(http::verb::post, "/v1/models/??$$%%@@$^^/versions/45:predict", "??$$%%@@$^^", "45", "predict", http::status::ok),
|
||||
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok)};
|
||||
|
||||
run_route(predict_regex, http::verb::post, actions, true);
|
||||
}
|
||||
|
||||
TEST(HttpRouteTests, PostRouteInvalidURLTest) {
|
||||
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
|
||||
|
||||
std::vector<test_data> actions{
|
||||
std::make_tuple(http::verb::post, "", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models:bar", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/abc/versions", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/abc/versions:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/a:bc/versions:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/abc/versions/2.0:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/models/abc/versions/2:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/versions/2:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/foo/versions/:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/foo/versions:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "v1/models/foo/versions/12:predict", "", "", "", http::status::not_found),
|
||||
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)};
|
||||
|
||||
run_route(predict_regex, http::verb::post, actions, false);
|
||||
}
|
||||
|
||||
// These tests are because we currently only support POST and GET
|
||||
// Some HTTP methods should be removed from test data if we support more (e.g. PUT)
|
||||
TEST(HttpRouteTests, PostRouteInvalidMethodTest) {
|
||||
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
|
||||
|
||||
std::vector<test_data> actions{
|
||||
std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed),
|
||||
std::make_tuple(http::verb::put, "/v1/models", "", "", "", http::status::method_not_allowed),
|
||||
std::make_tuple(http::verb::delete_, "/v1/models", "", "", "", http::status::method_not_allowed),
|
||||
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)};
|
||||
|
||||
run_route(predict_regex, http::verb::post, actions, false);
|
||||
}
|
||||
|
||||
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data) {
|
||||
Routes routes;
|
||||
EXPECT_TRUE(routes.RegisterController(method, pattern, do_something));
|
||||
|
||||
for (const auto& i : data) {
|
||||
http::verb test_method;
|
||||
std::string url_string;
|
||||
std::string name;
|
||||
std::string version;
|
||||
std::string action;
|
||||
HandlerFn fn;
|
||||
|
||||
std::string expected_name;
|
||||
std::string expected_version;
|
||||
std::string expected_action;
|
||||
http::status expected_status;
|
||||
|
||||
std::tie(test_method, url_string, expected_name, expected_version, expected_action, expected_status) = i;
|
||||
EXPECT_EQ(expected_status, routes.ParseUrl(test_method, url_string, name, version, action, fn));
|
||||
if (does_validate_data) {
|
||||
EXPECT_EQ(name, expected_name);
|
||||
EXPECT_EQ(version, expected_version);
|
||||
EXPECT_EQ(action, expected_action);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,128 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <fstream>
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "predict.pb.h"
|
||||
#include "server/http/json_handling.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
namespace test {
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
TEST(JsonDeserializationTests, HappyPath) {
|
||||
std::string input_json = R"({"inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAQEAAAAAAAAAAAAAAgEAAAABAAAAAAAAAMEEAAAAAAAAAAAAAYEEAAIA/AAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQEEAAAAAAAAAAAAA4EAAAAAAAACAPwAAIEEAAAAAAAAAQAAAAEAAAIBBAAAAAAAAQEAAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4EAAAABBAAAAAAAAAEEAAAAAAAAAAAAAAEEAAAAAAAAAAAAAmEEAAAAAAAAAAAAAgD8AAKhBAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAAAAAAIEEAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABQQQAAAAAAAHBBAAAgQQAA0EEAAAhCAACIQQAAmkIAADVDAAAyQwAADEIAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWQwAAfkMAAHpDAAB7QwAAc0MAAHxDAAB8QwAAf0MAADRCAADAQAAAAAAAAKBAAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBAAACQQgAATUMAAH9DAABuQwAAc0MAAH9DAAB+QwAAe0MAAHhDAABJQwAARkMAAGRCAAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWkMAAH9DAABxQwAAf0MAAHlDAAB6QwAAe0MAAHpDAAB/QwAAf0MAAHJDAABgQwAAREIAAAAAAABAQQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAABAQAAAAEAAAABAAACAPwAAAAAAAIJCAABkQwAAf0MAAH5DAAB0QwAA7kIAAAhCAAAkQgAA3EIAAHpDAAB/QwAAeEMAAPhCAACgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBBAAAAAAAAeEIAAM5CAADiQgAA6kIAAAhCAAAAAAAAAAAAAAAAAABIQwAAdEMAAH9DAAB/QwAAAAAAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAEAAAIA/AAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAAAAAAAAAAAAAABAAACAQAAAAAAAADBBAAAAAAAA4EAAAMBAAAAAAAAAlkIAAHRDAAB/QwAAf0MAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAIBAAACAQAAAAAAAAGBBAAAAAAAAAAAAAAAAAAAQQQAAAAAAAABAAAAAAAAAAAAAAAhCAAB/QwAAf0MAAH1DAAAgQQAAIEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAABAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAFBBAAAwQQAAAAAAAAAAAAAAAAAAwEAAAEBBAADGQgAAf0MAAH5DAAB4QwAAcEEAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAoEAAAMBAAAAwQQAAAAAAAAAAAACIQQAAOEMAAHdDAAB/QwAAc0MAAFBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAABAAACAQAAAgEAAAAAAAAAwQQAAAAAAAExCAAC8QgAAqkIAAKBAAACgQAAAyEEAAHZDAAB2QwAAf0MAAFBDAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAACAQAAAgD8AAAAAAAAAAAAAgD8AAOBAAABwQQAAmEEAAMZCAADOQgAANkMAAD1DAABtQwAAfUMAAHxDAAA/QwAAPkMAAGNDAABzQwAAfEMAAFJDAACQQQAA4EAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAAAAAAAAAAAABCAADaQgAAOUMAAHdDAAB/QwAAckMAAH9DAAB0QwAAf0MAAH9DAAByQwAAe0MAAH9DAABwQwAAf0MAAH9DAABaQwAA+EIAABBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAD+QgAAf0MAAGtDAAB/QwAAf0MAAHdDAABlQwAAVEMAAHJDAAB6QwAAf0MAAH9DAAB4QwAAf0MAAH1DAAB5QwAAf0MAAHNDAAAqQwAAQEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAQQQAAfUMAAH9DAAB/QwAAaUMAAEpDAACqQgAAAAAAAFRCAABEQwAAbkMAAH9DAABjQwAAbkMAAA5DAADaQgAAQUMAAH9DAABwQwAAf0MAADRDAAAAAAAAAAAAAAAAAAAAAAAAwEAAAAAAAACwQQAAgD8AAHVDAABzQwAAfkMAAH9DAABZQwAAa0MAAGJDAABVQwAAdEMAAHtDAAB/QwAAb0MAAJpCAAAAAAAAAAAAAKBBAAA2QwAAd0MAAG9DAABzQwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAlQwAAe0MAAH9DAAB1QwAAf0MAAHJDAAB9QwAAekMAAH9DAABFQwAA1kIAAGxCAAAAAAAAkEEAAABAAADAQAAAAAAAAFhCAAB/QwAAHkMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEEAAAAAAAAAAAAAwEAAAAhCAAAnQwAAQkMAADBDAAA3QwAAJEMAADBCAAAAQAAAIEEAAMBAAADAQAAAAAAAAAAAAACgQAAAAAAAAIA/AAAAAAAAYEEAAABAAAAAAAAAAAAAAAAAAAAAAAAAIEEAAAAAAABgQQAAAAAAAEBBAAAAAAAAoEAAAAAAAACAPwAAAAAAAMBAAAAAAAAA4EAAAAAAAAAAAAAAAAAAAABBAAAAAAAAIEEAAAAAAACgQAAAAAAAAAAAAAAgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAABgQQAAAAAAAIBAAAAAAAAAAAAAAMhBAAAAAAAAAAAAABBBAAAAAAAAAAAAABBBAAAAAAAAMEEAAAAAAACAPwAAAAAAAAAAAAAAQAAAAAAAAAAAAADgQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}},"outputFilter":["Plus214_Output_0"]})";
|
||||
onnxruntime::server::PredictRequest request;
|
||||
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
|
||||
|
||||
EXPECT_EQ(protobufutil::error::OK, status.error_code());
|
||||
}
|
||||
|
||||
TEST(JsonDeserializationTests, WithUnknownField) {
|
||||
std::string input_json = R"({"foo": "bar","inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAQEAAAAAAAAAAAAAAgEAAAABAAAAAAAAAMEEAAAAAAAAAAAAAYEEAAIA/AAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQEEAAAAAAAAAAAAA4EAAAAAAAACAPwAAIEEAAAAAAAAAQAAAAEAAAIBBAAAAAAAAQEAAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4EAAAABBAAAAAAAAAEEAAAAAAAAAAAAAAEEAAAAAAAAAAAAAmEEAAAAAAAAAAAAAgD8AAKhBAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAAAAAAIEEAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABQQQAAAAAAAHBBAAAgQQAA0EEAAAhCAACIQQAAmkIAADVDAAAyQwAADEIAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWQwAAfkMAAHpDAAB7QwAAc0MAAHxDAAB8QwAAf0MAADRCAADAQAAAAAAAAKBAAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBAAACQQgAATUMAAH9DAABuQwAAc0MAAH9DAAB+QwAAe0MAAHhDAABJQwAARkMAAGRCAAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWkMAAH9DAABxQwAAf0MAAHlDAAB6QwAAe0MAAHpDAAB/QwAAf0MAAHJDAABgQwAAREIAAAAAAABAQQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAABAQAAAAEAAAABAAACAPwAAAAAAAIJCAABkQwAAf0MAAH5DAAB0QwAA7kIAAAhCAAAkQgAA3EIAAHpDAAB/QwAAeEMAAPhCAACgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBBAAAAAAAAeEIAAM5CAADiQgAA6kIAAAhCAAAAAAAAAAAAAAAAAABIQwAAdEMAAH9DAAB/QwAAAAAAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAEAAAIA/AAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAAAAAAAAAAAAAABAAACAQAAAAAAAADBBAAAAAAAA4EAAAMBAAAAAAAAAlkIAAHRDAAB/QwAAf0MAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAIBAAACAQAAAAAAAAGBBAAAAAAAAAAAAAAAAAAAQQQAAAAAAAABAAAAAAAAAAAAAAAhCAAB/QwAAf0MAAH1DAAAgQQAAIEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAABAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAFBBAAAwQQAAAAAAAAAAAAAAAAAAwEAAAEBBAADGQgAAf0MAAH5DAAB4QwAAcEEAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAoEAAAMBAAAAwQQAAAAAAAAAAAACIQQAAOEMAAHdDAAB/QwAAc0MAAFBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAABAAACAQAAAgEAAAAAAAAAwQQAAAAAAAExCAAC8QgAAqkIAAKBAAACgQAAAyEEAAHZDAAB2QwAAf0MAAFBDAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAACAQAAAgD8AAAAAAAAAAAAAgD8AAOBAAABwQQAAmEEAAMZCAADOQgAANkMAAD1DAABtQwAAfUMAAHxDAAA/QwAAPkMAAGNDAABzQwAAfEMAAFJDAACQQQAA4EAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAAAAAAAAAAAABCAADaQgAAOUMAAHdDAAB/QwAAckMAAH9DAAB0QwAAf0MAAH9DAAByQwAAe0MAAH9DAABwQwAAf0MAAH9DAABaQwAA+EIAABBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAD+QgAAf0MAAGtDAAB/QwAAf0MAAHdDAABlQwAAVEMAAHJDAAB6QwAAf0MAAH9DAAB4QwAAf0MAAH1DAAB5QwAAf0MAAHNDAAAqQwAAQEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAQQQAAfUMAAH9DAAB/QwAAaUMAAEpDAACqQgAAAAAAAFRCAABEQwAAbkMAAH9DAABjQwAAbkMAAA5DAADaQgAAQUMAAH9DAABwQwAAf0MAADRDAAAAAAAAAAAAAAAAAAAAAAAAwEAAAAAAAACwQQAAgD8AAHVDAABzQwAAfkMAAH9DAABZQwAAa0MAAGJDAABVQwAAdEMAAHtDAAB/QwAAb0MAAJpCAAAAAAAAAAAAAKBBAAA2QwAAd0MAAG9DAABzQwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAlQwAAe0MAAH9DAAB1QwAAf0MAAHJDAAB9QwAAekMAAH9DAABFQwAA1kIAAGxCAAAAAAAAkEEAAABAAADAQAAAAAAAAFhCAAB/QwAAHkMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEEAAAAAAAAAAAAAwEAAAAhCAAAnQwAAQkMAADBDAAA3QwAAJEMAADBCAAAAQAAAIEEAAMBAAADAQAAAAAAAAAAAAACgQAAAAAAAAIA/AAAAAAAAYEEAAABAAAAAAAAAAAAAAAAAAAAAAAAAIEEAAAAAAABgQQAAAAAAAEBBAAAAAAAAoEAAAAAAAACAPwAAAAAAAMBAAAAAAAAA4EAAAAAAAAAAAAAAAAAAAABBAAAAAAAAIEEAAAAAAACgQAAAAAAAAAAAAAAgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAABgQQAAAAAAAIBAAAAAAAAAAAAAAMhBAAAAAAAAAAAAABBBAAAAAAAAAAAAABBBAAAAAAAAMEEAAAAAAACAPwAAAAAAAAAAAAAAQAAAAAAAAAAAAADgQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}},"outputFilter":["Plus214_Output_0"]})";
|
||||
onnxruntime::server::PredictRequest request;
|
||||
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
|
||||
|
||||
EXPECT_EQ(protobufutil::error::OK, status.error_code());
|
||||
}
|
||||
|
||||
TEST(JsonDeserializationTests, InvalidData) {
|
||||
std::string input_json = R"({"inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"hello"}},"outputFilter":["Plus214_Output_0"]})";
|
||||
onnxruntime::server::PredictRequest request;
|
||||
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
|
||||
|
||||
EXPECT_EQ(protobufutil::error::INVALID_ARGUMENT, status.error_code());
|
||||
EXPECT_EQ("inputs[0].value.raw_data: invalid value \"hello\" for type TYPE_BYTES", status.error_message());
|
||||
}
|
||||
|
||||
TEST(JsonDeserializationTests, InvalidJson) {
|
||||
std::string input_json = R"({inputs":{"Input3":{"dims":["1","1","28","28"],"dataType":1,"rawData":"hello"}},"outputFilter":["Plus214_Output_0"]})";
|
||||
onnxruntime::server::PredictRequest request;
|
||||
protobufutil::Status status = onnxruntime::server::GetRequestFromJson(input_json, request);
|
||||
|
||||
EXPECT_EQ(protobufutil::error::INVALID_ARGUMENT, status.error_code());
|
||||
std::string errmsg = status.error_message();
|
||||
EXPECT_EQ("Expected : between key:value pair.\n{inputs\":{\"Input3\":{\"dims\":\n ^", status.error_message());
|
||||
}
|
||||
|
||||
TEST(JsonSerializationTests, HappyPath) {
|
||||
std::string test_data = "testdata/server/response_0.pb";
|
||||
std::string expected_json_string = R"({"outputs":{"Plus214_Output_0":{"dims":["1","10"],"dataType":1,"rawData":"4+pzRFWuGsSMdM1F2gEnRFdRZcRZ9NDEURj0xBIzdsJOS0LEA/GzxA=="}}})";
|
||||
onnxruntime::server::PredictResponse response;
|
||||
std::string json_string;
|
||||
|
||||
std::ifstream ifs(test_data, std::ios_base::in | std::ios_base::binary);
|
||||
ASSERT_TRUE(ifs) << test_data << " Not Found" << std::endl;
|
||||
|
||||
bool succeeded = response.ParseFromIstream(&ifs);
|
||||
ifs.close();
|
||||
EXPECT_TRUE(succeeded) << test_data << " is invalid" << std::endl;
|
||||
|
||||
protobufutil::Status status = onnxruntime::server::GenerateResponseInJson(response, json_string);
|
||||
|
||||
EXPECT_EQ(protobufutil::error::OK, status.error_code());
|
||||
EXPECT_EQ(expected_json_string, json_string);
|
||||
}
|
||||
|
||||
TEST(StringEscapingTests, SimpleString) {
|
||||
std::string unescaped = "This is an error message \" \n ";
|
||||
EXPECT_EQ("This is an error message \\\" \\n ", escape_string(unescaped));
|
||||
}
|
||||
|
||||
TEST(StringEscapingTests, SimpleStringWithControlCharacter) {
|
||||
std::string unescaped = "This is an \x1f error message";
|
||||
EXPECT_EQ("This is an \\u001f error message", escape_string(unescaped));
|
||||
}
|
||||
|
||||
TEST(StringEscapingTests, SimpleStringWithNullCharacter) {
|
||||
std::string unescaped = "This is an error message \x00 end";
|
||||
EXPECT_EQ("This is an error message ", escape_string(unescaped));
|
||||
}
|
||||
|
||||
TEST(JsonErrorMessageTests, SimpleMessage) {
|
||||
auto status = http::status::bad_request;
|
||||
std::string error_message = "Incorrect headers";
|
||||
std::string expected = "{\"error_code\": 400, \"error_message\": \"Incorrect headers\"}\n";
|
||||
std::string res = CreateJsonError(status, error_message);
|
||||
EXPECT_EQ(expected, res);
|
||||
}
|
||||
|
||||
TEST(JsonErrorMessageTests, MessageWithNewLine) {
|
||||
auto status = http::status::internal_server_error;
|
||||
std::string error_message = "Contains newline \n here";
|
||||
std::string expected = "{\"error_code\": 500, \"error_message\": \"Contains newline \\n here\"}\n";
|
||||
std::string res = CreateJsonError(status, error_message);
|
||||
EXPECT_EQ(expected, res);
|
||||
}
|
||||
|
||||
TEST(JsonErrorMessageTests, MessageWithRealError) {
|
||||
auto status = http::status::bad_request;
|
||||
std::string error_message = "Expected , or ] after array value.\n0, 0.0, 0.0, 0.0 } }, \"outputFilter\n ^";
|
||||
std::string expected = "{\"error_code\": 400, \"error_message\": \"Expected , or ] after array value.\\n0, 0.0, 0.0, 0.0 } }, \\\"outputFilter\\n ^\"}\n";
|
||||
std::string res = CreateJsonError(status, error_message);
|
||||
EXPECT_EQ(expected, res);
|
||||
}
|
||||
|
||||
TEST(JsonErrorMessageTests, MessageWithQuotations) {
|
||||
auto status = http::status::bad_request;
|
||||
std::string error_message = R"(Error with "{"bleh": [1,2,3]|")";
|
||||
std::string expected = "{\"error_code\": 400, \"error_message\": \"Error with \\\"{\\\"bleh\\\": [1,2,3]|\\\"\"}\n";
|
||||
std::string result_t = CreateJsonError(status, error_message);
|
||||
EXPECT_EQ(expected, result_t);
|
||||
}
|
||||
|
||||
TEST(JsonErrorMessageTests, MessageWithManyCarriageCharacters) {
|
||||
auto status = http::status::bad_request;
|
||||
std::string error_message = "\"ab\r\n\b\f\t\\\x1a\"";
|
||||
std::string expected = "{\"error_code\": 400, \"error_message\": \"\\\"ab\\r\\n\\b\\f\\t\\\\\\u001a\\\"\"}\n";
|
||||
std::string result_t = CreateJsonError(status, error_message);
|
||||
EXPECT_EQ(expected, result_t);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "server/server_configuration.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
namespace test {
|
||||
|
||||
TEST(ConfigParsingTests, AllArgs) {
|
||||
char* test_argv[] = {
|
||||
const_cast<char*>("/path/to/binary"),
|
||||
const_cast<char*>("--model_path"), const_cast<char*>("testdata/mul_1.pb"),
|
||||
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
|
||||
const_cast<char*>("--http_port"), const_cast<char*>("80"),
|
||||
const_cast<char*>("--num_http_threads"), const_cast<char*>("1"),
|
||||
const_cast<char*>("--log_level"), const_cast<char*>("info")};
|
||||
|
||||
onnxruntime::server::ServerConfiguration config{};
|
||||
Result res = config.ParseInput(11, test_argv);
|
||||
EXPECT_EQ(res, Result::ContinueSuccess);
|
||||
EXPECT_EQ(config.model_path, "testdata/mul_1.pb");
|
||||
EXPECT_EQ(config.address, "4.4.4.4");
|
||||
EXPECT_EQ(config.http_port, 80);
|
||||
EXPECT_EQ(config.num_http_threads, 1);
|
||||
EXPECT_EQ(config.logging_level, onnxruntime::logging::Severity::kINFO);
|
||||
}
|
||||
|
||||
TEST(ConfigParsingTests, Defaults) {
|
||||
char* test_argv[] = {
|
||||
const_cast<char*>("/path/to/binary"),
|
||||
const_cast<char*>("--model"), const_cast<char*>("testdata/mul_1.pb"),
|
||||
const_cast<char*>("--num_http_threads"), const_cast<char*>("3")};
|
||||
|
||||
onnxruntime::server::ServerConfiguration config{};
|
||||
Result res = config.ParseInput(5, test_argv);
|
||||
EXPECT_EQ(res, Result::ContinueSuccess);
|
||||
EXPECT_EQ(config.model_path, "testdata/mul_1.pb");
|
||||
EXPECT_EQ(config.address, "0.0.0.0");
|
||||
EXPECT_EQ(config.http_port, 8001);
|
||||
EXPECT_EQ(config.num_http_threads, 3);
|
||||
EXPECT_EQ(config.logging_level, onnxruntime::logging::Severity::kINFO);
|
||||
}
|
||||
|
||||
TEST(ConfigParsingTests, Help) {
|
||||
char* test_argv[] = {
|
||||
const_cast<char*>("/path/to/binary"),
|
||||
const_cast<char*>("--help")};
|
||||
|
||||
onnxruntime::server::ServerConfiguration config{};
|
||||
auto res = config.ParseInput(2, test_argv);
|
||||
EXPECT_EQ(res, Result::ExitSuccess);
|
||||
}
|
||||
|
||||
TEST(ConfigParsingTests, NoModelArg) {
|
||||
char* test_argv[] = {
|
||||
const_cast<char*>("/path/to/binary"),
|
||||
const_cast<char*>("--num_http_threads"), const_cast<char*>("3")};
|
||||
|
||||
onnxruntime::server::ServerConfiguration config{};
|
||||
Result res = config.ParseInput(3, test_argv);
|
||||
EXPECT_EQ(res, Result::ExitFailure);
|
||||
}
|
||||
|
||||
TEST(ConfigParsingTests, ModelNotFound) {
|
||||
char* test_argv[] = {
|
||||
const_cast<char*>("/path/to/binary"),
|
||||
const_cast<char*>("--model_path"), const_cast<char*>("does/not/exist"),
|
||||
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
|
||||
const_cast<char*>("--http_port"), const_cast<char*>("80"),
|
||||
const_cast<char*>("--num_http_threads"), const_cast<char*>("1")};
|
||||
|
||||
onnxruntime::server::ServerConfiguration config{};
|
||||
Result res = config.ParseInput(9, test_argv);
|
||||
EXPECT_EQ(res, Result::ExitFailure);
|
||||
}
|
||||
|
||||
TEST(ConfigParsingTests, WrongLoggingLevel) {
|
||||
char* test_argv[] = {
|
||||
const_cast<char*>("/path/to/binary"),
|
||||
const_cast<char*>("--log_level"), const_cast<char*>("not a logging level"),
|
||||
const_cast<char*>("--model_path"), const_cast<char*>("testdata/mul_1.pb"),
|
||||
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
|
||||
const_cast<char*>("--http_port"), const_cast<char*>("80"),
|
||||
const_cast<char*>("--num_http_threads"), const_cast<char*>("1")};
|
||||
|
||||
onnxruntime::server::ServerConfiguration config{};
|
||||
Result res = config.ParseInput(11, test_argv);
|
||||
EXPECT_EQ(res, Result::ExitFailure);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
int status = 0;
|
||||
|
||||
try {
|
||||
const bool create_default_logger = true;
|
||||
onnxruntime::test::TestEnvironment environment{argc, argv, create_default_logger};
|
||||
|
||||
status = RUN_ALL_TESTS();
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what();
|
||||
status = -1;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <google/protobuf/stubs/status.h>
|
||||
#include "gtest/gtest.h"
|
||||
#include "server/http/core/context.h"
|
||||
#include "server/http/util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace server {
|
||||
namespace test {
|
||||
|
||||
namespace protobufutil = google::protobuf::util;
|
||||
|
||||
TEST(RequestContentTypeTests, ContentTypeJson) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::content_type, "application/json");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetRequestContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::Json);
|
||||
}
|
||||
|
||||
TEST(RequestContentTypeTests, ContentTypeRawData) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::content_type, "application/octet-stream");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetRequestContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
|
||||
context.request.set(http::field::content_type, "application/vnd.google.protobuf");
|
||||
result = GetRequestContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
|
||||
context.request.set(http::field::content_type, "application/x-protobuf");
|
||||
result = GetRequestContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
}
|
||||
|
||||
TEST(RequestContentTypeTests, ContentTypeUnknown) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::content_type, "text/plain");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetRequestContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::Unknown);
|
||||
}
|
||||
|
||||
TEST(RequestContentTypeTests, ContentTypeMissing) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
context.request = request;
|
||||
|
||||
auto result = GetRequestContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::Unknown);
|
||||
}
|
||||
|
||||
TEST(ResponseContentTypeTests, ContentTypeJson) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::accept, "application/json");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::Json);
|
||||
}
|
||||
|
||||
TEST(ResponseContentTypeTests, ContentTypeRawData) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::accept, "application/octet-stream");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
|
||||
context.request.set(http::field::accept, "application/vnd.google.protobuf");
|
||||
result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
|
||||
context.request.set(http::field::accept, "application/x-protobuf");
|
||||
result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
}
|
||||
|
||||
TEST(ResponseContentTypeTests, ContentTypeAny) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::accept, "*/*");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
}
|
||||
|
||||
TEST(ResponseContentTypeTests, ContentTypeUnknown) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
request.set(http::field::accept, "text/plain");
|
||||
context.request = request;
|
||||
|
||||
auto result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::Unknown);
|
||||
}
|
||||
|
||||
TEST(ContentTypeTests, ContentTypeMissing) {
|
||||
HttpContext context;
|
||||
http::request<http::string_body, http::basic_fields<std::allocator<char>>> request{};
|
||||
context.request = request;
|
||||
|
||||
auto result = GetResponseContentType(context);
|
||||
EXPECT_EQ(result, SupportedContentType::PbByteArray);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace server
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1 @@
|
|||
{"inputs": {"Input3": {"dims": ["1", "1", "28", "28"], "dataType": 1, "rawData": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAQEAAAAAAAAAAAAAAgEAAAABAAAAAAAAAMEEAAAAAAAAAAAAAYEEAAIA/AAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQEEAAAAAAAAAAAAA4EAAAAAAAACAPwAAIEEAAAAAAAAAQAAAAEAAAIBBAAAAAAAAQEAAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4EAAAABBAAAAAAAAAEEAAAAAAAAAAAAAAEEAAAAAAAAAAAAAmEEAAAAAAAAAAAAAgD8AAKhBAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAAAAAAAAAAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAAAAAAIEEAAEBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABQQQAAAAAAAHBBAAAgQQAA0EEAAAhCAACIQQAAmkIAADVDAAAyQwAADEIAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWQwAAfkMAAHpDAAB7QwAAc0MAAHxDAAB8QwAAf0MAADRCAADAQAAAAAAAAKBAAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBAAACQQgAATUMAAH9DAABuQwAAc0MAAH9DAAB+QwAAe0MAAHhDAABJQwAARkMAAGRCAAAAAAAAmEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWkMAAH9DAABxQwAAf0MAAHlDAAB6QwAAe0MAAHpDAAB/QwAAf0MAAHJDAABgQwAAREIAAAAAAABAQQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAABAQAAAAEAAAABAAACAPwAAAAAAAIJCAABkQwAAf0MAAH5DAAB0QwAA7kIAAAhCAAAkQgAA3EIAAHpDAAB/QwAAeEMAAPhCAACgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBBAAAAAAAAeEIAAM5CAADiQgAA6kIAAAhCAAAAAAAAAAAAAAAAAABIQwAAdEMAAH9DAAB/QwAAAAAAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAEAAAIA/AAAAAAAAAAAAAAAAAAAAAAAAgD8AAABAAAAAAAAAAAAAAABAAACAQAAAAAAAADBBAAAAAAAA4EAAAMBAAAAAAAAAlkIAAHRDAAB/QwAAf0MAAIBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAIBAAACAQAAAAAAAAGBBAAAAAAAAAAAAAAAAAAAQQQAAAAAAAABAAAAAAAAAAAAAAAhCAAB/QwAAf0MAAH1DAAAgQQAAIEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AAAAQAAAQEAAAABAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAFBBAAAwQQAAAAAAAAAAAAAAAAAAwEAAAEBBAADGQgAAf0MAAH5DAAB4QwAAcEEAAEBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIA/AACAPwAAgD8AAAAAAAAAAAAAAAAAAAAAAACAPwAAgD8AAAAAAAAAAAAAoEAAAMBAAAAwQQAAAAAAAAAAAACIQQAAOEMAAHdDAAB/QwAAc0MAAFBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAABAAACAQAAAgEAAAAAAAAAwQQAAAAAAAExCAAC8QgAAqkIAAKBAAACgQAAAyEEAAHZDAAB2QwAAf0MAAFBDAAAAAAAAEEEAAAAAAAAAAAAAAAAAAAAAAACAQAAAgD8AAAAAAAAAAAAAgD8AAOBAAABwQQAAmEEAAMZCAADOQgAANkMAAD1DAABtQwAAfUMAAHxDAAA/QwAAPkMAAGNDAABzQwAAfEMAAFJDAACQQQAA4EAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAAAAAAAAAAAABCAADaQgAAOUMAAHdDAAB/QwAAckMAAH9DAAB0QwAAf0MAAH9DAAByQwAAe0MAAH9DAABwQwAAf0MAAH9DAABaQwAA+EIAABBBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAD+QgAAf0MAAGtDAAB/QwAAf0MAAHdDAABlQwAAVEMAAHJDAAB6QwAAf0MAAH9DAAB4QwAAf0MAAH1DAAB5QwAAf0MAAHNDAAAqQwAAQEEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMEEAAAAAAAAQQQAAfUMAAH9DAAB/QwAAaUMAAEpDAACqQgAAAAAAAFRCAABEQwAAbkMAAH9DAABjQwAAbkMAAA5DAADaQgAAQUMAAH9DAABwQwAAf0MAADRDAAAAAAAAAAAAAAAAAAAAAAAAwEAAAAAAAACwQQAAgD8AAHVDAABzQwAAfkMAAH9DAABZQwAAa0MAAGJDAABVQwAAdEMAAHtDAAB/QwAAb0MAAJpCAAAAAAAAAAAAAKBBAAA2QwAAd0MAAG9DAABzQwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIBAAAAlQwAAe0MAAH9DAAB1QwAAf0MAAHJDAAB9QwAAekMAAH9DAABFQwAA1kIAAGxCAAAAAAAAkEEAAABAAADAQAAAAAAAAFhCAAB/QwAAHkMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEEAAAAAAAAAAAAAwEAAAAhCAAAnQwAAQkMAADBDAAA3QwAAJEMAADBCAAAAQAAAIEEAAMBAAADAQAAAAAAAAAAAAACgQAAAAAAAAIA/AAAAAAAAYEEAAABAAAAAAAAAAAAAAAAAAAAAAAAAIEEAAAAAAABgQQAAAAAAAEBBAAAAAAAAoEAAAAAAAACAPwAAAAAAAMBAAAAAAAAA4EAAAAAAAAAAAAAAAAAAAABBAAAAAAAAIEEAAAAAAACgQAAAAAAAAAAAAAAgQQAAAAAAAAAAAAAAAAAAAAAAAAAAAABgQQAAAAAAAIBAAAAAAAAAAAAAAMhBAAAAAAAAAAAAABBBAAAAAAAAAAAAABBBAAAAAAAAMEEAAAAAAACAPwAAAAAAAAAAAAAAQAAAAAAAAAAAAADgQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}}, "outputFilter": ["Plus214_Output_0"]}
|
Двоичные данные
onnxruntime/test/testdata/server/mnist_test_data_set_0_input.pb
поставляемый
Normal file
Двоичные данные
onnxruntime/test/testdata/server/mnist_test_data_set_0_input.pb
поставляемый
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
{"outputs": {"Plus214_Output_0": {"dims": ["1", "10"], "dataType": 1, "rawData": "4+pzRFWuGsSMdM1F2gEnRFdRZcRZ9NDEURj0xBIzdsJOS0LEA/GzxA=="}}}
|
|
@ -0,0 +1,5 @@
|
|||
|
||||
D
|
||||
Plus214_Output_00
|
||||
|
||||
J(肚sDUョト荊ヘEレ'DWQeトY<EFBE84>トQ<18>3vツNKBト<03>ト
|
Двоичный файл не отображается.
|
@ -0,0 +1,5 @@
|
|||
|
||||
D
|
||||
Plus214_Output_00
|
||||
|
||||
J(肚sDUョト荊ヘEレ'DWQeトY<EFBE84>トQ<18>3vツNKBト<03>ト
|
|
@ -95,6 +95,10 @@ Use the individual flags to only run the specified stages.
|
|||
# Build a shared lib
|
||||
parser.add_argument("--build_shared_lib", action='store_true', help="Build a shared library for the ONNXRuntime.")
|
||||
|
||||
# Build ONNX Runtime server
|
||||
parser.add_argument("--build_server", action='store_true', help="Build server application for the ONNXRuntime.")
|
||||
parser.add_argument("--enable_server_tests", action='store_true', help="Run server application tests.")
|
||||
|
||||
# Build options
|
||||
parser.add_argument("--cmake_extra_defines", nargs="+",
|
||||
help="Extra definitions to pass to CMake during build system generation. " +
|
||||
|
@ -324,9 +328,10 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
|
|||
"-Donnxruntime_TENSORRT_HOME=" + (tensorrt_home if args.use_tensorrt else ""),
|
||||
# By default - we currently support only cross compiling for ARM/ARM64 (no native compilation supported through this script)
|
||||
"-Donnxruntime_CROSS_COMPILING=" + ("ON" if args.arm64 or args.arm else "OFF"),
|
||||
"-Donnxruntime_BUILD_SERVER=" + ("ON" if args.build_server else "OFF"),
|
||||
"-Donnxruntime_BUILD_x86=" + ("ON" if args.x86 else "OFF"),
|
||||
# nGraph and TensorRT providers currently only supports full_protobuf option.
|
||||
"-Donnxruntime_USE_FULL_PROTOBUF=" + ("ON" if args.use_full_protobuf or args.use_ngraph or args.use_tensorrt else "OFF"),
|
||||
"-Donnxruntime_USE_FULL_PROTOBUF=" + ("ON" if args.use_full_protobuf or args.use_ngraph or args.use_tensorrt or args.build_server else "OFF"),
|
||||
"-Donnxruntime_DISABLE_CONTRIB_OPS=" + ("ON" if args.disable_contrib_ops else "OFF"),
|
||||
"-Donnxruntime_MSVC_STATIC_RUNTIME=" + ("ON" if args.enable_msvc_static_runtime else "OFF"),
|
||||
]
|
||||
|
@ -535,6 +540,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, enab
|
|||
if onnxml_test:
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_python_keras.py'], cwd=cwd, dll_path=dll_path)
|
||||
|
||||
|
||||
def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_parallel_executor_test, num_parallel_models):
|
||||
for config in configs:
|
||||
cwd = get_config_build_dir(build_dir, config)
|
||||
|
@ -565,6 +571,18 @@ def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_para
|
|||
run_subprocess([exe,'-x'] + cmd, cwd=cwd)
|
||||
|
||||
|
||||
def run_server_tests(build_dir, configs):
|
||||
run_subprocess([sys.executable, '-m', 'pip', 'install', '--trusted-host', 'files.pythonhosted.org', 'requests', 'protobuf', 'numpy'])
|
||||
for config in configs:
|
||||
config_build_dir = get_config_build_dir(build_dir, config)
|
||||
if is_windows():
|
||||
server_app_path = os.path.join(config_build_dir, config, 'onnxruntime_server.exe')
|
||||
else:
|
||||
server_app_path = os.path.join(config_build_dir, 'onnxruntime_server')
|
||||
server_test_folder = os.path.join(config_build_dir, 'server_test')
|
||||
server_test_data_folder = os.path.join(os.path.join(config_build_dir, 'testdata'), 'server')
|
||||
run_subprocess([sys.executable, 'test_main.py', server_app_path, server_test_data_folder, server_test_data_folder], cwd=server_test_folder, dll_path=None)
|
||||
|
||||
def build_python_wheel(source_dir, build_dir, configs, use_cuda, use_ngraph, use_tensorrt, nightly_build = False):
|
||||
for config in configs:
|
||||
cwd = get_config_build_dir(build_dir, config)
|
||||
|
@ -766,6 +784,9 @@ def main():
|
|||
if args.use_mkldnn:
|
||||
run_onnx_tests(build_dir, configs, onnx_test_data_dir, 'mkldnn', True, 1)
|
||||
|
||||
if args.build_server and args.enable_server_tests:
|
||||
run_server_tests(build_dir, configs)
|
||||
|
||||
if args.build:
|
||||
if args.build_wheel:
|
||||
nightly_build = bool(os.getenv('NIGHTLY_BUILD') == '1')
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
jobs:
|
||||
- job: Debug_Build
|
||||
pool: Hosted Ubuntu 1604
|
||||
steps:
|
||||
- template: templates/set-test-data-variables-step.yml
|
||||
- script: 'tools/ci_build/github/linux/server_run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -k $(acr.key) -x "--config Debug --build_server --use_openmp --use_full_protobuf --enable_server_tests"'
|
||||
displayName: 'Debug Build'
|
||||
- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
|
||||
displayName: 'Component Detection'
|
||||
- template: templates/clean-agent-build-directory-step.yml
|
||||
- job: Release_Build
|
||||
pool: Hosted Ubuntu 1604
|
||||
steps:
|
||||
- template: templates/set-test-data-variables-step.yml
|
||||
- script: 'tools/ci_build/github/linux/server_run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -k $(acr.key) -x "--config Release --build_server --use_openmp --use_full_protobuf --enable_server_tests"'
|
||||
displayName: 'Release Build'
|
||||
- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
|
||||
displayName: 'Component Detection'
|
||||
- template: templates/clean-agent-build-directory-step.yml
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
set -e -o -x
|
||||
|
||||
id
|
||||
|
||||
SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )"
|
||||
|
||||
while getopts c:d:x: parameter_Option
|
||||
do case "${parameter_Option}"
|
||||
in
|
||||
d) BUILD_DEVICE=${OPTARG};;
|
||||
x) BUILD_EXTR_PAR=${OPTARG};;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ $BUILD_DEVICE = "gpu" ]; then
|
||||
_CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2)
|
||||
python3 $SCRIPT_DIR/../../build.py --build_dir /home/onnxruntimedev \
|
||||
--config Debug Release \
|
||||
--skip_submodule_sync --enable_onnx_tests \
|
||||
--parallel --build_shared_lib \
|
||||
--use_cuda --use_openmp \
|
||||
--cuda_home /usr/local/cuda \
|
||||
--cudnn_home /usr/local/cudnn-$_CUDNN_VERSION/cuda --build_shared_lib $BUILD_EXTR_PAR
|
||||
/home/onnxruntimedev/Release/onnx_test_runner -e cuda /data/onnx
|
||||
else
|
||||
python3 $SCRIPT_DIR/../../build.py --build_dir /home/onnxruntimedev \
|
||||
--skip_submodule_sync \
|
||||
--parallel $BUILD_EXTR_PAR
|
||||
# /home/onnxruntimedev/Release/onnx_test_runner /data/onnx
|
||||
fi
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
set -e -o -x
|
||||
|
||||
SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )"
|
||||
SOURCE_ROOT=$(realpath $SCRIPT_DIR/../../../../)
|
||||
CUDA_VER=cuda10.0-cudnn7.3
|
||||
|
||||
while getopts c:o:d:k:r:p:x: parameter_Option
|
||||
do case "${parameter_Option}"
|
||||
in
|
||||
#ubuntu16.04
|
||||
o) BUILD_OS=${OPTARG};;
|
||||
#cpu, gpu
|
||||
d) BUILD_DEVICE=${OPTARG};;
|
||||
k) ACR_KEY=${OPTARG};;
|
||||
r) BUILD_DIR=${OPTARG};;
|
||||
#python version: 3.6 3.7 (absence means default 3.5)
|
||||
p) PYTHON_VER=${OPTARG};;
|
||||
# "--build_wheel --use_openblas"
|
||||
x) BUILD_EXTR_PAR=${OPTARG};;
|
||||
# "cuda10.0-cudnn7.3, cuda9.1-cudnn7.1"
|
||||
c) CUDA_VER=${OPTARG};;
|
||||
esac
|
||||
done
|
||||
|
||||
EXIT_CODE=1
|
||||
|
||||
echo "bo=$BUILD_OS bd=$BUILD_DEVICE bdir=$BUILD_DIR pv=$PYTHON_VER bex=$BUILD_EXTR_PAR"
|
||||
|
||||
cd $SCRIPT_DIR/docker
|
||||
if [ $BUILD_DEVICE = "gpu" ]; then
|
||||
IMAGE="ubuntu16.04-$CUDA_VER"
|
||||
DOCKER_FILE=Dockerfile.ubuntu_gpu
|
||||
if [ $CUDA_VER = "cuda9.1-cudnn7.1" ]; then
|
||||
DOCKER_FILE=Dockerfile.ubuntu_gpu_cuda9
|
||||
fi
|
||||
docker build -t "onnxruntime-$IMAGE" --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} -f $DOCKER_FILE .
|
||||
else
|
||||
IMAGE="ubuntu16.04"
|
||||
docker login onnxhostingdev.azurecr.io -u onnxhostingdev -p ${ACR_KEY}
|
||||
docker pull onnxhostingdev.azurecr.io/onnxruntime-ubuntu16.04:latest
|
||||
docker tag onnxhostingdev.azurecr.io/onnxruntime-ubuntu16.04:latest onnxruntime-ubuntu16.04:latest
|
||||
docker images
|
||||
id
|
||||
fi
|
||||
|
||||
set +e
|
||||
|
||||
if [ $BUILD_DEVICE = "cpu" ]; then
|
||||
docker rm -f "onnxruntime-$BUILD_DEVICE" || true
|
||||
docker run -h $HOSTNAME \
|
||||
--rm \
|
||||
--name "onnxruntime-$BUILD_DEVICE" \
|
||||
--volume "$SOURCE_ROOT:/onnxruntime_src" \
|
||||
--volume "$BUILD_DIR:/home/onnxruntimedev" \
|
||||
--volume "$HOME/.cache/onnxruntime:/home/onnxruntimedev/.cache/onnxruntime" \
|
||||
"onnxruntime-$IMAGE" \
|
||||
/bin/bash /onnxruntime_src/tools/ci_build/github/linux/server_run_build.sh \
|
||||
-d $BUILD_DEVICE -x "$BUILD_EXTR_PAR" &
|
||||
else
|
||||
docker rm -f "onnxruntime-$BUILD_DEVICE" || true
|
||||
nvidia-docker run --rm -h $HOSTNAME \
|
||||
--rm \
|
||||
--name "onnxruntime-$BUILD_DEVICE" \
|
||||
--volume "$SOURCE_ROOT:/onnxruntime_src" \
|
||||
--volume "$BUILD_DIR:/home/onnxruntimedev" \
|
||||
--volume "$HOME/.cache/onnxruntime:/home/onnxruntimedev/.cache/onnxruntime" \
|
||||
"onnxruntime-$IMAGE" \
|
||||
/bin/bash /onnxruntime_src/tools/ci_build/github/linux/server_run_build.sh \
|
||||
-d $BUILD_DEVICE -x "$BUILD_EXTR_PAR" &
|
||||
fi
|
||||
wait -n
|
||||
|
||||
EXIT_CODE=$?
|
||||
|
||||
set -e
|
||||
exit $EXIT_CODE
|
Загрузка…
Ссылка в новой задаче