using the huggingface whisper config instead of fixed numbers (#667)
* using the huggingface whisper config instead of fixed numbers * refactor a little bit
This commit is contained in:
Родитель
61369fb970
Коммит
6ac6fb6fbd
|
@ -13,8 +13,8 @@ from numpy import array as nparray
|
|||
from functools import partial
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
from ._cuops import CustomOpConverter, SingleOpGraph
|
||||
from .util import read_file
|
||||
from ._cuops import CustomOpConverter, SingleOpGraph
|
||||
|
||||
|
||||
class HFTokenizerConverter(CustomOpConverter):
|
||||
|
|
|
@ -88,19 +88,21 @@ class CustomOpStftNorm(torch.autograd.Function):
|
|||
|
||||
|
||||
class WhisperPrePipeline(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT,
|
||||
hop_length=_WhisperHParams.HOP_LENGTH, n_mels=_WhisperHParams.N_MELS,
|
||||
n_samples=_WhisperHParams.N_SAMPLES):
|
||||
super().__init__()
|
||||
self.window = torch.hann_window(_WhisperHParams.N_FFT)
|
||||
self.n_samples = n_samples
|
||||
self.hop_length = hop_length
|
||||
self.n_fft = n_fft
|
||||
self.window = torch.hann_window(n_fft)
|
||||
self.mel_filters = torch.from_numpy(
|
||||
_mel_filterbank(
|
||||
sr=_WhisperHParams.SAMPLE_RATE,
|
||||
n_fft=_WhisperHParams.N_FFT,
|
||||
n_mels=_WhisperHParams.N_MELS))
|
||||
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
|
||||
|
||||
def forward(self, audio_pcm: torch.Tensor):
|
||||
stft_norm = CustomOpStftNorm.apply(audio_pcm,
|
||||
_WhisperHParams.N_FFT,
|
||||
_WhisperHParams.HOP_LENGTH,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.window)
|
||||
magnitudes = stft_norm[:, :, :-1]
|
||||
mel_spec = self.mel_filters @ magnitudes
|
||||
|
@ -110,7 +112,7 @@ class WhisperPrePipeline(torch.nn.Module):
|
|||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1],
|
||||
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
|
||||
self.n_samples // self.hop_length - spec_shape[2],
|
||||
dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
|
@ -118,7 +120,7 @@ class WhisperPrePipeline(torch.nn.Module):
|
|||
return log_spec
|
||||
|
||||
|
||||
def _to_onnx_stft(onnx_model):
|
||||
def _to_onnx_stft(onnx_model, n_fft):
|
||||
"""Convert custom-op STFT-Norm to ONNX STFT"""
|
||||
node_idx = 0
|
||||
new_stft_nodes = []
|
||||
|
@ -136,8 +138,8 @@ def _to_onnx_stft(onnx_model):
|
|||
replaced_nodes = [
|
||||
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
|
||||
value=numpy_helper.from_array(np.array([0,
|
||||
_WhisperHParams.N_FFT // 2, 0,
|
||||
_WhisperHParams.N_FFT // 2], dtype='int64'),
|
||||
n_fft // 2, 0,
|
||||
n_fft // 2], dtype='int64'),
|
||||
name='const_14')),
|
||||
make_node('Pad',
|
||||
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
||||
|
@ -192,7 +194,13 @@ class WhisperDataProcGraph:
|
|||
def pre_processing(self, **kwargs):
|
||||
use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True)
|
||||
use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True)
|
||||
whisper_processing = WhisperPrePipeline()
|
||||
feature_extractor = self.hf_processor.feature_extractor
|
||||
whisper_processing = WhisperPrePipeline(
|
||||
feature_extractor.sampling_rate,
|
||||
feature_extractor.n_fft,
|
||||
feature_extractor.hop_length,
|
||||
feature_extractor.feature_size,
|
||||
feature_extractor.n_samples)
|
||||
|
||||
audio_pcm = torch.rand((1, 32000), dtype=torch.float32)
|
||||
model_args = (audio_pcm,)
|
||||
|
@ -209,13 +217,15 @@ class WhisperDataProcGraph:
|
|||
}
|
||||
)
|
||||
if use_onnx_stft:
|
||||
pre_model = _to_onnx_stft(pre_model)
|
||||
pre_model = _to_onnx_stft(pre_model, feature_extractor.n_fft)
|
||||
remove_unused_initializers(pre_model.graph)
|
||||
|
||||
pre_full = pre_model
|
||||
if use_audio_decoder:
|
||||
audecoder_g = SingleOpGraph.build_graph(
|
||||
"AudioDecoder", downsampling_rate=_WhisperHParams.SAMPLE_RATE, stereo_to_mono=1)
|
||||
"AudioDecoder",
|
||||
downsampling_rate=feature_extractor.sampling_rate,
|
||||
stereo_to_mono=1)
|
||||
audecoder_m = make_onnx_model(audecoder_g)
|
||||
pre_full = onnx.compose.merge_models(
|
||||
audecoder_m,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
@ -41,7 +40,6 @@ class TestHuggingfaceWhisper(unittest.TestCase):
|
|||
|
||||
self.assertEqual(log_mel.shape, (1, 80, 3000))
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_ort_stft_consistency(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
|
@ -63,7 +61,6 @@ class TestHuggingfaceWhisper(unittest.TestCase):
|
|||
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
|
||||
def test_stft_norm_consistency(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
|
@ -83,6 +80,25 @@ class TestHuggingfaceWhisper(unittest.TestCase):
|
|||
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
def test_stft_norm_consistency_large(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||
pre_m, _ = gen_processing_models(processor,
|
||||
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
|
||||
|
||||
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
|
||||
raw_audio = OrtPyFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
|
||||
|
||||
input_features = processor([raw_audio[0]], sampling_rate=16000)
|
||||
expected = input_features['input_features'][0]
|
||||
|
||||
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
|
||||
actual = log_mel[0]
|
||||
|
||||
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
|
||||
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче