Merge pull request #105 from google/sr

Splits travis rule for Python wrapper
This commit is contained in:
Taku Kudo 2018-06-16 23:18:53 +09:00 коммит произвёл GitHub
Родитель 68aba64804 0a30c9bac1
Коммит bcdfc037b1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 2400 добавлений и 682 удалений

Просмотреть файл

@ -16,13 +16,17 @@ matrix:
- os: linux
env: IMAGE=fedora:latest COMMAND=build_linux_gcc_fedora
services: docker
- os: linux
script:
- $TRAVIS_BUILD_DIR/make_py_wheel.sh
services: docker
- os: linux
env: IMAGE=ubuntu:rolling COMMAND=build_linux_clang_ubuntu
services: docker
- os: linux
script:
- $TRAVIS_BUILD_DIR/make_py_wheel.sh x86_64
services: docker
- os: linux
script:
- $TRAVIS_BUILD_DIR/make_py_wheel.sh i686
services: docker
- os: osx
osx_image: xcode9.3
env: IMAGE=native COMMAND=build_osx

Просмотреть файл

@ -83,6 +83,8 @@ make_wheel() {
if [ "$#" -eq 2 ]; then
eval "$1" $2
elif [ "$#" -eq 1 ]; then
run_docker quay.io/pypa/manylinux1_${1} ${1}
else
run_docker quay.io/pypa/manylinux1_i686 i686
run_docker quay.io/pypa/manylinux1_x86_64 x86_64

Просмотреть файл

@ -15,7 +15,7 @@ class PyInputString {
explicit PyInputString(PyObject* obj) {
#if PY_VERSION_HEX >= 0x03000000
if (PyUnicode_Check(obj)) {
str_ = PyUnicode_AsUTF8AndSize(obj, &size_);
str_ = const_cast<char *>(PyUnicode_AsUTF8AndSize(obj, &size_));
input_type_ = kUnicodeInput;
} else if (PyBytes_Check(obj)) {
PyBytes_AsStringAndSize(obj, &str_, &size_);
@ -125,77 +125,135 @@ int ToSwigError(sentencepiece::util::error::Code code) {
const std::string &,
google::protobuf::Message *message);
%extend sentencepiece::SentencePieceTrainer {
static util::Status train(const std::string &args) {
return sentencepiece::SentencePieceTrainer::Train(args);
}
}
%extend sentencepiece::SentencePieceProcessor {
util::Status load(const std::string &filename) {
return $self->Load(filename);
}
util::Status _set_encode_extra_options(const std::string &extra_option) {
return $self->SetEncodeExtraOptions(extra_option);
}
util::Status _set_decode_extra_options(const std::string &extra_option) {
return $self->SetDecodeExtraOptions(extra_option);
}
util::Status set_vocabulary(
const std::vector<std::string> &valid_vocab) {
return $self->SetVocabulary(valid_vocab);
}
util::Status reset_vocabulary() {
return $self->ResetVocabulary();
}
util::Status load_vocabulary(const std::string &filename,
int threshold) {
return $self->LoadVocabulary(filename, threshold);
}
std::vector<std::string> encode(const std::string& input) const {
return $self->EncodeAsPieces(input);
}
std::vector<std::string> encode_as_pieces(const std::string& input) const {
return $self->EncodeAsPieces(input);
}
std::vector<int> encode_as_ids(const std::string& input) const {
return $self->EncodeAsIds(input);
}
std::vector<std::vector<std::string>> nbest_encode(const std::string& input,
int nbest_size) const {
return $self->NBestEncodeAsPieces(input, nbest_size);
}
std::vector<std::vector<std::string>> nbest_encode_as_pieces(
const std::string& input, int nbest_size) const {
return $self->NBestEncodeAsPieces(input, nbest_size);
}
std::vector<std::vector<int>> nbest_encode_as_ids(const std::string& input,
int nbest_size) const {
return $self->NBestEncodeAsIds(input, nbest_size);
}
std::vector<std::string> sample_encode(const std::string& input,
int nbest_size, float alpha) const {
return $self->SampleEncodeAsPieces(input, nbest_size, alpha);
}
std::vector<std::string> sample_encode_as_pieces(const std::string& input,
int nbest_size, float alpha) const {
return $self->SampleEncodeAsPieces(input, nbest_size, alpha);
}
std::vector<int> sample_encode_as_ids(const std::string& input,
int nbest_size, float alpha) const {
return $self->SampleEncodeAsIds(input, nbest_size, alpha);
}
std::string decode(const std::vector<std::string>& input) const {
return $self->DecodePieces(input);
}
std::string decode_pieces(const std::vector<std::string>& input) const {
return $self->DecodePieces(input);
}
std::string decode_ids(const std::vector<int>& input) const {
return $self->DecodeIds(input);
}
std::vector<std::string> Encode(const std::string& input) const {
std::vector<std::string> output;
THROW_IF_ERROR($self->Encode(input, &output));
return output;
return $self->EncodeAsPieces(input);
}
std::vector<std::string> EncodeAsPieces(const std::string& input) const {
std::vector<std::string> output;
THROW_IF_ERROR($self->Encode(input, &output));
return output;
int get_piece_size() const {
return $self->GetPieceSize();
}
std::vector<int> EncodeAsIds(const std::string& input) const {
std::vector<int> output;
THROW_IF_ERROR($self->Encode(input, &output));
return output;
int piece_to_id(const std::string &piece) const {
return $self->PieceToId(piece);
}
std::string id_to_piece(int id) const {
return $self->IdToPiece(id);
}
float get_score(int id) const {
return $self->GetScore(id);
}
bool is_unknown(int id) const {
return $self->IsUnused(id);
}
bool is_control(int id) const {
return $self->IsControl(id);
}
bool is_unused(int id) const {
return $self->IsUnused(id);
}
std::vector<std::vector<std::string>> NBestEncode(const std::string& input, int nbest_size) const {
std::vector<std::vector<std::string>> output;
THROW_IF_ERROR($self->NBestEncode(input, nbest_size, &output));
return output;
}
std::vector<std::vector<std::string>> NBestEncodeAsPieces(const std::string& input, int nbest_size) const {
std::vector<std::vector<std::string>> output;
THROW_IF_ERROR($self->NBestEncode(input, nbest_size, &output));
return output;
}
std::vector<std::vector<int>> NBestEncodeAsIds(const std::string& input, int nbest_size) const {
std::vector<std::vector<int>> output;
THROW_IF_ERROR($self->NBestEncode(input, nbest_size, &output));
return output;
return $self->NBestEncodeAsPieces(input, nbest_size);
}
std::vector<std::string> SampleEncode(const std::string& input, int nbest_size, float alpha) const {
std::vector<std::string> output;
THROW_IF_ERROR($self->SampleEncode(input, nbest_size, alpha, &output));
return output;
}
std::vector<std::string> SampleEncodeAsPieces(const std::string& input, int nbest_size, float alpha) const {
std::vector<std::string> output;
THROW_IF_ERROR($self->SampleEncode(input, nbest_size, alpha, &output));
return output;
}
std::vector<int> SampleEncodeAsIds(const std::string& input, int nbest_size, float alpha) const {
std::vector<int> output;
THROW_IF_ERROR($self->SampleEncode(input, nbest_size, alpha, &output));
return output;
return $self->SampleEncodeAsPieces(input, nbest_size, alpha);
}
std::string Decode(const std::vector<std::string>& input) const {
std::string output;
THROW_IF_ERROR($self->Decode(input, &output));
return output;
}
std::string DecodePieces(const std::vector<std::string>& input) const {
std::string output;
THROW_IF_ERROR($self->Decode(input, &output));
return output;
}
std::string DecodeIds(const std::vector<int>& input) const {
std::string output;
THROW_IF_ERROR($self->Decode(input, &output));
return output;
return $self->DecodePieces(input);
}
int __len__() {

Просмотреть файл

@ -135,6 +135,30 @@ class SentencePieceProcessor(_object):
def LoadVocabulary(self, filename, threshold):
return _sentencepiece.SentencePieceProcessor_LoadVocabulary(self, filename, threshold)
def EncodeAsPieces(self, input):
return _sentencepiece.SentencePieceProcessor_EncodeAsPieces(self, input)
def EncodeAsIds(self, input):
return _sentencepiece.SentencePieceProcessor_EncodeAsIds(self, input)
def NBestEncodeAsPieces(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsPieces(self, input, nbest_size)
def NBestEncodeAsIds(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsIds(self, input, nbest_size)
def SampleEncodeAsPieces(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsPieces(self, input, nbest_size, alpha)
def SampleEncodeAsIds(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsIds(self, input, nbest_size, alpha)
def DecodePieces(self, pieces):
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces)
def DecodeIds(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids)
def GetPieceSize(self):
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
@ -156,42 +180,93 @@ class SentencePieceProcessor(_object):
def IsUnused(self, id):
return _sentencepiece.SentencePieceProcessor_IsUnused(self, id)
def load(self, filename):
return _sentencepiece.SentencePieceProcessor_load(self, filename)
def _set_encode_extra_options(self, extra_option):
return _sentencepiece.SentencePieceProcessor__set_encode_extra_options(self, extra_option)
def _set_decode_extra_options(self, extra_option):
return _sentencepiece.SentencePieceProcessor__set_decode_extra_options(self, extra_option)
def set_vocabulary(self, valid_vocab):
return _sentencepiece.SentencePieceProcessor_set_vocabulary(self, valid_vocab)
def reset_vocabulary(self):
return _sentencepiece.SentencePieceProcessor_reset_vocabulary(self)
def load_vocabulary(self, filename, threshold):
return _sentencepiece.SentencePieceProcessor_load_vocabulary(self, filename, threshold)
def encode(self, input):
return _sentencepiece.SentencePieceProcessor_encode(self, input)
def encode_as_pieces(self, input):
return _sentencepiece.SentencePieceProcessor_encode_as_pieces(self, input)
def encode_as_ids(self, input):
return _sentencepiece.SentencePieceProcessor_encode_as_ids(self, input)
def nbest_encode(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_nbest_encode(self, input, nbest_size)
def nbest_encode_as_pieces(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_nbest_encode_as_pieces(self, input, nbest_size)
def nbest_encode_as_ids(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_nbest_encode_as_ids(self, input, nbest_size)
def sample_encode(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_sample_encode(self, input, nbest_size, alpha)
def sample_encode_as_pieces(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_sample_encode_as_pieces(self, input, nbest_size, alpha)
def sample_encode_as_ids(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_sample_encode_as_ids(self, input, nbest_size, alpha)
def decode(self, input):
return _sentencepiece.SentencePieceProcessor_decode(self, input)
def decode_pieces(self, input):
return _sentencepiece.SentencePieceProcessor_decode_pieces(self, input)
def decode_ids(self, input):
return _sentencepiece.SentencePieceProcessor_decode_ids(self, input)
def Encode(self, input):
return _sentencepiece.SentencePieceProcessor_Encode(self, input)
def EncodeAsPieces(self, input):
return _sentencepiece.SentencePieceProcessor_EncodeAsPieces(self, input)
def get_piece_size(self):
return _sentencepiece.SentencePieceProcessor_get_piece_size(self)
def EncodeAsIds(self, input):
return _sentencepiece.SentencePieceProcessor_EncodeAsIds(self, input)
def piece_to_id(self, piece):
return _sentencepiece.SentencePieceProcessor_piece_to_id(self, piece)
def id_to_piece(self, id):
return _sentencepiece.SentencePieceProcessor_id_to_piece(self, id)
def get_score(self, id):
return _sentencepiece.SentencePieceProcessor_get_score(self, id)
def is_unknown(self, id):
return _sentencepiece.SentencePieceProcessor_is_unknown(self, id)
def is_control(self, id):
return _sentencepiece.SentencePieceProcessor_is_control(self, id)
def is_unused(self, id):
return _sentencepiece.SentencePieceProcessor_is_unused(self, id)
def NBestEncode(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_NBestEncode(self, input, nbest_size)
def NBestEncodeAsPieces(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsPieces(self, input, nbest_size)
def NBestEncodeAsIds(self, input, nbest_size):
return _sentencepiece.SentencePieceProcessor_NBestEncodeAsIds(self, input, nbest_size)
def SampleEncode(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_SampleEncode(self, input, nbest_size, alpha)
def SampleEncodeAsPieces(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsPieces(self, input, nbest_size, alpha)
def SampleEncodeAsIds(self, input, nbest_size, alpha):
return _sentencepiece.SentencePieceProcessor_SampleEncodeAsIds(self, input, nbest_size, alpha)
def Decode(self, input):
return _sentencepiece.SentencePieceProcessor_Decode(self, input)
def DecodePieces(self, input):
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, input)
def DecodeIds(self, input):
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, input)
def __len__(self):
return _sentencepiece.SentencePieceProcessor___len__(self)
@ -213,6 +288,10 @@ class SentencePieceTrainer(_object):
Train = staticmethod(_sentencepiece.SentencePieceTrainer_Train)
else:
Train = _sentencepiece.SentencePieceTrainer_Train
if _newclass:
train = staticmethod(_sentencepiece.SentencePieceTrainer_train)
else:
train = _sentencepiece.SentencePieceTrainer_train
SentencePieceTrainer_swigregister = _sentencepiece.SentencePieceTrainer_swigregister
SentencePieceTrainer_swigregister(SentencePieceTrainer)
@ -220,6 +299,10 @@ def SentencePieceTrainer_Train(args):
return _sentencepiece.SentencePieceTrainer_Train(args)
SentencePieceTrainer_Train = _sentencepiece.SentencePieceTrainer_Train
def SentencePieceTrainer_train(args):
return _sentencepiece.SentencePieceTrainer_train(args)
SentencePieceTrainer_train = _sentencepiece.SentencePieceTrainer_train
# This file is compatible with both classic and new-style classes.

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -6,99 +6,188 @@ import unittest
import sys
class TestSentencepieceProcessor(unittest.TestCase):
"""Test case for SentencePieceProcessor"""
"""Test case for SentencePieceProcessor"""
def setUp(self):
self.sp_ = spm.SentencePieceProcessor()
self.assertTrue(self.sp_.Load('test/test_model.model'))
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.jasp_.Load('test/test_ja_model.model'))
def setUp(self):
self.sp_ = spm.SentencePieceProcessor()
self.assertTrue(self.sp_.Load('test/test_model.model'))
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.jasp_.Load('test/test_ja_model.model'))
self.assertTrue(self.sp_.load('test/test_model.model'))
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.jasp_.load('test/test_ja_model.model'))
def test_load(self):
self.assertEqual(1000, self.sp_.GetPieceSize())
self.assertEqual(0, self.sp_.PieceToId('<unk>'))
self.assertEqual(1, self.sp_.PieceToId('<s>'))
self.assertEqual(2, self.sp_.PieceToId('</s>'))
self.assertEqual('<unk>', self.sp_.IdToPiece(0))
self.assertEqual('<s>', self.sp_.IdToPiece(1))
self.assertEqual('</s>', self.sp_.IdToPiece(2))
for i in range(self.sp_.GetPieceSize()):
piece = self.sp_.IdToPiece(i)
self.assertEqual(i, self.sp_.PieceToId(piece))
def test_load(self):
self.assertEqual(1000, self.sp_.GetPieceSize())
self.assertEqual(0, self.sp_.PieceToId('<unk>'))
self.assertEqual(1, self.sp_.PieceToId('<s>'))
self.assertEqual(2, self.sp_.PieceToId('</s>'))
self.assertEqual('<unk>', self.sp_.IdToPiece(0))
self.assertEqual('<s>', self.sp_.IdToPiece(1))
self.assertEqual('</s>', self.sp_.IdToPiece(2))
for i in range(self.sp_.GetPieceSize()):
piece = self.sp_.IdToPiece(i)
self.assertEqual(i, self.sp_.PieceToId(piece))
def test_roundtrip(self):
text = 'I saw a girl with a telescope.'
ids = self.sp_.EncodeAsIds(text)
pieces1 = self.sp_.EncodeAsPieces(text)
pieces2 = self.sp_.Encode(text)
pieces3 = self.sp_.NBestEncode(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(pieces1, pieces3)
self.assertEqual(text, self.sp_.Decode(pieces1))
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
self.assertEqual(text, self.sp_.DecodeIds(ids))
for n in range(100):
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)))
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)))
def test_roundtrip(self):
text = 'I saw a girl with a telescope.'
ids = self.sp_.EncodeAsIds(text)
pieces1 = self.sp_.EncodeAsPieces(text)
pieces2 = self.sp_.Encode(text)
pieces3 = self.sp_.NBestEncode(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(pieces1, pieces3)
self.assertEqual(text, self.sp_.Decode(pieces1))
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
self.assertEqual(text, self.sp_.DecodeIds(ids))
for n in range(100):
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)))
self.assertEqual(text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)))
def test_ja_load(self):
self.assertEqual(8000, self.jasp_.GetPieceSize())
self.assertEqual(0, self.jasp_.PieceToId('<unk>'))
self.assertEqual(1, self.jasp_.PieceToId('<s>'))
self.assertEqual(2, self.jasp_.PieceToId('</s>'))
self.assertEqual('<unk>', self.jasp_.IdToPiece(0))
self.assertEqual('<s>', self.jasp_.IdToPiece(1))
self.assertEqual('</s>', self.jasp_.IdToPiece(2))
for i in range(self.jasp_.GetPieceSize()):
piece = self.jasp_.IdToPiece(i)
self.assertEqual(i, self.jasp_.PieceToId(piece))
def test_ja_load(self):
self.assertEqual(8000, self.jasp_.GetPieceSize())
self.assertEqual(0, self.jasp_.PieceToId('<unk>'))
self.assertEqual(1, self.jasp_.PieceToId('<s>'))
self.assertEqual(2, self.jasp_.PieceToId('</s>'))
self.assertEqual('<unk>', self.jasp_.IdToPiece(0))
self.assertEqual('<s>', self.jasp_.IdToPiece(1))
self.assertEqual('</s>', self.jasp_.IdToPiece(2))
for i in range(self.jasp_.GetPieceSize()):
piece = self.jasp_.IdToPiece(i)
self.assertEqual(i, self.jasp_.PieceToId(piece))
def test_ja_roundtrip(self):
text = '清水寺は京都にある。'
ids = self.jasp_.EncodeAsIds(text)
pieces1 = self.jasp_.EncodeAsPieces(text)
pieces2 = self.jasp_.Encode(text)
pieces3 = self.jasp_.NBestEncode(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.Decode(pieces1))
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
self.assertEqual(text, self.jasp_.DecodeIds(ids))
for n in range(100):
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
def test_ja_roundtrip(self):
text = '清水寺は京都にある。'
ids = self.jasp_.EncodeAsIds(text)
pieces1 = self.jasp_.EncodeAsPieces(text)
pieces2 = self.jasp_.Encode(text)
pieces3 = self.jasp_.NBestEncode(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.Decode(pieces1))
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
self.assertEqual(text, self.jasp_.DecodeIds(ids))
for n in range(100):
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, 64, 0.5)))
self.assertEqual(text, self.sp_.Decode(self.sp_.SampleEncode(text, -1, 0.5)))
def test_unicode_roundtrip(self):
text = u'I saw a girl with a telescope.'
ids = self.sp_.EncodeAsIds(text)
pieces1 = self.sp_.EncodeAsPieces(text)
pieces2 = self.sp_.Encode(text)
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.sp_.Decode(pieces1))
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
# python2 returns `str`.
if sys.version_info < (3,0,0):
text = text.encode('utf-8')
self.assertEqual(text, self.sp_.DecodeIds(ids))
def test_unicode_roundtrip(self):
text = u'I saw a girl with a telescope.'
ids = self.sp_.EncodeAsIds(text)
pieces1 = self.sp_.EncodeAsPieces(text)
pieces2 = self.sp_.Encode(text)
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.sp_.Decode(pieces1))
self.assertEqual(text, self.sp_.DecodePieces(pieces2))
# python2 returns `str`.
if sys.version_info < (3,0,0):
text = text.encode('utf-8')
self.assertEqual(text, self.sp_.DecodeIds(ids))
def test_unicode_ja_roundtrip(self):
text = u'清水寺は京都にある。'
ids = self.jasp_.EncodeAsIds(text)
pieces1 = self.jasp_.EncodeAsPieces(text)
pieces2 = self.jasp_.Encode(text)
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.Decode(pieces1))
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
# python2 returns `str`.
if sys.version_info < (3,0,0):
text = text.encode('utf-8')
self.assertEqual(text, self.jasp_.DecodeIds(ids))
def test_unicode_ja_roundtrip(self):
text = u'清水寺は京都にある。'
ids = self.jasp_.EncodeAsIds(text)
pieces1 = self.jasp_.EncodeAsPieces(text)
pieces2 = self.jasp_.Encode(text)
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.Decode(pieces1))
self.assertEqual(text, self.jasp_.DecodePieces(pieces2))
# python2 returns `str`.
if sys.version_info < (3,0,0):
text = text.encode('utf-8')
self.assertEqual(text, self.jasp_.DecodeIds(ids))
def test_train(self):
spm.SentencePieceTrainer.Train(
"--input=test/botchan.txt --model_prefix=m --vocab_size=1000")
def test_train(self):
spm.SentencePieceTrainer.Train(
"--input=test/botchan.txt --model_prefix=m --vocab_size=1000")
# snake case API.
def test_load_snake(self):
self.assertEqual(1000, self.sp_.get_piece_size())
self.assertEqual(0, self.sp_.piece_to_id('<unk>'))
self.assertEqual(1, self.sp_.piece_to_id('<s>'))
self.assertEqual(2, self.sp_.piece_to_id('</s>'))
self.assertEqual('<unk>', self.sp_.id_to_piece(0))
self.assertEqual('<s>', self.sp_.id_to_piece(1))
self.assertEqual('</s>', self.sp_.id_to_piece(2))
for i in range(self.sp_.get_piece_size()):
piece = self.sp_.id_to_piece(i)
self.assertEqual(i, self.sp_.piece_to_id(piece))
def test_roundtrip_snake(self):
text = 'I saw a girl with a telescope.'
ids = self.sp_.encode_as_ids(text)
pieces1 = self.sp_.encode_as_pieces(text)
pieces2 = self.sp_.encode(text)
pieces3 = self.sp_.nbest_encode(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(pieces1, pieces3)
self.assertEqual(text, self.sp_.decode(pieces1))
self.assertEqual(text, self.sp_.decode_pieces(pieces2))
self.assertEqual(text, self.sp_.decode_ids(ids))
for n in range(100):
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, 64, 0.5)))
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, -1, 0.5)))
self.assertEqual(text, self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)))
self.assertEqual(text, self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)))
def test_ja_load_snake(self):
self.assertEqual(8000, self.jasp_.get_piece_size())
self.assertEqual(0, self.jasp_.piece_to_id('<unk>'))
self.assertEqual(1, self.jasp_.piece_to_id('<s>'))
self.assertEqual(2, self.jasp_.piece_to_id('</s>'))
self.assertEqual('<unk>', self.jasp_.id_to_piece(0))
self.assertEqual('<s>', self.jasp_.id_to_piece(1))
self.assertEqual('</s>', self.jasp_.id_to_piece(2))
for i in range(self.jasp_.get_piece_size()):
piece = self.jasp_.id_to_piece(i)
self.assertEqual(i, self.jasp_.piece_to_id(piece))
def test_ja_roundtrip_snake(self):
text = '清水寺は京都にある。'
ids = self.jasp_.encode_as_ids(text)
pieces1 = self.jasp_.encode_as_pieces(text)
pieces2 = self.jasp_.encode(text)
pieces3 = self.jasp_.nbest_encode(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.decode(pieces1))
self.assertEqual(text, self.jasp_.decode_pieces(pieces2))
self.assertEqual(text, self.jasp_.decode_ids(ids))
for n in range(100):
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, 64, 0.5)))
self.assertEqual(text, self.sp_.decode(self.sp_.sample_encode(text, -1, 0.5)))
def test_unicode_roundtrip_snake(self):
text = u'I saw a girl with a telescope.'
ids = self.sp_.encode_as_ids(text)
pieces1 = self.sp_.encode_as_pieces(text)
pieces2 = self.sp_.encode(text)
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.sp_.decode(pieces1))
self.assertEqual(text, self.sp_.decode_pieces(pieces2))
# python2 returns `str`.
if sys.version_info < (3,0,0):
text = text.encode('utf-8')
self.assertEqual(text, self.sp_.decode_ids(ids))
def test_unicode_ja_roundtrip_snake(self):
text = u'清水寺は京都にある。'
ids = self.jasp_.encode_as_ids(text)
pieces1 = self.jasp_.encode_as_pieces(text)
pieces2 = self.jasp_.encode(text)
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.decode(pieces1))
self.assertEqual(text, self.jasp_.decode_pieces(pieces2))
# python2 returns `str`.
if sys.version_info < (3,0,0):
text = text.encode('utf-8')
self.assertEqual(text, self.jasp_.decode_ids(ids))
def test_train_snake(self):
spm.SentencePieceTrainer.train(
"--input=test/botchan.txt --model_prefix=m --vocab_size=1000")
def suite():
@ -106,5 +195,6 @@ def suite():
suite.addTests(unittest.makeSuite(TestSentencepieceProcessor))
return suite
if __name__ == '__main__':
unittest.main()
unittest.main()

