fix(onnx_loader): Uses default provider is providers are not supplied.

This commit is contained in:
Gustavo Rosa 2023-02-01 15:26:35 -03:00
Родитель 36abe8851f
Коммит bec6cba801
3 изменённых файлов: 7 добавлений и 1 удалений

Просмотреть файл

@ -48,6 +48,10 @@ class TransformerFlexOnnxLatency(ModelEvaluator):
) -> None:
"""Initialize the evaluator.
This evaluator supports measuring in different ONNX Runtime providers. For measuring on
GPUs, use `providers=["CUDAExecutionProvider"]` and make sure that `onnxruntime-gpu`
package is installed.
Args:
search_space: The search space to use for loading the model.
providers: The list of ORT providers to use for benchmarking.

Просмотреть файл

@ -37,6 +37,8 @@ def load_from_onnx(onnx_model_path: str, providers: Optional[List[str]] = None)
options.intra_op_num_threads = OMP_NUM_THREADS
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
providers = providers or ["CPUExecutionProvider"]
session = InferenceSession(onnx_model_path, sess_options=options, providers=providers)
session.disable_fallback()

Просмотреть файл

@ -107,7 +107,7 @@ if __name__ == "__main__":
)
search_objectives.add_objective(
"onnx_latency",
TransformerFlexOnnxLatency(space, providers=["CPUExecutionProvider"], seq_len=1024, n_trials=5, use_past=False),
TransformerFlexOnnxLatency(space, seq_len=1024, n_trials=5, use_past=False),
higher_is_better=False,
compute_intensive=False,
)