Merge pull request #113 from google/sr

Avoids copy in Python2 Unicode mode.
This commit is contained in:
Taku Kudo 2018-06-20 00:25:51 +09:00 коммит произвёл GitHub
Родитель 2d9df1e6dc 5a246f5cef
Коммит d5293c93e2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 234 добавлений и 955 удалений

Просмотреть файл

@ -7,59 +7,63 @@
#include <sentencepiece_trainer.h>
namespace {
PyObject* kStringInput = reinterpret_cast<PyObject* >(0x1);
PyObject* kUnicodeInput = reinterpret_cast<PyObject* >(0x2);
PyObject* kUnicodeInput = reinterpret_cast<PyObject* >(0x1);
inline void ReleaseResultObject(PyObject *obj) {
if (obj != nullptr && obj != kUnicodeInput)
Py_XDECREF(obj);
}
class PyInputString {
public:
explicit PyInputString(PyObject* obj) {
#if PY_VERSION_HEX >= 0x03000000
if (PyUnicode_Check(obj)) {
// Python3, Unicode
str_ = const_cast<char *>(PyUnicode_AsUTF8AndSize(obj, &size_));
input_type_ = kUnicodeInput;
} else if (PyBytes_Check(obj)) {
// Python3, Bytes
PyBytes_AsStringAndSize(obj, &str_, &size_);
input_type_ = kStringInput;
input_type_ = nullptr;
}
#else
if (PyUnicode_Check(obj)) {
utf8_obj_ = PyUnicode_AsUTF8String(obj);
PyString_AsStringAndSize(utf8_obj_, &str_, &size_);
input_type_ = kUnicodeInput;
// Python2, Unicode
PyObject *utf8_obj = PyUnicode_AsUTF8String(obj);
PyString_AsStringAndSize(utf8_obj, &str_, &size_);
input_type_ = utf8_obj;
} else if (PyString_Check(obj)) {
// Python2, Bytes,
PyString_AsStringAndSize(obj, &str_, &size_);
input_type_ = kStringInput;
input_type_ = nullptr;
}
#endif
else {
str_ = nullptr;
}
}
virtual ~PyInputString() {
Py_XDECREF(utf8_obj_);
}
const char* data() const { return str_; }
Py_ssize_t size() const { return size_; }
bool IsAvalable() const { return str_ != nullptr; }
bool IsCopy() const { return utf8_obj_ != nullptr; }
PyObject *input_type() const { return input_type_; }
private:
PyObject* utf8_obj_ = nullptr;
PyObject* input_type_ = nullptr;
char* str_ = nullptr;
Py_ssize_t size_ = 0;
};
PyObject* MakePyOutputString(const std::string& output, PyObject *resultobj) {
PyObject* MakePyOutputString(const std::string& output,
PyObject *resultobj) {
#if PY_VERSION_HEX >= 0x03000000
return resultobj == kStringInput ?
return resultobj != nullptr ?
PyBytes_FromStringAndSize(output.data(), output.size()) :
PyUnicode_FromStringAndSize(output.data(), output.size());
#else
return resultobj == kUnicodeInput ?
PyUnicode_FromStringAndSize(output.data(), output.size()) :
PyString_FromStringAndSize(output.data(), output.size());
return resultobj == nullptr ?
PyString_FromStringAndSize(output.data(), output.size()) :
PyUnicode_FromStringAndSize(output.data(), output.size());
#endif
}
@ -80,7 +84,10 @@ int ToSwigError(sentencepiece::util::error::Code code) {
%}
%exception {
try { $action }
try {
$action
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
@ -88,6 +95,7 @@ int ToSwigError(sentencepiece::util::error::Code code) {
%ignore sentencepiece::util::Status;
%ignore sentencepiece::util::error::Code;
%ignore sentencepiece::util::min_string_view;
%ignore sentencepiece::SentencePieceText;
%ignore sentencepiece::NormalizerSpec;
%ignore sentencepiece::TrainerSpec;
@ -277,7 +285,7 @@ int ToSwigError(sentencepiece::util::error::Code code) {
const PyInputString ustring($input);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
return nullptr;
SWIG_fail;
}
resultobj = ustring.input_type();
$1 = new std::string(ustring.data(), ustring.size());
@ -289,14 +297,10 @@ int ToSwigError(sentencepiece::util::error::Code code) {
const PyInputString ustring($input);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
return nullptr;
SWIG_fail;
}
resultobj = ustring.input_type();
if (ustring.IsCopy()) {
$1.copy(ustring.data(), ustring.size());
} else {
$1.assign(ustring.data(), ustring.size());
}
$1 = sentencepiece::util::min_string_view(ustring.data(), ustring.size());
}
@ -311,13 +315,13 @@ int ToSwigError(sentencepiece::util::error::Code code) {
(*out)[i] = std::string(ustring.data(), ustring.size());
} else {
PyErr_SetString(PyExc_TypeError, "list must contain strings");
return nullptr;
SWIG_fail;
}
resultobj = ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a list");
return nullptr;
SWIG_fail;
}
$1 = out;
}
@ -333,12 +337,12 @@ int ToSwigError(sentencepiece::util::error::Code code) {
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
} else {
PyErr_SetString(PyExc_TypeError,"list must contain integers");
return nullptr;
SWIG_fail;
}
}
} else {
PyErr_SetString(PyExc_TypeError,"not a list");
return nullptr;
SWIG_fail;
}
$1 = out;
}

Просмотреть файл

@ -98,36 +98,6 @@ except __builtin__.Exception:
pass
_newclass = 0
class min_string_view(_object):
__swig_setmethods__ = {}
__setattr__ = lambda self, name, value: _swig_setattr(self, min_string_view, name, value)
__swig_getmethods__ = {}
__getattr__ = lambda self, name: _swig_getattr(self, min_string_view, name)
__repr__ = _swig_repr
def data(self):
return _sentencepiece.min_string_view_data(self)
def size(self):
return _sentencepiece.min_string_view_size(self)
def assign(self, data, len):
return _sentencepiece.min_string_view_assign(self, data, len)
def copy(self, data, len):
return _sentencepiece.min_string_view_copy(self, data, len)
def __init__(self, *args):
this = _sentencepiece.new_min_string_view(*args)
try:
self.this.append(this)
except __builtin__.Exception:
self.this = this
__swig_destroy__ = _sentencepiece.delete_min_string_view
__del__ = lambda self: None
min_string_view_swigregister = _sentencepiece.min_string_view_swigregister
min_string_view_swigregister(min_string_view)
class SentencePieceProcessor(_object):
__swig_setmethods__ = {}
__setattr__ = lambda self, name, value: _swig_setattr(self, SentencePieceProcessor, name, value)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -145,40 +145,18 @@ class Status {
// the argument of public APIs.
class min_string_view {
public:
min_string_view() noexcept : ptr_(nullptr), length_(0) {}
min_string_view() : ptr_(nullptr), length_(0) {}
min_string_view(const std::string &str)
: ptr_(str.data()), length_(str.size()) {}
min_string_view(const char *str)
: ptr_(str), length_(std::strlen(str)) {}
min_string_view(const char *data, size_t len)
: ptr_(data), length_(len) {}
min_string_view(const char *str) : ptr_(str), length_(std::strlen(str)) {}
min_string_view(const char *data, size_t len) : ptr_(data), length_(len) {}
const char *data() const { return ptr_; }
size_t size() const { return length_; }
void assign(const char *data, size_t len) {
ptr_ = data;
length_ = len;
}
void copy(const char *data, size_t len) {
rep_.reset(new std::string(data, len));
ptr_ = rep_->data();
length_ = rep_->size();
}
min_string_view(const min_string_view &s) {
if (s.rep_) {
copy(s.rep_->data(), s.rep_->size());
} else {
ptr_ = s.data();
length_ = s.size();
}
}
private:
const char *ptr_ = nullptr;
size_t length_ = 0;
std::unique_ptr<std::string> rep_;
};
} // namespace util