more code fixing related whisper models (#403)

This commit is contained in:
Wenbing Li 2023-04-21 09:26:44 -07:00 коммит произвёл GitHub
Родитель 26dda4eb74
Коммит 997fa892c2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 21 добавлений и 11 удалений

Просмотреть файл

@ -92,10 +92,13 @@ class OrtPyFunction:
def output_names(self):
return [vi_.name for vi_ in self.onnx_model.graph.output]
def _bind(self, oxml):
def _bind(self, oxml, model_path=None):
self.inputs = list(oxml.graph.input)
self.outputs = list(oxml.graph.output)
self._oxml = oxml
if model_path is not None:
self.ort_session = _ort.InferenceSession(
model_path, self.get_ort_session_options())
return self
def _ensure_ort_session(self):
@ -112,7 +115,13 @@ class OrtPyFunction:
@classmethod
def from_model(cls, path_or_model, *args, **kwargs):
return cls()._bind(onnx.load_model(path_or_model) if isinstance(path_or_model, str) else path_or_model)
mpath = None
if isinstance(path_or_model, str):
oxml = onnx.load_model(path_or_model)
mpath = path_or_model
else:
oxml = path_or_model
return cls()._bind(oxml, mpath)
def _argument_map(self, *args, **kwargs):
idx = 0

Просмотреть файл

@ -47,8 +47,7 @@ struct KernelAudioDecoder : public BaseKernel {
if (pos == format_mapping.end()) {
ORTX_CXX_API_THROW(MakeString(
"[AudioDecoder]: Unknown audio stream format: ", str_format),
ORT_INVALID_ARGUMENT);
}
ORT_INVALID_ARGUMENT); }
stream_format = pos->second;
}
@ -107,11 +106,11 @@ struct KernelAudioDecoder : public BaseKernel {
std::list<std::vector<float>> lst_frames;
if (stream_format == AudioStreamType::kMP3) {
drmp3 mp3_obj;
if (!drmp3_init_memory(&mp3_obj, p_data, input_dim.Size(), nullptr)) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input_dim.Size(), nullptr)) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
}
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, mp3_obj);
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
} else if (stream_format == AudioStreamType::kFLAC) {
drflac* flac_obj = drflac_open_memory(p_data, input_dim.Size(), nullptr);

Просмотреть файл

@ -7,9 +7,9 @@ import numpy as np
from pathlib import Path
from onnx import numpy_helper
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import WhisperProcessor
from onnxruntime_extensions import PyOrtFunction, util, optimize_model
from onnxruntime_extensions import PyOrtFunction, util
from onnxruntime_extensions.cvt import HFTokenizerConverter
@ -194,8 +194,10 @@ def merge_models(core: str, output_model:str, audio_data):
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpe_decoder_node
])
onnx.save_model(m_all, output_model.replace(".onnx", "_.onnx"))
optimize_model(m_all, output_model)
onnx.save_model(m_all, output_model,
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=True)
print(f"The final merged model was saved as: {output_model}")
print("Verify the final model...")