Merge pull request #64 from google/sr

Reimplement Trainer with Proto reflection
This commit is contained in:
Taku Kudo 2018-05-01 00:14:44 +09:00 коммит произвёл GitHub
Родитель 36a3b35e17 f228e556b4
Коммит 4e91816105
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 519 добавлений и 235 удалений

Просмотреть файл

@ -112,9 +112,13 @@ int ToSwigError(sentencepiece::util::error::Code code) {
%ignore sentencepiece::SentencePieceProcessor::Load(std::istream *);
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie(std::istream *);
%ignore sentencepiece::SentencePieceProcessor::model_proto();
%ignore sentencepiece::SentencePieceTrainer::Train(int, char **);
%ignore sentencepiece::SentencePieceTrainer::Train(const TrainerSpec &);
%ignore sentencepiece::SentencePieceTrainer::Train(const TrainerSpec &, const NormalizerSpec &);
%ignore sentencepiece::SentencePieceTrainer::MergeSpecsFromArgs(const std::string &,
TrainerSpec *, NormalizerSpec *);
%ignore sentencepiece::SentencePieceTrainer::SetProtoField(const std::string &,
const std::string &,
google::protobuf::Message *message);
%extend sentencepiece::SentencePieceProcessor {
std::vector<std::string> Encode(const std::string& input) const {

Просмотреть файл

@ -201,8 +201,8 @@ class SentencePieceTrainer(_object):
SentencePieceTrainer_swigregister = _sentencepiece.SentencePieceTrainer_swigregister
SentencePieceTrainer_swigregister(SentencePieceTrainer)
def SentencePieceTrainer_Train(arg):
return _sentencepiece.SentencePieceTrainer_Train(arg)
def SentencePieceTrainer_Train(args):
return _sentencepiece.SentencePieceTrainer_Train(args)
SentencePieceTrainer_Train = _sentencepiece.SentencePieceTrainer_Train
# This file is compatible with both classic and new-style classes.

Просмотреть файл

@ -4741,6 +4741,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_Train(PyObject *SWIGUNUSEDPARM(s
PyObject *resultobj = 0;
std::string *arg1 = 0 ;
PyObject * obj0 = 0 ;
sentencepiece::util::Status result;
if (!PyArg_ParseTuple(args,(char *)"O:SentencePieceTrainer_Train",&obj0)) SWIG_fail;
{
@ -4754,13 +4755,18 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_Train(PyObject *SWIGUNUSEDPARM(s
}
{
try {
sentencepiece::SentencePieceTrainer::Train((std::string const &)*arg1);
result = sentencepiece::SentencePieceTrainer::Train((std::string const &)*arg1);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
resultobj = SWIG_Py_Void();
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
{
delete arg1;
}

Просмотреть файл

@ -25,7 +25,7 @@ setup(name = 'sentencepiece',
author_email='taku@google.com',
description = 'SentencePiece python wrapper',
long_description = long_description,
version='0.0.7',
version='0.0.8',
url = 'https://github.com/google/sentencepiece',
license = 'Apache',
platforms = 'Unix',

Просмотреть файл

@ -14,6 +14,7 @@
#include "flags.h"
#include "common.h"
#include "util.h"
#include <algorithm>
#include <cctype>
@ -44,23 +45,6 @@ FlagMap *GetFlagMap() {
return &flag_map;
}
bool IsTrue(const std::string &value) {
const char *kTrue[] = {"1", "t", "true", "y", "yes"};
const char *kFalse[] = {"0", "f", "false", "n", "no"};
std::string lower_value = value;
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
::tolower);
for (size_t i = 0; i < 5; ++i) {
if (lower_value == kTrue[i]) {
return true;
} else if (lower_value == kFalse[i]) {
return false;
}
}
LOG(FATAL) << "cannot parse boolean value: " << value;
return false;
}
bool SetFlag(const std::string &name, const std::string &value) {
auto it = GetFlagMap()->find(name);
if (it == GetFlagMap()->end()) {
@ -85,31 +69,26 @@ bool SetFlag(const std::string &name, const std::string &value) {
}
}
#define DEFINE_ARG(FLAG_TYPE, CPP_TYPE) \
case FLAG_TYPE: { \
CPP_TYPE *s = reinterpret_cast<CPP_TYPE *>(flag->storage); \
CHECK(string_util::lexical_cast<CPP_TYPE>(v, s)); \
break; \
}
switch (flag->type) {
case I:
*reinterpret_cast<int32 *>(flag->storage) = atoi(v.c_str());
break;
case B:
*(reinterpret_cast<bool *>(flag->storage)) = IsTrue(v);
break;
case I64:
*reinterpret_cast<int64 *>(flag->storage) = atoll(v.c_str());
break;
case U64:
*reinterpret_cast<uint64 *>(flag->storage) = atoll(v.c_str());
break;
case D:
*reinterpret_cast<double *>(flag->storage) = strtod(v.c_str(), nullptr);
break;
case S:
*reinterpret_cast<std::string *>(flag->storage) = v;
break;
DEFINE_ARG(I, int32);
DEFINE_ARG(B, bool);
DEFINE_ARG(I64, int64);
DEFINE_ARG(U64, uint64);
DEFINE_ARG(D, double);
DEFINE_ARG(S, std::string);
default:
break;
}
return true;
}
} // namespace
bool CommandLineGetFlag(int argc, char **argv, std::string *key,
std::string *value, int *used_args) {

Просмотреть файл

@ -175,6 +175,12 @@ message NormalizerSpec {
// This field must be true to train sentence piece model.
optional bool escape_whitespaces = 5 [ default = true ];
// Custom normalization rule file in TSV format.
// https://github.com/google/sentencepiece/blob/master/doc/normalization.md
// This field is only used in SentencePieceTrainer::Train() method, which
// compiles the rule into the binary rule stored in `precompiled_charsmap`.
optional string normalization_rule_tsv = 6;
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;

Просмотреть файл

@ -31,6 +31,7 @@ const char kSpaceSymbol[] = "\xe2\x96\x81";
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
const char kUnknownSymbol[] = " \xE2\x81\x87 ";
} // namespace
SentencePieceProcessor::SentencePieceProcessor() {}

Просмотреть файл

@ -13,7 +13,6 @@
// limitations under the License.!
#include "sentencepiece_trainer.h"
#include <mutex>
#include <string>
#include "builder.h"
@ -25,196 +24,179 @@
#include "trainer_factory.h"
#include "util.h"
namespace {
static const sentencepiece::TrainerSpec kDefaultTrainerSpec;
static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
} // namespace
DEFINE_string(input, "", "comma separated list of input sentences");
DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
"Input format. Supported format is `text` or `tsv`.");
DEFINE_string(model_prefix, "", "output model prefix");
DEFINE_string(model_type, "unigram",
"model algorithm: unigram, bpe, word or char");
DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
DEFINE_string(accept_language, "",
"comma-separated list of languages this model can accept");
DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
"character coverage to determine the minimum symbols");
DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
"maximum size of sentences the trainer loads");
DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
"maximum size of sentences to make seed sentence piece");
DEFINE_int32(training_sentence_size,
kDefaultTrainerSpec.training_sentence_size(),
"maximum size of sentences to train sentence pieces");
DEFINE_int32(seed_sentencepiece_size,
kDefaultTrainerSpec.seed_sentencepiece_size(),
"the size of seed sentencepieces");
DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
"Keeps top shrinking_factor pieces with respect to the loss");
DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
"number of threads for training");
DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
"number of EM sub-iterations");
DEFINE_int32(max_sentencepiece_length,
kDefaultTrainerSpec.max_sentencepiece_length(),
"maximum length of sentence piece");
DEFINE_bool(split_by_unicode_script,
kDefaultTrainerSpec.split_by_unicode_script(),
"use Unicode script to split sentence pieces");
DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
"use a white space to split sentence pieces");
DEFINE_string(control_symbols, "", "comma separated list of control symbols");
DEFINE_string(user_defined_symbols, "",
"comma separated list of user defined symbols");
DEFINE_string(normalization_rule_name, "nfkc",
"Normalization rule name. "
"Choose from nfkc or identity");
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
"Add dummy whitespace at the beginning of text");
DEFINE_bool(remove_extra_whitespaces,
kDefaultNormalizerSpec.remove_extra_whitespaces(),
"Removes leading, trailing, and "
"duplicate internal whitespace");
DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
"If set to false, --vocab_size is considered as a soft limit.");
DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(),
"Override UNK (<unk>) id.");
DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
"Override BOS (<s>) id. Set -1 to disable BOS.");
DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
"Override EOS (</s>) id. Set -1 to disable EOS.");
DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
"Override PAD (<pad>) id. Set -1 to disable PAD.");
using sentencepiece::NormalizerSpec;
using sentencepiece::TrainerSpec;
using sentencepiece::normalizer::Builder;
namespace sentencepiece {
namespace {
static constexpr char kDefaultNormalizerName[] = "nfkc";
NormalizerSpec MakeNormalizerSpecFromFlags() {
if (!FLAGS_normalization_rule_tsv.empty()) {
const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile(
FLAGS_normalization_rule_tsv);
sentencepiece::NormalizerSpec spec;
spec.set_name("user_defined");
spec.set_precompiled_charsmap(
sentencepiece::normalizer::Builder::CompileCharsMap(chars_map));
return spec;
}
return sentencepiece::normalizer::Builder::GetNormalizerSpec(
FLAGS_normalization_rule_name);
}
TrainerSpec::ModelType GetModelTypeFromString(const std::string &type) {
const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
{"unigram", TrainerSpec::UNIGRAM},
{"bpe", TrainerSpec::BPE},
{"word", TrainerSpec::WORD},
{"char", TrainerSpec::CHAR}};
return port::FindOrDie(kModelTypeMap, type);
}
// Populates the value from flags to spec.
#define SetTrainerSpecFromFlag(name) trainer_spec->set_##name(FLAGS_##name);
#define SetNormalizerSpecFromFlag(name) \
normalizer_spec->set_##name(FLAGS_##name);
#define SetRepeatedTrainerSpecFromFlag(name) \
if (!FLAGS_##name.empty()) { \
for (const auto v : \
sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
trainer_spec->add_##name(v); \
} \
}
void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
NormalizerSpec *normalizer_spec) {
CHECK_NOTNULL(trainer_spec);
CHECK_NOTNULL(normalizer_spec);
SetTrainerSpecFromFlag(input_format);
SetTrainerSpecFromFlag(model_prefix);
SetTrainerSpecFromFlag(vocab_size);
SetTrainerSpecFromFlag(character_coverage);
SetTrainerSpecFromFlag(input_sentence_size);
SetTrainerSpecFromFlag(mining_sentence_size);
SetTrainerSpecFromFlag(training_sentence_size);
SetTrainerSpecFromFlag(seed_sentencepiece_size);
SetTrainerSpecFromFlag(shrinking_factor);
SetTrainerSpecFromFlag(num_threads);
SetTrainerSpecFromFlag(num_sub_iterations);
SetTrainerSpecFromFlag(max_sentencepiece_length);
SetTrainerSpecFromFlag(split_by_unicode_script);
SetTrainerSpecFromFlag(split_by_whitespace);
SetTrainerSpecFromFlag(hard_vocab_limit);
SetTrainerSpecFromFlag(unk_id);
SetTrainerSpecFromFlag(bos_id);
SetTrainerSpecFromFlag(eos_id);
SetTrainerSpecFromFlag(pad_id);
SetRepeatedTrainerSpecFromFlag(accept_language);
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
*normalizer_spec = MakeNormalizerSpecFromFlags();
SetNormalizerSpecFromFlag(add_dummy_prefix);
SetNormalizerSpecFromFlag(remove_extra_whitespaces);
for (const auto &filename :
sentencepiece::string_util::Split(FLAGS_input, ",")) {
trainer_spec->add_input(filename);
}
trainer_spec->set_model_type(GetModelTypeFromString(FLAGS_model_type));
}
} // namespace
// static
void SentencePieceTrainer::Train(int argc, char **argv) {
TrainerSpec trainer_spec;
util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
NormalizerSpec normalizer_spec;
{
static std::mutex flags_mutex;
std::lock_guard<std::mutex> lock(flags_mutex);
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
CHECK_OR_HELP(input);
CHECK_OR_HELP(model_prefix);
MakeTrainerSpecFromFlags(&trainer_spec, &normalizer_spec);
normalizer_spec.set_name(kDefaultNormalizerName);
Train(trainer_spec, normalizer_spec);
return util::OkStatus();
}
// static
util::Status SentencePieceTrainer::Train(
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
auto copied_normalizer_spec = normalizer_spec;
if (!copied_normalizer_spec.normalization_rule_tsv().empty()) {
if (!copied_normalizer_spec.precompiled_charsmap().empty()) {
return util::InternalError("precompiled_charsmap is already defined.");
}
const auto chars_map = normalizer::Builder::BuildMapFromFile(
copied_normalizer_spec.normalization_rule_tsv());
copied_normalizer_spec.set_precompiled_charsmap(
normalizer::Builder::CompileCharsMap(chars_map));
copied_normalizer_spec.set_name("user_defined");
} else {
if (copied_normalizer_spec.name().empty()) {
copied_normalizer_spec.set_name(kDefaultNormalizerName);
}
if (copied_normalizer_spec.precompiled_charsmap().empty()) {
*(copied_normalizer_spec.mutable_precompiled_charsmap()) =
normalizer::Builder::GetNormalizerSpec(copied_normalizer_spec.name())
.precompiled_charsmap();
}
}
SentencePieceTrainer::Train(trainer_spec, normalizer_spec);
}
// static
void SentencePieceTrainer::Train(const std::string &arg) {
const std::vector<std::string> args =
sentencepiece::string_util::Split(arg, " ");
std::vector<char *> cargs(args.size() + 1);
cargs[0] = const_cast<char *>("");
for (size_t i = 0; i < args.size(); ++i)
cargs[i + 1] = const_cast<char *>(args[i].data());
SentencePieceTrainer::Train(static_cast<int>(cargs.size()), &cargs[0]);
}
// static
void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
SentencePieceTrainer::Train(
trainer_spec,
normalizer::Builder::GetNormalizerSpec(kDefaultNormalizerName));
}
// static
void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec) {
auto trainer = TrainerFactory::Create(trainer_spec, normalizer_spec);
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
trainer->Train();
return util::OkStatus();
}
// static
util::Status SentencePieceTrainer::SetProtoField(
const std::string &field_name, const std::string &value,
google::protobuf::Message *message) {
const auto *descriptor = message->GetDescriptor();
const auto *reflection = message->GetReflection();
if (descriptor == nullptr || reflection == nullptr) {
return util::InternalError("Reflection is not supported.");
}
const auto *field = descriptor->FindFieldByName(std::string(field_name));
if (field == nullptr) {
return util::NotFoundError(std::string("Unknown field name \"") +
field_name + "\" in " +
descriptor->DebugString());
}
std::vector<std::string> values = {value};
if (field->is_repeated()) values = string_util::Split(value, ",");
#define SET_FIELD(METHOD_TYPE, v) \
if (field->is_repeated()) \
reflection->Add##METHOD_TYPE(message, field, v); \
else \
reflection->Set##METHOD_TYPE(message, field, v);
#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \
EMPTY) \
case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \
CPP_TYPE v; \
if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \
return util::InvalidArgumentError(std::string("Cannot parse \"") + \
value + "\" as \"" + \
field->type_name() + "\"."); \
SET_FIELD(METHOD_TYPE, v); \
break; \
}
for (const auto &value : values) {
switch (field->cpp_type()) {
DEFINE_SET_FIELD(INT32, int32, i, Int32, "");
DEFINE_SET_FIELD(INT64, int64, i, Int64, "");
DEFINE_SET_FIELD(UINT32, uint32, i, UInt32, "");
DEFINE_SET_FIELD(UINT64, uint64, i, UInt64, "");
DEFINE_SET_FIELD(DOUBLE, double, d, Double, "");
DEFINE_SET_FIELD(FLOAT, float, f, Float, "");
DEFINE_SET_FIELD(BOOL, bool, b, Bool, "true");
case google::protobuf::FieldDescriptor::CPPTYPE_STRING:
SET_FIELD(String, value);
break;
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
const auto *enum_value =
field->enum_type()->FindValueByName(string_util::ToUpper(value));
if (enum_value == nullptr)
return util::InvalidArgumentError(
std::string("Unknown enumeration value of \"") + value +
"\" for field \"" + field->name() + "\".");
SET_FIELD(Enum, enum_value);
break;
}
default:
return util::UnimplementedError(std::string("Proto type \"") +
field->cpp_type_name() +
"\" is not supported.");
break;
}
}
return util::OkStatus();
}
// static
util::Status SentencePieceTrainer::MergeSpecsFromArgs(
const std::string &args, TrainerSpec *trainer_spec,
NormalizerSpec *normalizer_spec) {
if (trainer_spec == nullptr || normalizer_spec == nullptr) {
return util::InternalError(
"`trainer_spec` and `normalizer_spec` must not be null.");
}
if (args.empty()) return util::OkStatus();
for (auto arg : string_util::SplitPiece(args, " ")) {
arg.Consume("--");
std::string key, value;
auto pos = arg.find("=");
if (pos == StringPiece::npos) {
key = arg.ToString();
} else {
key = arg.substr(0, pos).ToString();
value = arg.substr(pos + 1).ToString();
}
// Exception.
if (key == "normalization_rule_name") {
normalizer_spec->set_name(value);
continue;
}
const auto status_train = SetProtoField(key, value, trainer_spec);
if (status_train.ok()) continue;
if (!util::IsNotFound(status_train)) return status_train;
const auto status_norm = SetProtoField(key, value, normalizer_spec);
if (status_norm.ok()) continue;
if (!util::IsNotFound(status_norm)) return status_norm;
// Not found both in trainer_spec and normalizer_spec.
if (util::IsNotFound(status_train) && util::IsNotFound(status_norm)) {
return status_train;
}
}
return util::OkStatus();
}
// static
util::Status SentencePieceTrainer::Train(const std::string &args) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
normalizer_spec.set_name(kDefaultNormalizerName);
CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
return Train(trainer_spec, normalizer_spec);
}
} // namespace sentencepiece

