make tensorflow be optional for unittest (#394)
* make tensorflow be optional for unitest. * typo
This commit is contained in:
Родитель
b5dce955f0
Коммит
2202e3e19b
|
@ -4,3 +4,6 @@ requires = ["setuptools", "wheel", "numpy>=1.18.5", "cmake"] # PEP 508 specific
|
|||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
|
|
@ -43,8 +43,9 @@ class TestBpeTokenizer(unittest.TestCase):
|
|||
pcm_tensor = self.decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
|
||||
self.assertTrue(pcm_tensor.shape[1] > len(blob))
|
||||
# lossy compression, so we can only check the range
|
||||
np.testing.assert_allclose(np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
|
||||
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
|
||||
np.testing.assert_allclose(
|
||||
np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
|
||||
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -13,8 +13,15 @@ from onnxruntime_extensions import (
|
|||
OrtPyFunction,
|
||||
make_onnx_model,
|
||||
get_library_path as _get_library_path)
|
||||
import tensorflow as tf
|
||||
from tensorflow_text import SentencepieceTokenizer
|
||||
|
||||
|
||||
_is_tensorflow_avaliable = False
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from tensorflow_text import SentencepieceTokenizer
|
||||
_is_tensorflow_avaliable = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def load_piece(name):
|
||||
|
@ -235,6 +242,7 @@ def _create_test_model_ragged_to_dense(
|
|||
return model
|
||||
|
||||
|
||||
@unittest.skipIf(not _is_tensorflow_avaliable, "tensorflow/tensorflow-text is unavailable")
|
||||
class TestPythonOpSentencePiece(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
|
@ -430,6 +438,8 @@ class TestPythonOpSentencePiece(unittest.TestCase):
|
|||
assert_almost_equal(exp[i], py_txout[i])
|
||||
assert_almost_equal(exp[i], cc_txout[i])
|
||||
|
||||
|
||||
class TestOrtXSentencePiece(unittest.TestCase):
|
||||
def test_external_pretrained_model(self):
|
||||
fullname = util.get_test_data_file('data', 'en.wiki.bpe.vs100000.model')
|
||||
ofunc = OrtPyFunction.from_customop('SentencepieceTokenizer', model=open(fullname, 'rb').read())
|
||||
|
|
|
@ -160,7 +160,8 @@ if __name__ == '__main__':
|
|||
model_name = "openai/whisper-base.en"
|
||||
onnx_model_name = "whisper-base.en_beamsearch.onnx"
|
||||
if not Path(onnx_model_name).is_file():
|
||||
raise RuntimeError("Please run the script from where Whisper ONNX model was exported. like */onnx_models/openai")
|
||||
raise RuntimeError(
|
||||
"Please run the script from where Whisper ONNX model was exported. like */onnx_models/openai")
|
||||
|
||||
_processor = WhisperProcessor.from_pretrained(model_name)
|
||||
if USE_ONNX_COREMODEL:
|
||||
|
|
Загрузка…
Ссылка в новой задаче