From 083275a89f2bae2010245e56eef02b06aed86a66 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Mon, 13 Aug 2018 02:16:37 +0900 Subject: [PATCH] Fixed test error on Windows --- python/setup.py | 2 +- python/test/sentencepiece_test.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/python/setup.py b/python/setup.py index 045150d..62d1d2c 100755 --- a/python/setup.py +++ b/python/setup.py @@ -27,7 +27,7 @@ with codecs.open('README.md', 'r', 'utf-8') as f: long_description = f.read() with codecs.open(os.path.join('..', 'VERSION'), 'r', 'utf-8') as f: - version = f.read() + version = f.read().rstrip() def cmd(line): try: diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 620025d..39d8505 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -30,15 +30,13 @@ class TestSentencepieceProcessor(unittest.TestCase): def setUp(self): self.sp_ = spm.SentencePieceProcessor() + self.jasp_ = spm.SentencePieceProcessor() self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model'))) - self.jasp_ = spm.SentencePieceProcessor() self.assertTrue(self.jasp_.Load(os.path.join('test', 'test_ja_model.model'))) - self.assertTrue(self.sp_.LoadFromSerializedProto( - open(os.path.join('test', 'test_model.model'), 'rb').read())) - self.jasp_ = spm.SentencePieceProcessor() - self.assertTrue(self.jasp_.LoadFromSerializedProto( - open(os.path.join('test', 'test_ja_model.model'), 'rb').read())) - + with open(os.path.join('test', 'test_model.model'), 'rb') as f: + self.assertTrue(self.sp_.LoadFromSerializedProto(f.read())) + with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f: + self.assertTrue(self.jasp_.LoadFromSerializedProto(f.read())) def test_load(self): self.assertEqual(1000, self.sp_.GetPieceSize()) @@ -91,8 +89,8 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertEqual(text, self.jasp_.DecodePieces(pieces1)) self.assertEqual(text, self.jasp_.DecodeIds(ids)) for n in range(100): - self.assertEqual(text, self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5))) - self.assertEqual(text, self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5))) + self.assertEqual(text, self.jasp_.DecodePieces(self.jasp_.SampleEncodeAsPieces(text, 64, 0.5))) + self.assertEqual(text, self.jasp_.DecodePieces(self.jasp_.SampleEncodeAsPieces(text, -1, 0.5))) def test_unicode_roundtrip(self): @@ -176,8 +174,8 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertEqual(text, self.jasp_.decode_pieces(pieces1)) self.assertEqual(text, self.jasp_.decode_ids(ids)) for n in range(100): - self.assertEqual(text, self.sp_.decode_pieces(self.sp_.sample_encode_as_pieces(text, 64, 0.5))) - self.assertEqual(text, self.sp_.decode_pieces(self.sp_.sample_encode_as_pieces(text, -1, 0.5))) + self.assertEqual(text, self.jasp_.decode_pieces(self.jasp_.sample_encode_as_pieces(text, 64, 0.5))) + self.assertEqual(text, self.jasp_.decode_pieces(self.jasp_.sample_encode_as_pieces(text, -1, 0.5))) def test_unicode_roundtrip_snake(self): text = u'I saw a girl with a telescope.'