зеркало из https://github.com/microsoft/archai.git
fix(onnx_loader): Uses default provider is providers are not supplied.
This commit is contained in:
Родитель
36abe8851f
Коммит
bec6cba801
|
@ -48,6 +48,10 @@ class TransformerFlexOnnxLatency(ModelEvaluator):
|
|||
) -> None:
|
||||
"""Initialize the evaluator.
|
||||
|
||||
This evaluator supports measuring in different ONNX Runtime providers. For measuring on
|
||||
GPUs, use `providers=["CUDAExecutionProvider"]` and make sure that `onnxruntime-gpu`
|
||||
package is installed.
|
||||
|
||||
Args:
|
||||
search_space: The search space to use for loading the model.
|
||||
providers: The list of ORT providers to use for benchmarking.
|
||||
|
|
|
@ -37,6 +37,8 @@ def load_from_onnx(onnx_model_path: str, providers: Optional[List[str]] = None)
|
|||
options.intra_op_num_threads = OMP_NUM_THREADS
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
providers = providers or ["CPUExecutionProvider"]
|
||||
|
||||
session = InferenceSession(onnx_model_path, sess_options=options, providers=providers)
|
||||
session.disable_fallback()
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
search_objectives.add_objective(
|
||||
"onnx_latency",
|
||||
TransformerFlexOnnxLatency(space, providers=["CPUExecutionProvider"], seq_len=1024, n_trials=5, use_past=False),
|
||||
TransformerFlexOnnxLatency(space, seq_len=1024, n_trials=5, use_past=False),
|
||||
higher_is_better=False,
|
||||
compute_intensive=False,
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче