Add tests for mapping-related operators (#179)
* init * finish vector_to_string * add more test Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Родитель
c1e9fdcb08
Коммит
537c492219
|
@ -56,6 +56,10 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
s *= *it;
|
||||
return s;
|
||||
}
|
||||
|
||||
bool IsScalar() const{
|
||||
return empty();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class VectorToString(CustomOp):
|
|||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('text', onnx_proto.TensorProto.STRING, [])]
|
||||
return [cls.io_def('text', onnx_proto.TensorProto.STRING, [None])]
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
|
|
|
@ -24,11 +24,11 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
|||
|
||||
const int64_t* ptr = static_cast<const int64_t*>(input);
|
||||
|
||||
if (vector_len_ == 1 && input_dim.size() == 1) {
|
||||
if (vector_len_ == 1 && (input_dim.size() == 1 || input_dim.IsScalar())) {
|
||||
// only hit when the key is a scalar and the input is a vector
|
||||
output_dim = input_dim;
|
||||
} else {
|
||||
if (input_dim[input_dim.size() - 1] != vector_len_) {
|
||||
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != vector_len_) {
|
||||
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -402,3 +402,159 @@ TEST(utils, test_string_join_dims_empty_values_scalar) {
|
|||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(string_operator, test_vector_to_string) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(1);
|
||||
inputs[0].name = "token_ids";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
inputs[0].dims = {5};
|
||||
inputs[0].values_int64 = {0, 1, 2, 3, 4};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "text";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {5};
|
||||
outputs[0].values_string = {"a", "unk", "b", "c", "unk"};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_vector_to_string_scalar_map.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {5};
|
||||
inputs[0].values_int64 = {1000, 111, 2323, 444, 555};
|
||||
|
||||
outputs[0].dims = {5};
|
||||
outputs[0].values_string = {"unk", "unk", "unk", "unk", "unk"};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {0};
|
||||
inputs[0].values_int64 = {};
|
||||
|
||||
outputs[0].dims = {0};
|
||||
outputs[0].values_string = {};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {5, 3};
|
||||
inputs[0].values_int64 = {0, 0, 0, 0, 0, 1, 3, 0, 1, 100, 0, 1, 43, 23, 11};
|
||||
|
||||
outputs[0].dims = {5};
|
||||
outputs[0].values_string = {"a", "b", "c", "unk", "unk"};
|
||||
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_vector_to_string_vector_map.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {0, 3};
|
||||
inputs[0].values_int64 = {};
|
||||
|
||||
outputs[0].dims = {0};
|
||||
outputs[0].values_string = {};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {};
|
||||
inputs[0].values_int64 = {111};
|
||||
|
||||
outputs[0].dims = {};
|
||||
outputs[0].values_string = {"unk"};
|
||||
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_vector_to_string_scalar_input.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(string_operator, test_string_to_vector) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(1);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {5};
|
||||
inputs[0].values_string = {"black", "white", "black", "中文", "英文"};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "token_ids";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[0].dims = {5,3};
|
||||
outputs[0].values_int64 = {0, 1, 2, 2, 3, 4, 0, 1, 2, 3, 4, 4, -1, -1, -1};
|
||||
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_string_to_vector.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {0};
|
||||
inputs[0].values_string = {};
|
||||
|
||||
outputs[0].dims = {0, 3};
|
||||
outputs[0].values_int64 = {};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {1};
|
||||
inputs[0].values_string = {""};
|
||||
|
||||
outputs[0].dims = {1, 3};
|
||||
outputs[0].values_int64 = {-1, -1, -1};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {};
|
||||
inputs[0].values_string = {""};
|
||||
|
||||
outputs[0].dims = {3};
|
||||
outputs[0].values_int64 = {-1, -1, -1};
|
||||
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_string_to_vector_scalar.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(string_operator, test_string_mapping) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(1);
|
||||
inputs[0].name = "input";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {5};
|
||||
inputs[0].values_string = {"Orange and Yellow", "不知道啥颜色", "No color", "black", "white"};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {5};
|
||||
outputs[0].values_string = {"Maybe", "也不知道可不可以", "No color", "OK", "Not OK"};
|
||||
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_string_mapping.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {};
|
||||
inputs[0].values_string = {"不知道啥颜色"};
|
||||
|
||||
outputs[0].dims = {};
|
||||
outputs[0].values_string = {"也不知道可不可以"};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
inputs[0].dims = {};
|
||||
inputs[0].values_string = {""};
|
||||
|
||||
outputs[0].dims = {};
|
||||
outputs[0].values_string = {""};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
Загрузка…
Ссылка в новой задаче