Просмотреть файл

@ -16,6 +16,13 @@
#define SENTENCEPIECE_TRAINER_H_
#include <string>
#include "sentencepiece_processor.h"
namespace google {
namespace protobuf {
class Message;
} // namespace protobuf
} // namespace google
namespace sentencepiece {
@ -24,21 +31,32 @@ class NormalizerSpec;
class SentencePieceTrainer {
public:
// Entry point for main function.
static void Train(int argc, char **argv);
// Train from params with a single line.
// "--input=foo --model_prefix=m --vocab_size=1024"
static void Train(const std::string &arg);
// Trains SentencePiece model with `trainer_spec`.
// Default `normalizer_spec` is used.
static void Train(const TrainerSpec &trainer_spec);
static util::Status Train(const TrainerSpec &trainer_spec);
// Trains SentencePiece model with `trainer_spec` and
// `normalizer_spec`.
static void Train(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec);
static util::Status Train(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec);
// Trains SentencePiece model with command-line string in `args`,
// e.g.,
// '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
static util::Status Train(const std::string &args);
// Overrides `trainer_spec` and `normalizer_spec` with the
// command-line string in `args`.
static util::Status MergeSpecsFromArgs(const std::string &args,
TrainerSpec *trainer_spec,
NormalizerSpec *normalizer_spec);
// Helper function to set `field_name=value` in `message`.
// When `field_name` is repeated, multiple values can be passed
// with comma-separated values. `field_name` must not be a nested message.
static util::Status SetProtoField(const std::string &field_name,
const std::string &value,
google::protobuf::Message *message);
SentencePieceTrainer() = delete;
~SentencePieceTrainer() = delete;

Просмотреть файл

@ -15,6 +15,7 @@
#include "sentencepiece_trainer.h"
#include "sentencepiece_model.pb.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace {
@ -39,12 +40,98 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
"--normalization_rule_tsv=../data/nfkc.tsv");
}
TEST(SentencePieceTrainerTest, TrainErrorTest) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
normalizer_spec.set_normalization_rule_tsv("foo.tsv");
normalizer_spec.set_precompiled_charsmap("foo");
EXPECT_NOT_OK(SentencePieceTrainer::Train(trainer_spec, normalizer_spec));
}
TEST(SentencePieceTrainerTest, TrainTest) {
TrainerSpec trainer_spec;
trainer_spec.add_input("../data/botchan.txt");
trainer_spec.set_model_prefix("m");
trainer_spec.set_vocab_size(1000);
SentencePieceTrainer::Train(trainer_spec);
NormalizerSpec normalizer_spec;
EXPECT_OK(SentencePieceTrainer::Train(trainer_spec, normalizer_spec));
EXPECT_OK(SentencePieceTrainer::Train(trainer_spec));
}
TEST(SentencePieceTrainerTest, SetProtoFieldTest) {
TrainerSpec spec;
EXPECT_NOT_OK(SentencePieceTrainer::SetProtoField("dummy", "1000", &spec));
EXPECT_OK(SentencePieceTrainer::SetProtoField("vocab_size", "1000", &spec));
EXPECT_EQ(1000, spec.vocab_size());
EXPECT_NOT_OK(
SentencePieceTrainer::SetProtoField("vocab_size", "UNK", &spec));
EXPECT_OK(SentencePieceTrainer::SetProtoField("input_format", "TSV", &spec));
EXPECT_EQ("TSV", spec.input_format());
EXPECT_OK(SentencePieceTrainer::SetProtoField("input_format", "123", &spec));
EXPECT_EQ("123", spec.input_format());
EXPECT_OK(SentencePieceTrainer::SetProtoField("split_by_whitespace", "false",
&spec));
EXPECT_FALSE(spec.split_by_whitespace());
EXPECT_OK(
SentencePieceTrainer::SetProtoField("split_by_whitespace", "", &spec));
EXPECT_TRUE(spec.split_by_whitespace());
EXPECT_OK(
SentencePieceTrainer::SetProtoField("character_coverage", "0.5", &spec));
EXPECT_NEAR(spec.character_coverage(), 0.5, 0.001);
EXPECT_NOT_OK(
SentencePieceTrainer::SetProtoField("character_coverage", "UNK", &spec));
EXPECT_OK(SentencePieceTrainer::SetProtoField("input", "foo,bar,buz", &spec));
EXPECT_EQ(3, spec.input_size());
EXPECT_EQ("foo", spec.input(0));
EXPECT_EQ("bar", spec.input(1));
EXPECT_EQ("buz", spec.input(2));
EXPECT_OK(SentencePieceTrainer::SetProtoField("model_type", "BPE", &spec));
EXPECT_NOT_OK(
SentencePieceTrainer::SetProtoField("model_type", "UNK", &spec));
// Nested message is not supported.
ModelProto proto;
EXPECT_NOT_OK(
SentencePieceTrainer::SetProtoField("trainer_spec", "UNK", &proto));
}
TEST(SentencePieceTrainerTest, MergeSpecsFromArgs) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs("", nullptr, nullptr));
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs("", &trainer_spec,
&normalizer_spec));
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--unknown=BPE", &trainer_spec, &normalizer_spec));
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--vocab_size=UNK", &trainer_spec, &normalizer_spec));
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--model_type=UNK", &trainer_spec, &normalizer_spec));
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--model_type=bpe", &trainer_spec, &normalizer_spec));
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--split_by_whitespace", &trainer_spec, &normalizer_spec));
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--normalization_rule_name=foo", &trainer_spec, &normalizer_spec));
EXPECT_EQ("foo", normalizer_spec.name());
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
"--vocab_size=UNK", &trainer_spec, &normalizer_spec));
}
} // namespace
} // namespace sentencepiece

Просмотреть файл

@ -12,10 +12,140 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "builder.h"
#include "flags.h"
#include "sentencepiece_trainer.h"
#include "util.h"
using sentencepiece::NormalizerSpec;
using sentencepiece::TrainerSpec;
using sentencepiece::normalizer::Builder;
namespace {
static sentencepiece::TrainerSpec kDefaultTrainerSpec;
static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
} // namespace
DEFINE_string(input, "", "comma separated list of input sentences");
DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
"Input format. Supported format is `text` or `tsv`.");
DEFINE_string(model_prefix, "", "output model prefix");
DEFINE_string(model_type, "unigram",
"model algorithm: unigram, bpe, word or char");
DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
DEFINE_string(accept_language, "",
"comma-separated list of languages this model can accept");
DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
"character coverage to determine the minimum symbols");
DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
"maximum size of sentences the trainer loads");
DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
"maximum size of sentences to make seed sentence piece");
DEFINE_int32(training_sentence_size,
kDefaultTrainerSpec.training_sentence_size(),
"maximum size of sentences to train sentence pieces");
DEFINE_int32(seed_sentencepiece_size,
kDefaultTrainerSpec.seed_sentencepiece_size(),
"the size of seed sentencepieces");
DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
"Keeps top shrinking_factor pieces with respect to the loss");
DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
"number of threads for training");
DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
"number of EM sub-iterations");
DEFINE_int32(max_sentencepiece_length,
kDefaultTrainerSpec.max_sentencepiece_length(),
"maximum length of sentence piece");
DEFINE_bool(split_by_unicode_script,
kDefaultTrainerSpec.split_by_unicode_script(),
"use Unicode script to split sentence pieces");
DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
"use a white space to split sentence pieces");
DEFINE_string(control_symbols, "", "comma separated list of control symbols");
DEFINE_string(user_defined_symbols, "",
"comma separated list of user defined symbols");
DEFINE_string(normalization_rule_name, "nfkc",
"Normalization rule name. "
"Choose from nfkc or identity");
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
"Add dummy whitespace at the beginning of text");
DEFINE_bool(remove_extra_whitespaces,
kDefaultNormalizerSpec.remove_extra_whitespaces(),
"Removes leading, trailing, and "
"duplicate internal whitespace");
DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
"If set to false, --vocab_size is considered as a soft limit.");
DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
"Override BOS (<s>) id. Set -1 to disable BOS.");
DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
"Override EOS (</s>) id. Set -1 to disable EOS.");
DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
"Override PAD (<pad>) id. Set -1 to disable PAD.");
int main(int argc, char *argv[]) {
sentencepiece::SentencePieceTrainer::Train(argc, argv);
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
sentencepiece::TrainerSpec trainer_spec;
sentencepiece::NormalizerSpec normalizer_spec;
CHECK_OR_HELP(input);
CHECK_OR_HELP(model_prefix);
// Populates the value from flags to spec.
#define SetTrainerSpecFromFlag(name) trainer_spec.set_##name(FLAGS_##name);
#define SetNormalizerSpecFromFlag(name) \
normalizer_spec.set_##name(FLAGS_##name);
#define SetRepeatedTrainerSpecFromFlag(name) \
if (!FLAGS_##name.empty()) { \
for (const auto v : \
sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
trainer_spec.add_##name(v); \
} \
}
SetTrainerSpecFromFlag(input_format);
SetTrainerSpecFromFlag(model_prefix);
SetTrainerSpecFromFlag(vocab_size);
SetTrainerSpecFromFlag(character_coverage);
SetTrainerSpecFromFlag(input_sentence_size);
SetTrainerSpecFromFlag(mining_sentence_size);
SetTrainerSpecFromFlag(training_sentence_size);
SetTrainerSpecFromFlag(seed_sentencepiece_size);
SetTrainerSpecFromFlag(shrinking_factor);
SetTrainerSpecFromFlag(num_threads);
SetTrainerSpecFromFlag(num_sub_iterations);
SetTrainerSpecFromFlag(max_sentencepiece_length);
SetTrainerSpecFromFlag(split_by_unicode_script);
SetTrainerSpecFromFlag(split_by_whitespace);
SetTrainerSpecFromFlag(hard_vocab_limit);
SetTrainerSpecFromFlag(unk_id);
SetTrainerSpecFromFlag(bos_id);
SetTrainerSpecFromFlag(eos_id);
SetTrainerSpecFromFlag(pad_id);
SetRepeatedTrainerSpecFromFlag(input);
SetRepeatedTrainerSpecFromFlag(accept_language);
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
normalizer_spec.set_name(FLAGS_normalization_rule_name);
SetNormalizerSpecFromFlag(normalization_rule_tsv);
SetNormalizerSpecFromFlag(add_dummy_prefix);
SetNormalizerSpecFromFlag(remove_extra_whitespaces);
const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
{"unigram", TrainerSpec::UNIGRAM},
{"bpe", TrainerSpec::BPE},
{"word", TrainerSpec::WORD},
{"char", TrainerSpec::CHAR}};
trainer_spec.set_model_type(
sentencepiece::port::FindOrDie(kModelTypeMap, FLAGS_model_type));
CHECK_OK(sentencepiece::SentencePieceTrainer::Train(trainer_spec,
normalizer_spec));
return 0;
}

Просмотреть файл

@ -37,6 +37,52 @@ std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
// String utilities
namespace string_util {
inline std::string ToLower(StringPiece arg) {
std::string lower_value = arg.ToString();
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
::tolower);
return lower_value;
}
inline std::string ToUpper(StringPiece arg) {
std::string upper_value = arg.ToString();
std::transform(upper_value.begin(), upper_value.end(), upper_value.begin(),
::toupper);
return upper_value;
}
template <typename Target>
inline bool lexical_cast(StringPiece arg, Target *result) {
std::stringstream ss;
return (ss << arg.data() && ss >> *result);
}
template <>
inline bool lexical_cast(StringPiece arg, bool *result) {
const char *kTrue[] = {"1", "t", "true", "y", "yes"};
const char *kFalse[] = {"0", "f", "false", "n", "no"};
std::string lower_value = arg.ToString();
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
::tolower);
for (size_t i = 0; i < 5; ++i) {
if (lower_value == kTrue[i]) {
*result = true;
return true;
} else if (lower_value == kFalse[i]) {
*result = false;
return true;
}
}
return false;
}
template <>
inline bool lexical_cast(StringPiece arg, std::string *result) {
*result = arg.ToString();
return true;
}
std::vector<std::string> Split(const std::string &str,
const std::string &delim);

Просмотреть файл

@ -18,6 +18,31 @@
namespace sentencepiece {
TEST(UtilTest, LexicalCastTest) {
bool b = false;
EXPECT_TRUE(string_util::lexical_cast<bool>("true", &b));
EXPECT_TRUE(b);
EXPECT_TRUE(string_util::lexical_cast<bool>("false", &b));
EXPECT_FALSE(b);
EXPECT_FALSE(string_util::lexical_cast<bool>("UNK", &b));
int32 n = 0;
EXPECT_TRUE(string_util::lexical_cast<int32>("123", &n));
EXPECT_EQ(123, n);
EXPECT_TRUE(string_util::lexical_cast<int32>("-123", &n));
EXPECT_EQ(-123, n);
EXPECT_FALSE(string_util::lexical_cast<int32>("UNK", &n));
double d = 0.0;
EXPECT_TRUE(string_util::lexical_cast<double>("123.4", &d));
EXPECT_NEAR(123.4, d, 0.001);
EXPECT_FALSE(string_util::lexical_cast<double>("UNK", &d));
std::string s;
EXPECT_TRUE(string_util::lexical_cast<std::string>("123.4", &s));
EXPECT_EQ("123.4", s);
}
TEST(UtilTest, CheckNotNullTest) {
int a = 0;
CHECK_NOTNULL(&a);