adds support for Uint8ClampedArray (#21985)

Fixes https://github.com/microsoft/onnxruntime/issues/21753
This commit is contained in:
Prathik Rao 2024-09-11 22:02:30 -07:00 коммит произвёл GitHub
Родитель d8e64bb529
Коммит d495e6cf1c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 60 добавлений и 3 удалений

Просмотреть файл

@ -51,13 +51,16 @@ export class Tensor implements TensorInterface {
*/
constructor(
type: TensorType,
data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[],
data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[],
dims?: readonly number[],
);
/**
* Construct a new CPU tensor object from the given data and dims. Type is inferred from data.
*/
constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]);
constructor(
data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[],
dims?: readonly number[],
);
/**
* Construct a new tensor object from the pinned CPU data with the given type and dims.
*
@ -90,12 +93,13 @@ export class Tensor implements TensorInterface {
arg0:
| TensorType
| TensorDataType
| Uint8ClampedArray
| readonly string[]
| readonly boolean[]
| CpuPinnedConstructorParameters
| TextureConstructorParameters
| GpuBufferConstructorParameters,
arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[],
arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[],
arg2?: readonly number[],
) {
// perform one-time check for BigInt/Float16Array support
@ -216,6 +220,12 @@ export class Tensor implements TensorInterface {
}
} else if (arg1 instanceof typedArrayConstructor) {
data = arg1;
} else if (arg1 instanceof Uint8ClampedArray) {
if (arg0 === 'uint8') {
data = Uint8Array.from(arg1);
} else {
throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`);
}
} else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
@ -243,6 +253,9 @@ export class Tensor implements TensorInterface {
} else {
throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
}
} else if (arg0 instanceof Uint8ClampedArray) {
type = 'uint8';
data = Uint8Array.from(arg0);
} else {
// get tensor type from TypedArray
const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(

Просмотреть файл

@ -192,6 +192,15 @@ export interface TensorConstructor extends TensorFactory {
dims?: readonly number[],
): TypedTensor<'bool'>;
/**
* Construct a new uint8 tensor object from a Uint8ClampedArray, data and dims.
*
* @param type - Specify the element type.
* @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new (type: 'uint8', data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>;
/**
* Construct a new 64-bit integer typed tensor object from the given type, data and dims.
*
@ -245,6 +254,14 @@ export interface TensorConstructor extends TensorFactory {
*/
new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>;
/**
* Construct a new uint8 tensor object from the given data and dims.
*
* @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>;
/**
* Construct a new uint16 tensor object from the given data and dims.
*

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import * as ort from 'onnxruntime-common';
// construct from Uint8Array
//
// {type-tests}|pass
new ort.Tensor(new Uint8Array(1));
// construct from Uint8ClampedArray
//
// {type-tests}|pass
new ort.Tensor(new Uint8ClampedArray(1));
// construct from type (bool), data (Uint8ClampedArray) and shape (number array)
//
// {type-tests}|fail|1|2769
new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]);

Просмотреть файл

@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => {
assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");
});
it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => {
const uint8ClampedArray = new Uint8ClampedArray(2);
uint8ClampedArray[0] = 0;
uint8ClampedArray[1] = 256; // clamped
const tensor = new Tensor('uint8', uint8ClampedArray, [2]);
assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'");
});
it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => {
const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]);
assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");