Added optional outputs for GPT2, CLIP and Roberta Tokenizers (#389)

* Initial optional i/o for robertap

* Small fix

* Added working optional output functionality to RobertaTokenizer with tests

* Added optional outputs to CLIPTokenizer

* Added optional outputs to GPT2Tokenizer

* Use ternary operators

---------

Authored-by: Sayan Shaw <sayanshaw@microsoft.com>
This commit is contained in:
Sayan Shaw 2023-04-06 13:28:59 -07:00 коммит произвёл GitHub
Родитель 9cd1284da8
Коммит 460bd34183
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 242 добавлений и 53 удалений

Просмотреть файл

@ -148,33 +148,45 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
if (attention_mask != nullptr) {
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
mask[idx] = 1;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
mask[idx] = 0;
idx++;
}
}
}
if (offset_mapping != nullptr) {
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
offset[idx2] = mapping.first;
idx2++;
offset[idx2] = mapping.second;
idx2++;
}
}
}
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
token[idx] = id;
mask[idx] = 1;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
token[idx] = 0;
mask[idx] = 0;
idx++;
}
}
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
offset[idx2] = mapping.first;
idx2++;
offset[idx2] = mapping.second;
idx2++;
}
}
}
const char* CustomOpClipBpeTokenizer::GetName() const {
@ -188,6 +200,11 @@ size_t CustomOpClipBpeTokenizer::GetInputTypeCount() const {
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
return 3;
}
@ -195,3 +212,8 @@ size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetOutputCharacteristic(size_t index) const {
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}

Просмотреть файл

@ -21,6 +21,8 @@ struct CustomOpClipBpeTokenizer : OrtW::CustomOpBase<CustomOpClipBpeTokenizer, K
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
};

Просмотреть файл

@ -106,19 +106,30 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
if (attention_mask != nullptr) {
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
mask[idx] = 1;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
mask[idx] = 0;
idx++;
}
}
}
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
token[idx] = id;
mask[idx] = 1;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
token[idx] = 0;
mask[idx] = 0;
idx++;
}
}
@ -135,6 +146,11 @@ size_t CustomOpBpeTokenizer::GetInputTypeCount() const {
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
return 2;
}
@ -142,3 +158,8 @@ size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetOutputCharacteristic(size_t index) const {
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}

Просмотреть файл

@ -20,6 +20,8 @@ struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpe
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
};

Просмотреть файл

