more code fixing related whisper models (#403)
This commit is contained in:
Родитель
26dda4eb74
Коммит
997fa892c2
|
@ -92,10 +92,13 @@ class OrtPyFunction:
|
|||
def output_names(self):
|
||||
return [vi_.name for vi_ in self.onnx_model.graph.output]
|
||||
|
||||
def _bind(self, oxml):
|
||||
def _bind(self, oxml, model_path=None):
|
||||
self.inputs = list(oxml.graph.input)
|
||||
self.outputs = list(oxml.graph.output)
|
||||
self._oxml = oxml
|
||||
if model_path is not None:
|
||||
self.ort_session = _ort.InferenceSession(
|
||||
model_path, self.get_ort_session_options())
|
||||
return self
|
||||
|
||||
def _ensure_ort_session(self):
|
||||
|
@ -112,7 +115,13 @@ class OrtPyFunction:
|
|||
|
||||
@classmethod
|
||||
def from_model(cls, path_or_model, *args, **kwargs):
|
||||
return cls()._bind(onnx.load_model(path_or_model) if isinstance(path_or_model, str) else path_or_model)
|
||||
mpath = None
|
||||
if isinstance(path_or_model, str):
|
||||
oxml = onnx.load_model(path_or_model)
|
||||
mpath = path_or_model
|
||||
else:
|
||||
oxml = path_or_model
|
||||
return cls()._bind(oxml, mpath)
|
||||
|
||||
def _argument_map(self, *args, **kwargs):
|
||||
idx = 0
|
||||
|
|
|
@ -47,8 +47,7 @@ struct KernelAudioDecoder : public BaseKernel {
|
|||
if (pos == format_mapping.end()) {
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"[AudioDecoder]: Unknown audio stream format: ", str_format),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
ORT_INVALID_ARGUMENT); }
|
||||
stream_format = pos->second;
|
||||
}
|
||||
|
||||
|
@ -107,11 +106,11 @@ struct KernelAudioDecoder : public BaseKernel {
|
|||
std::list<std::vector<float>> lst_frames;
|
||||
|
||||
if (stream_format == AudioStreamType::kMP3) {
|
||||
drmp3 mp3_obj;
|
||||
if (!drmp3_init_memory(&mp3_obj, p_data, input_dim.Size(), nullptr)) {
|
||||
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
||||
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input_dim.Size(), nullptr)) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, mp3_obj);
|
||||
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
||||
|
||||
} else if (stream_format == AudioStreamType::kFLAC) {
|
||||
drflac* flac_obj = drflac_open_memory(p_data, input_dim.Size(), nullptr);
|
||||
|
|
|
@ -7,9 +7,9 @@ import numpy as np
|
|||
|
||||
from pathlib import Path
|
||||
from onnx import numpy_helper
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
from transformers import WhisperProcessor
|
||||
|
||||
from onnxruntime_extensions import PyOrtFunction, util, optimize_model
|
||||
from onnxruntime_extensions import PyOrtFunction, util
|
||||
from onnxruntime_extensions.cvt import HFTokenizerConverter
|
||||
|
||||
|
||||
|
@ -194,8 +194,10 @@ def merge_models(core: str, output_model:str, audio_data):
|
|||
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||
bpe_decoder_node
|
||||
])
|
||||
onnx.save_model(m_all, output_model.replace(".onnx", "_.onnx"))
|
||||
optimize_model(m_all, output_model)
|
||||
onnx.save_model(m_all, output_model,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
convert_attribute=True)
|
||||
print(f"The final merged model was saved as: {output_model}")
|
||||
|
||||
print("Verify the final model...")
|
||||
|
|
Загрузка…
Ссылка в новой задаче