Added optional outputs for GPT2, CLIP and Roberta Tokenizers (#389)
* Initial optional i/o for robertap * Small fix * Added working optional output functionality to RobertaTokenizer with tests * Added optional outputs to CLIPTokenizer * Added optional outputs to GPT2Tokenizer * Use ternary operators --------- Authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Родитель
9cd1284da8
Коммит
460bd34183
|
@ -148,33 +148,45 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
|
||||
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
|
||||
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
|
||||
if (attention_mask != nullptr) {
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
mask[idx] = 1;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (offset_mapping != nullptr) {
|
||||
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
|
||||
int idx2 = 0;
|
||||
for (auto& res : offset_map) {
|
||||
for (auto& mapping : res) {
|
||||
offset[idx2] = mapping.first;
|
||||
idx2++;
|
||||
offset[idx2] = mapping.second;
|
||||
idx2++;
|
||||
}
|
||||
}
|
||||
}
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
token[idx] = id;
|
||||
mask[idx] = 1;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
token[idx] = 0;
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
int idx2 = 0;
|
||||
for (auto& res : offset_map) {
|
||||
for (auto& mapping : res) {
|
||||
offset[idx2] = mapping.first;
|
||||
idx2++;
|
||||
offset[idx2] = mapping.second;
|
||||
idx2++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpClipBpeTokenizer::GetName() const {
|
||||
|
@ -188,6 +200,11 @@ size_t CustomOpClipBpeTokenizer::GetInputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
@ -195,3 +212,8 @@ size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetOutputCharacteristic(size_t index) const {
|
||||
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
|
||||
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
||||
}
|
|
@ -21,6 +21,8 @@ struct CustomOpClipBpeTokenizer : OrtW::CustomOpBase<CustomOpClipBpeTokenizer, K
|
|||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -106,19 +106,30 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
|
||||
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
if (attention_mask != nullptr) {
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
mask[idx] = 1;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
token[idx] = id;
|
||||
mask[idx] = 1;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
token[idx] = 0;
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
@ -135,6 +146,11 @@ size_t CustomOpBpeTokenizer::GetInputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
}
|
||||
|
@ -142,3 +158,8 @@ size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetOutputCharacteristic(size_t index) const {
|
||||
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
|
||||
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
||||
}
|
|
@ -20,6 +20,8 @@ struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpe
|
|||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -141,33 +141,45 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
|
||||
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
|
||||
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
|
||||
if (attention_mask != nullptr) {
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
mask[idx] = 1;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (offset_mapping != nullptr) {
|
||||
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
|
||||
int idx2 = 0;
|
||||
for (auto& res : offset_map) {
|
||||
for (auto& mapping : res) {
|
||||
offset[idx2] = mapping.first;
|
||||
idx2++;
|
||||
offset[idx2] = mapping.second;
|
||||
idx2++;
|
||||
}
|
||||
}
|
||||
}
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
token[idx] = id;
|
||||
mask[idx] = 1;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (size_t i = res.size(); i < max_length; i++) {
|
||||
token[idx] = 0;
|
||||
mask[idx] = 0;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
int idx2 = 0;
|
||||
for (auto& res : offset_map) {
|
||||
for (auto& mapping : res) {
|
||||
offset[idx2] = mapping.first;
|
||||
idx2++;
|
||||
offset[idx2] = mapping.second;
|
||||
idx2++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpRobertaBpeTokenizer::GetName() const {
|
||||
|
@ -181,6 +193,11 @@ size_t CustomOpRobertaBpeTokenizer::GetInputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t CustomOpRobertaBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
@ -188,3 +205,8 @@ size_t CustomOpRobertaBpeTokenizer::GetOutputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetOutputCharacteristic(size_t index) const {
|
||||
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
|
||||
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
||||
}
|
|
@ -21,6 +21,8 @@ struct CustomOpRobertaBpeTokenizer : OrtW::CustomOpBase<CustomOpRobertaBpeTokeni
|
|||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -22,10 +22,6 @@ def _create_test_model(**kwargs):
|
|||
merges_file = kwargs["merges_file"]
|
||||
max_length = kwargs["max_length"]
|
||||
|
||||
node = [helper.make_node(
|
||||
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'string_input', onnx_proto.TensorProto.STRING, [None])
|
||||
output1 = helper.make_tensor_value_info(
|
||||
|
@ -35,8 +31,32 @@ def _create_test_model(**kwargs):
|
|||
output3 = helper.make_tensor_value_info(
|
||||
'offset_mapping', onnx_proto.TensorProto.INT64, ["batch_size", "num_offsets", 2])
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
|
||||
model = make_onnx_model(graph)
|
||||
if kwargs["attention_mask"]:
|
||||
if kwargs["offset_map"]:
|
||||
node = [helper.make_node(
|
||||
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
|
||||
model = make_onnx_model(graph)
|
||||
else:
|
||||
node = [helper.make_node(
|
||||
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
|
||||
model = make_onnx_model(graph)
|
||||
else:
|
||||
node = [helper.make_node(
|
||||
'CLIPTokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1])
|
||||
model = make_onnx_model(graph)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
@ -53,7 +73,7 @@ class TestCLIPTokenizer(unittest.TestCase):
|
|||
cls.tokenizer_cvt = HFTokenizerConverter(cls.slow_tokenizer)
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True, offset_map=True)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
|
@ -102,5 +122,30 @@ class TestCLIPTokenizer(unittest.TestCase):
|
|||
np.testing.assert_array_equal(fn_out[1].reshape((fn_out[1].size,)), expect_attention_mask)
|
||||
np.testing.assert_array_equal(fn_out[2].reshape((fn_out[2].shape[1], fn_out[2].shape[2])), expect_offset_mapping)
|
||||
|
||||
def test_optional_outputs(self):
|
||||
# Test for models without offset mapping and without both attention mask and offset mapping (input id output is always required)
|
||||
model1 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=True, offset_map=False)
|
||||
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False, offset_map=False)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess1 = _ort.InferenceSession(model1.SerializeToString(), so)
|
||||
sess2 = _ort.InferenceSession(model2.SerializeToString(), so)
|
||||
input_text = np.array(["Hello World"])
|
||||
outputs1 = sess1.run(None, {'string_input': input_text})
|
||||
outputs2 = sess2.run(None, {'string_input': input_text})
|
||||
|
||||
# Test output size
|
||||
np.testing.assert_array_equal(len(outputs1), 2)
|
||||
np.testing.assert_array_equal(len(outputs2), 1)
|
||||
|
||||
# Test output values
|
||||
clip_out = self.tokenizer(["Hello World"], return_offsets_mapping=True)
|
||||
expect_input_ids = clip_out['input_ids']
|
||||
expect_attention_mask = clip_out['attention_mask']
|
||||
expect_offset_mapping = clip_out['offset_mapping']
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs1[0])
|
||||
np.testing.assert_array_equal(expect_attention_mask, outputs1[1])
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs2[0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -20,11 +20,7 @@ def _create_test_model(**kwargs):
|
|||
vocab_file = kwargs["vocab_file"]
|
||||
merges_file = kwargs["merges_file"]
|
||||
max_length = kwargs["max_length"]
|
||||
|
||||
node = [helper.make_node(
|
||||
'GPT2Tokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'string_input', onnx_proto.TensorProto.STRING, [None])
|
||||
output1 = helper.make_tensor_value_info(
|
||||
|
@ -32,8 +28,23 @@ def _create_test_model(**kwargs):
|
|||
output2 = helper.make_tensor_value_info(
|
||||
'attention_mask', onnx_proto.TensorProto.INT64, [None, None])
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
|
||||
model = make_onnx_model(graph)
|
||||
if kwargs["attention_mask"]:
|
||||
node = [helper.make_node(
|
||||
'GPT2Tokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
|
||||
model = make_onnx_model(graph)
|
||||
else:
|
||||
node = [helper.make_node(
|
||||
'GPT2Tokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1])
|
||||
model = make_onnx_model(graph)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
@ -74,7 +85,7 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
# return input_ids, attention_mask
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
|
@ -103,6 +114,23 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
|
||||
enable_py_op(True)
|
||||
|
||||
def test_optional_outputs(self):
|
||||
# Test for model without attention mask (input id output is always required)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
input_text = np.array(["Hello World"])
|
||||
outputs = sess.run(None, {'string_input': input_text})
|
||||
|
||||
# Test output size
|
||||
np.testing.assert_array_equal(len(outputs), 1)
|
||||
|
||||
# Test output values
|
||||
gpt2_out = self.tokenizer.tokenizer_sentence(["Hello World"], -1)
|
||||
expect_input_ids = gpt2_out[0]
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs[0])
|
||||
|
||||
|
||||
# def test_tokenizer_pyop(self):
|
||||
# self._run_tokenizer(["I can feel the magic, can you?"])
|
||||
|
|
|
@ -22,10 +22,6 @@ def _create_test_model(**kwargs):
|
|||
merges_file = kwargs["merges_file"]
|
||||
max_length = kwargs["max_length"]
|
||||
|
||||
node = [helper.make_node(
|
||||
'RobertaTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'string_input', onnx_proto.TensorProto.STRING, [None])
|
||||
output1 = helper.make_tensor_value_info(
|
||||
|
@ -35,8 +31,32 @@ def _create_test_model(**kwargs):
|
|||
output3 = helper.make_tensor_value_info(
|
||||
'offset_mapping', onnx_proto.TensorProto.INT64, ["batch_size", "num_offsets", 2])
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
|
||||
model = make_onnx_model(graph)
|
||||
if kwargs["attention_mask"]:
|
||||
if kwargs["offset_map"]:
|
||||
node = [helper.make_node(
|
||||
'RobertaTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
|
||||
model = make_onnx_model(graph)
|
||||
else:
|
||||
node = [helper.make_node(
|
||||
'RobertaTokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
|
||||
model = make_onnx_model(graph)
|
||||
else:
|
||||
node = [helper.make_node(
|
||||
'RobertaTokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
graph = helper.make_graph(node, 'test0', [input1], [output1])
|
||||
model = make_onnx_model(graph)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
@ -53,7 +73,7 @@ class TestRobertaTokenizer(unittest.TestCase):
|
|||
cls.tokenizer_cvt = HFTokenizerConverter(cls.slow_tokenizer)
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True, offset_map=True)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
|
@ -102,5 +122,30 @@ class TestRobertaTokenizer(unittest.TestCase):
|
|||
np.testing.assert_array_equal(fn_out[1].reshape((fn_out[1].size,)), expect_attention_mask)
|
||||
np.testing.assert_array_equal(fn_out[2].reshape((fn_out[2].shape[1], fn_out[2].shape[2])), expect_offset_mapping)
|
||||
|
||||
def test_optional_outputs(self):
|
||||
# Test for models without offset mapping and without both attention mask and offset mapping (input id output is always required)
|
||||
model1 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=True, offset_map=False)
|
||||
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False, offset_map=False)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess1 = _ort.InferenceSession(model1.SerializeToString(), so)
|
||||
sess2 = _ort.InferenceSession(model2.SerializeToString(), so)
|
||||
input_text = np.array(["Hello World"])
|
||||
outputs1 = sess1.run(None, {'string_input': input_text})
|
||||
outputs2 = sess2.run(None, {'string_input': input_text})
|
||||
|
||||
# Test output size
|
||||
np.testing.assert_array_equal(len(outputs1), 2)
|
||||
np.testing.assert_array_equal(len(outputs2), 1)
|
||||
|
||||
# Test output values
|
||||
roberta_out = self.tokenizer(["Hello World"], return_offsets_mapping=True)
|
||||
expect_input_ids = roberta_out['input_ids']
|
||||
expect_attention_mask = roberta_out['attention_mask']
|
||||
expect_offset_mapping = roberta_out['offset_mapping']
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs1[0])
|
||||
np.testing.assert_array_equal(expect_attention_mask, outputs1[1])
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs2[0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче