**Description**:

Adds support for creating and receiving sparse tensors in the ORT Java
API.

CSRC and COO tensors as inputs are tested, but there is no op which
accepts a block sparse tensor to test. COO tensors are tested as
outputs, but there is no op which emits a CSRC or block sparse tensor to
test.

**Motivation and Context**
- Why is this change required? What problem does it solve? Request to
expose ORT sparse tensor support in Java.

cc @yuslepukhin
This commit is contained in:
Adam Pocock 2022-11-22 13:29:24 -05:00 коммит произвёл GitHub
Родитель 8b0e0f4927
Коммит dd2c031d95
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 2218 добавлений и 106 удалений

Просмотреть файл

@ -0,0 +1,920 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import static ai.onnxruntime.OnnxTensor.fp16ToFloat;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;
/**
* A Java object wrapping an OnnxSparseTensor.
*
* <p>Sparse tensors support a variety of formats, and the {@link #getValue} method returns a
* different static inner class representing each type.
*/
public final class OnnxSparseTensor extends OnnxTensorLike {
private final SparseTensorType sparseTensorType;
// Held to prevent deallocation while used in native code.
private final Buffer indices;
private final LongBuffer innerIndices;
private final Buffer values;
/**
* Construct a sparse tensor from JNI.
*
* @param nativeHandle The tensor native handle.
* @param allocatorHandle The allocator handle.
* @param sparseType The sparsity type.
* @param info The tensor info.
*/
OnnxSparseTensor(long nativeHandle, long allocatorHandle, int sparseType, TensorInfo info) {
this(
nativeHandle,
allocatorHandle,
SparseTensorType.mapFromInt(sparseType),
info,
null,
null,
null);
}
/**
* Construct a COO or block sparse tensor.
*
* @param nativeHandle The tensor native handle.
* @param allocatorHandle The allocator handle.
* @param sparseType The sparsity type.
* @param info The tensor info.
* @param indices The indices buffer.
* @param values The data buffer.
*/
OnnxSparseTensor(
long nativeHandle,
long allocatorHandle,
SparseTensorType sparseType,
TensorInfo info,
Buffer indices,
Buffer values) {
this(nativeHandle, allocatorHandle, sparseType, info, indices, null, values);
}
/**
* Construct a sparse tensor.
*
* <p>If the tensor is COO or block sparse then innerIndices may be null.
*
* @param nativeHandle The tensor native handle.
* @param allocatorHandle The allocator handle.
* @param sparseType The sparsity type.
* @param info The tensor info.
* @param indices The indices buffer.
* @param innerIndices The inner indices buffer.
* @param values The data buffer.
*/
OnnxSparseTensor(
long nativeHandle,
long allocatorHandle,
SparseTensorType sparseType,
TensorInfo info,
Buffer indices,
LongBuffer innerIndices,
Buffer values) {
super(nativeHandle, allocatorHandle, info);
this.sparseTensorType = sparseType;
this.indices = indices;
this.innerIndices = innerIndices;
this.values = values;
}
/**
* Creates a Sparse Tensor in ORT from the Java side representation.
*
* @param env The OrtEnvironment.
* @param tensor The Java side representation.
* @param <T> The buffer type.
* @return The sparse tensor in ORT.
* @throws OrtException If the tensor could not be created or was invalid.
*/
public static <T extends Buffer> OnnxSparseTensor createSparseTensor(
OrtEnvironment env, SparseTensor<T> tensor) throws OrtException {
return createSparseTensor(env, env.defaultAllocator, tensor);
}
static <T extends Buffer> OnnxSparseTensor createSparseTensor(
OrtEnvironment env, OrtAllocator allocator, SparseTensor<T> tensor) throws OrtException {
if (!allocator.isClosed()) {
TensorInfo info = TensorInfo.constructFromSparseTensor(tensor);
OnnxJavaType indicesType = tensor.getIndicesType();
OrtUtil.BufferTuple indicesTuple = OrtUtil.prepareBuffer(tensor.getIndices(), indicesType);
OrtUtil.BufferTuple valuesTuple = OrtUtil.prepareBuffer(tensor.getValues(), info.type);
if (!((indicesTuple.data instanceof LongBuffer)
|| (indicesTuple.data instanceof IntBuffer))) {
throw new IllegalStateException(
"Unexpected type of indices buffer, found "
+ indicesTuple.data.getClass()
+ ", expected IntBuffer or LongBuffer");
}
// Replace with a type switch when using JDK 17+.
switch (tensor.getSparsityType()) {
case COO:
case BLOCK_SPARSE:
return new OnnxSparseTensor(
createSparseTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
indicesTuple.data,
indicesTuple.pos,
indicesTuple.size,
valuesTuple.data,
valuesTuple.pos,
info.shape,
tensor.getIndicesShape(),
tensor.getValuesShape(),
info.onnxType.value,
tensor.getSparsityType().value),
allocator.handle,
tensor.getSparsityType(),
info,
indicesTuple.data,
valuesTuple.data);
case CSRC:
OrtUtil.BufferTuple innerIndicesTuple =
OrtUtil.prepareBuffer(((CSRCTensor) tensor).getInnerIndices(), indicesType);
return new OnnxSparseTensor(
createCSRCSparseTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
indicesTuple.data,
indicesTuple.pos,
indicesTuple.size,
innerIndicesTuple.data,
innerIndicesTuple.pos,
innerIndicesTuple.size,
valuesTuple.data,
valuesTuple.pos,
info.shape,
tensor.getValuesShape(),
info.onnxType.value),
allocator.handle,
tensor.getSparsityType(),
info,
indicesTuple.data,
(LongBuffer) innerIndicesTuple.data,
valuesTuple.data);
case UNDEFINED:
default:
throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor.");
}
} else {
throw new IllegalStateException(
"Trying to create an OnnxSparseTensor on a closed OrtAllocator.");
}
}
@Override
public OnnxValueType getType() {
return OnnxValueType.ONNX_TYPE_SPARSETENSOR;
}
@Override
public SparseTensor<? extends Buffer> getValue() throws OrtException {
Buffer buffer = getValuesBuffer();
long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
switch (sparseTensorType) {
case COO:
return new COOTensor(
(LongBuffer) getIndicesBuffer(),
indicesShape,
buffer,
info.shape,
info.type,
buffer.remaining());
case CSRC:
return new CSRCTensor(
(LongBuffer) getIndicesBuffer(),
getInnerIndicesBuffer(),
buffer,
info.shape,
info.type,
buffer.remaining());
case BLOCK_SPARSE:
long[] valuesShape = getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle);
return new BlockSparseTensor(
(IntBuffer) getIndicesBuffer(),
indicesShape,
buffer,
valuesShape,
info.shape,
info.type,
buffer.remaining());
case UNDEFINED:
default:
throw new IllegalStateException("Undefined sparsity type in this sparse tensor.");
}
}
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Returns the type of this OnnxSparseTensor.
*
* @return The sparsity type.
*/
public SparseTensorType getSparseTensorType() {
return sparseTensorType;
}
/**
* Gets a copy of the indices.
*
* <p>These are the outer indices if it's a CSRC sparse tensor.
*
* <p>It's a {@link LongBuffer} if COO or CSRC, and {@link IntBuffer} if Block Sparse.
*
* @return The indices.
*/
public Buffer getIndicesBuffer() {
switch (sparseTensorType) {
case COO:
case CSRC:
{
LongBuffer longBuf =
getIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
.order(ByteOrder.nativeOrder())
.asLongBuffer();
LongBuffer output = LongBuffer.allocate(longBuf.capacity());
output.put(longBuf);
output.rewind();
return output;
}
case BLOCK_SPARSE:
{
IntBuffer intBuf =
getIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
.order(ByteOrder.nativeOrder())
.asIntBuffer();
IntBuffer output = IntBuffer.allocate(intBuf.capacity());
output.put(intBuf);
output.rewind();
return output;
}
case UNDEFINED:
default:
throw new IllegalStateException("UNDEFINED sparse tensor type.");
}
}
/**
* Gets a copy of the inner indices in a CSRC sparse tensor.
*
* <p>Throws {@link IllegalStateException} if called on a different sparse tensor type.
*
* @return The inner indices.
*/
public LongBuffer getInnerIndicesBuffer() {
if (sparseTensorType == SparseTensorType.CSRC) {
LongBuffer buf =
getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
.order(ByteOrder.nativeOrder())
.asLongBuffer();
LongBuffer output = LongBuffer.allocate(buf.capacity());
output.put(buf);
output.rewind();
return output;
} else {
throw new IllegalStateException(
"Inner indices are only available for CSRC sparse tensors, this sparse tensor is "
+ sparseTensorType);
}
}
/**
* Gets a copy of the data buffer.
*
* <p>As with {@link OnnxTensor} fp16 values are upcast into fp32 and returned as a {@link
* FloatBuffer}.
*
* @return The data buffer.
*/
public Buffer getValuesBuffer() {
ByteBuffer buffer =
getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
switch (info.type) {
case FLOAT:
if (info.onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
ShortBuffer shortBuffer = buffer.asShortBuffer();
int bufferCap = shortBuffer.capacity();
FloatBuffer output = FloatBuffer.allocate(bufferCap);
for (int i = 0; i < bufferCap; i++) {
output.put(fp16ToFloat(shortBuffer.get(i)));
}
output.rewind();
return output;
} else if (info.onnxType
== TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
throw new IllegalArgumentException("BFloat16 is not supported.");
} else {
// regular fp32
FloatBuffer floatBuf = buffer.asFloatBuffer();
FloatBuffer output = FloatBuffer.allocate(floatBuf.capacity());
output.put(floatBuf);
output.rewind();
return output;
}
case DOUBLE:
{
DoubleBuffer doubleBuf = buffer.asDoubleBuffer();
DoubleBuffer output = DoubleBuffer.allocate(doubleBuf.capacity());
output.put(doubleBuf);
output.rewind();
return output;
}
case INT16:
{
ShortBuffer shortBuf = buffer.asShortBuffer();
ShortBuffer output = ShortBuffer.allocate(shortBuf.capacity());
output.put(shortBuf);
output.rewind();
return output;
}
case INT32:
{
IntBuffer intBuf = buffer.asIntBuffer();
IntBuffer output = IntBuffer.allocate(intBuf.capacity());
output.put(intBuf);
output.rewind();
return output;
}
case INT64:
{
LongBuffer longBuf = buffer.asLongBuffer();
LongBuffer output = LongBuffer.allocate(longBuf.capacity());
output.put(longBuf);
output.rewind();
return output;
}
case BOOL:
case INT8:
case UINT8:
{
ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
output.put(buffer);
output.rewind();
return output;
}
case STRING:
throw new IllegalStateException("Unsupported data type String");
case UNKNOWN:
default:
throw new IllegalStateException("Unsupported data type");
}
}
/**
* Gets the shape of the (outer) indices.
*
* @return The indices shape.
*/
public long[] getIndicesShape() {
return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Gets the shape of the inner indices in a CSRC sparse tensor.
*
* @return The indices shape.
*/
public long[] getInnerIndicesShape() {
if (sparseTensorType == SparseTensorType.CSRC) {
return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
} else {
throw new IllegalStateException(
"Inner indices are only available for CSRC sparse tensors, this sparse tensor is "
+ sparseTensorType);
}
}
/**
* Gets the shape of the values.
*
* @return The values shape.
*/
public long[] getValuesShape() {
return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle);
}
/**
* Gets the shape of the (outer) indices.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
* @return The indices shape.
*/
private native long[] getIndicesShape(long apiHandle, long nativeHandle);
/**
* Gets the shape of the inner indices.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
* @return The inner indices shape.
*/
private native long[] getInnerIndicesShape(long apiHandle, long nativeHandle);
/**
* Gets the shape of the values.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
* @return The values shape.
*/
private native long[] getValuesShape(long apiHandle, long nativeHandle);
/**
* Wraps the indices in a direct byte buffer.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
* @return A ByteBuffer wrapping the indices.
*/
private native ByteBuffer getIndicesBuffer(long apiHandle, long nativeHandle);
/**
* Wraps the inner indices in a direct byte buffer.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
* @return A ByteBuffer wrapping the inner indices.
*/
private native ByteBuffer getInnerIndicesBuffer(long apiHandle, long nativeHandle);
/**
* Wraps the data in a direct byte buffer.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
* @return A ByteBuffer wrapping the indices.
*/
private native ByteBuffer getValuesBuffer(long apiHandle, long nativeHandle);
/**
* Closes the sparse tensor.
*
* @param apiHandle The OrtApi pointer.
* @param nativeHandle The OrtSparseTensor pointer.
*/
private native void close(long apiHandle, long nativeHandle);
/**
* Creates a sparse CSRC sparse tensor.
*
* <p>The buffers must be kept alive for the lifetime of the ORT sparse tensor object.
*
* @param apiHandle The ORT API pointer.
* @param allocatorHandle The allocator pointer.
* @param indicesData The outer indices.
* @param indicesBufferPos The outer indices position in bytes.
* @param indicesBufferSize The outer indices buffer size in longs.
* @param innerIndicesData The inner indices.
* @param innerIndicesBufferPos The inner indices position in bytes.
* @param innerIndicesBufferSize The inner indices buffer size in longs.
* @param values The data.
* @param bufferPos The data position in bytes.
* @param denseShape The dense shape of the tensor.
* @param valuesShape The shape of the values (should be a vector).
* @param onnxType The type of the values.
* @return A pointer to an ORT sparse tensor value.
* @throws OrtException If the tensor could not be created.
*/
private static native long createCSRCSparseTensorFromBuffer(
long apiHandle,
long allocatorHandle,
Buffer indicesData,
int indicesBufferPos,
long indicesBufferSize,
Buffer innerIndicesData,
int innerIndicesBufferPos,
long innerIndicesBufferSize,
Buffer values,
int bufferPos,
long[] denseShape,
long[] valuesShape,
int onnxType)
throws OrtException;
/**
* Creates a sparse COO or block sparse tensor.
*
* <p>The buffers must be kept alive for the lifetime of the ORT sparse tensor object.
*
* @param apiHandle The ORT API pointer.
* @param allocatorHandle The allocator pointer.
* @param indicesData The indices.
* @param indicesBufferPos The indices position in bytes.
* @param indicesBufferSize The indices buffer size in longs.
* @param values The data.
* @param bufferPos The data position in bytes.
* @param denseShape The dense shape of the tensor.
* @param indicesShape The shape of the indices (a vector or matrix for COO, and a matrix for
* block sparse).
* @param valuesShape The shape of the values (a vector for COO, and a block shape for block
* sparse).
* @param onnxType The type of the values.
* @param sparsityType The sparsity type.
* @return A pointer to an ORT sparse tensor value.
* @throws OrtException If the tensor could not be created.
*/
private static native long createSparseTensorFromBuffer(
long apiHandle,
long allocatorHandle,
Buffer indicesData,
int indicesBufferPos,
long indicesBufferSize,
Buffer values,
int bufferPos,
long[] denseShape,
long[] indicesShape,
long[] valuesShape,
int onnxType,
int sparsityType)
throws OrtException;
/**
* The type of the sparse tensor.
*
* <p>Should be synchronized with OrtSparseFormat in the C API.
*/
public enum SparseTensorType {
/** Undefined sparse tensor. */
UNDEFINED(0),
/** COO sparse tensor. */
COO(1),
/** CSR or CSC sparse tensor. */
CSRC(2),
/** Block sparse tensor. */
BLOCK_SPARSE(4);
/** The int value mirroring OrtSparseFormat. */
public final int value;
private static final SparseTensorType[] values = new SparseTensorType[5];
static {
values[0] = UNDEFINED;
values[1] = COO;
values[2] = CSRC;
values[3] = UNDEFINED;
values[4] = BLOCK_SPARSE;
}
SparseTensorType(int value) {
this.value = value;
}
/**
* Maps from an int in native land into a SparseTensorType instance.
*
* @param value The value to lookup.
* @return The enum instance.
*/
public static SparseTensorType mapFromInt(int value) {
if ((value > 0) && (value < values.length)) {
return values[value];
} else {
return UNDEFINED;
}
}
}
/**
* Abstract base class for Java sparse tensors
*
* <p>Will be sealed to {@link COOTensor}, {@link CSRCTensor} and {@link BlockSparseTensor} one
* day.
*/
public abstract static class SparseTensor<T extends Buffer> {
private final long[] indicesShape;
private final long[] valuesShape;
private final long[] denseShape;
private final OnnxJavaType type;
private final long numNonZero;
final T indices;
final Buffer values;
SparseTensor(
T indices,
long[] indicesShape,
Buffer values,
long[] valuesShape,
long[] denseShape,
OnnxJavaType type,
long numNonZero) {
this.indices = indices;
this.indicesShape = indicesShape;
this.values = values;
this.valuesShape = valuesShape;
this.denseShape = denseShape;
this.type = type;
this.numNonZero = numNonZero;
if (values.remaining() != numNonZero) {
throw new IllegalArgumentException(
"Expected numNonZero and data.remaining to be equal, found "
+ numNonZero
+ " and "
+ values.remaining()
+ " respectively");
}
if (type == OnnxJavaType.STRING) {
throw new IllegalArgumentException("String SparseTensors are not supported.");
}
}
/**
* Gets the dense shape of the sparse tensor.
*
* @return The sparse tensor shape.
*/
public long[] getDenseShape() {
return denseShape;
}
/**
* The data type of the sparse tensor.
*
* @return The sparse tensor data type.
*/
public OnnxJavaType getType() {
return type;
}
/**
* The number of non-zero elements.
*
* @return The number of non-zero elements.
*/
public long getNumNonZeroElements() {
return numNonZero;
}
/**
* Get the indices buffer.
*
* @return The indices buffer.
*/
public T getIndices() {
return indices;
}
/**
* Get the value buffer.
*
* @return The value buffer.
*/
public Buffer getValues() {
return values;
}
/**
* Gets the shape of the values of the sparse tensor.
*
* @return The sparse tensor value shape.
*/
public long[] getValuesShape() {
return valuesShape;
}
/**
* Gets the shape of the indices of the sparse tensor.
*
* @return The sparse tensor indices shape.
*/
public long[] getIndicesShape() {
return indicesShape;
}
/**
* The sparsity type of the sparse tensor.
*
* @return The sparse tensor sparsity type.
*/
public abstract SparseTensorType getSparsityType();
/**
* The indices type of the sparse tensor.
*
* <p>Only {@link OnnxJavaType#INT32} and {@link OnnxJavaType#INT64} are supported.
*
* @return The sparse tensor indices type.
*/
public abstract OnnxJavaType getIndicesType();
}
/** The Java side representation of a COO sparse tensor. */
public static final class COOTensor extends SparseTensor<LongBuffer> {
/**
* Creates a COO sparse tensor suitable for constructing an ORT Sparse Tensor.
*
* @param indices The indices. Should be a 1d vector, or a 2d vector.
* @param indicesShape The shape of the indices.
* @param values The data.
* @param denseShape The dense shape.
* @param type The data type.
* @param numNonZero The number of non-zero elements.
*/
public COOTensor(
LongBuffer indices,
long[] indicesShape,
Buffer values,
long[] denseShape,
OnnxJavaType type,
long numNonZero) {
super(indices, indicesShape, values, new long[] {numNonZero}, denseShape, type, numNonZero);
if ((indicesShape.length > 2)
|| (indicesShape.length == 0)
|| (indicesShape[0] != numNonZero)) {
throw new IllegalArgumentException(
"Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found "
+ Arrays.toString(indicesShape));
}
long elementCount = OrtUtil.elementCount(indicesShape);
if (elementCount != indices.remaining()) {
throw new IllegalArgumentException(
"Unexpected number of indices found in buffer, expected "
+ elementCount
+ " found "
+ indices.remaining());
}
if (values.remaining() != numNonZero) {
throw new IllegalArgumentException(
"Expected data.remaining() - "
+ values.remaining()
+ " to equal numNonZero - "
+ numNonZero);
}
}
@Override
public OnnxJavaType getIndicesType() {
return OnnxJavaType.INT64;
}
@Override
public SparseTensorType getSparsityType() {
return SparseTensorType.COO;
}
}
/** The Java side representation of a CSRC sparse tensor. */
public static final class CSRCTensor extends SparseTensor<LongBuffer> {
private final LongBuffer innerIndices;
/**
* Creates a CSRC sparse tensor suitable for constructing an ORT Sparse Tensor.
*
* @param outerIndices The outer indices.
* @param innerIndices The inner indices.
* @param values The data.
* @param denseShape The dense shape.
* @param type The data type.
* @param numNonZero The number of non-zero elements.
*/
public CSRCTensor(
LongBuffer outerIndices,
LongBuffer innerIndices,
Buffer values,
long[] denseShape,
OnnxJavaType type,
long numNonZero) {
super(
outerIndices,
new long[] {outerIndices.remaining()},
values,
new long[] {numNonZero},
denseShape,
type,
numNonZero);
this.innerIndices = innerIndices;
long expectedRows = denseShape[0] + 1;
if (outerIndices.remaining() != expectedRows) {
throw new IllegalArgumentException(
"Outer indices should be equal to the number of rows + 1 in the dense shape, found "
+ outerIndices.remaining()
+ ", expected "
+ expectedRows);
}
if (innerIndices.remaining() != numNonZero) {
throw new IllegalArgumentException(
"Inner indices should be equal to the number of non-zero elements, found "
+ innerIndices.remaining()
+ ", expected "
+ numNonZero);
}
}
/**
* Gets the shape of the inner indices.
*
* @return The inner indices shape.
*/
public long[] getInnerIndicesShape() {
return new long[] {innerIndices.remaining()};
}
/**
* Gets the inner indices buffer.
*
* @return The inner indices buffer.
*/
public LongBuffer getInnerIndices() {
return innerIndices;
}
@Override
public OnnxJavaType getIndicesType() {
return OnnxJavaType.INT64;
}
@Override
public SparseTensorType getSparsityType() {
return SparseTensorType.CSRC;
}
}
/** The Java side representation of a block sparse tensor. */
public static final class BlockSparseTensor extends SparseTensor<IntBuffer> {
/**
* Construct a block sparse tensor.
*
* @param indices The indices.
* @param indicesShape The shape of the indices.
* @param values The data.
* @param valuesShape The shape of the data.
* @param denseShape The dense shape.
* @param type The data type.
* @param numNonZero The number of non-zero elements.
*/
public BlockSparseTensor(
IntBuffer indices,
long[] indicesShape,
Buffer values,
long[] valuesShape,
long[] denseShape,
OnnxJavaType type,
long numNonZero) {
super(indices, indicesShape, values, valuesShape, denseShape, type, numNonZero);
if (OrtUtil.elementCount(valuesShape) != numNonZero) {
throw new IllegalArgumentException(
"Expected "
+ numNonZero
+ " entries in the data shape, found "
+ Arrays.toString(valuesShape));
}
if (numNonZero != values.remaining()) {
throw new IllegalArgumentException(
"Expected " + numNonZero + " elements in the data buffer, found " + values.remaining());
}
if (OrtUtil.elementCount(indicesShape) != indices.remaining()) {
throw new IllegalArgumentException(
"Expected "
+ OrtUtil.elementCount(indicesShape)
+ " elements in the indices buffer, found "
+ indices.remaining());
}
if (valuesShape.length < 3) {
throw new IllegalArgumentException(
"Expected [numBlocks, blockSize, blockSize] or larger, but data shape was "
+ Arrays.toString(valuesShape));
}
if (indicesShape.length < 2) {
throw new IllegalArgumentException(
"Expected [numBlocks, co-ordinates] or larger, but indices shape was "
+ Arrays.toString(indicesShape));
}
}
@Override
public OnnxJavaType getIndicesType() {
return OnnxJavaType.INT32;
}
@Override
public SparseTensorType getSparsityType() {
return SparseTensorType.BLOCK_SPARSE;
}
}
}

