Merge pull request #201 from google/sr

Added --use_all_vocab=true flag for WORD/CHAR model
This commit is contained in:
Taku Kudo 2018-09-08 01:06:15 +09:00 коммит произвёл GitHub
Родитель 03dad83922 b40cca7d0c
Коммит c14d581837
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 82 добавлений и 23 удалений

Просмотреть файл

@ -44,13 +44,20 @@ util::Status Trainer::Train() {
CHECK_OR_RETURN(final_pieces_.empty());
for (const auto &it : Sorted(required_chars_)) {
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
if (!trainer_spec_.use_all_vocab() &&
final_pieces_.size() == static_cast<size_t>(vocab_size)) {
break;
}
final_pieces_.emplace_back(string_util::UnicodeCharToUTF8(it.first),
log(it.second) - logsum);
}
if (trainer_spec_.use_all_vocab()) {
trainer_spec_.set_vocab_size(final_pieces_.size() + meta_pieces_.size());
}
LOG(INFO) << trainer_spec_.Utf8DebugString();
return Save();
}
} // namespace character

Просмотреть файл

@ -27,7 +27,7 @@ message TrainerSpec {
// B) Bilingual: TSV, source sentence <tab> target sentence
// When bilingual data is passed, shared vocabulary model is built.
// Note that the input file must be raw corpus, not a preprocessed corpus.
// Trainer only loads the first |input_sentence_size| sentences specified
// Trainer only loads the first `input_sentence_size` sentences specified
// with this parameter.
repeated string input = 1;
@ -62,13 +62,13 @@ message TrainerSpec {
///////////////////////////////////////////////////////////////////
// Training parameters.
//
// Uses characters which cover the corpus with the ratio of |chars_coverage|.
// Uses characters which cover the corpus with the ratio of `chars_coverage`.
// This parameter determines the set of basic Alphabet of sentence piece.
// 1.0 - |chars_coverage| characters are treated as UNK.
// 1.0 - `chars_coverage` characters are treated as UNK.
optional float character_coverage = 10 [ default = 0.9995 ];
// Maximum size of sentences the trainer loads from |input| parameter.
// Trainer simply loads the |input| files in sequence.
// Maximum size of sentences the trainer loads from `input` parameter.
// Trainer simply loads the `input` files in sequence.
// It is better to shuffle the input corpus randomly.
optional int32 input_sentence_size = 11 [ default = 10000000 ];
@ -82,11 +82,11 @@ message TrainerSpec {
optional int32 training_sentence_size = 13 [ default = 10000000 ];
// The size of seed sentencepieces.
// |seed_sentencepiece_size| must be larger than |vocab_size|.
// `seed_sentencepiece_size` must be larger than `vocab_size`.
optional int32 seed_sentencepiece_size = 14 [ default = 1000000 ];
// In every EM sub-iterations, keeps top
// |shrinking_factor| * |current sentencepieces size| with respect to
// `shrinking_factor` * `current sentencepieces size` with respect to
// the loss of the sentence piece. This value should be smaller than 1.0.
optional float shrinking_factor = 15 [ default = 0.75 ];
@ -103,7 +103,7 @@ message TrainerSpec {
optional int32 max_sentencepiece_length = 20 [ default = 16 ];
// Uses Unicode script to split sentence pieces.
// When |split_by_unicode_script| is true, we do not allow sentence piece to
// When `split_by_unicode_script` is true, we do not allow sentence piece to
// include multiple Unicode scripts, e.g. "F1" is not a valid piece.
// Exception: CJ characters (Hiragana/Katakana/Han) are all handled
// as one script type, since Japanese word can consist of multiple scripts.
@ -112,7 +112,7 @@ message TrainerSpec {
optional bool split_by_unicode_script = 21 [ default = true ];
// Use a white space to split sentence pieces.
// When |split_by_whitespace| is false, we may have the piece containing
// When `split_by_whitespace` is false, we may have the piece containing
// a white space in the middle. e.g., "in_the".
optional bool split_by_whitespace = 22 [ default = true ];
@ -142,6 +142,10 @@ message TrainerSpec {
// always assumes hard_vocab_limit = false.
optional bool hard_vocab_limit = 33 [ default = true ];
// use all symbols for vocab extraction. This flag is valid
// if model type is either CHAR or WORD
optional bool use_all_vocab = 34 [ default = false ];
///////////////////////////////////////////////////////////////////
// Reserved special meta tokens.
// * -1 is not used.

Просмотреть файл

@ -177,6 +177,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
// static
util::Status SentencePieceTrainer::Train(util::min_string_view args) {
LOG(INFO) << "Running command: " << args.data();
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
RETURN_IF_ERROR(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));

Просмотреть файл

@ -23,22 +23,50 @@ DECLARE_string(data_dir);
namespace sentencepiece {
namespace {
void CheckVocab(absl::string_view filename, int expected_vocab_size) {
SentencePieceProcessor sp;
CHECK_OK(sp.Load(filename.data()));
EXPECT_EQ(expected_vocab_size, sp.model_proto().trainer_spec().vocab_size());
EXPECT_EQ(sp.model_proto().pieces_size(),
sp.model_proto().trainer_spec().vocab_size());
}
TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
std::string input = util::JoinPath(FLAGS_data_dir, "botchan.txt");
SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000"));
SentencePieceTrainer::Train(string_util::StrCat(
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000")));
CheckVocab("m.model", 1000);
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input,
" --model_prefix=m --vocab_size=1000 --self_test_sample_size=100"));
SentencePieceTrainer::Train(string_util::StrCat(
" --model_prefix=m --vocab_size=1000 --self_test_sample_size=100")));
CheckVocab("m.model", 1000);
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
"--model_type=bpe"));
SentencePieceTrainer::Train(string_util::StrCat(
"--model_type=bpe")));
CheckVocab("m.model", 1000);
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
"--model_type=char"));
SentencePieceTrainer::Train(string_util::StrCat(
"--model_type=char")));
CheckVocab("m.model", 72);
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
"--model_type=word"));
"--model_type=word")));
CheckVocab("m.model", 1000);
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
"--model_type=char --use_all_vocab=true")));
CheckVocab("m.model", 86);
EXPECT_OK(SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
"--model_type=word --use_all_vocab=true")));
CheckVocab("m.model", 9186);
}
TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {

Просмотреть файл

@ -78,6 +78,9 @@ DEFINE_bool(remove_extra_whitespaces,
"duplicate internal whitespace");
DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
"If set to false, --vocab_size is considered as a soft limit.");
DEFINE_bool(use_all_vocab, kDefaultTrainerSpec.use_all_vocab(),
"If set to true, use all tokens as vocab. "
"Valid for word/char models.");
DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
"Override BOS (<s>) id. Set -1 to disable BOS.");
@ -127,6 +130,7 @@ int main(int argc, char *argv[]) {
SetTrainerSpecFromFlag(split_by_unicode_script);
SetTrainerSpecFromFlag(split_by_whitespace);
SetTrainerSpecFromFlag(hard_vocab_limit);
SetTrainerSpecFromFlag(use_all_vocab);
SetTrainerSpecFromFlag(unk_id);
SetTrainerSpecFromFlag(bos_id);
SetTrainerSpecFromFlag(eos_id);

Просмотреть файл

@ -48,6 +48,12 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
CHECK_GT_OR_RETURN(trainer_spec.input().size(), 0);
CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0);
if (trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
trainer_spec.model_type() == TrainerSpec::BPE) {
CHECK_OR_RETURN(!trainer_spec.use_all_vocab())
<< "--use_all_vocab=true is valid for WORD/CHAR model.";
}
#define CHECK_RANGE(variable, minval, maxval) \
CHECK_OR_RETURN(variable >= minval && variable <= maxval)
@ -247,7 +253,8 @@ END:
int64 accumulated_chars_count = 0;
for (const auto &w : Sorted(chars_count)) {
const float coverage = 1.0 * accumulated_chars_count / all_chars_count;
if (coverage >= trainer_spec_.character_coverage()) {
if (!trainer_spec_.use_all_vocab() &&
coverage >= trainer_spec_.character_coverage()) {
LOG(INFO) << "Done: " << 100.0 * coverage << "% characters are covered.";
break;
}
@ -256,7 +263,10 @@ END:
<< "space must not be included in normalized string.";
required_chars_.insert(w);
}
LOG(INFO) << "alphabet size=" << required_chars_.size();
LOG(INFO) << "Alphabet size=" << required_chars_.size();
LOG(INFO) << "Final character coverage="
<< 1.0 * accumulated_chars_count / all_chars_count;
CHECK_OR_RETURN(!port::ContainsKey(required_chars_, kUNKChar));

Просмотреть файл

@ -57,12 +57,17 @@ util::Status Trainer::Train() {
if (it.first.find(kUNKStr) != std::string::npos) {
continue;
}
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
if (!trainer_spec_.use_all_vocab() &&
final_pieces_.size() == static_cast<size_t>(vocab_size)) {
break;
}
final_pieces_.emplace_back(it.first, log(it.second) - logsum);
}
if (trainer_spec_.use_all_vocab()) {
trainer_spec_.set_vocab_size(final_pieces_.size() + meta_pieces_.size());
}
return Save();
}
} // namespace word