Merge pull request #117 from google/sr

Add LoadFromSerialiedProto to Python wrapper
This commit is contained in:
Taku Kudo 2018-06-24 23:12:32 +09:00 коммит произвёл GitHub
Родитель 0eccb9aa0d 293b1863b9
Коммит bfa93cc8fc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 112 добавлений и 2 удалений

Просмотреть файл

@ -129,6 +129,10 @@ int ToSwigError(sentencepiece::util::error::Code code) {
return $self->Load(filename);
}
util::Status load_from_serialized_proto(sentencepiece::util::min_string_view filename) {
return $self->LoadFromSerializedProto(filename);
}
util::Status set_encode_extra_options(
sentencepiece::util::min_string_view extra_option) {
return $self->SetEncodeExtraOptions(extra_option);

Просмотреть файл

@ -120,6 +120,9 @@ class SentencePieceProcessor(_object):
def LoadOrDie(self, filename):
return _sentencepiece.SentencePieceProcessor_LoadOrDie(self, filename)
def LoadFromSerializedProto(self, serialized):
return _sentencepiece.SentencePieceProcessor_LoadFromSerializedProto(self, serialized)
def SetEncodeExtraOptions(self, extra_option):
return _sentencepiece.SentencePieceProcessor_SetEncodeExtraOptions(self, extra_option)
@ -183,6 +186,9 @@ class SentencePieceProcessor(_object):
def load(self, filename):
return _sentencepiece.SentencePieceProcessor_load(self, filename)
def load_from_serialized_proto(self, filename):
return _sentencepiece.SentencePieceProcessor_load_from_serialized_proto(self, filename)
def set_encode_extra_options(self, extra_option):
return _sentencepiece.SentencePieceProcessor_set_encode_extra_options(self, extra_option)

Просмотреть файл

@ -3545,6 +3545,9 @@ SWIGINTERNINLINE PyObject*
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_load(sentencepiece::SentencePieceProcessor *self,sentencepiece::util::min_string_view filename){
return self->Load(filename);
}
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_load_from_serialized_proto(sentencepiece::SentencePieceProcessor *self,sentencepiece::util::min_string_view filename){
return self->LoadFromSerializedProto(filename);
}
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_set_encode_extra_options(sentencepiece::SentencePieceProcessor *self,sentencepiece::util::min_string_view extra_option){
return self->SetEncodeExtraOptions(extra_option);
}
@ -3753,6 +3756,52 @@ fail:
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
sentencepiece::util::min_string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject * obj0 = 0 ;
PyObject * obj1 = 0 ;
sentencepiece::util::Status result;
if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_LoadFromSerializedProto",&obj0,&obj1)) SWIG_fail;
res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_LoadFromSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
const PyInputString ustring(obj1);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size());
}
{
try {
result = (arg1)->LoadFromSerializedProto(arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_SetEncodeExtraOptions(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
@ -4778,6 +4827,52 @@ fail:
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_load_from_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
sentencepiece::util::min_string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject * obj0 = 0 ;
PyObject * obj1 = 0 ;
sentencepiece::util::Status result;
if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_load_from_serialized_proto",&obj0,&obj1)) SWIG_fail;
res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_load_from_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
const PyInputString ustring(obj1);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size());
}
{
try {
result = sentencepiece_SentencePieceProcessor_load_from_serialized_proto(arg1,arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_set_encode_extra_options(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
@ -5922,6 +6017,7 @@ static PyMethodDef SwigMethods[] = {
{ (char *)"delete_SentencePieceProcessor", _wrap_delete_SentencePieceProcessor, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_Load", _wrap_SentencePieceProcessor_Load, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_LoadOrDie", _wrap_SentencePieceProcessor_LoadOrDie, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_LoadFromSerializedProto", _wrap_SentencePieceProcessor_LoadFromSerializedProto, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_SetEncodeExtraOptions", _wrap_SentencePieceProcessor_SetEncodeExtraOptions, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_SetDecodeExtraOptions", _wrap_SentencePieceProcessor_SetDecodeExtraOptions, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_SetVocabulary", _wrap_SentencePieceProcessor_SetVocabulary, METH_VARARGS, NULL},
@ -5943,6 +6039,7 @@ static PyMethodDef SwigMethods[] = {
{ (char *)"SentencePieceProcessor_IsControl", _wrap_SentencePieceProcessor_IsControl, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_IsUnused", _wrap_SentencePieceProcessor_IsUnused, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_load", _wrap_SentencePieceProcessor_load, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_load_from_serialized_proto", _wrap_SentencePieceProcessor_load_from_serialized_proto, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_set_encode_extra_options", _wrap_SentencePieceProcessor_set_encode_extra_options, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_set_decode_extra_options", _wrap_SentencePieceProcessor_set_decode_extra_options, METH_VARARGS, NULL},
{ (char *)"SentencePieceProcessor_set_vocabulary", _wrap_SentencePieceProcessor_set_vocabulary, METH_VARARGS, NULL},

Просмотреть файл

@ -13,9 +13,12 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertTrue(self.sp_.Load('test/test_model.model'))
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.jasp_.Load('test/test_ja_model.model'))
self.assertTrue(self.sp_.load('test/test_model.model'))
self.assertTrue(self.sp_.LoadFromSerializedProto(
open('test/test_model.model', 'rb').read()))
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.jasp_.load('test/test_ja_model.model'))
self.assertTrue(self.jasp_.LoadFromSerializedProto(
open('test/test_ja_model.model', 'rb').read()))
def test_load(self):
self.assertEqual(1000, self.sp_.GetPieceSize())