Improve recent checkin operators (#144)

* update

* update

* update

* remove tokenizer space

* fix bugs

Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Mojimi 2021-09-07 13:34:47 +08:00 коммит произвёл GitHub
Родитель 2842d2208e
Коммит cce66310b2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 297 добавлений и 34 удалений

Просмотреть файл

@ -51,6 +51,10 @@ struct OrtTensorDimensions : std::vector<int64_t> {
}
const std::vector<int64_t>& GetDims() const { return *this; }
int64_t Size() const {
if (empty()) {
return 0;
}
int64_t s = 1.;
for (auto it = begin(); it != end(); ++it)
s *= *it;

Просмотреть файл

@ -112,7 +112,7 @@ class BlingFireSentenceBreaker(CustomOp):
class SegmentExtraction(CustomOp):
@classmethod
def get_inputs(cls):
return [cls.io_def("input", onnx.TensorProto.INT64, [None])]
return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
@classmethod
def get_outputs(cls):

Просмотреть файл

@ -28,7 +28,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
}
// push end position
if (i == input_dim.size() || p_data[i + 1] != p_data[i]) {
if (i == input_dim.Size() || p_data[i + 1] != p_data[i]) {
segment_position.push_back(i + 1);
}
}

Просмотреть файл

@ -53,21 +53,21 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
continue;
}
if (tokenize_punctuation_ && ::ispunct(c)) {
if (tokenize_punctuation_ && ::iswpunct(c)) {
push_current_token_and_clear();
push_single_char_and_clear(c);
continue;
}
// split by space
if (::isspace(c)) {
if (::iswspace(c)) {
push_current_token_and_clear();
continue;
}
// iscntrl will judge \t\f\n\r as control char
// but it has been filter by isspace(c)
if (remove_control_chars_ && ::iscntrl(c)) {
if (remove_control_chars_ && ::iswcntrl(c)) {
continue;
}

Просмотреть файл

@ -163,6 +163,67 @@ int32_t BertTokenizer::FindSpecialToken(ustring token) {
return it->second;
}
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
if (strategy_name == "longest_first") {
strategy_ = TruncateStrategyType::LONGEST_FIRST;
} else if (strategy_name == "only_first") {
strategy_ = TruncateStrategyType::ONLY_FIRST;
} else if (strategy_name == "only_second") {
strategy_ = TruncateStrategyType::ONLY_SECOND;
} else if (strategy_name == "longest_from_back") {
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
}
}
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
if (max_len < 0 || max_len >= ids.size()) {
return;
}
ids.resize(max_len);
}
void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len) {
if (max_len < 0 || (input1.size() + input2.size() <= max_len)) {
return;
}
auto input1_keep_len = input1.size();
auto input2_keep_len = input2.size();
auto half_max_len = max_len / 2;
switch (strategy_) {
case TruncateStrategyType::LONGEST_FIRST:
case TruncateStrategyType::LONGEST_FROM_BACK:
if ((input1_keep_len > half_max_len) && (input2_keep_len > half_max_len)) {
input1_keep_len = max_len - half_max_len;
input2_keep_len = half_max_len;
} else if (input2_keep_len > input1_keep_len) {
input2_keep_len = max_len - input1_keep_len;
} else {
input1_keep_len = max_len - input2_keep_len;
}
if (strategy_ == TruncateStrategyType::LONGEST_FIRST) {
input1.resize(input1_keep_len);
input2.resize(input2_keep_len);
} else {
input1.erase(input1.begin(), input1.end() - input1_keep_len);
input2.erase(input2.begin(), input2.end() - input2_keep_len);
}
return;
case TruncateStrategyType::ONLY_FIRST:
return;
case TruncateStrategyType::ONLY_SECOND:
return;
default:
return;
}
}
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
@ -175,10 +236,15 @@ KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info)
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1));
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
ustring(sep_token), ustring(pad_token),ustring(cls_token),
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
truncate_ = std::make_shared<TruncateStrategy>(truncation_strategy_name);
}
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
@ -193,13 +259,20 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
std::vector<int64_t> input_ids;
std::vector<int64_t> token_type_ids;
if (input_data.size() == 1) {
if (input_data.size() == 1 || input_data[1].empty()) {
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
input_ids = tokenizer_->AddSpecialToken(encode);
token_type_ids = tokenizer_->GenerateTypeId(encode);
} else if (input_data[0].empty()) {
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
input_ids = tokenizer_->AddSpecialToken(encode);
token_type_ids = tokenizer_->GenerateTypeId(encode);
} else {
std::vector<int64_t> encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
std::vector<int64_t> encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
truncate_->Truncate(encode1, encode2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3);
input_ids = tokenizer_->AddSpecialToken(encode1, encode2);
token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2);
}
@ -235,3 +308,4 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -54,11 +54,31 @@ class BertTokenizer {
int32_t FindSpecialToken(ustring token);
};
class TruncateStrategy {
public:
explicit TruncateStrategy(std::string strategy_name);
void Truncate(std::vector<int64_t>& ids, int64_t max_len);
void Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len);
private:
enum TruncateStrategyType{
LONGEST_FIRST,
ONLY_FIRST,
ONLY_SECOND,
LONGEST_FROM_BACK
}strategy_;
};
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BertTokenizer> tokenizer_;
std::shared_ptr<TruncateStrategy> truncate_;
int max_length_;
};
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {

Просмотреть файл

@ -1,7 +1,7 @@
#include "bert_tokenizer_decoder.hpp"
BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, ustring unk_token, ustring sep_token, ustring pad_token,
ustring cls_token, ustring mask_token, ustring suffix_indicator): unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
ustring cls_token, ustring mask_token, ustring suffix_indicator) : unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
auto tokens = SplitString(vocab, "\n", true);
vocab_.reserve(tokens.size());
for (int i = 0; i < tokens.size(); i++) {
@ -29,14 +29,13 @@ BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, ustring unk_token,
vocab_.push_back(token);
is_substr_.push_back(false);
}
}
}
ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids) {
ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces) {
ustring result;
for (auto id: ids) {
if (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_) {
for (auto id : ids) {
if (skip_special_tokens && (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_)) {
continue;
}
@ -54,7 +53,11 @@ ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids) {
continue;
}
if (!result.empty() && !is_substr_[id]) {
// At following situations, we needn't add space
// we needn't add a space at the beginning of the output
// we needn't add a space when the token is a substr (such as ##ing)
// we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true
if (!(result.empty() || is_substr_[id] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(result, id)))) {
result.push_back(U' ');
}
@ -64,6 +67,47 @@ ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids) {
return result;
}
bool BertTokenizerDecoder::RemoveTokenizeSpace(ustring& text, int64_t new_token_id) {
if (text.empty()) {
return true;
}
auto pre_char = text.back();
auto cur_char = vocab_[new_token_id][0];
// normal punctuation
if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') {
return true;
}
// only remove left side space
if (cur_char == U'}' || cur_char == U']' || cur_char == U'>' || cur_char == ')') {
return true;
}
// only remove right side space
if (pre_char == U'{' || pre_char == U'[' || pre_char == U'<' || pre_char == '(' || pre_char == '$') {
return true;
}
// remove both side space
if (pre_char == U'-' || pre_char == U'\'' || pre_char == U'"' || pre_char == U'/' || pre_char == U'@' || pre_char == U'\\' ||
cur_char == U'-' || cur_char == U'\'' || cur_char == U'"' || cur_char == U'/' || cur_char == U'@' || cur_char == U'\\') {
return true;
}
// remove both space beside unicode punctuation
if (pre_char > 128 && iswpunct(pre_char)) {
return true;
}
if (cur_char > 128 && iswpunct(cur_char)) {
return true;
}
return false;
}
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
@ -73,8 +117,12 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKern
std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
decoder_ = std::make_shared<BertTokenizerDecoder>(vocab, ustring(unk_token),ustring(sep_token),ustring(pad_token),
ustring(cls_token),ustring(mask_token), ustring(suffix_indicator));
use_indices_ = TryToGetAttributeWithDefault("use_indices", false);
skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false);
clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true);
decoder_ = std::make_shared<BertTokenizerDecoder>(vocab, ustring(unk_token), ustring(sep_token), ustring(pad_token),
ustring(cls_token), ustring(mask_token), ustring(suffix_indicator));
}
void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
@ -83,31 +131,36 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
OrtTensorDimensions ids_dim(ort_, ids);
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
const OrtValue* positions = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions positions_dim(ort_,positions);
if (!(positions_dim.empty() || (positions_dim.size() == 2 && positions_dim[1] == 2))) {
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix." , ORT_INVALID_GRAPH);
OrtTensorDimensions positions_dim(ort_, positions);
if (use_indices_ &&
(!(positions_dim.empty() ||
(positions_dim.Size() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
}
const int64_t* p_positions = positions_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(positions);
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
std::vector<ustring> result;
std::vector<int64_t> output_dim(1);
if (p_positions == nullptr) {
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size())));
if (!use_indices_) {
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_));
output_dim[0] = 1;
} else {
for (int i = 0; i < positions_dim[0]; i++) {
int64_t start = p_positions[2 * i];
int64_t end = p_positions[2 * i + 1];
if (p_positions != nullptr) {
for (int i = 0; i < positions_dim[0]; i++) {
int64_t start = p_positions[2 * i];
int64_t end = p_positions[2 * i + 1];
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end)));
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_));
}
output_dim[0] = positions_dim[0];
}
output_dim[0] = positions_dim[0];
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());

