[java] Sparse tensor support (#10653)
**Description**: Adds support for creating and receiving sparse tensors in the ORT Java API. CSRC and COO tensors as inputs are tested, but there is no op which accepts a block sparse tensor to test. COO tensors are tested as outputs, but there is no op which emits a CSRC or block sparse tensor to test. **Motivation and Context** - Why is this change required? What problem does it solve? Request to expose ORT sparse tensor support in Java. cc @yuslepukhin
This commit is contained in:
Родитель
8b0e0f4927
Коммит
dd2c031d95
|
@ -0,0 +1,920 @@
|
|||
/*
|
||||
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import static ai.onnxruntime.OnnxTensor.fp16ToFloat;
|
||||
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.ShortBuffer;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* A Java object wrapping an OnnxSparseTensor.
|
||||
*
|
||||
* <p>Sparse tensors support a variety of formats, and the {@link #getValue} method returns a
|
||||
* different static inner class representing each type.
|
||||
*/
|
||||
public final class OnnxSparseTensor extends OnnxTensorLike {
|
||||
private final SparseTensorType sparseTensorType;
|
||||
|
||||
// Held to prevent deallocation while used in native code.
|
||||
private final Buffer indices;
|
||||
private final LongBuffer innerIndices;
|
||||
private final Buffer values;
|
||||
|
||||
/**
|
||||
* Construct a sparse tensor from JNI.
|
||||
*
|
||||
* @param nativeHandle The tensor native handle.
|
||||
* @param allocatorHandle The allocator handle.
|
||||
* @param sparseType The sparsity type.
|
||||
* @param info The tensor info.
|
||||
*/
|
||||
OnnxSparseTensor(long nativeHandle, long allocatorHandle, int sparseType, TensorInfo info) {
|
||||
this(
|
||||
nativeHandle,
|
||||
allocatorHandle,
|
||||
SparseTensorType.mapFromInt(sparseType),
|
||||
info,
|
||||
null,
|
||||
null,
|
||||
null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a COO or block sparse tensor.
|
||||
*
|
||||
* @param nativeHandle The tensor native handle.
|
||||
* @param allocatorHandle The allocator handle.
|
||||
* @param sparseType The sparsity type.
|
||||
* @param info The tensor info.
|
||||
* @param indices The indices buffer.
|
||||
* @param values The data buffer.
|
||||
*/
|
||||
OnnxSparseTensor(
|
||||
long nativeHandle,
|
||||
long allocatorHandle,
|
||||
SparseTensorType sparseType,
|
||||
TensorInfo info,
|
||||
Buffer indices,
|
||||
Buffer values) {
|
||||
this(nativeHandle, allocatorHandle, sparseType, info, indices, null, values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a sparse tensor.
|
||||
*
|
||||
* <p>If the tensor is COO or block sparse then innerIndices may be null.
|
||||
*
|
||||
* @param nativeHandle The tensor native handle.
|
||||
* @param allocatorHandle The allocator handle.
|
||||
* @param sparseType The sparsity type.
|
||||
* @param info The tensor info.
|
||||
* @param indices The indices buffer.
|
||||
* @param innerIndices The inner indices buffer.
|
||||
* @param values The data buffer.
|
||||
*/
|
||||
OnnxSparseTensor(
|
||||
long nativeHandle,
|
||||
long allocatorHandle,
|
||||
SparseTensorType sparseType,
|
||||
TensorInfo info,
|
||||
Buffer indices,
|
||||
LongBuffer innerIndices,
|
||||
Buffer values) {
|
||||
super(nativeHandle, allocatorHandle, info);
|
||||
this.sparseTensorType = sparseType;
|
||||
this.indices = indices;
|
||||
this.innerIndices = innerIndices;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Sparse Tensor in ORT from the Java side representation.
|
||||
*
|
||||
* @param env The OrtEnvironment.
|
||||
* @param tensor The Java side representation.
|
||||
* @param <T> The buffer type.
|
||||
* @return The sparse tensor in ORT.
|
||||
* @throws OrtException If the tensor could not be created or was invalid.
|
||||
*/
|
||||
public static <T extends Buffer> OnnxSparseTensor createSparseTensor(
|
||||
OrtEnvironment env, SparseTensor<T> tensor) throws OrtException {
|
||||
return createSparseTensor(env, env.defaultAllocator, tensor);
|
||||
}
|
||||
|
||||
static <T extends Buffer> OnnxSparseTensor createSparseTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, SparseTensor<T> tensor) throws OrtException {
|
||||
if (!allocator.isClosed()) {
|
||||
TensorInfo info = TensorInfo.constructFromSparseTensor(tensor);
|
||||
OnnxJavaType indicesType = tensor.getIndicesType();
|
||||
OrtUtil.BufferTuple indicesTuple = OrtUtil.prepareBuffer(tensor.getIndices(), indicesType);
|
||||
OrtUtil.BufferTuple valuesTuple = OrtUtil.prepareBuffer(tensor.getValues(), info.type);
|
||||
if (!((indicesTuple.data instanceof LongBuffer)
|
||||
|| (indicesTuple.data instanceof IntBuffer))) {
|
||||
throw new IllegalStateException(
|
||||
"Unexpected type of indices buffer, found "
|
||||
+ indicesTuple.data.getClass()
|
||||
+ ", expected IntBuffer or LongBuffer");
|
||||
}
|
||||
// Replace with a type switch when using JDK 17+.
|
||||
switch (tensor.getSparsityType()) {
|
||||
case COO:
|
||||
case BLOCK_SPARSE:
|
||||
return new OnnxSparseTensor(
|
||||
createSparseTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
indicesTuple.data,
|
||||
indicesTuple.pos,
|
||||
indicesTuple.size,
|
||||
valuesTuple.data,
|
||||
valuesTuple.pos,
|
||||
info.shape,
|
||||
tensor.getIndicesShape(),
|
||||
tensor.getValuesShape(),
|
||||
info.onnxType.value,
|
||||
tensor.getSparsityType().value),
|
||||
allocator.handle,
|
||||
tensor.getSparsityType(),
|
||||
info,
|
||||
indicesTuple.data,
|
||||
valuesTuple.data);
|
||||
case CSRC:
|
||||
OrtUtil.BufferTuple innerIndicesTuple =
|
||||
OrtUtil.prepareBuffer(((CSRCTensor) tensor).getInnerIndices(), indicesType);
|
||||
return new OnnxSparseTensor(
|
||||
createCSRCSparseTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
indicesTuple.data,
|
||||
indicesTuple.pos,
|
||||
indicesTuple.size,
|
||||
innerIndicesTuple.data,
|
||||
innerIndicesTuple.pos,
|
||||
innerIndicesTuple.size,
|
||||
valuesTuple.data,
|
||||
valuesTuple.pos,
|
||||
info.shape,
|
||||
tensor.getValuesShape(),
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
tensor.getSparsityType(),
|
||||
info,
|
||||
indicesTuple.data,
|
||||
(LongBuffer) innerIndicesTuple.data,
|
||||
valuesTuple.data);
|
||||
case UNDEFINED:
|
||||
default:
|
||||
throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor.");
|
||||
}
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"Trying to create an OnnxSparseTensor on a closed OrtAllocator.");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxValueType getType() {
|
||||
return OnnxValueType.ONNX_TYPE_SPARSETENSOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseTensor<? extends Buffer> getValue() throws OrtException {
|
||||
Buffer buffer = getValuesBuffer();
|
||||
long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
switch (sparseTensorType) {
|
||||
case COO:
|
||||
return new COOTensor(
|
||||
(LongBuffer) getIndicesBuffer(),
|
||||
indicesShape,
|
||||
buffer,
|
||||
info.shape,
|
||||
info.type,
|
||||
buffer.remaining());
|
||||
case CSRC:
|
||||
return new CSRCTensor(
|
||||
(LongBuffer) getIndicesBuffer(),
|
||||
getInnerIndicesBuffer(),
|
||||
buffer,
|
||||
info.shape,
|
||||
info.type,
|
||||
buffer.remaining());
|
||||
case BLOCK_SPARSE:
|
||||
long[] valuesShape = getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
return new BlockSparseTensor(
|
||||
(IntBuffer) getIndicesBuffer(),
|
||||
indicesShape,
|
||||
buffer,
|
||||
valuesShape,
|
||||
info.shape,
|
||||
info.type,
|
||||
buffer.remaining());
|
||||
case UNDEFINED:
|
||||
default:
|
||||
throw new IllegalStateException("Undefined sparsity type in this sparse tensor.");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
close(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the type of this OnnxSparseTensor.
|
||||
*
|
||||
* @return The sparsity type.
|
||||
*/
|
||||
public SparseTensorType getSparseTensorType() {
|
||||
return sparseTensorType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a copy of the indices.
|
||||
*
|
||||
* <p>These are the outer indices if it's a CSRC sparse tensor.
|
||||
*
|
||||
* <p>It's a {@link LongBuffer} if COO or CSRC, and {@link IntBuffer} if Block Sparse.
|
||||
*
|
||||
* @return The indices.
|
||||
*/
|
||||
public Buffer getIndicesBuffer() {
|
||||
switch (sparseTensorType) {
|
||||
case COO:
|
||||
case CSRC:
|
||||
{
|
||||
LongBuffer longBuf =
|
||||
getIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asLongBuffer();
|
||||
LongBuffer output = LongBuffer.allocate(longBuf.capacity());
|
||||
output.put(longBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case BLOCK_SPARSE:
|
||||
{
|
||||
IntBuffer intBuf =
|
||||
getIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asIntBuffer();
|
||||
IntBuffer output = IntBuffer.allocate(intBuf.capacity());
|
||||
output.put(intBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case UNDEFINED:
|
||||
default:
|
||||
throw new IllegalStateException("UNDEFINED sparse tensor type.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a copy of the inner indices in a CSRC sparse tensor.
|
||||
*
|
||||
* <p>Throws {@link IllegalStateException} if called on a different sparse tensor type.
|
||||
*
|
||||
* @return The inner indices.
|
||||
*/
|
||||
public LongBuffer getInnerIndicesBuffer() {
|
||||
if (sparseTensorType == SparseTensorType.CSRC) {
|
||||
LongBuffer buf =
|
||||
getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asLongBuffer();
|
||||
LongBuffer output = LongBuffer.allocate(buf.capacity());
|
||||
output.put(buf);
|
||||
output.rewind();
|
||||
return output;
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"Inner indices are only available for CSRC sparse tensors, this sparse tensor is "
|
||||
+ sparseTensorType);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a copy of the data buffer.
|
||||
*
|
||||
* <p>As with {@link OnnxTensor} fp16 values are upcast into fp32 and returned as a {@link
|
||||
* FloatBuffer}.
|
||||
*
|
||||
* @return The data buffer.
|
||||
*/
|
||||
public Buffer getValuesBuffer() {
|
||||
ByteBuffer buffer =
|
||||
getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
|
||||
switch (info.type) {
|
||||
case FLOAT:
|
||||
if (info.onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
|
||||
ShortBuffer shortBuffer = buffer.asShortBuffer();
|
||||
int bufferCap = shortBuffer.capacity();
|
||||
FloatBuffer output = FloatBuffer.allocate(bufferCap);
|
||||
for (int i = 0; i < bufferCap; i++) {
|
||||
output.put(fp16ToFloat(shortBuffer.get(i)));
|
||||
}
|
||||
output.rewind();
|
||||
return output;
|
||||
} else if (info.onnxType
|
||||
== TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
|
||||
throw new IllegalArgumentException("BFloat16 is not supported.");
|
||||
} else {
|
||||
// regular fp32
|
||||
FloatBuffer floatBuf = buffer.asFloatBuffer();
|
||||
FloatBuffer output = FloatBuffer.allocate(floatBuf.capacity());
|
||||
output.put(floatBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case DOUBLE:
|
||||
{
|
||||
DoubleBuffer doubleBuf = buffer.asDoubleBuffer();
|
||||
DoubleBuffer output = DoubleBuffer.allocate(doubleBuf.capacity());
|
||||
output.put(doubleBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case INT16:
|
||||
{
|
||||
ShortBuffer shortBuf = buffer.asShortBuffer();
|
||||
ShortBuffer output = ShortBuffer.allocate(shortBuf.capacity());
|
||||
output.put(shortBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case INT32:
|
||||
{
|
||||
IntBuffer intBuf = buffer.asIntBuffer();
|
||||
IntBuffer output = IntBuffer.allocate(intBuf.capacity());
|
||||
output.put(intBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case INT64:
|
||||
{
|
||||
LongBuffer longBuf = buffer.asLongBuffer();
|
||||
LongBuffer output = LongBuffer.allocate(longBuf.capacity());
|
||||
output.put(longBuf);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case BOOL:
|
||||
case INT8:
|
||||
case UINT8:
|
||||
{
|
||||
ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
|
||||
output.put(buffer);
|
||||
output.rewind();
|
||||
return output;
|
||||
}
|
||||
case STRING:
|
||||
throw new IllegalStateException("Unsupported data type String");
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new IllegalStateException("Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the (outer) indices.
|
||||
*
|
||||
* @return The indices shape.
|
||||
*/
|
||||
public long[] getIndicesShape() {
|
||||
return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the inner indices in a CSRC sparse tensor.
|
||||
*
|
||||
* @return The indices shape.
|
||||
*/
|
||||
public long[] getInnerIndicesShape() {
|
||||
if (sparseTensorType == SparseTensorType.CSRC) {
|
||||
return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"Inner indices are only available for CSRC sparse tensors, this sparse tensor is "
|
||||
+ sparseTensorType);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the values.
|
||||
*
|
||||
* @return The values shape.
|
||||
*/
|
||||
public long[] getValuesShape() {
|
||||
return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the (outer) indices.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
* @return The indices shape.
|
||||
*/
|
||||
private native long[] getIndicesShape(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Gets the shape of the inner indices.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
* @return The inner indices shape.
|
||||
*/
|
||||
private native long[] getInnerIndicesShape(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Gets the shape of the values.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
* @return The values shape.
|
||||
*/
|
||||
private native long[] getValuesShape(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Wraps the indices in a direct byte buffer.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
* @return A ByteBuffer wrapping the indices.
|
||||
*/
|
||||
private native ByteBuffer getIndicesBuffer(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Wraps the inner indices in a direct byte buffer.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
* @return A ByteBuffer wrapping the inner indices.
|
||||
*/
|
||||
private native ByteBuffer getInnerIndicesBuffer(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Wraps the data in a direct byte buffer.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
* @return A ByteBuffer wrapping the indices.
|
||||
*/
|
||||
private native ByteBuffer getValuesBuffer(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Closes the sparse tensor.
|
||||
*
|
||||
* @param apiHandle The OrtApi pointer.
|
||||
* @param nativeHandle The OrtSparseTensor pointer.
|
||||
*/
|
||||
private native void close(long apiHandle, long nativeHandle);
|
||||
|
||||
/**
|
||||
* Creates a sparse CSRC sparse tensor.
|
||||
*
|
||||
* <p>The buffers must be kept alive for the lifetime of the ORT sparse tensor object.
|
||||
*
|
||||
* @param apiHandle The ORT API pointer.
|
||||
* @param allocatorHandle The allocator pointer.
|
||||
* @param indicesData The outer indices.
|
||||
* @param indicesBufferPos The outer indices position in bytes.
|
||||
* @param indicesBufferSize The outer indices buffer size in longs.
|
||||
* @param innerIndicesData The inner indices.
|
||||
* @param innerIndicesBufferPos The inner indices position in bytes.
|
||||
* @param innerIndicesBufferSize The inner indices buffer size in longs.
|
||||
* @param values The data.
|
||||
* @param bufferPos The data position in bytes.
|
||||
* @param denseShape The dense shape of the tensor.
|
||||
* @param valuesShape The shape of the values (should be a vector).
|
||||
* @param onnxType The type of the values.
|
||||
* @return A pointer to an ORT sparse tensor value.
|
||||
* @throws OrtException If the tensor could not be created.
|
||||
*/
|
||||
private static native long createCSRCSparseTensorFromBuffer(
|
||||
long apiHandle,
|
||||
long allocatorHandle,
|
||||
Buffer indicesData,
|
||||
int indicesBufferPos,
|
||||
long indicesBufferSize,
|
||||
Buffer innerIndicesData,
|
||||
int innerIndicesBufferPos,
|
||||
long innerIndicesBufferSize,
|
||||
Buffer values,
|
||||
int bufferPos,
|
||||
long[] denseShape,
|
||||
long[] valuesShape,
|
||||
int onnxType)
|
||||
throws OrtException;
|
||||
|
||||
/**
|
||||
* Creates a sparse COO or block sparse tensor.
|
||||
*
|
||||
* <p>The buffers must be kept alive for the lifetime of the ORT sparse tensor object.
|
||||
*
|
||||
* @param apiHandle The ORT API pointer.
|
||||
* @param allocatorHandle The allocator pointer.
|
||||
* @param indicesData The indices.
|
||||
* @param indicesBufferPos The indices position in bytes.
|
||||
* @param indicesBufferSize The indices buffer size in longs.
|
||||
* @param values The data.
|
||||
* @param bufferPos The data position in bytes.
|
||||
* @param denseShape The dense shape of the tensor.
|
||||
* @param indicesShape The shape of the indices (a vector or matrix for COO, and a matrix for
|
||||
* block sparse).
|
||||
* @param valuesShape The shape of the values (a vector for COO, and a block shape for block
|
||||
* sparse).
|
||||
* @param onnxType The type of the values.
|
||||
* @param sparsityType The sparsity type.
|
||||
* @return A pointer to an ORT sparse tensor value.
|
||||
* @throws OrtException If the tensor could not be created.
|
||||
*/
|
||||
private static native long createSparseTensorFromBuffer(
|
||||
long apiHandle,
|
||||
long allocatorHandle,
|
||||
Buffer indicesData,
|
||||
int indicesBufferPos,
|
||||
long indicesBufferSize,
|
||||
Buffer values,
|
||||
int bufferPos,
|
||||
long[] denseShape,
|
||||
long[] indicesShape,
|
||||
long[] valuesShape,
|
||||
int onnxType,
|
||||
int sparsityType)
|
||||
throws OrtException;
|
||||
|
||||
/**
|
||||
* The type of the sparse tensor.
|
||||
*
|
||||
* <p>Should be synchronized with OrtSparseFormat in the C API.
|
||||
*/
|
||||
public enum SparseTensorType {
|
||||
/** Undefined sparse tensor. */
|
||||
UNDEFINED(0),
|
||||
/** COO sparse tensor. */
|
||||
COO(1),
|
||||
/** CSR or CSC sparse tensor. */
|
||||
CSRC(2),
|
||||
/** Block sparse tensor. */
|
||||
BLOCK_SPARSE(4);
|
||||
|
||||
/** The int value mirroring OrtSparseFormat. */
|
||||
public final int value;
|
||||
|
||||
private static final SparseTensorType[] values = new SparseTensorType[5];
|
||||
|
||||
static {
|
||||
values[0] = UNDEFINED;
|
||||
values[1] = COO;
|
||||
values[2] = CSRC;
|
||||
values[3] = UNDEFINED;
|
||||
values[4] = BLOCK_SPARSE;
|
||||
}
|
||||
|
||||
SparseTensorType(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps from an int in native land into a SparseTensorType instance.
|
||||
*
|
||||
* @param value The value to lookup.
|
||||
* @return The enum instance.
|
||||
*/
|
||||
public static SparseTensorType mapFromInt(int value) {
|
||||
if ((value > 0) && (value < values.length)) {
|
||||
return values[value];
|
||||
} else {
|
||||
return UNDEFINED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstract base class for Java sparse tensors
|
||||
*
|
||||
* <p>Will be sealed to {@link COOTensor}, {@link CSRCTensor} and {@link BlockSparseTensor} one
|
||||
* day.
|
||||
*/
|
||||
public abstract static class SparseTensor<T extends Buffer> {
|
||||
private final long[] indicesShape;
|
||||
private final long[] valuesShape;
|
||||
private final long[] denseShape;
|
||||
private final OnnxJavaType type;
|
||||
private final long numNonZero;
|
||||
|
||||
final T indices;
|
||||
final Buffer values;
|
||||
|
||||
SparseTensor(
|
||||
T indices,
|
||||
long[] indicesShape,
|
||||
Buffer values,
|
||||
long[] valuesShape,
|
||||
long[] denseShape,
|
||||
OnnxJavaType type,
|
||||
long numNonZero) {
|
||||
this.indices = indices;
|
||||
this.indicesShape = indicesShape;
|
||||
this.values = values;
|
||||
this.valuesShape = valuesShape;
|
||||
this.denseShape = denseShape;
|
||||
this.type = type;
|
||||
this.numNonZero = numNonZero;
|
||||
if (values.remaining() != numNonZero) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected numNonZero and data.remaining to be equal, found "
|
||||
+ numNonZero
|
||||
+ " and "
|
||||
+ values.remaining()
|
||||
+ " respectively");
|
||||
}
|
||||
if (type == OnnxJavaType.STRING) {
|
||||
throw new IllegalArgumentException("String SparseTensors are not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the dense shape of the sparse tensor.
|
||||
*
|
||||
* @return The sparse tensor shape.
|
||||
*/
|
||||
public long[] getDenseShape() {
|
||||
return denseShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* The data type of the sparse tensor.
|
||||
*
|
||||
* @return The sparse tensor data type.
|
||||
*/
|
||||
public OnnxJavaType getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of non-zero elements.
|
||||
*
|
||||
* @return The number of non-zero elements.
|
||||
*/
|
||||
public long getNumNonZeroElements() {
|
||||
return numNonZero;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the indices buffer.
|
||||
*
|
||||
* @return The indices buffer.
|
||||
*/
|
||||
public T getIndices() {
|
||||
return indices;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the value buffer.
|
||||
*
|
||||
* @return The value buffer.
|
||||
*/
|
||||
public Buffer getValues() {
|
||||
return values;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the values of the sparse tensor.
|
||||
*
|
||||
* @return The sparse tensor value shape.
|
||||
*/
|
||||
public long[] getValuesShape() {
|
||||
return valuesShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the indices of the sparse tensor.
|
||||
*
|
||||
* @return The sparse tensor indices shape.
|
||||
*/
|
||||
public long[] getIndicesShape() {
|
||||
return indicesShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* The sparsity type of the sparse tensor.
|
||||
*
|
||||
* @return The sparse tensor sparsity type.
|
||||
*/
|
||||
public abstract SparseTensorType getSparsityType();
|
||||
|
||||
/**
|
||||
* The indices type of the sparse tensor.
|
||||
*
|
||||
* <p>Only {@link OnnxJavaType#INT32} and {@link OnnxJavaType#INT64} are supported.
|
||||
*
|
||||
* @return The sparse tensor indices type.
|
||||
*/
|
||||
public abstract OnnxJavaType getIndicesType();
|
||||
}
|
||||
|
||||
/** The Java side representation of a COO sparse tensor. */
|
||||
public static final class COOTensor extends SparseTensor<LongBuffer> {
|
||||
/**
|
||||
* Creates a COO sparse tensor suitable for constructing an ORT Sparse Tensor.
|
||||
*
|
||||
* @param indices The indices. Should be a 1d vector, or a 2d vector.
|
||||
* @param indicesShape The shape of the indices.
|
||||
* @param values The data.
|
||||
* @param denseShape The dense shape.
|
||||
* @param type The data type.
|
||||
* @param numNonZero The number of non-zero elements.
|
||||
*/
|
||||
public COOTensor(
|
||||
LongBuffer indices,
|
||||
long[] indicesShape,
|
||||
Buffer values,
|
||||
long[] denseShape,
|
||||
OnnxJavaType type,
|
||||
long numNonZero) {
|
||||
super(indices, indicesShape, values, new long[] {numNonZero}, denseShape, type, numNonZero);
|
||||
if ((indicesShape.length > 2)
|
||||
|| (indicesShape.length == 0)
|
||||
|| (indicesShape[0] != numNonZero)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found "
|
||||
+ Arrays.toString(indicesShape));
|
||||
}
|
||||
long elementCount = OrtUtil.elementCount(indicesShape);
|
||||
if (elementCount != indices.remaining()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Unexpected number of indices found in buffer, expected "
|
||||
+ elementCount
|
||||
+ " found "
|
||||
+ indices.remaining());
|
||||
}
|
||||
if (values.remaining() != numNonZero) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected data.remaining() - "
|
||||
+ values.remaining()
|
||||
+ " to equal numNonZero - "
|
||||
+ numNonZero);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxJavaType getIndicesType() {
|
||||
return OnnxJavaType.INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseTensorType getSparsityType() {
|
||||
return SparseTensorType.COO;
|
||||
}
|
||||
}
|
||||
|
||||
/** The Java side representation of a CSRC sparse tensor. */
|
||||
public static final class CSRCTensor extends SparseTensor<LongBuffer> {
|
||||
private final LongBuffer innerIndices;
|
||||
|
||||
/**
|
||||
* Creates a CSRC sparse tensor suitable for constructing an ORT Sparse Tensor.
|
||||
*
|
||||
* @param outerIndices The outer indices.
|
||||
* @param innerIndices The inner indices.
|
||||
* @param values The data.
|
||||
* @param denseShape The dense shape.
|
||||
* @param type The data type.
|
||||
* @param numNonZero The number of non-zero elements.
|
||||
*/
|
||||
public CSRCTensor(
|
||||
LongBuffer outerIndices,
|
||||
LongBuffer innerIndices,
|
||||
Buffer values,
|
||||
long[] denseShape,
|
||||
OnnxJavaType type,
|
||||
long numNonZero) {
|
||||
super(
|
||||
outerIndices,
|
||||
new long[] {outerIndices.remaining()},
|
||||
values,
|
||||
new long[] {numNonZero},
|
||||
denseShape,
|
||||
type,
|
||||
numNonZero);
|
||||
this.innerIndices = innerIndices;
|
||||
long expectedRows = denseShape[0] + 1;
|
||||
if (outerIndices.remaining() != expectedRows) {
|
||||
throw new IllegalArgumentException(
|
||||
"Outer indices should be equal to the number of rows + 1 in the dense shape, found "
|
||||
+ outerIndices.remaining()
|
||||
+ ", expected "
|
||||
+ expectedRows);
|
||||
}
|
||||
if (innerIndices.remaining() != numNonZero) {
|
||||
throw new IllegalArgumentException(
|
||||
"Inner indices should be equal to the number of non-zero elements, found "
|
||||
+ innerIndices.remaining()
|
||||
+ ", expected "
|
||||
+ numNonZero);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the inner indices.
|
||||
*
|
||||
* @return The inner indices shape.
|
||||
*/
|
||||
public long[] getInnerIndicesShape() {
|
||||
return new long[] {innerIndices.remaining()};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the inner indices buffer.
|
||||
*
|
||||
* @return The inner indices buffer.
|
||||
*/
|
||||
public LongBuffer getInnerIndices() {
|
||||
return innerIndices;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxJavaType getIndicesType() {
|
||||
return OnnxJavaType.INT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseTensorType getSparsityType() {
|
||||
return SparseTensorType.CSRC;
|
||||
}
|
||||
}
|
||||
|
||||
/** The Java side representation of a block sparse tensor. */
|
||||
public static final class BlockSparseTensor extends SparseTensor<IntBuffer> {
|
||||
/**
|
||||
* Construct a block sparse tensor.
|
||||
*
|
||||
* @param indices The indices.
|
||||
* @param indicesShape The shape of the indices.
|
||||
* @param values The data.
|
||||
* @param valuesShape The shape of the data.
|
||||
* @param denseShape The dense shape.
|
||||
* @param type The data type.
|
||||
* @param numNonZero The number of non-zero elements.
|
||||
*/
|
||||
public BlockSparseTensor(
|
||||
IntBuffer indices,
|
||||
long[] indicesShape,
|
||||
Buffer values,
|
||||
long[] valuesShape,
|
||||
long[] denseShape,
|
||||
OnnxJavaType type,
|
||||
long numNonZero) {
|
||||
super(indices, indicesShape, values, valuesShape, denseShape, type, numNonZero);
|
||||
if (OrtUtil.elementCount(valuesShape) != numNonZero) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected "
|
||||
+ numNonZero
|
||||
+ " entries in the data shape, found "
|
||||
+ Arrays.toString(valuesShape));
|
||||
}
|
||||
if (numNonZero != values.remaining()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected " + numNonZero + " elements in the data buffer, found " + values.remaining());
|
||||
}
|
||||
if (OrtUtil.elementCount(indicesShape) != indices.remaining()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected "
|
||||
+ OrtUtil.elementCount(indicesShape)
|
||||
+ " elements in the indices buffer, found "
|
||||
+ indices.remaining());
|
||||
}
|
||||
if (valuesShape.length < 3) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected [numBlocks, blockSize, blockSize] or larger, but data shape was "
|
||||
+ Arrays.toString(valuesShape));
|
||||
}
|
||||
if (indicesShape.length < 2) {
|
||||
throw new IllegalArgumentException(
|
||||
"Expected [numBlocks, co-ordinates] or larger, but indices shape was "
|
||||
+ Arrays.toString(indicesShape));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxJavaType getIndicesType() {
|
||||
return OnnxJavaType.INT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseTensorType getSparsityType() {
|
||||
return SparseTensorType.BLOCK_SPARSE;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,10 +1,9 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
@ -18,21 +17,7 @@ import java.nio.ShortBuffer;
|
|||
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
|
||||
* returned as outputs.
|
||||
*/
|
||||
public class OnnxTensor implements OnnxValue {
|
||||
static {
|
||||
try {
|
||||
OnnxRuntime.init();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load onnx-runtime library", e);
|
||||
}
|
||||
}
|
||||
|
||||
private final long nativeHandle;
|
||||
|
||||
private final long allocatorHandle;
|
||||
|
||||
private final TensorInfo info;
|
||||
|
||||
public class OnnxTensor extends OnnxTensorLike {
|
||||
/**
|
||||
* This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does
|
||||
* not go out of scope while the OnnxTensor exists.
|
||||
|
@ -44,9 +29,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
}
|
||||
|
||||
OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) {
|
||||
this.nativeHandle = nativeHandle;
|
||||
this.allocatorHandle = allocatorHandle;
|
||||
this.info = info;
|
||||
super(nativeHandle, allocatorHandle, info);
|
||||
this.buffer = buffer;
|
||||
}
|
||||
|
||||
|
@ -55,10 +38,6 @@ public class OnnxTensor implements OnnxValue {
|
|||
return OnnxValueType.ONNX_TYPE_TENSOR;
|
||||
}
|
||||
|
||||
long getNativeHandle() {
|
||||
return nativeHandle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of
|
||||
* primitives if it has multiple dimensions.
|
||||
|
@ -108,11 +87,6 @@ public class OnnxTensor implements OnnxValue {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TensorInfo getInfo() {
|
||||
return info;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OnnxTensor(info=" + info.toString() + ")";
|
||||
|
@ -300,7 +274,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
* @param input A uint16_t representing an IEEE half precision float.
|
||||
* @return A float.
|
||||
*/
|
||||
private static float fp16ToFloat(short input) {
|
||||
static float fp16ToFloat(short input) {
|
||||
int output =
|
||||
((input & 0x8000) << 16) | (((input & 0x7c00) + 0x1C000) << 13) | ((input & 0x03FF) << 13);
|
||||
return Float.intBitsToFloat(output);
|
||||
|
@ -715,73 +689,20 @@ public class OnnxTensor implements OnnxValue {
|
|||
*/
|
||||
private static OnnxTensor createTensor(
|
||||
OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException {
|
||||
int bufferPos;
|
||||
long bufferSizeLong = data.remaining() * (long) type.size;
|
||||
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
|
||||
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
|
||||
// on the JVM, so we check for something 8 elements below the maximum size which
|
||||
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
|
||||
throw new IllegalStateException(
|
||||
"Cannot allocate a direct buffer of the requested size and type, size "
|
||||
+ data.remaining()
|
||||
+ ", type = "
|
||||
+ type);
|
||||
}
|
||||
// Now we know we're in range
|
||||
int bufferSize = data.remaining() * type.size;
|
||||
Buffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
bufferPos = data.position() * type.size;
|
||||
} else {
|
||||
// Copy the data to a new direct buffer, then restore the state of the input.
|
||||
int origPosition = data.position();
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
switch (type) {
|
||||
case FLOAT:
|
||||
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
|
||||
break;
|
||||
case DOUBLE:
|
||||
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
|
||||
break;
|
||||
case UINT8:
|
||||
case INT8:
|
||||
// buffer is already a ByteBuffer, no cast needed.
|
||||
tmp = buffer.put((ByteBuffer) data);
|
||||
break;
|
||||
case INT16:
|
||||
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
|
||||
break;
|
||||
case INT32:
|
||||
tmp = buffer.asIntBuffer().put((IntBuffer) data);
|
||||
break;
|
||||
case INT64:
|
||||
tmp = buffer.asLongBuffer().put((LongBuffer) data);
|
||||
break;
|
||||
case BOOL:
|
||||
case STRING:
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new IllegalStateException(
|
||||
"Impossible to reach here, managed to cast a buffer as an incorrect type");
|
||||
}
|
||||
data.position(origPosition);
|
||||
tmp.rewind();
|
||||
bufferPos = 0;
|
||||
}
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
|
||||
OrtUtil.BufferTuple tuple = OrtUtil.prepareBuffer(data, type);
|
||||
TensorInfo info = TensorInfo.constructFromBuffer(tuple.data, shape, type);
|
||||
return new OnnxTensor(
|
||||
createTensorFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
allocator.handle,
|
||||
tmp,
|
||||
bufferPos,
|
||||
bufferSize,
|
||||
tuple.data,
|
||||
tuple.pos,
|
||||
tuple.byteSize,
|
||||
shape,
|
||||
info.onnxType.value),
|
||||
allocator.handle,
|
||||
info,
|
||||
tmp);
|
||||
tuple.data);
|
||||
}
|
||||
|
||||
private static native long createTensor(
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Currently implemented by {@link OnnxTensor}, {@link OnnxSparseTensor}. Will be sealed to these
|
||||
* types one day.
|
||||
*/
|
||||
public abstract class OnnxTensorLike implements OnnxValue {
|
||||
static {
|
||||
try {
|
||||
OnnxRuntime.init();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load onnx-runtime library", e);
|
||||
}
|
||||
}
|
||||
|
||||
protected final long nativeHandle;
|
||||
|
||||
protected final long allocatorHandle;
|
||||
|
||||
protected final TensorInfo info;
|
||||
|
||||
/**
|
||||
* Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor).
|
||||
*
|
||||
* @param nativeHandle The pointer to the tensor.
|
||||
* @param allocatorHandle The pointer to the memory allocator.
|
||||
* @param info The tensor info.
|
||||
*/
|
||||
OnnxTensorLike(long nativeHandle, long allocatorHandle, TensorInfo info) {
|
||||
this.nativeHandle = nativeHandle;
|
||||
this.allocatorHandle = allocatorHandle;
|
||||
this.info = info;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the native pointer.
|
||||
*
|
||||
* @return The native pointer.
|
||||
*/
|
||||
long getNativeHandle() {
|
||||
return nativeHandle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a {@link TensorInfo} for this tensor.
|
||||
*
|
||||
* @return The tensor info.
|
||||
*/
|
||||
@Override
|
||||
public TensorInfo getInfo() {
|
||||
return info;
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
@ -8,9 +8,8 @@ import java.util.Map;
|
|||
|
||||
/**
|
||||
* Top interface for input and output values from ONNX models. Currently implemented by {@link
|
||||
* OnnxTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed to these types one day.
|
||||
*
|
||||
* <p>Does not support sparse tensors.
|
||||
* OnnxTensor}, {@link OnnxSparseTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed
|
||||
* to these types one day.
|
||||
*/
|
||||
public interface OnnxValue extends AutoCloseable {
|
||||
|
||||
|
@ -21,7 +20,8 @@ public interface OnnxValue extends AutoCloseable {
|
|||
ONNX_TYPE_SEQUENCE(2),
|
||||
ONNX_TYPE_MAP(3),
|
||||
ONNX_TYPE_OPAQUE(4),
|
||||
ONNX_TYPE_SPARSETENSOR(5);
|
||||
ONNX_TYPE_SPARSETENSOR(5),
|
||||
ONNX_TYPE_OPTIONAL(6);
|
||||
|
||||
/** The id number of this type in the C API. */
|
||||
public final int value;
|
||||
|
|
|
@ -203,7 +203,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code, the input names are invalid, or if
|
||||
* there are zero or too many inputs.
|
||||
*/
|
||||
public Result run(Map<String, OnnxTensor> inputs) throws OrtException {
|
||||
public Result run(Map<String, ? extends OnnxTensorLike> inputs) throws OrtException {
|
||||
return run(inputs, outputNames);
|
||||
}
|
||||
|
||||
|
@ -218,7 +218,8 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code, the input names are invalid, or if
|
||||
* there are zero or too many inputs.
|
||||
*/
|
||||
public Result run(Map<String, OnnxTensor> inputs, RunOptions runOptions) throws OrtException {
|
||||
public Result run(Map<String, ? extends OnnxTensorLike> inputs, RunOptions runOptions)
|
||||
throws OrtException {
|
||||
return run(inputs, outputNames, runOptions);
|
||||
}
|
||||
|
||||
|
@ -233,7 +234,7 @@ public class OrtSession implements AutoCloseable {
|
|||
* @throws OrtException If there was an error in native code, the input or output names are
|
||||
* invalid, or if there are zero or too many inputs or outputs.
|
||||
*/
|
||||
public Result run(Map<String, OnnxTensor> inputs, Set<String> requestedOutputs)
|
||||
public Result run(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs)
|
||||
throws OrtException {
|
||||
return run(inputs, requestedOutputs, null);
|
||||
}
|
||||
|
@ -241,7 +242,7 @@ public class OrtSession implements AutoCloseable {
|
|||
/**
|
||||
* Scores an input feed dict, returning the map of requested inferred outputs.
|
||||
*
|
||||
* <p>The outputs are sorted based on the supplied set traveral order.
|
||||
* <p>The outputs are sorted based on the supplied set traversal order.
|
||||
*
|
||||
* @param inputs The inputs to score.
|
||||
* @param requestedOutputs The requested outputs.
|
||||
|
@ -251,10 +252,12 @@ public class OrtSession implements AutoCloseable {
|
|||
* invalid, or if there are zero or too many inputs or outputs.
|
||||
*/
|
||||
public Result run(
|
||||
Map<String, OnnxTensor> inputs, Set<String> requestedOutputs, RunOptions runOptions)
|
||||
Map<String, ? extends OnnxTensorLike> inputs,
|
||||
Set<String> requestedOutputs,
|
||||
RunOptions runOptions)
|
||||
throws OrtException {
|
||||
if (!closed) {
|
||||
if (inputs.isEmpty() || (inputs.size() > numInputs)) {
|
||||
if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) {
|
||||
throw new OrtException(
|
||||
"Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size());
|
||||
}
|
||||
|
@ -268,7 +271,7 @@ public class OrtSession implements AutoCloseable {
|
|||
String[] inputNamesArray = new String[inputs.size()];
|
||||
long[] inputHandles = new long[inputs.size()];
|
||||
int i = 0;
|
||||
for (Map.Entry<String, OnnxTensor> t : inputs.entrySet()) {
|
||||
for (Map.Entry<String, ? extends OnnxTensorLike> t : inputs.entrySet()) {
|
||||
if (inputNames.contains(t.getKey())) {
|
||||
inputNamesArray[i] = t.getKey();
|
||||
inputHandles[i] = t.getValue().getNativeHandle();
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.ShortBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
|
@ -472,4 +480,87 @@ public final class OrtUtil {
|
|||
// 0.75 is the default JDK load factor
|
||||
return (int) (size / 0.75 + 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepares a buffer, either copying it if it's not direct, or computing it's size and position if
|
||||
* it is.
|
||||
*
|
||||
* @param data The buffer to prepare.
|
||||
* @param type The Java-side type.
|
||||
* @return The prepared buffer tuple.
|
||||
*/
|
||||
static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) {
|
||||
int bufferPos;
|
||||
long bufferSizeLong = data.remaining() * (long) type.size;
|
||||
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
|
||||
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
|
||||
// on the JVM, so we check for something 8 elements below the maximum size which
|
||||
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
|
||||
throw new IllegalStateException(
|
||||
"Cannot allocate a direct buffer of the requested size and type, size "
|
||||
+ data.remaining()
|
||||
+ ", type = "
|
||||
+ type);
|
||||
}
|
||||
// Now we know we're in range
|
||||
int bufferSize = data.remaining() * type.size;
|
||||
Buffer tmp;
|
||||
if (data.isDirect()) {
|
||||
tmp = data;
|
||||
bufferPos = data.position() * type.size;
|
||||
} else {
|
||||
// Copy the data to a new direct buffer, then restore the state of the input.
|
||||
int origPosition = data.position();
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
|
||||
switch (type) {
|
||||
case FLOAT:
|
||||
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
|
||||
break;
|
||||
case DOUBLE:
|
||||
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
|
||||
break;
|
||||
case UINT8:
|
||||
case INT8:
|
||||
// buffer is already a ByteBuffer, no cast needed.
|
||||
tmp = buffer.put((ByteBuffer) data);
|
||||
break;
|
||||
case INT16:
|
||||
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
|
||||
break;
|
||||
case INT32:
|
||||
tmp = buffer.asIntBuffer().put((IntBuffer) data);
|
||||
break;
|
||||
case INT64:
|
||||
tmp = buffer.asLongBuffer().put((LongBuffer) data);
|
||||
break;
|
||||
case BOOL:
|
||||
case STRING:
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new IllegalStateException(
|
||||
"Impossible to reach here, managed to cast a buffer as an incorrect type");
|
||||
}
|
||||
data.position(origPosition);
|
||||
tmp.rewind();
|
||||
bufferPos = 0;
|
||||
}
|
||||
|
||||
return new BufferTuple(tmp, bufferPos, bufferSize, data.remaining(), tmp != data);
|
||||
}
|
||||
|
||||
static final class BufferTuple {
|
||||
final Buffer data;
|
||||
final int pos;
|
||||
final long byteSize;
|
||||
final long size;
|
||||
final boolean isCopy;
|
||||
|
||||
BufferTuple(Buffer data, int pos, long byteSize, long size, boolean isCopy) {
|
||||
this.data = data;
|
||||
this.pos = pos;
|
||||
this.byteSize = byteSize;
|
||||
this.size = size;
|
||||
this.isCopy = isCopy;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -297,6 +297,39 @@ public class TensorInfo implements ValueInfo {
|
|||
Arrays.copyOf(shape, shape.length), type, OnnxTensorType.mapFromJavaType(type));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a TensorInfo from the supplied {@link OnnxSparseTensor.SparseTensor}.
|
||||
*
|
||||
* @param tensor The sparse tensor.
|
||||
* @param <T> The buffer type.
|
||||
* @return A TensorInfo for a sparse tensor.
|
||||
* @throws OrtException If the supplied tensor has too many elements for it's shape.
|
||||
*/
|
||||
public static <T extends Buffer> TensorInfo constructFromSparseTensor(
|
||||
OnnxSparseTensor.SparseTensor<T> tensor) throws OrtException {
|
||||
long[] shape = tensor.getDenseShape();
|
||||
|
||||
long elementCount = OrtUtil.elementCount(shape);
|
||||
|
||||
long bufferRemaining = tensor.getValues().remaining();
|
||||
|
||||
if (elementCount < bufferRemaining) {
|
||||
throw new OrtException(
|
||||
"Shape "
|
||||
+ Arrays.toString(shape)
|
||||
+ ", has at most "
|
||||
+ elementCount
|
||||
+ " elements but the buffer has "
|
||||
+ bufferRemaining
|
||||
+ " elements.");
|
||||
}
|
||||
|
||||
return new TensorInfo(
|
||||
Arrays.copyOf(shape, shape.length),
|
||||
tensor.getType(),
|
||||
OnnxTensorType.mapFromJavaType(tensor.getType()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the shape from a multidimensional array. Checks to see if the array is ragged or not.
|
||||
*
|
||||
|
|
|
@ -68,6 +68,45 @@ ExecutionMode convertExecutionMode(jint mode) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Must be kept in sync with OrtSparseFormat and OnnxSparseTensor.SparseTensorType
|
||||
* @param format The Java int.
|
||||
* @return The enum.
|
||||
*/
|
||||
OrtSparseFormat convertToOrtSparseFormat(jint format) {
|
||||
switch (format) {
|
||||
case 0:
|
||||
return ORT_SPARSE_UNDEFINED;
|
||||
case 1:
|
||||
return ORT_SPARSE_COO;
|
||||
case 2:
|
||||
return ORT_SPARSE_CSRC;
|
||||
case 4:
|
||||
return ORT_SPARSE_BLOCK_SPARSE;
|
||||
default:
|
||||
return ORT_SPARSE_UNDEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Must be kept in sync with OrtSparseFormat and OnnxSparseTensor.SparseTensorType
|
||||
* @param format The enum.
|
||||
* @return The Java int.
|
||||
*/
|
||||
jint convertFromOrtSparseFormat(OrtSparseFormat format) {
|
||||
switch (format) {
|
||||
case ORT_SPARSE_COO:
|
||||
return 1;
|
||||
case ORT_SPARSE_CSRC:
|
||||
return 2;
|
||||
case ORT_SPARSE_BLOCK_SPARSE:
|
||||
return 4;
|
||||
case ORT_SPARSE_UNDEFINED:
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Must be kept in sync with convertToONNXDataFormat
|
||||
*/
|
||||
|
@ -228,7 +267,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo
|
|||
}
|
||||
|
||||
switch (type) {
|
||||
case ONNX_TYPE_TENSOR: {
|
||||
case ONNX_TYPE_TENSOR:
|
||||
case ONNX_TYPE_SPARSETENSOR: {
|
||||
const OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(info, &tensorInfo));
|
||||
if (code == ORT_OK) {
|
||||
|
@ -257,7 +297,6 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo
|
|||
}
|
||||
case ONNX_TYPE_UNKNOWN:
|
||||
case ONNX_TYPE_OPAQUE:
|
||||
case ONNX_TYPE_SPARSETENSOR:
|
||||
default: {
|
||||
throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found.");
|
||||
return NULL;
|
||||
|
@ -869,6 +908,40 @@ jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocato
|
|||
return javaTensor;
|
||||
}
|
||||
|
||||
jobject createJavaSparseTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
|
||||
// Extract the type information
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Construct the TensorInfo object
|
||||
jobject tensorInfo = convertToTensorInfo(jniEnv, api, info);
|
||||
|
||||
// Release the info object
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (tensorInfo == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Lookup the sparse tensor type enum
|
||||
OrtSparseFormat format;
|
||||
code = checkOrtStatus(jniEnv,api,api->GetSparseTensorFormat(tensor, &format));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
jint sparseTensorInt = convertFromOrtSparseFormat(format);
|
||||
|
||||
// Construct the ONNXTensor object
|
||||
char *tensorClassName = "ai/onnxruntime/OnnxSparseTensor";
|
||||
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName);
|
||||
jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv, clazz, "<init>", "(JJILai/onnxruntime/TensorInfo;)V");
|
||||
jobject javaSparseTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, sparseTensorInt, tensorInfo);
|
||||
|
||||
return javaSparseTensor;
|
||||
}
|
||||
|
||||
jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence) {
|
||||
// Get the sequence info class
|
||||
static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
|
||||
|
@ -1026,12 +1099,14 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca
|
|||
case ONNX_TYPE_MAP: {
|
||||
return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue);
|
||||
}
|
||||
case ONNX_TYPE_SPARSETENSOR: {
|
||||
return createJavaSparseTensorFromONNX(jniEnv, api, allocator, onnxValue);
|
||||
}
|
||||
case ONNX_TYPE_UNKNOWN:
|
||||
case ONNX_TYPE_OPAQUE:
|
||||
case ONNX_TYPE_OPTIONAL:
|
||||
case ONNX_TYPE_SPARSETENSOR:
|
||||
default: {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR.");
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_OPTIONAL.");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,10 @@ GraphOptimizationLevel convertOptimizationLevel(jint level);
|
|||
|
||||
ExecutionMode convertExecutionMode(jint mode);
|
||||
|
||||
OrtSparseFormat convertToOrtSparseFormat(jint format);
|
||||
|
||||
jint convertFromOrtSparseFormat(OrtSparseFormat format);
|
||||
|
||||
jint convertFromONNXDataFormat(ONNXTensorElementDataType type);
|
||||
|
||||
ONNXTensorElementDataType convertToONNXDataFormat(jint type);
|
||||
|
@ -68,6 +72,8 @@ jdoubleArray createDoubleArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, Ort
|
|||
|
||||
jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
|
||||
|
||||
jobject createJavaSparseTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
|
||||
|
||||
jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence);
|
||||
|
||||
jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map);
|
||||
|
|
|
@ -0,0 +1,534 @@
|
|||
/*
|
||||
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <math.h>
|
||||
#include <stdlib.h>
|
||||
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
||||
#include "OrtJniUtil.h"
|
||||
#include "ai_onnxruntime_OnnxSparseTensor.h"
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: getIndicesBuffer
|
||||
* Signature: (JJ)Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesBuffer
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtValue* ortValue = (const OrtValue*) handle;
|
||||
OrtSparseFormat format;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
enum OrtSparseIndicesFormat indicesFormat;
|
||||
switch (format) {
|
||||
case ORT_SPARSE_COO:
|
||||
indicesFormat = ORT_SPARSE_COO_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_CSRC:
|
||||
indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_BLOCK_SPARSE:
|
||||
indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_UNDEFINED:
|
||||
default: {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Sparse format is ORT_SPARSE_UNDEFINED, cannot get indices");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
OrtTensorTypeAndShapeInfo* info = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
size_t arrSize = 0;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
||||
size_t sizeBytes = arrSize * typeSize;
|
||||
|
||||
uint8_t* arr = NULL;
|
||||
size_t indices_size = 0;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (indices_size != arrSize) {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size");
|
||||
return NULL;
|
||||
} else {
|
||||
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: getInnerIndicesBuffer
|
||||
* Signature: (JJ)Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesBuffer
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtValue* ortValue = (const OrtValue*) handle;
|
||||
OrtSparseFormat format;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
enum OrtSparseIndicesFormat indicesFormat;
|
||||
switch (format) {
|
||||
case ORT_SPARSE_CSRC:
|
||||
indicesFormat = ORT_SPARSE_CSR_INNER_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_COO:
|
||||
case ORT_SPARSE_BLOCK_SPARSE:
|
||||
case ORT_SPARSE_UNDEFINED:
|
||||
default: {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
||||
"Sparse format is ORT_SPARSE_COO, ORT_SPARSE_BLOCK_SPARSE, or ORT_SPARSE_UNDEFINED, inner indices are not defined.");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
OrtTensorTypeAndShapeInfo* info = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
size_t arrSize = 0;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
||||
size_t sizeBytes = arrSize * typeSize;
|
||||
|
||||
uint8_t* arr;
|
||||
size_t indices_size;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (indices_size != arrSize) {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size");
|
||||
return NULL;
|
||||
} else {
|
||||
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: getValuesBuffer
|
||||
* Signature: (JJ)Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesBuffer
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtValue* ortValue = (const OrtValue*) handle;
|
||||
OrtSparseFormat format;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
switch (format) {
|
||||
case ORT_SPARSE_COO:
|
||||
case ORT_SPARSE_CSRC:
|
||||
case ORT_SPARSE_BLOCK_SPARSE: {
|
||||
OrtTensorTypeAndShapeInfo* info = NULL;
|
||||
checkOrtStatus(jniEnv, api, api->GetSparseTensorValuesTypeAndShape(ortValue, &info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
size_t arrSize = 0;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
ONNXTensorElementDataType onnxTypeEnum;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t typeSize = onnxTypeSize(onnxTypeEnum);
|
||||
size_t sizeBytes = arrSize * typeSize;
|
||||
|
||||
uint8_t* arr = NULL;
|
||||
checkOrtStatus(jniEnv, api, api->GetSparseTensorValues(ortValue, (const void**)&arr));
|
||||
|
||||
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
|
||||
}
|
||||
case ORT_SPARSE_UNDEFINED:
|
||||
default: {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
||||
"Sparse format is ORT_SPARSE_UNDEFINED, cannot get data");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: getInnerIndicesShape
|
||||
* Signature: (JJ)[J;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesShape
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtValue* value = (const OrtValue*) handle;
|
||||
|
||||
// Extract the info
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, ORT_SPARSE_CSR_INNER_INDICES, &info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Extract the shape
|
||||
size_t numDim = 0;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
int64_t* dimensions = malloc(sizeof(int64_t) * numDim);
|
||||
if (dimensions == NULL) {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
|
||||
// Free the info
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
free((void*)dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Create the long array for the shape.
|
||||
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
|
||||
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
|
||||
|
||||
// Free the dimensions array
|
||||
free((void*)dimensions);
|
||||
|
||||
return shape;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: getIndicesShape
|
||||
* Signature: (JJ)[J;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesShape
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtValue* value = (const OrtValue*) handle;
|
||||
|
||||
// Get the indices format
|
||||
OrtSparseFormat format;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(value, &format));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
enum OrtSparseIndicesFormat indicesFormat;
|
||||
switch (format) {
|
||||
case ORT_SPARSE_CSRC:
|
||||
indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_COO:
|
||||
indicesFormat = ORT_SPARSE_COO_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_BLOCK_SPARSE:
|
||||
indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES;
|
||||
break;
|
||||
case ORT_SPARSE_UNDEFINED:
|
||||
default: {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
||||
"Sparse format is ORT_SPARSE_UNDEFINED, indices are not defined.");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the info
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, indicesFormat, &info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Extract the shape
|
||||
size_t numDim = 0;
|
||||
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
int64_t* dimensions = malloc(sizeof(int64_t) * numDim);
|
||||
if (dimensions == NULL) {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
|
||||
// Free the info
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
free((void*)dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Create the long array for the shape.
|
||||
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
|
||||
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
|
||||
// Free the dimensions array
|
||||
free((void*)dimensions);
|
||||
|
||||
return shape;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: getValuesShape
|
||||
* Signature: (JJ)[J;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesShape
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtValue* value = (const OrtValue*) handle;
|
||||
|
||||
// Extract the info
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetSparseTensorValuesTypeAndShape(value,&info));
|
||||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Extract the shape
|
||||
size_t numDim = 0;
|
||||
code = checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&numDim));
|
||||
if (code != ORT_OK) {
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
int64_t* dimensions = malloc(sizeof(int64_t)*numDim);
|
||||
if (dimensions == NULL) {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
return NULL;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv,api,api->GetDimensions(info, dimensions, numDim));
|
||||
// Free the info
|
||||
api->ReleaseTensorTypeAndShapeInfo(info);
|
||||
if (code != ORT_OK) {
|
||||
free((void*)dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Create the long array for the shape.
|
||||
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
|
||||
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
|
||||
|
||||
// Free the dimensions array
|
||||
free((void*)dimensions);
|
||||
|
||||
return shape;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: close
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxSparseTensor_close(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
api->ReleaseValue((OrtValue*)handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: createCSRCSparseTensorFromBuffer
|
||||
* Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createCSRCSparseTensorFromBuffer
|
||||
(JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle,
|
||||
jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize,
|
||||
jobject innerIndicesBuffer, jint innerIndicesBufferPos, jlong innerIndicesBufferSize,
|
||||
jobject dataBuffer, jint dataBufferPos,
|
||||
jlongArray denseShape, jlongArray valuesShape,
|
||||
jint onnxTypeJava) {
|
||||
(void) cls; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
const OrtMemoryInfo* allocatorInfo;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
|
||||
if (code != ORT_OK) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Convert types to ONNX C enums
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
||||
|
||||
// Extract the buffers
|
||||
char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer);
|
||||
char* innerIndicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, innerIndicesBuffer);
|
||||
char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer);
|
||||
// Increment by bufferPos bytes
|
||||
indicesBufferArr = indicesBufferArr + indicesBufferPos;
|
||||
innerIndicesBufferArr = innerIndicesBufferArr + innerIndicesBufferPos;
|
||||
dataBufferArr = dataBufferArr + dataBufferPos;
|
||||
|
||||
// Extract the dense shape information
|
||||
jboolean mkCopy;
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape);
|
||||
|
||||
// Extract the value shape
|
||||
jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy);
|
||||
jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape);
|
||||
|
||||
// Create the OrtValue
|
||||
OrtValue* ortValue = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr,
|
||||
(int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue));
|
||||
// Release shapes
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT);
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT);
|
||||
if (code != ORT_OK) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Fill it with indices
|
||||
code = checkOrtStatus(jniEnv, api, api->UseCsrIndices(ortValue,
|
||||
(int64_t *) innerIndicesBufferArr, innerIndicesBufferSize,
|
||||
(int64_t *) indicesBufferArr, indicesBufferSize));
|
||||
if (code != ORT_OK) {
|
||||
return 0;
|
||||
} else {
|
||||
// Return the pointer to the OrtValue
|
||||
return (jlong) ortValue;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OnnxSparseTensor
|
||||
* Method: createSparseTensorFromBuffer
|
||||
* Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[J[JII)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createSparseTensorFromBuffer
|
||||
(JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle,
|
||||
jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize,
|
||||
jobject dataBuffer, jint dataBufferPos,
|
||||
jlongArray denseShape, jlongArray indicesShape, jlongArray valuesShape,
|
||||
jint onnxTypeJava, jint sparsityTypeJava) {
|
||||
(void) cls; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
const OrtMemoryInfo* allocatorInfo;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
|
||||
if (code != ORT_OK) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Convert types to ONNX C enums
|
||||
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
|
||||
OrtSparseFormat sparsityType = convertToOrtSparseFormat(sparsityTypeJava);
|
||||
|
||||
// Extract the buffers
|
||||
char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer);
|
||||
char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer);
|
||||
// Increment by bufferPos bytes
|
||||
indicesBufferArr = indicesBufferArr + indicesBufferPos;
|
||||
dataBufferArr = dataBufferArr + dataBufferPos;
|
||||
|
||||
// Extract the dense shape information
|
||||
jboolean mkCopy;
|
||||
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy);
|
||||
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape);
|
||||
|
||||
// Extract the value shape
|
||||
jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy);
|
||||
jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape);
|
||||
|
||||
// Create the OrtValue
|
||||
OrtValue* ortValue = NULL;
|
||||
code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr,
|
||||
(int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue));
|
||||
|
||||
// Release shapes
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT);
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT);
|
||||
if (code != ORT_OK) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Fill it with indices
|
||||
switch (sparsityType) {
|
||||
case ORT_SPARSE_COO: {
|
||||
// The cast is because we compute the offset in bytes in Java.
|
||||
code = checkOrtStatus(jniEnv, api, api->UseCooIndices(ortValue, (int64_t *) indicesBufferArr,
|
||||
indicesBufferSize));
|
||||
break;
|
||||
}
|
||||
case ORT_SPARSE_BLOCK_SPARSE: {
|
||||
// Extract the indices shape
|
||||
jlong* indicesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, indicesShape, &mkCopy);
|
||||
jsize indicesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, indicesShape);
|
||||
|
||||
// The cast is because we compute the offset in bytes in Java.
|
||||
code = checkOrtStatus(jniEnv, api, api->UseBlockSparseIndices(ortValue, (int64_t *) indicesShapeArr,
|
||||
indicesShapeLen, (int32_t *) indicesBufferArr));
|
||||
|
||||
// Release the indices shape
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, indicesShape, indicesShapeArr, JNI_ABORT);
|
||||
break;
|
||||
}
|
||||
case ORT_SPARSE_CSRC:
|
||||
case ORT_SPARSE_UNDEFINED: {
|
||||
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
|
||||
"These types are unsupported by this method - ORT_SPARSE_CSRC, ORT_SPARSE_UNDEFINED");
|
||||
code = ORT_NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
if (code != ORT_OK) {
|
||||
return 0;
|
||||
} else {
|
||||
// Return the pointer to the OrtValue
|
||||
return (jlong) ortValue;
|
||||
}
|
||||
}
|
|
@ -59,6 +59,10 @@ public class InferenceTest {
|
|||
|
||||
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
public static Path getResourcePath(String path) {
|
||||
return new File(InferenceTest.class.getResource(path).getFile()).toPath();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void environmentTest() {
|
||||
// Checks that the environment instance is the same.
|
||||
|
|
|
@ -0,0 +1,450 @@
|
|||
/*
|
||||
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import static ai.onnxruntime.InferenceTest.getResourcePath;
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SparseTensorTest {
|
||||
|
||||
@Test
|
||||
public void testCSRC() throws OrtException {
|
||||
String modelPath = getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
|
||||
|
||||
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
|
||||
long[] shape = new long[] {3, 3};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
* [
|
||||
* 0 1 0
|
||||
* 1 0 1
|
||||
* 4 0 6
|
||||
* ]
|
||||
*/
|
||||
LongBuffer outerIndices =
|
||||
ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
outerIndices.put(0);
|
||||
outerIndices.put(1);
|
||||
outerIndices.put(3);
|
||||
outerIndices.put(5);
|
||||
outerIndices.rewind();
|
||||
LongBuffer innerIndices =
|
||||
ByteBuffer.allocateDirect(5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
innerIndices.put(1);
|
||||
innerIndices.put(0);
|
||||
innerIndices.put(2);
|
||||
innerIndices.put(0);
|
||||
innerIndices.put(2);
|
||||
innerIndices.rewind();
|
||||
|
||||
FloatBuffer data =
|
||||
ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
data.put(1);
|
||||
data.put(1);
|
||||
data.put(1);
|
||||
data.put(4);
|
||||
data.put(6);
|
||||
data.rewind();
|
||||
|
||||
OnnxSparseTensor.CSRCTensor csrcTensor =
|
||||
new OnnxSparseTensor.CSRCTensor(
|
||||
outerIndices, innerIndices, data, shape, OnnxJavaType.FLOAT, 5);
|
||||
OnnxSparseTensor tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor);
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
OrtSession.Result result = session.run(inputMap);
|
||||
|
||||
OnnxTensor outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(shape, outputTensor.getInfo().getShape());
|
||||
float[] output = outputTensor.getFloatBuffer().array();
|
||||
float[] expected = new float[] {0, 1, 0, 1, 0, 1, 4, 0, 6};
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
inputMap.clear();
|
||||
|
||||
// check that the get methods return new buffers which exist past the tensor lifetime.
|
||||
Buffer valuesOne = tensor.getValuesBuffer();
|
||||
Buffer valuesTwo = tensor.getValuesBuffer();
|
||||
Buffer indicesOne = tensor.getIndicesBuffer();
|
||||
Buffer indicesTwo = tensor.getIndicesBuffer();
|
||||
Buffer innerIndicesOne = tensor.getInnerIndicesBuffer();
|
||||
Buffer innerIndicesTwo = tensor.getInnerIndicesBuffer();
|
||||
tensor.close();
|
||||
assertEquals(valuesOne, valuesTwo);
|
||||
assertFalse(valuesOne == valuesTwo);
|
||||
assertEquals(indicesOne, indicesTwo);
|
||||
assertFalse(indicesOne == indicesTwo);
|
||||
assertEquals(innerIndicesOne, innerIndicesTwo);
|
||||
assertFalse(innerIndicesOne == innerIndicesTwo);
|
||||
|
||||
long[] rectangularShape = new long[] {2, 3};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
* [
|
||||
* 1 0 3
|
||||
* 0 5 6
|
||||
* ]
|
||||
*/
|
||||
outerIndices =
|
||||
ByteBuffer.allocateDirect(3 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
outerIndices.put(0);
|
||||
outerIndices.put(2);
|
||||
outerIndices.put(4);
|
||||
outerIndices.rewind();
|
||||
innerIndices =
|
||||
ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
innerIndices.put(0);
|
||||
innerIndices.put(2);
|
||||
innerIndices.put(1);
|
||||
innerIndices.put(2);
|
||||
innerIndices.rewind();
|
||||
|
||||
data = ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
data.put(1);
|
||||
data.put(3);
|
||||
data.put(5);
|
||||
data.put(6);
|
||||
data.rewind();
|
||||
|
||||
csrcTensor =
|
||||
new OnnxSparseTensor.CSRCTensor(
|
||||
outerIndices, innerIndices, data, rectangularShape, OnnxJavaType.FLOAT, 4);
|
||||
tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor);
|
||||
|
||||
assertArrayEquals(new long[] {3}, tensor.getIndicesShape());
|
||||
assertArrayEquals(new long[] {4}, tensor.getInnerIndicesShape());
|
||||
assertArrayEquals(new long[] {4}, tensor.getValuesShape());
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
result = session.run(inputMap);
|
||||
|
||||
outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(rectangularShape, outputTensor.getInfo().getShape());
|
||||
output = outputTensor.getFloatBuffer().array();
|
||||
expected = new float[] {1, 0, 3, 0, 5, 6};
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
tensor.close();
|
||||
inputMap.clear();
|
||||
denseIdMatrix.close();
|
||||
|
||||
denseIdMatrix = makeIdentityMatrix(env, 4);
|
||||
long[] vectorShape = new long[] {1, 4};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
* [
|
||||
* 1 0 0 4
|
||||
* ]
|
||||
*/
|
||||
outerIndices =
|
||||
ByteBuffer.allocateDirect(2 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
outerIndices.put(0);
|
||||
outerIndices.put(2);
|
||||
outerIndices.rewind();
|
||||
innerIndices =
|
||||
ByteBuffer.allocateDirect(2 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
innerIndices.put(0);
|
||||
innerIndices.put(3);
|
||||
innerIndices.rewind();
|
||||
|
||||
data = ByteBuffer.allocateDirect(2 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
data.put(1);
|
||||
data.put(4);
|
||||
data.rewind();
|
||||
|
||||
csrcTensor =
|
||||
new OnnxSparseTensor.CSRCTensor(
|
||||
outerIndices, innerIndices, data, vectorShape, OnnxJavaType.FLOAT, 2);
|
||||
tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor);
|
||||
|
||||
assertArrayEquals(new long[] {2}, tensor.getIndicesShape());
|
||||
assertArrayEquals(new long[] {2}, tensor.getInnerIndicesShape());
|
||||
assertArrayEquals(new long[] {2}, tensor.getValuesShape());
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
result = session.run(inputMap);
|
||||
|
||||
outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(vectorShape, outputTensor.getInfo().getShape());
|
||||
output = outputTensor.getFloatBuffer().array();
|
||||
expected = new float[] {1, 0, 0, 4};
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
tensor.close();
|
||||
inputMap.clear();
|
||||
denseIdMatrix.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCOO() throws OrtException {
|
||||
String modelPath = getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
|
||||
|
||||
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
|
||||
long[] shape = new long[] {3, 3};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
* [
|
||||
* 0 1 0
|
||||
* 1 0 1
|
||||
* 4 0 6
|
||||
* ]
|
||||
*/
|
||||
LongBuffer indices =
|
||||
ByteBuffer.allocateDirect(2 * 5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
indices.put(0);
|
||||
indices.put(1);
|
||||
indices.put(1);
|
||||
indices.put(0);
|
||||
indices.put(1);
|
||||
indices.put(2);
|
||||
indices.put(2);
|
||||
indices.put(0);
|
||||
indices.put(2);
|
||||
indices.put(2);
|
||||
indices.rewind();
|
||||
|
||||
FloatBuffer data =
|
||||
ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
data.put(1);
|
||||
data.put(1);
|
||||
data.put(1);
|
||||
data.put(4);
|
||||
data.put(6);
|
||||
data.rewind();
|
||||
|
||||
OnnxSparseTensor.COOTensor cooTensor =
|
||||
new OnnxSparseTensor.COOTensor(
|
||||
indices, new long[] {5, 2}, data, shape, OnnxJavaType.FLOAT, 5);
|
||||
OnnxSparseTensor tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
OrtSession.Result result = session.run(inputMap);
|
||||
|
||||
OnnxTensor outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(shape, outputTensor.getInfo().getShape());
|
||||
float[] output = outputTensor.getFloatBuffer().array();
|
||||
float[] expected = new float[] {0, 1, 0, 1, 0, 1, 4, 0, 6};
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
tensor.close();
|
||||
inputMap.clear();
|
||||
|
||||
/* disabled as sparse_dense_matmul doesn't support COO tensors with 1d indices
|
||||
// Run the same tensor through, but using 1d indexing rather than 2d indexing
|
||||
indices = ByteBuffer.allocateDirect(5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
indices.put(1);
|
||||
indices.put(3);
|
||||
indices.put(5);
|
||||
indices.put(6);
|
||||
indices.put(8);
|
||||
indices.rewind();
|
||||
|
||||
cooTensor = new OnnxSparseTensor.COOTensor(indices, new long[]{5}, data, shape, OnnxJavaType.FLOAT, 5);
|
||||
tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
result = session.run(inputMap);
|
||||
|
||||
outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(shape, outputTensor.getInfo().getShape());
|
||||
output = outputTensor.getFloatBuffer().array();
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
tensor.close();
|
||||
inputMap.clear();
|
||||
*/
|
||||
|
||||
long[] rectangularShape = new long[] {2, 3};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
* [
|
||||
* 1 0 3
|
||||
* 0 5 6
|
||||
* ]
|
||||
*/
|
||||
indices =
|
||||
ByteBuffer.allocateDirect(2 * 4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
indices.put(0);
|
||||
indices.put(0);
|
||||
indices.put(0);
|
||||
indices.put(2);
|
||||
indices.put(1);
|
||||
indices.put(1);
|
||||
indices.put(1);
|
||||
indices.put(2);
|
||||
indices.rewind();
|
||||
|
||||
data = ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
data.put(1);
|
||||
data.put(3);
|
||||
data.put(5);
|
||||
data.put(6);
|
||||
data.rewind();
|
||||
|
||||
cooTensor =
|
||||
new OnnxSparseTensor.COOTensor(
|
||||
indices, new long[] {4, 2}, data, rectangularShape, OnnxJavaType.FLOAT, 4);
|
||||
tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
|
||||
|
||||
assertArrayEquals(new long[] {4, 2}, tensor.getIndicesShape());
|
||||
assertArrayEquals(new long[] {4}, tensor.getValuesShape());
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
result = session.run(inputMap);
|
||||
|
||||
outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(rectangularShape, outputTensor.getInfo().getShape());
|
||||
output = outputTensor.getFloatBuffer().array();
|
||||
expected = new float[] {1, 0, 3, 0, 5, 6};
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
tensor.close();
|
||||
inputMap.clear();
|
||||
denseIdMatrix.close();
|
||||
|
||||
denseIdMatrix = makeIdentityMatrix(env, 4);
|
||||
long[] vectorShape = new long[] {1, 4};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
* [
|
||||
* 1
|
||||
* 0
|
||||
* 0
|
||||
* 4
|
||||
* ]
|
||||
*/
|
||||
indices = ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
|
||||
indices.put(0);
|
||||
indices.put(0);
|
||||
indices.put(0);
|
||||
indices.put(3);
|
||||
indices.rewind();
|
||||
|
||||
data = ByteBuffer.allocateDirect(2 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
data.put(1);
|
||||
data.put(4);
|
||||
data.rewind();
|
||||
|
||||
cooTensor =
|
||||
new OnnxSparseTensor.COOTensor(
|
||||
indices, new long[] {2, 2}, data, vectorShape, OnnxJavaType.FLOAT, 2);
|
||||
tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
|
||||
|
||||
assertArrayEquals(new long[] {2, 2}, tensor.getIndicesShape());
|
||||
assertArrayEquals(new long[] {2}, tensor.getValuesShape());
|
||||
|
||||
inputMap.put("sparse_A", tensor);
|
||||
inputMap.put("dense_B", denseIdMatrix);
|
||||
|
||||
result = session.run(inputMap);
|
||||
|
||||
outputTensor = (OnnxTensor) result.get(0);
|
||||
assertArrayEquals(vectorShape, outputTensor.getInfo().getShape());
|
||||
output = outputTensor.getFloatBuffer().array();
|
||||
expected = new float[] {1, 0, 0, 4};
|
||||
assertArrayEquals(expected, output, 1e-6f);
|
||||
result.close();
|
||||
tensor.close();
|
||||
inputMap.clear();
|
||||
denseIdMatrix.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCOOOutput() throws OrtException {
|
||||
String modelPath = getResourcePath("/sparse_initializer_as_output.onnx").toString();
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, NodeInfo> outputs = session.getOutputInfo();
|
||||
assertEquals(1, outputs.size());
|
||||
|
||||
NodeInfo info = outputs.get("values");
|
||||
assertNotNull(info);
|
||||
assertTrue(info.getInfo() instanceof TensorInfo);
|
||||
|
||||
TensorInfo outputInfo = (TensorInfo) info.getInfo();
|
||||
assertArrayEquals(new long[] {3, 3}, outputInfo.getShape());
|
||||
assertEquals(
|
||||
TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, outputInfo.onnxType);
|
||||
assertEquals(OnnxJavaType.FLOAT, outputInfo.type);
|
||||
|
||||
OrtSession.Result result = session.run(Collections.emptyMap());
|
||||
OnnxValue output = result.get("values").get();
|
||||
|
||||
assertTrue(output instanceof OnnxSparseTensor);
|
||||
|
||||
OnnxSparseTensor sparseTensor = (OnnxSparseTensor) output;
|
||||
|
||||
assertEquals(OnnxSparseTensor.SparseTensorType.COO, sparseTensor.getSparseTensorType());
|
||||
|
||||
assertArrayEquals(new long[] {3}, sparseTensor.getIndicesShape());
|
||||
assertArrayEquals(new long[] {3}, sparseTensor.getValuesShape());
|
||||
assertArrayEquals(new long[] {3, 3}, sparseTensor.getInfo().getShape());
|
||||
|
||||
OnnxSparseTensor.SparseTensor<? extends Buffer> javaTensor = sparseTensor.getValue();
|
||||
|
||||
assertTrue(javaTensor instanceof OnnxSparseTensor.COOTensor);
|
||||
|
||||
OnnxSparseTensor.COOTensor cooTensor = (OnnxSparseTensor.COOTensor) javaTensor;
|
||||
|
||||
long[] indices = new long[3];
|
||||
cooTensor.getIndices().get(indices);
|
||||
float[] data = new float[3];
|
||||
((FloatBuffer) cooTensor.getValues()).get(data);
|
||||
|
||||
assertArrayEquals(new long[] {2, 3, 5}, indices);
|
||||
assertArrayEquals(
|
||||
new float[] {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f}, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
|
||||
float[][] values = new float[size][size];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
values[i][i] = 1.0f;
|
||||
}
|
||||
|
||||
return OnnxTensor.createTensor(env, values);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
dmitrism:Î
|
||||
F
|
||||
sparse_A
|
||||
dense_Bdense_YSpMM"SparseToDenseMatMul:
com.microsoftSpMMZ*
|
||||
sparse_AB
|
||||
A_dim_1
|
||||
inner_dimZ)
|
||||
dense_B
|
||||
|
||||
inner_dim
|
||||
B_dim_2b'
|
||||
dense_Y
|
||||
|
||||
A_dim_1
|
||||
B_dim_2B
|
||||
com.microsoft
|
Загрузка…
Ссылка в новой задаче