[Java] Add API for appending QNN EP (#22208)
- Add Java API for appending QNN EP - Update Java unit test setup - Fix issues with setting system properties for tests - Unify Windows/non-Windows setup to simplify
This commit is contained in:
Родитель
e2b9ccc44a
Коммит
c24e55b1f1
|
@ -1,23 +1,32 @@
|
|||
# Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# This is a windows only file so we can run gradle tests via ctest
|
||||
# This is a helper script that enables us to run gradle tests via ctest.
|
||||
|
||||
FILE(TO_NATIVE_PATH ${GRADLE_EXECUTABLE} GRADLE_NATIVE_PATH)
|
||||
FILE(TO_NATIVE_PATH ${BIN_DIR} BINDIR_NATIVE_PATH)
|
||||
|
||||
message(STATUS "GRADLE_TEST_EP_FLAGS: ${ORT_PROVIDER_FLAGS}")
|
||||
if (onnxruntime_ENABLE_TRAINING_APIS)
|
||||
message(STATUS "Running ORT Java training tests")
|
||||
execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -Dorg.gradle.daemon=false ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING_APIS=1
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java
|
||||
RESULT_VARIABLE HAD_ERROR)
|
||||
else()
|
||||
execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -Dorg.gradle.daemon=false ${ORT_PROVIDER_FLAGS}
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java
|
||||
RESULT_VARIABLE HAD_ERROR)
|
||||
message(STATUS "gradle additional system property definitions: ${GRADLE_SYSTEM_PROPERTY_DEFINITIONS}")
|
||||
|
||||
set(GRADLE_TEST_ARGS
|
||||
${GRADLE_NATIVE_PATH}
|
||||
test --rerun
|
||||
cmakeCheck
|
||||
--console=plain
|
||||
-DcmakeBuildDir=${BINDIR_NATIVE_PATH}
|
||||
-Dorg.gradle.daemon=false
|
||||
${GRADLE_SYSTEM_PROPERTY_DEFINITIONS})
|
||||
|
||||
if(WIN32)
|
||||
list(PREPEND GRADLE_TEST_ARGS cmd /C)
|
||||
endif()
|
||||
|
||||
message(STATUS "gradle test command args: ${GRADLE_TEST_ARGS}")
|
||||
|
||||
execute_process(COMMAND ${GRADLE_TEST_ARGS}
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java
|
||||
RESULT_VARIABLE HAD_ERROR)
|
||||
|
||||
if(HAD_ERROR)
|
||||
message(FATAL_ERROR "Java Unitests failed")
|
||||
message(FATAL_ERROR "Java Unitests failed")
|
||||
endif()
|
||||
|
|
|
@ -1590,39 +1590,55 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
|||
|
||||
if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
|
||||
if (onnxruntime_BUILD_JAVA AND NOT onnxruntime_ENABLE_STATIC_ANALYSIS)
|
||||
message(STATUS "Running Java tests")
|
||||
block()
|
||||
message(STATUS "Enabling Java tests")
|
||||
|
||||
# native-test is added to resources so custom_op_lib can be loaded
|
||||
# and we want to symlink it there
|
||||
# and we want to copy it there
|
||||
set(JAVA_NATIVE_TEST_DIR ${JAVA_OUTPUT_DIR}/native-test)
|
||||
file(MAKE_DIRECTORY ${JAVA_NATIVE_TEST_DIR})
|
||||
|
||||
# delegate to gradle's test runner
|
||||
if(WIN32)
|
||||
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:custom_op_library>
|
||||
${JAVA_NATIVE_TEST_DIR}/$<TARGET_FILE_NAME:custom_op_library>)
|
||||
# On windows ctest requires a test to be an .exe(.com) file
|
||||
# With gradle wrapper we get gradlew.bat. We delegate execution to a separate .cmake file
|
||||
# That can handle both .exe and .bat
|
||||
add_test(NAME onnxruntime4j_test COMMAND ${CMAKE_COMMAND}
|
||||
-DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE}
|
||||
-DBIN_DIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DREPO_ROOT=${REPO_ROOT}
|
||||
${ORT_PROVIDER_FLAGS}
|
||||
-P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake)
|
||||
else()
|
||||
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:custom_op_library>
|
||||
${JAVA_NATIVE_TEST_DIR}/$<TARGET_LINKER_FILE_NAME:custom_op_library>)
|
||||
if (onnxruntime_ENABLE_TRAINING_APIS)
|
||||
message(STATUS "Running Java inference and training tests")
|
||||
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING_APIS=1
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java)
|
||||
else()
|
||||
message(STATUS "Running Java inference tests only")
|
||||
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} ${ORT_PROVIDER_FLAGS}
|
||||
WORKING_DIRECTORY ${REPO_ROOT}/java)
|
||||
endif()
|
||||
set(CUSTOM_OP_LIBRARY_DST_FILE_NAME
|
||||
$<IF:$<BOOL:${WIN32}>,$<TARGET_FILE_NAME:custom_op_library>,$<TARGET_LINKER_FILE_NAME:custom_op_library>>)
|
||||
|
||||
add_custom_command(TARGET custom_op_library POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
$<TARGET_FILE:custom_op_library>
|
||||
${JAVA_NATIVE_TEST_DIR}/${CUSTOM_OP_LIBRARY_DST_FILE_NAME})
|
||||
|
||||
# also copy other library dependencies that may be required by tests to native-test
|
||||
if(onnxruntime_USE_QNN)
|
||||
add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR})
|
||||
endif()
|
||||
|
||||
# delegate to gradle's test runner
|
||||
|
||||
# On Windows, ctest requires a test to be an .exe(.com) file. With gradle wrapper, we get gradlew.bat.
|
||||
# To work around this, we delegate gradle execution to a separate .cmake file that can be run with cmake.
|
||||
# For simplicity, we use this setup for all supported platforms and not just Windows.
|
||||
|
||||
# Note: Here we rely on the values in ORT_PROVIDER_FLAGS to be of the format "-Doption=value".
|
||||
# This happens to also match the gradle command line option for specifying system properties.
|
||||
set(GRADLE_SYSTEM_PROPERTY_DEFINITIONS ${ORT_PROVIDER_FLAGS})
|
||||
|
||||
if(onnxruntime_ENABLE_TRAINING_APIS)
|
||||
message(STATUS "Enabling Java tests for training APIs")
|
||||
|
||||
list(APPEND GRADLE_SYSTEM_PROPERTY_DEFINITIONS "-DENABLE_TRAINING_APIS=1")
|
||||
endif()
|
||||
|
||||
add_test(NAME onnxruntime4j_test COMMAND
|
||||
${CMAKE_COMMAND}
|
||||
-DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE}
|
||||
-DBIN_DIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DREPO_ROOT=${REPO_ROOT}
|
||||
# Note: Quotes are important here to pass a list of values as a single property.
|
||||
"-DGRADLE_SYSTEM_PROPERTY_DEFINITIONS=${GRADLE_SYSTEM_PROPERTY_DEFINITIONS}"
|
||||
-P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake)
|
||||
|
||||
set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni)
|
||||
endblock()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -199,7 +199,19 @@ test {
|
|||
if (cmakeBuildDir != null) {
|
||||
workingDir cmakeBuildDir
|
||||
}
|
||||
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
|
||||
systemProperties System.getProperties().subMap([
|
||||
'ENABLE_TRAINING_APIS',
|
||||
'JAVA_FULL_TEST',
|
||||
'USE_COREML',
|
||||
'USE_CUDA',
|
||||
'USE_DML',
|
||||
'USE_DNNL',
|
||||
'USE_OPENVINO',
|
||||
'USE_ROCM',
|
||||
'USE_TENSORRT',
|
||||
'USE_QNN',
|
||||
'USE_XNNPACK',
|
||||
])
|
||||
testLogging {
|
||||
events "passed", "skipped", "failed"
|
||||
showStandardStreams = true
|
||||
|
|
|
@ -49,7 +49,7 @@ public enum OrtLoggingLevel {
|
|||
* @return The Java enum.
|
||||
*/
|
||||
public static OrtLoggingLevel mapFromInt(int logLevel) {
|
||||
if ((logLevel > 0) && (logLevel < values.length)) {
|
||||
if ((logLevel >= 0) && (logLevel < values.length)) {
|
||||
return values[logLevel];
|
||||
} else {
|
||||
logger.warning("Unknown logging level " + logLevel + " setting to ORT_LOGGING_LEVEL_VERBOSE");
|
||||
|
|
|
@ -40,7 +40,9 @@ public enum OrtProvider {
|
|||
/** The XNNPACK execution provider. */
|
||||
XNNPACK("XnnpackExecutionProvider"),
|
||||
/** The Azure remote endpoint execution provider. */
|
||||
AZURE("AzureExecutionProvider");
|
||||
AZURE("AzureExecutionProvider"),
|
||||
/** The QNN execution provider. */
|
||||
QNN("QNNExecutionProvider");
|
||||
|
||||
private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);
|
||||
|
||||
|
|
|
@ -1271,16 +1271,16 @@ public class OrtSession implements AutoCloseable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Adds Xnnpack as an execution backend. Needs to list all options hereif a new option
|
||||
* supported. current supported options: {} The maximum number of provider options is set to 128
|
||||
* (see addExecutionProvider's comment). This number is controlled by
|
||||
* ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is
|
||||
* not enough, please increase it or implementing an incremental way to add more options.
|
||||
* Adds the named execution provider (backend) as an execution backend. This generic function
|
||||
* only allows a subset of execution providers.
|
||||
*
|
||||
* @param providerOptions options pass to XNNPACK EP for initialization.
|
||||
* @param providerName The name of the execution provider.
|
||||
* @param providerOptions Configuration options for the execution provider. Refer to the
|
||||
* specific execution provider's documentation.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addXnnpack(Map<String, String> providerOptions) throws OrtException {
|
||||
private void addExecutionProvider(String providerName, Map<String, String> providerOptions)
|
||||
throws OrtException {
|
||||
checkClosed();
|
||||
String[] providerOptionKey = new String[providerOptions.size()];
|
||||
String[] providerOptionVal = new String[providerOptions.size()];
|
||||
|
@ -1291,7 +1291,35 @@ public class OrtSession implements AutoCloseable {
|
|||
i++;
|
||||
}
|
||||
addExecutionProvider(
|
||||
OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal);
|
||||
OnnxRuntime.ortApiHandle,
|
||||
nativeHandle,
|
||||
providerName,
|
||||
providerOptionKey,
|
||||
providerOptionVal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds XNNPACK as an execution backend.
|
||||
*
|
||||
* @param providerOptions Configuration options for the XNNPACK backend. Refer to the XNNPACK
|
||||
* execution provider's documentation.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addXnnpack(Map<String, String> providerOptions) throws OrtException {
|
||||
String xnnpackProviderName = "XNNPACK";
|
||||
addExecutionProvider(xnnpackProviderName, providerOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds QNN as an execution backend.
|
||||
*
|
||||
* @param providerOptions Configuration options for the QNN backend. Refer to the QNN execution
|
||||
* provider's documentation.
|
||||
* @throws OrtException If there was an error in native code.
|
||||
*/
|
||||
public void addQnn(Map<String, String> providerOptions) throws OrtException {
|
||||
String qnnProviderName = "QNN";
|
||||
addExecutionProvider(qnnProviderName, providerOptions);
|
||||
}
|
||||
|
||||
private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
|
||||
|
@ -1416,10 +1444,6 @@ public class OrtSession implements AutoCloseable {
|
|||
private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags)
|
||||
throws OrtException;
|
||||
|
||||
/*
|
||||
* The max length of providerOptionKey and providerOptionVal is 128, as specified by
|
||||
* ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 for its location).
|
||||
*/
|
||||
private native void addExecutionProvider(
|
||||
long apiHandle,
|
||||
long nativeHandle,
|
||||
|
|
|
@ -55,7 +55,9 @@ import java.util.stream.Stream;
|
|||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.DisabledOnOs;
|
||||
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
|
||||
import org.junit.jupiter.api.condition.OS;
|
||||
|
||||
/** Tests for the onnx-runtime Java interface. */
|
||||
public class InferenceTest {
|
||||
|
@ -66,7 +68,7 @@ public class InferenceTest {
|
|||
private static final Pattern inputPBPattern = Pattern.compile("input_*.pb");
|
||||
private static final Pattern outputPBPattern = Pattern.compile("output_*.pb");
|
||||
|
||||
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
private static final OrtEnvironment env = TestHelpers.getOrtEnvironment();
|
||||
|
||||
@Test
|
||||
public void environmentTest() {
|
||||
|
@ -711,6 +713,14 @@ public class InferenceTest {
|
|||
|
||||
@Test
|
||||
@EnabledIfSystemProperty(named = "USE_DNNL", matches = "1")
|
||||
// TODO see if this can be enabled on Windows.
|
||||
// Error in CI build:
|
||||
// ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message:
|
||||
// D:\a\_work\1\s\onnxruntime\core\session\provider_bridge_ort.cc:1530
|
||||
// onnxruntime::ProviderLibrary::Get [ONNXRuntimeError] : 1 : FAIL : LoadLibrary failed with error
|
||||
// 126 "" when trying to load
|
||||
// "C:\Users\cloudtest\AppData\Local\Temp\onnxruntime-java9085185608411256214\onnxruntime_providers_dnnl.dll"
|
||||
@DisabledOnOs(value = OS.WINDOWS)
|
||||
public void testDNNL() throws OrtException {
|
||||
runProvider(OrtProvider.DNNL);
|
||||
}
|
||||
|
@ -733,6 +743,12 @@ public class InferenceTest {
|
|||
runProvider(OrtProvider.DIRECT_ML);
|
||||
}
|
||||
|
||||
@Test
|
||||
@EnabledIfSystemProperty(named = "USE_QNN", matches = "1")
|
||||
public void testQNN() throws OrtException {
|
||||
runProvider(OrtProvider.QNN);
|
||||
}
|
||||
|
||||
private void runProvider(OrtProvider provider) throws OrtException {
|
||||
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
|
||||
assertTrue(providers.size() > 1);
|
||||
|
@ -2031,6 +2047,14 @@ public class InferenceTest {
|
|||
case XNNPACK:
|
||||
options.addXnnpack(Collections.emptyMap());
|
||||
break;
|
||||
case QNN:
|
||||
{
|
||||
String backendPath = OS.WINDOWS.isCurrentOs() ? "/QnnCpu.dll" : "/libQnnCpu.so";
|
||||
options.addQnn(
|
||||
Collections.singletonMap(
|
||||
"backend_path", TestHelpers.getResourcePath(backendPath).toString()));
|
||||
break;
|
||||
}
|
||||
case VITIS_AI:
|
||||
case RK_NPU:
|
||||
case MI_GRAPH_X:
|
||||
|
|
|
@ -22,10 +22,10 @@ import org.junit.jupiter.api.Assertions;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class OnnxTensorTest {
|
||||
private static final OrtEnvironment env = TestHelpers.getOrtEnvironment();
|
||||
|
||||
@Test
|
||||
public void testScalarCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String[] stringValues = new String[] {"true", "false"};
|
||||
for (String s : stringValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, s)) {
|
||||
|
@ -97,8 +97,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testArrayCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
// Test creating a value from a single dimensional array
|
||||
float[] arrValues = new float[] {0, 1, 2, 3, 4};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
|
||||
|
@ -192,8 +190,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testBufferCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
// Test creating a value from a non-direct byte buffer
|
||||
// Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap
|
||||
// direct byte buffers which can be directly passed to ORT
|
||||
|
@ -260,7 +256,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testStringCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
|
||||
|
@ -290,7 +285,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testUint8Creation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
byte[] buf = new byte[] {0, 1};
|
||||
ByteBuffer data = ByteBuffer.wrap(buf);
|
||||
long[] shape = new long[] {2};
|
||||
|
@ -301,7 +295,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testByteBufferCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder());
|
||||
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
|
||||
floatBuf.put(1.0f);
|
||||
|
@ -325,7 +318,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testEmptyTensor() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
FloatBuffer buf = FloatBuffer.allocate(0);
|
||||
long[] shape = new long[] {4, 0};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
|
||||
|
@ -346,7 +338,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testBf16ToFp32() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String modelPath = TestHelpers.getResourcePath("/java-bf16-to-fp32.onnx").toString();
|
||||
SplittableRandom rng = new SplittableRandom(1);
|
||||
|
||||
|
@ -379,7 +370,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testFp16ToFp32() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String modelPath = TestHelpers.getResourcePath("/java-fp16-to-fp32.onnx").toString();
|
||||
SplittableRandom rng = new SplittableRandom(1);
|
||||
|
||||
|
@ -412,7 +402,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testFp32ToFp16() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String modelPath = TestHelpers.getResourcePath("/java-fp32-to-fp16.onnx").toString();
|
||||
SplittableRandom rng = new SplittableRandom(1);
|
||||
|
||||
|
@ -469,7 +458,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testFp32ToBf16() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String modelPath = TestHelpers.getResourcePath("/java-fp32-to-bf16.onnx").toString();
|
||||
SplittableRandom rng = new SplittableRandom(1);
|
||||
|
||||
|
@ -560,7 +548,6 @@ public class OnnxTensorTest {
|
|||
|
||||
@Test
|
||||
public void testClose() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
long[] input = new long[] {1, 2, 3, 4, 5};
|
||||
OnnxTensor value = OnnxTensor.createTensor(env, input);
|
||||
assertFalse(value.isClosed());
|
||||
|
|
|
@ -21,13 +21,13 @@ import java.util.Map;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SparseTensorTest {
|
||||
private static final OrtEnvironment env = TestHelpers.getOrtEnvironment();
|
||||
|
||||
@Test
|
||||
public void testCSRC() throws OrtException {
|
||||
String modelPath =
|
||||
TestHelpers.getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
|
||||
|
||||
|
@ -207,8 +207,7 @@ public class SparseTensorTest {
|
|||
public void testCOO() throws OrtException {
|
||||
String modelPath =
|
||||
TestHelpers.getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
|
||||
|
||||
|
@ -393,8 +392,7 @@ public class SparseTensorTest {
|
|||
@Test
|
||||
public void testCOOOutput() throws OrtException {
|
||||
String modelPath = TestHelpers.getResourcePath("/sparse_initializer_as_output.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, NodeInfo> outputs = session.getOutputInfo();
|
||||
assertEquals(1, outputs.size());
|
||||
|
|
|
@ -27,7 +27,7 @@ import java.util.logging.Logger;
|
|||
import java.util.regex.Pattern;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
|
||||
/** Test helpers for manipulating primitive arrays. */
|
||||
/** Test helpers. */
|
||||
public class TestHelpers {
|
||||
|
||||
private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]");
|
||||
|
@ -469,4 +469,30 @@ public class TestHelpers {
|
|||
this.tensor = tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// Gets an OrtEnvironment instance to use in tests.
|
||||
// Reads a numeric log level from the ORT_JAVA_TEST_LOG_LEVEL environment variable and passes that
|
||||
// to OrtEnvironment.getEnvironment().
|
||||
public static OrtEnvironment getOrtEnvironment() {
|
||||
String logLevelEnvironmentVariableName = "ORT_JAVA_TEST_LOG_LEVEL";
|
||||
String logLevelString = System.getenv(logLevelEnvironmentVariableName);
|
||||
|
||||
if (logLevelString == null || logLevelString.isEmpty()) {
|
||||
return OrtEnvironment.getEnvironment();
|
||||
}
|
||||
|
||||
OrtLoggingLevel logLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
|
||||
|
||||
try {
|
||||
int logLevelInt = Integer.parseInt(logLevelString);
|
||||
logLevel = OrtLoggingLevel.mapFromInt(logLevelInt);
|
||||
} catch (NumberFormatException e) {
|
||||
System.err.println(
|
||||
String.format(
|
||||
"Failed to parse environment variable %s value ('%s') as an integer. It will be ignored.",
|
||||
logLevelEnvironmentVariableName, logLevelString));
|
||||
}
|
||||
|
||||
return OrtEnvironment.getEnvironment(logLevel);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
|
|||
@EnabledIfSystemProperty(named = "ENABLE_TRAINING_APIS", matches = "1")
|
||||
public class TrainingTest {
|
||||
|
||||
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
private static final OrtEnvironment env = TestHelpers.getOrtEnvironment();
|
||||
|
||||
@Test
|
||||
public void testLoadCheckpoint() throws OrtException {
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.junit.jupiter.api.Test;
|
|||
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
|
||||
|
||||
public class ProviderOptionsTest {
|
||||
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
private static final OrtEnvironment env = TestHelpers.getOrtEnvironment();
|
||||
|
||||
@Test
|
||||
@EnabledIfSystemProperty(named = "USE_CUDA", matches = "1")
|
||||
|
|
|
@ -63,22 +63,24 @@ jobs:
|
|||
python3 tools/ci_build/build.py \
|
||||
--build_dir build \
|
||||
--config Release \
|
||||
--parallel --use_binskim_compliant_compile_flags \
|
||||
--use_binskim_compliant_compile_flags \
|
||||
--build_java \
|
||||
--use_qnn \
|
||||
--qnn_home $(QnnSDKRootDir) \
|
||||
--cmake_generator=Ninja \
|
||||
--skip_tests
|
||||
--update --build --parallel
|
||||
displayName: Build QNN EP
|
||||
|
||||
- script: |
|
||||
python3 tools/ci_build/build.py \
|
||||
--build_dir build \
|
||||
--config Release --use_binskim_compliant_compile_flags \
|
||||
--test \
|
||||
--config Release \
|
||||
--use_binskim_compliant_compile_flags \
|
||||
--build_java \
|
||||
--use_qnn \
|
||||
--qnn_home $(QnnSDKRootDir) \
|
||||
--cmake_generator=Ninja \
|
||||
--skip_submodule_sync \
|
||||
--ctest_path ""
|
||||
--test
|
||||
displayName: Run unit tests
|
||||
|
||||
- task: CmdLine@2
|
||||
|
|
|
@ -75,12 +75,22 @@ jobs:
|
|||
displayName: 'Build'
|
||||
inputs:
|
||||
scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py'
|
||||
arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_tests --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QnnSDKRootDir) --parallel'
|
||||
workingDirectory: '$(Build.BinariesDirectory)'
|
||||
arguments: >-
|
||||
--config $(BuildConfig)
|
||||
--build_dir $(Build.BinariesDirectory)
|
||||
--cmake_generator "Visual Studio 17 2022"
|
||||
--use_qnn
|
||||
--qnn_home $(QnnSDKRootDir)
|
||||
--update --build --parallel
|
||||
|
||||
- powershell: |
|
||||
python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests
|
||||
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)'
|
||||
- script: |
|
||||
python $(Build.SourcesDirectory)\tools\ci_build\build.py ^
|
||||
--config $(BuildConfig) ^
|
||||
--build_dir $(Build.BinariesDirectory) ^
|
||||
--cmake_generator "Visual Studio 17 2022" ^
|
||||
--use_qnn ^
|
||||
--qnn_home $(QnnSDKRootDir) ^
|
||||
--test --enable_onnx_tests
|
||||
displayName: 'Run unit tests'
|
||||
|
||||
- script: |
|
||||
|
|
|
@ -62,43 +62,43 @@ jobs:
|
|||
parameters:
|
||||
QnnSDKVersion: ${{ parameters.QnnSdk }}
|
||||
|
||||
# TODO: Remove --compile_no_warning_as_error once we update from MSVC Runtime library version 14.32 or we
|
||||
# fix/silence the following warning from the <variant> STL header. This warning halts compilation due to
|
||||
# the /external:templates- option, which allows warnings from external libs for template instantiations.
|
||||
# Warning is not reported on version 14.35.32215 of the Runtime library.
|
||||
#
|
||||
# [warning]C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Tools\MSVC\14.32.31326\include\variant(1586,9)
|
||||
# : Warning C4189: '_Size': local variable is initialized but not referenced
|
||||
# (compiling source file D:\a\_work\1\s\onnxruntime\core\session\IOBinding.cc)
|
||||
#
|
||||
# MSVC\14.32.31326\include\variant(1633): message : see reference to function template instantiation
|
||||
# '_Ret &std::_Visit_strategy<1>::_Visit2<_Ret,_ListOfIndexVectors,_Callable,const
|
||||
# std::variant<onnxruntime::OpSchemaKernelTypeStrResolver,onnxruntime::KernelTypeStrResolver>&>(size_t,_Callable &&,
|
||||
# const std::variant<onnxruntime::OpSchemaKernelTypeStrResolver,onnxruntime::KernelTypeStrResolver> &)'
|
||||
- template: templates/jobs/win-ci-build-steps.yml
|
||||
parameters:
|
||||
WithCache: True
|
||||
Today: $(TODAY)
|
||||
AdditionalKey: "win-qnn | $(BuildConfig)"
|
||||
BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QnnSDKRootDir) --parallel --use_binskim_compliant_compile_flags'
|
||||
BuildPyArguments: >-
|
||||
--config $(BuildConfig)
|
||||
--build_dir $(Build.BinariesDirectory)
|
||||
--cmake_generator "Visual Studio 17 2022"
|
||||
--build_java
|
||||
--use_qnn
|
||||
--qnn_home $(QnnSDKRootDir)
|
||||
--use_binskim_compliant_compile_flags
|
||||
--update --parallel
|
||||
MsbuildArguments: $(MsbuildArguments)
|
||||
BuildArch: $(buildArch)
|
||||
Platform: 'x64'
|
||||
BuildConfig: $(BuildConfig)
|
||||
|
||||
- powershell: |
|
||||
python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests
|
||||
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)'
|
||||
- script: |
|
||||
python $(Build.SourcesDirectory)\tools\ci_build\build.py ^
|
||||
--config $(BuildConfig) ^
|
||||
--build_dir $(Build.BinariesDirectory) ^
|
||||
--cmake_generator "Visual Studio 17 2022" ^
|
||||
--build_java ^
|
||||
--use_qnn ^
|
||||
--qnn_home $(QnnSDKRootDir) ^
|
||||
--use_binskim_compliant_compile_flags ^
|
||||
--test --enable_onnx_tests
|
||||
displayName: 'Run unit tests'
|
||||
|
||||
# Comment out QnnCpu tests because QNN SDK 2.22 CPU backend crashes when executing MatMuls.
|
||||
# Does not happen with HTP backend.
|
||||
# - script: |
|
||||
# .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node
|
||||
# workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
|
||||
# displayName: 'Run ONNX Tests'
|
||||
#
|
||||
# - script: |
|
||||
# .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models
|
||||
# workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
|
||||
# displayName: 'Run float32 model tests'
|
||||
- script: |
|
||||
.\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node
|
||||
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
|
||||
displayName: 'Run ONNX Tests'
|
||||
|
||||
- script: |
|
||||
.\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models
|
||||
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
|
||||
displayName: 'Run float32 model tests'
|
||||
|
|
Загрузка…
Ссылка в новой задаче