using the huggingface whisper config instead of fixed numbers (#667)

* using the huggingface whisper config instead of fixed numbers

* refactor a little bit
This commit is contained in:
Wenbing Li 2024-03-06 14:29:49 -08:00 коммит произвёл GitHub
Родитель 61369fb970
Коммит 6ac6fb6fbd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 45 добавлений и 19 удалений

Просмотреть файл

@ -13,8 +13,8 @@ from numpy import array as nparray
from functools import partial
from collections import namedtuple, OrderedDict
from ._cuops import CustomOpConverter, SingleOpGraph
from .util import read_file
from ._cuops import CustomOpConverter, SingleOpGraph
class HFTokenizerConverter(CustomOpConverter):

Просмотреть файл

@ -88,19 +88,21 @@ class CustomOpStftNorm(torch.autograd.Function):
class WhisperPrePipeline(torch.nn.Module):
def __init__(self):
def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT,
hop_length=_WhisperHParams.HOP_LENGTH, n_mels=_WhisperHParams.N_MELS,
n_samples=_WhisperHParams.N_SAMPLES):
super().__init__()
self.window = torch.hann_window(_WhisperHParams.N_FFT)
self.n_samples = n_samples
self.hop_length = hop_length
self.n_fft = n_fft
self.window = torch.hann_window(n_fft)
self.mel_filters = torch.from_numpy(
_mel_filterbank(
sr=_WhisperHParams.SAMPLE_RATE,
n_fft=_WhisperHParams.N_FFT,
n_mels=_WhisperHParams.N_MELS))
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm,
_WhisperHParams.N_FFT,
_WhisperHParams.HOP_LENGTH,
self.n_fft,
self.hop_length,
self.window)
magnitudes = stft_norm[:, :, :-1]
mel_spec = self.mel_filters @ magnitudes
@ -110,7 +112,7 @@ class WhisperPrePipeline(torch.nn.Module):
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
_WhisperHParams.N_SAMPLES // _WhisperHParams.HOP_LENGTH - spec_shape[2],
self.n_samples // self.hop_length - spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
@ -118,7 +120,7 @@ class WhisperPrePipeline(torch.nn.Module):
return log_spec
def _to_onnx_stft(onnx_model):
def _to_onnx_stft(onnx_model, n_fft):
"""Convert custom-op STFT-Norm to ONNX STFT"""
node_idx = 0
new_stft_nodes = []
@ -136,8 +138,8 @@ def _to_onnx_stft(onnx_model):
replaced_nodes = [
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
value=numpy_helper.from_array(np.array([0,
_WhisperHParams.N_FFT // 2, 0,
_WhisperHParams.N_FFT // 2], dtype='int64'),
n_fft // 2, 0,
n_fft // 2], dtype='int64'),
name='const_14')),
make_node('Pad',
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
@ -192,7 +194,13 @@ class WhisperDataProcGraph:
def pre_processing(self, **kwargs):
use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True)
use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True)
whisper_processing = WhisperPrePipeline()
feature_extractor = self.hf_processor.feature_extractor
whisper_processing = WhisperPrePipeline(
feature_extractor.sampling_rate,
feature_extractor.n_fft,
feature_extractor.hop_length,
feature_extractor.feature_size,
feature_extractor.n_samples)
audio_pcm = torch.rand((1, 32000), dtype=torch.float32)
model_args = (audio_pcm,)
@ -209,13 +217,15 @@ class WhisperDataProcGraph:
}
)
if use_onnx_stft:
pre_model = _to_onnx_stft(pre_model)
pre_model = _to_onnx_stft(pre_model, feature_extractor.n_fft)
remove_unused_initializers(pre_model.graph)
pre_full = pre_model
if use_audio_decoder:
audecoder_g = SingleOpGraph.build_graph(
"AudioDecoder", downsampling_rate=_WhisperHParams.SAMPLE_RATE, stereo_to_mono=1)
"AudioDecoder",
downsampling_rate=feature_extractor.sampling_rate,
stereo_to_mono=1)
audecoder_m = make_onnx_model(audecoder_g)
pre_full = onnx.compose.merge_models(
audecoder_m,

Просмотреть файл

@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
@ -41,7 +40,6 @@ class TestHuggingfaceWhisper(unittest.TestCase):
self.assertEqual(log_mel.shape, (1, 80, 3000))
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_ort_stft_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
@ -63,7 +61,6 @@ class TestHuggingfaceWhisper(unittest.TestCase):
self.assertTrue(num_mismatched / np.size(expected) < 0.02)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
@unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.")
def test_stft_norm_consistency(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pre_m, _ = gen_processing_models(processor,
@ -83,6 +80,25 @@ class TestHuggingfaceWhisper(unittest.TestCase):
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
def test_stft_norm_consistency_large(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
pre_m, _ = gen_processing_models(processor,
pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False})
test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0)
raw_audio = OrtPyFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data)
input_features = processor([raw_audio[0]], sampling_rate=16000)
expected = input_features['input_features'][0]
log_mel = OrtPyFunction.from_model(pre_m)(raw_audio)
actual = log_mel[0]
np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05)
self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05)
if __name__ == "__main__":
unittest.main()