[java] Make the backing byte buffer in an OrtValue accessible (#16578)
### Description Adds a method to access the backing direct byte buffer from a Java `OnnxTensor` object, assuming it is backed by a direct byte buffer (tensors created by ORT's run call or ones created in Java from multidimensional arrays are not). Also adds a method to check if the backing byte buffer was copied from the user's buffer supplied on creation (this could be tested via a pointer comparison from the output of `getBufferRef` and the user's input buffer, so I'm not sure if it's necessary). ### Motivation and Context This is the first part of changes necessary to support output pinning in Java OrtSession.run/OrtTrainingSession.run calls. I split it out from the rest of the work as it's useful by itself (e.g. to allow users to keep a single input tensor and rewrite it each time with new inputs rather than allocate a fresh one) and the other change will be much more involved so splitting it makes it easier to review. cc @yuslepukhin
This commit is contained in:
Родитель
57c8736596
Коммит
3456831413
|
@ -13,6 +13,7 @@ import java.nio.FloatBuffer;
|
|||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.ShortBuffer;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
|
||||
|
@ -21,18 +22,60 @@ import java.nio.ShortBuffer;
|
|||
public class OnnxTensor extends OnnxTensorLike {
|
||||
|
||||
/**
|
||||
* This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does
|
||||
* This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does
|
||||
* not go out of scope while the OnnxTensor exists.
|
||||
*/
|
||||
private final Buffer buffer;
|
||||
|
||||
/**
|
||||
* Denotes if the OnnxTensor made a copy of the buffer on construction (i.e. it may have the only
|
||||
* reference).
|
||||
*/
|
||||
private final boolean ownsBuffer;
|
||||
|
||||
OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info) {
|
||||
this(nativeHandle, allocatorHandle, info, null);
|
||||
this(nativeHandle, allocatorHandle, info, null, false);
|
||||
}
|
||||
|
||||
OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) {
|
||||
OnnxTensor(
|
||||
long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer, boolean ownsBuffer) {
|
||||
super(nativeHandle, allocatorHandle, info);
|
||||
this.buffer = buffer;
|
||||
this.ownsBuffer = ownsBuffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the buffer in this OnnxTensor was created on construction of this tensor, i.e.,
|
||||
* it is a copy of a user supplied buffer or array and may hold the only reference to that buffer.
|
||||
*
|
||||
* <p>When this is true the backing buffer was copied from the user input, so users cannot mutate
|
||||
* the state of this buffer without first getting the reference via {@link #getBufferRef()}.
|
||||
*
|
||||
* @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is
|
||||
* a copy of a user buffer.)
|
||||
*/
|
||||
public boolean ownsBuffer() {
|
||||
return this.ownsBuffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not
|
||||
* backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by
|
||||
* ORT) this method returns an empty {@link Optional}.
|
||||
*
|
||||
* <p>Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be
|
||||
* used to repeatedly update a single tensor for multiple different inferences without allocating
|
||||
* new tensors, though the inputs <b>must</b> remain the same size and shape.
|
||||
*
|
||||
* <p>Note: the tensor could refer to a contiguous range of elements in this buffer, not the whole
|
||||
* buffer. It is up to the user to manage this information by respecting the position and limit.
|
||||
* As a consequence, accessing this reference should be considered problematic when multiple
|
||||
* threads hold references to the buffer.
|
||||
*
|
||||
* @return A reference to the buffer.
|
||||
*/
|
||||
public Optional<Buffer> getBufferRef() {
|
||||
return Optional.ofNullable(buffer);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,7 +88,8 @@ public class OnnxTensor extends OnnxTensorLike {
|
|||
* primitives if it has multiple dimensions.
|
||||
*
|
||||
* <p>Java multidimensional arrays are quite slow for more than 2 dimensions, in that case it is
|
||||
* recommended you use the java.nio.Buffer extractors below (e.g. {@link #getFloatBuffer}).
|
||||
* recommended you use the {@link java.nio.Buffer} extractors below (e.g., {@link
|
||||
* #getFloatBuffer}).
|
||||
*
|
||||
* @return A Java value.
|
||||
* @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the
|
||||
|
@ -283,6 +327,12 @@ public class OnnxTensor extends OnnxTensorLike {
|
|||
* multidimensional array. The shape is inferred from the object using reflection. The default
|
||||
* allocator is used.
|
||||
*
|
||||
* <p>Note: Java multidimensional arrays are not dense and this method requires traversing a large
|
||||
* number of pointers for high dimensional arrays. For types other than Strings it is recommended
|
||||
* to use one of the {@code createTensor} methods which accepts a {@link java.nio.Buffer}, e.g.
|
||||
* {@link #createTensor(OrtEnvironment, FloatBuffer, long[])} as those methods are zero copy to
|
||||
* transfer data into ORT when using direct buffers.
|
||||
*
|
||||
* @param env The current OrtEnvironment.
|
||||
* @param data The data to store in a tensor.
|
||||
* @return An OnnxTensor storing the data.
|
||||
|
@ -700,7 +750,8 @@ public class OnnxTensor extends OnnxTensorLike {
|
|||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tuple.data);
|
||||
tuple.data,
|
||||
tuple.isCopy);
|
||||
}
|
||||
|
||||
private static native long createTensor(
|
||||
|
|
|
@ -88,6 +88,91 @@ public class OnnxTensorTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBufferCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
// Test creating a value from an array
|
||||
// Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer
|
||||
float[] arrValues = new float[] {0, 1, 2, 3, 4};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
|
||||
// array creation isn't backed by buffers
|
||||
Assertions.assertFalse(t.ownsBuffer());
|
||||
Assertions.assertFalse(t.getBufferRef().isPresent());
|
||||
FloatBuffer buf = t.getFloatBuffer();
|
||||
float[] output = new float[arrValues.length];
|
||||
buf.get(output);
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
|
||||
// Can't modify the tensor through this buffer.
|
||||
buf.put(0, 25);
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
}
|
||||
|
||||
// Test creating a value from a non-direct byte buffer
|
||||
// Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap
|
||||
// direct byte buffers
|
||||
// which can be directly passed to ORT
|
||||
FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5);
|
||||
nonDirectBuffer.put(arrValues);
|
||||
nonDirectBuffer.rewind();
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, nonDirectBuffer, new long[] {1, 5})) {
|
||||
// non-direct buffers trigger a copy
|
||||
Assertions.assertTrue(t.ownsBuffer());
|
||||
// tensors backed by buffers can get the buffer ref back out
|
||||
Assertions.assertTrue(t.getBufferRef().isPresent());
|
||||
FloatBuffer buf = t.getFloatBuffer();
|
||||
float[] output = new float[arrValues.length];
|
||||
buf.get(output);
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
|
||||
// Can't modify the tensor through getFloatBuffer.
|
||||
buf.put(0, 25);
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
|
||||
// Can modify the tensor through getBufferRef.
|
||||
FloatBuffer ref = (FloatBuffer) t.getBufferRef().get();
|
||||
ref.put(0, 25);
|
||||
buf = t.getFloatBuffer();
|
||||
buf.get(output);
|
||||
Assertions.assertEquals(25, output[0]);
|
||||
}
|
||||
|
||||
// Test creating a value from a direct byte buffer
|
||||
// Direct byte buffers can be passed into ORT without additional copies or processing
|
||||
FloatBuffer directBuffer =
|
||||
ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
|
||||
directBuffer.put(arrValues);
|
||||
directBuffer.rewind();
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) {
|
||||
// direct buffers don't trigger a copy
|
||||
Assertions.assertFalse(t.ownsBuffer());
|
||||
// tensors backed by buffers can get the buffer ref back out
|
||||
Assertions.assertTrue(t.getBufferRef().isPresent());
|
||||
FloatBuffer buf = t.getFloatBuffer();
|
||||
float[] output = new float[arrValues.length];
|
||||
buf.get(output);
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
|
||||
// Can't modify the tensor through getFloatBuffer.
|
||||
buf.put(0, 25);
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
|
||||
// Can modify the tensor through getBufferRef.
|
||||
FloatBuffer ref = (FloatBuffer) t.getBufferRef().get();
|
||||
ref.put(0, 25);
|
||||
buf = t.getFloatBuffer();
|
||||
buf.get(output);
|
||||
Assertions.assertEquals(25, output[0]);
|
||||
|
||||
// Can modify the tensor through our original ref to the direct byte buffer
|
||||
directBuffer.put(1, 15);
|
||||
buf = t.getFloatBuffer();
|
||||
buf.get(output);
|
||||
Assertions.assertEquals(15, output[1]);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringCreation() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
|
Загрузка…
Ссылка в новой задаче