Update NV-ModelOPT INT4 quantization (#1441)

1. Update modelopt integration in Olive
2. Add phi3 example 
3. Remove the old bert model example

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.
This commit is contained in:
anujj 2024-11-02 00:54:27 +05:30 коммит произвёл GitHub
Родитель b73e0ac747
Коммит 16ffab8314
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
11 изменённых файлов: 431 добавлений и 159 удалений

Просмотреть файл

@ -5,7 +5,6 @@ This folder contains examples of BERT optimization using different workflows.
- CPU: [Optimization with Intel® Neural Compressor PTQ](#bert-optimization-with-intel®-neural-compressor-ptq-on-cpu)
- CPU: [Optimization with QAT Customized Training Loop](#bert-optimization-with-qat-customized-training-loop-on-cpu)
- GPU: [Optimization with CUDA/TensorRT](#bert-optimization-with-cudatensorrt-on-gpu)
- GPU: [Optimization with TensorRT-Model-Optimizer](#bert-optimization-with-tensorRT-model-optimizer-on-cpugpu)
Go to [How to run](#how-to-run)
@ -97,14 +96,6 @@ This workflow performs BERT optimization on GPU with CUDA/TensorRT. It performs
- *PyTorch Model -> Onnx Model -> ONNX Runtime performance tuning with trt_fp16_enable*
Config file: [bert_trt_gpu.json](bert_trt_gpu.json)
### BERT optimization with TensorRT-Model-Optimizer on CPU/GPU
This workflow performs BERT post training quantization (PTQ) on CPU/GPU with TensorRT-Model-Optimizer. It performs the optimization pipeline:
- *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model -> TensorRT-Model-Optimizer Quantized Onnx Model*
Deployment support for TensorRT-Model-Optimizer quantized models is coming soon in ORT, in the meantime try [TensorRT 10.x](https://github.com/NVIDIA/TensorRT/tree/v10.0.1).<br>
Config file: [bert_nvmo_ptq.json](bert_nvmo_ptq.json)
## How to run
### Pip requirements
Install the necessary python packages:

Просмотреть файл

@ -1,22 +0,0 @@
{
"input_model": { "type": "HfModel", "model_path": "Intel/bert-base-uncased-mrpc", "task": "text-classification" },
"data_configs": [
{
"name": "rotten_tomatoes",
"user_script": "nv_user_script.py",
"load_dataset_config": { "data_name": "rotten_tomatoes", "split": "validation[:10%]" },
"dataloader_config": { "type": "nvmo_calibration_dataloader" },
"pre_process_data_config": { "type": "skip_pre_process" }
}
],
"passes": {
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
"transformers_optimization": { "type": "OrtTransformersOptimization", "model_type": "bert", "opt_level": 0 },
"quantization": {
"type": "NVModelOptQuantization",
"precision": "int4",
"algorithm": "AWQ",
"data_config": "rotten_tomatoes"
}
}
}

Просмотреть файл

@ -1,25 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import torch
from datasets.utils import logging as datasets_logging # type: ignore[import]
from transformers import AutoTokenizer
from olive.data.registry import Registry
datasets_logging.disable_progress_bar()
datasets_logging.set_verbosity_error()
@Registry.register_dataloader("nvmo_calibration_dataloader")
def create_calibration_dataloader(dataset, batch_size, calib_size=64, **kwargs):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def tokenization(example):
return tokenizer(example["text"], padding="max_length", max_length=128, truncation=True)
dataset = dataset.map(tokenization, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask"])
return torch.utils.data.DataLoader(dataset.select(range(calib_size)), batch_size=batch_size, drop_last=True)

Просмотреть файл

@ -109,6 +109,23 @@ python phi3.py --quarot
Get access to the following resources on Hugging Face Hub:
- [nampdn-ai/tiny-codes](https://huggingface.co/nampdn-ai/tiny-codes)
### Quantize the model using [TensorRT-Model-Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
use; [Package onnxruntime-genai-directml](https://github.com/microsoft/onnxruntime-genai)>=0.4.0
Setup
```bash
pip install olive-ai[nvmo]
pip install onnxruntime-genai-directml>=0.4.0
pip install onnxruntime-directml
pip install -r requirements-nvmo-awq.txt
```
Install the CUDA version compatible with CuPy as mentioned in requirements-nvmo-awq.txt
quantization: For quantization, use the config file phi3_nvmo_ptq.json
```bash
olive run --config phi3_nvmo_ptq.json
```
## More Inference Examples
- [Android chat APP with Phi-3 and ONNX Runtime Mobile](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/mobile/examples/phi-3/android)

Просмотреть файл

@ -0,0 +1,22 @@
{
"input_model": {
"type": "HfModel",
"model_path": "microsoft/Phi-3-mini-4k-instruct",
"task": "text-classification"
},
"systems": {
"local_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "gpu", "execution_providers": [ "DmlExecutionProvider" ] } ]
}
},
"passes": {
"builder": { "type": "ModelBuilder", "precision": "fp16" },
"quantization": {
"type": "NVModelOptQuantization",
"algorithm": "AWQ",
"tokenizer_dir": "microsoft/Phi-3-mini-4k-instruct",
"calibration": "awq_lite"
}
}
}

Просмотреть файл

@ -0,0 +1,4 @@
cupy-cuda12x
datasets>=2.14.4
torch==2.4.0
transformers==4.44.0

Просмотреть файл

@ -196,7 +196,7 @@ TEMPLATE = {
"bnb4": {"type": "OnnxBnb4Quantization", "quant_type": "nf4"},
"matmul4": {"type": "OnnxMatMul4Quantizer", "accuracy_level": 4},
"mnb_to_qdq": {"type": "MatMulNBitsToQDQ"},
"nvmo": {"type": "NVModelOptQuantization", "precision": "int4", "algorithm": "RTN"},
"nvmo": {"type": "NVModelOptQuantization", "precision": "int4", "algorithm": "AWQ"},
"onnx_dynamic": {"type": "OnnxDynamicQuantization", "weight_type": "QInt8"},
"inc_dynamic": {"type": "IncDynamicQuantization", "quant_level": "auto", "algorithm": "RTN"},
# NOTE(all): Not supported yet!
@ -224,7 +224,7 @@ ALGORITHMS = {
"description": "(HfModel, OnnxModel) WOQ with GPTQ.",
},
"rtn": {
"implementations": ["quarot", "bnb4", "matmul4", "nvmo"],
"implementations": ["quarot", "bnb4", "matmul4"],
"hf_model_defaults": {"implementation": "quarot", "precision": "int16"},
"onnx_model_defaults": {"implementation": "onnx_static", "precision": "int8"},
"description": "(HfModel, OnnxModel) WOQ with RTN.",

Просмотреть файл

@ -356,7 +356,7 @@
"gpu": [ "onnxruntime-gpu" ],
"inc": [ "neural-compressor" ],
"lora": [ "accelerate>=0.30.0", "peft", "scipy" ],
"nvmo": [ "nvidia-modelopt~=0.11.0", "onnx-graphsurgeon" ],
"nvmo": [ "nvidia-modelopt", "onnx-graphsurgeon", "datasets>=2.14.4" ],
"openvino": [ "openvino==2023.2.0", "nncf==2.7.0", "numpy<2.0" ],
"optimum": [ "optimum" ],
"ort-genai": [ "onnxruntime-genai" ],

Просмотреть файл

@ -3,13 +3,17 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Union
from olive.common.config_utils import validate_config
import onnx
import torch
from onnx import helper
from onnx.onnx_pb import ModelProto
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from olive.common.utils import StrEnumBase
from olive.data.config import DataConfig
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import OliveModelHandler
from olive.model.utils import resolve_onnx_path
@ -21,16 +25,6 @@ from olive.strategy.search_parameter import Categorical
logger = logging.getLogger(__name__)
# static quantization specific config
_dataloader_config = {
"data_config": PassConfigParam(
type_=Union[DataConfig, Dict],
required=True,
description="Data config to load data for computing latency.",
),
}
class NVModelOptQuantization(Pass):
"""Quantize ONNX model with Nvidia-ModelOpt."""
@ -40,9 +34,12 @@ class NVModelOptQuantization(Pass):
INT4 = "int4"
class Algorithm(StrEnumBase):
RTN = "RTN"
AWQ = "AWQ"
class Calibration(StrEnumBase):
AWQ_LITE = "awq_lite"
AWQ_CLIP = "awq_clip"
@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
@ -54,46 +51,360 @@ class NVModelOptQuantization(Pass):
),
"algorithm": PassConfigParam(
type_=NVModelOptQuantization.Algorithm,
default_value="RTN",
search_defaults=Categorical(["RTN", "AWQ"]),
description="Algorithm of weight only quantization. Support 'RTN' and 'AWQ'.",
default_value="AWQ",
search_defaults=Categorical(["AWQ"]),
description="Algorithm of weight only quantization. Supports 'AWQ'.",
),
"calibration": PassConfigParam(
type_=NVModelOptQuantization.Calibration,
default_value="awq_clip",
search_defaults=Categorical(["awq_lite", "awq_clip"]),
description="Calibration method for weight only quantization. Supports 'awq_lite' and 'awq_clip'.",
),
"tokenizer_dir": PassConfigParam(
type_=str,
default_value="",
description="Tokenizer directory for calibration method.",
),
"random_calib_data": PassConfigParam(
type_=bool,
default_value=False,
description="Whether to use random calibration data instead of actual calibration data.",
),
**deepcopy(_dataloader_config),
}
def validate_search_point(
self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False
self,
search_point: Dict[str, Any],
accelerator_spec: AcceleratorSpec,
with_fixed_value: bool = False,
) -> bool:
if with_fixed_value:
search_point = self.config_at_search_point(search_point or {})
if search_point["precision"] != NVModelOptQuantization.Precision.INT4 or search_point["algorithm"] not in [
NVModelOptQuantization.Algorithm.RTN,
NVModelOptQuantization.Algorithm.AWQ,
]:
logger.error("Only INT4 quantization with RTN and AWQ algorithm is supported.")
# Validate Precision
if search_point.get("precision") != NVModelOptQuantization.Precision.INT4:
logger.error("Only INT4 quantization is supported.")
return False
# Validate Algorithm
if search_point.get("algorithm") not in [
NVModelOptQuantization.Algorithm.AWQ.value,
]:
logger.error("Only 'AWQ' algorithm is supported.")
return False
# Validate Calibration
if search_point.get("calibration") not in [
NVModelOptQuantization.Calibration.AWQ_LITE.value,
NVModelOptQuantization.Calibration.AWQ_CLIP.value,
]:
logger.error("Calibration method must be either 'awq_lite' or 'awq_clip'.")
return False
random_calib = search_point.get("random_calib_data", False)
if not isinstance(random_calib, bool):
logger.error("'random_calib_data' must be a boolean value.")
return False
tokenizer_dir = search_point.get("tokenizer_dir", "")
if not random_calib and not tokenizer_dir:
logger.error("'tokenizer_dir' must be specified when 'random_calib_data' is False.")
return False
# Optional: Validate 'tokenizer_dir' if necessary
if not search_point.get("tokenizer_dir"):
logger.warning("Tokenizer directory 'tokenizer_dir' is not specified.")
return True
def initialize_quant_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
# Check if 'tokenizer_dir' is provided and not empty
random_calib = config.get("random_calib_data", False)
if not random_calib:
# Prepare calibration inputs only if tokenizer_dir is specified
calib_inputs = self.get_calib_inputs(
dataset_name="cnn",
model_name=config["tokenizer_dir"],
cache_dir="./cache",
calib_size=32,
batch_size=1,
block_size=512,
device="cpu",
use_fp16=True,
use_buffer_share=False,
add_past_kv_inputs=True,
max_calib_rows_to_load=128,
add_position_ids=True,
)
else:
# If tokenizer_dir is empty, do not prepare calibration inputs
calib_inputs = None
logger.debug("No tokenizer directory specified. Skipping calibration input preparation.")
# Return a dictionary containing necessary configuration for quantization
return {
"algorithm": config.get("algorithm", self.Algorithm.AWQ.value),
"precision": config.get("precision", self.Precision.INT4.value),
"calibration_method": config.get("calibration", self.Calibration.AWQ_CLIP.value),
"tokenizer_dir": config.get("tokenizer_dir", ""),
"calibration_data_reader": calib_inputs,
}
def make_model_input(
self,
config,
input_ids_arg,
attention_mask_arg,
add_past_kv_inputs,
device,
use_fp16,
use_buffer_share,
add_position_ids,
):
input_ids = input_ids_arg
attention_mask = attention_mask_arg
if isinstance(input_ids_arg, list):
input_ids = torch.tensor(input_ids_arg, device=device, dtype=torch.int64)
attention_mask = torch.tensor(attention_mask_arg, device=device, dtype=torch.int64)
inputs = {
"input_ids": input_ids.contiguous(),
"attention_mask": attention_mask.contiguous(),
}
if add_position_ids:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
inputs["position_ids"] = position_ids.contiguous()
if add_past_kv_inputs:
torch_dtype = torch.float16 if use_fp16 else torch.float32
batch_size, _ = input_ids.shape
max_sequence_length = config.max_position_embeddings
num_heads, head_size = (
config.num_key_value_heads,
config.hidden_size // config.num_attention_heads,
)
for i in range(config.num_hidden_layers):
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
inputs.update(
{
f"past_key_values.{i}.key": past_key.contiguous(),
f"past_key_values.{i}.value": past_value.contiguous(),
}
)
return inputs
def get_calib_inputs(
self,
dataset_name,
model_name,
cache_dir,
calib_size,
batch_size,
block_size,
device,
use_fp16,
use_buffer_share,
add_past_kv_inputs,
max_calib_rows_to_load,
add_position_ids,
):
# Access transformers and datasets from the instance variables
config = AutoConfig.from_pretrained(
model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.pad_token = tokenizer.eos_token
assert calib_size <= max_calib_rows_to_load, "calib size should be no more than max_calib_rows_to_load"
from datasets import load_dataset
if "cnn" in dataset_name:
dataset2 = load_dataset("cnn_dailymail", name="3.0.0", split="train").select(range(max_calib_rows_to_load))
column = "article"
elif "pile" in dataset_name:
dataset2 = load_dataset("mit-han-lab/pile-val-backup", split="validation")
column = "text"
else:
raise ValueError(f'dataset "{dataset_name}" not supported')
dataset2 = dataset2[column][:calib_size]
batch_encoded = tokenizer.batch_encode_plus(
dataset2, return_tensors="pt", padding=True, truncation=True, max_length=block_size
)
batch_encoded = batch_encoded.to(device)
batch_encoded_input_ids = batch_encoded["input_ids"]
batch_encoded_attention_mask = batch_encoded["attention_mask"]
calib_dataloader_input_ids = DataLoader(batch_encoded_input_ids, batch_size=batch_size, shuffle=False)
calib_dataloader_attention_mask = DataLoader(batch_encoded_attention_mask, batch_size=batch_size, shuffle=False)
if len(calib_dataloader_input_ids.dataset) != len(calib_dataloader_attention_mask.dataset):
raise ValueError(
f"Mismatch in dataset len: calib_dataloader_input_ids has {len(calib_dataloader_input_ids.dataset)} "
f"items, calib_dataloader_attention_mask has {len(calib_dataloader_attention_mask.dataset)} items."
)
if len(calib_dataloader_input_ids) != len(calib_dataloader_attention_mask):
raise ValueError(
f"Mismatch in dataloader lengths: calib_dataloader_input_ids has {len(calib_dataloader_input_ids)} "
f"items, while calib_dataloader_attention_mask has {len(calib_dataloader_attention_mask)} items."
)
number_of_batched_samples = calib_size // batch_size
batched_input_ids = []
for idx, data in enumerate(calib_dataloader_input_ids):
batched_input_ids.append(data)
if idx == (number_of_batched_samples - 1):
break
batched_attention_mask = []
for idx, data in enumerate(calib_dataloader_attention_mask):
batched_attention_mask.append(data)
if idx == (number_of_batched_samples - 1):
break
batched_inputs_list = []
for i in range(number_of_batched_samples):
input_ids = batched_input_ids[i]
attention_mask = batched_attention_mask[i]
inputs = self.make_model_input(
config,
input_ids,
attention_mask,
add_past_kv_inputs,
device,
use_fp16,
use_buffer_share,
add_position_ids,
)
inputs = {input_name: torch_tensor.cpu().numpy() for input_name, torch_tensor in inputs.items()}
batched_inputs_list.append(inputs)
return batched_inputs_list
def quantize_awq(self, model: Union[ModelProto, str], quant_config: Dict[str, Any]) -> ModelProto:
"""Perform nvidia_awq quantization using ModelOpt's int4 quantize function.
Args:
model (ModelProto | str): The ONNX model or path to the model to quantize.
quant_config (Dict[str, Any]): Configuration dictionary for quantization.
Returns:
ModelProto: The quantized ONNX model.
"""
try:
from modelopt.onnx.quantization.int4 import quantize as quantize_int4
except ImportError:
logger.exception(
"Please ensure that 'modelopt' package is installed. Install it with 'pip install nvidia_modelopt'."
)
raise ImportError(
"modelopt is not installed. Please install it using 'pip install nvidia_modelopt'. Exiting."
) from None
logger.debug("Starting nvidia_awq quantization...")
# Prepare calibration inputs
calib_inputs = quant_config["calibration_data_reader"]
# Perform quantization using ModelOpt's int4 quantize function
quantized_model = quantize_int4(
model,
calibration_method=quant_config["calibration_method"],
calibration_data_reader=calib_inputs,
)
logger.debug("Completed nvidia_awq quantization.")
return quantized_model
def convert_opset_to_21_proto(self, model_proto: ModelProto) -> ModelProto:
"""Modify the model's opset to 21 if it's not already, operating on a ModelProto.
Args:
model_proto (ModelProto): The ONNX model proto to modify.
Returns:
ModelProto: The updated ONNX model proto with opset version 21.
"""
current_opset = {opset.domain: opset.version for opset in model_proto.opset_import}
default_domain_version = current_opset.get("", 0)
if default_domain_version >= 21:
logger.debug(
"Model already uses opset version %s for the default domain. Skip conversion.", default_domain_version
)
return model_proto # No conversion needed
new_opset_imports = [
helper.make_opsetid("", 21), # Default domain with opset version 21
helper.make_opsetid("com.microsoft", 1), # Microsoft domain with version 1
]
for domain, version in current_opset.items():
if domain not in ["", "com.microsoft"]:
new_opset_imports.append(helper.make_opsetid(domain, version))
# Update the model's opset imports
model_proto.ClearField("opset_import")
model_proto.opset_import.extend(new_opset_imports)
logger.debug("Model opset successfully converted to 21.")
return model_proto
def _run_for_config(
self, model: OliveModelHandler, config: Dict[str, Any], output_model_path: str
) -> OliveModelHandler:
try:
from modelopt.onnx.quantization.int4 import quantize_int4 # type: ignore[import]
except ImportError as exc:
raise ImportError(
"Please install `olive-ai[nvmo]` or `nvidia-modelopt[onnx]` to use INT4 AWQ quantization!"
) from exc
logger.debug("Loading the original ONNX model from %s.", model.model_path)
quant_config = self.initialize_quant_config(config)
data_config = validate_config(config["data_config"], DataConfig)
calib_dataloader = data_config.to_data_container().create_dataloader()
# Perform quantization
quantized_model_proto = self.quantize_awq(
model=model.model_path,
quant_config=quant_config,
)
quantize_mode = (
"int4_awq_clip" if config["algorithm"] == NVModelOptQuantization.Algorithm.AWQ else "int4_rtn_dq"
)
q_model = quantize_int4(quantize_mode, model.load_model(), calib_dataloader)
# Convert opset to 21 if required
converted_model_proto = self.convert_opset_to_21_proto(quantized_model_proto)
# save the model to the output path and return the model
output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)
return model_proto_to_olive_model(q_model, output_model_path, config)
output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)
onnx.save(converted_model_proto, output_model_path)
logger.debug("Quantized and opset-converted model saved to %s", output_model_path)
return model_proto_to_olive_model(converted_model_proto, output_model_path, config)
except Exception:
logger.exception("An error occurred during quantization and opset conversion")
raise

Просмотреть файл

@ -8,6 +8,7 @@ azureml-fsspec
# Pin azureml-metrics[all] greater than 0.0.26 to avoid breaking change in azureml-evaluate-mlflow
azureml-metrics[all]>=0.0.26
coverage
cppimport
datasets
docker>=7.1.0
evaluate
@ -15,9 +16,10 @@ git+https://github.com/microsoft/TransformerCompression.git ; python_version >=
mlflow>=2.4.0
neural-compressor
nncf==2.7.0
nvidia-modelopt~=0.11.0
nvidia-modelopt
onnx-graphsurgeon
onnxconverter_common
onnxmltools
onnxruntime_extensions
openvino==2023.2.0
optimum>=1.17.0

Просмотреть файл

@ -1,61 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
from test.unit_test.utils import get_onnx_model, get_pytorch_model_dummy_input
from typing import Any, Dict, Optional
from onnxruntime.quantization.calibrate import CalibrationDataReader # type: ignore[import]
from olive.data.config import DataComponentConfig, DataConfig
from olive.data.registry import Registry
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.nvmo_quantization import NVModelOptQuantization
class DummyCalibrationDataReader(CalibrationDataReader):
def __init__(self, batch_size: int = 16):
super().__init__()
self.sample_counter = 64
def get_next(self) -> Optional[Dict[Any, Any]]:
if self.sample_counter <= 0:
return None
data = get_pytorch_model_dummy_input()
try:
item = {"input": data.numpy()}
self.sample_counter -= 1
return item
except Exception:
return None
@Registry.register_dataloader()
def _test_nvmo_quat_dataloader(dataset, batch_size, **kwargs):
return DummyCalibrationDataReader(batch_size=batch_size)
def test_nvmo_quantization(tmp_path):
ov_model = get_onnx_model()
data_dir = tmp_path / "data"
data_dir.mkdir(exist_ok=True)
config = {
"data_config": DataConfig(
name="test_nvmo_quant_dc_config",
load_dataset_config=DataComponentConfig(type="simple_dataset", params={"data_dir": str(data_dir)}),
dataloader_config=DataComponentConfig(type="_test_nvmo_quat_dataloader"),
)
}
output_folder = str(tmp_path / "quantized")
# create NVModelOptQuantization pass and run quantization
p = create_pass_from_dict(NVModelOptQuantization, config, disable_search=True)
quantized_model = p.run(ov_model, output_folder)
# assert
assert quantized_model.model_path.endswith(".onnx")
assert Path(quantized_model.model_path).exists()
assert Path(quantized_model.model_path).is_file()
assert "DequantizeLinear" in [node.op_type for node in quantized_model.load_model().graph.node]
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
from test.unit_test.utils import get_onnx_model
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.nvmo_quantization import NVModelOptQuantization
def test_nvmo_quantization(tmp_path):
ov_model = get_onnx_model()
data_dir = tmp_path / "data"
data_dir.mkdir(exist_ok=True)
# Configuration with default values and random_calib_data set to True
config = {
"calibration": "awq_lite",
"random_calib_data": True,
}
output_folder = str(tmp_path / "quantized")
# Create NVModelOptQuantization pass and run quantization
p = create_pass_from_dict(NVModelOptQuantization, config, disable_search=True)
quantized_model = p.run(ov_model, output_folder)
# Assertions to check if quantization was successful
assert quantized_model.model_path.endswith(".onnx")
assert Path(quantized_model.model_path).exists()
assert Path(quantized_model.model_path).is_file()
assert "DequantizeLinear" in [node.op_type for node in quantized_model.load_model().graph.node]