Basic DirectML latency example (#148)

This change adds an example of optimizing the SqueezeNet model (a simple
"hello world" style network) for latency with DirectML. It's nice to
have a basic sample that runs quickly and doesn't involve huge models
with external weights.

Other changes:
- Introduces an ONNX pass for basic float-to-float16 conversion outside
of standard model conversion or transformer-specific optimizer pass.
Sometimes it matters where/when this conversion occurs, but for simple
GPU inference on non-transformer models it's nice to have as a
standalone pass.
- Updates OrtPerfTuning pass to filter EPs based on the configured
device. As-is, the pass will create tuning combos for _all_ EPs in the
user's onnxruntime package.
- Adds DML references in docs.
This commit is contained in:
Justin Stoecker 2023-04-03 14:14:26 -07:00 коммит произвёл GitHub
Родитель 47598cab6b
Коммит b311be67e1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 215 добавлений и 5 удалений

Просмотреть файл

@ -44,6 +44,10 @@ With onnxruntime-gpu:
```
pip install olive-ai[gpu]
```
With onnxruntime-directml:
```
pip install olive-ai[directml]
```
### Optional Dependencies
Olive has optional dependencies that can be installed to enable additional features. These dependencies can be installed as extras:

Просмотреть файл

@ -12,3 +12,6 @@
# ResNet optimization with QAT PyTorch Lightning Module on CPU
[resnet_qat_lightning_module_cpu](https://github.com/microsoft/Olive/tree/main/examples/quantization_aware_training/resnet_qat_lightning_module_cpu)
# SqueezeNet latency optimization with DirectML
[directml/squeezenet](https://github.com/microsoft/Olive/tree/main/examples/directml/squeezenet)

Просмотреть файл

@ -23,6 +23,10 @@ With onnxruntime-gpu:
```
pip install olive-ai[gpu]
```
With onnxruntime-directml:
```
pip install olive-ai[directml]
```
## Install from source
Install the latest `main` version of Olive from source. Please note that this is a development version and may not be stable.
@ -40,6 +44,11 @@ With onnxruntime-gpu:
```
pip install git+https://github.com/microsoft/Olive#egg=olive-ai[gpu]
```
With onnxruntime-directml:
```
pip install git+https://github.com/microsoft/Olive#egg=olive-ai[directml]
```
## Editable install

Просмотреть файл

@ -13,4 +13,5 @@ dependencies:
- datasets
- scipy
- transformers
- onnxconverter_common
- olive-ai==0.1.0

Просмотреть файл

@ -5,6 +5,7 @@ RUN pip install azure-ai-ml \
onnxruntime \
datasets \
transformers \
onnxconverter_common \
olive-ai==0.1.0
WORKDIR /olive

Просмотреть файл

@ -0,0 +1,20 @@
# SqueezeNet Latency Optimization with DirectML
This folder contains a sample use case of Olive to optimize the [SqueezeNet](https://pytorch.org/hub/pytorch_vision_squeezenet/) model using ONNX conversion, conversion to FLOAT16, and general ONNX performance tuning.
Performs optimization pipeline:
PyTorch Model -> [Convert to ONNX] -> [FP16 Conversion] -> [Tune performance] -> Optimized FP16 ONNX Model
Outputs the best metrics, model, and corresponding Olive config.
## Optimize SqueezeNet
```
python -m olive.workflows.run --config squeezenet_config.json
```
or run simply with python code:
```python
from olive.workflows import run as olive_run
olive_run("squeezenet_config.json")
```

Просмотреть файл

@ -0,0 +1,72 @@
{
"verbose": true,
"input_model": {
"type": "PyTorchModel",
"config": {
"model_path": null,
"is_file": false,
"model_loader": "load_pytorch_origin_model",
"model_script": "user_script.py"
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"device": "gpu"
}
}
},
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "latency",
"type": "latency",
"sub_type": "avg",
"user_config": {
"user_script": "user_script.py",
"dataloader_func": "create_dataloader",
"batch_size": 1
}
}
],
"target": "local_system"
}
},
"passes": {
"torch_to_onnx": {
"type": "OnnxConversion",
"config": {
"input_names": [ "input_image" ],
"input_shapes": [ [ 1, 3, 224, 224 ] ],
"output_names": [ "output" ],
"target_opset": 13
}
},
"float16_conversion": {
"type": "OnnxFloatToFloat16"
},
"perf_tuning": {
"type": "OrtPerfTuning",
"config": {
"user_script": "user_script.py",
"dataloader_func": "create_dataloader",
"device": "gpu",
"batch_size": 1,
"execution_mode_list": [ "ORT_SEQUENTIAL" ],
"providers_list": [ "DmlExecutionProvider" ]
}
}
},
"engine": {
"search_strategy": {
"execution_order": "joint",
"search_algorithm": "exhaustive"
},
"evaluator": "common_evaluator",
"host": "local_system",
"clean_cache": true,
"cache_dir": "cache"
}
}