Просмотреть файл

@ -15,7 +15,7 @@ class BertTokenizerDecoder {
public:
BertTokenizerDecoder(std::string vocab, ustring unk_token, ustring sep_token, ustring pad_token,
ustring cls_token,ustring mask_token,ustring suffix_indicator);
ustring Decode(const std::vector<int64_t>& ids);
ustring Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces);
private:
ustring unk_token_;
@ -27,6 +27,8 @@ class BertTokenizerDecoder {
ustring suffix_indicator_;
std::vector<ustring> vocab_;
std::vector<bool> is_substr_;
bool RemoveTokenizeSpace(ustring& text, int64_t new_token_id);
};
struct KernelBertTokenizerDecoder : BaseKernel {
@ -34,6 +36,9 @@ struct KernelBertTokenizerDecoder : BaseKernel {
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BertTokenizerDecoder> decoder_;
bool use_indices_;
bool skip_special_tokens_;
bool clean_up_tokenization_spaces_;
};
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {

Просмотреть файл

@ -70,3 +70,4 @@ TEST(strings, regex_split_begin_end_delim) {
EXPECT_EQ(expected_begin_offsets, begin_offsets);
EXPECT_EQ(expected_end_offsets, end_offsets);
}

Просмотреть файл

@ -141,4 +141,80 @@ TEST(tokenizer, basic_tokenizer_russia) {
BasicTokenizer tokenizer(true, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}
TEST(tokenizer, basic_tokenizer) {
ustring test_case = ustring("I mean, youll need something to talk about next Sunday, right?");
std::vector<ustring> expect_result = ustring_vector_convertor({"I", "mean", ",", "you", "", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"});
BasicTokenizer tokenizer(false, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}
TEST(tokenizer, truncation_one_input) {
TruncateStrategy truncate("longest_first");
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
auto test_input = init_vector1;
truncate.Truncate(test_input, -1);
EXPECT_EQ(test_input, init_vector1);
test_input = init_vector1;
truncate.Truncate(test_input, 5);
EXPECT_EQ(test_input, std::vector<int64_t>({1, 2, 3, 4, 5}));
test_input = init_vector2;
truncate.Truncate(test_input, 6);
EXPECT_EQ(test_input, init_vector2);
}
TEST(tokenizer, truncation_longest_first) {
TruncateStrategy truncate("longest_first");
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
auto test_input1 = init_vector1;
auto test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, -1);
EXPECT_EQ(test_input1, init_vector1);
EXPECT_EQ(test_input2, init_vector2);
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 15);
EXPECT_EQ(test_input1, init_vector1);
EXPECT_EQ(test_input2, init_vector2);
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 14);
EXPECT_EQ(test_input1, init_vector1);
EXPECT_EQ(test_input2, init_vector2);
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 8);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 9);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
test_input1 = init_vector1;
test_input2 = init_vector2;
truncate.Truncate(test_input1, test_input2, 12);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5}));
test_input1 = init_vector2;
test_input2 = init_vector1;
truncate.Truncate(test_input1, test_input2, 12);
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6 ,7}));
}

