HfModelHandler: Use `optimum` for automatic `io_config` and `dummy_inputs` (#1317)

This commit is contained in:
Jambay Kinley 2024-08-15 22:44:41 -07:00 коммит произвёл GitHub
Родитель a1c351ae92
Коммит 2f592507dc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
31 изменённых файлов: 274 добавлений и 525 удалений

Просмотреть файл

@ -56,7 +56,7 @@ Olive can automatically retrieve model configurations from Huggingface hub:
- Olive retrieves model [configuration](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoConfig) from transformers for future usage.
- Olive simplifies the process by automatically fetching configurations such as IO config and dummy input required for the `OnnxConversion` pass from [OnnxConfig](https://huggingface.co/docs/transformers/main_classes/onnx#onnx-configurations). This means there's no need for you to manually specify the IO config when using the `OnnxConversion` pass.
- Olive simplifies the process by automatically fetching configurations such as IO config and dummy input required for the `OnnxConversion` pass from [OnnxConfig](https://huggingface.co/docs/optimum/main/en/exporters/onnx/package_reference/configuration#optimum.exporters.onnx.OnnxConfig) if `optimum` is installed and the `model_type` and `task` are supported. This means there's no need for you to manually specify the IO config when using the `OnnxConversion` pass.
You can also provide your own IO config which will override the automatically fetched IO config and dummy inputs:

Просмотреть файл

@ -7,6 +7,7 @@ dependencies:
- pip:
- datasets
- evaluate
- optimum
- psutil
- scipy
- scikit-learn

Просмотреть файл

@ -7,6 +7,7 @@ dependencies:
- pip:
- datasets
- evaluate
- optimum
- psutil
- scipy
- scikit-learn

Просмотреть файл

@ -19,6 +19,7 @@ RUN pip install pandas \
psutil \
datasets \
transformers \
optimum \
onnxruntime-openvino \
"numpy<2.0" \
evaluate \

Просмотреть файл

@ -4,6 +4,7 @@ datasets
docker>=7.1.0
evaluate
neural-compressor
optimum
scikit-learn
scipy
tabulate

Просмотреть файл

@ -2,17 +2,6 @@
"input_model": {
"type": "HfModel",
"load_kwargs": { "attn_implementation": "eager" },
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int64", "int64", "int64" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
},
"model_path": "meta-llama/Llama-2-7b-hf"
},
"systems": {

Просмотреть файл

@ -2,18 +2,7 @@
"input_model": {
"type": "HfModel",
"model_path": "<model_name_placeholder>",
"load_kwargs": { "attn_implementation": "eager" },
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
"load_kwargs": { "attn_implementation": "eager" }
},
"data_configs": [
{

Просмотреть файл

@ -1,19 +1,5 @@
{
"input_model": {
"type": "HfModel",
"model_path": "meta-llama/Llama-2-7b-hf",
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
},
"input_model": { "type": "HfModel", "model_path": "meta-llama/Llama-2-7b-hf" },
"systems": {
"local_system": {
"type": "LocalSystem",

Просмотреть файл

@ -11,6 +11,7 @@ dependencies:
- bitsandbytes
- datasets
- huggingface_hub
- optimum
- peft
- scipy
- sentencepiece

Просмотреть файл

@ -12,17 +12,6 @@
"name": "Llama-2-7b",
"registry_name": "azureml-meta",
"version": "13"
},
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
},
"systems": {

Просмотреть файл

@ -2,18 +2,7 @@
"input_model": {
"type": "HfModel",
"model_path": "meta-llama/Llama-2-7b-hf",
"load_kwargs": { "attn_implementation": "eager" },
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
"load_kwargs": { "attn_implementation": "eager" }
},
"data_configs": [
{ "name": "transformer_token_dummy_data", "type": "TransformersTokenDummyDataContainer" },

Просмотреть файл

@ -2,18 +2,7 @@
"input_model": {
"type": "HfModel",
"model_path": "meta-llama/Llama-2-7b-hf",
"load_kwargs": { "attn_implementation": "eager" },
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
"load_kwargs": { "attn_implementation": "eager" }
},
"data_configs": [
{ "name": "transformer_token_dummy_data", "type": "TransformersTokenDummyDataContainer" },

Просмотреть файл

@ -2,18 +2,7 @@
"input_model": {
"type": "HfModel",
"model_path": "meta-llama/Llama-2-7b-hf",
"load_kwargs": { "attn_implementation": "eager" },
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
"load_kwargs": { "attn_implementation": "eager" }
},
"data_configs": [
{

Просмотреть файл

@ -6,6 +6,7 @@ dependencies:
- pip=22.3.1
- pip:
- datasets
- optimum
- sentencepiece
- transformers
- git+https://github.com/microsoft/Olive#egg=olive-ai[gpu]

Просмотреть файл

@ -4,21 +4,7 @@
"resource_group": "<resource_group>",
"workspace_name": "<workspace_name>"
},
"input_model": {
"type": "HfModel",
"model_path": "openlm-research/open_llama_3b",
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
},
"input_model": { "type": "HfModel", "model_path": "openlm-research/open_llama_3b" },
"systems": {
"aml": {
"type": "AzureML",

Просмотреть файл

@ -1,19 +1,5 @@
{
"input_model": {
"type": "HfModel",
"model_path": "openlm-research/open_llama_3b",
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
},
"input_model": { "type": "HfModel", "model_path": "openlm-research/open_llama_3b" },
"data_configs": [ { "name": "transformer_token_dummy_data", "type": "TransformersTokenDummyDataContainer" } ],
"evaluators": {
"common_evaluator": {

Просмотреть файл

@ -1,19 +1,5 @@
{
"input_model": {
"type": "HfModel",
"model_path": "openlm-research/open_llama_3b",
"io_config": {
"input_names": [ "input_ids", "attention_mask", "position_ids" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" },
"position_ids": { "0": "batch_size", "1": "sequence_length" }
}
}
},
"input_model": { "type": "HfModel", "model_path": "openlm-research/open_llama_3b" },
"data_configs": [
{
"name": "quant_data_config",

Просмотреть файл

@ -3,5 +3,6 @@ intel-extension-for-transformers
lm-eval==0.4.2
neural-compressor>=2.3
onnxruntime
optimum
sentencepiece
transformers

Просмотреть файл

@ -1,2 +1,3 @@
datasets
optimum
sentencepiece

Просмотреть файл

@ -1,26 +1,5 @@
{
"input_model": {
"type": "HfModel",
"io_config": {
"model_path": "facebook/opt_125m",
"task": "text-generation",
"input_names": [ "input_ids", "attention_mask" ],
"output_names": [ "logits" ],
"input_shapes": [ [ 2, 8 ], [ 2, 8 ] ],
"input_types": [ "int32", "int32" ],
"dynamic_axes": {
"input_ids": { "0": "batch_size", "1": "sequence_length" },
"attention_mask": { "0": "batch_size", "1": "total_sequence_length" }
},
"kv_cache": {
"ort_past_key_name": "past_key_<id>",
"ort_past_value_name": "past_value_<id>",
"ort_present_key_name": "present_key_<id>",
"ort_present_value_name": "present_value_<id>",
"dtype": "float16"
}
}
},
"input_model": { "type": "HfModel", "model_path": "facebook/opt_125m" },
"systems": {
"local_system": {
"type": "LocalSystem",
@ -31,12 +10,7 @@
{
"name": "transformer_prompt_dummy_data",
"type": "TransformersPromptDummyDataContainer",
"load_dataset_config": {
"ignore_input_fields": [ "position_ids" ],
"use_step": true,
"ort_past_key_name": "past_key_<id>",
"ort_past_value_name": "past_value_<id>"
}
"load_dataset_config": { "ignore_input_fields": [ "position_ids" ], "use_step": true }
}
],
"evaluators": {

Просмотреть файл

@ -1 +1,2 @@
autoawq
optimum

Просмотреть файл

@ -289,21 +289,7 @@ class FineTuneCommand(BaseOliveCLICommand):
TEMPLATE = {
"input_model": {
"type": "HfModel",
"load_kwargs": {"attn_implementation": "eager"},
"io_config": {
"input_names": ["input_ids", "attention_mask", "position_ids"],
"output_names": ["logits"],
"input_shapes": [[2, 8], [2, 8], [2, 8]],
"input_types": ["int64", "int64", "int64"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "sequence_length"},
"attention_mask": {"0": "batch_size", "1": "total_sequence_length"},
"position_ids": {"0": "batch_size", "1": "sequence_length"},
},
},
},
"input_model": {"type": "HfModel", "load_kwargs": {"attn_implementation": "eager"}},
"systems": {
"local_system": {
"type": "LocalSystem",

Просмотреть файл

@ -3,35 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mapping from task to feature
TASK_TO_FEATURE = {
"automatic-speech-recognition": "speech2seq-lm",
"fill-mask": "masked-lm",
"image-classification": "image-classification",
"image-segmentation": "image-segmentation",
"image-to-text": "vision2seq-lm",
"multiple-choice": "multiple-choice",
"ner": "token-classification",
"object-detection": "object-detection",
"question-answering": "question-answering",
"sentiment-analysis": "sequence-classification",
"summarization": "seq2seq-lm",
"text2text-generation": "seq2seq-lm",
"text-classification": "sequence-classification",
"text-generation": "causal-lm",
"token-classification": "token-classification",
"translation": "seq2seq-lm",
}
# mapping from feature to peft task type
# mapping from task to peft task type
# refer to peft.utils.peft_types.TaskType for all possible values
FEATURE_TO_PEFT_TASK_TYPE = {
"sequence-classification": "SEQ_CLS",
"seq2seq-lm": "SEQ_2_SEQ_LM",
"causal-lm": "CAUSAL_LM",
"token-classification": "TOKEN_CLS",
"question-answering": "QUESTION_ANS",
# TODO(jambayk): see if we need feature extraction
TASK_TO_PEFT_TASK_TYPE = {
"text-classification": "SEQ_CLS",
"text-generation": "CAUSAL_LM",
# TODO(jambayk): see if we need more task types
}
# model_type -> name for layers

Просмотреть файл

@ -3,118 +3,136 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Callable, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional
from olive.common.hf.utils import get_feature_from_task, get_model_config, get_tokenizer
from olive.common.utils import get_attr
if TYPE_CHECKING:
from transformers.onnx import OnnxConfig
from olive.common.hf.mlflow import get_pretrained_name_or_path
from olive.common.hf.utils import get_model_config, get_tokenizer, is_peft_model
logger = logging.getLogger(__name__)
# patched version of transformers.onnx.features.supported_features_mapping
# to support additional models in olive
def patched_supported_features_mapping(
*supported_features: str, onnx_config_cls: Optional[str] = None
) -> Dict[str, Callable]:
"""Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
Args:
*supported_features: The names of the supported features.
onnx_config_cls: The OnnxConfig full name corresponding to the model.
Returns:
The dictionary mapping a feature to an OnnxConfig constructor.
"""
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
from olive.common.hf import onnx_config
config_cls = get_attr(onnx_config, onnx_config_cls)
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
mapping[feature] = partial(config_cls.with_past, task=feature.replace("-with-past", ""))
else:
mapping[feature] = partial(config_cls.from_model_config, task=feature)
return mapping
if TYPE_CHECKING:
from optimum.exporters.onnx import OnnxConfig
from transformers import PreTrainedModel
# TODO(jambayk): switch to optimum backend and make this an optional feature
# remove "feature" entirely from the codebase
def get_onnx_config(model_name: str, task: str, feature: Optional[str] = None, **kwargs) -> "OnnxConfig":
# pylint: disable=protected-access
from transformers.onnx import FeaturesManager
def get_preprocessors(model_name: str, **kwargs) -> Optional[Dict]:
"""Get the preprocessors for the model_name."""
from optimum.utils.save_utils import maybe_load_preprocessors
from olive.common.hf.onnx_config import ADDITIONAL_MODEL_TYPES
# patch FeaturesManager._SUPPORTED_MODEL_TYPE to support additional models in olive
for model_type, feature_list in ADDITIONAL_MODEL_TYPES.items():
if model_type in FeaturesManager._SUPPORTED_MODEL_TYPE:
continue
# TODO(trajep): remove the need for unpacking feature_list
features, onnx_config_cls = feature_list
FeaturesManager._SUPPORTED_MODEL_TYPE[model_type] = patched_supported_features_mapping(
*features, onnx_config_cls=onnx_config_cls
)
# if feature is not provided, try to get it from task
# else use "default"
feature = feature or get_feature_from_task(task) or "default"
# don't want to load the model here since all we need is the config
# model loading is expensive computationally and memory-wise for large models
config = get_model_config(model_name, **kwargs)
# recreate the logic for FeaturesManager.check_supported_model_or_raise to get the model_onnx_config
# https://github.com/huggingface/transformers/blob/main/src/transformers/onnx/features.py#L712
model_type = config.model_type.replace("_", "-")
onnx_config = None
# get tokenizer separately to support mlflow models
tokenizer = None
try:
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
if feature in model_features:
onnx_config = FeaturesManager.get_config(model_type, feature)(config)
else:
logger.debug(
"%s doesn't support feature %s. Supported features are: %s", model_type, feature, model_features
)
tokenizer = get_tokenizer(model_name, **kwargs)
except Exception:
# there is no tokenizer for the model_name
pass
model_name = get_pretrained_name_or_path(model_name, "model")
preprocessors = maybe_load_preprocessors(
model_name, subfolder=kwargs.get("subfolder", ""), trust_remote_code=kwargs.get("trust_remote_code", False)
)
if tokenizer:
for i, preprocessor in enumerate(preprocessors):
if isinstance(preprocessor, type(tokenizer)):
preprocessors[i] = tokenizer
break
return preprocessors
def get_export_config(model_name: str, task: str, **kwargs) -> Optional["OnnxConfig"]:
"""Get the export config for the model_name and task."""
try:
from optimum.exporters.tasks import TasksManager
except ImportError:
logger.debug("optimum is not installed. Cannot get export config")
return None
model_config = get_model_config(model_name, **kwargs)
model_type = model_config.model_type.replace("_", "-")
task = TasksManager.map_from_synonym(task)
# use try except block since we don't want to access private class attributes like
# TasksManager._SUPPORTED_MODEL_TYPE
try:
supported_tasks = TasksManager.get_supported_tasks_for_model_type(
model_type, exporter="onnx", library_name="transformers"
)
if task not in supported_tasks:
logger.debug("Task %s is not supported for model type %s", task, model_type)
return None
except KeyError:
logger.debug("Model type %s is not supported", model_type)
return None
return onnx_config
# TODO(jambayk): ask caller for dtype?
dtype = getattr(model_config, "torch_dtype", "float32")
if "bfloat16" in str(dtype):
float_dtype = "bf16"
elif "float16" in str(dtype):
float_dtype = "fp16"
else:
float_dtype = "fp32"
export_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task=task, model_type=model_type, library_name="transformers"
)
export_config = export_config_constructor(
model_config,
int_dtype="int64",
float_dtype=float_dtype,
# TODO(jambayk): other preprocessors needed?
preprocessors=get_preprocessors(model_name, **kwargs),
)
if task.startswith("text-generation"):
# need kv cache for both input and output
export_config = export_config.__class__(
model_config,
use_past=export_config.use_past,
use_past_in_inputs=export_config.use_past,
# text-generation-with-past doesn't return position_ids
task="text-generation",
float_dtype=float_dtype,
int_dtype="int64",
)
return export_config
def get_model_io_config(model_name: str, task: str, feature: Optional[str] = None, **kwargs):
def get_model_io_config(model_name: str, task: str, model: "PreTrainedModel", **kwargs) -> Optional[Dict]:
"""Get the input/output config for the model_name and task."""
# just log a debug message if io_config is not found
# this is not a critical error and the caller may not need the io_config
model_config = get_onnx_config(model_name, task, feature, **kwargs)
if not model_config:
export_config = get_export_config(model_name, task, **kwargs)
if not export_config:
return None
inputs = model_config.inputs
outputs = model_config.outputs
if not inputs or not outputs:
# just log a warning and return None, since this is not a critical error
# and following pass may not use the io_config, like OptimumConversion
logger.debug("No inputs or outputs found from hf onnx_config %s. Won't use it to get io config", model_config)
if is_peft_model(model):
# if pytorch_model is PeftModel, we need to get the base model
# otherwise, the model forward has signature (*args, **kwargs)
model = model.get_base_model()
inputs = export_config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(export_config.outputs.keys())
dynamic_axes = dict(chain(inputs.items(), export_config.outputs.items()))
# optimum has the total sequence length as "past_sequence_length + 1" but that is not always the case
# change it to "past_sequence_length + sequence_length" if past is used
for value in dynamic_axes.values():
for axis, axis_name in value.items():
if axis_name == "past_sequence_length + 1":
value[axis] = "past_sequence_length + sequence_length"
return {"input_names": input_names, "output_names": output_names, "dynamic_axes": dynamic_axes}
def get_model_dummy_input(model_name: str, task: str, **kwargs) -> Optional[Dict]:
"""Get dummy inputs for the model_name and task."""
export_config = get_export_config(model_name, task, **kwargs)
if not export_config:
return None
io_config = {}
io_config["input_names"] = list(inputs.keys())
io_config["output_names"] = list(outputs.keys())
io_config["dynamic_axes"] = dict(chain(inputs.items(), outputs.items()))
return io_config
from optimum.utils import DEFAULT_DUMMY_SHAPES
def get_model_dummy_input(model_name: str, task: str, feature: Optional[str] = None, **kwargs):
model_config = get_onnx_config(model_name, task, feature, **kwargs)
if not model_config:
return None
tokenizer = get_tokenizer(model_name)
return model_config.generate_dummy_inputs(tokenizer, framework="pt")
dummy_inputs = export_config.generate_dummy_inputs(framework="pt", **DEFAULT_DUMMY_SHAPES)
return export_config.rename_ambiguous_inputs(dummy_inputs)

Просмотреть файл

@ -1,114 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from collections import OrderedDict
from typing import Any, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType
from transformers.onnx import OnnxConfigWithPast
# dictionary of model types and their (supported features list, config class name)
# the supported features list is the list of features that are supported by the model type
# similar to the tasks supported by the model type. Refer to `task_to_feature` under `get_onnx_config` in
# `hf_utils.py` for a mapping from task to feature.
ADDITIONAL_MODEL_TYPES = {
"gpt-neox": (
[
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
"token-classification",
],
"TextDecoderOnnxConfig",
),
"llama": (
[
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
],
"TextDecoderOnnxConfig",
),
"opt": (
[
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"question-answering",
"sequence-classification",
],
"TextDecoderOnnxConfig",
),
}
class TextDecoderOnnxConfig(OnnxConfigWithPast):
# in OnnxConfigWithPast.fill_with_past_key_values_
# there is a bug in the name for the present sequence length dimension
# it should be `past_sequence` instead of `past_sequence + sequence`
def fill_with_past_key_values_(
self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False
):
"""Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
Args:
inputs_or_outputs: The mapping to fill.
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
output mapping, this is important for axes naming.
inverted_values_shape:
If `True`, store values on dynamic axis 1, else on axis 2.
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
name = "past_key_values" if direction == "inputs" else "present"
sequence_length_name = "past_sequence" if direction == "inputs" else "past_sequence + sequence"
for i in range(self.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: sequence_length_name}
if inverted_values_shape:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: sequence_length_name}
else:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: sequence_length_name}
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
# there seems to be a bug in the size of the past_key_values dim 2
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()
input_order = ["input_ids"]
if self.use_past:
input_order += ["past_key_values"]
input_order += ["attention_mask"]
return OrderedDict({k: common_inputs[k] for k in input_order})
@property
def num_layers(self) -> int:
return self._config.num_hidden_layers

Просмотреть файл

@ -2,17 +2,19 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import importlib
import logging
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from transformers import AutoConfig, AutoModel, AutoTokenizer, GenerationConfig
from olive.common.hf.mappings import FEATURE_TO_PEFT_TASK_TYPE, MODELS_TO_MAX_LENGTH_MAPPING, TASK_TO_FEATURE
from olive.common.hf.mappings import MODELS_TO_MAX_LENGTH_MAPPING, TASK_TO_PEFT_TASK_TYPE
from olive.common.hf.mlflow import get_pretrained_name_or_path
from olive.common.utils import hardlink_copy_file
if TYPE_CHECKING:
import torch
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
@ -175,24 +177,9 @@ def save_tokenizer(
return tokenizer.save_pretrained(output_dir, **kwargs)
# TODO(jambayk): Remove this once we transition away from using "feature"
def get_feature_from_task(task: str, fail_on_not_found=False) -> str:
"""Get feature from task."""
feature = TASK_TO_FEATURE.get(task.replace("-with-past", ""), None)
not_found_msg = f"There is no feature for task {task}"
if feature is None and fail_on_not_found:
raise ValueError(not_found_msg)
elif feature is None:
logger.warning(not_found_msg)
elif task.endswith("-with-past"):
feature += "-with-past"
return feature
def get_peft_task_type_from_task(task: str, fail_on_not_found=False) -> str:
"""Get peft task type from feature."""
feature = get_feature_from_task(task)
peft_task_type = FEATURE_TO_PEFT_TASK_TYPE.get(feature.replace("-with-past", ""), None) if feature else None
"""Get peft task type from task."""
peft_task_type = TASK_TO_PEFT_TASK_TYPE.get(task.replace("-with-past", ""), None)
not_found_msg = f"There is no peft task type for task {task}"
if peft_task_type is None and fail_on_not_found:
raise ValueError(not_found_msg)
@ -226,3 +213,12 @@ def get_model_max_length(model_name_or_path: str, fail_on_not_found=False) -> in
else:
logger.warning(not_found_msg)
return None
def is_peft_model(model: "torch.nn.Module") -> bool:
"""Check if the model is a PeftModel."""
if importlib.util.find_spec("peft"):
from peft import PeftModel
return isinstance(model, PeftModel)
return False

Просмотреть файл

@ -99,10 +99,10 @@ class HfModelHandler(PyTorchModelHandlerBase, MLFlowTransformersMixin, HfMixin):
self._io_config, force_kv_cache=self.task.endswith("-with-past"), model_attributes=self.model_attributes
)
else:
logger.debug("Trying hf onnx_config to get io_config")
logger.debug("Trying hf optimum export config to get io_config")
io_config = self.get_hf_io_config()
if io_config:
logger.debug("Got io_config from hf onnx_config")
logger.debug("Got io_config from hf optimum export config")
return io_config
@ -120,13 +120,16 @@ class HfModelHandler(PyTorchModelHandlerBase, MLFlowTransformersMixin, HfMixin):
if dummy_inputs:
return dummy_inputs
logger.debug("Trying hf onnx_config to get dummy inputs")
logger.debug("Trying hf optimum export config to get dummy inputs")
dummy_inputs = self.get_hf_dummy_inputs()
if dummy_inputs:
logger.debug("Got dummy inputs from hf onnx_config")
logger.debug("Got dummy inputs from hf optimum export config")
if dummy_inputs is None:
raise ValueError("Unable to get dummy inputs for the model.")
raise ValueError(
"Unable to get dummy inputs for the model. Please provide io_config or install an optimum version that"
" supports the model for export."
)
return dummy_inputs

Просмотреть файл

@ -101,7 +101,7 @@ class HfMixin:
def get_hf_io_config(self) -> Optional[Dict[str, Any]]:
"""Get Io config for the model."""
return get_model_io_config(self.model_path, self.task, **self.get_load_kwargs())
return get_model_io_config(self.model_path, self.task, self.load_model(), **self.get_load_kwargs())
def get_hf_dummy_inputs(self) -> Optional[Dict[str, Any]]:
"""Get dummy inputs for the model."""

Просмотреть файл

@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import importlib
import logging
import multiprocessing
import tempfile
@ -15,6 +14,7 @@ import torch
from packaging import version
from olive.common.config_utils import validate_config
from olive.common.hf.utils import is_peft_model
from olive.common.utils import find_submodules, resolve_torch_dtype, tensor_data_to_device
from olive.hardware import AcceleratorSpec
from olive.model import (
@ -25,7 +25,6 @@ from olive.model import (
PyTorchModelHandler,
)
from olive.model.config import IoConfig
from olive.model.config.hf_config import HfLoadKwargs
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
@ -45,16 +44,6 @@ class TraceModelWrapper(torch.nn.Module):
return self.model(*input_data, **input_dict)
# TODO(jambayk): Consider conditional import of PeftModel so that we can type hint it.
def is_peft_model(model: torch.nn.Module) -> bool:
"""Check if the model is a PeftModel."""
if importlib.util.find_spec("peft"):
from peft import PeftModel
return isinstance(model, PeftModel)
return False
class OnnxConversion(Pass):
"""Convert a PyTorch model to ONNX model using torch.onnx.export on CPU."""
@ -277,14 +266,14 @@ class OnnxConversion(Pass):
return onnx_model
@staticmethod
def _load_pytorch_model(
model: Union[HfModelHandler, PyTorchModelHandler], device: str, torch_dtype: Optional[torch.dtype] = None
) -> Tuple[torch.nn.Module, Optional[Dict]]:
"""Load the model and return the model and the model attributes.
def _prepare_hf_model(
model: HfModelHandler, device: str, torch_dtype: Optional[torch.dtype] = None
) -> HfModelHandler:
"""Prepare the HfModelHandler for conversion.
This method handles the following cases:
1. PyTorchModelHandler or HfModelHandler with no load kwargs
- load the model directly
1. HfModelHandler with no load kwargs
- no need to change the model
2. HfModelHandler with load kwargs
- update load_kwargs.torch_dtype if torch_dtype is specified
- if torch_dtype not specified, make sure the load kwargs specify a dtype that is supported for
@ -293,77 +282,66 @@ class OnnxConversion(Pass):
- remove quantization config from the load kwargs
- find quantized modules and add them to the model attributes
- the onnx model must be quantized using OnnxBnb4Quantization pass after conversion
Model attributes is None if the output model should inherit the model attributes from the input model.
"""
pytorch_model = None
if not model.load_kwargs:
return model
model_attributes = deepcopy(model.model_attributes or {})
if isinstance(model, PyTorchModelHandler) or not model.load_kwargs:
# if the model is a PyTorchModelHandler or HfModelHandler with no load kwargs,
# we can load the model directly
pytorch_model = model.load_model()
else:
load_kwargs = model.load_kwargs
model_dtype = load_kwargs.get_torch_dtype()
new_load_kwargs = deepcopy(load_kwargs.dict())
if torch_dtype and torch_dtype != model_dtype:
# if the load kwargs specify a different dtype, update the load kwargs
logger.debug(
"Changing torch_dtype in load kwargs from %s to %s.",
load_kwargs.get_torch_dtype(),
torch_dtype,
)
new_load_kwargs["torch_dtype"] = torch_dtype
model_attributes["torch_dtype"] = str(torch_dtype).replace("torch.", "")
elif model_dtype == torch.float16 and device == "cpu":
logger.warning(
"Loading model on CPU, but the load kwargs specify dtype float16 which is not supported for"
" conversion on CPU. The dtype is changed to float32. If float16 model is desired, please specify"
" device as 'cuda' or use OrtTransformerOptimization/OnnxFloatToFloat16 pass after conversion to"
" convert the model to float16."
)
new_load_kwargs["torch_dtype"] = torch.float32
model_attributes["torch_dtype"] = "float32"
load_kwargs = model.load_kwargs
model_dtype = load_kwargs.get_torch_dtype()
new_load_kwargs = deepcopy(load_kwargs.dict())
if load_kwargs.quantization_method == "bitsandbytes" and load_kwargs.quantization_config["load_in_4bit"]:
logger.debug(
"Bitsandbytes 4bit quantization is not supported for conversion. The quantization config is removed"
" from the load kwargs. Use OnnxBnb4Quantization pass after conversion to quantize the"
" model."
)
new_load_kwargs["quantization_method"] = None
new_load_kwargs["quantization_config"] = None
model_attributes["quantization_config"] = load_kwargs.quantization_config
if "quantized_modules" not in model_attributes:
# find and add quantized modules to the model attributes
# the QLoRA pass already adds quantized_modules to the model attributes, so this will not be
# executed if the model was generated by QLoRA
quantized_model = model.load_model()
if torch_dtype and torch_dtype != model_dtype:
# if the load kwargs specify a different dtype, update the load kwargs
logger.debug(
"Changing torch_dtype in load kwargs from %s to %s.",
load_kwargs.get_torch_dtype(),
torch_dtype,
)
new_load_kwargs["torch_dtype"] = torch_dtype
model_attributes["torch_dtype"] = str(torch_dtype).replace("torch.", "")
elif model_dtype == torch.float16 and device == "cpu":
logger.warning(
"Loading model on CPU, but the load kwargs specify dtype float16 which is not supported for"
" conversion on CPU. The dtype is changed to float32. If float16 model is desired, please specify"
" device as 'cuda' or use OrtTransformerOptimization/OnnxFloatToFloat16 pass after conversion to"
" convert the model to float16."
)
new_load_kwargs["torch_dtype"] = torch.float32
model_attributes["torch_dtype"] = "float32"
# if PeftModel, need to unload adapter before finding quantized modules
if is_peft_model(quantized_model):
quantized_model = quantized_model.unload()
if load_kwargs.quantization_method == "bitsandbytes" and load_kwargs.quantization_config["load_in_4bit"]:
logger.debug(
"Bitsandbytes 4bit quantization is not supported for conversion. The quantization config is removed"
" from the load kwargs. Use OnnxBnb4Quantization pass after conversion to quantize the"
" model."
)
new_load_kwargs["quantization_method"] = None
new_load_kwargs["quantization_config"] = None
model_attributes["quantization_config"] = load_kwargs.quantization_config
if "quantized_modules" not in model_attributes:
# find and add quantized modules to the model attributes
# the QLoRA pass already adds quantized_modules to the model attributes, so this will not be
# executed if the model was generated by QLoRA
quantized_model = model.load_model()
import bitsandbytes as bnb
# if PeftModel, need to unload adapter before finding quantized modules
if is_peft_model(quantized_model):
quantized_model = quantized_model.unload()
model_attributes["quantized_modules"] = find_submodules(quantized_model, bnb.nn.Linear4bit)
import bitsandbytes as bnb
# required for peft models since unloading changes the model
# for others, do this to free gpu memory as quantized model is always on gpu
del quantized_model
model.model = None
model_attributes["quantized_modules"] = find_submodules(quantized_model, bnb.nn.Linear4bit)
# load the model with the updated load kwargs
pytorch_model = HfModelHandler(
model_path=model.model_path,
task=model.task,
adapter_path=model.adapter_path,
load_kwargs=HfLoadKwargs(**new_load_kwargs),
).load_model()
# required for peft models since unloading changes the model
# for others, do this to free gpu memory as quantized model is always on gpu
del quantized_model
model.model = None
if is_peft_model(pytorch_model):
model_attributes["lora_modules"] = list(pytorch_model.peft_config["default"].target_modules)
return pytorch_model, model_attributes
model_config = model.to_json()["config"]
model_config["load_kwargs"] = new_load_kwargs
model_config["model_attributes"] = model_attributes
return HfModelHandler(**model_config)
def _convert_model_on_device(
self,
@ -374,8 +352,14 @@ class OnnxConversion(Pass):
torch_dtype: Optional[torch.dtype] = None,
) -> ONNXModelHandler:
"""Convert an HfModelHandler or PyTorchModelHandler to an ONNXModelHandler."""
# prepare the model for conversion
if isinstance(model, HfModelHandler):
# optimum export config needs the loaded model to get io_config so we create a new model handler
# which will be used to load the model and get the io_config
model = self._prepare_hf_model(model, device, torch_dtype)
# load the model
pytorch_model, model_attributes = self._load_pytorch_model(model, device, torch_dtype)
pytorch_model = model.load_model()
if config["merge_adapter_weights"] and is_peft_model(pytorch_model):
logger.debug("Merging adapter weights into base model. This is specific to PeftModel.")
pytorch_model = pytorch_model.merge_and_unload()
@ -392,7 +376,9 @@ class OnnxConversion(Pass):
# save the model to the output path and return the model
output_model_path = resolve_onnx_path(output_model_path)
output_model = model_proto_to_olive_model(converted_onnx_model, output_model_path, config)
output_model.model_attributes = model_attributes
output_model.model_attributes = model_attributes = deepcopy(model.model_attributes or {})
if is_peft_model(pytorch_model):
model_attributes["lora_modules"] = list(pytorch_model.peft_config["default"].target_modules)
return output_model
@staticmethod

Просмотреть файл

@ -6,9 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest
import torch
from transformers.onnx import OnnxConfig
from olive.common.hf.model_io import get_onnx_config
from olive.common.hf.model_io import get_export_config, get_model_dummy_input, get_model_io_config
from olive.common.hf.utils import load_model_from_task
@ -63,12 +62,49 @@ def test_load_model_from_task_exception_handling(exceptions, expected_exception,
@pytest.mark.parametrize(
("model_name", "task", "feature"),
("model_name", "task"),
[
("hf-internal-testing/tiny-random-BertForSequenceClassification", "text-classification", "default"),
("hf-internal-testing/tiny-random-LlamaForCausalLM", "text-generation", "default"),
("hf-internal-testing/tiny-random-BertForSequenceClassification", "text-classification"),
("hf-internal-testing/tiny-random-LlamaForCausalLM", "text-generation"),
],
)
def test_get_onnx_config(model_name, task, feature):
onnx_config = get_onnx_config(model_name, task, feature)
assert isinstance(onnx_config, OnnxConfig)
def test_get_export_config(model_name, task):
from optimum.exporters.onnx import OnnxConfig
export_config = get_export_config(model_name, task)
assert isinstance(export_config, OnnxConfig)
def get_model_name_task(with_past: bool):
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
task = "text-generation"
if with_past:
task = "text-generation-with-past"
return model_name, task
@pytest.mark.parametrize("with_past", [True, False])
def test_get_model_dummy_input(with_past):
dummy_input = get_model_dummy_input(*get_model_name_task(with_past))
expected_keys = ["input_ids", "attention_mask", "position_ids"]
if with_past:
expected_keys.append("past_key_values")
assert set(dummy_input.keys()) == set(expected_keys)
@pytest.mark.parametrize("with_past", [True, False])
def test_get_model_io_config(with_past):
model_name, task = get_model_name_task(with_past)
model = load_model_from_task(task, model_name)
io_config = get_model_io_config(model_name, task, model)
expected_keys = ["input_names", "output_names", "dynamic_axes"]
assert set(io_config.keys()) == set(expected_keys)
expected_input_names = ["input_ids", "attention_mask", "position_ids"]
expected_output_names = ["logits"]
if with_past:
for layer_id in range(model.config.num_hidden_layers):
expected_input_names.extend([f"past_key_values.{layer_id}.key", f"past_key_values.{layer_id}.value"])
expected_output_names.extend([f"present.{layer_id}.key", f"present.{layer_id}.value"])
assert io_config["input_names"] == expected_input_names
assert io_config["output_names"] == expected_output_names
assert set(io_config["dynamic_axes"].keys()) == set(expected_input_names + expected_output_names)

Просмотреть файл

@ -188,7 +188,7 @@ class TestHFDummyInput(unittest.TestCase):
# get io config
io_config = olive_model.io_config
assert io_config == self.io_config
get_model_io_config.assert_called_once_with(self.model_name, self.task)
get_model_io_config.assert_called_once_with(self.model_name, self.task, olive_model.load_model())
@patch("olive.data.template.dummy_data_config_template")
def test_input_shapes_dummy_inputs(self, dummy_data_config_template):