Merge pull request #105 from google/sr
Splits travis rule for Python wrapper
This commit is contained in:
Коммит
bcdfc037b1
12
.travis.yml
12
.travis.yml
|
@ -16,13 +16,17 @@ matrix:
|
|||
- os: linux
|
||||
env: IMAGE=fedora:latest COMMAND=build_linux_gcc_fedora
|
||||
services: docker
|
||||
- os: linux
|
||||
script:
|
||||
- $TRAVIS_BUILD_DIR/make_py_wheel.sh
|
||||
services: docker
|
||||
- os: linux
|
||||
env: IMAGE=ubuntu:rolling COMMAND=build_linux_clang_ubuntu
|
||||
services: docker
|
||||
- os: linux
|
||||
script:
|
||||
- $TRAVIS_BUILD_DIR/make_py_wheel.sh x86_64
|
||||
services: docker
|
||||
- os: linux
|
||||
script:
|
||||
- $TRAVIS_BUILD_DIR/make_py_wheel.sh i686
|
||||
services: docker
|
||||
- os: osx
|
||||
osx_image: xcode9.3
|
||||
env: IMAGE=native COMMAND=build_osx
|
||||
|
|
|
@ -83,6 +83,8 @@ make_wheel() {
|
|||
|
||||
if [ "$#" -eq 2 ]; then
|
||||
eval "$1" $2
|
||||
elif [ "$#" -eq 1 ]; then
|
||||
run_docker quay.io/pypa/manylinux1_${1} ${1}
|
||||
else
|
||||
run_docker quay.io/pypa/manylinux1_i686 i686
|
||||
run_docker quay.io/pypa/manylinux1_x86_64 x86_64
|
||||
|
|
|
@ -15,7 +15,7 @@ class PyInputString {
|
|||
explicit PyInputString(PyObject* obj) {
|
||||
#if PY_VERSION_HEX >= 0x03000000
|
||||
if (PyUnicode_Check(obj)) {
|
||||
str_ = PyUnicode_AsUTF8AndSize(obj, &size_);
|
||||
str_ = const_cast<char *>(PyUnicode_AsUTF8AndSize(obj, &size_));
|
||||
input_type_ = kUnicodeInput;
|
||||
} else if (PyBytes_Check(obj)) {
|
||||
PyBytes_AsStringAndSize(obj, &str_, &size_);
|
||||
|
@ -125,77 +125,135 @@ int ToSwigError(sentencepiece::util::error::Code code) {
|
|||
const std::string &,
|
||||
google::protobuf::Message *message);
|
||||
|
||||
%extend sentencepiece::SentencePieceTrainer {
|
||||
static util::Status train(const std::string &args) {
|
||||
return sentencepiece::SentencePieceTrainer::Train(args);
|
||||
}
|
||||
}
|
||||
|
||||
%extend sentencepiece::SentencePieceProcessor {
|
||||
util::Status load(const std::string &filename) {
|
||||
return $self->Load(filename);
|
||||
}
|
||||
|
||||
util::Status _set_encode_extra_options(const std::string &extra_option) {
|
||||
return $self->SetEncodeExtraOptions(extra_option);
|
||||
}
|
||||
|
||||
util::Status _set_decode_extra_options(const std::string &extra_option) {
|
||||
return $self->SetDecodeExtraOptions(extra_option);
|
||||
}
|
||||
|
||||
util::Status set_vocabulary(
|
||||
const std::vector<std::string> &valid_vocab) {
|
||||
return $self->SetVocabulary(valid_vocab);
|
||||
}
|
||||
|
||||
util::Status reset_vocabulary() {
|
||||
return $self->ResetVocabulary();
|
||||
}
|
||||
|
||||
util::Status load_vocabulary(const std::string &filename,
|
||||
int threshold) {
|
||||
return $self->LoadVocabulary(filename, threshold);
|
||||
}
|
||||
|
||||
std::vector<std::string> encode(const std::string& input) const {
|
||||
return $self->EncodeAsPieces(input);
|
||||
}
|
||||
|
||||
std::vector<std::string> encode_as_pieces(const std::string& input) const {
|
||||
return $self->EncodeAsPieces(input);
|
||||
}
|
||||
|
||||
std::vector<int> encode_as_ids(const std::string& input) const {
|
||||
return $self->EncodeAsIds(input);
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> nbest_encode(const std::string& input,
|
||||
int nbest_size) const {
|
||||
return $self->NBestEncodeAsPieces(input, nbest_size);
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> nbest_encode_as_pieces(
|
||||
const std::string& input, int nbest_size) const {
|
||||
return $self->NBestEncodeAsPieces(input, nbest_size);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> nbest_encode_as_ids(const std::string& input,
|
||||
int nbest_size) const {
|
||||
return $self->NBestEncodeAsIds(input, nbest_size);
|
||||
}
|
||||
|
||||
std::vector<std::string> sample_encode(const std::string& input,
|
||||
int nbest_size, float alpha) const {
|
||||
return $self->SampleEncodeAsPieces(input, nbest_size, alpha);
|
||||
}
|
||||
|
||||
std::vector<std::string> sample_encode_as_pieces(const std::string& input,
|
||||
int nbest_size, float alpha) const {
|
||||
return $self->SampleEncodeAsPieces(input, nbest_size, alpha);
|
||||
}
|
||||
|
||||
std::vector<int> sample_encode_as_ids(const std::string& input,
|
||||
int nbest_size, float alpha) const {
|
||||
return $self->SampleEncodeAsIds(input, nbest_size, alpha);
|
||||
}
|
||||
|
||||
std::string decode(const std::vector<std::string>& input) const {
|
||||
return $self->DecodePieces(input);
|
||||
}
|
||||
|
||||
std::string decode_pieces(const std::vector<std::string>& input) const {
|
||||
return $self->DecodePieces(input);
|
||||
}
|
||||
|
||||
std::string decode_ids(const std::vector<int>& input) const {
|
||||
return $self->DecodeIds(input);
|
||||
}
|
||||
|
||||
std::vector<std::string> Encode(const std::string& input) const {
|
||||
std::vector<std::string> output;
|
||||
THROW_IF_ERROR($self->Encode(input, &output));
|
||||
return output;
|
||||
return $self->EncodeAsPieces(input);
|
||||
}
|
||||
|
||||
std::vector<std::string> EncodeAsPieces(const std::string& input) const {
|
||||
std::vector<std::string> output;
|
||||
THROW_IF_ERROR($self->Encode(input, &output));
|
||||
return output;
|
||||
int get_piece_size() const {
|
||||
return $self->GetPieceSize();
|
||||
}
|
||||
|
||||
std::vector<int> EncodeAsIds(const std::string& input) const {
|
||||
std::vector<int> output;
|
||||
THROW_IF_ERROR($self->Encode(input, &output));
|
||||
return output;
|
||||
int piece_to_id(const std::string &piece) const {
|
||||
return $self->PieceToId(piece);
|
||||
}
|
||||
|
||||
std::string id_to_piece(int id) const {
|
||||
return $self->IdToPiece(id);
|
||||
}
|
||||
|
||||
float get_score(int id) const {
|
||||
return $self->GetScore(id);
|
||||
}
|
||||
|
||||
bool is_unknown(int id) const {
|
||||
return $self->IsUnused(id);
|
||||
}
|
||||
|
||||
bool is_control(int id) const {
|
||||
return $self->IsControl(id);
|
||||
}
|
||||
|
||||
bool is_unused(int id) const {
|
||||
return $self->IsUnused(id);
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> NBestEncode(const std::string& input, int nbest_size) const {
|
||||
std::vector<std::vector<std::string>> output;
|
||||
THROW_IF_ERROR($self->NBestEncode(input, nbest_size, &output));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> NBestEncodeAsPieces(const std::string& input, int nbest_size) const {
|
||||
std::vector<std::vector<std::string>> output;
|
||||
THROW_IF_ERROR($self->NBestEncode(input, nbest_size, &output));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> NBestEncodeAsIds(const std::string& input, int nbest_size) const {
|
||||
std::vector<std::vector<int>> output;
|
||||
THROW_IF_ERROR($self->NBestEncode(input, nbest_size, &output));
|
||||
return output;
|
||||
return $self->NBestEncodeAsPieces(input, nbest_size);
|
||||
}
|
||||
|
||||
std::vector<std::string> SampleEncode(const std::string& input, int nbest_size, float alpha) const {
|
||||
std::vector<std::string> output;
|
||||
THROW_IF_ERROR($self->SampleEncode(input, nbest_size, alpha, &output));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> SampleEncodeAsPieces(const std::string& input, int nbest_size, float alpha) const {
|
||||
std::vector<std::string> output;
|
||||
THROW_IF_ERROR($self->SampleEncode(input, nbest_size, alpha, &output));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<int> SampleEncodeAsIds(const std::string& input, int nbest_size, float alpha) const {
|
||||
std::vector<int> output;
|
||||
THROW_IF_ERROR($self->SampleEncode(input, nbest_size, alpha, &output));
|
||||
return output;
|
||||
return $self->SampleEncodeAsPieces(input, nbest_size, alpha);
|
||||
}
|
||||
|
||||
std::string Decode(const std::vector<std::string>& input) const {
|
||||
std::string output;
|
||||
THROW_IF_ERROR($self->Decode(input, &output));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string DecodePieces(const std::vector<std::string>& input) const {
|
||||
std::string output;
|
||||
THROW_IF_ERROR($self->Decode(input, &output));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string DecodeIds(const std::vector<int>& input) const {
|
||||
std::string output;
|
||||
THROW_IF_ERROR($self->Decode(input, &output));
|
||||
return output;
|
||||
return $self->DecodePieces(input);
|
||||
}
|
||||
|
||||
int __len__() {
|
||||
|
|
|
@ -135,6 +135,30 @@ class SentencePieceProcessor(_object):
|
|||
def LoadVocabulary(self, filename, threshold):
|
||||
return _sentencepiece.SentencePieceProcessor_LoadVocabulary(self, filename, threshold)
|
||||
|
||||
def EncodeAsPieces(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_EncodeAsPieces(self, input)
|
||||
|
||||
def EncodeAsIds(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_EncodeAsIds(self, input)
|
||||
|
||||
def NBestEncodeAsPieces(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsPieces(self, input, nbest_size)
|
||||
|
||||
def NBestEncodeAsIds(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsIds(self, input, nbest_size)
|
||||
|
||||
def SampleEncodeAsPieces(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsPieces(self, input, nbest_size, alpha)
|
||||
|
||||
def SampleEncodeAsIds(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsIds(self, input, nbest_size, alpha)
|
||||
|
||||
def DecodePieces(self, pieces):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces)
|
||||
|
||||
def DecodeIds(self, ids):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids)
|
||||
|
||||
def GetPieceSize(self):
|
||||
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
|
||||
|
||||
|
@ -156,42 +180,93 @@ class SentencePieceProcessor(_object):
|
|||
def IsUnused(self, id):
|
||||
return _sentencepiece.SentencePieceProcessor_IsUnused(self, id)
|
||||
|
||||
def load(self, filename):
|
||||
return _sentencepiece.SentencePieceProcessor_load(self, filename)
|
||||
|
||||
def _set_encode_extra_options(self, extra_option):
|
||||
return _sentencepiece.SentencePieceProcessor__set_encode_extra_options(self, extra_option)
|
||||
|
||||
def _set_decode_extra_options(self, extra_option):
|
||||
return _sentencepiece.SentencePieceProcessor__set_decode_extra_options(self, extra_option)
|
||||
|
||||
def set_vocabulary(self, valid_vocab):
|
||||
return _sentencepiece.SentencePieceProcessor_set_vocabulary(self, valid_vocab)
|
||||
|
||||
def reset_vocabulary(self):
|
||||
return _sentencepiece.SentencePieceProcessor_reset_vocabulary(self)
|
||||
|
||||
def load_vocabulary(self, filename, threshold):
|
||||
return _sentencepiece.SentencePieceProcessor_load_vocabulary(self, filename, threshold)
|
||||
|
||||
def encode(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_encode(self, input)
|
||||
|
||||
def encode_as_pieces(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_encode_as_pieces(self, input)
|
||||
|
||||
def encode_as_ids(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_encode_as_ids(self, input)
|
||||
|
||||
def nbest_encode(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_nbest_encode(self, input, nbest_size)
|
||||
|
||||
def nbest_encode_as_pieces(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_nbest_encode_as_pieces(self, input, nbest_size)
|
||||
|
||||
def nbest_encode_as_ids(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_nbest_encode_as_ids(self, input, nbest_size)
|
||||
|
||||
def sample_encode(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_sample_encode(self, input, nbest_size, alpha)
|
||||
|
||||
def sample_encode_as_pieces(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_sample_encode_as_pieces(self, input, nbest_size, alpha)
|
||||
|
||||
def sample_encode_as_ids(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_sample_encode_as_ids(self, input, nbest_size, alpha)
|
||||
|
||||
def decode(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_decode(self, input)
|
||||
|
||||
def decode_pieces(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_decode_pieces(self, input)
|
||||
|
||||
def decode_ids(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_decode_ids(self, input)
|
||||
|
||||
def Encode(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_Encode(self, input)
|
||||
|
||||
def EncodeAsPieces(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_EncodeAsPieces(self, input)
|
||||
def get_piece_size(self):
|
||||
return _sentencepiece.SentencePieceProcessor_get_piece_size(self)
|
||||
|
||||
def EncodeAsIds(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_EncodeAsIds(self, input)
|
||||
def piece_to_id(self, piece):
|
||||
return _sentencepiece.SentencePieceProcessor_piece_to_id(self, piece)
|
||||
|
||||
def id_to_piece(self, id):
|
||||
return _sentencepiece.SentencePieceProcessor_id_to_piece(self, id)
|
||||
|
||||
def get_score(self, id):
|
||||
return _sentencepiece.SentencePieceProcessor_get_score(self, id)
|
||||
|
||||
def is_unknown(self, id):
|
||||
return _sentencepiece.SentencePieceProcessor_is_unknown(self, id)
|
||||
|
||||
def is_control(self, id):
|
||||
return _sentencepiece.SentencePieceProcessor_is_control(self, id)
|
||||
|
||||
def is_unused(self, id):
|
||||
return _sentencepiece.SentencePieceProcessor_is_unused(self, id)
|
||||
|
||||
def NBestEncode(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_NBestEncode(self, input, nbest_size)
|
||||
|
||||
def NBestEncodeAsPieces(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsPieces(self, input, nbest_size)
|
||||
|
||||
def NBestEncodeAsIds(self, input, nbest_size):
|
||||
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsIds(self, input, nbest_size)
|
||||
|
||||
def SampleEncode(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_SampleEncode(self, input, nbest_size, alpha)
|
||||
|
||||
def SampleEncodeAsPieces(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsPieces(self, input, nbest_size, alpha)
|
||||
|
||||
def SampleEncodeAsIds(self, input, nbest_size, alpha):
|
||||
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsIds(self, input, nbest_size, alpha)
|
||||
|
||||
def Decode(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_Decode(self, input)
|
||||
|
||||
def DecodePieces(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, input)
|
||||
|
||||
def DecodeIds(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, input)
|
||||
|
||||
def __len__(self):
|
||||
return _sentencepiece.SentencePieceProcessor___len__(self)
|
||||
|
||||
|
@ -213,6 +288,10 @@ class SentencePieceTrainer(_object):
|
|||
Train = staticmethod(_sentencepiece.SentencePieceTrainer_Train)
|
||||
else:
|
||||
Train = _sentencepiece.SentencePieceTrainer_Train
|
||||
if _newclass:
|
||||
train = staticmethod(_sentencepiece.SentencePieceTrainer_train)
|
||||
else:
|
||||
train = _sentencepiece.SentencePieceTrainer_train
|
||||
SentencePieceTrainer_swigregister = _sentencepiece.SentencePieceTrainer_swigregister
|
||||
SentencePieceTrainer_swigregister(SentencePieceTrainer)
|
||||
|
||||
|
@ -220,6 +299,10 @@ def SentencePieceTrainer_Train(args):
|
|||
return _sentencepiece.SentencePieceTrainer_Train(args)
|
||||
SentencePieceTrainer_Train = _sentencepiece.SentencePieceTrainer_Train
|
||||
|
||||
def SentencePieceTrainer_train(args):
|
||||
return _sentencepiece.SentencePieceTrainer_train(args)
|
||||
SentencePieceTrainer_train = _sentencepiece.SentencePieceTrainer_train
|
||||
|
||||
# This file is compatible with both classic and new-style classes.
|
||||
|
||||
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -6,99 +6,188 @@ import unittest
|
|||
import sys
|
||||
|
||||
class TestSentencepieceProcessor(unittest.TestCase):
|
||||
"""Test case for SentencePieceProcessor"""
|
||||
"""Test case for SentencePieceProcessor"""
|
||||
|
||||
def setUp(self):
|
||||
self.sp_ = spm.SentencePieceProcessor()
|
||||
self.assertTrue(self.sp_.Load('test/test_model.model'))
|
||||
self.jasp_ = spm.SentencePieceProcessor()
|
||||
self.assertTrue(self.jasp_.Load('test/test_ja_model.model'))
|
||||
def setUp(self):
|
||||
self.sp_ = spm.SentencePieceProcessor()
|
||||
self.assertTrue(self.sp_.Load('test/test_model.model'))
|
||||
self.jasp_ = spm.SentencePieceProcessor()
|
||||
self.assertTrue(self.jasp_.Load('test/test_ja_model.model'))
|
||||
self.assertTrue(self.sp_.load('test/test_model.model'))
|
||||
self.jasp_ = spm.SentencePieceProcessor()
|
||||
self.assertTrue(self.jasp_.load('test/test_ja_model.model'))
|
||||
|
||||
def test_load(self):
|
||||
self.assertEqual(1000, self.sp_.GetPieceSize())
|
||||
self.assertEqual(0, self.sp_.PieceToId('<unk>'))
|
||||
self.assertEqual(1, self.sp_.PieceToId('<s>'))
|
||||
self.assertEqual(2, self.sp_.PieceToId('</s>'))
|
||||
self.assertEqual('<unk>', self.sp_.IdToPiece(0))
|
||||
self.assertEqual('<s>', self.sp_.IdToPiece(1))
|
||||
self.assertEqual('</s>', self.sp_.IdToPiece(2))
|
||||
for i in range(self.sp_.GetPieceSize()):
|
||||
piece = self.sp_.IdToPiece(i)
|
||||
self.assertEqual(i, self.sp_.PieceToId(piece))
|
||||
def test_load(self):
|
||||
self.assertEqual(1000, self.sp_.GetPieceSize())
|
||||
self.assertEqual(0, self.sp_.PieceToId('<unk>'))
|
||||
self.assertEqual(1, self.sp_.PieceToId('<s>'))
|
||||
self.assertEqual(2, self.sp_.PieceToId('</s>'))
|
||||
self.assertEqual('<unk>', self.sp_.IdToPiece(0))
|
||||
self.assertEqual('<s>', self.sp_.IdToPiece(1))
|
||||
self.assertEqual('</s>', self.sp_.IdToPiece(2))
|
||||
for i in range(self.sp_.GetPieceSize()):
|
||||
piece = self.sp_.IdToPiece(i)
|
||||
self.assertEqual(i, self.sp_.PieceToId(piece))
|
||||
|
||||
def test_roundtrip(self):
|
||||
text = 'I saw a girl with a telescope.'
|
||||
ids = self.sp_.EncodeAsIds(text)
|
||||
pieces1 = self.sp_.EncodeAsPieces(text)
|
||||
pieces2 = self.sp_.Encode(text)
|
||||
pieces3 = self.sp_.NBestEncode(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(pieces1, pieces3)
|
||||
self.assertEqual(text, self.sp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)))
|
||||
def test_roundtrip(self):
|
||||
text = 'I saw a girl with a telescope.'
|
||||
ids = self.sp_.EncodeAsIds(text)
|
||||
pieces1 = self.sp_.EncodeAsPieces(text)
|
||||
pieces2 = self.sp_.Encode(text)
|
||||
pieces3 = self.sp_.NBestEncode(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(pieces1, pieces3)
|
||||
self.assertEqual(text, self.sp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)))
|
||||
|
||||
def test_ja_load(self):
|
||||
self.assertEqual(8000, self.jasp_.GetPieceSize())
|
||||
self.assertEqual(0, self.jasp_.PieceToId('<unk>'))
|
||||
self.assertEqual(1, self.jasp_.PieceToId('<s>'))
|
||||
self.assertEqual(2, self.jasp_.PieceToId('</s>'))
|
||||
self.assertEqual('<unk>', self.jasp_.IdToPiece(0))
|
||||
self.assertEqual('<s>', self.jasp_.IdToPiece(1))
|
||||
self.assertEqual('</s>', self.jasp_.IdToPiece(2))
|
||||
for i in range(self.jasp_.GetPieceSize()):
|
||||
piece = self.jasp_.IdToPiece(i)
|
||||
self.assertEqual(i, self.jasp_.PieceToId(piece))
|
||||
def test_ja_load(self):
|
||||
self.assertEqual(8000, self.jasp_.GetPieceSize())
|
||||
self.assertEqual(0, self.jasp_.PieceToId('<unk>'))
|
||||
self.assertEqual(1, self.jasp_.PieceToId('<s>'))
|
||||
self.assertEqual(2, self.jasp_.PieceToId('</s>'))
|
||||
self.assertEqual('<unk>', self.jasp_.IdToPiece(0))
|
||||
self.assertEqual('<s>', self.jasp_.IdToPiece(1))
|
||||
self.assertEqual('</s>', self.jasp_.IdToPiece(2))
|
||||
for i in range(self.jasp_.GetPieceSize()):
|
||||
piece = self.jasp_.IdToPiece(i)
|
||||
self.assertEqual(i, self.jasp_.PieceToId(piece))
|
||||
|
||||
def test_ja_roundtrip(self):
|
||||
text = '清水寺は京都にある。'
|
||||
ids = self.jasp_.EncodeAsIds(text)
|
||||
pieces1 = self.jasp_.EncodeAsPieces(text)
|
||||
pieces2 = self.jasp_.Encode(text)
|
||||
pieces3 = self.jasp_.NBestEncode(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
|
||||
self.assertEqual(text, self.jasp_.DecodeIds(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
|
||||
def test_ja_roundtrip(self):
|
||||
text = '清水寺は京都にある。'
|
||||
ids = self.jasp_.EncodeAsIds(text)
|
||||
pieces1 = self.jasp_.EncodeAsPieces(text)
|
||||
pieces2 = self.jasp_.Encode(text)
|
||||
pieces3 = self.jasp_.NBestEncode(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
|
||||
self.assertEqual(text, self.jasp_.DecodeIds(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
|
||||
|
||||
|
||||
def test_unicode_roundtrip(self):
|
||||
text = u'I saw a girl with a telescope.'
|
||||
ids = self.sp_.EncodeAsIds(text)
|
||||
pieces1 = self.sp_.EncodeAsPieces(text)
|
||||
pieces2 = self.sp_.Encode(text)
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.sp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
|
||||
# python2 returns `str`.
|
||||
if sys.version_info < (3,0,0):
|
||||
text = text.encode('utf-8')
|
||||
self.assertEqual(text, self.sp_.DecodeIds(ids))
|
||||
def test_unicode_roundtrip(self):
|
||||
text = u'I saw a girl with a telescope.'
|
||||
ids = self.sp_.EncodeAsIds(text)
|
||||
pieces1 = self.sp_.EncodeAsPieces(text)
|
||||
pieces2 = self.sp_.Encode(text)
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.sp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
|
||||
# python2 returns `str`.
|
||||
if sys.version_info < (3,0,0):
|
||||
text = text.encode('utf-8')
|
||||
self.assertEqual(text, self.sp_.DecodeIds(ids))
|
||||
|
||||
def test_unicode_ja_roundtrip(self):
|
||||
text = u'清水寺は京都にある。'
|
||||
ids = self.jasp_.EncodeAsIds(text)
|
||||
pieces1 = self.jasp_.EncodeAsPieces(text)
|
||||
pieces2 = self.jasp_.Encode(text)
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
|
||||
# python2 returns `str`.
|
||||
if sys.version_info < (3,0,0):
|
||||
text = text.encode('utf-8')
|
||||
self.assertEqual(text, self.jasp_.DecodeIds(ids))
|
||||
def test_unicode_ja_roundtrip(self):
|
||||
text = u'清水寺は京都にある。'
|
||||
ids = self.jasp_.EncodeAsIds(text)
|
||||
pieces1 = self.jasp_.EncodeAsPieces(text)
|
||||
pieces2 = self.jasp_.Encode(text)
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.Decode(pieces1))
|
||||
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
|
||||
# python2 returns `str`.
|
||||
if sys.version_info < (3,0,0):
|
||||
text = text.encode('utf-8')
|
||||
self.assertEqual(text, self.jasp_.DecodeIds(ids))
|
||||
|
||||
def test_train(self):
|
||||
spm.SentencePieceTrainer.Train(
|
||||
"--input=test/botchan.txt --model_prefix=m --vocab_size=1000")
|
||||
def test_train(self):
|
||||
spm.SentencePieceTrainer.Train(
|
||||
"--input=test/botchan.txt --model_prefix=m --vocab_size=1000")
|
||||
|
||||
# snake case API.
|
||||
def test_load_snake(self):
|
||||
self.assertEqual(1000, self.sp_.get_piece_size())
|
||||
self.assertEqual(0, self.sp_.piece_to_id('<unk>'))
|
||||
self.assertEqual(1, self.sp_.piece_to_id('<s>'))
|
||||
self.assertEqual(2, self.sp_.piece_to_id('</s>'))
|
||||
self.assertEqual('<unk>', self.sp_.id_to_piece(0))
|
||||
self.assertEqual('<s>', self.sp_.id_to_piece(1))
|
||||
self.assertEqual('</s>', self.sp_.id_to_piece(2))
|
||||
for i in range(self.sp_.get_piece_size()):
|
||||
piece = self.sp_.id_to_piece(i)
|
||||
self.assertEqual(i, self.sp_.piece_to_id(piece))
|
||||
|
||||
def test_roundtrip_snake(self):
|
||||
text = 'I saw a girl with a telescope.'
|
||||
ids = self.sp_.encode_as_ids(text)
|
||||
pieces1 = self.sp_.encode_as_pieces(text)
|
||||
pieces2 = self.sp_.encode(text)
|
||||
pieces3 = self.sp_.nbest_encode(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(pieces1, pieces3)
|
||||
self.assertEqual(text, self.sp_.decode(pieces1))
|
||||
self.assertEqual(text, self.sp_.decode_pieces(pieces2))
|
||||
self.assertEqual(text, self.sp_.decode_ids(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, -1, 0.5)))
|
||||
self.assertEqual(text, self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)))
|
||||
|
||||
def test_ja_load_snake(self):
|
||||
self.assertEqual(8000, self.jasp_.get_piece_size())
|
||||
self.assertEqual(0, self.jasp_.piece_to_id('<unk>'))
|
||||
self.assertEqual(1, self.jasp_.piece_to_id('<s>'))
|
||||
self.assertEqual(2, self.jasp_.piece_to_id('</s>'))
|
||||
self.assertEqual('<unk>', self.jasp_.id_to_piece(0))
|
||||
self.assertEqual('<s>', self.jasp_.id_to_piece(1))
|
||||
self.assertEqual('</s>', self.jasp_.id_to_piece(2))
|
||||
for i in range(self.jasp_.get_piece_size()):
|
||||
piece = self.jasp_.id_to_piece(i)
|
||||
self.assertEqual(i, self.jasp_.piece_to_id(piece))
|
||||
|
||||
def test_ja_roundtrip_snake(self):
|
||||
text = '清水寺は京都にある。'
|
||||
ids = self.jasp_.encode_as_ids(text)
|
||||
pieces1 = self.jasp_.encode_as_pieces(text)
|
||||
pieces2 = self.jasp_.encode(text)
|
||||
pieces3 = self.jasp_.nbest_encode(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.decode(pieces1))
|
||||
self.assertEqual(text, self.jasp_.decode_pieces(pieces2))
|
||||
self.assertEqual(text, self.jasp_.decode_ids(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, 64, 0.5)))
|
||||
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, -1, 0.5)))
|
||||
|
||||
def test_unicode_roundtrip_snake(self):
|
||||
text = u'I saw a girl with a telescope.'
|
||||
ids = self.sp_.encode_as_ids(text)
|
||||
pieces1 = self.sp_.encode_as_pieces(text)
|
||||
pieces2 = self.sp_.encode(text)
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.sp_.decode(pieces1))
|
||||
self.assertEqual(text, self.sp_.decode_pieces(pieces2))
|
||||
# python2 returns `str`.
|
||||
if sys.version_info < (3,0,0):
|
||||
text = text.encode('utf-8')
|
||||
self.assertEqual(text, self.sp_.decode_ids(ids))
|
||||
|
||||
def test_unicode_ja_roundtrip_snake(self):
|
||||
text = u'清水寺は京都にある。'
|
||||
ids = self.jasp_.encode_as_ids(text)
|
||||
pieces1 = self.jasp_.encode_as_pieces(text)
|
||||
pieces2 = self.jasp_.encode(text)
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.decode(pieces1))
|
||||
self.assertEqual(text, self.jasp_.decode_pieces(pieces2))
|
||||
# python2 returns `str`.
|
||||
if sys.version_info < (3,0,0):
|
||||
text = text.encode('utf-8')
|
||||
self.assertEqual(text, self.jasp_.decode_ids(ids))
|
||||
|
||||
def test_train_snake(self):
|
||||
spm.SentencePieceTrainer.train(
|
||||
"--input=test/botchan.txt --model_prefix=m --vocab_size=1000")
|
||||
|
||||
|
||||
def suite():
|
||||
|
@ -106,5 +195,6 @@ def suite():
|
|||
suite.addTests(unittest.makeSuite(TestSentencepieceProcessor))
|
||||
return suite
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -267,6 +267,71 @@ class SentencePieceProcessor {
|
|||
virtual util::Status Decode(const std::vector<int> &ids,
|
||||
SentencePieceText *spt) const;
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Handy methods that return the result directly.
|
||||
// These functions ignore internal errors.
|
||||
#ifdef SWIG
|
||||
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
|
||||
OutType output; \
|
||||
const auto _status = FuncName(__VA_ARGS__, &output); \
|
||||
if (!_status.ok()) throw _status; \
|
||||
return output;
|
||||
#else
|
||||
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
|
||||
OutType output; \
|
||||
FuncName(__VA_ARGS__, &output).IgnoreError(); \
|
||||
return output;
|
||||
#endif
|
||||
|
||||
// Encode
|
||||
virtual std::vector<std::string> EncodeAsPieces(
|
||||
const std::string &input) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
|
||||
}
|
||||
|
||||
virtual std::vector<int> EncodeAsIds(const std::string &input) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
|
||||
}
|
||||
|
||||
// NBestEncode
|
||||
virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
|
||||
const std::string &input, int nbest_size) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(
|
||||
NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
|
||||
}
|
||||
|
||||
virtual std::vector<std::vector<int>> NBestEncodeAsIds(
|
||||
const std::string &input, int nbest_size) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
|
||||
input, nbest_size);
|
||||
}
|
||||
|
||||
// SampleEncode
|
||||
virtual std::vector<std::string> SampleEncodeAsPieces(
|
||||
const std::string &input, int nbest_size, float alpha) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
|
||||
nbest_size, alpha);
|
||||
}
|
||||
|
||||
virtual std::vector<int> SampleEncodeAsIds(const std::string &input,
|
||||
int nbest_size,
|
||||
float alpha) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
|
||||
nbest_size, alpha);
|
||||
}
|
||||
|
||||
// Decode
|
||||
virtual std::string DecodePieces(
|
||||
const std::vector<std::string> &pieces) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
|
||||
}
|
||||
|
||||
virtual std::string DecodeIds(const std::vector<int> &ids) const {
|
||||
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
|
||||
}
|
||||
|
||||
#undef DEFINE_SPP_DIRECT_FUNC_IMPL
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Vocabulary management methods.
|
||||
//
|
||||
|
|
Загрузка…
Ссылка в новой задаче