Merge pull request #201 from google/sr
Added --use_all_vocab=true flag for WORD/CHAR model
This commit is contained in:
Коммит
c14d581837
|
@ -44,13 +44,20 @@ util::Status Trainer::Train() {
|
|||
|
||||
CHECK_OR_RETURN(final_pieces_.empty());
|
||||
for (const auto &it : Sorted(required_chars_)) {
|
||||
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
|
||||
if (!trainer_spec_.use_all_vocab() &&
|
||||
final_pieces_.size() == static_cast<size_t>(vocab_size)) {
|
||||
break;
|
||||
}
|
||||
final_pieces_.emplace_back(string_util::UnicodeCharToUTF8(it.first),
|
||||
log(it.second) - logsum);
|
||||
}
|
||||
|
||||
if (trainer_spec_.use_all_vocab()) {
|
||||
trainer_spec_.set_vocab_size(final_pieces_.size() + meta_pieces_.size());
|
||||
}
|
||||
|
||||
LOG(INFO) << trainer_spec_.Utf8DebugString();
|
||||
|
||||
return Save();
|
||||
}
|
||||
} // namespace character
|
||||
|
|
|
@ -27,7 +27,7 @@ message TrainerSpec {
|
|||
// B) Bilingual: TSV, source sentence <tab> target sentence
|
||||
// When bilingual data is passed, shared vocabulary model is built.
|
||||
// Note that the input file must be raw corpus, not a preprocessed corpus.
|
||||
// Trainer only loads the first |input_sentence_size| sentences specified
|
||||
// Trainer only loads the first `input_sentence_size` sentences specified
|
||||
// with this parameter.
|
||||
repeated string input = 1;
|
||||
|
||||
|
@ -62,13 +62,13 @@ message TrainerSpec {
|
|||
///////////////////////////////////////////////////////////////////
|
||||
// Training parameters.
|
||||
//
|
||||
// Uses characters which cover the corpus with the ratio of |chars_coverage|.
|
||||
// Uses characters which cover the corpus with the ratio of `chars_coverage`.
|
||||
// This parameter determines the set of basic Alphabet of sentence piece.
|
||||
// 1.0 - |chars_coverage| characters are treated as UNK.
|
||||
// 1.0 - `chars_coverage` characters are treated as UNK.
|
||||
optional float character_coverage = 10 [ default = 0.9995 ];
|
||||
|
||||
// Maximum size of sentences the trainer loads from |input| parameter.
|
||||
// Trainer simply loads the |input| files in sequence.
|
||||
// Maximum size of sentences the trainer loads from `input` parameter.
|
||||
// Trainer simply loads the `input` files in sequence.
|
||||
// It is better to shuffle the input corpus randomly.
|
||||
optional int32 input_sentence_size = 11 [ default = 10000000 ];
|
||||
|
||||
|
@ -82,11 +82,11 @@ message TrainerSpec {
|
|||
optional int32 training_sentence_size = 13 [ default = 10000000 ];
|
||||
|
||||
// The size of seed sentencepieces.
|
||||
// |seed_sentencepiece_size| must be larger than |vocab_size|.
|
||||
// `seed_sentencepiece_size` must be larger than `vocab_size`.
|
||||
optional int32 seed_sentencepiece_size = 14 [ default = 1000000 ];
|
||||
|
||||
// In every EM sub-iterations, keeps top
|
||||
// |shrinking_factor| * |current sentencepieces size| with respect to
|
||||
// `shrinking_factor` * `current sentencepieces size` with respect to
|
||||
// the loss of the sentence piece. This value should be smaller than 1.0.
|
||||
optional float shrinking_factor = 15 [ default = 0.75 ];
|
||||
|
||||
|
@ -103,7 +103,7 @@ message TrainerSpec {
|
|||
optional int32 max_sentencepiece_length = 20 [ default = 16 ];
|
||||
|
||||
// Uses Unicode script to split sentence pieces.
|
||||
// When |split_by_unicode_script| is true, we do not allow sentence piece to
|
||||
// When `split_by_unicode_script` is true, we do not allow sentence piece to
|
||||
// include multiple Unicode scripts, e.g. "F1" is not a valid piece.
|
||||
// Exception: CJ characters (Hiragana/Katakana/Han) are all handled
|
||||
// as one script type, since Japanese word can consist of multiple scripts.
|
||||
|
@ -112,7 +112,7 @@ message TrainerSpec {
|
|||
optional bool split_by_unicode_script = 21 [ default = true ];
|
||||
|
||||
// Use a white space to split sentence pieces.
|
||||
// When |split_by_whitespace| is false, we may have the piece containing
|
||||
// When `split_by_whitespace` is false, we may have the piece containing
|
||||
// a white space in the middle. e.g., "in_the".
|
||||
optional bool split_by_whitespace = 22 [ default = true ];
|
||||
|
||||
|
@ -142,6 +142,10 @@ message TrainerSpec {
|
|||
// always assumes hard_vocab_limit = false.
|
||||
optional bool hard_vocab_limit = 33 [ default = true ];
|
||||
|
||||
// use all symbols for vocab extraction. This flag is valid
|
||||
// if model type is either CHAR or WORD
|
||||
optional bool use_all_vocab = 34 [ default = false ];
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// Reserved special meta tokens.
|
||||
// * -1 is not used.
|
||||
|
|
|
@ -177,6 +177,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
|||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::Train(util::min_string_view args) {
|
||||
LOG(INFO) << "Running command: " << args.data();
|
||||
TrainerSpec trainer_spec;
|
||||
NormalizerSpec normalizer_spec;
|
||||
RETURN_IF_ERROR(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
|
||||
|
|
|
@ -23,22 +23,50 @@ DECLARE_string(data_dir);
|
|||
namespace sentencepiece {
|
||||
namespace {
|
||||
|
||||
void CheckVocab(absl::string_view filename, int expected_vocab_size) {
|
||||
SentencePieceProcessor sp;
|
||||
CHECK_OK(sp.Load(filename.data()));
|
||||
EXPECT_EQ(expected_vocab_size, sp.model_proto().trainer_spec().vocab_size());
|
||||
EXPECT_EQ(sp.model_proto().pieces_size(),
|
||||
sp.model_proto().trainer_spec().vocab_size());
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
|
||||
std::string input = util::JoinPath(FLAGS_data_dir, "botchan.txt");
|
||||
SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000"));
|
||||
SentencePieceTrainer::Train(string_util::StrCat(
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000")));
|
||||
CheckVocab("m.model", 1000);
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input,
|
||||
" --model_prefix=m --vocab_size=1000 --self_test_sample_size=100"));
|
||||
SentencePieceTrainer::Train(string_util::StrCat(
|
||||
" --model_prefix=m --vocab_size=1000 --self_test_sample_size=100")));
|
||||
CheckVocab("m.model", 1000);
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
|
||||
"--model_type=bpe"));
|
||||
SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--model_type=bpe")));
|
||||
CheckVocab("m.model", 1000);
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
|
||||
"--model_type=char"));
|
||||
SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--model_type=char")));
|
||||
CheckVocab("m.model", 72);
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
|
||||
"--model_type=word"));
|
||||
"--model_type=word")));
|
||||
CheckVocab("m.model", 1000);
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
|
||||
"--model_type=char --use_all_vocab=true")));
|
||||
CheckVocab("m.model", 86);
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
|
||||
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
|
||||
"--model_type=word --use_all_vocab=true")));
|
||||
CheckVocab("m.model", 9186);
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
|
||||
|
|
|
@ -78,6 +78,9 @@ DEFINE_bool(remove_extra_whitespaces,
|
|||
"duplicate internal whitespace");
|
||||
DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
|
||||
"If set to false, --vocab_size is considered as a soft limit.");
|
||||
DEFINE_bool(use_all_vocab, kDefaultTrainerSpec.use_all_vocab(),
|
||||
"If set to true, use all tokens as vocab. "
|
||||
"Valid for word/char models.");
|
||||
DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
|
||||
DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
|
||||
"Override BOS (<s>) id. Set -1 to disable BOS.");
|
||||
|
@ -127,6 +130,7 @@ int main(int argc, char *argv[]) {
|
|||
SetTrainerSpecFromFlag(split_by_unicode_script);
|
||||
SetTrainerSpecFromFlag(split_by_whitespace);
|
||||
SetTrainerSpecFromFlag(hard_vocab_limit);
|
||||
SetTrainerSpecFromFlag(use_all_vocab);
|
||||
SetTrainerSpecFromFlag(unk_id);
|
||||
SetTrainerSpecFromFlag(bos_id);
|
||||
SetTrainerSpecFromFlag(eos_id);
|
||||
|
|
|
@ -48,6 +48,12 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
|
|||
CHECK_GT_OR_RETURN(trainer_spec.input().size(), 0);
|
||||
CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0);
|
||||
|
||||
if (trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
|
||||
trainer_spec.model_type() == TrainerSpec::BPE) {
|
||||
CHECK_OR_RETURN(!trainer_spec.use_all_vocab())
|
||||
<< "--use_all_vocab=true is valid for WORD/CHAR model.";
|
||||
}
|
||||
|
||||
#define CHECK_RANGE(variable, minval, maxval) \
|
||||
CHECK_OR_RETURN(variable >= minval && variable <= maxval)
|
||||
|
||||
|
@ -247,7 +253,8 @@ END:
|
|||
int64 accumulated_chars_count = 0;
|
||||
for (const auto &w : Sorted(chars_count)) {
|
||||
const float coverage = 1.0 * accumulated_chars_count / all_chars_count;
|
||||
if (coverage >= trainer_spec_.character_coverage()) {
|
||||
if (!trainer_spec_.use_all_vocab() &&
|
||||
coverage >= trainer_spec_.character_coverage()) {
|
||||
LOG(INFO) << "Done: " << 100.0 * coverage << "% characters are covered.";
|
||||
break;
|
||||
}
|
||||
|
@ -256,7 +263,10 @@ END:
|
|||
<< "space must not be included in normalized string.";
|
||||
required_chars_.insert(w);
|
||||
}
|
||||
LOG(INFO) << "alphabet size=" << required_chars_.size();
|
||||
|
||||
LOG(INFO) << "Alphabet size=" << required_chars_.size();
|
||||
LOG(INFO) << "Final character coverage="
|
||||
<< 1.0 * accumulated_chars_count / all_chars_count;
|
||||
|
||||
CHECK_OR_RETURN(!port::ContainsKey(required_chars_, kUNKChar));
|
||||
|
||||
|
|
|
@ -57,12 +57,17 @@ util::Status Trainer::Train() {
|
|||
if (it.first.find(kUNKStr) != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
|
||||
if (!trainer_spec_.use_all_vocab() &&
|
||||
final_pieces_.size() == static_cast<size_t>(vocab_size)) {
|
||||
break;
|
||||
}
|
||||
final_pieces_.emplace_back(it.first, log(it.second) - logsum);
|
||||
}
|
||||
|
||||
if (trainer_spec_.use_all_vocab()) {
|
||||
trainer_spec_.set_vocab_size(final_pieces_.size() + meta_pieces_.size());
|
||||
}
|
||||
|
||||
return Save();
|
||||
}
|
||||
} // namespace word
|
||||
|
|
Загрузка…
Ссылка в новой задаче