This commit is contained in:
Taku Kudo 2018-08-13 02:16:37 +09:00
Родитель 382bae0d7b
Коммит 083275a89f
2 изменённых файлов: 10 добавлений и 12 удалений

Просмотреть файл

@ -27,7 +27,7 @@ with codecs.open('README.md', 'r', 'utf-8') as f:
long_description = f.read() long_description = f.read()
with codecs.open(os.path.join('..', 'VERSION'), 'r', 'utf-8') as f: with codecs.open(os.path.join('..', 'VERSION'), 'r', 'utf-8') as f:
version = f.read() version = f.read().rstrip()
def cmd(line): def cmd(line):
try: try:

Просмотреть файл

@ -30,15 +30,13 @@ class TestSentencepieceProcessor(unittest.TestCase):
def setUp(self): def setUp(self):
self.sp_ = spm.SentencePieceProcessor() self.sp_ = spm.SentencePieceProcessor()
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model'))) self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model')))
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.jasp_.Load(os.path.join('test', 'test_ja_model.model'))) self.assertTrue(self.jasp_.Load(os.path.join('test', 'test_ja_model.model')))
self.assertTrue(self.sp_.LoadFromSerializedProto( with open(os.path.join('test', 'test_model.model'), 'rb') as f:
open(os.path.join('test', 'test_model.model'), 'rb').read())) self.assertTrue(self.sp_.LoadFromSerializedProto(f.read()))
self.jasp_ = spm.SentencePieceProcessor() with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f:
self.assertTrue(self.jasp_.LoadFromSerializedProto( self.assertTrue(self.jasp_.LoadFromSerializedProto(f.read()))
open(os.path.join('test', 'test_ja_model.model'), 'rb').read()))
def test_load(self): def test_load(self):
self.assertEqual(1000, self.sp_.GetPieceSize()) self.assertEqual(1000, self.sp_.GetPieceSize())
@ -91,8 +89,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(text, self.jasp_.DecodePieces(pieces1)) self.assertEqual(text, self.jasp_.DecodePieces(pieces1))
self.assertEqual(text, self.jasp_.DecodeIds(ids)) self.assertEqual(text, self.jasp_.DecodeIds(ids))
for n in range(100): for n in range(100):
self.assertEqual(text, self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5))) self.assertEqual(text, self.jasp_.DecodePieces(self.jasp_.SampleEncodeAsPieces(text, 64, 0.5)))
self.assertEqual(text, self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5))) self.assertEqual(text, self.jasp_.DecodePieces(self.jasp_.SampleEncodeAsPieces(text, -1, 0.5)))
def test_unicode_roundtrip(self): def test_unicode_roundtrip(self):
@ -176,8 +174,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(text, self.jasp_.decode_pieces(pieces1)) self.assertEqual(text, self.jasp_.decode_pieces(pieces1))
self.assertEqual(text, self.jasp_.decode_ids(ids)) self.assertEqual(text, self.jasp_.decode_ids(ids))
for n in range(100): for n in range(100):
self.assertEqual(text, self.sp_.decode_pieces(self.sp_.sample_encode_as_pieces(text, 64, 0.5))) self.assertEqual(text, self.jasp_.decode_pieces(self.jasp_.sample_encode_as_pieces(text, 64, 0.5)))
self.assertEqual(text, self.sp_.decode_pieces(self.sp_.sample_encode_as_pieces(text, -1, 0.5))) self.assertEqual(text, self.jasp_.decode_pieces(self.jasp_.sample_encode_as_pieces(text, -1, 0.5)))
def test_unicode_roundtrip_snake(self): def test_unicode_roundtrip_snake(self):
text = u'I saw a girl with a telescope.' text = u'I saw a girl with a telescope.'