Просмотреть файл

@ -267,6 +267,71 @@ class SentencePieceProcessor {
virtual util::Status Decode(const std::vector<int> &ids,
SentencePieceText *spt) const;
//////////////////////////////////////////////////////////////
// Handy methods that return the result directly.
// These functions ignore internal errors.
#ifdef SWIG
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
OutType output; \
const auto _status = FuncName(__VA_ARGS__, &output); \
if (!_status.ok()) throw _status; \
return output;
#else
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
OutType output; \
FuncName(__VA_ARGS__, &output).IgnoreError(); \
return output;
#endif
// Encode
virtual std::vector<std::string> EncodeAsPieces(
const std::string &input) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
}
virtual std::vector<int> EncodeAsIds(const std::string &input) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
}
// NBestEncode
virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
const std::string &input, int nbest_size) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(
NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
}
virtual std::vector<std::vector<int>> NBestEncodeAsIds(
const std::string &input, int nbest_size) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
input, nbest_size);
}
// SampleEncode
virtual std::vector<std::string> SampleEncodeAsPieces(
const std::string &input, int nbest_size, float alpha) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
nbest_size, alpha);
}
virtual std::vector<int> SampleEncodeAsIds(const std::string &input,
int nbest_size,
float alpha) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
nbest_size, alpha);
}
// Decode
virtual std::string DecodePieces(
const std::vector<std::string> &pieces) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
}
virtual std::string DecodeIds(const std::vector<int> &ids) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
}
#undef DEFINE_SPP_DIRECT_FUNC_IMPL
//////////////////////////////////////////////////////////////
// Vocabulary management methods.
//