Merge pull request #113 from google/sr
Avoids copy in Python2 Unicode mode.
This commit is contained in:
Коммит
d5293c93e2
|
@ -7,59 +7,63 @@
|
|||
#include <sentencepiece_trainer.h>
|
||||
|
||||
namespace {
|
||||
PyObject* kStringInput = reinterpret_cast<PyObject* >(0x1);
|
||||
PyObject* kUnicodeInput = reinterpret_cast<PyObject* >(0x2);
|
||||
PyObject* kUnicodeInput = reinterpret_cast<PyObject* >(0x1);
|
||||
|
||||
inline void ReleaseResultObject(PyObject *obj) {
|
||||
if (obj != nullptr && obj != kUnicodeInput)
|
||||
Py_XDECREF(obj);
|
||||
}
|
||||
|
||||
class PyInputString {
|
||||
public:
|
||||
explicit PyInputString(PyObject* obj) {
|
||||
#if PY_VERSION_HEX >= 0x03000000
|
||||
if (PyUnicode_Check(obj)) {
|
||||
// Python3, Unicode
|
||||
str_ = const_cast<char *>(PyUnicode_AsUTF8AndSize(obj, &size_));
|
||||
input_type_ = kUnicodeInput;
|
||||
} else if (PyBytes_Check(obj)) {
|
||||
// Python3, Bytes
|
||||
PyBytes_AsStringAndSize(obj, &str_, &size_);
|
||||
input_type_ = kStringInput;
|
||||
input_type_ = nullptr;
|
||||
}
|
||||
#else
|
||||
if (PyUnicode_Check(obj)) {
|
||||
utf8_obj_ = PyUnicode_AsUTF8String(obj);
|
||||
PyString_AsStringAndSize(utf8_obj_, &str_, &size_);
|
||||
input_type_ = kUnicodeInput;
|
||||
// Python2, Unicode
|
||||
PyObject *utf8_obj = PyUnicode_AsUTF8String(obj);
|
||||
PyString_AsStringAndSize(utf8_obj, &str_, &size_);
|
||||
input_type_ = utf8_obj;
|
||||
} else if (PyString_Check(obj)) {
|
||||
// Python2, Bytes,
|
||||
PyString_AsStringAndSize(obj, &str_, &size_);
|
||||
input_type_ = kStringInput;
|
||||
input_type_ = nullptr;
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
str_ = nullptr;
|
||||
}
|
||||
}
|
||||
virtual ~PyInputString() {
|
||||
Py_XDECREF(utf8_obj_);
|
||||
}
|
||||
const char* data() const { return str_; }
|
||||
Py_ssize_t size() const { return size_; }
|
||||
bool IsAvalable() const { return str_ != nullptr; }
|
||||
bool IsCopy() const { return utf8_obj_ != nullptr; }
|
||||
PyObject *input_type() const { return input_type_; }
|
||||
|
||||
private:
|
||||
PyObject* utf8_obj_ = nullptr;
|
||||
PyObject* input_type_ = nullptr;
|
||||
char* str_ = nullptr;
|
||||
Py_ssize_t size_ = 0;
|
||||
};
|
||||
|
||||
PyObject* MakePyOutputString(const std::string& output, PyObject *resultobj) {
|
||||
PyObject* MakePyOutputString(const std::string& output,
|
||||
PyObject *resultobj) {
|
||||
#if PY_VERSION_HEX >= 0x03000000
|
||||
return resultobj == kStringInput ?
|
||||
return resultobj != nullptr ?
|
||||
PyBytes_FromStringAndSize(output.data(), output.size()) :
|
||||
PyUnicode_FromStringAndSize(output.data(), output.size());
|
||||
#else
|
||||
return resultobj == kUnicodeInput ?
|
||||
PyUnicode_FromStringAndSize(output.data(), output.size()) :
|
||||
PyString_FromStringAndSize(output.data(), output.size());
|
||||
return resultobj == nullptr ?
|
||||
PyString_FromStringAndSize(output.data(), output.size()) :
|
||||
PyUnicode_FromStringAndSize(output.data(), output.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -80,7 +84,10 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
%}
|
||||
|
||||
%exception {
|
||||
try { $action }
|
||||
try {
|
||||
$action
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
|
@ -88,6 +95,7 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
|
||||
%ignore sentencepiece::util::Status;
|
||||
%ignore sentencepiece::util::error::Code;
|
||||
%ignore sentencepiece::util::min_string_view;
|
||||
%ignore sentencepiece::SentencePieceText;
|
||||
%ignore sentencepiece::NormalizerSpec;
|
||||
%ignore sentencepiece::TrainerSpec;
|
||||
|
@ -277,7 +285,7 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
const PyInputString ustring($input);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
return nullptr;
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
$1 = new std::string(ustring.data(), ustring.size());
|
||||
|
@ -289,14 +297,10 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
const PyInputString ustring($input);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
return nullptr;
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
if (ustring.IsCopy()) {
|
||||
$1.copy(ustring.data(), ustring.size());
|
||||
} else {
|
||||
$1.assign(ustring.data(), ustring.size());
|
||||
}
|
||||
$1 = sentencepiece::util::min_string_view(ustring.data(), ustring.size());
|
||||
}
|
||||
|
||||
|
||||
|
@ -311,13 +315,13 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
(*out)[i] = std::string(ustring.data(), ustring.size());
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "list must contain strings");
|
||||
return nullptr;
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "not a list");
|
||||
return nullptr;
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = out;
|
||||
}
|
||||
|
@ -333,12 +337,12 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"list must contain integers");
|
||||
return nullptr;
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"not a list");
|
||||
return nullptr;
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = out;
|
||||
}
|
||||
|
|
|
@ -98,36 +98,6 @@ except __builtin__.Exception:
|
|||
pass
|
||||
_newclass = 0
|
||||
|
||||
class min_string_view(_object):
|
||||
__swig_setmethods__ = {}
|
||||
__setattr__ = lambda self, name, value: _swig_setattr(self, min_string_view, name, value)
|
||||
__swig_getmethods__ = {}
|
||||
__getattr__ = lambda self, name: _swig_getattr(self, min_string_view, name)
|
||||
__repr__ = _swig_repr
|
||||
|
||||
def data(self):
|
||||
return _sentencepiece.min_string_view_data(self)
|
||||
|
||||
def size(self):
|
||||
return _sentencepiece.min_string_view_size(self)
|
||||
|
||||
def assign(self, data, len):
|
||||
return _sentencepiece.min_string_view_assign(self, data, len)
|
||||
|
||||
def copy(self, data, len):
|
||||
return _sentencepiece.min_string_view_copy(self, data, len)
|
||||
|
||||
def __init__(self, *args):
|
||||
this = _sentencepiece.new_min_string_view(*args)
|
||||
try:
|
||||
self.this.append(this)
|
||||
except __builtin__.Exception:
|
||||
self.this = this
|
||||
__swig_destroy__ = _sentencepiece.delete_min_string_view
|
||||
__del__ = lambda self: None
|
||||
min_string_view_swigregister = _sentencepiece.min_string_view_swigregister
|
||||
min_string_view_swigregister(min_string_view)
|
||||
|
||||
class SentencePieceProcessor(_object):
|
||||
__swig_setmethods__ = {}
|
||||
__setattr__ = lambda self, name, value: _swig_setattr(self, SentencePieceProcessor, name, value)
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -145,40 +145,18 @@ class Status {
|
|||
// the argument of public APIs.
|
||||
class min_string_view {
|
||||
public:
|
||||
min_string_view() noexcept : ptr_(nullptr), length_(0) {}
|
||||
min_string_view() : ptr_(nullptr), length_(0) {}
|
||||
min_string_view(const std::string &str)
|
||||
: ptr_(str.data()), length_(str.size()) {}
|
||||
min_string_view(const char *str)
|
||||
: ptr_(str), length_(std::strlen(str)) {}
|
||||
min_string_view(const char *data, size_t len)
|
||||
: ptr_(data), length_(len) {}
|
||||
min_string_view(const char *str) : ptr_(str), length_(std::strlen(str)) {}
|
||||
min_string_view(const char *data, size_t len) : ptr_(data), length_(len) {}
|
||||
|
||||
const char *data() const { return ptr_; }
|
||||
size_t size() const { return length_; }
|
||||
|
||||
void assign(const char *data, size_t len) {
|
||||
ptr_ = data;
|
||||
length_ = len;
|
||||
}
|
||||
|
||||
void copy(const char *data, size_t len) {
|
||||
rep_.reset(new std::string(data, len));
|
||||
ptr_ = rep_->data();
|
||||
length_ = rep_->size();
|
||||
}
|
||||
|
||||
min_string_view(const min_string_view &s) {
|
||||
if (s.rep_) {
|
||||
copy(s.rep_->data(), s.rep_->size());
|
||||
} else {
|
||||
ptr_ = s.data();
|
||||
length_ = s.size();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const char *ptr_ = nullptr;
|
||||
size_t length_ = 0;
|
||||
std::unique_ptr<std::string> rep_;
|
||||
};
|
||||
} // namespace util
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче