diff --git a/includes/ocos.h b/includes/ocos.h index a396beaa..a60e8a2c 100644 --- a/includes/ocos.h +++ b/includes/ocos.h @@ -51,6 +51,10 @@ struct OrtTensorDimensions : std::vector { } const std::vector& GetDims() const { return *this; } int64_t Size() const { + if (empty()) { + return 0; + } + int64_t s = 1.; for (auto it = begin(); it != end(); ++it) s *= *it; diff --git a/onnxruntime_extensions/_cuops.py b/onnxruntime_extensions/_cuops.py index d9cd9450..f333e8d2 100644 --- a/onnxruntime_extensions/_cuops.py +++ b/onnxruntime_extensions/_cuops.py @@ -112,7 +112,7 @@ class BlingFireSentenceBreaker(CustomOp): class SegmentExtraction(CustomOp): @classmethod def get_inputs(cls): - return [cls.io_def("input", onnx.TensorProto.INT64, [None])] + return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])] @classmethod def get_outputs(cls): diff --git a/operators/math/segement_extraction.cc b/operators/math/segement_extraction.cc index 461168bd..fe393a4c 100644 --- a/operators/math/segement_extraction.cc +++ b/operators/math/segement_extraction.cc @@ -28,7 +28,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) { } // push end position - if (i == input_dim.size() || p_data[i + 1] != p_data[i]) { + if (i == input_dim.Size() || p_data[i + 1] != p_data[i]) { segment_position.push_back(i + 1); } } diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc index 5814f4e7..6324d82e 100644 --- a/operators/tokenizer/basic_tokenizer.cc +++ b/operators/tokenizer/basic_tokenizer.cc @@ -53,21 +53,21 @@ std::vector BasicTokenizer::Tokenize(ustring text) { continue; } - if (tokenize_punctuation_ && ::ispunct(c)) { + if (tokenize_punctuation_ && ::iswpunct(c)) { push_current_token_and_clear(); push_single_char_and_clear(c); continue; } // split by space - if (::isspace(c)) { + if (::iswspace(c)) { push_current_token_and_clear(); continue; } // iscntrl will judge \t\f\n\r as control char // but it has been filter by isspace(c) - if (remove_control_chars_ && ::iscntrl(c)) { + if (remove_control_chars_ && ::iswcntrl(c)) { continue; } diff --git a/operators/tokenizer/bert_tokenizer.cc b/operators/tokenizer/bert_tokenizer.cc index ec4b2d16..3a010b47 100644 --- a/operators/tokenizer/bert_tokenizer.cc +++ b/operators/tokenizer/bert_tokenizer.cc @@ -163,6 +163,67 @@ int32_t BertTokenizer::FindSpecialToken(ustring token) { return it->second; } +TruncateStrategy::TruncateStrategy(std::string strategy_name) { + if (strategy_name == "longest_first") { + strategy_ = TruncateStrategyType::LONGEST_FIRST; + } else if (strategy_name == "only_first") { + strategy_ = TruncateStrategyType::ONLY_FIRST; + } else if (strategy_name == "only_second") { + strategy_ = TruncateStrategyType::ONLY_SECOND; + } else if (strategy_name == "longest_from_back") { + strategy_ = TruncateStrategyType::LONGEST_FROM_BACK; + } +} + +void TruncateStrategy::Truncate(std::vector& ids, int64_t max_len) { + if (max_len < 0 || max_len >= ids.size()) { + return; + } + + ids.resize(max_len); +} + +void TruncateStrategy::Truncate(std::vector& input1, std::vector& input2, int64_t max_len) { + + if (max_len < 0 || (input1.size() + input2.size() <= max_len)) { + return; + } + + auto input1_keep_len = input1.size(); + auto input2_keep_len = input2.size(); + auto half_max_len = max_len / 2; + + switch (strategy_) { + case TruncateStrategyType::LONGEST_FIRST: + case TruncateStrategyType::LONGEST_FROM_BACK: + + if ((input1_keep_len > half_max_len) && (input2_keep_len > half_max_len)) { + input1_keep_len = max_len - half_max_len; + input2_keep_len = half_max_len; + } else if (input2_keep_len > input1_keep_len) { + input2_keep_len = max_len - input1_keep_len; + } else { + input1_keep_len = max_len - input2_keep_len; + } + + if (strategy_ == TruncateStrategyType::LONGEST_FIRST) { + input1.resize(input1_keep_len); + input2.resize(input2_keep_len); + } else { + input1.erase(input1.begin(), input1.end() - input1_keep_len); + input2.erase(input2.begin(), input2.end() - input2_keep_len); + } + + return; + case TruncateStrategyType::ONLY_FIRST: + return; + case TruncateStrategyType::ONLY_SECOND: + return; + default: + return; + } +} + KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { std::string vocab = ort_.KernelInfoGetAttribute(info, "vocab_file"); bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); @@ -175,10 +236,15 @@ KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); + std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first")); + max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1)); + tokenizer_ = std::make_shared(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token), ustring(sep_token), ustring(pad_token),ustring(cls_token), ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator)); + + truncate_ = std::make_shared(truncation_strategy_name); } void KernelBertTokenizer::Compute(OrtKernelContext* context) { @@ -193,13 +259,20 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) { std::vector input_ids; std::vector token_type_ids; - if (input_data.size() == 1) { + if (input_data.size() == 1 || input_data[1].empty()) { std::vector encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0]))); + truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2); + input_ids = tokenizer_->AddSpecialToken(encode); + token_type_ids = tokenizer_->GenerateTypeId(encode); + } else if (input_data[0].empty()) { + std::vector encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1]))); + truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2); input_ids = tokenizer_->AddSpecialToken(encode); token_type_ids = tokenizer_->GenerateTypeId(encode); } else { std::vector encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0]))); std::vector encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1]))); + truncate_->Truncate(encode1, encode2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3); input_ids = tokenizer_->AddSpecialToken(encode1, encode2); token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2); } @@ -235,3 +308,4 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }; + diff --git a/operators/tokenizer/bert_tokenizer.hpp b/operators/tokenizer/bert_tokenizer.hpp index 7b9a9144..450d5ba9 100644 --- a/operators/tokenizer/bert_tokenizer.hpp +++ b/operators/tokenizer/bert_tokenizer.hpp @@ -54,11 +54,31 @@ class BertTokenizer { int32_t FindSpecialToken(ustring token); }; + + +class TruncateStrategy { + public: + explicit TruncateStrategy(std::string strategy_name); + void Truncate(std::vector& ids, int64_t max_len); + void Truncate(std::vector& input1, std::vector& input2, int64_t max_len); + + private: + enum TruncateStrategyType{ + LONGEST_FIRST, + ONLY_FIRST, + ONLY_SECOND, + LONGEST_FROM_BACK + }strategy_; +}; + + struct KernelBertTokenizer : BaseKernel { KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info); void Compute(OrtKernelContext* context); private: std::shared_ptr tokenizer_; + std::shared_ptr truncate_; + int max_length_; }; struct CustomOpBertTokenizer : Ort::CustomOpBase { diff --git a/operators/tokenizer/bert_tokenizer_decoder.cc b/operators/tokenizer/bert_tokenizer_decoder.cc index 882b0a1c..522fe08d 100644 --- a/operators/tokenizer/bert_tokenizer_decoder.cc +++ b/operators/tokenizer/bert_tokenizer_decoder.cc @@ -1,7 +1,7 @@ #include "bert_tokenizer_decoder.hpp" BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, ustring unk_token, ustring sep_token, ustring pad_token, - ustring cls_token, ustring mask_token, ustring suffix_indicator): unk_token_(unk_token), suffix_indicator_(suffix_indicator) { + ustring cls_token, ustring mask_token, ustring suffix_indicator) : unk_token_(unk_token), suffix_indicator_(suffix_indicator) { auto tokens = SplitString(vocab, "\n", true); vocab_.reserve(tokens.size()); for (int i = 0; i < tokens.size(); i++) { @@ -29,14 +29,13 @@ BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, ustring unk_token, vocab_.push_back(token); is_substr_.push_back(false); } - } } -ustring BertTokenizerDecoder::Decode(const std::vector& ids) { +ustring BertTokenizerDecoder::Decode(const std::vector& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces) { ustring result; - for (auto id: ids) { - if (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_) { + for (auto id : ids) { + if (skip_special_tokens && (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_)) { continue; } @@ -54,7 +53,11 @@ ustring BertTokenizerDecoder::Decode(const std::vector& ids) { continue; } - if (!result.empty() && !is_substr_[id]) { + // At following situations, we needn't add space + // we needn't add a space at the beginning of the output + // we needn't add a space when the token is a substr (such as ##ing) + // we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true + if (!(result.empty() || is_substr_[id] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(result, id)))) { result.push_back(U' '); } @@ -64,6 +67,47 @@ ustring BertTokenizerDecoder::Decode(const std::vector& ids) { return result; } +bool BertTokenizerDecoder::RemoveTokenizeSpace(ustring& text, int64_t new_token_id) { + if (text.empty()) { + return true; + } + + auto pre_char = text.back(); + auto cur_char = vocab_[new_token_id][0]; + + // normal punctuation + if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') { + return true; + } + + // only remove left side space + if (cur_char == U'}' || cur_char == U']' || cur_char == U'>' || cur_char == ')') { + return true; + } + + // only remove right side space + if (pre_char == U'{' || pre_char == U'[' || pre_char == U'<' || pre_char == '(' || pre_char == '$') { + return true; + } + + // remove both side space + if (pre_char == U'-' || pre_char == U'\'' || pre_char == U'"' || pre_char == U'/' || pre_char == U'@' || pre_char == U'\\' || + cur_char == U'-' || cur_char == U'\'' || cur_char == U'"' || cur_char == U'/' || cur_char == U'@' || cur_char == U'\\') { + return true; + } + + // remove both space beside unicode punctuation + if (pre_char > 128 && iswpunct(pre_char)) { + return true; + } + + if (cur_char > 128 && iswpunct(cur_char)) { + return true; + } + + return false; +} + KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { std::string vocab = ort_.KernelInfoGetAttribute(info, "vocab_file"); std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); @@ -73,8 +117,12 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKern std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); - decoder_ = std::make_shared(vocab, ustring(unk_token),ustring(sep_token),ustring(pad_token), - ustring(cls_token),ustring(mask_token), ustring(suffix_indicator)); + use_indices_ = TryToGetAttributeWithDefault("use_indices", false); + skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false); + clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); + + decoder_ = std::make_shared(vocab, ustring(unk_token), ustring(sep_token), ustring(pad_token), + ustring(cls_token), ustring(mask_token), ustring(suffix_indicator)); } void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) { @@ -83,31 +131,36 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) { OrtTensorDimensions ids_dim(ort_, ids); if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) { - ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n]." , ORT_INVALID_GRAPH); + ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH); } -// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData(ort_row_indices); + // const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData(ort_row_indices); const OrtValue* positions = ort_.KernelContext_GetInput(context, 1); - OrtTensorDimensions positions_dim(ort_,positions); - if (!(positions_dim.empty() || (positions_dim.size() == 2 && positions_dim[1] == 2))) { - ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix." , ORT_INVALID_GRAPH); + OrtTensorDimensions positions_dim(ort_, positions); + if (use_indices_ && + (!(positions_dim.empty() || + (positions_dim.Size() == 0) || + (positions_dim.size() == 2 && positions_dim[1] == 2)))) { + ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH); } - const int64_t* p_positions = positions_dim.empty() ? nullptr : ort_.GetTensorData(positions); + const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData(positions); std::vector result; std::vector output_dim(1); - if (p_positions == nullptr) { - result.push_back(decoder_->Decode(std::vector(p_ids, p_ids + ids_dim.Size()))); + if (!use_indices_) { + result.push_back(decoder_->Decode(std::vector(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_)); output_dim[0] = 1; } else { - for (int i = 0; i < positions_dim[0]; i++) { - int64_t start = p_positions[2 * i]; - int64_t end = p_positions[2 * i + 1]; + if (p_positions != nullptr) { + for (int i = 0; i < positions_dim[0]; i++) { + int64_t start = p_positions[2 * i]; + int64_t end = p_positions[2 * i + 1]; - result.push_back(decoder_->Decode(std::vector(p_ids + start, p_ids + end))); + result.push_back(decoder_->Decode(std::vector(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_)); + } + output_dim[0] = positions_dim[0]; } - output_dim[0] = positions_dim[0]; } OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size()); diff --git a/operators/tokenizer/bert_tokenizer_decoder.hpp b/operators/tokenizer/bert_tokenizer_decoder.hpp index 1a1546f0..84340533 100644 --- a/operators/tokenizer/bert_tokenizer_decoder.hpp +++ b/operators/tokenizer/bert_tokenizer_decoder.hpp @@ -15,7 +15,7 @@ class BertTokenizerDecoder { public: BertTokenizerDecoder(std::string vocab, ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,ustring mask_token,ustring suffix_indicator); - ustring Decode(const std::vector& ids); + ustring Decode(const std::vector& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces); private: ustring unk_token_; @@ -27,6 +27,8 @@ class BertTokenizerDecoder { ustring suffix_indicator_; std::vector vocab_; std::vector is_substr_; + + bool RemoveTokenizeSpace(ustring& text, int64_t new_token_id); }; struct KernelBertTokenizerDecoder : BaseKernel { @@ -34,6 +36,9 @@ struct KernelBertTokenizerDecoder : BaseKernel { void Compute(OrtKernelContext* context); private: std::shared_ptr decoder_; + bool use_indices_; + bool skip_special_tokens_; + bool clean_up_tokenization_spaces_; }; struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase { diff --git a/test/static_test/test_strings.cc b/test/static_test/test_strings.cc index 9805f7dc..db419fb8 100644 --- a/test/static_test/test_strings.cc +++ b/test/static_test/test_strings.cc @@ -70,3 +70,4 @@ TEST(strings, regex_split_begin_end_delim) { EXPECT_EQ(expected_begin_offsets, begin_offsets); EXPECT_EQ(expected_end_offsets, end_offsets); } + diff --git a/test/static_test/test_tokenizer.cc b/test/static_test/test_tokenizer.cc index 4b963d80..d277deca 100644 --- a/test/static_test/test_tokenizer.cc +++ b/test/static_test/test_tokenizer.cc @@ -141,4 +141,80 @@ TEST(tokenizer, basic_tokenizer_russia) { BasicTokenizer tokenizer(true, true, true, true, true); auto result = tokenizer.Tokenize(test_case); EXPECT_EQ(result, expect_result); +} + +TEST(tokenizer, basic_tokenizer) { + ustring test_case = ustring("I mean, you’ll need something to talk about next Sunday, right?"); + std::vector expect_result = ustring_vector_convertor({"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"}); + BasicTokenizer tokenizer(false, true, true, true, true); + auto result = tokenizer.Tokenize(test_case); + EXPECT_EQ(result, expect_result); +} + +TEST(tokenizer, truncation_one_input) { + TruncateStrategy truncate("longest_first"); + + std::vector init_vector1({1, 2, 3, 4, 5, 6, 7, 9}); + std::vector init_vector2({1, 2, 3, 4, 5}); + + auto test_input = init_vector1; + truncate.Truncate(test_input, -1); + EXPECT_EQ(test_input, init_vector1); + + test_input = init_vector1; + truncate.Truncate(test_input, 5); + EXPECT_EQ(test_input, std::vector({1, 2, 3, 4, 5})); + + test_input = init_vector2; + truncate.Truncate(test_input, 6); + EXPECT_EQ(test_input, init_vector2); +} + +TEST(tokenizer, truncation_longest_first) { + TruncateStrategy truncate("longest_first"); + + std::vector init_vector1({1, 2, 3, 4, 5, 6, 7, 9}); + std::vector init_vector2({1, 2, 3, 4, 5}); + + auto test_input1 = init_vector1; + auto test_input2 = init_vector2; + truncate.Truncate(test_input1, test_input2, -1); + EXPECT_EQ(test_input1, init_vector1); + EXPECT_EQ(test_input2, init_vector2); + + test_input1 = init_vector1; + test_input2 = init_vector2; + truncate.Truncate(test_input1, test_input2, 15); + EXPECT_EQ(test_input1, init_vector1); + EXPECT_EQ(test_input2, init_vector2); + + test_input1 = init_vector1; + test_input2 = init_vector2; + truncate.Truncate(test_input1, test_input2, 14); + EXPECT_EQ(test_input1, init_vector1); + EXPECT_EQ(test_input2, init_vector2); + + test_input1 = init_vector1; + test_input2 = init_vector2; + truncate.Truncate(test_input1, test_input2, 8); + EXPECT_EQ(test_input1, std::vector({1, 2, 3, 4})); + EXPECT_EQ(test_input2, std::vector({1, 2, 3, 4})); + + test_input1 = init_vector1; + test_input2 = init_vector2; + truncate.Truncate(test_input1, test_input2, 9); + EXPECT_EQ(test_input1, std::vector({1, 2, 3, 4, 5})); + EXPECT_EQ(test_input2, std::vector({1, 2, 3, 4})); + + test_input1 = init_vector1; + test_input2 = init_vector2; + truncate.Truncate(test_input1, test_input2, 12); + EXPECT_EQ(test_input1, std::vector({1, 2, 3, 4, 5, 6, 7})); + EXPECT_EQ(test_input2, std::vector({1, 2, 3, 4, 5})); + + test_input1 = init_vector2; + test_input2 = init_vector1; + truncate.Truncate(test_input1, test_input2, 12); + EXPECT_EQ(test_input1, std::vector({1, 2, 3, 4, 5})); + EXPECT_EQ(test_input2, std::vector({1, 2, 3, 4, 5, 6 ,7})); } \ No newline at end of file diff --git a/test/test_bert_tokenizer_decoder.py b/test/test_bert_tokenizer_decoder.py index 24de6658..5a102d5d 100644 --- a/test/test_bert_tokenizer_decoder.py +++ b/test/test_bert_tokenizer_decoder.py @@ -7,6 +7,7 @@ from onnxruntime_extensions import PyOrtFunction, BertTokenizerDecoder bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased') bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') + def _get_test_data_file(*sub_dirs): test_dir = Path(__file__).parent return str(test_dir.joinpath(*sub_dirs)) @@ -15,10 +16,29 @@ def _get_test_data_file(*sub_dirs): def _run_basic_case(input, vocab_path): t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path) ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64) - position = np.array([[0, ids.size]], dtype=np.int64) + position = np.array([[]], dtype=np.int64) result = t2stc(ids, position) - np.testing.assert_array_equal(result[0], bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input), True, False)) + np.testing.assert_array_equal(result[0], + bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input))) + + +def _run_indices_case(input, indices, vocab_path): + t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path, use_indices=1) + ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64) + position = np.array(indices, dtype=np.int64) + + expect_result = [] + for index in indices: + if len(index) > 0: + result = bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)[index[0]:index[1]]) + result = result.split(' ') + if result[0].startswith('##'): + result.pop(0) + expect_result.append(" ".join(result)) + + result = t2stc(ids, position) + np.testing.assert_array_equal(result, expect_result, True, False) class TestBertTokenizerDecoder(unittest.TestCase): @@ -34,6 +54,11 @@ class TestBertTokenizerDecoder(unittest.TestCase): _run_basic_case(input="cat isnot playing toyssss", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) + _run_indices_case(input="cat isnot playing toyssss", indices=[[]], + vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) + + _run_indices_case(input="cat isnot playing toyssss", indices=[[1, 2], [3, 5]], + vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) if __name__ == "__main__": diff --git a/test/test_segment_extraction.py b/test/test_segment_extraction.py index 1966f2f0..247dc31a 100644 --- a/test/test_segment_extraction.py +++ b/test/test_segment_extraction.py @@ -19,21 +19,26 @@ def _run_segment_extraction(input, expect_position, expect_value): class TestSegmentExtraction(unittest.TestCase): def test_text_to_case1(self): - inputs = np.array([0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3], dtype=np.int64) - position = [[2, 4], [4,7], [7, 11]] + inputs = np.array([[0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3]], dtype=np.int64) + position = [[2, 4], [4, 7], [7, 11]] value = [1, 2, 3] _run_segment_extraction(inputs, position, value) - inputs = np.array([1, 1, 0, 0, 2, 2, 2, 3, 3, 3, 0, 5], dtype=np.int64) + inputs = np.array([[1, 1, 0, 0, 2, 2, 2, 3, 3, 3, 0, 5]], dtype=np.int64) position = [[0, 2], [4, 7], [7, 10], [11, 12]] value = [1, 2, 3, 5] _run_segment_extraction(inputs, position, value) - inputs = np.array([1, 2, 4, 5], dtype=np.int64) + inputs = np.array([[1, 2, 4, 5]], dtype=np.int64) position = [[0, 1], [1, 2], [2, 3], [3, 4]] value = [1, 2, 4, 5] _run_segment_extraction(inputs, position, value) + inputs = np.array([[0, 0, 1, 1, 1, 0, 0, 0, 0, 3, 3, 3, 0]], dtype=np.int64) + position = [[2, 5], [9, 12]] + value = [1, 3] + _run_segment_extraction(inputs, position, value) + if __name__ == "__main__": unittest.main()