onnxruntime-extensions/test/test_audio_codec.py

60 строки
2.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import numpy as np
from onnx import checker, helper, onnx_pb as onnx_proto
from onnxruntime_extensions import PyOrtFunction, util
class TestAudioCodec(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
cls.test_wav_file = util.get_test_data_file('data', '1272-141231-0002.wav')
cls.test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac')
cls.raw_data = np.load(util.get_test_data_file('data', '1272-141231-0002.npy'))
cls.decoder = PyOrtFunction.from_customop('AudioDecoder')
return super().setUpClass()
def test_wav_decoder(self):
blob = bytearray(util.read_file(self.test_wav_file, mode='rb'))
pcm_tensor = self.decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
np.testing.assert_allclose(pcm_tensor, self.raw_data, rtol=1e-05, atol=1e-08)
def test_wav_format_decoder(self):
blob = bytearray(util.read_file(self.test_wav_file, mode='rb'))
fmt_onnx_model = self.decoder.onnx_model
fmt_onnx_model.graph.input.extend([
helper.make_tensor_value_info('format', onnx_proto.TensorProto.STRING, [])])
fmt_onnx_model.graph.node[0].input.extend(['format'])
checker.check_model(fmt_onnx_model)
new_decoder = PyOrtFunction.from_model(fmt_onnx_model)
pcm_tensor = new_decoder(np.expand_dims(np.asarray(blob), axis=(0,)), ["wav"])
np.testing.assert_allclose(pcm_tensor, self.raw_data, rtol=1e-05, atol=1e-08)
def test_flac_decoder(self):
blob = bytearray(util.read_file(self.test_flac_file, mode='rb'))
pcm_tensor = self.decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
np.testing.assert_allclose(pcm_tensor, self.raw_data, atol=1e-03)
def test_mp3_decoder(self):
blob = bytearray(util.read_file(self.test_mp3_file, mode='rb'))
pcm_tensor = self.decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
self.assertTrue(pcm_tensor.shape[1] > len(blob))
# lossy compression, so we can only check the range
np.testing.assert_allclose(
np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
def test_decoder_resampling(self):
test_file = util.get_test_data_file('data', 'jfk.flac')
blob = bytearray(util.read_file(test_file, mode='rb'))
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
self.assertEqual(pcm_tensor.shape, (1, 176000))
if __name__ == "__main__":
unittest.main()