Просмотреть файл

@ -1,10 +1,9 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@ -18,21 +17,7 @@ import java.nio.ShortBuffer;
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
* returned as outputs.
*/
public class OnnxTensor implements OnnxValue {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}
private final long nativeHandle;
private final long allocatorHandle;
private final TensorInfo info;
public class OnnxTensor extends OnnxTensorLike {
/**
* This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does
* not go out of scope while the OnnxTensor exists.
@ -44,9 +29,7 @@ public class OnnxTensor implements OnnxValue {
}
OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
super(nativeHandle, allocatorHandle, info);
this.buffer = buffer;
}
@ -55,10 +38,6 @@ public class OnnxTensor implements OnnxValue {
return OnnxValueType.ONNX_TYPE_TENSOR;
}
long getNativeHandle() {
return nativeHandle;
}
/**
* Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of
* primitives if it has multiple dimensions.
@ -108,11 +87,6 @@ public class OnnxTensor implements OnnxValue {
}
}
@Override
public TensorInfo getInfo() {
return info;
}
@Override
public String toString() {
return "OnnxTensor(info=" + info.toString() + ")";
@ -300,7 +274,7 @@ public class OnnxTensor implements OnnxValue {
* @param input A uint16_t representing an IEEE half precision float.
* @return A float.
*/
private static float fp16ToFloat(short input) {
static float fp16ToFloat(short input) {
int output =
((input & 0x8000) << 16) | (((input & 0x7c00) + 0x1C000) << 13) | ((input & 0x03FF) << 13);
return Float.intBitsToFloat(output);
@ -715,73 +689,20 @@ public class OnnxTensor implements OnnxValue {
*/
private static OnnxTensor createTensor(
OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException {
int bufferPos;
long bufferSizeLong = data.remaining() * (long) type.size;
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
// on the JVM, so we check for something 8 elements below the maximum size which
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
throw new IllegalStateException(
"Cannot allocate a direct buffer of the requested size and type, size "
+ data.remaining()
+ ", type = "
+ type);
}
// Now we know we're in range
int bufferSize = data.remaining() * type.size;
Buffer tmp;
if (data.isDirect()) {
tmp = data;
bufferPos = data.position() * type.size;
} else {
// Copy the data to a new direct buffer, then restore the state of the input.
int origPosition = data.position();
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
switch (type) {
case FLOAT:
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
break;
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case UINT8:
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);
break;
case INT16:
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
break;
case INT32:
tmp = buffer.asIntBuffer().put((IntBuffer) data);
break;
case INT64:
tmp = buffer.asLongBuffer().put((LongBuffer) data);
break;
case BOOL:
case STRING:
case UNKNOWN:
default:
throw new IllegalStateException(
"Impossible to reach here, managed to cast a buffer as an incorrect type");
}
data.position(origPosition);
tmp.rewind();
bufferPos = 0;
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
OrtUtil.BufferTuple tuple = OrtUtil.prepareBuffer(data, type);
TensorInfo info = TensorInfo.constructFromBuffer(tuple.data, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferPos,
bufferSize,
tuple.data,
tuple.pos,
tuple.byteSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
tuple.data);
}
private static native long createTensor(

Просмотреть файл

@ -0,0 +1,59 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.io.IOException;
/**
* Currently implemented by {@link OnnxTensor}, {@link OnnxSparseTensor}. Will be sealed to these
* types one day.
*/
public abstract class OnnxTensorLike implements OnnxValue {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}
protected final long nativeHandle;
protected final long allocatorHandle;
protected final TensorInfo info;
/**
* Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor).
*
* @param nativeHandle The pointer to the tensor.
* @param allocatorHandle The pointer to the memory allocator.
* @param info The tensor info.
*/
OnnxTensorLike(long nativeHandle, long allocatorHandle, TensorInfo info) {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
}
/**
* Returns the native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}
/**
* Returns a {@link TensorInfo} for this tensor.
*
* @return The tensor info.
*/
@Override
public TensorInfo getInfo() {
return info;
}
}

Просмотреть файл

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -8,9 +8,8 @@ import java.util.Map;
/**
* Top interface for input and output values from ONNX models. Currently implemented by {@link
* OnnxTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed to these types one day.
*
* <p>Does not support sparse tensors.
* OnnxTensor}, {@link OnnxSparseTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed
* to these types one day.
*/
public interface OnnxValue extends AutoCloseable {
@ -21,7 +20,8 @@ public interface OnnxValue extends AutoCloseable {
ONNX_TYPE_SEQUENCE(2),
ONNX_TYPE_MAP(3),
ONNX_TYPE_OPAQUE(4),
ONNX_TYPE_SPARSETENSOR(5);
ONNX_TYPE_SPARSETENSOR(5),
ONNX_TYPE_OPTIONAL(6);
/** The id number of this type in the C API. */
public final int value;

Просмотреть файл

@ -203,7 +203,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code, the input names are invalid, or if
* there are zero or too many inputs.
*/
public Result run(Map<String, OnnxTensor> inputs) throws OrtException {
public Result run(Map<String, ? extends OnnxTensorLike> inputs) throws OrtException {
return run(inputs, outputNames);
}
@ -218,7 +218,8 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code, the input names are invalid, or if
* there are zero or too many inputs.
*/
public Result run(Map<String, OnnxTensor> inputs, RunOptions runOptions) throws OrtException {
public Result run(Map<String, ? extends OnnxTensorLike> inputs, RunOptions runOptions)
throws OrtException {
return run(inputs, outputNames, runOptions);
}
@ -233,7 +234,7 @@ public class OrtSession implements AutoCloseable {
* @throws OrtException If there was an error in native code, the input or output names are
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(Map<String, OnnxTensor> inputs, Set<String> requestedOutputs)
public Result run(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs)
throws OrtException {
return run(inputs, requestedOutputs, null);
}
@ -241,7 +242,7 @@ public class OrtSession implements AutoCloseable {
/**
* Scores an input feed dict, returning the map of requested inferred outputs.
*
* <p>The outputs are sorted based on the supplied set traveral order.
* <p>The outputs are sorted based on the supplied set traversal order.
*
* @param inputs The inputs to score.
* @param requestedOutputs The requested outputs.
@ -251,10 +252,12 @@ public class OrtSession implements AutoCloseable {
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(
Map<String, OnnxTensor> inputs, Set<String> requestedOutputs, RunOptions runOptions)
Map<String, ? extends OnnxTensorLike> inputs,
Set<String> requestedOutputs,
RunOptions runOptions)
throws OrtException {
if (!closed) {
if (inputs.isEmpty() || (inputs.size() > numInputs)) {
if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) {
throw new OrtException(
"Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size());
}
@ -268,7 +271,7 @@ public class OrtSession implements AutoCloseable {
String[] inputNamesArray = new String[inputs.size()];
long[] inputHandles = new long[inputs.size()];
int i = 0;
for (Map.Entry<String, OnnxTensor> t : inputs.entrySet()) {
for (Map.Entry<String, ? extends OnnxTensorLike> t : inputs.entrySet()) {
if (inputNames.contains(t.getKey())) {
inputNamesArray[i] = t.getKey();
inputHandles[i] = t.getValue().getNativeHandle();

Просмотреть файл

@ -1,10 +1,18 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.ArrayList;
import java.util.Arrays;
@ -472,4 +480,87 @@ public final class OrtUtil {
// 0.75 is the default JDK load factor
return (int) (size / 0.75 + 1);
}
/**
* Prepares a buffer, either copying it if it's not direct, or computing it's size and position if
* it is.
*
* @param data The buffer to prepare.
* @param type The Java-side type.
* @return The prepared buffer tuple.
*/
static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) {
int bufferPos;
long bufferSizeLong = data.remaining() * (long) type.size;
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
// on the JVM, so we check for something 8 elements below the maximum size which
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
throw new IllegalStateException(
"Cannot allocate a direct buffer of the requested size and type, size "
+ data.remaining()
+ ", type = "
+ type);
}
// Now we know we're in range
int bufferSize = data.remaining() * type.size;
Buffer tmp;
if (data.isDirect()) {
tmp = data;
bufferPos = data.position() * type.size;
} else {
// Copy the data to a new direct buffer, then restore the state of the input.
int origPosition = data.position();
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
switch (type) {
case FLOAT:
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
break;
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case UINT8:
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);
break;
case INT16:
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
break;
case INT32:
tmp = buffer.asIntBuffer().put((IntBuffer) data);
break;
case INT64:
tmp = buffer.asLongBuffer().put((LongBuffer) data);
break;
case BOOL:
case STRING:
case UNKNOWN:
default:
throw new IllegalStateException(
"Impossible to reach here, managed to cast a buffer as an incorrect type");
}
data.position(origPosition);
tmp.rewind();
bufferPos = 0;
}
return new BufferTuple(tmp, bufferPos, bufferSize, data.remaining(), tmp != data);
}
static final class BufferTuple {
final Buffer data;
final int pos;
final long byteSize;
final long size;
final boolean isCopy;
BufferTuple(Buffer data, int pos, long byteSize, long size, boolean isCopy) {
this.data = data;
this.pos = pos;
this.byteSize = byteSize;
this.size = size;
this.isCopy = isCopy;
}
}
}

Просмотреть файл

@ -297,6 +297,39 @@ public class TensorInfo implements ValueInfo {
Arrays.copyOf(shape, shape.length), type, OnnxTensorType.mapFromJavaType(type));
}
/**
* Constructs a TensorInfo from the supplied {@link OnnxSparseTensor.SparseTensor}.
*
* @param tensor The sparse tensor.
* @param <T> The buffer type.
* @return A TensorInfo for a sparse tensor.
* @throws OrtException If the supplied tensor has too many elements for it's shape.
*/
public static <T extends Buffer> TensorInfo constructFromSparseTensor(
OnnxSparseTensor.SparseTensor<T> tensor) throws OrtException {
long[] shape = tensor.getDenseShape();
long elementCount = OrtUtil.elementCount(shape);
long bufferRemaining = tensor.getValues().remaining();
if (elementCount < bufferRemaining) {
throw new OrtException(
"Shape "
+ Arrays.toString(shape)
+ ", has at most "
+ elementCount
+ " elements but the buffer has "
+ bufferRemaining
+ " elements.");
}
return new TensorInfo(
Arrays.copyOf(shape, shape.length),
tensor.getType(),
OnnxTensorType.mapFromJavaType(tensor.getType()));
}
/**
* Extracts the shape from a multidimensional array. Checks to see if the array is ragged or not.
*

Просмотреть файл

@ -68,6 +68,45 @@ ExecutionMode convertExecutionMode(jint mode) {
}
}
/**
* Must be kept in sync with OrtSparseFormat and OnnxSparseTensor.SparseTensorType
* @param format The Java int.
* @return The enum.
*/
OrtSparseFormat convertToOrtSparseFormat(jint format) {
switch (format) {
case 0:
return ORT_SPARSE_UNDEFINED;
case 1:
return ORT_SPARSE_COO;
case 2:
return ORT_SPARSE_CSRC;
case 4:
return ORT_SPARSE_BLOCK_SPARSE;
default:
return ORT_SPARSE_UNDEFINED;
}
}
/**
* Must be kept in sync with OrtSparseFormat and OnnxSparseTensor.SparseTensorType
* @param format The enum.
* @return The Java int.
*/
jint convertFromOrtSparseFormat(OrtSparseFormat format) {
switch (format) {
case ORT_SPARSE_COO:
return 1;
case ORT_SPARSE_CSRC:
return 2;
case ORT_SPARSE_BLOCK_SPARSE:
return 4;
case ORT_SPARSE_UNDEFINED:
default:
return 0;
}
}
/**
* Must be kept in sync with convertToONNXDataFormat
*/
@ -228,7 +267,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo
}
switch (type) {
case ONNX_TYPE_TENSOR: {
case ONNX_TYPE_TENSOR:
case ONNX_TYPE_SPARSETENSOR: {
const OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(info, &tensorInfo));
if (code == ORT_OK) {
@ -257,7 +297,6 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo
}
case ONNX_TYPE_UNKNOWN:
case ONNX_TYPE_OPAQUE:
case ONNX_TYPE_SPARSETENSOR:
default: {
throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found.");
return NULL;
@ -869,6 +908,40 @@ jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocato
return javaTensor;
}
jobject createJavaSparseTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
// Extract the type information
OrtTensorTypeAndShapeInfo* info;
OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &info));
if (code != ORT_OK) {
return NULL;
}
// Construct the TensorInfo object
jobject tensorInfo = convertToTensorInfo(jniEnv, api, info);
// Release the info object
api->ReleaseTensorTypeAndShapeInfo(info);
if (tensorInfo == NULL) {
return NULL;
}
// Lookup the sparse tensor type enum
OrtSparseFormat format;
code = checkOrtStatus(jniEnv,api,api->GetSparseTensorFormat(tensor, &format));
if (code != ORT_OK) {
return NULL;
}
jint sparseTensorInt = convertFromOrtSparseFormat(format);
// Construct the ONNXTensor object
char *tensorClassName = "ai/onnxruntime/OnnxSparseTensor";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName);
jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv, clazz, "<init>", "(JJILai/onnxruntime/TensorInfo;)V");
jobject javaSparseTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, sparseTensorInt, tensorInfo);
return javaSparseTensor;
}
jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence) {
// Get the sequence info class
static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
@ -1026,12 +1099,14 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca
case ONNX_TYPE_MAP: {
return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue);
}
case ONNX_TYPE_SPARSETENSOR: {
return createJavaSparseTensorFromONNX(jniEnv, api, allocator, onnxValue);
}
case ONNX_TYPE_UNKNOWN:
case ONNX_TYPE_OPAQUE:
case ONNX_TYPE_OPTIONAL:
case ONNX_TYPE_SPARSETENSOR:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR.");
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_OPTIONAL.");
return NULL;
}
}

