[java] Fix for OnnxTensor creation when passing in a ByteBuffer containing elements of a different type (#21774)

### Description
Fixes a bug where the buffer offset and position was incorrectly
computed if the user supplied a `ByteBuffer` to `createTensor` but set
the type of the tensor to something other than `INT8`. This would be
more common if the user was trying to load the initializers from a
serialized representation and didn't want to bother with the type
information (which is the case in #21321).

### Motivation and Context
Partial fix for #21321. The remainder of the fix is to add a helper
which allows users to load initializers out of an `onnx_data` file, but
that will require adding protobuf as a dependency for the Java API to
allow the parsing of an ONNX file separately from the native code. It
might be nicer to put that functionality into ORT's C API so it can
return the lengths & offsets of the initializers when provided with an
ONNX file containing external initializers. We hit this kind of thing in
Java more often than other languages as in Java models can be supplied
as classpath resources which we can easily read, but not materialize on
disk for the ORT native library to read.
This commit is contained in:
Adam Pocock 2024-09-12 22:38:17 -04:00 коммит произвёл GitHub
Родитель f7bf5a19ba
Коммит 22437b581b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 33 добавлений и 6 удалений

Просмотреть файл

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
@ -483,9 +483,12 @@ public final class OrtUtil {
if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) {
throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer");
}
// This buffer could be a ByteBuffer which is being used to carry data of another type, if so,
// it's type.size should be 1 to compute the correct buffer size and offset.
int elementSize = data instanceof ByteBuffer ? 1 : type.size;
int bufferPos;
long bufferSizeLong = data.remaining() * (long) type.size;
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
long bufferSizeLong = data.remaining() * (long) elementSize;
if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) {
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
// on the JVM, so we check for something 8 elements below the maximum size which
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
@ -496,11 +499,11 @@ public final class OrtUtil {
+ type);
}
// Now we know we're in range
int bufferSize = data.remaining() * type.size;
int bufferSize = data.remaining() * elementSize;
Buffer tmp;
if (data.isDirect()) {
tmp = data;
bufferPos = data.position() * type.size;
bufferPos = data.position() * elementSize;
} else {
// Copy the data to a new direct buffer, then restore the state of the input.
int origPosition = data.position();

Просмотреть файл

@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -218,6 +218,30 @@ public class OnnxTensorTest {
}
}
@Test
public void testByteBufferCreation() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder());
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
floatBuf.put(1.0f);
floatBuf.put(2.0f);
floatBuf.put(3.0f);
floatBuf.put(4.0f);
floatBuf.put(5.0f);
floatBuf.position(1);
float[] expected = new float[floatBuf.remaining()];
floatBuf.get(expected);
floatBuf.position(1);
byteBuf.position(4);
try (OnnxTensor t =
OnnxTensor.createTensor(
env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) {
Assertions.assertNotNull(t);
float[] actual = (float[]) t.getValue();
Assertions.assertArrayEquals(expected, actual);
}
}
@Test
public void testEmptyTensor() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();