[java] Make the backing byte buffer in an OrtValue accessible (#16578)

### Description
Adds a method to access the backing direct byte buffer from a Java
`OnnxTensor` object, assuming it is backed by a direct byte buffer
(tensors created by ORT's run call or ones created in Java from
multidimensional arrays are not). Also adds a method to check if the
backing byte buffer was copied from the user's buffer supplied on
creation (this could be tested via a pointer comparison from the output
of `getBufferRef` and the user's input buffer, so I'm not sure if it's
necessary).

### Motivation and Context
This is the first part of changes necessary to support output pinning in
Java OrtSession.run/OrtTrainingSession.run calls. I split it out from
the rest of the work as it's useful by itself (e.g. to allow users to
keep a single input tensor and rewrite it each time with new inputs
rather than allocate a fresh one) and the other change will be much more
involved so splitting it makes it easier to review.

cc @yuslepukhin
This commit is contained in:
Adam Pocock 2023-10-17 13:03:49 -04:00 коммит произвёл GitHub
Родитель 57c8736596
Коммит 3456831413
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 141 добавлений и 5 удалений

Просмотреть файл

@ -13,6 +13,7 @@ import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Optional;
/**
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
@ -21,18 +22,60 @@ import java.nio.ShortBuffer;
public class OnnxTensor extends OnnxTensorLike {
/**
* This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does
* This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does
* not go out of scope while the OnnxTensor exists.
*/
private final Buffer buffer;
/**
* Denotes if the OnnxTensor made a copy of the buffer on construction (i.e. it may have the only
* reference).
*/
private final boolean ownsBuffer;
OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info) {
this(nativeHandle, allocatorHandle, info, null);
this(nativeHandle, allocatorHandle, info, null, false);
}
OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) {
OnnxTensor(
long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer, boolean ownsBuffer) {
super(nativeHandle, allocatorHandle, info);
this.buffer = buffer;
this.ownsBuffer = ownsBuffer;
}
/**
* Returns true if the buffer in this OnnxTensor was created on construction of this tensor, i.e.,
* it is a copy of a user supplied buffer or array and may hold the only reference to that buffer.
*
* <p>When this is true the backing buffer was copied from the user input, so users cannot mutate
* the state of this buffer without first getting the reference via {@link #getBufferRef()}.
*
* @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is
* a copy of a user buffer.)
*/
public boolean ownsBuffer() {
return this.ownsBuffer;
}
/**
* Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not
* backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by
* ORT) this method returns an empty {@link Optional}.
*
* <p>Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be
* used to repeatedly update a single tensor for multiple different inferences without allocating
* new tensors, though the inputs <b>must</b> remain the same size and shape.
*
* <p>Note: the tensor could refer to a contiguous range of elements in this buffer, not the whole
* buffer. It is up to the user to manage this information by respecting the position and limit.
* As a consequence, accessing this reference should be considered problematic when multiple
* threads hold references to the buffer.
*
* @return A reference to the buffer.
*/
public Optional<Buffer> getBufferRef() {
return Optional.ofNullable(buffer);
}
@Override
@ -45,7 +88,8 @@ public class OnnxTensor extends OnnxTensorLike {
* primitives if it has multiple dimensions.
*
* <p>Java multidimensional arrays are quite slow for more than 2 dimensions, in that case it is
* recommended you use the java.nio.Buffer extractors below (e.g. {@link #getFloatBuffer}).
* recommended you use the {@link java.nio.Buffer} extractors below (e.g., {@link
* #getFloatBuffer}).
*
* @return A Java value.
* @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the
@ -283,6 +327,12 @@ public class OnnxTensor extends OnnxTensorLike {
* multidimensional array. The shape is inferred from the object using reflection. The default
* allocator is used.
*
* <p>Note: Java multidimensional arrays are not dense and this method requires traversing a large
* number of pointers for high dimensional arrays. For types other than Strings it is recommended
* to use one of the {@code createTensor} methods which accepts a {@link java.nio.Buffer}, e.g.
* {@link #createTensor(OrtEnvironment, FloatBuffer, long[])} as those methods are zero copy to
* transfer data into ORT when using direct buffers.
*
* @param env The current OrtEnvironment.
* @param data The data to store in a tensor.
* @return An OnnxTensor storing the data.
@ -700,7 +750,8 @@ public class OnnxTensor extends OnnxTensorLike {
info.onnxType.value),
allocator.handle,
info,
tuple.data);
tuple.data,
tuple.isCopy);
}
private static native long createTensor(

Просмотреть файл

@ -88,6 +88,91 @@ public class OnnxTensorTest {
}
}
@Test
public void testBufferCreation() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Test creating a value from an array
// Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer
float[] arrValues = new float[] {0, 1, 2, 3, 4};
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
// array creation isn't backed by buffers
Assertions.assertFalse(t.ownsBuffer());
Assertions.assertFalse(t.getBufferRef().isPresent());
FloatBuffer buf = t.getFloatBuffer();
float[] output = new float[arrValues.length];
buf.get(output);
Assertions.assertArrayEquals(arrValues, output);
// Can't modify the tensor through this buffer.
buf.put(0, 25);
Assertions.assertArrayEquals(arrValues, output);
}
// Test creating a value from a non-direct byte buffer
// Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap
// direct byte buffers
// which can be directly passed to ORT
FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5);
nonDirectBuffer.put(arrValues);
nonDirectBuffer.rewind();
try (OnnxTensor t = OnnxTensor.createTensor(env, nonDirectBuffer, new long[] {1, 5})) {
// non-direct buffers trigger a copy
Assertions.assertTrue(t.ownsBuffer());
// tensors backed by buffers can get the buffer ref back out
Assertions.assertTrue(t.getBufferRef().isPresent());
FloatBuffer buf = t.getFloatBuffer();
float[] output = new float[arrValues.length];
buf.get(output);
Assertions.assertArrayEquals(arrValues, output);
// Can't modify the tensor through getFloatBuffer.
buf.put(0, 25);
Assertions.assertArrayEquals(arrValues, output);
// Can modify the tensor through getBufferRef.
FloatBuffer ref = (FloatBuffer) t.getBufferRef().get();
ref.put(0, 25);
buf = t.getFloatBuffer();
buf.get(output);
Assertions.assertEquals(25, output[0]);
}
// Test creating a value from a direct byte buffer
// Direct byte buffers can be passed into ORT without additional copies or processing
FloatBuffer directBuffer =
ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
directBuffer.put(arrValues);
directBuffer.rewind();
try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) {
// direct buffers don't trigger a copy
Assertions.assertFalse(t.ownsBuffer());
// tensors backed by buffers can get the buffer ref back out
Assertions.assertTrue(t.getBufferRef().isPresent());
FloatBuffer buf = t.getFloatBuffer();
float[] output = new float[arrValues.length];
buf.get(output);
Assertions.assertArrayEquals(arrValues, output);
// Can't modify the tensor through getFloatBuffer.
buf.put(0, 25);
Assertions.assertArrayEquals(arrValues, output);
// Can modify the tensor through getBufferRef.
FloatBuffer ref = (FloatBuffer) t.getBufferRef().get();
ref.put(0, 25);
buf = t.getFloatBuffer();
buf.get(output);
Assertions.assertEquals(25, output[0]);
// Can modify the tensor through our original ref to the direct byte buffer
directBuffer.put(1, 15);
buf = t.getFloatBuffer();
buf.get(output);
Assertions.assertEquals(15, output[1]);
}
}
@Test
public void testStringCreation() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();