Просмотреть файл

@ -7,6 +7,7 @@ from onnxruntime_extensions import PyOrtFunction, BertTokenizerDecoder
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
def _get_test_data_file(*sub_dirs):
test_dir = Path(__file__).parent
return str(test_dir.joinpath(*sub_dirs))
@ -15,10 +16,29 @@ def _get_test_data_file(*sub_dirs):
def _run_basic_case(input, vocab_path):
t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path)
ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
position = np.array([[0, ids.size]], dtype=np.int64)
position = np.array([[]], dtype=np.int64)
result = t2stc(ids, position)
np.testing.assert_array_equal(result[0], bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input), True, False))
np.testing.assert_array_equal(result[0],
bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)))
def _run_indices_case(input, indices, vocab_path):
t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path, use_indices=1)
ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
position = np.array(indices, dtype=np.int64)
expect_result = []
for index in indices:
if len(index) > 0:
result = bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)[index[0]:index[1]])
result = result.split(' ')
if result[0].startswith('##'):
result.pop(0)
expect_result.append(" ".join(result))
result = t2stc(ids, position)
np.testing.assert_array_equal(result, expect_result, True, False)
class TestBertTokenizerDecoder(unittest.TestCase):
@ -34,6 +54,11 @@ class TestBertTokenizerDecoder(unittest.TestCase):
_run_basic_case(input="cat isnot playing toyssss",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_indices_case(input="cat isnot playing toyssss", indices=[[]],
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_indices_case(input="cat isnot playing toyssss", indices=[[1, 2], [3, 5]],
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
if __name__ == "__main__":

Просмотреть файл

@ -19,21 +19,26 @@ def _run_segment_extraction(input, expect_position, expect_value):
class TestSegmentExtraction(unittest.TestCase):
def test_text_to_case1(self):
inputs = np.array([0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3], dtype=np.int64)
position = [[2, 4], [4,7], [7, 11]]
inputs = np.array([[0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3]], dtype=np.int64)
position = [[2, 4], [4, 7], [7, 11]]
value = [1, 2, 3]
_run_segment_extraction(inputs, position, value)
inputs = np.array([1, 1, 0, 0, 2, 2, 2, 3, 3, 3, 0, 5], dtype=np.int64)
inputs = np.array([[1, 1, 0, 0, 2, 2, 2, 3, 3, 3, 0, 5]], dtype=np.int64)
position = [[0, 2], [4, 7], [7, 10], [11, 12]]
value = [1, 2, 3, 5]
_run_segment_extraction(inputs, position, value)
inputs = np.array([1, 2, 4, 5], dtype=np.int64)
inputs = np.array([[1, 2, 4, 5]], dtype=np.int64)
position = [[0, 1], [1, 2], [2, 3], [3, 4]]
value = [1, 2, 4, 5]
_run_segment_extraction(inputs, position, value)
inputs = np.array([[0, 0, 1, 1, 1, 0, 0, 0, 0, 3, 3, 3, 0]], dtype=np.int64)
position = [[2, 5], [9, 12]]
value = [1, 3]
_run_segment_extraction(inputs, position, value)
if __name__ == "__main__":
unittest.main()