Improve recent checkin operators (#144)
* update * update * update * remove tokenizer space * fix bugs Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Родитель
2842d2208e
Коммит
cce66310b2
|
@ -51,6 +51,10 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
}
|
||||
const std::vector<int64_t>& GetDims() const { return *this; }
|
||||
int64_t Size() const {
|
||||
if (empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t s = 1.;
|
||||
for (auto it = begin(); it != end(); ++it)
|
||||
s *= *it;
|
||||
|
|
|
@ -112,7 +112,7 @@ class BlingFireSentenceBreaker(CustomOp):
|
|||
class SegmentExtraction(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("input", onnx.TensorProto.INT64, [None])]
|
||||
return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
|
|
|
@ -28,7 +28,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
|
||||
// push end position
|
||||
if (i == input_dim.size() || p_data[i + 1] != p_data[i]) {
|
||||
if (i == input_dim.Size() || p_data[i + 1] != p_data[i]) {
|
||||
segment_position.push_back(i + 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,21 +53,21 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (tokenize_punctuation_ && ::ispunct(c)) {
|
||||
if (tokenize_punctuation_ && ::iswpunct(c)) {
|
||||
push_current_token_and_clear();
|
||||
push_single_char_and_clear(c);
|
||||
continue;
|
||||
}
|
||||
|
||||
// split by space
|
||||
if (::isspace(c)) {
|
||||
if (::iswspace(c)) {
|
||||
push_current_token_and_clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
// iscntrl will judge \t\f\n\r as control char
|
||||
// but it has been filter by isspace(c)
|
||||
if (remove_control_chars_ && ::iscntrl(c)) {
|
||||
if (remove_control_chars_ && ::iswcntrl(c)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -163,6 +163,67 @@ int32_t BertTokenizer::FindSpecialToken(ustring token) {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
|
||||
if (strategy_name == "longest_first") {
|
||||
strategy_ = TruncateStrategyType::LONGEST_FIRST;
|
||||
} else if (strategy_name == "only_first") {
|
||||
strategy_ = TruncateStrategyType::ONLY_FIRST;
|
||||
} else if (strategy_name == "only_second") {
|
||||
strategy_ = TruncateStrategyType::ONLY_SECOND;
|
||||
} else if (strategy_name == "longest_from_back") {
|
||||
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
|
||||
}
|
||||
}
|
||||
|
||||
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
|
||||
if (max_len < 0 || max_len >= ids.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ids.resize(max_len);
|
||||
}
|
||||
|
||||
void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len) {
|
||||
|
||||
if (max_len < 0 || (input1.size() + input2.size() <= max_len)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto input1_keep_len = input1.size();
|
||||
auto input2_keep_len = input2.size();
|
||||
auto half_max_len = max_len / 2;
|
||||
|
||||
switch (strategy_) {
|
||||
case TruncateStrategyType::LONGEST_FIRST:
|
||||
case TruncateStrategyType::LONGEST_FROM_BACK:
|
||||
|
||||
if ((input1_keep_len > half_max_len) && (input2_keep_len > half_max_len)) {
|
||||
input1_keep_len = max_len - half_max_len;
|
||||
input2_keep_len = half_max_len;
|
||||
} else if (input2_keep_len > input1_keep_len) {
|
||||
input2_keep_len = max_len - input1_keep_len;
|
||||
} else {
|
||||
input1_keep_len = max_len - input2_keep_len;
|
||||
}
|
||||
|
||||
if (strategy_ == TruncateStrategyType::LONGEST_FIRST) {
|
||||
input1.resize(input1_keep_len);
|
||||
input2.resize(input2_keep_len);
|
||||
} else {
|
||||
input1.erase(input1.begin(), input1.end() - input1_keep_len);
|
||||
input2.erase(input2.begin(), input2.end() - input2_keep_len);
|
||||
}
|
||||
|
||||
return;
|
||||
case TruncateStrategyType::ONLY_FIRST:
|
||||
return;
|
||||
case TruncateStrategyType::ONLY_SECOND:
|
||||
return;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
|
@ -175,10 +236,15 @@ KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info)
|
|||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
||||
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
|
||||
max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1));
|
||||
|
||||
|
||||
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
ustring(sep_token), ustring(pad_token),ustring(cls_token),
|
||||
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
|
||||
|
||||
truncate_ = std::make_shared<TruncateStrategy>(truncation_strategy_name);
|
||||
}
|
||||
|
||||
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
|
@ -193,13 +259,20 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> input_ids;
|
||||
std::vector<int64_t> token_type_ids;
|
||||
|
||||
if (input_data.size() == 1) {
|
||||
if (input_data.size() == 1 || input_data[1].empty()) {
|
||||
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
|
||||
truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
|
||||
input_ids = tokenizer_->AddSpecialToken(encode);
|
||||
token_type_ids = tokenizer_->GenerateTypeId(encode);
|
||||
} else if (input_data[0].empty()) {
|
||||
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
|
||||
truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
|
||||
input_ids = tokenizer_->AddSpecialToken(encode);
|
||||
token_type_ids = tokenizer_->GenerateTypeId(encode);
|
||||
} else {
|
||||
std::vector<int64_t> encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
|
||||
std::vector<int64_t> encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
|
||||
truncate_->Truncate(encode1, encode2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3);
|
||||
input_ids = tokenizer_->AddSpecialToken(encode1, encode2);
|
||||
token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2);
|
||||
}
|
||||
|
@ -235,3 +308,4 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/)
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -54,11 +54,31 @@ class BertTokenizer {
|
|||
int32_t FindSpecialToken(ustring token);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TruncateStrategy {
|
||||
public:
|
||||
explicit TruncateStrategy(std::string strategy_name);
|
||||
void Truncate(std::vector<int64_t>& ids, int64_t max_len);
|
||||
void Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len);
|
||||
|
||||
private:
|
||||
enum TruncateStrategyType{
|
||||
LONGEST_FIRST,
|
||||
ONLY_FIRST,
|
||||
ONLY_SECOND,
|
||||
LONGEST_FROM_BACK
|
||||
}strategy_;
|
||||
};
|
||||
|
||||
|
||||
struct KernelBertTokenizer : BaseKernel {
|
||||
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::shared_ptr<BertTokenizer> tokenizer_;
|
||||
std::shared_ptr<TruncateStrategy> truncate_;
|
||||
int max_length_;
|
||||
};
|
||||
|
||||
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "bert_tokenizer_decoder.hpp"
|
||||
|
||||
BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, ustring unk_token, ustring sep_token, ustring pad_token,
|
||||
ustring cls_token, ustring mask_token, ustring suffix_indicator): unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
|
||||
ustring cls_token, ustring mask_token, ustring suffix_indicator) : unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
|
||||
auto tokens = SplitString(vocab, "\n", true);
|
||||
vocab_.reserve(tokens.size());
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
|
@ -29,14 +29,13 @@ BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, ustring unk_token,
|
|||
vocab_.push_back(token);
|
||||
is_substr_.push_back(false);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids) {
|
||||
ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces) {
|
||||
ustring result;
|
||||
for (auto id: ids) {
|
||||
if (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_) {
|
||||
for (auto id : ids) {
|
||||
if (skip_special_tokens && (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -54,7 +53,11 @@ ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids) {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!result.empty() && !is_substr_[id]) {
|
||||
// At following situations, we needn't add space
|
||||
// we needn't add a space at the beginning of the output
|
||||
// we needn't add a space when the token is a substr (such as ##ing)
|
||||
// we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true
|
||||
if (!(result.empty() || is_substr_[id] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(result, id)))) {
|
||||
result.push_back(U' ');
|
||||
}
|
||||
|
||||
|
@ -64,6 +67,47 @@ ustring BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids) {
|
|||
return result;
|
||||
}
|
||||
|
||||
bool BertTokenizerDecoder::RemoveTokenizeSpace(ustring& text, int64_t new_token_id) {
|
||||
if (text.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto pre_char = text.back();
|
||||
auto cur_char = vocab_[new_token_id][0];
|
||||
|
||||
// normal punctuation
|
||||
if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// only remove left side space
|
||||
if (cur_char == U'}' || cur_char == U']' || cur_char == U'>' || cur_char == ')') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// only remove right side space
|
||||
if (pre_char == U'{' || pre_char == U'[' || pre_char == U'<' || pre_char == '(' || pre_char == '$') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// remove both side space
|
||||
if (pre_char == U'-' || pre_char == U'\'' || pre_char == U'"' || pre_char == U'/' || pre_char == U'@' || pre_char == U'\\' ||
|
||||
cur_char == U'-' || cur_char == U'\'' || cur_char == U'"' || cur_char == U'/' || cur_char == U'@' || cur_char == U'\\') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// remove both space beside unicode punctuation
|
||||
if (pre_char > 128 && iswpunct(pre_char)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cur_char > 128 && iswpunct(cur_char)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
|
@ -73,8 +117,12 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKern
|
|||
std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
|
||||
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
||||
|
||||
decoder_ = std::make_shared<BertTokenizerDecoder>(vocab, ustring(unk_token),ustring(sep_token),ustring(pad_token),
|
||||
ustring(cls_token),ustring(mask_token), ustring(suffix_indicator));
|
||||
use_indices_ = TryToGetAttributeWithDefault("use_indices", false);
|
||||
skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false);
|
||||
clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true);
|
||||
|
||||
decoder_ = std::make_shared<BertTokenizerDecoder>(vocab, ustring(unk_token), ustring(sep_token), ustring(pad_token),
|
||||
ustring(cls_token), ustring(mask_token), ustring(suffix_indicator));
|
||||
}
|
||||
|
||||
void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
||||
|
@ -83,31 +131,36 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions ids_dim(ort_, ids);
|
||||
|
||||
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
|
||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
|
||||
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
|
||||
const OrtValue* positions = ort_.KernelContext_GetInput(context, 1);
|
||||
OrtTensorDimensions positions_dim(ort_,positions);
|
||||
if (!(positions_dim.empty() || (positions_dim.size() == 2 && positions_dim[1] == 2))) {
|
||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix." , ORT_INVALID_GRAPH);
|
||||
OrtTensorDimensions positions_dim(ort_, positions);
|
||||
if (use_indices_ &&
|
||||
(!(positions_dim.empty() ||
|
||||
(positions_dim.Size() == 0) ||
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
const int64_t* p_positions = positions_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(positions);
|
||||
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
|
||||
|
||||
std::vector<ustring> result;
|
||||
std::vector<int64_t> output_dim(1);
|
||||
if (p_positions == nullptr) {
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size())));
|
||||
if (!use_indices_) {
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
output_dim[0] = 1;
|
||||
} else {
|
||||
for (int i = 0; i < positions_dim[0]; i++) {
|
||||
int64_t start = p_positions[2 * i];
|
||||
int64_t end = p_positions[2 * i + 1];
|
||||
if (p_positions != nullptr) {
|
||||
for (int i = 0; i < positions_dim[0]; i++) {
|
||||
int64_t start = p_positions[2 * i];
|
||||
int64_t end = p_positions[2 * i + 1];
|
||||
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end)));
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
}
|
||||
output_dim[0] = positions_dim[0];
|
||||
}
|
||||
output_dim[0] = positions_dim[0];
|
||||
}
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ class BertTokenizerDecoder {
|
|||
public:
|
||||
BertTokenizerDecoder(std::string vocab, ustring unk_token, ustring sep_token, ustring pad_token,
|
||||
ustring cls_token,ustring mask_token,ustring suffix_indicator);
|
||||
ustring Decode(const std::vector<int64_t>& ids);
|
||||
ustring Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces);
|
||||
|
||||
private:
|
||||
ustring unk_token_;
|
||||
|
@ -27,6 +27,8 @@ class BertTokenizerDecoder {
|
|||
ustring suffix_indicator_;
|
||||
std::vector<ustring> vocab_;
|
||||
std::vector<bool> is_substr_;
|
||||
|
||||
bool RemoveTokenizeSpace(ustring& text, int64_t new_token_id);
|
||||
};
|
||||
|
||||
struct KernelBertTokenizerDecoder : BaseKernel {
|
||||
|
@ -34,6 +36,9 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
|||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
||||
bool use_indices_;
|
||||
bool skip_special_tokens_;
|
||||
bool clean_up_tokenization_spaces_;
|
||||
};
|
||||
|
||||
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
||||
|
|
|
@ -70,3 +70,4 @@ TEST(strings, regex_split_begin_end_delim) {
|
|||
EXPECT_EQ(expected_begin_offsets, begin_offsets);
|
||||
EXPECT_EQ(expected_end_offsets, end_offsets);
|
||||
}
|
||||
|
||||
|
|
|
@ -141,4 +141,80 @@ TEST(tokenizer, basic_tokenizer_russia) {
|
|||
BasicTokenizer tokenizer(true, true, true, true, true);
|
||||
auto result = tokenizer.Tokenize(test_case);
|
||||
EXPECT_EQ(result, expect_result);
|
||||
}
|
||||
|
||||
TEST(tokenizer, basic_tokenizer) {
|
||||
ustring test_case = ustring("I mean, you’ll need something to talk about next Sunday, right?");
|
||||
std::vector<ustring> expect_result = ustring_vector_convertor({"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"});
|
||||
BasicTokenizer tokenizer(false, true, true, true, true);
|
||||
auto result = tokenizer.Tokenize(test_case);
|
||||
EXPECT_EQ(result, expect_result);
|
||||
}
|
||||
|
||||
TEST(tokenizer, truncation_one_input) {
|
||||
TruncateStrategy truncate("longest_first");
|
||||
|
||||
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
|
||||
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
|
||||
|
||||
auto test_input = init_vector1;
|
||||
truncate.Truncate(test_input, -1);
|
||||
EXPECT_EQ(test_input, init_vector1);
|
||||
|
||||
test_input = init_vector1;
|
||||
truncate.Truncate(test_input, 5);
|
||||
EXPECT_EQ(test_input, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||||
|
||||
test_input = init_vector2;
|
||||
truncate.Truncate(test_input, 6);
|
||||
EXPECT_EQ(test_input, init_vector2);
|
||||
}
|
||||
|
||||
TEST(tokenizer, truncation_longest_first) {
|
||||
TruncateStrategy truncate("longest_first");
|
||||
|
||||
std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
|
||||
std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
|
||||
|
||||
auto test_input1 = init_vector1;
|
||||
auto test_input2 = init_vector2;
|
||||
truncate.Truncate(test_input1, test_input2, -1);
|
||||
EXPECT_EQ(test_input1, init_vector1);
|
||||
EXPECT_EQ(test_input2, init_vector2);
|
||||
|
||||
test_input1 = init_vector1;
|
||||
test_input2 = init_vector2;
|
||||
truncate.Truncate(test_input1, test_input2, 15);
|
||||
EXPECT_EQ(test_input1, init_vector1);
|
||||
EXPECT_EQ(test_input2, init_vector2);
|
||||
|
||||
test_input1 = init_vector1;
|
||||
test_input2 = init_vector2;
|
||||
truncate.Truncate(test_input1, test_input2, 14);
|
||||
EXPECT_EQ(test_input1, init_vector1);
|
||||
EXPECT_EQ(test_input2, init_vector2);
|
||||
|
||||
test_input1 = init_vector1;
|
||||
test_input2 = init_vector2;
|
||||
truncate.Truncate(test_input1, test_input2, 8);
|
||||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4}));
|
||||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
|
||||
|
||||
test_input1 = init_vector1;
|
||||
test_input2 = init_vector2;
|
||||
truncate.Truncate(test_input1, test_input2, 9);
|
||||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
|
||||
|
||||
test_input1 = init_vector1;
|
||||
test_input2 = init_vector2;
|
||||
truncate.Truncate(test_input1, test_input2, 12);
|
||||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
|
||||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||||
|
||||
test_input1 = init_vector2;
|
||||
test_input2 = init_vector1;
|
||||
truncate.Truncate(test_input1, test_input2, 12);
|
||||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6 ,7}));
|
||||
}
|
|
@ -7,6 +7,7 @@ from onnxruntime_extensions import PyOrtFunction, BertTokenizerDecoder
|
|||
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
|
||||
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
@ -15,10 +16,29 @@ def _get_test_data_file(*sub_dirs):
|
|||
def _run_basic_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path)
|
||||
ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
|
||||
position = np.array([[0, ids.size]], dtype=np.int64)
|
||||
position = np.array([[]], dtype=np.int64)
|
||||
|
||||
result = t2stc(ids, position)
|
||||
np.testing.assert_array_equal(result[0], bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input), True, False))
|
||||
np.testing.assert_array_equal(result[0],
|
||||
bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)))
|
||||
|
||||
|
||||
def _run_indices_case(input, indices, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path, use_indices=1)
|
||||
ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
|
||||
position = np.array(indices, dtype=np.int64)
|
||||
|
||||
expect_result = []
|
||||
for index in indices:
|
||||
if len(index) > 0:
|
||||
result = bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)[index[0]:index[1]])
|
||||
result = result.split(' ')
|
||||
if result[0].startswith('##'):
|
||||
result.pop(0)
|
||||
expect_result.append(" ".join(result))
|
||||
|
||||
result = t2stc(ids, position)
|
||||
np.testing.assert_array_equal(result, expect_result, True, False)
|
||||
|
||||
|
||||
class TestBertTokenizerDecoder(unittest.TestCase):
|
||||
|
@ -34,6 +54,11 @@ class TestBertTokenizerDecoder(unittest.TestCase):
|
|||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
_run_indices_case(input="cat isnot playing toyssss", indices=[[]],
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
_run_indices_case(input="cat isnot playing toyssss", indices=[[1, 2], [3, 5]],
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -19,21 +19,26 @@ def _run_segment_extraction(input, expect_position, expect_value):
|
|||
class TestSegmentExtraction(unittest.TestCase):
|
||||
|
||||
def test_text_to_case1(self):
|
||||
inputs = np.array([0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3], dtype=np.int64)
|
||||
position = [[2, 4], [4,7], [7, 11]]
|
||||
inputs = np.array([[0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3]], dtype=np.int64)
|
||||
position = [[2, 4], [4, 7], [7, 11]]
|
||||
value = [1, 2, 3]
|
||||
_run_segment_extraction(inputs, position, value)
|
||||
|
||||
inputs = np.array([1, 1, 0, 0, 2, 2, 2, 3, 3, 3, 0, 5], dtype=np.int64)
|
||||
inputs = np.array([[1, 1, 0, 0, 2, 2, 2, 3, 3, 3, 0, 5]], dtype=np.int64)
|
||||
position = [[0, 2], [4, 7], [7, 10], [11, 12]]
|
||||
value = [1, 2, 3, 5]
|
||||
_run_segment_extraction(inputs, position, value)
|
||||
|
||||
inputs = np.array([1, 2, 4, 5], dtype=np.int64)
|
||||
inputs = np.array([[1, 2, 4, 5]], dtype=np.int64)
|
||||
position = [[0, 1], [1, 2], [2, 3], [3, 4]]
|
||||
value = [1, 2, 4, 5]
|
||||
_run_segment_extraction(inputs, position, value)
|
||||
|
||||
inputs = np.array([[0, 0, 1, 1, 1, 0, 0, 0, 0, 3, 3, 3, 0]], dtype=np.int64)
|
||||
position = [[2, 5], [9, 12]]
|
||||
value = [1, 3]
|
||||
_run_segment_extraction(inputs, position, value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче