From 293b1863b9ce64efa42289287d50bfac703b298e Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sun, 24 Jun 2018 22:33:25 +0900 Subject: [PATCH] Add LoadFromSerialiedProto to Python wrapper --- python/sentencepiece.i | 4 ++ python/sentencepiece.py | 6 ++ python/sentencepiece_wrap.cxx | 97 +++++++++++++++++++++++++++++++ python/test/sentencepiece_test.py | 7 ++- 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/python/sentencepiece.i b/python/sentencepiece.i index 23eaada..5ad0809 100644 --- a/python/sentencepiece.i +++ b/python/sentencepiece.i @@ -129,6 +129,10 @@ int ToSwigError(sentencepiece::util::error::Code code) { return $self->Load(filename); } + util::Status load_from_serialized_proto(sentencepiece::util::min_string_view filename) { + return $self->LoadFromSerializedProto(filename); + } + util::Status set_encode_extra_options( sentencepiece::util::min_string_view extra_option) { return $self->SetEncodeExtraOptions(extra_option); diff --git a/python/sentencepiece.py b/python/sentencepiece.py index 5b15e57..d2a6580 100644 --- a/python/sentencepiece.py +++ b/python/sentencepiece.py @@ -120,6 +120,9 @@ class SentencePieceProcessor(_object): def LoadOrDie(self, filename): return _sentencepiece.SentencePieceProcessor_LoadOrDie(self, filename) + def LoadFromSerializedProto(self, serialized): + return _sentencepiece.SentencePieceProcessor_LoadFromSerializedProto(self, serialized) + def SetEncodeExtraOptions(self, extra_option): return _sentencepiece.SentencePieceProcessor_SetEncodeExtraOptions(self, extra_option) @@ -183,6 +186,9 @@ class SentencePieceProcessor(_object): def load(self, filename): return _sentencepiece.SentencePieceProcessor_load(self, filename) + def load_from_serialized_proto(self, filename): + return _sentencepiece.SentencePieceProcessor_load_from_serialized_proto(self, filename) + def set_encode_extra_options(self, extra_option): return _sentencepiece.SentencePieceProcessor_set_encode_extra_options(self, extra_option) diff --git a/python/sentencepiece_wrap.cxx b/python/sentencepiece_wrap.cxx index 4997356..38b8572 100644 --- a/python/sentencepiece_wrap.cxx +++ b/python/sentencepiece_wrap.cxx @@ -3545,6 +3545,9 @@ SWIGINTERNINLINE PyObject* SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_load(sentencepiece::SentencePieceProcessor *self,sentencepiece::util::min_string_view filename){ return self->Load(filename); } +SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_load_from_serialized_proto(sentencepiece::SentencePieceProcessor *self,sentencepiece::util::min_string_view filename){ + return self->LoadFromSerializedProto(filename); + } SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_set_encode_extra_options(sentencepiece::SentencePieceProcessor *self,sentencepiece::util::min_string_view extra_option){ return self->SetEncodeExtraOptions(extra_option); } @@ -3753,6 +3756,52 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::Status result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_LoadFromSerializedProto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_LoadFromSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + { + try { + result = (arg1)->LoadFromSerializedProto(arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + if (!(&result)->ok()) { + SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str()); + } + resultobj = SWIG_From_bool((&result)->ok()); + } + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceProcessor_SetEncodeExtraOptions(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -4778,6 +4827,52 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_load_from_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::Status result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_load_from_serialized_proto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_load_from_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + { + try { + result = sentencepiece_SentencePieceProcessor_load_from_serialized_proto(arg1,arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + if (!(&result)->ok()) { + SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str()); + } + resultobj = SWIG_From_bool((&result)->ok()); + } + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceProcessor_set_encode_extra_options(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -5922,6 +6017,7 @@ static PyMethodDef SwigMethods[] = { { (char *)"delete_SentencePieceProcessor", _wrap_delete_SentencePieceProcessor, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_Load", _wrap_SentencePieceProcessor_Load, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_LoadOrDie", _wrap_SentencePieceProcessor_LoadOrDie, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_LoadFromSerializedProto", _wrap_SentencePieceProcessor_LoadFromSerializedProto, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_SetEncodeExtraOptions", _wrap_SentencePieceProcessor_SetEncodeExtraOptions, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_SetDecodeExtraOptions", _wrap_SentencePieceProcessor_SetDecodeExtraOptions, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_SetVocabulary", _wrap_SentencePieceProcessor_SetVocabulary, METH_VARARGS, NULL}, @@ -5943,6 +6039,7 @@ static PyMethodDef SwigMethods[] = { { (char *)"SentencePieceProcessor_IsControl", _wrap_SentencePieceProcessor_IsControl, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_IsUnused", _wrap_SentencePieceProcessor_IsUnused, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_load", _wrap_SentencePieceProcessor_load, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_load_from_serialized_proto", _wrap_SentencePieceProcessor_load_from_serialized_proto, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_set_encode_extra_options", _wrap_SentencePieceProcessor_set_encode_extra_options, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_set_decode_extra_options", _wrap_SentencePieceProcessor_set_decode_extra_options, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_set_vocabulary", _wrap_SentencePieceProcessor_set_vocabulary, METH_VARARGS, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 91dcfb2..6df38c8 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -13,9 +13,12 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertTrue(self.sp_.Load('test/test_model.model')) self.jasp_ = spm.SentencePieceProcessor() self.assertTrue(self.jasp_.Load('test/test_ja_model.model')) - self.assertTrue(self.sp_.load('test/test_model.model')) + self.assertTrue(self.sp_.LoadFromSerializedProto( + open('test/test_model.model', 'rb').read())) self.jasp_ = spm.SentencePieceProcessor() - self.assertTrue(self.jasp_.load('test/test_ja_model.model')) + self.assertTrue(self.jasp_.LoadFromSerializedProto( + open('test/test_ja_model.model', 'rb').read())) + def test_load(self): self.assertEqual(1000, self.sp_.GetPieceSize())