Просмотреть файл

@ -0,0 +1,23 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import torch
def load_pytorch_origin_model(torch_hub_model_path):
return torch.hub.load("pytorch/vision:v0.10.0", "squeezenet1_1", pretrained=True)
class DataLoader:
def __init__(self, batchsize):
self.batchsize = batchsize
def __getitem__(self, idx):
input_data = torch.rand((self.batchsize, 3, 224, 224), dtype=torch.float16)
label = None
return input_data, label
def create_dataloader(data_dir, batchsize):
return DataLoader(batchsize)

Просмотреть файл

@ -17,4 +17,5 @@ dependencies:
- onnx
- scipy
- onnxruntime
- onnxconverter_common
- olive-ai==0.1.0

Просмотреть файл

@ -112,6 +112,7 @@ class ONNXModel(OliveModel):
EXECUTION_PROVIDERS = {
"cpu": ["CPUExecutionProvider", "OpenVINOExecutionProvider"],
"gpu": [
"DmlExecutionProvider",
"CUDAExecutionProvider",
"OpenVINOExecutionProvider",
"TensorrtExecutionProvider",
@ -171,11 +172,15 @@ class ONNXModel(OliveModel):
if not execution_provider:
execution_provider = self.get_execution_providers(device)
elif isinstance(execution_provider, tuple):
execution_provider = execution_provider
elif isinstance(execution_provider, list):
# execution_provider may be a list of tuples where the first item in each tuple is the EP name
execution_provider = [i[0] if isinstance(i, tuple) else i for i in execution_provider]
elif isinstance(execution_provider, str):
execution_provider = [execution_provider]
if len(execution_provider) >= 1 and execution_provider[0] == "DmlExecutionProvider":
sess_options.enable_mem_pattern = False
return ort.InferenceSession(self.model_path, sess_options, providers=execution_provider)
def to_json(self, check_object: bool = False):

Просмотреть файл

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from olive.passes.onnx.conversion import OnnxConversion
from olive.passes.onnx.float16_conversion import OnnxFloatToFloat16
from olive.passes.onnx.model_optimizer import OnnxModelOptimizer
from olive.passes.onnx.perf_tuning import OrtPerfTuning
from olive.passes.onnx.quantization import OnnxDynamicQuantization, OnnxQuantization, OnnxStaticQuantization
@ -16,4 +17,5 @@ __all__ = [
"OrtPerfTuning",
"OrtTransformersOptimization",
"OnnxModelOptimizer",
"OnnxFloatToFloat16",
]

Просмотреть файл

@ -0,0 +1,63 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Dict
import onnx
from onnxconverter_common import float16
from olive.model import ONNXModel
from olive.passes import Pass
from olive.passes.pass_config import PassConfigParam
class OnnxFloatToFloat16(Pass):
"""Converts a model to float16.
It is based on onnxconverter-common.convert_float_to_float16.
See https://onnxruntime.ai/docs/performance/model-optimizations/float16.html#float16-conversion
"""
@staticmethod
def _default_config() -> Dict[str, PassConfigParam]:
return {
"min_positive_val": PassConfigParam(
type_=float, default=1e-7, description=("Constant values will be clipped against this value")
),
"max_finite_val": PassConfigParam(
type_=float, default=1e4, description=("Constant values will be clipped against this value")
),
"keep_io_types": PassConfigParam(
type_=bool, default=False, description=("Whether model inputs/outputs should be left as float32")
),
"disable_shape_infer": PassConfigParam(
type_=bool, default=False, description=("Skips running onnx shape/type inference.")
),
"op_block_list": PassConfigParam(
type_=list[str], default=[], description=("List of op types to leave as float32")
),
"node_block_list": PassConfigParam(
type_=list[str], default=[], description=("List of node names to leave as float32")
),
}
def _run_for_config(self, model: ONNXModel, config: Dict[str, Any], output_model_path: str) -> ONNXModel:
if Path(output_model_path).suffix != ".onnx":
output_model_path += ".onnx"
config = self._config_class(**config)
model_fp32 = onnx.load(str(model.model_path))
model_fp16 = float16.convert_float_to_float16(
model_fp32,
min_positive_val=config.min_positive_val,
max_finite_val=config.max_finite_val,
keep_io_types=config.keep_io_types,
disable_shape_infer=config.disable_shape_infer,
op_block_list=config.op_block_list,
node_block_list=config.node_block_list,
)
onnx.save(model_fp16, output_model_path)
return ONNXModel(output_model_path, model.name)

Просмотреть файл

@ -20,8 +20,8 @@ from olive.passes.pass_config import PassConfigParam
logger = logging.getLogger(__name__)
def generate_tuning_combos(config):
providers_list = config.providers_list if config.providers_list else ort.get_available_providers()
def generate_tuning_combos(model, config):
providers_list = config.providers_list if config.providers_list else model.get_execution_providers(config.device)
execution_mode_list = (
config.execution_mode_list
if config.execution_mode_list
@ -51,7 +51,7 @@ def tune_onnx_model(model, config):
pretuning_inference_result = get_benchmark(model, latency_metric, config)
tuning_results = []
for tuning_combo in generate_tuning_combos(config):
for tuning_combo in generate_tuning_combos(model, config):
logger.info("Run tuning for: {}".format(tuning_combo))
if tuning_combo[0] == "CPUExecutionProvider" and tuning_combo[3]:
continue

Просмотреть файл

@ -12,6 +12,7 @@ RUN pip install azure-ai-ml \
openvino \
openvino-dev[tensorflow,onnx] \
tensorflow \
onnxconverter_common \
olive-ai==0.1.0
ADD requirements.txt requirements.txt

Просмотреть файл

@ -1,5 +1,6 @@
numpy <= 1.23.4
onnx
onnxconverter_common
optuna
pandas
pydantic

Просмотреть файл

@ -29,6 +29,7 @@ EXTRAS = {
"docker": ["docker"],
"cpu": ["onnxruntime"],
"gpu": ["onnxruntime-gpu"],
"directml": ["onnxruntime-directml"],
"openvino": ["openvino==2022.3.0", "openvino-dev[tensorflow,onnx]==2022.3.0"],
"tf": ["tensorflow==1.15.0"],
}

Просмотреть файл

@ -13,4 +13,5 @@ dependencies:
- datasets
- scipy
- transformers
- onnxconverter_common
- olive-ai==0.1.0

Просмотреть файл

@ -11,4 +11,5 @@ dependencies:
- datasets
- transformers
- torchvision
- onnxconverter_common
- olive-ai==0.1.0

Просмотреть файл

@ -10,6 +10,7 @@ RUN pip install onnxruntime \
transformers \
openvino \
openvino-dev \
onnxconverter_common \
olive-ai==0.1.0
WORKDIR /olive