@ -141,33 +141,45 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
if (attention_mask != nullptr) {
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
mask[idx] = 1;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
mask[idx] = 0;
idx++;
}
}
}
if (offset_mapping != nullptr) {
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
offset[idx2] = mapping.first;
idx2++;
offset[idx2] = mapping.second;
idx2++;
}
}
}
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
token[idx] = id;
mask[idx] = 1;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
token[idx] = 0;
mask[idx] = 0;
idx++;
}
}
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
offset[idx2] = mapping.first;
idx2++;
offset[idx2] = mapping.second;
idx2++;
}
}
}
const char* CustomOpRobertaBpeTokenizer::GetName() const {
@ -181,6 +193,11 @@ size_t CustomOpRobertaBpeTokenizer::GetInputTypeCount() const {
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t CustomOpRobertaBpeTokenizer::GetOutputTypeCount() const {
return 3;
}
@ -188,3 +205,8 @@ size_t CustomOpRobertaBpeTokenizer::GetOutputTypeCount() const {
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetOutputCharacteristic(size_t index) const {
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}

Просмотреть файл

@ -21,6 +21,8 @@ struct CustomOpRobertaBpeTokenizer : OrtW::CustomOpBase<CustomOpRobertaBpeTokeni
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
};

Просмотреть файл

@ -22,10 +22,6 @@ def _create_test_model(**kwargs):
merges_file = kwargs["merges_file"]
max_length = kwargs["max_length"]
node = [helper.make_node(
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
input1 = helper.make_tensor_value_info(
'string_input', onnx_proto.TensorProto.STRING, [None])
output1 = helper.make_tensor_value_info(
@ -35,8 +31,32 @@ def _create_test_model(**kwargs):
output3 = helper.make_tensor_value_info(
'offset_mapping', onnx_proto.TensorProto.INT64, ["batch_size", "num_offsets", 2])
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
model = make_onnx_model(graph)
if kwargs["attention_mask"]:
if kwargs["offset_map"]:
node = [helper.make_node(
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
model = make_onnx_model(graph)
else:
node = [helper.make_node(
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
model = make_onnx_model(graph)
else:
node = [helper.make_node(
'CLIPTokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1])
model = make_onnx_model(graph)
return model
@ -53,7 +73,7 @@ class TestCLIPTokenizer(unittest.TestCase):
cls.tokenizer_cvt = HFTokenizerConverter(cls.slow_tokenizer)
def _run_tokenizer(self, test_sentence, padding_length=-1):
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True, offset_map=True)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model.SerializeToString(), so)
@ -102,5 +122,30 @@ class TestCLIPTokenizer(unittest.TestCase):
np.testing.assert_array_equal(fn_out[1].reshape((fn_out[1].size,)), expect_attention_mask)
np.testing.assert_array_equal(fn_out[2].reshape((fn_out[2].shape[1], fn_out[2].shape[2])), expect_offset_mapping)
def test_optional_outputs(self):
# Test for models without offset mapping and without both attention mask and offset mapping (input id output is always required)
model1 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=True, offset_map=False)
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False, offset_map=False)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess1 = _ort.InferenceSession(model1.SerializeToString(), so)
sess2 = _ort.InferenceSession(model2.SerializeToString(), so)
input_text = np.array(["Hello World"])
outputs1 = sess1.run(None, {'string_input': input_text})
outputs2 = sess2.run(None, {'string_input': input_text})
# Test output size
np.testing.assert_array_equal(len(outputs1), 2)
np.testing.assert_array_equal(len(outputs2), 1)
# Test output values
clip_out = self.tokenizer(["Hello World"], return_offsets_mapping=True)
expect_input_ids = clip_out['input_ids']
expect_attention_mask = clip_out['attention_mask']
expect_offset_mapping = clip_out['offset_mapping']
np.testing.assert_array_equal(expect_input_ids, outputs1[0])
np.testing.assert_array_equal(expect_attention_mask, outputs1[1])
np.testing.assert_array_equal(expect_input_ids, outputs2[0])
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -20,11 +20,7 @@ def _create_test_model(**kwargs):
vocab_file = kwargs["vocab_file"]
merges_file = kwargs["merges_file"]
max_length = kwargs["max_length"]
node = [helper.make_node(
'GPT2Tokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
input1 = helper.make_tensor_value_info(
'string_input', onnx_proto.TensorProto.STRING, [None])
output1 = helper.make_tensor_value_info(
@ -32,8 +28,23 @@ def _create_test_model(**kwargs):
output2 = helper.make_tensor_value_info(
'attention_mask', onnx_proto.TensorProto.INT64, [None, None])
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
model = make_onnx_model(graph)
if kwargs["attention_mask"]:
node = [helper.make_node(
'GPT2Tokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
model = make_onnx_model(graph)
else:
node = [helper.make_node(
'GPT2Tokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1])
model = make_onnx_model(graph)
return model
@ -74,7 +85,7 @@ class TestGPT2Tokenizer(unittest.TestCase):
# return input_ids, attention_mask
def _run_tokenizer(self, test_sentence, padding_length=-1):
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model.SerializeToString(), so)
@ -103,6 +114,23 @@ class TestGPT2Tokenizer(unittest.TestCase):
enable_py_op(True)
def test_optional_outputs(self):
# Test for model without attention mask (input id output is always required)
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model.SerializeToString(), so)
input_text = np.array(["Hello World"])
outputs = sess.run(None, {'string_input': input_text})
# Test output size
np.testing.assert_array_equal(len(outputs), 1)
# Test output values
gpt2_out = self.tokenizer.tokenizer_sentence(["Hello World"], -1)
expect_input_ids = gpt2_out[0]
np.testing.assert_array_equal(expect_input_ids, outputs[0])
# def test_tokenizer_pyop(self):
# self._run_tokenizer(["I can feel the magic, can you?"])

Просмотреть файл

@ -22,10 +22,6 @@ def _create_test_model(**kwargs):
merges_file = kwargs["merges_file"]
max_length = kwargs["max_length"]
node = [helper.make_node(
'RobertaTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
input1 = helper.make_tensor_value_info(
'string_input', onnx_proto.TensorProto.STRING, [None])
output1 = helper.make_tensor_value_info(
@ -35,8 +31,32 @@ def _create_test_model(**kwargs):
output3 = helper.make_tensor_value_info(
'offset_mapping', onnx_proto.TensorProto.INT64, ["batch_size", "num_offsets", 2])
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
model = make_onnx_model(graph)
if kwargs["attention_mask"]:
if kwargs["offset_map"]:
node = [helper.make_node(
'RobertaTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1, output2, output3])
model = make_onnx_model(graph)
else:
node = [helper.make_node(
'RobertaTokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
model = make_onnx_model(graph)
else:
node = [helper.make_node(
'RobertaTokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file),
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
domain='ai.onnx.contrib')]
graph = helper.make_graph(node, 'test0', [input1], [output1])
model = make_onnx_model(graph)
return model
@ -53,7 +73,7 @@ class TestRobertaTokenizer(unittest.TestCase):
cls.tokenizer_cvt = HFTokenizerConverter(cls.slow_tokenizer)
def _run_tokenizer(self, test_sentence, padding_length=-1):
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True, offset_map=True)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model.SerializeToString(), so)
@ -102,5 +122,30 @@ class TestRobertaTokenizer(unittest.TestCase):
np.testing.assert_array_equal(fn_out[1].reshape((fn_out[1].size,)), expect_attention_mask)
np.testing.assert_array_equal(fn_out[2].reshape((fn_out[2].shape[1], fn_out[2].shape[2])), expect_offset_mapping)
def test_optional_outputs(self):
# Test for models without offset mapping and without both attention mask and offset mapping (input id output is always required)
model1 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=True, offset_map=False)
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False, offset_map=False)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess1 = _ort.InferenceSession(model1.SerializeToString(), so)
sess2 = _ort.InferenceSession(model2.SerializeToString(), so)
input_text = np.array(["Hello World"])
outputs1 = sess1.run(None, {'string_input': input_text})
outputs2 = sess2.run(None, {'string_input': input_text})
# Test output size
np.testing.assert_array_equal(len(outputs1), 2)
np.testing.assert_array_equal(len(outputs2), 1)
# Test output values
roberta_out = self.tokenizer(["Hello World"], return_offsets_mapping=True)
expect_input_ids = roberta_out['input_ids']
expect_attention_mask = roberta_out['attention_mask']
expect_offset_mapping = roberta_out['offset_mapping']
np.testing.assert_array_equal(expect_input_ids, outputs1[0])
np.testing.assert_array_equal(expect_attention_mask, outputs1[1])
np.testing.assert_array_equal(expect_input_ids, outputs2[0])
if __name__ == "__main__":
unittest.main()