Merge pull request #179 from google/sr

Sr
This commit is contained in:
Taku Kudo 2018-08-15 17:40:27 +09:00 коммит произвёл GitHub
Родитель 33146ef56f 58ca64bf0e
Коммит 0b40e830a2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 41 добавлений и 47 удалений

Просмотреть файл

@ -3,44 +3,51 @@
import itertools as it
import os
import sys
import unittest
import tensorflow as tf
import sentencepiece as spm
import tf_sentencepiece as tfspm
class SentencePieceProcssorOpTest(unittest.TestCase):
def _getSentencePieceModelFile(self):
return '../python/test/test_ja_model.model'
return os.path.join('..', 'python', 'test', 'test_model.model')
def _getExpected(self, processor, reverse=False, add_bos=False,
def _getPieceSize(self):
return 1000
def _getExpected(self, reverse=False, add_bos=False,
add_eos=False, padding=''):
options = []
# TF uses str(bytes) as a string representation.
padding = padding.encode('utf8')
sentences = [b'Hello world.', b'I have a pen.',
b'I saw a girl with a telescope.']
pieces = [[b'\xe2\x96\x81He', b'll', b'o', b'\xe2\x96\x81world', b'.'],
[b'\xe2\x96\x81I', b'\xe2\x96\x81have', b'\xe2\x96\x81a',
b'\xe2\x96\x81p', b'en', b'.'],
[b'\xe2\x96\x81I', b'\xe2\x96\x81saw', b'\xe2\x96\x81a',
b'\xe2\x96\x81girl', b'\xe2\x96\x81with',
b'\xe2\x96\x81a', b'\xe2\x96\x81',
b'te', b'le', b's', b'c', b'o', b'pe', b'.']]
ids = [[151, 88, 21, 887, 6],
[9, 76, 11, 68, 98, 6],
[9, 459, 11, 939, 44, 11, 4, 142, 82, 8, 28, 21, 132, 6]]
seq_len = [5, 6, 14]
if reverse:
options.append('reverse')
ids = [x[::-1] for x in ids]
pieces = [x[::-1] for x in pieces]
if add_bos:
options.append('bos')
ids = [[1] + x for x in ids]
pieces = [[b'<s>'] + x for x in pieces]
seq_len = [x + 1 for x in seq_len]
if add_eos:
options.append('eos')
ids = [x + [2] for x in ids]
pieces = [x + [b'</s>'] for x in pieces]
seq_len = [x + 1 for x in seq_len]
processor.SetEncodeExtraOptions(':'.join(options))
processor.SetDecodeExtraOptions(':'.join(options))
sentences = ['Hello world.', 'I have a pen.',
'I saw a girl with a telescope.']
pieces = []
ids = []
seq_len = []
for s in sentences:
x = processor.EncodeAsPieces(s)
y = processor.EncodeAsIds(s)
pieces.append(x)
ids.append(y)
seq_len.append(len(x))
self.assertEqual(len(x), len(y))
# padding
max_len = max(seq_len)
pieces = [x + [padding] * (max_len - len(x)) for x in pieces]
ids = [x + [0] * (max_len - len(x)) for x in ids]
@ -49,21 +56,16 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
def testGetPieceSize(self):
sentencepiece_model_file = self._getSentencePieceModelFile()
processor = spm.SentencePieceProcessor()
processor.Load(sentencepiece_model_file)
with tf.Session():
s = tfspm.piece_size(
model_file=sentencepiece_model_file)
self.assertEqual(s.eval(), processor.GetPieceSize())
self.assertEqual(s.eval(), self._getPieceSize())
def testConvertPiece(self):
sentencepiece_model_file = self._getSentencePieceModelFile()
processor = spm.SentencePieceProcessor()
processor.Load(sentencepiece_model_file)
(sentences, expected_pieces,
expected_ids, expected_seq_len) = self._getExpected(processor,
padding='<unk>')
expected_ids, expected_seq_len) = self._getExpected(padding='<unk>')
with tf.Session():
ids_matrix = tfspm.piece_to_id(
@ -97,15 +99,13 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
def testEncodeAndDecode(self):
sentencepiece_model_file = self._getSentencePieceModelFile()
processor = spm.SentencePieceProcessor()
processor.Load(sentencepiece_model_file)
with tf.Session():
for reverse, add_bos, add_eos in list(it.product(
(True, False), repeat=3)):
(sentences, expected_pieces,
expected_ids, expected_seq_len) = self._getExpected(
processor, reverse, add_bos, add_eos)
reverse=reverse, add_bos=add_bos, add_eos=add_eos)
# Encode sentences into pieces/ids.
s = tf.constant(sentences)
@ -138,9 +138,7 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
def testSampleEncodeAndDecode(self):
sentencepiece_model_file = self._getSentencePieceModelFile()
processor = spm.SentencePieceProcessor()
processor.Load(sentencepiece_model_file)
sentences, _, _, _ = self._getExpected(processor)
sentences, _, _, _ = self._getExpected()
with tf.Session():
for n, a in [(-1, 0.1), (64, 0.1), (0, 0.0)]:
@ -165,14 +163,12 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
def testEncodeAndDecodeSparse(self):
sentencepiece_model_file = self._getSentencePieceModelFile()
processor = spm.SentencePieceProcessor()
processor.Load(sentencepiece_model_file)
with tf.Session():
for reverse, add_bos, add_eos in list(it.product(
(True, False), repeat=3)):
(sentences, expected_pieces, expected_ids,
_) = self._getExpected(processor, reverse, add_bos, add_eos)
_) = self._getExpected(reverse, add_bos, add_eos)
# Encode sentences into sparse pieces/ids.
s = tf.constant(sentences)
@ -191,18 +187,16 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
def testGetPieceType(self):
sentencepiece_model_file = self._getSentencePieceModelFile()
processor = spm.SentencePieceProcessor()
processor.Load(sentencepiece_model_file)
expected_is_unknown = []
expected_is_control = []
expected_is_unused = []
ids = []
for i in range(processor.GetPieceSize()):
for i in range(self._getPieceSize()):
ids.append(i)
expected_is_unknown.append(processor.IsUnknown(i))
expected_is_control.append(processor.IsControl(i))
expected_is_unused.append(processor.IsUnused(i))
expected_is_unknown.append(i == 0)
expected_is_control.append(i == 1 or i == 2)
expected_is_unused.append(False)
with tf.Session():
s = tf.constant(ids)