onnxruntime-extensions/operators/tokenizer/bert_tokenizer.hpp

120 строки
4.5 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "ustring.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "basic_tokenizer.hpp"
#include <unordered_map>
#include <list>
class BertTokenizerVocab final {
public:
explicit BertTokenizerVocab(std::string_view vocab);
bool FindToken(const ustring& token);
bool FindTokenId(const ustring& token, int32_t& token_id);
int32_t FindTokenId(const ustring& token);
private:
std::string raw_vocab_;
std::unordered_map<std::string_view, int32_t> vocab_;
};
class TruncateStrategy final {
public:
explicit TruncateStrategy(std::string_view strategy_name);
void Truncate(std::vector<int64_t>& ids, int32_t max_len);
void Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len);
private:
enum TruncateStrategyType {
LONGEST_FIRST,
ONLY_FIRST,
ONLY_SECOND,
LONGEST_FROM_BACK
} strategy_;
};
// TODO: merge with the implementation of word piece tokenizer
class WordpieceTokenizer final {
public:
WordpieceTokenizer(
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word = 100);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map,
bool compute_offset_mapping);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens, std::list<OffsetMappingType>& offset_map,
bool compute_offset_mapping);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
private:
int64_t max_input_chars_per_word_;
ustring suffix_indicator_;
ustring unk_token_;
int32_t unk_token_id_;
std::shared_ptr<BertTokenizerVocab> vocab_;
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
};
class BertTokenizer final {
public:
BertTokenizer(const std::string& vocab, bool do_lower_case, bool do_basic_tokenize,
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator, int32_t max_len, const std::string& truncation_strategy);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
std::vector<ustring> Tokenize(const ustring& text, std::list<OffsetMappingType>& offset_map,
bool compute_offset_mapping);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
void Truncate(std::vector<int64_t>& ids);
void Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
private:
int32_t unk_token_id_ = 0;
int32_t sep_token_id_ = 0;
int32_t pad_token_id_ = 0;
int32_t cls_token_id_ = 0;
int32_t mask_token_id_ = 0;
int32_t max_length_ = 0;
bool do_basic_tokenize_ = false;
std::unique_ptr<TruncateStrategy> truncate_;
std::shared_ptr<BertTokenizerVocab> vocab_;
std::unique_ptr<BasicTokenizer> basic_tokenizer_;
std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
};
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
protected:
std::unique_ptr<BertTokenizer> tokenizer_;
};
struct KernelHfBertTokenizer : KernelBertTokenizer {
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
};