Просмотреть файл

@ -28,6 +28,10 @@ GraphOptimizationLevel convertOptimizationLevel(jint level);
ExecutionMode convertExecutionMode(jint mode);
OrtSparseFormat convertToOrtSparseFormat(jint format);
jint convertFromOrtSparseFormat(OrtSparseFormat format);
jint convertFromONNXDataFormat(ONNXTensorElementDataType type);
ONNXTensorElementDataType convertToONNXDataFormat(jint type);
@ -68,6 +72,8 @@ jdoubleArray createDoubleArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, Ort
jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
jobject createJavaSparseTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence);
jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map);

Просмотреть файл

@ -0,0 +1,534 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include <math.h>
#include <stdlib.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OnnxSparseTensor.h"
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: getIndicesBuffer
* Signature: (JJ)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesBuffer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtValue* ortValue = (const OrtValue*) handle;
OrtSparseFormat format;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
if (code != ORT_OK) {
return NULL;
}
enum OrtSparseIndicesFormat indicesFormat;
switch (format) {
case ORT_SPARSE_COO:
indicesFormat = ORT_SPARSE_COO_INDICES;
break;
case ORT_SPARSE_CSRC:
indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES;
break;
case ORT_SPARSE_BLOCK_SPARSE:
indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES;
break;
case ORT_SPARSE_UNDEFINED:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Sparse format is ORT_SPARSE_UNDEFINED, cannot get indices");
return NULL;
}
}
OrtTensorTypeAndShapeInfo* info = NULL;
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info));
if (code != ORT_OK) {
return NULL;
}
size_t arrSize = 0;
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
ONNXTensorElementDataType onnxTypeEnum;
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
return NULL;
}
size_t typeSize = onnxTypeSize(onnxTypeEnum);
size_t sizeBytes = arrSize * typeSize;
uint8_t* arr = NULL;
size_t indices_size = 0;
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr));
if (code != ORT_OK) {
return NULL;
}
if (indices_size != arrSize) {
throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size");
return NULL;
} else {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
}
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: getInnerIndicesBuffer
* Signature: (JJ)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesBuffer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtValue* ortValue = (const OrtValue*) handle;
OrtSparseFormat format;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
if (code != ORT_OK) {
return NULL;
}
enum OrtSparseIndicesFormat indicesFormat;
switch (format) {
case ORT_SPARSE_CSRC:
indicesFormat = ORT_SPARSE_CSR_INNER_INDICES;
break;
case ORT_SPARSE_COO:
case ORT_SPARSE_BLOCK_SPARSE:
case ORT_SPARSE_UNDEFINED:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
"Sparse format is ORT_SPARSE_COO, ORT_SPARSE_BLOCK_SPARSE, or ORT_SPARSE_UNDEFINED, inner indices are not defined.");
return NULL;
}
}
OrtTensorTypeAndShapeInfo* info = NULL;
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info));
if (code != ORT_OK) {
return NULL;
}
size_t arrSize = 0;
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
ONNXTensorElementDataType onnxTypeEnum;
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
return NULL;
}
size_t typeSize = onnxTypeSize(onnxTypeEnum);
size_t sizeBytes = arrSize * typeSize;
uint8_t* arr;
size_t indices_size;
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr));
if (code != ORT_OK) {
return NULL;
}
if (indices_size != arrSize) {
throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size");
return NULL;
} else {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
}
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: getValuesBuffer
* Signature: (JJ)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesBuffer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtValue* ortValue = (const OrtValue*) handle;
OrtSparseFormat format;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format));
if (code != ORT_OK) {
return NULL;
}
switch (format) {
case ORT_SPARSE_COO:
case ORT_SPARSE_CSRC:
case ORT_SPARSE_BLOCK_SPARSE: {
OrtTensorTypeAndShapeInfo* info = NULL;
checkOrtStatus(jniEnv, api, api->GetSparseTensorValuesTypeAndShape(ortValue, &info));
if (code != ORT_OK) {
return NULL;
}
size_t arrSize = 0;
code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
ONNXTensorElementDataType onnxTypeEnum;
code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum));
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
return NULL;
}
size_t typeSize = onnxTypeSize(onnxTypeEnum);
size_t sizeBytes = arrSize * typeSize;
uint8_t* arr = NULL;
checkOrtStatus(jniEnv, api, api->GetSparseTensorValues(ortValue, (const void**)&arr));
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes);
}
case ORT_SPARSE_UNDEFINED:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
"Sparse format is ORT_SPARSE_UNDEFINED, cannot get data");
return NULL;
}
}
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: getInnerIndicesShape
* Signature: (JJ)[J;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesShape
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtValue* value = (const OrtValue*) handle;
// Extract the info
OrtTensorTypeAndShapeInfo* info;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, ORT_SPARSE_CSR_INNER_INDICES, &info));
if (code != ORT_OK) {
return NULL;
}
// Extract the shape
size_t numDim = 0;
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
int64_t* dimensions = malloc(sizeof(int64_t) * numDim);
if (dimensions == NULL) {
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
// Free the info
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
free((void*)dimensions);
return NULL;
}
// Create the long array for the shape.
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
// Free the dimensions array
free((void*)dimensions);
return shape;
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: getIndicesShape
* Signature: (JJ)[J;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesShape
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtValue* value = (const OrtValue*) handle;
// Get the indices format
OrtSparseFormat format;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(value, &format));
if (code != ORT_OK) {
return NULL;
}
enum OrtSparseIndicesFormat indicesFormat;
switch (format) {
case ORT_SPARSE_CSRC:
indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES;
break;
case ORT_SPARSE_COO:
indicesFormat = ORT_SPARSE_COO_INDICES;
break;
case ORT_SPARSE_BLOCK_SPARSE:
indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES;
break;
case ORT_SPARSE_UNDEFINED:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
"Sparse format is ORT_SPARSE_UNDEFINED, indices are not defined.");
return NULL;
}
}
// Extract the info
OrtTensorTypeAndShapeInfo* info;
code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, indicesFormat, &info));
if (code != ORT_OK) {
return NULL;
}
// Extract the shape
size_t numDim = 0;
code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
int64_t* dimensions = malloc(sizeof(int64_t) * numDim);
if (dimensions == NULL) {
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
// Free the info
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
free((void*)dimensions);
return NULL;
}
// Create the long array for the shape.
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
// Free the dimensions array
free((void*)dimensions);
return shape;
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: getValuesShape
* Signature: (JJ)[J;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesShape
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtValue* value = (const OrtValue*) handle;
// Extract the info
OrtTensorTypeAndShapeInfo* info;
OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetSparseTensorValuesTypeAndShape(value,&info));
if (code != ORT_OK) {
return NULL;
}
// Extract the shape
size_t numDim = 0;
code = checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&numDim));
if (code != ORT_OK) {
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
int64_t* dimensions = malloc(sizeof(int64_t)*numDim);
if (dimensions == NULL) {
throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array");
api->ReleaseTensorTypeAndShapeInfo(info);
return NULL;
}
code = checkOrtStatus(jniEnv,api,api->GetDimensions(info, dimensions, numDim));
// Free the info
api->ReleaseTensorTypeAndShapeInfo(info);
if (code != ORT_OK) {
free((void*)dimensions);
return NULL;
}
// Create the long array for the shape.
jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
(*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
// Free the dimensions array
free((void*)dimensions);
return shape;
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxSparseTensor_close(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseValue((OrtValue*)handle);
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: createCSRCSparseTensorFromBuffer
* Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[JI)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createCSRCSparseTensorFromBuffer
(JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle,
jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize,
jobject innerIndicesBuffer, jint innerIndicesBufferPos, jlong innerIndicesBufferSize,
jobject dataBuffer, jint dataBufferPos,
jlongArray denseShape, jlongArray valuesShape,
jint onnxTypeJava) {
(void) cls; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
const OrtMemoryInfo* allocatorInfo;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
if (code != ORT_OK) {
return 0;
}
// Convert types to ONNX C enums
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
// Extract the buffers
char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer);
char* innerIndicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, innerIndicesBuffer);
char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer);
// Increment by bufferPos bytes
indicesBufferArr = indicesBufferArr + indicesBufferPos;
innerIndicesBufferArr = innerIndicesBufferArr + innerIndicesBufferPos;
dataBufferArr = dataBufferArr + dataBufferPos;
// Extract the dense shape information
jboolean mkCopy;
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape);
// Extract the value shape
jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy);
jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape);
// Create the OrtValue
OrtValue* ortValue = NULL;
code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr,
(int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue));
// Release shapes
(*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT);
(*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT);
if (code != ORT_OK) {
return 0;
}
// Fill it with indices
code = checkOrtStatus(jniEnv, api, api->UseCsrIndices(ortValue,
(int64_t *) innerIndicesBufferArr, innerIndicesBufferSize,
(int64_t *) indicesBufferArr, indicesBufferSize));
if (code != ORT_OK) {
return 0;
} else {
// Return the pointer to the OrtValue
return (jlong) ortValue;
}
}
/*
* Class: ai_onnxruntime_OnnxSparseTensor
* Method: createSparseTensorFromBuffer
* Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[J[JII)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createSparseTensorFromBuffer
(JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle,
jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize,
jobject dataBuffer, jint dataBufferPos,
jlongArray denseShape, jlongArray indicesShape, jlongArray valuesShape,
jint onnxTypeJava, jint sparsityTypeJava) {
(void) cls; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
const OrtMemoryInfo* allocatorInfo;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo));
if (code != ORT_OK) {
return 0;
}
// Convert types to ONNX C enums
ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
OrtSparseFormat sparsityType = convertToOrtSparseFormat(sparsityTypeJava);
// Extract the buffers
char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer);
char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer);
// Increment by bufferPos bytes
indicesBufferArr = indicesBufferArr + indicesBufferPos;
dataBufferArr = dataBufferArr + dataBufferPos;
// Extract the dense shape information
jboolean mkCopy;
jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy);
jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape);
// Extract the value shape
jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy);
jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape);
// Create the OrtValue
OrtValue* ortValue = NULL;
code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr,
(int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue));
// Release shapes
(*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT);
(*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT);
if (code != ORT_OK) {
return 0;
}
// Fill it with indices
switch (sparsityType) {
case ORT_SPARSE_COO: {
// The cast is because we compute the offset in bytes in Java.
code = checkOrtStatus(jniEnv, api, api->UseCooIndices(ortValue, (int64_t *) indicesBufferArr,
indicesBufferSize));
break;
}
case ORT_SPARSE_BLOCK_SPARSE: {
// Extract the indices shape
jlong* indicesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, indicesShape, &mkCopy);
jsize indicesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, indicesShape);
// The cast is because we compute the offset in bytes in Java.
code = checkOrtStatus(jniEnv, api, api->UseBlockSparseIndices(ortValue, (int64_t *) indicesShapeArr,
indicesShapeLen, (int32_t *) indicesBufferArr));
// Release the indices shape
(*jniEnv)->ReleaseLongArrayElements(jniEnv, indicesShape, indicesShapeArr, JNI_ABORT);
break;
}
case ORT_SPARSE_CSRC:
case ORT_SPARSE_UNDEFINED: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED),
"These types are unsupported by this method - ORT_SPARSE_CSRC, ORT_SPARSE_UNDEFINED");
code = ORT_NOT_IMPLEMENTED;
}
}
if (code != ORT_OK) {
return 0;
} else {
// Return the pointer to the OrtValue
return (jlong) ortValue;
}
}

Просмотреть файл

@ -59,6 +59,10 @@ public class InferenceTest {
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
public static Path getResourcePath(String path) {
return new File(InferenceTest.class.getResource(path).getFile()).toPath();
}
@Test
public void environmentTest() {
// Checks that the environment instance is the same.

Просмотреть файл

@ -0,0 +1,450 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import static ai.onnxruntime.InferenceTest.getResourcePath;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
public class SparseTensorTest {
@Test
public void testCSRC() throws OrtException {
String modelPath = getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
long[] shape = new long[] {3, 3};
/*
* Sparse matrix:
* [
* 0 1 0
* 1 0 1
* 4 0 6
* ]
*/
LongBuffer outerIndices =
ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
outerIndices.put(0);
outerIndices.put(1);
outerIndices.put(3);
outerIndices.put(5);
outerIndices.rewind();
LongBuffer innerIndices =
ByteBuffer.allocateDirect(5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
innerIndices.put(1);
innerIndices.put(0);
innerIndices.put(2);
innerIndices.put(0);
innerIndices.put(2);
innerIndices.rewind();
FloatBuffer data =
ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
data.put(1);
data.put(1);
data.put(1);
data.put(4);
data.put(6);
data.rewind();
OnnxSparseTensor.CSRCTensor csrcTensor =
new OnnxSparseTensor.CSRCTensor(
outerIndices, innerIndices, data, shape, OnnxJavaType.FLOAT, 5);
OnnxSparseTensor tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor);
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
OrtSession.Result result = session.run(inputMap);
OnnxTensor outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(shape, outputTensor.getInfo().getShape());
float[] output = outputTensor.getFloatBuffer().array();
float[] expected = new float[] {0, 1, 0, 1, 0, 1, 4, 0, 6};
assertArrayEquals(expected, output, 1e-6f);
result.close();
inputMap.clear();
// check that the get methods return new buffers which exist past the tensor lifetime.
Buffer valuesOne = tensor.getValuesBuffer();
Buffer valuesTwo = tensor.getValuesBuffer();
Buffer indicesOne = tensor.getIndicesBuffer();
Buffer indicesTwo = tensor.getIndicesBuffer();
Buffer innerIndicesOne = tensor.getInnerIndicesBuffer();
Buffer innerIndicesTwo = tensor.getInnerIndicesBuffer();
tensor.close();
assertEquals(valuesOne, valuesTwo);
assertFalse(valuesOne == valuesTwo);
assertEquals(indicesOne, indicesTwo);
assertFalse(indicesOne == indicesTwo);
assertEquals(innerIndicesOne, innerIndicesTwo);
assertFalse(innerIndicesOne == innerIndicesTwo);
long[] rectangularShape = new long[] {2, 3};
/*
* Sparse matrix:
* [
* 1 0 3
* 0 5 6
* ]
*/
outerIndices =
ByteBuffer.allocateDirect(3 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
outerIndices.put(0);
outerIndices.put(2);
outerIndices.put(4);
outerIndices.rewind();
innerIndices =
ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
innerIndices.put(0);
innerIndices.put(2);
innerIndices.put(1);
innerIndices.put(2);
innerIndices.rewind();
data = ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
data.put(1);
data.put(3);
data.put(5);
data.put(6);
data.rewind();
csrcTensor =
new OnnxSparseTensor.CSRCTensor(
outerIndices, innerIndices, data, rectangularShape, OnnxJavaType.FLOAT, 4);
tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor);
assertArrayEquals(new long[] {3}, tensor.getIndicesShape());
assertArrayEquals(new long[] {4}, tensor.getInnerIndicesShape());
assertArrayEquals(new long[] {4}, tensor.getValuesShape());
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
result = session.run(inputMap);
outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(rectangularShape, outputTensor.getInfo().getShape());
output = outputTensor.getFloatBuffer().array();
expected = new float[] {1, 0, 3, 0, 5, 6};
assertArrayEquals(expected, output, 1e-6f);
result.close();
tensor.close();
inputMap.clear();
denseIdMatrix.close();
denseIdMatrix = makeIdentityMatrix(env, 4);
long[] vectorShape = new long[] {1, 4};
/*
* Sparse matrix:
* [
* 1 0 0 4
* ]
*/
outerIndices =
ByteBuffer.allocateDirect(2 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
outerIndices.put(0);
outerIndices.put(2);
outerIndices.rewind();
innerIndices =
ByteBuffer.allocateDirect(2 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
innerIndices.put(0);
innerIndices.put(3);
innerIndices.rewind();
data = ByteBuffer.allocateDirect(2 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
data.put(1);
data.put(4);
data.rewind();
csrcTensor =
new OnnxSparseTensor.CSRCTensor(
outerIndices, innerIndices, data, vectorShape, OnnxJavaType.FLOAT, 2);
tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor);
assertArrayEquals(new long[] {2}, tensor.getIndicesShape());
assertArrayEquals(new long[] {2}, tensor.getInnerIndicesShape());
assertArrayEquals(new long[] {2}, tensor.getValuesShape());
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
result = session.run(inputMap);
outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(vectorShape, outputTensor.getInfo().getShape());
output = outputTensor.getFloatBuffer().array();
expected = new float[] {1, 0, 0, 4};
assertArrayEquals(expected, output, 1e-6f);
result.close();
tensor.close();
inputMap.clear();
denseIdMatrix.close();
}
}
}
@Test
public void testCOO() throws OrtException {
String modelPath = getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
long[] shape = new long[] {3, 3};
/*
* Sparse matrix:
* [
* 0 1 0
* 1 0 1
* 4 0 6
* ]
*/
LongBuffer indices =
ByteBuffer.allocateDirect(2 * 5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
indices.put(0);
indices.put(1);
indices.put(1);
indices.put(0);
indices.put(1);
indices.put(2);
indices.put(2);
indices.put(0);
indices.put(2);
indices.put(2);
indices.rewind();
FloatBuffer data =
ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
data.put(1);
data.put(1);
data.put(1);
data.put(4);
data.put(6);
data.rewind();
OnnxSparseTensor.COOTensor cooTensor =
new OnnxSparseTensor.COOTensor(
indices, new long[] {5, 2}, data, shape, OnnxJavaType.FLOAT, 5);
OnnxSparseTensor tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
OrtSession.Result result = session.run(inputMap);
OnnxTensor outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(shape, outputTensor.getInfo().getShape());
float[] output = outputTensor.getFloatBuffer().array();
float[] expected = new float[] {0, 1, 0, 1, 0, 1, 4, 0, 6};
assertArrayEquals(expected, output, 1e-6f);
result.close();
tensor.close();
inputMap.clear();
/* disabled as sparse_dense_matmul doesn't support COO tensors with 1d indices
// Run the same tensor through, but using 1d indexing rather than 2d indexing
indices = ByteBuffer.allocateDirect(5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
indices.put(1);
indices.put(3);
indices.put(5);
indices.put(6);
indices.put(8);
indices.rewind();
cooTensor = new OnnxSparseTensor.COOTensor(indices, new long[]{5}, data, shape, OnnxJavaType.FLOAT, 5);
tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
result = session.run(inputMap);
outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(shape, outputTensor.getInfo().getShape());
output = outputTensor.getFloatBuffer().array();
assertArrayEquals(expected, output, 1e-6f);
result.close();
tensor.close();
inputMap.clear();
*/
long[] rectangularShape = new long[] {2, 3};
/*
* Sparse matrix:
* [
* 1 0 3
* 0 5 6
* ]
*/
indices =
ByteBuffer.allocateDirect(2 * 4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
indices.put(0);
indices.put(0);
indices.put(0);
indices.put(2);
indices.put(1);
indices.put(1);
indices.put(1);
indices.put(2);
indices.rewind();
data = ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
data.put(1);
data.put(3);
data.put(5);
data.put(6);
data.rewind();
cooTensor =
new OnnxSparseTensor.COOTensor(
indices, new long[] {4, 2}, data, rectangularShape, OnnxJavaType.FLOAT, 4);
tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
assertArrayEquals(new long[] {4, 2}, tensor.getIndicesShape());
assertArrayEquals(new long[] {4}, tensor.getValuesShape());
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
result = session.run(inputMap);
outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(rectangularShape, outputTensor.getInfo().getShape());
output = outputTensor.getFloatBuffer().array();
expected = new float[] {1, 0, 3, 0, 5, 6};
assertArrayEquals(expected, output, 1e-6f);
result.close();
tensor.close();
inputMap.clear();
denseIdMatrix.close();
denseIdMatrix = makeIdentityMatrix(env, 4);
long[] vectorShape = new long[] {1, 4};
/*
* Sparse matrix:
* [
* 1
* 0
* 0
* 4
* ]
*/
indices = ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer();
indices.put(0);
indices.put(0);
indices.put(0);
indices.put(3);
indices.rewind();
data = ByteBuffer.allocateDirect(2 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
data.put(1);
data.put(4);
data.rewind();
cooTensor =
new OnnxSparseTensor.COOTensor(
indices, new long[] {2, 2}, data, vectorShape, OnnxJavaType.FLOAT, 2);
tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor);
assertArrayEquals(new long[] {2, 2}, tensor.getIndicesShape());
assertArrayEquals(new long[] {2}, tensor.getValuesShape());
inputMap.put("sparse_A", tensor);
inputMap.put("dense_B", denseIdMatrix);
result = session.run(inputMap);
outputTensor = (OnnxTensor) result.get(0);
assertArrayEquals(vectorShape, outputTensor.getInfo().getShape());
output = outputTensor.getFloatBuffer().array();
expected = new float[] {1, 0, 0, 4};
assertArrayEquals(expected, output, 1e-6f);
result.close();
tensor.close();
inputMap.clear();
denseIdMatrix.close();
}
}
}
@Test
public void testCOOOutput() throws OrtException {
String modelPath = getResourcePath("/sparse_initializer_as_output.onnx").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions()) {
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, NodeInfo> outputs = session.getOutputInfo();
assertEquals(1, outputs.size());
NodeInfo info = outputs.get("values");
assertNotNull(info);
assertTrue(info.getInfo() instanceof TensorInfo);
TensorInfo outputInfo = (TensorInfo) info.getInfo();
assertArrayEquals(new long[] {3, 3}, outputInfo.getShape());
assertEquals(
TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, outputInfo.onnxType);
assertEquals(OnnxJavaType.FLOAT, outputInfo.type);
OrtSession.Result result = session.run(Collections.emptyMap());
OnnxValue output = result.get("values").get();
assertTrue(output instanceof OnnxSparseTensor);
OnnxSparseTensor sparseTensor = (OnnxSparseTensor) output;
assertEquals(OnnxSparseTensor.SparseTensorType.COO, sparseTensor.getSparseTensorType());
assertArrayEquals(new long[] {3}, sparseTensor.getIndicesShape());
assertArrayEquals(new long[] {3}, sparseTensor.getValuesShape());
assertArrayEquals(new long[] {3, 3}, sparseTensor.getInfo().getShape());
OnnxSparseTensor.SparseTensor<? extends Buffer> javaTensor = sparseTensor.getValue();
assertTrue(javaTensor instanceof OnnxSparseTensor.COOTensor);
OnnxSparseTensor.COOTensor cooTensor = (OnnxSparseTensor.COOTensor) javaTensor;
long[] indices = new long[3];
cooTensor.getIndices().get(indices);
float[] data = new float[3];
((FloatBuffer) cooTensor.getValues()).get(data);
assertArrayEquals(new long[] {2, 3, 5}, indices);
assertArrayEquals(
new float[] {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f}, data);
}
}
}
private static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
float[][] values = new float[size][size];
for (int i = 0; i < values.length; i++) {
values[i][i] = 1.0f;
}
return OnnxTensor.createTensor(env, values);
}
}

16
java/testdata/generic_sparse_to_dense_matmul.onnx поставляемый Normal file
Просмотреть файл

@ -0,0 +1,16 @@
dmitrism:Î
F
sparse_A
dense_Bdense_YSpMM"SparseToDenseMatMul: com.microsoftSpMMZ*
sparse_AB
A_dim_1
 inner_dimZ)
dense_B

 inner_dim
B_dim_2b'
dense_Y

A_dim_1
B_dim_2B
com.microsoft