[java] Fix for OnnxTensor creation when passing in a ByteBuffer containing elements of a different type (#21774)
### Description Fixes a bug where the buffer offset and position was incorrectly computed if the user supplied a `ByteBuffer` to `createTensor` but set the type of the tensor to something other than `INT8`. This would be more common if the user was trying to load the initializers from a serialized representation and didn't want to bother with the type information (which is the case in #21321). ### Motivation and Context Partial fix for #21321. The remainder of the fix is to add a helper which allows users to load initializers out of an `onnx_data` file, but that will require adding protobuf as a dependency for the Java API to allow the parsing of an ONNX file separately from the native code. It might be nicer to put that functionality into ORT's C API so it can return the lengths & offsets of the initializers when provided with an ONNX file containing external initializers. We hit this kind of thing in Java more often than other languages as in Java models can be supplied as classpath resources which we can easily read, but not materialize on disk for the ORT native library to read.
This commit is contained in:
Родитель
f7bf5a19ba
Коммит
22437b581b
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
|
@ -483,9 +483,12 @@ public final class OrtUtil {
|
|||
if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) {
|
||||
throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer");
|
||||
}
|
||||
// This buffer could be a ByteBuffer which is being used to carry data of another type, if so,
|
||||
// it's type.size should be 1 to compute the correct buffer size and offset.
|
||||
int elementSize = data instanceof ByteBuffer ? 1 : type.size;
|
||||
int bufferPos;
|
||||
long bufferSizeLong = data.remaining() * (long) type.size;
|
||||
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
|
||||
long bufferSizeLong = data.remaining() * (long) elementSize;
|
||||
if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) {
|
||||
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
|
||||
// on the JVM, so we check for something 8 elements below the maximum size which
|
||||
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
|
||||
|
@ -496,11 +499,11 @@ public final class OrtUtil {
|
|||
+ type);
|
||||
}
|
||||
// Now we know we're in range
|
||||
int bufferSize = data.remaining() * type.size;
|
||||
int bufferSize = data.remaining() * elementSize;
|
||||
Buffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
bufferPos = data.position() * type.size;
|
||||
bufferPos = data.position() * elementSize;
|
||||
} else {
|
||||
// Copy the data to a new direct buffer, then restore the state of the input.
|
||||
int origPosition = data.position();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
@ -218,6 +218,30 @@ public class OnnxTensorTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testByteBufferCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder());
|
||||
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
|
||||
floatBuf.put(1.0f);
|
||||
floatBuf.put(2.0f);
|
||||
floatBuf.put(3.0f);
|
||||
floatBuf.put(4.0f);
|
||||
floatBuf.put(5.0f);
|
||||
floatBuf.position(1);
|
||||
float[] expected = new float[floatBuf.remaining()];
|
||||
floatBuf.get(expected);
|
||||
floatBuf.position(1);
|
||||
byteBuf.position(4);
|
||||
try (OnnxTensor t =
|
||||
OnnxTensor.createTensor(
|
||||
env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) {
|
||||
Assertions.assertNotNull(t);
|
||||
float[] actual = (float[]) t.getValue();
|
||||
Assertions.assertArrayEquals(expected, actual);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyTensor() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
|
Загрузка…
Ссылка в новой задаче