diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 5b2e9b2efa..4f3dee3c00 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -483,9 +483,12 @@ public final class OrtUtil { if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) { throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer"); } + // This buffer could be a ByteBuffer which is being used to carry data of another type, if so, + // it's type.size should be 1 to compute the correct buffer size and offset. + int elementSize = data instanceof ByteBuffer ? 1 : type.size; int bufferPos; - long bufferSizeLong = data.remaining() * (long) type.size; - if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) { + long bufferSizeLong = data.remaining() * (long) elementSize; + if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) { // The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending // on the JVM, so we check for something 8 elements below the maximum size which // should be allocatable (assuming there is enough memory) on all 64-bit JVMs. @@ -496,11 +499,11 @@ public final class OrtUtil { + type); } // Now we know we're in range - int bufferSize = data.remaining() * type.size; + int bufferSize = data.remaining() * elementSize; Buffer tmp; if (data.isDirect()) { tmp = data; - bufferPos = data.position() * type.size; + bufferPos = data.position() * elementSize; } else { // Copy the data to a new direct buffer, then restore the state of the input. int origPosition = data.position(); diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index c060cf73ec..ea210d96c1 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -218,6 +218,30 @@ public class OnnxTensorTest { } } + @Test + public void testByteBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder()); + FloatBuffer floatBuf = byteBuf.asFloatBuffer(); + floatBuf.put(1.0f); + floatBuf.put(2.0f); + floatBuf.put(3.0f); + floatBuf.put(4.0f); + floatBuf.put(5.0f); + floatBuf.position(1); + float[] expected = new float[floatBuf.remaining()]; + floatBuf.get(expected); + floatBuf.position(1); + byteBuf.position(4); + try (OnnxTensor t = + OnnxTensor.createTensor( + env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) { + Assertions.assertNotNull(t); + float[] actual = (float[]) t.getValue(); + Assertions.assertArrayEquals(expected, actual); + } + } + @Test public void testEmptyTensor() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment();