Merge pull request #64 from google/sr
Reimplement Trainer with Proto reflection
This commit is contained in:
Коммит
4e91816105
|
@ -112,9 +112,13 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
%ignore sentencepiece::SentencePieceProcessor::Load(std::istream *);
|
||||
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie(std::istream *);
|
||||
%ignore sentencepiece::SentencePieceProcessor::model_proto();
|
||||
%ignore sentencepiece::SentencePieceTrainer::Train(int, char **);
|
||||
%ignore sentencepiece::SentencePieceTrainer::Train(const TrainerSpec &);
|
||||
%ignore sentencepiece::SentencePieceTrainer::Train(const TrainerSpec &, const NormalizerSpec &);
|
||||
%ignore sentencepiece::SentencePieceTrainer::MergeSpecsFromArgs(const std::string &,
|
||||
TrainerSpec *, NormalizerSpec *);
|
||||
%ignore sentencepiece::SentencePieceTrainer::SetProtoField(const std::string &,
|
||||
const std::string &,
|
||||
google::protobuf::Message *message);
|
||||
|
||||
%extend sentencepiece::SentencePieceProcessor {
|
||||
std::vector<std::string> Encode(const std::string& input) const {
|
||||
|
|
|
@ -201,8 +201,8 @@ class SentencePieceTrainer(_object):
|
|||
SentencePieceTrainer_swigregister = _sentencepiece.SentencePieceTrainer_swigregister
|
||||
SentencePieceTrainer_swigregister(SentencePieceTrainer)
|
||||
|
||||
def SentencePieceTrainer_Train(arg):
|
||||
return _sentencepiece.SentencePieceTrainer_Train(arg)
|
||||
def SentencePieceTrainer_Train(args):
|
||||
return _sentencepiece.SentencePieceTrainer_Train(args)
|
||||
SentencePieceTrainer_Train = _sentencepiece.SentencePieceTrainer_Train
|
||||
|
||||
# This file is compatible with both classic and new-style classes.
|
||||
|
|
|
@ -4741,6 +4741,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_Train(PyObject *SWIGUNUSEDPARM(s
|
|||
PyObject *resultobj = 0;
|
||||
std::string *arg1 = 0 ;
|
||||
PyObject * obj0 = 0 ;
|
||||
sentencepiece::util::Status result;
|
||||
|
||||
if (!PyArg_ParseTuple(args,(char *)"O:SentencePieceTrainer_Train",&obj0)) SWIG_fail;
|
||||
{
|
||||
|
@ -4754,13 +4755,18 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_Train(PyObject *SWIGUNUSEDPARM(s
|
|||
}
|
||||
{
|
||||
try {
|
||||
sentencepiece::SentencePieceTrainer::Train((std::string const &)*arg1);
|
||||
result = sentencepiece::SentencePieceTrainer::Train((std::string const &)*arg1);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
resultobj = SWIG_Py_Void();
|
||||
{
|
||||
if (!(&result)->ok()) {
|
||||
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
|
||||
}
|
||||
resultobj = SWIG_From_bool((&result)->ok());
|
||||
}
|
||||
{
|
||||
delete arg1;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ setup(name = 'sentencepiece',
|
|||
author_email='taku@google.com',
|
||||
description = 'SentencePiece python wrapper',
|
||||
long_description = long_description,
|
||||
version='0.0.7',
|
||||
version='0.0.8',
|
||||
url = 'https://github.com/google/sentencepiece',
|
||||
license = 'Apache',
|
||||
platforms = 'Unix',
|
||||
|
|
51
src/flags.cc
51
src/flags.cc
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "flags.h"
|
||||
#include "common.h"
|
||||
#include "util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
|
@ -44,23 +45,6 @@ FlagMap *GetFlagMap() {
|
|||
return &flag_map;
|
||||
}
|
||||
|
||||
bool IsTrue(const std::string &value) {
|
||||
const char *kTrue[] = {"1", "t", "true", "y", "yes"};
|
||||
const char *kFalse[] = {"0", "f", "false", "n", "no"};
|
||||
std::string lower_value = value;
|
||||
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
|
||||
::tolower);
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
if (lower_value == kTrue[i]) {
|
||||
return true;
|
||||
} else if (lower_value == kFalse[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "cannot parse boolean value: " << value;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SetFlag(const std::string &name, const std::string &value) {
|
||||
auto it = GetFlagMap()->find(name);
|
||||
if (it == GetFlagMap()->end()) {
|
||||
|
@ -85,31 +69,26 @@ bool SetFlag(const std::string &name, const std::string &value) {
|
|||
}
|
||||
}
|
||||
|
||||
#define DEFINE_ARG(FLAG_TYPE, CPP_TYPE) \
|
||||
case FLAG_TYPE: { \
|
||||
CPP_TYPE *s = reinterpret_cast<CPP_TYPE *>(flag->storage); \
|
||||
CHECK(string_util::lexical_cast<CPP_TYPE>(v, s)); \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (flag->type) {
|
||||
case I:
|
||||
*reinterpret_cast<int32 *>(flag->storage) = atoi(v.c_str());
|
||||
break;
|
||||
case B:
|
||||
*(reinterpret_cast<bool *>(flag->storage)) = IsTrue(v);
|
||||
break;
|
||||
case I64:
|
||||
*reinterpret_cast<int64 *>(flag->storage) = atoll(v.c_str());
|
||||
break;
|
||||
case U64:
|
||||
*reinterpret_cast<uint64 *>(flag->storage) = atoll(v.c_str());
|
||||
break;
|
||||
case D:
|
||||
*reinterpret_cast<double *>(flag->storage) = strtod(v.c_str(), nullptr);
|
||||
break;
|
||||
case S:
|
||||
*reinterpret_cast<std::string *>(flag->storage) = v;
|
||||
break;
|
||||
DEFINE_ARG(I, int32);
|
||||
DEFINE_ARG(B, bool);
|
||||
DEFINE_ARG(I64, int64);
|
||||
DEFINE_ARG(U64, uint64);
|
||||
DEFINE_ARG(D, double);
|
||||
DEFINE_ARG(S, std::string);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CommandLineGetFlag(int argc, char **argv, std::string *key,
|
||||
std::string *value, int *used_args) {
|
||||
|
|
|
@ -175,6 +175,12 @@ message NormalizerSpec {
|
|||
// This field must be true to train sentence piece model.
|
||||
optional bool escape_whitespaces = 5 [ default = true ];
|
||||
|
||||
// Custom normalization rule file in TSV format.
|
||||
// https://github.com/google/sentencepiece/blob/master/doc/normalization.md
|
||||
// This field is only used in SentencePieceTrainer::Train() method, which
|
||||
// compiles the rule into the binary rule stored in `precompiled_charsmap`.
|
||||
optional string normalization_rule_tsv = 6;
|
||||
|
||||
// Customized extensions: the range of field numbers
|
||||
// are open to third-party extensions.
|
||||
extensions 200 to max;
|
||||
|
|
|
@ -31,6 +31,7 @@ const char kSpaceSymbol[] = "\xe2\x96\x81";
|
|||
// since this character can be useful both for user and
|
||||
// developer. We can easily figure out that <unk> is emitted.
|
||||
const char kUnknownSymbol[] = " \xE2\x81\x87 ";
|
||||
|
||||
} // namespace
|
||||
|
||||
SentencePieceProcessor::SentencePieceProcessor() {}
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
// limitations under the License.!
|
||||
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
#include "builder.h"
|
||||
|
@ -25,196 +24,179 @@
|
|||
#include "trainer_factory.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace {
|
||||
static const sentencepiece::TrainerSpec kDefaultTrainerSpec;
|
||||
static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
|
||||
} // namespace
|
||||
|
||||
DEFINE_string(input, "", "comma separated list of input sentences");
|
||||
DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
|
||||
"Input format. Supported format is `text` or `tsv`.");
|
||||
DEFINE_string(model_prefix, "", "output model prefix");
|
||||
DEFINE_string(model_type, "unigram",
|
||||
"model algorithm: unigram, bpe, word or char");
|
||||
DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
|
||||
DEFINE_string(accept_language, "",
|
||||
"comma-separated list of languages this model can accept");
|
||||
DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
|
||||
"character coverage to determine the minimum symbols");
|
||||
DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
|
||||
"maximum size of sentences the trainer loads");
|
||||
DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
|
||||
"maximum size of sentences to make seed sentence piece");
|
||||
DEFINE_int32(training_sentence_size,
|
||||
kDefaultTrainerSpec.training_sentence_size(),
|
||||
"maximum size of sentences to train sentence pieces");
|
||||
DEFINE_int32(seed_sentencepiece_size,
|
||||
kDefaultTrainerSpec.seed_sentencepiece_size(),
|
||||
"the size of seed sentencepieces");
|
||||
DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
|
||||
"Keeps top shrinking_factor pieces with respect to the loss");
|
||||
DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
|
||||
"number of threads for training");
|
||||
DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
|
||||
"number of EM sub-iterations");
|
||||
DEFINE_int32(max_sentencepiece_length,
|
||||
kDefaultTrainerSpec.max_sentencepiece_length(),
|
||||
"maximum length of sentence piece");
|
||||
DEFINE_bool(split_by_unicode_script,
|
||||
kDefaultTrainerSpec.split_by_unicode_script(),
|
||||
"use Unicode script to split sentence pieces");
|
||||
DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
|
||||
"use a white space to split sentence pieces");
|
||||
DEFINE_string(control_symbols, "", "comma separated list of control symbols");
|
||||
DEFINE_string(user_defined_symbols, "",
|
||||
"comma separated list of user defined symbols");
|
||||
DEFINE_string(normalization_rule_name, "nfkc",
|
||||
"Normalization rule name. "
|
||||
"Choose from nfkc or identity");
|
||||
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
|
||||
DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
|
||||
"Add dummy whitespace at the beginning of text");
|
||||
DEFINE_bool(remove_extra_whitespaces,
|
||||
kDefaultNormalizerSpec.remove_extra_whitespaces(),
|
||||
"Removes leading, trailing, and "
|
||||
"duplicate internal whitespace");
|
||||
DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
|
||||
"If set to false, --vocab_size is considered as a soft limit.");
|
||||
DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(),
|
||||
"Override UNK (<unk>) id.");
|
||||
DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
|
||||
"Override BOS (<s>) id. Set -1 to disable BOS.");
|
||||
DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
|
||||
"Override EOS (</s>) id. Set -1 to disable EOS.");
|
||||
DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
|
||||
"Override PAD (<pad>) id. Set -1 to disable PAD.");
|
||||
|
||||
using sentencepiece::NormalizerSpec;
|
||||
using sentencepiece::TrainerSpec;
|
||||
using sentencepiece::normalizer::Builder;
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
static constexpr char kDefaultNormalizerName[] = "nfkc";
|
||||
|
||||
NormalizerSpec MakeNormalizerSpecFromFlags() {
|
||||
if (!FLAGS_normalization_rule_tsv.empty()) {
|
||||
const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile(
|
||||
FLAGS_normalization_rule_tsv);
|
||||
sentencepiece::NormalizerSpec spec;
|
||||
spec.set_name("user_defined");
|
||||
spec.set_precompiled_charsmap(
|
||||
sentencepiece::normalizer::Builder::CompileCharsMap(chars_map));
|
||||
return spec;
|
||||
}
|
||||
|
||||
return sentencepiece::normalizer::Builder::GetNormalizerSpec(
|
||||
FLAGS_normalization_rule_name);
|
||||
}
|
||||
|
||||
TrainerSpec::ModelType GetModelTypeFromString(const std::string &type) {
|
||||
const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
|
||||
{"unigram", TrainerSpec::UNIGRAM},
|
||||
{"bpe", TrainerSpec::BPE},
|
||||
{"word", TrainerSpec::WORD},
|
||||
{"char", TrainerSpec::CHAR}};
|
||||
return port::FindOrDie(kModelTypeMap, type);
|
||||
}
|
||||
|
||||
// Populates the value from flags to spec.
|
||||
#define SetTrainerSpecFromFlag(name) trainer_spec->set_##name(FLAGS_##name);
|
||||
|
||||
#define SetNormalizerSpecFromFlag(name) \
|
||||
normalizer_spec->set_##name(FLAGS_##name);
|
||||
|
||||
#define SetRepeatedTrainerSpecFromFlag(name) \
|
||||
if (!FLAGS_##name.empty()) { \
|
||||
for (const auto v : \
|
||||
sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
|
||||
trainer_spec->add_##name(v); \
|
||||
} \
|
||||
}
|
||||
|
||||
void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
|
||||
NormalizerSpec *normalizer_spec) {
|
||||
CHECK_NOTNULL(trainer_spec);
|
||||
CHECK_NOTNULL(normalizer_spec);
|
||||
|
||||
SetTrainerSpecFromFlag(input_format);
|
||||
SetTrainerSpecFromFlag(model_prefix);
|
||||
SetTrainerSpecFromFlag(vocab_size);
|
||||
SetTrainerSpecFromFlag(character_coverage);
|
||||
SetTrainerSpecFromFlag(input_sentence_size);
|
||||
SetTrainerSpecFromFlag(mining_sentence_size);
|
||||
SetTrainerSpecFromFlag(training_sentence_size);
|
||||
SetTrainerSpecFromFlag(seed_sentencepiece_size);
|
||||
SetTrainerSpecFromFlag(shrinking_factor);
|
||||
SetTrainerSpecFromFlag(num_threads);
|
||||
SetTrainerSpecFromFlag(num_sub_iterations);
|
||||
SetTrainerSpecFromFlag(max_sentencepiece_length);
|
||||
SetTrainerSpecFromFlag(split_by_unicode_script);
|
||||
SetTrainerSpecFromFlag(split_by_whitespace);
|
||||
SetTrainerSpecFromFlag(hard_vocab_limit);
|
||||
SetTrainerSpecFromFlag(unk_id);
|
||||
SetTrainerSpecFromFlag(bos_id);
|
||||
SetTrainerSpecFromFlag(eos_id);
|
||||
SetTrainerSpecFromFlag(pad_id);
|
||||
SetRepeatedTrainerSpecFromFlag(accept_language);
|
||||
SetRepeatedTrainerSpecFromFlag(control_symbols);
|
||||
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
|
||||
|
||||
*normalizer_spec = MakeNormalizerSpecFromFlags();
|
||||
SetNormalizerSpecFromFlag(add_dummy_prefix);
|
||||
SetNormalizerSpecFromFlag(remove_extra_whitespaces);
|
||||
|
||||
for (const auto &filename :
|
||||
sentencepiece::string_util::Split(FLAGS_input, ",")) {
|
||||
trainer_spec->add_input(filename);
|
||||
}
|
||||
|
||||
trainer_spec->set_model_type(GetModelTypeFromString(FLAGS_model_type));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
void SentencePieceTrainer::Train(int argc, char **argv) {
|
||||
TrainerSpec trainer_spec;
|
||||
util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
|
||||
NormalizerSpec normalizer_spec;
|
||||
{
|
||||
static std::mutex flags_mutex;
|
||||
std::lock_guard<std::mutex> lock(flags_mutex);
|
||||
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
|
||||
CHECK_OR_HELP(input);
|
||||
CHECK_OR_HELP(model_prefix);
|
||||
MakeTrainerSpecFromFlags(&trainer_spec, &normalizer_spec);
|
||||
normalizer_spec.set_name(kDefaultNormalizerName);
|
||||
Train(trainer_spec, normalizer_spec);
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::Train(
|
||||
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
|
||||
auto copied_normalizer_spec = normalizer_spec;
|
||||
|
||||
if (!copied_normalizer_spec.normalization_rule_tsv().empty()) {
|
||||
if (!copied_normalizer_spec.precompiled_charsmap().empty()) {
|
||||
return util::InternalError("precompiled_charsmap is already defined.");
|
||||
}
|
||||
|
||||
const auto chars_map = normalizer::Builder::BuildMapFromFile(
|
||||
copied_normalizer_spec.normalization_rule_tsv());
|
||||
copied_normalizer_spec.set_precompiled_charsmap(
|
||||
normalizer::Builder::CompileCharsMap(chars_map));
|
||||
copied_normalizer_spec.set_name("user_defined");
|
||||
} else {
|
||||
if (copied_normalizer_spec.name().empty()) {
|
||||
copied_normalizer_spec.set_name(kDefaultNormalizerName);
|
||||
}
|
||||
|
||||
if (copied_normalizer_spec.precompiled_charsmap().empty()) {
|
||||
*(copied_normalizer_spec.mutable_precompiled_charsmap()) =
|
||||
normalizer::Builder::GetNormalizerSpec(copied_normalizer_spec.name())
|
||||
.precompiled_charsmap();
|
||||
}
|
||||
}
|
||||
|
||||
SentencePieceTrainer::Train(trainer_spec, normalizer_spec);
|
||||
}
|
||||
|
||||
// static
|
||||
void SentencePieceTrainer::Train(const std::string &arg) {
|
||||
const std::vector<std::string> args =
|
||||
sentencepiece::string_util::Split(arg, " ");
|
||||
std::vector<char *> cargs(args.size() + 1);
|
||||
cargs[0] = const_cast<char *>("");
|
||||
for (size_t i = 0; i < args.size(); ++i)
|
||||
cargs[i + 1] = const_cast<char *>(args[i].data());
|
||||
SentencePieceTrainer::Train(static_cast<int>(cargs.size()), &cargs[0]);
|
||||
}
|
||||
|
||||
// static
|
||||
void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
|
||||
SentencePieceTrainer::Train(
|
||||
trainer_spec,
|
||||
normalizer::Builder::GetNormalizerSpec(kDefaultNormalizerName));
|
||||
}
|
||||
|
||||
// static
|
||||
void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec,
|
||||
const NormalizerSpec &normalizer_spec) {
|
||||
auto trainer = TrainerFactory::Create(trainer_spec, normalizer_spec);
|
||||
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
|
||||
trainer->Train();
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::SetProtoField(
|
||||
const std::string &field_name, const std::string &value,
|
||||
google::protobuf::Message *message) {
|
||||
const auto *descriptor = message->GetDescriptor();
|
||||
const auto *reflection = message->GetReflection();
|
||||
|
||||
if (descriptor == nullptr || reflection == nullptr) {
|
||||
return util::InternalError("Reflection is not supported.");
|
||||
}
|
||||
|
||||
const auto *field = descriptor->FindFieldByName(std::string(field_name));
|
||||
|
||||
if (field == nullptr) {
|
||||
return util::NotFoundError(std::string("Unknown field name \"") +
|
||||
field_name + "\" in " +
|
||||
descriptor->DebugString());
|
||||
}
|
||||
|
||||
std::vector<std::string> values = {value};
|
||||
if (field->is_repeated()) values = string_util::Split(value, ",");
|
||||
|
||||
#define SET_FIELD(METHOD_TYPE, v) \
|
||||
if (field->is_repeated()) \
|
||||
reflection->Add##METHOD_TYPE(message, field, v); \
|
||||
else \
|
||||
reflection->Set##METHOD_TYPE(message, field, v);
|
||||
|
||||
#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \
|
||||
EMPTY) \
|
||||
case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \
|
||||
CPP_TYPE v; \
|
||||
if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \
|
||||
return util::InvalidArgumentError(std::string("Cannot parse \"") + \
|
||||
value + "\" as \"" + \
|
||||
field->type_name() + "\"."); \
|
||||
SET_FIELD(METHOD_TYPE, v); \
|
||||
break; \
|
||||
}
|
||||
|
||||
for (const auto &value : values) {
|
||||
switch (field->cpp_type()) {
|
||||
DEFINE_SET_FIELD(INT32, int32, i, Int32, "");
|
||||
DEFINE_SET_FIELD(INT64, int64, i, Int64, "");
|
||||
DEFINE_SET_FIELD(UINT32, uint32, i, UInt32, "");
|
||||
DEFINE_SET_FIELD(UINT64, uint64, i, UInt64, "");
|
||||
DEFINE_SET_FIELD(DOUBLE, double, d, Double, "");
|
||||
DEFINE_SET_FIELD(FLOAT, float, f, Float, "");
|
||||
DEFINE_SET_FIELD(BOOL, bool, b, Bool, "true");
|
||||
case google::protobuf::FieldDescriptor::CPPTYPE_STRING:
|
||||
SET_FIELD(String, value);
|
||||
break;
|
||||
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
|
||||
const auto *enum_value =
|
||||
field->enum_type()->FindValueByName(string_util::ToUpper(value));
|
||||
if (enum_value == nullptr)
|
||||
return util::InvalidArgumentError(
|
||||
std::string("Unknown enumeration value of \"") + value +
|
||||
"\" for field \"" + field->name() + "\".");
|
||||
SET_FIELD(Enum, enum_value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return util::UnimplementedError(std::string("Proto type \"") +
|
||||
field->cpp_type_name() +
|
||||
"\" is not supported.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
const std::string &args, TrainerSpec *trainer_spec,
|
||||
NormalizerSpec *normalizer_spec) {
|
||||
if (trainer_spec == nullptr || normalizer_spec == nullptr) {
|
||||
return util::InternalError(
|
||||
"`trainer_spec` and `normalizer_spec` must not be null.");
|
||||
}
|
||||
|
||||
if (args.empty()) return util::OkStatus();
|
||||
|
||||
for (auto arg : string_util::SplitPiece(args, " ")) {
|
||||
arg.Consume("--");
|
||||
std::string key, value;
|
||||
auto pos = arg.find("=");
|
||||
if (pos == StringPiece::npos) {
|
||||
key = arg.ToString();
|
||||
} else {
|
||||
key = arg.substr(0, pos).ToString();
|
||||
value = arg.substr(pos + 1).ToString();
|
||||
}
|
||||
|
||||
// Exception.
|
||||
if (key == "normalization_rule_name") {
|
||||
normalizer_spec->set_name(value);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto status_train = SetProtoField(key, value, trainer_spec);
|
||||
if (status_train.ok()) continue;
|
||||
if (!util::IsNotFound(status_train)) return status_train;
|
||||
|
||||
const auto status_norm = SetProtoField(key, value, normalizer_spec);
|
||||
if (status_norm.ok()) continue;
|
||||
if (!util::IsNotFound(status_norm)) return status_norm;
|
||||
|
||||
// Not found both in trainer_spec and normalizer_spec.
|
||||
if (util::IsNotFound(status_train) && util::IsNotFound(status_norm)) {
|
||||
return status_train;
|
||||
}
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::Train(const std::string &args) {
|
||||
TrainerSpec trainer_spec;
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name(kDefaultNormalizerName);
|
||||
|
||||
CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
|
||||
|
||||
return Train(trainer_spec, normalizer_spec);
|
||||
}
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
|
|
@ -16,6 +16,13 @@
|
|||
#define SENTENCEPIECE_TRAINER_H_
|
||||
|
||||
#include <string>
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
namespace google {
|
||||
namespace protobuf {
|
||||
class Message;
|
||||
} // namespace protobuf
|
||||
} // namespace google
|
||||
|
||||
namespace sentencepiece {
|
||||
|
||||
|
@ -24,21 +31,32 @@ class NormalizerSpec;
|
|||
|
||||
class SentencePieceTrainer {
|
||||
public:
|
||||
// Entry point for main function.
|
||||
static void Train(int argc, char **argv);
|
||||
|
||||
// Train from params with a single line.
|
||||
// "--input=foo --model_prefix=m --vocab_size=1024"
|
||||
static void Train(const std::string &arg);
|
||||
|
||||
// Trains SentencePiece model with `trainer_spec`.
|
||||
// Default `normalizer_spec` is used.
|
||||
static void Train(const TrainerSpec &trainer_spec);
|
||||
static util::Status Train(const TrainerSpec &trainer_spec);
|
||||
|
||||
// Trains SentencePiece model with `trainer_spec` and
|
||||
// `normalizer_spec`.
|
||||
static void Train(const TrainerSpec &trainer_spec,
|
||||
const NormalizerSpec &normalizer_spec);
|
||||
static util::Status Train(const TrainerSpec &trainer_spec,
|
||||
const NormalizerSpec &normalizer_spec);
|
||||
|
||||
// Trains SentencePiece model with command-line string in `args`,
|
||||
// e.g.,
|
||||
// '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
|
||||
static util::Status Train(const std::string &args);
|
||||
|
||||
// Overrides `trainer_spec` and `normalizer_spec` with the
|
||||
// command-line string in `args`.
|
||||
static util::Status MergeSpecsFromArgs(const std::string &args,
|
||||
TrainerSpec *trainer_spec,
|
||||
NormalizerSpec *normalizer_spec);
|
||||
|
||||
// Helper function to set `field_name=value` in `message`.
|
||||
// When `field_name` is repeated, multiple values can be passed
|
||||
// with comma-separated values. `field_name` must not be a nested message.
|
||||
static util::Status SetProtoField(const std::string &field_name,
|
||||
const std::string &value,
|
||||
google::protobuf::Message *message);
|
||||
|
||||
SentencePieceTrainer() = delete;
|
||||
~SentencePieceTrainer() = delete;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "sentencepiece_trainer.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
|
@ -39,12 +40,98 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
|
|||
"--normalization_rule_tsv=../data/nfkc.tsv");
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, TrainErrorTest) {
|
||||
TrainerSpec trainer_spec;
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_normalization_rule_tsv("foo.tsv");
|
||||
normalizer_spec.set_precompiled_charsmap("foo");
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::Train(trainer_spec, normalizer_spec));
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, TrainTest) {
|
||||
TrainerSpec trainer_spec;
|
||||
trainer_spec.add_input("../data/botchan.txt");
|
||||
trainer_spec.set_model_prefix("m");
|
||||
trainer_spec.set_vocab_size(1000);
|
||||
SentencePieceTrainer::Train(trainer_spec);
|
||||
NormalizerSpec normalizer_spec;
|
||||
EXPECT_OK(SentencePieceTrainer::Train(trainer_spec, normalizer_spec));
|
||||
EXPECT_OK(SentencePieceTrainer::Train(trainer_spec));
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, SetProtoFieldTest) {
|
||||
TrainerSpec spec;
|
||||
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::SetProtoField("dummy", "1000", &spec));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::SetProtoField("vocab_size", "1000", &spec));
|
||||
EXPECT_EQ(1000, spec.vocab_size());
|
||||
EXPECT_NOT_OK(
|
||||
SentencePieceTrainer::SetProtoField("vocab_size", "UNK", &spec));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::SetProtoField("input_format", "TSV", &spec));
|
||||
EXPECT_EQ("TSV", spec.input_format());
|
||||
EXPECT_OK(SentencePieceTrainer::SetProtoField("input_format", "123", &spec));
|
||||
EXPECT_EQ("123", spec.input_format());
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::SetProtoField("split_by_whitespace", "false",
|
||||
&spec));
|
||||
EXPECT_FALSE(spec.split_by_whitespace());
|
||||
EXPECT_OK(
|
||||
SentencePieceTrainer::SetProtoField("split_by_whitespace", "", &spec));
|
||||
EXPECT_TRUE(spec.split_by_whitespace());
|
||||
|
||||
EXPECT_OK(
|
||||
SentencePieceTrainer::SetProtoField("character_coverage", "0.5", &spec));
|
||||
EXPECT_NEAR(spec.character_coverage(), 0.5, 0.001);
|
||||
EXPECT_NOT_OK(
|
||||
SentencePieceTrainer::SetProtoField("character_coverage", "UNK", &spec));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::SetProtoField("input", "foo,bar,buz", &spec));
|
||||
EXPECT_EQ(3, spec.input_size());
|
||||
EXPECT_EQ("foo", spec.input(0));
|
||||
EXPECT_EQ("bar", spec.input(1));
|
||||
EXPECT_EQ("buz", spec.input(2));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::SetProtoField("model_type", "BPE", &spec));
|
||||
EXPECT_NOT_OK(
|
||||
SentencePieceTrainer::SetProtoField("model_type", "UNK", &spec));
|
||||
|
||||
// Nested message is not supported.
|
||||
ModelProto proto;
|
||||
EXPECT_NOT_OK(
|
||||
SentencePieceTrainer::SetProtoField("trainer_spec", "UNK", &proto));
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, MergeSpecsFromArgs) {
|
||||
TrainerSpec trainer_spec;
|
||||
NormalizerSpec normalizer_spec;
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs("", nullptr, nullptr));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs("", &trainer_spec,
|
||||
&normalizer_spec));
|
||||
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--unknown=BPE", &trainer_spec, &normalizer_spec));
|
||||
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--vocab_size=UNK", &trainer_spec, &normalizer_spec));
|
||||
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--model_type=UNK", &trainer_spec, &normalizer_spec));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--model_type=bpe", &trainer_spec, &normalizer_spec));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--split_by_whitespace", &trainer_spec, &normalizer_spec));
|
||||
|
||||
EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--normalization_rule_name=foo", &trainer_spec, &normalizer_spec));
|
||||
EXPECT_EQ("foo", normalizer_spec.name());
|
||||
|
||||
EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
"--vocab_size=UNK", &trainer_spec, &normalizer_spec));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace sentencepiece
|
||||
|
|
|
@ -12,10 +12,140 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "builder.h"
|
||||
#include "flags.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "util.h"
|
||||
|
||||
using sentencepiece::NormalizerSpec;
|
||||
using sentencepiece::TrainerSpec;
|
||||
using sentencepiece::normalizer::Builder;
|
||||
|
||||
namespace {
|
||||
static sentencepiece::TrainerSpec kDefaultTrainerSpec;
|
||||
static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
|
||||
} // namespace
|
||||
|
||||
DEFINE_string(input, "", "comma separated list of input sentences");
|
||||
DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
|
||||
"Input format. Supported format is `text` or `tsv`.");
|
||||
DEFINE_string(model_prefix, "", "output model prefix");
|
||||
DEFINE_string(model_type, "unigram",
|
||||
"model algorithm: unigram, bpe, word or char");
|
||||
DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
|
||||
DEFINE_string(accept_language, "",
|
||||
"comma-separated list of languages this model can accept");
|
||||
DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
|
||||
"character coverage to determine the minimum symbols");
|
||||
DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
|
||||
"maximum size of sentences the trainer loads");
|
||||
DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
|
||||
"maximum size of sentences to make seed sentence piece");
|
||||
DEFINE_int32(training_sentence_size,
|
||||
kDefaultTrainerSpec.training_sentence_size(),
|
||||
"maximum size of sentences to train sentence pieces");
|
||||
DEFINE_int32(seed_sentencepiece_size,
|
||||
kDefaultTrainerSpec.seed_sentencepiece_size(),
|
||||
"the size of seed sentencepieces");
|
||||
DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
|
||||
"Keeps top shrinking_factor pieces with respect to the loss");
|
||||
DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
|
||||
"number of threads for training");
|
||||
DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
|
||||
"number of EM sub-iterations");
|
||||
DEFINE_int32(max_sentencepiece_length,
|
||||
kDefaultTrainerSpec.max_sentencepiece_length(),
|
||||
"maximum length of sentence piece");
|
||||
DEFINE_bool(split_by_unicode_script,
|
||||
kDefaultTrainerSpec.split_by_unicode_script(),
|
||||
"use Unicode script to split sentence pieces");
|
||||
DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
|
||||
"use a white space to split sentence pieces");
|
||||
DEFINE_string(control_symbols, "", "comma separated list of control symbols");
|
||||
DEFINE_string(user_defined_symbols, "",
|
||||
"comma separated list of user defined symbols");
|
||||
DEFINE_string(normalization_rule_name, "nfkc",
|
||||
"Normalization rule name. "
|
||||
"Choose from nfkc or identity");
|
||||
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
|
||||
DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
|
||||
"Add dummy whitespace at the beginning of text");
|
||||
DEFINE_bool(remove_extra_whitespaces,
|
||||
kDefaultNormalizerSpec.remove_extra_whitespaces(),
|
||||
"Removes leading, trailing, and "
|
||||
"duplicate internal whitespace");
|
||||
DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
|
||||
"If set to false, --vocab_size is considered as a soft limit.");
|
||||
DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
|
||||
DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
|
||||
"Override BOS (<s>) id. Set -1 to disable BOS.");
|
||||
DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
|
||||
"Override EOS (</s>) id. Set -1 to disable EOS.");
|
||||
DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
|
||||
"Override PAD (<pad>) id. Set -1 to disable PAD.");
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
sentencepiece::SentencePieceTrainer::Train(argc, argv);
|
||||
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
|
||||
sentencepiece::TrainerSpec trainer_spec;
|
||||
sentencepiece::NormalizerSpec normalizer_spec;
|
||||
|
||||
CHECK_OR_HELP(input);
|
||||
CHECK_OR_HELP(model_prefix);
|
||||
|
||||
// Populates the value from flags to spec.
|
||||
#define SetTrainerSpecFromFlag(name) trainer_spec.set_##name(FLAGS_##name);
|
||||
|
||||
#define SetNormalizerSpecFromFlag(name) \
|
||||
normalizer_spec.set_##name(FLAGS_##name);
|
||||
|
||||
#define SetRepeatedTrainerSpecFromFlag(name) \
|
||||
if (!FLAGS_##name.empty()) { \
|
||||
for (const auto v : \
|
||||
sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
|
||||
trainer_spec.add_##name(v); \
|
||||
} \
|
||||
}
|
||||
|
||||
SetTrainerSpecFromFlag(input_format);
|
||||
SetTrainerSpecFromFlag(model_prefix);
|
||||
SetTrainerSpecFromFlag(vocab_size);
|
||||
SetTrainerSpecFromFlag(character_coverage);
|
||||
SetTrainerSpecFromFlag(input_sentence_size);
|
||||
SetTrainerSpecFromFlag(mining_sentence_size);
|
||||
SetTrainerSpecFromFlag(training_sentence_size);
|
||||
SetTrainerSpecFromFlag(seed_sentencepiece_size);
|
||||
SetTrainerSpecFromFlag(shrinking_factor);
|
||||
SetTrainerSpecFromFlag(num_threads);
|
||||
SetTrainerSpecFromFlag(num_sub_iterations);
|
||||
SetTrainerSpecFromFlag(max_sentencepiece_length);
|
||||
SetTrainerSpecFromFlag(split_by_unicode_script);
|
||||
SetTrainerSpecFromFlag(split_by_whitespace);
|
||||
SetTrainerSpecFromFlag(hard_vocab_limit);
|
||||
SetTrainerSpecFromFlag(unk_id);
|
||||
SetTrainerSpecFromFlag(bos_id);
|
||||
SetTrainerSpecFromFlag(eos_id);
|
||||
SetTrainerSpecFromFlag(pad_id);
|
||||
SetRepeatedTrainerSpecFromFlag(input);
|
||||
SetRepeatedTrainerSpecFromFlag(accept_language);
|
||||
SetRepeatedTrainerSpecFromFlag(control_symbols);
|
||||
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
|
||||
|
||||
normalizer_spec.set_name(FLAGS_normalization_rule_name);
|
||||
SetNormalizerSpecFromFlag(normalization_rule_tsv);
|
||||
SetNormalizerSpecFromFlag(add_dummy_prefix);
|
||||
SetNormalizerSpecFromFlag(remove_extra_whitespaces);
|
||||
|
||||
const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
|
||||
{"unigram", TrainerSpec::UNIGRAM},
|
||||
{"bpe", TrainerSpec::BPE},
|
||||
{"word", TrainerSpec::WORD},
|
||||
{"char", TrainerSpec::CHAR}};
|
||||
|
||||
trainer_spec.set_model_type(
|
||||
sentencepiece::port::FindOrDie(kModelTypeMap, FLAGS_model_type));
|
||||
|
||||
CHECK_OK(sentencepiece::SentencePieceTrainer::Train(trainer_spec,
|
||||
normalizer_spec));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
46
src/util.h
46
src/util.h
|
@ -37,6 +37,52 @@ std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
|
|||
// String utilities
|
||||
namespace string_util {
|
||||
|
||||
inline std::string ToLower(StringPiece arg) {
|
||||
std::string lower_value = arg.ToString();
|
||||
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
|
||||
::tolower);
|
||||
return lower_value;
|
||||
}
|
||||
|
||||
inline std::string ToUpper(StringPiece arg) {
|
||||
std::string upper_value = arg.ToString();
|
||||
std::transform(upper_value.begin(), upper_value.end(), upper_value.begin(),
|
||||
::toupper);
|
||||
return upper_value;
|
||||
}
|
||||
|
||||
template <typename Target>
|
||||
inline bool lexical_cast(StringPiece arg, Target *result) {
|
||||
std::stringstream ss;
|
||||
return (ss << arg.data() && ss >> *result);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool lexical_cast(StringPiece arg, bool *result) {
|
||||
const char *kTrue[] = {"1", "t", "true", "y", "yes"};
|
||||
const char *kFalse[] = {"0", "f", "false", "n", "no"};
|
||||
std::string lower_value = arg.ToString();
|
||||
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
|
||||
::tolower);
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
if (lower_value == kTrue[i]) {
|
||||
*result = true;
|
||||
return true;
|
||||
} else if (lower_value == kFalse[i]) {
|
||||
*result = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool lexical_cast(StringPiece arg, std::string *result) {
|
||||
*result = arg.ToString();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> Split(const std::string &str,
|
||||
const std::string &delim);
|
||||
|
||||
|
|
|
@ -18,6 +18,31 @@
|
|||
|
||||
namespace sentencepiece {
|
||||
|
||||
TEST(UtilTest, LexicalCastTest) {
|
||||
bool b = false;
|
||||
EXPECT_TRUE(string_util::lexical_cast<bool>("true", &b));
|
||||
EXPECT_TRUE(b);
|
||||
EXPECT_TRUE(string_util::lexical_cast<bool>("false", &b));
|
||||
EXPECT_FALSE(b);
|
||||
EXPECT_FALSE(string_util::lexical_cast<bool>("UNK", &b));
|
||||
|
||||
int32 n = 0;
|
||||
EXPECT_TRUE(string_util::lexical_cast<int32>("123", &n));
|
||||
EXPECT_EQ(123, n);
|
||||
EXPECT_TRUE(string_util::lexical_cast<int32>("-123", &n));
|
||||
EXPECT_EQ(-123, n);
|
||||
EXPECT_FALSE(string_util::lexical_cast<int32>("UNK", &n));
|
||||
|
||||
double d = 0.0;
|
||||
EXPECT_TRUE(string_util::lexical_cast<double>("123.4", &d));
|
||||
EXPECT_NEAR(123.4, d, 0.001);
|
||||
EXPECT_FALSE(string_util::lexical_cast<double>("UNK", &d));
|
||||
|
||||
std::string s;
|
||||
EXPECT_TRUE(string_util::lexical_cast<std::string>("123.4", &s));
|
||||
EXPECT_EQ("123.4", s);
|
||||
}
|
||||
|
||||
TEST(UtilTest, CheckNotNullTest) {
|
||||
int a = 0;
|
||||
CHECK_NOTNULL(&a);
|
||||
|
|
Загрузка…
Ссылка в новой задаче